Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 31892b4c

History | View | Annotate | Download (63.3 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import time
32
import subprocess
33
import re
34
import socket
35
import tempfile
36
import shutil
37
import errno
38
import pwd
39
import itertools
40
import select
41
import fcntl
42
import resource
43
import logging
44
import signal
45
import string
46

    
47
from cStringIO import StringIO
48

    
49
try:
50
  from hashlib import sha1
51
except ImportError:
52
  import sha
53
  sha1 = sha.new
54

    
55
from ganeti import errors
56
from ganeti import constants
57

    
58

    
59
_locksheld = []
60
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
61

    
62
debug_locks = False
63

    
64
#: when set to True, L{RunCmd} is disabled
65
no_fork = False
66

    
67
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
68

    
69

    
70
class RunResult(object):
71
  """Holds the result of running external programs.
72

73
  @type exit_code: int
74
  @ivar exit_code: the exit code of the program, or None (if the program
75
      didn't exit())
76
  @type signal: int or None
77
  @ivar signal: the signal that caused the program to finish, or None
78
      (if the program wasn't terminated by a signal)
79
  @type stdout: str
80
  @ivar stdout: the standard output of the program
81
  @type stderr: str
82
  @ivar stderr: the standard error of the program
83
  @type failed: boolean
84
  @ivar failed: True in case the program was
85
      terminated by a signal or exited with a non-zero exit code
86
  @ivar fail_reason: a string detailing the termination reason
87

88
  """
89
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
90
               "failed", "fail_reason", "cmd"]
91

    
92

    
93
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
94
    self.cmd = cmd
95
    self.exit_code = exit_code
96
    self.signal = signal_
97
    self.stdout = stdout
98
    self.stderr = stderr
99
    self.failed = (signal_ is not None or exit_code != 0)
100

    
101
    if self.signal is not None:
102
      self.fail_reason = "terminated by signal %s" % self.signal
103
    elif self.exit_code is not None:
104
      self.fail_reason = "exited with exit code %s" % self.exit_code
105
    else:
106
      self.fail_reason = "unable to determine termination reason"
107

    
108
    if self.failed:
109
      logging.debug("Command '%s' failed (%s); output: %s",
110
                    self.cmd, self.fail_reason, self.output)
111

    
112
  def _GetOutput(self):
113
    """Returns the combined stdout and stderr for easier usage.
114

115
    """
116
    return self.stdout + self.stderr
117

    
118
  output = property(_GetOutput, None, None, "Return full output")
119

    
120

    
121
def RunCmd(cmd, env=None, output=None, cwd='/'):
122
  """Execute a (shell) command.
123

124
  The command should not read from its standard input, as it will be
125
  closed.
126

127
  @type  cmd: string or list
128
  @param cmd: Command to run
129
  @type env: dict
130
  @param env: Additional environment
131
  @type output: str
132
  @param output: if desired, the output of the command can be
133
      saved in a file instead of the RunResult instance; this
134
      parameter denotes the file name (if not None)
135
  @type cwd: string
136
  @param cwd: if specified, will be used as the working
137
      directory for the command; the default will be /
138
  @rtype: L{RunResult}
139
  @return: RunResult instance
140
  @raise errors.ProgrammerError: if we call this when forks are disabled
141

142
  """
143
  if no_fork:
144
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
145

    
146
  if isinstance(cmd, list):
147
    cmd = [str(val) for val in cmd]
148
    strcmd = " ".join(cmd)
149
    shell = False
150
  else:
151
    strcmd = cmd
152
    shell = True
153
  logging.debug("RunCmd '%s'", strcmd)
154

    
155
  cmd_env = os.environ.copy()
156
  cmd_env["LC_ALL"] = "C"
157
  if env is not None:
158
    cmd_env.update(env)
159

    
160
  try:
161
    if output is None:
162
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
163
    else:
164
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
165
      out = err = ""
166
  except OSError, err:
167
    if err.errno == errno.ENOENT:
168
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
169
                               (strcmd, err))
170
    else:
171
      raise
172

    
173
  if status >= 0:
174
    exitcode = status
175
    signal_ = None
176
  else:
177
    exitcode = None
178
    signal_ = -status
179

    
180
  return RunResult(exitcode, signal_, out, err, strcmd)
181

    
182

    
183
def _RunCmdPipe(cmd, env, via_shell, cwd):
184
  """Run a command and return its output.
185

186
  @type  cmd: string or list
187
  @param cmd: Command to run
188
  @type env: dict
189
  @param env: The environment to use
190
  @type via_shell: bool
191
  @param via_shell: if we should run via the shell
192
  @type cwd: string
193
  @param cwd: the working directory for the program
194
  @rtype: tuple
195
  @return: (out, err, status)
196

197
  """
198
  poller = select.poll()
199
  child = subprocess.Popen(cmd, shell=via_shell,
200
                           stderr=subprocess.PIPE,
201
                           stdout=subprocess.PIPE,
202
                           stdin=subprocess.PIPE,
203
                           close_fds=True, env=env,
204
                           cwd=cwd)
205

    
206
  child.stdin.close()
207
  poller.register(child.stdout, select.POLLIN)
208
  poller.register(child.stderr, select.POLLIN)
209
  out = StringIO()
210
  err = StringIO()
211
  fdmap = {
212
    child.stdout.fileno(): (out, child.stdout),
213
    child.stderr.fileno(): (err, child.stderr),
214
    }
215
  for fd in fdmap:
216
    status = fcntl.fcntl(fd, fcntl.F_GETFL)
217
    fcntl.fcntl(fd, fcntl.F_SETFL, status | os.O_NONBLOCK)
218

    
219
  while fdmap:
220
    try:
221
      pollresult = poller.poll()
222
    except EnvironmentError, eerr:
223
      if eerr.errno == errno.EINTR:
224
        continue
225
      raise
226
    except select.error, serr:
227
      if serr[0] == errno.EINTR:
228
        continue
229
      raise
230

    
231
    for fd, event in pollresult:
232
      if event & select.POLLIN or event & select.POLLPRI:
233
        data = fdmap[fd][1].read()
234
        # no data from read signifies EOF (the same as POLLHUP)
235
        if not data:
236
          poller.unregister(fd)
237
          del fdmap[fd]
238
          continue
239
        fdmap[fd][0].write(data)
240
      if (event & select.POLLNVAL or event & select.POLLHUP or
241
          event & select.POLLERR):
242
        poller.unregister(fd)
243
        del fdmap[fd]
244

    
245
  out = out.getvalue()
246
  err = err.getvalue()
247

    
248
  status = child.wait()
249
  return out, err, status
250

    
251

    
252
def _RunCmdFile(cmd, env, via_shell, output, cwd):
253
  """Run a command and save its output to a file.
254

255
  @type  cmd: string or list
256
  @param cmd: Command to run
257
  @type env: dict
258
  @param env: The environment to use
259
  @type via_shell: bool
260
  @param via_shell: if we should run via the shell
261
  @type output: str
262
  @param output: the filename in which to save the output
263
  @type cwd: string
264
  @param cwd: the working directory for the program
265
  @rtype: int
266
  @return: the exit status
267

268
  """
269
  fh = open(output, "a")
270
  try:
271
    child = subprocess.Popen(cmd, shell=via_shell,
272
                             stderr=subprocess.STDOUT,
273
                             stdout=fh,
274
                             stdin=subprocess.PIPE,
275
                             close_fds=True, env=env,
276
                             cwd=cwd)
277

    
278
    child.stdin.close()
279
    status = child.wait()
280
  finally:
281
    fh.close()
282
  return status
283

    
284

    
285
def RemoveFile(filename):
286
  """Remove a file ignoring some errors.
287

288
  Remove a file, ignoring non-existing ones or directories. Other
289
  errors are passed.
290

291
  @type filename: str
292
  @param filename: the file to be removed
293

294
  """
295
  try:
296
    os.unlink(filename)
297
  except OSError, err:
298
    if err.errno not in (errno.ENOENT, errno.EISDIR):
299
      raise
300

    
301

    
302
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
303
  """Renames a file.
304

305
  @type old: string
306
  @param old: Original path
307
  @type new: string
308
  @param new: New path
309
  @type mkdir: bool
310
  @param mkdir: Whether to create target directory if it doesn't exist
311
  @type mkdir_mode: int
312
  @param mkdir_mode: Mode for newly created directories
313

314
  """
315
  try:
316
    return os.rename(old, new)
317
  except OSError, err:
318
    # In at least one use case of this function, the job queue, directory
319
    # creation is very rare. Checking for the directory before renaming is not
320
    # as efficient.
321
    if mkdir and err.errno == errno.ENOENT:
322
      # Create directory and try again
323
      os.makedirs(os.path.dirname(new), mkdir_mode)
324
      return os.rename(old, new)
325
    raise
326

    
327

    
328
def _FingerprintFile(filename):
329
  """Compute the fingerprint of a file.
330

331
  If the file does not exist, a None will be returned
332
  instead.
333

334
  @type filename: str
335
  @param filename: the filename to checksum
336
  @rtype: str
337
  @return: the hex digest of the sha checksum of the contents
338
      of the file
339

340
  """
341
  if not (os.path.exists(filename) and os.path.isfile(filename)):
342
    return None
343

    
344
  f = open(filename)
345

    
346
  fp = sha1()
347
  while True:
348
    data = f.read(4096)
349
    if not data:
350
      break
351

    
352
    fp.update(data)
353

    
354
  return fp.hexdigest()
355

    
356

    
357
def FingerprintFiles(files):
358
  """Compute fingerprints for a list of files.
359

360
  @type files: list
361
  @param files: the list of filename to fingerprint
362
  @rtype: dict
363
  @return: a dictionary filename: fingerprint, holding only
364
      existing files
365

366
  """
367
  ret = {}
368

    
369
  for filename in files:
370
    cksum = _FingerprintFile(filename)
371
    if cksum:
372
      ret[filename] = cksum
373

    
374
  return ret
375

    
376

    
377
def ForceDictType(target, key_types, allowed_values=None):
378
  """Force the values of a dict to have certain types.
379

380
  @type target: dict
381
  @param target: the dict to update
382
  @type key_types: dict
383
  @param key_types: dict mapping target dict keys to types
384
                    in constants.ENFORCEABLE_TYPES
385
  @type allowed_values: list
386
  @keyword allowed_values: list of specially allowed values
387

388
  """
389
  if allowed_values is None:
390
    allowed_values = []
391

    
392
  if not isinstance(target, dict):
393
    msg = "Expected dictionary, got '%s'" % target
394
    raise errors.TypeEnforcementError(msg)
395

    
396
  for key in target:
397
    if key not in key_types:
398
      msg = "Unknown key '%s'" % key
399
      raise errors.TypeEnforcementError(msg)
400

    
401
    if target[key] in allowed_values:
402
      continue
403

    
404
    ktype = key_types[key]
405
    if ktype not in constants.ENFORCEABLE_TYPES:
406
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
407
      raise errors.ProgrammerError(msg)
408

    
409
    if ktype == constants.VTYPE_STRING:
410
      if not isinstance(target[key], basestring):
411
        if isinstance(target[key], bool) and not target[key]:
412
          target[key] = ''
413
        else:
414
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
415
          raise errors.TypeEnforcementError(msg)
416
    elif ktype == constants.VTYPE_BOOL:
417
      if isinstance(target[key], basestring) and target[key]:
418
        if target[key].lower() == constants.VALUE_FALSE:
419
          target[key] = False
420
        elif target[key].lower() == constants.VALUE_TRUE:
421
          target[key] = True
422
        else:
423
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
424
          raise errors.TypeEnforcementError(msg)
425
      elif target[key]:
426
        target[key] = True
427
      else:
428
        target[key] = False
429
    elif ktype == constants.VTYPE_SIZE:
430
      try:
431
        target[key] = ParseUnit(target[key])
432
      except errors.UnitParseError, err:
433
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
434
              (key, target[key], err)
435
        raise errors.TypeEnforcementError(msg)
436
    elif ktype == constants.VTYPE_INT:
437
      try:
438
        target[key] = int(target[key])
439
      except (ValueError, TypeError):
440
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
441
        raise errors.TypeEnforcementError(msg)
442

    
443

    
444
def IsProcessAlive(pid):
445
  """Check if a given pid exists on the system.
446

447
  @note: zombie status is not handled, so zombie processes
448
      will be returned as alive
449
  @type pid: int
450
  @param pid: the process ID to check
451
  @rtype: boolean
452
  @return: True if the process exists
453

454
  """
455
  if pid <= 0:
456
    return False
457

    
458
  try:
459
    os.stat("/proc/%d/status" % pid)
460
    return True
461
  except EnvironmentError, err:
462
    if err.errno in (errno.ENOENT, errno.ENOTDIR):
463
      return False
464
    raise
465

    
466

    
467
def ReadPidFile(pidfile):
468
  """Read a pid from a file.
469

470
  @type  pidfile: string
471
  @param pidfile: path to the file containing the pid
472
  @rtype: int
473
  @return: The process id, if the file exists and contains a valid PID,
474
           otherwise 0
475

476
  """
477
  try:
478
    raw_data = ReadFile(pidfile)
479
  except EnvironmentError, err:
480
    if err.errno != errno.ENOENT:
481
      logging.exception("Can't read pid file")
482
    return 0
483

    
484
  try:
485
    pid = int(raw_data)
486
  except ValueError, err:
487
    logging.info("Can't parse pid file contents", exc_info=True)
488
    return 0
489

    
490
  return pid
491

    
492

    
493
def MatchNameComponent(key, name_list, case_sensitive=True):
494
  """Try to match a name against a list.
495

496
  This function will try to match a name like test1 against a list
497
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
498
  this list, I{'test1'} as well as I{'test1.example'} will match, but
499
  not I{'test1.ex'}. A multiple match will be considered as no match
500
  at all (e.g. I{'test1'} against C{['test1.example.com',
501
  'test1.example.org']}), except when the key fully matches an entry
502
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
503

504
  @type key: str
505
  @param key: the name to be searched
506
  @type name_list: list
507
  @param name_list: the list of strings against which to search the key
508
  @type case_sensitive: boolean
509
  @param case_sensitive: whether to provide a case-sensitive match
510

511
  @rtype: None or str
512
  @return: None if there is no match I{or} if there are multiple matches,
513
      otherwise the element from the list which matches
514

515
  """
516
  if key in name_list:
517
    return key
518

    
519
  re_flags = 0
520
  if not case_sensitive:
521
    re_flags |= re.IGNORECASE
522
    key = string.upper(key)
523
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
524
  names_filtered = []
525
  string_matches = []
526
  for name in name_list:
527
    if mo.match(name) is not None:
528
      names_filtered.append(name)
529
      if not case_sensitive and key == string.upper(name):
530
        string_matches.append(name)
531

    
532
  if len(string_matches) == 1:
533
    return string_matches[0]
534
  if len(names_filtered) == 1:
535
    return names_filtered[0]
536
  return None
537

    
538

    
539
class HostInfo:
540
  """Class implementing resolver and hostname functionality
541

542
  """
543
  def __init__(self, name=None):
544
    """Initialize the host name object.
545

546
    If the name argument is not passed, it will use this system's
547
    name.
548

549
    """
550
    if name is None:
551
      name = self.SysName()
552

    
553
    self.query = name
554
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
555
    self.ip = self.ipaddrs[0]
556

    
557
  def ShortName(self):
558
    """Returns the hostname without domain.
559

560
    """
561
    return self.name.split('.')[0]
562

    
563
  @staticmethod
564
  def SysName():
565
    """Return the current system's name.
566

567
    This is simply a wrapper over C{socket.gethostname()}.
568

569
    """
570
    return socket.gethostname()
571

    
572
  @staticmethod
573
  def LookupHostname(hostname):
574
    """Look up hostname
575

576
    @type hostname: str
577
    @param hostname: hostname to look up
578

579
    @rtype: tuple
580
    @return: a tuple (name, aliases, ipaddrs) as returned by
581
        C{socket.gethostbyname_ex}
582
    @raise errors.ResolverError: in case of errors in resolving
583

584
    """
585
    try:
586
      result = socket.gethostbyname_ex(hostname)
587
    except socket.gaierror, err:
588
      # hostname not found in DNS
589
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
590

    
591
    return result
592

    
593

    
594
def ListVolumeGroups():
595
  """List volume groups and their size
596

597
  @rtype: dict
598
  @return:
599
       Dictionary with keys volume name and values
600
       the size of the volume
601

602
  """
603
  command = "vgs --noheadings --units m --nosuffix -o name,size"
604
  result = RunCmd(command)
605
  retval = {}
606
  if result.failed:
607
    return retval
608

    
609
  for line in result.stdout.splitlines():
610
    try:
611
      name, size = line.split()
612
      size = int(float(size))
613
    except (IndexError, ValueError), err:
614
      logging.error("Invalid output from vgs (%s): %s", err, line)
615
      continue
616

    
617
    retval[name] = size
618

    
619
  return retval
620

    
621

    
622
def BridgeExists(bridge):
623
  """Check whether the given bridge exists in the system
624

625
  @type bridge: str
626
  @param bridge: the bridge name to check
627
  @rtype: boolean
628
  @return: True if it does
629

630
  """
631
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
632

    
633

    
634
def NiceSort(name_list):
635
  """Sort a list of strings based on digit and non-digit groupings.
636

637
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
638
  will sort the list in the logical order C{['a1', 'a2', 'a10',
639
  'a11']}.
640

641
  The sort algorithm breaks each name in groups of either only-digits
642
  or no-digits. Only the first eight such groups are considered, and
643
  after that we just use what's left of the string.
644

645
  @type name_list: list
646
  @param name_list: the names to be sorted
647
  @rtype: list
648
  @return: a copy of the name list sorted with our algorithm
649

650
  """
651
  _SORTER_BASE = "(\D+|\d+)"
652
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
653
                                                  _SORTER_BASE, _SORTER_BASE,
654
                                                  _SORTER_BASE, _SORTER_BASE,
655
                                                  _SORTER_BASE, _SORTER_BASE)
656
  _SORTER_RE = re.compile(_SORTER_FULL)
657
  _SORTER_NODIGIT = re.compile("^\D*$")
658
  def _TryInt(val):
659
    """Attempts to convert a variable to integer."""
660
    if val is None or _SORTER_NODIGIT.match(val):
661
      return val
662
    rval = int(val)
663
    return rval
664

    
665
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
666
             for name in name_list]
667
  to_sort.sort()
668
  return [tup[1] for tup in to_sort]
669

    
670

    
671
def TryConvert(fn, val):
672
  """Try to convert a value ignoring errors.
673

674
  This function tries to apply function I{fn} to I{val}. If no
675
  C{ValueError} or C{TypeError} exceptions are raised, it will return
676
  the result, else it will return the original value. Any other
677
  exceptions are propagated to the caller.
678

679
  @type fn: callable
680
  @param fn: function to apply to the value
681
  @param val: the value to be converted
682
  @return: The converted value if the conversion was successful,
683
      otherwise the original value.
684

685
  """
686
  try:
687
    nv = fn(val)
688
  except (ValueError, TypeError):
689
    nv = val
690
  return nv
691

    
692

    
693
def IsValidIP(ip):
694
  """Verifies the syntax of an IPv4 address.
695

696
  This function checks if the IPv4 address passes is valid or not based
697
  on syntax (not IP range, class calculations, etc.).
698

699
  @type ip: str
700
  @param ip: the address to be checked
701
  @rtype: a regular expression match object
702
  @return: a regular expression match object, or None if the
703
      address is not valid
704

705
  """
706
  unit = "(0|[1-9]\d{0,2})"
707
  #TODO: convert and return only boolean
708
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
709

    
710

    
711
def IsValidShellParam(word):
712
  """Verifies is the given word is safe from the shell's p.o.v.
713

714
  This means that we can pass this to a command via the shell and be
715
  sure that it doesn't alter the command line and is passed as such to
716
  the actual command.
717

718
  Note that we are overly restrictive here, in order to be on the safe
719
  side.
720

721
  @type word: str
722
  @param word: the word to check
723
  @rtype: boolean
724
  @return: True if the word is 'safe'
725

726
  """
727
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
728

    
729

    
730
def BuildShellCmd(template, *args):
731
  """Build a safe shell command line from the given arguments.
732

733
  This function will check all arguments in the args list so that they
734
  are valid shell parameters (i.e. they don't contain shell
735
  metacharacters). If everything is ok, it will return the result of
736
  template % args.
737

738
  @type template: str
739
  @param template: the string holding the template for the
740
      string formatting
741
  @rtype: str
742
  @return: the expanded command line
743

744
  """
745
  for word in args:
746
    if not IsValidShellParam(word):
747
      raise errors.ProgrammerError("Shell argument '%s' contains"
748
                                   " invalid characters" % word)
749
  return template % args
750

    
751

    
752
def FormatUnit(value, units):
753
  """Formats an incoming number of MiB with the appropriate unit.
754

755
  @type value: int
756
  @param value: integer representing the value in MiB (1048576)
757
  @type units: char
758
  @param units: the type of formatting we should do:
759
      - 'h' for automatic scaling
760
      - 'm' for MiBs
761
      - 'g' for GiBs
762
      - 't' for TiBs
763
  @rtype: str
764
  @return: the formatted value (with suffix)
765

766
  """
767
  if units not in ('m', 'g', 't', 'h'):
768
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
769

    
770
  suffix = ''
771

    
772
  if units == 'm' or (units == 'h' and value < 1024):
773
    if units == 'h':
774
      suffix = 'M'
775
    return "%d%s" % (round(value, 0), suffix)
776

    
777
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
778
    if units == 'h':
779
      suffix = 'G'
780
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
781

    
782
  else:
783
    if units == 'h':
784
      suffix = 'T'
785
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
786

    
787

    
788
def ParseUnit(input_string):
789
  """Tries to extract number and scale from the given string.
790

791
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
792
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
793
  is always an int in MiB.
794

795
  """
796
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
797
  if not m:
798
    raise errors.UnitParseError("Invalid format")
799

    
800
  value = float(m.groups()[0])
801

    
802
  unit = m.groups()[1]
803
  if unit:
804
    lcunit = unit.lower()
805
  else:
806
    lcunit = 'm'
807

    
808
  if lcunit in ('m', 'mb', 'mib'):
809
    # Value already in MiB
810
    pass
811

    
812
  elif lcunit in ('g', 'gb', 'gib'):
813
    value *= 1024
814

    
815
  elif lcunit in ('t', 'tb', 'tib'):
816
    value *= 1024 * 1024
817

    
818
  else:
819
    raise errors.UnitParseError("Unknown unit: %s" % unit)
820

    
821
  # Make sure we round up
822
  if int(value) < value:
823
    value += 1
824

    
825
  # Round up to the next multiple of 4
826
  value = int(value)
827
  if value % 4:
828
    value += 4 - value % 4
829

    
830
  return value
831

    
832

    
833
def AddAuthorizedKey(file_name, key):
834
  """Adds an SSH public key to an authorized_keys file.
835

836
  @type file_name: str
837
  @param file_name: path to authorized_keys file
838
  @type key: str
839
  @param key: string containing key
840

841
  """
842
  key_fields = key.split()
843

    
844
  f = open(file_name, 'a+')
845
  try:
846
    nl = True
847
    for line in f:
848
      # Ignore whitespace changes
849
      if line.split() == key_fields:
850
        break
851
      nl = line.endswith('\n')
852
    else:
853
      if not nl:
854
        f.write("\n")
855
      f.write(key.rstrip('\r\n'))
856
      f.write("\n")
857
      f.flush()
858
  finally:
859
    f.close()
860

    
861

    
862
def RemoveAuthorizedKey(file_name, key):
863
  """Removes an SSH public key from an authorized_keys file.
864

865
  @type file_name: str
866
  @param file_name: path to authorized_keys file
867
  @type key: str
868
  @param key: string containing key
869

870
  """
871
  key_fields = key.split()
872

    
873
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
874
  try:
875
    out = os.fdopen(fd, 'w')
876
    try:
877
      f = open(file_name, 'r')
878
      try:
879
        for line in f:
880
          # Ignore whitespace changes while comparing lines
881
          if line.split() != key_fields:
882
            out.write(line)
883

    
884
        out.flush()
885
        os.rename(tmpname, file_name)
886
      finally:
887
        f.close()
888
    finally:
889
      out.close()
890
  except:
891
    RemoveFile(tmpname)
892
    raise
893

    
894

    
895
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
896
  """Sets the name of an IP address and hostname in /etc/hosts.
897

898
  @type file_name: str
899
  @param file_name: path to the file to modify (usually C{/etc/hosts})
900
  @type ip: str
901
  @param ip: the IP address
902
  @type hostname: str
903
  @param hostname: the hostname to be added
904
  @type aliases: list
905
  @param aliases: the list of aliases to add for the hostname
906

907
  """
908
  # FIXME: use WriteFile + fn rather than duplicating its efforts
909
  # Ensure aliases are unique
910
  aliases = UniqueSequence([hostname] + aliases)[1:]
911

    
912
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
913
  try:
914
    out = os.fdopen(fd, 'w')
915
    try:
916
      f = open(file_name, 'r')
917
      try:
918
        for line in f:
919
          fields = line.split()
920
          if fields and not fields[0].startswith('#') and ip == fields[0]:
921
            continue
922
          out.write(line)
923

    
924
        out.write("%s\t%s" % (ip, hostname))
925
        if aliases:
926
          out.write(" %s" % ' '.join(aliases))
927
        out.write('\n')
928

    
929
        out.flush()
930
        os.fsync(out)
931
        os.chmod(tmpname, 0644)
932
        os.rename(tmpname, file_name)
933
      finally:
934
        f.close()
935
    finally:
936
      out.close()
937
  except:
938
    RemoveFile(tmpname)
939
    raise
940

    
941

    
942
def AddHostToEtcHosts(hostname):
943
  """Wrapper around SetEtcHostsEntry.
944

945
  @type hostname: str
946
  @param hostname: a hostname that will be resolved and added to
947
      L{constants.ETC_HOSTS}
948

949
  """
950
  hi = HostInfo(name=hostname)
951
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
952

    
953

    
954
def RemoveEtcHostsEntry(file_name, hostname):
955
  """Removes a hostname from /etc/hosts.
956

957
  IP addresses without names are removed from the file.
958

959
  @type file_name: str
960
  @param file_name: path to the file to modify (usually C{/etc/hosts})
961
  @type hostname: str
962
  @param hostname: the hostname to be removed
963

964
  """
965
  # FIXME: use WriteFile + fn rather than duplicating its efforts
966
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
967
  try:
968
    out = os.fdopen(fd, 'w')
969
    try:
970
      f = open(file_name, 'r')
971
      try:
972
        for line in f:
973
          fields = line.split()
974
          if len(fields) > 1 and not fields[0].startswith('#'):
975
            names = fields[1:]
976
            if hostname in names:
977
              while hostname in names:
978
                names.remove(hostname)
979
              if names:
980
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
981
              continue
982

    
983
          out.write(line)
984

    
985
        out.flush()
986
        os.fsync(out)
987
        os.chmod(tmpname, 0644)
988
        os.rename(tmpname, file_name)
989
      finally:
990
        f.close()
991
    finally:
992
      out.close()
993
  except:
994
    RemoveFile(tmpname)
995
    raise
996

    
997

    
998
def RemoveHostFromEtcHosts(hostname):
999
  """Wrapper around RemoveEtcHostsEntry.
1000

1001
  @type hostname: str
1002
  @param hostname: hostname that will be resolved and its
1003
      full and shot name will be removed from
1004
      L{constants.ETC_HOSTS}
1005

1006
  """
1007
  hi = HostInfo(name=hostname)
1008
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1009
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1010

    
1011

    
1012
def CreateBackup(file_name):
1013
  """Creates a backup of a file.
1014

1015
  @type file_name: str
1016
  @param file_name: file to be backed up
1017
  @rtype: str
1018
  @return: the path to the newly created backup
1019
  @raise errors.ProgrammerError: for invalid file names
1020

1021
  """
1022
  if not os.path.isfile(file_name):
1023
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1024
                                file_name)
1025

    
1026
  prefix = '%s.backup-%d.' % (os.path.basename(file_name), int(time.time()))
1027
  dir_name = os.path.dirname(file_name)
1028

    
1029
  fsrc = open(file_name, 'rb')
1030
  try:
1031
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1032
    fdst = os.fdopen(fd, 'wb')
1033
    try:
1034
      shutil.copyfileobj(fsrc, fdst)
1035
    finally:
1036
      fdst.close()
1037
  finally:
1038
    fsrc.close()
1039

    
1040
  return backup_name
1041

    
1042

    
1043
def ShellQuote(value):
1044
  """Quotes shell argument according to POSIX.
1045

1046
  @type value: str
1047
  @param value: the argument to be quoted
1048
  @rtype: str
1049
  @return: the quoted value
1050

1051
  """
1052
  if _re_shell_unquoted.match(value):
1053
    return value
1054
  else:
1055
    return "'%s'" % value.replace("'", "'\\''")
1056

    
1057

    
1058
def ShellQuoteArgs(args):
1059
  """Quotes a list of shell arguments.
1060

1061
  @type args: list
1062
  @param args: list of arguments to be quoted
1063
  @rtype: str
1064
  @return: the quoted arguments concatenated with spaces
1065

1066
  """
1067
  return ' '.join([ShellQuote(i) for i in args])
1068

    
1069

    
1070
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1071
  """Simple ping implementation using TCP connect(2).
1072

1073
  Check if the given IP is reachable by doing attempting a TCP connect
1074
  to it.
1075

1076
  @type target: str
1077
  @param target: the IP or hostname to ping
1078
  @type port: int
1079
  @param port: the port to connect to
1080
  @type timeout: int
1081
  @param timeout: the timeout on the connection attempt
1082
  @type live_port_needed: boolean
1083
  @param live_port_needed: whether a closed port will cause the
1084
      function to return failure, as if there was a timeout
1085
  @type source: str or None
1086
  @param source: if specified, will cause the connect to be made
1087
      from this specific source address; failures to bind other
1088
      than C{EADDRNOTAVAIL} will be ignored
1089

1090
  """
1091
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1092

    
1093
  success = False
1094

    
1095
  if source is not None:
1096
    try:
1097
      sock.bind((source, 0))
1098
    except socket.error, (errcode, _):
1099
      if errcode == errno.EADDRNOTAVAIL:
1100
        success = False
1101

    
1102
  sock.settimeout(timeout)
1103

    
1104
  try:
1105
    sock.connect((target, port))
1106
    sock.close()
1107
    success = True
1108
  except socket.timeout:
1109
    success = False
1110
  except socket.error, (errcode, errstring):
1111
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1112

    
1113
  return success
1114

    
1115

    
1116
def OwnIpAddress(address):
1117
  """Check if the current host has the the given IP address.
1118

1119
  Currently this is done by TCP-pinging the address from the loopback
1120
  address.
1121

1122
  @type address: string
1123
  @param address: the address to check
1124
  @rtype: bool
1125
  @return: True if we own the address
1126

1127
  """
1128
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1129
                 source=constants.LOCALHOST_IP_ADDRESS)
1130

    
1131

    
1132
def ListVisibleFiles(path):
1133
  """Returns a list of visible files in a directory.
1134

1135
  @type path: str
1136
  @param path: the directory to enumerate
1137
  @rtype: list
1138
  @return: the list of all files not starting with a dot
1139

1140
  """
1141
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1142
  files.sort()
1143
  return files
1144

    
1145

    
1146
def GetHomeDir(user, default=None):
1147
  """Try to get the homedir of the given user.
1148

1149
  The user can be passed either as a string (denoting the name) or as
1150
  an integer (denoting the user id). If the user is not found, the
1151
  'default' argument is returned, which defaults to None.
1152

1153
  """
1154
  try:
1155
    if isinstance(user, basestring):
1156
      result = pwd.getpwnam(user)
1157
    elif isinstance(user, (int, long)):
1158
      result = pwd.getpwuid(user)
1159
    else:
1160
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1161
                                   type(user))
1162
  except KeyError:
1163
    return default
1164
  return result.pw_dir
1165

    
1166

    
1167
def NewUUID():
1168
  """Returns a random UUID.
1169

1170
  @note: This is a Linux-specific method as it uses the /proc
1171
      filesystem.
1172
  @rtype: str
1173

1174
  """
1175
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1176

    
1177

    
1178
def GenerateSecret(numbytes=20):
1179
  """Generates a random secret.
1180

1181
  This will generate a pseudo-random secret returning an hex string
1182
  (so that it can be used where an ASCII string is needed).
1183

1184
  @param numbytes: the number of bytes which will be represented by the returned
1185
      string (defaulting to 20, the length of a SHA1 hash)
1186
  @rtype: str
1187
  @return: an hex representation of the pseudo-random sequence
1188

1189
  """
1190
  return os.urandom(numbytes).encode('hex')
1191

    
1192

    
1193
def EnsureDirs(dirs):
1194
  """Make required directories, if they don't exist.
1195

1196
  @param dirs: list of tuples (dir_name, dir_mode)
1197
  @type dirs: list of (string, integer)
1198

1199
  """
1200
  for dir_name, dir_mode in dirs:
1201
    try:
1202
      os.mkdir(dir_name, dir_mode)
1203
    except EnvironmentError, err:
1204
      if err.errno != errno.EEXIST:
1205
        raise errors.GenericError("Cannot create needed directory"
1206
                                  " '%s': %s" % (dir_name, err))
1207
    if not os.path.isdir(dir_name):
1208
      raise errors.GenericError("%s is not a directory" % dir_name)
1209

    
1210

    
1211
def ReadFile(file_name, size=None):
1212
  """Reads a file.
1213

1214
  @type size: None or int
1215
  @param size: Read at most size bytes
1216
  @rtype: str
1217
  @return: the (possibly partial) content of the file
1218

1219
  """
1220
  f = open(file_name, "r")
1221
  try:
1222
    if size is None:
1223
      return f.read()
1224
    else:
1225
      return f.read(size)
1226
  finally:
1227
    f.close()
1228

    
1229

    
1230
def WriteFile(file_name, fn=None, data=None,
1231
              mode=None, uid=-1, gid=-1,
1232
              atime=None, mtime=None, close=True,
1233
              dry_run=False, backup=False,
1234
              prewrite=None, postwrite=None):
1235
  """(Over)write a file atomically.
1236

1237
  The file_name and either fn (a function taking one argument, the
1238
  file descriptor, and which should write the data to it) or data (the
1239
  contents of the file) must be passed. The other arguments are
1240
  optional and allow setting the file mode, owner and group, and the
1241
  mtime/atime of the file.
1242

1243
  If the function doesn't raise an exception, it has succeeded and the
1244
  target file has the new contents. If the function has raised an
1245
  exception, an existing target file should be unmodified and the
1246
  temporary file should be removed.
1247

1248
  @type file_name: str
1249
  @param file_name: the target filename
1250
  @type fn: callable
1251
  @param fn: content writing function, called with
1252
      file descriptor as parameter
1253
  @type data: str
1254
  @param data: contents of the file
1255
  @type mode: int
1256
  @param mode: file mode
1257
  @type uid: int
1258
  @param uid: the owner of the file
1259
  @type gid: int
1260
  @param gid: the group of the file
1261
  @type atime: int
1262
  @param atime: a custom access time to be set on the file
1263
  @type mtime: int
1264
  @param mtime: a custom modification time to be set on the file
1265
  @type close: boolean
1266
  @param close: whether to close file after writing it
1267
  @type prewrite: callable
1268
  @param prewrite: function to be called before writing content
1269
  @type postwrite: callable
1270
  @param postwrite: function to be called after writing content
1271

1272
  @rtype: None or int
1273
  @return: None if the 'close' parameter evaluates to True,
1274
      otherwise the file descriptor
1275

1276
  @raise errors.ProgrammerError: if any of the arguments are not valid
1277

1278
  """
1279
  if not os.path.isabs(file_name):
1280
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1281
                                 " absolute: '%s'" % file_name)
1282

    
1283
  if [fn, data].count(None) != 1:
1284
    raise errors.ProgrammerError("fn or data required")
1285

    
1286
  if [atime, mtime].count(None) == 1:
1287
    raise errors.ProgrammerError("Both atime and mtime must be either"
1288
                                 " set or None")
1289

    
1290
  if backup and not dry_run and os.path.isfile(file_name):
1291
    CreateBackup(file_name)
1292

    
1293
  dir_name, base_name = os.path.split(file_name)
1294
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1295
  do_remove = True
1296
  # here we need to make sure we remove the temp file, if any error
1297
  # leaves it in place
1298
  try:
1299
    if uid != -1 or gid != -1:
1300
      os.chown(new_name, uid, gid)
1301
    if mode:
1302
      os.chmod(new_name, mode)
1303
    if callable(prewrite):
1304
      prewrite(fd)
1305
    if data is not None:
1306
      os.write(fd, data)
1307
    else:
1308
      fn(fd)
1309
    if callable(postwrite):
1310
      postwrite(fd)
1311
    os.fsync(fd)
1312
    if atime is not None and mtime is not None:
1313
      os.utime(new_name, (atime, mtime))
1314
    if not dry_run:
1315
      os.rename(new_name, file_name)
1316
      do_remove = False
1317
  finally:
1318
    if close:
1319
      os.close(fd)
1320
      result = None
1321
    else:
1322
      result = fd
1323
    if do_remove:
1324
      RemoveFile(new_name)
1325

    
1326
  return result
1327

    
1328

    
1329
def FirstFree(seq, base=0):
1330
  """Returns the first non-existing integer from seq.
1331

1332
  The seq argument should be a sorted list of positive integers. The
1333
  first time the index of an element is smaller than the element
1334
  value, the index will be returned.
1335

1336
  The base argument is used to start at a different offset,
1337
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1338

1339
  Example: C{[0, 1, 3]} will return I{2}.
1340

1341
  @type seq: sequence
1342
  @param seq: the sequence to be analyzed.
1343
  @type base: int
1344
  @param base: use this value as the base index of the sequence
1345
  @rtype: int
1346
  @return: the first non-used index in the sequence
1347

1348
  """
1349
  for idx, elem in enumerate(seq):
1350
    assert elem >= base, "Passed element is higher than base offset"
1351
    if elem > idx + base:
1352
      # idx is not used
1353
      return idx + base
1354
  return None
1355

    
1356

    
1357
def all(seq, pred=bool):
1358
  "Returns True if pred(x) is True for every element in the iterable"
1359
  for _ in itertools.ifilterfalse(pred, seq):
1360
    return False
1361
  return True
1362

    
1363

    
1364
def any(seq, pred=bool):
1365
  "Returns True if pred(x) is True for at least one element in the iterable"
1366
  for _ in itertools.ifilter(pred, seq):
1367
    return True
1368
  return False
1369

    
1370

    
1371
def UniqueSequence(seq):
1372
  """Returns a list with unique elements.
1373

1374
  Element order is preserved.
1375

1376
  @type seq: sequence
1377
  @param seq: the sequence with the source elements
1378
  @rtype: list
1379
  @return: list of unique elements from seq
1380

1381
  """
1382
  seen = set()
1383
  return [i for i in seq if i not in seen and not seen.add(i)]
1384

    
1385

    
1386
def IsValidMac(mac):
1387
  """Predicate to check if a MAC address is valid.
1388

1389
  Checks whether the supplied MAC address is formally correct, only
1390
  accepts colon separated format.
1391

1392
  @type mac: str
1393
  @param mac: the MAC to be validated
1394
  @rtype: boolean
1395
  @return: True is the MAC seems valid
1396

1397
  """
1398
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$")
1399
  return mac_check.match(mac) is not None
1400

    
1401

    
1402
def TestDelay(duration):
1403
  """Sleep for a fixed amount of time.
1404

1405
  @type duration: float
1406
  @param duration: the sleep duration
1407
  @rtype: boolean
1408
  @return: False for negative value, True otherwise
1409

1410
  """
1411
  if duration < 0:
1412
    return False, "Invalid sleep duration"
1413
  time.sleep(duration)
1414
  return True, None
1415

    
1416

    
1417
def _CloseFDNoErr(fd, retries=5):
1418
  """Close a file descriptor ignoring errors.
1419

1420
  @type fd: int
1421
  @param fd: the file descriptor
1422
  @type retries: int
1423
  @param retries: how many retries to make, in case we get any
1424
      other error than EBADF
1425

1426
  """
1427
  try:
1428
    os.close(fd)
1429
  except OSError, err:
1430
    if err.errno != errno.EBADF:
1431
      if retries > 0:
1432
        _CloseFDNoErr(fd, retries - 1)
1433
    # else either it's closed already or we're out of retries, so we
1434
    # ignore this and go on
1435

    
1436

    
1437
def CloseFDs(noclose_fds=None):
1438
  """Close file descriptors.
1439

1440
  This closes all file descriptors above 2 (i.e. except
1441
  stdin/out/err).
1442

1443
  @type noclose_fds: list or None
1444
  @param noclose_fds: if given, it denotes a list of file descriptor
1445
      that should not be closed
1446

1447
  """
1448
  # Default maximum for the number of available file descriptors.
1449
  if 'SC_OPEN_MAX' in os.sysconf_names:
1450
    try:
1451
      MAXFD = os.sysconf('SC_OPEN_MAX')
1452
      if MAXFD < 0:
1453
        MAXFD = 1024
1454
    except OSError:
1455
      MAXFD = 1024
1456
  else:
1457
    MAXFD = 1024
1458
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
1459
  if (maxfd == resource.RLIM_INFINITY):
1460
    maxfd = MAXFD
1461

    
1462
  # Iterate through and close all file descriptors (except the standard ones)
1463
  for fd in range(3, maxfd):
1464
    if noclose_fds and fd in noclose_fds:
1465
      continue
1466
    _CloseFDNoErr(fd)
1467

    
1468

    
1469
def Daemonize(logfile):
1470
  """Daemonize the current process.
1471

1472
  This detaches the current process from the controlling terminal and
1473
  runs it in the background as a daemon.
1474

1475
  @type logfile: str
1476
  @param logfile: the logfile to which we should redirect stdout/stderr
1477
  @rtype: int
1478
  @return: the value zero
1479

1480
  """
1481
  UMASK = 077
1482
  WORKDIR = "/"
1483

    
1484
  # this might fail
1485
  pid = os.fork()
1486
  if (pid == 0):  # The first child.
1487
    os.setsid()
1488
    # this might fail
1489
    pid = os.fork() # Fork a second child.
1490
    if (pid == 0):  # The second child.
1491
      os.chdir(WORKDIR)
1492
      os.umask(UMASK)
1493
    else:
1494
      # exit() or _exit()?  See below.
1495
      os._exit(0) # Exit parent (the first child) of the second child.
1496
  else:
1497
    os._exit(0) # Exit parent of the first child.
1498

    
1499
  for fd in range(3):
1500
    _CloseFDNoErr(fd)
1501
  i = os.open("/dev/null", os.O_RDONLY) # stdin
1502
  assert i == 0, "Can't close/reopen stdin"
1503
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
1504
  assert i == 1, "Can't close/reopen stdout"
1505
  # Duplicate standard output to standard error.
1506
  os.dup2(1, 2)
1507
  return 0
1508

    
1509

    
1510
def DaemonPidFileName(name):
1511
  """Compute a ganeti pid file absolute path
1512

1513
  @type name: str
1514
  @param name: the daemon name
1515
  @rtype: str
1516
  @return: the full path to the pidfile corresponding to the given
1517
      daemon name
1518

1519
  """
1520
  return os.path.join(constants.RUN_GANETI_DIR, "%s.pid" % name)
1521

    
1522

    
1523
def WritePidFile(name):
1524
  """Write the current process pidfile.
1525

1526
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
1527

1528
  @type name: str
1529
  @param name: the daemon name to use
1530
  @raise errors.GenericError: if the pid file already exists and
1531
      points to a live process
1532

1533
  """
1534
  pid = os.getpid()
1535
  pidfilename = DaemonPidFileName(name)
1536
  if IsProcessAlive(ReadPidFile(pidfilename)):
1537
    raise errors.GenericError("%s contains a live process" % pidfilename)
1538

    
1539
  WriteFile(pidfilename, data="%d\n" % pid)
1540

    
1541

    
1542
def RemovePidFile(name):
1543
  """Remove the current process pidfile.
1544

1545
  Any errors are ignored.
1546

1547
  @type name: str
1548
  @param name: the daemon name used to derive the pidfile name
1549

1550
  """
1551
  pidfilename = DaemonPidFileName(name)
1552
  # TODO: we could check here that the file contains our pid
1553
  try:
1554
    RemoveFile(pidfilename)
1555
  except:
1556
    pass
1557

    
1558

    
1559
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
1560
                waitpid=False):
1561
  """Kill a process given by its pid.
1562

1563
  @type pid: int
1564
  @param pid: The PID to terminate.
1565
  @type signal_: int
1566
  @param signal_: The signal to send, by default SIGTERM
1567
  @type timeout: int
1568
  @param timeout: The timeout after which, if the process is still alive,
1569
                  a SIGKILL will be sent. If not positive, no such checking
1570
                  will be done
1571
  @type waitpid: boolean
1572
  @param waitpid: If true, we should waitpid on this process after
1573
      sending signals, since it's our own child and otherwise it
1574
      would remain as zombie
1575

1576
  """
1577
  def _helper(pid, signal_, wait):
1578
    """Simple helper to encapsulate the kill/waitpid sequence"""
1579
    os.kill(pid, signal_)
1580
    if wait:
1581
      try:
1582
        os.waitpid(pid, os.WNOHANG)
1583
      except OSError:
1584
        pass
1585

    
1586
  if pid <= 0:
1587
    # kill with pid=0 == suicide
1588
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
1589

    
1590
  if not IsProcessAlive(pid):
1591
    return
1592

    
1593
  _helper(pid, signal_, waitpid)
1594

    
1595
  if timeout <= 0:
1596
    return
1597

    
1598
  def _CheckProcess():
1599
    if not IsProcessAlive(pid):
1600
      return
1601

    
1602
    try:
1603
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
1604
    except OSError:
1605
      raise RetryAgain()
1606

    
1607
    if result_pid > 0:
1608
      return
1609

    
1610
    raise RetryAgain()
1611

    
1612
  try:
1613
    # Wait up to $timeout seconds
1614
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
1615
  except RetryTimeout:
1616
    pass
1617

    
1618
  if IsProcessAlive(pid):
1619
    # Kill process if it's still alive
1620
    _helper(pid, signal.SIGKILL, waitpid)
1621

    
1622

    
1623
def FindFile(name, search_path, test=os.path.exists):
1624
  """Look for a filesystem object in a given path.
1625

1626
  This is an abstract method to search for filesystem object (files,
1627
  dirs) under a given search path.
1628

1629
  @type name: str
1630
  @param name: the name to look for
1631
  @type search_path: str
1632
  @param search_path: location to start at
1633
  @type test: callable
1634
  @param test: a function taking one argument that should return True
1635
      if the a given object is valid; the default value is
1636
      os.path.exists, causing only existing files to be returned
1637
  @rtype: str or None
1638
  @return: full path to the object if found, None otherwise
1639

1640
  """
1641
  for dir_name in search_path:
1642
    item_name = os.path.sep.join([dir_name, name])
1643
    if test(item_name):
1644
      return item_name
1645
  return None
1646

    
1647

    
1648
def CheckVolumeGroupSize(vglist, vgname, minsize):
1649
  """Checks if the volume group list is valid.
1650

1651
  The function will check if a given volume group is in the list of
1652
  volume groups and has a minimum size.
1653

1654
  @type vglist: dict
1655
  @param vglist: dictionary of volume group names and their size
1656
  @type vgname: str
1657
  @param vgname: the volume group we should check
1658
  @type minsize: int
1659
  @param minsize: the minimum size we accept
1660
  @rtype: None or str
1661
  @return: None for success, otherwise the error message
1662

1663
  """
1664
  vgsize = vglist.get(vgname, None)
1665
  if vgsize is None:
1666
    return "volume group '%s' missing" % vgname
1667
  elif vgsize < minsize:
1668
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
1669
            (vgname, minsize, vgsize))
1670
  return None
1671

    
1672

    
1673
def SplitTime(value):
1674
  """Splits time as floating point number into a tuple.
1675

1676
  @param value: Time in seconds
1677
  @type value: int or float
1678
  @return: Tuple containing (seconds, microseconds)
1679

1680
  """
1681
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
1682

    
1683
  assert 0 <= seconds, \
1684
    "Seconds must be larger than or equal to 0, but are %s" % seconds
1685
  assert 0 <= microseconds <= 999999, \
1686
    "Microseconds must be 0-999999, but are %s" % microseconds
1687

    
1688
  return (int(seconds), int(microseconds))
1689

    
1690

    
1691
def MergeTime(timetuple):
1692
  """Merges a tuple into time as a floating point number.
1693

1694
  @param timetuple: Time as tuple, (seconds, microseconds)
1695
  @type timetuple: tuple
1696
  @return: Time as a floating point number expressed in seconds
1697

1698
  """
1699
  (seconds, microseconds) = timetuple
1700

    
1701
  assert 0 <= seconds, \
1702
    "Seconds must be larger than or equal to 0, but are %s" % seconds
1703
  assert 0 <= microseconds <= 999999, \
1704
    "Microseconds must be 0-999999, but are %s" % microseconds
1705

    
1706
  return float(seconds) + (float(microseconds) * 0.000001)
1707

    
1708

    
1709
def GetDaemonPort(daemon_name):
1710
  """Get the daemon port for this cluster.
1711

1712
  Note that this routine does not read a ganeti-specific file, but
1713
  instead uses C{socket.getservbyname} to allow pre-customization of
1714
  this parameter outside of Ganeti.
1715

1716
  @type daemon_name: string
1717
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
1718
  @rtype: int
1719

1720
  """
1721
  if daemon_name not in constants.DAEMONS_PORTS:
1722
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
1723

    
1724
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
1725
  try:
1726
    port = socket.getservbyname(daemon_name, proto)
1727
  except socket.error:
1728
    port = default_port
1729

    
1730
  return port
1731

    
1732

    
1733
def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
1734
                 multithreaded=False):
1735
  """Configures the logging module.
1736

1737
  @type logfile: str
1738
  @param logfile: the filename to which we should log
1739
  @type debug: boolean
1740
  @param debug: whether to enable debug messages too or
1741
      only those at C{INFO} and above level
1742
  @type stderr_logging: boolean
1743
  @param stderr_logging: whether we should also log to the standard error
1744
  @type program: str
1745
  @param program: the name under which we should log messages
1746
  @type multithreaded: boolean
1747
  @param multithreaded: if True, will add the thread name to the log file
1748
  @raise EnvironmentError: if we can't open the log file and
1749
      stderr logging is disabled
1750

1751
  """
1752
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
1753
  if multithreaded:
1754
    fmt += "/%(threadName)s"
1755
  if debug:
1756
    fmt += " %(module)s:%(lineno)s"
1757
  fmt += " %(levelname)s %(message)s"
1758
  formatter = logging.Formatter(fmt)
1759

    
1760
  root_logger = logging.getLogger("")
1761
  root_logger.setLevel(logging.NOTSET)
1762

    
1763
  # Remove all previously setup handlers
1764
  for handler in root_logger.handlers:
1765
    handler.close()
1766
    root_logger.removeHandler(handler)
1767

    
1768
  if stderr_logging:
1769
    stderr_handler = logging.StreamHandler()
1770
    stderr_handler.setFormatter(formatter)
1771
    if debug:
1772
      stderr_handler.setLevel(logging.NOTSET)
1773
    else:
1774
      stderr_handler.setLevel(logging.CRITICAL)
1775
    root_logger.addHandler(stderr_handler)
1776

    
1777
  # this can fail, if the logging directories are not setup or we have
1778
  # a permisssion problem; in this case, it's best to log but ignore
1779
  # the error if stderr_logging is True, and if false we re-raise the
1780
  # exception since otherwise we could run but without any logs at all
1781
  try:
1782
    logfile_handler = logging.FileHandler(logfile)
1783
    logfile_handler.setFormatter(formatter)
1784
    if debug:
1785
      logfile_handler.setLevel(logging.DEBUG)
1786
    else:
1787
      logfile_handler.setLevel(logging.INFO)
1788
    root_logger.addHandler(logfile_handler)
1789
  except EnvironmentError:
1790
    if stderr_logging:
1791
      logging.exception("Failed to enable logging to file '%s'", logfile)
1792
    else:
1793
      # we need to re-raise the exception
1794
      raise
1795

    
1796

    
1797
def IsNormAbsPath(path):
1798
  """Check whether a path is absolute and also normalized
1799

1800
  This avoids things like /dir/../../other/path to be valid.
1801

1802
  """
1803
  return os.path.normpath(path) == path and os.path.isabs(path)
1804

    
1805

    
1806
def TailFile(fname, lines=20):
1807
  """Return the last lines from a file.
1808

1809
  @note: this function will only read and parse the last 4KB of
1810
      the file; if the lines are very long, it could be that less
1811
      than the requested number of lines are returned
1812

1813
  @param fname: the file name
1814
  @type lines: int
1815
  @param lines: the (maximum) number of lines to return
1816

1817
  """
1818
  fd = open(fname, "r")
1819
  try:
1820
    fd.seek(0, 2)
1821
    pos = fd.tell()
1822
    pos = max(0, pos-4096)
1823
    fd.seek(pos, 0)
1824
    raw_data = fd.read()
1825
  finally:
1826
    fd.close()
1827

    
1828
  rows = raw_data.splitlines()
1829
  return rows[-lines:]
1830

    
1831

    
1832
def SafeEncode(text):
1833
  """Return a 'safe' version of a source string.
1834

1835
  This function mangles the input string and returns a version that
1836
  should be safe to display/encode as ASCII. To this end, we first
1837
  convert it to ASCII using the 'backslashreplace' encoding which
1838
  should get rid of any non-ASCII chars, and then we process it
1839
  through a loop copied from the string repr sources in the python; we
1840
  don't use string_escape anymore since that escape single quotes and
1841
  backslashes too, and that is too much; and that escaping is not
1842
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
1843

1844
  @type text: str or unicode
1845
  @param text: input data
1846
  @rtype: str
1847
  @return: a safe version of text
1848

1849
  """
1850
  if isinstance(text, unicode):
1851
    # only if unicode; if str already, we handle it below
1852
    text = text.encode('ascii', 'backslashreplace')
1853
  resu = ""
1854
  for char in text:
1855
    c = ord(char)
1856
    if char  == '\t':
1857
      resu += r'\t'
1858
    elif char == '\n':
1859
      resu += r'\n'
1860
    elif char == '\r':
1861
      resu += r'\'r'
1862
    elif c < 32 or c >= 127: # non-printable
1863
      resu += "\\x%02x" % (c & 0xff)
1864
    else:
1865
      resu += char
1866
  return resu
1867

    
1868

    
1869
def BytesToMebibyte(value):
1870
  """Converts bytes to mebibytes.
1871

1872
  @type value: int
1873
  @param value: Value in bytes
1874
  @rtype: int
1875
  @return: Value in mebibytes
1876

1877
  """
1878
  return int(round(value / (1024.0 * 1024.0), 0))
1879

    
1880

    
1881
def CalculateDirectorySize(path):
1882
  """Calculates the size of a directory recursively.
1883

1884
  @type path: string
1885
  @param path: Path to directory
1886
  @rtype: int
1887
  @return: Size in mebibytes
1888

1889
  """
1890
  size = 0
1891

    
1892
  for (curpath, _, files) in os.walk(path):
1893
    for filename in files:
1894
      st = os.lstat(os.path.join(curpath, filename))
1895
      size += st.st_size
1896

    
1897
  return BytesToMebibyte(size)
1898

    
1899

    
1900
def GetFilesystemStats(path):
1901
  """Returns the total and free space on a filesystem.
1902

1903
  @type path: string
1904
  @param path: Path on filesystem to be examined
1905
  @rtype: int
1906
  @return: tuple of (Total space, Free space) in mebibytes
1907

1908
  """
1909
  st = os.statvfs(path)
1910

    
1911
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
1912
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
1913
  return (tsize, fsize)
1914

    
1915

    
1916
def LockedMethod(fn):
1917
  """Synchronized object access decorator.
1918

1919
  This decorator is intended to protect access to an object using the
1920
  object's own lock which is hardcoded to '_lock'.
1921

1922
  """
1923
  def _LockDebug(*args, **kwargs):
1924
    if debug_locks:
1925
      logging.debug(*args, **kwargs)
1926

    
1927
  def wrapper(self, *args, **kwargs):
1928
    assert hasattr(self, '_lock')
1929
    lock = self._lock
1930
    _LockDebug("Waiting for %s", lock)
1931
    lock.acquire()
1932
    try:
1933
      _LockDebug("Acquired %s", lock)
1934
      result = fn(self, *args, **kwargs)
1935
    finally:
1936
      _LockDebug("Releasing %s", lock)
1937
      lock.release()
1938
      _LockDebug("Released %s", lock)
1939
    return result
1940
  return wrapper
1941

    
1942

    
1943
def LockFile(fd):
1944
  """Locks a file using POSIX locks.
1945

1946
  @type fd: int
1947
  @param fd: the file descriptor we need to lock
1948

1949
  """
1950
  try:
1951
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
1952
  except IOError, err:
1953
    if err.errno == errno.EAGAIN:
1954
      raise errors.LockError("File already locked")
1955
    raise
1956

    
1957

    
1958
def FormatTime(val):
1959
  """Formats a time value.
1960

1961
  @type val: float or None
1962
  @param val: the timestamp as returned by time.time()
1963
  @return: a string value or N/A if we don't have a valid timestamp
1964

1965
  """
1966
  if val is None or not isinstance(val, (int, float)):
1967
    return "N/A"
1968
  # these two codes works on Linux, but they are not guaranteed on all
1969
  # platforms
1970
  return time.strftime("%F %T", time.localtime(val))
1971

    
1972

    
1973
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
1974
  """Reads the watcher pause file.
1975

1976
  @type filename: string
1977
  @param filename: Path to watcher pause file
1978
  @type now: None, float or int
1979
  @param now: Current time as Unix timestamp
1980
  @type remove_after: int
1981
  @param remove_after: Remove watcher pause file after specified amount of
1982
    seconds past the pause end time
1983

1984
  """
1985
  if now is None:
1986
    now = time.time()
1987

    
1988
  try:
1989
    value = ReadFile(filename)
1990
  except IOError, err:
1991
    if err.errno != errno.ENOENT:
1992
      raise
1993
    value = None
1994

    
1995
  if value is not None:
1996
    try:
1997
      value = int(value)
1998
    except ValueError:
1999
      logging.warning(("Watcher pause file (%s) contains invalid value,"
2000
                       " removing it"), filename)
2001
      RemoveFile(filename)
2002
      value = None
2003

    
2004
    if value is not None:
2005
      # Remove file if it's outdated
2006
      if now > (value + remove_after):
2007
        RemoveFile(filename)
2008
        value = None
2009

    
2010
      elif now > value:
2011
        value = None
2012

    
2013
  return value
2014

    
2015

    
2016
class RetryTimeout(Exception):
2017
  """Retry loop timed out.
2018

2019
  """
2020

    
2021

    
2022
class RetryAgain(Exception):
2023
  """Retry again.
2024

2025
  """
2026

    
2027

    
2028
class _RetryDelayCalculator(object):
2029
  """Calculator for increasing delays.
2030

2031
  """
2032
  __slots__ = [
2033
    "_factor",
2034
    "_limit",
2035
    "_next",
2036
    "_start",
2037
    ]
2038

    
2039
  def __init__(self, start, factor, limit):
2040
    """Initializes this class.
2041

2042
    @type start: float
2043
    @param start: Initial delay
2044
    @type factor: float
2045
    @param factor: Factor for delay increase
2046
    @type limit: float or None
2047
    @param limit: Upper limit for delay or None for no limit
2048

2049
    """
2050
    assert start > 0.0
2051
    assert factor >= 1.0
2052
    assert limit is None or limit >= 0.0
2053

    
2054
    self._start = start
2055
    self._factor = factor
2056
    self._limit = limit
2057

    
2058
    self._next = start
2059

    
2060
  def __call__(self):
2061
    """Returns current delay and calculates the next one.
2062

2063
    """
2064
    current = self._next
2065

    
2066
    # Update for next run
2067
    if self._limit is None or self._next < self._limit:
2068
      self._next = max(self._limit, self._next * self._factor)
2069

    
2070
    return current
2071

    
2072

    
2073
#: Special delay to specify whole remaining timeout
2074
RETRY_REMAINING_TIME = object()
2075

    
2076

    
2077
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
2078
          _time_fn=time.time):
2079
  """Call a function repeatedly until it succeeds.
2080

2081
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
2082
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
2083
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
2084

2085
  C{delay} can be one of the following:
2086
    - callable returning the delay length as a float
2087
    - Tuple of (start, factor, limit)
2088
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
2089
      useful when overriding L{wait_fn} to wait for an external event)
2090
    - A static delay as a number (int or float)
2091

2092
  @type fn: callable
2093
  @param fn: Function to be called
2094
  @param delay: Either a callable (returning the delay), a tuple of (start,
2095
                factor, limit) (see L{_RetryDelayCalculator}),
2096
                L{RETRY_REMAINING_TIME} or a number (int or float)
2097
  @type timeout: float
2098
  @param timeout: Total timeout
2099
  @type wait_fn: callable
2100
  @param wait_fn: Waiting function
2101
  @return: Return value of function
2102

2103
  """
2104
  assert callable(fn)
2105
  assert callable(wait_fn)
2106
  assert callable(_time_fn)
2107

    
2108
  if args is None:
2109
    args = []
2110

    
2111
  end_time = _time_fn() + timeout
2112

    
2113
  if callable(delay):
2114
    # External function to calculate delay
2115
    calc_delay = delay
2116

    
2117
  elif isinstance(delay, (tuple, list)):
2118
    # Increasing delay with optional upper boundary
2119
    (start, factor, limit) = delay
2120
    calc_delay = _RetryDelayCalculator(start, factor, limit)
2121

    
2122
  elif delay is RETRY_REMAINING_TIME:
2123
    # Always use the remaining time
2124
    calc_delay = None
2125

    
2126
  else:
2127
    # Static delay
2128
    calc_delay = lambda: delay
2129

    
2130
  assert calc_delay is None or callable(calc_delay)
2131

    
2132
  while True:
2133
    try:
2134
      return fn(*args)
2135
    except RetryAgain:
2136
      pass
2137

    
2138
    remaining_time = end_time - _time_fn()
2139

    
2140
    if remaining_time < 0.0:
2141
      raise RetryTimeout()
2142

    
2143
    assert remaining_time >= 0.0
2144

    
2145
    if calc_delay is None:
2146
      wait_fn(remaining_time)
2147
    else:
2148
      current_delay = calc_delay()
2149
      if current_delay > 0.0:
2150
        wait_fn(current_delay)
2151

    
2152

    
2153
class FileLock(object):
2154
  """Utility class for file locks.
2155

2156
  """
2157
  def __init__(self, filename):
2158
    """Constructor for FileLock.
2159

2160
    This will open the file denoted by the I{filename} argument.
2161

2162
    @type filename: str
2163
    @param filename: path to the file to be locked
2164

2165
    """
2166
    self.filename = filename
2167
    self.fd = open(self.filename, "w")
2168

    
2169
  def __del__(self):
2170
    self.Close()
2171

    
2172
  def Close(self):
2173
    """Close the file and release the lock.
2174

2175
    """
2176
    if self.fd:
2177
      self.fd.close()
2178
      self.fd = None
2179

    
2180
  def _flock(self, flag, blocking, timeout, errmsg):
2181
    """Wrapper for fcntl.flock.
2182

2183
    @type flag: int
2184
    @param flag: operation flag
2185
    @type blocking: bool
2186
    @param blocking: whether the operation should be done in blocking mode.
2187
    @type timeout: None or float
2188
    @param timeout: for how long the operation should be retried (implies
2189
                    non-blocking mode).
2190
    @type errmsg: string
2191
    @param errmsg: error message in case operation fails.
2192

2193
    """
2194
    assert self.fd, "Lock was closed"
2195
    assert timeout is None or timeout >= 0, \
2196
      "If specified, timeout must be positive"
2197

    
2198
    if timeout is not None:
2199
      flag |= fcntl.LOCK_NB
2200
      timeout_end = time.time() + timeout
2201

    
2202
    # Blocking doesn't have effect with timeout
2203
    elif not blocking:
2204
      flag |= fcntl.LOCK_NB
2205
      timeout_end = None
2206

    
2207
    # TODO: Convert to utils.Retry
2208

    
2209
    retry = True
2210
    while retry:
2211
      try:
2212
        fcntl.flock(self.fd, flag)
2213
        retry = False
2214
      except IOError, err:
2215
        if err.errno in (errno.EAGAIN, ):
2216
          if timeout_end is not None and time.time() < timeout_end:
2217
            # Wait before trying again
2218
            time.sleep(max(0.1, min(1.0, timeout)))
2219
          else:
2220
            raise errors.LockError(errmsg)
2221
        else:
2222
          logging.exception("fcntl.flock failed")
2223
          raise
2224

    
2225
  def Exclusive(self, blocking=False, timeout=None):
2226
    """Locks the file in exclusive mode.
2227

2228
    @type blocking: boolean
2229
    @param blocking: whether to block and wait until we
2230
        can lock the file or return immediately
2231
    @type timeout: int or None
2232
    @param timeout: if not None, the duration to wait for the lock
2233
        (in blocking mode)
2234

2235
    """
2236
    self._flock(fcntl.LOCK_EX, blocking, timeout,
2237
                "Failed to lock %s in exclusive mode" % self.filename)
2238

    
2239
  def Shared(self, blocking=False, timeout=None):
2240
    """Locks the file in shared mode.
2241

2242
    @type blocking: boolean
2243
    @param blocking: whether to block and wait until we
2244
        can lock the file or return immediately
2245
    @type timeout: int or None
2246
    @param timeout: if not None, the duration to wait for the lock
2247
        (in blocking mode)
2248

2249
    """
2250
    self._flock(fcntl.LOCK_SH, blocking, timeout,
2251
                "Failed to lock %s in shared mode" % self.filename)
2252

    
2253
  def Unlock(self, blocking=True, timeout=None):
2254
    """Unlocks the file.
2255

2256
    According to C{flock(2)}, unlocking can also be a nonblocking
2257
    operation::
2258

2259
      To make a non-blocking request, include LOCK_NB with any of the above
2260
      operations.
2261

2262
    @type blocking: boolean
2263
    @param blocking: whether to block and wait until we
2264
        can lock the file or return immediately
2265
    @type timeout: int or None
2266
    @param timeout: if not None, the duration to wait for the lock
2267
        (in blocking mode)
2268

2269
    """
2270
    self._flock(fcntl.LOCK_UN, blocking, timeout,
2271
                "Failed to unlock %s" % self.filename)
2272

    
2273

    
2274
def SignalHandled(signums):
2275
  """Signal Handled decoration.
2276

2277
  This special decorator installs a signal handler and then calls the target
2278
  function. The function must accept a 'signal_handlers' keyword argument,
2279
  which will contain a dict indexed by signal number, with SignalHandler
2280
  objects as values.
2281

2282
  The decorator can be safely stacked with iself, to handle multiple signals
2283
  with different handlers.
2284

2285
  @type signums: list
2286
  @param signums: signals to intercept
2287

2288
  """
2289
  def wrap(fn):
2290
    def sig_function(*args, **kwargs):
2291
      assert 'signal_handlers' not in kwargs or \
2292
             kwargs['signal_handlers'] is None or \
2293
             isinstance(kwargs['signal_handlers'], dict), \
2294
             "Wrong signal_handlers parameter in original function call"
2295
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
2296
        signal_handlers = kwargs['signal_handlers']
2297
      else:
2298
        signal_handlers = {}
2299
        kwargs['signal_handlers'] = signal_handlers
2300
      sighandler = SignalHandler(signums)
2301
      try:
2302
        for sig in signums:
2303
          signal_handlers[sig] = sighandler
2304
        return fn(*args, **kwargs)
2305
      finally:
2306
        sighandler.Reset()
2307
    return sig_function
2308
  return wrap
2309

    
2310

    
2311
class SignalHandler(object):
2312
  """Generic signal handler class.
2313

2314
  It automatically restores the original handler when deconstructed or
2315
  when L{Reset} is called. You can either pass your own handler
2316
  function in or query the L{called} attribute to detect whether the
2317
  signal was sent.
2318

2319
  @type signum: list
2320
  @ivar signum: the signals we handle
2321
  @type called: boolean
2322
  @ivar called: tracks whether any of the signals have been raised
2323

2324
  """
2325
  def __init__(self, signum):
2326
    """Constructs a new SignalHandler instance.
2327

2328
    @type signum: int or list of ints
2329
    @param signum: Single signal number or set of signal numbers
2330

2331
    """
2332
    self.signum = set(signum)
2333
    self.called = False
2334

    
2335
    self._previous = {}
2336
    try:
2337
      for signum in self.signum:
2338
        # Setup handler
2339
        prev_handler = signal.signal(signum, self._HandleSignal)
2340
        try:
2341
          self._previous[signum] = prev_handler
2342
        except:
2343
          # Restore previous handler
2344
          signal.signal(signum, prev_handler)
2345
          raise
2346
    except:
2347
      # Reset all handlers
2348
      self.Reset()
2349
      # Here we have a race condition: a handler may have already been called,
2350
      # but there's not much we can do about it at this point.
2351
      raise
2352

    
2353
  def __del__(self):
2354
    self.Reset()
2355

    
2356
  def Reset(self):
2357
    """Restore previous handler.
2358

2359
    This will reset all the signals to their previous handlers.
2360

2361
    """
2362
    for signum, prev_handler in self._previous.items():
2363
      signal.signal(signum, prev_handler)
2364
      # If successful, remove from dict
2365
      del self._previous[signum]
2366

    
2367
  def Clear(self):
2368
    """Unsets the L{called} flag.
2369

2370
    This function can be used in case a signal may arrive several times.
2371

2372
    """
2373
    self.called = False
2374

    
2375
  def _HandleSignal(self, signum, frame):
2376
    """Actual signal handling function.
2377

2378
    """
2379
    # This is not nice and not absolutely atomic, but it appears to be the only
2380
    # solution in Python -- there are no atomic types.
2381
    self.called = True
2382

    
2383

    
2384
class FieldSet(object):
2385
  """A simple field set.
2386

2387
  Among the features are:
2388
    - checking if a string is among a list of static string or regex objects
2389
    - checking if a whole list of string matches
2390
    - returning the matching groups from a regex match
2391

2392
  Internally, all fields are held as regular expression objects.
2393

2394
  """
2395
  def __init__(self, *items):
2396
    self.items = [re.compile("^%s$" % value) for value in items]
2397

    
2398
  def Extend(self, other_set):
2399
    """Extend the field set with the items from another one"""
2400
    self.items.extend(other_set.items)
2401

    
2402
  def Matches(self, field):
2403
    """Checks if a field matches the current set
2404

2405
    @type field: str
2406
    @param field: the string to match
2407
    @return: either False or a regular expression match object
2408

2409
    """
2410
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
2411
      return m
2412
    return False
2413

    
2414
  def NonMatching(self, items):
2415
    """Returns the list of fields not matching the current set
2416

2417
    @type items: list
2418
    @param items: the list of fields to check
2419
    @rtype: list
2420
    @return: list of non-matching fields
2421

2422
    """
2423
    return [val for val in items if not self.Matches(val)]