Generate a shared HMAC key at cluster init time
[ganeti-local] / lib / utils.py
index fe0dcb1..aace5ec 100644 (file)
@@ -27,9 +27,7 @@ the command line scripts.
 """
 
 
 """
 
 
-import sys
 import os
 import os
-import sha
 import time
 import subprocess
 import re
 import time
 import subprocess
 import re
@@ -47,6 +45,12 @@ import signal
 
 from cStringIO import StringIO
 
 
 from cStringIO import StringIO
 
+try:
+  from hashlib import sha1
+except ImportError:
+  import sha
+  sha1 = sha.new
+
 from ganeti import errors
 from ganeti import constants
 
 from ganeti import errors
 from ganeti import constants
 
@@ -54,7 +58,6 @@ from ganeti import constants
 _locksheld = []
 _re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
 
 _locksheld = []
 _re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
 
-debug = False
 debug_locks = False
 
 #: when set to True, L{RunCmd} is disabled
 debug_locks = False
 
 #: when set to True, L{RunCmd} is disabled
@@ -131,7 +134,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/'):
       directory for the command; the default will be /
   @rtype: L{RunResult}
   @return: RunResult instance
       directory for the command; the default will be /
   @rtype: L{RunResult}
   @return: RunResult instance
-  @raise erors.ProgrammerError: if we call this when forks are disabled
+  @raise errors.ProgrammerError: if we call this when forks are disabled
 
   """
   if no_fork:
 
   """
   if no_fork:
@@ -151,11 +154,18 @@ def RunCmd(cmd, env=None, output=None, cwd='/'):
   if env is not None:
     cmd_env.update(env)
 
   if env is not None:
     cmd_env.update(env)
 
-  if output is None:
-    out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
-  else:
-    status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
-    out = err = ""
+  try:
+    if output is None:
+      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
+    else:
+      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
+      out = err = ""
+  except OSError, err:
+    if err.errno == errno.ENOENT:
+      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
+                               (strcmd, err))
+    else:
+      raise
 
   if status >= 0:
     exitcode = status
 
   if status >= 0:
     exitcode = status
@@ -330,7 +340,7 @@ def _FingerprintFile(filename):
 
   f = open(filename)
 
 
   f = open(filename)
 
-  fp = sha.sha()
+  fp = sha1()
   while True:
     data = f.read(4096)
     if not data:
   while True:
     data = f.read(4096)
     if not data:
@@ -361,32 +371,6 @@ def FingerprintFiles(files):
   return ret
 
 
   return ret
 
 
-def CheckDict(target, template, logname=None):
-  """Ensure a dictionary has a required set of keys.
-
-  For the given dictionaries I{target} and I{template}, ensure
-  I{target} has all the keys from I{template}. Missing keys are added
-  with values from template.
-
-  @type target: dict
-  @param target: the dictionary to update
-  @type template: dict
-  @param template: the dictionary holding the default values
-  @type logname: str or None
-  @param logname: if not None, causes the missing keys to be
-      logged with this name
-
-  """
-  missing = []
-  for k in template:
-    if k not in target:
-      missing.append(k)
-      target[k] = template[k]
-
-  if missing and logname:
-    logging.warning('%s missing keys %s', logname, ', '.join(missing))
-
-
 def ForceDictType(target, key_types, allowed_values=None):
   """Force the values of a dict to have certain types.
 
 def ForceDictType(target, key_types, allowed_values=None):
   """Force the values of a dict to have certain types.
 
@@ -402,6 +386,10 @@ def ForceDictType(target, key_types, allowed_values=None):
   if allowed_values is None:
     allowed_values = []
 
   if allowed_values is None:
     allowed_values = []
 
+  if not isinstance(target, dict):
+    msg = "Expected dictionary, got '%s'" % target
+    raise errors.TypeEnforcementError(msg)
+
   for key in target:
     if key not in key_types:
       msg = "Unknown key '%s'" % key
   for key in target:
     if key not in key_types:
       msg = "Unknown key '%s'" % key
@@ -410,19 +398,19 @@ def ForceDictType(target, key_types, allowed_values=None):
     if target[key] in allowed_values:
       continue
 
     if target[key] in allowed_values:
       continue
 
-    type = key_types[key]
-    if type not in constants.ENFORCEABLE_TYPES:
-      msg = "'%s' has non-enforceable type %s" % (key, type)
+    ktype = key_types[key]
+    if ktype not in constants.ENFORCEABLE_TYPES:
+      msg = "'%s' has non-enforceable type %s" % (key, ktype)
       raise errors.ProgrammerError(msg)
 
       raise errors.ProgrammerError(msg)
 
-    if type == constants.VTYPE_STRING:
+    if ktype == constants.VTYPE_STRING:
       if not isinstance(target[key], basestring):
         if isinstance(target[key], bool) and not target[key]:
           target[key] = ''
         else:
           msg = "'%s' (value %s) is not a valid string" % (key, target[key])
           raise errors.TypeEnforcementError(msg)
       if not isinstance(target[key], basestring):
         if isinstance(target[key], bool) and not target[key]:
           target[key] = ''
         else:
           msg = "'%s' (value %s) is not a valid string" % (key, target[key])
           raise errors.TypeEnforcementError(msg)
-    elif type == constants.VTYPE_BOOL:
+    elif ktype == constants.VTYPE_BOOL:
       if isinstance(target[key], basestring) and target[key]:
         if target[key].lower() == constants.VALUE_FALSE:
           target[key] = False
       if isinstance(target[key], basestring) and target[key]:
         if target[key].lower() == constants.VALUE_FALSE:
           target[key] = False
@@ -435,14 +423,14 @@ def ForceDictType(target, key_types, allowed_values=None):
         target[key] = True
       else:
         target[key] = False
         target[key] = True
       else:
         target[key] = False
-    elif type == constants.VTYPE_SIZE:
+    elif ktype == constants.VTYPE_SIZE:
       try:
         target[key] = ParseUnit(target[key])
       except errors.UnitParseError, err:
         msg = "'%s' (value %s) is not a valid size. error: %s" % \
               (key, target[key], err)
         raise errors.TypeEnforcementError(msg)
       try:
         target[key] = ParseUnit(target[key])
       except errors.UnitParseError, err:
         msg = "'%s' (value %s) is not a valid size. error: %s" % \
               (key, target[key], err)
         raise errors.TypeEnforcementError(msg)
-    elif type == constants.VTYPE_INT:
+    elif ktype == constants.VTYPE_INT:
       try:
         target[key] = int(target[key])
       except (ValueError, TypeError):
       try:
         target[key] = int(target[key])
       except (ValueError, TypeError):
@@ -675,7 +663,7 @@ def TryConvert(fn, val):
   """
   try:
     nv = fn(val)
   """
   try:
     nv = fn(val)
-  except (ValueError, TypeError), err:
+  except (ValueError, TypeError):
     nv = val
   return nv
 
     nv = val
   return nv
 
@@ -689,7 +677,7 @@ def IsValidIP(ip):
   @type ip: str
   @param ip: the address to be checked
   @rtype: a regular expression match object
   @type ip: str
   @param ip: the address to be checked
   @rtype: a regular expression match object
-  @return: a regular epression match object, or None if the
+  @return: a regular expression match object, or None if the
       address is not valid
 
   """
       address is not valid
 
   """
@@ -722,7 +710,7 @@ def BuildShellCmd(template, *args):
 
   This function will check all arguments in the args list so that they
   are valid shell parameters (i.e. they don't contain shell
 
   This function will check all arguments in the args list so that they
   are valid shell parameters (i.e. they don't contain shell
-  metacharaters). If everything is ok, it will return the result of
+  metacharacters). If everything is ok, it will return the result of
   template % args.
 
   @type template: str
   template % args.
 
   @type template: str
@@ -1051,7 +1039,7 @@ def ShellQuoteArgs(args):
   @type args: list
   @param args: list of arguments to be quoted
   @rtype: str
   @type args: list
   @param args: list of arguments to be quoted
   @rtype: str
-  @return: the quoted arguments concatenaned with spaces
+  @return: the quoted arguments concatenated with spaces
 
   """
   return ' '.join([ShellQuote(i) for i in args])
 
   """
   return ' '.join([ShellQuote(i) for i in args])
@@ -1068,7 +1056,7 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
   @type port: int
   @param port: the port to connect to
   @type timeout: int
   @type port: int
   @param port: the port to connect to
   @type timeout: int
-  @param timeout: the timeout on the connection attemp
+  @param timeout: the timeout on the connection attempt
   @type live_port_needed: boolean
   @param live_port_needed: whether a closed port will cause the
       function to return failure, as if there was a timeout
   @type live_port_needed: boolean
   @param live_port_needed: whether a closed port will cause the
       function to return failure, as if there was a timeout
@@ -1085,7 +1073,7 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
   if source is not None:
     try:
       sock.bind((source, 0))
   if source is not None:
     try:
       sock.bind((source, 0))
-    except socket.error, (errcode, errstring):
+    except socket.error, (errcode, _):
       if errcode == errno.EADDRNOTAVAIL:
         success = False
 
       if errcode == errno.EADDRNOTAVAIL:
         success = False
 
@@ -1110,7 +1098,7 @@ def OwnIpAddress(address):
   address.
 
   @type address: string
   address.
 
   @type address: string
-  @param address: the addres to check
+  @param address: the address to check
   @rtype: bool
   @return: True if we own the address
 
   @rtype: bool
   @return: True if we own the address
 
@@ -1179,7 +1167,7 @@ def GenerateSecret():
   @return: a sha1 hexdigest of a block of 64 random bytes
 
   """
   @return: a sha1 hexdigest of a block of 64 random bytes
 
   """
-  return sha.new(os.urandom(64)).hexdigest()
+  return sha1(os.urandom(64)).hexdigest()
 
 
 def EnsureDirs(dirs):
 
 
 def EnsureDirs(dirs):
@@ -1206,7 +1194,7 @@ def ReadFile(file_name, size=None):
   @type size: None or int
   @param size: Read at most size bytes
   @rtype: str
   @type size: None or int
   @param size: Read at most size bytes
   @rtype: str
-  @return: the (possibly partial) conent of the file
+  @return: the (possibly partial) content of the file
 
   """
   f = open(file_name, "r")
 
   """
   f = open(file_name, "r")
@@ -1284,6 +1272,7 @@ def WriteFile(file_name, fn=None, data=None,
 
   dir_name, base_name = os.path.split(file_name)
   fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
 
   dir_name, base_name = os.path.split(file_name)
   fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
+  do_remove = True
   # here we need to make sure we remove the temp file, if any error
   # leaves it in place
   try:
   # here we need to make sure we remove the temp file, if any error
   # leaves it in place
   try:
@@ -1304,13 +1293,15 @@ def WriteFile(file_name, fn=None, data=None,
       os.utime(new_name, (atime, mtime))
     if not dry_run:
       os.rename(new_name, file_name)
       os.utime(new_name, (atime, mtime))
     if not dry_run:
       os.rename(new_name, file_name)
+      do_remove = False
   finally:
     if close:
       os.close(fd)
       result = None
     else:
       result = fd
   finally:
     if close:
       os.close(fd)
       result = None
     else:
       result = fd
-    RemoveFile(new_name)
+    if do_remove:
+      RemoveFile(new_name)
 
   return result
 
 
   return result
 
@@ -1345,14 +1336,14 @@ def FirstFree(seq, base=0):
 
 def all(seq, pred=bool):
   "Returns True if pred(x) is True for every element in the iterable"
 
 def all(seq, pred=bool):
   "Returns True if pred(x) is True for every element in the iterable"
-  for elem in itertools.ifilterfalse(pred, seq):
+  for _ in itertools.ifilterfalse(pred, seq):
     return False
   return True
 
 
 def any(seq, pred=bool):
   "Returns True if pred(x) is True for at least one element in the iterable"
     return False
   return True
 
 
 def any(seq, pred=bool):
   "Returns True if pred(x) is True for at least one element in the iterable"
-  for elem in itertools.ifilter(pred, seq):
+  for _ in itertools.ifilter(pred, seq):
     return True
   return False
 
     return True
   return False
 
@@ -1363,7 +1354,7 @@ def UniqueSequence(seq):
   Element order is preserved.
 
   @type seq: sequence
   Element order is preserved.
 
   @type seq: sequence
-  @param seq: the sequence with the source elementes
+  @param seq: the sequence with the source elements
   @rtype: list
   @return: list of unique elements from seq
 
   @rtype: list
   @return: list of unique elements from seq
 
@@ -1375,7 +1366,7 @@ def UniqueSequence(seq):
 def IsValidMac(mac):
   """Predicate to check if a MAC address is valid.
 
 def IsValidMac(mac):
   """Predicate to check if a MAC address is valid.
 
-  Checks wether the supplied MAC address is formally correct, only
+  Checks whether the supplied MAC address is formally correct, only
   accepts colon separated format.
 
   @type mac: str
   accepts colon separated format.
 
   @type mac: str
@@ -1398,9 +1389,9 @@ def TestDelay(duration):
 
   """
   if duration < 0:
 
   """
   if duration < 0:
-    return False
+    return False, "Invalid sleep duration"
   time.sleep(duration)
   time.sleep(duration)
-  return True
+  return True, None
 
 
 def _CloseFDNoErr(fd, retries=5):
 
 
 def _CloseFDNoErr(fd, retries=5):
@@ -1537,7 +1528,6 @@ def RemovePidFile(name):
   @param name: the daemon name used to derive the pidfile name
 
   """
   @param name: the daemon name used to derive the pidfile name
 
   """
-  pid = os.getpid()
   pidfilename = DaemonPidFileName(name)
   # TODO: we could check here that the file contains our pid
   try:
   pidfilename = DaemonPidFileName(name)
   # TODO: we could check here that the file contains our pid
   try:
@@ -1770,6 +1760,13 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
       # we need to re-raise the exception
       raise
 
       # we need to re-raise the exception
       raise
 
+def IsNormAbsPath(path):
+  """Check whether a path is absolute and also normalized
+
+  This avoids things like /dir/../../other/path to be valid.
+
+  """
+  return os.path.normpath(path) == path and os.path.isabs(path)
 
 def TailFile(fname, lines=20):
   """Return the last lines from a file.
 
 def TailFile(fname, lines=20):
   """Return the last lines from a file.
@@ -1801,11 +1798,13 @@ def SafeEncode(text):
   """Return a 'safe' version of a source string.
 
   This function mangles the input string and returns a version that
   """Return a 'safe' version of a source string.
 
   This function mangles the input string and returns a version that
-  should be safe to disply/encode as ASCII. To this end, we first
+  should be safe to display/encode as ASCII. To this end, we first
   convert it to ASCII using the 'backslashreplace' encoding which
   convert it to ASCII using the 'backslashreplace' encoding which
-  should get rid of any non-ASCII chars, and then we again encode it
-  via 'string_escape' which converts '\n' into '\\n' so that log
-  messages remain one-line.
+  should get rid of any non-ASCII chars, and then we process it
+  through a loop copied from the string repr sources in the python; we
+  don't use string_escape anymore since that escape single quotes and
+  backslashes too, and that is too much; and that escaping is not
+  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
 
   @type text: str or unicode
   @param text: input data
 
   @type text: str or unicode
   @param text: input data
@@ -1813,9 +1812,33 @@ def SafeEncode(text):
   @return: a safe version of text
 
   """
   @return: a safe version of text
 
   """
-  text = text.encode('ascii', 'backslashreplace')
-  text = text.encode('string_escape')
-  return text
+  if isinstance(text, unicode):
+    # only if unicode; if str already, we handle it below
+    text = text.encode('ascii', 'backslashreplace')
+  resu = ""
+  for char in text:
+    c = ord(char)
+    if char  == '\t':
+      resu += r'\t'
+    elif char == '\n':
+      resu += r'\n'
+    elif char == '\r':
+      resu += r'\'r'
+    elif c < 32 or c >= 127: # non-printable
+      resu += "\\x%02x" % (c & 0xff)
+    else:
+      resu += char
+  return resu
+
+
+def CommaJoin(names):
+  """Nicely join a set of identifiers.
+
+  @param names: set, list or tuple
+  @return: a string with the formatted results
+
+  """
+  return ", ".join(["'%s'" % val for val in names])
 
 
 def LockedMethod(fn):
 
 
 def LockedMethod(fn):