X-Git-Url: https://code.grnet.gr/git/ganeti-local/blobdiff_plain/7d88772a2e6e5dc44287053b0e55dfd2b3e0b653..2ee88aeb76a2430ec0c7f86629bf66cfd0b6f564:/lib/utils.py diff --git a/lib/utils.py b/lib/utils.py index 6a6679c..48388ba 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -29,7 +29,6 @@ the command line scripts. import sys import os -import sha import time import subprocess import re @@ -47,6 +46,12 @@ import signal from cStringIO import StringIO +try: + from hashlib import sha1 +except ImportError: + import sha + sha1 = sha.new + from ganeti import errors from ganeti import constants @@ -151,11 +156,18 @@ def RunCmd(cmd, env=None, output=None, cwd='/'): if env is not None: cmd_env.update(env) - if output is None: - out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd) - else: - status = _RunCmdFile(cmd, cmd_env, shell, output, cwd) - out = err = "" + try: + if output is None: + out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd) + else: + status = _RunCmdFile(cmd, cmd_env, shell, output, cwd) + out = err = "" + except OSError, err: + if err.errno == errno.ENOENT: + raise errors.OpExecError("Can't execute '%s': not found (%s)" % + (strcmd, err)) + else: + raise if status >= 0: exitcode = status @@ -166,6 +178,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/'): return RunResult(exitcode, signal_, out, err, strcmd) + def _RunCmdPipe(cmd, env, via_shell, cwd): """Run a command and return its output. @@ -329,7 +342,7 @@ def _FingerprintFile(filename): f = open(filename) - fp = sha.sha() + fp = sha1() while True: data = f.read(4096) if not data: @@ -386,6 +399,69 @@ def CheckDict(target, template, logname=None): logging.warning('%s missing keys %s', logname, ', '.join(missing)) +def ForceDictType(target, key_types, allowed_values=None): + """Force the values of a dict to have certain types. + + @type target: dict + @param target: the dict to update + @type key_types: dict + @param key_types: dict mapping target dict keys to types + in constants.ENFORCEABLE_TYPES + @type allowed_values: list + @keyword allowed_values: list of specially allowed values + + """ + if allowed_values is None: + allowed_values = [] + + for key in target: + if key not in key_types: + msg = "Unknown key '%s'" % key + raise errors.TypeEnforcementError(msg) + + if target[key] in allowed_values: + continue + + type = key_types[key] + if type not in constants.ENFORCEABLE_TYPES: + msg = "'%s' has non-enforceable type %s" % (key, type) + raise errors.ProgrammerError(msg) + + if type == constants.VTYPE_STRING: + if not isinstance(target[key], basestring): + if isinstance(target[key], bool) and not target[key]: + target[key] = '' + else: + msg = "'%s' (value %s) is not a valid string" % (key, target[key]) + raise errors.TypeEnforcementError(msg) + elif type == constants.VTYPE_BOOL: + if isinstance(target[key], basestring) and target[key]: + if target[key].lower() == constants.VALUE_FALSE: + target[key] = False + elif target[key].lower() == constants.VALUE_TRUE: + target[key] = True + else: + msg = "'%s' (value %s) is not a valid boolean" % (key, target[key]) + raise errors.TypeEnforcementError(msg) + elif target[key]: + target[key] = True + else: + target[key] = False + elif type == constants.VTYPE_SIZE: + try: + target[key] = ParseUnit(target[key]) + except errors.UnitParseError, err: + msg = "'%s' (value %s) is not a valid size. error: %s" % \ + (key, target[key], err) + raise errors.TypeEnforcementError(msg) + elif type == constants.VTYPE_INT: + try: + target[key] = int(target[key]) + except (ValueError, TypeError): + msg = "'%s' (value %s) is not a valid integer" % (key, target[key]) + raise errors.TypeEnforcementError(msg) + + def IsProcessAlive(pid): """Check if a given pid exists on the system. @@ -415,7 +491,7 @@ def ReadPidFile(pidfile): @type pidfile: string @param pidfile: path to the file containing the pid @rtype: int - @return: The process id, if the file exista and contains a valid PID, + @return: The process id, if the file exists and contains a valid PID, otherwise 0 """ @@ -557,37 +633,6 @@ def BridgeExists(bridge): return os.path.isdir("/sys/class/net/%s/bridge" % bridge) -def CheckBEParams(beparams): - """Checks whether the user-supplied be-params are valid, - and converts them from string format where appropriate. - - @type beparams: dict - @param beparams: new params dict - - """ - if beparams: - for item in beparams: - if item not in constants.BES_PARAMETERS: - raise errors.OpPrereqError("Unknown backend parameter %s" % item) - if item in (constants.BE_MEMORY, constants.BE_VCPUS): - val = beparams[item] - if val != constants.VALUE_DEFAULT: - try: - val = int(val) - except ValueError, err: - raise errors.OpPrereqError("Invalid %s size: %s" % (item, str(err))) - beparams[item] = val - if item in (constants.BE_AUTO_BALANCE): - val = beparams[item] - if not isinstance(val, bool): - if val == constants.VALUE_TRUE: - beparams[item] = True - elif val == constants.VALUE_FALSE: - beparams[item] = False - else: - raise errors.OpPrereqError("Invalid %s value: %s" % (item, val)) - - def NiceSort(name_list): """Sort a list of strings based on digit and non-digit groupings. @@ -750,7 +795,7 @@ def ParseUnit(input_string): is always an int in MiB. """ - m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', input_string) + m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string)) if not m: raise errors.UnitParseError("Invalid format") @@ -862,6 +907,7 @@ def SetEtcHostsEntry(file_name, ip, hostname, aliases): @param aliases: the list of aliases to add for the hostname """ + # FIXME: use WriteFile + fn rather than duplicating its efforts # Ensure aliases are unique aliases = UniqueSequence([hostname] + aliases)[1:] @@ -884,6 +930,7 @@ def SetEtcHostsEntry(file_name, ip, hostname, aliases): out.flush() os.fsync(out) + os.chmod(tmpname, 0644) os.rename(tmpname, file_name) finally: f.close() @@ -917,6 +964,7 @@ def RemoveEtcHostsEntry(file_name, hostname): @param hostname: the hostname to be removed """ + # FIXME: use WriteFile + fn rather than duplicating its efforts fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name)) try: out = os.fdopen(fd, 'w') @@ -938,6 +986,7 @@ def RemoveEtcHostsEntry(file_name, hostname): out.flush() os.fsync(out) + os.chmod(tmpname, 0644) os.rename(tmpname, file_name) finally: f.close() @@ -1142,7 +1191,25 @@ def GenerateSecret(): @return: a sha1 hexdigest of a block of 64 random bytes """ - return sha.new(os.urandom(64)).hexdigest() + return sha1(os.urandom(64)).hexdigest() + + +def EnsureDirs(dirs): + """Make required directories, if they don't exist. + + @param dirs: list of tuples (dir_name, dir_mode) + @type dirs: list of (string, integer) + + """ + for dir_name, dir_mode in dirs: + try: + os.mkdir(dir_name, dir_mode) + except EnvironmentError, err: + if err.errno != errno.EEXIST: + raise errors.GenericError("Cannot create needed directory" + " '%s': %s" % (dir_name, err)) + if not os.path.isdir(dir_name): + raise errors.GenericError("%s is not a directory" % dir_name) def ReadFile(file_name, size=None): @@ -1178,7 +1245,7 @@ def WriteFile(file_name, fn=None, data=None, mtime/atime of the file. If the function doesn't raise an exception, it has succeeded and the - target file has the new contents. If the file has raised an + target file has the new contents. If the function has raised an exception, an existing target file should be unmodified and the temporary file should be removed. @@ -1187,7 +1254,7 @@ def WriteFile(file_name, fn=None, data=None, @type fn: callable @param fn: content writing function, called with file descriptor as parameter - @type data: sr + @type data: str @param data: contents of the file @type mode: int @param mode: file mode @@ -1210,7 +1277,7 @@ def WriteFile(file_name, fn=None, data=None, @return: None if the 'close' parameter evaluates to True, otherwise the file descriptor - @raise errors.ProgrammerError: if an of the arguments are not valid + @raise errors.ProgrammerError: if any of the arguments are not valid """ if not os.path.isabs(file_name): @@ -1229,6 +1296,7 @@ def WriteFile(file_name, fn=None, data=None, dir_name, base_name = os.path.split(file_name) fd, new_name = tempfile.mkstemp('.new', base_name, dir_name) + do_remove = True # here we need to make sure we remove the temp file, if any error # leaves it in place try: @@ -1249,13 +1317,15 @@ def WriteFile(file_name, fn=None, data=None, os.utime(new_name, (atime, mtime)) if not dry_run: os.rename(new_name, file_name) + do_remove = False finally: if close: os.close(fd) result = None else: result = fd - RemoveFile(new_name) + if do_remove: + RemoveFile(new_name) return result @@ -1409,7 +1479,7 @@ def Daemonize(logfile): @type logfile: str @param logfile: the logfile to which we should redirect stdout/stderr @rtype: int - @returns: the value zero + @return: the value zero """ UMASK = 077 @@ -1652,7 +1722,8 @@ def GetNodeDaemonPort(): return port -def SetupLogging(logfile, debug=False, stderr_logging=False, program=""): +def SetupLogging(logfile, debug=False, stderr_logging=False, program="", + multithreaded=False): """Configures the logging module. @type logfile: str @@ -1664,16 +1735,18 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program=""): @param stderr_logging: whether we should also log to the standard error @type program: str @param program: the name under which we should log messages + @type multithreaded: boolean + @param multithreaded: if True, will add the thread name to the log file @raise EnvironmentError: if we can't open the log file and stderr logging is disabled """ - fmt = "%(asctime)s: " + program + " " + fmt = "%(asctime)s: " + program + " pid=%(process)d" + if multithreaded: + fmt += "/%(threadName)s" if debug: - fmt += ("pid=%(process)d/%(threadName)s %(levelname)s" - " %(module)s:%(lineno)s %(message)s") - else: - fmt += "pid=%(process)d %(levelname)s %(message)s" + fmt += " %(module)s:%(lineno)s" + fmt += " %(levelname)s %(message)s" formatter = logging.Formatter(fmt) root_logger = logging.getLogger("") @@ -1705,13 +1778,77 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program=""): else: logfile_handler.setLevel(logging.INFO) root_logger.addHandler(logfile_handler) - except EnvironmentError, err: + except EnvironmentError: if stderr_logging: logging.exception("Failed to enable logging to file '%s'", logfile) else: # we need to re-raise the exception raise +def IsNormAbsPath(path): + """Check whether a path is absolute and also normalized + + This avoids things like /dir/../../other/path to be valid. + + """ + return os.path.normpath(path) == path and os.path.isabs(path) + +def TailFile(fname, lines=20): + """Return the last lines from a file. + + @note: this function will only read and parse the last 4KB of + the file; if the lines are very long, it could be that less + than the requested number of lines are returned + + @param fname: the file name + @type lines: int + @param lines: the (maximum) number of lines to return + + """ + fd = open(fname, "r") + try: + fd.seek(0, 2) + pos = fd.tell() + pos = max(0, pos-4096) + fd.seek(pos, 0) + raw_data = fd.read() + finally: + fd.close() + + rows = raw_data.splitlines() + return rows[-lines:] + + +def SafeEncode(text): + """Return a 'safe' version of a source string. + + This function mangles the input string and returns a version that + should be safe to disply/encode as ASCII. To this end, we first + convert it to ASCII using the 'backslashreplace' encoding which + should get rid of any non-ASCII chars, and then we again encode it + via 'string_escape' which converts '\n' into '\\n' so that log + messages remain one-line. + + @type text: str or unicode + @param text: input data + @rtype: str + @return: a safe version of text + + """ + text = text.encode('ascii', 'backslashreplace') + text = text.encode('string_escape') + return text + + +def CommaJoin(names): + """Nicely join a set of identifiers. + + @param names: set, list or tuple + @return: a string with the formatted results + + """ + return ", ".join(["'%s'" % val for val in names]) + def LockedMethod(fn): """Synchronized object access decorator.