Remove the obsolete EvacuateNode OpCode/LU
[ganeti-local] / lib / utils.py
index 59ecaea..b19e5a4 100644 (file)
@@ -55,12 +55,13 @@ import IN
 from cStringIO import StringIO
 
 try:
-  from hashlib import sha1
+  import ctypes
 except ImportError:
-  import sha as sha1
+  ctypes = None
 
 from ganeti import errors
 from ganeti import constants
+from ganeti import compat
 
 
 _locksheld = []
@@ -80,6 +81,8 @@ X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
                              HEX_CHAR_RE, HEX_CHAR_RE),
                             re.S | re.I)
 
+_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
+
 # Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
 # struct ucred { pid_t pid; uid_t uid; gid_t gid; };
 #
@@ -92,6 +95,14 @@ X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
 _STRUCT_UCRED = "iII"
 _STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
 
+# Certificate verification results
+(CERT_WARNING,
+ CERT_ERROR) = range(1, 3)
+
+# Flags for mlockall() (from bits/mman.h)
+_MCL_CURRENT = 1
+_MCL_FUTURE = 2
+
 
 class RunResult(object):
   """Holds the result of running external programs.
@@ -560,7 +571,9 @@ def RetryOnSignal(fn, *args, **kwargs):
     except EnvironmentError, err:
       if err.errno != errno.EINTR:
         raise
-    except select.error, err:
+    except (socket.error, select.error), err:
+      # In python 2.6 and above select.error is an IOError, so it's handled
+      # above, in 2.5 and below it's not, and it's handled here.
       if not (err.args and err.args[0] == errno.EINTR):
         raise
 
@@ -632,6 +645,24 @@ def RemoveFile(filename):
       raise
 
 
+def RemoveDir(dirname):
+  """Remove an empty directory.
+
+  Remove a directory, ignoring non-existing ones.
+  Other errors are passed. This includes the case,
+  where the directory is not empty, so it can't be removed.
+
+  @type dirname: str
+  @param dirname: the empty directory to be removed
+
+  """
+  try:
+    os.rmdir(dirname)
+  except OSError, err:
+    if err.errno != errno.ENOENT:
+      raise
+
+
 def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
   """Renames a file.
 
@@ -717,10 +748,7 @@ def _FingerprintFile(filename):
 
   f = open(filename)
 
-  if callable(sha1):
-    fp = sha1()
-  else:
-    fp = sha1.new()
+  fp = compat.sha1_hash()
   while True:
     data = f.read(4096)
     if not data:
@@ -818,6 +846,17 @@ def ForceDictType(target, key_types, allowed_values=None):
         raise errors.TypeEnforcementError(msg)
 
 
+def _GetProcStatusPath(pid):
+  """Returns the path for a PID's proc status file.
+
+  @type pid: int
+  @param pid: Process ID
+  @rtype: string
+
+  """
+  return "/proc/%d/status" % pid
+
+
 def IsProcessAlive(pid):
   """Check if a given pid exists on the system.
 
@@ -829,17 +868,113 @@ def IsProcessAlive(pid):
   @return: True if the process exists
 
   """
+  def _TryStat(name):
+    try:
+      os.stat(name)
+      return True
+    except EnvironmentError, err:
+      if err.errno in (errno.ENOENT, errno.ENOTDIR):
+        return False
+      elif err.errno == errno.EINVAL:
+        raise RetryAgain(err)
+      raise
+
+  assert isinstance(pid, int), "pid must be an integer"
   if pid <= 0:
     return False
 
+  # /proc in a multiprocessor environment can have strange behaviors.
+  # Retry the os.stat a few times until we get a good result.
+  try:
+    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
+                 args=[_GetProcStatusPath(pid)])
+  except RetryTimeout, err:
+    err.RaiseInner()
+
+
+def _ParseSigsetT(sigset):
+  """Parse a rendered sigset_t value.
+
+  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
+  function.
+
+  @type sigset: string
+  @param sigset: Rendered signal set from /proc/$pid/status
+  @rtype: set
+  @return: Set of all enabled signal numbers
+
+  """
+  result = set()
+
+  signum = 0
+  for ch in reversed(sigset):
+    chv = int(ch, 16)
+
+    # The following could be done in a loop, but it's easier to read and
+    # understand in the unrolled form
+    if chv & 1:
+      result.add(signum + 1)
+    if chv & 2:
+      result.add(signum + 2)
+    if chv & 4:
+      result.add(signum + 3)
+    if chv & 8:
+      result.add(signum + 4)
+
+    signum += 4
+
+  return result
+
+
+def _GetProcStatusField(pstatus, field):
+  """Retrieves a field from the contents of a proc status file.
+
+  @type pstatus: string
+  @param pstatus: Contents of /proc/$pid/status
+  @type field: string
+  @param field: Name of field whose value should be returned
+  @rtype: string
+
+  """
+  for line in pstatus.splitlines():
+    parts = line.split(":", 1)
+
+    if len(parts) < 2 or parts[0] != field:
+      continue
+
+    return parts[1].strip()
+
+  return None
+
+
+def IsProcessHandlingSignal(pid, signum, status_path=None):
+  """Checks whether a process is handling a signal.
+
+  @type pid: int
+  @param pid: Process ID
+  @type signum: int
+  @param signum: Signal number
+  @rtype: bool
+
+  """
+  if status_path is None:
+    status_path = _GetProcStatusPath(pid)
+
   try:
-    os.stat("/proc/%d/status" % pid)
-    return True
+    proc_status = ReadFile(status_path)
   except EnvironmentError, err:
-    if err.errno in (errno.ENOENT, errno.ENOTDIR):
+    # In at least one case, reading /proc/$pid/status failed with ESRCH.
+    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
       return False
     raise
 
+  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
+  if sigcgt is None:
+    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
+
+  # Now check whether signal is handled
+  return signum in _ParseSigsetT(sigcgt)
+
 
 def ReadPidFile(pidfile):
   """Read a pid from a file.
@@ -852,7 +987,7 @@ def ReadPidFile(pidfile):
 
   """
   try:
-    raw_data = ReadFile(pidfile)
+    raw_data = ReadOneLineFile(pidfile)
   except EnvironmentError, err:
     if err.errno != errno.ENOENT:
       logging.exception("Can't read pid file")
@@ -994,8 +1129,9 @@ class HostInfo:
     """
     try:
       result = socket.gethostbyname_ex(hostname)
-    except socket.gaierror, err:
-      # hostname not found in DNS
+    except (socket.gaierror, socket.herror, socket.error), err:
+      # hostname not found in DNS, or other socket exception in the
+      # (code, description format)
       raise errors.ResolverError(hostname, err.args[0], err.args[1])
 
     return result
@@ -1022,6 +1158,30 @@ class HostInfo:
     return hostname
 
 
+def ValidateServiceName(name):
+  """Validate the given service name.
+
+  @type name: number or string
+  @param name: Service name or port specification
+
+  """
+  try:
+    numport = int(name)
+  except (ValueError, TypeError):
+    # Non-numeric service name
+    valid = _VALID_SERVICE_NAME_RE.match(name)
+  else:
+    # Numeric port (protocols other than TCP or UDP might need adjustments
+    # here)
+    valid = (numport >= 0 and numport < (1 << 16))
+
+  if not valid:
+    raise errors.OpPrereqError("Invalid service name '%s'" % name,
+                               errors.ECODE_INVAL)
+
+  return name
+
+
 def GetHostInfo(name=None):
   """Lookup host name and raise an OpPrereqError for failures"""
 
@@ -1596,7 +1756,6 @@ def ListVisibleFiles(path):
     raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
                                  " absolute/normalized: '%s'" % path)
   files = [i for i in os.listdir(path) if not i.startswith(".")]
-  files.sort()
   return files
 
 
@@ -1661,6 +1820,11 @@ def EnsureDirs(dirs):
       if err.errno != errno.EEXIST:
         raise errors.GenericError("Cannot create needed directory"
                                   " '%s': %s" % (dir_name, err))
+    try:
+      os.chmod(dir_name, dir_mode)
+    except EnvironmentError, err:
+      raise errors.GenericError("Cannot change directory permissions on"
+                                " '%s': %s" % (dir_name, err))
     if not os.path.isdir(dir_name):
       raise errors.GenericError("%s is not a directory" % dir_name)
 
@@ -1780,6 +1944,24 @@ def WriteFile(file_name, fn=None, data=None,
   return result
 
 
+def ReadOneLineFile(file_name, strict=False):
+  """Return the first non-empty line from a file.
+
+  @type strict: boolean
+  @param strict: if True, abort if the file has more than one
+      non-empty line
+
+  """
+  file_lines = ReadFile(file_name).splitlines()
+  full_lines = filter(bool, file_lines)
+  if not file_lines or not full_lines:
+    raise errors.GenericError("No data in one-liner file %s" % file_name)
+  elif strict and len(full_lines) > 1:
+    raise errors.GenericError("Too many lines in one-liner file %s" %
+                              file_name)
+  return full_lines[0]
+
+
 def FirstFree(seq, base=0):
   """Returns the first non-existing integer from seq.
 
@@ -2003,7 +2185,40 @@ def CloseFDs(noclose_fds=None):
     _CloseFDNoErr(fd)
 
 
-def Daemonize(logfile):
+def Mlockall():
+  """Lock current process' virtual address space into RAM.
+
+  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
+  see mlock(2) for more details. This function requires ctypes module.
+
+  """
+  if ctypes is None:
+    logging.warning("Cannot set memory lock, ctypes module not found")
+    return
+
+  libc = ctypes.cdll.LoadLibrary("libc.so.6")
+  if libc is None:
+    logging.error("Cannot set memory lock, ctypes cannot load libc")
+    return
+
+  # Some older version of the ctypes module don't have built-in functionality
+  # to access the errno global variable, where function error codes are stored.
+  # By declaring this variable as a pointer to an integer we can then access
+  # its value correctly, should the mlockall call fail, in order to see what
+  # the actual error code was.
+  # pylint: disable-msg=W0212
+  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
+
+  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
+    # pylint: disable-msg=W0212
+    logging.error("Cannot set memory lock: %s",
+                  os.strerror(libc.__errno_location().contents.value))
+    return
+
+  logging.debug("Memory lock set")
+
+
+def Daemonize(logfile, run_uid, run_gid):
   """Daemonize the current process.
 
   This detaches the current process from the controlling terminal and
@@ -2011,6 +2226,10 @@ def Daemonize(logfile):
 
   @type logfile: str
   @param logfile: the logfile to which we should redirect stdout/stderr
+  @type run_uid: int
+  @param run_uid: Run the child under this uid
+  @type run_gid: int
+  @param run_gid: Run the child under this gid
   @rtype: int
   @return: the value zero
 
@@ -2024,6 +2243,11 @@ def Daemonize(logfile):
   pid = os.fork()
   if (pid == 0):  # The first child.
     os.setsid()
+    # FIXME: When removing again and moving to start-stop-daemon privilege drop
+    #        make sure to check for config permission and bail out when invoked
+    #        with wrong user.
+    os.setgid(run_gid)
+    os.setuid(run_uid)
     # this might fail
     pid = os.fork() # Fork a second child.
     if (pid == 0):  # The second child.
@@ -2072,6 +2296,19 @@ def EnsureDaemon(name):
   return True
 
 
+def StopDaemon(name):
+  """Stop daemon
+
+  """
+  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
+  if result.failed:
+    logging.error("Can't stop daemon '%s', failure %s, output: %s",
+                  name, result.fail_reason, result.output)
+    return False
+
+  return True
+
+
 def WritePidFile(name):
   """Write the current process pidfile.
 
@@ -2128,8 +2365,7 @@ def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
   """
   def _helper(pid, signal_, wait):
     """Simple helper to encapsulate the kill/waitpid sequence"""
-    os.kill(pid, signal_)
-    if wait:
+    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
       try:
         os.waitpid(pid, os.WNOHANG)
       except OSError:
@@ -2291,8 +2527,43 @@ def GetDaemonPort(daemon_name):
   return port
 
 
+class LogFileHandler(logging.FileHandler):
+  """Log handler that doesn't fallback to stderr.
+
+  When an error occurs while writing on the logfile, logging.FileHandler tries
+  to log on stderr. This doesn't work in ganeti since stderr is redirected to
+  the logfile. This class avoids failures reporting errors to /dev/console.
+
+  """
+  def __init__(self, filename, mode="a", encoding=None):
+    """Open the specified file and use it as the stream for logging.
+
+    Also open /dev/console to report errors while logging.
+
+    """
+    logging.FileHandler.__init__(self, filename, mode, encoding)
+    self.console = open(constants.DEV_CONSOLE, "a")
+
+  def handleError(self, record): # pylint: disable-msg=C0103
+    """Handle errors which occur during an emit() call.
+
+    Try to handle errors with FileHandler method, if it fails write to
+    /dev/console.
+
+    """
+    try:
+      logging.FileHandler.handleError(self, record)
+    except Exception: # pylint: disable-msg=W0703
+      try:
+        self.console.write("Cannot log message:\n%s\n" % self.format(record))
+      except Exception: # pylint: disable-msg=W0703
+        # Log handler tried everything it could, now just give up
+        pass
+
+
 def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
-                 multithreaded=False, syslog=constants.SYSLOG_USAGE):
+                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
+                 console_logging=False):
   """Configures the logging module.
 
   @type logfile: str
@@ -2311,6 +2582,9 @@ def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
       - if no, syslog is not used
       - if yes, syslog is used (in addition to file-logging)
       - if only, only syslog is used
+  @type console_logging: boolean
+  @param console_logging: if True, will use a FileHandler which falls back to
+      the system console if logging fails
   @raise EnvironmentError: if we can't open the log file and
       syslog/stderr logging is disabled
 
@@ -2362,7 +2636,10 @@ def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
     # the error if stderr_logging is True, and if false we re-raise the
     # exception since otherwise we could run but without any logs at all
     try:
-      logfile_handler = logging.FileHandler(logfile)
+      if console_logging:
+        logfile_handler = LogFileHandler(logfile)
+      else:
+        logfile_handler = logging.FileHandler(logfile)
       logfile_handler.setFormatter(formatter)
       if debug:
         logfile_handler.setLevel(logging.DEBUG)
@@ -2442,6 +2719,13 @@ def TailFile(fname, lines=20):
   return rows[-lines:]
 
 
+def FormatTimestampWithTZ(secs):
+  """Formats a Unix timestamp with the local timezone.
+
+  """
+  return time.strftime("%F %T %Z", time.gmtime(secs))
+
+
 def _ParseAsn1Generalizedtime(value):
   """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
 
@@ -2505,6 +2789,75 @@ def GetX509CertValidity(cert):
   return (not_before, not_after)
 
 
+def _VerifyCertificateInner(expired, not_before, not_after, now,
+                            warn_days, error_days):
+  """Verifies certificate validity.
+
+  @type expired: bool
+  @param expired: Whether pyOpenSSL considers the certificate as expired
+  @type not_before: number or None
+  @param not_before: Unix timestamp before which certificate is not valid
+  @type not_after: number or None
+  @param not_after: Unix timestamp after which certificate is invalid
+  @type now: number
+  @param now: Current time as Unix timestamp
+  @type warn_days: number or None
+  @param warn_days: How many days before expiration a warning should be reported
+  @type error_days: number or None
+  @param error_days: How many days before expiration an error should be reported
+
+  """
+  if expired:
+    msg = "Certificate is expired"
+
+    if not_before is not None and not_after is not None:
+      msg += (" (valid from %s to %s)" %
+              (FormatTimestampWithTZ(not_before),
+               FormatTimestampWithTZ(not_after)))
+    elif not_before is not None:
+      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
+    elif not_after is not None:
+      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
+
+    return (CERT_ERROR, msg)
+
+  elif not_before is not None and not_before > now:
+    return (CERT_WARNING,
+            "Certificate not yet valid (valid from %s)" %
+            FormatTimestampWithTZ(not_before))
+
+  elif not_after is not None:
+    remaining_days = int((not_after - now) / (24 * 3600))
+
+    msg = "Certificate expires in about %d days" % remaining_days
+
+    if error_days is not None and remaining_days <= error_days:
+      return (CERT_ERROR, msg)
+
+    if warn_days is not None and remaining_days <= warn_days:
+      return (CERT_WARNING, msg)
+
+  return (None, None)
+
+
+def VerifyX509Certificate(cert, warn_days, error_days):
+  """Verifies a certificate for LUVerifyCluster.
+
+  @type cert: OpenSSL.crypto.X509
+  @param cert: X509 certificate object
+  @type warn_days: number or None
+  @param warn_days: How many days before expiration a warning should be reported
+  @type error_days: number or None
+  @param error_days: How many days before expiration an error should be reported
+
+  """
+  # Depending on the pyOpenSSL version, this can just return (None, None)
+  (not_before, not_after) = GetX509CertValidity(cert)
+
+  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
+                                 time.time(), warn_days, error_days)
+
+
 def SignX509Certificate(cert, key, salt):
   """Sign a X509 certificate.
 
@@ -2528,7 +2881,7 @@ def SignX509Certificate(cert, key, salt):
 
   return ("%s: %s/%s\n\n%s" %
           (constants.X509_CERT_SIGNATURE_HEADER, salt,
-           hmac.new(key, salt + cert_pem, sha1).hexdigest(),
+           Sha1Hmac(key, cert_pem, salt=salt),
            cert_pem))
 
 
@@ -2567,12 +2920,47 @@ def LoadSignedX509Certificate(cert_pem, key):
   # Dump again to ensure it's in a sane format
   sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
 
-  if signature != hmac.new(key, salt + sane_pem, sha1).hexdigest():
+  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
     raise errors.GenericError("X509 certificate signature is invalid")
 
   return (cert, salt)
 
 
+def Sha1Hmac(key, text, salt=None):
+  """Calculates the HMAC-SHA1 digest of a text.
+
+  HMAC is defined in RFC2104.
+
+  @type key: string
+  @param key: Secret key
+  @type text: string
+
+  """
+  if salt:
+    salted_text = salt + text
+  else:
+    salted_text = text
+
+  return hmac.new(key, salted_text, compat.sha1).hexdigest()
+
+
+def VerifySha1Hmac(key, text, digest, salt=None):
+  """Verifies the HMAC-SHA1 digest of a text.
+
+  HMAC is defined in RFC2104.
+
+  @type key: string
+  @param key: Secret key
+  @type text: string
+  @type digest: string
+  @param digest: Expected digest
+  @rtype: bool
+  @return: Whether HMAC-SHA1 digest matches
+
+  """
+  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
+
+
 def SafeEncode(text):
   """Return a 'safe' version of a source string.
 
@@ -2756,6 +3144,46 @@ def RunInSeparateProcess(fn, *args):
   return bool(exitcode)
 
 
+def IgnoreProcessNotFound(fn, *args, **kwargs):
+  """Ignores ESRCH when calling a process-related function.
+
+  ESRCH is raised when a process is not found.
+
+  @rtype: bool
+  @return: Whether process was found
+
+  """
+  try:
+    fn(*args, **kwargs)
+  except EnvironmentError, err:
+    # Ignore ESRCH
+    if err.errno == errno.ESRCH:
+      return False
+    raise
+
+  return True
+
+
+def IgnoreSignals(fn, *args, **kwargs):
+  """Tries to call a function ignoring failures due to EINTR.
+
+  """
+  try:
+    return fn(*args, **kwargs)
+  except EnvironmentError, err:
+    if err.errno == errno.EINTR:
+      return None
+    else:
+      raise
+  except (select.error, socket.error), err:
+    # In python 2.6 and above select.error is an IOError, so it's handled
+    # above, in 2.5 and below it's not, and it's handled here.
+    if err.args and err.args[0] == errno.EINTR:
+      return None
+    else:
+      raise
+
+
 def LockedMethod(fn):
   """Synchronized object access decorator.
 
@@ -2814,6 +3242,31 @@ def FormatTime(val):
   return time.strftime("%F %T", time.localtime(val))
 
 
+def FormatSeconds(secs):
+  """Formats seconds for easier reading.
+
+  @type secs: number
+  @param secs: Number of seconds
+  @rtype: string
+  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
+
+  """
+  parts = []
+
+  secs = round(secs, 0)
+
+  if secs > 0:
+    # Negative values would be a bit tricky
+    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
+      (complete, secs) = divmod(secs, one)
+      if complete or parts:
+        parts.append("%d%s" % (complete, unit))
+
+  parts.append("%ds" % secs)
+
+  return " ".join(parts)
+
+
 def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
   """Reads the watcher pause file.
 
@@ -2860,12 +3313,25 @@ def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
 class RetryTimeout(Exception):
   """Retry loop timed out.
 
+  Any arguments which was passed by the retried function to RetryAgain will be
+  preserved in RetryTimeout, if it is raised. If such argument was an exception
+  the RaiseInner helper method will reraise it.
+
   """
+  def RaiseInner(self):
+    if self.args and isinstance(self.args[0], Exception):
+      raise self.args[0]
+    else:
+      raise RetryTimeout(*self.args)
 
 
 class RetryAgain(Exception):
   """Retry again.
 
+  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
+  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
+  of the RetryTimeout() method can be used to reraise it.
+
   """
 
 
@@ -2974,11 +3440,12 @@ def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
   assert calc_delay is None or callable(calc_delay)
 
   while True:
+    retry_args = []
     try:
       # pylint: disable-msg=W0142
       return fn(*args)
-    except RetryAgain:
-      pass
+    except RetryAgain, err:
+      retry_args = err.args
     except RetryTimeout:
       raise errors.ProgrammerError("Nested retry loop detected that didn't"
                                    " handle RetryTimeout")
@@ -2986,7 +3453,8 @@ def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
     remaining_time = end_time - _time_fn()
 
     if remaining_time < 0.0:
-      raise RetryTimeout()
+      # pylint: disable-msg=W0142
+      raise RetryTimeout(*retry_args)
 
     assert remaining_time >= 0.0