Avoid absolute path for privileged commands
[ganeti-local] / lib / utils.py
index 105364e..6877dc7 100644 (file)
@@ -41,6 +41,7 @@ import select
 import fcntl
 import resource
 import logging
+import logging.handlers
 import signal
 
 from cStringIO import StringIO
@@ -117,13 +118,13 @@ class RunResult(object):
   output = property(_GetOutput, None, None, "Return full output")
 
 
-def RunCmd(cmd, env=None, output=None, cwd='/'):
+def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False):
   """Execute a (shell) command.
 
   The command should not read from its standard input, as it will be
   closed.
 
-  @type  cmd: string or list
+  @type cmd: string or list
   @param cmd: Command to run
   @type env: dict
   @param env: Additional environment
@@ -134,6 +135,8 @@ def RunCmd(cmd, env=None, output=None, cwd='/'):
   @type cwd: string
   @param cwd: if specified, will be used as the working
       directory for the command; the default will be /
+  @type reset_env: boolean
+  @param reset_env: whether to reset or keep the default os environment
   @rtype: L{RunResult}
   @return: RunResult instance
   @raise errors.ProgrammerError: if we call this when forks are disabled
@@ -151,8 +154,12 @@ def RunCmd(cmd, env=None, output=None, cwd='/'):
     shell = True
   logging.debug("RunCmd '%s'", strcmd)
 
-  cmd_env = os.environ.copy()
-  cmd_env["LC_ALL"] = "C"
+  if not reset_env:
+    cmd_env = os.environ.copy()
+    cmd_env["LC_ALL"] = "C"
+  else:
+    cmd_env = {}
+
   if env is not None:
     cmd_env.update(env)
 
@@ -281,6 +288,43 @@ def _RunCmdFile(cmd, env, via_shell, output, cwd):
   return status
 
 
+def RunParts(dir_name, env=None, reset_env=False):
+  """Run Scripts or programs in a directory
+
+  @type dir_name: string
+  @param dir_name: absolute path to a directory
+  @type env: dict
+  @param env: The environment to use
+  @type reset_env: boolean
+  @param reset_env: whether to reset or keep the default os environment
+  @rtype: list of tuples
+  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
+
+  """
+  rr = []
+
+  try:
+    dir_contents = ListVisibleFiles(dir_name)
+  except OSError, err:
+    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
+    return rr
+
+  for relname in sorted(dir_contents):
+    fname = os.path.join(dir_name, relname)
+    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
+            constants.EXT_PLUGIN_MASK.match(relname) is not None):
+      rr.append((relname, constants.RUNPARTS_SKIP, None))
+    else:
+      try:
+        result = RunCmd([fname], env=env, reset_env=reset_env)
+      except Exception, err: # pylint: disable-msg=W0703
+        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
+      else:
+        rr.append((relname, constants.RUNPARTS_RUN, result))
+
+  return rr
+
+
 def RemoveFile(filename):
   """Remove a file ignoring some errors.
 
@@ -333,6 +377,29 @@ def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
     raise
 
 
+def ResetTempfileModule():
+  """Resets the random name generator of the tempfile module.
+
+  This function should be called after C{os.fork} in the child process to
+  ensure it creates a newly seeded random generator. Otherwise it would
+  generate the same random parts as the parent process. If several processes
+  race for the creation of a temporary file, this could lead to one not getting
+  a temporary name.
+
+  """
+  # pylint: disable-msg=W0212
+  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
+    tempfile._once_lock.acquire()
+    try:
+      # Reset random name generator
+      tempfile._name_sequence = None
+    finally:
+      tempfile._once_lock.release()
+  else:
+    logging.critical("The tempfile module misses at least one of the"
+                     " '_once_lock' and '_name_sequence' attributes")
+
+
 def _FingerprintFile(filename):
   """Compute the fingerprint of a file.
 
@@ -491,7 +558,7 @@ def ReadPidFile(pidfile):
 
   try:
     pid = int(raw_data)
-  except ValueError, err:
+  except (TypeError, ValueError), err:
     logging.info("Can't parse pid file contents", exc_info=True)
     return 0
 
@@ -1369,14 +1436,14 @@ def FirstFree(seq, base=0):
   return None
 
 
-def all(seq, pred=bool):
+def all(seq, pred=bool): # pylint: disable-msg=W0622
   "Returns True if pred(x) is True for every element in the iterable"
   for _ in itertools.ifilterfalse(pred, seq):
     return False
   return True
 
 
-def any(seq, pred=bool):
+def any(seq, pred=bool): # pylint: disable-msg=W0622
   "Returns True if pred(x) is True for at least one element in the iterable"
   for _ in itertools.ifilter(pred, seq):
     return True
@@ -1398,20 +1465,26 @@ def UniqueSequence(seq):
   return [i for i in seq if i not in seen and not seen.add(i)]
 
 
-def IsValidMac(mac):
-  """Predicate to check if a MAC address is valid.
+def NormalizeAndValidateMac(mac):
+  """Normalizes and check if a MAC address is valid.
 
   Checks whether the supplied MAC address is formally correct, only
-  accepts colon separated format.
+  accepts colon separated format. Normalize it to all lower.
 
   @type mac: str
   @param mac: the MAC to be validated
-  @rtype: boolean
-  @return: True is the MAC seems valid
+  @rtype: str
+  @return: returns the normalized and validated MAC.
+
+  @raise errors.OpPrereqError: If the MAC isn't valid
 
   """
-  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$")
-  return mac_check.match(mac) is not None
+  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
+  if not mac_check.match(mac):
+    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
+                               mac, errors.ECODE_INVAL)
+
+  return mac.lower()
 
 
 def TestDelay(duration):
@@ -1493,6 +1566,8 @@ def Daemonize(logfile):
   @return: the value zero
 
   """
+  # pylint: disable-msg=W0212
+  # yes, we really want os._exit
   UMASK = 077
   WORKDIR = "/"
 
@@ -1535,6 +1610,19 @@ def DaemonPidFileName(name):
   return os.path.join(constants.RUN_GANETI_DIR, "%s.pid" % name)
 
 
+def EnsureDaemon(name):
+  """Check for and start daemon if not alive.
+
+  """
+  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
+  if result.failed:
+    logging.error("Can't start daemon '%s', failure %s, output: %s",
+                  name, result.fail_reason, result.output)
+    return False
+
+  return True
+
+
 def WritePidFile(name):
   """Write the current process pidfile.
 
@@ -1567,7 +1655,7 @@ def RemovePidFile(name):
   # TODO: we could check here that the file contains our pid
   try:
     RemoveFile(pidfilename)
-  except:
+  except: # pylint: disable-msg=W0702
     pass
 
 
@@ -1753,14 +1841,14 @@ def GetDaemonPort(daemon_name):
   return port
 
 
-def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
-                 multithreaded=False):
+def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
+                 multithreaded=False, syslog=constants.SYSLOG_USAGE):
   """Configures the logging module.
 
   @type logfile: str
   @param logfile: the filename to which we should log
-  @type debug: boolean
-  @param debug: whether to enable debug messages too or
+  @type debug: integer
+  @param debug: if greater than zero, enable debug messages, otherwise
       only those at C{INFO} and above level
   @type stderr_logging: boolean
   @param stderr_logging: whether we should also log to the standard error
@@ -1768,17 +1856,29 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
   @param program: the name under which we should log messages
   @type multithreaded: boolean
   @param multithreaded: if True, will add the thread name to the log file
+  @type syslog: string
+  @param syslog: one of 'no', 'yes', 'only':
+      - if no, syslog is not used
+      - if yes, syslog is used (in addition to file-logging)
+      - if only, only syslog is used
   @raise EnvironmentError: if we can't open the log file and
-      stderr logging is disabled
+      syslog/stderr logging is disabled
 
   """
   fmt = "%(asctime)s: " + program + " pid=%(process)d"
+  sft = program + "[%(process)d]:"
   if multithreaded:
     fmt += "/%(threadName)s"
+    sft += " (%(threadName)s)"
   if debug:
     fmt += " %(module)s:%(lineno)s"
+    # no debug info for syslog loggers
   fmt += " %(levelname)s %(message)s"
+  # yes, we do want the textual level, as remote syslog will probably
+  # lose the error level, and it's easier to grep for it
+  sft += " %(levelname)s %(message)s"
   formatter = logging.Formatter(fmt)
+  sys_fmt = logging.Formatter(sft)
 
   root_logger = logging.getLogger("")
   root_logger.setLevel(logging.NOTSET)
@@ -1797,24 +1897,34 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
       stderr_handler.setLevel(logging.CRITICAL)
     root_logger.addHandler(stderr_handler)
 
-  # this can fail, if the logging directories are not setup or we have
-  # a permisssion problem; in this case, it's best to log but ignore
-  # the error if stderr_logging is True, and if false we re-raise the
-  # exception since otherwise we could run but without any logs at all
-  try:
-    logfile_handler = logging.FileHandler(logfile)
-    logfile_handler.setFormatter(formatter)
-    if debug:
-      logfile_handler.setLevel(logging.DEBUG)
-    else:
-      logfile_handler.setLevel(logging.INFO)
-    root_logger.addHandler(logfile_handler)
-  except EnvironmentError:
-    if stderr_logging:
-      logging.exception("Failed to enable logging to file '%s'", logfile)
-    else:
-      # we need to re-raise the exception
-      raise
+  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
+    facility = logging.handlers.SysLogHandler.LOG_DAEMON
+    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
+                                                    facility)
+    syslog_handler.setFormatter(sys_fmt)
+    # Never enable debug over syslog
+    syslog_handler.setLevel(logging.INFO)
+    root_logger.addHandler(syslog_handler)
+
+  if syslog != constants.SYSLOG_ONLY:
+    # this can fail, if the logging directories are not setup or we have
+    # a permisssion problem; in this case, it's best to log but ignore
+    # the error if stderr_logging is True, and if false we re-raise the
+    # exception since otherwise we could run but without any logs at all
+    try:
+      logfile_handler = logging.FileHandler(logfile)
+      logfile_handler.setFormatter(formatter)
+      if debug:
+        logfile_handler.setLevel(logging.DEBUG)
+      else:
+        logfile_handler.setLevel(logging.INFO)
+      root_logger.addHandler(logfile_handler)
+    except EnvironmentError:
+      if stderr_logging or syslog == constants.SYSLOG_YES:
+        logging.exception("Failed to enable logging to file '%s'", logfile)
+      else:
+        # we need to re-raise the exception
+        raise
 
 
 def IsNormAbsPath(path):
@@ -1889,6 +1999,48 @@ def SafeEncode(text):
   return resu
 
 
+def UnescapeAndSplit(text, sep=","):
+  """Split and unescape a string based on a given separator.
+
+  This function splits a string based on a separator where the
+  separator itself can be escape in order to be an element of the
+  elements. The escaping rules are (assuming coma being the
+  separator):
+    - a plain , separates the elements
+    - a sequence \\\\, (double backslash plus comma) is handled as a
+      backslash plus a separator comma
+    - a sequence \, (backslash plus comma) is handled as a
+      non-separator comma
+
+  @type text: string
+  @param text: the string to split
+  @type sep: string
+  @param text: the separator
+  @rtype: string
+  @return: a list of strings
+
+  """
+  # we split the list by sep (with no escaping at this stage)
+  slist = text.split(sep)
+  # next, we revisit the elements and if any of them ended with an odd
+  # number of backslashes, then we join it with the next
+  rlist = []
+  while slist:
+    e1 = slist.pop(0)
+    if e1.endswith("\\"):
+      num_b = len(e1) - len(e1.rstrip("\\"))
+      if num_b % 2 == 1:
+        e2 = slist.pop(0)
+        # here the backslashes remain (all), and will be reduced in
+        # the next step
+        rlist.append(e1 + sep + e2)
+        continue
+    rlist.append(e1)
+  # finally, replace backslash-something with something
+  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
+  return rlist
+
+
 def CommaJoin(names):
   """Nicely join a set of identifiers.
 
@@ -1946,6 +2098,53 @@ def GetFilesystemStats(path):
   return (tsize, fsize)
 
 
+def RunInSeparateProcess(fn):
+  """Runs a function in a separate process.
+
+  Note: Only boolean return values are supported.
+
+  @type fn: callable
+  @param fn: Function to be called
+  @rtype: tuple of (int/None, int/None)
+  @return: Exit code and signal number
+
+  """
+  pid = os.fork()
+  if pid == 0:
+    # Child process
+    try:
+      # In case the function uses temporary files
+      ResetTempfileModule()
+
+      # Call function
+      result = int(bool(fn()))
+      assert result in (0, 1)
+    except: # pylint: disable-msg=W0702
+      logging.exception("Error while calling function in separate process")
+      # 0 and 1 are reserved for the return value
+      result = 33
+
+    os._exit(result) # pylint: disable-msg=W0212
+
+  # Parent process
+
+  # Avoid zombies and check exit code
+  (_, status) = os.waitpid(pid, 0)
+
+  if os.WIFSIGNALED(status):
+    exitcode = None
+    signum = os.WTERMSIG(status)
+  else:
+    exitcode = os.WEXITSTATUS(status)
+    signum = None
+
+  if not (exitcode in (0, 1) and signum is None):
+    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
+                              (exitcode, signum))
+
+  return bool(exitcode)
+
+
 def LockedMethod(fn):
   """Synchronized object access decorator.
 
@@ -1958,6 +2157,7 @@ def LockedMethod(fn):
       logging.debug(*args, **kwargs)
 
   def wrapper(self, *args, **kwargs):
+    # pylint: disable-msg=W0212
     assert hasattr(self, '_lock')
     lock = self._lock
     _LockDebug("Waiting for %s", lock)
@@ -2098,7 +2298,7 @@ class _RetryDelayCalculator(object):
 
     # Update for next run
     if self._limit is None or self._next < self._limit:
-      self._next = max(self._limit, self._next * self._factor)
+      self._next = min(self._limit, self._next * self._factor)
 
     return current
 
@@ -2164,6 +2364,7 @@ def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
 
   while True:
     try:
+      # pylint: disable-msg=W0142
       return fn(*args)
     except RetryAgain:
       pass
@@ -2206,7 +2407,7 @@ class FileLock(object):
     """Close the file and release the lock.
 
     """
-    if self.fd:
+    if hasattr(self, "fd") and self.fd:
       self.fd.close()
       self.fd = None
 
@@ -2405,7 +2606,8 @@ class SignalHandler(object):
     """
     self.called = False
 
-  def _HandleSignal(self, signum, frame):
+  # we don't care about arguments, but we leave them named for the future
+  def _HandleSignal(self, signum, frame): # pylint: disable-msg=W0613
     """Actual signal handling function.
 
     """