Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 72087dcd

History | View | Annotate | Download (100.8 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  from hashlib import sha1
59
except ImportError:
60
  import sha as sha1
61

    
62
try:
63
  import ctypes
64
except ImportError:
65
  ctypes = None
66

    
67
from ganeti import errors
68
from ganeti import constants
69

    
70

    
71
_locksheld = []
72
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
73

    
74
debug_locks = False
75

    
76
#: when set to True, L{RunCmd} is disabled
77
no_fork = False
78

    
79
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
80

    
81
HEX_CHAR_RE = r"[a-zA-Z0-9]"
82
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
83
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
84
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
85
                             HEX_CHAR_RE, HEX_CHAR_RE),
86
                            re.S | re.I)
87

    
88
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
89
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
90
#
91
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
92
# pid_t to "int".
93
#
94
# IEEE Std 1003.1-2008:
95
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
96
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
97
_STRUCT_UCRED = "iII"
98
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
99

    
100
# Certificate verification results
101
(CERT_WARNING,
102
 CERT_ERROR) = range(1, 3)
103

    
104
# Flags for mlockall() (from bits/mman.h)
105
_MCL_CURRENT = 1
106
_MCL_FUTURE = 2
107

    
108

    
109
class RunResult(object):
110
  """Holds the result of running external programs.
111

112
  @type exit_code: int
113
  @ivar exit_code: the exit code of the program, or None (if the program
114
      didn't exit())
115
  @type signal: int or None
116
  @ivar signal: the signal that caused the program to finish, or None
117
      (if the program wasn't terminated by a signal)
118
  @type stdout: str
119
  @ivar stdout: the standard output of the program
120
  @type stderr: str
121
  @ivar stderr: the standard error of the program
122
  @type failed: boolean
123
  @ivar failed: True in case the program was
124
      terminated by a signal or exited with a non-zero exit code
125
  @ivar fail_reason: a string detailing the termination reason
126

127
  """
128
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
129
               "failed", "fail_reason", "cmd"]
130

    
131

    
132
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
133
    self.cmd = cmd
134
    self.exit_code = exit_code
135
    self.signal = signal_
136
    self.stdout = stdout
137
    self.stderr = stderr
138
    self.failed = (signal_ is not None or exit_code != 0)
139

    
140
    if self.signal is not None:
141
      self.fail_reason = "terminated by signal %s" % self.signal
142
    elif self.exit_code is not None:
143
      self.fail_reason = "exited with exit code %s" % self.exit_code
144
    else:
145
      self.fail_reason = "unable to determine termination reason"
146

    
147
    if self.failed:
148
      logging.debug("Command '%s' failed (%s); output: %s",
149
                    self.cmd, self.fail_reason, self.output)
150

    
151
  def _GetOutput(self):
152
    """Returns the combined stdout and stderr for easier usage.
153

154
    """
155
    return self.stdout + self.stderr
156

    
157
  output = property(_GetOutput, None, None, "Return full output")
158

    
159

    
160
def _BuildCmdEnvironment(env, reset):
161
  """Builds the environment for an external program.
162

163
  """
164
  if reset:
165
    cmd_env = {}
166
  else:
167
    cmd_env = os.environ.copy()
168
    cmd_env["LC_ALL"] = "C"
169

    
170
  if env is not None:
171
    cmd_env.update(env)
172

    
173
  return cmd_env
174

    
175

    
176
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
177
  """Execute a (shell) command.
178

179
  The command should not read from its standard input, as it will be
180
  closed.
181

182
  @type cmd: string or list
183
  @param cmd: Command to run
184
  @type env: dict
185
  @param env: Additional environment variables
186
  @type output: str
187
  @param output: if desired, the output of the command can be
188
      saved in a file instead of the RunResult instance; this
189
      parameter denotes the file name (if not None)
190
  @type cwd: string
191
  @param cwd: if specified, will be used as the working
192
      directory for the command; the default will be /
193
  @type reset_env: boolean
194
  @param reset_env: whether to reset or keep the default os environment
195
  @rtype: L{RunResult}
196
  @return: RunResult instance
197
  @raise errors.ProgrammerError: if we call this when forks are disabled
198

199
  """
200
  if no_fork:
201
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
202

    
203
  if isinstance(cmd, basestring):
204
    strcmd = cmd
205
    shell = True
206
  else:
207
    cmd = [str(val) for val in cmd]
208
    strcmd = ShellQuoteArgs(cmd)
209
    shell = False
210

    
211
  if output:
212
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
213
  else:
214
    logging.debug("RunCmd %s", strcmd)
215

    
216
  cmd_env = _BuildCmdEnvironment(env, reset_env)
217

    
218
  try:
219
    if output is None:
220
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
221
    else:
222
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
223
      out = err = ""
224
  except OSError, err:
225
    if err.errno == errno.ENOENT:
226
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
227
                               (strcmd, err))
228
    else:
229
      raise
230

    
231
  if status >= 0:
232
    exitcode = status
233
    signal_ = None
234
  else:
235
    exitcode = None
236
    signal_ = -status
237

    
238
  return RunResult(exitcode, signal_, out, err, strcmd)
239

    
240

    
241
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
242
                pidfile=None):
243
  """Start a daemon process after forking twice.
244

245
  @type cmd: string or list
246
  @param cmd: Command to run
247
  @type env: dict
248
  @param env: Additional environment variables
249
  @type cwd: string
250
  @param cwd: Working directory for the program
251
  @type output: string
252
  @param output: Path to file in which to save the output
253
  @type output_fd: int
254
  @param output_fd: File descriptor for output
255
  @type pidfile: string
256
  @param pidfile: Process ID file
257
  @rtype: int
258
  @return: Daemon process ID
259
  @raise errors.ProgrammerError: if we call this when forks are disabled
260

261
  """
262
  if no_fork:
263
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
264
                                 " disabled")
265

    
266
  if output and not (bool(output) ^ (output_fd is not None)):
267
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
268
                                 " specified")
269

    
270
  if isinstance(cmd, basestring):
271
    cmd = ["/bin/sh", "-c", cmd]
272

    
273
  strcmd = ShellQuoteArgs(cmd)
274

    
275
  if output:
276
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
277
  else:
278
    logging.debug("StartDaemon %s", strcmd)
279

    
280
  cmd_env = _BuildCmdEnvironment(env, False)
281

    
282
  # Create pipe for sending PID back
283
  (pidpipe_read, pidpipe_write) = os.pipe()
284
  try:
285
    try:
286
      # Create pipe for sending error messages
287
      (errpipe_read, errpipe_write) = os.pipe()
288
      try:
289
        try:
290
          # First fork
291
          pid = os.fork()
292
          if pid == 0:
293
            try:
294
              # Child process, won't return
295
              _StartDaemonChild(errpipe_read, errpipe_write,
296
                                pidpipe_read, pidpipe_write,
297
                                cmd, cmd_env, cwd,
298
                                output, output_fd, pidfile)
299
            finally:
300
              # Well, maybe child process failed
301
              os._exit(1) # pylint: disable-msg=W0212
302
        finally:
303
          _CloseFDNoErr(errpipe_write)
304

    
305
        # Wait for daemon to be started (or an error message to arrive) and read
306
        # up to 100 KB as an error message
307
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
308
      finally:
309
        _CloseFDNoErr(errpipe_read)
310
    finally:
311
      _CloseFDNoErr(pidpipe_write)
312

    
313
    # Read up to 128 bytes for PID
314
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
315
  finally:
316
    _CloseFDNoErr(pidpipe_read)
317

    
318
  # Try to avoid zombies by waiting for child process
319
  try:
320
    os.waitpid(pid, 0)
321
  except OSError:
322
    pass
323

    
324
  if errormsg:
325
    raise errors.OpExecError("Error when starting daemon process: %r" %
326
                             errormsg)
327

    
328
  try:
329
    return int(pidtext)
330
  except (ValueError, TypeError), err:
331
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
332
                             (pidtext, err))
333

    
334

    
335
def _StartDaemonChild(errpipe_read, errpipe_write,
336
                      pidpipe_read, pidpipe_write,
337
                      args, env, cwd,
338
                      output, fd_output, pidfile):
339
  """Child process for starting daemon.
340

341
  """
342
  try:
343
    # Close parent's side
344
    _CloseFDNoErr(errpipe_read)
345
    _CloseFDNoErr(pidpipe_read)
346

    
347
    # First child process
348
    os.chdir("/")
349
    os.umask(077)
350
    os.setsid()
351

    
352
    # And fork for the second time
353
    pid = os.fork()
354
    if pid != 0:
355
      # Exit first child process
356
      os._exit(0) # pylint: disable-msg=W0212
357

    
358
    # Make sure pipe is closed on execv* (and thereby notifies original process)
359
    SetCloseOnExecFlag(errpipe_write, True)
360

    
361
    # List of file descriptors to be left open
362
    noclose_fds = [errpipe_write]
363

    
364
    # Open PID file
365
    if pidfile:
366
      try:
367
        # TODO: Atomic replace with another locked file instead of writing into
368
        # it after creating
369
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
370

    
371
        # Lock the PID file (and fail if not possible to do so). Any code
372
        # wanting to send a signal to the daemon should try to lock the PID
373
        # file before reading it. If acquiring the lock succeeds, the daemon is
374
        # no longer running and the signal should not be sent.
375
        LockFile(fd_pidfile)
376

    
377
        os.write(fd_pidfile, "%d\n" % os.getpid())
378
      except Exception, err:
379
        raise Exception("Creating and locking PID file failed: %s" % err)
380

    
381
      # Keeping the file open to hold the lock
382
      noclose_fds.append(fd_pidfile)
383

    
384
      SetCloseOnExecFlag(fd_pidfile, False)
385
    else:
386
      fd_pidfile = None
387

    
388
    # Open /dev/null
389
    fd_devnull = os.open(os.devnull, os.O_RDWR)
390

    
391
    assert not output or (bool(output) ^ (fd_output is not None))
392

    
393
    if fd_output is not None:
394
      pass
395
    elif output:
396
      # Open output file
397
      try:
398
        # TODO: Implement flag to set append=yes/no
399
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
400
      except EnvironmentError, err:
401
        raise Exception("Opening output file failed: %s" % err)
402
    else:
403
      fd_output = fd_devnull
404

    
405
    # Redirect standard I/O
406
    os.dup2(fd_devnull, 0)
407
    os.dup2(fd_output, 1)
408
    os.dup2(fd_output, 2)
409

    
410
    # Send daemon PID to parent
411
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
412

    
413
    # Close all file descriptors except stdio and error message pipe
414
    CloseFDs(noclose_fds=noclose_fds)
415

    
416
    # Change working directory
417
    os.chdir(cwd)
418

    
419
    if env is None:
420
      os.execvp(args[0], args)
421
    else:
422
      os.execvpe(args[0], args, env)
423
  except: # pylint: disable-msg=W0702
424
    try:
425
      # Report errors to original process
426
      buf = str(sys.exc_info()[1])
427

    
428
      RetryOnSignal(os.write, errpipe_write, buf)
429
    except: # pylint: disable-msg=W0702
430
      # Ignore errors in error handling
431
      pass
432

    
433
  os._exit(1) # pylint: disable-msg=W0212
434

    
435

    
436
def _RunCmdPipe(cmd, env, via_shell, cwd):
437
  """Run a command and return its output.
438

439
  @type  cmd: string or list
440
  @param cmd: Command to run
441
  @type env: dict
442
  @param env: The environment to use
443
  @type via_shell: bool
444
  @param via_shell: if we should run via the shell
445
  @type cwd: string
446
  @param cwd: the working directory for the program
447
  @rtype: tuple
448
  @return: (out, err, status)
449

450
  """
451
  poller = select.poll()
452
  child = subprocess.Popen(cmd, shell=via_shell,
453
                           stderr=subprocess.PIPE,
454
                           stdout=subprocess.PIPE,
455
                           stdin=subprocess.PIPE,
456
                           close_fds=True, env=env,
457
                           cwd=cwd)
458

    
459
  child.stdin.close()
460
  poller.register(child.stdout, select.POLLIN)
461
  poller.register(child.stderr, select.POLLIN)
462
  out = StringIO()
463
  err = StringIO()
464
  fdmap = {
465
    child.stdout.fileno(): (out, child.stdout),
466
    child.stderr.fileno(): (err, child.stderr),
467
    }
468
  for fd in fdmap:
469
    SetNonblockFlag(fd, True)
470

    
471
  while fdmap:
472
    pollresult = RetryOnSignal(poller.poll)
473

    
474
    for fd, event in pollresult:
475
      if event & select.POLLIN or event & select.POLLPRI:
476
        data = fdmap[fd][1].read()
477
        # no data from read signifies EOF (the same as POLLHUP)
478
        if not data:
479
          poller.unregister(fd)
480
          del fdmap[fd]
481
          continue
482
        fdmap[fd][0].write(data)
483
      if (event & select.POLLNVAL or event & select.POLLHUP or
484
          event & select.POLLERR):
485
        poller.unregister(fd)
486
        del fdmap[fd]
487

    
488
  out = out.getvalue()
489
  err = err.getvalue()
490

    
491
  status = child.wait()
492
  return out, err, status
493

    
494

    
495
def _RunCmdFile(cmd, env, via_shell, output, cwd):
496
  """Run a command and save its output to a file.
497

498
  @type  cmd: string or list
499
  @param cmd: Command to run
500
  @type env: dict
501
  @param env: The environment to use
502
  @type via_shell: bool
503
  @param via_shell: if we should run via the shell
504
  @type output: str
505
  @param output: the filename in which to save the output
506
  @type cwd: string
507
  @param cwd: the working directory for the program
508
  @rtype: int
509
  @return: the exit status
510

511
  """
512
  fh = open(output, "a")
513
  try:
514
    child = subprocess.Popen(cmd, shell=via_shell,
515
                             stderr=subprocess.STDOUT,
516
                             stdout=fh,
517
                             stdin=subprocess.PIPE,
518
                             close_fds=True, env=env,
519
                             cwd=cwd)
520

    
521
    child.stdin.close()
522
    status = child.wait()
523
  finally:
524
    fh.close()
525
  return status
526

    
527

    
528
def SetCloseOnExecFlag(fd, enable):
529
  """Sets or unsets the close-on-exec flag on a file descriptor.
530

531
  @type fd: int
532
  @param fd: File descriptor
533
  @type enable: bool
534
  @param enable: Whether to set or unset it.
535

536
  """
537
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
538

    
539
  if enable:
540
    flags |= fcntl.FD_CLOEXEC
541
  else:
542
    flags &= ~fcntl.FD_CLOEXEC
543

    
544
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
545

    
546

    
547
def SetNonblockFlag(fd, enable):
548
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
549

550
  @type fd: int
551
  @param fd: File descriptor
552
  @type enable: bool
553
  @param enable: Whether to set or unset it
554

555
  """
556
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
557

    
558
  if enable:
559
    flags |= os.O_NONBLOCK
560
  else:
561
    flags &= ~os.O_NONBLOCK
562

    
563
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
564

    
565

    
566
def RetryOnSignal(fn, *args, **kwargs):
567
  """Calls a function again if it failed due to EINTR.
568

569
  """
570
  while True:
571
    try:
572
      return fn(*args, **kwargs)
573
    except (EnvironmentError, socket.error), err:
574
      if err.errno != errno.EINTR:
575
        raise
576
    except select.error, err:
577
      if not (err.args and err.args[0] == errno.EINTR):
578
        raise
579

    
580

    
581
def RunParts(dir_name, env=None, reset_env=False):
582
  """Run Scripts or programs in a directory
583

584
  @type dir_name: string
585
  @param dir_name: absolute path to a directory
586
  @type env: dict
587
  @param env: The environment to use
588
  @type reset_env: boolean
589
  @param reset_env: whether to reset or keep the default os environment
590
  @rtype: list of tuples
591
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
592

593
  """
594
  rr = []
595

    
596
  try:
597
    dir_contents = ListVisibleFiles(dir_name)
598
  except OSError, err:
599
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
600
    return rr
601

    
602
  for relname in sorted(dir_contents):
603
    fname = PathJoin(dir_name, relname)
604
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
605
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
606
      rr.append((relname, constants.RUNPARTS_SKIP, None))
607
    else:
608
      try:
609
        result = RunCmd([fname], env=env, reset_env=reset_env)
610
      except Exception, err: # pylint: disable-msg=W0703
611
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
612
      else:
613
        rr.append((relname, constants.RUNPARTS_RUN, result))
614

    
615
  return rr
616

    
617

    
618
def GetSocketCredentials(sock):
619
  """Returns the credentials of the foreign process connected to a socket.
620

621
  @param sock: Unix socket
622
  @rtype: tuple; (number, number, number)
623
  @return: The PID, UID and GID of the connected foreign process.
624

625
  """
626
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
627
                             _STRUCT_UCRED_SIZE)
628
  return struct.unpack(_STRUCT_UCRED, peercred)
629

    
630

    
631
def RemoveFile(filename):
632
  """Remove a file ignoring some errors.
633

634
  Remove a file, ignoring non-existing ones or directories. Other
635
  errors are passed.
636

637
  @type filename: str
638
  @param filename: the file to be removed
639

640
  """
641
  try:
642
    os.unlink(filename)
643
  except OSError, err:
644
    if err.errno not in (errno.ENOENT, errno.EISDIR):
645
      raise
646

    
647

    
648
def RemoveDir(dirname):
649
  """Remove an empty directory.
650

651
  Remove a directory, ignoring non-existing ones.
652
  Other errors are passed. This includes the case,
653
  where the directory is not empty, so it can't be removed.
654

655
  @type dirname: str
656
  @param dirname: the empty directory to be removed
657

658
  """
659
  try:
660
    os.rmdir(dirname)
661
  except OSError, err:
662
    if err.errno != errno.ENOENT:
663
      raise
664

    
665

    
666
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
667
  """Renames a file.
668

669
  @type old: string
670
  @param old: Original path
671
  @type new: string
672
  @param new: New path
673
  @type mkdir: bool
674
  @param mkdir: Whether to create target directory if it doesn't exist
675
  @type mkdir_mode: int
676
  @param mkdir_mode: Mode for newly created directories
677

678
  """
679
  try:
680
    return os.rename(old, new)
681
  except OSError, err:
682
    # In at least one use case of this function, the job queue, directory
683
    # creation is very rare. Checking for the directory before renaming is not
684
    # as efficient.
685
    if mkdir and err.errno == errno.ENOENT:
686
      # Create directory and try again
687
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
688

    
689
      return os.rename(old, new)
690

    
691
    raise
692

    
693

    
694
def Makedirs(path, mode=0750):
695
  """Super-mkdir; create a leaf directory and all intermediate ones.
696

697
  This is a wrapper around C{os.makedirs} adding error handling not implemented
698
  before Python 2.5.
699

700
  """
701
  try:
702
    os.makedirs(path, mode)
703
  except OSError, err:
704
    # Ignore EEXIST. This is only handled in os.makedirs as included in
705
    # Python 2.5 and above.
706
    if err.errno != errno.EEXIST or not os.path.exists(path):
707
      raise
708

    
709

    
710
def ResetTempfileModule():
711
  """Resets the random name generator of the tempfile module.
712

713
  This function should be called after C{os.fork} in the child process to
714
  ensure it creates a newly seeded random generator. Otherwise it would
715
  generate the same random parts as the parent process. If several processes
716
  race for the creation of a temporary file, this could lead to one not getting
717
  a temporary name.
718

719
  """
720
  # pylint: disable-msg=W0212
721
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
722
    tempfile._once_lock.acquire()
723
    try:
724
      # Reset random name generator
725
      tempfile._name_sequence = None
726
    finally:
727
      tempfile._once_lock.release()
728
  else:
729
    logging.critical("The tempfile module misses at least one of the"
730
                     " '_once_lock' and '_name_sequence' attributes")
731

    
732

    
733
def _FingerprintFile(filename):
734
  """Compute the fingerprint of a file.
735

736
  If the file does not exist, a None will be returned
737
  instead.
738

739
  @type filename: str
740
  @param filename: the filename to checksum
741
  @rtype: str
742
  @return: the hex digest of the sha checksum of the contents
743
      of the file
744

745
  """
746
  if not (os.path.exists(filename) and os.path.isfile(filename)):
747
    return None
748

    
749
  f = open(filename)
750

    
751
  if callable(sha1):
752
    fp = sha1()
753
  else:
754
    fp = sha1.new()
755
  while True:
756
    data = f.read(4096)
757
    if not data:
758
      break
759

    
760
    fp.update(data)
761

    
762
  return fp.hexdigest()
763

    
764

    
765
def FingerprintFiles(files):
766
  """Compute fingerprints for a list of files.
767

768
  @type files: list
769
  @param files: the list of filename to fingerprint
770
  @rtype: dict
771
  @return: a dictionary filename: fingerprint, holding only
772
      existing files
773

774
  """
775
  ret = {}
776

    
777
  for filename in files:
778
    cksum = _FingerprintFile(filename)
779
    if cksum:
780
      ret[filename] = cksum
781

    
782
  return ret
783

    
784

    
785
def ForceDictType(target, key_types, allowed_values=None):
786
  """Force the values of a dict to have certain types.
787

788
  @type target: dict
789
  @param target: the dict to update
790
  @type key_types: dict
791
  @param key_types: dict mapping target dict keys to types
792
                    in constants.ENFORCEABLE_TYPES
793
  @type allowed_values: list
794
  @keyword allowed_values: list of specially allowed values
795

796
  """
797
  if allowed_values is None:
798
    allowed_values = []
799

    
800
  if not isinstance(target, dict):
801
    msg = "Expected dictionary, got '%s'" % target
802
    raise errors.TypeEnforcementError(msg)
803

    
804
  for key in target:
805
    if key not in key_types:
806
      msg = "Unknown key '%s'" % key
807
      raise errors.TypeEnforcementError(msg)
808

    
809
    if target[key] in allowed_values:
810
      continue
811

    
812
    ktype = key_types[key]
813
    if ktype not in constants.ENFORCEABLE_TYPES:
814
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
815
      raise errors.ProgrammerError(msg)
816

    
817
    if ktype == constants.VTYPE_STRING:
818
      if not isinstance(target[key], basestring):
819
        if isinstance(target[key], bool) and not target[key]:
820
          target[key] = ''
821
        else:
822
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
823
          raise errors.TypeEnforcementError(msg)
824
    elif ktype == constants.VTYPE_BOOL:
825
      if isinstance(target[key], basestring) and target[key]:
826
        if target[key].lower() == constants.VALUE_FALSE:
827
          target[key] = False
828
        elif target[key].lower() == constants.VALUE_TRUE:
829
          target[key] = True
830
        else:
831
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
832
          raise errors.TypeEnforcementError(msg)
833
      elif target[key]:
834
        target[key] = True
835
      else:
836
        target[key] = False
837
    elif ktype == constants.VTYPE_SIZE:
838
      try:
839
        target[key] = ParseUnit(target[key])
840
      except errors.UnitParseError, err:
841
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
842
              (key, target[key], err)
843
        raise errors.TypeEnforcementError(msg)
844
    elif ktype == constants.VTYPE_INT:
845
      try:
846
        target[key] = int(target[key])
847
      except (ValueError, TypeError):
848
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
849
        raise errors.TypeEnforcementError(msg)
850

    
851

    
852
def IsProcessAlive(pid):
853
  """Check if a given pid exists on the system.
854

855
  @note: zombie status is not handled, so zombie processes
856
      will be returned as alive
857
  @type pid: int
858
  @param pid: the process ID to check
859
  @rtype: boolean
860
  @return: True if the process exists
861

862
  """
863
  def _TryStat(name):
864
    try:
865
      os.stat(name)
866
      return True
867
    except EnvironmentError, err:
868
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
869
        return False
870
      elif err.errno == errno.EINVAL:
871
        raise RetryAgain(err)
872
      raise
873

    
874
  assert isinstance(pid, int), "pid must be an integer"
875
  if pid <= 0:
876
    return False
877

    
878
  proc_entry = "/proc/%d/status" % pid
879
  # /proc in a multiprocessor environment can have strange behaviors.
880
  # Retry the os.stat a few times until we get a good result.
881
  try:
882
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5, args=[proc_entry])
883
  except RetryTimeout, err:
884
    err.RaiseInner()
885

    
886

    
887
def ReadPidFile(pidfile):
888
  """Read a pid from a file.
889

890
  @type  pidfile: string
891
  @param pidfile: path to the file containing the pid
892
  @rtype: int
893
  @return: The process id, if the file exists and contains a valid PID,
894
           otherwise 0
895

896
  """
897
  try:
898
    raw_data = ReadOneLineFile(pidfile)
899
  except EnvironmentError, err:
900
    if err.errno != errno.ENOENT:
901
      logging.exception("Can't read pid file")
902
    return 0
903

    
904
  try:
905
    pid = int(raw_data)
906
  except (TypeError, ValueError), err:
907
    logging.info("Can't parse pid file contents", exc_info=True)
908
    return 0
909

    
910
  return pid
911

    
912

    
913
def ReadLockedPidFile(path):
914
  """Reads a locked PID file.
915

916
  This can be used together with L{StartDaemon}.
917

918
  @type path: string
919
  @param path: Path to PID file
920
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
921

922
  """
923
  try:
924
    fd = os.open(path, os.O_RDONLY)
925
  except EnvironmentError, err:
926
    if err.errno == errno.ENOENT:
927
      # PID file doesn't exist
928
      return None
929
    raise
930

    
931
  try:
932
    try:
933
      # Try to acquire lock
934
      LockFile(fd)
935
    except errors.LockError:
936
      # Couldn't lock, daemon is running
937
      return int(os.read(fd, 100))
938
  finally:
939
    os.close(fd)
940

    
941
  return None
942

    
943

    
944
def MatchNameComponent(key, name_list, case_sensitive=True):
945
  """Try to match a name against a list.
946

947
  This function will try to match a name like test1 against a list
948
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
949
  this list, I{'test1'} as well as I{'test1.example'} will match, but
950
  not I{'test1.ex'}. A multiple match will be considered as no match
951
  at all (e.g. I{'test1'} against C{['test1.example.com',
952
  'test1.example.org']}), except when the key fully matches an entry
953
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
954

955
  @type key: str
956
  @param key: the name to be searched
957
  @type name_list: list
958
  @param name_list: the list of strings against which to search the key
959
  @type case_sensitive: boolean
960
  @param case_sensitive: whether to provide a case-sensitive match
961

962
  @rtype: None or str
963
  @return: None if there is no match I{or} if there are multiple matches,
964
      otherwise the element from the list which matches
965

966
  """
967
  if key in name_list:
968
    return key
969

    
970
  re_flags = 0
971
  if not case_sensitive:
972
    re_flags |= re.IGNORECASE
973
    key = key.upper()
974
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
975
  names_filtered = []
976
  string_matches = []
977
  for name in name_list:
978
    if mo.match(name) is not None:
979
      names_filtered.append(name)
980
      if not case_sensitive and key == name.upper():
981
        string_matches.append(name)
982

    
983
  if len(string_matches) == 1:
984
    return string_matches[0]
985
  if len(names_filtered) == 1:
986
    return names_filtered[0]
987
  return None
988

    
989

    
990
class HostInfo:
991
  """Class implementing resolver and hostname functionality
992

993
  """
994
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
995

    
996
  def __init__(self, name=None):
997
    """Initialize the host name object.
998

999
    If the name argument is not passed, it will use this system's
1000
    name.
1001

1002
    """
1003
    if name is None:
1004
      name = self.SysName()
1005

    
1006
    self.query = name
1007
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1008
    self.ip = self.ipaddrs[0]
1009

    
1010
  def ShortName(self):
1011
    """Returns the hostname without domain.
1012

1013
    """
1014
    return self.name.split('.')[0]
1015

    
1016
  @staticmethod
1017
  def SysName():
1018
    """Return the current system's name.
1019

1020
    This is simply a wrapper over C{socket.gethostname()}.
1021

1022
    """
1023
    return socket.gethostname()
1024

    
1025
  @staticmethod
1026
  def LookupHostname(hostname):
1027
    """Look up hostname
1028

1029
    @type hostname: str
1030
    @param hostname: hostname to look up
1031

1032
    @rtype: tuple
1033
    @return: a tuple (name, aliases, ipaddrs) as returned by
1034
        C{socket.gethostbyname_ex}
1035
    @raise errors.ResolverError: in case of errors in resolving
1036

1037
    """
1038
    try:
1039
      result = socket.gethostbyname_ex(hostname)
1040
    except socket.gaierror, err:
1041
      # hostname not found in DNS
1042
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1043

    
1044
    return result
1045

    
1046
  @classmethod
1047
  def NormalizeName(cls, hostname):
1048
    """Validate and normalize the given hostname.
1049

1050
    @attention: the validation is a bit more relaxed than the standards
1051
        require; most importantly, we allow underscores in names
1052
    @raise errors.OpPrereqError: when the name is not valid
1053

1054
    """
1055
    hostname = hostname.lower()
1056
    if (not cls._VALID_NAME_RE.match(hostname) or
1057
        # double-dots, meaning empty label
1058
        ".." in hostname or
1059
        # empty initial label
1060
        hostname.startswith(".")):
1061
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1062
                                 errors.ECODE_INVAL)
1063
    if hostname.endswith("."):
1064
      hostname = hostname.rstrip(".")
1065
    return hostname
1066

    
1067

    
1068
def GetHostInfo(name=None):
1069
  """Lookup host name and raise an OpPrereqError for failures"""
1070

    
1071
  try:
1072
    return HostInfo(name)
1073
  except errors.ResolverError, err:
1074
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1075
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1076

    
1077

    
1078
def ListVolumeGroups():
1079
  """List volume groups and their size
1080

1081
  @rtype: dict
1082
  @return:
1083
       Dictionary with keys volume name and values
1084
       the size of the volume
1085

1086
  """
1087
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1088
  result = RunCmd(command)
1089
  retval = {}
1090
  if result.failed:
1091
    return retval
1092

    
1093
  for line in result.stdout.splitlines():
1094
    try:
1095
      name, size = line.split()
1096
      size = int(float(size))
1097
    except (IndexError, ValueError), err:
1098
      logging.error("Invalid output from vgs (%s): %s", err, line)
1099
      continue
1100

    
1101
    retval[name] = size
1102

    
1103
  return retval
1104

    
1105

    
1106
def BridgeExists(bridge):
1107
  """Check whether the given bridge exists in the system
1108

1109
  @type bridge: str
1110
  @param bridge: the bridge name to check
1111
  @rtype: boolean
1112
  @return: True if it does
1113

1114
  """
1115
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1116

    
1117

    
1118
def NiceSort(name_list):
1119
  """Sort a list of strings based on digit and non-digit groupings.
1120

1121
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1122
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1123
  'a11']}.
1124

1125
  The sort algorithm breaks each name in groups of either only-digits
1126
  or no-digits. Only the first eight such groups are considered, and
1127
  after that we just use what's left of the string.
1128

1129
  @type name_list: list
1130
  @param name_list: the names to be sorted
1131
  @rtype: list
1132
  @return: a copy of the name list sorted with our algorithm
1133

1134
  """
1135
  _SORTER_BASE = "(\D+|\d+)"
1136
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1137
                                                  _SORTER_BASE, _SORTER_BASE,
1138
                                                  _SORTER_BASE, _SORTER_BASE,
1139
                                                  _SORTER_BASE, _SORTER_BASE)
1140
  _SORTER_RE = re.compile(_SORTER_FULL)
1141
  _SORTER_NODIGIT = re.compile("^\D*$")
1142
  def _TryInt(val):
1143
    """Attempts to convert a variable to integer."""
1144
    if val is None or _SORTER_NODIGIT.match(val):
1145
      return val
1146
    rval = int(val)
1147
    return rval
1148

    
1149
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1150
             for name in name_list]
1151
  to_sort.sort()
1152
  return [tup[1] for tup in to_sort]
1153

    
1154

    
1155
def TryConvert(fn, val):
1156
  """Try to convert a value ignoring errors.
1157

1158
  This function tries to apply function I{fn} to I{val}. If no
1159
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1160
  the result, else it will return the original value. Any other
1161
  exceptions are propagated to the caller.
1162

1163
  @type fn: callable
1164
  @param fn: function to apply to the value
1165
  @param val: the value to be converted
1166
  @return: The converted value if the conversion was successful,
1167
      otherwise the original value.
1168

1169
  """
1170
  try:
1171
    nv = fn(val)
1172
  except (ValueError, TypeError):
1173
    nv = val
1174
  return nv
1175

    
1176

    
1177
def IsValidIP(ip):
1178
  """Verifies the syntax of an IPv4 address.
1179

1180
  This function checks if the IPv4 address passes is valid or not based
1181
  on syntax (not IP range, class calculations, etc.).
1182

1183
  @type ip: str
1184
  @param ip: the address to be checked
1185
  @rtype: a regular expression match object
1186
  @return: a regular expression match object, or None if the
1187
      address is not valid
1188

1189
  """
1190
  unit = "(0|[1-9]\d{0,2})"
1191
  #TODO: convert and return only boolean
1192
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1193

    
1194

    
1195
def IsValidShellParam(word):
1196
  """Verifies is the given word is safe from the shell's p.o.v.
1197

1198
  This means that we can pass this to a command via the shell and be
1199
  sure that it doesn't alter the command line and is passed as such to
1200
  the actual command.
1201

1202
  Note that we are overly restrictive here, in order to be on the safe
1203
  side.
1204

1205
  @type word: str
1206
  @param word: the word to check
1207
  @rtype: boolean
1208
  @return: True if the word is 'safe'
1209

1210
  """
1211
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1212

    
1213

    
1214
def BuildShellCmd(template, *args):
1215
  """Build a safe shell command line from the given arguments.
1216

1217
  This function will check all arguments in the args list so that they
1218
  are valid shell parameters (i.e. they don't contain shell
1219
  metacharacters). If everything is ok, it will return the result of
1220
  template % args.
1221

1222
  @type template: str
1223
  @param template: the string holding the template for the
1224
      string formatting
1225
  @rtype: str
1226
  @return: the expanded command line
1227

1228
  """
1229
  for word in args:
1230
    if not IsValidShellParam(word):
1231
      raise errors.ProgrammerError("Shell argument '%s' contains"
1232
                                   " invalid characters" % word)
1233
  return template % args
1234

    
1235

    
1236
def FormatUnit(value, units):
1237
  """Formats an incoming number of MiB with the appropriate unit.
1238

1239
  @type value: int
1240
  @param value: integer representing the value in MiB (1048576)
1241
  @type units: char
1242
  @param units: the type of formatting we should do:
1243
      - 'h' for automatic scaling
1244
      - 'm' for MiBs
1245
      - 'g' for GiBs
1246
      - 't' for TiBs
1247
  @rtype: str
1248
  @return: the formatted value (with suffix)
1249

1250
  """
1251
  if units not in ('m', 'g', 't', 'h'):
1252
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1253

    
1254
  suffix = ''
1255

    
1256
  if units == 'm' or (units == 'h' and value < 1024):
1257
    if units == 'h':
1258
      suffix = 'M'
1259
    return "%d%s" % (round(value, 0), suffix)
1260

    
1261
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1262
    if units == 'h':
1263
      suffix = 'G'
1264
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1265

    
1266
  else:
1267
    if units == 'h':
1268
      suffix = 'T'
1269
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1270

    
1271

    
1272
def ParseUnit(input_string):
1273
  """Tries to extract number and scale from the given string.
1274

1275
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1276
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1277
  is always an int in MiB.
1278

1279
  """
1280
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1281
  if not m:
1282
    raise errors.UnitParseError("Invalid format")
1283

    
1284
  value = float(m.groups()[0])
1285

    
1286
  unit = m.groups()[1]
1287
  if unit:
1288
    lcunit = unit.lower()
1289
  else:
1290
    lcunit = 'm'
1291

    
1292
  if lcunit in ('m', 'mb', 'mib'):
1293
    # Value already in MiB
1294
    pass
1295

    
1296
  elif lcunit in ('g', 'gb', 'gib'):
1297
    value *= 1024
1298

    
1299
  elif lcunit in ('t', 'tb', 'tib'):
1300
    value *= 1024 * 1024
1301

    
1302
  else:
1303
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1304

    
1305
  # Make sure we round up
1306
  if int(value) < value:
1307
    value += 1
1308

    
1309
  # Round up to the next multiple of 4
1310
  value = int(value)
1311
  if value % 4:
1312
    value += 4 - value % 4
1313

    
1314
  return value
1315

    
1316

    
1317
def AddAuthorizedKey(file_name, key):
1318
  """Adds an SSH public key to an authorized_keys file.
1319

1320
  @type file_name: str
1321
  @param file_name: path to authorized_keys file
1322
  @type key: str
1323
  @param key: string containing key
1324

1325
  """
1326
  key_fields = key.split()
1327

    
1328
  f = open(file_name, 'a+')
1329
  try:
1330
    nl = True
1331
    for line in f:
1332
      # Ignore whitespace changes
1333
      if line.split() == key_fields:
1334
        break
1335
      nl = line.endswith('\n')
1336
    else:
1337
      if not nl:
1338
        f.write("\n")
1339
      f.write(key.rstrip('\r\n'))
1340
      f.write("\n")
1341
      f.flush()
1342
  finally:
1343
    f.close()
1344

    
1345

    
1346
def RemoveAuthorizedKey(file_name, key):
1347
  """Removes an SSH public key from an authorized_keys file.
1348

1349
  @type file_name: str
1350
  @param file_name: path to authorized_keys file
1351
  @type key: str
1352
  @param key: string containing key
1353

1354
  """
1355
  key_fields = key.split()
1356

    
1357
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1358
  try:
1359
    out = os.fdopen(fd, 'w')
1360
    try:
1361
      f = open(file_name, 'r')
1362
      try:
1363
        for line in f:
1364
          # Ignore whitespace changes while comparing lines
1365
          if line.split() != key_fields:
1366
            out.write(line)
1367

    
1368
        out.flush()
1369
        os.rename(tmpname, file_name)
1370
      finally:
1371
        f.close()
1372
    finally:
1373
      out.close()
1374
  except:
1375
    RemoveFile(tmpname)
1376
    raise
1377

    
1378

    
1379
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1380
  """Sets the name of an IP address and hostname in /etc/hosts.
1381

1382
  @type file_name: str
1383
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1384
  @type ip: str
1385
  @param ip: the IP address
1386
  @type hostname: str
1387
  @param hostname: the hostname to be added
1388
  @type aliases: list
1389
  @param aliases: the list of aliases to add for the hostname
1390

1391
  """
1392
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1393
  # Ensure aliases are unique
1394
  aliases = UniqueSequence([hostname] + aliases)[1:]
1395

    
1396
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1397
  try:
1398
    out = os.fdopen(fd, 'w')
1399
    try:
1400
      f = open(file_name, 'r')
1401
      try:
1402
        for line in f:
1403
          fields = line.split()
1404
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1405
            continue
1406
          out.write(line)
1407

    
1408
        out.write("%s\t%s" % (ip, hostname))
1409
        if aliases:
1410
          out.write(" %s" % ' '.join(aliases))
1411
        out.write('\n')
1412

    
1413
        out.flush()
1414
        os.fsync(out)
1415
        os.chmod(tmpname, 0644)
1416
        os.rename(tmpname, file_name)
1417
      finally:
1418
        f.close()
1419
    finally:
1420
      out.close()
1421
  except:
1422
    RemoveFile(tmpname)
1423
    raise
1424

    
1425

    
1426
def AddHostToEtcHosts(hostname):
1427
  """Wrapper around SetEtcHostsEntry.
1428

1429
  @type hostname: str
1430
  @param hostname: a hostname that will be resolved and added to
1431
      L{constants.ETC_HOSTS}
1432

1433
  """
1434
  hi = HostInfo(name=hostname)
1435
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1436

    
1437

    
1438
def RemoveEtcHostsEntry(file_name, hostname):
1439
  """Removes a hostname from /etc/hosts.
1440

1441
  IP addresses without names are removed from the file.
1442

1443
  @type file_name: str
1444
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1445
  @type hostname: str
1446
  @param hostname: the hostname to be removed
1447

1448
  """
1449
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1450
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1451
  try:
1452
    out = os.fdopen(fd, 'w')
1453
    try:
1454
      f = open(file_name, 'r')
1455
      try:
1456
        for line in f:
1457
          fields = line.split()
1458
          if len(fields) > 1 and not fields[0].startswith('#'):
1459
            names = fields[1:]
1460
            if hostname in names:
1461
              while hostname in names:
1462
                names.remove(hostname)
1463
              if names:
1464
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1465
              continue
1466

    
1467
          out.write(line)
1468

    
1469
        out.flush()
1470
        os.fsync(out)
1471
        os.chmod(tmpname, 0644)
1472
        os.rename(tmpname, file_name)
1473
      finally:
1474
        f.close()
1475
    finally:
1476
      out.close()
1477
  except:
1478
    RemoveFile(tmpname)
1479
    raise
1480

    
1481

    
1482
def RemoveHostFromEtcHosts(hostname):
1483
  """Wrapper around RemoveEtcHostsEntry.
1484

1485
  @type hostname: str
1486
  @param hostname: hostname that will be resolved and its
1487
      full and shot name will be removed from
1488
      L{constants.ETC_HOSTS}
1489

1490
  """
1491
  hi = HostInfo(name=hostname)
1492
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1493
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1494

    
1495

    
1496
def TimestampForFilename():
1497
  """Returns the current time formatted for filenames.
1498

1499
  The format doesn't contain colons as some shells and applications them as
1500
  separators.
1501

1502
  """
1503
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1504

    
1505

    
1506
def CreateBackup(file_name):
1507
  """Creates a backup of a file.
1508

1509
  @type file_name: str
1510
  @param file_name: file to be backed up
1511
  @rtype: str
1512
  @return: the path to the newly created backup
1513
  @raise errors.ProgrammerError: for invalid file names
1514

1515
  """
1516
  if not os.path.isfile(file_name):
1517
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1518
                                file_name)
1519

    
1520
  prefix = ("%s.backup-%s." %
1521
            (os.path.basename(file_name), TimestampForFilename()))
1522
  dir_name = os.path.dirname(file_name)
1523

    
1524
  fsrc = open(file_name, 'rb')
1525
  try:
1526
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1527
    fdst = os.fdopen(fd, 'wb')
1528
    try:
1529
      logging.debug("Backing up %s at %s", file_name, backup_name)
1530
      shutil.copyfileobj(fsrc, fdst)
1531
    finally:
1532
      fdst.close()
1533
  finally:
1534
    fsrc.close()
1535

    
1536
  return backup_name
1537

    
1538

    
1539
def ShellQuote(value):
1540
  """Quotes shell argument according to POSIX.
1541

1542
  @type value: str
1543
  @param value: the argument to be quoted
1544
  @rtype: str
1545
  @return: the quoted value
1546

1547
  """
1548
  if _re_shell_unquoted.match(value):
1549
    return value
1550
  else:
1551
    return "'%s'" % value.replace("'", "'\\''")
1552

    
1553

    
1554
def ShellQuoteArgs(args):
1555
  """Quotes a list of shell arguments.
1556

1557
  @type args: list
1558
  @param args: list of arguments to be quoted
1559
  @rtype: str
1560
  @return: the quoted arguments concatenated with spaces
1561

1562
  """
1563
  return ' '.join([ShellQuote(i) for i in args])
1564

    
1565

    
1566
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1567
  """Simple ping implementation using TCP connect(2).
1568

1569
  Check if the given IP is reachable by doing attempting a TCP connect
1570
  to it.
1571

1572
  @type target: str
1573
  @param target: the IP or hostname to ping
1574
  @type port: int
1575
  @param port: the port to connect to
1576
  @type timeout: int
1577
  @param timeout: the timeout on the connection attempt
1578
  @type live_port_needed: boolean
1579
  @param live_port_needed: whether a closed port will cause the
1580
      function to return failure, as if there was a timeout
1581
  @type source: str or None
1582
  @param source: if specified, will cause the connect to be made
1583
      from this specific source address; failures to bind other
1584
      than C{EADDRNOTAVAIL} will be ignored
1585

1586
  """
1587
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1588

    
1589
  success = False
1590

    
1591
  if source is not None:
1592
    try:
1593
      sock.bind((source, 0))
1594
    except socket.error, (errcode, _):
1595
      if errcode == errno.EADDRNOTAVAIL:
1596
        success = False
1597

    
1598
  sock.settimeout(timeout)
1599

    
1600
  try:
1601
    sock.connect((target, port))
1602
    sock.close()
1603
    success = True
1604
  except socket.timeout:
1605
    success = False
1606
  except socket.error, (errcode, _):
1607
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1608

    
1609
  return success
1610

    
1611

    
1612
def OwnIpAddress(address):
1613
  """Check if the current host has the the given IP address.
1614

1615
  Currently this is done by TCP-pinging the address from the loopback
1616
  address.
1617

1618
  @type address: string
1619
  @param address: the address to check
1620
  @rtype: bool
1621
  @return: True if we own the address
1622

1623
  """
1624
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1625
                 source=constants.LOCALHOST_IP_ADDRESS)
1626

    
1627

    
1628
def ListVisibleFiles(path):
1629
  """Returns a list of visible files in a directory.
1630

1631
  @type path: str
1632
  @param path: the directory to enumerate
1633
  @rtype: list
1634
  @return: the list of all files not starting with a dot
1635
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1636

1637
  """
1638
  if not IsNormAbsPath(path):
1639
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1640
                                 " absolute/normalized: '%s'" % path)
1641
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1642
  files.sort()
1643
  return files
1644

    
1645

    
1646
def GetHomeDir(user, default=None):
1647
  """Try to get the homedir of the given user.
1648

1649
  The user can be passed either as a string (denoting the name) or as
1650
  an integer (denoting the user id). If the user is not found, the
1651
  'default' argument is returned, which defaults to None.
1652

1653
  """
1654
  try:
1655
    if isinstance(user, basestring):
1656
      result = pwd.getpwnam(user)
1657
    elif isinstance(user, (int, long)):
1658
      result = pwd.getpwuid(user)
1659
    else:
1660
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1661
                                   type(user))
1662
  except KeyError:
1663
    return default
1664
  return result.pw_dir
1665

    
1666

    
1667
def NewUUID():
1668
  """Returns a random UUID.
1669

1670
  @note: This is a Linux-specific method as it uses the /proc
1671
      filesystem.
1672
  @rtype: str
1673

1674
  """
1675
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1676

    
1677

    
1678
def GenerateSecret(numbytes=20):
1679
  """Generates a random secret.
1680

1681
  This will generate a pseudo-random secret returning an hex string
1682
  (so that it can be used where an ASCII string is needed).
1683

1684
  @param numbytes: the number of bytes which will be represented by the returned
1685
      string (defaulting to 20, the length of a SHA1 hash)
1686
  @rtype: str
1687
  @return: an hex representation of the pseudo-random sequence
1688

1689
  """
1690
  return os.urandom(numbytes).encode('hex')
1691

    
1692

    
1693
def EnsureDirs(dirs):
1694
  """Make required directories, if they don't exist.
1695

1696
  @param dirs: list of tuples (dir_name, dir_mode)
1697
  @type dirs: list of (string, integer)
1698

1699
  """
1700
  for dir_name, dir_mode in dirs:
1701
    try:
1702
      os.mkdir(dir_name, dir_mode)
1703
    except EnvironmentError, err:
1704
      if err.errno != errno.EEXIST:
1705
        raise errors.GenericError("Cannot create needed directory"
1706
                                  " '%s': %s" % (dir_name, err))
1707
    if not os.path.isdir(dir_name):
1708
      raise errors.GenericError("%s is not a directory" % dir_name)
1709

    
1710

    
1711
def ReadFile(file_name, size=-1):
1712
  """Reads a file.
1713

1714
  @type size: int
1715
  @param size: Read at most size bytes (if negative, entire file)
1716
  @rtype: str
1717
  @return: the (possibly partial) content of the file
1718

1719
  """
1720
  f = open(file_name, "r")
1721
  try:
1722
    return f.read(size)
1723
  finally:
1724
    f.close()
1725

    
1726

    
1727
def WriteFile(file_name, fn=None, data=None,
1728
              mode=None, uid=-1, gid=-1,
1729
              atime=None, mtime=None, close=True,
1730
              dry_run=False, backup=False,
1731
              prewrite=None, postwrite=None):
1732
  """(Over)write a file atomically.
1733

1734
  The file_name and either fn (a function taking one argument, the
1735
  file descriptor, and which should write the data to it) or data (the
1736
  contents of the file) must be passed. The other arguments are
1737
  optional and allow setting the file mode, owner and group, and the
1738
  mtime/atime of the file.
1739

1740
  If the function doesn't raise an exception, it has succeeded and the
1741
  target file has the new contents. If the function has raised an
1742
  exception, an existing target file should be unmodified and the
1743
  temporary file should be removed.
1744

1745
  @type file_name: str
1746
  @param file_name: the target filename
1747
  @type fn: callable
1748
  @param fn: content writing function, called with
1749
      file descriptor as parameter
1750
  @type data: str
1751
  @param data: contents of the file
1752
  @type mode: int
1753
  @param mode: file mode
1754
  @type uid: int
1755
  @param uid: the owner of the file
1756
  @type gid: int
1757
  @param gid: the group of the file
1758
  @type atime: int
1759
  @param atime: a custom access time to be set on the file
1760
  @type mtime: int
1761
  @param mtime: a custom modification time to be set on the file
1762
  @type close: boolean
1763
  @param close: whether to close file after writing it
1764
  @type prewrite: callable
1765
  @param prewrite: function to be called before writing content
1766
  @type postwrite: callable
1767
  @param postwrite: function to be called after writing content
1768

1769
  @rtype: None or int
1770
  @return: None if the 'close' parameter evaluates to True,
1771
      otherwise the file descriptor
1772

1773
  @raise errors.ProgrammerError: if any of the arguments are not valid
1774

1775
  """
1776
  if not os.path.isabs(file_name):
1777
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1778
                                 " absolute: '%s'" % file_name)
1779

    
1780
  if [fn, data].count(None) != 1:
1781
    raise errors.ProgrammerError("fn or data required")
1782

    
1783
  if [atime, mtime].count(None) == 1:
1784
    raise errors.ProgrammerError("Both atime and mtime must be either"
1785
                                 " set or None")
1786

    
1787
  if backup and not dry_run and os.path.isfile(file_name):
1788
    CreateBackup(file_name)
1789

    
1790
  dir_name, base_name = os.path.split(file_name)
1791
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1792
  do_remove = True
1793
  # here we need to make sure we remove the temp file, if any error
1794
  # leaves it in place
1795
  try:
1796
    if uid != -1 or gid != -1:
1797
      os.chown(new_name, uid, gid)
1798
    if mode:
1799
      os.chmod(new_name, mode)
1800
    if callable(prewrite):
1801
      prewrite(fd)
1802
    if data is not None:
1803
      os.write(fd, data)
1804
    else:
1805
      fn(fd)
1806
    if callable(postwrite):
1807
      postwrite(fd)
1808
    os.fsync(fd)
1809
    if atime is not None and mtime is not None:
1810
      os.utime(new_name, (atime, mtime))
1811
    if not dry_run:
1812
      os.rename(new_name, file_name)
1813
      do_remove = False
1814
  finally:
1815
    if close:
1816
      os.close(fd)
1817
      result = None
1818
    else:
1819
      result = fd
1820
    if do_remove:
1821
      RemoveFile(new_name)
1822

    
1823
  return result
1824

    
1825

    
1826
def ReadOneLineFile(file_name, strict=False):
1827
  """Return the first non-empty line from a file.
1828

1829
  @type strict: boolean
1830
  @param strict: if True, abort if the file has more than one
1831
      non-empty line
1832

1833
  """
1834
  file_lines = ReadFile(file_name).splitlines()
1835
  full_lines = filter(bool, file_lines)
1836
  if not file_lines or not full_lines:
1837
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1838
  elif strict and len(full_lines) > 1:
1839
    raise errors.GenericError("Too many lines in one-liner file %s" %
1840
                              file_name)
1841
  return full_lines[0]
1842

    
1843

    
1844
def FirstFree(seq, base=0):
1845
  """Returns the first non-existing integer from seq.
1846

1847
  The seq argument should be a sorted list of positive integers. The
1848
  first time the index of an element is smaller than the element
1849
  value, the index will be returned.
1850

1851
  The base argument is used to start at a different offset,
1852
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1853

1854
  Example: C{[0, 1, 3]} will return I{2}.
1855

1856
  @type seq: sequence
1857
  @param seq: the sequence to be analyzed.
1858
  @type base: int
1859
  @param base: use this value as the base index of the sequence
1860
  @rtype: int
1861
  @return: the first non-used index in the sequence
1862

1863
  """
1864
  for idx, elem in enumerate(seq):
1865
    assert elem >= base, "Passed element is higher than base offset"
1866
    if elem > idx + base:
1867
      # idx is not used
1868
      return idx + base
1869
  return None
1870

    
1871

    
1872
def SingleWaitForFdCondition(fdobj, event, timeout):
1873
  """Waits for a condition to occur on the socket.
1874

1875
  Immediately returns at the first interruption.
1876

1877
  @type fdobj: integer or object supporting a fileno() method
1878
  @param fdobj: entity to wait for events on
1879
  @type event: integer
1880
  @param event: ORed condition (see select module)
1881
  @type timeout: float or None
1882
  @param timeout: Timeout in seconds
1883
  @rtype: int or None
1884
  @return: None for timeout, otherwise occured conditions
1885

1886
  """
1887
  check = (event | select.POLLPRI |
1888
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1889

    
1890
  if timeout is not None:
1891
    # Poller object expects milliseconds
1892
    timeout *= 1000
1893

    
1894
  poller = select.poll()
1895
  poller.register(fdobj, event)
1896
  try:
1897
    # TODO: If the main thread receives a signal and we have no timeout, we
1898
    # could wait forever. This should check a global "quit" flag or something
1899
    # every so often.
1900
    io_events = poller.poll(timeout)
1901
  except select.error, err:
1902
    if err[0] != errno.EINTR:
1903
      raise
1904
    io_events = []
1905
  if io_events and io_events[0][1] & check:
1906
    return io_events[0][1]
1907
  else:
1908
    return None
1909

    
1910

    
1911
class FdConditionWaiterHelper(object):
1912
  """Retry helper for WaitForFdCondition.
1913

1914
  This class contains the retried and wait functions that make sure
1915
  WaitForFdCondition can continue waiting until the timeout is actually
1916
  expired.
1917

1918
  """
1919

    
1920
  def __init__(self, timeout):
1921
    self.timeout = timeout
1922

    
1923
  def Poll(self, fdobj, event):
1924
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1925
    if result is None:
1926
      raise RetryAgain()
1927
    else:
1928
      return result
1929

    
1930
  def UpdateTimeout(self, timeout):
1931
    self.timeout = timeout
1932

    
1933

    
1934
def WaitForFdCondition(fdobj, event, timeout):
1935
  """Waits for a condition to occur on the socket.
1936

1937
  Retries until the timeout is expired, even if interrupted.
1938

1939
  @type fdobj: integer or object supporting a fileno() method
1940
  @param fdobj: entity to wait for events on
1941
  @type event: integer
1942
  @param event: ORed condition (see select module)
1943
  @type timeout: float or None
1944
  @param timeout: Timeout in seconds
1945
  @rtype: int or None
1946
  @return: None for timeout, otherwise occured conditions
1947

1948
  """
1949
  if timeout is not None:
1950
    retrywaiter = FdConditionWaiterHelper(timeout)
1951
    try:
1952
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1953
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1954
    except RetryTimeout:
1955
      result = None
1956
  else:
1957
    result = None
1958
    while result is None:
1959
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1960
  return result
1961

    
1962

    
1963
def UniqueSequence(seq):
1964
  """Returns a list with unique elements.
1965

1966
  Element order is preserved.
1967

1968
  @type seq: sequence
1969
  @param seq: the sequence with the source elements
1970
  @rtype: list
1971
  @return: list of unique elements from seq
1972

1973
  """
1974
  seen = set()
1975
  return [i for i in seq if i not in seen and not seen.add(i)]
1976

    
1977

    
1978
def NormalizeAndValidateMac(mac):
1979
  """Normalizes and check if a MAC address is valid.
1980

1981
  Checks whether the supplied MAC address is formally correct, only
1982
  accepts colon separated format. Normalize it to all lower.
1983

1984
  @type mac: str
1985
  @param mac: the MAC to be validated
1986
  @rtype: str
1987
  @return: returns the normalized and validated MAC.
1988

1989
  @raise errors.OpPrereqError: If the MAC isn't valid
1990

1991
  """
1992
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
1993
  if not mac_check.match(mac):
1994
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
1995
                               mac, errors.ECODE_INVAL)
1996

    
1997
  return mac.lower()
1998

    
1999

    
2000
def TestDelay(duration):
2001
  """Sleep for a fixed amount of time.
2002

2003
  @type duration: float
2004
  @param duration: the sleep duration
2005
  @rtype: boolean
2006
  @return: False for negative value, True otherwise
2007

2008
  """
2009
  if duration < 0:
2010
    return False, "Invalid sleep duration"
2011
  time.sleep(duration)
2012
  return True, None
2013

    
2014

    
2015
def _CloseFDNoErr(fd, retries=5):
2016
  """Close a file descriptor ignoring errors.
2017

2018
  @type fd: int
2019
  @param fd: the file descriptor
2020
  @type retries: int
2021
  @param retries: how many retries to make, in case we get any
2022
      other error than EBADF
2023

2024
  """
2025
  try:
2026
    os.close(fd)
2027
  except OSError, err:
2028
    if err.errno != errno.EBADF:
2029
      if retries > 0:
2030
        _CloseFDNoErr(fd, retries - 1)
2031
    # else either it's closed already or we're out of retries, so we
2032
    # ignore this and go on
2033

    
2034

    
2035
def CloseFDs(noclose_fds=None):
2036
  """Close file descriptors.
2037

2038
  This closes all file descriptors above 2 (i.e. except
2039
  stdin/out/err).
2040

2041
  @type noclose_fds: list or None
2042
  @param noclose_fds: if given, it denotes a list of file descriptor
2043
      that should not be closed
2044

2045
  """
2046
  # Default maximum for the number of available file descriptors.
2047
  if 'SC_OPEN_MAX' in os.sysconf_names:
2048
    try:
2049
      MAXFD = os.sysconf('SC_OPEN_MAX')
2050
      if MAXFD < 0:
2051
        MAXFD = 1024
2052
    except OSError:
2053
      MAXFD = 1024
2054
  else:
2055
    MAXFD = 1024
2056
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2057
  if (maxfd == resource.RLIM_INFINITY):
2058
    maxfd = MAXFD
2059

    
2060
  # Iterate through and close all file descriptors (except the standard ones)
2061
  for fd in range(3, maxfd):
2062
    if noclose_fds and fd in noclose_fds:
2063
      continue
2064
    _CloseFDNoErr(fd)
2065

    
2066

    
2067
def Mlockall():
2068
  """Lock current process' virtual address space into RAM.
2069

2070
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2071
  see mlock(2) for more details. This function requires ctypes module.
2072

2073
  """
2074
  if ctypes is None:
2075
    logging.warning("Cannot set memory lock, ctypes module not found")
2076
    return
2077

    
2078
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2079
  if libc is None:
2080
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2081
    return
2082

    
2083
  # Some older version of the ctypes module don't have built-in functionality
2084
  # to access the errno global variable, where function error codes are stored.
2085
  # By declaring this variable as a pointer to an integer we can then access
2086
  # its value correctly, should the mlockall call fail, in order to see what
2087
  # the actual error code was.
2088
  # pylint: disable-msg=W0212
2089
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2090

    
2091
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2092
    # pylint: disable-msg=W0212
2093
    logging.error("Cannot set memory lock: %s",
2094
                  os.strerror(libc.__errno_location().contents.value))
2095
    return
2096

    
2097
  logging.debug("Memory lock set")
2098

    
2099

    
2100
def Daemonize(logfile):
2101
  """Daemonize the current process.
2102

2103
  This detaches the current process from the controlling terminal and
2104
  runs it in the background as a daemon.
2105

2106
  @type logfile: str
2107
  @param logfile: the logfile to which we should redirect stdout/stderr
2108
  @rtype: int
2109
  @return: the value zero
2110

2111
  """
2112
  # pylint: disable-msg=W0212
2113
  # yes, we really want os._exit
2114
  UMASK = 077
2115
  WORKDIR = "/"
2116

    
2117
  # this might fail
2118
  pid = os.fork()
2119
  if (pid == 0):  # The first child.
2120
    os.setsid()
2121
    # this might fail
2122
    pid = os.fork() # Fork a second child.
2123
    if (pid == 0):  # The second child.
2124
      os.chdir(WORKDIR)
2125
      os.umask(UMASK)
2126
    else:
2127
      # exit() or _exit()?  See below.
2128
      os._exit(0) # Exit parent (the first child) of the second child.
2129
  else:
2130
    os._exit(0) # Exit parent of the first child.
2131

    
2132
  for fd in range(3):
2133
    _CloseFDNoErr(fd)
2134
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2135
  assert i == 0, "Can't close/reopen stdin"
2136
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2137
  assert i == 1, "Can't close/reopen stdout"
2138
  # Duplicate standard output to standard error.
2139
  os.dup2(1, 2)
2140
  return 0
2141

    
2142

    
2143
def DaemonPidFileName(name):
2144
  """Compute a ganeti pid file absolute path
2145

2146
  @type name: str
2147
  @param name: the daemon name
2148
  @rtype: str
2149
  @return: the full path to the pidfile corresponding to the given
2150
      daemon name
2151

2152
  """
2153
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2154

    
2155

    
2156
def EnsureDaemon(name):
2157
  """Check for and start daemon if not alive.
2158

2159
  """
2160
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2161
  if result.failed:
2162
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2163
                  name, result.fail_reason, result.output)
2164
    return False
2165

    
2166
  return True
2167

    
2168

    
2169
def WritePidFile(name):
2170
  """Write the current process pidfile.
2171

2172
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2173

2174
  @type name: str
2175
  @param name: the daemon name to use
2176
  @raise errors.GenericError: if the pid file already exists and
2177
      points to a live process
2178

2179
  """
2180
  pid = os.getpid()
2181
  pidfilename = DaemonPidFileName(name)
2182
  if IsProcessAlive(ReadPidFile(pidfilename)):
2183
    raise errors.GenericError("%s contains a live process" % pidfilename)
2184

    
2185
  WriteFile(pidfilename, data="%d\n" % pid)
2186

    
2187

    
2188
def RemovePidFile(name):
2189
  """Remove the current process pidfile.
2190

2191
  Any errors are ignored.
2192

2193
  @type name: str
2194
  @param name: the daemon name used to derive the pidfile name
2195

2196
  """
2197
  pidfilename = DaemonPidFileName(name)
2198
  # TODO: we could check here that the file contains our pid
2199
  try:
2200
    RemoveFile(pidfilename)
2201
  except: # pylint: disable-msg=W0702
2202
    pass
2203

    
2204

    
2205
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2206
                waitpid=False):
2207
  """Kill a process given by its pid.
2208

2209
  @type pid: int
2210
  @param pid: The PID to terminate.
2211
  @type signal_: int
2212
  @param signal_: The signal to send, by default SIGTERM
2213
  @type timeout: int
2214
  @param timeout: The timeout after which, if the process is still alive,
2215
                  a SIGKILL will be sent. If not positive, no such checking
2216
                  will be done
2217
  @type waitpid: boolean
2218
  @param waitpid: If true, we should waitpid on this process after
2219
      sending signals, since it's our own child and otherwise it
2220
      would remain as zombie
2221

2222
  """
2223
  def _helper(pid, signal_, wait):
2224
    """Simple helper to encapsulate the kill/waitpid sequence"""
2225
    os.kill(pid, signal_)
2226
    if wait:
2227
      try:
2228
        os.waitpid(pid, os.WNOHANG)
2229
      except OSError:
2230
        pass
2231

    
2232
  if pid <= 0:
2233
    # kill with pid=0 == suicide
2234
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2235

    
2236
  if not IsProcessAlive(pid):
2237
    return
2238

    
2239
  _helper(pid, signal_, waitpid)
2240

    
2241
  if timeout <= 0:
2242
    return
2243

    
2244
  def _CheckProcess():
2245
    if not IsProcessAlive(pid):
2246
      return
2247

    
2248
    try:
2249
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2250
    except OSError:
2251
      raise RetryAgain()
2252

    
2253
    if result_pid > 0:
2254
      return
2255

    
2256
    raise RetryAgain()
2257

    
2258
  try:
2259
    # Wait up to $timeout seconds
2260
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2261
  except RetryTimeout:
2262
    pass
2263

    
2264
  if IsProcessAlive(pid):
2265
    # Kill process if it's still alive
2266
    _helper(pid, signal.SIGKILL, waitpid)
2267

    
2268

    
2269
def FindFile(name, search_path, test=os.path.exists):
2270
  """Look for a filesystem object in a given path.
2271

2272
  This is an abstract method to search for filesystem object (files,
2273
  dirs) under a given search path.
2274

2275
  @type name: str
2276
  @param name: the name to look for
2277
  @type search_path: str
2278
  @param search_path: location to start at
2279
  @type test: callable
2280
  @param test: a function taking one argument that should return True
2281
      if the a given object is valid; the default value is
2282
      os.path.exists, causing only existing files to be returned
2283
  @rtype: str or None
2284
  @return: full path to the object if found, None otherwise
2285

2286
  """
2287
  # validate the filename mask
2288
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2289
    logging.critical("Invalid value passed for external script name: '%s'",
2290
                     name)
2291
    return None
2292

    
2293
  for dir_name in search_path:
2294
    # FIXME: investigate switch to PathJoin
2295
    item_name = os.path.sep.join([dir_name, name])
2296
    # check the user test and that we're indeed resolving to the given
2297
    # basename
2298
    if test(item_name) and os.path.basename(item_name) == name:
2299
      return item_name
2300
  return None
2301

    
2302

    
2303
def CheckVolumeGroupSize(vglist, vgname, minsize):
2304
  """Checks if the volume group list is valid.
2305

2306
  The function will check if a given volume group is in the list of
2307
  volume groups and has a minimum size.
2308

2309
  @type vglist: dict
2310
  @param vglist: dictionary of volume group names and their size
2311
  @type vgname: str
2312
  @param vgname: the volume group we should check
2313
  @type minsize: int
2314
  @param minsize: the minimum size we accept
2315
  @rtype: None or str
2316
  @return: None for success, otherwise the error message
2317

2318
  """
2319
  vgsize = vglist.get(vgname, None)
2320
  if vgsize is None:
2321
    return "volume group '%s' missing" % vgname
2322
  elif vgsize < minsize:
2323
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2324
            (vgname, minsize, vgsize))
2325
  return None
2326

    
2327

    
2328
def SplitTime(value):
2329
  """Splits time as floating point number into a tuple.
2330

2331
  @param value: Time in seconds
2332
  @type value: int or float
2333
  @return: Tuple containing (seconds, microseconds)
2334

2335
  """
2336
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2337

    
2338
  assert 0 <= seconds, \
2339
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2340
  assert 0 <= microseconds <= 999999, \
2341
    "Microseconds must be 0-999999, but are %s" % microseconds
2342

    
2343
  return (int(seconds), int(microseconds))
2344

    
2345

    
2346
def MergeTime(timetuple):
2347
  """Merges a tuple into time as a floating point number.
2348

2349
  @param timetuple: Time as tuple, (seconds, microseconds)
2350
  @type timetuple: tuple
2351
  @return: Time as a floating point number expressed in seconds
2352

2353
  """
2354
  (seconds, microseconds) = timetuple
2355

    
2356
  assert 0 <= seconds, \
2357
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2358
  assert 0 <= microseconds <= 999999, \
2359
    "Microseconds must be 0-999999, but are %s" % microseconds
2360

    
2361
  return float(seconds) + (float(microseconds) * 0.000001)
2362

    
2363

    
2364
def GetDaemonPort(daemon_name):
2365
  """Get the daemon port for this cluster.
2366

2367
  Note that this routine does not read a ganeti-specific file, but
2368
  instead uses C{socket.getservbyname} to allow pre-customization of
2369
  this parameter outside of Ganeti.
2370

2371
  @type daemon_name: string
2372
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2373
  @rtype: int
2374

2375
  """
2376
  if daemon_name not in constants.DAEMONS_PORTS:
2377
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2378

    
2379
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2380
  try:
2381
    port = socket.getservbyname(daemon_name, proto)
2382
  except socket.error:
2383
    port = default_port
2384

    
2385
  return port
2386

    
2387

    
2388
class LogFileHandler(logging.FileHandler):
2389
  """Log handler that doesn't fallback to stderr.
2390

2391
  When an error occurs while writing on the logfile, logging.FileHandler tries
2392
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2393
  the logfile. This class avoids failures reporting errors to /dev/console.
2394

2395
  """
2396
  def __init__(self, filename, mode="a", encoding=None):
2397
    """Open the specified file and use it as the stream for logging.
2398

2399
    Also open /dev/console to report errors while logging.
2400

2401
    """
2402
    logging.FileHandler.__init__(self, filename, mode, encoding)
2403
    self.console = open(constants.DEV_CONSOLE, "a")
2404

    
2405
  def handleError(self, record): # pylint: disable-msg=C0103
2406
    """Handle errors which occur during an emit() call.
2407

2408
    Try to handle errors with FileHandler method, if it fails write to
2409
    /dev/console.
2410

2411
    """
2412
    try:
2413
      logging.FileHandler.handleError(self, record)
2414
    except Exception: # pylint: disable-msg=W0703
2415
      try:
2416
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2417
      except Exception: # pylint: disable-msg=W0703
2418
        # Log handler tried everything it could, now just give up
2419
        pass
2420

    
2421

    
2422
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2423
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2424
                 console_logging=False):
2425
  """Configures the logging module.
2426

2427
  @type logfile: str
2428
  @param logfile: the filename to which we should log
2429
  @type debug: integer
2430
  @param debug: if greater than zero, enable debug messages, otherwise
2431
      only those at C{INFO} and above level
2432
  @type stderr_logging: boolean
2433
  @param stderr_logging: whether we should also log to the standard error
2434
  @type program: str
2435
  @param program: the name under which we should log messages
2436
  @type multithreaded: boolean
2437
  @param multithreaded: if True, will add the thread name to the log file
2438
  @type syslog: string
2439
  @param syslog: one of 'no', 'yes', 'only':
2440
      - if no, syslog is not used
2441
      - if yes, syslog is used (in addition to file-logging)
2442
      - if only, only syslog is used
2443
  @type console_logging: boolean
2444
  @param console_logging: if True, will use a FileHandler which falls back to
2445
      the system console if logging fails
2446
  @raise EnvironmentError: if we can't open the log file and
2447
      syslog/stderr logging is disabled
2448

2449
  """
2450
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2451
  sft = program + "[%(process)d]:"
2452
  if multithreaded:
2453
    fmt += "/%(threadName)s"
2454
    sft += " (%(threadName)s)"
2455
  if debug:
2456
    fmt += " %(module)s:%(lineno)s"
2457
    # no debug info for syslog loggers
2458
  fmt += " %(levelname)s %(message)s"
2459
  # yes, we do want the textual level, as remote syslog will probably
2460
  # lose the error level, and it's easier to grep for it
2461
  sft += " %(levelname)s %(message)s"
2462
  formatter = logging.Formatter(fmt)
2463
  sys_fmt = logging.Formatter(sft)
2464

    
2465
  root_logger = logging.getLogger("")
2466
  root_logger.setLevel(logging.NOTSET)
2467

    
2468
  # Remove all previously setup handlers
2469
  for handler in root_logger.handlers:
2470
    handler.close()
2471
    root_logger.removeHandler(handler)
2472

    
2473
  if stderr_logging:
2474
    stderr_handler = logging.StreamHandler()
2475
    stderr_handler.setFormatter(formatter)
2476
    if debug:
2477
      stderr_handler.setLevel(logging.NOTSET)
2478
    else:
2479
      stderr_handler.setLevel(logging.CRITICAL)
2480
    root_logger.addHandler(stderr_handler)
2481

    
2482
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2483
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2484
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2485
                                                    facility)
2486
    syslog_handler.setFormatter(sys_fmt)
2487
    # Never enable debug over syslog
2488
    syslog_handler.setLevel(logging.INFO)
2489
    root_logger.addHandler(syslog_handler)
2490

    
2491
  if syslog != constants.SYSLOG_ONLY:
2492
    # this can fail, if the logging directories are not setup or we have
2493
    # a permisssion problem; in this case, it's best to log but ignore
2494
    # the error if stderr_logging is True, and if false we re-raise the
2495
    # exception since otherwise we could run but without any logs at all
2496
    try:
2497
      if console_logging:
2498
        logfile_handler = LogFileHandler(logfile)
2499
      else:
2500
        logfile_handler = logging.FileHandler(logfile)
2501
      logfile_handler.setFormatter(formatter)
2502
      if debug:
2503
        logfile_handler.setLevel(logging.DEBUG)
2504
      else:
2505
        logfile_handler.setLevel(logging.INFO)
2506
      root_logger.addHandler(logfile_handler)
2507
    except EnvironmentError:
2508
      if stderr_logging or syslog == constants.SYSLOG_YES:
2509
        logging.exception("Failed to enable logging to file '%s'", logfile)
2510
      else:
2511
        # we need to re-raise the exception
2512
        raise
2513

    
2514

    
2515
def IsNormAbsPath(path):
2516
  """Check whether a path is absolute and also normalized
2517

2518
  This avoids things like /dir/../../other/path to be valid.
2519

2520
  """
2521
  return os.path.normpath(path) == path and os.path.isabs(path)
2522

    
2523

    
2524
def PathJoin(*args):
2525
  """Safe-join a list of path components.
2526

2527
  Requirements:
2528
      - the first argument must be an absolute path
2529
      - no component in the path must have backtracking (e.g. /../),
2530
        since we check for normalization at the end
2531

2532
  @param args: the path components to be joined
2533
  @raise ValueError: for invalid paths
2534

2535
  """
2536
  # ensure we're having at least one path passed in
2537
  assert args
2538
  # ensure the first component is an absolute and normalized path name
2539
  root = args[0]
2540
  if not IsNormAbsPath(root):
2541
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2542
  result = os.path.join(*args)
2543
  # ensure that the whole path is normalized
2544
  if not IsNormAbsPath(result):
2545
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2546
  # check that we're still under the original prefix
2547
  prefix = os.path.commonprefix([root, result])
2548
  if prefix != root:
2549
    raise ValueError("Error: path joining resulted in different prefix"
2550
                     " (%s != %s)" % (prefix, root))
2551
  return result
2552

    
2553

    
2554
def TailFile(fname, lines=20):
2555
  """Return the last lines from a file.
2556

2557
  @note: this function will only read and parse the last 4KB of
2558
      the file; if the lines are very long, it could be that less
2559
      than the requested number of lines are returned
2560

2561
  @param fname: the file name
2562
  @type lines: int
2563
  @param lines: the (maximum) number of lines to return
2564

2565
  """
2566
  fd = open(fname, "r")
2567
  try:
2568
    fd.seek(0, 2)
2569
    pos = fd.tell()
2570
    pos = max(0, pos-4096)
2571
    fd.seek(pos, 0)
2572
    raw_data = fd.read()
2573
  finally:
2574
    fd.close()
2575

    
2576
  rows = raw_data.splitlines()
2577
  return rows[-lines:]
2578

    
2579

    
2580
def FormatTimestampWithTZ(secs):
2581
  """Formats a Unix timestamp with the local timezone.
2582

2583
  """
2584
  return time.strftime("%F %T %Z", time.gmtime(secs))
2585

    
2586

    
2587
def _ParseAsn1Generalizedtime(value):
2588
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2589

2590
  @type value: string
2591
  @param value: ASN1 GENERALIZEDTIME timestamp
2592

2593
  """
2594
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2595
  if m:
2596
    # We have an offset
2597
    asn1time = m.group(1)
2598
    hours = int(m.group(2))
2599
    minutes = int(m.group(3))
2600
    utcoffset = (60 * hours) + minutes
2601
  else:
2602
    if not value.endswith("Z"):
2603
      raise ValueError("Missing timezone")
2604
    asn1time = value[:-1]
2605
    utcoffset = 0
2606

    
2607
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2608

    
2609
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2610

    
2611
  return calendar.timegm(tt.utctimetuple())
2612

    
2613

    
2614
def GetX509CertValidity(cert):
2615
  """Returns the validity period of the certificate.
2616

2617
  @type cert: OpenSSL.crypto.X509
2618
  @param cert: X509 certificate object
2619

2620
  """
2621
  # The get_notBefore and get_notAfter functions are only supported in
2622
  # pyOpenSSL 0.7 and above.
2623
  try:
2624
    get_notbefore_fn = cert.get_notBefore
2625
  except AttributeError:
2626
    not_before = None
2627
  else:
2628
    not_before_asn1 = get_notbefore_fn()
2629

    
2630
    if not_before_asn1 is None:
2631
      not_before = None
2632
    else:
2633
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2634

    
2635
  try:
2636
    get_notafter_fn = cert.get_notAfter
2637
  except AttributeError:
2638
    not_after = None
2639
  else:
2640
    not_after_asn1 = get_notafter_fn()
2641

    
2642
    if not_after_asn1 is None:
2643
      not_after = None
2644
    else:
2645
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2646

    
2647
  return (not_before, not_after)
2648

    
2649

    
2650
def _VerifyCertificateInner(expired, not_before, not_after, now,
2651
                            warn_days, error_days):
2652
  """Verifies certificate validity.
2653

2654
  @type expired: bool
2655
  @param expired: Whether pyOpenSSL considers the certificate as expired
2656
  @type not_before: number or None
2657
  @param not_before: Unix timestamp before which certificate is not valid
2658
  @type not_after: number or None
2659
  @param not_after: Unix timestamp after which certificate is invalid
2660
  @type now: number
2661
  @param now: Current time as Unix timestamp
2662
  @type warn_days: number or None
2663
  @param warn_days: How many days before expiration a warning should be reported
2664
  @type error_days: number or None
2665
  @param error_days: How many days before expiration an error should be reported
2666

2667
  """
2668
  if expired:
2669
    msg = "Certificate is expired"
2670

    
2671
    if not_before is not None and not_after is not None:
2672
      msg += (" (valid from %s to %s)" %
2673
              (FormatTimestampWithTZ(not_before),
2674
               FormatTimestampWithTZ(not_after)))
2675
    elif not_before is not None:
2676
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2677
    elif not_after is not None:
2678
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2679

    
2680
    return (CERT_ERROR, msg)
2681

    
2682
  elif not_before is not None and not_before > now:
2683
    return (CERT_WARNING,
2684
            "Certificate not yet valid (valid from %s)" %
2685
            FormatTimestampWithTZ(not_before))
2686

    
2687
  elif not_after is not None:
2688
    remaining_days = int((not_after - now) / (24 * 3600))
2689

    
2690
    msg = "Certificate expires in about %d days" % remaining_days
2691

    
2692
    if error_days is not None and remaining_days <= error_days:
2693
      return (CERT_ERROR, msg)
2694

    
2695
    if warn_days is not None and remaining_days <= warn_days:
2696
      return (CERT_WARNING, msg)
2697

    
2698
  return (None, None)
2699

    
2700

    
2701
def VerifyX509Certificate(cert, warn_days, error_days):
2702
  """Verifies a certificate for LUVerifyCluster.
2703

2704
  @type cert: OpenSSL.crypto.X509
2705
  @param cert: X509 certificate object
2706
  @type warn_days: number or None
2707
  @param warn_days: How many days before expiration a warning should be reported
2708
  @type error_days: number or None
2709
  @param error_days: How many days before expiration an error should be reported
2710

2711
  """
2712
  # Depending on the pyOpenSSL version, this can just return (None, None)
2713
  (not_before, not_after) = GetX509CertValidity(cert)
2714

    
2715
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2716
                                 time.time(), warn_days, error_days)
2717

    
2718

    
2719
def SignX509Certificate(cert, key, salt):
2720
  """Sign a X509 certificate.
2721

2722
  An RFC822-like signature header is added in front of the certificate.
2723

2724
  @type cert: OpenSSL.crypto.X509
2725
  @param cert: X509 certificate object
2726
  @type key: string
2727
  @param key: Key for HMAC
2728
  @type salt: string
2729
  @param salt: Salt for HMAC
2730
  @rtype: string
2731
  @return: Serialized and signed certificate in PEM format
2732

2733
  """
2734
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2735
    raise errors.GenericError("Invalid salt: %r" % salt)
2736

    
2737
  # Dumping as PEM here ensures the certificate is in a sane format
2738
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2739

    
2740
  return ("%s: %s/%s\n\n%s" %
2741
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2742
           Sha1Hmac(key, cert_pem, salt=salt),
2743
           cert_pem))
2744

    
2745

    
2746
def _ExtractX509CertificateSignature(cert_pem):
2747
  """Helper function to extract signature from X509 certificate.
2748

2749
  """
2750
  # Extract signature from original PEM data
2751
  for line in cert_pem.splitlines():
2752
    if line.startswith("---"):
2753
      break
2754

    
2755
    m = X509_SIGNATURE.match(line.strip())
2756
    if m:
2757
      return (m.group("salt"), m.group("sign"))
2758

    
2759
  raise errors.GenericError("X509 certificate signature is missing")
2760

    
2761

    
2762
def LoadSignedX509Certificate(cert_pem, key):
2763
  """Verifies a signed X509 certificate.
2764

2765
  @type cert_pem: string
2766
  @param cert_pem: Certificate in PEM format and with signature header
2767
  @type key: string
2768
  @param key: Key for HMAC
2769
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2770
  @return: X509 certificate object and salt
2771

2772
  """
2773
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2774

    
2775
  # Load certificate
2776
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2777

    
2778
  # Dump again to ensure it's in a sane format
2779
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2780

    
2781
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2782
    raise errors.GenericError("X509 certificate signature is invalid")
2783

    
2784
  return (cert, salt)
2785

    
2786

    
2787
def Sha1Hmac(key, text, salt=None):
2788
  """Calculates the HMAC-SHA1 digest of a text.
2789

2790
  HMAC is defined in RFC2104.
2791

2792
  @type key: string
2793
  @param key: Secret key
2794
  @type text: string
2795

2796
  """
2797
  if salt:
2798
    salted_text = salt + text
2799
  else:
2800
    salted_text = text
2801

    
2802
  return hmac.new(key, salted_text, sha1).hexdigest()
2803

    
2804

    
2805
def VerifySha1Hmac(key, text, digest, salt=None):
2806
  """Verifies the HMAC-SHA1 digest of a text.
2807

2808
  HMAC is defined in RFC2104.
2809

2810
  @type key: string
2811
  @param key: Secret key
2812
  @type text: string
2813
  @type digest: string
2814
  @param digest: Expected digest
2815
  @rtype: bool
2816
  @return: Whether HMAC-SHA1 digest matches
2817

2818
  """
2819
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2820

    
2821

    
2822
def SafeEncode(text):
2823
  """Return a 'safe' version of a source string.
2824

2825
  This function mangles the input string and returns a version that
2826
  should be safe to display/encode as ASCII. To this end, we first
2827
  convert it to ASCII using the 'backslashreplace' encoding which
2828
  should get rid of any non-ASCII chars, and then we process it
2829
  through a loop copied from the string repr sources in the python; we
2830
  don't use string_escape anymore since that escape single quotes and
2831
  backslashes too, and that is too much; and that escaping is not
2832
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2833

2834
  @type text: str or unicode
2835
  @param text: input data
2836
  @rtype: str
2837
  @return: a safe version of text
2838

2839
  """
2840
  if isinstance(text, unicode):
2841
    # only if unicode; if str already, we handle it below
2842
    text = text.encode('ascii', 'backslashreplace')
2843
  resu = ""
2844
  for char in text:
2845
    c = ord(char)
2846
    if char  == '\t':
2847
      resu += r'\t'
2848
    elif char == '\n':
2849
      resu += r'\n'
2850
    elif char == '\r':
2851
      resu += r'\'r'
2852
    elif c < 32 or c >= 127: # non-printable
2853
      resu += "\\x%02x" % (c & 0xff)
2854
    else:
2855
      resu += char
2856
  return resu
2857

    
2858

    
2859
def UnescapeAndSplit(text, sep=","):
2860
  """Split and unescape a string based on a given separator.
2861

2862
  This function splits a string based on a separator where the
2863
  separator itself can be escape in order to be an element of the
2864
  elements. The escaping rules are (assuming coma being the
2865
  separator):
2866
    - a plain , separates the elements
2867
    - a sequence \\\\, (double backslash plus comma) is handled as a
2868
      backslash plus a separator comma
2869
    - a sequence \, (backslash plus comma) is handled as a
2870
      non-separator comma
2871

2872
  @type text: string
2873
  @param text: the string to split
2874
  @type sep: string
2875
  @param text: the separator
2876
  @rtype: string
2877
  @return: a list of strings
2878

2879
  """
2880
  # we split the list by sep (with no escaping at this stage)
2881
  slist = text.split(sep)
2882
  # next, we revisit the elements and if any of them ended with an odd
2883
  # number of backslashes, then we join it with the next
2884
  rlist = []
2885
  while slist:
2886
    e1 = slist.pop(0)
2887
    if e1.endswith("\\"):
2888
      num_b = len(e1) - len(e1.rstrip("\\"))
2889
      if num_b % 2 == 1:
2890
        e2 = slist.pop(0)
2891
        # here the backslashes remain (all), and will be reduced in
2892
        # the next step
2893
        rlist.append(e1 + sep + e2)
2894
        continue
2895
    rlist.append(e1)
2896
  # finally, replace backslash-something with something
2897
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2898
  return rlist
2899

    
2900

    
2901
def CommaJoin(names):
2902
  """Nicely join a set of identifiers.
2903

2904
  @param names: set, list or tuple
2905
  @return: a string with the formatted results
2906

2907
  """
2908
  return ", ".join([str(val) for val in names])
2909

    
2910

    
2911
def BytesToMebibyte(value):
2912
  """Converts bytes to mebibytes.
2913

2914
  @type value: int
2915
  @param value: Value in bytes
2916
  @rtype: int
2917
  @return: Value in mebibytes
2918

2919
  """
2920
  return int(round(value / (1024.0 * 1024.0), 0))
2921

    
2922

    
2923
def CalculateDirectorySize(path):
2924
  """Calculates the size of a directory recursively.
2925

2926
  @type path: string
2927
  @param path: Path to directory
2928
  @rtype: int
2929
  @return: Size in mebibytes
2930

2931
  """
2932
  size = 0
2933

    
2934
  for (curpath, _, files) in os.walk(path):
2935
    for filename in files:
2936
      st = os.lstat(PathJoin(curpath, filename))
2937
      size += st.st_size
2938

    
2939
  return BytesToMebibyte(size)
2940

    
2941

    
2942
def GetFilesystemStats(path):
2943
  """Returns the total and free space on a filesystem.
2944

2945
  @type path: string
2946
  @param path: Path on filesystem to be examined
2947
  @rtype: int
2948
  @return: tuple of (Total space, Free space) in mebibytes
2949

2950
  """
2951
  st = os.statvfs(path)
2952

    
2953
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2954
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2955
  return (tsize, fsize)
2956

    
2957

    
2958
def RunInSeparateProcess(fn, *args):
2959
  """Runs a function in a separate process.
2960

2961
  Note: Only boolean return values are supported.
2962

2963
  @type fn: callable
2964
  @param fn: Function to be called
2965
  @rtype: bool
2966
  @return: Function's result
2967

2968
  """
2969
  pid = os.fork()
2970
  if pid == 0:
2971
    # Child process
2972
    try:
2973
      # In case the function uses temporary files
2974
      ResetTempfileModule()
2975

    
2976
      # Call function
2977
      result = int(bool(fn(*args)))
2978
      assert result in (0, 1)
2979
    except: # pylint: disable-msg=W0702
2980
      logging.exception("Error while calling function in separate process")
2981
      # 0 and 1 are reserved for the return value
2982
      result = 33
2983

    
2984
    os._exit(result) # pylint: disable-msg=W0212
2985

    
2986
  # Parent process
2987

    
2988
  # Avoid zombies and check exit code
2989
  (_, status) = os.waitpid(pid, 0)
2990

    
2991
  if os.WIFSIGNALED(status):
2992
    exitcode = None
2993
    signum = os.WTERMSIG(status)
2994
  else:
2995
    exitcode = os.WEXITSTATUS(status)
2996
    signum = None
2997

    
2998
  if not (exitcode in (0, 1) and signum is None):
2999
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3000
                              (exitcode, signum))
3001

    
3002
  return bool(exitcode)
3003

    
3004

    
3005
def IgnoreSignals(fn, *args, **kwargs):
3006
  """Tries to call a function ignoring failures due to EINTR.
3007

3008
  """
3009
  try:
3010
    return fn(*args, **kwargs)
3011
  except (EnvironmentError, socket.error), err:
3012
    if err.errno != errno.EINTR:
3013
      raise
3014
  except select.error, err:
3015
    if not (err.args and err.args[0] == errno.EINTR):
3016
      raise
3017

    
3018

    
3019
def LockedMethod(fn):
3020
  """Synchronized object access decorator.
3021

3022
  This decorator is intended to protect access to an object using the
3023
  object's own lock which is hardcoded to '_lock'.
3024

3025
  """
3026
  def _LockDebug(*args, **kwargs):
3027
    if debug_locks:
3028
      logging.debug(*args, **kwargs)
3029

    
3030
  def wrapper(self, *args, **kwargs):
3031
    # pylint: disable-msg=W0212
3032
    assert hasattr(self, '_lock')
3033
    lock = self._lock
3034
    _LockDebug("Waiting for %s", lock)
3035
    lock.acquire()
3036
    try:
3037
      _LockDebug("Acquired %s", lock)
3038
      result = fn(self, *args, **kwargs)
3039
    finally:
3040
      _LockDebug("Releasing %s", lock)
3041
      lock.release()
3042
      _LockDebug("Released %s", lock)
3043
    return result
3044
  return wrapper
3045

    
3046

    
3047
def LockFile(fd):
3048
  """Locks a file using POSIX locks.
3049

3050
  @type fd: int
3051
  @param fd: the file descriptor we need to lock
3052

3053
  """
3054
  try:
3055
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3056
  except IOError, err:
3057
    if err.errno == errno.EAGAIN:
3058
      raise errors.LockError("File already locked")
3059
    raise
3060

    
3061

    
3062
def FormatTime(val):
3063
  """Formats a time value.
3064

3065
  @type val: float or None
3066
  @param val: the timestamp as returned by time.time()
3067
  @return: a string value or N/A if we don't have a valid timestamp
3068

3069
  """
3070
  if val is None or not isinstance(val, (int, float)):
3071
    return "N/A"
3072
  # these two codes works on Linux, but they are not guaranteed on all
3073
  # platforms
3074
  return time.strftime("%F %T", time.localtime(val))
3075

    
3076

    
3077
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3078
  """Reads the watcher pause file.
3079

3080
  @type filename: string
3081
  @param filename: Path to watcher pause file
3082
  @type now: None, float or int
3083
  @param now: Current time as Unix timestamp
3084
  @type remove_after: int
3085
  @param remove_after: Remove watcher pause file after specified amount of
3086
    seconds past the pause end time
3087

3088
  """
3089
  if now is None:
3090
    now = time.time()
3091

    
3092
  try:
3093
    value = ReadFile(filename)
3094
  except IOError, err:
3095
    if err.errno != errno.ENOENT:
3096
      raise
3097
    value = None
3098

    
3099
  if value is not None:
3100
    try:
3101
      value = int(value)
3102
    except ValueError:
3103
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3104
                       " removing it"), filename)
3105
      RemoveFile(filename)
3106
      value = None
3107

    
3108
    if value is not None:
3109
      # Remove file if it's outdated
3110
      if now > (value + remove_after):
3111
        RemoveFile(filename)
3112
        value = None
3113

    
3114
      elif now > value:
3115
        value = None
3116

    
3117
  return value
3118

    
3119

    
3120
class RetryTimeout(Exception):
3121
  """Retry loop timed out.
3122

3123
  Any arguments which was passed by the retried function to RetryAgain will be
3124
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3125
  the RaiseInner helper method will reraise it.
3126

3127
  """
3128
  def RaiseInner(self):
3129
    if self.args and isinstance(self.args[0], Exception):
3130
      raise self.args[0]
3131
    else:
3132
      raise RetryTimeout(*self.args)
3133

    
3134

    
3135
class RetryAgain(Exception):
3136
  """Retry again.
3137

3138
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3139
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3140
  of the RetryTimeout() method can be used to reraise it.
3141

3142
  """
3143

    
3144

    
3145
class _RetryDelayCalculator(object):
3146
  """Calculator for increasing delays.
3147

3148
  """
3149
  __slots__ = [
3150
    "_factor",
3151
    "_limit",
3152
    "_next",
3153
    "_start",
3154
    ]
3155

    
3156
  def __init__(self, start, factor, limit):
3157
    """Initializes this class.
3158

3159
    @type start: float
3160
    @param start: Initial delay
3161
    @type factor: float
3162
    @param factor: Factor for delay increase
3163
    @type limit: float or None
3164
    @param limit: Upper limit for delay or None for no limit
3165

3166
    """
3167
    assert start > 0.0
3168
    assert factor >= 1.0
3169
    assert limit is None or limit >= 0.0
3170

    
3171
    self._start = start
3172
    self._factor = factor
3173
    self._limit = limit
3174

    
3175
    self._next = start
3176

    
3177
  def __call__(self):
3178
    """Returns current delay and calculates the next one.
3179

3180
    """
3181
    current = self._next
3182

    
3183
    # Update for next run
3184
    if self._limit is None or self._next < self._limit:
3185
      self._next = min(self._limit, self._next * self._factor)
3186

    
3187
    return current
3188

    
3189

    
3190
#: Special delay to specify whole remaining timeout
3191
RETRY_REMAINING_TIME = object()
3192

    
3193

    
3194
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3195
          _time_fn=time.time):
3196
  """Call a function repeatedly until it succeeds.
3197

3198
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3199
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3200
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3201

3202
  C{delay} can be one of the following:
3203
    - callable returning the delay length as a float
3204
    - Tuple of (start, factor, limit)
3205
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3206
      useful when overriding L{wait_fn} to wait for an external event)
3207
    - A static delay as a number (int or float)
3208

3209
  @type fn: callable
3210
  @param fn: Function to be called
3211
  @param delay: Either a callable (returning the delay), a tuple of (start,
3212
                factor, limit) (see L{_RetryDelayCalculator}),
3213
                L{RETRY_REMAINING_TIME} or a number (int or float)
3214
  @type timeout: float
3215
  @param timeout: Total timeout
3216
  @type wait_fn: callable
3217
  @param wait_fn: Waiting function
3218
  @return: Return value of function
3219

3220
  """
3221
  assert callable(fn)
3222
  assert callable(wait_fn)
3223
  assert callable(_time_fn)
3224

    
3225
  if args is None:
3226
    args = []
3227

    
3228
  end_time = _time_fn() + timeout
3229

    
3230
  if callable(delay):
3231
    # External function to calculate delay
3232
    calc_delay = delay
3233

    
3234
  elif isinstance(delay, (tuple, list)):
3235
    # Increasing delay with optional upper boundary
3236
    (start, factor, limit) = delay
3237
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3238

    
3239
  elif delay is RETRY_REMAINING_TIME:
3240
    # Always use the remaining time
3241
    calc_delay = None
3242

    
3243
  else:
3244
    # Static delay
3245
    calc_delay = lambda: delay
3246

    
3247
  assert calc_delay is None or callable(calc_delay)
3248

    
3249
  while True:
3250
    retry_args = []
3251
    try:
3252
      # pylint: disable-msg=W0142
3253
      return fn(*args)
3254
    except RetryAgain, err:
3255
      retry_args = err.args
3256
    except RetryTimeout:
3257
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3258
                                   " handle RetryTimeout")
3259

    
3260
    remaining_time = end_time - _time_fn()
3261

    
3262
    if remaining_time < 0.0:
3263
      # pylint: disable-msg=W0142
3264
      raise RetryTimeout(*retry_args)
3265

    
3266
    assert remaining_time >= 0.0
3267

    
3268
    if calc_delay is None:
3269
      wait_fn(remaining_time)
3270
    else:
3271
      current_delay = calc_delay()
3272
      if current_delay > 0.0:
3273
        wait_fn(current_delay)
3274

    
3275

    
3276
def GetClosedTempfile(*args, **kwargs):
3277
  """Creates a temporary file and returns its path.
3278

3279
  """
3280
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3281
  _CloseFDNoErr(fd)
3282
  return path
3283

    
3284

    
3285
def GenerateSelfSignedX509Cert(common_name, validity):
3286
  """Generates a self-signed X509 certificate.
3287

3288
  @type common_name: string
3289
  @param common_name: commonName value
3290
  @type validity: int
3291
  @param validity: Validity for certificate in seconds
3292

3293
  """
3294
  # Create private and public key
3295
  key = OpenSSL.crypto.PKey()
3296
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3297

    
3298
  # Create self-signed certificate
3299
  cert = OpenSSL.crypto.X509()
3300
  if common_name:
3301
    cert.get_subject().CN = common_name
3302
  cert.set_serial_number(1)
3303
  cert.gmtime_adj_notBefore(0)
3304
  cert.gmtime_adj_notAfter(validity)
3305
  cert.set_issuer(cert.get_subject())
3306
  cert.set_pubkey(key)
3307
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3308

    
3309
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3310
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3311

    
3312
  return (key_pem, cert_pem)
3313

    
3314

    
3315
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3316
  """Legacy function to generate self-signed X509 certificate.
3317

3318
  """
3319
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3320
                                                   validity * 24 * 60 * 60)
3321

    
3322
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3323

    
3324

    
3325
class FileLock(object):
3326
  """Utility class for file locks.
3327

3328
  """
3329
  def __init__(self, fd, filename):
3330
    """Constructor for FileLock.
3331

3332
    @type fd: file
3333
    @param fd: File object
3334
    @type filename: str
3335
    @param filename: Path of the file opened at I{fd}
3336

3337
    """
3338
    self.fd = fd
3339
    self.filename = filename
3340

    
3341
  @classmethod
3342
  def Open(cls, filename):
3343
    """Creates and opens a file to be used as a file-based lock.
3344

3345
    @type filename: string
3346
    @param filename: path to the file to be locked
3347

3348
    """
3349
    # Using "os.open" is necessary to allow both opening existing file
3350
    # read/write and creating if not existing. Vanilla "open" will truncate an
3351
    # existing file -or- allow creating if not existing.
3352
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3353
               filename)
3354

    
3355
  def __del__(self):
3356
    self.Close()
3357

    
3358
  def Close(self):
3359
    """Close the file and release the lock.
3360

3361
    """
3362
    if hasattr(self, "fd") and self.fd:
3363
      self.fd.close()
3364
      self.fd = None
3365

    
3366
  def _flock(self, flag, blocking, timeout, errmsg):
3367
    """Wrapper for fcntl.flock.
3368

3369
    @type flag: int
3370
    @param flag: operation flag
3371
    @type blocking: bool
3372
    @param blocking: whether the operation should be done in blocking mode.
3373
    @type timeout: None or float
3374
    @param timeout: for how long the operation should be retried (implies
3375
                    non-blocking mode).
3376
    @type errmsg: string
3377
    @param errmsg: error message in case operation fails.
3378

3379
    """
3380
    assert self.fd, "Lock was closed"
3381
    assert timeout is None or timeout >= 0, \
3382
      "If specified, timeout must be positive"
3383
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3384

    
3385
    # When a timeout is used, LOCK_NB must always be set
3386
    if not (timeout is None and blocking):
3387
      flag |= fcntl.LOCK_NB
3388

    
3389
    if timeout is None:
3390
      self._Lock(self.fd, flag, timeout)
3391
    else:
3392
      try:
3393
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3394
              args=(self.fd, flag, timeout))
3395
      except RetryTimeout:
3396
        raise errors.LockError(errmsg)
3397

    
3398
  @staticmethod
3399
  def _Lock(fd, flag, timeout):
3400
    try:
3401
      fcntl.flock(fd, flag)
3402
    except IOError, err:
3403
      if timeout is not None and err.errno == errno.EAGAIN:
3404
        raise RetryAgain()
3405

    
3406
      logging.exception("fcntl.flock failed")
3407
      raise
3408

    
3409
  def Exclusive(self, blocking=False, timeout=None):
3410
    """Locks the file in exclusive mode.
3411

3412
    @type blocking: boolean
3413
    @param blocking: whether to block and wait until we
3414
        can lock the file or return immediately
3415
    @type timeout: int or None
3416
    @param timeout: if not None, the duration to wait for the lock
3417
        (in blocking mode)
3418

3419
    """
3420
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3421
                "Failed to lock %s in exclusive mode" % self.filename)
3422

    
3423
  def Shared(self, blocking=False, timeout=None):
3424
    """Locks the file in shared mode.
3425

3426
    @type blocking: boolean
3427
    @param blocking: whether to block and wait until we
3428
        can lock the file or return immediately
3429
    @type timeout: int or None
3430
    @param timeout: if not None, the duration to wait for the lock
3431
        (in blocking mode)
3432

3433
    """
3434
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3435
                "Failed to lock %s in shared mode" % self.filename)
3436

    
3437
  def Unlock(self, blocking=True, timeout=None):
3438
    """Unlocks the file.
3439

3440
    According to C{flock(2)}, unlocking can also be a nonblocking
3441
    operation::
3442

3443
      To make a non-blocking request, include LOCK_NB with any of the above
3444
      operations.
3445

3446
    @type blocking: boolean
3447
    @param blocking: whether to block and wait until we
3448
        can lock the file or return immediately
3449
    @type timeout: int or None
3450
    @param timeout: if not None, the duration to wait for the lock
3451
        (in blocking mode)
3452

3453
    """
3454
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3455
                "Failed to unlock %s" % self.filename)
3456

    
3457

    
3458
class LineSplitter:
3459
  """Splits data chunks into lines separated by newline.
3460

3461
  Instances provide a file-like interface.
3462

3463
  """
3464
  def __init__(self, line_fn, *args):
3465
    """Initializes this class.
3466

3467
    @type line_fn: callable
3468
    @param line_fn: Function called for each line, first parameter is line
3469
    @param args: Extra arguments for L{line_fn}
3470

3471
    """
3472
    assert callable(line_fn)
3473

    
3474
    if args:
3475
      # Python 2.4 doesn't have functools.partial yet
3476
      self._line_fn = \
3477
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3478
    else:
3479
      self._line_fn = line_fn
3480

    
3481
    self._lines = collections.deque()
3482
    self._buffer = ""
3483

    
3484
  def write(self, data):
3485
    parts = (self._buffer + data).split("\n")
3486
    self._buffer = parts.pop()
3487
    self._lines.extend(parts)
3488

    
3489
  def flush(self):
3490
    while self._lines:
3491
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3492

    
3493
  def close(self):
3494
    self.flush()
3495
    if self._buffer:
3496
      self._line_fn(self._buffer)
3497

    
3498

    
3499
def SignalHandled(signums):
3500
  """Signal Handled decoration.
3501

3502
  This special decorator installs a signal handler and then calls the target
3503
  function. The function must accept a 'signal_handlers' keyword argument,
3504
  which will contain a dict indexed by signal number, with SignalHandler
3505
  objects as values.
3506

3507
  The decorator can be safely stacked with iself, to handle multiple signals
3508
  with different handlers.
3509

3510
  @type signums: list
3511
  @param signums: signals to intercept
3512

3513
  """
3514
  def wrap(fn):
3515
    def sig_function(*args, **kwargs):
3516
      assert 'signal_handlers' not in kwargs or \
3517
             kwargs['signal_handlers'] is None or \
3518
             isinstance(kwargs['signal_handlers'], dict), \
3519
             "Wrong signal_handlers parameter in original function call"
3520
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3521
        signal_handlers = kwargs['signal_handlers']
3522
      else:
3523
        signal_handlers = {}
3524
        kwargs['signal_handlers'] = signal_handlers
3525
      sighandler = SignalHandler(signums)
3526
      try:
3527
        for sig in signums:
3528
          signal_handlers[sig] = sighandler
3529
        return fn(*args, **kwargs)
3530
      finally:
3531
        sighandler.Reset()
3532
    return sig_function
3533
  return wrap
3534

    
3535

    
3536
class SignalWakeupFd(object):
3537
  try:
3538
    # This is only supported in Python 2.5 and above (some distributions
3539
    # backported it to Python 2.4)
3540
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3541
  except AttributeError:
3542
    # Not supported
3543
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3544
      return -1
3545
  else:
3546
    def _SetWakeupFd(self, fd):
3547
      return self._set_wakeup_fd_fn(fd)
3548

    
3549
  def __init__(self):
3550
    """Initializes this class.
3551

3552
    """
3553
    (read_fd, write_fd) = os.pipe()
3554

    
3555
    # Once these succeeded, the file descriptors will be closed automatically.
3556
    # Buffer size 0 is important, otherwise .read() with a specified length
3557
    # might buffer data and the file descriptors won't be marked readable.
3558
    self._read_fh = os.fdopen(read_fd, "r", 0)
3559
    self._write_fh = os.fdopen(write_fd, "w", 0)
3560

    
3561
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3562

    
3563
    # Utility functions
3564
    self.fileno = self._read_fh.fileno
3565
    self.read = self._read_fh.read
3566

    
3567
  def Reset(self):
3568
    """Restores the previous wakeup file descriptor.
3569

3570
    """
3571
    if hasattr(self, "_previous") and self._previous is not None:
3572
      self._SetWakeupFd(self._previous)
3573
      self._previous = None
3574

    
3575
  def Notify(self):
3576
    """Notifies the wakeup file descriptor.
3577

3578
    """
3579
    self._write_fh.write("\0")
3580

    
3581
  def __del__(self):
3582
    """Called before object deletion.
3583

3584
    """
3585
    self.Reset()
3586

    
3587

    
3588
class SignalHandler(object):
3589
  """Generic signal handler class.
3590

3591
  It automatically restores the original handler when deconstructed or
3592
  when L{Reset} is called. You can either pass your own handler
3593
  function in or query the L{called} attribute to detect whether the
3594
  signal was sent.
3595

3596
  @type signum: list
3597
  @ivar signum: the signals we handle
3598
  @type called: boolean
3599
  @ivar called: tracks whether any of the signals have been raised
3600

3601
  """
3602
  def __init__(self, signum, handler_fn=None, wakeup=None):
3603
    """Constructs a new SignalHandler instance.
3604

3605
    @type signum: int or list of ints
3606
    @param signum: Single signal number or set of signal numbers
3607
    @type handler_fn: callable
3608
    @param handler_fn: Signal handling function
3609

3610
    """
3611
    assert handler_fn is None or callable(handler_fn)
3612

    
3613
    self.signum = set(signum)
3614
    self.called = False
3615

    
3616
    self._handler_fn = handler_fn
3617
    self._wakeup = wakeup
3618

    
3619
    self._previous = {}
3620
    try:
3621
      for signum in self.signum:
3622
        # Setup handler
3623
        prev_handler = signal.signal(signum, self._HandleSignal)
3624
        try:
3625
          self._previous[signum] = prev_handler
3626
        except:
3627
          # Restore previous handler
3628
          signal.signal(signum, prev_handler)
3629
          raise
3630
    except:
3631
      # Reset all handlers
3632
      self.Reset()
3633
      # Here we have a race condition: a handler may have already been called,
3634
      # but there's not much we can do about it at this point.
3635
      raise
3636

    
3637
  def __del__(self):
3638
    self.Reset()
3639

    
3640
  def Reset(self):
3641
    """Restore previous handler.
3642

3643
    This will reset all the signals to their previous handlers.
3644

3645
    """
3646
    for signum, prev_handler in self._previous.items():
3647
      signal.signal(signum, prev_handler)
3648
      # If successful, remove from dict
3649
      del self._previous[signum]
3650

    
3651
  def Clear(self):
3652
    """Unsets the L{called} flag.
3653

3654
    This function can be used in case a signal may arrive several times.
3655

3656
    """
3657
    self.called = False
3658

    
3659
  def _HandleSignal(self, signum, frame):
3660
    """Actual signal handling function.
3661

3662
    """
3663
    # This is not nice and not absolutely atomic, but it appears to be the only
3664
    # solution in Python -- there are no atomic types.
3665
    self.called = True
3666

    
3667
    if self._wakeup:
3668
      # Notify whoever is interested in signals
3669
      self._wakeup.Notify()
3670

    
3671
    if self._handler_fn:
3672
      self._handler_fn(signum, frame)
3673

    
3674

    
3675
class FieldSet(object):
3676
  """A simple field set.
3677

3678
  Among the features are:
3679
    - checking if a string is among a list of static string or regex objects
3680
    - checking if a whole list of string matches
3681
    - returning the matching groups from a regex match
3682

3683
  Internally, all fields are held as regular expression objects.
3684

3685
  """
3686
  def __init__(self, *items):
3687
    self.items = [re.compile("^%s$" % value) for value in items]
3688

    
3689
  def Extend(self, other_set):
3690
    """Extend the field set with the items from another one"""
3691
    self.items.extend(other_set.items)
3692

    
3693
  def Matches(self, field):
3694
    """Checks if a field matches the current set
3695

3696
    @type field: str
3697
    @param field: the string to match
3698
    @return: either None or a regular expression match object
3699

3700
    """
3701
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3702
      return m
3703
    return None
3704

    
3705
  def NonMatching(self, items):
3706
    """Returns the list of fields not matching the current set
3707

3708
    @type items: list
3709
    @param items: the list of fields to check
3710
    @rtype: list
3711
    @return: list of non-matching fields
3712

3713
    """
3714
    return [val for val in items if not self.Matches(val)]