Statistics
| Branch: | Tag: | Revision:

root / lib / netutils.py @ e7b3ad26

History | View | Annotate | Download (13.1 kB)

1
#
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti network utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import errno
31
import re
32
import socket
33
import struct
34
import IN
35

    
36
from ganeti import constants
37
from ganeti import errors
38

    
39
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
40
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
41
#
42
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
43
# pid_t to "int".
44
#
45
# IEEE Std 1003.1-2008:
46
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
47
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
48
_STRUCT_UCRED = "iII"
49
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
50

    
51

    
52
def GetSocketCredentials(sock):
53
  """Returns the credentials of the foreign process connected to a socket.
54

55
  @param sock: Unix socket
56
  @rtype: tuple; (number, number, number)
57
  @return: The PID, UID and GID of the connected foreign process.
58

59
  """
60
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
61
                             _STRUCT_UCRED_SIZE)
62
  return struct.unpack(_STRUCT_UCRED, peercred)
63

    
64

    
65
def GetHostname(name=None, family=None):
66
  """Returns a Hostname object.
67

68
  @type name: str
69
  @param name: hostname or None
70
  @type family: int
71
  @param family: AF_INET | AF_INET6 | None
72
  @rtype: L{Hostname}
73
  @return: Hostname object
74
  @raise: errors.OpPrereqError
75

76
  """
77
  try:
78
    return Hostname(name=name, family=family)
79
  except errors.ResolverError, err:
80
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
81
                               (err[0], err[2]), errors.ECODE_RESOLVER)
82

    
83

    
84
class Hostname:
85
  """Class implementing resolver and hostname functionality.
86

87
  """
88
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
89

    
90
  def __init__(self, name=None, family=None):
91
    """Initialize the host name object.
92

93
    If the name argument is None, it will use this system's name.
94

95
    @type family: int
96
    @param family: AF_INET | AF_INET6 | None
97
    @type name: str
98
    @param name: hostname or None
99

100
    """
101
    if name is None:
102
      name = self.GetSysName()
103

    
104
    self.name = self.GetNormalizedName(name)
105
    self.ip = self.GetIP(self.name, family=family)
106

    
107
  @staticmethod
108
  def GetSysName():
109
    """Return the current system's name.
110

111
    This is simply a wrapper over C{socket.gethostname()}.
112

113
    """
114
    return socket.gethostname()
115

    
116
  @staticmethod
117
  def GetIP(hostname, family=None):
118
    """Return IP address of given hostname.
119

120
    Supports both IPv4 and IPv6.
121

122
    @type hostname: str
123
    @param hostname: hostname to look up
124
    @type family: int
125
    @param family: AF_INET | AF_INET6 | None
126
    @rtype: str
127
    @return: IP address
128
    @raise errors.ResolverError: in case of errors in resolving
129

130
    """
131
    try:
132
      if family in (socket.AF_INET, socket.AF_INET6):
133
        result = socket.getaddrinfo(hostname, None, family)
134
      else:
135
        result = socket.getaddrinfo(hostname, None, socket.AF_INET)
136
    except (socket.gaierror, socket.herror, socket.error), err:
137
      # hostname not found in DNS, or other socket exception in the
138
      # (code, description format)
139
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
140

    
141
    # getaddrinfo() returns a list of 5-tupes (family, socktype, proto,
142
    # canonname, sockaddr). We return the first tuple's first address in
143
    # sockaddr
144
    try:
145
      return result[0][4][0]
146
    except IndexError, err:
147
      raise errors.ResolverError("Unknown error in getaddrinfo(): %s" % err)
148

    
149
  @classmethod
150
  def GetNormalizedName(cls, hostname):
151
    """Validate and normalize the given hostname.
152

153
    @attention: the validation is a bit more relaxed than the standards
154
        require; most importantly, we allow underscores in names
155
    @raise errors.OpPrereqError: when the name is not valid
156

157
    """
158
    hostname = hostname.lower()
159
    if (not cls._VALID_NAME_RE.match(hostname) or
160
        # double-dots, meaning empty label
161
        ".." in hostname or
162
        # empty initial label
163
        hostname.startswith(".")):
164
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
165
                                 errors.ECODE_INVAL)
166
    if hostname.endswith("."):
167
      hostname = hostname.rstrip(".")
168
    return hostname
169

    
170

    
171
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
172
  """Simple ping implementation using TCP connect(2).
173

174
  Check if the given IP is reachable by doing attempting a TCP connect
175
  to it.
176

177
  @type target: str
178
  @param target: the IP or hostname to ping
179
  @type port: int
180
  @param port: the port to connect to
181
  @type timeout: int
182
  @param timeout: the timeout on the connection attempt
183
  @type live_port_needed: boolean
184
  @param live_port_needed: whether a closed port will cause the
185
      function to return failure, as if there was a timeout
186
  @type source: str or None
187
  @param source: if specified, will cause the connect to be made
188
      from this specific source address; failures to bind other
189
      than C{EADDRNOTAVAIL} will be ignored
190

191
  """
192
  try:
193
    family = IPAddress.GetAddressFamily(target)
194
  except errors.GenericError:
195
    return False
196

    
197
  sock = socket.socket(family, socket.SOCK_STREAM)
198
  success = False
199

    
200
  if source is not None:
201
    try:
202
      sock.bind((source, 0))
203
    except socket.error, (errcode, _):
204
      if errcode == errno.EADDRNOTAVAIL:
205
        success = False
206

    
207
  sock.settimeout(timeout)
208

    
209
  try:
210
    sock.connect((target, port))
211
    sock.close()
212
    success = True
213
  except socket.timeout:
214
    success = False
215
  except socket.error, (errcode, _):
216
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
217

    
218
  return success
219

    
220

    
221
def GetDaemonPort(daemon_name):
222
  """Get the daemon port for this cluster.
223

224
  Note that this routine does not read a ganeti-specific file, but
225
  instead uses C{socket.getservbyname} to allow pre-customization of
226
  this parameter outside of Ganeti.
227

228
  @type daemon_name: string
229
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
230
  @rtype: int
231

232
  """
233
  if daemon_name not in constants.DAEMONS_PORTS:
234
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
235

    
236
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
237
  try:
238
    port = socket.getservbyname(daemon_name, proto)
239
  except socket.error:
240
    port = default_port
241

    
242
  return port
243

    
244

    
245
class IPAddress(object):
246
  """Class that represents an IP address.
247

248
  """
249
  iplen = 0
250
  family = None
251
  loopback_cidr = None
252

    
253
  @staticmethod
254
  def _GetIPIntFromString(address):
255
    """Abstract method to please pylint.
256

257
    """
258
    raise NotImplementedError
259

    
260
  @classmethod
261
  def IsValid(cls, address):
262
    """Validate a IP address.
263

264
    @type address: str
265
    @param address: IP address to be checked
266
    @rtype: bool
267
    @return: True if valid, False otherwise
268

269
    """
270
    if cls.family is None:
271
      try:
272
        family = cls.GetAddressFamily(address)
273
      except errors.IPAddressError:
274
        return False
275
    else:
276
      family = cls.family
277

    
278
    try:
279
      socket.inet_pton(family, address)
280
      return True
281
    except socket.error:
282
      return False
283

    
284
  @classmethod
285
  def Own(cls, address):
286
    """Check if the current host has the the given IP address.
287

288
    This is done by trying to bind the given address. We return True if we
289
    succeed or false if a socket.error is raised.
290

291
    @type address: str
292
    @param address: IP address to be checked
293
    @rtype: bool
294
    @return: True if we own the address, False otherwise
295

296
    """
297
    if cls.family is None:
298
      try:
299
        family = cls.GetAddressFamily(address)
300
      except errors.IPAddressError:
301
        return False
302
    else:
303
      family = cls.family
304

    
305
    s = socket.socket(family, socket.SOCK_DGRAM)
306
    success = False
307
    try:
308
      try:
309
        s.bind((address, 0))
310
        success = True
311
      except socket.error:
312
        success = False
313
    finally:
314
      s.close()
315
    return success
316

    
317
  @classmethod
318
  def InNetwork(cls, cidr, address):
319
    """Determine whether an address is within a network.
320

321
    @type cidr: string
322
    @param cidr: Network in CIDR notation, e.g. '192.0.2.0/24', '2001:db8::/64'
323
    @type address: str
324
    @param address: IP address
325
    @rtype: bool
326
    @return: True if address is in cidr, False otherwise
327

328
    """
329
    address_int = cls._GetIPIntFromString(address)
330
    subnet = cidr.split("/")
331
    assert len(subnet) == 2
332
    try:
333
      prefix = int(subnet[1])
334
    except ValueError:
335
      return False
336

    
337
    assert 0 <= prefix <= cls.iplen
338
    target_int = cls._GetIPIntFromString(subnet[0])
339
    # Convert prefix netmask to integer value of netmask
340
    netmask_int = (2**cls.iplen)-1 ^ ((2**cls.iplen)-1 >> prefix)
341
    # Calculate hostmask
342
    hostmask_int = netmask_int ^ (2**cls.iplen)-1
343
    # Calculate network address by and'ing netmask
344
    network_int = target_int & netmask_int
345
    # Calculate broadcast address by or'ing hostmask
346
    broadcast_int = target_int | hostmask_int
347

    
348
    return network_int <= address_int <= broadcast_int
349

    
350
  @staticmethod
351
  def GetAddressFamily(address):
352
    """Get the address family of the given address.
353

354
    @type address: str
355
    @param address: ip address whose family will be returned
356
    @rtype: int
357
    @return: socket.AF_INET or socket.AF_INET6
358
    @raise errors.GenericError: for invalid addresses
359

360
    """
361
    try:
362
      return IP4Address(address).family
363
    except errors.IPAddressError:
364
      pass
365

    
366
    try:
367
      return IP6Address(address).family
368
    except errors.IPAddressError:
369
      pass
370

    
371
    raise errors.IPAddressError("Invalid address '%s'" % address)
372

    
373
  @classmethod
374
  def IsLoopback(cls, address):
375
    """Determine whether it is a loopback address.
376

377
    @type address: str
378
    @param address: IP address to be checked
379
    @rtype: bool
380
    @return: True if loopback, False otherwise
381

382
    """
383
    try:
384
      return cls.InNetwork(cls.loopback_cidr, address)
385
    except errors.IPAddressError:
386
      return False
387

    
388

    
389
class IP4Address(IPAddress):
390
  """IPv4 address class.
391

392
  """
393
  iplen = 32
394
  family = socket.AF_INET
395
  loopback_cidr = "127.0.0.0/8"
396

    
397
  def __init__(self, address):
398
    """Constructor for IPv4 address.
399

400
    @type address: str
401
    @param address: IP address
402
    @raises errors.IPAddressError: if address invalid
403

404
    """
405
    IPAddress.__init__(self)
406
    if not self.IsValid(address):
407
      raise errors.IPAddressError("IPv4 Address %s invalid" % address)
408

    
409
    self.address = address
410

    
411
  @staticmethod
412
  def _GetIPIntFromString(address):
413
    """Get integer value of IPv4 address.
414

415
    @type address: str
416
    @param: IPv6 address
417
    @rtype: int
418
    @return: integer value of given IP address
419

420
    """
421
    address_int = 0
422
    parts = address.split(".")
423
    assert len(parts) == 4
424
    for part in parts:
425
      address_int = (address_int << 8) | int(part)
426

    
427
    return address_int
428

    
429

    
430
class IP6Address(IPAddress):
431
  """IPv6 address class.
432

433
  """
434
  iplen = 128
435
  family = socket.AF_INET6
436
  loopback_cidr = "::1/128"
437

    
438
  def __init__(self, address):
439
    """Constructor for IPv6 address.
440

441
    @type address: str
442
    @param address: IP address
443
    @raises errors.IPAddressError: if address invalid
444

445
    """
446
    IPAddress.__init__(self)
447
    if not self.IsValid(address):
448
      raise errors.IPAddressError("IPv6 Address [%s] invalid" % address)
449
    self.address = address
450

    
451
  @staticmethod
452
  def _GetIPIntFromString(address):
453
    """Get integer value of IPv6 address.
454

455
    @type address: str
456
    @param: IPv6 address
457
    @rtype: int
458
    @return: integer value of given IP address
459

460
    """
461
    doublecolons = address.count("::")
462
    assert not doublecolons > 1
463
    if doublecolons == 1:
464
      # We have a shorthand address, expand it
465
      parts = []
466
      twoparts = address.split("::")
467
      sep = len(twoparts[0].split(':')) + len(twoparts[1].split(':'))
468
      parts = twoparts[0].split(':')
469
      [parts.append("0") for _ in range(8 - sep)]
470
      parts += twoparts[1].split(':')
471
    else:
472
      parts = address.split(":")
473

    
474
    address_int = 0
475
    for part in parts:
476
      address_int = (address_int << 16) + int(part or '0', 16)
477

    
478
    return address_int
479

    
480

    
481
def FormatAddress(family, address):
482
  """Format a socket address
483

484
  @type family: integer
485
  @param family: socket family (one of socket.AF_*)
486
  @type address: family specific (usually tuple)
487
  @param address: address, as reported by this class
488

489
  """
490
  if family == socket.AF_UNIX and len(address) == 3:
491
    return "pid=%s, uid=%s, gid=%s" % address
492

    
493
  if family in (socket.AF_INET, socket.AF_INET6) and len(address) == 2:
494
    host, port = address
495
    if family == socket.AF_INET6:
496
      res = "[%s]" % host
497
    else:
498
      res = host
499

    
500
    if port is not None:
501
      res += ":%s" % port
502

    
503
    return res
504

    
505
  raise errors.ParameterError(family, address)