X-Git-Url: https://code.grnet.gr/git/ganeti-local/blobdiff_plain/05b35f15d87b1c367e62e97c39cbb0b402b0d9c3..ca4ac9c9a6d0447bab972a9e6d945ec9234cec21:/lib/utils.py diff --git a/lib/utils.py b/lib/utils.py index 4484085..286ef2b 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1,7 +1,7 @@ # # -# Copyright (C) 2006, 2007 Google Inc. +# Copyright (C) 2006, 2007, 2010 Google Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -28,6 +28,7 @@ the command line scripts. import os +import sys import time import subprocess import re @@ -43,27 +44,23 @@ import resource import logging import logging.handlers import signal +import OpenSSL import datetime import calendar +import hmac import collections -import struct -import IN from cStringIO import StringIO try: - from hashlib import sha1 -except ImportError: - import sha - sha1 = sha.new - -try: + # pylint: disable-msg=F0401 import ctypes except ImportError: ctypes = None from ganeti import errors from ganeti import constants +from ganeti import compat _locksheld = [] @@ -76,22 +73,45 @@ no_fork = False _RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid" -# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...): -# struct ucred { pid_t pid; uid_t uid; gid_t gid; }; -# -# The GNU C Library defines gid_t and uid_t to be "unsigned int" and -# pid_t to "int". -# -# IEEE Std 1003.1-2008: -# "nlink_t, uid_t, gid_t, and id_t shall be integer types" -# "blksize_t, pid_t, and ssize_t shall be signed integer types" -_STRUCT_UCRED = "iII" -_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED) +HEX_CHAR_RE = r"[a-zA-Z0-9]" +VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S) +X509_SIGNATURE = re.compile(r"^%s:\s*(?P%s+)/(?P%s+)$" % + (re.escape(constants.X509_CERT_SIGNATURE_HEADER), + HEX_CHAR_RE, HEX_CHAR_RE), + re.S | re.I) + +_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$") + +UUID_RE = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-' + '[a-f0-9]{4}-[a-f0-9]{12}$') + +# Certificate verification results +(CERT_WARNING, + CERT_ERROR) = range(1, 3) # Flags for mlockall() (from bits/mman.h) _MCL_CURRENT = 1 _MCL_FUTURE = 2 +#: MAC checker regexp +_MAC_CHECK = re.compile("^([0-9a-f]{2}:){5}[0-9a-f]{2}$", re.I) + +(_TIMEOUT_NONE, + _TIMEOUT_TERM, + _TIMEOUT_KILL) = range(3) + +#: Shell param checker regexp +_SHELLPARAM_REGEX = re.compile(r"^[-a-zA-Z0-9._+/:%@]+$") + +#: Unit checker regexp +_PARSEUNIT_REGEX = re.compile(r"^([.\d]+)\s*([a-zA-Z]+)?$") + +#: ASN1 time regexp +_ASN1_TIME_REGEX = re.compile(r"^(\d+)([-+]\d\d)(\d\d)$") + +_SORTER_RE = re.compile("^%s(.*)$" % (8 * "(\D+|\d+)?")) +_SORTER_DIGIT = re.compile("^\d+$") + class RunResult(object): """Holds the result of running external programs. @@ -116,7 +136,8 @@ class RunResult(object): "failed", "fail_reason", "cmd"] - def __init__(self, exit_code, signal_, stdout, stderr, cmd): + def __init__(self, exit_code, signal_, stdout, stderr, cmd, timeout_action, + timeout): self.cmd = cmd self.exit_code = exit_code self.signal = signal_ @@ -124,12 +145,23 @@ class RunResult(object): self.stderr = stderr self.failed = (signal_ is not None or exit_code != 0) + fail_msgs = [] if self.signal is not None: - self.fail_reason = "terminated by signal %s" % self.signal + fail_msgs.append("terminated by signal %s" % self.signal) elif self.exit_code is not None: - self.fail_reason = "exited with exit code %s" % self.exit_code + fail_msgs.append("exited with exit code %s" % self.exit_code) else: - self.fail_reason = "unable to determine termination reason" + fail_msgs.append("unable to determine termination reason") + + if timeout_action == _TIMEOUT_TERM: + fail_msgs.append("terminated after timeout of %.2f seconds" % timeout) + elif timeout_action == _TIMEOUT_KILL: + fail_msgs.append(("force termination after timeout of %.2f seconds" + " and linger for another %.2f seconds") % + (timeout, constants.CHILD_LINGER_TIMEOUT)) + + if fail_msgs and self.failed: + self.fail_reason = CommaJoin(fail_msgs) if self.failed: logging.debug("Command '%s' failed (%s); output: %s", @@ -144,7 +176,24 @@ class RunResult(object): output = property(_GetOutput, None, None, "Return full output") -def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): +def _BuildCmdEnvironment(env, reset): + """Builds the environment for an external program. + + """ + if reset: + cmd_env = {} + else: + cmd_env = os.environ.copy() + cmd_env["LC_ALL"] = "C" + + if env is not None: + cmd_env.update(env) + + return cmd_env + + +def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False, + interactive=False, timeout=None): """Execute a (shell) command. The command should not read from its standard input, as it will be @@ -153,7 +202,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): @type cmd: string or list @param cmd: Command to run @type env: dict - @param env: Additional environment + @param env: Additional environment variables @type output: str @param output: if desired, the output of the command can be saved in a file instead of the RunResult instance; this @@ -163,6 +212,12 @@ def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): directory for the command; the default will be / @type reset_env: boolean @param reset_env: whether to reset or keep the default os environment + @type interactive: boolean + @param interactive: weather we pipe stdin, stdout and stderr + (default behaviour) or run the command interactive + @type timeout: int + @param timeout: If not None, timeout in seconds until child process gets + killed @rtype: L{RunResult} @return: RunResult instance @raise errors.ProgrammerError: if we call this when forks are disabled @@ -171,28 +226,31 @@ def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): if no_fork: raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled") - if isinstance(cmd, list): - cmd = [str(val) for val in cmd] - strcmd = " ".join(cmd) - shell = False - else: + if output and interactive: + raise errors.ProgrammerError("Parameters 'output' and 'interactive' can" + " not be provided at the same time") + + if isinstance(cmd, basestring): strcmd = cmd shell = True - logging.debug("RunCmd '%s'", strcmd) + else: + cmd = [str(val) for val in cmd] + strcmd = ShellQuoteArgs(cmd) + shell = False - if not reset_env: - cmd_env = os.environ.copy() - cmd_env["LC_ALL"] = "C" + if output: + logging.debug("RunCmd %s, output file '%s'", strcmd, output) else: - cmd_env = {} + logging.debug("RunCmd %s", strcmd) - if env is not None: - cmd_env.update(env) + cmd_env = _BuildCmdEnvironment(env, reset_env) try: if output is None: - out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd) + out, err, status, timeout_action = _RunCmdPipe(cmd, cmd_env, shell, cwd, + interactive, timeout) else: + timeout_action = _TIMEOUT_NONE status = _RunCmdFile(cmd, cmd_env, shell, output, cwd) out = err = "" except OSError, err: @@ -209,10 +267,254 @@ def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False): exitcode = None signal_ = -status - return RunResult(exitcode, signal_, out, err, strcmd) + return RunResult(exitcode, signal_, out, err, strcmd, timeout_action, timeout) + + +def SetupDaemonEnv(cwd="/", umask=077): + """Setup a daemon's environment. + + This should be called between the first and second fork, due to + setsid usage. + + @param cwd: the directory to which to chdir + @param umask: the umask to setup + + """ + os.chdir(cwd) + os.umask(umask) + os.setsid() + + +def SetupDaemonFDs(output_file, output_fd): + """Setups up a daemon's file descriptors. + + @param output_file: if not None, the file to which to redirect + stdout/stderr + @param output_fd: if not None, the file descriptor for stdout/stderr + + """ + # check that at most one is defined + assert [output_file, output_fd].count(None) >= 1 + + # Open /dev/null (read-only, only for stdin) + devnull_fd = os.open(os.devnull, os.O_RDONLY) + + if output_fd is not None: + pass + elif output_file is not None: + # Open output file + try: + output_fd = os.open(output_file, + os.O_WRONLY | os.O_CREAT | os.O_APPEND, 0600) + except EnvironmentError, err: + raise Exception("Opening output file failed: %s" % err) + else: + output_fd = os.open(os.devnull, os.O_WRONLY) + + # Redirect standard I/O + os.dup2(devnull_fd, 0) + os.dup2(output_fd, 1) + os.dup2(output_fd, 2) + + +def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None, + pidfile=None): + """Start a daemon process after forking twice. + + @type cmd: string or list + @param cmd: Command to run + @type env: dict + @param env: Additional environment variables + @type cwd: string + @param cwd: Working directory for the program + @type output: string + @param output: Path to file in which to save the output + @type output_fd: int + @param output_fd: File descriptor for output + @type pidfile: string + @param pidfile: Process ID file + @rtype: int + @return: Daemon process ID + @raise errors.ProgrammerError: if we call this when forks are disabled + + """ + if no_fork: + raise errors.ProgrammerError("utils.StartDaemon() called with fork()" + " disabled") + + if output and not (bool(output) ^ (output_fd is not None)): + raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be" + " specified") + + if isinstance(cmd, basestring): + cmd = ["/bin/sh", "-c", cmd] + + strcmd = ShellQuoteArgs(cmd) + + if output: + logging.debug("StartDaemon %s, output file '%s'", strcmd, output) + else: + logging.debug("StartDaemon %s", strcmd) + + cmd_env = _BuildCmdEnvironment(env, False) + + # Create pipe for sending PID back + (pidpipe_read, pidpipe_write) = os.pipe() + try: + try: + # Create pipe for sending error messages + (errpipe_read, errpipe_write) = os.pipe() + try: + try: + # First fork + pid = os.fork() + if pid == 0: + try: + # Child process, won't return + _StartDaemonChild(errpipe_read, errpipe_write, + pidpipe_read, pidpipe_write, + cmd, cmd_env, cwd, + output, output_fd, pidfile) + finally: + # Well, maybe child process failed + os._exit(1) # pylint: disable-msg=W0212 + finally: + _CloseFDNoErr(errpipe_write) + + # Wait for daemon to be started (or an error message to + # arrive) and read up to 100 KB as an error message + errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024) + finally: + _CloseFDNoErr(errpipe_read) + finally: + _CloseFDNoErr(pidpipe_write) + + # Read up to 128 bytes for PID + pidtext = RetryOnSignal(os.read, pidpipe_read, 128) + finally: + _CloseFDNoErr(pidpipe_read) + + # Try to avoid zombies by waiting for child process + try: + os.waitpid(pid, 0) + except OSError: + pass + + if errormsg: + raise errors.OpExecError("Error when starting daemon process: %r" % + errormsg) + + try: + return int(pidtext) + except (ValueError, TypeError), err: + raise errors.OpExecError("Error while trying to parse PID %r: %s" % + (pidtext, err)) + +def _StartDaemonChild(errpipe_read, errpipe_write, + pidpipe_read, pidpipe_write, + args, env, cwd, + output, fd_output, pidfile): + """Child process for starting daemon. -def _RunCmdPipe(cmd, env, via_shell, cwd): + """ + try: + # Close parent's side + _CloseFDNoErr(errpipe_read) + _CloseFDNoErr(pidpipe_read) + + # First child process + SetupDaemonEnv() + + # And fork for the second time + pid = os.fork() + if pid != 0: + # Exit first child process + os._exit(0) # pylint: disable-msg=W0212 + + # Make sure pipe is closed on execv* (and thereby notifies + # original process) + SetCloseOnExecFlag(errpipe_write, True) + + # List of file descriptors to be left open + noclose_fds = [errpipe_write] + + # Open PID file + if pidfile: + fd_pidfile = WritePidFile(pidfile) + + # Keeping the file open to hold the lock + noclose_fds.append(fd_pidfile) + + SetCloseOnExecFlag(fd_pidfile, False) + else: + fd_pidfile = None + + SetupDaemonFDs(output, fd_output) + + # Send daemon PID to parent + RetryOnSignal(os.write, pidpipe_write, str(os.getpid())) + + # Close all file descriptors except stdio and error message pipe + CloseFDs(noclose_fds=noclose_fds) + + # Change working directory + os.chdir(cwd) + + if env is None: + os.execvp(args[0], args) + else: + os.execvpe(args[0], args, env) + except: # pylint: disable-msg=W0702 + try: + # Report errors to original process + WriteErrorToFD(errpipe_write, str(sys.exc_info()[1])) + except: # pylint: disable-msg=W0702 + # Ignore errors in error handling + pass + + os._exit(1) # pylint: disable-msg=W0212 + + +def WriteErrorToFD(fd, err): + """Possibly write an error message to a fd. + + @type fd: None or int (file descriptor) + @param fd: if not None, the error will be written to this fd + @param err: string, the error message + + """ + if fd is None: + return + + if not err: + err = "" + + RetryOnSignal(os.write, fd, err) + + +def _CheckIfAlive(child): + """Raises L{RetryAgain} if child is still alive. + + @raises RetryAgain: If child is still alive + + """ + if child.poll() is None: + raise RetryAgain() + + +def _WaitForProcess(child, timeout): + """Waits for the child to terminate or until we reach timeout. + + """ + try: + Retry(_CheckIfAlive, (1.0, 1.2, 5.0), max(0, timeout), args=[child]) + except RetryTimeout: + pass + + +def _RunCmdPipe(cmd, env, via_shell, cwd, interactive, timeout, + _linger_timeout=constants.CHILD_LINGER_TIMEOUT): """Run a command and return its output. @type cmd: string or list @@ -223,62 +525,119 @@ def _RunCmdPipe(cmd, env, via_shell, cwd): @param via_shell: if we should run via the shell @type cwd: string @param cwd: the working directory for the program + @type interactive: boolean + @param interactive: Run command interactive (without piping) + @type timeout: int + @param timeout: Timeout after the programm gets terminated @rtype: tuple @return: (out, err, status) """ poller = select.poll() + + stderr = subprocess.PIPE + stdout = subprocess.PIPE + stdin = subprocess.PIPE + + if interactive: + stderr = stdout = stdin = None + child = subprocess.Popen(cmd, shell=via_shell, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, + stderr=stderr, + stdout=stdout, + stdin=stdin, close_fds=True, env=env, cwd=cwd) - child.stdin.close() - poller.register(child.stdout, select.POLLIN) - poller.register(child.stderr, select.POLLIN) out = StringIO() err = StringIO() - fdmap = { - child.stdout.fileno(): (out, child.stdout), - child.stderr.fileno(): (err, child.stderr), - } - for fd in fdmap: - status = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, status | os.O_NONBLOCK) - - while fdmap: - try: - pollresult = poller.poll() - except EnvironmentError, eerr: - if eerr.errno == errno.EINTR: - continue - raise - except select.error, serr: - if serr[0] == errno.EINTR: - continue - raise - for fd, event in pollresult: - if event & select.POLLIN or event & select.POLLPRI: - data = fdmap[fd][1].read() - # no data from read signifies EOF (the same as POLLHUP) - if not data: + linger_timeout = None + + if timeout is None: + poll_timeout = None + else: + poll_timeout = RunningTimeout(timeout, True).Remaining + + msg_timeout = ("Command %s (%d) run into execution timeout, terminating" % + (cmd, child.pid)) + msg_linger = ("Command %s (%d) run into linger timeout, killing" % + (cmd, child.pid)) + + timeout_action = _TIMEOUT_NONE + + if not interactive: + child.stdin.close() + poller.register(child.stdout, select.POLLIN) + poller.register(child.stderr, select.POLLIN) + fdmap = { + child.stdout.fileno(): (out, child.stdout), + child.stderr.fileno(): (err, child.stderr), + } + for fd in fdmap: + SetNonblockFlag(fd, True) + + while fdmap: + if poll_timeout: + pt = poll_timeout() * 1000 + if pt < 0: + if linger_timeout is None: + logging.warning(msg_timeout) + if child.poll() is None: + timeout_action = _TIMEOUT_TERM + IgnoreProcessNotFound(os.kill, child.pid, signal.SIGTERM) + linger_timeout = RunningTimeout(_linger_timeout, True).Remaining + pt = linger_timeout() * 1000 + if pt < 0: + break + else: + pt = None + + pollresult = RetryOnSignal(poller.poll, pt) + + for fd, event in pollresult: + if event & select.POLLIN or event & select.POLLPRI: + data = fdmap[fd][1].read() + # no data from read signifies EOF (the same as POLLHUP) + if not data: + poller.unregister(fd) + del fdmap[fd] + continue + fdmap[fd][0].write(data) + if (event & select.POLLNVAL or event & select.POLLHUP or + event & select.POLLERR): poller.unregister(fd) del fdmap[fd] - continue - fdmap[fd][0].write(data) - if (event & select.POLLNVAL or event & select.POLLHUP or - event & select.POLLERR): - poller.unregister(fd) - del fdmap[fd] + + if timeout is not None: + assert callable(poll_timeout) + + # We have no I/O left but it might still run + if child.poll() is None: + _WaitForProcess(child, poll_timeout()) + + # Terminate if still alive after timeout + if child.poll() is None: + if linger_timeout is None: + logging.warning(msg_timeout) + timeout_action = _TIMEOUT_TERM + IgnoreProcessNotFound(os.kill, child.pid, signal.SIGTERM) + lt = _linger_timeout + else: + lt = linger_timeout() + _WaitForProcess(child, lt) + + # Okay, still alive after timeout and linger timeout? Kill it! + if child.poll() is None: + timeout_action = _TIMEOUT_KILL + logging.warning(msg_linger) + IgnoreProcessNotFound(os.kill, child.pid, signal.SIGKILL) out = out.getvalue() err = err.getvalue() status = child.wait() - return out, err, status + return out, err, status, timeout_action def _RunCmdFile(cmd, env, via_shell, output, cwd): @@ -314,6 +673,61 @@ def _RunCmdFile(cmd, env, via_shell, output, cwd): return status +def SetCloseOnExecFlag(fd, enable): + """Sets or unsets the close-on-exec flag on a file descriptor. + + @type fd: int + @param fd: File descriptor + @type enable: bool + @param enable: Whether to set or unset it. + + """ + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + + if enable: + flags |= fcntl.FD_CLOEXEC + else: + flags &= ~fcntl.FD_CLOEXEC + + fcntl.fcntl(fd, fcntl.F_SETFD, flags) + + +def SetNonblockFlag(fd, enable): + """Sets or unsets the O_NONBLOCK flag on on a file descriptor. + + @type fd: int + @param fd: File descriptor + @type enable: bool + @param enable: Whether to set or unset it + + """ + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + + if enable: + flags |= os.O_NONBLOCK + else: + flags &= ~os.O_NONBLOCK + + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +def RetryOnSignal(fn, *args, **kwargs): + """Calls a function again if it failed due to EINTR. + + """ + while True: + try: + return fn(*args, **kwargs) + except EnvironmentError, err: + if err.errno != errno.EINTR: + raise + except (socket.error, select.error), err: + # In python 2.6 and above select.error is an IOError, so it's handled + # above, in 2.5 and below it's not, and it's handled here. + if not (err.args and err.args[0] == errno.EINTR): + raise + + def RunParts(dir_name, env=None, reset_env=False): """Run Scripts or programs in a directory @@ -351,19 +765,6 @@ def RunParts(dir_name, env=None, reset_env=False): return rr -def GetSocketCredentials(sock): - """Returns the credentials of the foreign process connected to a socket. - - @param sock: Unix socket - @rtype: tuple; (number, number, number) - @return: The PID, UID and GID of the connected foreign process. - - """ - peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED, - _STRUCT_UCRED_SIZE) - return struct.unpack(_STRUCT_UCRED, peercred) - - def RemoveFile(filename): """Remove a file ignoring some errors. @@ -381,6 +782,24 @@ def RemoveFile(filename): raise +def RemoveDir(dirname): + """Remove an empty directory. + + Remove a directory, ignoring non-existing ones. + Other errors are passed. This includes the case, + where the directory is not empty, so it can't be removed. + + @type dirname: str + @param dirname: the empty directory to be removed + + """ + try: + os.rmdir(dirname) + except OSError, err: + if err.errno != errno.ENOENT: + raise + + def RenameFile(old, new, mkdir=False, mkdir_mode=0750): """Renames a file. @@ -466,7 +885,7 @@ def _FingerprintFile(filename): f = open(filename) - fp = sha1() + fp = compat.sha1_hash() while True: data = f.read(4096) if not data: @@ -529,8 +948,10 @@ def ForceDictType(target, key_types, allowed_values=None): msg = "'%s' has non-enforceable type %s" % (key, ktype) raise errors.ProgrammerError(msg) - if ktype == constants.VTYPE_STRING: - if not isinstance(target[key], basestring): + if ktype in (constants.VTYPE_STRING, constants.VTYPE_MAYBE_STRING): + if target[key] is None and ktype == constants.VTYPE_MAYBE_STRING: + pass + elif not isinstance(target[key], basestring): if isinstance(target[key], bool) and not target[key]: target[key] = '' else: @@ -564,6 +985,17 @@ def ForceDictType(target, key_types, allowed_values=None): raise errors.TypeEnforcementError(msg) +def _GetProcStatusPath(pid): + """Returns the path for a PID's proc status file. + + @type pid: int + @param pid: Process ID + @rtype: string + + """ + return "/proc/%d/status" % pid + + def IsProcessAlive(pid): """Check if a given pid exists on the system. @@ -590,15 +1022,99 @@ def IsProcessAlive(pid): if pid <= 0: return False - proc_entry = "/proc/%d/status" % pid # /proc in a multiprocessor environment can have strange behaviors. # Retry the os.stat a few times until we get a good result. try: - return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5, args=[proc_entry]) + return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5, + args=[_GetProcStatusPath(pid)]) except RetryTimeout, err: err.RaiseInner() +def _ParseSigsetT(sigset): + """Parse a rendered sigset_t value. + + This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t + function. + + @type sigset: string + @param sigset: Rendered signal set from /proc/$pid/status + @rtype: set + @return: Set of all enabled signal numbers + + """ + result = set() + + signum = 0 + for ch in reversed(sigset): + chv = int(ch, 16) + + # The following could be done in a loop, but it's easier to read and + # understand in the unrolled form + if chv & 1: + result.add(signum + 1) + if chv & 2: + result.add(signum + 2) + if chv & 4: + result.add(signum + 3) + if chv & 8: + result.add(signum + 4) + + signum += 4 + + return result + + +def _GetProcStatusField(pstatus, field): + """Retrieves a field from the contents of a proc status file. + + @type pstatus: string + @param pstatus: Contents of /proc/$pid/status + @type field: string + @param field: Name of field whose value should be returned + @rtype: string + + """ + for line in pstatus.splitlines(): + parts = line.split(":", 1) + + if len(parts) < 2 or parts[0] != field: + continue + + return parts[1].strip() + + return None + + +def IsProcessHandlingSignal(pid, signum, status_path=None): + """Checks whether a process is handling a signal. + + @type pid: int + @param pid: Process ID + @type signum: int + @param signum: Signal number + @rtype: bool + + """ + if status_path is None: + status_path = _GetProcStatusPath(pid) + + try: + proc_status = ReadFile(status_path) + except EnvironmentError, err: + # In at least one case, reading /proc/$pid/status failed with ESRCH. + if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH): + return False + raise + + sigcgt = _GetProcStatusField(proc_status, "SigCgt") + if sigcgt is None: + raise RuntimeError("%s is missing 'SigCgt' field" % status_path) + + # Now check whether signal is handled + return signum in _ParseSigsetT(sigcgt) + + def ReadPidFile(pidfile): """Read a pid from a file. @@ -625,6 +1141,37 @@ def ReadPidFile(pidfile): return pid +def ReadLockedPidFile(path): + """Reads a locked PID file. + + This can be used together with L{StartDaemon}. + + @type path: string + @param path: Path to PID file + @return: PID as integer or, if file was unlocked or couldn't be opened, None + + """ + try: + fd = os.open(path, os.O_RDONLY) + except EnvironmentError, err: + if err.errno == errno.ENOENT: + # PID file doesn't exist + return None + raise + + try: + try: + # Try to acquire lock + LockFile(fd) + except errors.LockError: + # Couldn't lock, daemon is running + return int(os.read(fd, 100)) + finally: + os.close(fd) + + return None + + def MatchNameComponent(key, name_list, case_sensitive=True): """Try to match a name against a list. @@ -671,92 +1218,28 @@ def MatchNameComponent(key, name_list, case_sensitive=True): return None -class HostInfo: - """Class implementing resolver and hostname functionality - - """ - _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$") - - def __init__(self, name=None): - """Initialize the host name object. +def ValidateServiceName(name): + """Validate the given service name. - If the name argument is not passed, it will use this system's - name. + @type name: number or string + @param name: Service name or port specification - """ - if name is None: - name = self.SysName() - - self.query = name - self.name, self.aliases, self.ipaddrs = self.LookupHostname(name) - self.ip = self.ipaddrs[0] - - def ShortName(self): - """Returns the hostname without domain. - - """ - return self.name.split('.')[0] - - @staticmethod - def SysName(): - """Return the current system's name. - - This is simply a wrapper over C{socket.gethostname()}. - - """ - return socket.gethostname() - - @staticmethod - def LookupHostname(hostname): - """Look up hostname - - @type hostname: str - @param hostname: hostname to look up - - @rtype: tuple - @return: a tuple (name, aliases, ipaddrs) as returned by - C{socket.gethostbyname_ex} - @raise errors.ResolverError: in case of errors in resolving - - """ - try: - result = socket.gethostbyname_ex(hostname) - except socket.gaierror, err: - # hostname not found in DNS - raise errors.ResolverError(hostname, err.args[0], err.args[1]) - - return result - - @classmethod - def NormalizeName(cls, hostname): - """Validate and normalize the given hostname. - - @attention: the validation is a bit more relaxed than the standards - require; most importantly, we allow underscores in names - @raise errors.OpPrereqError: when the name is not valid + """ + try: + numport = int(name) + except (ValueError, TypeError): + # Non-numeric service name + valid = _VALID_SERVICE_NAME_RE.match(name) + else: + # Numeric port (protocols other than TCP or UDP might need adjustments + # here) + valid = (numport >= 0 and numport < (1 << 16)) - """ - hostname = hostname.lower() - if (not cls._VALID_NAME_RE.match(hostname) or - # double-dots, meaning empty label - ".." in hostname or - # empty initial label - hostname.startswith(".")): - raise errors.OpPrereqError("Invalid hostname '%s'" % hostname, - errors.ECODE_INVAL) - if hostname.endswith("."): - hostname = hostname.rstrip(".") - return hostname - - -def GetHostInfo(name=None): - """Lookup host name and raise an OpPrereqError for failures""" + if not valid: + raise errors.OpPrereqError("Invalid service name '%s'" % name, + errors.ECODE_INVAL) - try: - return HostInfo(name) - except errors.ResolverError, err: - raise errors.OpPrereqError("The given name (%s) does not resolve: %s" % - (err[0], err[2]), errors.ECODE_RESOLVER) + return name def ListVolumeGroups(): @@ -799,7 +1282,25 @@ def BridgeExists(bridge): return os.path.isdir("/sys/class/net/%s/bridge" % bridge) -def NiceSort(name_list): +def _NiceSortTryInt(val): + """Attempts to convert a string to an integer. + + """ + if val and _SORTER_DIGIT.match(val): + return int(val) + else: + return val + + +def _NiceSortKey(value): + """Extract key for sorting. + + """ + return [_NiceSortTryInt(grp) + for grp in _SORTER_RE.match(value).groups()] + + +def NiceSort(values, key=None): """Sort a list of strings based on digit and non-digit groupings. Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function @@ -810,30 +1311,21 @@ def NiceSort(name_list): or no-digits. Only the first eight such groups are considered, and after that we just use what's left of the string. - @type name_list: list - @param name_list: the names to be sorted + @type values: list + @param values: the names to be sorted + @type key: callable or None + @param key: function of one argument to extract a comparison key from each + list element, must return string @rtype: list @return: a copy of the name list sorted with our algorithm """ - _SORTER_BASE = "(\D+|\d+)" - _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE, - _SORTER_BASE, _SORTER_BASE, - _SORTER_BASE, _SORTER_BASE, - _SORTER_BASE, _SORTER_BASE) - _SORTER_RE = re.compile(_SORTER_FULL) - _SORTER_NODIGIT = re.compile("^\D*$") - def _TryInt(val): - """Attempts to convert a variable to integer.""" - if val is None or _SORTER_NODIGIT.match(val): - return val - rval = int(val) - return rval + if key is None: + keyfunc = _NiceSortKey + else: + keyfunc = lambda value: _NiceSortKey(key(value)) - to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name) - for name in name_list] - to_sort.sort() - return [tup[1] for tup in to_sort] + return sorted(values, key=keyfunc) def TryConvert(fn, val): @@ -858,24 +1350,6 @@ def TryConvert(fn, val): return nv -def IsValidIP(ip): - """Verifies the syntax of an IPv4 address. - - This function checks if the IPv4 address passes is valid or not based - on syntax (not IP range, class calculations, etc.). - - @type ip: str - @param ip: the address to be checked - @rtype: a regular expression match object - @return: a regular expression match object, or None if the - address is not valid - - """ - unit = "(0|[1-9]\d{0,2})" - #TODO: convert and return only boolean - return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip) - - def IsValidShellParam(word): """Verifies is the given word is safe from the shell's p.o.v. @@ -892,7 +1366,7 @@ def IsValidShellParam(word): @return: True if the word is 'safe' """ - return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word)) + return bool(_SHELLPARAM_REGEX.match(word)) def BuildShellCmd(template, *args): @@ -961,7 +1435,7 @@ def ParseUnit(input_string): is always an int in MiB. """ - m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string)) + m = _PARSEUNIT_REGEX.match(str(input_string)) if not m: raise errors.UnitParseError("Invalid format") @@ -998,18 +1472,61 @@ def ParseUnit(input_string): return value -def AddAuthorizedKey(file_name, key): +def ParseCpuMask(cpu_mask): + """Parse a CPU mask definition and return the list of CPU IDs. + + CPU mask format: comma-separated list of CPU IDs + or dash-separated ID ranges + Example: "0-2,5" -> "0,1,2,5" + + @type cpu_mask: str + @param cpu_mask: CPU mask definition + @rtype: list of int + @return: list of CPU IDs + + """ + if not cpu_mask: + return [] + cpu_list = [] + for range_def in cpu_mask.split(","): + boundaries = range_def.split("-") + n_elements = len(boundaries) + if n_elements > 2: + raise errors.ParseError("Invalid CPU ID range definition" + " (only one hyphen allowed): %s" % range_def) + try: + lower = int(boundaries[0]) + except (ValueError, TypeError), err: + raise errors.ParseError("Invalid CPU ID value for lower boundary of" + " CPU ID range: %s" % str(err)) + try: + higher = int(boundaries[-1]) + except (ValueError, TypeError), err: + raise errors.ParseError("Invalid CPU ID value for higher boundary of" + " CPU ID range: %s" % str(err)) + if lower > higher: + raise errors.ParseError("Invalid CPU ID range definition" + " (%d > %d): %s" % (lower, higher, range_def)) + cpu_list.extend(range(lower, higher + 1)) + return cpu_list + + +def AddAuthorizedKey(file_obj, key): """Adds an SSH public key to an authorized_keys file. - @type file_name: str - @param file_name: path to authorized_keys file + @type file_obj: str or file handle + @param file_obj: path to authorized_keys file @type key: str @param key: string containing key """ key_fields = key.split() - f = open(file_name, 'a+') + if isinstance(file_obj, basestring): + f = open(file_obj, 'a+') + else: + f = file_obj + try: nl = True for line in f: @@ -1073,50 +1590,42 @@ def SetEtcHostsEntry(file_name, ip, hostname, aliases): @param aliases: the list of aliases to add for the hostname """ - # FIXME: use WriteFile + fn rather than duplicating its efforts # Ensure aliases are unique aliases = UniqueSequence([hostname] + aliases)[1:] - fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name)) - try: - out = os.fdopen(fd, 'w') + def _WriteEtcHosts(fd): + # Duplicating file descriptor because os.fdopen's result will automatically + # close the descriptor, but we would still like to have its functionality. + out = os.fdopen(os.dup(fd), "w") try: - f = open(file_name, 'r') - try: - for line in f: - fields = line.split() - if fields and not fields[0].startswith('#') and ip == fields[0]: - continue - out.write(line) - - out.write("%s\t%s" % (ip, hostname)) - if aliases: - out.write(" %s" % ' '.join(aliases)) - out.write('\n') + for line in ReadFile(file_name).splitlines(True): + fields = line.split() + if fields and not fields[0].startswith("#") and ip == fields[0]: + continue + out.write(line) - out.flush() - os.fsync(out) - os.chmod(tmpname, 0644) - os.rename(tmpname, file_name) - finally: - f.close() + out.write("%s\t%s" % (ip, hostname)) + if aliases: + out.write(" %s" % " ".join(aliases)) + out.write("\n") + out.flush() finally: out.close() - except: - RemoveFile(tmpname) - raise + + WriteFile(file_name, fn=_WriteEtcHosts, mode=0644) -def AddHostToEtcHosts(hostname): +def AddHostToEtcHosts(hostname, ip): """Wrapper around SetEtcHostsEntry. @type hostname: str @param hostname: a hostname that will be resolved and added to L{constants.ETC_HOSTS} + @type ip: str + @param ip: The ip address of the host """ - hi = HostInfo(name=hostname) - SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()]) + SetEtcHostsEntry(constants.ETC_HOSTS, ip, hostname, [hostname.split(".")[0]]) def RemoveEtcHostsEntry(file_name, hostname): @@ -1130,37 +1639,29 @@ def RemoveEtcHostsEntry(file_name, hostname): @param hostname: the hostname to be removed """ - # FIXME: use WriteFile + fn rather than duplicating its efforts - fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name)) - try: - out = os.fdopen(fd, 'w') + def _WriteEtcHosts(fd): + # Duplicating file descriptor because os.fdopen's result will automatically + # close the descriptor, but we would still like to have its functionality. + out = os.fdopen(os.dup(fd), "w") try: - f = open(file_name, 'r') - try: - for line in f: - fields = line.split() - if len(fields) > 1 and not fields[0].startswith('#'): - names = fields[1:] - if hostname in names: - while hostname in names: - names.remove(hostname) - if names: - out.write("%s %s\n" % (fields[0], ' '.join(names))) - continue - - out.write(line) + for line in ReadFile(file_name).splitlines(True): + fields = line.split() + if len(fields) > 1 and not fields[0].startswith("#"): + names = fields[1:] + if hostname in names: + while hostname in names: + names.remove(hostname) + if names: + out.write("%s %s\n" % (fields[0], " ".join(names))) + continue - out.flush() - os.fsync(out) - os.chmod(tmpname, 0644) - os.rename(tmpname, file_name) - finally: - f.close() + out.write(line) + + out.flush() finally: out.close() - except: - RemoveFile(tmpname) - raise + + WriteFile(file_name, fn=_WriteEtcHosts, mode=0644) def RemoveHostFromEtcHosts(hostname): @@ -1172,16 +1673,15 @@ def RemoveHostFromEtcHosts(hostname): L{constants.ETC_HOSTS} """ - hi = HostInfo(name=hostname) - RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name) - RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName()) + RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname) + RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname.split(".")[0]) def TimestampForFilename(): """Returns the current time formatted for filenames. - The format doesn't contain colons as some shells and applications them as - separators. + The format doesn't contain colons as some shells and applications treat them + as separators. Uses the local timezone. """ return time.strftime("%Y-%m-%d_%H_%M_%S") @@ -1247,66 +1747,46 @@ def ShellQuoteArgs(args): return ' '.join([ShellQuote(i) for i in args]) -def TcpPing(target, port, timeout=10, live_port_needed=False, source=None): - """Simple ping implementation using TCP connect(2). - - Check if the given IP is reachable by doing attempting a TCP connect - to it. - - @type target: str - @param target: the IP or hostname to ping - @type port: int - @param port: the port to connect to - @type timeout: int - @param timeout: the timeout on the connection attempt - @type live_port_needed: boolean - @param live_port_needed: whether a closed port will cause the - function to return failure, as if there was a timeout - @type source: str or None - @param source: if specified, will cause the connect to be made - from this specific source address; failures to bind other - than C{EADDRNOTAVAIL} will be ignored +class ShellWriter: + """Helper class to write scripts with indentation. """ - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + INDENT_STR = " " - success = False + def __init__(self, fh): + """Initializes this class. - if source is not None: - try: - sock.bind((source, 0)) - except socket.error, (errcode, _): - if errcode == errno.EADDRNOTAVAIL: - success = False + """ + self._fh = fh + self._indent = 0 - sock.settimeout(timeout) + def IncIndent(self): + """Increase indentation level by 1. - try: - sock.connect((target, port)) - sock.close() - success = True - except socket.timeout: - success = False - except socket.error, (errcode, _): - success = (not live_port_needed) and (errcode == errno.ECONNREFUSED) + """ + self._indent += 1 - return success + def DecIndent(self): + """Decrease indentation level by 1. + """ + assert self._indent > 0 + self._indent -= 1 -def OwnIpAddress(address): - """Check if the current host has the the given IP address. + def Write(self, txt, *args): + """Write line to output file. - Currently this is done by TCP-pinging the address from the loopback - address. + """ + assert self._indent >= 0 - @type address: string - @param address: the address to check - @rtype: bool - @return: True if we own the address + self._fh.write(self._indent * self.INDENT_STR) - """ - return TcpPing(address, constants.DEFAULT_NODED_PORT, - source=constants.LOCALHOST_IP_ADDRESS) + if args: + self._fh.write(txt % args) + else: + self._fh.write(txt) + + self._fh.write("\n") def ListVisibleFiles(path): @@ -1323,7 +1803,6 @@ def ListVisibleFiles(path): raise errors.ProgrammerError("Path passed to ListVisibleFiles is not" " absolute/normalized: '%s'" % path) files = [i for i in os.listdir(path) if not i.startswith(".")] - files.sort() return files @@ -1388,6 +1867,11 @@ def EnsureDirs(dirs): if err.errno != errno.EEXIST: raise errors.GenericError("Cannot create needed directory" " '%s': %s" % (dir_name, err)) + try: + os.chmod(dir_name, dir_mode) + except EnvironmentError, err: + raise errors.GenericError("Cannot change directory permissions on" + " '%s': %s" % (dir_name, err)) if not os.path.isdir(dir_name): raise errors.GenericError("%s is not a directory" % dir_name) @@ -1507,6 +1991,71 @@ def WriteFile(file_name, fn=None, data=None, return result +def GetFileID(path=None, fd=None): + """Returns the file 'id', i.e. the dev/inode and mtime information. + + Either the path to the file or the fd must be given. + + @param path: the file path + @param fd: a file descriptor + @return: a tuple of (device number, inode number, mtime) + + """ + if [path, fd].count(None) != 1: + raise errors.ProgrammerError("One and only one of fd/path must be given") + + if fd is None: + st = os.stat(path) + else: + st = os.fstat(fd) + + return (st.st_dev, st.st_ino, st.st_mtime) + + +def VerifyFileID(fi_disk, fi_ours): + """Verifies that two file IDs are matching. + + Differences in the inode/device are not accepted, but and older + timestamp for fi_disk is accepted. + + @param fi_disk: tuple (dev, inode, mtime) representing the actual + file data + @param fi_ours: tuple (dev, inode, mtime) representing the last + written file data + @rtype: boolean + + """ + (d1, i1, m1) = fi_disk + (d2, i2, m2) = fi_ours + + return (d1, i1) == (d2, i2) and m1 <= m2 + + +def SafeWriteFile(file_name, file_id, **kwargs): + """Wraper over L{WriteFile} that locks the target file. + + By keeping the target file locked during WriteFile, we ensure that + cooperating writers will safely serialise access to the file. + + @type file_name: str + @param file_name: the target filename + @type file_id: tuple + @param file_id: a result from L{GetFileID} + + """ + fd = os.open(file_name, os.O_RDONLY | os.O_CREAT) + try: + LockFile(fd) + if file_id is not None: + disk_id = GetFileID(fd=fd) + if not VerifyFileID(disk_id, file_id): + raise errors.LockError("Cannot overwrite file %s, it has been modified" + " since last written" % file_name) + return WriteFile(file_name, **kwargs) + finally: + os.close(fd) + + def ReadOneLineFile(file_name, strict=False): """Return the first non-empty line from a file. @@ -1659,6 +2208,29 @@ def UniqueSequence(seq): return [i for i in seq if i not in seen and not seen.add(i)] +def FindDuplicates(seq): + """Identifies duplicates in a list. + + Does not preserve element order. + + @type seq: sequence + @param seq: Sequence with source elements + @rtype: list + @return: List of duplicate elements from seq + + """ + dup = set() + seen = set() + + for item in seq: + if item in seen: + dup.add(item) + else: + seen.add(item) + + return list(dup) + + def NormalizeAndValidateMac(mac): """Normalizes and check if a MAC address is valid. @@ -1673,8 +2245,7 @@ def NormalizeAndValidateMac(mac): @raise errors.OpPrereqError: If the MAC isn't valid """ - mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I) - if not mac_check.match(mac): + if not _MAC_CHECK.match(mac): raise errors.OpPrereqError("Invalid MAC address specified: %s" % mac, errors.ECODE_INVAL) @@ -1748,18 +2319,19 @@ def CloseFDs(noclose_fds=None): _CloseFDNoErr(fd) -def Mlockall(): +def Mlockall(_ctypes=ctypes): """Lock current process' virtual address space into RAM. This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE), see mlock(2) for more details. This function requires ctypes module. + @raises errors.NoCtypesError: if ctypes module is not found + """ - if ctypes is None: - logging.warning("Cannot set memory lock, ctypes module not found") - return + if _ctypes is None: + raise errors.NoCtypesError() - libc = ctypes.cdll.LoadLibrary("libc.so.6") + libc = _ctypes.cdll.LoadLibrary("libc.so.6") if libc is None: logging.error("Cannot set memory lock, ctypes cannot load libc") return @@ -1769,10 +2341,12 @@ def Mlockall(): # By declaring this variable as a pointer to an integer we can then access # its value correctly, should the mlockall call fail, in order to see what # the actual error code was. - libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int) + # pylint: disable-msg=W0212 + libc.__errno_location.restype = _ctypes.POINTER(_ctypes.c_int) if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE): - logging.error("Cannot set memory lock: %s" % + # pylint: disable-msg=W0212 + logging.error("Cannot set memory lock: %s", os.strerror(libc.__errno_location().contents.value)) return @@ -1793,33 +2367,39 @@ def Daemonize(logfile): """ # pylint: disable-msg=W0212 # yes, we really want os._exit - UMASK = 077 - WORKDIR = "/" + + # TODO: do another attempt to merge Daemonize and StartDaemon, or at + # least abstract the pipe functionality between them + + # Create pipe for sending error messages + (rpipe, wpipe) = os.pipe() # this might fail pid = os.fork() if (pid == 0): # The first child. - os.setsid() + SetupDaemonEnv() + # this might fail pid = os.fork() # Fork a second child. if (pid == 0): # The second child. - os.chdir(WORKDIR) - os.umask(UMASK) + _CloseFDNoErr(rpipe) else: # exit() or _exit()? See below. os._exit(0) # Exit parent (the first child) of the second child. else: - os._exit(0) # Exit parent of the first child. + _CloseFDNoErr(wpipe) + # Wait for daemon to be started (or an error message to + # arrive) and read up to 100 KB as an error message + errormsg = RetryOnSignal(os.read, rpipe, 100 * 1024) + if errormsg: + sys.stderr.write("Error when starting daemon process: %r\n" % errormsg) + rcode = 1 + else: + rcode = 0 + os._exit(rcode) # Exit parent of the first child. - for fd in range(3): - _CloseFDNoErr(fd) - i = os.open("/dev/null", os.O_RDONLY) # stdin - assert i == 0, "Can't close/reopen stdin" - i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout - assert i == 1, "Can't close/reopen stdout" - # Duplicate standard output to standard error. - os.dup2(1, 2) - return 0 + SetupDaemonFDs(logfile, None) + return wpipe def DaemonPidFileName(name): @@ -1848,23 +2428,44 @@ def EnsureDaemon(name): return True -def WritePidFile(name): - """Write the current process pidfile. +def StopDaemon(name): + """Stop daemon + + """ + result = RunCmd([constants.DAEMON_UTIL, "stop", name]) + if result.failed: + logging.error("Can't stop daemon '%s', failure %s, output: %s", + name, result.fail_reason, result.output) + return False + + return True - The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid} - @type name: str - @param name: the daemon name to use - @raise errors.GenericError: if the pid file already exists and +def WritePidFile(pidfile): + """Write the current process pidfile. + + @type pidfile: sting + @param pidfile: the path to the file to be written + @raise errors.LockError: if the pid file already exists and points to a live process + @rtype: int + @return: the file descriptor of the lock file; do not close this unless + you want to unlock the pid file """ - pid = os.getpid() - pidfilename = DaemonPidFileName(name) - if IsProcessAlive(ReadPidFile(pidfilename)): - raise errors.GenericError("%s contains a live process" % pidfilename) + # We don't rename nor truncate the file to not drop locks under + # existing processes + fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600) - WriteFile(pidfilename, data="%d\n" % pid) + # Lock the PID file (and fail if not possible to do so). Any code + # wanting to send a signal to the daemon should try to lock the PID + # file before reading it. If acquiring the lock succeeds, the daemon is + # no longer running and the signal should not be sent. + LockFile(fd_pidfile) + + os.write(fd_pidfile, "%d\n" % os.getpid()) + + return fd_pidfile def RemovePidFile(name): @@ -1904,8 +2505,7 @@ def KillProcess(pid, signal_=signal.SIGTERM, timeout=30, """ def _helper(pid, signal_, wait): """Simple helper to encapsulate the kill/waitpid sequence""" - os.kill(pid, signal_) - if wait: + if IgnoreProcessNotFound(os.kill, pid, signal_) and wait: try: os.waitpid(pid, os.WNOHANG) except OSError: @@ -2043,30 +2643,6 @@ def MergeTime(timetuple): return float(seconds) + (float(microseconds) * 0.000001) -def GetDaemonPort(daemon_name): - """Get the daemon port for this cluster. - - Note that this routine does not read a ganeti-specific file, but - instead uses C{socket.getservbyname} to allow pre-customization of - this parameter outside of Ganeti. - - @type daemon_name: string - @param daemon_name: daemon name (in constants.DAEMONS_PORTS) - @rtype: int - - """ - if daemon_name not in constants.DAEMONS_PORTS: - raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name) - - (proto, default_port) = constants.DAEMONS_PORTS[daemon_name] - try: - port = socket.getservbyname(daemon_name, proto) - except socket.error: - port = default_port - - return port - - class LogFileHandler(logging.FileHandler): """Log handler that doesn't fallback to stderr. @@ -2084,7 +2660,7 @@ class LogFileHandler(logging.FileHandler): logging.FileHandler.__init__(self, filename, mode, encoding) self.console = open(constants.DEV_CONSOLE, "a") - def handleError(self, record): + def handleError(self, record): # pylint: disable-msg=C0103 """Handle errors which occur during an emit() call. Try to handle errors with FileHandler method, if it fails write to @@ -2093,10 +2669,10 @@ class LogFileHandler(logging.FileHandler): """ try: logging.FileHandler.handleError(self, record) - except Exception: + except Exception: # pylint: disable-msg=W0703 try: self.console.write("Cannot log message:\n%s\n" % self.format(record)) - except Exception: + except Exception: # pylint: disable-msg=W0703 # Log handler tried everything it could, now just give up pass @@ -2264,9 +2840,10 @@ def _ParseAsn1Generalizedtime(value): @type value: string @param value: ASN1 GENERALIZEDTIME timestamp + @return: Seconds since the Epoch (1970-01-01 00:00:00 UTC) """ - m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value) + m = _ASN1_TIME_REGEX.match(value) if m: # We have an offset asn1time = m.group(1) @@ -2322,6 +2899,177 @@ def GetX509CertValidity(cert): return (not_before, not_after) +def _VerifyCertificateInner(expired, not_before, not_after, now, + warn_days, error_days): + """Verifies certificate validity. + + @type expired: bool + @param expired: Whether pyOpenSSL considers the certificate as expired + @type not_before: number or None + @param not_before: Unix timestamp before which certificate is not valid + @type not_after: number or None + @param not_after: Unix timestamp after which certificate is invalid + @type now: number + @param now: Current time as Unix timestamp + @type warn_days: number or None + @param warn_days: How many days before expiration a warning should be reported + @type error_days: number or None + @param error_days: How many days before expiration an error should be reported + + """ + if expired: + msg = "Certificate is expired" + + if not_before is not None and not_after is not None: + msg += (" (valid from %s to %s)" % + (FormatTime(not_before), FormatTime(not_after))) + elif not_before is not None: + msg += " (valid from %s)" % FormatTime(not_before) + elif not_after is not None: + msg += " (valid until %s)" % FormatTime(not_after) + + return (CERT_ERROR, msg) + + elif not_before is not None and not_before > now: + return (CERT_WARNING, + "Certificate not yet valid (valid from %s)" % + FormatTime(not_before)) + + elif not_after is not None: + remaining_days = int((not_after - now) / (24 * 3600)) + + msg = "Certificate expires in about %d days" % remaining_days + + if error_days is not None and remaining_days <= error_days: + return (CERT_ERROR, msg) + + if warn_days is not None and remaining_days <= warn_days: + return (CERT_WARNING, msg) + + return (None, None) + + +def VerifyX509Certificate(cert, warn_days, error_days): + """Verifies a certificate for LUVerifyCluster. + + @type cert: OpenSSL.crypto.X509 + @param cert: X509 certificate object + @type warn_days: number or None + @param warn_days: How many days before expiration a warning should be reported + @type error_days: number or None + @param error_days: How many days before expiration an error should be reported + + """ + # Depending on the pyOpenSSL version, this can just return (None, None) + (not_before, not_after) = GetX509CertValidity(cert) + + return _VerifyCertificateInner(cert.has_expired(), not_before, not_after, + time.time(), warn_days, error_days) + + +def SignX509Certificate(cert, key, salt): + """Sign a X509 certificate. + + An RFC822-like signature header is added in front of the certificate. + + @type cert: OpenSSL.crypto.X509 + @param cert: X509 certificate object + @type key: string + @param key: Key for HMAC + @type salt: string + @param salt: Salt for HMAC + @rtype: string + @return: Serialized and signed certificate in PEM format + + """ + if not VALID_X509_SIGNATURE_SALT.match(salt): + raise errors.GenericError("Invalid salt: %r" % salt) + + # Dumping as PEM here ensures the certificate is in a sane format + cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + + return ("%s: %s/%s\n\n%s" % + (constants.X509_CERT_SIGNATURE_HEADER, salt, + Sha1Hmac(key, cert_pem, salt=salt), + cert_pem)) + + +def _ExtractX509CertificateSignature(cert_pem): + """Helper function to extract signature from X509 certificate. + + """ + # Extract signature from original PEM data + for line in cert_pem.splitlines(): + if line.startswith("---"): + break + + m = X509_SIGNATURE.match(line.strip()) + if m: + return (m.group("salt"), m.group("sign")) + + raise errors.GenericError("X509 certificate signature is missing") + + +def LoadSignedX509Certificate(cert_pem, key): + """Verifies a signed X509 certificate. + + @type cert_pem: string + @param cert_pem: Certificate in PEM format and with signature header + @type key: string + @param key: Key for HMAC + @rtype: tuple; (OpenSSL.crypto.X509, string) + @return: X509 certificate object and salt + + """ + (salt, signature) = _ExtractX509CertificateSignature(cert_pem) + + # Load certificate + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem) + + # Dump again to ensure it's in a sane format + sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + + if not VerifySha1Hmac(key, sane_pem, signature, salt=salt): + raise errors.GenericError("X509 certificate signature is invalid") + + return (cert, salt) + + +def Sha1Hmac(key, text, salt=None): + """Calculates the HMAC-SHA1 digest of a text. + + HMAC is defined in RFC2104. + + @type key: string + @param key: Secret key + @type text: string + + """ + if salt: + salted_text = salt + text + else: + salted_text = text + + return hmac.new(key, salted_text, compat.sha1).hexdigest() + + +def VerifySha1Hmac(key, text, digest, salt=None): + """Verifies the HMAC-SHA1 digest of a text. + + HMAC is defined in RFC2104. + + @type key: string + @param key: Secret key + @type text: string + @type digest: string + @param digest: Expected digest + @rtype: bool + @return: Whether HMAC-SHA1 digest matches + + """ + return digest.lower() == Sha1Hmac(key, text, salt=salt).lower() + + def SafeEncode(text): """Return a 'safe' version of a source string. @@ -2411,6 +3159,33 @@ def CommaJoin(names): return ", ".join([str(val) for val in names]) +def FindMatch(data, name): + """Tries to find an item in a dictionary matching a name. + + Callers have to ensure the data names aren't contradictory (e.g. a regexp + that matches a string). If the name isn't a direct key, all regular + expression objects in the dictionary are matched against it. + + @type data: dict + @param data: Dictionary containing data + @type name: string + @param name: Name to look for + @rtype: tuple; (value in dictionary, matched groups as list) + + """ + if name in data: + return (data[name], []) + + for key, value in data.items(): + # Regex objects + if hasattr(key, "match"): + m = key.match(name) + if m: + return (value, list(m.groups())) + + return None + + def BytesToMebibyte(value): """Converts bytes to mebibytes. @@ -2442,6 +3217,26 @@ def CalculateDirectorySize(path): return BytesToMebibyte(size) +def GetMounts(filename=constants.PROC_MOUNTS): + """Returns the list of mounted filesystems. + + This function is Linux-specific. + + @param filename: path of mounts file (/proc/mounts by default) + @rtype: list of tuples + @return: list of mount entries (device, mountpoint, fstype, options) + + """ + # TODO(iustin): investigate non-Linux options (e.g. via mount output) + data = [] + mountlines = ReadFile(filename).splitlines() + for line in mountlines: + device, mountpoint, fstype, options, _ = line.split(None, 4) + data.append((device, mountpoint, fstype, options)) + + return data + + def GetFilesystemStats(path): """Returns the total and free space on a filesystem. @@ -2505,32 +3300,44 @@ def RunInSeparateProcess(fn, *args): return bool(exitcode) -def LockedMethod(fn): - """Synchronized object access decorator. +def IgnoreProcessNotFound(fn, *args, **kwargs): + """Ignores ESRCH when calling a process-related function. + + ESRCH is raised when a process is not found. - This decorator is intended to protect access to an object using the - object's own lock which is hardcoded to '_lock'. + @rtype: bool + @return: Whether process was found """ - def _LockDebug(*args, **kwargs): - if debug_locks: - logging.debug(*args, **kwargs) + try: + fn(*args, **kwargs) + except EnvironmentError, err: + # Ignore ESRCH + if err.errno == errno.ESRCH: + return False + raise - def wrapper(self, *args, **kwargs): - # pylint: disable-msg=W0212 - assert hasattr(self, '_lock') - lock = self._lock - _LockDebug("Waiting for %s", lock) - lock.acquire() - try: - _LockDebug("Acquired %s", lock) - result = fn(self, *args, **kwargs) - finally: - _LockDebug("Releasing %s", lock) - lock.release() - _LockDebug("Released %s", lock) - return result - return wrapper + return True + + +def IgnoreSignals(fn, *args, **kwargs): + """Tries to call a function ignoring failures due to EINTR. + + """ + try: + return fn(*args, **kwargs) + except EnvironmentError, err: + if err.errno == errno.EINTR: + return None + else: + raise + except (select.error, socket.error), err: + # In python 2.6 and above select.error is an IOError, so it's handled + # above, in 2.5 and below it's not, and it's handled here. + if err.args and err.args[0] == errno.EINTR: + return None + else: + raise def LockFile(fd): @@ -2552,7 +3359,8 @@ def FormatTime(val): """Formats a time value. @type val: float or None - @param val: the timestamp as returned by time.time() + @param val: Timestamp as returned by time.time() (seconds since Epoch, + 1970-01-01 00:00:00 UTC) @return: a string value or N/A if we don't have a valid timestamp """ @@ -2563,6 +3371,31 @@ def FormatTime(val): return time.strftime("%F %T", time.localtime(val)) +def FormatSeconds(secs): + """Formats seconds for easier reading. + + @type secs: number + @param secs: Number of seconds + @rtype: string + @return: Formatted seconds (e.g. "2d 9h 19m 49s") + + """ + parts = [] + + secs = round(secs, 0) + + if secs > 0: + # Negative values would be a bit tricky + for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]: + (complete, secs) = divmod(secs, one) + if complete or parts: + parts.append("%d%s" % (complete, unit)) + + parts.append("%ds" % secs) + + return " ".join(parts) + + def ReadWatcherPauseFile(filename, now=None, remove_after=3600): """Reads the watcher pause file. @@ -2762,6 +3595,66 @@ def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep, wait_fn(current_delay) +def GetClosedTempfile(*args, **kwargs): + """Creates a temporary file and returns its path. + + """ + (fd, path) = tempfile.mkstemp(*args, **kwargs) + _CloseFDNoErr(fd) + return path + + +def GenerateSelfSignedX509Cert(common_name, validity): + """Generates a self-signed X509 certificate. + + @type common_name: string + @param common_name: commonName value + @type validity: int + @param validity: Validity for certificate in seconds + + """ + # Create private and public key + key = OpenSSL.crypto.PKey() + key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS) + + # Create self-signed certificate + cert = OpenSSL.crypto.X509() + if common_name: + cert.get_subject().CN = common_name + cert.set_serial_number(1) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(validity) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.sign(key, constants.X509_CERT_SIGN_DIGEST) + + key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key) + cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert) + + return (key_pem, cert_pem) + + +def GenerateSelfSignedSslCert(filename, common_name=constants.X509_CERT_CN, + validity=constants.X509_CERT_DEFAULT_VALIDITY): + """Legacy function to generate self-signed X509 certificate. + + @type filename: str + @param filename: path to write certificate to + @type common_name: string + @param common_name: commonName value + @type validity: int + @param validity: validity of certificate in number of days + + """ + # TODO: Investigate using the cluster name instead of X505_CERT_CN for + # common_name, as cluster-renames are very seldom, and it'd be nice if RAPI + # and node daemon certificates have the proper Subject/Issuer. + (key_pem, cert_pem) = GenerateSelfSignedX509Cert(common_name, + validity * 24 * 60 * 60) + + WriteFile(filename, mode=0400, data=key_pem + cert_pem) + + class FileLock(object): """Utility class for file locks. @@ -2973,6 +3866,58 @@ def SignalHandled(signums): return wrap +class SignalWakeupFd(object): + try: + # This is only supported in Python 2.5 and above (some distributions + # backported it to Python 2.4) + _set_wakeup_fd_fn = signal.set_wakeup_fd + except AttributeError: + # Not supported + def _SetWakeupFd(self, _): # pylint: disable-msg=R0201 + return -1 + else: + def _SetWakeupFd(self, fd): + return self._set_wakeup_fd_fn(fd) + + def __init__(self): + """Initializes this class. + + """ + (read_fd, write_fd) = os.pipe() + + # Once these succeeded, the file descriptors will be closed automatically. + # Buffer size 0 is important, otherwise .read() with a specified length + # might buffer data and the file descriptors won't be marked readable. + self._read_fh = os.fdopen(read_fd, "r", 0) + self._write_fh = os.fdopen(write_fd, "w", 0) + + self._previous = self._SetWakeupFd(self._write_fh.fileno()) + + # Utility functions + self.fileno = self._read_fh.fileno + self.read = self._read_fh.read + + def Reset(self): + """Restores the previous wakeup file descriptor. + + """ + if hasattr(self, "_previous") and self._previous is not None: + self._SetWakeupFd(self._previous) + self._previous = None + + def Notify(self): + """Notifies the wakeup file descriptor. + + """ + self._write_fh.write("\0") + + def __del__(self): + """Called before object deletion. + + """ + self.Reset() + + class SignalHandler(object): """Generic signal handler class. @@ -2987,16 +3932,23 @@ class SignalHandler(object): @ivar called: tracks whether any of the signals have been raised """ - def __init__(self, signum): + def __init__(self, signum, handler_fn=None, wakeup=None): """Constructs a new SignalHandler instance. @type signum: int or list of ints @param signum: Single signal number or set of signal numbers + @type handler_fn: callable + @param handler_fn: Signal handling function """ + assert handler_fn is None or callable(handler_fn) + self.signum = set(signum) self.called = False + self._handler_fn = handler_fn + self._wakeup = wakeup + self._previous = {} try: for signum in self.signum: @@ -3037,8 +3989,7 @@ class SignalHandler(object): """ self.called = False - # we don't care about arguments, but we leave them named for the future - def _HandleSignal(self, signum, frame): # pylint: disable-msg=W0613 + def _HandleSignal(self, signum, frame): """Actual signal handling function. """ @@ -3046,6 +3997,13 @@ class SignalHandler(object): # solution in Python -- there are no atomic types. self.called = True + if self._wakeup: + # Notify whoever is interested in signals + self._wakeup.Notify() + + if self._handler_fn: + self._handler_fn(signum, frame) + class FieldSet(object): """A simple field set. @@ -3087,3 +4045,56 @@ class FieldSet(object): """ return [val for val in items if not self.Matches(val)] + + +class RunningTimeout(object): + """Class to calculate remaining timeout when doing several operations. + + """ + __slots__ = [ + "_allow_negative", + "_start_time", + "_time_fn", + "_timeout", + ] + + def __init__(self, timeout, allow_negative, _time_fn=time.time): + """Initializes this class. + + @type timeout: float + @param timeout: Timeout duration + @type allow_negative: bool + @param allow_negative: Whether to return values below zero + @param _time_fn: Time function for unittests + + """ + object.__init__(self) + + if timeout is not None and timeout < 0.0: + raise ValueError("Timeout must not be negative") + + self._timeout = timeout + self._allow_negative = allow_negative + self._time_fn = _time_fn + + self._start_time = None + + def Remaining(self): + """Returns the remaining timeout. + + """ + if self._timeout is None: + return None + + # Get start time on first calculation + if self._start_time is None: + self._start_time = self._time_fn() + + # Calculate remaining time + remaining_timeout = self._start_time + self._timeout - self._time_fn() + + if not self._allow_negative: + # Ensure timeout is always >= 0 + return max(0.0, remaining_timeout) + + return remaining_timeout