X-Git-Url: https://code.grnet.gr/git/ganeti-local/blobdiff_plain/cea881e554d3bb44d8ea83db6275ceb124915c34..e8d61457f16974cbf0d77479f9d06f4c6345a02e:/lib/utils.py diff --git a/lib/utils.py b/lib/utils.py index be3a080..0256925 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1,7 +1,7 @@ # # -# Copyright (C) 2006, 2007 Google Inc. +# Copyright (C) 2006, 2007, 2010 Google Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -28,6 +28,7 @@ the command line scripts. import os +import sys import time import subprocess import re @@ -43,20 +44,23 @@ import resource import logging import logging.handlers import signal +import OpenSSL import datetime import calendar +import hmac import collections from cStringIO import StringIO try: - from hashlib import sha1 + # pylint: disable-msg=F0401 + import ctypes except ImportError: - import sha - sha1 = sha.new + ctypes = None from ganeti import errors from ganeti import constants +from ganeti import compat _locksheld = [] @@ -69,6 +73,23 @@ no_fork = False _RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid" +HEX_CHAR_RE = r"[a-zA-Z0-9]" +VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S) +X509_SIGNATURE = re.compile(r"^%s:\s*(?P%s+)/(?P%s+)$" % + (re.escape(constants.X509_CERT_SIGNATURE_HEADER), + HEX_CHAR_RE, HEX_CHAR_RE), + re.S | re.I) + +_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$") + +# Certificate verification results +(CERT_WARNING, + CERT_ERROR) = range(1, 3) + +# Flags for mlockall() (from bits/mman.h) +_MCL_CURRENT = 1 +_MCL_FUTURE = 2 + class RunResult(object): """Holds the result of running external programs. @@ -121,7 +142,23 @@ class RunResult(object): output = property(_GetOutput, None, None, "Return full output") -def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): +def _BuildCmdEnvironment(env, reset): + """Builds the environment for an external program. + + """ + if reset: + cmd_env = {} + else: + cmd_env = os.environ.copy() + cmd_env["LC_ALL"] = "C" + + if env is not None: + cmd_env.update(env) + + return cmd_env + + +def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False): """Execute a (shell) command. The command should not read from its standard input, as it will be @@ -130,7 +167,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): @type cmd: string or list @param cmd: Command to run @type env: dict - @param env: Additional environment + @param env: Additional environment variables @type output: str @param output: if desired, the output of the command can be saved in a file instead of the RunResult instance; this @@ -148,23 +185,20 @@ def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): if no_fork: raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled") - if isinstance(cmd, list): - cmd = [str(val) for val in cmd] - strcmd = " ".join(cmd) - shell = False - else: + if isinstance(cmd, basestring): strcmd = cmd shell = True - logging.debug("RunCmd '%s'", strcmd) + else: + cmd = [str(val) for val in cmd] + strcmd = ShellQuoteArgs(cmd) + shell = False - if not reset_env: - cmd_env = os.environ.copy() - cmd_env["LC_ALL"] = "C" + if output: + logging.debug("RunCmd %s, output file '%s'", strcmd, output) else: - cmd_env = {} + logging.debug("RunCmd %s", strcmd) - if env is not None: - cmd_env.update(env) + cmd_env = _BuildCmdEnvironment(env, reset_env) try: if output is None: @@ -189,6 +223,201 @@ def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): return RunResult(exitcode, signal_, out, err, strcmd) +def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None, + pidfile=None): + """Start a daemon process after forking twice. + + @type cmd: string or list + @param cmd: Command to run + @type env: dict + @param env: Additional environment variables + @type cwd: string + @param cwd: Working directory for the program + @type output: string + @param output: Path to file in which to save the output + @type output_fd: int + @param output_fd: File descriptor for output + @type pidfile: string + @param pidfile: Process ID file + @rtype: int + @return: Daemon process ID + @raise errors.ProgrammerError: if we call this when forks are disabled + + """ + if no_fork: + raise errors.ProgrammerError("utils.StartDaemon() called with fork()" + " disabled") + + if output and not (bool(output) ^ (output_fd is not None)): + raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be" + " specified") + + if isinstance(cmd, basestring): + cmd = ["/bin/sh", "-c", cmd] + + strcmd = ShellQuoteArgs(cmd) + + if output: + logging.debug("StartDaemon %s, output file '%s'", strcmd, output) + else: + logging.debug("StartDaemon %s", strcmd) + + cmd_env = _BuildCmdEnvironment(env, False) + + # Create pipe for sending PID back + (pidpipe_read, pidpipe_write) = os.pipe() + try: + try: + # Create pipe for sending error messages + (errpipe_read, errpipe_write) = os.pipe() + try: + try: + # First fork + pid = os.fork() + if pid == 0: + try: + # Child process, won't return + _StartDaemonChild(errpipe_read, errpipe_write, + pidpipe_read, pidpipe_write, + cmd, cmd_env, cwd, + output, output_fd, pidfile) + finally: + # Well, maybe child process failed + os._exit(1) # pylint: disable-msg=W0212 + finally: + _CloseFDNoErr(errpipe_write) + + # Wait for daemon to be started (or an error message to arrive) and read + # up to 100 KB as an error message + errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024) + finally: + _CloseFDNoErr(errpipe_read) + finally: + _CloseFDNoErr(pidpipe_write) + + # Read up to 128 bytes for PID + pidtext = RetryOnSignal(os.read, pidpipe_read, 128) + finally: + _CloseFDNoErr(pidpipe_read) + + # Try to avoid zombies by waiting for child process + try: + os.waitpid(pid, 0) + except OSError: + pass + + if errormsg: + raise errors.OpExecError("Error when starting daemon process: %r" % + errormsg) + + try: + return int(pidtext) + except (ValueError, TypeError), err: + raise errors.OpExecError("Error while trying to parse PID %r: %s" % + (pidtext, err)) + + +def _StartDaemonChild(errpipe_read, errpipe_write, + pidpipe_read, pidpipe_write, + args, env, cwd, + output, fd_output, pidfile): + """Child process for starting daemon. + + """ + try: + # Close parent's side + _CloseFDNoErr(errpipe_read) + _CloseFDNoErr(pidpipe_read) + + # First child process + os.chdir("/") + os.umask(077) + os.setsid() + + # And fork for the second time + pid = os.fork() + if pid != 0: + # Exit first child process + os._exit(0) # pylint: disable-msg=W0212 + + # Make sure pipe is closed on execv* (and thereby notifies original process) + SetCloseOnExecFlag(errpipe_write, True) + + # List of file descriptors to be left open + noclose_fds = [errpipe_write] + + # Open PID file + if pidfile: + try: + # TODO: Atomic replace with another locked file instead of writing into + # it after creating + fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600) + + # Lock the PID file (and fail if not possible to do so). Any code + # wanting to send a signal to the daemon should try to lock the PID + # file before reading it. If acquiring the lock succeeds, the daemon is + # no longer running and the signal should not be sent. + LockFile(fd_pidfile) + + os.write(fd_pidfile, "%d\n" % os.getpid()) + except Exception, err: + raise Exception("Creating and locking PID file failed: %s" % err) + + # Keeping the file open to hold the lock + noclose_fds.append(fd_pidfile) + + SetCloseOnExecFlag(fd_pidfile, False) + else: + fd_pidfile = None + + # Open /dev/null + fd_devnull = os.open(os.devnull, os.O_RDWR) + + assert not output or (bool(output) ^ (fd_output is not None)) + + if fd_output is not None: + pass + elif output: + # Open output file + try: + # TODO: Implement flag to set append=yes/no + fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600) + except EnvironmentError, err: + raise Exception("Opening output file failed: %s" % err) + else: + fd_output = fd_devnull + + # Redirect standard I/O + os.dup2(fd_devnull, 0) + os.dup2(fd_output, 1) + os.dup2(fd_output, 2) + + # Send daemon PID to parent + RetryOnSignal(os.write, pidpipe_write, str(os.getpid())) + + # Close all file descriptors except stdio and error message pipe + CloseFDs(noclose_fds=noclose_fds) + + # Change working directory + os.chdir(cwd) + + if env is None: + os.execvp(args[0], args) + else: + os.execvpe(args[0], args, env) + except: # pylint: disable-msg=W0702 + try: + # Report errors to original process + buf = str(sys.exc_info()[1]) + + RetryOnSignal(os.write, errpipe_write, buf) + except: # pylint: disable-msg=W0702 + # Ignore errors in error handling + pass + + os._exit(1) # pylint: disable-msg=W0212 + + def _RunCmdPipe(cmd, env, via_shell, cwd): """Run a command and return its output. @@ -222,20 +451,10 @@ def _RunCmdPipe(cmd, env, via_shell, cwd): child.stderr.fileno(): (err, child.stderr), } for fd in fdmap: - status = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, status | os.O_NONBLOCK) + SetNonblockFlag(fd, True) while fdmap: - try: - pollresult = poller.poll() - except EnvironmentError, eerr: - if eerr.errno == errno.EINTR: - continue - raise - except select.error, serr: - if serr[0] == errno.EINTR: - continue - raise + pollresult = RetryOnSignal(poller.poll) for fd, event in pollresult: if event & select.POLLIN or event & select.POLLPRI: @@ -291,6 +510,61 @@ def _RunCmdFile(cmd, env, via_shell, output, cwd): return status +def SetCloseOnExecFlag(fd, enable): + """Sets or unsets the close-on-exec flag on a file descriptor. + + @type fd: int + @param fd: File descriptor + @type enable: bool + @param enable: Whether to set or unset it. + + """ + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + + if enable: + flags |= fcntl.FD_CLOEXEC + else: + flags &= ~fcntl.FD_CLOEXEC + + fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + +def SetNonblockFlag(fd, enable): + """Sets or unsets the O_NONBLOCK flag on on a file descriptor. + + @type fd: int + @param fd: File descriptor + @type enable: bool + @param enable: Whether to set or unset it + + """ + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + + if enable: + flags |= os.O_NONBLOCK + else: + flags &= ~os.O_NONBLOCK + + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +def RetryOnSignal(fn, *args, **kwargs): + """Calls a function again if it failed due to EINTR. + + """ + while True: + try: + return fn(*args, **kwargs) + except EnvironmentError, err: + if err.errno != errno.EINTR: + raise + except (socket.error, select.error), err: + # In python 2.6 and above select.error is an IOError, so it's handled + # above, in 2.5 and below it's not, and it's handled here. + if not (err.args and err.args[0] == errno.EINTR): + raise + + def RunParts(dir_name, env=None, reset_env=False): """Run Scripts or programs in a directory @@ -345,6 +619,24 @@ def RemoveFile(filename): raise +def RemoveDir(dirname): + """Remove an empty directory. + + Remove a directory, ignoring non-existing ones. + Other errors are passed. This includes the case, + where the directory is not empty, so it can't be removed. + + @type dirname: str + @param dirname: the empty directory to be removed + + """ + try: + os.rmdir(dirname) + except OSError, err: + if err.errno != errno.ENOENT: + raise + + def RenameFile(old, new, mkdir=False, mkdir_mode=0750): """Renames a file. @@ -430,7 +722,7 @@ def _FingerprintFile(filename): f = open(filename) - fp = sha1() + fp = compat.sha1_hash() while True: data = f.read(4096) if not data: @@ -493,8 +785,10 @@ def ForceDictType(target, key_types, allowed_values=None): msg = "'%s' has non-enforceable type %s" % (key, ktype) raise errors.ProgrammerError(msg) - if ktype == constants.VTYPE_STRING: - if not isinstance(target[key], basestring): + if ktype in (constants.VTYPE_STRING, constants.VTYPE_MAYBE_STRING): + if target[key] is None and ktype == constants.VTYPE_MAYBE_STRING: + pass + elif not isinstance(target[key], basestring): if isinstance(target[key], bool) and not target[key]: target[key] = '' else: @@ -528,6 +822,17 @@ def ForceDictType(target, key_types, allowed_values=None): raise errors.TypeEnforcementError(msg) +def _GetProcStatusPath(pid): + """Returns the path for a PID's proc status file. + + @type pid: int + @param pid: Process ID + @rtype: string + + """ + return "/proc/%d/status" % pid + + def IsProcessAlive(pid): """Check if a given pid exists on the system. @@ -539,17 +844,113 @@ def IsProcessAlive(pid): @return: True if the process exists """ + def _TryStat(name): + try: + os.stat(name) + return True + except EnvironmentError, err: + if err.errno in (errno.ENOENT, errno.ENOTDIR): + return False + elif err.errno == errno.EINVAL: + raise RetryAgain(err) + raise + + assert isinstance(pid, int), "pid must be an integer" if pid <= 0: return False + # /proc in a multiprocessor environment can have strange behaviors. + # Retry the os.stat a few times until we get a good result. + try: + return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5, + args=[_GetProcStatusPath(pid)]) + except RetryTimeout, err: + err.RaiseInner() + + +def _ParseSigsetT(sigset): + """Parse a rendered sigset_t value. + + This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t + function. + + @type sigset: string + @param sigset: Rendered signal set from /proc/$pid/status + @rtype: set + @return: Set of all enabled signal numbers + + """ + result = set() + + signum = 0 + for ch in reversed(sigset): + chv = int(ch, 16) + + # The following could be done in a loop, but it's easier to read and + # understand in the unrolled form + if chv & 1: + result.add(signum + 1) + if chv & 2: + result.add(signum + 2) + if chv & 4: + result.add(signum + 3) + if chv & 8: + result.add(signum + 4) + + signum += 4 + + return result + + +def _GetProcStatusField(pstatus, field): + """Retrieves a field from the contents of a proc status file. + + @type pstatus: string + @param pstatus: Contents of /proc/$pid/status + @type field: string + @param field: Name of field whose value should be returned + @rtype: string + + """ + for line in pstatus.splitlines(): + parts = line.split(":", 1) + + if len(parts) < 2 or parts[0] != field: + continue + + return parts[1].strip() + + return None + + +def IsProcessHandlingSignal(pid, signum, status_path=None): + """Checks whether a process is handling a signal. + + @type pid: int + @param pid: Process ID + @type signum: int + @param signum: Signal number + @rtype: bool + + """ + if status_path is None: + status_path = _GetProcStatusPath(pid) + try: - os.stat("/proc/%d/status" % pid) - return True + proc_status = ReadFile(status_path) except EnvironmentError, err: - if err.errno in (errno.ENOENT, errno.ENOTDIR): + # In at least one case, reading /proc/$pid/status failed with ESRCH. + if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH): return False raise + sigcgt = _GetProcStatusField(proc_status, "SigCgt") + if sigcgt is None: + raise RuntimeError("%s is missing 'SigCgt' field" % status_path) + + # Now check whether signal is handled + return signum in _ParseSigsetT(sigcgt) + def ReadPidFile(pidfile): """Read a pid from a file. @@ -562,7 +963,7 @@ def ReadPidFile(pidfile): """ try: - raw_data = ReadFile(pidfile) + raw_data = ReadOneLineFile(pidfile) except EnvironmentError, err: if err.errno != errno.ENOENT: logging.exception("Can't read pid file") @@ -577,6 +978,37 @@ def ReadPidFile(pidfile): return pid +def ReadLockedPidFile(path): + """Reads a locked PID file. + + This can be used together with L{StartDaemon}. + + @type path: string + @param path: Path to PID file + @return: PID as integer or, if file was unlocked or couldn't be opened, None + + """ + try: + fd = os.open(path, os.O_RDONLY) + except EnvironmentError, err: + if err.errno == errno.ENOENT: + # PID file doesn't exist + return None + raise + + try: + try: + # Try to acquire lock + LockFile(fd) + except errors.LockError: + # Couldn't lock, daemon is running + return int(os.read(fd, 100)) + finally: + os.close(fd) + + return None + + def MatchNameComponent(key, name_list, case_sensitive=True): """Try to match a name against a list. @@ -623,92 +1055,28 @@ def MatchNameComponent(key, name_list, case_sensitive=True): return None -class HostInfo: - """Class implementing resolver and hostname functionality - - """ - _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$") - - def __init__(self, name=None): - """Initialize the host name object. - - If the name argument is not passed, it will use this system's - name. - - """ - if name is None: - name = self.SysName() - - self.query = name - self.name, self.aliases, self.ipaddrs = self.LookupHostname(name) - self.ip = self.ipaddrs[0] - - def ShortName(self): - """Returns the hostname without domain. - - """ - return self.name.split('.')[0] - - @staticmethod - def SysName(): - """Return the current system's name. - - This is simply a wrapper over C{socket.gethostname()}. - - """ - return socket.gethostname() - - @staticmethod - def LookupHostname(hostname): - """Look up hostname - - @type hostname: str - @param hostname: hostname to look up +def ValidateServiceName(name): + """Validate the given service name. - @rtype: tuple - @return: a tuple (name, aliases, ipaddrs) as returned by - C{socket.gethostbyname_ex} - @raise errors.ResolverError: in case of errors in resolving + @type name: number or string + @param name: Service name or port specification - """ - try: - result = socket.gethostbyname_ex(hostname) - except socket.gaierror, err: - # hostname not found in DNS - raise errors.ResolverError(hostname, err.args[0], err.args[1]) - - return result - - @classmethod - def NormalizeName(cls, hostname): - """Validate and normalize the given hostname. - - @attention: the validation is a bit more relaxed than the standards - require; most importantly, we allow underscores in names - @raise errors.OpPrereqError: when the name is not valid + """ + try: + numport = int(name) + except (ValueError, TypeError): + # Non-numeric service name + valid = _VALID_SERVICE_NAME_RE.match(name) + else: + # Numeric port (protocols other than TCP or UDP might need adjustments + # here) + valid = (numport >= 0 and numport < (1 << 16)) - """ - hostname = hostname.lower() - if (not cls._VALID_NAME_RE.match(hostname) or - # double-dots, meaning empty label - ".." in hostname or - # empty initial label - hostname.startswith(".")): - raise errors.OpPrereqError("Invalid hostname '%s'" % hostname, - errors.ECODE_INVAL) - if hostname.endswith("."): - hostname = hostname.rstrip(".") - return hostname - - -def GetHostInfo(name=None): - """Lookup host name and raise an OpPrereqError for failures""" + if not valid: + raise errors.OpPrereqError("Invalid service name '%s'" % name, + errors.ECODE_INVAL) - try: - return HostInfo(name) - except errors.ResolverError, err: - raise errors.OpPrereqError("The given name (%s) does not resolve: %s" % - (err[0], err[2]), errors.ECODE_RESOLVER) + return name def ListVolumeGroups(): @@ -810,24 +1178,6 @@ def TryConvert(fn, val): return nv -def IsValidIP(ip): - """Verifies the syntax of an IPv4 address. - - This function checks if the IPv4 address passes is valid or not based - on syntax (not IP range, class calculations, etc.). - - @type ip: str - @param ip: the address to be checked - @rtype: a regular expression match object - @return: a regular expression match object, or None if the - address is not valid - - """ - unit = "(0|[1-9]\d{0,2})" - #TODO: convert and return only boolean - return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip) - - def IsValidShellParam(word): """Verifies is the given word is safe from the shell's p.o.v. @@ -950,18 +1300,61 @@ def ParseUnit(input_string): return value -def AddAuthorizedKey(file_name, key): +def ParseCpuMask(cpu_mask): + """Parse a CPU mask definition and return the list of CPU IDs. + + CPU mask format: comma-separated list of CPU IDs + or dash-separated ID ranges + Example: "0-2,5" -> "0,1,2,5" + + @type cpu_mask: str + @param cpu_mask: CPU mask definition + @rtype: list of int + @return: list of CPU IDs + + """ + if not cpu_mask: + return [] + cpu_list = [] + for range_def in cpu_mask.split(","): + boundaries = range_def.split("-") + n_elements = len(boundaries) + if n_elements > 2: + raise errors.ParseError("Invalid CPU ID range definition" + " (only one hyphen allowed): %s" % range_def) + try: + lower = int(boundaries[0]) + except (ValueError, TypeError), err: + raise errors.ParseError("Invalid CPU ID value for lower boundary of" + " CPU ID range: %s" % str(err)) + try: + higher = int(boundaries[-1]) + except (ValueError, TypeError), err: + raise errors.ParseError("Invalid CPU ID value for higher boundary of" + " CPU ID range: %s" % str(err)) + if lower > higher: + raise errors.ParseError("Invalid CPU ID range definition" + " (%d > %d): %s" % (lower, higher, range_def)) + cpu_list.extend(range(lower, higher + 1)) + return cpu_list + + +def AddAuthorizedKey(file_obj, key): """Adds an SSH public key to an authorized_keys file. - @type file_name: str - @param file_name: path to authorized_keys file + @type file_obj: str or file handle + @param file_obj: path to authorized_keys file @type key: str @param key: string containing key """ key_fields = key.split() - f = open(file_name, 'a+') + if isinstance(file_obj, basestring): + f = open(file_obj, 'a+') + else: + f = file_obj + try: nl = True for line in f: @@ -1067,8 +1460,8 @@ def AddHostToEtcHosts(hostname): L{constants.ETC_HOSTS} """ - hi = HostInfo(name=hostname) - SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()]) + SetEtcHostsEntry(constants.ETC_HOSTS, hostname.ip, hostname.name, + [hostname.name.split(".")[0]]) def RemoveEtcHostsEntry(file_name, hostname): @@ -1124,9 +1517,8 @@ def RemoveHostFromEtcHosts(hostname): L{constants.ETC_HOSTS} """ - hi = HostInfo(name=hostname) - RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name) - RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName()) + RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname) + RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname.split(".")[0]) def TimestampForFilename(): @@ -1199,66 +1591,46 @@ def ShellQuoteArgs(args): return ' '.join([ShellQuote(i) for i in args]) -def TcpPing(target, port, timeout=10, live_port_needed=False, source=None): - """Simple ping implementation using TCP connect(2). - - Check if the given IP is reachable by doing attempting a TCP connect - to it. - - @type target: str - @param target: the IP or hostname to ping - @type port: int - @param port: the port to connect to - @type timeout: int - @param timeout: the timeout on the connection attempt - @type live_port_needed: boolean - @param live_port_needed: whether a closed port will cause the - function to return failure, as if there was a timeout - @type source: str or None - @param source: if specified, will cause the connect to be made - from this specific source address; failures to bind other - than C{EADDRNOTAVAIL} will be ignored +class ShellWriter: + """Helper class to write scripts with indentation. """ - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + INDENT_STR = " " - success = False + def __init__(self, fh): + """Initializes this class. - if source is not None: - try: - sock.bind((source, 0)) - except socket.error, (errcode, _): - if errcode == errno.EADDRNOTAVAIL: - success = False + """ + self._fh = fh + self._indent = 0 - sock.settimeout(timeout) + def IncIndent(self): + """Increase indentation level by 1. - try: - sock.connect((target, port)) - sock.close() - success = True - except socket.timeout: - success = False - except socket.error, (errcode, _): - success = (not live_port_needed) and (errcode == errno.ECONNREFUSED) + """ + self._indent += 1 - return success + def DecIndent(self): + """Decrease indentation level by 1. + """ + assert self._indent > 0 + self._indent -= 1 -def OwnIpAddress(address): - """Check if the current host has the the given IP address. + def Write(self, txt, *args): + """Write line to output file. - Currently this is done by TCP-pinging the address from the loopback - address. + """ + assert self._indent >= 0 - @type address: string - @param address: the address to check - @rtype: bool - @return: True if we own the address + self._fh.write(self._indent * self.INDENT_STR) - """ - return TcpPing(address, constants.DEFAULT_NODED_PORT, - source=constants.LOCALHOST_IP_ADDRESS) + if args: + self._fh.write(txt % args) + else: + self._fh.write(txt) + + self._fh.write("\n") def ListVisibleFiles(path): @@ -1275,7 +1647,6 @@ def ListVisibleFiles(path): raise errors.ProgrammerError("Path passed to ListVisibleFiles is not" " absolute/normalized: '%s'" % path) files = [i for i in os.listdir(path) if not i.startswith(".")] - files.sort() return files @@ -1340,6 +1711,11 @@ def EnsureDirs(dirs): if err.errno != errno.EEXIST: raise errors.GenericError("Cannot create needed directory" " '%s': %s" % (dir_name, err)) + try: + os.chmod(dir_name, dir_mode) + except EnvironmentError, err: + raise errors.GenericError("Cannot change directory permissions on" + " '%s': %s" % (dir_name, err)) if not os.path.isdir(dir_name): raise errors.GenericError("%s is not a directory" % dir_name) @@ -1459,6 +1835,24 @@ def WriteFile(file_name, fn=None, data=None, return result +def ReadOneLineFile(file_name, strict=False): + """Return the first non-empty line from a file. + + @type strict: boolean + @param strict: if True, abort if the file has more than one + non-empty line + + """ + file_lines = ReadFile(file_name).splitlines() + full_lines = filter(bool, file_lines) + if not file_lines or not full_lines: + raise errors.GenericError("No data in one-liner file %s" % file_name) + elif strict and len(full_lines) > 1: + raise errors.GenericError("Too many lines in one-liner file %s" % + file_name) + return full_lines[0] + + def FirstFree(seq, base=0): """Returns the first non-existing integer from seq. @@ -1682,7 +2076,41 @@ def CloseFDs(noclose_fds=None): _CloseFDNoErr(fd) -def Daemonize(logfile): +def Mlockall(_ctypes=ctypes): + """Lock current process' virtual address space into RAM. + + This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE), + see mlock(2) for more details. This function requires ctypes module. + + @raises errors.NoCtypesError: if ctypes module is not found + + """ + if _ctypes is None: + raise errors.NoCtypesError() + + libc = _ctypes.cdll.LoadLibrary("libc.so.6") + if libc is None: + logging.error("Cannot set memory lock, ctypes cannot load libc") + return + + # Some older version of the ctypes module don't have built-in functionality + # to access the errno global variable, where function error codes are stored. + # By declaring this variable as a pointer to an integer we can then access + # its value correctly, should the mlockall call fail, in order to see what + # the actual error code was. + # pylint: disable-msg=W0212 + libc.__errno_location.restype = _ctypes.POINTER(_ctypes.c_int) + + if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE): + # pylint: disable-msg=W0212 + logging.error("Cannot set memory lock: %s", + os.strerror(libc.__errno_location().contents.value)) + return + + logging.debug("Memory lock set") + + +def Daemonize(logfile, run_uid, run_gid): """Daemonize the current process. This detaches the current process from the controlling terminal and @@ -1690,6 +2118,10 @@ def Daemonize(logfile): @type logfile: str @param logfile: the logfile to which we should redirect stdout/stderr + @type run_uid: int + @param run_uid: Run the child under this uid + @type run_gid: int + @param run_gid: Run the child under this gid @rtype: int @return: the value zero @@ -1703,6 +2135,11 @@ def Daemonize(logfile): pid = os.fork() if (pid == 0): # The first child. os.setsid() + # FIXME: When removing again and moving to start-stop-daemon privilege drop + # make sure to check for config permission and bail out when invoked + # with wrong user. + os.setgid(run_gid) + os.setuid(run_uid) # this might fail pid = os.fork() # Fork a second child. if (pid == 0): # The second child. @@ -1751,6 +2188,19 @@ def EnsureDaemon(name): return True +def StopDaemon(name): + """Stop daemon + + """ + result = RunCmd([constants.DAEMON_UTIL, "stop", name]) + if result.failed: + logging.error("Can't stop daemon '%s', failure %s, output: %s", + name, result.fail_reason, result.output) + return False + + return True + + def WritePidFile(name): """Write the current process pidfile. @@ -1807,8 +2257,7 @@ def KillProcess(pid, signal_=signal.SIGTERM, timeout=30, """ def _helper(pid, signal_, wait): """Simple helper to encapsulate the kill/waitpid sequence""" - os.kill(pid, signal_) - if wait: + if IgnoreProcessNotFound(os.kill, pid, signal_) and wait: try: os.waitpid(pid, os.WNOHANG) except OSError: @@ -1946,32 +2395,43 @@ def MergeTime(timetuple): return float(seconds) + (float(microseconds) * 0.000001) -def GetDaemonPort(daemon_name): - """Get the daemon port for this cluster. +class LogFileHandler(logging.FileHandler): + """Log handler that doesn't fallback to stderr. - Note that this routine does not read a ganeti-specific file, but - instead uses C{socket.getservbyname} to allow pre-customization of - this parameter outside of Ganeti. - - @type daemon_name: string - @param daemon_name: daemon name (in constants.DAEMONS_PORTS) - @rtype: int + When an error occurs while writing on the logfile, logging.FileHandler tries + to log on stderr. This doesn't work in ganeti since stderr is redirected to + the logfile. This class avoids failures reporting errors to /dev/console. """ - if daemon_name not in constants.DAEMONS_PORTS: - raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name) + def __init__(self, filename, mode="a", encoding=None): + """Open the specified file and use it as the stream for logging. - (proto, default_port) = constants.DAEMONS_PORTS[daemon_name] - try: - port = socket.getservbyname(daemon_name, proto) - except socket.error: - port = default_port + Also open /dev/console to report errors while logging. + + """ + logging.FileHandler.__init__(self, filename, mode, encoding) + self.console = open(constants.DEV_CONSOLE, "a") - return port + def handleError(self, record): # pylint: disable-msg=C0103 + """Handle errors which occur during an emit() call. + + Try to handle errors with FileHandler method, if it fails write to + /dev/console. + + """ + try: + logging.FileHandler.handleError(self, record) + except Exception: # pylint: disable-msg=W0703 + try: + self.console.write("Cannot log message:\n%s\n" % self.format(record)) + except Exception: # pylint: disable-msg=W0703 + # Log handler tried everything it could, now just give up + pass def SetupLogging(logfile, debug=0, stderr_logging=False, program="", - multithreaded=False, syslog=constants.SYSLOG_USAGE): + multithreaded=False, syslog=constants.SYSLOG_USAGE, + console_logging=False): """Configures the logging module. @type logfile: str @@ -1990,6 +2450,9 @@ def SetupLogging(logfile, debug=0, stderr_logging=False, program="", - if no, syslog is not used - if yes, syslog is used (in addition to file-logging) - if only, only syslog is used + @type console_logging: boolean + @param console_logging: if True, will use a FileHandler which falls back to + the system console if logging fails @raise EnvironmentError: if we can't open the log file and syslog/stderr logging is disabled @@ -2041,7 +2504,10 @@ def SetupLogging(logfile, debug=0, stderr_logging=False, program="", # the error if stderr_logging is True, and if false we re-raise the # exception since otherwise we could run but without any logs at all try: - logfile_handler = logging.FileHandler(logfile) + if console_logging: + logfile_handler = LogFileHandler(logfile) + else: + logfile_handler = logging.FileHandler(logfile) logfile_handler.setFormatter(formatter) if debug: logfile_handler.setLevel(logging.DEBUG) @@ -2121,6 +2587,13 @@ def TailFile(fname, lines=20): return rows[-lines:] +def FormatTimestampWithTZ(secs): + """Formats a Unix timestamp with the local timezone. + + """ + return time.strftime("%F %T %Z", time.gmtime(secs)) + + def _ParseAsn1Generalizedtime(value): """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL. @@ -2184,6 +2657,178 @@ def GetX509CertValidity(cert): return (not_before, not_after) +def _VerifyCertificateInner(expired, not_before, not_after, now, + warn_days, error_days): + """Verifies certificate validity. + + @type expired: bool + @param expired: Whether pyOpenSSL considers the certificate as expired + @type not_before: number or None + @param not_before: Unix timestamp before which certificate is not valid + @type not_after: number or None + @param not_after: Unix timestamp after which certificate is invalid + @type now: number + @param now: Current time as Unix timestamp + @type warn_days: number or None + @param warn_days: How many days before expiration a warning should be reported + @type error_days: number or None + @param error_days: How many days before expiration an error should be reported + + """ + if expired: + msg = "Certificate is expired" + + if not_before is not None and not_after is not None: + msg += (" (valid from %s to %s)" % + (FormatTimestampWithTZ(not_before), + FormatTimestampWithTZ(not_after))) + elif not_before is not None: + msg += " (valid from %s)" % FormatTimestampWithTZ(not_before) + elif not_after is not None: + msg += " (valid until %s)" % FormatTimestampWithTZ(not_after) + + return (CERT_ERROR, msg) + + elif not_before is not None and not_before > now: + return (CERT_WARNING, + "Certificate not yet valid (valid from %s)" % + FormatTimestampWithTZ(not_before)) + + elif not_after is not None: + remaining_days = int((not_after - now) / (24 * 3600)) + + msg = "Certificate expires in about %d days" % remaining_days + + if error_days is not None and remaining_days <= error_days: + return (CERT_ERROR, msg) + + if warn_days is not None and remaining_days <= warn_days: + return (CERT_WARNING, msg) + + return (None, None) + + +def VerifyX509Certificate(cert, warn_days, error_days): + """Verifies a certificate for LUVerifyCluster. + + @type cert: OpenSSL.crypto.X509 + @param cert: X509 certificate object + @type warn_days: number or None + @param warn_days: How many days before expiration a warning should be reported + @type error_days: number or None + @param error_days: How many days before expiration an error should be reported + + """ + # Depending on the pyOpenSSL version, this can just return (None, None) + (not_before, not_after) = GetX509CertValidity(cert) + + return _VerifyCertificateInner(cert.has_expired(), not_before, not_after, + time.time(), warn_days, error_days) + + +def SignX509Certificate(cert, key, salt): + """Sign a X509 certificate. + + An RFC822-like signature header is added in front of the certificate. + + @type cert: OpenSSL.crypto.X509 + @param cert: X509 certificate object + @type key: string + @param key: Key for HMAC + @type salt: string + @param salt: Salt for HMAC + @rtype: string + @return: Serialized and signed certificate in PEM format + + """ + if not VALID_X509_SIGNATURE_SALT.match(salt): + raise errors.GenericError("Invalid salt: %r" % salt) + + # Dumping as PEM here ensures the certificate is in a sane format + cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + + return ("%s: %s/%s\n\n%s" % + (constants.X509_CERT_SIGNATURE_HEADER, salt, + Sha1Hmac(key, cert_pem, salt=salt), + cert_pem)) + + +def _ExtractX509CertificateSignature(cert_pem): + """Helper function to extract signature from X509 certificate. + + """ + # Extract signature from original PEM data + for line in cert_pem.splitlines(): + if line.startswith("---"): + break + + m = X509_SIGNATURE.match(line.strip()) + if m: + return (m.group("salt"), m.group("sign")) + + raise errors.GenericError("X509 certificate signature is missing") + + +def LoadSignedX509Certificate(cert_pem, key): + """Verifies a signed X509 certificate. + + @type cert_pem: string + @param cert_pem: Certificate in PEM format and with signature header + @type key: string + @param key: Key for HMAC + @rtype: tuple; (OpenSSL.crypto.X509, string) + @return: X509 certificate object and salt + + """ + (salt, signature) = _ExtractX509CertificateSignature(cert_pem) + + # Load certificate + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem) + + # Dump again to ensure it's in a sane format + sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + + if not VerifySha1Hmac(key, sane_pem, signature, salt=salt): + raise errors.GenericError("X509 certificate signature is invalid") + + return (cert, salt) + + +def Sha1Hmac(key, text, salt=None): + """Calculates the HMAC-SHA1 digest of a text. + + HMAC is defined in RFC2104. + + @type key: string + @param key: Secret key + @type text: string + + """ + if salt: + salted_text = salt + text + else: + salted_text = text + + return hmac.new(key, salted_text, compat.sha1).hexdigest() + + +def VerifySha1Hmac(key, text, digest, salt=None): + """Verifies the HMAC-SHA1 digest of a text. + + HMAC is defined in RFC2104. + + @type key: string + @param key: Secret key + @type text: string + @type digest: string + @param digest: Expected digest + @rtype: bool + @return: Whether HMAC-SHA1 digest matches + + """ + return digest.lower() == Sha1Hmac(key, text, salt=salt).lower() + + def SafeEncode(text): """Return a 'safe' version of a source string. @@ -2304,6 +2949,26 @@ def CalculateDirectorySize(path): return BytesToMebibyte(size) +def GetMounts(filename=constants.PROC_MOUNTS): + """Returns the list of mounted filesystems. + + This function is Linux-specific. + + @param filename: path of mounts file (/proc/mounts by default) + @rtype: list of tuples + @return: list of mount entries (device, mountpoint, fstype, options) + + """ + # TODO(iustin): investigate non-Linux options (e.g. via mount output) + data = [] + mountlines = ReadFile(filename).splitlines() + for line in mountlines: + device, mountpoint, fstype, options, _ = line.split(None, 4) + data.append((device, mountpoint, fstype, options)) + + return data + + def GetFilesystemStats(path): """Returns the total and free space on a filesystem. @@ -2367,32 +3032,44 @@ def RunInSeparateProcess(fn, *args): return bool(exitcode) -def LockedMethod(fn): - """Synchronized object access decorator. +def IgnoreProcessNotFound(fn, *args, **kwargs): + """Ignores ESRCH when calling a process-related function. + + ESRCH is raised when a process is not found. - This decorator is intended to protect access to an object using the - object's own lock which is hardcoded to '_lock'. + @rtype: bool + @return: Whether process was found """ - def _LockDebug(*args, **kwargs): - if debug_locks: - logging.debug(*args, **kwargs) + try: + fn(*args, **kwargs) + except EnvironmentError, err: + # Ignore ESRCH + if err.errno == errno.ESRCH: + return False + raise - def wrapper(self, *args, **kwargs): - # pylint: disable-msg=W0212 - assert hasattr(self, '_lock') - lock = self._lock - _LockDebug("Waiting for %s", lock) - lock.acquire() - try: - _LockDebug("Acquired %s", lock) - result = fn(self, *args, **kwargs) - finally: - _LockDebug("Releasing %s", lock) - lock.release() - _LockDebug("Released %s", lock) - return result - return wrapper + return True + + +def IgnoreSignals(fn, *args, **kwargs): + """Tries to call a function ignoring failures due to EINTR. + + """ + try: + return fn(*args, **kwargs) + except EnvironmentError, err: + if err.errno == errno.EINTR: + return None + else: + raise + except (select.error, socket.error), err: + # In python 2.6 and above select.error is an IOError, so it's handled + # above, in 2.5 and below it's not, and it's handled here. + if err.args and err.args[0] == errno.EINTR: + return None + else: + raise def LockFile(fd): @@ -2425,6 +3102,31 @@ def FormatTime(val): return time.strftime("%F %T", time.localtime(val)) +def FormatSeconds(secs): + """Formats seconds for easier reading. + + @type secs: number + @param secs: Number of seconds + @rtype: string + @return: Formatted seconds (e.g. "2d 9h 19m 49s") + + """ + parts = [] + + secs = round(secs, 0) + + if secs > 0: + # Negative values would be a bit tricky + for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]: + (complete, secs) = divmod(secs, one) + if complete or parts: + parts.append("%d%s" % (complete, unit)) + + parts.append("%ds" % secs) + + return " ".join(parts) + + def ReadWatcherPauseFile(filename, now=None, remove_after=3600): """Reads the watcher pause file. @@ -2471,12 +3173,25 @@ def ReadWatcherPauseFile(filename, now=None, remove_after=3600): class RetryTimeout(Exception): """Retry loop timed out. + Any arguments which was passed by the retried function to RetryAgain will be + preserved in RetryTimeout, if it is raised. If such argument was an exception + the RaiseInner helper method will reraise it. + """ + def RaiseInner(self): + if self.args and isinstance(self.args[0], Exception): + raise self.args[0] + else: + raise RetryTimeout(*self.args) class RetryAgain(Exception): """Retry again. + Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as + arguments to RetryTimeout. If an exception is passed, the RaiseInner() method + of the RetryTimeout() method can be used to reraise it. + """ @@ -2585,11 +3300,12 @@ def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep, assert calc_delay is None or callable(calc_delay) while True: + retry_args = [] try: # pylint: disable-msg=W0142 return fn(*args) - except RetryAgain: - pass + except RetryAgain, err: + retry_args = err.args except RetryTimeout: raise errors.ProgrammerError("Nested retry loop detected that didn't" " handle RetryTimeout") @@ -2597,7 +3313,8 @@ def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep, remaining_time = end_time - _time_fn() if remaining_time < 0.0: - raise RetryTimeout() + # pylint: disable-msg=W0142 + raise RetryTimeout(*retry_args) assert remaining_time >= 0.0 @@ -2609,6 +3326,66 @@ def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep, wait_fn(current_delay) +def GetClosedTempfile(*args, **kwargs): + """Creates a temporary file and returns its path. + + """ + (fd, path) = tempfile.mkstemp(*args, **kwargs) + _CloseFDNoErr(fd) + return path + + +def GenerateSelfSignedX509Cert(common_name, validity): + """Generates a self-signed X509 certificate. + + @type common_name: string + @param common_name: commonName value + @type validity: int + @param validity: Validity for certificate in seconds + + """ + # Create private and public key + key = OpenSSL.crypto.PKey() + key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS) + + # Create self-signed certificate + cert = OpenSSL.crypto.X509() + if common_name: + cert.get_subject().CN = common_name + cert.set_serial_number(1) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(validity) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.sign(key, constants.X509_CERT_SIGN_DIGEST) + + key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key) + cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + + return (key_pem, cert_pem) + + +def GenerateSelfSignedSslCert(filename, common_name=constants.X509_CERT_CN, + validity=constants.X509_CERT_DEFAULT_VALIDITY): + """Legacy function to generate self-signed X509 certificate. + + @type filename: str + @param filename: path to write certificate to + @type common_name: string + @param common_name: commonName value + @type validity: int + @param validity: validity of certificate in number of days + + """ + # TODO: Investigate using the cluster name instead of X505_CERT_CN for + # common_name, as cluster-renames are very seldom, and it'd be nice if RAPI + # and node daemon certificates have the proper Subject/Issuer. + (key_pem, cert_pem) = GenerateSelfSignedX509Cert(common_name, + validity * 24 * 60 * 60) + + WriteFile(filename, mode=0400, data=key_pem + cert_pem) + + class FileLock(object): """Utility class for file locks. @@ -2820,6 +3597,58 @@ def SignalHandled(signums): return wrap +class SignalWakeupFd(object): + try: + # This is only supported in Python 2.5 and above (some distributions + # backported it to Python 2.4) + _set_wakeup_fd_fn = signal.set_wakeup_fd + except AttributeError: + # Not supported + def _SetWakeupFd(self, _): # pylint: disable-msg=R0201 + return -1 + else: + def _SetWakeupFd(self, fd): + return self._set_wakeup_fd_fn(fd) + + def __init__(self): + """Initializes this class. + + """ + (read_fd, write_fd) = os.pipe() + + # Once these succeeded, the file descriptors will be closed automatically. + # Buffer size 0 is important, otherwise .read() with a specified length + # might buffer data and the file descriptors won't be marked readable. + self._read_fh = os.fdopen(read_fd, "r", 0) + self._write_fh = os.fdopen(write_fd, "w", 0) + + self._previous = self._SetWakeupFd(self._write_fh.fileno()) + + # Utility functions + self.fileno = self._read_fh.fileno + self.read = self._read_fh.read + + def Reset(self): + """Restores the previous wakeup file descriptor. + + """ + if hasattr(self, "_previous") and self._previous is not None: + self._SetWakeupFd(self._previous) + self._previous = None + + def Notify(self): + """Notifies the wakeup file descriptor. + + """ + self._write_fh.write("\0") + + def __del__(self): + """Called before object deletion. + + """ + self.Reset() + + class SignalHandler(object): """Generic signal handler class. @@ -2834,16 +3663,23 @@ class SignalHandler(object): @ivar called: tracks whether any of the signals have been raised """ - def __init__(self, signum): + def __init__(self, signum, handler_fn=None, wakeup=None): """Constructs a new SignalHandler instance. @type signum: int or list of ints @param signum: Single signal number or set of signal numbers + @type handler_fn: callable + @param handler_fn: Signal handling function """ + assert handler_fn is None or callable(handler_fn) + self.signum = set(signum) self.called = False + self._handler_fn = handler_fn + self._wakeup = wakeup + self._previous = {} try: for signum in self.signum: @@ -2884,8 +3720,7 @@ class SignalHandler(object): """ self.called = False - # we don't care about arguments, but we leave them named for the future - def _HandleSignal(self, signum, frame): # pylint: disable-msg=W0613 + def _HandleSignal(self, signum, frame): """Actual signal handling function. """ @@ -2893,6 +3728,13 @@ class SignalHandler(object): # solution in Python -- there are no atomic types. self.called = True + if self._wakeup: + # Notify whoever is interested in signals + self._wakeup.Notify() + + if self._handler_fn: + self._handler_fn(signum, frame) + class FieldSet(object): """A simple field set.