Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 6c881c52

History | View | Annotate | Download (63.6 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import time
32
import subprocess
33
import re
34
import socket
35
import tempfile
36
import shutil
37
import errno
38
import pwd
39
import itertools
40
import select
41
import fcntl
42
import resource
43
import logging
44
import signal
45

    
46
from cStringIO import StringIO
47

    
48
try:
49
  from hashlib import sha1
50
except ImportError:
51
  import sha
52
  sha1 = sha.new
53

    
54
from ganeti import errors
55
from ganeti import constants
56

    
57

    
58
_locksheld = []
59
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
60

    
61
debug_locks = False
62

    
63
#: when set to True, L{RunCmd} is disabled
64
no_fork = False
65

    
66
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
67

    
68

    
69
class RunResult(object):
70
  """Holds the result of running external programs.
71

72
  @type exit_code: int
73
  @ivar exit_code: the exit code of the program, or None (if the program
74
      didn't exit())
75
  @type signal: int or None
76
  @ivar signal: the signal that caused the program to finish, or None
77
      (if the program wasn't terminated by a signal)
78
  @type stdout: str
79
  @ivar stdout: the standard output of the program
80
  @type stderr: str
81
  @ivar stderr: the standard error of the program
82
  @type failed: boolean
83
  @ivar failed: True in case the program was
84
      terminated by a signal or exited with a non-zero exit code
85
  @ivar fail_reason: a string detailing the termination reason
86

87
  """
88
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
89
               "failed", "fail_reason", "cmd"]
90

    
91

    
92
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
93
    self.cmd = cmd
94
    self.exit_code = exit_code
95
    self.signal = signal_
96
    self.stdout = stdout
97
    self.stderr = stderr
98
    self.failed = (signal_ is not None or exit_code != 0)
99

    
100
    if self.signal is not None:
101
      self.fail_reason = "terminated by signal %s" % self.signal
102
    elif self.exit_code is not None:
103
      self.fail_reason = "exited with exit code %s" % self.exit_code
104
    else:
105
      self.fail_reason = "unable to determine termination reason"
106

    
107
    if self.failed:
108
      logging.debug("Command '%s' failed (%s); output: %s",
109
                    self.cmd, self.fail_reason, self.output)
110

    
111
  def _GetOutput(self):
112
    """Returns the combined stdout and stderr for easier usage.
113

114
    """
115
    return self.stdout + self.stderr
116

    
117
  output = property(_GetOutput, None, None, "Return full output")
118

    
119

    
120
def RunCmd(cmd, env=None, output=None, cwd='/'):
121
  """Execute a (shell) command.
122

123
  The command should not read from its standard input, as it will be
124
  closed.
125

126
  @type  cmd: string or list
127
  @param cmd: Command to run
128
  @type env: dict
129
  @param env: Additional environment
130
  @type output: str
131
  @param output: if desired, the output of the command can be
132
      saved in a file instead of the RunResult instance; this
133
      parameter denotes the file name (if not None)
134
  @type cwd: string
135
  @param cwd: if specified, will be used as the working
136
      directory for the command; the default will be /
137
  @rtype: L{RunResult}
138
  @return: RunResult instance
139
  @raise errors.ProgrammerError: if we call this when forks are disabled
140

141
  """
142
  if no_fork:
143
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
144

    
145
  if isinstance(cmd, list):
146
    cmd = [str(val) for val in cmd]
147
    strcmd = " ".join(cmd)
148
    shell = False
149
  else:
150
    strcmd = cmd
151
    shell = True
152
  logging.debug("RunCmd '%s'", strcmd)
153

    
154
  cmd_env = os.environ.copy()
155
  cmd_env["LC_ALL"] = "C"
156
  if env is not None:
157
    cmd_env.update(env)
158

    
159
  try:
160
    if output is None:
161
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
162
    else:
163
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
164
      out = err = ""
165
  except OSError, err:
166
    if err.errno == errno.ENOENT:
167
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
168
                               (strcmd, err))
169
    else:
170
      raise
171

    
172
  if status >= 0:
173
    exitcode = status
174
    signal_ = None
175
  else:
176
    exitcode = None
177
    signal_ = -status
178

    
179
  return RunResult(exitcode, signal_, out, err, strcmd)
180

    
181

    
182
def _RunCmdPipe(cmd, env, via_shell, cwd):
183
  """Run a command and return its output.
184

185
  @type  cmd: string or list
186
  @param cmd: Command to run
187
  @type env: dict
188
  @param env: The environment to use
189
  @type via_shell: bool
190
  @param via_shell: if we should run via the shell
191
  @type cwd: string
192
  @param cwd: the working directory for the program
193
  @rtype: tuple
194
  @return: (out, err, status)
195

196
  """
197
  poller = select.poll()
198
  child = subprocess.Popen(cmd, shell=via_shell,
199
                           stderr=subprocess.PIPE,
200
                           stdout=subprocess.PIPE,
201
                           stdin=subprocess.PIPE,
202
                           close_fds=True, env=env,
203
                           cwd=cwd)
204

    
205
  child.stdin.close()
206
  poller.register(child.stdout, select.POLLIN)
207
  poller.register(child.stderr, select.POLLIN)
208
  out = StringIO()
209
  err = StringIO()
210
  fdmap = {
211
    child.stdout.fileno(): (out, child.stdout),
212
    child.stderr.fileno(): (err, child.stderr),
213
    }
214
  for fd in fdmap:
215
    status = fcntl.fcntl(fd, fcntl.F_GETFL)
216
    fcntl.fcntl(fd, fcntl.F_SETFL, status | os.O_NONBLOCK)
217

    
218
  while fdmap:
219
    try:
220
      pollresult = poller.poll()
221
    except EnvironmentError, eerr:
222
      if eerr.errno == errno.EINTR:
223
        continue
224
      raise
225
    except select.error, serr:
226
      if serr[0] == errno.EINTR:
227
        continue
228
      raise
229

    
230
    for fd, event in pollresult:
231
      if event & select.POLLIN or event & select.POLLPRI:
232
        data = fdmap[fd][1].read()
233
        # no data from read signifies EOF (the same as POLLHUP)
234
        if not data:
235
          poller.unregister(fd)
236
          del fdmap[fd]
237
          continue
238
        fdmap[fd][0].write(data)
239
      if (event & select.POLLNVAL or event & select.POLLHUP or
240
          event & select.POLLERR):
241
        poller.unregister(fd)
242
        del fdmap[fd]
243

    
244
  out = out.getvalue()
245
  err = err.getvalue()
246

    
247
  status = child.wait()
248
  return out, err, status
249

    
250

    
251
def _RunCmdFile(cmd, env, via_shell, output, cwd):
252
  """Run a command and save its output to a file.
253

254
  @type  cmd: string or list
255
  @param cmd: Command to run
256
  @type env: dict
257
  @param env: The environment to use
258
  @type via_shell: bool
259
  @param via_shell: if we should run via the shell
260
  @type output: str
261
  @param output: the filename in which to save the output
262
  @type cwd: string
263
  @param cwd: the working directory for the program
264
  @rtype: int
265
  @return: the exit status
266

267
  """
268
  fh = open(output, "a")
269
  try:
270
    child = subprocess.Popen(cmd, shell=via_shell,
271
                             stderr=subprocess.STDOUT,
272
                             stdout=fh,
273
                             stdin=subprocess.PIPE,
274
                             close_fds=True, env=env,
275
                             cwd=cwd)
276

    
277
    child.stdin.close()
278
    status = child.wait()
279
  finally:
280
    fh.close()
281
  return status
282

    
283

    
284
def RemoveFile(filename):
285
  """Remove a file ignoring some errors.
286

287
  Remove a file, ignoring non-existing ones or directories. Other
288
  errors are passed.
289

290
  @type filename: str
291
  @param filename: the file to be removed
292

293
  """
294
  try:
295
    os.unlink(filename)
296
  except OSError, err:
297
    if err.errno not in (errno.ENOENT, errno.EISDIR):
298
      raise
299

    
300

    
301
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
302
  """Renames a file.
303

304
  @type old: string
305
  @param old: Original path
306
  @type new: string
307
  @param new: New path
308
  @type mkdir: bool
309
  @param mkdir: Whether to create target directory if it doesn't exist
310
  @type mkdir_mode: int
311
  @param mkdir_mode: Mode for newly created directories
312

313
  """
314
  try:
315
    return os.rename(old, new)
316
  except OSError, err:
317
    # In at least one use case of this function, the job queue, directory
318
    # creation is very rare. Checking for the directory before renaming is not
319
    # as efficient.
320
    if mkdir and err.errno == errno.ENOENT:
321
      # Create directory and try again
322
      os.makedirs(os.path.dirname(new), mkdir_mode)
323
      return os.rename(old, new)
324
    raise
325

    
326

    
327
def _FingerprintFile(filename):
328
  """Compute the fingerprint of a file.
329

330
  If the file does not exist, a None will be returned
331
  instead.
332

333
  @type filename: str
334
  @param filename: the filename to checksum
335
  @rtype: str
336
  @return: the hex digest of the sha checksum of the contents
337
      of the file
338

339
  """
340
  if not (os.path.exists(filename) and os.path.isfile(filename)):
341
    return None
342

    
343
  f = open(filename)
344

    
345
  fp = sha1()
346
  while True:
347
    data = f.read(4096)
348
    if not data:
349
      break
350

    
351
    fp.update(data)
352

    
353
  return fp.hexdigest()
354

    
355

    
356
def FingerprintFiles(files):
357
  """Compute fingerprints for a list of files.
358

359
  @type files: list
360
  @param files: the list of filename to fingerprint
361
  @rtype: dict
362
  @return: a dictionary filename: fingerprint, holding only
363
      existing files
364

365
  """
366
  ret = {}
367

    
368
  for filename in files:
369
    cksum = _FingerprintFile(filename)
370
    if cksum:
371
      ret[filename] = cksum
372

    
373
  return ret
374

    
375

    
376
def ForceDictType(target, key_types, allowed_values=None):
377
  """Force the values of a dict to have certain types.
378

379
  @type target: dict
380
  @param target: the dict to update
381
  @type key_types: dict
382
  @param key_types: dict mapping target dict keys to types
383
                    in constants.ENFORCEABLE_TYPES
384
  @type allowed_values: list
385
  @keyword allowed_values: list of specially allowed values
386

387
  """
388
  if allowed_values is None:
389
    allowed_values = []
390

    
391
  if not isinstance(target, dict):
392
    msg = "Expected dictionary, got '%s'" % target
393
    raise errors.TypeEnforcementError(msg)
394

    
395
  for key in target:
396
    if key not in key_types:
397
      msg = "Unknown key '%s'" % key
398
      raise errors.TypeEnforcementError(msg)
399

    
400
    if target[key] in allowed_values:
401
      continue
402

    
403
    ktype = key_types[key]
404
    if ktype not in constants.ENFORCEABLE_TYPES:
405
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
406
      raise errors.ProgrammerError(msg)
407

    
408
    if ktype == constants.VTYPE_STRING:
409
      if not isinstance(target[key], basestring):
410
        if isinstance(target[key], bool) and not target[key]:
411
          target[key] = ''
412
        else:
413
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
414
          raise errors.TypeEnforcementError(msg)
415
    elif ktype == constants.VTYPE_BOOL:
416
      if isinstance(target[key], basestring) and target[key]:
417
        if target[key].lower() == constants.VALUE_FALSE:
418
          target[key] = False
419
        elif target[key].lower() == constants.VALUE_TRUE:
420
          target[key] = True
421
        else:
422
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
423
          raise errors.TypeEnforcementError(msg)
424
      elif target[key]:
425
        target[key] = True
426
      else:
427
        target[key] = False
428
    elif ktype == constants.VTYPE_SIZE:
429
      try:
430
        target[key] = ParseUnit(target[key])
431
      except errors.UnitParseError, err:
432
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
433
              (key, target[key], err)
434
        raise errors.TypeEnforcementError(msg)
435
    elif ktype == constants.VTYPE_INT:
436
      try:
437
        target[key] = int(target[key])
438
      except (ValueError, TypeError):
439
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
440
        raise errors.TypeEnforcementError(msg)
441

    
442

    
443
def IsProcessAlive(pid):
444
  """Check if a given pid exists on the system.
445

446
  @note: zombie status is not handled, so zombie processes
447
      will be returned as alive
448
  @type pid: int
449
  @param pid: the process ID to check
450
  @rtype: boolean
451
  @return: True if the process exists
452

453
  """
454
  if pid <= 0:
455
    return False
456

    
457
  try:
458
    os.stat("/proc/%d/status" % pid)
459
    return True
460
  except EnvironmentError, err:
461
    if err.errno in (errno.ENOENT, errno.ENOTDIR):
462
      return False
463
    raise
464

    
465

    
466
def ReadPidFile(pidfile):
467
  """Read a pid from a file.
468

469
  @type  pidfile: string
470
  @param pidfile: path to the file containing the pid
471
  @rtype: int
472
  @return: The process id, if the file exists and contains a valid PID,
473
           otherwise 0
474

475
  """
476
  try:
477
    raw_data = ReadFile(pidfile)
478
  except EnvironmentError, err:
479
    if err.errno != errno.ENOENT:
480
      logging.exception("Can't read pid file")
481
    return 0
482

    
483
  try:
484
    pid = int(raw_data)
485
  except ValueError, err:
486
    logging.info("Can't parse pid file contents", exc_info=True)
487
    return 0
488

    
489
  return pid
490

    
491

    
492
def MatchNameComponent(key, name_list, case_sensitive=True):
493
  """Try to match a name against a list.
494

495
  This function will try to match a name like test1 against a list
496
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
497
  this list, I{'test1'} as well as I{'test1.example'} will match, but
498
  not I{'test1.ex'}. A multiple match will be considered as no match
499
  at all (e.g. I{'test1'} against C{['test1.example.com',
500
  'test1.example.org']}), except when the key fully matches an entry
501
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
502

503
  @type key: str
504
  @param key: the name to be searched
505
  @type name_list: list
506
  @param name_list: the list of strings against which to search the key
507
  @type case_sensitive: boolean
508
  @param case_sensitive: whether to provide a case-sensitive match
509

510
  @rtype: None or str
511
  @return: None if there is no match I{or} if there are multiple matches,
512
      otherwise the element from the list which matches
513

514
  """
515
  if key in name_list:
516
    return key
517

    
518
  re_flags = 0
519
  if not case_sensitive:
520
    re_flags |= re.IGNORECASE
521
    key = key.upper()
522
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
523
  names_filtered = []
524
  string_matches = []
525
  for name in name_list:
526
    if mo.match(name) is not None:
527
      names_filtered.append(name)
528
      if not case_sensitive and key == name.upper():
529
        string_matches.append(name)
530

    
531
  if len(string_matches) == 1:
532
    return string_matches[0]
533
  if len(names_filtered) == 1:
534
    return names_filtered[0]
535
  return None
536

    
537

    
538
class HostInfo:
539
  """Class implementing resolver and hostname functionality
540

541
  """
542
  def __init__(self, name=None):
543
    """Initialize the host name object.
544

545
    If the name argument is not passed, it will use this system's
546
    name.
547

548
    """
549
    if name is None:
550
      name = self.SysName()
551

    
552
    self.query = name
553
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
554
    self.ip = self.ipaddrs[0]
555

    
556
  def ShortName(self):
557
    """Returns the hostname without domain.
558

559
    """
560
    return self.name.split('.')[0]
561

    
562
  @staticmethod
563
  def SysName():
564
    """Return the current system's name.
565

566
    This is simply a wrapper over C{socket.gethostname()}.
567

568
    """
569
    return socket.gethostname()
570

    
571
  @staticmethod
572
  def LookupHostname(hostname):
573
    """Look up hostname
574

575
    @type hostname: str
576
    @param hostname: hostname to look up
577

578
    @rtype: tuple
579
    @return: a tuple (name, aliases, ipaddrs) as returned by
580
        C{socket.gethostbyname_ex}
581
    @raise errors.ResolverError: in case of errors in resolving
582

583
    """
584
    try:
585
      result = socket.gethostbyname_ex(hostname)
586
    except socket.gaierror, err:
587
      # hostname not found in DNS
588
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
589

    
590
    return result
591

    
592

    
593
def GetHostInfo(name=None):
594
  """Lookup host name and raise an OpPrereqError for failures"""
595

    
596
  try:
597
    return HostInfo(name)
598
  except errors.ResolverError, err:
599
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
600
                               (err[0], err[2]), errors.ECODE_RESOLVER)
601

    
602

    
603
def ListVolumeGroups():
604
  """List volume groups and their size
605

606
  @rtype: dict
607
  @return:
608
       Dictionary with keys volume name and values
609
       the size of the volume
610

611
  """
612
  command = "vgs --noheadings --units m --nosuffix -o name,size"
613
  result = RunCmd(command)
614
  retval = {}
615
  if result.failed:
616
    return retval
617

    
618
  for line in result.stdout.splitlines():
619
    try:
620
      name, size = line.split()
621
      size = int(float(size))
622
    except (IndexError, ValueError), err:
623
      logging.error("Invalid output from vgs (%s): %s", err, line)
624
      continue
625

    
626
    retval[name] = size
627

    
628
  return retval
629

    
630

    
631
def BridgeExists(bridge):
632
  """Check whether the given bridge exists in the system
633

634
  @type bridge: str
635
  @param bridge: the bridge name to check
636
  @rtype: boolean
637
  @return: True if it does
638

639
  """
640
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
641

    
642

    
643
def NiceSort(name_list):
644
  """Sort a list of strings based on digit and non-digit groupings.
645

646
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
647
  will sort the list in the logical order C{['a1', 'a2', 'a10',
648
  'a11']}.
649

650
  The sort algorithm breaks each name in groups of either only-digits
651
  or no-digits. Only the first eight such groups are considered, and
652
  after that we just use what's left of the string.
653

654
  @type name_list: list
655
  @param name_list: the names to be sorted
656
  @rtype: list
657
  @return: a copy of the name list sorted with our algorithm
658

659
  """
660
  _SORTER_BASE = "(\D+|\d+)"
661
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
662
                                                  _SORTER_BASE, _SORTER_BASE,
663
                                                  _SORTER_BASE, _SORTER_BASE,
664
                                                  _SORTER_BASE, _SORTER_BASE)
665
  _SORTER_RE = re.compile(_SORTER_FULL)
666
  _SORTER_NODIGIT = re.compile("^\D*$")
667
  def _TryInt(val):
668
    """Attempts to convert a variable to integer."""
669
    if val is None or _SORTER_NODIGIT.match(val):
670
      return val
671
    rval = int(val)
672
    return rval
673

    
674
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
675
             for name in name_list]
676
  to_sort.sort()
677
  return [tup[1] for tup in to_sort]
678

    
679

    
680
def TryConvert(fn, val):
681
  """Try to convert a value ignoring errors.
682

683
  This function tries to apply function I{fn} to I{val}. If no
684
  C{ValueError} or C{TypeError} exceptions are raised, it will return
685
  the result, else it will return the original value. Any other
686
  exceptions are propagated to the caller.
687

688
  @type fn: callable
689
  @param fn: function to apply to the value
690
  @param val: the value to be converted
691
  @return: The converted value if the conversion was successful,
692
      otherwise the original value.
693

694
  """
695
  try:
696
    nv = fn(val)
697
  except (ValueError, TypeError):
698
    nv = val
699
  return nv
700

    
701

    
702
def IsValidIP(ip):
703
  """Verifies the syntax of an IPv4 address.
704

705
  This function checks if the IPv4 address passes is valid or not based
706
  on syntax (not IP range, class calculations, etc.).
707

708
  @type ip: str
709
  @param ip: the address to be checked
710
  @rtype: a regular expression match object
711
  @return: a regular expression match object, or None if the
712
      address is not valid
713

714
  """
715
  unit = "(0|[1-9]\d{0,2})"
716
  #TODO: convert and return only boolean
717
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
718

    
719

    
720
def IsValidShellParam(word):
721
  """Verifies is the given word is safe from the shell's p.o.v.
722

723
  This means that we can pass this to a command via the shell and be
724
  sure that it doesn't alter the command line and is passed as such to
725
  the actual command.
726

727
  Note that we are overly restrictive here, in order to be on the safe
728
  side.
729

730
  @type word: str
731
  @param word: the word to check
732
  @rtype: boolean
733
  @return: True if the word is 'safe'
734

735
  """
736
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
737

    
738

    
739
def BuildShellCmd(template, *args):
740
  """Build a safe shell command line from the given arguments.
741

742
  This function will check all arguments in the args list so that they
743
  are valid shell parameters (i.e. they don't contain shell
744
  metacharacters). If everything is ok, it will return the result of
745
  template % args.
746

747
  @type template: str
748
  @param template: the string holding the template for the
749
      string formatting
750
  @rtype: str
751
  @return: the expanded command line
752

753
  """
754
  for word in args:
755
    if not IsValidShellParam(word):
756
      raise errors.ProgrammerError("Shell argument '%s' contains"
757
                                   " invalid characters" % word)
758
  return template % args
759

    
760

    
761
def FormatUnit(value, units):
762
  """Formats an incoming number of MiB with the appropriate unit.
763

764
  @type value: int
765
  @param value: integer representing the value in MiB (1048576)
766
  @type units: char
767
  @param units: the type of formatting we should do:
768
      - 'h' for automatic scaling
769
      - 'm' for MiBs
770
      - 'g' for GiBs
771
      - 't' for TiBs
772
  @rtype: str
773
  @return: the formatted value (with suffix)
774

775
  """
776
  if units not in ('m', 'g', 't', 'h'):
777
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
778

    
779
  suffix = ''
780

    
781
  if units == 'm' or (units == 'h' and value < 1024):
782
    if units == 'h':
783
      suffix = 'M'
784
    return "%d%s" % (round(value, 0), suffix)
785

    
786
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
787
    if units == 'h':
788
      suffix = 'G'
789
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
790

    
791
  else:
792
    if units == 'h':
793
      suffix = 'T'
794
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
795

    
796

    
797
def ParseUnit(input_string):
798
  """Tries to extract number and scale from the given string.
799

800
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
801
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
802
  is always an int in MiB.
803

804
  """
805
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
806
  if not m:
807
    raise errors.UnitParseError("Invalid format")
808

    
809
  value = float(m.groups()[0])
810

    
811
  unit = m.groups()[1]
812
  if unit:
813
    lcunit = unit.lower()
814
  else:
815
    lcunit = 'm'
816

    
817
  if lcunit in ('m', 'mb', 'mib'):
818
    # Value already in MiB
819
    pass
820

    
821
  elif lcunit in ('g', 'gb', 'gib'):
822
    value *= 1024
823

    
824
  elif lcunit in ('t', 'tb', 'tib'):
825
    value *= 1024 * 1024
826

    
827
  else:
828
    raise errors.UnitParseError("Unknown unit: %s" % unit)
829

    
830
  # Make sure we round up
831
  if int(value) < value:
832
    value += 1
833

    
834
  # Round up to the next multiple of 4
835
  value = int(value)
836
  if value % 4:
837
    value += 4 - value % 4
838

    
839
  return value
840

    
841

    
842
def AddAuthorizedKey(file_name, key):
843
  """Adds an SSH public key to an authorized_keys file.
844

845
  @type file_name: str
846
  @param file_name: path to authorized_keys file
847
  @type key: str
848
  @param key: string containing key
849

850
  """
851
  key_fields = key.split()
852

    
853
  f = open(file_name, 'a+')
854
  try:
855
    nl = True
856
    for line in f:
857
      # Ignore whitespace changes
858
      if line.split() == key_fields:
859
        break
860
      nl = line.endswith('\n')
861
    else:
862
      if not nl:
863
        f.write("\n")
864
      f.write(key.rstrip('\r\n'))
865
      f.write("\n")
866
      f.flush()
867
  finally:
868
    f.close()
869

    
870

    
871
def RemoveAuthorizedKey(file_name, key):
872
  """Removes an SSH public key from an authorized_keys file.
873

874
  @type file_name: str
875
  @param file_name: path to authorized_keys file
876
  @type key: str
877
  @param key: string containing key
878

879
  """
880
  key_fields = key.split()
881

    
882
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
883
  try:
884
    out = os.fdopen(fd, 'w')
885
    try:
886
      f = open(file_name, 'r')
887
      try:
888
        for line in f:
889
          # Ignore whitespace changes while comparing lines
890
          if line.split() != key_fields:
891
            out.write(line)
892

    
893
        out.flush()
894
        os.rename(tmpname, file_name)
895
      finally:
896
        f.close()
897
    finally:
898
      out.close()
899
  except:
900
    RemoveFile(tmpname)
901
    raise
902

    
903

    
904
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
905
  """Sets the name of an IP address and hostname in /etc/hosts.
906

907
  @type file_name: str
908
  @param file_name: path to the file to modify (usually C{/etc/hosts})
909
  @type ip: str
910
  @param ip: the IP address
911
  @type hostname: str
912
  @param hostname: the hostname to be added
913
  @type aliases: list
914
  @param aliases: the list of aliases to add for the hostname
915

916
  """
917
  # FIXME: use WriteFile + fn rather than duplicating its efforts
918
  # Ensure aliases are unique
919
  aliases = UniqueSequence([hostname] + aliases)[1:]
920

    
921
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
922
  try:
923
    out = os.fdopen(fd, 'w')
924
    try:
925
      f = open(file_name, 'r')
926
      try:
927
        for line in f:
928
          fields = line.split()
929
          if fields and not fields[0].startswith('#') and ip == fields[0]:
930
            continue
931
          out.write(line)
932

    
933
        out.write("%s\t%s" % (ip, hostname))
934
        if aliases:
935
          out.write(" %s" % ' '.join(aliases))
936
        out.write('\n')
937

    
938
        out.flush()
939
        os.fsync(out)
940
        os.chmod(tmpname, 0644)
941
        os.rename(tmpname, file_name)
942
      finally:
943
        f.close()
944
    finally:
945
      out.close()
946
  except:
947
    RemoveFile(tmpname)
948
    raise
949

    
950

    
951
def AddHostToEtcHosts(hostname):
952
  """Wrapper around SetEtcHostsEntry.
953

954
  @type hostname: str
955
  @param hostname: a hostname that will be resolved and added to
956
      L{constants.ETC_HOSTS}
957

958
  """
959
  hi = HostInfo(name=hostname)
960
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
961

    
962

    
963
def RemoveEtcHostsEntry(file_name, hostname):
964
  """Removes a hostname from /etc/hosts.
965

966
  IP addresses without names are removed from the file.
967

968
  @type file_name: str
969
  @param file_name: path to the file to modify (usually C{/etc/hosts})
970
  @type hostname: str
971
  @param hostname: the hostname to be removed
972

973
  """
974
  # FIXME: use WriteFile + fn rather than duplicating its efforts
975
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
976
  try:
977
    out = os.fdopen(fd, 'w')
978
    try:
979
      f = open(file_name, 'r')
980
      try:
981
        for line in f:
982
          fields = line.split()
983
          if len(fields) > 1 and not fields[0].startswith('#'):
984
            names = fields[1:]
985
            if hostname in names:
986
              while hostname in names:
987
                names.remove(hostname)
988
              if names:
989
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
990
              continue
991

    
992
          out.write(line)
993

    
994
        out.flush()
995
        os.fsync(out)
996
        os.chmod(tmpname, 0644)
997
        os.rename(tmpname, file_name)
998
      finally:
999
        f.close()
1000
    finally:
1001
      out.close()
1002
  except:
1003
    RemoveFile(tmpname)
1004
    raise
1005

    
1006

    
1007
def RemoveHostFromEtcHosts(hostname):
1008
  """Wrapper around RemoveEtcHostsEntry.
1009

1010
  @type hostname: str
1011
  @param hostname: hostname that will be resolved and its
1012
      full and shot name will be removed from
1013
      L{constants.ETC_HOSTS}
1014

1015
  """
1016
  hi = HostInfo(name=hostname)
1017
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1018
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1019

    
1020

    
1021
def CreateBackup(file_name):
1022
  """Creates a backup of a file.
1023

1024
  @type file_name: str
1025
  @param file_name: file to be backed up
1026
  @rtype: str
1027
  @return: the path to the newly created backup
1028
  @raise errors.ProgrammerError: for invalid file names
1029

1030
  """
1031
  if not os.path.isfile(file_name):
1032
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1033
                                file_name)
1034

    
1035
  prefix = '%s.backup-%d.' % (os.path.basename(file_name), int(time.time()))
1036
  dir_name = os.path.dirname(file_name)
1037

    
1038
  fsrc = open(file_name, 'rb')
1039
  try:
1040
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1041
    fdst = os.fdopen(fd, 'wb')
1042
    try:
1043
      shutil.copyfileobj(fsrc, fdst)
1044
    finally:
1045
      fdst.close()
1046
  finally:
1047
    fsrc.close()
1048

    
1049
  return backup_name
1050

    
1051

    
1052
def ShellQuote(value):
1053
  """Quotes shell argument according to POSIX.
1054

1055
  @type value: str
1056
  @param value: the argument to be quoted
1057
  @rtype: str
1058
  @return: the quoted value
1059

1060
  """
1061
  if _re_shell_unquoted.match(value):
1062
    return value
1063
  else:
1064
    return "'%s'" % value.replace("'", "'\\''")
1065

    
1066

    
1067
def ShellQuoteArgs(args):
1068
  """Quotes a list of shell arguments.
1069

1070
  @type args: list
1071
  @param args: list of arguments to be quoted
1072
  @rtype: str
1073
  @return: the quoted arguments concatenated with spaces
1074

1075
  """
1076
  return ' '.join([ShellQuote(i) for i in args])
1077

    
1078

    
1079
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1080
  """Simple ping implementation using TCP connect(2).
1081

1082
  Check if the given IP is reachable by doing attempting a TCP connect
1083
  to it.
1084

1085
  @type target: str
1086
  @param target: the IP or hostname to ping
1087
  @type port: int
1088
  @param port: the port to connect to
1089
  @type timeout: int
1090
  @param timeout: the timeout on the connection attempt
1091
  @type live_port_needed: boolean
1092
  @param live_port_needed: whether a closed port will cause the
1093
      function to return failure, as if there was a timeout
1094
  @type source: str or None
1095
  @param source: if specified, will cause the connect to be made
1096
      from this specific source address; failures to bind other
1097
      than C{EADDRNOTAVAIL} will be ignored
1098

1099
  """
1100
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1101

    
1102
  success = False
1103

    
1104
  if source is not None:
1105
    try:
1106
      sock.bind((source, 0))
1107
    except socket.error, (errcode, _):
1108
      if errcode == errno.EADDRNOTAVAIL:
1109
        success = False
1110

    
1111
  sock.settimeout(timeout)
1112

    
1113
  try:
1114
    sock.connect((target, port))
1115
    sock.close()
1116
    success = True
1117
  except socket.timeout:
1118
    success = False
1119
  except socket.error, (errcode, _):
1120
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1121

    
1122
  return success
1123

    
1124

    
1125
def OwnIpAddress(address):
1126
  """Check if the current host has the the given IP address.
1127

1128
  Currently this is done by TCP-pinging the address from the loopback
1129
  address.
1130

1131
  @type address: string
1132
  @param address: the address to check
1133
  @rtype: bool
1134
  @return: True if we own the address
1135

1136
  """
1137
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1138
                 source=constants.LOCALHOST_IP_ADDRESS)
1139

    
1140

    
1141
def ListVisibleFiles(path):
1142
  """Returns a list of visible files in a directory.
1143

1144
  @type path: str
1145
  @param path: the directory to enumerate
1146
  @rtype: list
1147
  @return: the list of all files not starting with a dot
1148

1149
  """
1150
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1151
  files.sort()
1152
  return files
1153

    
1154

    
1155
def GetHomeDir(user, default=None):
1156
  """Try to get the homedir of the given user.
1157

1158
  The user can be passed either as a string (denoting the name) or as
1159
  an integer (denoting the user id). If the user is not found, the
1160
  'default' argument is returned, which defaults to None.
1161

1162
  """
1163
  try:
1164
    if isinstance(user, basestring):
1165
      result = pwd.getpwnam(user)
1166
    elif isinstance(user, (int, long)):
1167
      result = pwd.getpwuid(user)
1168
    else:
1169
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1170
                                   type(user))
1171
  except KeyError:
1172
    return default
1173
  return result.pw_dir
1174

    
1175

    
1176
def NewUUID():
1177
  """Returns a random UUID.
1178

1179
  @note: This is a Linux-specific method as it uses the /proc
1180
      filesystem.
1181
  @rtype: str
1182

1183
  """
1184
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1185

    
1186

    
1187
def GenerateSecret(numbytes=20):
1188
  """Generates a random secret.
1189

1190
  This will generate a pseudo-random secret returning an hex string
1191
  (so that it can be used where an ASCII string is needed).
1192

1193
  @param numbytes: the number of bytes which will be represented by the returned
1194
      string (defaulting to 20, the length of a SHA1 hash)
1195
  @rtype: str
1196
  @return: an hex representation of the pseudo-random sequence
1197

1198
  """
1199
  return os.urandom(numbytes).encode('hex')
1200

    
1201

    
1202
def EnsureDirs(dirs):
1203
  """Make required directories, if they don't exist.
1204

1205
  @param dirs: list of tuples (dir_name, dir_mode)
1206
  @type dirs: list of (string, integer)
1207

1208
  """
1209
  for dir_name, dir_mode in dirs:
1210
    try:
1211
      os.mkdir(dir_name, dir_mode)
1212
    except EnvironmentError, err:
1213
      if err.errno != errno.EEXIST:
1214
        raise errors.GenericError("Cannot create needed directory"
1215
                                  " '%s': %s" % (dir_name, err))
1216
    if not os.path.isdir(dir_name):
1217
      raise errors.GenericError("%s is not a directory" % dir_name)
1218

    
1219

    
1220
def ReadFile(file_name, size=None):
1221
  """Reads a file.
1222

1223
  @type size: None or int
1224
  @param size: Read at most size bytes
1225
  @rtype: str
1226
  @return: the (possibly partial) content of the file
1227

1228
  """
1229
  f = open(file_name, "r")
1230
  try:
1231
    if size is None:
1232
      return f.read()
1233
    else:
1234
      return f.read(size)
1235
  finally:
1236
    f.close()
1237

    
1238

    
1239
def WriteFile(file_name, fn=None, data=None,
1240
              mode=None, uid=-1, gid=-1,
1241
              atime=None, mtime=None, close=True,
1242
              dry_run=False, backup=False,
1243
              prewrite=None, postwrite=None):
1244
  """(Over)write a file atomically.
1245

1246
  The file_name and either fn (a function taking one argument, the
1247
  file descriptor, and which should write the data to it) or data (the
1248
  contents of the file) must be passed. The other arguments are
1249
  optional and allow setting the file mode, owner and group, and the
1250
  mtime/atime of the file.
1251

1252
  If the function doesn't raise an exception, it has succeeded and the
1253
  target file has the new contents. If the function has raised an
1254
  exception, an existing target file should be unmodified and the
1255
  temporary file should be removed.
1256

1257
  @type file_name: str
1258
  @param file_name: the target filename
1259
  @type fn: callable
1260
  @param fn: content writing function, called with
1261
      file descriptor as parameter
1262
  @type data: str
1263
  @param data: contents of the file
1264
  @type mode: int
1265
  @param mode: file mode
1266
  @type uid: int
1267
  @param uid: the owner of the file
1268
  @type gid: int
1269
  @param gid: the group of the file
1270
  @type atime: int
1271
  @param atime: a custom access time to be set on the file
1272
  @type mtime: int
1273
  @param mtime: a custom modification time to be set on the file
1274
  @type close: boolean
1275
  @param close: whether to close file after writing it
1276
  @type prewrite: callable
1277
  @param prewrite: function to be called before writing content
1278
  @type postwrite: callable
1279
  @param postwrite: function to be called after writing content
1280

1281
  @rtype: None or int
1282
  @return: None if the 'close' parameter evaluates to True,
1283
      otherwise the file descriptor
1284

1285
  @raise errors.ProgrammerError: if any of the arguments are not valid
1286

1287
  """
1288
  if not os.path.isabs(file_name):
1289
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1290
                                 " absolute: '%s'" % file_name)
1291

    
1292
  if [fn, data].count(None) != 1:
1293
    raise errors.ProgrammerError("fn or data required")
1294

    
1295
  if [atime, mtime].count(None) == 1:
1296
    raise errors.ProgrammerError("Both atime and mtime must be either"
1297
                                 " set or None")
1298

    
1299
  if backup and not dry_run and os.path.isfile(file_name):
1300
    CreateBackup(file_name)
1301

    
1302
  dir_name, base_name = os.path.split(file_name)
1303
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1304
  do_remove = True
1305
  # here we need to make sure we remove the temp file, if any error
1306
  # leaves it in place
1307
  try:
1308
    if uid != -1 or gid != -1:
1309
      os.chown(new_name, uid, gid)
1310
    if mode:
1311
      os.chmod(new_name, mode)
1312
    if callable(prewrite):
1313
      prewrite(fd)
1314
    if data is not None:
1315
      os.write(fd, data)
1316
    else:
1317
      fn(fd)
1318
    if callable(postwrite):
1319
      postwrite(fd)
1320
    os.fsync(fd)
1321
    if atime is not None and mtime is not None:
1322
      os.utime(new_name, (atime, mtime))
1323
    if not dry_run:
1324
      os.rename(new_name, file_name)
1325
      do_remove = False
1326
  finally:
1327
    if close:
1328
      os.close(fd)
1329
      result = None
1330
    else:
1331
      result = fd
1332
    if do_remove:
1333
      RemoveFile(new_name)
1334

    
1335
  return result
1336

    
1337

    
1338
def FirstFree(seq, base=0):
1339
  """Returns the first non-existing integer from seq.
1340

1341
  The seq argument should be a sorted list of positive integers. The
1342
  first time the index of an element is smaller than the element
1343
  value, the index will be returned.
1344

1345
  The base argument is used to start at a different offset,
1346
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1347

1348
  Example: C{[0, 1, 3]} will return I{2}.
1349

1350
  @type seq: sequence
1351
  @param seq: the sequence to be analyzed.
1352
  @type base: int
1353
  @param base: use this value as the base index of the sequence
1354
  @rtype: int
1355
  @return: the first non-used index in the sequence
1356

1357
  """
1358
  for idx, elem in enumerate(seq):
1359
    assert elem >= base, "Passed element is higher than base offset"
1360
    if elem > idx + base:
1361
      # idx is not used
1362
      return idx + base
1363
  return None
1364

    
1365

    
1366
def all(seq, pred=bool):
1367
  "Returns True if pred(x) is True for every element in the iterable"
1368
  for _ in itertools.ifilterfalse(pred, seq):
1369
    return False
1370
  return True
1371

    
1372

    
1373
def any(seq, pred=bool):
1374
  "Returns True if pred(x) is True for at least one element in the iterable"
1375
  for _ in itertools.ifilter(pred, seq):
1376
    return True
1377
  return False
1378

    
1379

    
1380
def UniqueSequence(seq):
1381
  """Returns a list with unique elements.
1382

1383
  Element order is preserved.
1384

1385
  @type seq: sequence
1386
  @param seq: the sequence with the source elements
1387
  @rtype: list
1388
  @return: list of unique elements from seq
1389

1390
  """
1391
  seen = set()
1392
  return [i for i in seq if i not in seen and not seen.add(i)]
1393

    
1394

    
1395
def IsValidMac(mac):
1396
  """Predicate to check if a MAC address is valid.
1397

1398
  Checks whether the supplied MAC address is formally correct, only
1399
  accepts colon separated format.
1400

1401
  @type mac: str
1402
  @param mac: the MAC to be validated
1403
  @rtype: boolean
1404
  @return: True is the MAC seems valid
1405

1406
  """
1407
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$")
1408
  return mac_check.match(mac) is not None
1409

    
1410

    
1411
def TestDelay(duration):
1412
  """Sleep for a fixed amount of time.
1413

1414
  @type duration: float
1415
  @param duration: the sleep duration
1416
  @rtype: boolean
1417
  @return: False for negative value, True otherwise
1418

1419
  """
1420
  if duration < 0:
1421
    return False, "Invalid sleep duration"
1422
  time.sleep(duration)
1423
  return True, None
1424

    
1425

    
1426
def _CloseFDNoErr(fd, retries=5):
1427
  """Close a file descriptor ignoring errors.
1428

1429
  @type fd: int
1430
  @param fd: the file descriptor
1431
  @type retries: int
1432
  @param retries: how many retries to make, in case we get any
1433
      other error than EBADF
1434

1435
  """
1436
  try:
1437
    os.close(fd)
1438
  except OSError, err:
1439
    if err.errno != errno.EBADF:
1440
      if retries > 0:
1441
        _CloseFDNoErr(fd, retries - 1)
1442
    # else either it's closed already or we're out of retries, so we
1443
    # ignore this and go on
1444

    
1445

    
1446
def CloseFDs(noclose_fds=None):
1447
  """Close file descriptors.
1448

1449
  This closes all file descriptors above 2 (i.e. except
1450
  stdin/out/err).
1451

1452
  @type noclose_fds: list or None
1453
  @param noclose_fds: if given, it denotes a list of file descriptor
1454
      that should not be closed
1455

1456
  """
1457
  # Default maximum for the number of available file descriptors.
1458
  if 'SC_OPEN_MAX' in os.sysconf_names:
1459
    try:
1460
      MAXFD = os.sysconf('SC_OPEN_MAX')
1461
      if MAXFD < 0:
1462
        MAXFD = 1024
1463
    except OSError:
1464
      MAXFD = 1024
1465
  else:
1466
    MAXFD = 1024
1467
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
1468
  if (maxfd == resource.RLIM_INFINITY):
1469
    maxfd = MAXFD
1470

    
1471
  # Iterate through and close all file descriptors (except the standard ones)
1472
  for fd in range(3, maxfd):
1473
    if noclose_fds and fd in noclose_fds:
1474
      continue
1475
    _CloseFDNoErr(fd)
1476

    
1477

    
1478
def Daemonize(logfile):
1479
  """Daemonize the current process.
1480

1481
  This detaches the current process from the controlling terminal and
1482
  runs it in the background as a daemon.
1483

1484
  @type logfile: str
1485
  @param logfile: the logfile to which we should redirect stdout/stderr
1486
  @rtype: int
1487
  @return: the value zero
1488

1489
  """
1490
  UMASK = 077
1491
  WORKDIR = "/"
1492

    
1493
  # this might fail
1494
  pid = os.fork()
1495
  if (pid == 0):  # The first child.
1496
    os.setsid()
1497
    # this might fail
1498
    pid = os.fork() # Fork a second child.
1499
    if (pid == 0):  # The second child.
1500
      os.chdir(WORKDIR)
1501
      os.umask(UMASK)
1502
    else:
1503
      # exit() or _exit()?  See below.
1504
      os._exit(0) # Exit parent (the first child) of the second child.
1505
  else:
1506
    os._exit(0) # Exit parent of the first child.
1507

    
1508
  for fd in range(3):
1509
    _CloseFDNoErr(fd)
1510
  i = os.open("/dev/null", os.O_RDONLY) # stdin
1511
  assert i == 0, "Can't close/reopen stdin"
1512
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
1513
  assert i == 1, "Can't close/reopen stdout"
1514
  # Duplicate standard output to standard error.
1515
  os.dup2(1, 2)
1516
  return 0
1517

    
1518

    
1519
def DaemonPidFileName(name):
1520
  """Compute a ganeti pid file absolute path
1521

1522
  @type name: str
1523
  @param name: the daemon name
1524
  @rtype: str
1525
  @return: the full path to the pidfile corresponding to the given
1526
      daemon name
1527

1528
  """
1529
  return os.path.join(constants.RUN_GANETI_DIR, "%s.pid" % name)
1530

    
1531

    
1532
def WritePidFile(name):
1533
  """Write the current process pidfile.
1534

1535
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
1536

1537
  @type name: str
1538
  @param name: the daemon name to use
1539
  @raise errors.GenericError: if the pid file already exists and
1540
      points to a live process
1541

1542
  """
1543
  pid = os.getpid()
1544
  pidfilename = DaemonPidFileName(name)
1545
  if IsProcessAlive(ReadPidFile(pidfilename)):
1546
    raise errors.GenericError("%s contains a live process" % pidfilename)
1547

    
1548
  WriteFile(pidfilename, data="%d\n" % pid)
1549

    
1550

    
1551
def RemovePidFile(name):
1552
  """Remove the current process pidfile.
1553

1554
  Any errors are ignored.
1555

1556
  @type name: str
1557
  @param name: the daemon name used to derive the pidfile name
1558

1559
  """
1560
  pidfilename = DaemonPidFileName(name)
1561
  # TODO: we could check here that the file contains our pid
1562
  try:
1563
    RemoveFile(pidfilename)
1564
  except:
1565
    pass
1566

    
1567

    
1568
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
1569
                waitpid=False):
1570
  """Kill a process given by its pid.
1571

1572
  @type pid: int
1573
  @param pid: The PID to terminate.
1574
  @type signal_: int
1575
  @param signal_: The signal to send, by default SIGTERM
1576
  @type timeout: int
1577
  @param timeout: The timeout after which, if the process is still alive,
1578
                  a SIGKILL will be sent. If not positive, no such checking
1579
                  will be done
1580
  @type waitpid: boolean
1581
  @param waitpid: If true, we should waitpid on this process after
1582
      sending signals, since it's our own child and otherwise it
1583
      would remain as zombie
1584

1585
  """
1586
  def _helper(pid, signal_, wait):
1587
    """Simple helper to encapsulate the kill/waitpid sequence"""
1588
    os.kill(pid, signal_)
1589
    if wait:
1590
      try:
1591
        os.waitpid(pid, os.WNOHANG)
1592
      except OSError:
1593
        pass
1594

    
1595
  if pid <= 0:
1596
    # kill with pid=0 == suicide
1597
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
1598

    
1599
  if not IsProcessAlive(pid):
1600
    return
1601

    
1602
  _helper(pid, signal_, waitpid)
1603

    
1604
  if timeout <= 0:
1605
    return
1606

    
1607
  def _CheckProcess():
1608
    if not IsProcessAlive(pid):
1609
      return
1610

    
1611
    try:
1612
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
1613
    except OSError:
1614
      raise RetryAgain()
1615

    
1616
    if result_pid > 0:
1617
      return
1618

    
1619
    raise RetryAgain()
1620

    
1621
  try:
1622
    # Wait up to $timeout seconds
1623
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
1624
  except RetryTimeout:
1625
    pass
1626

    
1627
  if IsProcessAlive(pid):
1628
    # Kill process if it's still alive
1629
    _helper(pid, signal.SIGKILL, waitpid)
1630

    
1631

    
1632
def FindFile(name, search_path, test=os.path.exists):
1633
  """Look for a filesystem object in a given path.
1634

1635
  This is an abstract method to search for filesystem object (files,
1636
  dirs) under a given search path.
1637

1638
  @type name: str
1639
  @param name: the name to look for
1640
  @type search_path: str
1641
  @param search_path: location to start at
1642
  @type test: callable
1643
  @param test: a function taking one argument that should return True
1644
      if the a given object is valid; the default value is
1645
      os.path.exists, causing only existing files to be returned
1646
  @rtype: str or None
1647
  @return: full path to the object if found, None otherwise
1648

1649
  """
1650
  for dir_name in search_path:
1651
    item_name = os.path.sep.join([dir_name, name])
1652
    if test(item_name):
1653
      return item_name
1654
  return None
1655

    
1656

    
1657
def CheckVolumeGroupSize(vglist, vgname, minsize):
1658
  """Checks if the volume group list is valid.
1659

1660
  The function will check if a given volume group is in the list of
1661
  volume groups and has a minimum size.
1662

1663
  @type vglist: dict
1664
  @param vglist: dictionary of volume group names and their size
1665
  @type vgname: str
1666
  @param vgname: the volume group we should check
1667
  @type minsize: int
1668
  @param minsize: the minimum size we accept
1669
  @rtype: None or str
1670
  @return: None for success, otherwise the error message
1671

1672
  """
1673
  vgsize = vglist.get(vgname, None)
1674
  if vgsize is None:
1675
    return "volume group '%s' missing" % vgname
1676
  elif vgsize < minsize:
1677
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
1678
            (vgname, minsize, vgsize))
1679
  return None
1680

    
1681

    
1682
def SplitTime(value):
1683
  """Splits time as floating point number into a tuple.
1684

1685
  @param value: Time in seconds
1686
  @type value: int or float
1687
  @return: Tuple containing (seconds, microseconds)
1688

1689
  """
1690
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
1691

    
1692
  assert 0 <= seconds, \
1693
    "Seconds must be larger than or equal to 0, but are %s" % seconds
1694
  assert 0 <= microseconds <= 999999, \
1695
    "Microseconds must be 0-999999, but are %s" % microseconds
1696

    
1697
  return (int(seconds), int(microseconds))
1698

    
1699

    
1700
def MergeTime(timetuple):
1701
  """Merges a tuple into time as a floating point number.
1702

1703
  @param timetuple: Time as tuple, (seconds, microseconds)
1704
  @type timetuple: tuple
1705
  @return: Time as a floating point number expressed in seconds
1706

1707
  """
1708
  (seconds, microseconds) = timetuple
1709

    
1710
  assert 0 <= seconds, \
1711
    "Seconds must be larger than or equal to 0, but are %s" % seconds
1712
  assert 0 <= microseconds <= 999999, \
1713
    "Microseconds must be 0-999999, but are %s" % microseconds
1714

    
1715
  return float(seconds) + (float(microseconds) * 0.000001)
1716

    
1717

    
1718
def GetDaemonPort(daemon_name):
1719
  """Get the daemon port for this cluster.
1720

1721
  Note that this routine does not read a ganeti-specific file, but
1722
  instead uses C{socket.getservbyname} to allow pre-customization of
1723
  this parameter outside of Ganeti.
1724

1725
  @type daemon_name: string
1726
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
1727
  @rtype: int
1728

1729
  """
1730
  if daemon_name not in constants.DAEMONS_PORTS:
1731
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
1732

    
1733
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
1734
  try:
1735
    port = socket.getservbyname(daemon_name, proto)
1736
  except socket.error:
1737
    port = default_port
1738

    
1739
  return port
1740

    
1741

    
1742
def SetupLogging(logfile, debug=False, stderr_logging=False, program="",
1743
                 multithreaded=False):
1744
  """Configures the logging module.
1745

1746
  @type logfile: str
1747
  @param logfile: the filename to which we should log
1748
  @type debug: boolean
1749
  @param debug: whether to enable debug messages too or
1750
      only those at C{INFO} and above level
1751
  @type stderr_logging: boolean
1752
  @param stderr_logging: whether we should also log to the standard error
1753
  @type program: str
1754
  @param program: the name under which we should log messages
1755
  @type multithreaded: boolean
1756
  @param multithreaded: if True, will add the thread name to the log file
1757
  @raise EnvironmentError: if we can't open the log file and
1758
      stderr logging is disabled
1759

1760
  """
1761
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
1762
  if multithreaded:
1763
    fmt += "/%(threadName)s"
1764
  if debug:
1765
    fmt += " %(module)s:%(lineno)s"
1766
  fmt += " %(levelname)s %(message)s"
1767
  formatter = logging.Formatter(fmt)
1768

    
1769
  root_logger = logging.getLogger("")
1770
  root_logger.setLevel(logging.NOTSET)
1771

    
1772
  # Remove all previously setup handlers
1773
  for handler in root_logger.handlers:
1774
    handler.close()
1775
    root_logger.removeHandler(handler)
1776

    
1777
  if stderr_logging:
1778
    stderr_handler = logging.StreamHandler()
1779
    stderr_handler.setFormatter(formatter)
1780
    if debug:
1781
      stderr_handler.setLevel(logging.NOTSET)
1782
    else:
1783
      stderr_handler.setLevel(logging.CRITICAL)
1784
    root_logger.addHandler(stderr_handler)
1785

    
1786
  # this can fail, if the logging directories are not setup or we have
1787
  # a permisssion problem; in this case, it's best to log but ignore
1788
  # the error if stderr_logging is True, and if false we re-raise the
1789
  # exception since otherwise we could run but without any logs at all
1790
  try:
1791
    logfile_handler = logging.FileHandler(logfile)
1792
    logfile_handler.setFormatter(formatter)
1793
    if debug:
1794
      logfile_handler.setLevel(logging.DEBUG)
1795
    else:
1796
      logfile_handler.setLevel(logging.INFO)
1797
    root_logger.addHandler(logfile_handler)
1798
  except EnvironmentError:
1799
    if stderr_logging:
1800
      logging.exception("Failed to enable logging to file '%s'", logfile)
1801
    else:
1802
      # we need to re-raise the exception
1803
      raise
1804

    
1805

    
1806
def IsNormAbsPath(path):
1807
  """Check whether a path is absolute and also normalized
1808

1809
  This avoids things like /dir/../../other/path to be valid.
1810

1811
  """
1812
  return os.path.normpath(path) == path and os.path.isabs(path)
1813

    
1814

    
1815
def TailFile(fname, lines=20):
1816
  """Return the last lines from a file.
1817

1818
  @note: this function will only read and parse the last 4KB of
1819
      the file; if the lines are very long, it could be that less
1820
      than the requested number of lines are returned
1821

1822
  @param fname: the file name
1823
  @type lines: int
1824
  @param lines: the (maximum) number of lines to return
1825

1826
  """
1827
  fd = open(fname, "r")
1828
  try:
1829
    fd.seek(0, 2)
1830
    pos = fd.tell()
1831
    pos = max(0, pos-4096)
1832
    fd.seek(pos, 0)
1833
    raw_data = fd.read()
1834
  finally:
1835
    fd.close()
1836

    
1837
  rows = raw_data.splitlines()
1838
  return rows[-lines:]
1839

    
1840

    
1841
def SafeEncode(text):
1842
  """Return a 'safe' version of a source string.
1843

1844
  This function mangles the input string and returns a version that
1845
  should be safe to display/encode as ASCII. To this end, we first
1846
  convert it to ASCII using the 'backslashreplace' encoding which
1847
  should get rid of any non-ASCII chars, and then we process it
1848
  through a loop copied from the string repr sources in the python; we
1849
  don't use string_escape anymore since that escape single quotes and
1850
  backslashes too, and that is too much; and that escaping is not
1851
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
1852

1853
  @type text: str or unicode
1854
  @param text: input data
1855
  @rtype: str
1856
  @return: a safe version of text
1857

1858
  """
1859
  if isinstance(text, unicode):
1860
    # only if unicode; if str already, we handle it below
1861
    text = text.encode('ascii', 'backslashreplace')
1862
  resu = ""
1863
  for char in text:
1864
    c = ord(char)
1865
    if char  == '\t':
1866
      resu += r'\t'
1867
    elif char == '\n':
1868
      resu += r'\n'
1869
    elif char == '\r':
1870
      resu += r'\'r'
1871
    elif c < 32 or c >= 127: # non-printable
1872
      resu += "\\x%02x" % (c & 0xff)
1873
    else:
1874
      resu += char
1875
  return resu
1876

    
1877

    
1878
def BytesToMebibyte(value):
1879
  """Converts bytes to mebibytes.
1880

1881
  @type value: int
1882
  @param value: Value in bytes
1883
  @rtype: int
1884
  @return: Value in mebibytes
1885

1886
  """
1887
  return int(round(value / (1024.0 * 1024.0), 0))
1888

    
1889

    
1890
def CalculateDirectorySize(path):
1891
  """Calculates the size of a directory recursively.
1892

1893
  @type path: string
1894
  @param path: Path to directory
1895
  @rtype: int
1896
  @return: Size in mebibytes
1897

1898
  """
1899
  size = 0
1900

    
1901
  for (curpath, _, files) in os.walk(path):
1902
    for filename in files:
1903
      st = os.lstat(os.path.join(curpath, filename))
1904
      size += st.st_size
1905

    
1906
  return BytesToMebibyte(size)
1907

    
1908

    
1909
def GetFilesystemStats(path):
1910
  """Returns the total and free space on a filesystem.
1911

1912
  @type path: string
1913
  @param path: Path on filesystem to be examined
1914
  @rtype: int
1915
  @return: tuple of (Total space, Free space) in mebibytes
1916

1917
  """
1918
  st = os.statvfs(path)
1919

    
1920
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
1921
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
1922
  return (tsize, fsize)
1923

    
1924

    
1925
def LockedMethod(fn):
1926
  """Synchronized object access decorator.
1927

1928
  This decorator is intended to protect access to an object using the
1929
  object's own lock which is hardcoded to '_lock'.
1930

1931
  """
1932
  def _LockDebug(*args, **kwargs):
1933
    if debug_locks:
1934
      logging.debug(*args, **kwargs)
1935

    
1936
  def wrapper(self, *args, **kwargs):
1937
    assert hasattr(self, '_lock')
1938
    lock = self._lock
1939
    _LockDebug("Waiting for %s", lock)
1940
    lock.acquire()
1941
    try:
1942
      _LockDebug("Acquired %s", lock)
1943
      result = fn(self, *args, **kwargs)
1944
    finally:
1945
      _LockDebug("Releasing %s", lock)
1946
      lock.release()
1947
      _LockDebug("Released %s", lock)
1948
    return result
1949
  return wrapper
1950

    
1951

    
1952
def LockFile(fd):
1953
  """Locks a file using POSIX locks.
1954

1955
  @type fd: int
1956
  @param fd: the file descriptor we need to lock
1957

1958
  """
1959
  try:
1960
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
1961
  except IOError, err:
1962
    if err.errno == errno.EAGAIN:
1963
      raise errors.LockError("File already locked")
1964
    raise
1965

    
1966

    
1967
def FormatTime(val):
1968
  """Formats a time value.
1969

1970
  @type val: float or None
1971
  @param val: the timestamp as returned by time.time()
1972
  @return: a string value or N/A if we don't have a valid timestamp
1973

1974
  """
1975
  if val is None or not isinstance(val, (int, float)):
1976
    return "N/A"
1977
  # these two codes works on Linux, but they are not guaranteed on all
1978
  # platforms
1979
  return time.strftime("%F %T", time.localtime(val))
1980

    
1981

    
1982
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
1983
  """Reads the watcher pause file.
1984

1985
  @type filename: string
1986
  @param filename: Path to watcher pause file
1987
  @type now: None, float or int
1988
  @param now: Current time as Unix timestamp
1989
  @type remove_after: int
1990
  @param remove_after: Remove watcher pause file after specified amount of
1991
    seconds past the pause end time
1992

1993
  """
1994
  if now is None:
1995
    now = time.time()
1996

    
1997
  try:
1998
    value = ReadFile(filename)
1999
  except IOError, err:
2000
    if err.errno != errno.ENOENT:
2001
      raise
2002
    value = None
2003

    
2004
  if value is not None:
2005
    try:
2006
      value = int(value)
2007
    except ValueError:
2008
      logging.warning(("Watcher pause file (%s) contains invalid value,"
2009
                       " removing it"), filename)
2010
      RemoveFile(filename)
2011
      value = None
2012

    
2013
    if value is not None:
2014
      # Remove file if it's outdated
2015
      if now > (value + remove_after):
2016
        RemoveFile(filename)
2017
        value = None
2018

    
2019
      elif now > value:
2020
        value = None
2021

    
2022
  return value
2023

    
2024

    
2025
class RetryTimeout(Exception):
2026
  """Retry loop timed out.
2027

2028
  """
2029

    
2030

    
2031
class RetryAgain(Exception):
2032
  """Retry again.
2033

2034
  """
2035

    
2036

    
2037
class _RetryDelayCalculator(object):
2038
  """Calculator for increasing delays.
2039

2040
  """
2041
  __slots__ = [
2042
    "_factor",
2043
    "_limit",
2044
    "_next",
2045
    "_start",
2046
    ]
2047

    
2048
  def __init__(self, start, factor, limit):
2049
    """Initializes this class.
2050

2051
    @type start: float
2052
    @param start: Initial delay
2053
    @type factor: float
2054
    @param factor: Factor for delay increase
2055
    @type limit: float or None
2056
    @param limit: Upper limit for delay or None for no limit
2057

2058
    """
2059
    assert start > 0.0
2060
    assert factor >= 1.0
2061
    assert limit is None or limit >= 0.0
2062

    
2063
    self._start = start
2064
    self._factor = factor
2065
    self._limit = limit
2066

    
2067
    self._next = start
2068

    
2069
  def __call__(self):
2070
    """Returns current delay and calculates the next one.
2071

2072
    """
2073
    current = self._next
2074

    
2075
    # Update for next run
2076
    if self._limit is None or self._next < self._limit:
2077
      self._next = max(self._limit, self._next * self._factor)
2078

    
2079
    return current
2080

    
2081

    
2082
#: Special delay to specify whole remaining timeout
2083
RETRY_REMAINING_TIME = object()
2084

    
2085

    
2086
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
2087
          _time_fn=time.time):
2088
  """Call a function repeatedly until it succeeds.
2089

2090
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
2091
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
2092
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
2093

2094
  C{delay} can be one of the following:
2095
    - callable returning the delay length as a float
2096
    - Tuple of (start, factor, limit)
2097
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
2098
      useful when overriding L{wait_fn} to wait for an external event)
2099
    - A static delay as a number (int or float)
2100

2101
  @type fn: callable
2102
  @param fn: Function to be called
2103
  @param delay: Either a callable (returning the delay), a tuple of (start,
2104
                factor, limit) (see L{_RetryDelayCalculator}),
2105
                L{RETRY_REMAINING_TIME} or a number (int or float)
2106
  @type timeout: float
2107
  @param timeout: Total timeout
2108
  @type wait_fn: callable
2109
  @param wait_fn: Waiting function
2110
  @return: Return value of function
2111

2112
  """
2113
  assert callable(fn)
2114
  assert callable(wait_fn)
2115
  assert callable(_time_fn)
2116

    
2117
  if args is None:
2118
    args = []
2119

    
2120
  end_time = _time_fn() + timeout
2121

    
2122
  if callable(delay):
2123
    # External function to calculate delay
2124
    calc_delay = delay
2125

    
2126
  elif isinstance(delay, (tuple, list)):
2127
    # Increasing delay with optional upper boundary
2128
    (start, factor, limit) = delay
2129
    calc_delay = _RetryDelayCalculator(start, factor, limit)
2130

    
2131
  elif delay is RETRY_REMAINING_TIME:
2132
    # Always use the remaining time
2133
    calc_delay = None
2134

    
2135
  else:
2136
    # Static delay
2137
    calc_delay = lambda: delay
2138

    
2139
  assert calc_delay is None or callable(calc_delay)
2140

    
2141
  while True:
2142
    try:
2143
      return fn(*args)
2144
    except RetryAgain:
2145
      pass
2146

    
2147
    remaining_time = end_time - _time_fn()
2148

    
2149
    if remaining_time < 0.0:
2150
      raise RetryTimeout()
2151

    
2152
    assert remaining_time >= 0.0
2153

    
2154
    if calc_delay is None:
2155
      wait_fn(remaining_time)
2156
    else:
2157
      current_delay = calc_delay()
2158
      if current_delay > 0.0:
2159
        wait_fn(current_delay)
2160

    
2161

    
2162
class FileLock(object):
2163
  """Utility class for file locks.
2164

2165
  """
2166
  def __init__(self, filename):
2167
    """Constructor for FileLock.
2168

2169
    This will open the file denoted by the I{filename} argument.
2170

2171
    @type filename: str
2172
    @param filename: path to the file to be locked
2173

2174
    """
2175
    self.filename = filename
2176
    self.fd = open(self.filename, "w")
2177

    
2178
  def __del__(self):
2179
    self.Close()
2180

    
2181
  def Close(self):
2182
    """Close the file and release the lock.
2183

2184
    """
2185
    if self.fd:
2186
      self.fd.close()
2187
      self.fd = None
2188

    
2189
  def _flock(self, flag, blocking, timeout, errmsg):
2190
    """Wrapper for fcntl.flock.
2191

2192
    @type flag: int
2193
    @param flag: operation flag
2194
    @type blocking: bool
2195
    @param blocking: whether the operation should be done in blocking mode.
2196
    @type timeout: None or float
2197
    @param timeout: for how long the operation should be retried (implies
2198
                    non-blocking mode).
2199
    @type errmsg: string
2200
    @param errmsg: error message in case operation fails.
2201

2202
    """
2203
    assert self.fd, "Lock was closed"
2204
    assert timeout is None or timeout >= 0, \
2205
      "If specified, timeout must be positive"
2206

    
2207
    if timeout is not None:
2208
      flag |= fcntl.LOCK_NB
2209
      timeout_end = time.time() + timeout
2210

    
2211
    # Blocking doesn't have effect with timeout
2212
    elif not blocking:
2213
      flag |= fcntl.LOCK_NB
2214
      timeout_end = None
2215

    
2216
    # TODO: Convert to utils.Retry
2217

    
2218
    retry = True
2219
    while retry:
2220
      try:
2221
        fcntl.flock(self.fd, flag)
2222
        retry = False
2223
      except IOError, err:
2224
        if err.errno in (errno.EAGAIN, ):
2225
          if timeout_end is not None and time.time() < timeout_end:
2226
            # Wait before trying again
2227
            time.sleep(max(0.1, min(1.0, timeout)))
2228
          else:
2229
            raise errors.LockError(errmsg)
2230
        else:
2231
          logging.exception("fcntl.flock failed")
2232
          raise
2233

    
2234
  def Exclusive(self, blocking=False, timeout=None):
2235
    """Locks the file in exclusive mode.
2236

2237
    @type blocking: boolean
2238
    @param blocking: whether to block and wait until we
2239
        can lock the file or return immediately
2240
    @type timeout: int or None
2241
    @param timeout: if not None, the duration to wait for the lock
2242
        (in blocking mode)
2243

2244
    """
2245
    self._flock(fcntl.LOCK_EX, blocking, timeout,
2246
                "Failed to lock %s in exclusive mode" % self.filename)
2247

    
2248
  def Shared(self, blocking=False, timeout=None):
2249
    """Locks the file in shared mode.
2250

2251
    @type blocking: boolean
2252
    @param blocking: whether to block and wait until we
2253
        can lock the file or return immediately
2254
    @type timeout: int or None
2255
    @param timeout: if not None, the duration to wait for the lock
2256
        (in blocking mode)
2257

2258
    """
2259
    self._flock(fcntl.LOCK_SH, blocking, timeout,
2260
                "Failed to lock %s in shared mode" % self.filename)
2261

    
2262
  def Unlock(self, blocking=True, timeout=None):
2263
    """Unlocks the file.
2264

2265
    According to C{flock(2)}, unlocking can also be a nonblocking
2266
    operation::
2267

2268
      To make a non-blocking request, include LOCK_NB with any of the above
2269
      operations.
2270

2271
    @type blocking: boolean
2272
    @param blocking: whether to block and wait until we
2273
        can lock the file or return immediately
2274
    @type timeout: int or None
2275
    @param timeout: if not None, the duration to wait for the lock
2276
        (in blocking mode)
2277

2278
    """
2279
    self._flock(fcntl.LOCK_UN, blocking, timeout,
2280
                "Failed to unlock %s" % self.filename)
2281

    
2282

    
2283
def SignalHandled(signums):
2284
  """Signal Handled decoration.
2285

2286
  This special decorator installs a signal handler and then calls the target
2287
  function. The function must accept a 'signal_handlers' keyword argument,
2288
  which will contain a dict indexed by signal number, with SignalHandler
2289
  objects as values.
2290

2291
  The decorator can be safely stacked with iself, to handle multiple signals
2292
  with different handlers.
2293

2294
  @type signums: list
2295
  @param signums: signals to intercept
2296

2297
  """
2298
  def wrap(fn):
2299
    def sig_function(*args, **kwargs):
2300
      assert 'signal_handlers' not in kwargs or \
2301
             kwargs['signal_handlers'] is None or \
2302
             isinstance(kwargs['signal_handlers'], dict), \
2303
             "Wrong signal_handlers parameter in original function call"
2304
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
2305
        signal_handlers = kwargs['signal_handlers']
2306
      else:
2307
        signal_handlers = {}
2308
        kwargs['signal_handlers'] = signal_handlers
2309
      sighandler = SignalHandler(signums)
2310
      try:
2311
        for sig in signums:
2312
          signal_handlers[sig] = sighandler
2313
        return fn(*args, **kwargs)
2314
      finally:
2315
        sighandler.Reset()
2316
    return sig_function
2317
  return wrap
2318

    
2319

    
2320
class SignalHandler(object):
2321
  """Generic signal handler class.
2322

2323
  It automatically restores the original handler when deconstructed or
2324
  when L{Reset} is called. You can either pass your own handler
2325
  function in or query the L{called} attribute to detect whether the
2326
  signal was sent.
2327

2328
  @type signum: list
2329
  @ivar signum: the signals we handle
2330
  @type called: boolean
2331
  @ivar called: tracks whether any of the signals have been raised
2332

2333
  """
2334
  def __init__(self, signum):
2335
    """Constructs a new SignalHandler instance.
2336

2337
    @type signum: int or list of ints
2338
    @param signum: Single signal number or set of signal numbers
2339

2340
    """
2341
    self.signum = set(signum)
2342
    self.called = False
2343

    
2344
    self._previous = {}
2345
    try:
2346
      for signum in self.signum:
2347
        # Setup handler
2348
        prev_handler = signal.signal(signum, self._HandleSignal)
2349
        try:
2350
          self._previous[signum] = prev_handler
2351
        except:
2352
          # Restore previous handler
2353
          signal.signal(signum, prev_handler)
2354
          raise
2355
    except:
2356
      # Reset all handlers
2357
      self.Reset()
2358
      # Here we have a race condition: a handler may have already been called,
2359
      # but there's not much we can do about it at this point.
2360
      raise
2361

    
2362
  def __del__(self):
2363
    self.Reset()
2364

    
2365
  def Reset(self):
2366
    """Restore previous handler.
2367

2368
    This will reset all the signals to their previous handlers.
2369

2370
    """
2371
    for signum, prev_handler in self._previous.items():
2372
      signal.signal(signum, prev_handler)
2373
      # If successful, remove from dict
2374
      del self._previous[signum]
2375

    
2376
  def Clear(self):
2377
    """Unsets the L{called} flag.
2378

2379
    This function can be used in case a signal may arrive several times.
2380

2381
    """
2382
    self.called = False
2383

    
2384
  def _HandleSignal(self, signum, frame):
2385
    """Actual signal handling function.
2386

2387
    """
2388
    # This is not nice and not absolutely atomic, but it appears to be the only
2389
    # solution in Python -- there are no atomic types.
2390
    self.called = True
2391

    
2392

    
2393
class FieldSet(object):
2394
  """A simple field set.
2395

2396
  Among the features are:
2397
    - checking if a string is among a list of static string or regex objects
2398
    - checking if a whole list of string matches
2399
    - returning the matching groups from a regex match
2400

2401
  Internally, all fields are held as regular expression objects.
2402

2403
  """
2404
  def __init__(self, *items):
2405
    self.items = [re.compile("^%s$" % value) for value in items]
2406

    
2407
  def Extend(self, other_set):
2408
    """Extend the field set with the items from another one"""
2409
    self.items.extend(other_set.items)
2410

    
2411
  def Matches(self, field):
2412
    """Checks if a field matches the current set
2413

2414
    @type field: str
2415
    @param field: the string to match
2416
    @return: either None or a regular expression match object
2417

2418
    """
2419
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
2420
      return m
2421
    return None
2422

    
2423
  def NonMatching(self, items):
2424
    """Returns the list of fields not matching the current set
2425

2426
    @type items: list
2427
    @param items: the list of fields to check
2428
    @rtype: list
2429
    @return: list of non-matching fields
2430

2431
    """
2432
    return [val for val in items if not self.Matches(val)]