Merge branch 'devel-2.1'
[ganeti-local] / lib / utils.py
index e876e2c..f08e75c 100644 (file)
@@ -27,7 +27,6 @@ the command line scripts.
 """
 
 
-import sys
 import os
 import time
 import subprocess
@@ -59,12 +58,13 @@ from ganeti import constants
 _locksheld = []
 _re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
 
-debug = False
 debug_locks = False
 
 #: when set to True, L{RunCmd} is disabled
 no_fork = False
 
+_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
+
 
 class RunResult(object):
   """Holds the result of running external programs.
@@ -136,7 +136,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/'):
       directory for the command; the default will be /
   @rtype: L{RunResult}
   @return: RunResult instance
-  @raise erors.ProgrammerError: if we call this when forks are disabled
+  @raise errors.ProgrammerError: if we call this when forks are disabled
 
   """
   if no_fork:
@@ -319,8 +319,17 @@ def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
     # as efficient.
     if mkdir and err.errno == errno.ENOENT:
       # Create directory and try again
-      os.makedirs(os.path.dirname(new), mkdir_mode)
+      dirname = os.path.dirname(new)
+      try:
+        os.makedirs(dirname, mode=mkdir_mode)
+      except OSError, err:
+        # Ignore EEXIST. This is only handled in os.makedirs as included in
+        # Python 2.5 and above.
+        if err.errno != errno.EEXIST or not os.path.exists(dirname):
+          raise
+
       return os.rename(old, new)
+
     raise
 
 
@@ -373,32 +382,6 @@ def FingerprintFiles(files):
   return ret
 
 
-def CheckDict(target, template, logname=None):
-  """Ensure a dictionary has a required set of keys.
-
-  For the given dictionaries I{target} and I{template}, ensure
-  I{target} has all the keys from I{template}. Missing keys are added
-  with values from template.
-
-  @type target: dict
-  @param target: the dictionary to update
-  @type template: dict
-  @param template: the dictionary holding the default values
-  @type logname: str or None
-  @param logname: if not None, causes the missing keys to be
-      logged with this name
-
-  """
-  missing = []
-  for k in template:
-    if k not in target:
-      missing.append(k)
-      target[k] = template[k]
-
-  if missing and logname:
-    logging.warning('%s missing keys %s', logname, ', '.join(missing))
-
-
 def ForceDictType(target, key_types, allowed_values=None):
   """Force the values of a dict to have certain types.
 
@@ -414,6 +397,10 @@ def ForceDictType(target, key_types, allowed_values=None):
   if allowed_values is None:
     allowed_values = []
 
+  if not isinstance(target, dict):
+    msg = "Expected dictionary, got '%s'" % target
+    raise errors.TypeEnforcementError(msg)
+
   for key in target:
     if key not in key_types:
       msg = "Unknown key '%s'" % key
@@ -422,19 +409,19 @@ def ForceDictType(target, key_types, allowed_values=None):
     if target[key] in allowed_values:
       continue
 
-    type = key_types[key]
-    if type not in constants.ENFORCEABLE_TYPES:
-      msg = "'%s' has non-enforceable type %s" % (key, type)
+    ktype = key_types[key]
+    if ktype not in constants.ENFORCEABLE_TYPES:
+      msg = "'%s' has non-enforceable type %s" % (key, ktype)
       raise errors.ProgrammerError(msg)
 
-    if type == constants.VTYPE_STRING:
+    if ktype == constants.VTYPE_STRING:
       if not isinstance(target[key], basestring):
         if isinstance(target[key], bool) and not target[key]:
           target[key] = ''
         else:
           msg = "'%s' (value %s) is not a valid string" % (key, target[key])
           raise errors.TypeEnforcementError(msg)
-    elif type == constants.VTYPE_BOOL:
+    elif ktype == constants.VTYPE_BOOL:
       if isinstance(target[key], basestring) and target[key]:
         if target[key].lower() == constants.VALUE_FALSE:
           target[key] = False
@@ -447,14 +434,14 @@ def ForceDictType(target, key_types, allowed_values=None):
         target[key] = True
       else:
         target[key] = False
-    elif type == constants.VTYPE_SIZE:
+    elif ktype == constants.VTYPE_SIZE:
       try:
         target[key] = ParseUnit(target[key])
       except errors.UnitParseError, err:
         msg = "'%s' (value %s) is not a valid size. error: %s" % \
               (key, target[key], err)
         raise errors.TypeEnforcementError(msg)
-    elif type == constants.VTYPE_INT:
+    elif ktype == constants.VTYPE_INT:
       try:
         target[key] = int(target[key])
       except (ValueError, TypeError):
@@ -496,14 +483,14 @@ def ReadPidFile(pidfile):
 
   """
   try:
-    pf = open(pidfile, 'r')
+    raw_data = ReadFile(pidfile)
   except EnvironmentError, err:
     if err.errno != errno.ENOENT:
-      logging.exception("Can't read pid file?!")
+      logging.exception("Can't read pid file")
     return 0
 
   try:
-    pid = int(pf.read())
+    pid = int(raw_data)
   except ValueError, err:
     logging.info("Can't parse pid file contents", exc_info=True)
     return 0
@@ -511,7 +498,7 @@ def ReadPidFile(pidfile):
   return pid
 
 
-def MatchNameComponent(key, name_list):
+def MatchNameComponent(key, name_list, case_sensitive=True):
   """Try to match a name against a list.
 
   This function will try to match a name like test1 against a list
@@ -519,23 +506,42 @@ def MatchNameComponent(key, name_list):
   this list, I{'test1'} as well as I{'test1.example'} will match, but
   not I{'test1.ex'}. A multiple match will be considered as no match
   at all (e.g. I{'test1'} against C{['test1.example.com',
-  'test1.example.org']}).
+  'test1.example.org']}), except when the key fully matches an entry
+  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
 
   @type key: str
   @param key: the name to be searched
   @type name_list: list
   @param name_list: the list of strings against which to search the key
+  @type case_sensitive: boolean
+  @param case_sensitive: whether to provide a case-sensitive match
 
   @rtype: None or str
   @return: None if there is no match I{or} if there are multiple matches,
       otherwise the element from the list which matches
 
   """
-  mo = re.compile("^%s(\..*)?$" % re.escape(key))
-  names_filtered = [name for name in name_list if mo.match(name) is not None]
-  if len(names_filtered) != 1:
-    return None
-  return names_filtered[0]
+  if key in name_list:
+    return key
+
+  re_flags = 0
+  if not case_sensitive:
+    re_flags |= re.IGNORECASE
+    key = key.upper()
+  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
+  names_filtered = []
+  string_matches = []
+  for name in name_list:
+    if mo.match(name) is not None:
+      names_filtered.append(name)
+      if not case_sensitive and key == name.upper():
+        string_matches.append(name)
+
+  if len(string_matches) == 1:
+    return string_matches[0]
+  if len(names_filtered) == 1:
+    return names_filtered[0]
+  return None
 
 
 class HostInfo:
@@ -593,6 +599,16 @@ class HostInfo:
     return result
 
 
+def GetHostInfo(name=None):
+  """Lookup host name and raise an OpPrereqError for failures"""
+
+  try:
+    return HostInfo(name)
+  except errors.ResolverError, err:
+    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
+                               (err[0], err[2]), errors.ECODE_RESOLVER)
+
+
 def ListVolumeGroups():
   """List volume groups and their size
 
@@ -687,7 +703,7 @@ def TryConvert(fn, val):
   """
   try:
     nv = fn(val)
-  except (ValueError, TypeError), err:
+  except (ValueError, TypeError):
     nv = val
   return nv
 
@@ -701,7 +717,7 @@ def IsValidIP(ip):
   @type ip: str
   @param ip: the address to be checked
   @rtype: a regular expression match object
-  @return: a regular epression match object, or None if the
+  @return: a regular expression match object, or None if the
       address is not valid
 
   """
@@ -734,7 +750,7 @@ def BuildShellCmd(template, *args):
 
   This function will check all arguments in the args list so that they
   are valid shell parameters (i.e. they don't contain shell
-  metacharaters). If everything is ok, it will return the result of
+  metacharacters). If everything is ok, it will return the result of
   template % args.
 
   @type template: str
@@ -1063,7 +1079,7 @@ def ShellQuoteArgs(args):
   @type args: list
   @param args: list of arguments to be quoted
   @rtype: str
-  @return: the quoted arguments concatenaned with spaces
+  @return: the quoted arguments concatenated with spaces
 
   """
   return ' '.join([ShellQuote(i) for i in args])
@@ -1080,7 +1096,7 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
   @type port: int
   @param port: the port to connect to
   @type timeout: int
-  @param timeout: the timeout on the connection attemp
+  @param timeout: the timeout on the connection attempt
   @type live_port_needed: boolean
   @param live_port_needed: whether a closed port will cause the
       function to return failure, as if there was a timeout
@@ -1097,7 +1113,7 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
   if source is not None:
     try:
       sock.bind((source, 0))
-    except socket.error, (errcode, errstring):
+    except socket.error, (errcode, _):
       if errcode == errno.EADDRNOTAVAIL:
         success = False
 
@@ -1109,7 +1125,7 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
     success = True
   except socket.timeout:
     success = False
-  except socket.error, (errcode, errstring):
+  except socket.error, (errcode, _):
     success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
 
   return success
@@ -1122,7 +1138,7 @@ def OwnIpAddress(address):
   address.
 
   @type address: string
-  @param address: the addres to check
+  @param address: the address to check
   @rtype: bool
   @return: True if we own the address
 
@@ -1174,24 +1190,22 @@ def NewUUID():
   @rtype: str
 
   """
-  f = open("/proc/sys/kernel/random/uuid", "r")
-  try:
-    return f.read(128).rstrip("\n")
-  finally:
-    f.close()
+  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
 
 
-def GenerateSecret():
+def GenerateSecret(numbytes=20):
   """Generates a random secret.
 
-  This will generate a pseudo-random secret, and return its sha digest
+  This will generate a pseudo-random secret returning an hex string
   (so that it can be used where an ASCII string is needed).
 
+  @param numbytes: the number of bytes which will be represented by the returned
+      string (defaulting to 20, the length of a SHA1 hash)
   @rtype: str
-  @return: a sha1 hexdigest of a block of 64 random bytes
+  @return: an hex representation of the pseudo-random sequence
 
   """
-  return sha1(os.urandom(64)).hexdigest()
+  return os.urandom(numbytes).encode('hex')
 
 
 def EnsureDirs(dirs):
@@ -1212,21 +1226,18 @@ def EnsureDirs(dirs):
       raise errors.GenericError("%s is not a directory" % dir_name)
 
 
-def ReadFile(file_name, size=None):
+def ReadFile(file_name, size=-1):
   """Reads a file.
 
-  @type size: None or int
-  @param size: Read at most size bytes
+  @type size: int
+  @param size: Read at most size bytes (if negative, entire file)
   @rtype: str
-  @return: the (possibly partial) conent of the file
+  @return: the (possibly partial) content of the file
 
   """
   f = open(file_name, "r")
   try:
-    if size is None:
-      return f.read()
-    else:
-      return f.read(size)
+    return f.read(size)
   finally:
     f.close()
 
@@ -1360,14 +1371,14 @@ def FirstFree(seq, base=0):
 
 def all(seq, pred=bool):
   "Returns True if pred(x) is True for every element in the iterable"
-  for elem in itertools.ifilterfalse(pred, seq):
+  for _ in itertools.ifilterfalse(pred, seq):
     return False
   return True
 
 
 def any(seq, pred=bool):
   "Returns True if pred(x) is True for at least one element in the iterable"
-  for elem in itertools.ifilter(pred, seq):
+  for _ in itertools.ifilter(pred, seq):
     return True
   return False
 
@@ -1378,7 +1389,7 @@ def UniqueSequence(seq):
   Element order is preserved.
 
   @type seq: sequence
-  @param seq: the sequence with the source elementes
+  @param seq: the sequence with the source elements
   @rtype: list
   @return: list of unique elements from seq
 
@@ -1390,7 +1401,7 @@ def UniqueSequence(seq):
 def IsValidMac(mac):
   """Predicate to check if a MAC address is valid.
 
-  Checks wether the supplied MAC address is formally correct, only
+  Checks whether the supplied MAC address is formally correct, only
   accepts colon separated format.
 
   @type mac: str
@@ -1413,9 +1424,9 @@ def TestDelay(duration):
 
   """
   if duration < 0:
-    return False
+    return False, "Invalid sleep duration"
   time.sleep(duration)
-  return True
+  return True, None
 
 
 def _CloseFDNoErr(fd, retries=5):
@@ -1552,7 +1563,6 @@ def RemovePidFile(name):
   @param name: the daemon name used to derive the pidfile name
 
   """
-  pid = os.getpid()
   pidfilename = DaemonPidFileName(name)
   # TODO: we could check here that the file contains our pid
   try:
@@ -1594,24 +1604,31 @@ def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
 
   if not IsProcessAlive(pid):
     return
+
   _helper(pid, signal_, waitpid)
+
   if timeout <= 0:
     return
 
-  # Wait up to $timeout seconds
-  end = time.time() + timeout
-  wait = 0.01
-  while time.time() < end and IsProcessAlive(pid):
+  def _CheckProcess():
+    if not IsProcessAlive(pid):
+      return
+
     try:
       (result_pid, _) = os.waitpid(pid, os.WNOHANG)
-      if result_pid > 0:
-        break
     except OSError:
-      pass
-    time.sleep(wait)
-    # Make wait time longer for next try
-    if wait < 0.1:
-      wait *= 1.5
+      raise RetryAgain()
+
+    if result_pid > 0:
+      return
+
+    raise RetryAgain()
+
+  try:
+    # Wait up to $timeout seconds
+    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
+  except RetryTimeout:
+    pass
 
   if IsProcessAlive(pid):
     # Kill process if it's still alive
@@ -1704,20 +1721,26 @@ def MergeTime(timetuple):
   return float(seconds) + (float(microseconds) * 0.000001)
 
 
-def GetNodeDaemonPort():
-  """Get the node daemon port for this cluster.
+def GetDaemonPort(daemon_name):
+  """Get the daemon port for this cluster.
 
   Note that this routine does not read a ganeti-specific file, but
   instead uses C{socket.getservbyname} to allow pre-customization of
   this parameter outside of Ganeti.
 
+  @type daemon_name: string
+  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
   @rtype: int
 
   """
+  if daemon_name not in constants.DAEMONS_PORTS:
+    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
+
+  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
   try:
-    port = socket.getservbyname("ganeti-noded", "tcp")
+    port = socket.getservbyname(daemon_name, proto)
   except socket.error:
-    port = constants.DEFAULT_NODED_PORT
+    port = default_port
 
   return port
 
@@ -1786,6 +1809,15 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
       raise
 
 
+def IsNormAbsPath(path):
+  """Check whether a path is absolute and also normalized
+
+  This avoids things like /dir/../../other/path to be valid.
+
+  """
+  return os.path.normpath(path) == path and os.path.isabs(path)
+
+
 def TailFile(fname, lines=20):
   """Return the last lines from a file.
 
@@ -1816,11 +1848,13 @@ def SafeEncode(text):
   """Return a 'safe' version of a source string.
 
   This function mangles the input string and returns a version that
-  should be safe to disply/encode as ASCII. To this end, we first
+  should be safe to display/encode as ASCII. To this end, we first
   convert it to ASCII using the 'backslashreplace' encoding which
-  should get rid of any non-ASCII chars, and then we again encode it
-  via 'string_escape' which converts '\n' into '\\n' so that log
-  messages remain one-line.
+  should get rid of any non-ASCII chars, and then we process it
+  through a loop copied from the string repr sources in the python; we
+  don't use string_escape anymore since that escape single quotes and
+  backslashes too, and that is too much; and that escaping is not
+  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
 
   @type text: str or unicode
   @param text: input data
@@ -1828,9 +1862,23 @@ def SafeEncode(text):
   @return: a safe version of text
 
   """
-  text = text.encode('ascii', 'backslashreplace')
-  text = text.encode('string_escape')
-  return text
+  if isinstance(text, unicode):
+    # only if unicode; if str already, we handle it below
+    text = text.encode('ascii', 'backslashreplace')
+  resu = ""
+  for char in text:
+    c = ord(char)
+    if char  == '\t':
+      resu += r'\t'
+    elif char == '\n':
+      resu += r'\n'
+    elif char == '\r':
+      resu += r'\'r'
+    elif c < 32 or c >= 127: # non-printable
+      resu += "\\x%02x" % (c & 0xff)
+    else:
+      resu += char
+  return resu
 
 
 def CommaJoin(names):
@@ -1840,7 +1888,54 @@ def CommaJoin(names):
   @return: a string with the formatted results
 
   """
-  return ", ".join(["'%s'" % val for val in names])
+  return ", ".join([str(val) for val in names])
+
+
+def BytesToMebibyte(value):
+  """Converts bytes to mebibytes.
+
+  @type value: int
+  @param value: Value in bytes
+  @rtype: int
+  @return: Value in mebibytes
+
+  """
+  return int(round(value / (1024.0 * 1024.0), 0))
+
+
+def CalculateDirectorySize(path):
+  """Calculates the size of a directory recursively.
+
+  @type path: string
+  @param path: Path to directory
+  @rtype: int
+  @return: Size in mebibytes
+
+  """
+  size = 0
+
+  for (curpath, _, files) in os.walk(path):
+    for filename in files:
+      st = os.lstat(os.path.join(curpath, filename))
+      size += st.st_size
+
+  return BytesToMebibyte(size)
+
+
+def GetFilesystemStats(path):
+  """Returns the total and free space on a filesystem.
+
+  @type path: string
+  @param path: Path on filesystem to be examined
+  @rtype: int
+  @return: tuple of (Total space, Free space) in mebibytes
+
+  """
+  st = os.statvfs(path)
+
+  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
+  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
+  return (tsize, fsize)
 
 
 def LockedMethod(fn):
@@ -1885,6 +1980,201 @@ def LockFile(fd):
     raise
 
 
+def FormatTime(val):
+  """Formats a time value.
+
+  @type val: float or None
+  @param val: the timestamp as returned by time.time()
+  @return: a string value or N/A if we don't have a valid timestamp
+
+  """
+  if val is None or not isinstance(val, (int, float)):
+    return "N/A"
+  # these two codes works on Linux, but they are not guaranteed on all
+  # platforms
+  return time.strftime("%F %T", time.localtime(val))
+
+
+def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
+  """Reads the watcher pause file.
+
+  @type filename: string
+  @param filename: Path to watcher pause file
+  @type now: None, float or int
+  @param now: Current time as Unix timestamp
+  @type remove_after: int
+  @param remove_after: Remove watcher pause file after specified amount of
+    seconds past the pause end time
+
+  """
+  if now is None:
+    now = time.time()
+
+  try:
+    value = ReadFile(filename)
+  except IOError, err:
+    if err.errno != errno.ENOENT:
+      raise
+    value = None
+
+  if value is not None:
+    try:
+      value = int(value)
+    except ValueError:
+      logging.warning(("Watcher pause file (%s) contains invalid value,"
+                       " removing it"), filename)
+      RemoveFile(filename)
+      value = None
+
+    if value is not None:
+      # Remove file if it's outdated
+      if now > (value + remove_after):
+        RemoveFile(filename)
+        value = None
+
+      elif now > value:
+        value = None
+
+  return value
+
+
+class RetryTimeout(Exception):
+  """Retry loop timed out.
+
+  """
+
+
+class RetryAgain(Exception):
+  """Retry again.
+
+  """
+
+
+class _RetryDelayCalculator(object):
+  """Calculator for increasing delays.
+
+  """
+  __slots__ = [
+    "_factor",
+    "_limit",
+    "_next",
+    "_start",
+    ]
+
+  def __init__(self, start, factor, limit):
+    """Initializes this class.
+
+    @type start: float
+    @param start: Initial delay
+    @type factor: float
+    @param factor: Factor for delay increase
+    @type limit: float or None
+    @param limit: Upper limit for delay or None for no limit
+
+    """
+    assert start > 0.0
+    assert factor >= 1.0
+    assert limit is None or limit >= 0.0
+
+    self._start = start
+    self._factor = factor
+    self._limit = limit
+
+    self._next = start
+
+  def __call__(self):
+    """Returns current delay and calculates the next one.
+
+    """
+    current = self._next
+
+    # Update for next run
+    if self._limit is None or self._next < self._limit:
+      self._next = max(self._limit, self._next * self._factor)
+
+    return current
+
+
+#: Special delay to specify whole remaining timeout
+RETRY_REMAINING_TIME = object()
+
+
+def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
+          _time_fn=time.time):
+  """Call a function repeatedly until it succeeds.
+
+  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
+  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
+  total of C{timeout} seconds, this function throws L{RetryTimeout}.
+
+  C{delay} can be one of the following:
+    - callable returning the delay length as a float
+    - Tuple of (start, factor, limit)
+    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
+      useful when overriding L{wait_fn} to wait for an external event)
+    - A static delay as a number (int or float)
+
+  @type fn: callable
+  @param fn: Function to be called
+  @param delay: Either a callable (returning the delay), a tuple of (start,
+                factor, limit) (see L{_RetryDelayCalculator}),
+                L{RETRY_REMAINING_TIME} or a number (int or float)
+  @type timeout: float
+  @param timeout: Total timeout
+  @type wait_fn: callable
+  @param wait_fn: Waiting function
+  @return: Return value of function
+
+  """
+  assert callable(fn)
+  assert callable(wait_fn)
+  assert callable(_time_fn)
+
+  if args is None:
+    args = []
+
+  end_time = _time_fn() + timeout
+
+  if callable(delay):
+    # External function to calculate delay
+    calc_delay = delay
+
+  elif isinstance(delay, (tuple, list)):
+    # Increasing delay with optional upper boundary
+    (start, factor, limit) = delay
+    calc_delay = _RetryDelayCalculator(start, factor, limit)
+
+  elif delay is RETRY_REMAINING_TIME:
+    # Always use the remaining time
+    calc_delay = None
+
+  else:
+    # Static delay
+    calc_delay = lambda: delay
+
+  assert calc_delay is None or callable(calc_delay)
+
+  while True:
+    try:
+      return fn(*args)
+    except RetryAgain:
+      pass
+
+    remaining_time = end_time - _time_fn()
+
+    if remaining_time < 0.0:
+      raise RetryTimeout()
+
+    assert remaining_time >= 0.0
+
+    if calc_delay is None:
+      wait_fn(remaining_time)
+    else:
+      current_delay = calc_delay()
+      if current_delay > 0.0:
+        wait_fn(current_delay)
+
+
 class FileLock(object):
   """Utility class for file locks.
 
@@ -1939,6 +2229,8 @@ class FileLock(object):
       flag |= fcntl.LOCK_NB
       timeout_end = None
 
+    # TODO: Convert to utils.Retry
+
     retry = True
     while retry:
       try:
@@ -2004,6 +2296,43 @@ class FileLock(object):
                 "Failed to unlock %s" % self.filename)
 
 
+def SignalHandled(signums):
+  """Signal Handled decoration.
+
+  This special decorator installs a signal handler and then calls the target
+  function. The function must accept a 'signal_handlers' keyword argument,
+  which will contain a dict indexed by signal number, with SignalHandler
+  objects as values.
+
+  The decorator can be safely stacked with iself, to handle multiple signals
+  with different handlers.
+
+  @type signums: list
+  @param signums: signals to intercept
+
+  """
+  def wrap(fn):
+    def sig_function(*args, **kwargs):
+      assert 'signal_handlers' not in kwargs or \
+             kwargs['signal_handlers'] is None or \
+             isinstance(kwargs['signal_handlers'], dict), \
+             "Wrong signal_handlers parameter in original function call"
+      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
+        signal_handlers = kwargs['signal_handlers']
+      else:
+        signal_handlers = {}
+        kwargs['signal_handlers'] = signal_handlers
+      sighandler = SignalHandler(signums)
+      try:
+        for sig in signums:
+          signal_handlers[sig] = sighandler
+        return fn(*args, **kwargs)
+      finally:
+        sighandler.Reset()
+    return sig_function
+  return wrap
+
+
 class SignalHandler(object):
   """Generic signal handler class.
 
@@ -2025,11 +2354,7 @@ class SignalHandler(object):
     @param signum: Single signal number or set of signal numbers
 
     """
-    if isinstance(signum, (int, long)):
-      self.signum = set([signum])
-    else:
-      self.signum = set(signum)
-
+    self.signum = set(signum)
     self.called = False
 
     self._previous = {}
@@ -2104,12 +2429,12 @@ class FieldSet(object):
 
     @type field: str
     @param field: the string to match
-    @return: either False or a regular expression match object
+    @return: either None or a regular expression match object
 
     """
     for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
       return m
-    return False
+    return None
 
   def NonMatching(self, items):
     """Returns the list of fields not matching the current set