X-Git-Url: https://code.grnet.gr/git/ganeti-local/blobdiff_plain/ff5251bc6dd7ef8d62e3397a05209a44039892db..4a34c5cf5664c10a1c06e8865067b429ab0b9c71:/lib/utils.py diff --git a/lib/utils.py b/lib/utils.py index 4de6a1e..aace5ec 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -27,9 +27,7 @@ the command line scripts. """ -import sys import os -import sha import time import subprocess import re @@ -47,6 +45,12 @@ import signal from cStringIO import StringIO +try: + from hashlib import sha1 +except ImportError: + import sha + sha1 = sha.new + from ganeti import errors from ganeti import constants @@ -54,7 +58,6 @@ from ganeti import constants _locksheld = [] _re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$') -debug = False debug_locks = False #: when set to True, L{RunCmd} is disabled @@ -131,7 +134,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/'): directory for the command; the default will be / @rtype: L{RunResult} @return: RunResult instance - @raise erors.ProgrammerError: if we call this when forks are disabled + @raise errors.ProgrammerError: if we call this when forks are disabled """ if no_fork: @@ -151,11 +154,18 @@ def RunCmd(cmd, env=None, output=None, cwd='/'): if env is not None: cmd_env.update(env) - if output is None: - out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd) - else: - status = _RunCmdFile(cmd, cmd_env, shell, output, cwd) - out = err = "" + try: + if output is None: + out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd) + else: + status = _RunCmdFile(cmd, cmd_env, shell, output, cwd) + out = err = "" + except OSError, err: + if err.errno == errno.ENOENT: + raise errors.OpExecError("Can't execute '%s': not found (%s)" % + (strcmd, err)) + else: + raise if status >= 0: exitcode = status @@ -166,6 +176,7 @@ def RunCmd(cmd, env=None, output=None, cwd='/'): return RunResult(exitcode, signal_, out, err, strcmd) + def _RunCmdPipe(cmd, env, via_shell, cwd): """Run a command and return its output. @@ -203,7 +214,18 @@ def _RunCmdPipe(cmd, env, via_shell, cwd): fcntl.fcntl(fd, fcntl.F_SETFL, status | os.O_NONBLOCK) while fdmap: - for fd, event in poller.poll(): + try: + pollresult = poller.poll() + except EnvironmentError, eerr: + if eerr.errno == errno.EINTR: + continue + raise + except select.error, serr: + if serr[0] == errno.EINTR: + continue + raise + + for fd, event in pollresult: if event & select.POLLIN or event & select.POLLPRI: data = fdmap[fd][1].read() # no data from read signifies EOF (the same as POLLHUP) @@ -274,6 +296,32 @@ def RemoveFile(filename): raise +def RenameFile(old, new, mkdir=False, mkdir_mode=0750): + """Renames a file. + + @type old: string + @param old: Original path + @type new: string + @param new: New path + @type mkdir: bool + @param mkdir: Whether to create target directory if it doesn't exist + @type mkdir_mode: int + @param mkdir_mode: Mode for newly created directories + + """ + try: + return os.rename(old, new) + except OSError, err: + # In at least one use case of this function, the job queue, directory + # creation is very rare. Checking for the directory before renaming is not + # as efficient. + if mkdir and err.errno == errno.ENOENT: + # Create directory and try again + os.makedirs(os.path.dirname(new), mkdir_mode) + return os.rename(old, new) + raise + + def _FingerprintFile(filename): """Compute the fingerprint of a file. @@ -292,7 +340,7 @@ def _FingerprintFile(filename): f = open(filename) - fp = sha.sha() + fp = sha1() while True: data = f.read(4096) if not data: @@ -323,37 +371,78 @@ def FingerprintFiles(files): return ret -def CheckDict(target, template, logname=None): - """Ensure a dictionary has a required set of keys. - - For the given dictionaries I{target} and I{template}, ensure - I{target} has all the keys from I{template}. Missing keys are added - with values from template. +def ForceDictType(target, key_types, allowed_values=None): + """Force the values of a dict to have certain types. @type target: dict - @param target: the dictionary to update - @type template: dict - @param template: the dictionary holding the default values - @type logname: str or None - @param logname: if not None, causes the missing keys to be - logged with this name + @param target: the dict to update + @type key_types: dict + @param key_types: dict mapping target dict keys to types + in constants.ENFORCEABLE_TYPES + @type allowed_values: list + @keyword allowed_values: list of specially allowed values """ - missing = [] - for k in template: - if k not in target: - missing.append(k) - target[k] = template[k] + if allowed_values is None: + allowed_values = [] + + if not isinstance(target, dict): + msg = "Expected dictionary, got '%s'" % target + raise errors.TypeEnforcementError(msg) + + for key in target: + if key not in key_types: + msg = "Unknown key '%s'" % key + raise errors.TypeEnforcementError(msg) + + if target[key] in allowed_values: + continue - if missing and logname: - logging.warning('%s missing keys %s', logname, ', '.join(missing)) + ktype = key_types[key] + if ktype not in constants.ENFORCEABLE_TYPES: + msg = "'%s' has non-enforceable type %s" % (key, ktype) + raise errors.ProgrammerError(msg) + + if ktype == constants.VTYPE_STRING: + if not isinstance(target[key], basestring): + if isinstance(target[key], bool) and not target[key]: + target[key] = '' + else: + msg = "'%s' (value %s) is not a valid string" % (key, target[key]) + raise errors.TypeEnforcementError(msg) + elif ktype == constants.VTYPE_BOOL: + if isinstance(target[key], basestring) and target[key]: + if target[key].lower() == constants.VALUE_FALSE: + target[key] = False + elif target[key].lower() == constants.VALUE_TRUE: + target[key] = True + else: + msg = "'%s' (value %s) is not a valid boolean" % (key, target[key]) + raise errors.TypeEnforcementError(msg) + elif target[key]: + target[key] = True + else: + target[key] = False + elif ktype == constants.VTYPE_SIZE: + try: + target[key] = ParseUnit(target[key]) + except errors.UnitParseError, err: + msg = "'%s' (value %s) is not a valid size. error: %s" % \ + (key, target[key], err) + raise errors.TypeEnforcementError(msg) + elif ktype == constants.VTYPE_INT: + try: + target[key] = int(target[key]) + except (ValueError, TypeError): + msg = "'%s' (value %s) is not a valid integer" % (key, target[key]) + raise errors.TypeEnforcementError(msg) def IsProcessAlive(pid): """Check if a given pid exists on the system. - @note: zombie processes treated as not alive, and giving a - pid M{<= 0} causes the function to return False. + @note: zombie status is not handled, so zombie processes + will be returned as alive @type pid: int @param pid: the process ID to check @rtype: boolean @@ -364,22 +453,12 @@ def IsProcessAlive(pid): return False try: - f = open("/proc/%d/status" % pid) - except IOError, err: + os.stat("/proc/%d/status" % pid) + return True + except EnvironmentError, err: if err.errno in (errno.ENOENT, errno.ENOTDIR): return False - - alive = True - try: - data = f.readlines() - if len(data) > 1: - state = data[1].split() - if len(state) > 1 and state[1] == "Z": - alive = False - finally: - f.close() - - return alive + raise def ReadPidFile(pidfile): @@ -388,7 +467,7 @@ def ReadPidFile(pidfile): @type pidfile: string @param pidfile: path to the file containing the pid @rtype: int - @return: The process id, if the file exista and contains a valid PID, + @return: The process id, if the file exists and contains a valid PID, otherwise 0 """ @@ -584,7 +663,7 @@ def TryConvert(fn, val): """ try: nv = fn(val) - except (ValueError, TypeError), err: + except (ValueError, TypeError): nv = val return nv @@ -598,7 +677,7 @@ def IsValidIP(ip): @type ip: str @param ip: the address to be checked @rtype: a regular expression match object - @return: a regular epression match object, or None if the + @return: a regular expression match object, or None if the address is not valid """ @@ -631,7 +710,7 @@ def BuildShellCmd(template, *args): This function will check all arguments in the args list so that they are valid shell parameters (i.e. they don't contain shell - metacharaters). If everything is ok, it will return the result of + metacharacters). If everything is ok, it will return the result of template % args. @type template: str @@ -648,23 +727,40 @@ def BuildShellCmd(template, *args): return template % args -def FormatUnit(value): +def FormatUnit(value, units): """Formats an incoming number of MiB with the appropriate unit. @type value: int @param value: integer representing the value in MiB (1048576) + @type units: char + @param units: the type of formatting we should do: + - 'h' for automatic scaling + - 'm' for MiBs + - 'g' for GiBs + - 't' for TiBs @rtype: str @return: the formatted value (with suffix) """ - if value < 1024: - return "%dM" % round(value, 0) + if units not in ('m', 'g', 't', 'h'): + raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units)) + + suffix = '' + + if units == 'm' or (units == 'h' and value < 1024): + if units == 'h': + suffix = 'M' + return "%d%s" % (round(value, 0), suffix) - elif value < (1024 * 1024): - return "%0.1fG" % round(float(value) / 1024, 1) + elif units == 'g' or (units == 'h' and value < (1024 * 1024)): + if units == 'h': + suffix = 'G' + return "%0.1f%s" % (round(float(value) / 1024, 1), suffix) else: - return "%0.1fT" % round(float(value) / 1024 / 1024, 1) + if units == 'h': + suffix = 'T' + return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix) def ParseUnit(input_string): @@ -675,7 +771,7 @@ def ParseUnit(input_string): is always an int in MiB. """ - m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', input_string) + m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string)) if not m: raise errors.UnitParseError("Invalid format") @@ -787,6 +883,7 @@ def SetEtcHostsEntry(file_name, ip, hostname, aliases): @param aliases: the list of aliases to add for the hostname """ + # FIXME: use WriteFile + fn rather than duplicating its efforts # Ensure aliases are unique aliases = UniqueSequence([hostname] + aliases)[1:] @@ -796,7 +893,6 @@ def SetEtcHostsEntry(file_name, ip, hostname, aliases): try: f = open(file_name, 'r') try: - written = False for line in f: fields = line.split() if fields and not fields[0].startswith('#') and ip == fields[0]: @@ -810,6 +906,7 @@ def SetEtcHostsEntry(file_name, ip, hostname, aliases): out.flush() os.fsync(out) + os.chmod(tmpname, 0644) os.rename(tmpname, file_name) finally: f.close() @@ -843,6 +940,7 @@ def RemoveEtcHostsEntry(file_name, hostname): @param hostname: the hostname to be removed """ + # FIXME: use WriteFile + fn rather than duplicating its efforts fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name)) try: out = os.fdopen(fd, 'w') @@ -864,6 +962,7 @@ def RemoveEtcHostsEntry(file_name, hostname): out.flush() os.fsync(out) + os.chmod(tmpname, 0644) os.rename(tmpname, file_name) finally: f.close() @@ -940,7 +1039,7 @@ def ShellQuoteArgs(args): @type args: list @param args: list of arguments to be quoted @rtype: str - @return: the quoted arguments concatenaned with spaces + @return: the quoted arguments concatenated with spaces """ return ' '.join([ShellQuote(i) for i in args]) @@ -957,7 +1056,7 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None): @type port: int @param port: the port to connect to @type timeout: int - @param timeout: the timeout on the connection attemp + @param timeout: the timeout on the connection attempt @type live_port_needed: boolean @param live_port_needed: whether a closed port will cause the function to return failure, as if there was a timeout @@ -969,12 +1068,12 @@ def TcpPing(target, port, timeout=10, live_port_needed=False, source=None): """ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sucess = False + success = False if source is not None: try: sock.bind((source, 0)) - except socket.error, (errcode, errstring): + except socket.error, (errcode, _): if errcode == errno.EADDRNOTAVAIL: success = False @@ -999,7 +1098,7 @@ def OwnIpAddress(address): address. @type address: string - @param address: the addres to check + @param address: the address to check @rtype: bool @return: True if we own the address @@ -1068,7 +1167,25 @@ def GenerateSecret(): @return: a sha1 hexdigest of a block of 64 random bytes """ - return sha.new(os.urandom(64)).hexdigest() + return sha1(os.urandom(64)).hexdigest() + + +def EnsureDirs(dirs): + """Make required directories, if they don't exist. + + @param dirs: list of tuples (dir_name, dir_mode) + @type dirs: list of (string, integer) + + """ + for dir_name, dir_mode in dirs: + try: + os.mkdir(dir_name, dir_mode) + except EnvironmentError, err: + if err.errno != errno.EEXIST: + raise errors.GenericError("Cannot create needed directory" + " '%s': %s" % (dir_name, err)) + if not os.path.isdir(dir_name): + raise errors.GenericError("%s is not a directory" % dir_name) def ReadFile(file_name, size=None): @@ -1077,7 +1194,7 @@ def ReadFile(file_name, size=None): @type size: None or int @param size: Read at most size bytes @rtype: str - @return: the (possibly partial) conent of the file + @return: the (possibly partial) content of the file """ f = open(file_name, "r") @@ -1104,7 +1221,7 @@ def WriteFile(file_name, fn=None, data=None, mtime/atime of the file. If the function doesn't raise an exception, it has succeeded and the - target file has the new contents. If the file has raised an + target file has the new contents. If the function has raised an exception, an existing target file should be unmodified and the temporary file should be removed. @@ -1113,7 +1230,7 @@ def WriteFile(file_name, fn=None, data=None, @type fn: callable @param fn: content writing function, called with file descriptor as parameter - @type data: sr + @type data: str @param data: contents of the file @type mode: int @param mode: file mode @@ -1136,7 +1253,7 @@ def WriteFile(file_name, fn=None, data=None, @return: None if the 'close' parameter evaluates to True, otherwise the file descriptor - @raise errors.ProgrammerError: if an of the arguments are not valid + @raise errors.ProgrammerError: if any of the arguments are not valid """ if not os.path.isabs(file_name): @@ -1155,6 +1272,7 @@ def WriteFile(file_name, fn=None, data=None, dir_name, base_name = os.path.split(file_name) fd, new_name = tempfile.mkstemp('.new', base_name, dir_name) + do_remove = True # here we need to make sure we remove the temp file, if any error # leaves it in place try: @@ -1175,13 +1293,15 @@ def WriteFile(file_name, fn=None, data=None, os.utime(new_name, (atime, mtime)) if not dry_run: os.rename(new_name, file_name) + do_remove = False finally: if close: os.close(fd) result = None else: result = fd - RemoveFile(new_name) + if do_remove: + RemoveFile(new_name) return result @@ -1216,14 +1336,14 @@ def FirstFree(seq, base=0): def all(seq, pred=bool): "Returns True if pred(x) is True for every element in the iterable" - for elem in itertools.ifilterfalse(pred, seq): + for _ in itertools.ifilterfalse(pred, seq): return False return True def any(seq, pred=bool): "Returns True if pred(x) is True for at least one element in the iterable" - for elem in itertools.ifilter(pred, seq): + for _ in itertools.ifilter(pred, seq): return True return False @@ -1234,7 +1354,7 @@ def UniqueSequence(seq): Element order is preserved. @type seq: sequence - @param seq: the sequence with the source elementes + @param seq: the sequence with the source elements @rtype: list @return: list of unique elements from seq @@ -1246,7 +1366,7 @@ def UniqueSequence(seq): def IsValidMac(mac): """Predicate to check if a MAC address is valid. - Checks wether the supplied MAC address is formally correct, only + Checks whether the supplied MAC address is formally correct, only accepts colon separated format. @type mac: str @@ -1269,28 +1389,42 @@ def TestDelay(duration): """ if duration < 0: - return False + return False, "Invalid sleep duration" time.sleep(duration) - return True + return True, None -def Daemonize(logfile, noclose_fds=None): - """Daemonize the current process. +def _CloseFDNoErr(fd, retries=5): + """Close a file descriptor ignoring errors. - This detaches the current process from the controlling terminal and - runs it in the background as a daemon. + @type fd: int + @param fd: the file descriptor + @type retries: int + @param retries: how many retries to make, in case we get any + other error than EBADF + + """ + try: + os.close(fd) + except OSError, err: + if err.errno != errno.EBADF: + if retries > 0: + _CloseFDNoErr(fd, retries - 1) + # else either it's closed already or we're out of retries, so we + # ignore this and go on + + +def CloseFDs(noclose_fds=None): + """Close file descriptors. + + This closes all file descriptors above 2 (i.e. except + stdin/out/err). - @type logfile: str - @param logfile: the logfile to which we should redirect stdout/stderr @type noclose_fds: list or None @param noclose_fds: if given, it denotes a list of file descriptor that should not be closed - @rtype: int - @returns: the value zero """ - UMASK = 077 - WORKDIR = "/" # Default maximum for the number of available file descriptors. if 'SC_OPEN_MAX' in os.sysconf_names: try: @@ -1301,6 +1435,31 @@ def Daemonize(logfile, noclose_fds=None): MAXFD = 1024 else: MAXFD = 1024 + maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] + if (maxfd == resource.RLIM_INFINITY): + maxfd = MAXFD + + # Iterate through and close all file descriptors (except the standard ones) + for fd in range(3, maxfd): + if noclose_fds and fd in noclose_fds: + continue + _CloseFDNoErr(fd) + + +def Daemonize(logfile): + """Daemonize the current process. + + This detaches the current process from the controlling terminal and + runs it in the background as a daemon. + + @type logfile: str + @param logfile: the logfile to which we should redirect stdout/stderr + @rtype: int + @return: the value zero + + """ + UMASK = 077 + WORKDIR = "/" # this might fail pid = os.fork() @@ -1316,22 +1475,15 @@ def Daemonize(logfile, noclose_fds=None): os._exit(0) # Exit parent (the first child) of the second child. else: os._exit(0) # Exit parent of the first child. - maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] - if (maxfd == resource.RLIM_INFINITY): - maxfd = MAXFD - # Iterate through and close all file descriptors. - for fd in range(0, maxfd): - if noclose_fds and fd in noclose_fds: - continue - try: - os.close(fd) - except OSError: # ERROR, fd wasn't open to begin with (ignored) - pass - os.open(logfile, os.O_RDWR|os.O_CREAT|os.O_APPEND, 0600) - # Duplicate standard input to standard output and standard error. - os.dup2(0, 1) # standard output (1) - os.dup2(0, 2) # standard error (2) + for fd in range(3): + _CloseFDNoErr(fd) + i = os.open("/dev/null", os.O_RDONLY) # stdin + assert i == 0, "Can't close/reopen stdin" + i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout + assert i == 1, "Can't close/reopen stdout" + # Duplicate standard output to standard error. + os.dup2(1, 2) return 0 @@ -1376,7 +1528,6 @@ def RemovePidFile(name): @param name: the daemon name used to derive the pidfile name """ - pid = os.getpid() pidfilename = DaemonPidFileName(name) # TODO: we could check here that the file contains our pid try: @@ -1421,11 +1572,25 @@ def KillProcess(pid, signal_=signal.SIGTERM, timeout=30, _helper(pid, signal_, waitpid) if timeout <= 0: return + + # Wait up to $timeout seconds end = time.time() + timeout + wait = 0.01 while time.time() < end and IsProcessAlive(pid): - time.sleep(0.1) + try: + (result_pid, _) = os.waitpid(pid, os.WNOHANG) + if result_pid > 0: + break + except OSError: + pass + time.sleep(wait) + # Make wait time longer for next try + if wait < 0.1: + wait *= 1.5 + if IsProcessAlive(pid): - _helper(pid, signal.SIGKILL, wait) + # Kill process if it's still alive + _helper(pid, signal.SIGKILL, waitpid) def FindFile(name, search_path, test=os.path.exists): @@ -1532,16 +1697,8 @@ def GetNodeDaemonPort(): return port -def GetNodeDaemonPassword(): - """Get the node password for the cluster. - - @rtype: str - - """ - return ReadFile(constants.CLUSTER_PASSWORD_FILE) - - -def SetupLogging(logfile, debug=False, stderr_logging=False, program=""): +def SetupLogging(logfile, debug=False, stderr_logging=False, program="", + multithreaded=False): """Configures the logging module. @type logfile: str @@ -1553,21 +1710,28 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program=""): @param stderr_logging: whether we should also log to the standard error @type program: str @param program: the name under which we should log messages + @type multithreaded: boolean + @param multithreaded: if True, will add the thread name to the log file @raise EnvironmentError: if we can't open the log file and stderr logging is disabled """ - fmt = "%(asctime)s: " + program + " " + fmt = "%(asctime)s: " + program + " pid=%(process)d" + if multithreaded: + fmt += "/%(threadName)s" if debug: - fmt += ("pid=%(process)d/%(threadName)s %(levelname)s" - " %(module)s:%(lineno)s %(message)s") - else: - fmt += "pid=%(process)d %(levelname)s %(message)s" + fmt += " %(module)s:%(lineno)s" + fmt += " %(levelname)s %(message)s" formatter = logging.Formatter(fmt) root_logger = logging.getLogger("") root_logger.setLevel(logging.NOTSET) + # Remove all previously setup handlers + for handler in root_logger.handlers: + handler.close() + root_logger.removeHandler(handler) + if stderr_logging: stderr_handler = logging.StreamHandler() stderr_handler.setFormatter(formatter) @@ -1589,13 +1753,93 @@ def SetupLogging(logfile, debug=False, stderr_logging=False, program=""): else: logfile_handler.setLevel(logging.INFO) root_logger.addHandler(logfile_handler) - except EnvironmentError, err: + except EnvironmentError: if stderr_logging: logging.exception("Failed to enable logging to file '%s'", logfile) else: # we need to re-raise the exception raise +def IsNormAbsPath(path): + """Check whether a path is absolute and also normalized + + This avoids things like /dir/../../other/path to be valid. + + """ + return os.path.normpath(path) == path and os.path.isabs(path) + +def TailFile(fname, lines=20): + """Return the last lines from a file. + + @note: this function will only read and parse the last 4KB of + the file; if the lines are very long, it could be that less + than the requested number of lines are returned + + @param fname: the file name + @type lines: int + @param lines: the (maximum) number of lines to return + + """ + fd = open(fname, "r") + try: + fd.seek(0, 2) + pos = fd.tell() + pos = max(0, pos-4096) + fd.seek(pos, 0) + raw_data = fd.read() + finally: + fd.close() + + rows = raw_data.splitlines() + return rows[-lines:] + + +def SafeEncode(text): + """Return a 'safe' version of a source string. + + This function mangles the input string and returns a version that + should be safe to display/encode as ASCII. To this end, we first + convert it to ASCII using the 'backslashreplace' encoding which + should get rid of any non-ASCII chars, and then we process it + through a loop copied from the string repr sources in the python; we + don't use string_escape anymore since that escape single quotes and + backslashes too, and that is too much; and that escaping is not + stable, i.e. string_escape(string_escape(x)) != string_escape(x). + + @type text: str or unicode + @param text: input data + @rtype: str + @return: a safe version of text + + """ + if isinstance(text, unicode): + # only if unicode; if str already, we handle it below + text = text.encode('ascii', 'backslashreplace') + resu = "" + for char in text: + c = ord(char) + if char == '\t': + resu += r'\t' + elif char == '\n': + resu += r'\n' + elif char == '\r': + resu += r'\'r' + elif c < 32 or c >= 127: # non-printable + resu += "\\x%02x" % (c & 0xff) + else: + resu += char + return resu + + +def CommaJoin(names): + """Nicely join a set of identifiers. + + @param names: set, list or tuple + @return: a string with the formatted results + + """ + return ", ".join(["'%s'" % val for val in names]) + def LockedMethod(fn): """Synchronized object access decorator. @@ -1833,3 +2077,45 @@ class SignalHandler(object): # This is not nice and not absolutely atomic, but it appears to be the only # solution in Python -- there are no atomic types. self.called = True + + +class FieldSet(object): + """A simple field set. + + Among the features are: + - checking if a string is among a list of static string or regex objects + - checking if a whole list of string matches + - returning the matching groups from a regex match + + Internally, all fields are held as regular expression objects. + + """ + def __init__(self, *items): + self.items = [re.compile("^%s$" % value) for value in items] + + def Extend(self, other_set): + """Extend the field set with the items from another one""" + self.items.extend(other_set.items) + + def Matches(self, field): + """Checks if a field matches the current set + + @type field: str + @param field: the string to match + @return: either False or a regular expression match object + + """ + for m in itertools.ifilter(None, (val.match(field) for val in self.items)): + return m + return False + + def NonMatching(self, items): + """Returns the list of fields not matching the current set + + @type items: list + @param items: the list of fields to check + @rtype: list + @return: list of non-matching fields + + """ + return [val for val in items if not self.Matches(val)]