Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 600535f0

History | View | Annotate | Download (102.1 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52

    
53
from cStringIO import StringIO
54

    
55
try:
56
  # pylint: disable-msg=F0401
57
  import ctypes
58
except ImportError:
59
  ctypes = None
60

    
61
from ganeti import errors
62
from ganeti import constants
63
from ganeti import compat
64
from ganeti import netutils
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
85

    
86
# Certificate verification results
87
(CERT_WARNING,
88
 CERT_ERROR) = range(1, 3)
89

    
90
# Flags for mlockall() (from bits/mman.h)
91
_MCL_CURRENT = 1
92
_MCL_FUTURE = 2
93

    
94

    
95
class RunResult(object):
96
  """Holds the result of running external programs.
97

98
  @type exit_code: int
99
  @ivar exit_code: the exit code of the program, or None (if the program
100
      didn't exit())
101
  @type signal: int or None
102
  @ivar signal: the signal that caused the program to finish, or None
103
      (if the program wasn't terminated by a signal)
104
  @type stdout: str
105
  @ivar stdout: the standard output of the program
106
  @type stderr: str
107
  @ivar stderr: the standard error of the program
108
  @type failed: boolean
109
  @ivar failed: True in case the program was
110
      terminated by a signal or exited with a non-zero exit code
111
  @ivar fail_reason: a string detailing the termination reason
112

113
  """
114
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
115
               "failed", "fail_reason", "cmd"]
116

    
117

    
118
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
119
    self.cmd = cmd
120
    self.exit_code = exit_code
121
    self.signal = signal_
122
    self.stdout = stdout
123
    self.stderr = stderr
124
    self.failed = (signal_ is not None or exit_code != 0)
125

    
126
    if self.signal is not None:
127
      self.fail_reason = "terminated by signal %s" % self.signal
128
    elif self.exit_code is not None:
129
      self.fail_reason = "exited with exit code %s" % self.exit_code
130
    else:
131
      self.fail_reason = "unable to determine termination reason"
132

    
133
    if self.failed:
134
      logging.debug("Command '%s' failed (%s); output: %s",
135
                    self.cmd, self.fail_reason, self.output)
136

    
137
  def _GetOutput(self):
138
    """Returns the combined stdout and stderr for easier usage.
139

140
    """
141
    return self.stdout + self.stderr
142

    
143
  output = property(_GetOutput, None, None, "Return full output")
144

    
145

    
146
def _BuildCmdEnvironment(env, reset):
147
  """Builds the environment for an external program.
148

149
  """
150
  if reset:
151
    cmd_env = {}
152
  else:
153
    cmd_env = os.environ.copy()
154
    cmd_env["LC_ALL"] = "C"
155

    
156
  if env is not None:
157
    cmd_env.update(env)
158

    
159
  return cmd_env
160

    
161

    
162
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
163
  """Execute a (shell) command.
164

165
  The command should not read from its standard input, as it will be
166
  closed.
167

168
  @type cmd: string or list
169
  @param cmd: Command to run
170
  @type env: dict
171
  @param env: Additional environment variables
172
  @type output: str
173
  @param output: if desired, the output of the command can be
174
      saved in a file instead of the RunResult instance; this
175
      parameter denotes the file name (if not None)
176
  @type cwd: string
177
  @param cwd: if specified, will be used as the working
178
      directory for the command; the default will be /
179
  @type reset_env: boolean
180
  @param reset_env: whether to reset or keep the default os environment
181
  @rtype: L{RunResult}
182
  @return: RunResult instance
183
  @raise errors.ProgrammerError: if we call this when forks are disabled
184

185
  """
186
  if no_fork:
187
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
188

    
189
  if isinstance(cmd, basestring):
190
    strcmd = cmd
191
    shell = True
192
  else:
193
    cmd = [str(val) for val in cmd]
194
    strcmd = ShellQuoteArgs(cmd)
195
    shell = False
196

    
197
  if output:
198
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
199
  else:
200
    logging.debug("RunCmd %s", strcmd)
201

    
202
  cmd_env = _BuildCmdEnvironment(env, reset_env)
203

    
204
  try:
205
    if output is None:
206
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
207
    else:
208
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
209
      out = err = ""
210
  except OSError, err:
211
    if err.errno == errno.ENOENT:
212
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
213
                               (strcmd, err))
214
    else:
215
      raise
216

    
217
  if status >= 0:
218
    exitcode = status
219
    signal_ = None
220
  else:
221
    exitcode = None
222
    signal_ = -status
223

    
224
  return RunResult(exitcode, signal_, out, err, strcmd)
225

    
226

    
227
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
228
                pidfile=None):
229
  """Start a daemon process after forking twice.
230

231
  @type cmd: string or list
232
  @param cmd: Command to run
233
  @type env: dict
234
  @param env: Additional environment variables
235
  @type cwd: string
236
  @param cwd: Working directory for the program
237
  @type output: string
238
  @param output: Path to file in which to save the output
239
  @type output_fd: int
240
  @param output_fd: File descriptor for output
241
  @type pidfile: string
242
  @param pidfile: Process ID file
243
  @rtype: int
244
  @return: Daemon process ID
245
  @raise errors.ProgrammerError: if we call this when forks are disabled
246

247
  """
248
  if no_fork:
249
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
250
                                 " disabled")
251

    
252
  if output and not (bool(output) ^ (output_fd is not None)):
253
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
254
                                 " specified")
255

    
256
  if isinstance(cmd, basestring):
257
    cmd = ["/bin/sh", "-c", cmd]
258

    
259
  strcmd = ShellQuoteArgs(cmd)
260

    
261
  if output:
262
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
263
  else:
264
    logging.debug("StartDaemon %s", strcmd)
265

    
266
  cmd_env = _BuildCmdEnvironment(env, False)
267

    
268
  # Create pipe for sending PID back
269
  (pidpipe_read, pidpipe_write) = os.pipe()
270
  try:
271
    try:
272
      # Create pipe for sending error messages
273
      (errpipe_read, errpipe_write) = os.pipe()
274
      try:
275
        try:
276
          # First fork
277
          pid = os.fork()
278
          if pid == 0:
279
            try:
280
              # Child process, won't return
281
              _StartDaemonChild(errpipe_read, errpipe_write,
282
                                pidpipe_read, pidpipe_write,
283
                                cmd, cmd_env, cwd,
284
                                output, output_fd, pidfile)
285
            finally:
286
              # Well, maybe child process failed
287
              os._exit(1) # pylint: disable-msg=W0212
288
        finally:
289
          _CloseFDNoErr(errpipe_write)
290

    
291
        # Wait for daemon to be started (or an error message to arrive) and read
292
        # up to 100 KB as an error message
293
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
294
      finally:
295
        _CloseFDNoErr(errpipe_read)
296
    finally:
297
      _CloseFDNoErr(pidpipe_write)
298

    
299
    # Read up to 128 bytes for PID
300
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
301
  finally:
302
    _CloseFDNoErr(pidpipe_read)
303

    
304
  # Try to avoid zombies by waiting for child process
305
  try:
306
    os.waitpid(pid, 0)
307
  except OSError:
308
    pass
309

    
310
  if errormsg:
311
    raise errors.OpExecError("Error when starting daemon process: %r" %
312
                             errormsg)
313

    
314
  try:
315
    return int(pidtext)
316
  except (ValueError, TypeError), err:
317
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
318
                             (pidtext, err))
319

    
320

    
321
def _StartDaemonChild(errpipe_read, errpipe_write,
322
                      pidpipe_read, pidpipe_write,
323
                      args, env, cwd,
324
                      output, fd_output, pidfile):
325
  """Child process for starting daemon.
326

327
  """
328
  try:
329
    # Close parent's side
330
    _CloseFDNoErr(errpipe_read)
331
    _CloseFDNoErr(pidpipe_read)
332

    
333
    # First child process
334
    os.chdir("/")
335
    os.umask(077)
336
    os.setsid()
337

    
338
    # And fork for the second time
339
    pid = os.fork()
340
    if pid != 0:
341
      # Exit first child process
342
      os._exit(0) # pylint: disable-msg=W0212
343

    
344
    # Make sure pipe is closed on execv* (and thereby notifies original process)
345
    SetCloseOnExecFlag(errpipe_write, True)
346

    
347
    # List of file descriptors to be left open
348
    noclose_fds = [errpipe_write]
349

    
350
    # Open PID file
351
    if pidfile:
352
      try:
353
        # TODO: Atomic replace with another locked file instead of writing into
354
        # it after creating
355
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
356

    
357
        # Lock the PID file (and fail if not possible to do so). Any code
358
        # wanting to send a signal to the daemon should try to lock the PID
359
        # file before reading it. If acquiring the lock succeeds, the daemon is
360
        # no longer running and the signal should not be sent.
361
        LockFile(fd_pidfile)
362

    
363
        os.write(fd_pidfile, "%d\n" % os.getpid())
364
      except Exception, err:
365
        raise Exception("Creating and locking PID file failed: %s" % err)
366

    
367
      # Keeping the file open to hold the lock
368
      noclose_fds.append(fd_pidfile)
369

    
370
      SetCloseOnExecFlag(fd_pidfile, False)
371
    else:
372
      fd_pidfile = None
373

    
374
    # Open /dev/null
375
    fd_devnull = os.open(os.devnull, os.O_RDWR)
376

    
377
    assert not output or (bool(output) ^ (fd_output is not None))
378

    
379
    if fd_output is not None:
380
      pass
381
    elif output:
382
      # Open output file
383
      try:
384
        # TODO: Implement flag to set append=yes/no
385
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
386
      except EnvironmentError, err:
387
        raise Exception("Opening output file failed: %s" % err)
388
    else:
389
      fd_output = fd_devnull
390

    
391
    # Redirect standard I/O
392
    os.dup2(fd_devnull, 0)
393
    os.dup2(fd_output, 1)
394
    os.dup2(fd_output, 2)
395

    
396
    # Send daemon PID to parent
397
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
398

    
399
    # Close all file descriptors except stdio and error message pipe
400
    CloseFDs(noclose_fds=noclose_fds)
401

    
402
    # Change working directory
403
    os.chdir(cwd)
404

    
405
    if env is None:
406
      os.execvp(args[0], args)
407
    else:
408
      os.execvpe(args[0], args, env)
409
  except: # pylint: disable-msg=W0702
410
    try:
411
      # Report errors to original process
412
      buf = str(sys.exc_info()[1])
413

    
414
      RetryOnSignal(os.write, errpipe_write, buf)
415
    except: # pylint: disable-msg=W0702
416
      # Ignore errors in error handling
417
      pass
418

    
419
  os._exit(1) # pylint: disable-msg=W0212
420

    
421

    
422
def _RunCmdPipe(cmd, env, via_shell, cwd):
423
  """Run a command and return its output.
424

425
  @type  cmd: string or list
426
  @param cmd: Command to run
427
  @type env: dict
428
  @param env: The environment to use
429
  @type via_shell: bool
430
  @param via_shell: if we should run via the shell
431
  @type cwd: string
432
  @param cwd: the working directory for the program
433
  @rtype: tuple
434
  @return: (out, err, status)
435

436
  """
437
  poller = select.poll()
438
  child = subprocess.Popen(cmd, shell=via_shell,
439
                           stderr=subprocess.PIPE,
440
                           stdout=subprocess.PIPE,
441
                           stdin=subprocess.PIPE,
442
                           close_fds=True, env=env,
443
                           cwd=cwd)
444

    
445
  child.stdin.close()
446
  poller.register(child.stdout, select.POLLIN)
447
  poller.register(child.stderr, select.POLLIN)
448
  out = StringIO()
449
  err = StringIO()
450
  fdmap = {
451
    child.stdout.fileno(): (out, child.stdout),
452
    child.stderr.fileno(): (err, child.stderr),
453
    }
454
  for fd in fdmap:
455
    SetNonblockFlag(fd, True)
456

    
457
  while fdmap:
458
    pollresult = RetryOnSignal(poller.poll)
459

    
460
    for fd, event in pollresult:
461
      if event & select.POLLIN or event & select.POLLPRI:
462
        data = fdmap[fd][1].read()
463
        # no data from read signifies EOF (the same as POLLHUP)
464
        if not data:
465
          poller.unregister(fd)
466
          del fdmap[fd]
467
          continue
468
        fdmap[fd][0].write(data)
469
      if (event & select.POLLNVAL or event & select.POLLHUP or
470
          event & select.POLLERR):
471
        poller.unregister(fd)
472
        del fdmap[fd]
473

    
474
  out = out.getvalue()
475
  err = err.getvalue()
476

    
477
  status = child.wait()
478
  return out, err, status
479

    
480

    
481
def _RunCmdFile(cmd, env, via_shell, output, cwd):
482
  """Run a command and save its output to a file.
483

484
  @type  cmd: string or list
485
  @param cmd: Command to run
486
  @type env: dict
487
  @param env: The environment to use
488
  @type via_shell: bool
489
  @param via_shell: if we should run via the shell
490
  @type output: str
491
  @param output: the filename in which to save the output
492
  @type cwd: string
493
  @param cwd: the working directory for the program
494
  @rtype: int
495
  @return: the exit status
496

497
  """
498
  fh = open(output, "a")
499
  try:
500
    child = subprocess.Popen(cmd, shell=via_shell,
501
                             stderr=subprocess.STDOUT,
502
                             stdout=fh,
503
                             stdin=subprocess.PIPE,
504
                             close_fds=True, env=env,
505
                             cwd=cwd)
506

    
507
    child.stdin.close()
508
    status = child.wait()
509
  finally:
510
    fh.close()
511
  return status
512

    
513

    
514
def SetCloseOnExecFlag(fd, enable):
515
  """Sets or unsets the close-on-exec flag on a file descriptor.
516

517
  @type fd: int
518
  @param fd: File descriptor
519
  @type enable: bool
520
  @param enable: Whether to set or unset it.
521

522
  """
523
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
524

    
525
  if enable:
526
    flags |= fcntl.FD_CLOEXEC
527
  else:
528
    flags &= ~fcntl.FD_CLOEXEC
529

    
530
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
531

    
532

    
533
def SetNonblockFlag(fd, enable):
534
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
535

536
  @type fd: int
537
  @param fd: File descriptor
538
  @type enable: bool
539
  @param enable: Whether to set or unset it
540

541
  """
542
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
543

    
544
  if enable:
545
    flags |= os.O_NONBLOCK
546
  else:
547
    flags &= ~os.O_NONBLOCK
548

    
549
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
550

    
551

    
552
def RetryOnSignal(fn, *args, **kwargs):
553
  """Calls a function again if it failed due to EINTR.
554

555
  """
556
  while True:
557
    try:
558
      return fn(*args, **kwargs)
559
    except EnvironmentError, err:
560
      if err.errno != errno.EINTR:
561
        raise
562
    except (socket.error, select.error), err:
563
      # In python 2.6 and above select.error is an IOError, so it's handled
564
      # above, in 2.5 and below it's not, and it's handled here.
565
      if not (err.args and err.args[0] == errno.EINTR):
566
        raise
567

    
568

    
569
def RunParts(dir_name, env=None, reset_env=False):
570
  """Run Scripts or programs in a directory
571

572
  @type dir_name: string
573
  @param dir_name: absolute path to a directory
574
  @type env: dict
575
  @param env: The environment to use
576
  @type reset_env: boolean
577
  @param reset_env: whether to reset or keep the default os environment
578
  @rtype: list of tuples
579
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
580

581
  """
582
  rr = []
583

    
584
  try:
585
    dir_contents = ListVisibleFiles(dir_name)
586
  except OSError, err:
587
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
588
    return rr
589

    
590
  for relname in sorted(dir_contents):
591
    fname = PathJoin(dir_name, relname)
592
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
593
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
594
      rr.append((relname, constants.RUNPARTS_SKIP, None))
595
    else:
596
      try:
597
        result = RunCmd([fname], env=env, reset_env=reset_env)
598
      except Exception, err: # pylint: disable-msg=W0703
599
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
600
      else:
601
        rr.append((relname, constants.RUNPARTS_RUN, result))
602

    
603
  return rr
604

    
605

    
606
def RemoveFile(filename):
607
  """Remove a file ignoring some errors.
608

609
  Remove a file, ignoring non-existing ones or directories. Other
610
  errors are passed.
611

612
  @type filename: str
613
  @param filename: the file to be removed
614

615
  """
616
  try:
617
    os.unlink(filename)
618
  except OSError, err:
619
    if err.errno not in (errno.ENOENT, errno.EISDIR):
620
      raise
621

    
622

    
623
def RemoveDir(dirname):
624
  """Remove an empty directory.
625

626
  Remove a directory, ignoring non-existing ones.
627
  Other errors are passed. This includes the case,
628
  where the directory is not empty, so it can't be removed.
629

630
  @type dirname: str
631
  @param dirname: the empty directory to be removed
632

633
  """
634
  try:
635
    os.rmdir(dirname)
636
  except OSError, err:
637
    if err.errno != errno.ENOENT:
638
      raise
639

    
640

    
641
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
642
  """Renames a file.
643

644
  @type old: string
645
  @param old: Original path
646
  @type new: string
647
  @param new: New path
648
  @type mkdir: bool
649
  @param mkdir: Whether to create target directory if it doesn't exist
650
  @type mkdir_mode: int
651
  @param mkdir_mode: Mode for newly created directories
652

653
  """
654
  try:
655
    return os.rename(old, new)
656
  except OSError, err:
657
    # In at least one use case of this function, the job queue, directory
658
    # creation is very rare. Checking for the directory before renaming is not
659
    # as efficient.
660
    if mkdir and err.errno == errno.ENOENT:
661
      # Create directory and try again
662
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
663

    
664
      return os.rename(old, new)
665

    
666
    raise
667

    
668

    
669
def Makedirs(path, mode=0750):
670
  """Super-mkdir; create a leaf directory and all intermediate ones.
671

672
  This is a wrapper around C{os.makedirs} adding error handling not implemented
673
  before Python 2.5.
674

675
  """
676
  try:
677
    os.makedirs(path, mode)
678
  except OSError, err:
679
    # Ignore EEXIST. This is only handled in os.makedirs as included in
680
    # Python 2.5 and above.
681
    if err.errno != errno.EEXIST or not os.path.exists(path):
682
      raise
683

    
684

    
685
def ResetTempfileModule():
686
  """Resets the random name generator of the tempfile module.
687

688
  This function should be called after C{os.fork} in the child process to
689
  ensure it creates a newly seeded random generator. Otherwise it would
690
  generate the same random parts as the parent process. If several processes
691
  race for the creation of a temporary file, this could lead to one not getting
692
  a temporary name.
693

694
  """
695
  # pylint: disable-msg=W0212
696
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
697
    tempfile._once_lock.acquire()
698
    try:
699
      # Reset random name generator
700
      tempfile._name_sequence = None
701
    finally:
702
      tempfile._once_lock.release()
703
  else:
704
    logging.critical("The tempfile module misses at least one of the"
705
                     " '_once_lock' and '_name_sequence' attributes")
706

    
707

    
708
def _FingerprintFile(filename):
709
  """Compute the fingerprint of a file.
710

711
  If the file does not exist, a None will be returned
712
  instead.
713

714
  @type filename: str
715
  @param filename: the filename to checksum
716
  @rtype: str
717
  @return: the hex digest of the sha checksum of the contents
718
      of the file
719

720
  """
721
  if not (os.path.exists(filename) and os.path.isfile(filename)):
722
    return None
723

    
724
  f = open(filename)
725

    
726
  fp = compat.sha1_hash()
727
  while True:
728
    data = f.read(4096)
729
    if not data:
730
      break
731

    
732
    fp.update(data)
733

    
734
  return fp.hexdigest()
735

    
736

    
737
def FingerprintFiles(files):
738
  """Compute fingerprints for a list of files.
739

740
  @type files: list
741
  @param files: the list of filename to fingerprint
742
  @rtype: dict
743
  @return: a dictionary filename: fingerprint, holding only
744
      existing files
745

746
  """
747
  ret = {}
748

    
749
  for filename in files:
750
    cksum = _FingerprintFile(filename)
751
    if cksum:
752
      ret[filename] = cksum
753

    
754
  return ret
755

    
756

    
757
def ForceDictType(target, key_types, allowed_values=None):
758
  """Force the values of a dict to have certain types.
759

760
  @type target: dict
761
  @param target: the dict to update
762
  @type key_types: dict
763
  @param key_types: dict mapping target dict keys to types
764
                    in constants.ENFORCEABLE_TYPES
765
  @type allowed_values: list
766
  @keyword allowed_values: list of specially allowed values
767

768
  """
769
  if allowed_values is None:
770
    allowed_values = []
771

    
772
  if not isinstance(target, dict):
773
    msg = "Expected dictionary, got '%s'" % target
774
    raise errors.TypeEnforcementError(msg)
775

    
776
  for key in target:
777
    if key not in key_types:
778
      msg = "Unknown key '%s'" % key
779
      raise errors.TypeEnforcementError(msg)
780

    
781
    if target[key] in allowed_values:
782
      continue
783

    
784
    ktype = key_types[key]
785
    if ktype not in constants.ENFORCEABLE_TYPES:
786
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
787
      raise errors.ProgrammerError(msg)
788

    
789
    if ktype == constants.VTYPE_STRING:
790
      if not isinstance(target[key], basestring):
791
        if isinstance(target[key], bool) and not target[key]:
792
          target[key] = ''
793
        else:
794
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
795
          raise errors.TypeEnforcementError(msg)
796
    elif ktype == constants.VTYPE_BOOL:
797
      if isinstance(target[key], basestring) and target[key]:
798
        if target[key].lower() == constants.VALUE_FALSE:
799
          target[key] = False
800
        elif target[key].lower() == constants.VALUE_TRUE:
801
          target[key] = True
802
        else:
803
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
804
          raise errors.TypeEnforcementError(msg)
805
      elif target[key]:
806
        target[key] = True
807
      else:
808
        target[key] = False
809
    elif ktype == constants.VTYPE_SIZE:
810
      try:
811
        target[key] = ParseUnit(target[key])
812
      except errors.UnitParseError, err:
813
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
814
              (key, target[key], err)
815
        raise errors.TypeEnforcementError(msg)
816
    elif ktype == constants.VTYPE_INT:
817
      try:
818
        target[key] = int(target[key])
819
      except (ValueError, TypeError):
820
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
821
        raise errors.TypeEnforcementError(msg)
822

    
823

    
824
def _GetProcStatusPath(pid):
825
  """Returns the path for a PID's proc status file.
826

827
  @type pid: int
828
  @param pid: Process ID
829
  @rtype: string
830

831
  """
832
  return "/proc/%d/status" % pid
833

    
834

    
835
def IsProcessAlive(pid):
836
  """Check if a given pid exists on the system.
837

838
  @note: zombie status is not handled, so zombie processes
839
      will be returned as alive
840
  @type pid: int
841
  @param pid: the process ID to check
842
  @rtype: boolean
843
  @return: True if the process exists
844

845
  """
846
  def _TryStat(name):
847
    try:
848
      os.stat(name)
849
      return True
850
    except EnvironmentError, err:
851
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
852
        return False
853
      elif err.errno == errno.EINVAL:
854
        raise RetryAgain(err)
855
      raise
856

    
857
  assert isinstance(pid, int), "pid must be an integer"
858
  if pid <= 0:
859
    return False
860

    
861
  # /proc in a multiprocessor environment can have strange behaviors.
862
  # Retry the os.stat a few times until we get a good result.
863
  try:
864
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
865
                 args=[_GetProcStatusPath(pid)])
866
  except RetryTimeout, err:
867
    err.RaiseInner()
868

    
869

    
870
def _ParseSigsetT(sigset):
871
  """Parse a rendered sigset_t value.
872

873
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
874
  function.
875

876
  @type sigset: string
877
  @param sigset: Rendered signal set from /proc/$pid/status
878
  @rtype: set
879
  @return: Set of all enabled signal numbers
880

881
  """
882
  result = set()
883

    
884
  signum = 0
885
  for ch in reversed(sigset):
886
    chv = int(ch, 16)
887

    
888
    # The following could be done in a loop, but it's easier to read and
889
    # understand in the unrolled form
890
    if chv & 1:
891
      result.add(signum + 1)
892
    if chv & 2:
893
      result.add(signum + 2)
894
    if chv & 4:
895
      result.add(signum + 3)
896
    if chv & 8:
897
      result.add(signum + 4)
898

    
899
    signum += 4
900

    
901
  return result
902

    
903

    
904
def _GetProcStatusField(pstatus, field):
905
  """Retrieves a field from the contents of a proc status file.
906

907
  @type pstatus: string
908
  @param pstatus: Contents of /proc/$pid/status
909
  @type field: string
910
  @param field: Name of field whose value should be returned
911
  @rtype: string
912

913
  """
914
  for line in pstatus.splitlines():
915
    parts = line.split(":", 1)
916

    
917
    if len(parts) < 2 or parts[0] != field:
918
      continue
919

    
920
    return parts[1].strip()
921

    
922
  return None
923

    
924

    
925
def IsProcessHandlingSignal(pid, signum, status_path=None):
926
  """Checks whether a process is handling a signal.
927

928
  @type pid: int
929
  @param pid: Process ID
930
  @type signum: int
931
  @param signum: Signal number
932
  @rtype: bool
933

934
  """
935
  if status_path is None:
936
    status_path = _GetProcStatusPath(pid)
937

    
938
  try:
939
    proc_status = ReadFile(status_path)
940
  except EnvironmentError, err:
941
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
942
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
943
      return False
944
    raise
945

    
946
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
947
  if sigcgt is None:
948
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
949

    
950
  # Now check whether signal is handled
951
  return signum in _ParseSigsetT(sigcgt)
952

    
953

    
954
def ReadPidFile(pidfile):
955
  """Read a pid from a file.
956

957
  @type  pidfile: string
958
  @param pidfile: path to the file containing the pid
959
  @rtype: int
960
  @return: The process id, if the file exists and contains a valid PID,
961
           otherwise 0
962

963
  """
964
  try:
965
    raw_data = ReadOneLineFile(pidfile)
966
  except EnvironmentError, err:
967
    if err.errno != errno.ENOENT:
968
      logging.exception("Can't read pid file")
969
    return 0
970

    
971
  try:
972
    pid = int(raw_data)
973
  except (TypeError, ValueError), err:
974
    logging.info("Can't parse pid file contents", exc_info=True)
975
    return 0
976

    
977
  return pid
978

    
979

    
980
def ReadLockedPidFile(path):
981
  """Reads a locked PID file.
982

983
  This can be used together with L{StartDaemon}.
984

985
  @type path: string
986
  @param path: Path to PID file
987
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
988

989
  """
990
  try:
991
    fd = os.open(path, os.O_RDONLY)
992
  except EnvironmentError, err:
993
    if err.errno == errno.ENOENT:
994
      # PID file doesn't exist
995
      return None
996
    raise
997

    
998
  try:
999
    try:
1000
      # Try to acquire lock
1001
      LockFile(fd)
1002
    except errors.LockError:
1003
      # Couldn't lock, daemon is running
1004
      return int(os.read(fd, 100))
1005
  finally:
1006
    os.close(fd)
1007

    
1008
  return None
1009

    
1010

    
1011
def MatchNameComponent(key, name_list, case_sensitive=True):
1012
  """Try to match a name against a list.
1013

1014
  This function will try to match a name like test1 against a list
1015
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1016
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1017
  not I{'test1.ex'}. A multiple match will be considered as no match
1018
  at all (e.g. I{'test1'} against C{['test1.example.com',
1019
  'test1.example.org']}), except when the key fully matches an entry
1020
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1021

1022
  @type key: str
1023
  @param key: the name to be searched
1024
  @type name_list: list
1025
  @param name_list: the list of strings against which to search the key
1026
  @type case_sensitive: boolean
1027
  @param case_sensitive: whether to provide a case-sensitive match
1028

1029
  @rtype: None or str
1030
  @return: None if there is no match I{or} if there are multiple matches,
1031
      otherwise the element from the list which matches
1032

1033
  """
1034
  if key in name_list:
1035
    return key
1036

    
1037
  re_flags = 0
1038
  if not case_sensitive:
1039
    re_flags |= re.IGNORECASE
1040
    key = key.upper()
1041
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1042
  names_filtered = []
1043
  string_matches = []
1044
  for name in name_list:
1045
    if mo.match(name) is not None:
1046
      names_filtered.append(name)
1047
      if not case_sensitive and key == name.upper():
1048
        string_matches.append(name)
1049

    
1050
  if len(string_matches) == 1:
1051
    return string_matches[0]
1052
  if len(names_filtered) == 1:
1053
    return names_filtered[0]
1054
  return None
1055

    
1056

    
1057
def ValidateServiceName(name):
1058
  """Validate the given service name.
1059

1060
  @type name: number or string
1061
  @param name: Service name or port specification
1062

1063
  """
1064
  try:
1065
    numport = int(name)
1066
  except (ValueError, TypeError):
1067
    # Non-numeric service name
1068
    valid = _VALID_SERVICE_NAME_RE.match(name)
1069
  else:
1070
    # Numeric port (protocols other than TCP or UDP might need adjustments
1071
    # here)
1072
    valid = (numport >= 0 and numport < (1 << 16))
1073

    
1074
  if not valid:
1075
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1076
                               errors.ECODE_INVAL)
1077

    
1078
  return name
1079

    
1080

    
1081
def ListVolumeGroups():
1082
  """List volume groups and their size
1083

1084
  @rtype: dict
1085
  @return:
1086
       Dictionary with keys volume name and values
1087
       the size of the volume
1088

1089
  """
1090
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1091
  result = RunCmd(command)
1092
  retval = {}
1093
  if result.failed:
1094
    return retval
1095

    
1096
  for line in result.stdout.splitlines():
1097
    try:
1098
      name, size = line.split()
1099
      size = int(float(size))
1100
    except (IndexError, ValueError), err:
1101
      logging.error("Invalid output from vgs (%s): %s", err, line)
1102
      continue
1103

    
1104
    retval[name] = size
1105

    
1106
  return retval
1107

    
1108

    
1109
def BridgeExists(bridge):
1110
  """Check whether the given bridge exists in the system
1111

1112
  @type bridge: str
1113
  @param bridge: the bridge name to check
1114
  @rtype: boolean
1115
  @return: True if it does
1116

1117
  """
1118
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1119

    
1120

    
1121
def NiceSort(name_list):
1122
  """Sort a list of strings based on digit and non-digit groupings.
1123

1124
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1125
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1126
  'a11']}.
1127

1128
  The sort algorithm breaks each name in groups of either only-digits
1129
  or no-digits. Only the first eight such groups are considered, and
1130
  after that we just use what's left of the string.
1131

1132
  @type name_list: list
1133
  @param name_list: the names to be sorted
1134
  @rtype: list
1135
  @return: a copy of the name list sorted with our algorithm
1136

1137
  """
1138
  _SORTER_BASE = "(\D+|\d+)"
1139
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1140
                                                  _SORTER_BASE, _SORTER_BASE,
1141
                                                  _SORTER_BASE, _SORTER_BASE,
1142
                                                  _SORTER_BASE, _SORTER_BASE)
1143
  _SORTER_RE = re.compile(_SORTER_FULL)
1144
  _SORTER_NODIGIT = re.compile("^\D*$")
1145
  def _TryInt(val):
1146
    """Attempts to convert a variable to integer."""
1147
    if val is None or _SORTER_NODIGIT.match(val):
1148
      return val
1149
    rval = int(val)
1150
    return rval
1151

    
1152
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1153
             for name in name_list]
1154
  to_sort.sort()
1155
  return [tup[1] for tup in to_sort]
1156

    
1157

    
1158
def TryConvert(fn, val):
1159
  """Try to convert a value ignoring errors.
1160

1161
  This function tries to apply function I{fn} to I{val}. If no
1162
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1163
  the result, else it will return the original value. Any other
1164
  exceptions are propagated to the caller.
1165

1166
  @type fn: callable
1167
  @param fn: function to apply to the value
1168
  @param val: the value to be converted
1169
  @return: The converted value if the conversion was successful,
1170
      otherwise the original value.
1171

1172
  """
1173
  try:
1174
    nv = fn(val)
1175
  except (ValueError, TypeError):
1176
    nv = val
1177
  return nv
1178

    
1179

    
1180
def IsValidShellParam(word):
1181
  """Verifies is the given word is safe from the shell's p.o.v.
1182

1183
  This means that we can pass this to a command via the shell and be
1184
  sure that it doesn't alter the command line and is passed as such to
1185
  the actual command.
1186

1187
  Note that we are overly restrictive here, in order to be on the safe
1188
  side.
1189

1190
  @type word: str
1191
  @param word: the word to check
1192
  @rtype: boolean
1193
  @return: True if the word is 'safe'
1194

1195
  """
1196
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1197

    
1198

    
1199
def BuildShellCmd(template, *args):
1200
  """Build a safe shell command line from the given arguments.
1201

1202
  This function will check all arguments in the args list so that they
1203
  are valid shell parameters (i.e. they don't contain shell
1204
  metacharacters). If everything is ok, it will return the result of
1205
  template % args.
1206

1207
  @type template: str
1208
  @param template: the string holding the template for the
1209
      string formatting
1210
  @rtype: str
1211
  @return: the expanded command line
1212

1213
  """
1214
  for word in args:
1215
    if not IsValidShellParam(word):
1216
      raise errors.ProgrammerError("Shell argument '%s' contains"
1217
                                   " invalid characters" % word)
1218
  return template % args
1219

    
1220

    
1221
def FormatUnit(value, units):
1222
  """Formats an incoming number of MiB with the appropriate unit.
1223

1224
  @type value: int
1225
  @param value: integer representing the value in MiB (1048576)
1226
  @type units: char
1227
  @param units: the type of formatting we should do:
1228
      - 'h' for automatic scaling
1229
      - 'm' for MiBs
1230
      - 'g' for GiBs
1231
      - 't' for TiBs
1232
  @rtype: str
1233
  @return: the formatted value (with suffix)
1234

1235
  """
1236
  if units not in ('m', 'g', 't', 'h'):
1237
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1238

    
1239
  suffix = ''
1240

    
1241
  if units == 'm' or (units == 'h' and value < 1024):
1242
    if units == 'h':
1243
      suffix = 'M'
1244
    return "%d%s" % (round(value, 0), suffix)
1245

    
1246
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1247
    if units == 'h':
1248
      suffix = 'G'
1249
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1250

    
1251
  else:
1252
    if units == 'h':
1253
      suffix = 'T'
1254
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1255

    
1256

    
1257
def ParseUnit(input_string):
1258
  """Tries to extract number and scale from the given string.
1259

1260
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1261
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1262
  is always an int in MiB.
1263

1264
  """
1265
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1266
  if not m:
1267
    raise errors.UnitParseError("Invalid format")
1268

    
1269
  value = float(m.groups()[0])
1270

    
1271
  unit = m.groups()[1]
1272
  if unit:
1273
    lcunit = unit.lower()
1274
  else:
1275
    lcunit = 'm'
1276

    
1277
  if lcunit in ('m', 'mb', 'mib'):
1278
    # Value already in MiB
1279
    pass
1280

    
1281
  elif lcunit in ('g', 'gb', 'gib'):
1282
    value *= 1024
1283

    
1284
  elif lcunit in ('t', 'tb', 'tib'):
1285
    value *= 1024 * 1024
1286

    
1287
  else:
1288
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1289

    
1290
  # Make sure we round up
1291
  if int(value) < value:
1292
    value += 1
1293

    
1294
  # Round up to the next multiple of 4
1295
  value = int(value)
1296
  if value % 4:
1297
    value += 4 - value % 4
1298

    
1299
  return value
1300

    
1301

    
1302
def ParseCpuMask(cpu_mask):
1303
  """Parse a CPU mask definition and return the list of CPU IDs.
1304

1305
  CPU mask format: comma-separated list of CPU IDs
1306
  or dash-separated ID ranges
1307
  Example: "0-2,5" -> "0,1,2,5"
1308

1309
  @type cpu_mask: str
1310
  @param cpu_mask: CPU mask definition
1311
  @rtype: list of int
1312
  @return: list of CPU IDs
1313

1314
  """
1315
  if not cpu_mask:
1316
    return []
1317
  cpu_list = []
1318
  for range_def in cpu_mask.split(","):
1319
    boundaries = range_def.split("-")
1320
    n_elements = len(boundaries)
1321
    if n_elements > 2:
1322
      raise errors.ParseError("Invalid CPU ID range definition"
1323
                              " (only one hyphen allowed): %s" % range_def)
1324
    try:
1325
      lower = int(boundaries[0])
1326
    except (ValueError, TypeError), err:
1327
      raise errors.ParseError("Invalid CPU ID value for lower boundary of"
1328
                              " CPU ID range: %s" % str(err))
1329
    try:
1330
      higher = int(boundaries[-1])
1331
    except (ValueError, TypeError), err:
1332
      raise errors.ParseError("Invalid CPU ID value for higher boundary of"
1333
                              " CPU ID range: %s" % str(err))
1334
    if lower > higher:
1335
      raise errors.ParseError("Invalid CPU ID range definition"
1336
                              " (%d > %d): %s" % (lower, higher, range_def))
1337
    cpu_list.extend(range(lower, higher + 1))
1338
  return cpu_list
1339

    
1340

    
1341
def AddAuthorizedKey(file_obj, key):
1342
  """Adds an SSH public key to an authorized_keys file.
1343

1344
  @type file_obj: str or file handle
1345
  @param file_obj: path to authorized_keys file
1346
  @type key: str
1347
  @param key: string containing key
1348

1349
  """
1350
  key_fields = key.split()
1351

    
1352
  if isinstance(file_obj, basestring):
1353
    f = open(file_obj, 'a+')
1354
  else:
1355
    f = file_obj
1356

    
1357
  try:
1358
    nl = True
1359
    for line in f:
1360
      # Ignore whitespace changes
1361
      if line.split() == key_fields:
1362
        break
1363
      nl = line.endswith('\n')
1364
    else:
1365
      if not nl:
1366
        f.write("\n")
1367
      f.write(key.rstrip('\r\n'))
1368
      f.write("\n")
1369
      f.flush()
1370
  finally:
1371
    f.close()
1372

    
1373

    
1374
def RemoveAuthorizedKey(file_name, key):
1375
  """Removes an SSH public key from an authorized_keys file.
1376

1377
  @type file_name: str
1378
  @param file_name: path to authorized_keys file
1379
  @type key: str
1380
  @param key: string containing key
1381

1382
  """
1383
  key_fields = key.split()
1384

    
1385
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1386
  try:
1387
    out = os.fdopen(fd, 'w')
1388
    try:
1389
      f = open(file_name, 'r')
1390
      try:
1391
        for line in f:
1392
          # Ignore whitespace changes while comparing lines
1393
          if line.split() != key_fields:
1394
            out.write(line)
1395

    
1396
        out.flush()
1397
        os.rename(tmpname, file_name)
1398
      finally:
1399
        f.close()
1400
    finally:
1401
      out.close()
1402
  except:
1403
    RemoveFile(tmpname)
1404
    raise
1405

    
1406

    
1407
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1408
  """Sets the name of an IP address and hostname in /etc/hosts.
1409

1410
  @type file_name: str
1411
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1412
  @type ip: str
1413
  @param ip: the IP address
1414
  @type hostname: str
1415
  @param hostname: the hostname to be added
1416
  @type aliases: list
1417
  @param aliases: the list of aliases to add for the hostname
1418

1419
  """
1420
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1421
  # Ensure aliases are unique
1422
  aliases = UniqueSequence([hostname] + aliases)[1:]
1423

    
1424
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1425
  try:
1426
    out = os.fdopen(fd, 'w')
1427
    try:
1428
      f = open(file_name, 'r')
1429
      try:
1430
        for line in f:
1431
          fields = line.split()
1432
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1433
            continue
1434
          out.write(line)
1435

    
1436
        out.write("%s\t%s" % (ip, hostname))
1437
        if aliases:
1438
          out.write(" %s" % ' '.join(aliases))
1439
        out.write('\n')
1440

    
1441
        out.flush()
1442
        os.fsync(out)
1443
        os.chmod(tmpname, 0644)
1444
        os.rename(tmpname, file_name)
1445
      finally:
1446
        f.close()
1447
    finally:
1448
      out.close()
1449
  except:
1450
    RemoveFile(tmpname)
1451
    raise
1452

    
1453

    
1454
def AddHostToEtcHosts(hostname):
1455
  """Wrapper around SetEtcHostsEntry.
1456

1457
  @type hostname: str
1458
  @param hostname: a hostname that will be resolved and added to
1459
      L{constants.ETC_HOSTS}
1460

1461
  """
1462
  hi = netutils.HostInfo(name=hostname)
1463
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1464

    
1465

    
1466
def RemoveEtcHostsEntry(file_name, hostname):
1467
  """Removes a hostname from /etc/hosts.
1468

1469
  IP addresses without names are removed from the file.
1470

1471
  @type file_name: str
1472
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1473
  @type hostname: str
1474
  @param hostname: the hostname to be removed
1475

1476
  """
1477
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1478
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1479
  try:
1480
    out = os.fdopen(fd, 'w')
1481
    try:
1482
      f = open(file_name, 'r')
1483
      try:
1484
        for line in f:
1485
          fields = line.split()
1486
          if len(fields) > 1 and not fields[0].startswith('#'):
1487
            names = fields[1:]
1488
            if hostname in names:
1489
              while hostname in names:
1490
                names.remove(hostname)
1491
              if names:
1492
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1493
              continue
1494

    
1495
          out.write(line)
1496

    
1497
        out.flush()
1498
        os.fsync(out)
1499
        os.chmod(tmpname, 0644)
1500
        os.rename(tmpname, file_name)
1501
      finally:
1502
        f.close()
1503
    finally:
1504
      out.close()
1505
  except:
1506
    RemoveFile(tmpname)
1507
    raise
1508

    
1509

    
1510
def RemoveHostFromEtcHosts(hostname):
1511
  """Wrapper around RemoveEtcHostsEntry.
1512

1513
  @type hostname: str
1514
  @param hostname: hostname that will be resolved and its
1515
      full and shot name will be removed from
1516
      L{constants.ETC_HOSTS}
1517

1518
  """
1519
  hi = netutils.HostInfo(name=hostname)
1520
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1521
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1522

    
1523

    
1524
def TimestampForFilename():
1525
  """Returns the current time formatted for filenames.
1526

1527
  The format doesn't contain colons as some shells and applications them as
1528
  separators.
1529

1530
  """
1531
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1532

    
1533

    
1534
def CreateBackup(file_name):
1535
  """Creates a backup of a file.
1536

1537
  @type file_name: str
1538
  @param file_name: file to be backed up
1539
  @rtype: str
1540
  @return: the path to the newly created backup
1541
  @raise errors.ProgrammerError: for invalid file names
1542

1543
  """
1544
  if not os.path.isfile(file_name):
1545
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1546
                                file_name)
1547

    
1548
  prefix = ("%s.backup-%s." %
1549
            (os.path.basename(file_name), TimestampForFilename()))
1550
  dir_name = os.path.dirname(file_name)
1551

    
1552
  fsrc = open(file_name, 'rb')
1553
  try:
1554
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1555
    fdst = os.fdopen(fd, 'wb')
1556
    try:
1557
      logging.debug("Backing up %s at %s", file_name, backup_name)
1558
      shutil.copyfileobj(fsrc, fdst)
1559
    finally:
1560
      fdst.close()
1561
  finally:
1562
    fsrc.close()
1563

    
1564
  return backup_name
1565

    
1566

    
1567
def ShellQuote(value):
1568
  """Quotes shell argument according to POSIX.
1569

1570
  @type value: str
1571
  @param value: the argument to be quoted
1572
  @rtype: str
1573
  @return: the quoted value
1574

1575
  """
1576
  if _re_shell_unquoted.match(value):
1577
    return value
1578
  else:
1579
    return "'%s'" % value.replace("'", "'\\''")
1580

    
1581

    
1582
def ShellQuoteArgs(args):
1583
  """Quotes a list of shell arguments.
1584

1585
  @type args: list
1586
  @param args: list of arguments to be quoted
1587
  @rtype: str
1588
  @return: the quoted arguments concatenated with spaces
1589

1590
  """
1591
  return ' '.join([ShellQuote(i) for i in args])
1592

    
1593

    
1594
class ShellWriter:
1595
  """Helper class to write scripts with indentation.
1596

1597
  """
1598
  INDENT_STR = "  "
1599

    
1600
  def __init__(self, fh):
1601
    """Initializes this class.
1602

1603
    """
1604
    self._fh = fh
1605
    self._indent = 0
1606

    
1607
  def IncIndent(self):
1608
    """Increase indentation level by 1.
1609

1610
    """
1611
    self._indent += 1
1612

    
1613
  def DecIndent(self):
1614
    """Decrease indentation level by 1.
1615

1616
    """
1617
    assert self._indent > 0
1618
    self._indent -= 1
1619

    
1620
  def Write(self, txt, *args):
1621
    """Write line to output file.
1622

1623
    """
1624
    assert self._indent >= 0
1625

    
1626
    self._fh.write(self._indent * self.INDENT_STR)
1627

    
1628
    if args:
1629
      self._fh.write(txt % args)
1630
    else:
1631
      self._fh.write(txt)
1632

    
1633
    self._fh.write("\n")
1634

    
1635

    
1636
def ListVisibleFiles(path):
1637
  """Returns a list of visible files in a directory.
1638

1639
  @type path: str
1640
  @param path: the directory to enumerate
1641
  @rtype: list
1642
  @return: the list of all files not starting with a dot
1643
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1644

1645
  """
1646
  if not IsNormAbsPath(path):
1647
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1648
                                 " absolute/normalized: '%s'" % path)
1649
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1650
  return files
1651

    
1652

    
1653
def GetHomeDir(user, default=None):
1654
  """Try to get the homedir of the given user.
1655

1656
  The user can be passed either as a string (denoting the name) or as
1657
  an integer (denoting the user id). If the user is not found, the
1658
  'default' argument is returned, which defaults to None.
1659

1660
  """
1661
  try:
1662
    if isinstance(user, basestring):
1663
      result = pwd.getpwnam(user)
1664
    elif isinstance(user, (int, long)):
1665
      result = pwd.getpwuid(user)
1666
    else:
1667
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1668
                                   type(user))
1669
  except KeyError:
1670
    return default
1671
  return result.pw_dir
1672

    
1673

    
1674
def NewUUID():
1675
  """Returns a random UUID.
1676

1677
  @note: This is a Linux-specific method as it uses the /proc
1678
      filesystem.
1679
  @rtype: str
1680

1681
  """
1682
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1683

    
1684

    
1685
def GenerateSecret(numbytes=20):
1686
  """Generates a random secret.
1687

1688
  This will generate a pseudo-random secret returning an hex string
1689
  (so that it can be used where an ASCII string is needed).
1690

1691
  @param numbytes: the number of bytes which will be represented by the returned
1692
      string (defaulting to 20, the length of a SHA1 hash)
1693
  @rtype: str
1694
  @return: an hex representation of the pseudo-random sequence
1695

1696
  """
1697
  return os.urandom(numbytes).encode('hex')
1698

    
1699

    
1700
def EnsureDirs(dirs):
1701
  """Make required directories, if they don't exist.
1702

1703
  @param dirs: list of tuples (dir_name, dir_mode)
1704
  @type dirs: list of (string, integer)
1705

1706
  """
1707
  for dir_name, dir_mode in dirs:
1708
    try:
1709
      os.mkdir(dir_name, dir_mode)
1710
    except EnvironmentError, err:
1711
      if err.errno != errno.EEXIST:
1712
        raise errors.GenericError("Cannot create needed directory"
1713
                                  " '%s': %s" % (dir_name, err))
1714
    try:
1715
      os.chmod(dir_name, dir_mode)
1716
    except EnvironmentError, err:
1717
      raise errors.GenericError("Cannot change directory permissions on"
1718
                                " '%s': %s" % (dir_name, err))
1719
    if not os.path.isdir(dir_name):
1720
      raise errors.GenericError("%s is not a directory" % dir_name)
1721

    
1722

    
1723
def ReadFile(file_name, size=-1):
1724
  """Reads a file.
1725

1726
  @type size: int
1727
  @param size: Read at most size bytes (if negative, entire file)
1728
  @rtype: str
1729
  @return: the (possibly partial) content of the file
1730

1731
  """
1732
  f = open(file_name, "r")
1733
  try:
1734
    return f.read(size)
1735
  finally:
1736
    f.close()
1737

    
1738

    
1739
def WriteFile(file_name, fn=None, data=None,
1740
              mode=None, uid=-1, gid=-1,
1741
              atime=None, mtime=None, close=True,
1742
              dry_run=False, backup=False,
1743
              prewrite=None, postwrite=None):
1744
  """(Over)write a file atomically.
1745

1746
  The file_name and either fn (a function taking one argument, the
1747
  file descriptor, and which should write the data to it) or data (the
1748
  contents of the file) must be passed. The other arguments are
1749
  optional and allow setting the file mode, owner and group, and the
1750
  mtime/atime of the file.
1751

1752
  If the function doesn't raise an exception, it has succeeded and the
1753
  target file has the new contents. If the function has raised an
1754
  exception, an existing target file should be unmodified and the
1755
  temporary file should be removed.
1756

1757
  @type file_name: str
1758
  @param file_name: the target filename
1759
  @type fn: callable
1760
  @param fn: content writing function, called with
1761
      file descriptor as parameter
1762
  @type data: str
1763
  @param data: contents of the file
1764
  @type mode: int
1765
  @param mode: file mode
1766
  @type uid: int
1767
  @param uid: the owner of the file
1768
  @type gid: int
1769
  @param gid: the group of the file
1770
  @type atime: int
1771
  @param atime: a custom access time to be set on the file
1772
  @type mtime: int
1773
  @param mtime: a custom modification time to be set on the file
1774
  @type close: boolean
1775
  @param close: whether to close file after writing it
1776
  @type prewrite: callable
1777
  @param prewrite: function to be called before writing content
1778
  @type postwrite: callable
1779
  @param postwrite: function to be called after writing content
1780

1781
  @rtype: None or int
1782
  @return: None if the 'close' parameter evaluates to True,
1783
      otherwise the file descriptor
1784

1785
  @raise errors.ProgrammerError: if any of the arguments are not valid
1786

1787
  """
1788
  if not os.path.isabs(file_name):
1789
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1790
                                 " absolute: '%s'" % file_name)
1791

    
1792
  if [fn, data].count(None) != 1:
1793
    raise errors.ProgrammerError("fn or data required")
1794

    
1795
  if [atime, mtime].count(None) == 1:
1796
    raise errors.ProgrammerError("Both atime and mtime must be either"
1797
                                 " set or None")
1798

    
1799
  if backup and not dry_run and os.path.isfile(file_name):
1800
    CreateBackup(file_name)
1801

    
1802
  dir_name, base_name = os.path.split(file_name)
1803
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1804
  do_remove = True
1805
  # here we need to make sure we remove the temp file, if any error
1806
  # leaves it in place
1807
  try:
1808
    if uid != -1 or gid != -1:
1809
      os.chown(new_name, uid, gid)
1810
    if mode:
1811
      os.chmod(new_name, mode)
1812
    if callable(prewrite):
1813
      prewrite(fd)
1814
    if data is not None:
1815
      os.write(fd, data)
1816
    else:
1817
      fn(fd)
1818
    if callable(postwrite):
1819
      postwrite(fd)
1820
    os.fsync(fd)
1821
    if atime is not None and mtime is not None:
1822
      os.utime(new_name, (atime, mtime))
1823
    if not dry_run:
1824
      os.rename(new_name, file_name)
1825
      do_remove = False
1826
  finally:
1827
    if close:
1828
      os.close(fd)
1829
      result = None
1830
    else:
1831
      result = fd
1832
    if do_remove:
1833
      RemoveFile(new_name)
1834

    
1835
  return result
1836

    
1837

    
1838
def ReadOneLineFile(file_name, strict=False):
1839
  """Return the first non-empty line from a file.
1840

1841
  @type strict: boolean
1842
  @param strict: if True, abort if the file has more than one
1843
      non-empty line
1844

1845
  """
1846
  file_lines = ReadFile(file_name).splitlines()
1847
  full_lines = filter(bool, file_lines)
1848
  if not file_lines or not full_lines:
1849
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1850
  elif strict and len(full_lines) > 1:
1851
    raise errors.GenericError("Too many lines in one-liner file %s" %
1852
                              file_name)
1853
  return full_lines[0]
1854

    
1855

    
1856
def FirstFree(seq, base=0):
1857
  """Returns the first non-existing integer from seq.
1858

1859
  The seq argument should be a sorted list of positive integers. The
1860
  first time the index of an element is smaller than the element
1861
  value, the index will be returned.
1862

1863
  The base argument is used to start at a different offset,
1864
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1865

1866
  Example: C{[0, 1, 3]} will return I{2}.
1867

1868
  @type seq: sequence
1869
  @param seq: the sequence to be analyzed.
1870
  @type base: int
1871
  @param base: use this value as the base index of the sequence
1872
  @rtype: int
1873
  @return: the first non-used index in the sequence
1874

1875
  """
1876
  for idx, elem in enumerate(seq):
1877
    assert elem >= base, "Passed element is higher than base offset"
1878
    if elem > idx + base:
1879
      # idx is not used
1880
      return idx + base
1881
  return None
1882

    
1883

    
1884
def SingleWaitForFdCondition(fdobj, event, timeout):
1885
  """Waits for a condition to occur on the socket.
1886

1887
  Immediately returns at the first interruption.
1888

1889
  @type fdobj: integer or object supporting a fileno() method
1890
  @param fdobj: entity to wait for events on
1891
  @type event: integer
1892
  @param event: ORed condition (see select module)
1893
  @type timeout: float or None
1894
  @param timeout: Timeout in seconds
1895
  @rtype: int or None
1896
  @return: None for timeout, otherwise occured conditions
1897

1898
  """
1899
  check = (event | select.POLLPRI |
1900
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1901

    
1902
  if timeout is not None:
1903
    # Poller object expects milliseconds
1904
    timeout *= 1000
1905

    
1906
  poller = select.poll()
1907
  poller.register(fdobj, event)
1908
  try:
1909
    # TODO: If the main thread receives a signal and we have no timeout, we
1910
    # could wait forever. This should check a global "quit" flag or something
1911
    # every so often.
1912
    io_events = poller.poll(timeout)
1913
  except select.error, err:
1914
    if err[0] != errno.EINTR:
1915
      raise
1916
    io_events = []
1917
  if io_events and io_events[0][1] & check:
1918
    return io_events[0][1]
1919
  else:
1920
    return None
1921

    
1922

    
1923
class FdConditionWaiterHelper(object):
1924
  """Retry helper for WaitForFdCondition.
1925

1926
  This class contains the retried and wait functions that make sure
1927
  WaitForFdCondition can continue waiting until the timeout is actually
1928
  expired.
1929

1930
  """
1931

    
1932
  def __init__(self, timeout):
1933
    self.timeout = timeout
1934

    
1935
  def Poll(self, fdobj, event):
1936
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1937
    if result is None:
1938
      raise RetryAgain()
1939
    else:
1940
      return result
1941

    
1942
  def UpdateTimeout(self, timeout):
1943
    self.timeout = timeout
1944

    
1945

    
1946
def WaitForFdCondition(fdobj, event, timeout):
1947
  """Waits for a condition to occur on the socket.
1948

1949
  Retries until the timeout is expired, even if interrupted.
1950

1951
  @type fdobj: integer or object supporting a fileno() method
1952
  @param fdobj: entity to wait for events on
1953
  @type event: integer
1954
  @param event: ORed condition (see select module)
1955
  @type timeout: float or None
1956
  @param timeout: Timeout in seconds
1957
  @rtype: int or None
1958
  @return: None for timeout, otherwise occured conditions
1959

1960
  """
1961
  if timeout is not None:
1962
    retrywaiter = FdConditionWaiterHelper(timeout)
1963
    try:
1964
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1965
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1966
    except RetryTimeout:
1967
      result = None
1968
  else:
1969
    result = None
1970
    while result is None:
1971
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1972
  return result
1973

    
1974

    
1975
def UniqueSequence(seq):
1976
  """Returns a list with unique elements.
1977

1978
  Element order is preserved.
1979

1980
  @type seq: sequence
1981
  @param seq: the sequence with the source elements
1982
  @rtype: list
1983
  @return: list of unique elements from seq
1984

1985
  """
1986
  seen = set()
1987
  return [i for i in seq if i not in seen and not seen.add(i)]
1988

    
1989

    
1990
def NormalizeAndValidateMac(mac):
1991
  """Normalizes and check if a MAC address is valid.
1992

1993
  Checks whether the supplied MAC address is formally correct, only
1994
  accepts colon separated format. Normalize it to all lower.
1995

1996
  @type mac: str
1997
  @param mac: the MAC to be validated
1998
  @rtype: str
1999
  @return: returns the normalized and validated MAC.
2000

2001
  @raise errors.OpPrereqError: If the MAC isn't valid
2002

2003
  """
2004
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2005
  if not mac_check.match(mac):
2006
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2007
                               mac, errors.ECODE_INVAL)
2008

    
2009
  return mac.lower()
2010

    
2011

    
2012
def TestDelay(duration):
2013
  """Sleep for a fixed amount of time.
2014

2015
  @type duration: float
2016
  @param duration: the sleep duration
2017
  @rtype: boolean
2018
  @return: False for negative value, True otherwise
2019

2020
  """
2021
  if duration < 0:
2022
    return False, "Invalid sleep duration"
2023
  time.sleep(duration)
2024
  return True, None
2025

    
2026

    
2027
def _CloseFDNoErr(fd, retries=5):
2028
  """Close a file descriptor ignoring errors.
2029

2030
  @type fd: int
2031
  @param fd: the file descriptor
2032
  @type retries: int
2033
  @param retries: how many retries to make, in case we get any
2034
      other error than EBADF
2035

2036
  """
2037
  try:
2038
    os.close(fd)
2039
  except OSError, err:
2040
    if err.errno != errno.EBADF:
2041
      if retries > 0:
2042
        _CloseFDNoErr(fd, retries - 1)
2043
    # else either it's closed already or we're out of retries, so we
2044
    # ignore this and go on
2045

    
2046

    
2047
def CloseFDs(noclose_fds=None):
2048
  """Close file descriptors.
2049

2050
  This closes all file descriptors above 2 (i.e. except
2051
  stdin/out/err).
2052

2053
  @type noclose_fds: list or None
2054
  @param noclose_fds: if given, it denotes a list of file descriptor
2055
      that should not be closed
2056

2057
  """
2058
  # Default maximum for the number of available file descriptors.
2059
  if 'SC_OPEN_MAX' in os.sysconf_names:
2060
    try:
2061
      MAXFD = os.sysconf('SC_OPEN_MAX')
2062
      if MAXFD < 0:
2063
        MAXFD = 1024
2064
    except OSError:
2065
      MAXFD = 1024
2066
  else:
2067
    MAXFD = 1024
2068
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2069
  if (maxfd == resource.RLIM_INFINITY):
2070
    maxfd = MAXFD
2071

    
2072
  # Iterate through and close all file descriptors (except the standard ones)
2073
  for fd in range(3, maxfd):
2074
    if noclose_fds and fd in noclose_fds:
2075
      continue
2076
    _CloseFDNoErr(fd)
2077

    
2078

    
2079
def Mlockall(_ctypes=ctypes):
2080
  """Lock current process' virtual address space into RAM.
2081

2082
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2083
  see mlock(2) for more details. This function requires ctypes module.
2084

2085
  @raises errors.NoCtypesError: if ctypes module is not found
2086

2087
  """
2088
  if _ctypes is None:
2089
    raise errors.NoCtypesError()
2090

    
2091
  libc = _ctypes.cdll.LoadLibrary("libc.so.6")
2092
  if libc is None:
2093
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2094
    return
2095

    
2096
  # Some older version of the ctypes module don't have built-in functionality
2097
  # to access the errno global variable, where function error codes are stored.
2098
  # By declaring this variable as a pointer to an integer we can then access
2099
  # its value correctly, should the mlockall call fail, in order to see what
2100
  # the actual error code was.
2101
  # pylint: disable-msg=W0212
2102
  libc.__errno_location.restype = _ctypes.POINTER(_ctypes.c_int)
2103

    
2104
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2105
    # pylint: disable-msg=W0212
2106
    logging.error("Cannot set memory lock: %s",
2107
                  os.strerror(libc.__errno_location().contents.value))
2108
    return
2109

    
2110
  logging.debug("Memory lock set")
2111

    
2112

    
2113
def Daemonize(logfile, run_uid, run_gid):
2114
  """Daemonize the current process.
2115

2116
  This detaches the current process from the controlling terminal and
2117
  runs it in the background as a daemon.
2118

2119
  @type logfile: str
2120
  @param logfile: the logfile to which we should redirect stdout/stderr
2121
  @type run_uid: int
2122
  @param run_uid: Run the child under this uid
2123
  @type run_gid: int
2124
  @param run_gid: Run the child under this gid
2125
  @rtype: int
2126
  @return: the value zero
2127

2128
  """
2129
  # pylint: disable-msg=W0212
2130
  # yes, we really want os._exit
2131
  UMASK = 077
2132
  WORKDIR = "/"
2133

    
2134
  # this might fail
2135
  pid = os.fork()
2136
  if (pid == 0):  # The first child.
2137
    os.setsid()
2138
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2139
    #        make sure to check for config permission and bail out when invoked
2140
    #        with wrong user.
2141
    os.setgid(run_gid)
2142
    os.setuid(run_uid)
2143
    # this might fail
2144
    pid = os.fork() # Fork a second child.
2145
    if (pid == 0):  # The second child.
2146
      os.chdir(WORKDIR)
2147
      os.umask(UMASK)
2148
    else:
2149
      # exit() or _exit()?  See below.
2150
      os._exit(0) # Exit parent (the first child) of the second child.
2151
  else:
2152
    os._exit(0) # Exit parent of the first child.
2153

    
2154
  for fd in range(3):
2155
    _CloseFDNoErr(fd)
2156
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2157
  assert i == 0, "Can't close/reopen stdin"
2158
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2159
  assert i == 1, "Can't close/reopen stdout"
2160
  # Duplicate standard output to standard error.
2161
  os.dup2(1, 2)
2162
  return 0
2163

    
2164

    
2165
def DaemonPidFileName(name):
2166
  """Compute a ganeti pid file absolute path
2167

2168
  @type name: str
2169
  @param name: the daemon name
2170
  @rtype: str
2171
  @return: the full path to the pidfile corresponding to the given
2172
      daemon name
2173

2174
  """
2175
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2176

    
2177

    
2178
def EnsureDaemon(name):
2179
  """Check for and start daemon if not alive.
2180

2181
  """
2182
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2183
  if result.failed:
2184
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2185
                  name, result.fail_reason, result.output)
2186
    return False
2187

    
2188
  return True
2189

    
2190

    
2191
def StopDaemon(name):
2192
  """Stop daemon
2193

2194
  """
2195
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2196
  if result.failed:
2197
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2198
                  name, result.fail_reason, result.output)
2199
    return False
2200

    
2201
  return True
2202

    
2203

    
2204
def WritePidFile(name):
2205
  """Write the current process pidfile.
2206

2207
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2208

2209
  @type name: str
2210
  @param name: the daemon name to use
2211
  @raise errors.GenericError: if the pid file already exists and
2212
      points to a live process
2213

2214
  """
2215
  pid = os.getpid()
2216
  pidfilename = DaemonPidFileName(name)
2217
  if IsProcessAlive(ReadPidFile(pidfilename)):
2218
    raise errors.GenericError("%s contains a live process" % pidfilename)
2219

    
2220
  WriteFile(pidfilename, data="%d\n" % pid)
2221

    
2222

    
2223
def RemovePidFile(name):
2224
  """Remove the current process pidfile.
2225

2226
  Any errors are ignored.
2227

2228
  @type name: str
2229
  @param name: the daemon name used to derive the pidfile name
2230

2231
  """
2232
  pidfilename = DaemonPidFileName(name)
2233
  # TODO: we could check here that the file contains our pid
2234
  try:
2235
    RemoveFile(pidfilename)
2236
  except: # pylint: disable-msg=W0702
2237
    pass
2238

    
2239

    
2240
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2241
                waitpid=False):
2242
  """Kill a process given by its pid.
2243

2244
  @type pid: int
2245
  @param pid: The PID to terminate.
2246
  @type signal_: int
2247
  @param signal_: The signal to send, by default SIGTERM
2248
  @type timeout: int
2249
  @param timeout: The timeout after which, if the process is still alive,
2250
                  a SIGKILL will be sent. If not positive, no such checking
2251
                  will be done
2252
  @type waitpid: boolean
2253
  @param waitpid: If true, we should waitpid on this process after
2254
      sending signals, since it's our own child and otherwise it
2255
      would remain as zombie
2256

2257
  """
2258
  def _helper(pid, signal_, wait):
2259
    """Simple helper to encapsulate the kill/waitpid sequence"""
2260
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2261
      try:
2262
        os.waitpid(pid, os.WNOHANG)
2263
      except OSError:
2264
        pass
2265

    
2266
  if pid <= 0:
2267
    # kill with pid=0 == suicide
2268
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2269

    
2270
  if not IsProcessAlive(pid):
2271
    return
2272

    
2273
  _helper(pid, signal_, waitpid)
2274

    
2275
  if timeout <= 0:
2276
    return
2277

    
2278
  def _CheckProcess():
2279
    if not IsProcessAlive(pid):
2280
      return
2281

    
2282
    try:
2283
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2284
    except OSError:
2285
      raise RetryAgain()
2286

    
2287
    if result_pid > 0:
2288
      return
2289

    
2290
    raise RetryAgain()
2291

    
2292
  try:
2293
    # Wait up to $timeout seconds
2294
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2295
  except RetryTimeout:
2296
    pass
2297

    
2298
  if IsProcessAlive(pid):
2299
    # Kill process if it's still alive
2300
    _helper(pid, signal.SIGKILL, waitpid)
2301

    
2302

    
2303
def FindFile(name, search_path, test=os.path.exists):
2304
  """Look for a filesystem object in a given path.
2305

2306
  This is an abstract method to search for filesystem object (files,
2307
  dirs) under a given search path.
2308

2309
  @type name: str
2310
  @param name: the name to look for
2311
  @type search_path: str
2312
  @param search_path: location to start at
2313
  @type test: callable
2314
  @param test: a function taking one argument that should return True
2315
      if the a given object is valid; the default value is
2316
      os.path.exists, causing only existing files to be returned
2317
  @rtype: str or None
2318
  @return: full path to the object if found, None otherwise
2319

2320
  """
2321
  # validate the filename mask
2322
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2323
    logging.critical("Invalid value passed for external script name: '%s'",
2324
                     name)
2325
    return None
2326

    
2327
  for dir_name in search_path:
2328
    # FIXME: investigate switch to PathJoin
2329
    item_name = os.path.sep.join([dir_name, name])
2330
    # check the user test and that we're indeed resolving to the given
2331
    # basename
2332
    if test(item_name) and os.path.basename(item_name) == name:
2333
      return item_name
2334
  return None
2335

    
2336

    
2337
def CheckVolumeGroupSize(vglist, vgname, minsize):
2338
  """Checks if the volume group list is valid.
2339

2340
  The function will check if a given volume group is in the list of
2341
  volume groups and has a minimum size.
2342

2343
  @type vglist: dict
2344
  @param vglist: dictionary of volume group names and their size
2345
  @type vgname: str
2346
  @param vgname: the volume group we should check
2347
  @type minsize: int
2348
  @param minsize: the minimum size we accept
2349
  @rtype: None or str
2350
  @return: None for success, otherwise the error message
2351

2352
  """
2353
  vgsize = vglist.get(vgname, None)
2354
  if vgsize is None:
2355
    return "volume group '%s' missing" % vgname
2356
  elif vgsize < minsize:
2357
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2358
            (vgname, minsize, vgsize))
2359
  return None
2360

    
2361

    
2362
def SplitTime(value):
2363
  """Splits time as floating point number into a tuple.
2364

2365
  @param value: Time in seconds
2366
  @type value: int or float
2367
  @return: Tuple containing (seconds, microseconds)
2368

2369
  """
2370
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2371

    
2372
  assert 0 <= seconds, \
2373
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2374
  assert 0 <= microseconds <= 999999, \
2375
    "Microseconds must be 0-999999, but are %s" % microseconds
2376

    
2377
  return (int(seconds), int(microseconds))
2378

    
2379

    
2380
def MergeTime(timetuple):
2381
  """Merges a tuple into time as a floating point number.
2382

2383
  @param timetuple: Time as tuple, (seconds, microseconds)
2384
  @type timetuple: tuple
2385
  @return: Time as a floating point number expressed in seconds
2386

2387
  """
2388
  (seconds, microseconds) = timetuple
2389

    
2390
  assert 0 <= seconds, \
2391
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2392
  assert 0 <= microseconds <= 999999, \
2393
    "Microseconds must be 0-999999, but are %s" % microseconds
2394

    
2395
  return float(seconds) + (float(microseconds) * 0.000001)
2396

    
2397

    
2398
class LogFileHandler(logging.FileHandler):
2399
  """Log handler that doesn't fallback to stderr.
2400

2401
  When an error occurs while writing on the logfile, logging.FileHandler tries
2402
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2403
  the logfile. This class avoids failures reporting errors to /dev/console.
2404

2405
  """
2406
  def __init__(self, filename, mode="a", encoding=None):
2407
    """Open the specified file and use it as the stream for logging.
2408

2409
    Also open /dev/console to report errors while logging.
2410

2411
    """
2412
    logging.FileHandler.__init__(self, filename, mode, encoding)
2413
    self.console = open(constants.DEV_CONSOLE, "a")
2414

    
2415
  def handleError(self, record): # pylint: disable-msg=C0103
2416
    """Handle errors which occur during an emit() call.
2417

2418
    Try to handle errors with FileHandler method, if it fails write to
2419
    /dev/console.
2420

2421
    """
2422
    try:
2423
      logging.FileHandler.handleError(self, record)
2424
    except Exception: # pylint: disable-msg=W0703
2425
      try:
2426
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2427
      except Exception: # pylint: disable-msg=W0703
2428
        # Log handler tried everything it could, now just give up
2429
        pass
2430

    
2431

    
2432
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2433
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2434
                 console_logging=False):
2435
  """Configures the logging module.
2436

2437
  @type logfile: str
2438
  @param logfile: the filename to which we should log
2439
  @type debug: integer
2440
  @param debug: if greater than zero, enable debug messages, otherwise
2441
      only those at C{INFO} and above level
2442
  @type stderr_logging: boolean
2443
  @param stderr_logging: whether we should also log to the standard error
2444
  @type program: str
2445
  @param program: the name under which we should log messages
2446
  @type multithreaded: boolean
2447
  @param multithreaded: if True, will add the thread name to the log file
2448
  @type syslog: string
2449
  @param syslog: one of 'no', 'yes', 'only':
2450
      - if no, syslog is not used
2451
      - if yes, syslog is used (in addition to file-logging)
2452
      - if only, only syslog is used
2453
  @type console_logging: boolean
2454
  @param console_logging: if True, will use a FileHandler which falls back to
2455
      the system console if logging fails
2456
  @raise EnvironmentError: if we can't open the log file and
2457
      syslog/stderr logging is disabled
2458

2459
  """
2460
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2461
  sft = program + "[%(process)d]:"
2462
  if multithreaded:
2463
    fmt += "/%(threadName)s"
2464
    sft += " (%(threadName)s)"
2465
  if debug:
2466
    fmt += " %(module)s:%(lineno)s"
2467
    # no debug info for syslog loggers
2468
  fmt += " %(levelname)s %(message)s"
2469
  # yes, we do want the textual level, as remote syslog will probably
2470
  # lose the error level, and it's easier to grep for it
2471
  sft += " %(levelname)s %(message)s"
2472
  formatter = logging.Formatter(fmt)
2473
  sys_fmt = logging.Formatter(sft)
2474

    
2475
  root_logger = logging.getLogger("")
2476
  root_logger.setLevel(logging.NOTSET)
2477

    
2478
  # Remove all previously setup handlers
2479
  for handler in root_logger.handlers:
2480
    handler.close()
2481
    root_logger.removeHandler(handler)
2482

    
2483
  if stderr_logging:
2484
    stderr_handler = logging.StreamHandler()
2485
    stderr_handler.setFormatter(formatter)
2486
    if debug:
2487
      stderr_handler.setLevel(logging.NOTSET)
2488
    else:
2489
      stderr_handler.setLevel(logging.CRITICAL)
2490
    root_logger.addHandler(stderr_handler)
2491

    
2492
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2493
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2494
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2495
                                                    facility)
2496
    syslog_handler.setFormatter(sys_fmt)
2497
    # Never enable debug over syslog
2498
    syslog_handler.setLevel(logging.INFO)
2499
    root_logger.addHandler(syslog_handler)
2500

    
2501
  if syslog != constants.SYSLOG_ONLY:
2502
    # this can fail, if the logging directories are not setup or we have
2503
    # a permisssion problem; in this case, it's best to log but ignore
2504
    # the error if stderr_logging is True, and if false we re-raise the
2505
    # exception since otherwise we could run but without any logs at all
2506
    try:
2507
      if console_logging:
2508
        logfile_handler = LogFileHandler(logfile)
2509
      else:
2510
        logfile_handler = logging.FileHandler(logfile)
2511
      logfile_handler.setFormatter(formatter)
2512
      if debug:
2513
        logfile_handler.setLevel(logging.DEBUG)
2514
      else:
2515
        logfile_handler.setLevel(logging.INFO)
2516
      root_logger.addHandler(logfile_handler)
2517
    except EnvironmentError:
2518
      if stderr_logging or syslog == constants.SYSLOG_YES:
2519
        logging.exception("Failed to enable logging to file '%s'", logfile)
2520
      else:
2521
        # we need to re-raise the exception
2522
        raise
2523

    
2524

    
2525
def IsNormAbsPath(path):
2526
  """Check whether a path is absolute and also normalized
2527

2528
  This avoids things like /dir/../../other/path to be valid.
2529

2530
  """
2531
  return os.path.normpath(path) == path and os.path.isabs(path)
2532

    
2533

    
2534
def PathJoin(*args):
2535
  """Safe-join a list of path components.
2536

2537
  Requirements:
2538
      - the first argument must be an absolute path
2539
      - no component in the path must have backtracking (e.g. /../),
2540
        since we check for normalization at the end
2541

2542
  @param args: the path components to be joined
2543
  @raise ValueError: for invalid paths
2544

2545
  """
2546
  # ensure we're having at least one path passed in
2547
  assert args
2548
  # ensure the first component is an absolute and normalized path name
2549
  root = args[0]
2550
  if not IsNormAbsPath(root):
2551
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2552
  result = os.path.join(*args)
2553
  # ensure that the whole path is normalized
2554
  if not IsNormAbsPath(result):
2555
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2556
  # check that we're still under the original prefix
2557
  prefix = os.path.commonprefix([root, result])
2558
  if prefix != root:
2559
    raise ValueError("Error: path joining resulted in different prefix"
2560
                     " (%s != %s)" % (prefix, root))
2561
  return result
2562

    
2563

    
2564
def TailFile(fname, lines=20):
2565
  """Return the last lines from a file.
2566

2567
  @note: this function will only read and parse the last 4KB of
2568
      the file; if the lines are very long, it could be that less
2569
      than the requested number of lines are returned
2570

2571
  @param fname: the file name
2572
  @type lines: int
2573
  @param lines: the (maximum) number of lines to return
2574

2575
  """
2576
  fd = open(fname, "r")
2577
  try:
2578
    fd.seek(0, 2)
2579
    pos = fd.tell()
2580
    pos = max(0, pos-4096)
2581
    fd.seek(pos, 0)
2582
    raw_data = fd.read()
2583
  finally:
2584
    fd.close()
2585

    
2586
  rows = raw_data.splitlines()
2587
  return rows[-lines:]
2588

    
2589

    
2590
def FormatTimestampWithTZ(secs):
2591
  """Formats a Unix timestamp with the local timezone.
2592

2593
  """
2594
  return time.strftime("%F %T %Z", time.gmtime(secs))
2595

    
2596

    
2597
def _ParseAsn1Generalizedtime(value):
2598
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2599

2600
  @type value: string
2601
  @param value: ASN1 GENERALIZEDTIME timestamp
2602

2603
  """
2604
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2605
  if m:
2606
    # We have an offset
2607
    asn1time = m.group(1)
2608
    hours = int(m.group(2))
2609
    minutes = int(m.group(3))
2610
    utcoffset = (60 * hours) + minutes
2611
  else:
2612
    if not value.endswith("Z"):
2613
      raise ValueError("Missing timezone")
2614
    asn1time = value[:-1]
2615
    utcoffset = 0
2616

    
2617
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2618

    
2619
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2620

    
2621
  return calendar.timegm(tt.utctimetuple())
2622

    
2623

    
2624
def GetX509CertValidity(cert):
2625
  """Returns the validity period of the certificate.
2626

2627
  @type cert: OpenSSL.crypto.X509
2628
  @param cert: X509 certificate object
2629

2630
  """
2631
  # The get_notBefore and get_notAfter functions are only supported in
2632
  # pyOpenSSL 0.7 and above.
2633
  try:
2634
    get_notbefore_fn = cert.get_notBefore
2635
  except AttributeError:
2636
    not_before = None
2637
  else:
2638
    not_before_asn1 = get_notbefore_fn()
2639

    
2640
    if not_before_asn1 is None:
2641
      not_before = None
2642
    else:
2643
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2644

    
2645
  try:
2646
    get_notafter_fn = cert.get_notAfter
2647
  except AttributeError:
2648
    not_after = None
2649
  else:
2650
    not_after_asn1 = get_notafter_fn()
2651

    
2652
    if not_after_asn1 is None:
2653
      not_after = None
2654
    else:
2655
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2656

    
2657
  return (not_before, not_after)
2658

    
2659

    
2660
def _VerifyCertificateInner(expired, not_before, not_after, now,
2661
                            warn_days, error_days):
2662
  """Verifies certificate validity.
2663

2664
  @type expired: bool
2665
  @param expired: Whether pyOpenSSL considers the certificate as expired
2666
  @type not_before: number or None
2667
  @param not_before: Unix timestamp before which certificate is not valid
2668
  @type not_after: number or None
2669
  @param not_after: Unix timestamp after which certificate is invalid
2670
  @type now: number
2671
  @param now: Current time as Unix timestamp
2672
  @type warn_days: number or None
2673
  @param warn_days: How many days before expiration a warning should be reported
2674
  @type error_days: number or None
2675
  @param error_days: How many days before expiration an error should be reported
2676

2677
  """
2678
  if expired:
2679
    msg = "Certificate is expired"
2680

    
2681
    if not_before is not None and not_after is not None:
2682
      msg += (" (valid from %s to %s)" %
2683
              (FormatTimestampWithTZ(not_before),
2684
               FormatTimestampWithTZ(not_after)))
2685
    elif not_before is not None:
2686
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2687
    elif not_after is not None:
2688
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2689

    
2690
    return (CERT_ERROR, msg)
2691

    
2692
  elif not_before is not None and not_before > now:
2693
    return (CERT_WARNING,
2694
            "Certificate not yet valid (valid from %s)" %
2695
            FormatTimestampWithTZ(not_before))
2696

    
2697
  elif not_after is not None:
2698
    remaining_days = int((not_after - now) / (24 * 3600))
2699

    
2700
    msg = "Certificate expires in about %d days" % remaining_days
2701

    
2702
    if error_days is not None and remaining_days <= error_days:
2703
      return (CERT_ERROR, msg)
2704

    
2705
    if warn_days is not None and remaining_days <= warn_days:
2706
      return (CERT_WARNING, msg)
2707

    
2708
  return (None, None)
2709

    
2710

    
2711
def VerifyX509Certificate(cert, warn_days, error_days):
2712
  """Verifies a certificate for LUVerifyCluster.
2713

2714
  @type cert: OpenSSL.crypto.X509
2715
  @param cert: X509 certificate object
2716
  @type warn_days: number or None
2717
  @param warn_days: How many days before expiration a warning should be reported
2718
  @type error_days: number or None
2719
  @param error_days: How many days before expiration an error should be reported
2720

2721
  """
2722
  # Depending on the pyOpenSSL version, this can just return (None, None)
2723
  (not_before, not_after) = GetX509CertValidity(cert)
2724

    
2725
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2726
                                 time.time(), warn_days, error_days)
2727

    
2728

    
2729
def SignX509Certificate(cert, key, salt):
2730
  """Sign a X509 certificate.
2731

2732
  An RFC822-like signature header is added in front of the certificate.
2733

2734
  @type cert: OpenSSL.crypto.X509
2735
  @param cert: X509 certificate object
2736
  @type key: string
2737
  @param key: Key for HMAC
2738
  @type salt: string
2739
  @param salt: Salt for HMAC
2740
  @rtype: string
2741
  @return: Serialized and signed certificate in PEM format
2742

2743
  """
2744
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2745
    raise errors.GenericError("Invalid salt: %r" % salt)
2746

    
2747
  # Dumping as PEM here ensures the certificate is in a sane format
2748
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2749

    
2750
  return ("%s: %s/%s\n\n%s" %
2751
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2752
           Sha1Hmac(key, cert_pem, salt=salt),
2753
           cert_pem))
2754

    
2755

    
2756
def _ExtractX509CertificateSignature(cert_pem):
2757
  """Helper function to extract signature from X509 certificate.
2758

2759
  """
2760
  # Extract signature from original PEM data
2761
  for line in cert_pem.splitlines():
2762
    if line.startswith("---"):
2763
      break
2764

    
2765
    m = X509_SIGNATURE.match(line.strip())
2766
    if m:
2767
      return (m.group("salt"), m.group("sign"))
2768

    
2769
  raise errors.GenericError("X509 certificate signature is missing")
2770

    
2771

    
2772
def LoadSignedX509Certificate(cert_pem, key):
2773
  """Verifies a signed X509 certificate.
2774

2775
  @type cert_pem: string
2776
  @param cert_pem: Certificate in PEM format and with signature header
2777
  @type key: string
2778
  @param key: Key for HMAC
2779
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2780
  @return: X509 certificate object and salt
2781

2782
  """
2783
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2784

    
2785
  # Load certificate
2786
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2787

    
2788
  # Dump again to ensure it's in a sane format
2789
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2790

    
2791
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2792
    raise errors.GenericError("X509 certificate signature is invalid")
2793

    
2794
  return (cert, salt)
2795

    
2796

    
2797
def Sha1Hmac(key, text, salt=None):
2798
  """Calculates the HMAC-SHA1 digest of a text.
2799

2800
  HMAC is defined in RFC2104.
2801

2802
  @type key: string
2803
  @param key: Secret key
2804
  @type text: string
2805

2806
  """
2807
  if salt:
2808
    salted_text = salt + text
2809
  else:
2810
    salted_text = text
2811

    
2812
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2813

    
2814

    
2815
def VerifySha1Hmac(key, text, digest, salt=None):
2816
  """Verifies the HMAC-SHA1 digest of a text.
2817

2818
  HMAC is defined in RFC2104.
2819

2820
  @type key: string
2821
  @param key: Secret key
2822
  @type text: string
2823
  @type digest: string
2824
  @param digest: Expected digest
2825
  @rtype: bool
2826
  @return: Whether HMAC-SHA1 digest matches
2827

2828
  """
2829
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2830

    
2831

    
2832
def SafeEncode(text):
2833
  """Return a 'safe' version of a source string.
2834

2835
  This function mangles the input string and returns a version that
2836
  should be safe to display/encode as ASCII. To this end, we first
2837
  convert it to ASCII using the 'backslashreplace' encoding which
2838
  should get rid of any non-ASCII chars, and then we process it
2839
  through a loop copied from the string repr sources in the python; we
2840
  don't use string_escape anymore since that escape single quotes and
2841
  backslashes too, and that is too much; and that escaping is not
2842
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2843

2844
  @type text: str or unicode
2845
  @param text: input data
2846
  @rtype: str
2847
  @return: a safe version of text
2848

2849
  """
2850
  if isinstance(text, unicode):
2851
    # only if unicode; if str already, we handle it below
2852
    text = text.encode('ascii', 'backslashreplace')
2853
  resu = ""
2854
  for char in text:
2855
    c = ord(char)
2856
    if char  == '\t':
2857
      resu += r'\t'
2858
    elif char == '\n':
2859
      resu += r'\n'
2860
    elif char == '\r':
2861
      resu += r'\'r'
2862
    elif c < 32 or c >= 127: # non-printable
2863
      resu += "\\x%02x" % (c & 0xff)
2864
    else:
2865
      resu += char
2866
  return resu
2867

    
2868

    
2869
def UnescapeAndSplit(text, sep=","):
2870
  """Split and unescape a string based on a given separator.
2871

2872
  This function splits a string based on a separator where the
2873
  separator itself can be escape in order to be an element of the
2874
  elements. The escaping rules are (assuming coma being the
2875
  separator):
2876
    - a plain , separates the elements
2877
    - a sequence \\\\, (double backslash plus comma) is handled as a
2878
      backslash plus a separator comma
2879
    - a sequence \, (backslash plus comma) is handled as a
2880
      non-separator comma
2881

2882
  @type text: string
2883
  @param text: the string to split
2884
  @type sep: string
2885
  @param text: the separator
2886
  @rtype: string
2887
  @return: a list of strings
2888

2889
  """
2890
  # we split the list by sep (with no escaping at this stage)
2891
  slist = text.split(sep)
2892
  # next, we revisit the elements and if any of them ended with an odd
2893
  # number of backslashes, then we join it with the next
2894
  rlist = []
2895
  while slist:
2896
    e1 = slist.pop(0)
2897
    if e1.endswith("\\"):
2898
      num_b = len(e1) - len(e1.rstrip("\\"))
2899
      if num_b % 2 == 1:
2900
        e2 = slist.pop(0)
2901
        # here the backslashes remain (all), and will be reduced in
2902
        # the next step
2903
        rlist.append(e1 + sep + e2)
2904
        continue
2905
    rlist.append(e1)
2906
  # finally, replace backslash-something with something
2907
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2908
  return rlist
2909

    
2910

    
2911
def CommaJoin(names):
2912
  """Nicely join a set of identifiers.
2913

2914
  @param names: set, list or tuple
2915
  @return: a string with the formatted results
2916

2917
  """
2918
  return ", ".join([str(val) for val in names])
2919

    
2920

    
2921
def BytesToMebibyte(value):
2922
  """Converts bytes to mebibytes.
2923

2924
  @type value: int
2925
  @param value: Value in bytes
2926
  @rtype: int
2927
  @return: Value in mebibytes
2928

2929
  """
2930
  return int(round(value / (1024.0 * 1024.0), 0))
2931

    
2932

    
2933
def CalculateDirectorySize(path):
2934
  """Calculates the size of a directory recursively.
2935

2936
  @type path: string
2937
  @param path: Path to directory
2938
  @rtype: int
2939
  @return: Size in mebibytes
2940

2941
  """
2942
  size = 0
2943

    
2944
  for (curpath, _, files) in os.walk(path):
2945
    for filename in files:
2946
      st = os.lstat(PathJoin(curpath, filename))
2947
      size += st.st_size
2948

    
2949
  return BytesToMebibyte(size)
2950

    
2951

    
2952
def GetMounts(filename=constants.PROC_MOUNTS):
2953
  """Returns the list of mounted filesystems.
2954

2955
  This function is Linux-specific.
2956

2957
  @param filename: path of mounts file (/proc/mounts by default)
2958
  @rtype: list of tuples
2959
  @return: list of mount entries (device, mountpoint, fstype, options)
2960

2961
  """
2962
  # TODO(iustin): investigate non-Linux options (e.g. via mount output)
2963
  data = []
2964
  mountlines = ReadFile(filename).splitlines()
2965
  for line in mountlines:
2966
    device, mountpoint, fstype, options, _ = line.split(None, 4)
2967
    data.append((device, mountpoint, fstype, options))
2968

    
2969
  return data
2970

    
2971

    
2972
def GetFilesystemStats(path):
2973
  """Returns the total and free space on a filesystem.
2974

2975
  @type path: string
2976
  @param path: Path on filesystem to be examined
2977
  @rtype: int
2978
  @return: tuple of (Total space, Free space) in mebibytes
2979

2980
  """
2981
  st = os.statvfs(path)
2982

    
2983
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2984
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2985
  return (tsize, fsize)
2986

    
2987

    
2988
def RunInSeparateProcess(fn, *args):
2989
  """Runs a function in a separate process.
2990

2991
  Note: Only boolean return values are supported.
2992

2993
  @type fn: callable
2994
  @param fn: Function to be called
2995
  @rtype: bool
2996
  @return: Function's result
2997

2998
  """
2999
  pid = os.fork()
3000
  if pid == 0:
3001
    # Child process
3002
    try:
3003
      # In case the function uses temporary files
3004
      ResetTempfileModule()
3005

    
3006
      # Call function
3007
      result = int(bool(fn(*args)))
3008
      assert result in (0, 1)
3009
    except: # pylint: disable-msg=W0702
3010
      logging.exception("Error while calling function in separate process")
3011
      # 0 and 1 are reserved for the return value
3012
      result = 33
3013

    
3014
    os._exit(result) # pylint: disable-msg=W0212
3015

    
3016
  # Parent process
3017

    
3018
  # Avoid zombies and check exit code
3019
  (_, status) = os.waitpid(pid, 0)
3020

    
3021
  if os.WIFSIGNALED(status):
3022
    exitcode = None
3023
    signum = os.WTERMSIG(status)
3024
  else:
3025
    exitcode = os.WEXITSTATUS(status)
3026
    signum = None
3027

    
3028
  if not (exitcode in (0, 1) and signum is None):
3029
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3030
                              (exitcode, signum))
3031

    
3032
  return bool(exitcode)
3033

    
3034

    
3035
def IgnoreProcessNotFound(fn, *args, **kwargs):
3036
  """Ignores ESRCH when calling a process-related function.
3037

3038
  ESRCH is raised when a process is not found.
3039

3040
  @rtype: bool
3041
  @return: Whether process was found
3042

3043
  """
3044
  try:
3045
    fn(*args, **kwargs)
3046
  except EnvironmentError, err:
3047
    # Ignore ESRCH
3048
    if err.errno == errno.ESRCH:
3049
      return False
3050
    raise
3051

    
3052
  return True
3053

    
3054

    
3055
def IgnoreSignals(fn, *args, **kwargs):
3056
  """Tries to call a function ignoring failures due to EINTR.
3057

3058
  """
3059
  try:
3060
    return fn(*args, **kwargs)
3061
  except EnvironmentError, err:
3062
    if err.errno == errno.EINTR:
3063
      return None
3064
    else:
3065
      raise
3066
  except (select.error, socket.error), err:
3067
    # In python 2.6 and above select.error is an IOError, so it's handled
3068
    # above, in 2.5 and below it's not, and it's handled here.
3069
    if err.args and err.args[0] == errno.EINTR:
3070
      return None
3071
    else:
3072
      raise
3073

    
3074

    
3075
def LockFile(fd):
3076
  """Locks a file using POSIX locks.
3077

3078
  @type fd: int
3079
  @param fd: the file descriptor we need to lock
3080

3081
  """
3082
  try:
3083
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3084
  except IOError, err:
3085
    if err.errno == errno.EAGAIN:
3086
      raise errors.LockError("File already locked")
3087
    raise
3088

    
3089

    
3090
def FormatTime(val):
3091
  """Formats a time value.
3092

3093
  @type val: float or None
3094
  @param val: the timestamp as returned by time.time()
3095
  @return: a string value or N/A if we don't have a valid timestamp
3096

3097
  """
3098
  if val is None or not isinstance(val, (int, float)):
3099
    return "N/A"
3100
  # these two codes works on Linux, but they are not guaranteed on all
3101
  # platforms
3102
  return time.strftime("%F %T", time.localtime(val))
3103

    
3104

    
3105
def FormatSeconds(secs):
3106
  """Formats seconds for easier reading.
3107

3108
  @type secs: number
3109
  @param secs: Number of seconds
3110
  @rtype: string
3111
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3112

3113
  """
3114
  parts = []
3115

    
3116
  secs = round(secs, 0)
3117

    
3118
  if secs > 0:
3119
    # Negative values would be a bit tricky
3120
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3121
      (complete, secs) = divmod(secs, one)
3122
      if complete or parts:
3123
        parts.append("%d%s" % (complete, unit))
3124

    
3125
  parts.append("%ds" % secs)
3126

    
3127
  return " ".join(parts)
3128

    
3129

    
3130
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3131
  """Reads the watcher pause file.
3132

3133
  @type filename: string
3134
  @param filename: Path to watcher pause file
3135
  @type now: None, float or int
3136
  @param now: Current time as Unix timestamp
3137
  @type remove_after: int
3138
  @param remove_after: Remove watcher pause file after specified amount of
3139
    seconds past the pause end time
3140

3141
  """
3142
  if now is None:
3143
    now = time.time()
3144

    
3145
  try:
3146
    value = ReadFile(filename)
3147
  except IOError, err:
3148
    if err.errno != errno.ENOENT:
3149
      raise
3150
    value = None
3151

    
3152
  if value is not None:
3153
    try:
3154
      value = int(value)
3155
    except ValueError:
3156
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3157
                       " removing it"), filename)
3158
      RemoveFile(filename)
3159
      value = None
3160

    
3161
    if value is not None:
3162
      # Remove file if it's outdated
3163
      if now > (value + remove_after):
3164
        RemoveFile(filename)
3165
        value = None
3166

    
3167
      elif now > value:
3168
        value = None
3169

    
3170
  return value
3171

    
3172

    
3173
class RetryTimeout(Exception):
3174
  """Retry loop timed out.
3175

3176
  Any arguments which was passed by the retried function to RetryAgain will be
3177
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3178
  the RaiseInner helper method will reraise it.
3179

3180
  """
3181
  def RaiseInner(self):
3182
    if self.args and isinstance(self.args[0], Exception):
3183
      raise self.args[0]
3184
    else:
3185
      raise RetryTimeout(*self.args)
3186

    
3187

    
3188
class RetryAgain(Exception):
3189
  """Retry again.
3190

3191
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3192
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3193
  of the RetryTimeout() method can be used to reraise it.
3194

3195
  """
3196

    
3197

    
3198
class _RetryDelayCalculator(object):
3199
  """Calculator for increasing delays.
3200

3201
  """
3202
  __slots__ = [
3203
    "_factor",
3204
    "_limit",
3205
    "_next",
3206
    "_start",
3207
    ]
3208

    
3209
  def __init__(self, start, factor, limit):
3210
    """Initializes this class.
3211

3212
    @type start: float
3213
    @param start: Initial delay
3214
    @type factor: float
3215
    @param factor: Factor for delay increase
3216
    @type limit: float or None
3217
    @param limit: Upper limit for delay or None for no limit
3218

3219
    """
3220
    assert start > 0.0
3221
    assert factor >= 1.0
3222
    assert limit is None or limit >= 0.0
3223

    
3224
    self._start = start
3225
    self._factor = factor
3226
    self._limit = limit
3227

    
3228
    self._next = start
3229

    
3230
  def __call__(self):
3231
    """Returns current delay and calculates the next one.
3232

3233
    """
3234
    current = self._next
3235

    
3236
    # Update for next run
3237
    if self._limit is None or self._next < self._limit:
3238
      self._next = min(self._limit, self._next * self._factor)
3239

    
3240
    return current
3241

    
3242

    
3243
#: Special delay to specify whole remaining timeout
3244
RETRY_REMAINING_TIME = object()
3245

    
3246

    
3247
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3248
          _time_fn=time.time):
3249
  """Call a function repeatedly until it succeeds.
3250

3251
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3252
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3253
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3254

3255
  C{delay} can be one of the following:
3256
    - callable returning the delay length as a float
3257
    - Tuple of (start, factor, limit)
3258
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3259
      useful when overriding L{wait_fn} to wait for an external event)
3260
    - A static delay as a number (int or float)
3261

3262
  @type fn: callable
3263
  @param fn: Function to be called
3264
  @param delay: Either a callable (returning the delay), a tuple of (start,
3265
                factor, limit) (see L{_RetryDelayCalculator}),
3266
                L{RETRY_REMAINING_TIME} or a number (int or float)
3267
  @type timeout: float
3268
  @param timeout: Total timeout
3269
  @type wait_fn: callable
3270
  @param wait_fn: Waiting function
3271
  @return: Return value of function
3272

3273
  """
3274
  assert callable(fn)
3275
  assert callable(wait_fn)
3276
  assert callable(_time_fn)
3277

    
3278
  if args is None:
3279
    args = []
3280

    
3281
  end_time = _time_fn() + timeout
3282

    
3283
  if callable(delay):
3284
    # External function to calculate delay
3285
    calc_delay = delay
3286

    
3287
  elif isinstance(delay, (tuple, list)):
3288
    # Increasing delay with optional upper boundary
3289
    (start, factor, limit) = delay
3290
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3291

    
3292
  elif delay is RETRY_REMAINING_TIME:
3293
    # Always use the remaining time
3294
    calc_delay = None
3295

    
3296
  else:
3297
    # Static delay
3298
    calc_delay = lambda: delay
3299

    
3300
  assert calc_delay is None or callable(calc_delay)
3301

    
3302
  while True:
3303
    retry_args = []
3304
    try:
3305
      # pylint: disable-msg=W0142
3306
      return fn(*args)
3307
    except RetryAgain, err:
3308
      retry_args = err.args
3309
    except RetryTimeout:
3310
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3311
                                   " handle RetryTimeout")
3312

    
3313
    remaining_time = end_time - _time_fn()
3314

    
3315
    if remaining_time < 0.0:
3316
      # pylint: disable-msg=W0142
3317
      raise RetryTimeout(*retry_args)
3318

    
3319
    assert remaining_time >= 0.0
3320

    
3321
    if calc_delay is None:
3322
      wait_fn(remaining_time)
3323
    else:
3324
      current_delay = calc_delay()
3325
      if current_delay > 0.0:
3326
        wait_fn(current_delay)
3327

    
3328

    
3329
def GetClosedTempfile(*args, **kwargs):
3330
  """Creates a temporary file and returns its path.
3331

3332
  """
3333
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3334
  _CloseFDNoErr(fd)
3335
  return path
3336

    
3337

    
3338
def GenerateSelfSignedX509Cert(common_name, validity):
3339
  """Generates a self-signed X509 certificate.
3340

3341
  @type common_name: string
3342
  @param common_name: commonName value
3343
  @type validity: int
3344
  @param validity: Validity for certificate in seconds
3345

3346
  """
3347
  # Create private and public key
3348
  key = OpenSSL.crypto.PKey()
3349
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3350

    
3351
  # Create self-signed certificate
3352
  cert = OpenSSL.crypto.X509()
3353
  if common_name:
3354
    cert.get_subject().CN = common_name
3355
  cert.set_serial_number(1)
3356
  cert.gmtime_adj_notBefore(0)
3357
  cert.gmtime_adj_notAfter(validity)
3358
  cert.set_issuer(cert.get_subject())
3359
  cert.set_pubkey(key)
3360
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3361

    
3362
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3363
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3364

    
3365
  return (key_pem, cert_pem)
3366

    
3367

    
3368
def GenerateSelfSignedSslCert(filename, common_name=constants.X509_CERT_CN,
3369
                              validity=constants.X509_CERT_DEFAULT_VALIDITY):
3370
  """Legacy function to generate self-signed X509 certificate.
3371

3372
  @type filename = str
3373
  @param filename = path to write certificate to
3374
  @type common_name: string
3375
  @param common_name: commonName value
3376
  @type validity: int
3377
  @param validity: validity of certificate in number of days
3378

3379
  """
3380
  # TODO: Investigate using the cluster name instead of X505_CERT_CN for
3381
  # common_name, as cluster-renames are very seldom, and it'd be nice if RAPI
3382
  # and node daemon certificates have the proper Subject/Issuer.
3383
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(common_name,
3384
                                                   validity * 24 * 60 * 60)
3385

    
3386
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3387

    
3388

    
3389
class FileLock(object):
3390
  """Utility class for file locks.
3391

3392
  """
3393
  def __init__(self, fd, filename):
3394
    """Constructor for FileLock.
3395

3396
    @type fd: file
3397
    @param fd: File object
3398
    @type filename: str
3399
    @param filename: Path of the file opened at I{fd}
3400

3401
    """
3402
    self.fd = fd
3403
    self.filename = filename
3404

    
3405
  @classmethod
3406
  def Open(cls, filename):
3407
    """Creates and opens a file to be used as a file-based lock.
3408

3409
    @type filename: string
3410
    @param filename: path to the file to be locked
3411

3412
    """
3413
    # Using "os.open" is necessary to allow both opening existing file
3414
    # read/write and creating if not existing. Vanilla "open" will truncate an
3415
    # existing file -or- allow creating if not existing.
3416
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3417
               filename)
3418

    
3419
  def __del__(self):
3420
    self.Close()
3421

    
3422
  def Close(self):
3423
    """Close the file and release the lock.
3424

3425
    """
3426
    if hasattr(self, "fd") and self.fd:
3427
      self.fd.close()
3428
      self.fd = None
3429

    
3430
  def _flock(self, flag, blocking, timeout, errmsg):
3431
    """Wrapper for fcntl.flock.
3432

3433
    @type flag: int
3434
    @param flag: operation flag
3435
    @type blocking: bool
3436
    @param blocking: whether the operation should be done in blocking mode.
3437
    @type timeout: None or float
3438
    @param timeout: for how long the operation should be retried (implies
3439
                    non-blocking mode).
3440
    @type errmsg: string
3441
    @param errmsg: error message in case operation fails.
3442

3443
    """
3444
    assert self.fd, "Lock was closed"
3445
    assert timeout is None or timeout >= 0, \
3446
      "If specified, timeout must be positive"
3447
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3448

    
3449
    # When a timeout is used, LOCK_NB must always be set
3450
    if not (timeout is None and blocking):
3451
      flag |= fcntl.LOCK_NB
3452

    
3453
    if timeout is None:
3454
      self._Lock(self.fd, flag, timeout)
3455
    else:
3456
      try:
3457
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3458
              args=(self.fd, flag, timeout))
3459
      except RetryTimeout:
3460
        raise errors.LockError(errmsg)
3461

    
3462
  @staticmethod
3463
  def _Lock(fd, flag, timeout):
3464
    try:
3465
      fcntl.flock(fd, flag)
3466
    except IOError, err:
3467
      if timeout is not None and err.errno == errno.EAGAIN:
3468
        raise RetryAgain()
3469

    
3470
      logging.exception("fcntl.flock failed")
3471
      raise
3472

    
3473
  def Exclusive(self, blocking=False, timeout=None):
3474
    """Locks the file in exclusive mode.
3475

3476
    @type blocking: boolean
3477
    @param blocking: whether to block and wait until we
3478
        can lock the file or return immediately
3479
    @type timeout: int or None
3480
    @param timeout: if not None, the duration to wait for the lock
3481
        (in blocking mode)
3482

3483
    """
3484
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3485
                "Failed to lock %s in exclusive mode" % self.filename)
3486

    
3487
  def Shared(self, blocking=False, timeout=None):
3488
    """Locks the file in shared mode.
3489

3490
    @type blocking: boolean
3491
    @param blocking: whether to block and wait until we
3492
        can lock the file or return immediately
3493
    @type timeout: int or None
3494
    @param timeout: if not None, the duration to wait for the lock
3495
        (in blocking mode)
3496

3497
    """
3498
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3499
                "Failed to lock %s in shared mode" % self.filename)
3500

    
3501
  def Unlock(self, blocking=True, timeout=None):
3502
    """Unlocks the file.
3503

3504
    According to C{flock(2)}, unlocking can also be a nonblocking
3505
    operation::
3506

3507
      To make a non-blocking request, include LOCK_NB with any of the above
3508
      operations.
3509

3510
    @type blocking: boolean
3511
    @param blocking: whether to block and wait until we
3512
        can lock the file or return immediately
3513
    @type timeout: int or None
3514
    @param timeout: if not None, the duration to wait for the lock
3515
        (in blocking mode)
3516

3517
    """
3518
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3519
                "Failed to unlock %s" % self.filename)
3520

    
3521

    
3522
class LineSplitter:
3523
  """Splits data chunks into lines separated by newline.
3524

3525
  Instances provide a file-like interface.
3526

3527
  """
3528
  def __init__(self, line_fn, *args):
3529
    """Initializes this class.
3530

3531
    @type line_fn: callable
3532
    @param line_fn: Function called for each line, first parameter is line
3533
    @param args: Extra arguments for L{line_fn}
3534

3535
    """
3536
    assert callable(line_fn)
3537

    
3538
    if args:
3539
      # Python 2.4 doesn't have functools.partial yet
3540
      self._line_fn = \
3541
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3542
    else:
3543
      self._line_fn = line_fn
3544

    
3545
    self._lines = collections.deque()
3546
    self._buffer = ""
3547

    
3548
  def write(self, data):
3549
    parts = (self._buffer + data).split("\n")
3550
    self._buffer = parts.pop()
3551
    self._lines.extend(parts)
3552

    
3553
  def flush(self):
3554
    while self._lines:
3555
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3556

    
3557
  def close(self):
3558
    self.flush()
3559
    if self._buffer:
3560
      self._line_fn(self._buffer)
3561

    
3562

    
3563
def SignalHandled(signums):
3564
  """Signal Handled decoration.
3565

3566
  This special decorator installs a signal handler and then calls the target
3567
  function. The function must accept a 'signal_handlers' keyword argument,
3568
  which will contain a dict indexed by signal number, with SignalHandler
3569
  objects as values.
3570

3571
  The decorator can be safely stacked with iself, to handle multiple signals
3572
  with different handlers.
3573

3574
  @type signums: list
3575
  @param signums: signals to intercept
3576

3577
  """
3578
  def wrap(fn):
3579
    def sig_function(*args, **kwargs):
3580
      assert 'signal_handlers' not in kwargs or \
3581
             kwargs['signal_handlers'] is None or \
3582
             isinstance(kwargs['signal_handlers'], dict), \
3583
             "Wrong signal_handlers parameter in original function call"
3584
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3585
        signal_handlers = kwargs['signal_handlers']
3586
      else:
3587
        signal_handlers = {}
3588
        kwargs['signal_handlers'] = signal_handlers
3589
      sighandler = SignalHandler(signums)
3590
      try:
3591
        for sig in signums:
3592
          signal_handlers[sig] = sighandler
3593
        return fn(*args, **kwargs)
3594
      finally:
3595
        sighandler.Reset()
3596
    return sig_function
3597
  return wrap
3598

    
3599

    
3600
class SignalWakeupFd(object):
3601
  try:
3602
    # This is only supported in Python 2.5 and above (some distributions
3603
    # backported it to Python 2.4)
3604
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3605
  except AttributeError:
3606
    # Not supported
3607
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3608
      return -1
3609
  else:
3610
    def _SetWakeupFd(self, fd):
3611
      return self._set_wakeup_fd_fn(fd)
3612

    
3613
  def __init__(self):
3614
    """Initializes this class.
3615

3616
    """
3617
    (read_fd, write_fd) = os.pipe()
3618

    
3619
    # Once these succeeded, the file descriptors will be closed automatically.
3620
    # Buffer size 0 is important, otherwise .read() with a specified length
3621
    # might buffer data and the file descriptors won't be marked readable.
3622
    self._read_fh = os.fdopen(read_fd, "r", 0)
3623
    self._write_fh = os.fdopen(write_fd, "w", 0)
3624

    
3625
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3626

    
3627
    # Utility functions
3628
    self.fileno = self._read_fh.fileno
3629
    self.read = self._read_fh.read
3630

    
3631
  def Reset(self):
3632
    """Restores the previous wakeup file descriptor.
3633

3634
    """
3635
    if hasattr(self, "_previous") and self._previous is not None:
3636
      self._SetWakeupFd(self._previous)
3637
      self._previous = None
3638

    
3639
  def Notify(self):
3640
    """Notifies the wakeup file descriptor.
3641

3642
    """
3643
    self._write_fh.write("\0")
3644

    
3645
  def __del__(self):
3646
    """Called before object deletion.
3647

3648
    """
3649
    self.Reset()
3650

    
3651

    
3652
class SignalHandler(object):
3653
  """Generic signal handler class.
3654

3655
  It automatically restores the original handler when deconstructed or
3656
  when L{Reset} is called. You can either pass your own handler
3657
  function in or query the L{called} attribute to detect whether the
3658
  signal was sent.
3659

3660
  @type signum: list
3661
  @ivar signum: the signals we handle
3662
  @type called: boolean
3663
  @ivar called: tracks whether any of the signals have been raised
3664

3665
  """
3666
  def __init__(self, signum, handler_fn=None, wakeup=None):
3667
    """Constructs a new SignalHandler instance.
3668

3669
    @type signum: int or list of ints
3670
    @param signum: Single signal number or set of signal numbers
3671
    @type handler_fn: callable
3672
    @param handler_fn: Signal handling function
3673

3674
    """
3675
    assert handler_fn is None or callable(handler_fn)
3676

    
3677
    self.signum = set(signum)
3678
    self.called = False
3679

    
3680
    self._handler_fn = handler_fn
3681
    self._wakeup = wakeup
3682

    
3683
    self._previous = {}
3684
    try:
3685
      for signum in self.signum:
3686
        # Setup handler
3687
        prev_handler = signal.signal(signum, self._HandleSignal)
3688
        try:
3689
          self._previous[signum] = prev_handler
3690
        except:
3691
          # Restore previous handler
3692
          signal.signal(signum, prev_handler)
3693
          raise
3694
    except:
3695
      # Reset all handlers
3696
      self.Reset()
3697
      # Here we have a race condition: a handler may have already been called,
3698
      # but there's not much we can do about it at this point.
3699
      raise
3700

    
3701
  def __del__(self):
3702
    self.Reset()
3703

    
3704
  def Reset(self):
3705
    """Restore previous handler.
3706

3707
    This will reset all the signals to their previous handlers.
3708

3709
    """
3710
    for signum, prev_handler in self._previous.items():
3711
      signal.signal(signum, prev_handler)
3712
      # If successful, remove from dict
3713
      del self._previous[signum]
3714

    
3715
  def Clear(self):
3716
    """Unsets the L{called} flag.
3717

3718
    This function can be used in case a signal may arrive several times.
3719

3720
    """
3721
    self.called = False
3722

    
3723
  def _HandleSignal(self, signum, frame):
3724
    """Actual signal handling function.
3725

3726
    """
3727
    # This is not nice and not absolutely atomic, but it appears to be the only
3728
    # solution in Python -- there are no atomic types.
3729
    self.called = True
3730

    
3731
    if self._wakeup:
3732
      # Notify whoever is interested in signals
3733
      self._wakeup.Notify()
3734

    
3735
    if self._handler_fn:
3736
      self._handler_fn(signum, frame)
3737

    
3738

    
3739
class FieldSet(object):
3740
  """A simple field set.
3741

3742
  Among the features are:
3743
    - checking if a string is among a list of static string or regex objects
3744
    - checking if a whole list of string matches
3745
    - returning the matching groups from a regex match
3746

3747
  Internally, all fields are held as regular expression objects.
3748

3749
  """
3750
  def __init__(self, *items):
3751
    self.items = [re.compile("^%s$" % value) for value in items]
3752

    
3753
  def Extend(self, other_set):
3754
    """Extend the field set with the items from another one"""
3755
    self.items.extend(other_set.items)
3756

    
3757
  def Matches(self, field):
3758
    """Checks if a field matches the current set
3759

3760
    @type field: str
3761
    @param field: the string to match
3762
    @return: either None or a regular expression match object
3763

3764
    """
3765
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3766
      return m
3767
    return None
3768

    
3769
  def NonMatching(self, items):
3770
    """Returns the list of fields not matching the current set
3771

3772
    @type items: list
3773
    @param items: the list of fields to check
3774
    @rtype: list
3775
    @return: list of non-matching fields
3776

3777
    """
3778
    return [val for val in items if not self.Matches(val)]