Fix RPC result handling in _AssembleInstanceDisks
[ganeti-local] / lib / utils.py
index 8a4087a..f887303 100644 (file)
@@ -166,6 +166,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/'):
 
   return RunResult(exitcode, signal_, out, err, strcmd)
 
+
 def _RunCmdPipe(cmd, env, via_shell, cwd):
   """Run a command and return its output.
 
@@ -386,6 +387,69 @@ def CheckDict(target, template, logname=None):
     logging.warning('%s missing keys %s', logname, ', '.join(missing))
 
 
+def ForceDictType(target, key_types, allowed_values=None):
+  """Force the values of a dict to have certain types.
+
+  @type target: dict
+  @param target: the dict to update
+  @type key_types: dict
+  @param key_types: dict mapping target dict keys to types
+                    in constants.ENFORCEABLE_TYPES
+  @type allowed_values: list
+  @keyword allowed_values: list of specially allowed values
+
+  """
+  if allowed_values is None:
+    allowed_values = []
+
+  for key in target:
+    if key not in key_types:
+      msg = "Unknown key '%s'" % key
+      raise errors.TypeEnforcementError(msg)
+
+    if target[key] in allowed_values:
+      continue
+
+    type = key_types[key]
+    if type not in constants.ENFORCEABLE_TYPES:
+      msg = "'%s' has non-enforceable type %s" % (key, type)
+      raise errors.ProgrammerError(msg)
+
+    if type == constants.VTYPE_STRING:
+      if not isinstance(target[key], basestring):
+        if isinstance(target[key], bool) and not target[key]:
+          target[key] = ''
+        else:
+          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
+          raise errors.TypeEnforcementError(msg)
+    elif type == constants.VTYPE_BOOL:
+      if isinstance(target[key], basestring) and target[key]:
+        if target[key].lower() == constants.VALUE_FALSE:
+          target[key] = False
+        elif target[key].lower() == constants.VALUE_TRUE:
+          target[key] = True
+        else:
+          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
+          raise errors.TypeEnforcementError(msg)
+      elif target[key]:
+        target[key] = True
+      else:
+        target[key] = False
+    elif type == constants.VTYPE_SIZE:
+      try:
+        target[key] = ParseUnit(target[key])
+      except errors.UnitParseError, err:
+        msg = "'%s' (value %s) is not a valid size. error: %s" % \
+              (key, target[key], err)
+        raise errors.TypeEnforcementError(msg)
+    elif type == constants.VTYPE_INT:
+      try:
+        target[key] = int(target[key])
+      except (ValueError, TypeError):
+        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
+        raise errors.TypeEnforcementError(msg)
+
+
 def IsProcessAlive(pid):
   """Check if a given pid exists on the system.
 
@@ -415,7 +479,7 @@ def ReadPidFile(pidfile):
   @type  pidfile: string
   @param pidfile: path to the file containing the pid
   @rtype: int
-  @return: The process id, if the file exista and contains a valid PID,
+  @return: The process id, if the file exists and contains a valid PID,
            otherwise 0
 
   """
@@ -557,37 +621,6 @@ def BridgeExists(bridge):
   return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
 
 
-def CheckBEParams(beparams):
-  """Checks whether the user-supplied be-params are valid,
-  and converts them from string format where appropriate.
-
-  @type beparams: dict
-  @param beparams: new params dict
-
-  """
-  if beparams:
-    for item in beparams:
-      if item not in constants.BES_PARAMETERS:
-        raise errors.OpPrereqError("Unknown backend parameter %s" % item)
-      if item in (constants.BE_MEMORY, constants.BE_VCPUS):
-        val = beparams[item]
-        if val != constants.VALUE_DEFAULT:
-          try:
-            val = int(val)
-          except ValueError, err:
-            raise errors.OpPrereqError("Invalid %s size: %s" % (item, str(err)))
-          beparams[item] = val
-      if item in (constants.BE_AUTO_BALANCE):
-        val = beparams[item]
-        if not isinstance(val, bool):
-          if val == constants.VALUE_TRUE:
-            beparams[item] = True
-          elif val == constants.VALUE_FALSE:
-            beparams[item] = False
-          else:
-            raise errors.OpPrereqError("Invalid %s value: %s" % (item, val))
-
-
 def NiceSort(name_list):
   """Sort a list of strings based on digit and non-digit groupings.
 
@@ -750,7 +783,7 @@ def ParseUnit(input_string):
   is always an int in MiB.
 
   """
-  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', input_string)
+  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
   if not m:
     raise errors.UnitParseError("Invalid format")
 
@@ -1652,7 +1685,8 @@ def GetNodeDaemonPort():
   return port
 
 
-def SetupLogging(logfile, debug=False, stderr_logging=False, program=""):
+def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
+                 multithreaded=False):
   """Configures the logging module.
 
   @type logfile: str
@@ -1664,16 +1698,18 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program=""):
   @param stderr_logging: whether we should also log to the standard error
   @type program: str
   @param program: the name under which we should log messages
+  @type multithreaded: boolean
+  @param multithreaded: if True, will add the thread name to the log file
   @raise EnvironmentError: if we can't open the log file and
       stderr logging is disabled
 
   """
-  fmt = "%(asctime)s: " + program + " "
+  fmt = "%(asctime)s: " + program + " pid=%(process)d"
+  if multithreaded:
+    fmt += "/%(threadName)s"
   if debug:
-    fmt += ("pid=%(process)d/%(threadName)s %(levelname)s"
-           " %(module)s:%(lineno)s %(message)s")
-  else:
-    fmt += "pid=%(process)d %(levelname)s %(message)s"
+    fmt += " %(module)s:%(lineno)s"
+  fmt += " %(levelname)s %(message)s"
   formatter = logging.Formatter(fmt)
 
   root_logger = logging.getLogger("")
@@ -1705,7 +1741,7 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program=""):
     else:
       logfile_handler.setLevel(logging.INFO)
     root_logger.addHandler(logfile_handler)
-  except EnvironmentError, err:
+  except EnvironmentError:
     if stderr_logging:
       logging.exception("Failed to enable logging to file '%s'", logfile)
     else:
@@ -1739,6 +1775,27 @@ def TailFile(fname, lines=20):
   return rows[-lines:]
 
 
+def SafeEncode(text):
+  """Return a 'safe' version of a source string.
+
+  This function mangles the input string and returns a version that
+  should be safe to disply/encode as ASCII. To this end, we first
+  convert it to ASCII using the 'backslashreplace' encoding which
+  should get rid of any non-ASCII chars, and then we again encode it
+  via 'string_escape' which converts '\n' into '\\n' so that log
+  messages remain one-line.
+
+  @type text: str or unicode
+  @param text: input data
+  @rtype: str
+  @return: a safe version of text
+
+  """
+  text = text.encode('ascii', 'backslashreplace')
+  text = text.encode('string_escape')
+  return text
+
+
 def LockedMethod(fn):
   """Synchronized object access decorator.