Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 858905fb

History | View | Annotate | Download (100.2 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52

    
53
from cStringIO import StringIO
54

    
55
try:
56
  # pylint: disable-msg=F0401
57
  import ctypes
58
except ImportError:
59
  ctypes = None
60

    
61
from ganeti import errors
62
from ganeti import constants
63
from ganeti import compat
64
from ganeti import netutils
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
85

    
86
# Certificate verification results
87
(CERT_WARNING,
88
 CERT_ERROR) = range(1, 3)
89

    
90
# Flags for mlockall() (from bits/mman.h)
91
_MCL_CURRENT = 1
92
_MCL_FUTURE = 2
93

    
94

    
95
class RunResult(object):
96
  """Holds the result of running external programs.
97

98
  @type exit_code: int
99
  @ivar exit_code: the exit code of the program, or None (if the program
100
      didn't exit())
101
  @type signal: int or None
102
  @ivar signal: the signal that caused the program to finish, or None
103
      (if the program wasn't terminated by a signal)
104
  @type stdout: str
105
  @ivar stdout: the standard output of the program
106
  @type stderr: str
107
  @ivar stderr: the standard error of the program
108
  @type failed: boolean
109
  @ivar failed: True in case the program was
110
      terminated by a signal or exited with a non-zero exit code
111
  @ivar fail_reason: a string detailing the termination reason
112

113
  """
114
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
115
               "failed", "fail_reason", "cmd"]
116

    
117

    
118
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
119
    self.cmd = cmd
120
    self.exit_code = exit_code
121
    self.signal = signal_
122
    self.stdout = stdout
123
    self.stderr = stderr
124
    self.failed = (signal_ is not None or exit_code != 0)
125

    
126
    if self.signal is not None:
127
      self.fail_reason = "terminated by signal %s" % self.signal
128
    elif self.exit_code is not None:
129
      self.fail_reason = "exited with exit code %s" % self.exit_code
130
    else:
131
      self.fail_reason = "unable to determine termination reason"
132

    
133
    if self.failed:
134
      logging.debug("Command '%s' failed (%s); output: %s",
135
                    self.cmd, self.fail_reason, self.output)
136

    
137
  def _GetOutput(self):
138
    """Returns the combined stdout and stderr for easier usage.
139

140
    """
141
    return self.stdout + self.stderr
142

    
143
  output = property(_GetOutput, None, None, "Return full output")
144

    
145

    
146
def _BuildCmdEnvironment(env, reset):
147
  """Builds the environment for an external program.
148

149
  """
150
  if reset:
151
    cmd_env = {}
152
  else:
153
    cmd_env = os.environ.copy()
154
    cmd_env["LC_ALL"] = "C"
155

    
156
  if env is not None:
157
    cmd_env.update(env)
158

    
159
  return cmd_env
160

    
161

    
162
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
163
  """Execute a (shell) command.
164

165
  The command should not read from its standard input, as it will be
166
  closed.
167

168
  @type cmd: string or list
169
  @param cmd: Command to run
170
  @type env: dict
171
  @param env: Additional environment variables
172
  @type output: str
173
  @param output: if desired, the output of the command can be
174
      saved in a file instead of the RunResult instance; this
175
      parameter denotes the file name (if not None)
176
  @type cwd: string
177
  @param cwd: if specified, will be used as the working
178
      directory for the command; the default will be /
179
  @type reset_env: boolean
180
  @param reset_env: whether to reset or keep the default os environment
181
  @rtype: L{RunResult}
182
  @return: RunResult instance
183
  @raise errors.ProgrammerError: if we call this when forks are disabled
184

185
  """
186
  if no_fork:
187
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
188

    
189
  if isinstance(cmd, basestring):
190
    strcmd = cmd
191
    shell = True
192
  else:
193
    cmd = [str(val) for val in cmd]
194
    strcmd = ShellQuoteArgs(cmd)
195
    shell = False
196

    
197
  if output:
198
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
199
  else:
200
    logging.debug("RunCmd %s", strcmd)
201

    
202
  cmd_env = _BuildCmdEnvironment(env, reset_env)
203

    
204
  try:
205
    if output is None:
206
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
207
    else:
208
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
209
      out = err = ""
210
  except OSError, err:
211
    if err.errno == errno.ENOENT:
212
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
213
                               (strcmd, err))
214
    else:
215
      raise
216

    
217
  if status >= 0:
218
    exitcode = status
219
    signal_ = None
220
  else:
221
    exitcode = None
222
    signal_ = -status
223

    
224
  return RunResult(exitcode, signal_, out, err, strcmd)
225

    
226

    
227
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
228
                pidfile=None):
229
  """Start a daemon process after forking twice.
230

231
  @type cmd: string or list
232
  @param cmd: Command to run
233
  @type env: dict
234
  @param env: Additional environment variables
235
  @type cwd: string
236
  @param cwd: Working directory for the program
237
  @type output: string
238
  @param output: Path to file in which to save the output
239
  @type output_fd: int
240
  @param output_fd: File descriptor for output
241
  @type pidfile: string
242
  @param pidfile: Process ID file
243
  @rtype: int
244
  @return: Daemon process ID
245
  @raise errors.ProgrammerError: if we call this when forks are disabled
246

247
  """
248
  if no_fork:
249
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
250
                                 " disabled")
251

    
252
  if output and not (bool(output) ^ (output_fd is not None)):
253
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
254
                                 " specified")
255

    
256
  if isinstance(cmd, basestring):
257
    cmd = ["/bin/sh", "-c", cmd]
258

    
259
  strcmd = ShellQuoteArgs(cmd)
260

    
261
  if output:
262
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
263
  else:
264
    logging.debug("StartDaemon %s", strcmd)
265

    
266
  cmd_env = _BuildCmdEnvironment(env, False)
267

    
268
  # Create pipe for sending PID back
269
  (pidpipe_read, pidpipe_write) = os.pipe()
270
  try:
271
    try:
272
      # Create pipe for sending error messages
273
      (errpipe_read, errpipe_write) = os.pipe()
274
      try:
275
        try:
276
          # First fork
277
          pid = os.fork()
278
          if pid == 0:
279
            try:
280
              # Child process, won't return
281
              _StartDaemonChild(errpipe_read, errpipe_write,
282
                                pidpipe_read, pidpipe_write,
283
                                cmd, cmd_env, cwd,
284
                                output, output_fd, pidfile)
285
            finally:
286
              # Well, maybe child process failed
287
              os._exit(1) # pylint: disable-msg=W0212
288
        finally:
289
          _CloseFDNoErr(errpipe_write)
290

    
291
        # Wait for daemon to be started (or an error message to arrive) and read
292
        # up to 100 KB as an error message
293
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
294
      finally:
295
        _CloseFDNoErr(errpipe_read)
296
    finally:
297
      _CloseFDNoErr(pidpipe_write)
298

    
299
    # Read up to 128 bytes for PID
300
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
301
  finally:
302
    _CloseFDNoErr(pidpipe_read)
303

    
304
  # Try to avoid zombies by waiting for child process
305
  try:
306
    os.waitpid(pid, 0)
307
  except OSError:
308
    pass
309

    
310
  if errormsg:
311
    raise errors.OpExecError("Error when starting daemon process: %r" %
312
                             errormsg)
313

    
314
  try:
315
    return int(pidtext)
316
  except (ValueError, TypeError), err:
317
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
318
                             (pidtext, err))
319

    
320

    
321
def _StartDaemonChild(errpipe_read, errpipe_write,
322
                      pidpipe_read, pidpipe_write,
323
                      args, env, cwd,
324
                      output, fd_output, pidfile):
325
  """Child process for starting daemon.
326

327
  """
328
  try:
329
    # Close parent's side
330
    _CloseFDNoErr(errpipe_read)
331
    _CloseFDNoErr(pidpipe_read)
332

    
333
    # First child process
334
    os.chdir("/")
335
    os.umask(077)
336
    os.setsid()
337

    
338
    # And fork for the second time
339
    pid = os.fork()
340
    if pid != 0:
341
      # Exit first child process
342
      os._exit(0) # pylint: disable-msg=W0212
343

    
344
    # Make sure pipe is closed on execv* (and thereby notifies original process)
345
    SetCloseOnExecFlag(errpipe_write, True)
346

    
347
    # List of file descriptors to be left open
348
    noclose_fds = [errpipe_write]
349

    
350
    # Open PID file
351
    if pidfile:
352
      try:
353
        # TODO: Atomic replace with another locked file instead of writing into
354
        # it after creating
355
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
356

    
357
        # Lock the PID file (and fail if not possible to do so). Any code
358
        # wanting to send a signal to the daemon should try to lock the PID
359
        # file before reading it. If acquiring the lock succeeds, the daemon is
360
        # no longer running and the signal should not be sent.
361
        LockFile(fd_pidfile)
362

    
363
        os.write(fd_pidfile, "%d\n" % os.getpid())
364
      except Exception, err:
365
        raise Exception("Creating and locking PID file failed: %s" % err)
366

    
367
      # Keeping the file open to hold the lock
368
      noclose_fds.append(fd_pidfile)
369

    
370
      SetCloseOnExecFlag(fd_pidfile, False)
371
    else:
372
      fd_pidfile = None
373

    
374
    # Open /dev/null
375
    fd_devnull = os.open(os.devnull, os.O_RDWR)
376

    
377
    assert not output or (bool(output) ^ (fd_output is not None))
378

    
379
    if fd_output is not None:
380
      pass
381
    elif output:
382
      # Open output file
383
      try:
384
        # TODO: Implement flag to set append=yes/no
385
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
386
      except EnvironmentError, err:
387
        raise Exception("Opening output file failed: %s" % err)
388
    else:
389
      fd_output = fd_devnull
390

    
391
    # Redirect standard I/O
392
    os.dup2(fd_devnull, 0)
393
    os.dup2(fd_output, 1)
394
    os.dup2(fd_output, 2)
395

    
396
    # Send daemon PID to parent
397
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
398

    
399
    # Close all file descriptors except stdio and error message pipe
400
    CloseFDs(noclose_fds=noclose_fds)
401

    
402
    # Change working directory
403
    os.chdir(cwd)
404

    
405
    if env is None:
406
      os.execvp(args[0], args)
407
    else:
408
      os.execvpe(args[0], args, env)
409
  except: # pylint: disable-msg=W0702
410
    try:
411
      # Report errors to original process
412
      buf = str(sys.exc_info()[1])
413

    
414
      RetryOnSignal(os.write, errpipe_write, buf)
415
    except: # pylint: disable-msg=W0702
416
      # Ignore errors in error handling
417
      pass
418

    
419
  os._exit(1) # pylint: disable-msg=W0212
420

    
421

    
422
def _RunCmdPipe(cmd, env, via_shell, cwd):
423
  """Run a command and return its output.
424

425
  @type  cmd: string or list
426
  @param cmd: Command to run
427
  @type env: dict
428
  @param env: The environment to use
429
  @type via_shell: bool
430
  @param via_shell: if we should run via the shell
431
  @type cwd: string
432
  @param cwd: the working directory for the program
433
  @rtype: tuple
434
  @return: (out, err, status)
435

436
  """
437
  poller = select.poll()
438
  child = subprocess.Popen(cmd, shell=via_shell,
439
                           stderr=subprocess.PIPE,
440
                           stdout=subprocess.PIPE,
441
                           stdin=subprocess.PIPE,
442
                           close_fds=True, env=env,
443
                           cwd=cwd)
444

    
445
  child.stdin.close()
446
  poller.register(child.stdout, select.POLLIN)
447
  poller.register(child.stderr, select.POLLIN)
448
  out = StringIO()
449
  err = StringIO()
450
  fdmap = {
451
    child.stdout.fileno(): (out, child.stdout),
452
    child.stderr.fileno(): (err, child.stderr),
453
    }
454
  for fd in fdmap:
455
    SetNonblockFlag(fd, True)
456

    
457
  while fdmap:
458
    pollresult = RetryOnSignal(poller.poll)
459

    
460
    for fd, event in pollresult:
461
      if event & select.POLLIN or event & select.POLLPRI:
462
        data = fdmap[fd][1].read()
463
        # no data from read signifies EOF (the same as POLLHUP)
464
        if not data:
465
          poller.unregister(fd)
466
          del fdmap[fd]
467
          continue
468
        fdmap[fd][0].write(data)
469
      if (event & select.POLLNVAL or event & select.POLLHUP or
470
          event & select.POLLERR):
471
        poller.unregister(fd)
472
        del fdmap[fd]
473

    
474
  out = out.getvalue()
475
  err = err.getvalue()
476

    
477
  status = child.wait()
478
  return out, err, status
479

    
480

    
481
def _RunCmdFile(cmd, env, via_shell, output, cwd):
482
  """Run a command and save its output to a file.
483

484
  @type  cmd: string or list
485
  @param cmd: Command to run
486
  @type env: dict
487
  @param env: The environment to use
488
  @type via_shell: bool
489
  @param via_shell: if we should run via the shell
490
  @type output: str
491
  @param output: the filename in which to save the output
492
  @type cwd: string
493
  @param cwd: the working directory for the program
494
  @rtype: int
495
  @return: the exit status
496

497
  """
498
  fh = open(output, "a")
499
  try:
500
    child = subprocess.Popen(cmd, shell=via_shell,
501
                             stderr=subprocess.STDOUT,
502
                             stdout=fh,
503
                             stdin=subprocess.PIPE,
504
                             close_fds=True, env=env,
505
                             cwd=cwd)
506

    
507
    child.stdin.close()
508
    status = child.wait()
509
  finally:
510
    fh.close()
511
  return status
512

    
513

    
514
def SetCloseOnExecFlag(fd, enable):
515
  """Sets or unsets the close-on-exec flag on a file descriptor.
516

517
  @type fd: int
518
  @param fd: File descriptor
519
  @type enable: bool
520
  @param enable: Whether to set or unset it.
521

522
  """
523
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
524

    
525
  if enable:
526
    flags |= fcntl.FD_CLOEXEC
527
  else:
528
    flags &= ~fcntl.FD_CLOEXEC
529

    
530
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
531

    
532

    
533
def SetNonblockFlag(fd, enable):
534
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
535

536
  @type fd: int
537
  @param fd: File descriptor
538
  @type enable: bool
539
  @param enable: Whether to set or unset it
540

541
  """
542
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
543

    
544
  if enable:
545
    flags |= os.O_NONBLOCK
546
  else:
547
    flags &= ~os.O_NONBLOCK
548

    
549
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
550

    
551

    
552
def RetryOnSignal(fn, *args, **kwargs):
553
  """Calls a function again if it failed due to EINTR.
554

555
  """
556
  while True:
557
    try:
558
      return fn(*args, **kwargs)
559
    except EnvironmentError, err:
560
      if err.errno != errno.EINTR:
561
        raise
562
    except (socket.error, select.error), err:
563
      # In python 2.6 and above select.error is an IOError, so it's handled
564
      # above, in 2.5 and below it's not, and it's handled here.
565
      if not (err.args and err.args[0] == errno.EINTR):
566
        raise
567

    
568

    
569
def RunParts(dir_name, env=None, reset_env=False):
570
  """Run Scripts or programs in a directory
571

572
  @type dir_name: string
573
  @param dir_name: absolute path to a directory
574
  @type env: dict
575
  @param env: The environment to use
576
  @type reset_env: boolean
577
  @param reset_env: whether to reset or keep the default os environment
578
  @rtype: list of tuples
579
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
580

581
  """
582
  rr = []
583

    
584
  try:
585
    dir_contents = ListVisibleFiles(dir_name)
586
  except OSError, err:
587
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
588
    return rr
589

    
590
  for relname in sorted(dir_contents):
591
    fname = PathJoin(dir_name, relname)
592
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
593
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
594
      rr.append((relname, constants.RUNPARTS_SKIP, None))
595
    else:
596
      try:
597
        result = RunCmd([fname], env=env, reset_env=reset_env)
598
      except Exception, err: # pylint: disable-msg=W0703
599
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
600
      else:
601
        rr.append((relname, constants.RUNPARTS_RUN, result))
602

    
603
  return rr
604

    
605

    
606
def RemoveFile(filename):
607
  """Remove a file ignoring some errors.
608

609
  Remove a file, ignoring non-existing ones or directories. Other
610
  errors are passed.
611

612
  @type filename: str
613
  @param filename: the file to be removed
614

615
  """
616
  try:
617
    os.unlink(filename)
618
  except OSError, err:
619
    if err.errno not in (errno.ENOENT, errno.EISDIR):
620
      raise
621

    
622

    
623
def RemoveDir(dirname):
624
  """Remove an empty directory.
625

626
  Remove a directory, ignoring non-existing ones.
627
  Other errors are passed. This includes the case,
628
  where the directory is not empty, so it can't be removed.
629

630
  @type dirname: str
631
  @param dirname: the empty directory to be removed
632

633
  """
634
  try:
635
    os.rmdir(dirname)
636
  except OSError, err:
637
    if err.errno != errno.ENOENT:
638
      raise
639

    
640

    
641
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
642
  """Renames a file.
643

644
  @type old: string
645
  @param old: Original path
646
  @type new: string
647
  @param new: New path
648
  @type mkdir: bool
649
  @param mkdir: Whether to create target directory if it doesn't exist
650
  @type mkdir_mode: int
651
  @param mkdir_mode: Mode for newly created directories
652

653
  """
654
  try:
655
    return os.rename(old, new)
656
  except OSError, err:
657
    # In at least one use case of this function, the job queue, directory
658
    # creation is very rare. Checking for the directory before renaming is not
659
    # as efficient.
660
    if mkdir and err.errno == errno.ENOENT:
661
      # Create directory and try again
662
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
663

    
664
      return os.rename(old, new)
665

    
666
    raise
667

    
668

    
669
def Makedirs(path, mode=0750):
670
  """Super-mkdir; create a leaf directory and all intermediate ones.
671

672
  This is a wrapper around C{os.makedirs} adding error handling not implemented
673
  before Python 2.5.
674

675
  """
676
  try:
677
    os.makedirs(path, mode)
678
  except OSError, err:
679
    # Ignore EEXIST. This is only handled in os.makedirs as included in
680
    # Python 2.5 and above.
681
    if err.errno != errno.EEXIST or not os.path.exists(path):
682
      raise
683

    
684

    
685
def ResetTempfileModule():
686
  """Resets the random name generator of the tempfile module.
687

688
  This function should be called after C{os.fork} in the child process to
689
  ensure it creates a newly seeded random generator. Otherwise it would
690
  generate the same random parts as the parent process. If several processes
691
  race for the creation of a temporary file, this could lead to one not getting
692
  a temporary name.
693

694
  """
695
  # pylint: disable-msg=W0212
696
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
697
    tempfile._once_lock.acquire()
698
    try:
699
      # Reset random name generator
700
      tempfile._name_sequence = None
701
    finally:
702
      tempfile._once_lock.release()
703
  else:
704
    logging.critical("The tempfile module misses at least one of the"
705
                     " '_once_lock' and '_name_sequence' attributes")
706

    
707

    
708
def _FingerprintFile(filename):
709
  """Compute the fingerprint of a file.
710

711
  If the file does not exist, a None will be returned
712
  instead.
713

714
  @type filename: str
715
  @param filename: the filename to checksum
716
  @rtype: str
717
  @return: the hex digest of the sha checksum of the contents
718
      of the file
719

720
  """
721
  if not (os.path.exists(filename) and os.path.isfile(filename)):
722
    return None
723

    
724
  f = open(filename)
725

    
726
  fp = compat.sha1_hash()
727
  while True:
728
    data = f.read(4096)
729
    if not data:
730
      break
731

    
732
    fp.update(data)
733

    
734
  return fp.hexdigest()
735

    
736

    
737
def FingerprintFiles(files):
738
  """Compute fingerprints for a list of files.
739

740
  @type files: list
741
  @param files: the list of filename to fingerprint
742
  @rtype: dict
743
  @return: a dictionary filename: fingerprint, holding only
744
      existing files
745

746
  """
747
  ret = {}
748

    
749
  for filename in files:
750
    cksum = _FingerprintFile(filename)
751
    if cksum:
752
      ret[filename] = cksum
753

    
754
  return ret
755

    
756

    
757
def ForceDictType(target, key_types, allowed_values=None):
758
  """Force the values of a dict to have certain types.
759

760
  @type target: dict
761
  @param target: the dict to update
762
  @type key_types: dict
763
  @param key_types: dict mapping target dict keys to types
764
                    in constants.ENFORCEABLE_TYPES
765
  @type allowed_values: list
766
  @keyword allowed_values: list of specially allowed values
767

768
  """
769
  if allowed_values is None:
770
    allowed_values = []
771

    
772
  if not isinstance(target, dict):
773
    msg = "Expected dictionary, got '%s'" % target
774
    raise errors.TypeEnforcementError(msg)
775

    
776
  for key in target:
777
    if key not in key_types:
778
      msg = "Unknown key '%s'" % key
779
      raise errors.TypeEnforcementError(msg)
780

    
781
    if target[key] in allowed_values:
782
      continue
783

    
784
    ktype = key_types[key]
785
    if ktype not in constants.ENFORCEABLE_TYPES:
786
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
787
      raise errors.ProgrammerError(msg)
788

    
789
    if ktype == constants.VTYPE_STRING:
790
      if not isinstance(target[key], basestring):
791
        if isinstance(target[key], bool) and not target[key]:
792
          target[key] = ''
793
        else:
794
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
795
          raise errors.TypeEnforcementError(msg)
796
    elif ktype == constants.VTYPE_BOOL:
797
      if isinstance(target[key], basestring) and target[key]:
798
        if target[key].lower() == constants.VALUE_FALSE:
799
          target[key] = False
800
        elif target[key].lower() == constants.VALUE_TRUE:
801
          target[key] = True
802
        else:
803
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
804
          raise errors.TypeEnforcementError(msg)
805
      elif target[key]:
806
        target[key] = True
807
      else:
808
        target[key] = False
809
    elif ktype == constants.VTYPE_SIZE:
810
      try:
811
        target[key] = ParseUnit(target[key])
812
      except errors.UnitParseError, err:
813
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
814
              (key, target[key], err)
815
        raise errors.TypeEnforcementError(msg)
816
    elif ktype == constants.VTYPE_INT:
817
      try:
818
        target[key] = int(target[key])
819
      except (ValueError, TypeError):
820
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
821
        raise errors.TypeEnforcementError(msg)
822

    
823

    
824
def _GetProcStatusPath(pid):
825
  """Returns the path for a PID's proc status file.
826

827
  @type pid: int
828
  @param pid: Process ID
829
  @rtype: string
830

831
  """
832
  return "/proc/%d/status" % pid
833

    
834

    
835
def IsProcessAlive(pid):
836
  """Check if a given pid exists on the system.
837

838
  @note: zombie status is not handled, so zombie processes
839
      will be returned as alive
840
  @type pid: int
841
  @param pid: the process ID to check
842
  @rtype: boolean
843
  @return: True if the process exists
844

845
  """
846
  def _TryStat(name):
847
    try:
848
      os.stat(name)
849
      return True
850
    except EnvironmentError, err:
851
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
852
        return False
853
      elif err.errno == errno.EINVAL:
854
        raise RetryAgain(err)
855
      raise
856

    
857
  assert isinstance(pid, int), "pid must be an integer"
858
  if pid <= 0:
859
    return False
860

    
861
  # /proc in a multiprocessor environment can have strange behaviors.
862
  # Retry the os.stat a few times until we get a good result.
863
  try:
864
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
865
                 args=[_GetProcStatusPath(pid)])
866
  except RetryTimeout, err:
867
    err.RaiseInner()
868

    
869

    
870
def _ParseSigsetT(sigset):
871
  """Parse a rendered sigset_t value.
872

873
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
874
  function.
875

876
  @type sigset: string
877
  @param sigset: Rendered signal set from /proc/$pid/status
878
  @rtype: set
879
  @return: Set of all enabled signal numbers
880

881
  """
882
  result = set()
883

    
884
  signum = 0
885
  for ch in reversed(sigset):
886
    chv = int(ch, 16)
887

    
888
    # The following could be done in a loop, but it's easier to read and
889
    # understand in the unrolled form
890
    if chv & 1:
891
      result.add(signum + 1)
892
    if chv & 2:
893
      result.add(signum + 2)
894
    if chv & 4:
895
      result.add(signum + 3)
896
    if chv & 8:
897
      result.add(signum + 4)
898

    
899
    signum += 4
900

    
901
  return result
902

    
903

    
904
def _GetProcStatusField(pstatus, field):
905
  """Retrieves a field from the contents of a proc status file.
906

907
  @type pstatus: string
908
  @param pstatus: Contents of /proc/$pid/status
909
  @type field: string
910
  @param field: Name of field whose value should be returned
911
  @rtype: string
912

913
  """
914
  for line in pstatus.splitlines():
915
    parts = line.split(":", 1)
916

    
917
    if len(parts) < 2 or parts[0] != field:
918
      continue
919

    
920
    return parts[1].strip()
921

    
922
  return None
923

    
924

    
925
def IsProcessHandlingSignal(pid, signum, status_path=None):
926
  """Checks whether a process is handling a signal.
927

928
  @type pid: int
929
  @param pid: Process ID
930
  @type signum: int
931
  @param signum: Signal number
932
  @rtype: bool
933

934
  """
935
  if status_path is None:
936
    status_path = _GetProcStatusPath(pid)
937

    
938
  try:
939
    proc_status = ReadFile(status_path)
940
  except EnvironmentError, err:
941
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
942
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
943
      return False
944
    raise
945

    
946
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
947
  if sigcgt is None:
948
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
949

    
950
  # Now check whether signal is handled
951
  return signum in _ParseSigsetT(sigcgt)
952

    
953

    
954
def ReadPidFile(pidfile):
955
  """Read a pid from a file.
956

957
  @type  pidfile: string
958
  @param pidfile: path to the file containing the pid
959
  @rtype: int
960
  @return: The process id, if the file exists and contains a valid PID,
961
           otherwise 0
962

963
  """
964
  try:
965
    raw_data = ReadOneLineFile(pidfile)
966
  except EnvironmentError, err:
967
    if err.errno != errno.ENOENT:
968
      logging.exception("Can't read pid file")
969
    return 0
970

    
971
  try:
972
    pid = int(raw_data)
973
  except (TypeError, ValueError), err:
974
    logging.info("Can't parse pid file contents", exc_info=True)
975
    return 0
976

    
977
  return pid
978

    
979

    
980
def ReadLockedPidFile(path):
981
  """Reads a locked PID file.
982

983
  This can be used together with L{StartDaemon}.
984

985
  @type path: string
986
  @param path: Path to PID file
987
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
988

989
  """
990
  try:
991
    fd = os.open(path, os.O_RDONLY)
992
  except EnvironmentError, err:
993
    if err.errno == errno.ENOENT:
994
      # PID file doesn't exist
995
      return None
996
    raise
997

    
998
  try:
999
    try:
1000
      # Try to acquire lock
1001
      LockFile(fd)
1002
    except errors.LockError:
1003
      # Couldn't lock, daemon is running
1004
      return int(os.read(fd, 100))
1005
  finally:
1006
    os.close(fd)
1007

    
1008
  return None
1009

    
1010

    
1011
def MatchNameComponent(key, name_list, case_sensitive=True):
1012
  """Try to match a name against a list.
1013

1014
  This function will try to match a name like test1 against a list
1015
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1016
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1017
  not I{'test1.ex'}. A multiple match will be considered as no match
1018
  at all (e.g. I{'test1'} against C{['test1.example.com',
1019
  'test1.example.org']}), except when the key fully matches an entry
1020
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1021

1022
  @type key: str
1023
  @param key: the name to be searched
1024
  @type name_list: list
1025
  @param name_list: the list of strings against which to search the key
1026
  @type case_sensitive: boolean
1027
  @param case_sensitive: whether to provide a case-sensitive match
1028

1029
  @rtype: None or str
1030
  @return: None if there is no match I{or} if there are multiple matches,
1031
      otherwise the element from the list which matches
1032

1033
  """
1034
  if key in name_list:
1035
    return key
1036

    
1037
  re_flags = 0
1038
  if not case_sensitive:
1039
    re_flags |= re.IGNORECASE
1040
    key = key.upper()
1041
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1042
  names_filtered = []
1043
  string_matches = []
1044
  for name in name_list:
1045
    if mo.match(name) is not None:
1046
      names_filtered.append(name)
1047
      if not case_sensitive and key == name.upper():
1048
        string_matches.append(name)
1049

    
1050
  if len(string_matches) == 1:
1051
    return string_matches[0]
1052
  if len(names_filtered) == 1:
1053
    return names_filtered[0]
1054
  return None
1055

    
1056

    
1057
def ValidateServiceName(name):
1058
  """Validate the given service name.
1059

1060
  @type name: number or string
1061
  @param name: Service name or port specification
1062

1063
  """
1064
  try:
1065
    numport = int(name)
1066
  except (ValueError, TypeError):
1067
    # Non-numeric service name
1068
    valid = _VALID_SERVICE_NAME_RE.match(name)
1069
  else:
1070
    # Numeric port (protocols other than TCP or UDP might need adjustments
1071
    # here)
1072
    valid = (numport >= 0 and numport < (1 << 16))
1073

    
1074
  if not valid:
1075
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1076
                               errors.ECODE_INVAL)
1077

    
1078
  return name
1079

    
1080

    
1081
def ListVolumeGroups():
1082
  """List volume groups and their size
1083

1084
  @rtype: dict
1085
  @return:
1086
       Dictionary with keys volume name and values
1087
       the size of the volume
1088

1089
  """
1090
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1091
  result = RunCmd(command)
1092
  retval = {}
1093
  if result.failed:
1094
    return retval
1095

    
1096
  for line in result.stdout.splitlines():
1097
    try:
1098
      name, size = line.split()
1099
      size = int(float(size))
1100
    except (IndexError, ValueError), err:
1101
      logging.error("Invalid output from vgs (%s): %s", err, line)
1102
      continue
1103

    
1104
    retval[name] = size
1105

    
1106
  return retval
1107

    
1108

    
1109
def BridgeExists(bridge):
1110
  """Check whether the given bridge exists in the system
1111

1112
  @type bridge: str
1113
  @param bridge: the bridge name to check
1114
  @rtype: boolean
1115
  @return: True if it does
1116

1117
  """
1118
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1119

    
1120

    
1121
def NiceSort(name_list):
1122
  """Sort a list of strings based on digit and non-digit groupings.
1123

1124
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1125
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1126
  'a11']}.
1127

1128
  The sort algorithm breaks each name in groups of either only-digits
1129
  or no-digits. Only the first eight such groups are considered, and
1130
  after that we just use what's left of the string.
1131

1132
  @type name_list: list
1133
  @param name_list: the names to be sorted
1134
  @rtype: list
1135
  @return: a copy of the name list sorted with our algorithm
1136

1137
  """
1138
  _SORTER_BASE = "(\D+|\d+)"
1139
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1140
                                                  _SORTER_BASE, _SORTER_BASE,
1141
                                                  _SORTER_BASE, _SORTER_BASE,
1142
                                                  _SORTER_BASE, _SORTER_BASE)
1143
  _SORTER_RE = re.compile(_SORTER_FULL)
1144
  _SORTER_NODIGIT = re.compile("^\D*$")
1145
  def _TryInt(val):
1146
    """Attempts to convert a variable to integer."""
1147
    if val is None or _SORTER_NODIGIT.match(val):
1148
      return val
1149
    rval = int(val)
1150
    return rval
1151

    
1152
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1153
             for name in name_list]
1154
  to_sort.sort()
1155
  return [tup[1] for tup in to_sort]
1156

    
1157

    
1158
def TryConvert(fn, val):
1159
  """Try to convert a value ignoring errors.
1160

1161
  This function tries to apply function I{fn} to I{val}. If no
1162
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1163
  the result, else it will return the original value. Any other
1164
  exceptions are propagated to the caller.
1165

1166
  @type fn: callable
1167
  @param fn: function to apply to the value
1168
  @param val: the value to be converted
1169
  @return: The converted value if the conversion was successful,
1170
      otherwise the original value.
1171

1172
  """
1173
  try:
1174
    nv = fn(val)
1175
  except (ValueError, TypeError):
1176
    nv = val
1177
  return nv
1178

    
1179

    
1180
def IsValidShellParam(word):
1181
  """Verifies is the given word is safe from the shell's p.o.v.
1182

1183
  This means that we can pass this to a command via the shell and be
1184
  sure that it doesn't alter the command line and is passed as such to
1185
  the actual command.
1186

1187
  Note that we are overly restrictive here, in order to be on the safe
1188
  side.
1189

1190
  @type word: str
1191
  @param word: the word to check
1192
  @rtype: boolean
1193
  @return: True if the word is 'safe'
1194

1195
  """
1196
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1197

    
1198

    
1199
def BuildShellCmd(template, *args):
1200
  """Build a safe shell command line from the given arguments.
1201

1202
  This function will check all arguments in the args list so that they
1203
  are valid shell parameters (i.e. they don't contain shell
1204
  metacharacters). If everything is ok, it will return the result of
1205
  template % args.
1206

1207
  @type template: str
1208
  @param template: the string holding the template for the
1209
      string formatting
1210
  @rtype: str
1211
  @return: the expanded command line
1212

1213
  """
1214
  for word in args:
1215
    if not IsValidShellParam(word):
1216
      raise errors.ProgrammerError("Shell argument '%s' contains"
1217
                                   " invalid characters" % word)
1218
  return template % args
1219

    
1220

    
1221
def FormatUnit(value, units):
1222
  """Formats an incoming number of MiB with the appropriate unit.
1223

1224
  @type value: int
1225
  @param value: integer representing the value in MiB (1048576)
1226
  @type units: char
1227
  @param units: the type of formatting we should do:
1228
      - 'h' for automatic scaling
1229
      - 'm' for MiBs
1230
      - 'g' for GiBs
1231
      - 't' for TiBs
1232
  @rtype: str
1233
  @return: the formatted value (with suffix)
1234

1235
  """
1236
  if units not in ('m', 'g', 't', 'h'):
1237
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1238

    
1239
  suffix = ''
1240

    
1241
  if units == 'm' or (units == 'h' and value < 1024):
1242
    if units == 'h':
1243
      suffix = 'M'
1244
    return "%d%s" % (round(value, 0), suffix)
1245

    
1246
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1247
    if units == 'h':
1248
      suffix = 'G'
1249
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1250

    
1251
  else:
1252
    if units == 'h':
1253
      suffix = 'T'
1254
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1255

    
1256

    
1257
def ParseUnit(input_string):
1258
  """Tries to extract number and scale from the given string.
1259

1260
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1261
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1262
  is always an int in MiB.
1263

1264
  """
1265
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1266
  if not m:
1267
    raise errors.UnitParseError("Invalid format")
1268

    
1269
  value = float(m.groups()[0])
1270

    
1271
  unit = m.groups()[1]
1272
  if unit:
1273
    lcunit = unit.lower()
1274
  else:
1275
    lcunit = 'm'
1276

    
1277
  if lcunit in ('m', 'mb', 'mib'):
1278
    # Value already in MiB
1279
    pass
1280

    
1281
  elif lcunit in ('g', 'gb', 'gib'):
1282
    value *= 1024
1283

    
1284
  elif lcunit in ('t', 'tb', 'tib'):
1285
    value *= 1024 * 1024
1286

    
1287
  else:
1288
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1289

    
1290
  # Make sure we round up
1291
  if int(value) < value:
1292
    value += 1
1293

    
1294
  # Round up to the next multiple of 4
1295
  value = int(value)
1296
  if value % 4:
1297
    value += 4 - value % 4
1298

    
1299
  return value
1300

    
1301

    
1302
def AddAuthorizedKey(file_name, key):
1303
  """Adds an SSH public key to an authorized_keys file.
1304

1305
  @type file_name: str
1306
  @param file_name: path to authorized_keys file
1307
  @type key: str
1308
  @param key: string containing key
1309

1310
  """
1311
  key_fields = key.split()
1312

    
1313
  f = open(file_name, 'a+')
1314
  try:
1315
    nl = True
1316
    for line in f:
1317
      # Ignore whitespace changes
1318
      if line.split() == key_fields:
1319
        break
1320
      nl = line.endswith('\n')
1321
    else:
1322
      if not nl:
1323
        f.write("\n")
1324
      f.write(key.rstrip('\r\n'))
1325
      f.write("\n")
1326
      f.flush()
1327
  finally:
1328
    f.close()
1329

    
1330

    
1331
def RemoveAuthorizedKey(file_name, key):
1332
  """Removes an SSH public key from an authorized_keys file.
1333

1334
  @type file_name: str
1335
  @param file_name: path to authorized_keys file
1336
  @type key: str
1337
  @param key: string containing key
1338

1339
  """
1340
  key_fields = key.split()
1341

    
1342
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1343
  try:
1344
    out = os.fdopen(fd, 'w')
1345
    try:
1346
      f = open(file_name, 'r')
1347
      try:
1348
        for line in f:
1349
          # Ignore whitespace changes while comparing lines
1350
          if line.split() != key_fields:
1351
            out.write(line)
1352

    
1353
        out.flush()
1354
        os.rename(tmpname, file_name)
1355
      finally:
1356
        f.close()
1357
    finally:
1358
      out.close()
1359
  except:
1360
    RemoveFile(tmpname)
1361
    raise
1362

    
1363

    
1364
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1365
  """Sets the name of an IP address and hostname in /etc/hosts.
1366

1367
  @type file_name: str
1368
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1369
  @type ip: str
1370
  @param ip: the IP address
1371
  @type hostname: str
1372
  @param hostname: the hostname to be added
1373
  @type aliases: list
1374
  @param aliases: the list of aliases to add for the hostname
1375

1376
  """
1377
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1378
  # Ensure aliases are unique
1379
  aliases = UniqueSequence([hostname] + aliases)[1:]
1380

    
1381
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1382
  try:
1383
    out = os.fdopen(fd, 'w')
1384
    try:
1385
      f = open(file_name, 'r')
1386
      try:
1387
        for line in f:
1388
          fields = line.split()
1389
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1390
            continue
1391
          out.write(line)
1392

    
1393
        out.write("%s\t%s" % (ip, hostname))
1394
        if aliases:
1395
          out.write(" %s" % ' '.join(aliases))
1396
        out.write('\n')
1397

    
1398
        out.flush()
1399
        os.fsync(out)
1400
        os.chmod(tmpname, 0644)
1401
        os.rename(tmpname, file_name)
1402
      finally:
1403
        f.close()
1404
    finally:
1405
      out.close()
1406
  except:
1407
    RemoveFile(tmpname)
1408
    raise
1409

    
1410

    
1411
def AddHostToEtcHosts(hostname):
1412
  """Wrapper around SetEtcHostsEntry.
1413

1414
  @type hostname: str
1415
  @param hostname: a hostname that will be resolved and added to
1416
      L{constants.ETC_HOSTS}
1417

1418
  """
1419
  hi = netutils.HostInfo(name=hostname)
1420
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1421

    
1422

    
1423
def RemoveEtcHostsEntry(file_name, hostname):
1424
  """Removes a hostname from /etc/hosts.
1425

1426
  IP addresses without names are removed from the file.
1427

1428
  @type file_name: str
1429
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1430
  @type hostname: str
1431
  @param hostname: the hostname to be removed
1432

1433
  """
1434
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1435
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1436
  try:
1437
    out = os.fdopen(fd, 'w')
1438
    try:
1439
      f = open(file_name, 'r')
1440
      try:
1441
        for line in f:
1442
          fields = line.split()
1443
          if len(fields) > 1 and not fields[0].startswith('#'):
1444
            names = fields[1:]
1445
            if hostname in names:
1446
              while hostname in names:
1447
                names.remove(hostname)
1448
              if names:
1449
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1450
              continue
1451

    
1452
          out.write(line)
1453

    
1454
        out.flush()
1455
        os.fsync(out)
1456
        os.chmod(tmpname, 0644)
1457
        os.rename(tmpname, file_name)
1458
      finally:
1459
        f.close()
1460
    finally:
1461
      out.close()
1462
  except:
1463
    RemoveFile(tmpname)
1464
    raise
1465

    
1466

    
1467
def RemoveHostFromEtcHosts(hostname):
1468
  """Wrapper around RemoveEtcHostsEntry.
1469

1470
  @type hostname: str
1471
  @param hostname: hostname that will be resolved and its
1472
      full and shot name will be removed from
1473
      L{constants.ETC_HOSTS}
1474

1475
  """
1476
  hi = netutils.HostInfo(name=hostname)
1477
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1478
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1479

    
1480

    
1481
def TimestampForFilename():
1482
  """Returns the current time formatted for filenames.
1483

1484
  The format doesn't contain colons as some shells and applications them as
1485
  separators.
1486

1487
  """
1488
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1489

    
1490

    
1491
def CreateBackup(file_name):
1492
  """Creates a backup of a file.
1493

1494
  @type file_name: str
1495
  @param file_name: file to be backed up
1496
  @rtype: str
1497
  @return: the path to the newly created backup
1498
  @raise errors.ProgrammerError: for invalid file names
1499

1500
  """
1501
  if not os.path.isfile(file_name):
1502
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1503
                                file_name)
1504

    
1505
  prefix = ("%s.backup-%s." %
1506
            (os.path.basename(file_name), TimestampForFilename()))
1507
  dir_name = os.path.dirname(file_name)
1508

    
1509
  fsrc = open(file_name, 'rb')
1510
  try:
1511
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1512
    fdst = os.fdopen(fd, 'wb')
1513
    try:
1514
      logging.debug("Backing up %s at %s", file_name, backup_name)
1515
      shutil.copyfileobj(fsrc, fdst)
1516
    finally:
1517
      fdst.close()
1518
  finally:
1519
    fsrc.close()
1520

    
1521
  return backup_name
1522

    
1523

    
1524
def ShellQuote(value):
1525
  """Quotes shell argument according to POSIX.
1526

1527
  @type value: str
1528
  @param value: the argument to be quoted
1529
  @rtype: str
1530
  @return: the quoted value
1531

1532
  """
1533
  if _re_shell_unquoted.match(value):
1534
    return value
1535
  else:
1536
    return "'%s'" % value.replace("'", "'\\''")
1537

    
1538

    
1539
def ShellQuoteArgs(args):
1540
  """Quotes a list of shell arguments.
1541

1542
  @type args: list
1543
  @param args: list of arguments to be quoted
1544
  @rtype: str
1545
  @return: the quoted arguments concatenated with spaces
1546

1547
  """
1548
  return ' '.join([ShellQuote(i) for i in args])
1549

    
1550

    
1551
class ShellWriter:
1552
  """Helper class to write scripts with indentation.
1553

1554
  """
1555
  INDENT_STR = "  "
1556

    
1557
  def __init__(self, fh):
1558
    """Initializes this class.
1559

1560
    """
1561
    self._fh = fh
1562
    self._indent = 0
1563

    
1564
  def IncIndent(self):
1565
    """Increase indentation level by 1.
1566

1567
    """
1568
    self._indent += 1
1569

    
1570
  def DecIndent(self):
1571
    """Decrease indentation level by 1.
1572

1573
    """
1574
    assert self._indent > 0
1575
    self._indent -= 1
1576

    
1577
  def Write(self, txt, *args):
1578
    """Write line to output file.
1579

1580
    """
1581
    assert self._indent >= 0
1582

    
1583
    self._fh.write(self._indent * self.INDENT_STR)
1584

    
1585
    if args:
1586
      self._fh.write(txt % args)
1587
    else:
1588
      self._fh.write(txt)
1589

    
1590
    self._fh.write("\n")
1591

    
1592

    
1593
def ListVisibleFiles(path):
1594
  """Returns a list of visible files in a directory.
1595

1596
  @type path: str
1597
  @param path: the directory to enumerate
1598
  @rtype: list
1599
  @return: the list of all files not starting with a dot
1600
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1601

1602
  """
1603
  if not IsNormAbsPath(path):
1604
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1605
                                 " absolute/normalized: '%s'" % path)
1606
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1607
  return files
1608

    
1609

    
1610
def GetHomeDir(user, default=None):
1611
  """Try to get the homedir of the given user.
1612

1613
  The user can be passed either as a string (denoting the name) or as
1614
  an integer (denoting the user id). If the user is not found, the
1615
  'default' argument is returned, which defaults to None.
1616

1617
  """
1618
  try:
1619
    if isinstance(user, basestring):
1620
      result = pwd.getpwnam(user)
1621
    elif isinstance(user, (int, long)):
1622
      result = pwd.getpwuid(user)
1623
    else:
1624
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1625
                                   type(user))
1626
  except KeyError:
1627
    return default
1628
  return result.pw_dir
1629

    
1630

    
1631
def NewUUID():
1632
  """Returns a random UUID.
1633

1634
  @note: This is a Linux-specific method as it uses the /proc
1635
      filesystem.
1636
  @rtype: str
1637

1638
  """
1639
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1640

    
1641

    
1642
def GenerateSecret(numbytes=20):
1643
  """Generates a random secret.
1644

1645
  This will generate a pseudo-random secret returning an hex string
1646
  (so that it can be used where an ASCII string is needed).
1647

1648
  @param numbytes: the number of bytes which will be represented by the returned
1649
      string (defaulting to 20, the length of a SHA1 hash)
1650
  @rtype: str
1651
  @return: an hex representation of the pseudo-random sequence
1652

1653
  """
1654
  return os.urandom(numbytes).encode('hex')
1655

    
1656

    
1657
def EnsureDirs(dirs):
1658
  """Make required directories, if they don't exist.
1659

1660
  @param dirs: list of tuples (dir_name, dir_mode)
1661
  @type dirs: list of (string, integer)
1662

1663
  """
1664
  for dir_name, dir_mode in dirs:
1665
    try:
1666
      os.mkdir(dir_name, dir_mode)
1667
    except EnvironmentError, err:
1668
      if err.errno != errno.EEXIST:
1669
        raise errors.GenericError("Cannot create needed directory"
1670
                                  " '%s': %s" % (dir_name, err))
1671
    try:
1672
      os.chmod(dir_name, dir_mode)
1673
    except EnvironmentError, err:
1674
      raise errors.GenericError("Cannot change directory permissions on"
1675
                                " '%s': %s" % (dir_name, err))
1676
    if not os.path.isdir(dir_name):
1677
      raise errors.GenericError("%s is not a directory" % dir_name)
1678

    
1679

    
1680
def ReadFile(file_name, size=-1):
1681
  """Reads a file.
1682

1683
  @type size: int
1684
  @param size: Read at most size bytes (if negative, entire file)
1685
  @rtype: str
1686
  @return: the (possibly partial) content of the file
1687

1688
  """
1689
  f = open(file_name, "r")
1690
  try:
1691
    return f.read(size)
1692
  finally:
1693
    f.close()
1694

    
1695

    
1696
def WriteFile(file_name, fn=None, data=None,
1697
              mode=None, uid=-1, gid=-1,
1698
              atime=None, mtime=None, close=True,
1699
              dry_run=False, backup=False,
1700
              prewrite=None, postwrite=None):
1701
  """(Over)write a file atomically.
1702

1703
  The file_name and either fn (a function taking one argument, the
1704
  file descriptor, and which should write the data to it) or data (the
1705
  contents of the file) must be passed. The other arguments are
1706
  optional and allow setting the file mode, owner and group, and the
1707
  mtime/atime of the file.
1708

1709
  If the function doesn't raise an exception, it has succeeded and the
1710
  target file has the new contents. If the function has raised an
1711
  exception, an existing target file should be unmodified and the
1712
  temporary file should be removed.
1713

1714
  @type file_name: str
1715
  @param file_name: the target filename
1716
  @type fn: callable
1717
  @param fn: content writing function, called with
1718
      file descriptor as parameter
1719
  @type data: str
1720
  @param data: contents of the file
1721
  @type mode: int
1722
  @param mode: file mode
1723
  @type uid: int
1724
  @param uid: the owner of the file
1725
  @type gid: int
1726
  @param gid: the group of the file
1727
  @type atime: int
1728
  @param atime: a custom access time to be set on the file
1729
  @type mtime: int
1730
  @param mtime: a custom modification time to be set on the file
1731
  @type close: boolean
1732
  @param close: whether to close file after writing it
1733
  @type prewrite: callable
1734
  @param prewrite: function to be called before writing content
1735
  @type postwrite: callable
1736
  @param postwrite: function to be called after writing content
1737

1738
  @rtype: None or int
1739
  @return: None if the 'close' parameter evaluates to True,
1740
      otherwise the file descriptor
1741

1742
  @raise errors.ProgrammerError: if any of the arguments are not valid
1743

1744
  """
1745
  if not os.path.isabs(file_name):
1746
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1747
                                 " absolute: '%s'" % file_name)
1748

    
1749
  if [fn, data].count(None) != 1:
1750
    raise errors.ProgrammerError("fn or data required")
1751

    
1752
  if [atime, mtime].count(None) == 1:
1753
    raise errors.ProgrammerError("Both atime and mtime must be either"
1754
                                 " set or None")
1755

    
1756
  if backup and not dry_run and os.path.isfile(file_name):
1757
    CreateBackup(file_name)
1758

    
1759
  dir_name, base_name = os.path.split(file_name)
1760
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1761
  do_remove = True
1762
  # here we need to make sure we remove the temp file, if any error
1763
  # leaves it in place
1764
  try:
1765
    if uid != -1 or gid != -1:
1766
      os.chown(new_name, uid, gid)
1767
    if mode:
1768
      os.chmod(new_name, mode)
1769
    if callable(prewrite):
1770
      prewrite(fd)
1771
    if data is not None:
1772
      os.write(fd, data)
1773
    else:
1774
      fn(fd)
1775
    if callable(postwrite):
1776
      postwrite(fd)
1777
    os.fsync(fd)
1778
    if atime is not None and mtime is not None:
1779
      os.utime(new_name, (atime, mtime))
1780
    if not dry_run:
1781
      os.rename(new_name, file_name)
1782
      do_remove = False
1783
  finally:
1784
    if close:
1785
      os.close(fd)
1786
      result = None
1787
    else:
1788
      result = fd
1789
    if do_remove:
1790
      RemoveFile(new_name)
1791

    
1792
  return result
1793

    
1794

    
1795
def ReadOneLineFile(file_name, strict=False):
1796
  """Return the first non-empty line from a file.
1797

1798
  @type strict: boolean
1799
  @param strict: if True, abort if the file has more than one
1800
      non-empty line
1801

1802
  """
1803
  file_lines = ReadFile(file_name).splitlines()
1804
  full_lines = filter(bool, file_lines)
1805
  if not file_lines or not full_lines:
1806
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1807
  elif strict and len(full_lines) > 1:
1808
    raise errors.GenericError("Too many lines in one-liner file %s" %
1809
                              file_name)
1810
  return full_lines[0]
1811

    
1812

    
1813
def FirstFree(seq, base=0):
1814
  """Returns the first non-existing integer from seq.
1815

1816
  The seq argument should be a sorted list of positive integers. The
1817
  first time the index of an element is smaller than the element
1818
  value, the index will be returned.
1819

1820
  The base argument is used to start at a different offset,
1821
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1822

1823
  Example: C{[0, 1, 3]} will return I{2}.
1824

1825
  @type seq: sequence
1826
  @param seq: the sequence to be analyzed.
1827
  @type base: int
1828
  @param base: use this value as the base index of the sequence
1829
  @rtype: int
1830
  @return: the first non-used index in the sequence
1831

1832
  """
1833
  for idx, elem in enumerate(seq):
1834
    assert elem >= base, "Passed element is higher than base offset"
1835
    if elem > idx + base:
1836
      # idx is not used
1837
      return idx + base
1838
  return None
1839

    
1840

    
1841
def SingleWaitForFdCondition(fdobj, event, timeout):
1842
  """Waits for a condition to occur on the socket.
1843

1844
  Immediately returns at the first interruption.
1845

1846
  @type fdobj: integer or object supporting a fileno() method
1847
  @param fdobj: entity to wait for events on
1848
  @type event: integer
1849
  @param event: ORed condition (see select module)
1850
  @type timeout: float or None
1851
  @param timeout: Timeout in seconds
1852
  @rtype: int or None
1853
  @return: None for timeout, otherwise occured conditions
1854

1855
  """
1856
  check = (event | select.POLLPRI |
1857
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1858

    
1859
  if timeout is not None:
1860
    # Poller object expects milliseconds
1861
    timeout *= 1000
1862

    
1863
  poller = select.poll()
1864
  poller.register(fdobj, event)
1865
  try:
1866
    # TODO: If the main thread receives a signal and we have no timeout, we
1867
    # could wait forever. This should check a global "quit" flag or something
1868
    # every so often.
1869
    io_events = poller.poll(timeout)
1870
  except select.error, err:
1871
    if err[0] != errno.EINTR:
1872
      raise
1873
    io_events = []
1874
  if io_events and io_events[0][1] & check:
1875
    return io_events[0][1]
1876
  else:
1877
    return None
1878

    
1879

    
1880
class FdConditionWaiterHelper(object):
1881
  """Retry helper for WaitForFdCondition.
1882

1883
  This class contains the retried and wait functions that make sure
1884
  WaitForFdCondition can continue waiting until the timeout is actually
1885
  expired.
1886

1887
  """
1888

    
1889
  def __init__(self, timeout):
1890
    self.timeout = timeout
1891

    
1892
  def Poll(self, fdobj, event):
1893
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1894
    if result is None:
1895
      raise RetryAgain()
1896
    else:
1897
      return result
1898

    
1899
  def UpdateTimeout(self, timeout):
1900
    self.timeout = timeout
1901

    
1902

    
1903
def WaitForFdCondition(fdobj, event, timeout):
1904
  """Waits for a condition to occur on the socket.
1905

1906
  Retries until the timeout is expired, even if interrupted.
1907

1908
  @type fdobj: integer or object supporting a fileno() method
1909
  @param fdobj: entity to wait for events on
1910
  @type event: integer
1911
  @param event: ORed condition (see select module)
1912
  @type timeout: float or None
1913
  @param timeout: Timeout in seconds
1914
  @rtype: int or None
1915
  @return: None for timeout, otherwise occured conditions
1916

1917
  """
1918
  if timeout is not None:
1919
    retrywaiter = FdConditionWaiterHelper(timeout)
1920
    try:
1921
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1922
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1923
    except RetryTimeout:
1924
      result = None
1925
  else:
1926
    result = None
1927
    while result is None:
1928
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1929
  return result
1930

    
1931

    
1932
def UniqueSequence(seq):
1933
  """Returns a list with unique elements.
1934

1935
  Element order is preserved.
1936

1937
  @type seq: sequence
1938
  @param seq: the sequence with the source elements
1939
  @rtype: list
1940
  @return: list of unique elements from seq
1941

1942
  """
1943
  seen = set()
1944
  return [i for i in seq if i not in seen and not seen.add(i)]
1945

    
1946

    
1947
def NormalizeAndValidateMac(mac):
1948
  """Normalizes and check if a MAC address is valid.
1949

1950
  Checks whether the supplied MAC address is formally correct, only
1951
  accepts colon separated format. Normalize it to all lower.
1952

1953
  @type mac: str
1954
  @param mac: the MAC to be validated
1955
  @rtype: str
1956
  @return: returns the normalized and validated MAC.
1957

1958
  @raise errors.OpPrereqError: If the MAC isn't valid
1959

1960
  """
1961
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
1962
  if not mac_check.match(mac):
1963
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
1964
                               mac, errors.ECODE_INVAL)
1965

    
1966
  return mac.lower()
1967

    
1968

    
1969
def TestDelay(duration):
1970
  """Sleep for a fixed amount of time.
1971

1972
  @type duration: float
1973
  @param duration: the sleep duration
1974
  @rtype: boolean
1975
  @return: False for negative value, True otherwise
1976

1977
  """
1978
  if duration < 0:
1979
    return False, "Invalid sleep duration"
1980
  time.sleep(duration)
1981
  return True, None
1982

    
1983

    
1984
def _CloseFDNoErr(fd, retries=5):
1985
  """Close a file descriptor ignoring errors.
1986

1987
  @type fd: int
1988
  @param fd: the file descriptor
1989
  @type retries: int
1990
  @param retries: how many retries to make, in case we get any
1991
      other error than EBADF
1992

1993
  """
1994
  try:
1995
    os.close(fd)
1996
  except OSError, err:
1997
    if err.errno != errno.EBADF:
1998
      if retries > 0:
1999
        _CloseFDNoErr(fd, retries - 1)
2000
    # else either it's closed already or we're out of retries, so we
2001
    # ignore this and go on
2002

    
2003

    
2004
def CloseFDs(noclose_fds=None):
2005
  """Close file descriptors.
2006

2007
  This closes all file descriptors above 2 (i.e. except
2008
  stdin/out/err).
2009

2010
  @type noclose_fds: list or None
2011
  @param noclose_fds: if given, it denotes a list of file descriptor
2012
      that should not be closed
2013

2014
  """
2015
  # Default maximum for the number of available file descriptors.
2016
  if 'SC_OPEN_MAX' in os.sysconf_names:
2017
    try:
2018
      MAXFD = os.sysconf('SC_OPEN_MAX')
2019
      if MAXFD < 0:
2020
        MAXFD = 1024
2021
    except OSError:
2022
      MAXFD = 1024
2023
  else:
2024
    MAXFD = 1024
2025
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2026
  if (maxfd == resource.RLIM_INFINITY):
2027
    maxfd = MAXFD
2028

    
2029
  # Iterate through and close all file descriptors (except the standard ones)
2030
  for fd in range(3, maxfd):
2031
    if noclose_fds and fd in noclose_fds:
2032
      continue
2033
    _CloseFDNoErr(fd)
2034

    
2035

    
2036
def Mlockall(_ctypes=ctypes):
2037
  """Lock current process' virtual address space into RAM.
2038

2039
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2040
  see mlock(2) for more details. This function requires ctypes module.
2041

2042
  @raises errors.NoCtypesError: if ctypes module is not found
2043

2044
  """
2045
  if _ctypes is None:
2046
    raise errors.NoCtypesError()
2047

    
2048
  libc = _ctypes.cdll.LoadLibrary("libc.so.6")
2049
  if libc is None:
2050
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2051
    return
2052

    
2053
  # Some older version of the ctypes module don't have built-in functionality
2054
  # to access the errno global variable, where function error codes are stored.
2055
  # By declaring this variable as a pointer to an integer we can then access
2056
  # its value correctly, should the mlockall call fail, in order to see what
2057
  # the actual error code was.
2058
  # pylint: disable-msg=W0212
2059
  libc.__errno_location.restype = _ctypes.POINTER(_ctypes.c_int)
2060

    
2061
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2062
    # pylint: disable-msg=W0212
2063
    logging.error("Cannot set memory lock: %s",
2064
                  os.strerror(libc.__errno_location().contents.value))
2065
    return
2066

    
2067
  logging.debug("Memory lock set")
2068

    
2069

    
2070
def Daemonize(logfile, run_uid, run_gid):
2071
  """Daemonize the current process.
2072

2073
  This detaches the current process from the controlling terminal and
2074
  runs it in the background as a daemon.
2075

2076
  @type logfile: str
2077
  @param logfile: the logfile to which we should redirect stdout/stderr
2078
  @type run_uid: int
2079
  @param run_uid: Run the child under this uid
2080
  @type run_gid: int
2081
  @param run_gid: Run the child under this gid
2082
  @rtype: int
2083
  @return: the value zero
2084

2085
  """
2086
  # pylint: disable-msg=W0212
2087
  # yes, we really want os._exit
2088
  UMASK = 077
2089
  WORKDIR = "/"
2090

    
2091
  # this might fail
2092
  pid = os.fork()
2093
  if (pid == 0):  # The first child.
2094
    os.setsid()
2095
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2096
    #        make sure to check for config permission and bail out when invoked
2097
    #        with wrong user.
2098
    os.setgid(run_gid)
2099
    os.setuid(run_uid)
2100
    # this might fail
2101
    pid = os.fork() # Fork a second child.
2102
    if (pid == 0):  # The second child.
2103
      os.chdir(WORKDIR)
2104
      os.umask(UMASK)
2105
    else:
2106
      # exit() or _exit()?  See below.
2107
      os._exit(0) # Exit parent (the first child) of the second child.
2108
  else:
2109
    os._exit(0) # Exit parent of the first child.
2110

    
2111
  for fd in range(3):
2112
    _CloseFDNoErr(fd)
2113
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2114
  assert i == 0, "Can't close/reopen stdin"
2115
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2116
  assert i == 1, "Can't close/reopen stdout"
2117
  # Duplicate standard output to standard error.
2118
  os.dup2(1, 2)
2119
  return 0
2120

    
2121

    
2122
def DaemonPidFileName(name):
2123
  """Compute a ganeti pid file absolute path
2124

2125
  @type name: str
2126
  @param name: the daemon name
2127
  @rtype: str
2128
  @return: the full path to the pidfile corresponding to the given
2129
      daemon name
2130

2131
  """
2132
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2133

    
2134

    
2135
def EnsureDaemon(name):
2136
  """Check for and start daemon if not alive.
2137

2138
  """
2139
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2140
  if result.failed:
2141
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2142
                  name, result.fail_reason, result.output)
2143
    return False
2144

    
2145
  return True
2146

    
2147

    
2148
def StopDaemon(name):
2149
  """Stop daemon
2150

2151
  """
2152
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2153
  if result.failed:
2154
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2155
                  name, result.fail_reason, result.output)
2156
    return False
2157

    
2158
  return True
2159

    
2160

    
2161
def WritePidFile(name):
2162
  """Write the current process pidfile.
2163

2164
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2165

2166
  @type name: str
2167
  @param name: the daemon name to use
2168
  @raise errors.GenericError: if the pid file already exists and
2169
      points to a live process
2170

2171
  """
2172
  pid = os.getpid()
2173
  pidfilename = DaemonPidFileName(name)
2174
  if IsProcessAlive(ReadPidFile(pidfilename)):
2175
    raise errors.GenericError("%s contains a live process" % pidfilename)
2176

    
2177
  WriteFile(pidfilename, data="%d\n" % pid)
2178

    
2179

    
2180
def RemovePidFile(name):
2181
  """Remove the current process pidfile.
2182

2183
  Any errors are ignored.
2184

2185
  @type name: str
2186
  @param name: the daemon name used to derive the pidfile name
2187

2188
  """
2189
  pidfilename = DaemonPidFileName(name)
2190
  # TODO: we could check here that the file contains our pid
2191
  try:
2192
    RemoveFile(pidfilename)
2193
  except: # pylint: disable-msg=W0702
2194
    pass
2195

    
2196

    
2197
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2198
                waitpid=False):
2199
  """Kill a process given by its pid.
2200

2201
  @type pid: int
2202
  @param pid: The PID to terminate.
2203
  @type signal_: int
2204
  @param signal_: The signal to send, by default SIGTERM
2205
  @type timeout: int
2206
  @param timeout: The timeout after which, if the process is still alive,
2207
                  a SIGKILL will be sent. If not positive, no such checking
2208
                  will be done
2209
  @type waitpid: boolean
2210
  @param waitpid: If true, we should waitpid on this process after
2211
      sending signals, since it's our own child and otherwise it
2212
      would remain as zombie
2213

2214
  """
2215
  def _helper(pid, signal_, wait):
2216
    """Simple helper to encapsulate the kill/waitpid sequence"""
2217
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2218
      try:
2219
        os.waitpid(pid, os.WNOHANG)
2220
      except OSError:
2221
        pass
2222

    
2223
  if pid <= 0:
2224
    # kill with pid=0 == suicide
2225
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2226

    
2227
  if not IsProcessAlive(pid):
2228
    return
2229

    
2230
  _helper(pid, signal_, waitpid)
2231

    
2232
  if timeout <= 0:
2233
    return
2234

    
2235
  def _CheckProcess():
2236
    if not IsProcessAlive(pid):
2237
      return
2238

    
2239
    try:
2240
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2241
    except OSError:
2242
      raise RetryAgain()
2243

    
2244
    if result_pid > 0:
2245
      return
2246

    
2247
    raise RetryAgain()
2248

    
2249
  try:
2250
    # Wait up to $timeout seconds
2251
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2252
  except RetryTimeout:
2253
    pass
2254

    
2255
  if IsProcessAlive(pid):
2256
    # Kill process if it's still alive
2257
    _helper(pid, signal.SIGKILL, waitpid)
2258

    
2259

    
2260
def FindFile(name, search_path, test=os.path.exists):
2261
  """Look for a filesystem object in a given path.
2262

2263
  This is an abstract method to search for filesystem object (files,
2264
  dirs) under a given search path.
2265

2266
  @type name: str
2267
  @param name: the name to look for
2268
  @type search_path: str
2269
  @param search_path: location to start at
2270
  @type test: callable
2271
  @param test: a function taking one argument that should return True
2272
      if the a given object is valid; the default value is
2273
      os.path.exists, causing only existing files to be returned
2274
  @rtype: str or None
2275
  @return: full path to the object if found, None otherwise
2276

2277
  """
2278
  # validate the filename mask
2279
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2280
    logging.critical("Invalid value passed for external script name: '%s'",
2281
                     name)
2282
    return None
2283

    
2284
  for dir_name in search_path:
2285
    # FIXME: investigate switch to PathJoin
2286
    item_name = os.path.sep.join([dir_name, name])
2287
    # check the user test and that we're indeed resolving to the given
2288
    # basename
2289
    if test(item_name) and os.path.basename(item_name) == name:
2290
      return item_name
2291
  return None
2292

    
2293

    
2294
def CheckVolumeGroupSize(vglist, vgname, minsize):
2295
  """Checks if the volume group list is valid.
2296

2297
  The function will check if a given volume group is in the list of
2298
  volume groups and has a minimum size.
2299

2300
  @type vglist: dict
2301
  @param vglist: dictionary of volume group names and their size
2302
  @type vgname: str
2303
  @param vgname: the volume group we should check
2304
  @type minsize: int
2305
  @param minsize: the minimum size we accept
2306
  @rtype: None or str
2307
  @return: None for success, otherwise the error message
2308

2309
  """
2310
  vgsize = vglist.get(vgname, None)
2311
  if vgsize is None:
2312
    return "volume group '%s' missing" % vgname
2313
  elif vgsize < minsize:
2314
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2315
            (vgname, minsize, vgsize))
2316
  return None
2317

    
2318

    
2319
def SplitTime(value):
2320
  """Splits time as floating point number into a tuple.
2321

2322
  @param value: Time in seconds
2323
  @type value: int or float
2324
  @return: Tuple containing (seconds, microseconds)
2325

2326
  """
2327
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2328

    
2329
  assert 0 <= seconds, \
2330
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2331
  assert 0 <= microseconds <= 999999, \
2332
    "Microseconds must be 0-999999, but are %s" % microseconds
2333

    
2334
  return (int(seconds), int(microseconds))
2335

    
2336

    
2337
def MergeTime(timetuple):
2338
  """Merges a tuple into time as a floating point number.
2339

2340
  @param timetuple: Time as tuple, (seconds, microseconds)
2341
  @type timetuple: tuple
2342
  @return: Time as a floating point number expressed in seconds
2343

2344
  """
2345
  (seconds, microseconds) = timetuple
2346

    
2347
  assert 0 <= seconds, \
2348
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2349
  assert 0 <= microseconds <= 999999, \
2350
    "Microseconds must be 0-999999, but are %s" % microseconds
2351

    
2352
  return float(seconds) + (float(microseconds) * 0.000001)
2353

    
2354

    
2355
class LogFileHandler(logging.FileHandler):
2356
  """Log handler that doesn't fallback to stderr.
2357

2358
  When an error occurs while writing on the logfile, logging.FileHandler tries
2359
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2360
  the logfile. This class avoids failures reporting errors to /dev/console.
2361

2362
  """
2363
  def __init__(self, filename, mode="a", encoding=None):
2364
    """Open the specified file and use it as the stream for logging.
2365

2366
    Also open /dev/console to report errors while logging.
2367

2368
    """
2369
    logging.FileHandler.__init__(self, filename, mode, encoding)
2370
    self.console = open(constants.DEV_CONSOLE, "a")
2371

    
2372
  def handleError(self, record): # pylint: disable-msg=C0103
2373
    """Handle errors which occur during an emit() call.
2374

2375
    Try to handle errors with FileHandler method, if it fails write to
2376
    /dev/console.
2377

2378
    """
2379
    try:
2380
      logging.FileHandler.handleError(self, record)
2381
    except Exception: # pylint: disable-msg=W0703
2382
      try:
2383
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2384
      except Exception: # pylint: disable-msg=W0703
2385
        # Log handler tried everything it could, now just give up
2386
        pass
2387

    
2388

    
2389
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2390
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2391
                 console_logging=False):
2392
  """Configures the logging module.
2393

2394
  @type logfile: str
2395
  @param logfile: the filename to which we should log
2396
  @type debug: integer
2397
  @param debug: if greater than zero, enable debug messages, otherwise
2398
      only those at C{INFO} and above level
2399
  @type stderr_logging: boolean
2400
  @param stderr_logging: whether we should also log to the standard error
2401
  @type program: str
2402
  @param program: the name under which we should log messages
2403
  @type multithreaded: boolean
2404
  @param multithreaded: if True, will add the thread name to the log file
2405
  @type syslog: string
2406
  @param syslog: one of 'no', 'yes', 'only':
2407
      - if no, syslog is not used
2408
      - if yes, syslog is used (in addition to file-logging)
2409
      - if only, only syslog is used
2410
  @type console_logging: boolean
2411
  @param console_logging: if True, will use a FileHandler which falls back to
2412
      the system console if logging fails
2413
  @raise EnvironmentError: if we can't open the log file and
2414
      syslog/stderr logging is disabled
2415

2416
  """
2417
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2418
  sft = program + "[%(process)d]:"
2419
  if multithreaded:
2420
    fmt += "/%(threadName)s"
2421
    sft += " (%(threadName)s)"
2422
  if debug:
2423
    fmt += " %(module)s:%(lineno)s"
2424
    # no debug info for syslog loggers
2425
  fmt += " %(levelname)s %(message)s"
2426
  # yes, we do want the textual level, as remote syslog will probably
2427
  # lose the error level, and it's easier to grep for it
2428
  sft += " %(levelname)s %(message)s"
2429
  formatter = logging.Formatter(fmt)
2430
  sys_fmt = logging.Formatter(sft)
2431

    
2432
  root_logger = logging.getLogger("")
2433
  root_logger.setLevel(logging.NOTSET)
2434

    
2435
  # Remove all previously setup handlers
2436
  for handler in root_logger.handlers:
2437
    handler.close()
2438
    root_logger.removeHandler(handler)
2439

    
2440
  if stderr_logging:
2441
    stderr_handler = logging.StreamHandler()
2442
    stderr_handler.setFormatter(formatter)
2443
    if debug:
2444
      stderr_handler.setLevel(logging.NOTSET)
2445
    else:
2446
      stderr_handler.setLevel(logging.CRITICAL)
2447
    root_logger.addHandler(stderr_handler)
2448

    
2449
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2450
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2451
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2452
                                                    facility)
2453
    syslog_handler.setFormatter(sys_fmt)
2454
    # Never enable debug over syslog
2455
    syslog_handler.setLevel(logging.INFO)
2456
    root_logger.addHandler(syslog_handler)
2457

    
2458
  if syslog != constants.SYSLOG_ONLY:
2459
    # this can fail, if the logging directories are not setup or we have
2460
    # a permisssion problem; in this case, it's best to log but ignore
2461
    # the error if stderr_logging is True, and if false we re-raise the
2462
    # exception since otherwise we could run but without any logs at all
2463
    try:
2464
      if console_logging:
2465
        logfile_handler = LogFileHandler(logfile)
2466
      else:
2467
        logfile_handler = logging.FileHandler(logfile)
2468
      logfile_handler.setFormatter(formatter)
2469
      if debug:
2470
        logfile_handler.setLevel(logging.DEBUG)
2471
      else:
2472
        logfile_handler.setLevel(logging.INFO)
2473
      root_logger.addHandler(logfile_handler)
2474
    except EnvironmentError:
2475
      if stderr_logging or syslog == constants.SYSLOG_YES:
2476
        logging.exception("Failed to enable logging to file '%s'", logfile)
2477
      else:
2478
        # we need to re-raise the exception
2479
        raise
2480

    
2481

    
2482
def IsNormAbsPath(path):
2483
  """Check whether a path is absolute and also normalized
2484

2485
  This avoids things like /dir/../../other/path to be valid.
2486

2487
  """
2488
  return os.path.normpath(path) == path and os.path.isabs(path)
2489

    
2490

    
2491
def PathJoin(*args):
2492
  """Safe-join a list of path components.
2493

2494
  Requirements:
2495
      - the first argument must be an absolute path
2496
      - no component in the path must have backtracking (e.g. /../),
2497
        since we check for normalization at the end
2498

2499
  @param args: the path components to be joined
2500
  @raise ValueError: for invalid paths
2501

2502
  """
2503
  # ensure we're having at least one path passed in
2504
  assert args
2505
  # ensure the first component is an absolute and normalized path name
2506
  root = args[0]
2507
  if not IsNormAbsPath(root):
2508
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2509
  result = os.path.join(*args)
2510
  # ensure that the whole path is normalized
2511
  if not IsNormAbsPath(result):
2512
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2513
  # check that we're still under the original prefix
2514
  prefix = os.path.commonprefix([root, result])
2515
  if prefix != root:
2516
    raise ValueError("Error: path joining resulted in different prefix"
2517
                     " (%s != %s)" % (prefix, root))
2518
  return result
2519

    
2520

    
2521
def TailFile(fname, lines=20):
2522
  """Return the last lines from a file.
2523

2524
  @note: this function will only read and parse the last 4KB of
2525
      the file; if the lines are very long, it could be that less
2526
      than the requested number of lines are returned
2527

2528
  @param fname: the file name
2529
  @type lines: int
2530
  @param lines: the (maximum) number of lines to return
2531

2532
  """
2533
  fd = open(fname, "r")
2534
  try:
2535
    fd.seek(0, 2)
2536
    pos = fd.tell()
2537
    pos = max(0, pos-4096)
2538
    fd.seek(pos, 0)
2539
    raw_data = fd.read()
2540
  finally:
2541
    fd.close()
2542

    
2543
  rows = raw_data.splitlines()
2544
  return rows[-lines:]
2545

    
2546

    
2547
def FormatTimestampWithTZ(secs):
2548
  """Formats a Unix timestamp with the local timezone.
2549

2550
  """
2551
  return time.strftime("%F %T %Z", time.gmtime(secs))
2552

    
2553

    
2554
def _ParseAsn1Generalizedtime(value):
2555
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2556

2557
  @type value: string
2558
  @param value: ASN1 GENERALIZEDTIME timestamp
2559

2560
  """
2561
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2562
  if m:
2563
    # We have an offset
2564
    asn1time = m.group(1)
2565
    hours = int(m.group(2))
2566
    minutes = int(m.group(3))
2567
    utcoffset = (60 * hours) + minutes
2568
  else:
2569
    if not value.endswith("Z"):
2570
      raise ValueError("Missing timezone")
2571
    asn1time = value[:-1]
2572
    utcoffset = 0
2573

    
2574
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2575

    
2576
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2577

    
2578
  return calendar.timegm(tt.utctimetuple())
2579

    
2580

    
2581
def GetX509CertValidity(cert):
2582
  """Returns the validity period of the certificate.
2583

2584
  @type cert: OpenSSL.crypto.X509
2585
  @param cert: X509 certificate object
2586

2587
  """
2588
  # The get_notBefore and get_notAfter functions are only supported in
2589
  # pyOpenSSL 0.7 and above.
2590
  try:
2591
    get_notbefore_fn = cert.get_notBefore
2592
  except AttributeError:
2593
    not_before = None
2594
  else:
2595
    not_before_asn1 = get_notbefore_fn()
2596

    
2597
    if not_before_asn1 is None:
2598
      not_before = None
2599
    else:
2600
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2601

    
2602
  try:
2603
    get_notafter_fn = cert.get_notAfter
2604
  except AttributeError:
2605
    not_after = None
2606
  else:
2607
    not_after_asn1 = get_notafter_fn()
2608

    
2609
    if not_after_asn1 is None:
2610
      not_after = None
2611
    else:
2612
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2613

    
2614
  return (not_before, not_after)
2615

    
2616

    
2617
def _VerifyCertificateInner(expired, not_before, not_after, now,
2618
                            warn_days, error_days):
2619
  """Verifies certificate validity.
2620

2621
  @type expired: bool
2622
  @param expired: Whether pyOpenSSL considers the certificate as expired
2623
  @type not_before: number or None
2624
  @param not_before: Unix timestamp before which certificate is not valid
2625
  @type not_after: number or None
2626
  @param not_after: Unix timestamp after which certificate is invalid
2627
  @type now: number
2628
  @param now: Current time as Unix timestamp
2629
  @type warn_days: number or None
2630
  @param warn_days: How many days before expiration a warning should be reported
2631
  @type error_days: number or None
2632
  @param error_days: How many days before expiration an error should be reported
2633

2634
  """
2635
  if expired:
2636
    msg = "Certificate is expired"
2637

    
2638
    if not_before is not None and not_after is not None:
2639
      msg += (" (valid from %s to %s)" %
2640
              (FormatTimestampWithTZ(not_before),
2641
               FormatTimestampWithTZ(not_after)))
2642
    elif not_before is not None:
2643
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2644
    elif not_after is not None:
2645
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2646

    
2647
    return (CERT_ERROR, msg)
2648

    
2649
  elif not_before is not None and not_before > now:
2650
    return (CERT_WARNING,
2651
            "Certificate not yet valid (valid from %s)" %
2652
            FormatTimestampWithTZ(not_before))
2653

    
2654
  elif not_after is not None:
2655
    remaining_days = int((not_after - now) / (24 * 3600))
2656

    
2657
    msg = "Certificate expires in about %d days" % remaining_days
2658

    
2659
    if error_days is not None and remaining_days <= error_days:
2660
      return (CERT_ERROR, msg)
2661

    
2662
    if warn_days is not None and remaining_days <= warn_days:
2663
      return (CERT_WARNING, msg)
2664

    
2665
  return (None, None)
2666

    
2667

    
2668
def VerifyX509Certificate(cert, warn_days, error_days):
2669
  """Verifies a certificate for LUVerifyCluster.
2670

2671
  @type cert: OpenSSL.crypto.X509
2672
  @param cert: X509 certificate object
2673
  @type warn_days: number or None
2674
  @param warn_days: How many days before expiration a warning should be reported
2675
  @type error_days: number or None
2676
  @param error_days: How many days before expiration an error should be reported
2677

2678
  """
2679
  # Depending on the pyOpenSSL version, this can just return (None, None)
2680
  (not_before, not_after) = GetX509CertValidity(cert)
2681

    
2682
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2683
                                 time.time(), warn_days, error_days)
2684

    
2685

    
2686
def SignX509Certificate(cert, key, salt):
2687
  """Sign a X509 certificate.
2688

2689
  An RFC822-like signature header is added in front of the certificate.
2690

2691
  @type cert: OpenSSL.crypto.X509
2692
  @param cert: X509 certificate object
2693
  @type key: string
2694
  @param key: Key for HMAC
2695
  @type salt: string
2696
  @param salt: Salt for HMAC
2697
  @rtype: string
2698
  @return: Serialized and signed certificate in PEM format
2699

2700
  """
2701
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2702
    raise errors.GenericError("Invalid salt: %r" % salt)
2703

    
2704
  # Dumping as PEM here ensures the certificate is in a sane format
2705
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2706

    
2707
  return ("%s: %s/%s\n\n%s" %
2708
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2709
           Sha1Hmac(key, cert_pem, salt=salt),
2710
           cert_pem))
2711

    
2712

    
2713
def _ExtractX509CertificateSignature(cert_pem):
2714
  """Helper function to extract signature from X509 certificate.
2715

2716
  """
2717
  # Extract signature from original PEM data
2718
  for line in cert_pem.splitlines():
2719
    if line.startswith("---"):
2720
      break
2721

    
2722
    m = X509_SIGNATURE.match(line.strip())
2723
    if m:
2724
      return (m.group("salt"), m.group("sign"))
2725

    
2726
  raise errors.GenericError("X509 certificate signature is missing")
2727

    
2728

    
2729
def LoadSignedX509Certificate(cert_pem, key):
2730
  """Verifies a signed X509 certificate.
2731

2732
  @type cert_pem: string
2733
  @param cert_pem: Certificate in PEM format and with signature header
2734
  @type key: string
2735
  @param key: Key for HMAC
2736
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2737
  @return: X509 certificate object and salt
2738

2739
  """
2740
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2741

    
2742
  # Load certificate
2743
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2744

    
2745
  # Dump again to ensure it's in a sane format
2746
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2747

    
2748
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2749
    raise errors.GenericError("X509 certificate signature is invalid")
2750

    
2751
  return (cert, salt)
2752

    
2753

    
2754
def Sha1Hmac(key, text, salt=None):
2755
  """Calculates the HMAC-SHA1 digest of a text.
2756

2757
  HMAC is defined in RFC2104.
2758

2759
  @type key: string
2760
  @param key: Secret key
2761
  @type text: string
2762

2763
  """
2764
  if salt:
2765
    salted_text = salt + text
2766
  else:
2767
    salted_text = text
2768

    
2769
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2770

    
2771

    
2772
def VerifySha1Hmac(key, text, digest, salt=None):
2773
  """Verifies the HMAC-SHA1 digest of a text.
2774

2775
  HMAC is defined in RFC2104.
2776

2777
  @type key: string
2778
  @param key: Secret key
2779
  @type text: string
2780
  @type digest: string
2781
  @param digest: Expected digest
2782
  @rtype: bool
2783
  @return: Whether HMAC-SHA1 digest matches
2784

2785
  """
2786
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2787

    
2788

    
2789
def SafeEncode(text):
2790
  """Return a 'safe' version of a source string.
2791

2792
  This function mangles the input string and returns a version that
2793
  should be safe to display/encode as ASCII. To this end, we first
2794
  convert it to ASCII using the 'backslashreplace' encoding which
2795
  should get rid of any non-ASCII chars, and then we process it
2796
  through a loop copied from the string repr sources in the python; we
2797
  don't use string_escape anymore since that escape single quotes and
2798
  backslashes too, and that is too much; and that escaping is not
2799
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2800

2801
  @type text: str or unicode
2802
  @param text: input data
2803
  @rtype: str
2804
  @return: a safe version of text
2805

2806
  """
2807
  if isinstance(text, unicode):
2808
    # only if unicode; if str already, we handle it below
2809
    text = text.encode('ascii', 'backslashreplace')
2810
  resu = ""
2811
  for char in text:
2812
    c = ord(char)
2813
    if char  == '\t':
2814
      resu += r'\t'
2815
    elif char == '\n':
2816
      resu += r'\n'
2817
    elif char == '\r':
2818
      resu += r'\'r'
2819
    elif c < 32 or c >= 127: # non-printable
2820
      resu += "\\x%02x" % (c & 0xff)
2821
    else:
2822
      resu += char
2823
  return resu
2824

    
2825

    
2826
def UnescapeAndSplit(text, sep=","):
2827
  """Split and unescape a string based on a given separator.
2828

2829
  This function splits a string based on a separator where the
2830
  separator itself can be escape in order to be an element of the
2831
  elements. The escaping rules are (assuming coma being the
2832
  separator):
2833
    - a plain , separates the elements
2834
    - a sequence \\\\, (double backslash plus comma) is handled as a
2835
      backslash plus a separator comma
2836
    - a sequence \, (backslash plus comma) is handled as a
2837
      non-separator comma
2838

2839
  @type text: string
2840
  @param text: the string to split
2841
  @type sep: string
2842
  @param text: the separator
2843
  @rtype: string
2844
  @return: a list of strings
2845

2846
  """
2847
  # we split the list by sep (with no escaping at this stage)
2848
  slist = text.split(sep)
2849
  # next, we revisit the elements and if any of them ended with an odd
2850
  # number of backslashes, then we join it with the next
2851
  rlist = []
2852
  while slist:
2853
    e1 = slist.pop(0)
2854
    if e1.endswith("\\"):
2855
      num_b = len(e1) - len(e1.rstrip("\\"))
2856
      if num_b % 2 == 1:
2857
        e2 = slist.pop(0)
2858
        # here the backslashes remain (all), and will be reduced in
2859
        # the next step
2860
        rlist.append(e1 + sep + e2)
2861
        continue
2862
    rlist.append(e1)
2863
  # finally, replace backslash-something with something
2864
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2865
  return rlist
2866

    
2867

    
2868
def CommaJoin(names):
2869
  """Nicely join a set of identifiers.
2870

2871
  @param names: set, list or tuple
2872
  @return: a string with the formatted results
2873

2874
  """
2875
  return ", ".join([str(val) for val in names])
2876

    
2877

    
2878
def BytesToMebibyte(value):
2879
  """Converts bytes to mebibytes.
2880

2881
  @type value: int
2882
  @param value: Value in bytes
2883
  @rtype: int
2884
  @return: Value in mebibytes
2885

2886
  """
2887
  return int(round(value / (1024.0 * 1024.0), 0))
2888

    
2889

    
2890
def CalculateDirectorySize(path):
2891
  """Calculates the size of a directory recursively.
2892

2893
  @type path: string
2894
  @param path: Path to directory
2895
  @rtype: int
2896
  @return: Size in mebibytes
2897

2898
  """
2899
  size = 0
2900

    
2901
  for (curpath, _, files) in os.walk(path):
2902
    for filename in files:
2903
      st = os.lstat(PathJoin(curpath, filename))
2904
      size += st.st_size
2905

    
2906
  return BytesToMebibyte(size)
2907

    
2908

    
2909
def GetMounts(filename=constants.PROC_MOUNTS):
2910
  """Returns the list of mounted filesystems.
2911

2912
  This function is Linux-specific.
2913

2914
  @param filename: path of mounts file (/proc/mounts by default)
2915
  @rtype: list of tuples
2916
  @return: list of mount entries (device, mountpoint, fstype, options)
2917

2918
  """
2919
  # TODO(iustin): investigate non-Linux options (e.g. via mount output)
2920
  data = []
2921
  mountlines = ReadFile(filename).splitlines()
2922
  for line in mountlines:
2923
    device, mountpoint, fstype, options, _ = line.split(None, 4)
2924
    data.append((device, mountpoint, fstype, options))
2925

    
2926
  return data
2927

    
2928

    
2929
def GetFilesystemStats(path):
2930
  """Returns the total and free space on a filesystem.
2931

2932
  @type path: string
2933
  @param path: Path on filesystem to be examined
2934
  @rtype: int
2935
  @return: tuple of (Total space, Free space) in mebibytes
2936

2937
  """
2938
  st = os.statvfs(path)
2939

    
2940
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2941
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2942
  return (tsize, fsize)
2943

    
2944

    
2945
def RunInSeparateProcess(fn, *args):
2946
  """Runs a function in a separate process.
2947

2948
  Note: Only boolean return values are supported.
2949

2950
  @type fn: callable
2951
  @param fn: Function to be called
2952
  @rtype: bool
2953
  @return: Function's result
2954

2955
  """
2956
  pid = os.fork()
2957
  if pid == 0:
2958
    # Child process
2959
    try:
2960
      # In case the function uses temporary files
2961
      ResetTempfileModule()
2962

    
2963
      # Call function
2964
      result = int(bool(fn(*args)))
2965
      assert result in (0, 1)
2966
    except: # pylint: disable-msg=W0702
2967
      logging.exception("Error while calling function in separate process")
2968
      # 0 and 1 are reserved for the return value
2969
      result = 33
2970

    
2971
    os._exit(result) # pylint: disable-msg=W0212
2972

    
2973
  # Parent process
2974

    
2975
  # Avoid zombies and check exit code
2976
  (_, status) = os.waitpid(pid, 0)
2977

    
2978
  if os.WIFSIGNALED(status):
2979
    exitcode = None
2980
    signum = os.WTERMSIG(status)
2981
  else:
2982
    exitcode = os.WEXITSTATUS(status)
2983
    signum = None
2984

    
2985
  if not (exitcode in (0, 1) and signum is None):
2986
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
2987
                              (exitcode, signum))
2988

    
2989
  return bool(exitcode)
2990

    
2991

    
2992
def IgnoreProcessNotFound(fn, *args, **kwargs):
2993
  """Ignores ESRCH when calling a process-related function.
2994

2995
  ESRCH is raised when a process is not found.
2996

2997
  @rtype: bool
2998
  @return: Whether process was found
2999

3000
  """
3001
  try:
3002
    fn(*args, **kwargs)
3003
  except EnvironmentError, err:
3004
    # Ignore ESRCH
3005
    if err.errno == errno.ESRCH:
3006
      return False
3007
    raise
3008

    
3009
  return True
3010

    
3011

    
3012
def IgnoreSignals(fn, *args, **kwargs):
3013
  """Tries to call a function ignoring failures due to EINTR.
3014

3015
  """
3016
  try:
3017
    return fn(*args, **kwargs)
3018
  except EnvironmentError, err:
3019
    if err.errno == errno.EINTR:
3020
      return None
3021
    else:
3022
      raise
3023
  except (select.error, socket.error), err:
3024
    # In python 2.6 and above select.error is an IOError, so it's handled
3025
    # above, in 2.5 and below it's not, and it's handled here.
3026
    if err.args and err.args[0] == errno.EINTR:
3027
      return None
3028
    else:
3029
      raise
3030

    
3031

    
3032
def LockFile(fd):
3033
  """Locks a file using POSIX locks.
3034

3035
  @type fd: int
3036
  @param fd: the file descriptor we need to lock
3037

3038
  """
3039
  try:
3040
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3041
  except IOError, err:
3042
    if err.errno == errno.EAGAIN:
3043
      raise errors.LockError("File already locked")
3044
    raise
3045

    
3046

    
3047
def FormatTime(val):
3048
  """Formats a time value.
3049

3050
  @type val: float or None
3051
  @param val: the timestamp as returned by time.time()
3052
  @return: a string value or N/A if we don't have a valid timestamp
3053

3054
  """
3055
  if val is None or not isinstance(val, (int, float)):
3056
    return "N/A"
3057
  # these two codes works on Linux, but they are not guaranteed on all
3058
  # platforms
3059
  return time.strftime("%F %T", time.localtime(val))
3060

    
3061

    
3062
def FormatSeconds(secs):
3063
  """Formats seconds for easier reading.
3064

3065
  @type secs: number
3066
  @param secs: Number of seconds
3067
  @rtype: string
3068
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3069

3070
  """
3071
  parts = []
3072

    
3073
  secs = round(secs, 0)
3074

    
3075
  if secs > 0:
3076
    # Negative values would be a bit tricky
3077
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3078
      (complete, secs) = divmod(secs, one)
3079
      if complete or parts:
3080
        parts.append("%d%s" % (complete, unit))
3081

    
3082
  parts.append("%ds" % secs)
3083

    
3084
  return " ".join(parts)
3085

    
3086

    
3087
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3088
  """Reads the watcher pause file.
3089

3090
  @type filename: string
3091
  @param filename: Path to watcher pause file
3092
  @type now: None, float or int
3093
  @param now: Current time as Unix timestamp
3094
  @type remove_after: int
3095
  @param remove_after: Remove watcher pause file after specified amount of
3096
    seconds past the pause end time
3097

3098
  """
3099
  if now is None:
3100
    now = time.time()
3101

    
3102
  try:
3103
    value = ReadFile(filename)
3104
  except IOError, err:
3105
    if err.errno != errno.ENOENT:
3106
      raise
3107
    value = None
3108

    
3109
  if value is not None:
3110
    try:
3111
      value = int(value)
3112
    except ValueError:
3113
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3114
                       " removing it"), filename)
3115
      RemoveFile(filename)
3116
      value = None
3117

    
3118
    if value is not None:
3119
      # Remove file if it's outdated
3120
      if now > (value + remove_after):
3121
        RemoveFile(filename)
3122
        value = None
3123

    
3124
      elif now > value:
3125
        value = None
3126

    
3127
  return value
3128

    
3129

    
3130
class RetryTimeout(Exception):
3131
  """Retry loop timed out.
3132

3133
  Any arguments which was passed by the retried function to RetryAgain will be
3134
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3135
  the RaiseInner helper method will reraise it.
3136

3137
  """
3138
  def RaiseInner(self):
3139
    if self.args and isinstance(self.args[0], Exception):
3140
      raise self.args[0]
3141
    else:
3142
      raise RetryTimeout(*self.args)
3143

    
3144

    
3145
class RetryAgain(Exception):
3146
  """Retry again.
3147

3148
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3149
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3150
  of the RetryTimeout() method can be used to reraise it.
3151

3152
  """
3153

    
3154

    
3155
class _RetryDelayCalculator(object):
3156
  """Calculator for increasing delays.
3157

3158
  """
3159
  __slots__ = [
3160
    "_factor",
3161
    "_limit",
3162
    "_next",
3163
    "_start",
3164
    ]
3165

    
3166
  def __init__(self, start, factor, limit):
3167
    """Initializes this class.
3168

3169
    @type start: float
3170
    @param start: Initial delay
3171
    @type factor: float
3172
    @param factor: Factor for delay increase
3173
    @type limit: float or None
3174
    @param limit: Upper limit for delay or None for no limit
3175

3176
    """
3177
    assert start > 0.0
3178
    assert factor >= 1.0
3179
    assert limit is None or limit >= 0.0
3180

    
3181
    self._start = start
3182
    self._factor = factor
3183
    self._limit = limit
3184

    
3185
    self._next = start
3186

    
3187
  def __call__(self):
3188
    """Returns current delay and calculates the next one.
3189

3190
    """
3191
    current = self._next
3192

    
3193
    # Update for next run
3194
    if self._limit is None or self._next < self._limit:
3195
      self._next = min(self._limit, self._next * self._factor)
3196

    
3197
    return current
3198

    
3199

    
3200
#: Special delay to specify whole remaining timeout
3201
RETRY_REMAINING_TIME = object()
3202

    
3203

    
3204
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3205
          _time_fn=time.time):
3206
  """Call a function repeatedly until it succeeds.
3207

3208
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3209
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3210
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3211

3212
  C{delay} can be one of the following:
3213
    - callable returning the delay length as a float
3214
    - Tuple of (start, factor, limit)
3215
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3216
      useful when overriding L{wait_fn} to wait for an external event)
3217
    - A static delay as a number (int or float)
3218

3219
  @type fn: callable
3220
  @param fn: Function to be called
3221
  @param delay: Either a callable (returning the delay), a tuple of (start,
3222
                factor, limit) (see L{_RetryDelayCalculator}),
3223
                L{RETRY_REMAINING_TIME} or a number (int or float)
3224
  @type timeout: float
3225
  @param timeout: Total timeout
3226
  @type wait_fn: callable
3227
  @param wait_fn: Waiting function
3228
  @return: Return value of function
3229

3230
  """
3231
  assert callable(fn)
3232
  assert callable(wait_fn)
3233
  assert callable(_time_fn)
3234

    
3235
  if args is None:
3236
    args = []
3237

    
3238
  end_time = _time_fn() + timeout
3239

    
3240
  if callable(delay):
3241
    # External function to calculate delay
3242
    calc_delay = delay
3243

    
3244
  elif isinstance(delay, (tuple, list)):
3245
    # Increasing delay with optional upper boundary
3246
    (start, factor, limit) = delay
3247
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3248

    
3249
  elif delay is RETRY_REMAINING_TIME:
3250
    # Always use the remaining time
3251
    calc_delay = None
3252

    
3253
  else:
3254
    # Static delay
3255
    calc_delay = lambda: delay
3256

    
3257
  assert calc_delay is None or callable(calc_delay)
3258

    
3259
  while True:
3260
    retry_args = []
3261
    try:
3262
      # pylint: disable-msg=W0142
3263
      return fn(*args)
3264
    except RetryAgain, err:
3265
      retry_args = err.args
3266
    except RetryTimeout:
3267
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3268
                                   " handle RetryTimeout")
3269

    
3270
    remaining_time = end_time - _time_fn()
3271

    
3272
    if remaining_time < 0.0:
3273
      # pylint: disable-msg=W0142
3274
      raise RetryTimeout(*retry_args)
3275

    
3276
    assert remaining_time >= 0.0
3277

    
3278
    if calc_delay is None:
3279
      wait_fn(remaining_time)
3280
    else:
3281
      current_delay = calc_delay()
3282
      if current_delay > 0.0:
3283
        wait_fn(current_delay)
3284

    
3285

    
3286
def GetClosedTempfile(*args, **kwargs):
3287
  """Creates a temporary file and returns its path.
3288

3289
  """
3290
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3291
  _CloseFDNoErr(fd)
3292
  return path
3293

    
3294

    
3295
def GenerateSelfSignedX509Cert(common_name, validity):
3296
  """Generates a self-signed X509 certificate.
3297

3298
  @type common_name: string
3299
  @param common_name: commonName value
3300
  @type validity: int
3301
  @param validity: Validity for certificate in seconds
3302

3303
  """
3304
  # Create private and public key
3305
  key = OpenSSL.crypto.PKey()
3306
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3307

    
3308
  # Create self-signed certificate
3309
  cert = OpenSSL.crypto.X509()
3310
  if common_name:
3311
    cert.get_subject().CN = common_name
3312
  cert.set_serial_number(1)
3313
  cert.gmtime_adj_notBefore(0)
3314
  cert.gmtime_adj_notAfter(validity)
3315
  cert.set_issuer(cert.get_subject())
3316
  cert.set_pubkey(key)
3317
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3318

    
3319
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3320
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3321

    
3322
  return (key_pem, cert_pem)
3323

    
3324

    
3325
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3326
  """Legacy function to generate self-signed X509 certificate.
3327

3328
  """
3329
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3330
                                                   validity * 24 * 60 * 60)
3331

    
3332
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3333

    
3334

    
3335
class FileLock(object):
3336
  """Utility class for file locks.
3337

3338
  """
3339
  def __init__(self, fd, filename):
3340
    """Constructor for FileLock.
3341

3342
    @type fd: file
3343
    @param fd: File object
3344
    @type filename: str
3345
    @param filename: Path of the file opened at I{fd}
3346

3347
    """
3348
    self.fd = fd
3349
    self.filename = filename
3350

    
3351
  @classmethod
3352
  def Open(cls, filename):
3353
    """Creates and opens a file to be used as a file-based lock.
3354

3355
    @type filename: string
3356
    @param filename: path to the file to be locked
3357

3358
    """
3359
    # Using "os.open" is necessary to allow both opening existing file
3360
    # read/write and creating if not existing. Vanilla "open" will truncate an
3361
    # existing file -or- allow creating if not existing.
3362
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3363
               filename)
3364

    
3365
  def __del__(self):
3366
    self.Close()
3367

    
3368
  def Close(self):
3369
    """Close the file and release the lock.
3370

3371
    """
3372
    if hasattr(self, "fd") and self.fd:
3373
      self.fd.close()
3374
      self.fd = None
3375

    
3376
  def _flock(self, flag, blocking, timeout, errmsg):
3377
    """Wrapper for fcntl.flock.
3378

3379
    @type flag: int
3380
    @param flag: operation flag
3381
    @type blocking: bool
3382
    @param blocking: whether the operation should be done in blocking mode.
3383
    @type timeout: None or float
3384
    @param timeout: for how long the operation should be retried (implies
3385
                    non-blocking mode).
3386
    @type errmsg: string
3387
    @param errmsg: error message in case operation fails.
3388

3389
    """
3390
    assert self.fd, "Lock was closed"
3391
    assert timeout is None or timeout >= 0, \
3392
      "If specified, timeout must be positive"
3393
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3394

    
3395
    # When a timeout is used, LOCK_NB must always be set
3396
    if not (timeout is None and blocking):
3397
      flag |= fcntl.LOCK_NB
3398

    
3399
    if timeout is None:
3400
      self._Lock(self.fd, flag, timeout)
3401
    else:
3402
      try:
3403
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3404
              args=(self.fd, flag, timeout))
3405
      except RetryTimeout:
3406
        raise errors.LockError(errmsg)
3407

    
3408
  @staticmethod
3409
  def _Lock(fd, flag, timeout):
3410
    try:
3411
      fcntl.flock(fd, flag)
3412
    except IOError, err:
3413
      if timeout is not None and err.errno == errno.EAGAIN:
3414
        raise RetryAgain()
3415

    
3416
      logging.exception("fcntl.flock failed")
3417
      raise
3418

    
3419
  def Exclusive(self, blocking=False, timeout=None):
3420
    """Locks the file in exclusive mode.
3421

3422
    @type blocking: boolean
3423
    @param blocking: whether to block and wait until we
3424
        can lock the file or return immediately
3425
    @type timeout: int or None
3426
    @param timeout: if not None, the duration to wait for the lock
3427
        (in blocking mode)
3428

3429
    """
3430
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3431
                "Failed to lock %s in exclusive mode" % self.filename)
3432

    
3433
  def Shared(self, blocking=False, timeout=None):
3434
    """Locks the file in shared mode.
3435

3436
    @type blocking: boolean
3437
    @param blocking: whether to block and wait until we
3438
        can lock the file or return immediately
3439
    @type timeout: int or None
3440
    @param timeout: if not None, the duration to wait for the lock
3441
        (in blocking mode)
3442

3443
    """
3444
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3445
                "Failed to lock %s in shared mode" % self.filename)
3446

    
3447
  def Unlock(self, blocking=True, timeout=None):
3448
    """Unlocks the file.
3449

3450
    According to C{flock(2)}, unlocking can also be a nonblocking
3451
    operation::
3452

3453
      To make a non-blocking request, include LOCK_NB with any of the above
3454
      operations.
3455

3456
    @type blocking: boolean
3457
    @param blocking: whether to block and wait until we
3458
        can lock the file or return immediately
3459
    @type timeout: int or None
3460
    @param timeout: if not None, the duration to wait for the lock
3461
        (in blocking mode)
3462

3463
    """
3464
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3465
                "Failed to unlock %s" % self.filename)
3466

    
3467

    
3468
class LineSplitter:
3469
  """Splits data chunks into lines separated by newline.
3470

3471
  Instances provide a file-like interface.
3472

3473
  """
3474
  def __init__(self, line_fn, *args):
3475
    """Initializes this class.
3476

3477
    @type line_fn: callable
3478
    @param line_fn: Function called for each line, first parameter is line
3479
    @param args: Extra arguments for L{line_fn}
3480

3481
    """
3482
    assert callable(line_fn)
3483

    
3484
    if args:
3485
      # Python 2.4 doesn't have functools.partial yet
3486
      self._line_fn = \
3487
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3488
    else:
3489
      self._line_fn = line_fn
3490

    
3491
    self._lines = collections.deque()
3492
    self._buffer = ""
3493

    
3494
  def write(self, data):
3495
    parts = (self._buffer + data).split("\n")
3496
    self._buffer = parts.pop()
3497
    self._lines.extend(parts)
3498

    
3499
  def flush(self):
3500
    while self._lines:
3501
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3502

    
3503
  def close(self):
3504
    self.flush()
3505
    if self._buffer:
3506
      self._line_fn(self._buffer)
3507

    
3508

    
3509
def SignalHandled(signums):
3510
  """Signal Handled decoration.
3511

3512
  This special decorator installs a signal handler and then calls the target
3513
  function. The function must accept a 'signal_handlers' keyword argument,
3514
  which will contain a dict indexed by signal number, with SignalHandler
3515
  objects as values.
3516

3517
  The decorator can be safely stacked with iself, to handle multiple signals
3518
  with different handlers.
3519

3520
  @type signums: list
3521
  @param signums: signals to intercept
3522

3523
  """
3524
  def wrap(fn):
3525
    def sig_function(*args, **kwargs):
3526
      assert 'signal_handlers' not in kwargs or \
3527
             kwargs['signal_handlers'] is None or \
3528
             isinstance(kwargs['signal_handlers'], dict), \
3529
             "Wrong signal_handlers parameter in original function call"
3530
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3531
        signal_handlers = kwargs['signal_handlers']
3532
      else:
3533
        signal_handlers = {}
3534
        kwargs['signal_handlers'] = signal_handlers
3535
      sighandler = SignalHandler(signums)
3536
      try:
3537
        for sig in signums:
3538
          signal_handlers[sig] = sighandler
3539
        return fn(*args, **kwargs)
3540
      finally:
3541
        sighandler.Reset()
3542
    return sig_function
3543
  return wrap
3544

    
3545

    
3546
class SignalWakeupFd(object):
3547
  try:
3548
    # This is only supported in Python 2.5 and above (some distributions
3549
    # backported it to Python 2.4)
3550
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3551
  except AttributeError:
3552
    # Not supported
3553
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3554
      return -1
3555
  else:
3556
    def _SetWakeupFd(self, fd):
3557
      return self._set_wakeup_fd_fn(fd)
3558

    
3559
  def __init__(self):
3560
    """Initializes this class.
3561

3562
    """
3563
    (read_fd, write_fd) = os.pipe()
3564

    
3565
    # Once these succeeded, the file descriptors will be closed automatically.
3566
    # Buffer size 0 is important, otherwise .read() with a specified length
3567
    # might buffer data and the file descriptors won't be marked readable.
3568
    self._read_fh = os.fdopen(read_fd, "r", 0)
3569
    self._write_fh = os.fdopen(write_fd, "w", 0)
3570

    
3571
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3572

    
3573
    # Utility functions
3574
    self.fileno = self._read_fh.fileno
3575
    self.read = self._read_fh.read
3576

    
3577
  def Reset(self):
3578
    """Restores the previous wakeup file descriptor.
3579

3580
    """
3581
    if hasattr(self, "_previous") and self._previous is not None:
3582
      self._SetWakeupFd(self._previous)
3583
      self._previous = None
3584

    
3585
  def Notify(self):
3586
    """Notifies the wakeup file descriptor.
3587

3588
    """
3589
    self._write_fh.write("\0")
3590

    
3591
  def __del__(self):
3592
    """Called before object deletion.
3593

3594
    """
3595
    self.Reset()
3596

    
3597

    
3598
class SignalHandler(object):
3599
  """Generic signal handler class.
3600

3601
  It automatically restores the original handler when deconstructed or
3602
  when L{Reset} is called. You can either pass your own handler
3603
  function in or query the L{called} attribute to detect whether the
3604
  signal was sent.
3605

3606
  @type signum: list
3607
  @ivar signum: the signals we handle
3608
  @type called: boolean
3609
  @ivar called: tracks whether any of the signals have been raised
3610

3611
  """
3612
  def __init__(self, signum, handler_fn=None, wakeup=None):
3613
    """Constructs a new SignalHandler instance.
3614

3615
    @type signum: int or list of ints
3616
    @param signum: Single signal number or set of signal numbers
3617
    @type handler_fn: callable
3618
    @param handler_fn: Signal handling function
3619

3620
    """
3621
    assert handler_fn is None or callable(handler_fn)
3622

    
3623
    self.signum = set(signum)
3624
    self.called = False
3625

    
3626
    self._handler_fn = handler_fn
3627
    self._wakeup = wakeup
3628

    
3629
    self._previous = {}
3630
    try:
3631
      for signum in self.signum:
3632
        # Setup handler
3633
        prev_handler = signal.signal(signum, self._HandleSignal)
3634
        try:
3635
          self._previous[signum] = prev_handler
3636
        except:
3637
          # Restore previous handler
3638
          signal.signal(signum, prev_handler)
3639
          raise
3640
    except:
3641
      # Reset all handlers
3642
      self.Reset()
3643
      # Here we have a race condition: a handler may have already been called,
3644
      # but there's not much we can do about it at this point.
3645
      raise
3646

    
3647
  def __del__(self):
3648
    self.Reset()
3649

    
3650
  def Reset(self):
3651
    """Restore previous handler.
3652

3653
    This will reset all the signals to their previous handlers.
3654

3655
    """
3656
    for signum, prev_handler in self._previous.items():
3657
      signal.signal(signum, prev_handler)
3658
      # If successful, remove from dict
3659
      del self._previous[signum]
3660

    
3661
  def Clear(self):
3662
    """Unsets the L{called} flag.
3663

3664
    This function can be used in case a signal may arrive several times.
3665

3666
    """
3667
    self.called = False
3668

    
3669
  def _HandleSignal(self, signum, frame):
3670
    """Actual signal handling function.
3671

3672
    """
3673
    # This is not nice and not absolutely atomic, but it appears to be the only
3674
    # solution in Python -- there are no atomic types.
3675
    self.called = True
3676

    
3677
    if self._wakeup:
3678
      # Notify whoever is interested in signals
3679
      self._wakeup.Notify()
3680

    
3681
    if self._handler_fn:
3682
      self._handler_fn(signum, frame)
3683

    
3684

    
3685
class FieldSet(object):
3686
  """A simple field set.
3687

3688
  Among the features are:
3689
    - checking if a string is among a list of static string or regex objects
3690
    - checking if a whole list of string matches
3691
    - returning the matching groups from a regex match
3692

3693
  Internally, all fields are held as regular expression objects.
3694

3695
  """
3696
  def __init__(self, *items):
3697
    self.items = [re.compile("^%s$" % value) for value in items]
3698

    
3699
  def Extend(self, other_set):
3700
    """Extend the field set with the items from another one"""
3701
    self.items.extend(other_set.items)
3702

    
3703
  def Matches(self, field):
3704
    """Checks if a field matches the current set
3705

3706
    @type field: str
3707
    @param field: the string to match
3708
    @return: either None or a regular expression match object
3709

3710
    """
3711
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3712
      return m
3713
    return None
3714

    
3715
  def NonMatching(self, items):
3716
    """Returns the list of fields not matching the current set
3717

3718
    @type items: list
3719
    @param items: the list of fields to check
3720
    @rtype: list
3721
    @return: list of non-matching fields
3722

3723
    """
3724
    return [val for val in items if not self.Matches(val)]