Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 965d0e5b

History | View | Annotate | Download (100.9 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  import ctypes
59
except ImportError:
60
  ctypes = None
61

    
62
from ganeti import errors
63
from ganeti import constants
64
from ganeti import compat
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
85
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
86
#
87
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
88
# pid_t to "int".
89
#
90
# IEEE Std 1003.1-2008:
91
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
92
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
93
_STRUCT_UCRED = "iII"
94
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
95

    
96
# Certificate verification results
97
(CERT_WARNING,
98
 CERT_ERROR) = range(1, 3)
99

    
100
# Flags for mlockall() (from bits/mman.h)
101
_MCL_CURRENT = 1
102
_MCL_FUTURE = 2
103

    
104

    
105
class RunResult(object):
106
  """Holds the result of running external programs.
107

108
  @type exit_code: int
109
  @ivar exit_code: the exit code of the program, or None (if the program
110
      didn't exit())
111
  @type signal: int or None
112
  @ivar signal: the signal that caused the program to finish, or None
113
      (if the program wasn't terminated by a signal)
114
  @type stdout: str
115
  @ivar stdout: the standard output of the program
116
  @type stderr: str
117
  @ivar stderr: the standard error of the program
118
  @type failed: boolean
119
  @ivar failed: True in case the program was
120
      terminated by a signal or exited with a non-zero exit code
121
  @ivar fail_reason: a string detailing the termination reason
122

123
  """
124
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
125
               "failed", "fail_reason", "cmd"]
126

    
127

    
128
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
129
    self.cmd = cmd
130
    self.exit_code = exit_code
131
    self.signal = signal_
132
    self.stdout = stdout
133
    self.stderr = stderr
134
    self.failed = (signal_ is not None or exit_code != 0)
135

    
136
    if self.signal is not None:
137
      self.fail_reason = "terminated by signal %s" % self.signal
138
    elif self.exit_code is not None:
139
      self.fail_reason = "exited with exit code %s" % self.exit_code
140
    else:
141
      self.fail_reason = "unable to determine termination reason"
142

    
143
    if self.failed:
144
      logging.debug("Command '%s' failed (%s); output: %s",
145
                    self.cmd, self.fail_reason, self.output)
146

    
147
  def _GetOutput(self):
148
    """Returns the combined stdout and stderr for easier usage.
149

150
    """
151
    return self.stdout + self.stderr
152

    
153
  output = property(_GetOutput, None, None, "Return full output")
154

    
155

    
156
def _BuildCmdEnvironment(env, reset):
157
  """Builds the environment for an external program.
158

159
  """
160
  if reset:
161
    cmd_env = {}
162
  else:
163
    cmd_env = os.environ.copy()
164
    cmd_env["LC_ALL"] = "C"
165

    
166
  if env is not None:
167
    cmd_env.update(env)
168

    
169
  return cmd_env
170

    
171

    
172
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
173
  """Execute a (shell) command.
174

175
  The command should not read from its standard input, as it will be
176
  closed.
177

178
  @type cmd: string or list
179
  @param cmd: Command to run
180
  @type env: dict
181
  @param env: Additional environment variables
182
  @type output: str
183
  @param output: if desired, the output of the command can be
184
      saved in a file instead of the RunResult instance; this
185
      parameter denotes the file name (if not None)
186
  @type cwd: string
187
  @param cwd: if specified, will be used as the working
188
      directory for the command; the default will be /
189
  @type reset_env: boolean
190
  @param reset_env: whether to reset or keep the default os environment
191
  @rtype: L{RunResult}
192
  @return: RunResult instance
193
  @raise errors.ProgrammerError: if we call this when forks are disabled
194

195
  """
196
  if no_fork:
197
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
198

    
199
  if isinstance(cmd, basestring):
200
    strcmd = cmd
201
    shell = True
202
  else:
203
    cmd = [str(val) for val in cmd]
204
    strcmd = ShellQuoteArgs(cmd)
205
    shell = False
206

    
207
  if output:
208
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
209
  else:
210
    logging.debug("RunCmd %s", strcmd)
211

    
212
  cmd_env = _BuildCmdEnvironment(env, reset_env)
213

    
214
  try:
215
    if output is None:
216
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
217
    else:
218
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
219
      out = err = ""
220
  except OSError, err:
221
    if err.errno == errno.ENOENT:
222
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
223
                               (strcmd, err))
224
    else:
225
      raise
226

    
227
  if status >= 0:
228
    exitcode = status
229
    signal_ = None
230
  else:
231
    exitcode = None
232
    signal_ = -status
233

    
234
  return RunResult(exitcode, signal_, out, err, strcmd)
235

    
236

    
237
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
238
                pidfile=None):
239
  """Start a daemon process after forking twice.
240

241
  @type cmd: string or list
242
  @param cmd: Command to run
243
  @type env: dict
244
  @param env: Additional environment variables
245
  @type cwd: string
246
  @param cwd: Working directory for the program
247
  @type output: string
248
  @param output: Path to file in which to save the output
249
  @type output_fd: int
250
  @param output_fd: File descriptor for output
251
  @type pidfile: string
252
  @param pidfile: Process ID file
253
  @rtype: int
254
  @return: Daemon process ID
255
  @raise errors.ProgrammerError: if we call this when forks are disabled
256

257
  """
258
  if no_fork:
259
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
260
                                 " disabled")
261

    
262
  if output and not (bool(output) ^ (output_fd is not None)):
263
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
264
                                 " specified")
265

    
266
  if isinstance(cmd, basestring):
267
    cmd = ["/bin/sh", "-c", cmd]
268

    
269
  strcmd = ShellQuoteArgs(cmd)
270

    
271
  if output:
272
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
273
  else:
274
    logging.debug("StartDaemon %s", strcmd)
275

    
276
  cmd_env = _BuildCmdEnvironment(env, False)
277

    
278
  # Create pipe for sending PID back
279
  (pidpipe_read, pidpipe_write) = os.pipe()
280
  try:
281
    try:
282
      # Create pipe for sending error messages
283
      (errpipe_read, errpipe_write) = os.pipe()
284
      try:
285
        try:
286
          # First fork
287
          pid = os.fork()
288
          if pid == 0:
289
            try:
290
              # Child process, won't return
291
              _StartDaemonChild(errpipe_read, errpipe_write,
292
                                pidpipe_read, pidpipe_write,
293
                                cmd, cmd_env, cwd,
294
                                output, output_fd, pidfile)
295
            finally:
296
              # Well, maybe child process failed
297
              os._exit(1) # pylint: disable-msg=W0212
298
        finally:
299
          _CloseFDNoErr(errpipe_write)
300

    
301
        # Wait for daemon to be started (or an error message to arrive) and read
302
        # up to 100 KB as an error message
303
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
304
      finally:
305
        _CloseFDNoErr(errpipe_read)
306
    finally:
307
      _CloseFDNoErr(pidpipe_write)
308

    
309
    # Read up to 128 bytes for PID
310
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
311
  finally:
312
    _CloseFDNoErr(pidpipe_read)
313

    
314
  # Try to avoid zombies by waiting for child process
315
  try:
316
    os.waitpid(pid, 0)
317
  except OSError:
318
    pass
319

    
320
  if errormsg:
321
    raise errors.OpExecError("Error when starting daemon process: %r" %
322
                             errormsg)
323

    
324
  try:
325
    return int(pidtext)
326
  except (ValueError, TypeError), err:
327
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
328
                             (pidtext, err))
329

    
330

    
331
def _StartDaemonChild(errpipe_read, errpipe_write,
332
                      pidpipe_read, pidpipe_write,
333
                      args, env, cwd,
334
                      output, fd_output, pidfile):
335
  """Child process for starting daemon.
336

337
  """
338
  try:
339
    # Close parent's side
340
    _CloseFDNoErr(errpipe_read)
341
    _CloseFDNoErr(pidpipe_read)
342

    
343
    # First child process
344
    os.chdir("/")
345
    os.umask(077)
346
    os.setsid()
347

    
348
    # And fork for the second time
349
    pid = os.fork()
350
    if pid != 0:
351
      # Exit first child process
352
      os._exit(0) # pylint: disable-msg=W0212
353

    
354
    # Make sure pipe is closed on execv* (and thereby notifies original process)
355
    SetCloseOnExecFlag(errpipe_write, True)
356

    
357
    # List of file descriptors to be left open
358
    noclose_fds = [errpipe_write]
359

    
360
    # Open PID file
361
    if pidfile:
362
      try:
363
        # TODO: Atomic replace with another locked file instead of writing into
364
        # it after creating
365
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
366

    
367
        # Lock the PID file (and fail if not possible to do so). Any code
368
        # wanting to send a signal to the daemon should try to lock the PID
369
        # file before reading it. If acquiring the lock succeeds, the daemon is
370
        # no longer running and the signal should not be sent.
371
        LockFile(fd_pidfile)
372

    
373
        os.write(fd_pidfile, "%d\n" % os.getpid())
374
      except Exception, err:
375
        raise Exception("Creating and locking PID file failed: %s" % err)
376

    
377
      # Keeping the file open to hold the lock
378
      noclose_fds.append(fd_pidfile)
379

    
380
      SetCloseOnExecFlag(fd_pidfile, False)
381
    else:
382
      fd_pidfile = None
383

    
384
    # Open /dev/null
385
    fd_devnull = os.open(os.devnull, os.O_RDWR)
386

    
387
    assert not output or (bool(output) ^ (fd_output is not None))
388

    
389
    if fd_output is not None:
390
      pass
391
    elif output:
392
      # Open output file
393
      try:
394
        # TODO: Implement flag to set append=yes/no
395
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
396
      except EnvironmentError, err:
397
        raise Exception("Opening output file failed: %s" % err)
398
    else:
399
      fd_output = fd_devnull
400

    
401
    # Redirect standard I/O
402
    os.dup2(fd_devnull, 0)
403
    os.dup2(fd_output, 1)
404
    os.dup2(fd_output, 2)
405

    
406
    # Send daemon PID to parent
407
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
408

    
409
    # Close all file descriptors except stdio and error message pipe
410
    CloseFDs(noclose_fds=noclose_fds)
411

    
412
    # Change working directory
413
    os.chdir(cwd)
414

    
415
    if env is None:
416
      os.execvp(args[0], args)
417
    else:
418
      os.execvpe(args[0], args, env)
419
  except: # pylint: disable-msg=W0702
420
    try:
421
      # Report errors to original process
422
      buf = str(sys.exc_info()[1])
423

    
424
      RetryOnSignal(os.write, errpipe_write, buf)
425
    except: # pylint: disable-msg=W0702
426
      # Ignore errors in error handling
427
      pass
428

    
429
  os._exit(1) # pylint: disable-msg=W0212
430

    
431

    
432
def _RunCmdPipe(cmd, env, via_shell, cwd):
433
  """Run a command and return its output.
434

435
  @type  cmd: string or list
436
  @param cmd: Command to run
437
  @type env: dict
438
  @param env: The environment to use
439
  @type via_shell: bool
440
  @param via_shell: if we should run via the shell
441
  @type cwd: string
442
  @param cwd: the working directory for the program
443
  @rtype: tuple
444
  @return: (out, err, status)
445

446
  """
447
  poller = select.poll()
448
  child = subprocess.Popen(cmd, shell=via_shell,
449
                           stderr=subprocess.PIPE,
450
                           stdout=subprocess.PIPE,
451
                           stdin=subprocess.PIPE,
452
                           close_fds=True, env=env,
453
                           cwd=cwd)
454

    
455
  child.stdin.close()
456
  poller.register(child.stdout, select.POLLIN)
457
  poller.register(child.stderr, select.POLLIN)
458
  out = StringIO()
459
  err = StringIO()
460
  fdmap = {
461
    child.stdout.fileno(): (out, child.stdout),
462
    child.stderr.fileno(): (err, child.stderr),
463
    }
464
  for fd in fdmap:
465
    SetNonblockFlag(fd, True)
466

    
467
  while fdmap:
468
    pollresult = RetryOnSignal(poller.poll)
469

    
470
    for fd, event in pollresult:
471
      if event & select.POLLIN or event & select.POLLPRI:
472
        data = fdmap[fd][1].read()
473
        # no data from read signifies EOF (the same as POLLHUP)
474
        if not data:
475
          poller.unregister(fd)
476
          del fdmap[fd]
477
          continue
478
        fdmap[fd][0].write(data)
479
      if (event & select.POLLNVAL or event & select.POLLHUP or
480
          event & select.POLLERR):
481
        poller.unregister(fd)
482
        del fdmap[fd]
483

    
484
  out = out.getvalue()
485
  err = err.getvalue()
486

    
487
  status = child.wait()
488
  return out, err, status
489

    
490

    
491
def _RunCmdFile(cmd, env, via_shell, output, cwd):
492
  """Run a command and save its output to a file.
493

494
  @type  cmd: string or list
495
  @param cmd: Command to run
496
  @type env: dict
497
  @param env: The environment to use
498
  @type via_shell: bool
499
  @param via_shell: if we should run via the shell
500
  @type output: str
501
  @param output: the filename in which to save the output
502
  @type cwd: string
503
  @param cwd: the working directory for the program
504
  @rtype: int
505
  @return: the exit status
506

507
  """
508
  fh = open(output, "a")
509
  try:
510
    child = subprocess.Popen(cmd, shell=via_shell,
511
                             stderr=subprocess.STDOUT,
512
                             stdout=fh,
513
                             stdin=subprocess.PIPE,
514
                             close_fds=True, env=env,
515
                             cwd=cwd)
516

    
517
    child.stdin.close()
518
    status = child.wait()
519
  finally:
520
    fh.close()
521
  return status
522

    
523

    
524
def SetCloseOnExecFlag(fd, enable):
525
  """Sets or unsets the close-on-exec flag on a file descriptor.
526

527
  @type fd: int
528
  @param fd: File descriptor
529
  @type enable: bool
530
  @param enable: Whether to set or unset it.
531

532
  """
533
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
534

    
535
  if enable:
536
    flags |= fcntl.FD_CLOEXEC
537
  else:
538
    flags &= ~fcntl.FD_CLOEXEC
539

    
540
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
541

    
542

    
543
def SetNonblockFlag(fd, enable):
544
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
545

546
  @type fd: int
547
  @param fd: File descriptor
548
  @type enable: bool
549
  @param enable: Whether to set or unset it
550

551
  """
552
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
553

    
554
  if enable:
555
    flags |= os.O_NONBLOCK
556
  else:
557
    flags &= ~os.O_NONBLOCK
558

    
559
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
560

    
561

    
562
def RetryOnSignal(fn, *args, **kwargs):
563
  """Calls a function again if it failed due to EINTR.
564

565
  """
566
  while True:
567
    try:
568
      return fn(*args, **kwargs)
569
    except EnvironmentError, err:
570
      if err.errno != errno.EINTR:
571
        raise
572
    except (socket.error, select.error), err:
573
      # In python 2.6 and above select.error is an IOError, so it's handled
574
      # above, in 2.5 and below it's not, and it's handled here.
575
      if not (err.args and err.args[0] == errno.EINTR):
576
        raise
577

    
578

    
579
def RunParts(dir_name, env=None, reset_env=False):
580
  """Run Scripts or programs in a directory
581

582
  @type dir_name: string
583
  @param dir_name: absolute path to a directory
584
  @type env: dict
585
  @param env: The environment to use
586
  @type reset_env: boolean
587
  @param reset_env: whether to reset or keep the default os environment
588
  @rtype: list of tuples
589
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
590

591
  """
592
  rr = []
593

    
594
  try:
595
    dir_contents = ListVisibleFiles(dir_name)
596
  except OSError, err:
597
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
598
    return rr
599

    
600
  for relname in sorted(dir_contents):
601
    fname = PathJoin(dir_name, relname)
602
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
603
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
604
      rr.append((relname, constants.RUNPARTS_SKIP, None))
605
    else:
606
      try:
607
        result = RunCmd([fname], env=env, reset_env=reset_env)
608
      except Exception, err: # pylint: disable-msg=W0703
609
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
610
      else:
611
        rr.append((relname, constants.RUNPARTS_RUN, result))
612

    
613
  return rr
614

    
615

    
616
def GetSocketCredentials(sock):
617
  """Returns the credentials of the foreign process connected to a socket.
618

619
  @param sock: Unix socket
620
  @rtype: tuple; (number, number, number)
621
  @return: The PID, UID and GID of the connected foreign process.
622

623
  """
624
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
625
                             _STRUCT_UCRED_SIZE)
626
  return struct.unpack(_STRUCT_UCRED, peercred)
627

    
628

    
629
def RemoveFile(filename):
630
  """Remove a file ignoring some errors.
631

632
  Remove a file, ignoring non-existing ones or directories. Other
633
  errors are passed.
634

635
  @type filename: str
636
  @param filename: the file to be removed
637

638
  """
639
  try:
640
    os.unlink(filename)
641
  except OSError, err:
642
    if err.errno not in (errno.ENOENT, errno.EISDIR):
643
      raise
644

    
645

    
646
def RemoveDir(dirname):
647
  """Remove an empty directory.
648

649
  Remove a directory, ignoring non-existing ones.
650
  Other errors are passed. This includes the case,
651
  where the directory is not empty, so it can't be removed.
652

653
  @type dirname: str
654
  @param dirname: the empty directory to be removed
655

656
  """
657
  try:
658
    os.rmdir(dirname)
659
  except OSError, err:
660
    if err.errno != errno.ENOENT:
661
      raise
662

    
663

    
664
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
665
  """Renames a file.
666

667
  @type old: string
668
  @param old: Original path
669
  @type new: string
670
  @param new: New path
671
  @type mkdir: bool
672
  @param mkdir: Whether to create target directory if it doesn't exist
673
  @type mkdir_mode: int
674
  @param mkdir_mode: Mode for newly created directories
675

676
  """
677
  try:
678
    return os.rename(old, new)
679
  except OSError, err:
680
    # In at least one use case of this function, the job queue, directory
681
    # creation is very rare. Checking for the directory before renaming is not
682
    # as efficient.
683
    if mkdir and err.errno == errno.ENOENT:
684
      # Create directory and try again
685
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
686

    
687
      return os.rename(old, new)
688

    
689
    raise
690

    
691

    
692
def Makedirs(path, mode=0750):
693
  """Super-mkdir; create a leaf directory and all intermediate ones.
694

695
  This is a wrapper around C{os.makedirs} adding error handling not implemented
696
  before Python 2.5.
697

698
  """
699
  try:
700
    os.makedirs(path, mode)
701
  except OSError, err:
702
    # Ignore EEXIST. This is only handled in os.makedirs as included in
703
    # Python 2.5 and above.
704
    if err.errno != errno.EEXIST or not os.path.exists(path):
705
      raise
706

    
707

    
708
def ResetTempfileModule():
709
  """Resets the random name generator of the tempfile module.
710

711
  This function should be called after C{os.fork} in the child process to
712
  ensure it creates a newly seeded random generator. Otherwise it would
713
  generate the same random parts as the parent process. If several processes
714
  race for the creation of a temporary file, this could lead to one not getting
715
  a temporary name.
716

717
  """
718
  # pylint: disable-msg=W0212
719
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
720
    tempfile._once_lock.acquire()
721
    try:
722
      # Reset random name generator
723
      tempfile._name_sequence = None
724
    finally:
725
      tempfile._once_lock.release()
726
  else:
727
    logging.critical("The tempfile module misses at least one of the"
728
                     " '_once_lock' and '_name_sequence' attributes")
729

    
730

    
731
def _FingerprintFile(filename):
732
  """Compute the fingerprint of a file.
733

734
  If the file does not exist, a None will be returned
735
  instead.
736

737
  @type filename: str
738
  @param filename: the filename to checksum
739
  @rtype: str
740
  @return: the hex digest of the sha checksum of the contents
741
      of the file
742

743
  """
744
  if not (os.path.exists(filename) and os.path.isfile(filename)):
745
    return None
746

    
747
  f = open(filename)
748

    
749
  fp = compat.sha1_hash()
750
  while True:
751
    data = f.read(4096)
752
    if not data:
753
      break
754

    
755
    fp.update(data)
756

    
757
  return fp.hexdigest()
758

    
759

    
760
def FingerprintFiles(files):
761
  """Compute fingerprints for a list of files.
762

763
  @type files: list
764
  @param files: the list of filename to fingerprint
765
  @rtype: dict
766
  @return: a dictionary filename: fingerprint, holding only
767
      existing files
768

769
  """
770
  ret = {}
771

    
772
  for filename in files:
773
    cksum = _FingerprintFile(filename)
774
    if cksum:
775
      ret[filename] = cksum
776

    
777
  return ret
778

    
779

    
780
def ForceDictType(target, key_types, allowed_values=None):
781
  """Force the values of a dict to have certain types.
782

783
  @type target: dict
784
  @param target: the dict to update
785
  @type key_types: dict
786
  @param key_types: dict mapping target dict keys to types
787
                    in constants.ENFORCEABLE_TYPES
788
  @type allowed_values: list
789
  @keyword allowed_values: list of specially allowed values
790

791
  """
792
  if allowed_values is None:
793
    allowed_values = []
794

    
795
  if not isinstance(target, dict):
796
    msg = "Expected dictionary, got '%s'" % target
797
    raise errors.TypeEnforcementError(msg)
798

    
799
  for key in target:
800
    if key not in key_types:
801
      msg = "Unknown key '%s'" % key
802
      raise errors.TypeEnforcementError(msg)
803

    
804
    if target[key] in allowed_values:
805
      continue
806

    
807
    ktype = key_types[key]
808
    if ktype not in constants.ENFORCEABLE_TYPES:
809
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
810
      raise errors.ProgrammerError(msg)
811

    
812
    if ktype == constants.VTYPE_STRING:
813
      if not isinstance(target[key], basestring):
814
        if isinstance(target[key], bool) and not target[key]:
815
          target[key] = ''
816
        else:
817
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
818
          raise errors.TypeEnforcementError(msg)
819
    elif ktype == constants.VTYPE_BOOL:
820
      if isinstance(target[key], basestring) and target[key]:
821
        if target[key].lower() == constants.VALUE_FALSE:
822
          target[key] = False
823
        elif target[key].lower() == constants.VALUE_TRUE:
824
          target[key] = True
825
        else:
826
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
827
          raise errors.TypeEnforcementError(msg)
828
      elif target[key]:
829
        target[key] = True
830
      else:
831
        target[key] = False
832
    elif ktype == constants.VTYPE_SIZE:
833
      try:
834
        target[key] = ParseUnit(target[key])
835
      except errors.UnitParseError, err:
836
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
837
              (key, target[key], err)
838
        raise errors.TypeEnforcementError(msg)
839
    elif ktype == constants.VTYPE_INT:
840
      try:
841
        target[key] = int(target[key])
842
      except (ValueError, TypeError):
843
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
844
        raise errors.TypeEnforcementError(msg)
845

    
846

    
847
def IsProcessAlive(pid):
848
  """Check if a given pid exists on the system.
849

850
  @note: zombie status is not handled, so zombie processes
851
      will be returned as alive
852
  @type pid: int
853
  @param pid: the process ID to check
854
  @rtype: boolean
855
  @return: True if the process exists
856

857
  """
858
  def _TryStat(name):
859
    try:
860
      os.stat(name)
861
      return True
862
    except EnvironmentError, err:
863
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
864
        return False
865
      elif err.errno == errno.EINVAL:
866
        raise RetryAgain(err)
867
      raise
868

    
869
  assert isinstance(pid, int), "pid must be an integer"
870
  if pid <= 0:
871
    return False
872

    
873
  proc_entry = "/proc/%d/status" % pid
874
  # /proc in a multiprocessor environment can have strange behaviors.
875
  # Retry the os.stat a few times until we get a good result.
876
  try:
877
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5, args=[proc_entry])
878
  except RetryTimeout, err:
879
    err.RaiseInner()
880

    
881

    
882
def ReadPidFile(pidfile):
883
  """Read a pid from a file.
884

885
  @type  pidfile: string
886
  @param pidfile: path to the file containing the pid
887
  @rtype: int
888
  @return: The process id, if the file exists and contains a valid PID,
889
           otherwise 0
890

891
  """
892
  try:
893
    raw_data = ReadOneLineFile(pidfile)
894
  except EnvironmentError, err:
895
    if err.errno != errno.ENOENT:
896
      logging.exception("Can't read pid file")
897
    return 0
898

    
899
  try:
900
    pid = int(raw_data)
901
  except (TypeError, ValueError), err:
902
    logging.info("Can't parse pid file contents", exc_info=True)
903
    return 0
904

    
905
  return pid
906

    
907

    
908
def ReadLockedPidFile(path):
909
  """Reads a locked PID file.
910

911
  This can be used together with L{StartDaemon}.
912

913
  @type path: string
914
  @param path: Path to PID file
915
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
916

917
  """
918
  try:
919
    fd = os.open(path, os.O_RDONLY)
920
  except EnvironmentError, err:
921
    if err.errno == errno.ENOENT:
922
      # PID file doesn't exist
923
      return None
924
    raise
925

    
926
  try:
927
    try:
928
      # Try to acquire lock
929
      LockFile(fd)
930
    except errors.LockError:
931
      # Couldn't lock, daemon is running
932
      return int(os.read(fd, 100))
933
  finally:
934
    os.close(fd)
935

    
936
  return None
937

    
938

    
939
def MatchNameComponent(key, name_list, case_sensitive=True):
940
  """Try to match a name against a list.
941

942
  This function will try to match a name like test1 against a list
943
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
944
  this list, I{'test1'} as well as I{'test1.example'} will match, but
945
  not I{'test1.ex'}. A multiple match will be considered as no match
946
  at all (e.g. I{'test1'} against C{['test1.example.com',
947
  'test1.example.org']}), except when the key fully matches an entry
948
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
949

950
  @type key: str
951
  @param key: the name to be searched
952
  @type name_list: list
953
  @param name_list: the list of strings against which to search the key
954
  @type case_sensitive: boolean
955
  @param case_sensitive: whether to provide a case-sensitive match
956

957
  @rtype: None or str
958
  @return: None if there is no match I{or} if there are multiple matches,
959
      otherwise the element from the list which matches
960

961
  """
962
  if key in name_list:
963
    return key
964

    
965
  re_flags = 0
966
  if not case_sensitive:
967
    re_flags |= re.IGNORECASE
968
    key = key.upper()
969
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
970
  names_filtered = []
971
  string_matches = []
972
  for name in name_list:
973
    if mo.match(name) is not None:
974
      names_filtered.append(name)
975
      if not case_sensitive and key == name.upper():
976
        string_matches.append(name)
977

    
978
  if len(string_matches) == 1:
979
    return string_matches[0]
980
  if len(names_filtered) == 1:
981
    return names_filtered[0]
982
  return None
983

    
984

    
985
class HostInfo:
986
  """Class implementing resolver and hostname functionality
987

988
  """
989
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
990

    
991
  def __init__(self, name=None):
992
    """Initialize the host name object.
993

994
    If the name argument is not passed, it will use this system's
995
    name.
996

997
    """
998
    if name is None:
999
      name = self.SysName()
1000

    
1001
    self.query = name
1002
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1003
    self.ip = self.ipaddrs[0]
1004

    
1005
  def ShortName(self):
1006
    """Returns the hostname without domain.
1007

1008
    """
1009
    return self.name.split('.')[0]
1010

    
1011
  @staticmethod
1012
  def SysName():
1013
    """Return the current system's name.
1014

1015
    This is simply a wrapper over C{socket.gethostname()}.
1016

1017
    """
1018
    return socket.gethostname()
1019

    
1020
  @staticmethod
1021
  def LookupHostname(hostname):
1022
    """Look up hostname
1023

1024
    @type hostname: str
1025
    @param hostname: hostname to look up
1026

1027
    @rtype: tuple
1028
    @return: a tuple (name, aliases, ipaddrs) as returned by
1029
        C{socket.gethostbyname_ex}
1030
    @raise errors.ResolverError: in case of errors in resolving
1031

1032
    """
1033
    try:
1034
      result = socket.gethostbyname_ex(hostname)
1035
    except socket.gaierror, err:
1036
      # hostname not found in DNS
1037
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1038

    
1039
    return result
1040

    
1041
  @classmethod
1042
  def NormalizeName(cls, hostname):
1043
    """Validate and normalize the given hostname.
1044

1045
    @attention: the validation is a bit more relaxed than the standards
1046
        require; most importantly, we allow underscores in names
1047
    @raise errors.OpPrereqError: when the name is not valid
1048

1049
    """
1050
    hostname = hostname.lower()
1051
    if (not cls._VALID_NAME_RE.match(hostname) or
1052
        # double-dots, meaning empty label
1053
        ".." in hostname or
1054
        # empty initial label
1055
        hostname.startswith(".")):
1056
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1057
                                 errors.ECODE_INVAL)
1058
    if hostname.endswith("."):
1059
      hostname = hostname.rstrip(".")
1060
    return hostname
1061

    
1062

    
1063
def GetHostInfo(name=None):
1064
  """Lookup host name and raise an OpPrereqError for failures"""
1065

    
1066
  try:
1067
    return HostInfo(name)
1068
  except errors.ResolverError, err:
1069
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1070
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1071

    
1072

    
1073
def ListVolumeGroups():
1074
  """List volume groups and their size
1075

1076
  @rtype: dict
1077
  @return:
1078
       Dictionary with keys volume name and values
1079
       the size of the volume
1080

1081
  """
1082
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1083
  result = RunCmd(command)
1084
  retval = {}
1085
  if result.failed:
1086
    return retval
1087

    
1088
  for line in result.stdout.splitlines():
1089
    try:
1090
      name, size = line.split()
1091
      size = int(float(size))
1092
    except (IndexError, ValueError), err:
1093
      logging.error("Invalid output from vgs (%s): %s", err, line)
1094
      continue
1095

    
1096
    retval[name] = size
1097

    
1098
  return retval
1099

    
1100

    
1101
def BridgeExists(bridge):
1102
  """Check whether the given bridge exists in the system
1103

1104
  @type bridge: str
1105
  @param bridge: the bridge name to check
1106
  @rtype: boolean
1107
  @return: True if it does
1108

1109
  """
1110
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1111

    
1112

    
1113
def NiceSort(name_list):
1114
  """Sort a list of strings based on digit and non-digit groupings.
1115

1116
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1117
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1118
  'a11']}.
1119

1120
  The sort algorithm breaks each name in groups of either only-digits
1121
  or no-digits. Only the first eight such groups are considered, and
1122
  after that we just use what's left of the string.
1123

1124
  @type name_list: list
1125
  @param name_list: the names to be sorted
1126
  @rtype: list
1127
  @return: a copy of the name list sorted with our algorithm
1128

1129
  """
1130
  _SORTER_BASE = "(\D+|\d+)"
1131
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1132
                                                  _SORTER_BASE, _SORTER_BASE,
1133
                                                  _SORTER_BASE, _SORTER_BASE,
1134
                                                  _SORTER_BASE, _SORTER_BASE)
1135
  _SORTER_RE = re.compile(_SORTER_FULL)
1136
  _SORTER_NODIGIT = re.compile("^\D*$")
1137
  def _TryInt(val):
1138
    """Attempts to convert a variable to integer."""
1139
    if val is None or _SORTER_NODIGIT.match(val):
1140
      return val
1141
    rval = int(val)
1142
    return rval
1143

    
1144
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1145
             for name in name_list]
1146
  to_sort.sort()
1147
  return [tup[1] for tup in to_sort]
1148

    
1149

    
1150
def TryConvert(fn, val):
1151
  """Try to convert a value ignoring errors.
1152

1153
  This function tries to apply function I{fn} to I{val}. If no
1154
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1155
  the result, else it will return the original value. Any other
1156
  exceptions are propagated to the caller.
1157

1158
  @type fn: callable
1159
  @param fn: function to apply to the value
1160
  @param val: the value to be converted
1161
  @return: The converted value if the conversion was successful,
1162
      otherwise the original value.
1163

1164
  """
1165
  try:
1166
    nv = fn(val)
1167
  except (ValueError, TypeError):
1168
    nv = val
1169
  return nv
1170

    
1171

    
1172
def IsValidIP(ip):
1173
  """Verifies the syntax of an IPv4 address.
1174

1175
  This function checks if the IPv4 address passes is valid or not based
1176
  on syntax (not IP range, class calculations, etc.).
1177

1178
  @type ip: str
1179
  @param ip: the address to be checked
1180
  @rtype: a regular expression match object
1181
  @return: a regular expression match object, or None if the
1182
      address is not valid
1183

1184
  """
1185
  unit = "(0|[1-9]\d{0,2})"
1186
  #TODO: convert and return only boolean
1187
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1188

    
1189

    
1190
def IsValidShellParam(word):
1191
  """Verifies is the given word is safe from the shell's p.o.v.
1192

1193
  This means that we can pass this to a command via the shell and be
1194
  sure that it doesn't alter the command line and is passed as such to
1195
  the actual command.
1196

1197
  Note that we are overly restrictive here, in order to be on the safe
1198
  side.
1199

1200
  @type word: str
1201
  @param word: the word to check
1202
  @rtype: boolean
1203
  @return: True if the word is 'safe'
1204

1205
  """
1206
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1207

    
1208

    
1209
def BuildShellCmd(template, *args):
1210
  """Build a safe shell command line from the given arguments.
1211

1212
  This function will check all arguments in the args list so that they
1213
  are valid shell parameters (i.e. they don't contain shell
1214
  metacharacters). If everything is ok, it will return the result of
1215
  template % args.
1216

1217
  @type template: str
1218
  @param template: the string holding the template for the
1219
      string formatting
1220
  @rtype: str
1221
  @return: the expanded command line
1222

1223
  """
1224
  for word in args:
1225
    if not IsValidShellParam(word):
1226
      raise errors.ProgrammerError("Shell argument '%s' contains"
1227
                                   " invalid characters" % word)
1228
  return template % args
1229

    
1230

    
1231
def FormatUnit(value, units):
1232
  """Formats an incoming number of MiB with the appropriate unit.
1233

1234
  @type value: int
1235
  @param value: integer representing the value in MiB (1048576)
1236
  @type units: char
1237
  @param units: the type of formatting we should do:
1238
      - 'h' for automatic scaling
1239
      - 'm' for MiBs
1240
      - 'g' for GiBs
1241
      - 't' for TiBs
1242
  @rtype: str
1243
  @return: the formatted value (with suffix)
1244

1245
  """
1246
  if units not in ('m', 'g', 't', 'h'):
1247
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1248

    
1249
  suffix = ''
1250

    
1251
  if units == 'm' or (units == 'h' and value < 1024):
1252
    if units == 'h':
1253
      suffix = 'M'
1254
    return "%d%s" % (round(value, 0), suffix)
1255

    
1256
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1257
    if units == 'h':
1258
      suffix = 'G'
1259
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1260

    
1261
  else:
1262
    if units == 'h':
1263
      suffix = 'T'
1264
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1265

    
1266

    
1267
def ParseUnit(input_string):
1268
  """Tries to extract number and scale from the given string.
1269

1270
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1271
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1272
  is always an int in MiB.
1273

1274
  """
1275
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1276
  if not m:
1277
    raise errors.UnitParseError("Invalid format")
1278

    
1279
  value = float(m.groups()[0])
1280

    
1281
  unit = m.groups()[1]
1282
  if unit:
1283
    lcunit = unit.lower()
1284
  else:
1285
    lcunit = 'm'
1286

    
1287
  if lcunit in ('m', 'mb', 'mib'):
1288
    # Value already in MiB
1289
    pass
1290

    
1291
  elif lcunit in ('g', 'gb', 'gib'):
1292
    value *= 1024
1293

    
1294
  elif lcunit in ('t', 'tb', 'tib'):
1295
    value *= 1024 * 1024
1296

    
1297
  else:
1298
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1299

    
1300
  # Make sure we round up
1301
  if int(value) < value:
1302
    value += 1
1303

    
1304
  # Round up to the next multiple of 4
1305
  value = int(value)
1306
  if value % 4:
1307
    value += 4 - value % 4
1308

    
1309
  return value
1310

    
1311

    
1312
def AddAuthorizedKey(file_name, key):
1313
  """Adds an SSH public key to an authorized_keys file.
1314

1315
  @type file_name: str
1316
  @param file_name: path to authorized_keys file
1317
  @type key: str
1318
  @param key: string containing key
1319

1320
  """
1321
  key_fields = key.split()
1322

    
1323
  f = open(file_name, 'a+')
1324
  try:
1325
    nl = True
1326
    for line in f:
1327
      # Ignore whitespace changes
1328
      if line.split() == key_fields:
1329
        break
1330
      nl = line.endswith('\n')
1331
    else:
1332
      if not nl:
1333
        f.write("\n")
1334
      f.write(key.rstrip('\r\n'))
1335
      f.write("\n")
1336
      f.flush()
1337
  finally:
1338
    f.close()
1339

    
1340

    
1341
def RemoveAuthorizedKey(file_name, key):
1342
  """Removes an SSH public key from an authorized_keys file.
1343

1344
  @type file_name: str
1345
  @param file_name: path to authorized_keys file
1346
  @type key: str
1347
  @param key: string containing key
1348

1349
  """
1350
  key_fields = key.split()
1351

    
1352
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1353
  try:
1354
    out = os.fdopen(fd, 'w')
1355
    try:
1356
      f = open(file_name, 'r')
1357
      try:
1358
        for line in f:
1359
          # Ignore whitespace changes while comparing lines
1360
          if line.split() != key_fields:
1361
            out.write(line)
1362

    
1363
        out.flush()
1364
        os.rename(tmpname, file_name)
1365
      finally:
1366
        f.close()
1367
    finally:
1368
      out.close()
1369
  except:
1370
    RemoveFile(tmpname)
1371
    raise
1372

    
1373

    
1374
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1375
  """Sets the name of an IP address and hostname in /etc/hosts.
1376

1377
  @type file_name: str
1378
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1379
  @type ip: str
1380
  @param ip: the IP address
1381
  @type hostname: str
1382
  @param hostname: the hostname to be added
1383
  @type aliases: list
1384
  @param aliases: the list of aliases to add for the hostname
1385

1386
  """
1387
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1388
  # Ensure aliases are unique
1389
  aliases = UniqueSequence([hostname] + aliases)[1:]
1390

    
1391
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1392
  try:
1393
    out = os.fdopen(fd, 'w')
1394
    try:
1395
      f = open(file_name, 'r')
1396
      try:
1397
        for line in f:
1398
          fields = line.split()
1399
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1400
            continue
1401
          out.write(line)
1402

    
1403
        out.write("%s\t%s" % (ip, hostname))
1404
        if aliases:
1405
          out.write(" %s" % ' '.join(aliases))
1406
        out.write('\n')
1407

    
1408
        out.flush()
1409
        os.fsync(out)
1410
        os.chmod(tmpname, 0644)
1411
        os.rename(tmpname, file_name)
1412
      finally:
1413
        f.close()
1414
    finally:
1415
      out.close()
1416
  except:
1417
    RemoveFile(tmpname)
1418
    raise
1419

    
1420

    
1421
def AddHostToEtcHosts(hostname):
1422
  """Wrapper around SetEtcHostsEntry.
1423

1424
  @type hostname: str
1425
  @param hostname: a hostname that will be resolved and added to
1426
      L{constants.ETC_HOSTS}
1427

1428
  """
1429
  hi = HostInfo(name=hostname)
1430
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1431

    
1432

    
1433
def RemoveEtcHostsEntry(file_name, hostname):
1434
  """Removes a hostname from /etc/hosts.
1435

1436
  IP addresses without names are removed from the file.
1437

1438
  @type file_name: str
1439
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1440
  @type hostname: str
1441
  @param hostname: the hostname to be removed
1442

1443
  """
1444
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1445
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1446
  try:
1447
    out = os.fdopen(fd, 'w')
1448
    try:
1449
      f = open(file_name, 'r')
1450
      try:
1451
        for line in f:
1452
          fields = line.split()
1453
          if len(fields) > 1 and not fields[0].startswith('#'):
1454
            names = fields[1:]
1455
            if hostname in names:
1456
              while hostname in names:
1457
                names.remove(hostname)
1458
              if names:
1459
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1460
              continue
1461

    
1462
          out.write(line)
1463

    
1464
        out.flush()
1465
        os.fsync(out)
1466
        os.chmod(tmpname, 0644)
1467
        os.rename(tmpname, file_name)
1468
      finally:
1469
        f.close()
1470
    finally:
1471
      out.close()
1472
  except:
1473
    RemoveFile(tmpname)
1474
    raise
1475

    
1476

    
1477
def RemoveHostFromEtcHosts(hostname):
1478
  """Wrapper around RemoveEtcHostsEntry.
1479

1480
  @type hostname: str
1481
  @param hostname: hostname that will be resolved and its
1482
      full and shot name will be removed from
1483
      L{constants.ETC_HOSTS}
1484

1485
  """
1486
  hi = HostInfo(name=hostname)
1487
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1488
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1489

    
1490

    
1491
def TimestampForFilename():
1492
  """Returns the current time formatted for filenames.
1493

1494
  The format doesn't contain colons as some shells and applications them as
1495
  separators.
1496

1497
  """
1498
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1499

    
1500

    
1501
def CreateBackup(file_name):
1502
  """Creates a backup of a file.
1503

1504
  @type file_name: str
1505
  @param file_name: file to be backed up
1506
  @rtype: str
1507
  @return: the path to the newly created backup
1508
  @raise errors.ProgrammerError: for invalid file names
1509

1510
  """
1511
  if not os.path.isfile(file_name):
1512
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1513
                                file_name)
1514

    
1515
  prefix = ("%s.backup-%s." %
1516
            (os.path.basename(file_name), TimestampForFilename()))
1517
  dir_name = os.path.dirname(file_name)
1518

    
1519
  fsrc = open(file_name, 'rb')
1520
  try:
1521
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1522
    fdst = os.fdopen(fd, 'wb')
1523
    try:
1524
      logging.debug("Backing up %s at %s", file_name, backup_name)
1525
      shutil.copyfileobj(fsrc, fdst)
1526
    finally:
1527
      fdst.close()
1528
  finally:
1529
    fsrc.close()
1530

    
1531
  return backup_name
1532

    
1533

    
1534
def ShellQuote(value):
1535
  """Quotes shell argument according to POSIX.
1536

1537
  @type value: str
1538
  @param value: the argument to be quoted
1539
  @rtype: str
1540
  @return: the quoted value
1541

1542
  """
1543
  if _re_shell_unquoted.match(value):
1544
    return value
1545
  else:
1546
    return "'%s'" % value.replace("'", "'\\''")
1547

    
1548

    
1549
def ShellQuoteArgs(args):
1550
  """Quotes a list of shell arguments.
1551

1552
  @type args: list
1553
  @param args: list of arguments to be quoted
1554
  @rtype: str
1555
  @return: the quoted arguments concatenated with spaces
1556

1557
  """
1558
  return ' '.join([ShellQuote(i) for i in args])
1559

    
1560

    
1561
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1562
  """Simple ping implementation using TCP connect(2).
1563

1564
  Check if the given IP is reachable by doing attempting a TCP connect
1565
  to it.
1566

1567
  @type target: str
1568
  @param target: the IP or hostname to ping
1569
  @type port: int
1570
  @param port: the port to connect to
1571
  @type timeout: int
1572
  @param timeout: the timeout on the connection attempt
1573
  @type live_port_needed: boolean
1574
  @param live_port_needed: whether a closed port will cause the
1575
      function to return failure, as if there was a timeout
1576
  @type source: str or None
1577
  @param source: if specified, will cause the connect to be made
1578
      from this specific source address; failures to bind other
1579
      than C{EADDRNOTAVAIL} will be ignored
1580

1581
  """
1582
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1583

    
1584
  success = False
1585

    
1586
  if source is not None:
1587
    try:
1588
      sock.bind((source, 0))
1589
    except socket.error, (errcode, _):
1590
      if errcode == errno.EADDRNOTAVAIL:
1591
        success = False
1592

    
1593
  sock.settimeout(timeout)
1594

    
1595
  try:
1596
    sock.connect((target, port))
1597
    sock.close()
1598
    success = True
1599
  except socket.timeout:
1600
    success = False
1601
  except socket.error, (errcode, _):
1602
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1603

    
1604
  return success
1605

    
1606

    
1607
def OwnIpAddress(address):
1608
  """Check if the current host has the the given IP address.
1609

1610
  Currently this is done by TCP-pinging the address from the loopback
1611
  address.
1612

1613
  @type address: string
1614
  @param address: the address to check
1615
  @rtype: bool
1616
  @return: True if we own the address
1617

1618
  """
1619
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1620
                 source=constants.LOCALHOST_IP_ADDRESS)
1621

    
1622

    
1623
def ListVisibleFiles(path):
1624
  """Returns a list of visible files in a directory.
1625

1626
  @type path: str
1627
  @param path: the directory to enumerate
1628
  @rtype: list
1629
  @return: the list of all files not starting with a dot
1630
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1631

1632
  """
1633
  if not IsNormAbsPath(path):
1634
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1635
                                 " absolute/normalized: '%s'" % path)
1636
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1637
  files.sort()
1638
  return files
1639

    
1640

    
1641
def GetHomeDir(user, default=None):
1642
  """Try to get the homedir of the given user.
1643

1644
  The user can be passed either as a string (denoting the name) or as
1645
  an integer (denoting the user id). If the user is not found, the
1646
  'default' argument is returned, which defaults to None.
1647

1648
  """
1649
  try:
1650
    if isinstance(user, basestring):
1651
      result = pwd.getpwnam(user)
1652
    elif isinstance(user, (int, long)):
1653
      result = pwd.getpwuid(user)
1654
    else:
1655
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1656
                                   type(user))
1657
  except KeyError:
1658
    return default
1659
  return result.pw_dir
1660

    
1661

    
1662
def NewUUID():
1663
  """Returns a random UUID.
1664

1665
  @note: This is a Linux-specific method as it uses the /proc
1666
      filesystem.
1667
  @rtype: str
1668

1669
  """
1670
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1671

    
1672

    
1673
def GenerateSecret(numbytes=20):
1674
  """Generates a random secret.
1675

1676
  This will generate a pseudo-random secret returning an hex string
1677
  (so that it can be used where an ASCII string is needed).
1678

1679
  @param numbytes: the number of bytes which will be represented by the returned
1680
      string (defaulting to 20, the length of a SHA1 hash)
1681
  @rtype: str
1682
  @return: an hex representation of the pseudo-random sequence
1683

1684
  """
1685
  return os.urandom(numbytes).encode('hex')
1686

    
1687

    
1688
def EnsureDirs(dirs):
1689
  """Make required directories, if they don't exist.
1690

1691
  @param dirs: list of tuples (dir_name, dir_mode)
1692
  @type dirs: list of (string, integer)
1693

1694
  """
1695
  for dir_name, dir_mode in dirs:
1696
    try:
1697
      os.mkdir(dir_name, dir_mode)
1698
    except EnvironmentError, err:
1699
      if err.errno != errno.EEXIST:
1700
        raise errors.GenericError("Cannot create needed directory"
1701
                                  " '%s': %s" % (dir_name, err))
1702
    if not os.path.isdir(dir_name):
1703
      raise errors.GenericError("%s is not a directory" % dir_name)
1704

    
1705

    
1706
def ReadFile(file_name, size=-1):
1707
  """Reads a file.
1708

1709
  @type size: int
1710
  @param size: Read at most size bytes (if negative, entire file)
1711
  @rtype: str
1712
  @return: the (possibly partial) content of the file
1713

1714
  """
1715
  f = open(file_name, "r")
1716
  try:
1717
    return f.read(size)
1718
  finally:
1719
    f.close()
1720

    
1721

    
1722
def WriteFile(file_name, fn=None, data=None,
1723
              mode=None, uid=-1, gid=-1,
1724
              atime=None, mtime=None, close=True,
1725
              dry_run=False, backup=False,
1726
              prewrite=None, postwrite=None):
1727
  """(Over)write a file atomically.
1728

1729
  The file_name and either fn (a function taking one argument, the
1730
  file descriptor, and which should write the data to it) or data (the
1731
  contents of the file) must be passed. The other arguments are
1732
  optional and allow setting the file mode, owner and group, and the
1733
  mtime/atime of the file.
1734

1735
  If the function doesn't raise an exception, it has succeeded and the
1736
  target file has the new contents. If the function has raised an
1737
  exception, an existing target file should be unmodified and the
1738
  temporary file should be removed.
1739

1740
  @type file_name: str
1741
  @param file_name: the target filename
1742
  @type fn: callable
1743
  @param fn: content writing function, called with
1744
      file descriptor as parameter
1745
  @type data: str
1746
  @param data: contents of the file
1747
  @type mode: int
1748
  @param mode: file mode
1749
  @type uid: int
1750
  @param uid: the owner of the file
1751
  @type gid: int
1752
  @param gid: the group of the file
1753
  @type atime: int
1754
  @param atime: a custom access time to be set on the file
1755
  @type mtime: int
1756
  @param mtime: a custom modification time to be set on the file
1757
  @type close: boolean
1758
  @param close: whether to close file after writing it
1759
  @type prewrite: callable
1760
  @param prewrite: function to be called before writing content
1761
  @type postwrite: callable
1762
  @param postwrite: function to be called after writing content
1763

1764
  @rtype: None or int
1765
  @return: None if the 'close' parameter evaluates to True,
1766
      otherwise the file descriptor
1767

1768
  @raise errors.ProgrammerError: if any of the arguments are not valid
1769

1770
  """
1771
  if not os.path.isabs(file_name):
1772
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1773
                                 " absolute: '%s'" % file_name)
1774

    
1775
  if [fn, data].count(None) != 1:
1776
    raise errors.ProgrammerError("fn or data required")
1777

    
1778
  if [atime, mtime].count(None) == 1:
1779
    raise errors.ProgrammerError("Both atime and mtime must be either"
1780
                                 " set or None")
1781

    
1782
  if backup and not dry_run and os.path.isfile(file_name):
1783
    CreateBackup(file_name)
1784

    
1785
  dir_name, base_name = os.path.split(file_name)
1786
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1787
  do_remove = True
1788
  # here we need to make sure we remove the temp file, if any error
1789
  # leaves it in place
1790
  try:
1791
    if uid != -1 or gid != -1:
1792
      os.chown(new_name, uid, gid)
1793
    if mode:
1794
      os.chmod(new_name, mode)
1795
    if callable(prewrite):
1796
      prewrite(fd)
1797
    if data is not None:
1798
      os.write(fd, data)
1799
    else:
1800
      fn(fd)
1801
    if callable(postwrite):
1802
      postwrite(fd)
1803
    os.fsync(fd)
1804
    if atime is not None and mtime is not None:
1805
      os.utime(new_name, (atime, mtime))
1806
    if not dry_run:
1807
      os.rename(new_name, file_name)
1808
      do_remove = False
1809
  finally:
1810
    if close:
1811
      os.close(fd)
1812
      result = None
1813
    else:
1814
      result = fd
1815
    if do_remove:
1816
      RemoveFile(new_name)
1817

    
1818
  return result
1819

    
1820

    
1821
def ReadOneLineFile(file_name, strict=False):
1822
  """Return the first non-empty line from a file.
1823

1824
  @type strict: boolean
1825
  @param strict: if True, abort if the file has more than one
1826
      non-empty line
1827

1828
  """
1829
  file_lines = ReadFile(file_name).splitlines()
1830
  full_lines = filter(bool, file_lines)
1831
  if not file_lines or not full_lines:
1832
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1833
  elif strict and len(full_lines) > 1:
1834
    raise errors.GenericError("Too many lines in one-liner file %s" %
1835
                              file_name)
1836
  return full_lines[0]
1837

    
1838

    
1839
def FirstFree(seq, base=0):
1840
  """Returns the first non-existing integer from seq.
1841

1842
  The seq argument should be a sorted list of positive integers. The
1843
  first time the index of an element is smaller than the element
1844
  value, the index will be returned.
1845

1846
  The base argument is used to start at a different offset,
1847
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1848

1849
  Example: C{[0, 1, 3]} will return I{2}.
1850

1851
  @type seq: sequence
1852
  @param seq: the sequence to be analyzed.
1853
  @type base: int
1854
  @param base: use this value as the base index of the sequence
1855
  @rtype: int
1856
  @return: the first non-used index in the sequence
1857

1858
  """
1859
  for idx, elem in enumerate(seq):
1860
    assert elem >= base, "Passed element is higher than base offset"
1861
    if elem > idx + base:
1862
      # idx is not used
1863
      return idx + base
1864
  return None
1865

    
1866

    
1867
def SingleWaitForFdCondition(fdobj, event, timeout):
1868
  """Waits for a condition to occur on the socket.
1869

1870
  Immediately returns at the first interruption.
1871

1872
  @type fdobj: integer or object supporting a fileno() method
1873
  @param fdobj: entity to wait for events on
1874
  @type event: integer
1875
  @param event: ORed condition (see select module)
1876
  @type timeout: float or None
1877
  @param timeout: Timeout in seconds
1878
  @rtype: int or None
1879
  @return: None for timeout, otherwise occured conditions
1880

1881
  """
1882
  check = (event | select.POLLPRI |
1883
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1884

    
1885
  if timeout is not None:
1886
    # Poller object expects milliseconds
1887
    timeout *= 1000
1888

    
1889
  poller = select.poll()
1890
  poller.register(fdobj, event)
1891
  try:
1892
    # TODO: If the main thread receives a signal and we have no timeout, we
1893
    # could wait forever. This should check a global "quit" flag or something
1894
    # every so often.
1895
    io_events = poller.poll(timeout)
1896
  except select.error, err:
1897
    if err[0] != errno.EINTR:
1898
      raise
1899
    io_events = []
1900
  if io_events and io_events[0][1] & check:
1901
    return io_events[0][1]
1902
  else:
1903
    return None
1904

    
1905

    
1906
class FdConditionWaiterHelper(object):
1907
  """Retry helper for WaitForFdCondition.
1908

1909
  This class contains the retried and wait functions that make sure
1910
  WaitForFdCondition can continue waiting until the timeout is actually
1911
  expired.
1912

1913
  """
1914

    
1915
  def __init__(self, timeout):
1916
    self.timeout = timeout
1917

    
1918
  def Poll(self, fdobj, event):
1919
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1920
    if result is None:
1921
      raise RetryAgain()
1922
    else:
1923
      return result
1924

    
1925
  def UpdateTimeout(self, timeout):
1926
    self.timeout = timeout
1927

    
1928

    
1929
def WaitForFdCondition(fdobj, event, timeout):
1930
  """Waits for a condition to occur on the socket.
1931

1932
  Retries until the timeout is expired, even if interrupted.
1933

1934
  @type fdobj: integer or object supporting a fileno() method
1935
  @param fdobj: entity to wait for events on
1936
  @type event: integer
1937
  @param event: ORed condition (see select module)
1938
  @type timeout: float or None
1939
  @param timeout: Timeout in seconds
1940
  @rtype: int or None
1941
  @return: None for timeout, otherwise occured conditions
1942

1943
  """
1944
  if timeout is not None:
1945
    retrywaiter = FdConditionWaiterHelper(timeout)
1946
    try:
1947
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1948
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1949
    except RetryTimeout:
1950
      result = None
1951
  else:
1952
    result = None
1953
    while result is None:
1954
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1955
  return result
1956

    
1957

    
1958
def UniqueSequence(seq):
1959
  """Returns a list with unique elements.
1960

1961
  Element order is preserved.
1962

1963
  @type seq: sequence
1964
  @param seq: the sequence with the source elements
1965
  @rtype: list
1966
  @return: list of unique elements from seq
1967

1968
  """
1969
  seen = set()
1970
  return [i for i in seq if i not in seen and not seen.add(i)]
1971

    
1972

    
1973
def NormalizeAndValidateMac(mac):
1974
  """Normalizes and check if a MAC address is valid.
1975

1976
  Checks whether the supplied MAC address is formally correct, only
1977
  accepts colon separated format. Normalize it to all lower.
1978

1979
  @type mac: str
1980
  @param mac: the MAC to be validated
1981
  @rtype: str
1982
  @return: returns the normalized and validated MAC.
1983

1984
  @raise errors.OpPrereqError: If the MAC isn't valid
1985

1986
  """
1987
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
1988
  if not mac_check.match(mac):
1989
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
1990
                               mac, errors.ECODE_INVAL)
1991

    
1992
  return mac.lower()
1993

    
1994

    
1995
def TestDelay(duration):
1996
  """Sleep for a fixed amount of time.
1997

1998
  @type duration: float
1999
  @param duration: the sleep duration
2000
  @rtype: boolean
2001
  @return: False for negative value, True otherwise
2002

2003
  """
2004
  if duration < 0:
2005
    return False, "Invalid sleep duration"
2006
  time.sleep(duration)
2007
  return True, None
2008

    
2009

    
2010
def _CloseFDNoErr(fd, retries=5):
2011
  """Close a file descriptor ignoring errors.
2012

2013
  @type fd: int
2014
  @param fd: the file descriptor
2015
  @type retries: int
2016
  @param retries: how many retries to make, in case we get any
2017
      other error than EBADF
2018

2019
  """
2020
  try:
2021
    os.close(fd)
2022
  except OSError, err:
2023
    if err.errno != errno.EBADF:
2024
      if retries > 0:
2025
        _CloseFDNoErr(fd, retries - 1)
2026
    # else either it's closed already or we're out of retries, so we
2027
    # ignore this and go on
2028

    
2029

    
2030
def CloseFDs(noclose_fds=None):
2031
  """Close file descriptors.
2032

2033
  This closes all file descriptors above 2 (i.e. except
2034
  stdin/out/err).
2035

2036
  @type noclose_fds: list or None
2037
  @param noclose_fds: if given, it denotes a list of file descriptor
2038
      that should not be closed
2039

2040
  """
2041
  # Default maximum for the number of available file descriptors.
2042
  if 'SC_OPEN_MAX' in os.sysconf_names:
2043
    try:
2044
      MAXFD = os.sysconf('SC_OPEN_MAX')
2045
      if MAXFD < 0:
2046
        MAXFD = 1024
2047
    except OSError:
2048
      MAXFD = 1024
2049
  else:
2050
    MAXFD = 1024
2051
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2052
  if (maxfd == resource.RLIM_INFINITY):
2053
    maxfd = MAXFD
2054

    
2055
  # Iterate through and close all file descriptors (except the standard ones)
2056
  for fd in range(3, maxfd):
2057
    if noclose_fds and fd in noclose_fds:
2058
      continue
2059
    _CloseFDNoErr(fd)
2060

    
2061

    
2062
def Mlockall():
2063
  """Lock current process' virtual address space into RAM.
2064

2065
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2066
  see mlock(2) for more details. This function requires ctypes module.
2067

2068
  """
2069
  if ctypes is None:
2070
    logging.warning("Cannot set memory lock, ctypes module not found")
2071
    return
2072

    
2073
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2074
  if libc is None:
2075
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2076
    return
2077

    
2078
  # Some older version of the ctypes module don't have built-in functionality
2079
  # to access the errno global variable, where function error codes are stored.
2080
  # By declaring this variable as a pointer to an integer we can then access
2081
  # its value correctly, should the mlockall call fail, in order to see what
2082
  # the actual error code was.
2083
  # pylint: disable-msg=W0212
2084
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2085

    
2086
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2087
    # pylint: disable-msg=W0212
2088
    logging.error("Cannot set memory lock: %s",
2089
                  os.strerror(libc.__errno_location().contents.value))
2090
    return
2091

    
2092
  logging.debug("Memory lock set")
2093

    
2094

    
2095
def Daemonize(logfile):
2096
  """Daemonize the current process.
2097

2098
  This detaches the current process from the controlling terminal and
2099
  runs it in the background as a daemon.
2100

2101
  @type logfile: str
2102
  @param logfile: the logfile to which we should redirect stdout/stderr
2103
  @rtype: int
2104
  @return: the value zero
2105

2106
  """
2107
  # pylint: disable-msg=W0212
2108
  # yes, we really want os._exit
2109
  UMASK = 077
2110
  WORKDIR = "/"
2111

    
2112
  # this might fail
2113
  pid = os.fork()
2114
  if (pid == 0):  # The first child.
2115
    os.setsid()
2116
    # this might fail
2117
    pid = os.fork() # Fork a second child.
2118
    if (pid == 0):  # The second child.
2119
      os.chdir(WORKDIR)
2120
      os.umask(UMASK)
2121
    else:
2122
      # exit() or _exit()?  See below.
2123
      os._exit(0) # Exit parent (the first child) of the second child.
2124
  else:
2125
    os._exit(0) # Exit parent of the first child.
2126

    
2127
  for fd in range(3):
2128
    _CloseFDNoErr(fd)
2129
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2130
  assert i == 0, "Can't close/reopen stdin"
2131
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2132
  assert i == 1, "Can't close/reopen stdout"
2133
  # Duplicate standard output to standard error.
2134
  os.dup2(1, 2)
2135
  return 0
2136

    
2137

    
2138
def DaemonPidFileName(name):
2139
  """Compute a ganeti pid file absolute path
2140

2141
  @type name: str
2142
  @param name: the daemon name
2143
  @rtype: str
2144
  @return: the full path to the pidfile corresponding to the given
2145
      daemon name
2146

2147
  """
2148
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2149

    
2150

    
2151
def EnsureDaemon(name):
2152
  """Check for and start daemon if not alive.
2153

2154
  """
2155
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2156
  if result.failed:
2157
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2158
                  name, result.fail_reason, result.output)
2159
    return False
2160

    
2161
  return True
2162

    
2163

    
2164
def WritePidFile(name):
2165
  """Write the current process pidfile.
2166

2167
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2168

2169
  @type name: str
2170
  @param name: the daemon name to use
2171
  @raise errors.GenericError: if the pid file already exists and
2172
      points to a live process
2173

2174
  """
2175
  pid = os.getpid()
2176
  pidfilename = DaemonPidFileName(name)
2177
  if IsProcessAlive(ReadPidFile(pidfilename)):
2178
    raise errors.GenericError("%s contains a live process" % pidfilename)
2179

    
2180
  WriteFile(pidfilename, data="%d\n" % pid)
2181

    
2182

    
2183
def RemovePidFile(name):
2184
  """Remove the current process pidfile.
2185

2186
  Any errors are ignored.
2187

2188
  @type name: str
2189
  @param name: the daemon name used to derive the pidfile name
2190

2191
  """
2192
  pidfilename = DaemonPidFileName(name)
2193
  # TODO: we could check here that the file contains our pid
2194
  try:
2195
    RemoveFile(pidfilename)
2196
  except: # pylint: disable-msg=W0702
2197
    pass
2198

    
2199

    
2200
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2201
                waitpid=False):
2202
  """Kill a process given by its pid.
2203

2204
  @type pid: int
2205
  @param pid: The PID to terminate.
2206
  @type signal_: int
2207
  @param signal_: The signal to send, by default SIGTERM
2208
  @type timeout: int
2209
  @param timeout: The timeout after which, if the process is still alive,
2210
                  a SIGKILL will be sent. If not positive, no such checking
2211
                  will be done
2212
  @type waitpid: boolean
2213
  @param waitpid: If true, we should waitpid on this process after
2214
      sending signals, since it's our own child and otherwise it
2215
      would remain as zombie
2216

2217
  """
2218
  def _helper(pid, signal_, wait):
2219
    """Simple helper to encapsulate the kill/waitpid sequence"""
2220
    os.kill(pid, signal_)
2221
    if wait:
2222
      try:
2223
        os.waitpid(pid, os.WNOHANG)
2224
      except OSError:
2225
        pass
2226

    
2227
  if pid <= 0:
2228
    # kill with pid=0 == suicide
2229
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2230

    
2231
  if not IsProcessAlive(pid):
2232
    return
2233

    
2234
  _helper(pid, signal_, waitpid)
2235

    
2236
  if timeout <= 0:
2237
    return
2238

    
2239
  def _CheckProcess():
2240
    if not IsProcessAlive(pid):
2241
      return
2242

    
2243
    try:
2244
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2245
    except OSError:
2246
      raise RetryAgain()
2247

    
2248
    if result_pid > 0:
2249
      return
2250

    
2251
    raise RetryAgain()
2252

    
2253
  try:
2254
    # Wait up to $timeout seconds
2255
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2256
  except RetryTimeout:
2257
    pass
2258

    
2259
  if IsProcessAlive(pid):
2260
    # Kill process if it's still alive
2261
    _helper(pid, signal.SIGKILL, waitpid)
2262

    
2263

    
2264
def FindFile(name, search_path, test=os.path.exists):
2265
  """Look for a filesystem object in a given path.
2266

2267
  This is an abstract method to search for filesystem object (files,
2268
  dirs) under a given search path.
2269

2270
  @type name: str
2271
  @param name: the name to look for
2272
  @type search_path: str
2273
  @param search_path: location to start at
2274
  @type test: callable
2275
  @param test: a function taking one argument that should return True
2276
      if the a given object is valid; the default value is
2277
      os.path.exists, causing only existing files to be returned
2278
  @rtype: str or None
2279
  @return: full path to the object if found, None otherwise
2280

2281
  """
2282
  # validate the filename mask
2283
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2284
    logging.critical("Invalid value passed for external script name: '%s'",
2285
                     name)
2286
    return None
2287

    
2288
  for dir_name in search_path:
2289
    # FIXME: investigate switch to PathJoin
2290
    item_name = os.path.sep.join([dir_name, name])
2291
    # check the user test and that we're indeed resolving to the given
2292
    # basename
2293
    if test(item_name) and os.path.basename(item_name) == name:
2294
      return item_name
2295
  return None
2296

    
2297

    
2298
def CheckVolumeGroupSize(vglist, vgname, minsize):
2299
  """Checks if the volume group list is valid.
2300

2301
  The function will check if a given volume group is in the list of
2302
  volume groups and has a minimum size.
2303

2304
  @type vglist: dict
2305
  @param vglist: dictionary of volume group names and their size
2306
  @type vgname: str
2307
  @param vgname: the volume group we should check
2308
  @type minsize: int
2309
  @param minsize: the minimum size we accept
2310
  @rtype: None or str
2311
  @return: None for success, otherwise the error message
2312

2313
  """
2314
  vgsize = vglist.get(vgname, None)
2315
  if vgsize is None:
2316
    return "volume group '%s' missing" % vgname
2317
  elif vgsize < minsize:
2318
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2319
            (vgname, minsize, vgsize))
2320
  return None
2321

    
2322

    
2323
def SplitTime(value):
2324
  """Splits time as floating point number into a tuple.
2325

2326
  @param value: Time in seconds
2327
  @type value: int or float
2328
  @return: Tuple containing (seconds, microseconds)
2329

2330
  """
2331
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2332

    
2333
  assert 0 <= seconds, \
2334
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2335
  assert 0 <= microseconds <= 999999, \
2336
    "Microseconds must be 0-999999, but are %s" % microseconds
2337

    
2338
  return (int(seconds), int(microseconds))
2339

    
2340

    
2341
def MergeTime(timetuple):
2342
  """Merges a tuple into time as a floating point number.
2343

2344
  @param timetuple: Time as tuple, (seconds, microseconds)
2345
  @type timetuple: tuple
2346
  @return: Time as a floating point number expressed in seconds
2347

2348
  """
2349
  (seconds, microseconds) = timetuple
2350

    
2351
  assert 0 <= seconds, \
2352
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2353
  assert 0 <= microseconds <= 999999, \
2354
    "Microseconds must be 0-999999, but are %s" % microseconds
2355

    
2356
  return float(seconds) + (float(microseconds) * 0.000001)
2357

    
2358

    
2359
def GetDaemonPort(daemon_name):
2360
  """Get the daemon port for this cluster.
2361

2362
  Note that this routine does not read a ganeti-specific file, but
2363
  instead uses C{socket.getservbyname} to allow pre-customization of
2364
  this parameter outside of Ganeti.
2365

2366
  @type daemon_name: string
2367
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2368
  @rtype: int
2369

2370
  """
2371
  if daemon_name not in constants.DAEMONS_PORTS:
2372
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2373

    
2374
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2375
  try:
2376
    port = socket.getservbyname(daemon_name, proto)
2377
  except socket.error:
2378
    port = default_port
2379

    
2380
  return port
2381

    
2382

    
2383
class LogFileHandler(logging.FileHandler):
2384
  """Log handler that doesn't fallback to stderr.
2385

2386
  When an error occurs while writing on the logfile, logging.FileHandler tries
2387
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2388
  the logfile. This class avoids failures reporting errors to /dev/console.
2389

2390
  """
2391
  def __init__(self, filename, mode="a", encoding=None):
2392
    """Open the specified file and use it as the stream for logging.
2393

2394
    Also open /dev/console to report errors while logging.
2395

2396
    """
2397
    logging.FileHandler.__init__(self, filename, mode, encoding)
2398
    self.console = open(constants.DEV_CONSOLE, "a")
2399

    
2400
  def handleError(self, record): # pylint: disable-msg=C0103
2401
    """Handle errors which occur during an emit() call.
2402

2403
    Try to handle errors with FileHandler method, if it fails write to
2404
    /dev/console.
2405

2406
    """
2407
    try:
2408
      logging.FileHandler.handleError(self, record)
2409
    except Exception: # pylint: disable-msg=W0703
2410
      try:
2411
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2412
      except Exception: # pylint: disable-msg=W0703
2413
        # Log handler tried everything it could, now just give up
2414
        pass
2415

    
2416

    
2417
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2418
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2419
                 console_logging=False):
2420
  """Configures the logging module.
2421

2422
  @type logfile: str
2423
  @param logfile: the filename to which we should log
2424
  @type debug: integer
2425
  @param debug: if greater than zero, enable debug messages, otherwise
2426
      only those at C{INFO} and above level
2427
  @type stderr_logging: boolean
2428
  @param stderr_logging: whether we should also log to the standard error
2429
  @type program: str
2430
  @param program: the name under which we should log messages
2431
  @type multithreaded: boolean
2432
  @param multithreaded: if True, will add the thread name to the log file
2433
  @type syslog: string
2434
  @param syslog: one of 'no', 'yes', 'only':
2435
      - if no, syslog is not used
2436
      - if yes, syslog is used (in addition to file-logging)
2437
      - if only, only syslog is used
2438
  @type console_logging: boolean
2439
  @param console_logging: if True, will use a FileHandler which falls back to
2440
      the system console if logging fails
2441
  @raise EnvironmentError: if we can't open the log file and
2442
      syslog/stderr logging is disabled
2443

2444
  """
2445
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2446
  sft = program + "[%(process)d]:"
2447
  if multithreaded:
2448
    fmt += "/%(threadName)s"
2449
    sft += " (%(threadName)s)"
2450
  if debug:
2451
    fmt += " %(module)s:%(lineno)s"
2452
    # no debug info for syslog loggers
2453
  fmt += " %(levelname)s %(message)s"
2454
  # yes, we do want the textual level, as remote syslog will probably
2455
  # lose the error level, and it's easier to grep for it
2456
  sft += " %(levelname)s %(message)s"
2457
  formatter = logging.Formatter(fmt)
2458
  sys_fmt = logging.Formatter(sft)
2459

    
2460
  root_logger = logging.getLogger("")
2461
  root_logger.setLevel(logging.NOTSET)
2462

    
2463
  # Remove all previously setup handlers
2464
  for handler in root_logger.handlers:
2465
    handler.close()
2466
    root_logger.removeHandler(handler)
2467

    
2468
  if stderr_logging:
2469
    stderr_handler = logging.StreamHandler()
2470
    stderr_handler.setFormatter(formatter)
2471
    if debug:
2472
      stderr_handler.setLevel(logging.NOTSET)
2473
    else:
2474
      stderr_handler.setLevel(logging.CRITICAL)
2475
    root_logger.addHandler(stderr_handler)
2476

    
2477
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2478
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2479
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2480
                                                    facility)
2481
    syslog_handler.setFormatter(sys_fmt)
2482
    # Never enable debug over syslog
2483
    syslog_handler.setLevel(logging.INFO)
2484
    root_logger.addHandler(syslog_handler)
2485

    
2486
  if syslog != constants.SYSLOG_ONLY:
2487
    # this can fail, if the logging directories are not setup or we have
2488
    # a permisssion problem; in this case, it's best to log but ignore
2489
    # the error if stderr_logging is True, and if false we re-raise the
2490
    # exception since otherwise we could run but without any logs at all
2491
    try:
2492
      if console_logging:
2493
        logfile_handler = LogFileHandler(logfile)
2494
      else:
2495
        logfile_handler = logging.FileHandler(logfile)
2496
      logfile_handler.setFormatter(formatter)
2497
      if debug:
2498
        logfile_handler.setLevel(logging.DEBUG)
2499
      else:
2500
        logfile_handler.setLevel(logging.INFO)
2501
      root_logger.addHandler(logfile_handler)
2502
    except EnvironmentError:
2503
      if stderr_logging or syslog == constants.SYSLOG_YES:
2504
        logging.exception("Failed to enable logging to file '%s'", logfile)
2505
      else:
2506
        # we need to re-raise the exception
2507
        raise
2508

    
2509

    
2510
def IsNormAbsPath(path):
2511
  """Check whether a path is absolute and also normalized
2512

2513
  This avoids things like /dir/../../other/path to be valid.
2514

2515
  """
2516
  return os.path.normpath(path) == path and os.path.isabs(path)
2517

    
2518

    
2519
def PathJoin(*args):
2520
  """Safe-join a list of path components.
2521

2522
  Requirements:
2523
      - the first argument must be an absolute path
2524
      - no component in the path must have backtracking (e.g. /../),
2525
        since we check for normalization at the end
2526

2527
  @param args: the path components to be joined
2528
  @raise ValueError: for invalid paths
2529

2530
  """
2531
  # ensure we're having at least one path passed in
2532
  assert args
2533
  # ensure the first component is an absolute and normalized path name
2534
  root = args[0]
2535
  if not IsNormAbsPath(root):
2536
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2537
  result = os.path.join(*args)
2538
  # ensure that the whole path is normalized
2539
  if not IsNormAbsPath(result):
2540
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2541
  # check that we're still under the original prefix
2542
  prefix = os.path.commonprefix([root, result])
2543
  if prefix != root:
2544
    raise ValueError("Error: path joining resulted in different prefix"
2545
                     " (%s != %s)" % (prefix, root))
2546
  return result
2547

    
2548

    
2549
def TailFile(fname, lines=20):
2550
  """Return the last lines from a file.
2551

2552
  @note: this function will only read and parse the last 4KB of
2553
      the file; if the lines are very long, it could be that less
2554
      than the requested number of lines are returned
2555

2556
  @param fname: the file name
2557
  @type lines: int
2558
  @param lines: the (maximum) number of lines to return
2559

2560
  """
2561
  fd = open(fname, "r")
2562
  try:
2563
    fd.seek(0, 2)
2564
    pos = fd.tell()
2565
    pos = max(0, pos-4096)
2566
    fd.seek(pos, 0)
2567
    raw_data = fd.read()
2568
  finally:
2569
    fd.close()
2570

    
2571
  rows = raw_data.splitlines()
2572
  return rows[-lines:]
2573

    
2574

    
2575
def FormatTimestampWithTZ(secs):
2576
  """Formats a Unix timestamp with the local timezone.
2577

2578
  """
2579
  return time.strftime("%F %T %Z", time.gmtime(secs))
2580

    
2581

    
2582
def _ParseAsn1Generalizedtime(value):
2583
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2584

2585
  @type value: string
2586
  @param value: ASN1 GENERALIZEDTIME timestamp
2587

2588
  """
2589
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2590
  if m:
2591
    # We have an offset
2592
    asn1time = m.group(1)
2593
    hours = int(m.group(2))
2594
    minutes = int(m.group(3))
2595
    utcoffset = (60 * hours) + minutes
2596
  else:
2597
    if not value.endswith("Z"):
2598
      raise ValueError("Missing timezone")
2599
    asn1time = value[:-1]
2600
    utcoffset = 0
2601

    
2602
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2603

    
2604
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2605

    
2606
  return calendar.timegm(tt.utctimetuple())
2607

    
2608

    
2609
def GetX509CertValidity(cert):
2610
  """Returns the validity period of the certificate.
2611

2612
  @type cert: OpenSSL.crypto.X509
2613
  @param cert: X509 certificate object
2614

2615
  """
2616
  # The get_notBefore and get_notAfter functions are only supported in
2617
  # pyOpenSSL 0.7 and above.
2618
  try:
2619
    get_notbefore_fn = cert.get_notBefore
2620
  except AttributeError:
2621
    not_before = None
2622
  else:
2623
    not_before_asn1 = get_notbefore_fn()
2624

    
2625
    if not_before_asn1 is None:
2626
      not_before = None
2627
    else:
2628
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2629

    
2630
  try:
2631
    get_notafter_fn = cert.get_notAfter
2632
  except AttributeError:
2633
    not_after = None
2634
  else:
2635
    not_after_asn1 = get_notafter_fn()
2636

    
2637
    if not_after_asn1 is None:
2638
      not_after = None
2639
    else:
2640
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2641

    
2642
  return (not_before, not_after)
2643

    
2644

    
2645
def _VerifyCertificateInner(expired, not_before, not_after, now,
2646
                            warn_days, error_days):
2647
  """Verifies certificate validity.
2648

2649
  @type expired: bool
2650
  @param expired: Whether pyOpenSSL considers the certificate as expired
2651
  @type not_before: number or None
2652
  @param not_before: Unix timestamp before which certificate is not valid
2653
  @type not_after: number or None
2654
  @param not_after: Unix timestamp after which certificate is invalid
2655
  @type now: number
2656
  @param now: Current time as Unix timestamp
2657
  @type warn_days: number or None
2658
  @param warn_days: How many days before expiration a warning should be reported
2659
  @type error_days: number or None
2660
  @param error_days: How many days before expiration an error should be reported
2661

2662
  """
2663
  if expired:
2664
    msg = "Certificate is expired"
2665

    
2666
    if not_before is not None and not_after is not None:
2667
      msg += (" (valid from %s to %s)" %
2668
              (FormatTimestampWithTZ(not_before),
2669
               FormatTimestampWithTZ(not_after)))
2670
    elif not_before is not None:
2671
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2672
    elif not_after is not None:
2673
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2674

    
2675
    return (CERT_ERROR, msg)
2676

    
2677
  elif not_before is not None and not_before > now:
2678
    return (CERT_WARNING,
2679
            "Certificate not yet valid (valid from %s)" %
2680
            FormatTimestampWithTZ(not_before))
2681

    
2682
  elif not_after is not None:
2683
    remaining_days = int((not_after - now) / (24 * 3600))
2684

    
2685
    msg = "Certificate expires in about %d days" % remaining_days
2686

    
2687
    if error_days is not None and remaining_days <= error_days:
2688
      return (CERT_ERROR, msg)
2689

    
2690
    if warn_days is not None and remaining_days <= warn_days:
2691
      return (CERT_WARNING, msg)
2692

    
2693
  return (None, None)
2694

    
2695

    
2696
def VerifyX509Certificate(cert, warn_days, error_days):
2697
  """Verifies a certificate for LUVerifyCluster.
2698

2699
  @type cert: OpenSSL.crypto.X509
2700
  @param cert: X509 certificate object
2701
  @type warn_days: number or None
2702
  @param warn_days: How many days before expiration a warning should be reported
2703
  @type error_days: number or None
2704
  @param error_days: How many days before expiration an error should be reported
2705

2706
  """
2707
  # Depending on the pyOpenSSL version, this can just return (None, None)
2708
  (not_before, not_after) = GetX509CertValidity(cert)
2709

    
2710
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2711
                                 time.time(), warn_days, error_days)
2712

    
2713

    
2714
def SignX509Certificate(cert, key, salt):
2715
  """Sign a X509 certificate.
2716

2717
  An RFC822-like signature header is added in front of the certificate.
2718

2719
  @type cert: OpenSSL.crypto.X509
2720
  @param cert: X509 certificate object
2721
  @type key: string
2722
  @param key: Key for HMAC
2723
  @type salt: string
2724
  @param salt: Salt for HMAC
2725
  @rtype: string
2726
  @return: Serialized and signed certificate in PEM format
2727

2728
  """
2729
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2730
    raise errors.GenericError("Invalid salt: %r" % salt)
2731

    
2732
  # Dumping as PEM here ensures the certificate is in a sane format
2733
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2734

    
2735
  return ("%s: %s/%s\n\n%s" %
2736
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2737
           Sha1Hmac(key, cert_pem, salt=salt),
2738
           cert_pem))
2739

    
2740

    
2741
def _ExtractX509CertificateSignature(cert_pem):
2742
  """Helper function to extract signature from X509 certificate.
2743

2744
  """
2745
  # Extract signature from original PEM data
2746
  for line in cert_pem.splitlines():
2747
    if line.startswith("---"):
2748
      break
2749

    
2750
    m = X509_SIGNATURE.match(line.strip())
2751
    if m:
2752
      return (m.group("salt"), m.group("sign"))
2753

    
2754
  raise errors.GenericError("X509 certificate signature is missing")
2755

    
2756

    
2757
def LoadSignedX509Certificate(cert_pem, key):
2758
  """Verifies a signed X509 certificate.
2759

2760
  @type cert_pem: string
2761
  @param cert_pem: Certificate in PEM format and with signature header
2762
  @type key: string
2763
  @param key: Key for HMAC
2764
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2765
  @return: X509 certificate object and salt
2766

2767
  """
2768
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2769

    
2770
  # Load certificate
2771
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2772

    
2773
  # Dump again to ensure it's in a sane format
2774
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2775

    
2776
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2777
    raise errors.GenericError("X509 certificate signature is invalid")
2778

    
2779
  return (cert, salt)
2780

    
2781

    
2782
def Sha1Hmac(key, text, salt=None):
2783
  """Calculates the HMAC-SHA1 digest of a text.
2784

2785
  HMAC is defined in RFC2104.
2786

2787
  @type key: string
2788
  @param key: Secret key
2789
  @type text: string
2790

2791
  """
2792
  if salt:
2793
    salted_text = salt + text
2794
  else:
2795
    salted_text = text
2796

    
2797
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2798

    
2799

    
2800
def VerifySha1Hmac(key, text, digest, salt=None):
2801
  """Verifies the HMAC-SHA1 digest of a text.
2802

2803
  HMAC is defined in RFC2104.
2804

2805
  @type key: string
2806
  @param key: Secret key
2807
  @type text: string
2808
  @type digest: string
2809
  @param digest: Expected digest
2810
  @rtype: bool
2811
  @return: Whether HMAC-SHA1 digest matches
2812

2813
  """
2814
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2815

    
2816

    
2817
def SafeEncode(text):
2818
  """Return a 'safe' version of a source string.
2819

2820
  This function mangles the input string and returns a version that
2821
  should be safe to display/encode as ASCII. To this end, we first
2822
  convert it to ASCII using the 'backslashreplace' encoding which
2823
  should get rid of any non-ASCII chars, and then we process it
2824
  through a loop copied from the string repr sources in the python; we
2825
  don't use string_escape anymore since that escape single quotes and
2826
  backslashes too, and that is too much; and that escaping is not
2827
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2828

2829
  @type text: str or unicode
2830
  @param text: input data
2831
  @rtype: str
2832
  @return: a safe version of text
2833

2834
  """
2835
  if isinstance(text, unicode):
2836
    # only if unicode; if str already, we handle it below
2837
    text = text.encode('ascii', 'backslashreplace')
2838
  resu = ""
2839
  for char in text:
2840
    c = ord(char)
2841
    if char  == '\t':
2842
      resu += r'\t'
2843
    elif char == '\n':
2844
      resu += r'\n'
2845
    elif char == '\r':
2846
      resu += r'\'r'
2847
    elif c < 32 or c >= 127: # non-printable
2848
      resu += "\\x%02x" % (c & 0xff)
2849
    else:
2850
      resu += char
2851
  return resu
2852

    
2853

    
2854
def UnescapeAndSplit(text, sep=","):
2855
  """Split and unescape a string based on a given separator.
2856

2857
  This function splits a string based on a separator where the
2858
  separator itself can be escape in order to be an element of the
2859
  elements. The escaping rules are (assuming coma being the
2860
  separator):
2861
    - a plain , separates the elements
2862
    - a sequence \\\\, (double backslash plus comma) is handled as a
2863
      backslash plus a separator comma
2864
    - a sequence \, (backslash plus comma) is handled as a
2865
      non-separator comma
2866

2867
  @type text: string
2868
  @param text: the string to split
2869
  @type sep: string
2870
  @param text: the separator
2871
  @rtype: string
2872
  @return: a list of strings
2873

2874
  """
2875
  # we split the list by sep (with no escaping at this stage)
2876
  slist = text.split(sep)
2877
  # next, we revisit the elements and if any of them ended with an odd
2878
  # number of backslashes, then we join it with the next
2879
  rlist = []
2880
  while slist:
2881
    e1 = slist.pop(0)
2882
    if e1.endswith("\\"):
2883
      num_b = len(e1) - len(e1.rstrip("\\"))
2884
      if num_b % 2 == 1:
2885
        e2 = slist.pop(0)
2886
        # here the backslashes remain (all), and will be reduced in
2887
        # the next step
2888
        rlist.append(e1 + sep + e2)
2889
        continue
2890
    rlist.append(e1)
2891
  # finally, replace backslash-something with something
2892
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2893
  return rlist
2894

    
2895

    
2896
def CommaJoin(names):
2897
  """Nicely join a set of identifiers.
2898

2899
  @param names: set, list or tuple
2900
  @return: a string with the formatted results
2901

2902
  """
2903
  return ", ".join([str(val) for val in names])
2904

    
2905

    
2906
def BytesToMebibyte(value):
2907
  """Converts bytes to mebibytes.
2908

2909
  @type value: int
2910
  @param value: Value in bytes
2911
  @rtype: int
2912
  @return: Value in mebibytes
2913

2914
  """
2915
  return int(round(value / (1024.0 * 1024.0), 0))
2916

    
2917

    
2918
def CalculateDirectorySize(path):
2919
  """Calculates the size of a directory recursively.
2920

2921
  @type path: string
2922
  @param path: Path to directory
2923
  @rtype: int
2924
  @return: Size in mebibytes
2925

2926
  """
2927
  size = 0
2928

    
2929
  for (curpath, _, files) in os.walk(path):
2930
    for filename in files:
2931
      st = os.lstat(PathJoin(curpath, filename))
2932
      size += st.st_size
2933

    
2934
  return BytesToMebibyte(size)
2935

    
2936

    
2937
def GetFilesystemStats(path):
2938
  """Returns the total and free space on a filesystem.
2939

2940
  @type path: string
2941
  @param path: Path on filesystem to be examined
2942
  @rtype: int
2943
  @return: tuple of (Total space, Free space) in mebibytes
2944

2945
  """
2946
  st = os.statvfs(path)
2947

    
2948
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2949
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2950
  return (tsize, fsize)
2951

    
2952

    
2953
def RunInSeparateProcess(fn, *args):
2954
  """Runs a function in a separate process.
2955

2956
  Note: Only boolean return values are supported.
2957

2958
  @type fn: callable
2959
  @param fn: Function to be called
2960
  @rtype: bool
2961
  @return: Function's result
2962

2963
  """
2964
  pid = os.fork()
2965
  if pid == 0:
2966
    # Child process
2967
    try:
2968
      # In case the function uses temporary files
2969
      ResetTempfileModule()
2970

    
2971
      # Call function
2972
      result = int(bool(fn(*args)))
2973
      assert result in (0, 1)
2974
    except: # pylint: disable-msg=W0702
2975
      logging.exception("Error while calling function in separate process")
2976
      # 0 and 1 are reserved for the return value
2977
      result = 33
2978

    
2979
    os._exit(result) # pylint: disable-msg=W0212
2980

    
2981
  # Parent process
2982

    
2983
  # Avoid zombies and check exit code
2984
  (_, status) = os.waitpid(pid, 0)
2985

    
2986
  if os.WIFSIGNALED(status):
2987
    exitcode = None
2988
    signum = os.WTERMSIG(status)
2989
  else:
2990
    exitcode = os.WEXITSTATUS(status)
2991
    signum = None
2992

    
2993
  if not (exitcode in (0, 1) and signum is None):
2994
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
2995
                              (exitcode, signum))
2996

    
2997
  return bool(exitcode)
2998

    
2999

    
3000
def IgnoreSignals(fn, *args, **kwargs):
3001
  """Tries to call a function ignoring failures due to EINTR.
3002

3003
  """
3004
  try:
3005
    return fn(*args, **kwargs)
3006
  except EnvironmentError, err:
3007
    if err.errno != errno.EINTR:
3008
      raise
3009
  except (select.error, socket.error), err:
3010
    # In python 2.6 and above select.error is an IOError, so it's handled
3011
    # above, in 2.5 and below it's not, and it's handled here.
3012
    if not (err.args and err.args[0] == errno.EINTR):
3013
      raise
3014

    
3015

    
3016
def LockedMethod(fn):
3017
  """Synchronized object access decorator.
3018

3019
  This decorator is intended to protect access to an object using the
3020
  object's own lock which is hardcoded to '_lock'.
3021

3022
  """
3023
  def _LockDebug(*args, **kwargs):
3024
    if debug_locks:
3025
      logging.debug(*args, **kwargs)
3026

    
3027
  def wrapper(self, *args, **kwargs):
3028
    # pylint: disable-msg=W0212
3029
    assert hasattr(self, '_lock')
3030
    lock = self._lock
3031
    _LockDebug("Waiting for %s", lock)
3032
    lock.acquire()
3033
    try:
3034
      _LockDebug("Acquired %s", lock)
3035
      result = fn(self, *args, **kwargs)
3036
    finally:
3037
      _LockDebug("Releasing %s", lock)
3038
      lock.release()
3039
      _LockDebug("Released %s", lock)
3040
    return result
3041
  return wrapper
3042

    
3043

    
3044
def LockFile(fd):
3045
  """Locks a file using POSIX locks.
3046

3047
  @type fd: int
3048
  @param fd: the file descriptor we need to lock
3049

3050
  """
3051
  try:
3052
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3053
  except IOError, err:
3054
    if err.errno == errno.EAGAIN:
3055
      raise errors.LockError("File already locked")
3056
    raise
3057

    
3058

    
3059
def FormatTime(val):
3060
  """Formats a time value.
3061

3062
  @type val: float or None
3063
  @param val: the timestamp as returned by time.time()
3064
  @return: a string value or N/A if we don't have a valid timestamp
3065

3066
  """
3067
  if val is None or not isinstance(val, (int, float)):
3068
    return "N/A"
3069
  # these two codes works on Linux, but they are not guaranteed on all
3070
  # platforms
3071
  return time.strftime("%F %T", time.localtime(val))
3072

    
3073

    
3074
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3075
  """Reads the watcher pause file.
3076

3077
  @type filename: string
3078
  @param filename: Path to watcher pause file
3079
  @type now: None, float or int
3080
  @param now: Current time as Unix timestamp
3081
  @type remove_after: int
3082
  @param remove_after: Remove watcher pause file after specified amount of
3083
    seconds past the pause end time
3084

3085
  """
3086
  if now is None:
3087
    now = time.time()
3088

    
3089
  try:
3090
    value = ReadFile(filename)
3091
  except IOError, err:
3092
    if err.errno != errno.ENOENT:
3093
      raise
3094
    value = None
3095

    
3096
  if value is not None:
3097
    try:
3098
      value = int(value)
3099
    except ValueError:
3100
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3101
                       " removing it"), filename)
3102
      RemoveFile(filename)
3103
      value = None
3104

    
3105
    if value is not None:
3106
      # Remove file if it's outdated
3107
      if now > (value + remove_after):
3108
        RemoveFile(filename)
3109
        value = None
3110

    
3111
      elif now > value:
3112
        value = None
3113

    
3114
  return value
3115

    
3116

    
3117
class RetryTimeout(Exception):
3118
  """Retry loop timed out.
3119

3120
  Any arguments which was passed by the retried function to RetryAgain will be
3121
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3122
  the RaiseInner helper method will reraise it.
3123

3124
  """
3125
  def RaiseInner(self):
3126
    if self.args and isinstance(self.args[0], Exception):
3127
      raise self.args[0]
3128
    else:
3129
      raise RetryTimeout(*self.args)
3130

    
3131

    
3132
class RetryAgain(Exception):
3133
  """Retry again.
3134

3135
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3136
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3137
  of the RetryTimeout() method can be used to reraise it.
3138

3139
  """
3140

    
3141

    
3142
class _RetryDelayCalculator(object):
3143
  """Calculator for increasing delays.
3144

3145
  """
3146
  __slots__ = [
3147
    "_factor",
3148
    "_limit",
3149
    "_next",
3150
    "_start",
3151
    ]
3152

    
3153
  def __init__(self, start, factor, limit):
3154
    """Initializes this class.
3155

3156
    @type start: float
3157
    @param start: Initial delay
3158
    @type factor: float
3159
    @param factor: Factor for delay increase
3160
    @type limit: float or None
3161
    @param limit: Upper limit for delay or None for no limit
3162

3163
    """
3164
    assert start > 0.0
3165
    assert factor >= 1.0
3166
    assert limit is None or limit >= 0.0
3167

    
3168
    self._start = start
3169
    self._factor = factor
3170
    self._limit = limit
3171

    
3172
    self._next = start
3173

    
3174
  def __call__(self):
3175
    """Returns current delay and calculates the next one.
3176

3177
    """
3178
    current = self._next
3179

    
3180
    # Update for next run
3181
    if self._limit is None or self._next < self._limit:
3182
      self._next = min(self._limit, self._next * self._factor)
3183

    
3184
    return current
3185

    
3186

    
3187
#: Special delay to specify whole remaining timeout
3188
RETRY_REMAINING_TIME = object()
3189

    
3190

    
3191
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3192
          _time_fn=time.time):
3193
  """Call a function repeatedly until it succeeds.
3194

3195
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3196
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3197
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3198

3199
  C{delay} can be one of the following:
3200
    - callable returning the delay length as a float
3201
    - Tuple of (start, factor, limit)
3202
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3203
      useful when overriding L{wait_fn} to wait for an external event)
3204
    - A static delay as a number (int or float)
3205

3206
  @type fn: callable
3207
  @param fn: Function to be called
3208
  @param delay: Either a callable (returning the delay), a tuple of (start,
3209
                factor, limit) (see L{_RetryDelayCalculator}),
3210
                L{RETRY_REMAINING_TIME} or a number (int or float)
3211
  @type timeout: float
3212
  @param timeout: Total timeout
3213
  @type wait_fn: callable
3214
  @param wait_fn: Waiting function
3215
  @return: Return value of function
3216

3217
  """
3218
  assert callable(fn)
3219
  assert callable(wait_fn)
3220
  assert callable(_time_fn)
3221

    
3222
  if args is None:
3223
    args = []
3224

    
3225
  end_time = _time_fn() + timeout
3226

    
3227
  if callable(delay):
3228
    # External function to calculate delay
3229
    calc_delay = delay
3230

    
3231
  elif isinstance(delay, (tuple, list)):
3232
    # Increasing delay with optional upper boundary
3233
    (start, factor, limit) = delay
3234
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3235

    
3236
  elif delay is RETRY_REMAINING_TIME:
3237
    # Always use the remaining time
3238
    calc_delay = None
3239

    
3240
  else:
3241
    # Static delay
3242
    calc_delay = lambda: delay
3243

    
3244
  assert calc_delay is None or callable(calc_delay)
3245

    
3246
  while True:
3247
    retry_args = []
3248
    try:
3249
      # pylint: disable-msg=W0142
3250
      return fn(*args)
3251
    except RetryAgain, err:
3252
      retry_args = err.args
3253
    except RetryTimeout:
3254
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3255
                                   " handle RetryTimeout")
3256

    
3257
    remaining_time = end_time - _time_fn()
3258

    
3259
    if remaining_time < 0.0:
3260
      # pylint: disable-msg=W0142
3261
      raise RetryTimeout(*retry_args)
3262

    
3263
    assert remaining_time >= 0.0
3264

    
3265
    if calc_delay is None:
3266
      wait_fn(remaining_time)
3267
    else:
3268
      current_delay = calc_delay()
3269
      if current_delay > 0.0:
3270
        wait_fn(current_delay)
3271

    
3272

    
3273
def GetClosedTempfile(*args, **kwargs):
3274
  """Creates a temporary file and returns its path.
3275

3276
  """
3277
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3278
  _CloseFDNoErr(fd)
3279
  return path
3280

    
3281

    
3282
def GenerateSelfSignedX509Cert(common_name, validity):
3283
  """Generates a self-signed X509 certificate.
3284

3285
  @type common_name: string
3286
  @param common_name: commonName value
3287
  @type validity: int
3288
  @param validity: Validity for certificate in seconds
3289

3290
  """
3291
  # Create private and public key
3292
  key = OpenSSL.crypto.PKey()
3293
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3294

    
3295
  # Create self-signed certificate
3296
  cert = OpenSSL.crypto.X509()
3297
  if common_name:
3298
    cert.get_subject().CN = common_name
3299
  cert.set_serial_number(1)
3300
  cert.gmtime_adj_notBefore(0)
3301
  cert.gmtime_adj_notAfter(validity)
3302
  cert.set_issuer(cert.get_subject())
3303
  cert.set_pubkey(key)
3304
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3305

    
3306
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3307
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3308

    
3309
  return (key_pem, cert_pem)
3310

    
3311

    
3312
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3313
  """Legacy function to generate self-signed X509 certificate.
3314

3315
  """
3316
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3317
                                                   validity * 24 * 60 * 60)
3318

    
3319
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3320

    
3321

    
3322
class FileLock(object):
3323
  """Utility class for file locks.
3324

3325
  """
3326
  def __init__(self, fd, filename):
3327
    """Constructor for FileLock.
3328

3329
    @type fd: file
3330
    @param fd: File object
3331
    @type filename: str
3332
    @param filename: Path of the file opened at I{fd}
3333

3334
    """
3335
    self.fd = fd
3336
    self.filename = filename
3337

    
3338
  @classmethod
3339
  def Open(cls, filename):
3340
    """Creates and opens a file to be used as a file-based lock.
3341

3342
    @type filename: string
3343
    @param filename: path to the file to be locked
3344

3345
    """
3346
    # Using "os.open" is necessary to allow both opening existing file
3347
    # read/write and creating if not existing. Vanilla "open" will truncate an
3348
    # existing file -or- allow creating if not existing.
3349
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3350
               filename)
3351

    
3352
  def __del__(self):
3353
    self.Close()
3354

    
3355
  def Close(self):
3356
    """Close the file and release the lock.
3357

3358
    """
3359
    if hasattr(self, "fd") and self.fd:
3360
      self.fd.close()
3361
      self.fd = None
3362

    
3363
  def _flock(self, flag, blocking, timeout, errmsg):
3364
    """Wrapper for fcntl.flock.
3365

3366
    @type flag: int
3367
    @param flag: operation flag
3368
    @type blocking: bool
3369
    @param blocking: whether the operation should be done in blocking mode.
3370
    @type timeout: None or float
3371
    @param timeout: for how long the operation should be retried (implies
3372
                    non-blocking mode).
3373
    @type errmsg: string
3374
    @param errmsg: error message in case operation fails.
3375

3376
    """
3377
    assert self.fd, "Lock was closed"
3378
    assert timeout is None or timeout >= 0, \
3379
      "If specified, timeout must be positive"
3380
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3381

    
3382
    # When a timeout is used, LOCK_NB must always be set
3383
    if not (timeout is None and blocking):
3384
      flag |= fcntl.LOCK_NB
3385

    
3386
    if timeout is None:
3387
      self._Lock(self.fd, flag, timeout)
3388
    else:
3389
      try:
3390
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3391
              args=(self.fd, flag, timeout))
3392
      except RetryTimeout:
3393
        raise errors.LockError(errmsg)
3394

    
3395
  @staticmethod
3396
  def _Lock(fd, flag, timeout):
3397
    try:
3398
      fcntl.flock(fd, flag)
3399
    except IOError, err:
3400
      if timeout is not None and err.errno == errno.EAGAIN:
3401
        raise RetryAgain()
3402

    
3403
      logging.exception("fcntl.flock failed")
3404
      raise
3405

    
3406
  def Exclusive(self, blocking=False, timeout=None):
3407
    """Locks the file in exclusive mode.
3408

3409
    @type blocking: boolean
3410
    @param blocking: whether to block and wait until we
3411
        can lock the file or return immediately
3412
    @type timeout: int or None
3413
    @param timeout: if not None, the duration to wait for the lock
3414
        (in blocking mode)
3415

3416
    """
3417
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3418
                "Failed to lock %s in exclusive mode" % self.filename)
3419

    
3420
  def Shared(self, blocking=False, timeout=None):
3421
    """Locks the file in shared mode.
3422

3423
    @type blocking: boolean
3424
    @param blocking: whether to block and wait until we
3425
        can lock the file or return immediately
3426
    @type timeout: int or None
3427
    @param timeout: if not None, the duration to wait for the lock
3428
        (in blocking mode)
3429

3430
    """
3431
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3432
                "Failed to lock %s in shared mode" % self.filename)
3433

    
3434
  def Unlock(self, blocking=True, timeout=None):
3435
    """Unlocks the file.
3436

3437
    According to C{flock(2)}, unlocking can also be a nonblocking
3438
    operation::
3439

3440
      To make a non-blocking request, include LOCK_NB with any of the above
3441
      operations.
3442

3443
    @type blocking: boolean
3444
    @param blocking: whether to block and wait until we
3445
        can lock the file or return immediately
3446
    @type timeout: int or None
3447
    @param timeout: if not None, the duration to wait for the lock
3448
        (in blocking mode)
3449

3450
    """
3451
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3452
                "Failed to unlock %s" % self.filename)
3453

    
3454

    
3455
class LineSplitter:
3456
  """Splits data chunks into lines separated by newline.
3457

3458
  Instances provide a file-like interface.
3459

3460
  """
3461
  def __init__(self, line_fn, *args):
3462
    """Initializes this class.
3463

3464
    @type line_fn: callable
3465
    @param line_fn: Function called for each line, first parameter is line
3466
    @param args: Extra arguments for L{line_fn}
3467

3468
    """
3469
    assert callable(line_fn)
3470

    
3471
    if args:
3472
      # Python 2.4 doesn't have functools.partial yet
3473
      self._line_fn = \
3474
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3475
    else:
3476
      self._line_fn = line_fn
3477

    
3478
    self._lines = collections.deque()
3479
    self._buffer = ""
3480

    
3481
  def write(self, data):
3482
    parts = (self._buffer + data).split("\n")
3483
    self._buffer = parts.pop()
3484
    self._lines.extend(parts)
3485

    
3486
  def flush(self):
3487
    while self._lines:
3488
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3489

    
3490
  def close(self):
3491
    self.flush()
3492
    if self._buffer:
3493
      self._line_fn(self._buffer)
3494

    
3495

    
3496
def SignalHandled(signums):
3497
  """Signal Handled decoration.
3498

3499
  This special decorator installs a signal handler and then calls the target
3500
  function. The function must accept a 'signal_handlers' keyword argument,
3501
  which will contain a dict indexed by signal number, with SignalHandler
3502
  objects as values.
3503

3504
  The decorator can be safely stacked with iself, to handle multiple signals
3505
  with different handlers.
3506

3507
  @type signums: list
3508
  @param signums: signals to intercept
3509

3510
  """
3511
  def wrap(fn):
3512
    def sig_function(*args, **kwargs):
3513
      assert 'signal_handlers' not in kwargs or \
3514
             kwargs['signal_handlers'] is None or \
3515
             isinstance(kwargs['signal_handlers'], dict), \
3516
             "Wrong signal_handlers parameter in original function call"
3517
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3518
        signal_handlers = kwargs['signal_handlers']
3519
      else:
3520
        signal_handlers = {}
3521
        kwargs['signal_handlers'] = signal_handlers
3522
      sighandler = SignalHandler(signums)
3523
      try:
3524
        for sig in signums:
3525
          signal_handlers[sig] = sighandler
3526
        return fn(*args, **kwargs)
3527
      finally:
3528
        sighandler.Reset()
3529
    return sig_function
3530
  return wrap
3531

    
3532

    
3533
class SignalWakeupFd(object):
3534
  try:
3535
    # This is only supported in Python 2.5 and above (some distributions
3536
    # backported it to Python 2.4)
3537
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3538
  except AttributeError:
3539
    # Not supported
3540
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3541
      return -1
3542
  else:
3543
    def _SetWakeupFd(self, fd):
3544
      return self._set_wakeup_fd_fn(fd)
3545

    
3546
  def __init__(self):
3547
    """Initializes this class.
3548

3549
    """
3550
    (read_fd, write_fd) = os.pipe()
3551

    
3552
    # Once these succeeded, the file descriptors will be closed automatically.
3553
    # Buffer size 0 is important, otherwise .read() with a specified length
3554
    # might buffer data and the file descriptors won't be marked readable.
3555
    self._read_fh = os.fdopen(read_fd, "r", 0)
3556
    self._write_fh = os.fdopen(write_fd, "w", 0)
3557

    
3558
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3559

    
3560
    # Utility functions
3561
    self.fileno = self._read_fh.fileno
3562
    self.read = self._read_fh.read
3563

    
3564
  def Reset(self):
3565
    """Restores the previous wakeup file descriptor.
3566

3567
    """
3568
    if hasattr(self, "_previous") and self._previous is not None:
3569
      self._SetWakeupFd(self._previous)
3570
      self._previous = None
3571

    
3572
  def Notify(self):
3573
    """Notifies the wakeup file descriptor.
3574

3575
    """
3576
    self._write_fh.write("\0")
3577

    
3578
  def __del__(self):
3579
    """Called before object deletion.
3580

3581
    """
3582
    self.Reset()
3583

    
3584

    
3585
class SignalHandler(object):
3586
  """Generic signal handler class.
3587

3588
  It automatically restores the original handler when deconstructed or
3589
  when L{Reset} is called. You can either pass your own handler
3590
  function in or query the L{called} attribute to detect whether the
3591
  signal was sent.
3592

3593
  @type signum: list
3594
  @ivar signum: the signals we handle
3595
  @type called: boolean
3596
  @ivar called: tracks whether any of the signals have been raised
3597

3598
  """
3599
  def __init__(self, signum, handler_fn=None, wakeup=None):
3600
    """Constructs a new SignalHandler instance.
3601

3602
    @type signum: int or list of ints
3603
    @param signum: Single signal number or set of signal numbers
3604
    @type handler_fn: callable
3605
    @param handler_fn: Signal handling function
3606

3607
    """
3608
    assert handler_fn is None or callable(handler_fn)
3609

    
3610
    self.signum = set(signum)
3611
    self.called = False
3612

    
3613
    self._handler_fn = handler_fn
3614
    self._wakeup = wakeup
3615

    
3616
    self._previous = {}
3617
    try:
3618
      for signum in self.signum:
3619
        # Setup handler
3620
        prev_handler = signal.signal(signum, self._HandleSignal)
3621
        try:
3622
          self._previous[signum] = prev_handler
3623
        except:
3624
          # Restore previous handler
3625
          signal.signal(signum, prev_handler)
3626
          raise
3627
    except:
3628
      # Reset all handlers
3629
      self.Reset()
3630
      # Here we have a race condition: a handler may have already been called,
3631
      # but there's not much we can do about it at this point.
3632
      raise
3633

    
3634
  def __del__(self):
3635
    self.Reset()
3636

    
3637
  def Reset(self):
3638
    """Restore previous handler.
3639

3640
    This will reset all the signals to their previous handlers.
3641

3642
    """
3643
    for signum, prev_handler in self._previous.items():
3644
      signal.signal(signum, prev_handler)
3645
      # If successful, remove from dict
3646
      del self._previous[signum]
3647

    
3648
  def Clear(self):
3649
    """Unsets the L{called} flag.
3650

3651
    This function can be used in case a signal may arrive several times.
3652

3653
    """
3654
    self.called = False
3655

    
3656
  def _HandleSignal(self, signum, frame):
3657
    """Actual signal handling function.
3658

3659
    """
3660
    # This is not nice and not absolutely atomic, but it appears to be the only
3661
    # solution in Python -- there are no atomic types.
3662
    self.called = True
3663

    
3664
    if self._wakeup:
3665
      # Notify whoever is interested in signals
3666
      self._wakeup.Notify()
3667

    
3668
    if self._handler_fn:
3669
      self._handler_fn(signum, frame)
3670

    
3671

    
3672
class FieldSet(object):
3673
  """A simple field set.
3674

3675
  Among the features are:
3676
    - checking if a string is among a list of static string or regex objects
3677
    - checking if a whole list of string matches
3678
    - returning the matching groups from a regex match
3679

3680
  Internally, all fields are held as regular expression objects.
3681

3682
  """
3683
  def __init__(self, *items):
3684
    self.items = [re.compile("^%s$" % value) for value in items]
3685

    
3686
  def Extend(self, other_set):
3687
    """Extend the field set with the items from another one"""
3688
    self.items.extend(other_set.items)
3689

    
3690
  def Matches(self, field):
3691
    """Checks if a field matches the current set
3692

3693
    @type field: str
3694
    @param field: the string to match
3695
    @return: either None or a regular expression match object
3696

3697
    """
3698
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3699
      return m
3700
    return None
3701

    
3702
  def NonMatching(self, items):
3703
    """Returns the list of fields not matching the current set
3704

3705
    @type items: list
3706
    @param items: the list of fields to check
3707
    @rtype: list
3708
    @return: list of non-matching fields
3709

3710
    """
3711
    return [val for val in items if not self.Matches(val)]