Merge remote-tracking branch 'origin/stable-2.8'
[ganeti-local] / lib / netutils.py
index 7a1b6d0..ad9b530 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2010 Google Inc.
+# Copyright (C) 2010, 2011, 2012 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -38,6 +38,7 @@ import logging
 from ganeti import constants
 from ganeti import errors
 from ganeti import utils
+from ganeti import vcluster
 
 # Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
 # struct ucred { pid_t pid; uid_t uid; gid_t gid; };
@@ -166,7 +167,7 @@ class Hostname:
     @param name: hostname or None
 
     """
-    self.name = self.GetNormalizedName(self.GetFqdn(name))
+    self.name = self.GetFqdn(name)
     self.ip = self.GetIP(self.name, family=family)
 
   @classmethod
@@ -176,8 +177,8 @@ class Hostname:
     """
     return cls.GetFqdn()
 
-  @staticmethod
-  def GetFqdn(hostname=None):
+  @classmethod
+  def GetFqdn(cls, hostname=None):
     """Return fqdn.
 
     If hostname is None the system's fqdn is returned.
@@ -189,9 +190,15 @@ class Hostname:
 
     """
     if hostname is None:
-      return socket.getfqdn()
+      virtfqdn = vcluster.GetVirtualHostname()
+      if virtfqdn:
+        result = virtfqdn
+      else:
+        result = socket.getfqdn()
     else:
-      return socket.getfqdn(hostname)
+      result = socket.getfqdn(hostname)
+
+    return cls.GetNormalizedName(result)
 
   @staticmethod
   def GetIP(hostname, family=None):
@@ -224,7 +231,12 @@ class Hostname:
     try:
       return result[0][4][0]
     except IndexError, err:
-      raise errors.ResolverError("Unknown error in getaddrinfo(): %s" % err)
+      # we don't have here an actual error code, it's just that the
+      # data type returned by getaddrinfo is not what we expected;
+      # let's keep the same format in the exception arguments with a
+      # dummy error code
+      raise errors.ResolverError(hostname, 0,
+                                 "Unknown error in getaddrinfo(): %s" % err)
 
   @classmethod
   def GetNormalizedName(cls, hostname):
@@ -269,10 +281,14 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
       than C{EADDRNOTAVAIL} will be ignored
 
   """
+  logging.debug("Attempting to reach TCP port %s on target %s with a timeout"
+                " of %s seconds", port, target, timeout)
+
   try:
     family = IPAddress.GetAddressFamily(target)
-  except errors.GenericError:
-    return False
+  except errors.IPAddressError, err:
+    raise errors.ProgrammerError("Family of IP address given in parameter"
+                                 " 'target' can't be determined: %s" % err)
 
   sock = socket.socket(family, socket.SOCK_STREAM)
   success = False
@@ -280,8 +296,8 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
   if source is not None:
     try:
       sock.bind((source, 0))
-    except socket.error, (errcode, _):
-      if errcode == errno.EADDRNOTAVAIL:
+    except socket.error, err:
+      if err[0] == errno.EADDRNOTAVAIL:
         success = False
 
   sock.settimeout(timeout)
@@ -292,8 +308,8 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
     success = True
   except socket.timeout:
     success = False
-  except socket.error, (errcode, _):
-    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
+  except socket.error, err:
+    success = (not live_port_needed) and (err[0] == errno.ECONNREFUSED)
 
   return success
 
@@ -362,6 +378,20 @@ class IPAddress(object):
       return False
 
   @classmethod
+  def ValidateNetmask(cls, netmask):
+    """Validate a netmask suffix in CIDR notation.
+
+    @type netmask: int
+    @param netmask: netmask suffix to validate
+    @rtype: bool
+    @return: True if valid, False otherwise
+
+    """
+    assert (isinstance(netmask, (int, long)))
+
+    return 0 < netmask <= cls.iplen
+
+  @classmethod
   def Own(cls, address):
     """Check if the current host has the the given IP address.
 
@@ -487,6 +517,36 @@ class IPAddress(object):
 
     raise errors.ProgrammerError("%s is not a valid IP version" % version)
 
+  @staticmethod
+  def GetClassFromIpVersion(version):
+    """Return the IPAddress subclass for the given IP version.
+
+    @type version: int
+    @param version: IP version, one of L{constants.IP4_VERSION} or
+                    L{constants.IP6_VERSION}
+    @return: a subclass of L{netutils.IPAddress}
+    @raise errors.ProgrammerError: for unknowo IP versions
+
+    """
+    if version == constants.IP4_VERSION:
+      return IP4Address
+    elif version == constants.IP6_VERSION:
+      return IP6Address
+
+    raise errors.ProgrammerError("%s is not a valid IP version" % version)
+
+  @staticmethod
+  def GetClassFromIpFamily(family):
+    """Return the IPAddress subclass for the given IP family.
+
+    @param family: IP family (one of C{socket.AF_INET} or C{socket.AF_INET6}
+    @return: a subclass of L{netutils.IPAddress}
+    @raise errors.ProgrammerError: for unknowo IP versions
+
+    """
+    return IPAddress.GetClassFromIpVersion(
+              IPAddress.GetVersionFromAddressFamily(family))
+
   @classmethod
   def IsLoopback(cls, address):
     """Determine whether it is a loopback address.
@@ -583,7 +643,7 @@ class IP6Address(IPAddress):
       twoparts = address.split("::")
       sep = len(twoparts[0].split(":")) + len(twoparts[1].split(":"))
       parts = twoparts[0].split(":")
-      [parts.append("0") for _ in range(8 - sep)]
+      parts.extend(["0"] * (8 - sep))
       parts += twoparts[1].split(":")
     else:
       parts = address.split(":")