"""
-import sys
import os
-import sha
import time
import subprocess
import re
import fcntl
import resource
import logging
+import logging.handlers
import signal
+import datetime
+import calendar
from cStringIO import StringIO
+try:
+ from hashlib import sha1
+except ImportError:
+ import sha
+ sha1 = sha.new
+
from ganeti import errors
from ganeti import constants
_locksheld = []
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
-debug = False
debug_locks = False
#: when set to True, L{RunCmd} is disabled
no_fork = False
+_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
+
class RunResult(object):
"""Holds the result of running external programs.
output = property(_GetOutput, None, None, "Return full output")
-def RunCmd(cmd, env=None, output=None, cwd='/'):
+def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False):
"""Execute a (shell) command.
The command should not read from its standard input, as it will be
closed.
- @type cmd: string or list
+ @type cmd: string or list
@param cmd: Command to run
@type env: dict
@param env: Additional environment
@type cwd: string
@param cwd: if specified, will be used as the working
directory for the command; the default will be /
+ @type reset_env: boolean
+ @param reset_env: whether to reset or keep the default os environment
@rtype: L{RunResult}
@return: RunResult instance
- @raise erors.ProgrammerError: if we call this when forks are disabled
+ @raise errors.ProgrammerError: if we call this when forks are disabled
"""
if no_fork:
shell = True
logging.debug("RunCmd '%s'", strcmd)
- cmd_env = os.environ.copy()
- cmd_env["LC_ALL"] = "C"
+ if not reset_env:
+ cmd_env = os.environ.copy()
+ cmd_env["LC_ALL"] = "C"
+ else:
+ cmd_env = {}
+
if env is not None:
cmd_env.update(env)
- if output is None:
- out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
- else:
- status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
- out = err = ""
+ try:
+ if output is None:
+ out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
+ else:
+ status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
+ out = err = ""
+ except OSError, err:
+ if err.errno == errno.ENOENT:
+ raise errors.OpExecError("Can't execute '%s': not found (%s)" %
+ (strcmd, err))
+ else:
+ raise
if status >= 0:
exitcode = status
return RunResult(exitcode, signal_, out, err, strcmd)
+
def _RunCmdPipe(cmd, env, via_shell, cwd):
"""Run a command and return its output.
return status
+def RunParts(dir_name, env=None, reset_env=False):
+ """Run Scripts or programs in a directory
+
+ @type dir_name: string
+ @param dir_name: absolute path to a directory
+ @type env: dict
+ @param env: The environment to use
+ @type reset_env: boolean
+ @param reset_env: whether to reset or keep the default os environment
+ @rtype: list of tuples
+ @return: list of (name, (one of RUNDIR_STATUS), RunResult)
+
+ """
+ rr = []
+
+ try:
+ dir_contents = ListVisibleFiles(dir_name)
+ except OSError, err:
+ logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
+ return rr
+
+ for relname in sorted(dir_contents):
+ fname = PathJoin(dir_name, relname)
+ if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
+ constants.EXT_PLUGIN_MASK.match(relname) is not None):
+ rr.append((relname, constants.RUNPARTS_SKIP, None))
+ else:
+ try:
+ result = RunCmd([fname], env=env, reset_env=reset_env)
+ except Exception, err: # pylint: disable-msg=W0703
+ rr.append((relname, constants.RUNPARTS_ERR, str(err)))
+ else:
+ rr.append((relname, constants.RUNPARTS_RUN, result))
+
+ return rr
+
+
def RemoveFile(filename):
"""Remove a file ignoring some errors.
raise
+def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
+ """Renames a file.
+
+ @type old: string
+ @param old: Original path
+ @type new: string
+ @param new: New path
+ @type mkdir: bool
+ @param mkdir: Whether to create target directory if it doesn't exist
+ @type mkdir_mode: int
+ @param mkdir_mode: Mode for newly created directories
+
+ """
+ try:
+ return os.rename(old, new)
+ except OSError, err:
+ # In at least one use case of this function, the job queue, directory
+ # creation is very rare. Checking for the directory before renaming is not
+ # as efficient.
+ if mkdir and err.errno == errno.ENOENT:
+ # Create directory and try again
+ dirname = os.path.dirname(new)
+ try:
+ os.makedirs(dirname, mode=mkdir_mode)
+ except OSError, err:
+ # Ignore EEXIST. This is only handled in os.makedirs as included in
+ # Python 2.5 and above.
+ if err.errno != errno.EEXIST or not os.path.exists(dirname):
+ raise
+
+ return os.rename(old, new)
+
+ raise
+
+
+def ResetTempfileModule():
+ """Resets the random name generator of the tempfile module.
+
+ This function should be called after C{os.fork} in the child process to
+ ensure it creates a newly seeded random generator. Otherwise it would
+ generate the same random parts as the parent process. If several processes
+ race for the creation of a temporary file, this could lead to one not getting
+ a temporary name.
+
+ """
+ # pylint: disable-msg=W0212
+ if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
+ tempfile._once_lock.acquire()
+ try:
+ # Reset random name generator
+ tempfile._name_sequence = None
+ finally:
+ tempfile._once_lock.release()
+ else:
+ logging.critical("The tempfile module misses at least one of the"
+ " '_once_lock' and '_name_sequence' attributes")
+
+
def _FingerprintFile(filename):
"""Compute the fingerprint of a file.
f = open(filename)
- fp = sha.sha()
+ fp = sha1()
while True:
data = f.read(4096)
if not data:
return ret
-def CheckDict(target, template, logname=None):
- """Ensure a dictionary has a required set of keys.
-
- For the given dictionaries I{target} and I{template}, ensure
- I{target} has all the keys from I{template}. Missing keys are added
- with values from template.
+def ForceDictType(target, key_types, allowed_values=None):
+ """Force the values of a dict to have certain types.
@type target: dict
- @param target: the dictionary to update
- @type template: dict
- @param template: the dictionary holding the default values
- @type logname: str or None
- @param logname: if not None, causes the missing keys to be
- logged with this name
+ @param target: the dict to update
+ @type key_types: dict
+ @param key_types: dict mapping target dict keys to types
+ in constants.ENFORCEABLE_TYPES
+ @type allowed_values: list
+ @keyword allowed_values: list of specially allowed values
"""
- missing = []
- for k in template:
- if k not in target:
- missing.append(k)
- target[k] = template[k]
+ if allowed_values is None:
+ allowed_values = []
+
+ if not isinstance(target, dict):
+ msg = "Expected dictionary, got '%s'" % target
+ raise errors.TypeEnforcementError(msg)
+
+ for key in target:
+ if key not in key_types:
+ msg = "Unknown key '%s'" % key
+ raise errors.TypeEnforcementError(msg)
- if missing and logname:
- logging.warning('%s missing keys %s', logname, ', '.join(missing))
+ if target[key] in allowed_values:
+ continue
+
+ ktype = key_types[key]
+ if ktype not in constants.ENFORCEABLE_TYPES:
+ msg = "'%s' has non-enforceable type %s" % (key, ktype)
+ raise errors.ProgrammerError(msg)
+
+ if ktype == constants.VTYPE_STRING:
+ if not isinstance(target[key], basestring):
+ if isinstance(target[key], bool) and not target[key]:
+ target[key] = ''
+ else:
+ msg = "'%s' (value %s) is not a valid string" % (key, target[key])
+ raise errors.TypeEnforcementError(msg)
+ elif ktype == constants.VTYPE_BOOL:
+ if isinstance(target[key], basestring) and target[key]:
+ if target[key].lower() == constants.VALUE_FALSE:
+ target[key] = False
+ elif target[key].lower() == constants.VALUE_TRUE:
+ target[key] = True
+ else:
+ msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
+ raise errors.TypeEnforcementError(msg)
+ elif target[key]:
+ target[key] = True
+ else:
+ target[key] = False
+ elif ktype == constants.VTYPE_SIZE:
+ try:
+ target[key] = ParseUnit(target[key])
+ except errors.UnitParseError, err:
+ msg = "'%s' (value %s) is not a valid size. error: %s" % \
+ (key, target[key], err)
+ raise errors.TypeEnforcementError(msg)
+ elif ktype == constants.VTYPE_INT:
+ try:
+ target[key] = int(target[key])
+ except (ValueError, TypeError):
+ msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
+ raise errors.TypeEnforcementError(msg)
def IsProcessAlive(pid):
@type pidfile: string
@param pidfile: path to the file containing the pid
@rtype: int
- @return: The process id, if the file exista and contains a valid PID,
+ @return: The process id, if the file exists and contains a valid PID,
otherwise 0
"""
try:
- pf = open(pidfile, 'r')
+ raw_data = ReadFile(pidfile)
except EnvironmentError, err:
if err.errno != errno.ENOENT:
- logging.exception("Can't read pid file?!")
+ logging.exception("Can't read pid file")
return 0
try:
- pid = int(pf.read())
- except ValueError, err:
+ pid = int(raw_data)
+ except (TypeError, ValueError), err:
logging.info("Can't parse pid file contents", exc_info=True)
return 0
return pid
-def MatchNameComponent(key, name_list):
+def MatchNameComponent(key, name_list, case_sensitive=True):
"""Try to match a name against a list.
This function will try to match a name like test1 against a list
this list, I{'test1'} as well as I{'test1.example'} will match, but
not I{'test1.ex'}. A multiple match will be considered as no match
at all (e.g. I{'test1'} against C{['test1.example.com',
- 'test1.example.org']}).
+ 'test1.example.org']}), except when the key fully matches an entry
+ (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
@type key: str
@param key: the name to be searched
@type name_list: list
@param name_list: the list of strings against which to search the key
+ @type case_sensitive: boolean
+ @param case_sensitive: whether to provide a case-sensitive match
@rtype: None or str
@return: None if there is no match I{or} if there are multiple matches,
otherwise the element from the list which matches
"""
- mo = re.compile("^%s(\..*)?$" % re.escape(key))
- names_filtered = [name for name in name_list if mo.match(name) is not None]
- if len(names_filtered) != 1:
- return None
- return names_filtered[0]
+ if key in name_list:
+ return key
+
+ re_flags = 0
+ if not case_sensitive:
+ re_flags |= re.IGNORECASE
+ key = key.upper()
+ mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
+ names_filtered = []
+ string_matches = []
+ for name in name_list:
+ if mo.match(name) is not None:
+ names_filtered.append(name)
+ if not case_sensitive and key == name.upper():
+ string_matches.append(name)
+
+ if len(string_matches) == 1:
+ return string_matches[0]
+ if len(names_filtered) == 1:
+ return names_filtered[0]
+ return None
class HostInfo:
"""Class implementing resolver and hostname functionality
"""
+ _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
+
def __init__(self, name=None):
"""Initialize the host name object.
return result
+ @classmethod
+ def NormalizeName(cls, hostname):
+ """Validate and normalize the given hostname.
+
+ @attention: the validation is a bit more relaxed than the standards
+ require; most importantly, we allow underscores in names
+ @raise errors.OpPrereqError: when the name is not valid
+
+ """
+ hostname = hostname.lower()
+ if (not cls._VALID_NAME_RE.match(hostname) or
+ # double-dots, meaning empty label
+ ".." in hostname or
+ # empty initial label
+ hostname.startswith(".")):
+ raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
+ errors.ECODE_INVAL)
+ if hostname.endswith("."):
+ hostname = hostname.rstrip(".")
+ return hostname
+
+
+def GetHostInfo(name=None):
+ """Lookup host name and raise an OpPrereqError for failures"""
+
+ try:
+ return HostInfo(name)
+ except errors.ResolverError, err:
+ raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
+ (err[0], err[2]), errors.ECODE_RESOLVER)
+
def ListVolumeGroups():
"""List volume groups and their size
return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
-def CheckBEParams(beparams):
- """Checks whether the user-supplied be-params are valid,
- and converts them from string format where appropriate.
-
- @type beparams: dict
- @param beparams: new params dict
-
- """
- if beparams:
- for item in beparams:
- if item not in constants.BES_PARAMETERS:
- raise errors.OpPrereqError("Unknown backend parameter %s" % item)
- if item in (constants.BE_MEMORY, constants.BE_VCPUS):
- val = beparams[item]
- if val != constants.VALUE_DEFAULT:
- try:
- val = int(val)
- except ValueError, err:
- raise errors.OpPrereqError("Invalid %s size: %s" % (item, str(err)))
- beparams[item] = val
- if item in (constants.BE_AUTO_BALANCE):
- val = beparams[item]
- if not isinstance(val, bool):
- if val == constants.VALUE_TRUE:
- beparams[item] = True
- elif val == constants.VALUE_FALSE:
- beparams[item] = False
- else:
- raise errors.OpPrereqError("Invalid %s value: %s" % (item, val))
-
-
def NiceSort(name_list):
"""Sort a list of strings based on digit and non-digit groupings.
"""
try:
nv = fn(val)
- except (ValueError, TypeError), err:
+ except (ValueError, TypeError):
nv = val
return nv
@type ip: str
@param ip: the address to be checked
@rtype: a regular expression match object
- @return: a regular epression match object, or None if the
+ @return: a regular expression match object, or None if the
address is not valid
"""
This function will check all arguments in the args list so that they
are valid shell parameters (i.e. they don't contain shell
- metacharaters). If everything is ok, it will return the result of
+ metacharacters). If everything is ok, it will return the result of
template % args.
@type template: str
is always an int in MiB.
"""
- m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', input_string)
+ m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
if not m:
raise errors.UnitParseError("Invalid format")
@param aliases: the list of aliases to add for the hostname
"""
+ # FIXME: use WriteFile + fn rather than duplicating its efforts
# Ensure aliases are unique
aliases = UniqueSequence([hostname] + aliases)[1:]
out.flush()
os.fsync(out)
+ os.chmod(tmpname, 0644)
os.rename(tmpname, file_name)
finally:
f.close()
@param hostname: the hostname to be removed
"""
+ # FIXME: use WriteFile + fn rather than duplicating its efforts
fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
try:
out = os.fdopen(fd, 'w')
out.flush()
os.fsync(out)
+ os.chmod(tmpname, 0644)
os.rename(tmpname, file_name)
finally:
f.close()
RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
+def TimestampForFilename():
+ """Returns the current time formatted for filenames.
+
+ The format doesn't contain colons as some shells and applications them as
+ separators.
+
+ """
+ return time.strftime("%Y-%m-%d_%H_%M_%S")
+
+
def CreateBackup(file_name):
"""Creates a backup of a file.
raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
file_name)
- prefix = '%s.backup-%d.' % (os.path.basename(file_name), int(time.time()))
+ prefix = ("%s.backup-%s." %
+ (os.path.basename(file_name), TimestampForFilename()))
dir_name = os.path.dirname(file_name)
fsrc = open(file_name, 'rb')
(fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
fdst = os.fdopen(fd, 'wb')
try:
+ logging.debug("Backing up %s at %s", file_name, backup_name)
shutil.copyfileobj(fsrc, fdst)
finally:
fdst.close()
@type args: list
@param args: list of arguments to be quoted
@rtype: str
- @return: the quoted arguments concatenaned with spaces
+ @return: the quoted arguments concatenated with spaces
"""
return ' '.join([ShellQuote(i) for i in args])
@type port: int
@param port: the port to connect to
@type timeout: int
- @param timeout: the timeout on the connection attemp
+ @param timeout: the timeout on the connection attempt
@type live_port_needed: boolean
@param live_port_needed: whether a closed port will cause the
function to return failure, as if there was a timeout
if source is not None:
try:
sock.bind((source, 0))
- except socket.error, (errcode, errstring):
+ except socket.error, (errcode, _):
if errcode == errno.EADDRNOTAVAIL:
success = False
success = True
except socket.timeout:
success = False
- except socket.error, (errcode, errstring):
+ except socket.error, (errcode, _):
success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
return success
address.
@type address: string
- @param address: the addres to check
+ @param address: the address to check
@rtype: bool
@return: True if we own the address
@param path: the directory to enumerate
@rtype: list
@return: the list of all files not starting with a dot
+ @raise ProgrammerError: if L{path} is not an absolue and normalized path
"""
+ if not IsNormAbsPath(path):
+ raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
+ " absolute/normalized: '%s'" % path)
files = [i for i in os.listdir(path) if not i.startswith(".")]
files.sort()
return files
@rtype: str
"""
- f = open("/proc/sys/kernel/random/uuid", "r")
- try:
- return f.read(128).rstrip("\n")
- finally:
- f.close()
+ return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
-def GenerateSecret():
+def GenerateSecret(numbytes=20):
"""Generates a random secret.
- This will generate a pseudo-random secret, and return its sha digest
+ This will generate a pseudo-random secret returning an hex string
(so that it can be used where an ASCII string is needed).
+ @param numbytes: the number of bytes which will be represented by the returned
+ string (defaulting to 20, the length of a SHA1 hash)
@rtype: str
- @return: a sha1 hexdigest of a block of 64 random bytes
+ @return: an hex representation of the pseudo-random sequence
+
+ """
+ return os.urandom(numbytes).encode('hex')
+
+
+def EnsureDirs(dirs):
+ """Make required directories, if they don't exist.
+
+ @param dirs: list of tuples (dir_name, dir_mode)
+ @type dirs: list of (string, integer)
"""
- return sha.new(os.urandom(64)).hexdigest()
+ for dir_name, dir_mode in dirs:
+ try:
+ os.mkdir(dir_name, dir_mode)
+ except EnvironmentError, err:
+ if err.errno != errno.EEXIST:
+ raise errors.GenericError("Cannot create needed directory"
+ " '%s': %s" % (dir_name, err))
+ if not os.path.isdir(dir_name):
+ raise errors.GenericError("%s is not a directory" % dir_name)
-def ReadFile(file_name, size=None):
+def ReadFile(file_name, size=-1):
"""Reads a file.
- @type size: None or int
- @param size: Read at most size bytes
+ @type size: int
+ @param size: Read at most size bytes (if negative, entire file)
@rtype: str
- @return: the (possibly partial) conent of the file
+ @return: the (possibly partial) content of the file
"""
f = open(file_name, "r")
try:
- if size is None:
- return f.read()
- else:
- return f.read(size)
+ return f.read(size)
finally:
f.close()
mtime/atime of the file.
If the function doesn't raise an exception, it has succeeded and the
- target file has the new contents. If the file has raised an
+ target file has the new contents. If the function has raised an
exception, an existing target file should be unmodified and the
temporary file should be removed.
@type fn: callable
@param fn: content writing function, called with
file descriptor as parameter
- @type data: sr
+ @type data: str
@param data: contents of the file
@type mode: int
@param mode: file mode
@return: None if the 'close' parameter evaluates to True,
otherwise the file descriptor
- @raise errors.ProgrammerError: if an of the arguments are not valid
+ @raise errors.ProgrammerError: if any of the arguments are not valid
"""
if not os.path.isabs(file_name):
dir_name, base_name = os.path.split(file_name)
fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
+ do_remove = True
# here we need to make sure we remove the temp file, if any error
# leaves it in place
try:
os.utime(new_name, (atime, mtime))
if not dry_run:
os.rename(new_name, file_name)
+ do_remove = False
finally:
if close:
os.close(fd)
result = None
else:
result = fd
- RemoveFile(new_name)
+ if do_remove:
+ RemoveFile(new_name)
return result
return None
-def all(seq, pred=bool):
+def all(seq, pred=bool): # pylint: disable-msg=W0622
"Returns True if pred(x) is True for every element in the iterable"
- for elem in itertools.ifilterfalse(pred, seq):
+ for _ in itertools.ifilterfalse(pred, seq):
return False
return True
-def any(seq, pred=bool):
+def any(seq, pred=bool): # pylint: disable-msg=W0622
"Returns True if pred(x) is True for at least one element in the iterable"
- for elem in itertools.ifilter(pred, seq):
+ for _ in itertools.ifilter(pred, seq):
return True
return False
+def SingleWaitForFdCondition(fdobj, event, timeout):
+ """Waits for a condition to occur on the socket.
+
+ Immediately returns at the first interruption.
+
+ @type fdobj: integer or object supporting a fileno() method
+ @param fdobj: entity to wait for events on
+ @type event: integer
+ @param event: ORed condition (see select module)
+ @type timeout: float or None
+ @param timeout: Timeout in seconds
+ @rtype: int or None
+ @return: None for timeout, otherwise occured conditions
+
+ """
+ check = (event | select.POLLPRI |
+ select.POLLNVAL | select.POLLHUP | select.POLLERR)
+
+ if timeout is not None:
+ # Poller object expects milliseconds
+ timeout *= 1000
+
+ poller = select.poll()
+ poller.register(fdobj, event)
+ try:
+ # TODO: If the main thread receives a signal and we have no timeout, we
+ # could wait forever. This should check a global "quit" flag or something
+ # every so often.
+ io_events = poller.poll(timeout)
+ except select.error, err:
+ if err[0] != errno.EINTR:
+ raise
+ io_events = []
+ if io_events and io_events[0][1] & check:
+ return io_events[0][1]
+ else:
+ return None
+
+
+class FdConditionWaiterHelper(object):
+ """Retry helper for WaitForFdCondition.
+
+ This class contains the retried and wait functions that make sure
+ WaitForFdCondition can continue waiting until the timeout is actually
+ expired.
+
+ """
+
+ def __init__(self, timeout):
+ self.timeout = timeout
+
+ def Poll(self, fdobj, event):
+ result = SingleWaitForFdCondition(fdobj, event, self.timeout)
+ if result is None:
+ raise RetryAgain()
+ else:
+ return result
+
+ def UpdateTimeout(self, timeout):
+ self.timeout = timeout
+
+
+def WaitForFdCondition(fdobj, event, timeout):
+ """Waits for a condition to occur on the socket.
+
+ Retries until the timeout is expired, even if interrupted.
+
+ @type fdobj: integer or object supporting a fileno() method
+ @param fdobj: entity to wait for events on
+ @type event: integer
+ @param event: ORed condition (see select module)
+ @type timeout: float or None
+ @param timeout: Timeout in seconds
+ @rtype: int or None
+ @return: None for timeout, otherwise occured conditions
+
+ """
+ if timeout is not None:
+ retrywaiter = FdConditionWaiterHelper(timeout)
+ result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
+ args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
+ else:
+ result = None
+ while result is None:
+ result = SingleWaitForFdCondition(fdobj, event, timeout)
+ return result
+
+
+def partition(seq, pred=bool): # # pylint: disable-msg=W0622
+ "Partition a list in two, based on the given predicate"
+ return (list(itertools.ifilter(pred, seq)),
+ list(itertools.ifilterfalse(pred, seq)))
+
+
def UniqueSequence(seq):
"""Returns a list with unique elements.
Element order is preserved.
@type seq: sequence
- @param seq: the sequence with the source elementes
+ @param seq: the sequence with the source elements
@rtype: list
@return: list of unique elements from seq
return [i for i in seq if i not in seen and not seen.add(i)]
-def IsValidMac(mac):
- """Predicate to check if a MAC address is valid.
+def NormalizeAndValidateMac(mac):
+ """Normalizes and check if a MAC address is valid.
- Checks wether the supplied MAC address is formally correct, only
- accepts colon separated format.
+ Checks whether the supplied MAC address is formally correct, only
+ accepts colon separated format. Normalize it to all lower.
@type mac: str
@param mac: the MAC to be validated
- @rtype: boolean
- @return: True is the MAC seems valid
+ @rtype: str
+ @return: returns the normalized and validated MAC.
+
+ @raise errors.OpPrereqError: If the MAC isn't valid
"""
- mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$")
- return mac_check.match(mac) is not None
+ mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
+ if not mac_check.match(mac):
+ raise errors.OpPrereqError("Invalid MAC address specified: %s" %
+ mac, errors.ECODE_INVAL)
+
+ return mac.lower()
def TestDelay(duration):
"""
if duration < 0:
- return False
+ return False, "Invalid sleep duration"
time.sleep(duration)
- return True
+ return True, None
-def Daemonize(logfile, noclose_fds=None):
- """Daemonize the current process.
+def _CloseFDNoErr(fd, retries=5):
+ """Close a file descriptor ignoring errors.
- This detaches the current process from the controlling terminal and
- runs it in the background as a daemon.
+ @type fd: int
+ @param fd: the file descriptor
+ @type retries: int
+ @param retries: how many retries to make, in case we get any
+ other error than EBADF
+
+ """
+ try:
+ os.close(fd)
+ except OSError, err:
+ if err.errno != errno.EBADF:
+ if retries > 0:
+ _CloseFDNoErr(fd, retries - 1)
+ # else either it's closed already or we're out of retries, so we
+ # ignore this and go on
+
+
+def CloseFDs(noclose_fds=None):
+ """Close file descriptors.
+
+ This closes all file descriptors above 2 (i.e. except
+ stdin/out/err).
- @type logfile: str
- @param logfile: the logfile to which we should redirect stdout/stderr
@type noclose_fds: list or None
@param noclose_fds: if given, it denotes a list of file descriptor
that should not be closed
- @rtype: int
- @returns: the value zero
"""
- UMASK = 077
- WORKDIR = "/"
# Default maximum for the number of available file descriptors.
if 'SC_OPEN_MAX' in os.sysconf_names:
try:
MAXFD = 1024
else:
MAXFD = 1024
+ maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
+ if (maxfd == resource.RLIM_INFINITY):
+ maxfd = MAXFD
+
+ # Iterate through and close all file descriptors (except the standard ones)
+ for fd in range(3, maxfd):
+ if noclose_fds and fd in noclose_fds:
+ continue
+ _CloseFDNoErr(fd)
+
+
+def Daemonize(logfile):
+ """Daemonize the current process.
+
+ This detaches the current process from the controlling terminal and
+ runs it in the background as a daemon.
+
+ @type logfile: str
+ @param logfile: the logfile to which we should redirect stdout/stderr
+ @rtype: int
+ @return: the value zero
+
+ """
+ # pylint: disable-msg=W0212
+ # yes, we really want os._exit
+ UMASK = 077
+ WORKDIR = "/"
# this might fail
pid = os.fork()
os._exit(0) # Exit parent (the first child) of the second child.
else:
os._exit(0) # Exit parent of the first child.
- maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
- if (maxfd == resource.RLIM_INFINITY):
- maxfd = MAXFD
- # Iterate through and close all file descriptors.
- for fd in range(0, maxfd):
- if noclose_fds and fd in noclose_fds:
- continue
- try:
- os.close(fd)
- except OSError: # ERROR, fd wasn't open to begin with (ignored)
- pass
- os.open(logfile, os.O_RDWR|os.O_CREAT|os.O_APPEND, 0600)
- # Duplicate standard input to standard output and standard error.
- os.dup2(0, 1) # standard output (1)
- os.dup2(0, 2) # standard error (2)
+ for fd in range(3):
+ _CloseFDNoErr(fd)
+ i = os.open("/dev/null", os.O_RDONLY) # stdin
+ assert i == 0, "Can't close/reopen stdin"
+ i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
+ assert i == 1, "Can't close/reopen stdout"
+ # Duplicate standard output to standard error.
+ os.dup2(1, 2)
return 0
daemon name
"""
- return os.path.join(constants.RUN_GANETI_DIR, "%s.pid" % name)
+ return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
+
+
+def EnsureDaemon(name):
+ """Check for and start daemon if not alive.
+
+ """
+ result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
+ if result.failed:
+ logging.error("Can't start daemon '%s', failure %s, output: %s",
+ name, result.fail_reason, result.output)
+ return False
+
+ return True
def WritePidFile(name):
@param name: the daemon name used to derive the pidfile name
"""
- pid = os.getpid()
pidfilename = DaemonPidFileName(name)
# TODO: we could check here that the file contains our pid
try:
RemoveFile(pidfilename)
- except:
+ except: # pylint: disable-msg=W0702
pass
if not IsProcessAlive(pid):
return
+
_helper(pid, signal_, waitpid)
+
if timeout <= 0:
return
- # Wait up to $timeout seconds
- end = time.time() + timeout
- wait = 0.01
- while time.time() < end and IsProcessAlive(pid):
+ def _CheckProcess():
+ if not IsProcessAlive(pid):
+ return
+
try:
(result_pid, _) = os.waitpid(pid, os.WNOHANG)
- if result_pid > 0:
- break
except OSError:
- pass
- time.sleep(wait)
- # Make wait time longer for next try
- if wait < 0.1:
- wait *= 1.5
+ raise RetryAgain()
+
+ if result_pid > 0:
+ return
+
+ raise RetryAgain()
+
+ try:
+ # Wait up to $timeout seconds
+ Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
+ except RetryTimeout:
+ pass
if IsProcessAlive(pid):
# Kill process if it's still alive
@return: full path to the object if found, None otherwise
"""
+ # validate the filename mask
+ if constants.EXT_PLUGIN_MASK.match(name) is None:
+ logging.critical("Invalid value passed for external script name: '%s'",
+ name)
+ return None
+
for dir_name in search_path:
+ # FIXME: investigate switch to PathJoin
item_name = os.path.sep.join([dir_name, name])
- if test(item_name):
+ # check the user test and that we're indeed resolving to the given
+ # basename
+ if test(item_name) and os.path.basename(item_name) == name:
return item_name
return None
return float(seconds) + (float(microseconds) * 0.000001)
-def GetNodeDaemonPort():
- """Get the node daemon port for this cluster.
+def GetDaemonPort(daemon_name):
+ """Get the daemon port for this cluster.
Note that this routine does not read a ganeti-specific file, but
instead uses C{socket.getservbyname} to allow pre-customization of
this parameter outside of Ganeti.
+ @type daemon_name: string
+ @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
@rtype: int
"""
+ if daemon_name not in constants.DAEMONS_PORTS:
+ raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
+
+ (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
try:
- port = socket.getservbyname("ganeti-noded", "tcp")
+ port = socket.getservbyname(daemon_name, proto)
except socket.error:
- port = constants.DEFAULT_NODED_PORT
+ port = default_port
return port
-def SetupLogging(logfile, debug=False, stderr_logging=False, program=""):
+def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
+ multithreaded=False, syslog=constants.SYSLOG_USAGE):
"""Configures the logging module.
@type logfile: str
@param logfile: the filename to which we should log
- @type debug: boolean
- @param debug: whether to enable debug messages too or
+ @type debug: integer
+ @param debug: if greater than zero, enable debug messages, otherwise
only those at C{INFO} and above level
@type stderr_logging: boolean
@param stderr_logging: whether we should also log to the standard error
@type program: str
@param program: the name under which we should log messages
+ @type multithreaded: boolean
+ @param multithreaded: if True, will add the thread name to the log file
+ @type syslog: string
+ @param syslog: one of 'no', 'yes', 'only':
+ - if no, syslog is not used
+ - if yes, syslog is used (in addition to file-logging)
+ - if only, only syslog is used
@raise EnvironmentError: if we can't open the log file and
- stderr logging is disabled
+ syslog/stderr logging is disabled
"""
- fmt = "%(asctime)s: " + program + " "
+ fmt = "%(asctime)s: " + program + " pid=%(process)d"
+ sft = program + "[%(process)d]:"
+ if multithreaded:
+ fmt += "/%(threadName)s"
+ sft += " (%(threadName)s)"
if debug:
- fmt += ("pid=%(process)d/%(threadName)s %(levelname)s"
- " %(module)s:%(lineno)s %(message)s")
- else:
- fmt += "pid=%(process)d %(levelname)s %(message)s"
+ fmt += " %(module)s:%(lineno)s"
+ # no debug info for syslog loggers
+ fmt += " %(levelname)s %(message)s"
+ # yes, we do want the textual level, as remote syslog will probably
+ # lose the error level, and it's easier to grep for it
+ sft += " %(levelname)s %(message)s"
formatter = logging.Formatter(fmt)
+ sys_fmt = logging.Formatter(sft)
root_logger = logging.getLogger("")
root_logger.setLevel(logging.NOTSET)
# Remove all previously setup handlers
for handler in root_logger.handlers:
+ handler.close()
root_logger.removeHandler(handler)
if stderr_logging:
stderr_handler.setLevel(logging.CRITICAL)
root_logger.addHandler(stderr_handler)
- # this can fail, if the logging directories are not setup or we have
- # a permisssion problem; in this case, it's best to log but ignore
- # the error if stderr_logging is True, and if false we re-raise the
- # exception since otherwise we could run but without any logs at all
+ if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
+ facility = logging.handlers.SysLogHandler.LOG_DAEMON
+ syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
+ facility)
+ syslog_handler.setFormatter(sys_fmt)
+ # Never enable debug over syslog
+ syslog_handler.setLevel(logging.INFO)
+ root_logger.addHandler(syslog_handler)
+
+ if syslog != constants.SYSLOG_ONLY:
+ # this can fail, if the logging directories are not setup or we have
+ # a permisssion problem; in this case, it's best to log but ignore
+ # the error if stderr_logging is True, and if false we re-raise the
+ # exception since otherwise we could run but without any logs at all
+ try:
+ logfile_handler = logging.FileHandler(logfile)
+ logfile_handler.setFormatter(formatter)
+ if debug:
+ logfile_handler.setLevel(logging.DEBUG)
+ else:
+ logfile_handler.setLevel(logging.INFO)
+ root_logger.addHandler(logfile_handler)
+ except EnvironmentError:
+ if stderr_logging or syslog == constants.SYSLOG_YES:
+ logging.exception("Failed to enable logging to file '%s'", logfile)
+ else:
+ # we need to re-raise the exception
+ raise
+
+
+def IsNormAbsPath(path):
+ """Check whether a path is absolute and also normalized
+
+ This avoids things like /dir/../../other/path to be valid.
+
+ """
+ return os.path.normpath(path) == path and os.path.isabs(path)
+
+
+def PathJoin(*args):
+ """Safe-join a list of path components.
+
+ Requirements:
+ - the first argument must be an absolute path
+ - no component in the path must have backtracking (e.g. /../),
+ since we check for normalization at the end
+
+ @param args: the path components to be joined
+ @raise ValueError: for invalid paths
+
+ """
+ # ensure we're having at least one path passed in
+ assert args
+ # ensure the first component is an absolute and normalized path name
+ root = args[0]
+ if not IsNormAbsPath(root):
+ raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
+ result = os.path.join(*args)
+ # ensure that the whole path is normalized
+ if not IsNormAbsPath(result):
+ raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
+ # check that we're still under the original prefix
+ prefix = os.path.commonprefix([root, result])
+ if prefix != root:
+ raise ValueError("Error: path joining resulted in different prefix"
+ " (%s != %s)" % (prefix, root))
+ return result
+
+
+def TailFile(fname, lines=20):
+ """Return the last lines from a file.
+
+ @note: this function will only read and parse the last 4KB of
+ the file; if the lines are very long, it could be that less
+ than the requested number of lines are returned
+
+ @param fname: the file name
+ @type lines: int
+ @param lines: the (maximum) number of lines to return
+
+ """
+ fd = open(fname, "r")
try:
- logfile_handler = logging.FileHandler(logfile)
- logfile_handler.setFormatter(formatter)
- if debug:
- logfile_handler.setLevel(logging.DEBUG)
+ fd.seek(0, 2)
+ pos = fd.tell()
+ pos = max(0, pos-4096)
+ fd.seek(pos, 0)
+ raw_data = fd.read()
+ finally:
+ fd.close()
+
+ rows = raw_data.splitlines()
+ return rows[-lines:]
+
+
+def _ParseAsn1Generalizedtime(value):
+ """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
+
+ @type value: string
+ @param value: ASN1 GENERALIZEDTIME timestamp
+
+ """
+ m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
+ if m:
+ # We have an offset
+ asn1time = m.group(1)
+ hours = int(m.group(2))
+ minutes = int(m.group(3))
+ utcoffset = (60 * hours) + minutes
+ else:
+ if not value.endswith("Z"):
+ raise ValueError("Missing timezone")
+ asn1time = value[:-1]
+ utcoffset = 0
+
+ parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
+
+ tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
+
+ return calendar.timegm(tt.utctimetuple())
+
+
+def GetX509CertValidity(cert):
+ """Returns the validity period of the certificate.
+
+ @type cert: OpenSSL.crypto.X509
+ @param cert: X509 certificate object
+
+ """
+ # The get_notBefore and get_notAfter functions are only supported in
+ # pyOpenSSL 0.7 and above.
+ try:
+ get_notbefore_fn = cert.get_notBefore
+ except AttributeError:
+ not_before = None
+ else:
+ not_before_asn1 = get_notbefore_fn()
+
+ if not_before_asn1 is None:
+ not_before = None
else:
- logfile_handler.setLevel(logging.INFO)
- root_logger.addHandler(logfile_handler)
- except EnvironmentError, err:
- if stderr_logging:
- logging.exception("Failed to enable logging to file '%s'", logfile)
+ not_before = _ParseAsn1Generalizedtime(not_before_asn1)
+
+ try:
+ get_notafter_fn = cert.get_notAfter
+ except AttributeError:
+ not_after = None
+ else:
+ not_after_asn1 = get_notafter_fn()
+
+ if not_after_asn1 is None:
+ not_after = None
else:
- # we need to re-raise the exception
- raise
+ not_after = _ParseAsn1Generalizedtime(not_after_asn1)
+
+ return (not_before, not_after)
+
+
+def SafeEncode(text):
+ """Return a 'safe' version of a source string.
+
+ This function mangles the input string and returns a version that
+ should be safe to display/encode as ASCII. To this end, we first
+ convert it to ASCII using the 'backslashreplace' encoding which
+ should get rid of any non-ASCII chars, and then we process it
+ through a loop copied from the string repr sources in the python; we
+ don't use string_escape anymore since that escape single quotes and
+ backslashes too, and that is too much; and that escaping is not
+ stable, i.e. string_escape(string_escape(x)) != string_escape(x).
+
+ @type text: str or unicode
+ @param text: input data
+ @rtype: str
+ @return: a safe version of text
+
+ """
+ if isinstance(text, unicode):
+ # only if unicode; if str already, we handle it below
+ text = text.encode('ascii', 'backslashreplace')
+ resu = ""
+ for char in text:
+ c = ord(char)
+ if char == '\t':
+ resu += r'\t'
+ elif char == '\n':
+ resu += r'\n'
+ elif char == '\r':
+ resu += r'\'r'
+ elif c < 32 or c >= 127: # non-printable
+ resu += "\\x%02x" % (c & 0xff)
+ else:
+ resu += char
+ return resu
+
+
+def UnescapeAndSplit(text, sep=","):
+ """Split and unescape a string based on a given separator.
+
+ This function splits a string based on a separator where the
+ separator itself can be escape in order to be an element of the
+ elements. The escaping rules are (assuming coma being the
+ separator):
+ - a plain , separates the elements
+ - a sequence \\\\, (double backslash plus comma) is handled as a
+ backslash plus a separator comma
+ - a sequence \, (backslash plus comma) is handled as a
+ non-separator comma
+
+ @type text: string
+ @param text: the string to split
+ @type sep: string
+ @param text: the separator
+ @rtype: string
+ @return: a list of strings
+
+ """
+ # we split the list by sep (with no escaping at this stage)
+ slist = text.split(sep)
+ # next, we revisit the elements and if any of them ended with an odd
+ # number of backslashes, then we join it with the next
+ rlist = []
+ while slist:
+ e1 = slist.pop(0)
+ if e1.endswith("\\"):
+ num_b = len(e1) - len(e1.rstrip("\\"))
+ if num_b % 2 == 1:
+ e2 = slist.pop(0)
+ # here the backslashes remain (all), and will be reduced in
+ # the next step
+ rlist.append(e1 + sep + e2)
+ continue
+ rlist.append(e1)
+ # finally, replace backslash-something with something
+ rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
+ return rlist
+
+
+def CommaJoin(names):
+ """Nicely join a set of identifiers.
+
+ @param names: set, list or tuple
+ @return: a string with the formatted results
+
+ """
+ return ", ".join([str(val) for val in names])
+
+
+def BytesToMebibyte(value):
+ """Converts bytes to mebibytes.
+
+ @type value: int
+ @param value: Value in bytes
+ @rtype: int
+ @return: Value in mebibytes
+
+ """
+ return int(round(value / (1024.0 * 1024.0), 0))
+
+
+def CalculateDirectorySize(path):
+ """Calculates the size of a directory recursively.
+
+ @type path: string
+ @param path: Path to directory
+ @rtype: int
+ @return: Size in mebibytes
+
+ """
+ size = 0
+
+ for (curpath, _, files) in os.walk(path):
+ for filename in files:
+ st = os.lstat(PathJoin(curpath, filename))
+ size += st.st_size
+
+ return BytesToMebibyte(size)
+
+
+def GetFilesystemStats(path):
+ """Returns the total and free space on a filesystem.
+
+ @type path: string
+ @param path: Path on filesystem to be examined
+ @rtype: int
+ @return: tuple of (Total space, Free space) in mebibytes
+
+ """
+ st = os.statvfs(path)
+
+ fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
+ tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
+ return (tsize, fsize)
+
+
+def RunInSeparateProcess(fn, *args):
+ """Runs a function in a separate process.
+
+ Note: Only boolean return values are supported.
+
+ @type fn: callable
+ @param fn: Function to be called
+ @rtype: bool
+ @return: Function's result
+
+ """
+ pid = os.fork()
+ if pid == 0:
+ # Child process
+ try:
+ # In case the function uses temporary files
+ ResetTempfileModule()
+
+ # Call function
+ result = int(bool(fn(*args)))
+ assert result in (0, 1)
+ except: # pylint: disable-msg=W0702
+ logging.exception("Error while calling function in separate process")
+ # 0 and 1 are reserved for the return value
+ result = 33
+
+ os._exit(result) # pylint: disable-msg=W0212
+
+ # Parent process
+
+ # Avoid zombies and check exit code
+ (_, status) = os.waitpid(pid, 0)
+
+ if os.WIFSIGNALED(status):
+ exitcode = None
+ signum = os.WTERMSIG(status)
+ else:
+ exitcode = os.WEXITSTATUS(status)
+ signum = None
+
+ if not (exitcode in (0, 1) and signum is None):
+ raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
+ (exitcode, signum))
+
+ return bool(exitcode)
def LockedMethod(fn):
logging.debug(*args, **kwargs)
def wrapper(self, *args, **kwargs):
+ # pylint: disable-msg=W0212
assert hasattr(self, '_lock')
lock = self._lock
_LockDebug("Waiting for %s", lock)
raise
+def FormatTime(val):
+ """Formats a time value.
+
+ @type val: float or None
+ @param val: the timestamp as returned by time.time()
+ @return: a string value or N/A if we don't have a valid timestamp
+
+ """
+ if val is None or not isinstance(val, (int, float)):
+ return "N/A"
+ # these two codes works on Linux, but they are not guaranteed on all
+ # platforms
+ return time.strftime("%F %T", time.localtime(val))
+
+
+def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
+ """Reads the watcher pause file.
+
+ @type filename: string
+ @param filename: Path to watcher pause file
+ @type now: None, float or int
+ @param now: Current time as Unix timestamp
+ @type remove_after: int
+ @param remove_after: Remove watcher pause file after specified amount of
+ seconds past the pause end time
+
+ """
+ if now is None:
+ now = time.time()
+
+ try:
+ value = ReadFile(filename)
+ except IOError, err:
+ if err.errno != errno.ENOENT:
+ raise
+ value = None
+
+ if value is not None:
+ try:
+ value = int(value)
+ except ValueError:
+ logging.warning(("Watcher pause file (%s) contains invalid value,"
+ " removing it"), filename)
+ RemoveFile(filename)
+ value = None
+
+ if value is not None:
+ # Remove file if it's outdated
+ if now > (value + remove_after):
+ RemoveFile(filename)
+ value = None
+
+ elif now > value:
+ value = None
+
+ return value
+
+
+class RetryTimeout(Exception):
+ """Retry loop timed out.
+
+ """
+
+
+class RetryAgain(Exception):
+ """Retry again.
+
+ """
+
+
+class _RetryDelayCalculator(object):
+ """Calculator for increasing delays.
+
+ """
+ __slots__ = [
+ "_factor",
+ "_limit",
+ "_next",
+ "_start",
+ ]
+
+ def __init__(self, start, factor, limit):
+ """Initializes this class.
+
+ @type start: float
+ @param start: Initial delay
+ @type factor: float
+ @param factor: Factor for delay increase
+ @type limit: float or None
+ @param limit: Upper limit for delay or None for no limit
+
+ """
+ assert start > 0.0
+ assert factor >= 1.0
+ assert limit is None or limit >= 0.0
+
+ self._start = start
+ self._factor = factor
+ self._limit = limit
+
+ self._next = start
+
+ def __call__(self):
+ """Returns current delay and calculates the next one.
+
+ """
+ current = self._next
+
+ # Update for next run
+ if self._limit is None or self._next < self._limit:
+ self._next = min(self._limit, self._next * self._factor)
+
+ return current
+
+
+#: Special delay to specify whole remaining timeout
+RETRY_REMAINING_TIME = object()
+
+
+def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
+ _time_fn=time.time):
+ """Call a function repeatedly until it succeeds.
+
+ The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
+ anymore. Between calls a delay, specified by C{delay}, is inserted. After a
+ total of C{timeout} seconds, this function throws L{RetryTimeout}.
+
+ C{delay} can be one of the following:
+ - callable returning the delay length as a float
+ - Tuple of (start, factor, limit)
+ - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
+ useful when overriding L{wait_fn} to wait for an external event)
+ - A static delay as a number (int or float)
+
+ @type fn: callable
+ @param fn: Function to be called
+ @param delay: Either a callable (returning the delay), a tuple of (start,
+ factor, limit) (see L{_RetryDelayCalculator}),
+ L{RETRY_REMAINING_TIME} or a number (int or float)
+ @type timeout: float
+ @param timeout: Total timeout
+ @type wait_fn: callable
+ @param wait_fn: Waiting function
+ @return: Return value of function
+
+ """
+ assert callable(fn)
+ assert callable(wait_fn)
+ assert callable(_time_fn)
+
+ if args is None:
+ args = []
+
+ end_time = _time_fn() + timeout
+
+ if callable(delay):
+ # External function to calculate delay
+ calc_delay = delay
+
+ elif isinstance(delay, (tuple, list)):
+ # Increasing delay with optional upper boundary
+ (start, factor, limit) = delay
+ calc_delay = _RetryDelayCalculator(start, factor, limit)
+
+ elif delay is RETRY_REMAINING_TIME:
+ # Always use the remaining time
+ calc_delay = None
+
+ else:
+ # Static delay
+ calc_delay = lambda: delay
+
+ assert calc_delay is None or callable(calc_delay)
+
+ while True:
+ try:
+ # pylint: disable-msg=W0142
+ return fn(*args)
+ except RetryAgain:
+ pass
+
+ remaining_time = end_time - _time_fn()
+
+ if remaining_time < 0.0:
+ raise RetryTimeout()
+
+ assert remaining_time >= 0.0
+
+ if calc_delay is None:
+ wait_fn(remaining_time)
+ else:
+ current_delay = calc_delay()
+ if current_delay > 0.0:
+ wait_fn(current_delay)
+
+
class FileLock(object):
"""Utility class for file locks.
"""
- def __init__(self, filename):
+ def __init__(self, fd, filename):
"""Constructor for FileLock.
- This will open the file denoted by the I{filename} argument.
-
+ @type fd: file
+ @param fd: File object
@type filename: str
- @param filename: path to the file to be locked
+ @param filename: Path of the file opened at I{fd}
"""
+ self.fd = fd
self.filename = filename
- self.fd = open(self.filename, "w")
+
+ @classmethod
+ def Open(cls, filename):
+ """Creates and opens a file to be used as a file-based lock.
+
+ @type filename: string
+ @param filename: path to the file to be locked
+
+ """
+ # Using "os.open" is necessary to allow both opening existing file
+ # read/write and creating if not existing. Vanilla "open" will truncate an
+ # existing file -or- allow creating if not existing.
+ return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
+ filename)
def __del__(self):
self.Close()
"""Close the file and release the lock.
"""
- if self.fd:
+ if hasattr(self, "fd") and self.fd:
self.fd.close()
self.fd = None
assert self.fd, "Lock was closed"
assert timeout is None or timeout >= 0, \
"If specified, timeout must be positive"
+ assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
- if timeout is not None:
- flag |= fcntl.LOCK_NB
- timeout_end = time.time() + timeout
-
- # Blocking doesn't have effect with timeout
- elif not blocking:
+ # When a timeout is used, LOCK_NB must always be set
+ if not (timeout is None and blocking):
flag |= fcntl.LOCK_NB
- timeout_end = None
- retry = True
- while retry:
+ if timeout is None:
+ self._Lock(self.fd, flag, timeout)
+ else:
try:
- fcntl.flock(self.fd, flag)
- retry = False
- except IOError, err:
- if err.errno in (errno.EAGAIN, ):
- if timeout_end is not None and time.time() < timeout_end:
- # Wait before trying again
- time.sleep(max(0.1, min(1.0, timeout)))
- else:
- raise errors.LockError(errmsg)
- else:
- logging.exception("fcntl.flock failed")
- raise
+ Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
+ args=(self.fd, flag, timeout))
+ except RetryTimeout:
+ raise errors.LockError(errmsg)
+
+ @staticmethod
+ def _Lock(fd, flag, timeout):
+ try:
+ fcntl.flock(fd, flag)
+ except IOError, err:
+ if timeout is not None and err.errno == errno.EAGAIN:
+ raise RetryAgain()
+
+ logging.exception("fcntl.flock failed")
+ raise
def Exclusive(self, blocking=False, timeout=None):
"""Locks the file in exclusive mode.
"Failed to unlock %s" % self.filename)
+def SignalHandled(signums):
+ """Signal Handled decoration.
+
+ This special decorator installs a signal handler and then calls the target
+ function. The function must accept a 'signal_handlers' keyword argument,
+ which will contain a dict indexed by signal number, with SignalHandler
+ objects as values.
+
+ The decorator can be safely stacked with iself, to handle multiple signals
+ with different handlers.
+
+ @type signums: list
+ @param signums: signals to intercept
+
+ """
+ def wrap(fn):
+ def sig_function(*args, **kwargs):
+ assert 'signal_handlers' not in kwargs or \
+ kwargs['signal_handlers'] is None or \
+ isinstance(kwargs['signal_handlers'], dict), \
+ "Wrong signal_handlers parameter in original function call"
+ if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
+ signal_handlers = kwargs['signal_handlers']
+ else:
+ signal_handlers = {}
+ kwargs['signal_handlers'] = signal_handlers
+ sighandler = SignalHandler(signums)
+ try:
+ for sig in signums:
+ signal_handlers[sig] = sighandler
+ return fn(*args, **kwargs)
+ finally:
+ sighandler.Reset()
+ return sig_function
+ return wrap
+
+
class SignalHandler(object):
"""Generic signal handler class.
@param signum: Single signal number or set of signal numbers
"""
- if isinstance(signum, (int, long)):
- self.signum = set([signum])
- else:
- self.signum = set(signum)
-
+ self.signum = set(signum)
self.called = False
self._previous = {}
"""
self.called = False
- def _HandleSignal(self, signum, frame):
+ # we don't care about arguments, but we leave them named for the future
+ def _HandleSignal(self, signum, frame): # pylint: disable-msg=W0613
"""Actual signal handling function.
"""
@type field: str
@param field: the string to match
- @return: either False or a regular expression match object
+ @return: either None or a regular expression match object
"""
for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
return m
- return False
+ return None
def NonMatching(self, items):
"""Returns the list of fields not matching the current set