Statistics
| Branch: | Tag: | Revision:

root / lib / netutils.py @ 7845b8c8

History | View | Annotate | Download (13 kB)

1
#
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti network utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import errno
31
import re
32
import socket
33
import struct
34
import IN
35

    
36
from ganeti import constants
37
from ganeti import errors
38

    
39
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
40
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
41
#
42
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
43
# pid_t to "int".
44
#
45
# IEEE Std 1003.1-2008:
46
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
47
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
48
_STRUCT_UCRED = "iII"
49
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
50

    
51

    
52
def GetSocketCredentials(sock):
53
  """Returns the credentials of the foreign process connected to a socket.
54

55
  @param sock: Unix socket
56
  @rtype: tuple; (number, number, number)
57
  @return: The PID, UID and GID of the connected foreign process.
58

59
  """
60
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
61
                             _STRUCT_UCRED_SIZE)
62
  return struct.unpack(_STRUCT_UCRED, peercred)
63

    
64

    
65
def GetHostname(name=None, family=None):
66
  """Returns a Hostname object.
67

68
  @type name: str
69
  @param name: hostname or None
70
  @type family: int
71
  @param family: AF_INET | AF_INET6 | None
72
  @rtype: L{Hostname}
73
  @return: Hostname object
74
  @raise: errors.OpPrereqError
75

76
  """
77
  try:
78
    return Hostname(name=name, family=family)
79
  except errors.ResolverError, err:
80
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
81
                               (err[0], err[2]), errors.ECODE_RESOLVER)
82

    
83

    
84
class Hostname:
85
  """Class implementing resolver and hostname functionality.
86

87
  """
88
  def __init__(self, name=None, family=None):
89
    """Initialize the host name object.
90

91
    If the name argument is None, it will use this system's name.
92

93
    @type family: int
94
    @param family: AF_INET | AF_INET6 | None
95
    @type name: str
96
    @param name: hostname or None
97

98
    """
99
    if name is None:
100
      name = self.GetSysName()
101

    
102
    self.name = self.GetNormalizedName(name)
103
    self.ip = self.GetIP(self.name, family=family)
104

    
105
  @staticmethod
106
  def GetSysName():
107
    """Return the current system's name.
108

109
    This is simply a wrapper over C{socket.gethostname()}.
110

111
    """
112
    return socket.gethostname()
113

    
114
  @staticmethod
115
  def GetIP(hostname, family=None):
116
    """Return IP address of given hostname.
117

118
    Supports both IPv4 and IPv6.
119

120
    @type hostname: str
121
    @param hostname: hostname to look up
122
    @type family: int
123
    @param family: AF_INET | AF_INET6 | None
124
    @rtype: str
125
    @return: IP address
126
    @raise errors.ResolverError: in case of errors in resolving
127

128
    """
129
    try:
130
      if family in (socket.AF_INET, socket.AF_INET6):
131
        result = socket.getaddrinfo(hostname, None, family)
132
      else:
133
        result = socket.getaddrinfo(hostname, None, socket.AF_INET)
134
    except (socket.gaierror, socket.herror, socket.error), err:
135
      # hostname not found in DNS, or other socket exception in the
136
      # (code, description format)
137
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
138

    
139
    # getaddrinfo() returns a list of 5-tupes (family, socktype, proto,
140
    # canonname, sockaddr). We return the first tuple's first address in
141
    # sockaddr
142
    return result[0][4][0]
143

    
144
  @staticmethod
145
  def GetNormalizedName(hostname):
146
    """Validate and normalize the given hostname.
147

148
    @attention: the validation is a bit more relaxed than the standards
149
        require; most importantly, we allow underscores in names
150
    @raise errors.OpPrereqError: when the name is not valid
151

152
    """
153
    valid_name_re = re.compile("^[a-z0-9._-]{1,255}$")
154
    hostname = hostname.lower()
155
    if (not valid_name_re.match(hostname) or
156
        # double-dots, meaning empty label
157
        ".." in hostname or
158
        # empty initial label
159
        hostname.startswith(".")):
160
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
161
                                 errors.ECODE_INVAL)
162
    if hostname.endswith("."):
163
      hostname = hostname.rstrip(".")
164
    return hostname
165

    
166

    
167
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
168
  """Simple ping implementation using TCP connect(2).
169

170
  Check if the given IP is reachable by doing attempting a TCP connect
171
  to it.
172

173
  @type target: str
174
  @param target: the IP or hostname to ping
175
  @type port: int
176
  @param port: the port to connect to
177
  @type timeout: int
178
  @param timeout: the timeout on the connection attempt
179
  @type live_port_needed: boolean
180
  @param live_port_needed: whether a closed port will cause the
181
      function to return failure, as if there was a timeout
182
  @type source: str or None
183
  @param source: if specified, will cause the connect to be made
184
      from this specific source address; failures to bind other
185
      than C{EADDRNOTAVAIL} will be ignored
186

187
  """
188
  try:
189
    family = IPAddress.GetAddressFamily(target)
190
  except errors.GenericError:
191
    return False
192

    
193
  sock = socket.socket(family, socket.SOCK_STREAM)
194
  success = False
195

    
196
  if source is not None:
197
    try:
198
      sock.bind((source, 0))
199
    except socket.error, (errcode, _):
200
      if errcode == errno.EADDRNOTAVAIL:
201
        success = False
202

    
203
  sock.settimeout(timeout)
204

    
205
  try:
206
    sock.connect((target, port))
207
    sock.close()
208
    success = True
209
  except socket.timeout:
210
    success = False
211
  except socket.error, (errcode, _):
212
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
213

    
214
  return success
215

    
216

    
217
def GetDaemonPort(daemon_name):
218
  """Get the daemon port for this cluster.
219

220
  Note that this routine does not read a ganeti-specific file, but
221
  instead uses C{socket.getservbyname} to allow pre-customization of
222
  this parameter outside of Ganeti.
223

224
  @type daemon_name: string
225
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
226
  @rtype: int
227

228
  """
229
  if daemon_name not in constants.DAEMONS_PORTS:
230
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
231

    
232
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
233
  try:
234
    port = socket.getservbyname(daemon_name, proto)
235
  except socket.error:
236
    port = default_port
237

    
238
  return port
239

    
240

    
241
class IPAddress(object):
242
  """Class that represents an IP address.
243

244
  """
245
  iplen = 0
246
  family = None
247
  loopback_cidr = None
248

    
249
  @staticmethod
250
  def _GetIPIntFromString(address):
251
    """Abstract method to please pylint.
252

253
    """
254
    raise NotImplementedError
255

    
256
  @classmethod
257
  def IsValid(cls, address):
258
    """Validate a IP address.
259

260
    @type address: str
261
    @param address: IP address to be checked
262
    @rtype: bool
263
    @return: True if valid, False otherwise
264

265
    """
266
    if cls.family is None:
267
      try:
268
        family = cls.GetAddressFamily(address)
269
      except errors.IPAddressError:
270
        return False
271
    else:
272
      family = cls.family
273

    
274
    try:
275
      socket.inet_pton(family, address)
276
      return True
277
    except socket.error:
278
      return False
279

    
280
  @classmethod
281
  def Own(cls, address):
282
    """Check if the current host has the the given IP address.
283

284
    This is done by trying to bind the given address. We return True if we
285
    succeed or false if a socket.error is raised.
286

287
    @type address: str
288
    @param address: IP address to be checked
289
    @rtype: bool
290
    @return: True if we own the address, False otherwise
291

292
    """
293
    if cls.family is None:
294
      try:
295
        family = cls.GetAddressFamily(address)
296
      except errors.IPAddressError:
297
        return False
298
    else:
299
      family = cls.family
300

    
301
    s = socket.socket(family, socket.SOCK_DGRAM)
302
    success = False
303
    try:
304
      try:
305
        s.bind((address, 0))
306
        success = True
307
      except socket.error:
308
        success = False
309
    finally:
310
      s.close()
311
    return success
312

    
313
  @classmethod
314
  def InNetwork(cls, cidr, address):
315
    """Determine whether an address is within a network.
316

317
    @type cidr: string
318
    @param cidr: Network in CIDR notation, e.g. '192.0.2.0/24', '2001:db8::/64'
319
    @type address: str
320
    @param address: IP address
321
    @rtype: bool
322
    @return: True if address is in cidr, False otherwise
323

324
    """
325
    address_int = cls._GetIPIntFromString(address)
326
    subnet = cidr.split("/")
327
    assert len(subnet) == 2
328
    try:
329
      prefix = int(subnet[1])
330
    except ValueError:
331
      return False
332

    
333
    assert 0 <= prefix <= cls.iplen
334
    target_int = cls._GetIPIntFromString(subnet[0])
335
    # Convert prefix netmask to integer value of netmask
336
    netmask_int = (2**cls.iplen)-1 ^ ((2**cls.iplen)-1 >> prefix)
337
    # Calculate hostmask
338
    hostmask_int = netmask_int ^ (2**cls.iplen)-1
339
    # Calculate network address by and'ing netmask
340
    network_int = target_int & netmask_int
341
    # Calculate broadcast address by or'ing hostmask
342
    broadcast_int = target_int | hostmask_int
343

    
344
    return network_int <= address_int <= broadcast_int
345

    
346
  @staticmethod
347
  def GetAddressFamily(address):
348
    """Get the address family of the given address.
349

350
    @type address: str
351
    @param address: ip address whose family will be returned
352
    @rtype: int
353
    @return: socket.AF_INET or socket.AF_INET6
354
    @raise errors.GenericError: for invalid addresses
355

356
    """
357
    try:
358
      return IP4Address(address).family
359
    except errors.IPAddressError:
360
      pass
361

    
362
    try:
363
      return IP6Address(address).family
364
    except errors.IPAddressError:
365
      pass
366

    
367
    raise errors.IPAddressError("Invalid address '%s'" % address)
368

    
369
  @classmethod
370
  def IsLoopback(cls, address):
371
    """Determine whether it is a loopback address.
372

373
    @type address: str
374
    @param address: IP address to be checked
375
    @rtype: bool
376
    @return: True if loopback, False otherwise
377

378
    """
379
    try:
380
      return cls.InNetwork(cls.loopback_cidr, address)
381
    except errors.IPAddressError:
382
      return False
383

    
384

    
385
class IP4Address(IPAddress):
386
  """IPv4 address class.
387

388
  """
389
  iplen = 32
390
  family = socket.AF_INET
391
  loopback_cidr = "127.0.0.0/8"
392

    
393
  def __init__(self, address):
394
    """Constructor for IPv4 address.
395

396
    @type address: str
397
    @param address: IP address
398
    @raises errors.IPAddressError: if address invalid
399

400
    """
401
    IPAddress.__init__(self)
402
    if not self.IsValid(address):
403
      raise errors.IPAddressError("IPv4 Address %s invalid" % address)
404

    
405
    self.address = address
406

    
407
  @staticmethod
408
  def _GetIPIntFromString(address):
409
    """Get integer value of IPv4 address.
410

411
    @type address: str
412
    @param: IPv6 address
413
    @rtype: int
414
    @return: integer value of given IP address
415

416
    """
417
    address_int = 0
418
    parts = address.split(".")
419
    assert len(parts) == 4
420
    for part in parts:
421
      address_int = (address_int << 8) | int(part)
422

    
423
    return address_int
424

    
425

    
426
class IP6Address(IPAddress):
427
  """IPv6 address class.
428

429
  """
430
  iplen = 128
431
  family = socket.AF_INET6
432
  loopback_cidr = "::1/128"
433

    
434
  def __init__(self, address):
435
    """Constructor for IPv6 address.
436

437
    @type address: str
438
    @param address: IP address
439
    @raises errors.IPAddressError: if address invalid
440

441
    """
442
    IPAddress.__init__(self)
443
    if not self.IsValid(address):
444
      raise errors.IPAddressError("IPv6 Address [%s] invalid" % address)
445
    self.address = address
446

    
447
  @staticmethod
448
  def _GetIPIntFromString(address):
449
    """Get integer value of IPv6 address.
450

451
    @type address: str
452
    @param: IPv6 address
453
    @rtype: int
454
    @return: integer value of given IP address
455

456
    """
457
    doublecolons = address.count("::")
458
    assert not doublecolons > 1
459
    if doublecolons == 1:
460
      # We have a shorthand address, expand it
461
      parts = []
462
      twoparts = address.split("::")
463
      sep = len(twoparts[0].split(':')) + len(twoparts[1].split(':'))
464
      parts = twoparts[0].split(':')
465
      [parts.append("0") for _ in range(8 - sep)]
466
      parts += twoparts[1].split(':')
467
    else:
468
      parts = address.split(":")
469

    
470
    address_int = 0
471
    for part in parts:
472
      address_int = (address_int << 16) + int(part or '0', 16)
473

    
474
    return address_int
475

    
476

    
477
def FormatAddress(family, address):
478
  """Format a socket address
479

480
  @type family: integer
481
  @param family: socket family (one of socket.AF_*)
482
  @type address: family specific (usually tuple)
483
  @param address: address, as reported by this class
484

485
  """
486
  if family == socket.AF_UNIX and len(address) == 3:
487
    return "pid=%s, uid=%s, gid=%s" % address
488

    
489
  if family in (socket.AF_INET, socket.AF_INET6) and len(address) == 2:
490
    host, port = address
491
    if family == socket.AF_INET6:
492
      res = "[%s]" % host
493
    else:
494
      res = host
495

    
496
    if port is not None:
497
      res += ":%s" % port
498

    
499
    return res
500

    
501
  raise errors.ParameterError(family, address)