Don't remove master's hostname from /etc/hosts on cluster destroy.
[ganeti-local] / lib / utils.py
index e9f426f..73d1bbb 100644 (file)
@@ -1,4 +1,4 @@
-#!/usr/bin/python
+#
 #
 
 # Copyright (C) 2006, 2007 Google Inc.
@@ -20,6 +20,7 @@
 
 
 """Ganeti small utilities
+
 """
 
 
@@ -33,10 +34,13 @@ import socket
 import tempfile
 import shutil
 import errno
+import pwd
+import itertools
 
 from ganeti import logger
 from ganeti import errors
 
+
 _locksheld = []
 _re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
 
@@ -397,38 +401,57 @@ def MatchNameComponent(key, name_list):
 
 
 class HostInfo:
-  """Class holding host info as returned by gethostbyname
+  """Class implementing resolver and hostname functionality
 
   """
-  def __init__(self, name, aliases, ipaddrs):
+  def __init__(self, name=None):
     """Initialize the host name object.
 
-    Arguments are the same as returned by socket.gethostbyname_ex()
+    If the name argument is not passed, it will use this system's
+    name.
 
     """
-    self.name = name
-    self.aliases = aliases
-    self.ipaddrs = ipaddrs
+    if name is None:
+      name = self.SysName()
+
+    self.query = name
+    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
     self.ip = self.ipaddrs[0]
 
+  def ShortName(self):
+    """Returns the hostname without domain.
 
-def LookupHostname(hostname):
-  """Look up hostname
+    """
+    return self.name.split('.')[0]
 
-  Args:
-    hostname: hostname to look up, can be also be a non FQDN
+  @staticmethod
+  def SysName():
+    """Return the current system's name.
 
-  Returns:
-    a HostInfo object
+    This is simply a wrapper over socket.gethostname()
 
-  """
-  try:
-    (name, aliases, ipaddrs) = socket.gethostbyname_ex(hostname)
-  except socket.gaierror:
-    # hostname not found in DNS
-    return None
+    """
+    return socket.gethostname()
+
+  @staticmethod
+  def LookupHostname(hostname):
+    """Look up hostname
 
-  return HostInfo(name, aliases, ipaddrs)
+    Args:
+      hostname: hostname to look up
+
+    Returns:
+      a tuple (name, aliases, ipaddrs) as returned by socket.gethostbyname_ex
+      in case of errors in resolving, we raise a ResolverError
+
+    """
+    try:
+      result = socket.gethostbyname_ex(hostname)
+    except socket.gaierror, err:
+      # hostname not found in DNS
+      raise errors.ResolverError(hostname, err.args[0], err.args[1])
+
+    return result
 
 
 def ListVolumeGroups():
@@ -708,21 +731,95 @@ def RemoveAuthorizedKey(file_name, key):
   key_fields = key.split()
 
   fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
-  out = os.fdopen(fd, 'w')
   try:
-    f = open(file_name, 'r')
+    out = os.fdopen(fd, 'w')
     try:
-      for line in f:
-        # Ignore whitespace changes while comparing lines
-        if line.split() != key_fields:
+      f = open(file_name, 'r')
+      try:
+        for line in f:
+          # Ignore whitespace changes while comparing lines
+          if line.split() != key_fields:
+            out.write(line)
+
+        out.flush()
+        os.rename(tmpname, file_name)
+      finally:
+        f.close()
+    finally:
+      out.close()
+  except:
+    RemoveFile(tmpname)
+    raise
+
+
+def SetEtcHostsEntry(file_name, ip, hostname, aliases):
+  """Sets the name of an IP address and hostname in /etc/hosts.
+
+  """
+  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
+  try:
+    out = os.fdopen(fd, 'w')
+    try:
+      f = open(file_name, 'r')
+      try:
+        written = False
+        for line in f:
+          fields = line.split()
+          if fields and not fields[0].startswith('#') and ip == fields[0]:
+            continue
           out.write(line)
 
-      out.flush()
-      os.rename(tmpname, file_name)
+        out.write("%s\t%s" % (ip, hostname))
+        if aliases:
+          out.write(" %s" % ' '.join(aliases))
+        out.write('\n')
+
+        out.flush()
+        os.fsync(out)
+        os.rename(tmpname, file_name)
+      finally:
+        f.close()
     finally:
-      f.close()
-  finally:
-    out.close()
+      out.close()
+  except:
+    RemoveFile(tmpname)
+    raise
+
+
+def RemoveEtcHostsEntry(file_name, hostname):
+  """Removes a hostname from /etc/hosts.
+
+  IP addresses without names are removed from the file.
+  """
+  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
+  try:
+    out = os.fdopen(fd, 'w')
+    try:
+      f = open(file_name, 'r')
+      try:
+        for line in f:
+          fields = line.split()
+          if len(fields) > 1 and not fields[0].startswith('#'):
+            names = fields[1:]
+            if hostname in names:
+              while hostname in names:
+                names.remove(hostname)
+              if names:
+                out.write("%s %s\n" % (fields[0], ' '.join(names)))
+              continue
+
+          out.write(line)
+
+        out.flush()
+        os.fsync(out)
+        os.rename(tmpname, file_name)
+      finally:
+        f.close()
+    finally:
+      out.close()
+  except:
+    RemoveFile(tmpname)
+    raise
 
 
 def CreateBackup(file_name):
@@ -735,10 +832,20 @@ def CreateBackup(file_name):
     raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
                                 file_name)
 
-  # Warning: the following code contains a race condition when we create more
-  # than one backup of the same file in a second.
-  backup_name = file_name + '.backup-%d' % int(time.time())
-  shutil.copyfile(file_name, backup_name)
+  prefix = '%s.backup-%d.' % (os.path.basename(file_name), int(time.time()))
+  dir_name = os.path.dirname(file_name)
+
+  fsrc = open(file_name, 'rb')
+  try:
+    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
+    fdst = os.fdopen(fd, 'wb')
+    try:
+      shutil.copyfileobj(fsrc, fdst)
+    finally:
+      fdst.close()
+  finally:
+    fsrc.close()
+
   return backup_name
 
 
@@ -759,42 +866,8 @@ def ShellQuoteArgs(args):
   return ' '.join([ShellQuote(i) for i in args])
 
 
-def _ParseIpOutput(output):
-  """Parsing code for GetLocalIPAddresses().
-
-  This function is split out, so we can unit test it.
-
-  """
-  re_ip = re.compile('^(\d+\.\d+\.\d+\.\d+)(?:/\d+)$')
-
-  ips = []
-  for line in output.splitlines(False):
-    fields = line.split()
-    if len(line) < 4:
-      continue
-    m = re_ip.match(fields[3])
-    if m:
-      ips.append(m.group(1))
-
-  return ips
-
-
-def GetLocalIPAddresses():
-  """Gets a list of all local IP addresses.
-
-  Should this break one day, a small Python module written in C could
-  use the API call getifaddrs().
 
-  """
-  result = RunCmd(["ip", "-family", "inet", "-oneline", "addr", "show"])
-  if result.failed:
-    raise errors.OpExecError("Command '%s' failed, error: %s,"
-      " output: %s" % (result.cmd, result.fail_reason, result.output))
-
-  return _ParseIpOutput(result.output)
-
-
-def TcpPing(source, target, port, timeout=10, live_port_needed=True):
+def TcpPing(source, target, port, timeout=10, live_port_needed=False):
   """Simple ping implementation using TCP connect(2).
 
   Try to do a TCP connect(2) from the specified source IP to the specified
@@ -825,3 +898,107 @@ def TcpPing(source, target, port, timeout=10, live_port_needed=True):
     success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
 
   return success
+
+
+def ListVisibleFiles(path):
+  """Returns a list of all visible files in a directory.
+
+  """
+  return [i for i in os.listdir(path) if not i.startswith(".")]
+
+
+def GetHomeDir(user, default=None):
+  """Try to get the homedir of the given user.
+
+  The user can be passed either as a string (denoting the name) or as
+  an integer (denoting the user id). If the user is not found, the
+  'default' argument is returned, which defaults to None.
+
+  """
+  try:
+    if isinstance(user, basestring):
+      result = pwd.getpwnam(user)
+    elif isinstance(user, (int, long)):
+      result = pwd.getpwuid(user)
+    else:
+      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
+                                   type(user))
+  except KeyError:
+    return default
+  return result.pw_dir
+
+
+def NewUUID():
+  """Returns a random UUID.
+
+  """
+  f = open("/proc/sys/kernel/random/uuid", "r")
+  try:
+    return f.read(128).rstrip("\n")
+  finally:
+    f.close()
+
+
+def WriteFile(file_name, fn=None, data=None,
+              mode=None, uid=-1, gid=-1,
+              atime=None, mtime=None):
+  """(Over)write a file atomically.
+
+  The file_name and either fn (a function taking one argument, the
+  file descriptor, and which should write the data to it) or data (the
+  contents of the file) must be passed. The other arguments are
+  optional and allow setting the file mode, owner and group, and the
+  mtime/atime of the file.
+
+  If the function doesn't raise an exception, it has succeeded and the
+  target file has the new contents. If the file has raised an
+  exception, an existing target file should be unmodified and the
+  temporary file should be removed.
+
+  """
+  if not os.path.isabs(file_name):
+    raise errors.ProgrammerError("Path passed to WriteFile is not"
+                                 " absolute: '%s'" % file_name)
+
+  if [fn, data].count(None) != 1:
+    raise errors.ProgrammerError("fn or data required")
+
+  if [atime, mtime].count(None) == 1:
+    raise errors.ProgrammerError("Both atime and mtime must be either"
+                                 " set or None")
+
+
+  dir_name, base_name = os.path.split(file_name)
+  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
+  # here we need to make sure we remove the temp file, if any error
+  # leaves it in place
+  try:
+    if uid != -1 or gid != -1:
+      os.chown(new_name, uid, gid)
+    if mode:
+      os.chmod(new_name, mode)
+    if data is not None:
+      os.write(fd, data)
+    else:
+      fn(fd)
+    os.fsync(fd)
+    if atime is not None and mtime is not None:
+      os.utime(new_name, (atime, mtime))
+    os.rename(new_name, file_name)
+  finally:
+    os.close(fd)
+    RemoveFile(new_name)
+
+
+def all(seq, pred=bool):
+  "Returns True if pred(x) is True for every element in the iterable"
+  for elem in itertools.ifilterfalse(pred, seq):
+    return False
+  return True
+
+
+def any(seq, pred=bool):
+  "Returns True if pred(x) is True for at least one element in the iterable"
+  for elem in itertools.ifilter(pred, seq):
+    return True
+  return False