Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ f8ea4ada

History | View | Annotate | Download (104.1 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  import ctypes
59
except ImportError:
60
  ctypes = None
61

    
62
from ganeti import errors
63
from ganeti import constants
64
from ganeti import compat
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
85
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
86
#
87
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
88
# pid_t to "int".
89
#
90
# IEEE Std 1003.1-2008:
91
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
92
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
93
_STRUCT_UCRED = "iII"
94
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
95

    
96
# Certificate verification results
97
(CERT_WARNING,
98
 CERT_ERROR) = range(1, 3)
99

    
100
# Flags for mlockall() (from bits/mman.h)
101
_MCL_CURRENT = 1
102
_MCL_FUTURE = 2
103

    
104

    
105
class RunResult(object):
106
  """Holds the result of running external programs.
107

108
  @type exit_code: int
109
  @ivar exit_code: the exit code of the program, or None (if the program
110
      didn't exit())
111
  @type signal: int or None
112
  @ivar signal: the signal that caused the program to finish, or None
113
      (if the program wasn't terminated by a signal)
114
  @type stdout: str
115
  @ivar stdout: the standard output of the program
116
  @type stderr: str
117
  @ivar stderr: the standard error of the program
118
  @type failed: boolean
119
  @ivar failed: True in case the program was
120
      terminated by a signal or exited with a non-zero exit code
121
  @ivar fail_reason: a string detailing the termination reason
122

123
  """
124
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
125
               "failed", "fail_reason", "cmd"]
126

    
127

    
128
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
129
    self.cmd = cmd
130
    self.exit_code = exit_code
131
    self.signal = signal_
132
    self.stdout = stdout
133
    self.stderr = stderr
134
    self.failed = (signal_ is not None or exit_code != 0)
135

    
136
    if self.signal is not None:
137
      self.fail_reason = "terminated by signal %s" % self.signal
138
    elif self.exit_code is not None:
139
      self.fail_reason = "exited with exit code %s" % self.exit_code
140
    else:
141
      self.fail_reason = "unable to determine termination reason"
142

    
143
    if self.failed:
144
      logging.debug("Command '%s' failed (%s); output: %s",
145
                    self.cmd, self.fail_reason, self.output)
146

    
147
  def _GetOutput(self):
148
    """Returns the combined stdout and stderr for easier usage.
149

150
    """
151
    return self.stdout + self.stderr
152

    
153
  output = property(_GetOutput, None, None, "Return full output")
154

    
155

    
156
def _BuildCmdEnvironment(env, reset):
157
  """Builds the environment for an external program.
158

159
  """
160
  if reset:
161
    cmd_env = {}
162
  else:
163
    cmd_env = os.environ.copy()
164
    cmd_env["LC_ALL"] = "C"
165

    
166
  if env is not None:
167
    cmd_env.update(env)
168

    
169
  return cmd_env
170

    
171

    
172
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
173
  """Execute a (shell) command.
174

175
  The command should not read from its standard input, as it will be
176
  closed.
177

178
  @type cmd: string or list
179
  @param cmd: Command to run
180
  @type env: dict
181
  @param env: Additional environment variables
182
  @type output: str
183
  @param output: if desired, the output of the command can be
184
      saved in a file instead of the RunResult instance; this
185
      parameter denotes the file name (if not None)
186
  @type cwd: string
187
  @param cwd: if specified, will be used as the working
188
      directory for the command; the default will be /
189
  @type reset_env: boolean
190
  @param reset_env: whether to reset or keep the default os environment
191
  @rtype: L{RunResult}
192
  @return: RunResult instance
193
  @raise errors.ProgrammerError: if we call this when forks are disabled
194

195
  """
196
  if no_fork:
197
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
198

    
199
  if isinstance(cmd, basestring):
200
    strcmd = cmd
201
    shell = True
202
  else:
203
    cmd = [str(val) for val in cmd]
204
    strcmd = ShellQuoteArgs(cmd)
205
    shell = False
206

    
207
  if output:
208
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
209
  else:
210
    logging.debug("RunCmd %s", strcmd)
211

    
212
  cmd_env = _BuildCmdEnvironment(env, reset_env)
213

    
214
  try:
215
    if output is None:
216
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
217
    else:
218
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
219
      out = err = ""
220
  except OSError, err:
221
    if err.errno == errno.ENOENT:
222
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
223
                               (strcmd, err))
224
    else:
225
      raise
226

    
227
  if status >= 0:
228
    exitcode = status
229
    signal_ = None
230
  else:
231
    exitcode = None
232
    signal_ = -status
233

    
234
  return RunResult(exitcode, signal_, out, err, strcmd)
235

    
236

    
237
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
238
                pidfile=None):
239
  """Start a daemon process after forking twice.
240

241
  @type cmd: string or list
242
  @param cmd: Command to run
243
  @type env: dict
244
  @param env: Additional environment variables
245
  @type cwd: string
246
  @param cwd: Working directory for the program
247
  @type output: string
248
  @param output: Path to file in which to save the output
249
  @type output_fd: int
250
  @param output_fd: File descriptor for output
251
  @type pidfile: string
252
  @param pidfile: Process ID file
253
  @rtype: int
254
  @return: Daemon process ID
255
  @raise errors.ProgrammerError: if we call this when forks are disabled
256

257
  """
258
  if no_fork:
259
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
260
                                 " disabled")
261

    
262
  if output and not (bool(output) ^ (output_fd is not None)):
263
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
264
                                 " specified")
265

    
266
  if isinstance(cmd, basestring):
267
    cmd = ["/bin/sh", "-c", cmd]
268

    
269
  strcmd = ShellQuoteArgs(cmd)
270

    
271
  if output:
272
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
273
  else:
274
    logging.debug("StartDaemon %s", strcmd)
275

    
276
  cmd_env = _BuildCmdEnvironment(env, False)
277

    
278
  # Create pipe for sending PID back
279
  (pidpipe_read, pidpipe_write) = os.pipe()
280
  try:
281
    try:
282
      # Create pipe for sending error messages
283
      (errpipe_read, errpipe_write) = os.pipe()
284
      try:
285
        try:
286
          # First fork
287
          pid = os.fork()
288
          if pid == 0:
289
            try:
290
              # Child process, won't return
291
              _StartDaemonChild(errpipe_read, errpipe_write,
292
                                pidpipe_read, pidpipe_write,
293
                                cmd, cmd_env, cwd,
294
                                output, output_fd, pidfile)
295
            finally:
296
              # Well, maybe child process failed
297
              os._exit(1) # pylint: disable-msg=W0212
298
        finally:
299
          _CloseFDNoErr(errpipe_write)
300

    
301
        # Wait for daemon to be started (or an error message to arrive) and read
302
        # up to 100 KB as an error message
303
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
304
      finally:
305
        _CloseFDNoErr(errpipe_read)
306
    finally:
307
      _CloseFDNoErr(pidpipe_write)
308

    
309
    # Read up to 128 bytes for PID
310
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
311
  finally:
312
    _CloseFDNoErr(pidpipe_read)
313

    
314
  # Try to avoid zombies by waiting for child process
315
  try:
316
    os.waitpid(pid, 0)
317
  except OSError:
318
    pass
319

    
320
  if errormsg:
321
    raise errors.OpExecError("Error when starting daemon process: %r" %
322
                             errormsg)
323

    
324
  try:
325
    return int(pidtext)
326
  except (ValueError, TypeError), err:
327
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
328
                             (pidtext, err))
329

    
330

    
331
def _StartDaemonChild(errpipe_read, errpipe_write,
332
                      pidpipe_read, pidpipe_write,
333
                      args, env, cwd,
334
                      output, fd_output, pidfile):
335
  """Child process for starting daemon.
336

337
  """
338
  try:
339
    # Close parent's side
340
    _CloseFDNoErr(errpipe_read)
341
    _CloseFDNoErr(pidpipe_read)
342

    
343
    # First child process
344
    os.chdir("/")
345
    os.umask(077)
346
    os.setsid()
347

    
348
    # And fork for the second time
349
    pid = os.fork()
350
    if pid != 0:
351
      # Exit first child process
352
      os._exit(0) # pylint: disable-msg=W0212
353

    
354
    # Make sure pipe is closed on execv* (and thereby notifies original process)
355
    SetCloseOnExecFlag(errpipe_write, True)
356

    
357
    # List of file descriptors to be left open
358
    noclose_fds = [errpipe_write]
359

    
360
    # Open PID file
361
    if pidfile:
362
      try:
363
        # TODO: Atomic replace with another locked file instead of writing into
364
        # it after creating
365
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
366

    
367
        # Lock the PID file (and fail if not possible to do so). Any code
368
        # wanting to send a signal to the daemon should try to lock the PID
369
        # file before reading it. If acquiring the lock succeeds, the daemon is
370
        # no longer running and the signal should not be sent.
371
        LockFile(fd_pidfile)
372

    
373
        os.write(fd_pidfile, "%d\n" % os.getpid())
374
      except Exception, err:
375
        raise Exception("Creating and locking PID file failed: %s" % err)
376

    
377
      # Keeping the file open to hold the lock
378
      noclose_fds.append(fd_pidfile)
379

    
380
      SetCloseOnExecFlag(fd_pidfile, False)
381
    else:
382
      fd_pidfile = None
383

    
384
    # Open /dev/null
385
    fd_devnull = os.open(os.devnull, os.O_RDWR)
386

    
387
    assert not output or (bool(output) ^ (fd_output is not None))
388

    
389
    if fd_output is not None:
390
      pass
391
    elif output:
392
      # Open output file
393
      try:
394
        # TODO: Implement flag to set append=yes/no
395
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
396
      except EnvironmentError, err:
397
        raise Exception("Opening output file failed: %s" % err)
398
    else:
399
      fd_output = fd_devnull
400

    
401
    # Redirect standard I/O
402
    os.dup2(fd_devnull, 0)
403
    os.dup2(fd_output, 1)
404
    os.dup2(fd_output, 2)
405

    
406
    # Send daemon PID to parent
407
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
408

    
409
    # Close all file descriptors except stdio and error message pipe
410
    CloseFDs(noclose_fds=noclose_fds)
411

    
412
    # Change working directory
413
    os.chdir(cwd)
414

    
415
    if env is None:
416
      os.execvp(args[0], args)
417
    else:
418
      os.execvpe(args[0], args, env)
419
  except: # pylint: disable-msg=W0702
420
    try:
421
      # Report errors to original process
422
      buf = str(sys.exc_info()[1])
423

    
424
      RetryOnSignal(os.write, errpipe_write, buf)
425
    except: # pylint: disable-msg=W0702
426
      # Ignore errors in error handling
427
      pass
428

    
429
  os._exit(1) # pylint: disable-msg=W0212
430

    
431

    
432
def _RunCmdPipe(cmd, env, via_shell, cwd):
433
  """Run a command and return its output.
434

435
  @type  cmd: string or list
436
  @param cmd: Command to run
437
  @type env: dict
438
  @param env: The environment to use
439
  @type via_shell: bool
440
  @param via_shell: if we should run via the shell
441
  @type cwd: string
442
  @param cwd: the working directory for the program
443
  @rtype: tuple
444
  @return: (out, err, status)
445

446
  """
447
  poller = select.poll()
448
  child = subprocess.Popen(cmd, shell=via_shell,
449
                           stderr=subprocess.PIPE,
450
                           stdout=subprocess.PIPE,
451
                           stdin=subprocess.PIPE,
452
                           close_fds=True, env=env,
453
                           cwd=cwd)
454

    
455
  child.stdin.close()
456
  poller.register(child.stdout, select.POLLIN)
457
  poller.register(child.stderr, select.POLLIN)
458
  out = StringIO()
459
  err = StringIO()
460
  fdmap = {
461
    child.stdout.fileno(): (out, child.stdout),
462
    child.stderr.fileno(): (err, child.stderr),
463
    }
464
  for fd in fdmap:
465
    SetNonblockFlag(fd, True)
466

    
467
  while fdmap:
468
    pollresult = RetryOnSignal(poller.poll)
469

    
470
    for fd, event in pollresult:
471
      if event & select.POLLIN or event & select.POLLPRI:
472
        data = fdmap[fd][1].read()
473
        # no data from read signifies EOF (the same as POLLHUP)
474
        if not data:
475
          poller.unregister(fd)
476
          del fdmap[fd]
477
          continue
478
        fdmap[fd][0].write(data)
479
      if (event & select.POLLNVAL or event & select.POLLHUP or
480
          event & select.POLLERR):
481
        poller.unregister(fd)
482
        del fdmap[fd]
483

    
484
  out = out.getvalue()
485
  err = err.getvalue()
486

    
487
  status = child.wait()
488
  return out, err, status
489

    
490

    
491
def _RunCmdFile(cmd, env, via_shell, output, cwd):
492
  """Run a command and save its output to a file.
493

494
  @type  cmd: string or list
495
  @param cmd: Command to run
496
  @type env: dict
497
  @param env: The environment to use
498
  @type via_shell: bool
499
  @param via_shell: if we should run via the shell
500
  @type output: str
501
  @param output: the filename in which to save the output
502
  @type cwd: string
503
  @param cwd: the working directory for the program
504
  @rtype: int
505
  @return: the exit status
506

507
  """
508
  fh = open(output, "a")
509
  try:
510
    child = subprocess.Popen(cmd, shell=via_shell,
511
                             stderr=subprocess.STDOUT,
512
                             stdout=fh,
513
                             stdin=subprocess.PIPE,
514
                             close_fds=True, env=env,
515
                             cwd=cwd)
516

    
517
    child.stdin.close()
518
    status = child.wait()
519
  finally:
520
    fh.close()
521
  return status
522

    
523

    
524
def SetCloseOnExecFlag(fd, enable):
525
  """Sets or unsets the close-on-exec flag on a file descriptor.
526

527
  @type fd: int
528
  @param fd: File descriptor
529
  @type enable: bool
530
  @param enable: Whether to set or unset it.
531

532
  """
533
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
534

    
535
  if enable:
536
    flags |= fcntl.FD_CLOEXEC
537
  else:
538
    flags &= ~fcntl.FD_CLOEXEC
539

    
540
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
541

    
542

    
543
def SetNonblockFlag(fd, enable):
544
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
545

546
  @type fd: int
547
  @param fd: File descriptor
548
  @type enable: bool
549
  @param enable: Whether to set or unset it
550

551
  """
552
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
553

    
554
  if enable:
555
    flags |= os.O_NONBLOCK
556
  else:
557
    flags &= ~os.O_NONBLOCK
558

    
559
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
560

    
561

    
562
def RetryOnSignal(fn, *args, **kwargs):
563
  """Calls a function again if it failed due to EINTR.
564

565
  """
566
  while True:
567
    try:
568
      return fn(*args, **kwargs)
569
    except EnvironmentError, err:
570
      if err.errno != errno.EINTR:
571
        raise
572
    except (socket.error, select.error), err:
573
      # In python 2.6 and above select.error is an IOError, so it's handled
574
      # above, in 2.5 and below it's not, and it's handled here.
575
      if not (err.args and err.args[0] == errno.EINTR):
576
        raise
577

    
578

    
579
def RunParts(dir_name, env=None, reset_env=False):
580
  """Run Scripts or programs in a directory
581

582
  @type dir_name: string
583
  @param dir_name: absolute path to a directory
584
  @type env: dict
585
  @param env: The environment to use
586
  @type reset_env: boolean
587
  @param reset_env: whether to reset or keep the default os environment
588
  @rtype: list of tuples
589
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
590

591
  """
592
  rr = []
593

    
594
  try:
595
    dir_contents = ListVisibleFiles(dir_name)
596
  except OSError, err:
597
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
598
    return rr
599

    
600
  for relname in sorted(dir_contents):
601
    fname = PathJoin(dir_name, relname)
602
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
603
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
604
      rr.append((relname, constants.RUNPARTS_SKIP, None))
605
    else:
606
      try:
607
        result = RunCmd([fname], env=env, reset_env=reset_env)
608
      except Exception, err: # pylint: disable-msg=W0703
609
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
610
      else:
611
        rr.append((relname, constants.RUNPARTS_RUN, result))
612

    
613
  return rr
614

    
615

    
616
def GetSocketCredentials(sock):
617
  """Returns the credentials of the foreign process connected to a socket.
618

619
  @param sock: Unix socket
620
  @rtype: tuple; (number, number, number)
621
  @return: The PID, UID and GID of the connected foreign process.
622

623
  """
624
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
625
                             _STRUCT_UCRED_SIZE)
626
  return struct.unpack(_STRUCT_UCRED, peercred)
627

    
628

    
629
def RemoveFile(filename):
630
  """Remove a file ignoring some errors.
631

632
  Remove a file, ignoring non-existing ones or directories. Other
633
  errors are passed.
634

635
  @type filename: str
636
  @param filename: the file to be removed
637

638
  """
639
  try:
640
    os.unlink(filename)
641
  except OSError, err:
642
    if err.errno not in (errno.ENOENT, errno.EISDIR):
643
      raise
644

    
645

    
646
def RemoveDir(dirname):
647
  """Remove an empty directory.
648

649
  Remove a directory, ignoring non-existing ones.
650
  Other errors are passed. This includes the case,
651
  where the directory is not empty, so it can't be removed.
652

653
  @type dirname: str
654
  @param dirname: the empty directory to be removed
655

656
  """
657
  try:
658
    os.rmdir(dirname)
659
  except OSError, err:
660
    if err.errno != errno.ENOENT:
661
      raise
662

    
663

    
664
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
665
  """Renames a file.
666

667
  @type old: string
668
  @param old: Original path
669
  @type new: string
670
  @param new: New path
671
  @type mkdir: bool
672
  @param mkdir: Whether to create target directory if it doesn't exist
673
  @type mkdir_mode: int
674
  @param mkdir_mode: Mode for newly created directories
675

676
  """
677
  try:
678
    return os.rename(old, new)
679
  except OSError, err:
680
    # In at least one use case of this function, the job queue, directory
681
    # creation is very rare. Checking for the directory before renaming is not
682
    # as efficient.
683
    if mkdir and err.errno == errno.ENOENT:
684
      # Create directory and try again
685
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
686

    
687
      return os.rename(old, new)
688

    
689
    raise
690

    
691

    
692
def Makedirs(path, mode=0750):
693
  """Super-mkdir; create a leaf directory and all intermediate ones.
694

695
  This is a wrapper around C{os.makedirs} adding error handling not implemented
696
  before Python 2.5.
697

698
  """
699
  try:
700
    os.makedirs(path, mode)
701
  except OSError, err:
702
    # Ignore EEXIST. This is only handled in os.makedirs as included in
703
    # Python 2.5 and above.
704
    if err.errno != errno.EEXIST or not os.path.exists(path):
705
      raise
706

    
707

    
708
def ResetTempfileModule():
709
  """Resets the random name generator of the tempfile module.
710

711
  This function should be called after C{os.fork} in the child process to
712
  ensure it creates a newly seeded random generator. Otherwise it would
713
  generate the same random parts as the parent process. If several processes
714
  race for the creation of a temporary file, this could lead to one not getting
715
  a temporary name.
716

717
  """
718
  # pylint: disable-msg=W0212
719
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
720
    tempfile._once_lock.acquire()
721
    try:
722
      # Reset random name generator
723
      tempfile._name_sequence = None
724
    finally:
725
      tempfile._once_lock.release()
726
  else:
727
    logging.critical("The tempfile module misses at least one of the"
728
                     " '_once_lock' and '_name_sequence' attributes")
729

    
730

    
731
def _FingerprintFile(filename):
732
  """Compute the fingerprint of a file.
733

734
  If the file does not exist, a None will be returned
735
  instead.
736

737
  @type filename: str
738
  @param filename: the filename to checksum
739
  @rtype: str
740
  @return: the hex digest of the sha checksum of the contents
741
      of the file
742

743
  """
744
  if not (os.path.exists(filename) and os.path.isfile(filename)):
745
    return None
746

    
747
  f = open(filename)
748

    
749
  fp = compat.sha1_hash()
750
  while True:
751
    data = f.read(4096)
752
    if not data:
753
      break
754

    
755
    fp.update(data)
756

    
757
  return fp.hexdigest()
758

    
759

    
760
def FingerprintFiles(files):
761
  """Compute fingerprints for a list of files.
762

763
  @type files: list
764
  @param files: the list of filename to fingerprint
765
  @rtype: dict
766
  @return: a dictionary filename: fingerprint, holding only
767
      existing files
768

769
  """
770
  ret = {}
771

    
772
  for filename in files:
773
    cksum = _FingerprintFile(filename)
774
    if cksum:
775
      ret[filename] = cksum
776

    
777
  return ret
778

    
779

    
780
def ForceDictType(target, key_types, allowed_values=None):
781
  """Force the values of a dict to have certain types.
782

783
  @type target: dict
784
  @param target: the dict to update
785
  @type key_types: dict
786
  @param key_types: dict mapping target dict keys to types
787
                    in constants.ENFORCEABLE_TYPES
788
  @type allowed_values: list
789
  @keyword allowed_values: list of specially allowed values
790

791
  """
792
  if allowed_values is None:
793
    allowed_values = []
794

    
795
  if not isinstance(target, dict):
796
    msg = "Expected dictionary, got '%s'" % target
797
    raise errors.TypeEnforcementError(msg)
798

    
799
  for key in target:
800
    if key not in key_types:
801
      msg = "Unknown key '%s'" % key
802
      raise errors.TypeEnforcementError(msg)
803

    
804
    if target[key] in allowed_values:
805
      continue
806

    
807
    ktype = key_types[key]
808
    if ktype not in constants.ENFORCEABLE_TYPES:
809
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
810
      raise errors.ProgrammerError(msg)
811

    
812
    if ktype == constants.VTYPE_STRING:
813
      if not isinstance(target[key], basestring):
814
        if isinstance(target[key], bool) and not target[key]:
815
          target[key] = ''
816
        else:
817
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
818
          raise errors.TypeEnforcementError(msg)
819
    elif ktype == constants.VTYPE_BOOL:
820
      if isinstance(target[key], basestring) and target[key]:
821
        if target[key].lower() == constants.VALUE_FALSE:
822
          target[key] = False
823
        elif target[key].lower() == constants.VALUE_TRUE:
824
          target[key] = True
825
        else:
826
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
827
          raise errors.TypeEnforcementError(msg)
828
      elif target[key]:
829
        target[key] = True
830
      else:
831
        target[key] = False
832
    elif ktype == constants.VTYPE_SIZE:
833
      try:
834
        target[key] = ParseUnit(target[key])
835
      except errors.UnitParseError, err:
836
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
837
              (key, target[key], err)
838
        raise errors.TypeEnforcementError(msg)
839
    elif ktype == constants.VTYPE_INT:
840
      try:
841
        target[key] = int(target[key])
842
      except (ValueError, TypeError):
843
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
844
        raise errors.TypeEnforcementError(msg)
845

    
846

    
847
def _GetProcStatusPath(pid):
848
  """Returns the path for a PID's proc status file.
849

850
  @type pid: int
851
  @param pid: Process ID
852
  @rtype: string
853

854
  """
855
  return "/proc/%d/status" % pid
856

    
857

    
858
def IsProcessAlive(pid):
859
  """Check if a given pid exists on the system.
860

861
  @note: zombie status is not handled, so zombie processes
862
      will be returned as alive
863
  @type pid: int
864
  @param pid: the process ID to check
865
  @rtype: boolean
866
  @return: True if the process exists
867

868
  """
869
  def _TryStat(name):
870
    try:
871
      os.stat(name)
872
      return True
873
    except EnvironmentError, err:
874
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
875
        return False
876
      elif err.errno == errno.EINVAL:
877
        raise RetryAgain(err)
878
      raise
879

    
880
  assert isinstance(pid, int), "pid must be an integer"
881
  if pid <= 0:
882
    return False
883

    
884
  # /proc in a multiprocessor environment can have strange behaviors.
885
  # Retry the os.stat a few times until we get a good result.
886
  try:
887
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
888
                 args=[_GetProcStatusPath(pid)])
889
  except RetryTimeout, err:
890
    err.RaiseInner()
891

    
892

    
893
def _ParseSigsetT(sigset):
894
  """Parse a rendered sigset_t value.
895

896
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
897
  function.
898

899
  @type sigset: string
900
  @param sigset: Rendered signal set from /proc/$pid/status
901
  @rtype: set
902
  @return: Set of all enabled signal numbers
903

904
  """
905
  result = set()
906

    
907
  signum = 0
908
  for ch in reversed(sigset):
909
    chv = int(ch, 16)
910

    
911
    # The following could be done in a loop, but it's easier to read and
912
    # understand in the unrolled form
913
    if chv & 1:
914
      result.add(signum + 1)
915
    if chv & 2:
916
      result.add(signum + 2)
917
    if chv & 4:
918
      result.add(signum + 3)
919
    if chv & 8:
920
      result.add(signum + 4)
921

    
922
    signum += 4
923

    
924
  return result
925

    
926

    
927
def _GetProcStatusField(pstatus, field):
928
  """Retrieves a field from the contents of a proc status file.
929

930
  @type pstatus: string
931
  @param pstatus: Contents of /proc/$pid/status
932
  @type field: string
933
  @param field: Name of field whose value should be returned
934
  @rtype: string
935

936
  """
937
  for line in pstatus.splitlines():
938
    parts = line.split(":", 1)
939

    
940
    if len(parts) < 2 or parts[0] != field:
941
      continue
942

    
943
    return parts[1].strip()
944

    
945
  return None
946

    
947

    
948
def IsProcessHandlingSignal(pid, signum, status_path=None):
949
  """Checks whether a process is handling a signal.
950

951
  @type pid: int
952
  @param pid: Process ID
953
  @type signum: int
954
  @param signum: Signal number
955
  @rtype: bool
956

957
  """
958
  if status_path is None:
959
    status_path = _GetProcStatusPath(pid)
960

    
961
  try:
962
    proc_status = ReadFile(status_path)
963
  except EnvironmentError, err:
964
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
965
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
966
      return False
967
    raise
968

    
969
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
970
  if sigcgt is None:
971
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
972

    
973
  # Now check whether signal is handled
974
  return signum in _ParseSigsetT(sigcgt)
975

    
976

    
977
def ReadPidFile(pidfile):
978
  """Read a pid from a file.
979

980
  @type  pidfile: string
981
  @param pidfile: path to the file containing the pid
982
  @rtype: int
983
  @return: The process id, if the file exists and contains a valid PID,
984
           otherwise 0
985

986
  """
987
  try:
988
    raw_data = ReadOneLineFile(pidfile)
989
  except EnvironmentError, err:
990
    if err.errno != errno.ENOENT:
991
      logging.exception("Can't read pid file")
992
    return 0
993

    
994
  try:
995
    pid = int(raw_data)
996
  except (TypeError, ValueError), err:
997
    logging.info("Can't parse pid file contents", exc_info=True)
998
    return 0
999

    
1000
  return pid
1001

    
1002

    
1003
def ReadLockedPidFile(path):
1004
  """Reads a locked PID file.
1005

1006
  This can be used together with L{StartDaemon}.
1007

1008
  @type path: string
1009
  @param path: Path to PID file
1010
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1011

1012
  """
1013
  try:
1014
    fd = os.open(path, os.O_RDONLY)
1015
  except EnvironmentError, err:
1016
    if err.errno == errno.ENOENT:
1017
      # PID file doesn't exist
1018
      return None
1019
    raise
1020

    
1021
  try:
1022
    try:
1023
      # Try to acquire lock
1024
      LockFile(fd)
1025
    except errors.LockError:
1026
      # Couldn't lock, daemon is running
1027
      return int(os.read(fd, 100))
1028
  finally:
1029
    os.close(fd)
1030

    
1031
  return None
1032

    
1033

    
1034
def MatchNameComponent(key, name_list, case_sensitive=True):
1035
  """Try to match a name against a list.
1036

1037
  This function will try to match a name like test1 against a list
1038
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1039
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1040
  not I{'test1.ex'}. A multiple match will be considered as no match
1041
  at all (e.g. I{'test1'} against C{['test1.example.com',
1042
  'test1.example.org']}), except when the key fully matches an entry
1043
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1044

1045
  @type key: str
1046
  @param key: the name to be searched
1047
  @type name_list: list
1048
  @param name_list: the list of strings against which to search the key
1049
  @type case_sensitive: boolean
1050
  @param case_sensitive: whether to provide a case-sensitive match
1051

1052
  @rtype: None or str
1053
  @return: None if there is no match I{or} if there are multiple matches,
1054
      otherwise the element from the list which matches
1055

1056
  """
1057
  if key in name_list:
1058
    return key
1059

    
1060
  re_flags = 0
1061
  if not case_sensitive:
1062
    re_flags |= re.IGNORECASE
1063
    key = key.upper()
1064
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1065
  names_filtered = []
1066
  string_matches = []
1067
  for name in name_list:
1068
    if mo.match(name) is not None:
1069
      names_filtered.append(name)
1070
      if not case_sensitive and key == name.upper():
1071
        string_matches.append(name)
1072

    
1073
  if len(string_matches) == 1:
1074
    return string_matches[0]
1075
  if len(names_filtered) == 1:
1076
    return names_filtered[0]
1077
  return None
1078

    
1079

    
1080
class HostInfo:
1081
  """Class implementing resolver and hostname functionality
1082

1083
  """
1084
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1085

    
1086
  def __init__(self, name=None):
1087
    """Initialize the host name object.
1088

1089
    If the name argument is not passed, it will use this system's
1090
    name.
1091

1092
    """
1093
    if name is None:
1094
      name = self.SysName()
1095

    
1096
    self.query = name
1097
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1098
    self.ip = self.ipaddrs[0]
1099

    
1100
  def ShortName(self):
1101
    """Returns the hostname without domain.
1102

1103
    """
1104
    return self.name.split('.')[0]
1105

    
1106
  @staticmethod
1107
  def SysName():
1108
    """Return the current system's name.
1109

1110
    This is simply a wrapper over C{socket.gethostname()}.
1111

1112
    """
1113
    return socket.gethostname()
1114

    
1115
  @staticmethod
1116
  def LookupHostname(hostname):
1117
    """Look up hostname
1118

1119
    @type hostname: str
1120
    @param hostname: hostname to look up
1121

1122
    @rtype: tuple
1123
    @return: a tuple (name, aliases, ipaddrs) as returned by
1124
        C{socket.gethostbyname_ex}
1125
    @raise errors.ResolverError: in case of errors in resolving
1126

1127
    """
1128
    try:
1129
      result = socket.gethostbyname_ex(hostname)
1130
    except socket.gaierror, err:
1131
      # hostname not found in DNS
1132
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1133

    
1134
    return result
1135

    
1136
  @classmethod
1137
  def NormalizeName(cls, hostname):
1138
    """Validate and normalize the given hostname.
1139

1140
    @attention: the validation is a bit more relaxed than the standards
1141
        require; most importantly, we allow underscores in names
1142
    @raise errors.OpPrereqError: when the name is not valid
1143

1144
    """
1145
    hostname = hostname.lower()
1146
    if (not cls._VALID_NAME_RE.match(hostname) or
1147
        # double-dots, meaning empty label
1148
        ".." in hostname or
1149
        # empty initial label
1150
        hostname.startswith(".")):
1151
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1152
                                 errors.ECODE_INVAL)
1153
    if hostname.endswith("."):
1154
      hostname = hostname.rstrip(".")
1155
    return hostname
1156

    
1157

    
1158
def GetHostInfo(name=None):
1159
  """Lookup host name and raise an OpPrereqError for failures"""
1160

    
1161
  try:
1162
    return HostInfo(name)
1163
  except errors.ResolverError, err:
1164
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1165
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1166

    
1167

    
1168
def ListVolumeGroups():
1169
  """List volume groups and their size
1170

1171
  @rtype: dict
1172
  @return:
1173
       Dictionary with keys volume name and values
1174
       the size of the volume
1175

1176
  """
1177
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1178
  result = RunCmd(command)
1179
  retval = {}
1180
  if result.failed:
1181
    return retval
1182

    
1183
  for line in result.stdout.splitlines():
1184
    try:
1185
      name, size = line.split()
1186
      size = int(float(size))
1187
    except (IndexError, ValueError), err:
1188
      logging.error("Invalid output from vgs (%s): %s", err, line)
1189
      continue
1190

    
1191
    retval[name] = size
1192

    
1193
  return retval
1194

    
1195

    
1196
def BridgeExists(bridge):
1197
  """Check whether the given bridge exists in the system
1198

1199
  @type bridge: str
1200
  @param bridge: the bridge name to check
1201
  @rtype: boolean
1202
  @return: True if it does
1203

1204
  """
1205
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1206

    
1207

    
1208
def NiceSort(name_list):
1209
  """Sort a list of strings based on digit and non-digit groupings.
1210

1211
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1212
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1213
  'a11']}.
1214

1215
  The sort algorithm breaks each name in groups of either only-digits
1216
  or no-digits. Only the first eight such groups are considered, and
1217
  after that we just use what's left of the string.
1218

1219
  @type name_list: list
1220
  @param name_list: the names to be sorted
1221
  @rtype: list
1222
  @return: a copy of the name list sorted with our algorithm
1223

1224
  """
1225
  _SORTER_BASE = "(\D+|\d+)"
1226
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1227
                                                  _SORTER_BASE, _SORTER_BASE,
1228
                                                  _SORTER_BASE, _SORTER_BASE,
1229
                                                  _SORTER_BASE, _SORTER_BASE)
1230
  _SORTER_RE = re.compile(_SORTER_FULL)
1231
  _SORTER_NODIGIT = re.compile("^\D*$")
1232
  def _TryInt(val):
1233
    """Attempts to convert a variable to integer."""
1234
    if val is None or _SORTER_NODIGIT.match(val):
1235
      return val
1236
    rval = int(val)
1237
    return rval
1238

    
1239
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1240
             for name in name_list]
1241
  to_sort.sort()
1242
  return [tup[1] for tup in to_sort]
1243

    
1244

    
1245
def TryConvert(fn, val):
1246
  """Try to convert a value ignoring errors.
1247

1248
  This function tries to apply function I{fn} to I{val}. If no
1249
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1250
  the result, else it will return the original value. Any other
1251
  exceptions are propagated to the caller.
1252

1253
  @type fn: callable
1254
  @param fn: function to apply to the value
1255
  @param val: the value to be converted
1256
  @return: The converted value if the conversion was successful,
1257
      otherwise the original value.
1258

1259
  """
1260
  try:
1261
    nv = fn(val)
1262
  except (ValueError, TypeError):
1263
    nv = val
1264
  return nv
1265

    
1266

    
1267
def IsValidIP(ip):
1268
  """Verifies the syntax of an IPv4 address.
1269

1270
  This function checks if the IPv4 address passes is valid or not based
1271
  on syntax (not IP range, class calculations, etc.).
1272

1273
  @type ip: str
1274
  @param ip: the address to be checked
1275
  @rtype: a regular expression match object
1276
  @return: a regular expression match object, or None if the
1277
      address is not valid
1278

1279
  """
1280
  unit = "(0|[1-9]\d{0,2})"
1281
  #TODO: convert and return only boolean
1282
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1283

    
1284

    
1285
def IsValidShellParam(word):
1286
  """Verifies is the given word is safe from the shell's p.o.v.
1287

1288
  This means that we can pass this to a command via the shell and be
1289
  sure that it doesn't alter the command line and is passed as such to
1290
  the actual command.
1291

1292
  Note that we are overly restrictive here, in order to be on the safe
1293
  side.
1294

1295
  @type word: str
1296
  @param word: the word to check
1297
  @rtype: boolean
1298
  @return: True if the word is 'safe'
1299

1300
  """
1301
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1302

    
1303

    
1304
def BuildShellCmd(template, *args):
1305
  """Build a safe shell command line from the given arguments.
1306

1307
  This function will check all arguments in the args list so that they
1308
  are valid shell parameters (i.e. they don't contain shell
1309
  metacharacters). If everything is ok, it will return the result of
1310
  template % args.
1311

1312
  @type template: str
1313
  @param template: the string holding the template for the
1314
      string formatting
1315
  @rtype: str
1316
  @return: the expanded command line
1317

1318
  """
1319
  for word in args:
1320
    if not IsValidShellParam(word):
1321
      raise errors.ProgrammerError("Shell argument '%s' contains"
1322
                                   " invalid characters" % word)
1323
  return template % args
1324

    
1325

    
1326
def FormatUnit(value, units):
1327
  """Formats an incoming number of MiB with the appropriate unit.
1328

1329
  @type value: int
1330
  @param value: integer representing the value in MiB (1048576)
1331
  @type units: char
1332
  @param units: the type of formatting we should do:
1333
      - 'h' for automatic scaling
1334
      - 'm' for MiBs
1335
      - 'g' for GiBs
1336
      - 't' for TiBs
1337
  @rtype: str
1338
  @return: the formatted value (with suffix)
1339

1340
  """
1341
  if units not in ('m', 'g', 't', 'h'):
1342
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1343

    
1344
  suffix = ''
1345

    
1346
  if units == 'm' or (units == 'h' and value < 1024):
1347
    if units == 'h':
1348
      suffix = 'M'
1349
    return "%d%s" % (round(value, 0), suffix)
1350

    
1351
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1352
    if units == 'h':
1353
      suffix = 'G'
1354
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1355

    
1356
  else:
1357
    if units == 'h':
1358
      suffix = 'T'
1359
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1360

    
1361

    
1362
def ParseUnit(input_string):
1363
  """Tries to extract number and scale from the given string.
1364

1365
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1366
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1367
  is always an int in MiB.
1368

1369
  """
1370
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1371
  if not m:
1372
    raise errors.UnitParseError("Invalid format")
1373

    
1374
  value = float(m.groups()[0])
1375

    
1376
  unit = m.groups()[1]
1377
  if unit:
1378
    lcunit = unit.lower()
1379
  else:
1380
    lcunit = 'm'
1381

    
1382
  if lcunit in ('m', 'mb', 'mib'):
1383
    # Value already in MiB
1384
    pass
1385

    
1386
  elif lcunit in ('g', 'gb', 'gib'):
1387
    value *= 1024
1388

    
1389
  elif lcunit in ('t', 'tb', 'tib'):
1390
    value *= 1024 * 1024
1391

    
1392
  else:
1393
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1394

    
1395
  # Make sure we round up
1396
  if int(value) < value:
1397
    value += 1
1398

    
1399
  # Round up to the next multiple of 4
1400
  value = int(value)
1401
  if value % 4:
1402
    value += 4 - value % 4
1403

    
1404
  return value
1405

    
1406

    
1407
def AddAuthorizedKey(file_name, key):
1408
  """Adds an SSH public key to an authorized_keys file.
1409

1410
  @type file_name: str
1411
  @param file_name: path to authorized_keys file
1412
  @type key: str
1413
  @param key: string containing key
1414

1415
  """
1416
  key_fields = key.split()
1417

    
1418
  f = open(file_name, 'a+')
1419
  try:
1420
    nl = True
1421
    for line in f:
1422
      # Ignore whitespace changes
1423
      if line.split() == key_fields:
1424
        break
1425
      nl = line.endswith('\n')
1426
    else:
1427
      if not nl:
1428
        f.write("\n")
1429
      f.write(key.rstrip('\r\n'))
1430
      f.write("\n")
1431
      f.flush()
1432
  finally:
1433
    f.close()
1434

    
1435

    
1436
def RemoveAuthorizedKey(file_name, key):
1437
  """Removes an SSH public key from an authorized_keys file.
1438

1439
  @type file_name: str
1440
  @param file_name: path to authorized_keys file
1441
  @type key: str
1442
  @param key: string containing key
1443

1444
  """
1445
  key_fields = key.split()
1446

    
1447
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1448
  try:
1449
    out = os.fdopen(fd, 'w')
1450
    try:
1451
      f = open(file_name, 'r')
1452
      try:
1453
        for line in f:
1454
          # Ignore whitespace changes while comparing lines
1455
          if line.split() != key_fields:
1456
            out.write(line)
1457

    
1458
        out.flush()
1459
        os.rename(tmpname, file_name)
1460
      finally:
1461
        f.close()
1462
    finally:
1463
      out.close()
1464
  except:
1465
    RemoveFile(tmpname)
1466
    raise
1467

    
1468

    
1469
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1470
  """Sets the name of an IP address and hostname in /etc/hosts.
1471

1472
  @type file_name: str
1473
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1474
  @type ip: str
1475
  @param ip: the IP address
1476
  @type hostname: str
1477
  @param hostname: the hostname to be added
1478
  @type aliases: list
1479
  @param aliases: the list of aliases to add for the hostname
1480

1481
  """
1482
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1483
  # Ensure aliases are unique
1484
  aliases = UniqueSequence([hostname] + aliases)[1:]
1485

    
1486
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1487
  try:
1488
    out = os.fdopen(fd, 'w')
1489
    try:
1490
      f = open(file_name, 'r')
1491
      try:
1492
        for line in f:
1493
          fields = line.split()
1494
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1495
            continue
1496
          out.write(line)
1497

    
1498
        out.write("%s\t%s" % (ip, hostname))
1499
        if aliases:
1500
          out.write(" %s" % ' '.join(aliases))
1501
        out.write('\n')
1502

    
1503
        out.flush()
1504
        os.fsync(out)
1505
        os.chmod(tmpname, 0644)
1506
        os.rename(tmpname, file_name)
1507
      finally:
1508
        f.close()
1509
    finally:
1510
      out.close()
1511
  except:
1512
    RemoveFile(tmpname)
1513
    raise
1514

    
1515

    
1516
def AddHostToEtcHosts(hostname):
1517
  """Wrapper around SetEtcHostsEntry.
1518

1519
  @type hostname: str
1520
  @param hostname: a hostname that will be resolved and added to
1521
      L{constants.ETC_HOSTS}
1522

1523
  """
1524
  hi = HostInfo(name=hostname)
1525
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1526

    
1527

    
1528
def RemoveEtcHostsEntry(file_name, hostname):
1529
  """Removes a hostname from /etc/hosts.
1530

1531
  IP addresses without names are removed from the file.
1532

1533
  @type file_name: str
1534
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1535
  @type hostname: str
1536
  @param hostname: the hostname to be removed
1537

1538
  """
1539
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1540
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1541
  try:
1542
    out = os.fdopen(fd, 'w')
1543
    try:
1544
      f = open(file_name, 'r')
1545
      try:
1546
        for line in f:
1547
          fields = line.split()
1548
          if len(fields) > 1 and not fields[0].startswith('#'):
1549
            names = fields[1:]
1550
            if hostname in names:
1551
              while hostname in names:
1552
                names.remove(hostname)
1553
              if names:
1554
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1555
              continue
1556

    
1557
          out.write(line)
1558

    
1559
        out.flush()
1560
        os.fsync(out)
1561
        os.chmod(tmpname, 0644)
1562
        os.rename(tmpname, file_name)
1563
      finally:
1564
        f.close()
1565
    finally:
1566
      out.close()
1567
  except:
1568
    RemoveFile(tmpname)
1569
    raise
1570

    
1571

    
1572
def RemoveHostFromEtcHosts(hostname):
1573
  """Wrapper around RemoveEtcHostsEntry.
1574

1575
  @type hostname: str
1576
  @param hostname: hostname that will be resolved and its
1577
      full and shot name will be removed from
1578
      L{constants.ETC_HOSTS}
1579

1580
  """
1581
  hi = HostInfo(name=hostname)
1582
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1583
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1584

    
1585

    
1586
def TimestampForFilename():
1587
  """Returns the current time formatted for filenames.
1588

1589
  The format doesn't contain colons as some shells and applications them as
1590
  separators.
1591

1592
  """
1593
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1594

    
1595

    
1596
def CreateBackup(file_name):
1597
  """Creates a backup of a file.
1598

1599
  @type file_name: str
1600
  @param file_name: file to be backed up
1601
  @rtype: str
1602
  @return: the path to the newly created backup
1603
  @raise errors.ProgrammerError: for invalid file names
1604

1605
  """
1606
  if not os.path.isfile(file_name):
1607
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1608
                                file_name)
1609

    
1610
  prefix = ("%s.backup-%s." %
1611
            (os.path.basename(file_name), TimestampForFilename()))
1612
  dir_name = os.path.dirname(file_name)
1613

    
1614
  fsrc = open(file_name, 'rb')
1615
  try:
1616
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1617
    fdst = os.fdopen(fd, 'wb')
1618
    try:
1619
      logging.debug("Backing up %s at %s", file_name, backup_name)
1620
      shutil.copyfileobj(fsrc, fdst)
1621
    finally:
1622
      fdst.close()
1623
  finally:
1624
    fsrc.close()
1625

    
1626
  return backup_name
1627

    
1628

    
1629
def ShellQuote(value):
1630
  """Quotes shell argument according to POSIX.
1631

1632
  @type value: str
1633
  @param value: the argument to be quoted
1634
  @rtype: str
1635
  @return: the quoted value
1636

1637
  """
1638
  if _re_shell_unquoted.match(value):
1639
    return value
1640
  else:
1641
    return "'%s'" % value.replace("'", "'\\''")
1642

    
1643

    
1644
def ShellQuoteArgs(args):
1645
  """Quotes a list of shell arguments.
1646

1647
  @type args: list
1648
  @param args: list of arguments to be quoted
1649
  @rtype: str
1650
  @return: the quoted arguments concatenated with spaces
1651

1652
  """
1653
  return ' '.join([ShellQuote(i) for i in args])
1654

    
1655

    
1656
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1657
  """Simple ping implementation using TCP connect(2).
1658

1659
  Check if the given IP is reachable by doing attempting a TCP connect
1660
  to it.
1661

1662
  @type target: str
1663
  @param target: the IP or hostname to ping
1664
  @type port: int
1665
  @param port: the port to connect to
1666
  @type timeout: int
1667
  @param timeout: the timeout on the connection attempt
1668
  @type live_port_needed: boolean
1669
  @param live_port_needed: whether a closed port will cause the
1670
      function to return failure, as if there was a timeout
1671
  @type source: str or None
1672
  @param source: if specified, will cause the connect to be made
1673
      from this specific source address; failures to bind other
1674
      than C{EADDRNOTAVAIL} will be ignored
1675

1676
  """
1677
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1678

    
1679
  success = False
1680

    
1681
  if source is not None:
1682
    try:
1683
      sock.bind((source, 0))
1684
    except socket.error, (errcode, _):
1685
      if errcode == errno.EADDRNOTAVAIL:
1686
        success = False
1687

    
1688
  sock.settimeout(timeout)
1689

    
1690
  try:
1691
    sock.connect((target, port))
1692
    sock.close()
1693
    success = True
1694
  except socket.timeout:
1695
    success = False
1696
  except socket.error, (errcode, _):
1697
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1698

    
1699
  return success
1700

    
1701

    
1702
def OwnIpAddress(address):
1703
  """Check if the current host has the the given IP address.
1704

1705
  Currently this is done by TCP-pinging the address from the loopback
1706
  address.
1707

1708
  @type address: string
1709
  @param address: the address to check
1710
  @rtype: bool
1711
  @return: True if we own the address
1712

1713
  """
1714
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1715
                 source=constants.LOCALHOST_IP_ADDRESS)
1716

    
1717

    
1718
def ListVisibleFiles(path):
1719
  """Returns a list of visible files in a directory.
1720

1721
  @type path: str
1722
  @param path: the directory to enumerate
1723
  @rtype: list
1724
  @return: the list of all files not starting with a dot
1725
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1726

1727
  """
1728
  if not IsNormAbsPath(path):
1729
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1730
                                 " absolute/normalized: '%s'" % path)
1731
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1732
  files.sort()
1733
  return files
1734

    
1735

    
1736
def GetHomeDir(user, default=None):
1737
  """Try to get the homedir of the given user.
1738

1739
  The user can be passed either as a string (denoting the name) or as
1740
  an integer (denoting the user id). If the user is not found, the
1741
  'default' argument is returned, which defaults to None.
1742

1743
  """
1744
  try:
1745
    if isinstance(user, basestring):
1746
      result = pwd.getpwnam(user)
1747
    elif isinstance(user, (int, long)):
1748
      result = pwd.getpwuid(user)
1749
    else:
1750
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1751
                                   type(user))
1752
  except KeyError:
1753
    return default
1754
  return result.pw_dir
1755

    
1756

    
1757
def NewUUID():
1758
  """Returns a random UUID.
1759

1760
  @note: This is a Linux-specific method as it uses the /proc
1761
      filesystem.
1762
  @rtype: str
1763

1764
  """
1765
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1766

    
1767

    
1768
def GenerateSecret(numbytes=20):
1769
  """Generates a random secret.
1770

1771
  This will generate a pseudo-random secret returning an hex string
1772
  (so that it can be used where an ASCII string is needed).
1773

1774
  @param numbytes: the number of bytes which will be represented by the returned
1775
      string (defaulting to 20, the length of a SHA1 hash)
1776
  @rtype: str
1777
  @return: an hex representation of the pseudo-random sequence
1778

1779
  """
1780
  return os.urandom(numbytes).encode('hex')
1781

    
1782

    
1783
def EnsureDirs(dirs):
1784
  """Make required directories, if they don't exist.
1785

1786
  @param dirs: list of tuples (dir_name, dir_mode)
1787
  @type dirs: list of (string, integer)
1788

1789
  """
1790
  for dir_name, dir_mode in dirs:
1791
    try:
1792
      os.mkdir(dir_name, dir_mode)
1793
    except EnvironmentError, err:
1794
      if err.errno != errno.EEXIST:
1795
        raise errors.GenericError("Cannot create needed directory"
1796
                                  " '%s': %s" % (dir_name, err))
1797
    try:
1798
      os.chmod(dir_name, dir_mode)
1799
    except EnvironmentError, err:
1800
      raise errors.GenericError("Cannot change directory permissions on"
1801
                                " '%s': %s" % (dir_name, err))
1802
    if not os.path.isdir(dir_name):
1803
      raise errors.GenericError("%s is not a directory" % dir_name)
1804

    
1805

    
1806
def ReadFile(file_name, size=-1):
1807
  """Reads a file.
1808

1809
  @type size: int
1810
  @param size: Read at most size bytes (if negative, entire file)
1811
  @rtype: str
1812
  @return: the (possibly partial) content of the file
1813

1814
  """
1815
  f = open(file_name, "r")
1816
  try:
1817
    return f.read(size)
1818
  finally:
1819
    f.close()
1820

    
1821

    
1822
def WriteFile(file_name, fn=None, data=None,
1823
              mode=None, uid=-1, gid=-1,
1824
              atime=None, mtime=None, close=True,
1825
              dry_run=False, backup=False,
1826
              prewrite=None, postwrite=None):
1827
  """(Over)write a file atomically.
1828

1829
  The file_name and either fn (a function taking one argument, the
1830
  file descriptor, and which should write the data to it) or data (the
1831
  contents of the file) must be passed. The other arguments are
1832
  optional and allow setting the file mode, owner and group, and the
1833
  mtime/atime of the file.
1834

1835
  If the function doesn't raise an exception, it has succeeded and the
1836
  target file has the new contents. If the function has raised an
1837
  exception, an existing target file should be unmodified and the
1838
  temporary file should be removed.
1839

1840
  @type file_name: str
1841
  @param file_name: the target filename
1842
  @type fn: callable
1843
  @param fn: content writing function, called with
1844
      file descriptor as parameter
1845
  @type data: str
1846
  @param data: contents of the file
1847
  @type mode: int
1848
  @param mode: file mode
1849
  @type uid: int
1850
  @param uid: the owner of the file
1851
  @type gid: int
1852
  @param gid: the group of the file
1853
  @type atime: int
1854
  @param atime: a custom access time to be set on the file
1855
  @type mtime: int
1856
  @param mtime: a custom modification time to be set on the file
1857
  @type close: boolean
1858
  @param close: whether to close file after writing it
1859
  @type prewrite: callable
1860
  @param prewrite: function to be called before writing content
1861
  @type postwrite: callable
1862
  @param postwrite: function to be called after writing content
1863

1864
  @rtype: None or int
1865
  @return: None if the 'close' parameter evaluates to True,
1866
      otherwise the file descriptor
1867

1868
  @raise errors.ProgrammerError: if any of the arguments are not valid
1869

1870
  """
1871
  if not os.path.isabs(file_name):
1872
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1873
                                 " absolute: '%s'" % file_name)
1874

    
1875
  if [fn, data].count(None) != 1:
1876
    raise errors.ProgrammerError("fn or data required")
1877

    
1878
  if [atime, mtime].count(None) == 1:
1879
    raise errors.ProgrammerError("Both atime and mtime must be either"
1880
                                 " set or None")
1881

    
1882
  if backup and not dry_run and os.path.isfile(file_name):
1883
    CreateBackup(file_name)
1884

    
1885
  dir_name, base_name = os.path.split(file_name)
1886
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1887
  do_remove = True
1888
  # here we need to make sure we remove the temp file, if any error
1889
  # leaves it in place
1890
  try:
1891
    if uid != -1 or gid != -1:
1892
      os.chown(new_name, uid, gid)
1893
    if mode:
1894
      os.chmod(new_name, mode)
1895
    if callable(prewrite):
1896
      prewrite(fd)
1897
    if data is not None:
1898
      os.write(fd, data)
1899
    else:
1900
      fn(fd)
1901
    if callable(postwrite):
1902
      postwrite(fd)
1903
    os.fsync(fd)
1904
    if atime is not None and mtime is not None:
1905
      os.utime(new_name, (atime, mtime))
1906
    if not dry_run:
1907
      os.rename(new_name, file_name)
1908
      do_remove = False
1909
  finally:
1910
    if close:
1911
      os.close(fd)
1912
      result = None
1913
    else:
1914
      result = fd
1915
    if do_remove:
1916
      RemoveFile(new_name)
1917

    
1918
  return result
1919

    
1920

    
1921
def ReadOneLineFile(file_name, strict=False):
1922
  """Return the first non-empty line from a file.
1923

1924
  @type strict: boolean
1925
  @param strict: if True, abort if the file has more than one
1926
      non-empty line
1927

1928
  """
1929
  file_lines = ReadFile(file_name).splitlines()
1930
  full_lines = filter(bool, file_lines)
1931
  if not file_lines or not full_lines:
1932
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1933
  elif strict and len(full_lines) > 1:
1934
    raise errors.GenericError("Too many lines in one-liner file %s" %
1935
                              file_name)
1936
  return full_lines[0]
1937

    
1938

    
1939
def FirstFree(seq, base=0):
1940
  """Returns the first non-existing integer from seq.
1941

1942
  The seq argument should be a sorted list of positive integers. The
1943
  first time the index of an element is smaller than the element
1944
  value, the index will be returned.
1945

1946
  The base argument is used to start at a different offset,
1947
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1948

1949
  Example: C{[0, 1, 3]} will return I{2}.
1950

1951
  @type seq: sequence
1952
  @param seq: the sequence to be analyzed.
1953
  @type base: int
1954
  @param base: use this value as the base index of the sequence
1955
  @rtype: int
1956
  @return: the first non-used index in the sequence
1957

1958
  """
1959
  for idx, elem in enumerate(seq):
1960
    assert elem >= base, "Passed element is higher than base offset"
1961
    if elem > idx + base:
1962
      # idx is not used
1963
      return idx + base
1964
  return None
1965

    
1966

    
1967
def SingleWaitForFdCondition(fdobj, event, timeout):
1968
  """Waits for a condition to occur on the socket.
1969

1970
  Immediately returns at the first interruption.
1971

1972
  @type fdobj: integer or object supporting a fileno() method
1973
  @param fdobj: entity to wait for events on
1974
  @type event: integer
1975
  @param event: ORed condition (see select module)
1976
  @type timeout: float or None
1977
  @param timeout: Timeout in seconds
1978
  @rtype: int or None
1979
  @return: None for timeout, otherwise occured conditions
1980

1981
  """
1982
  check = (event | select.POLLPRI |
1983
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1984

    
1985
  if timeout is not None:
1986
    # Poller object expects milliseconds
1987
    timeout *= 1000
1988

    
1989
  poller = select.poll()
1990
  poller.register(fdobj, event)
1991
  try:
1992
    # TODO: If the main thread receives a signal and we have no timeout, we
1993
    # could wait forever. This should check a global "quit" flag or something
1994
    # every so often.
1995
    io_events = poller.poll(timeout)
1996
  except select.error, err:
1997
    if err[0] != errno.EINTR:
1998
      raise
1999
    io_events = []
2000
  if io_events and io_events[0][1] & check:
2001
    return io_events[0][1]
2002
  else:
2003
    return None
2004

    
2005

    
2006
class FdConditionWaiterHelper(object):
2007
  """Retry helper for WaitForFdCondition.
2008

2009
  This class contains the retried and wait functions that make sure
2010
  WaitForFdCondition can continue waiting until the timeout is actually
2011
  expired.
2012

2013
  """
2014

    
2015
  def __init__(self, timeout):
2016
    self.timeout = timeout
2017

    
2018
  def Poll(self, fdobj, event):
2019
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2020
    if result is None:
2021
      raise RetryAgain()
2022
    else:
2023
      return result
2024

    
2025
  def UpdateTimeout(self, timeout):
2026
    self.timeout = timeout
2027

    
2028

    
2029
def WaitForFdCondition(fdobj, event, timeout):
2030
  """Waits for a condition to occur on the socket.
2031

2032
  Retries until the timeout is expired, even if interrupted.
2033

2034
  @type fdobj: integer or object supporting a fileno() method
2035
  @param fdobj: entity to wait for events on
2036
  @type event: integer
2037
  @param event: ORed condition (see select module)
2038
  @type timeout: float or None
2039
  @param timeout: Timeout in seconds
2040
  @rtype: int or None
2041
  @return: None for timeout, otherwise occured conditions
2042

2043
  """
2044
  if timeout is not None:
2045
    retrywaiter = FdConditionWaiterHelper(timeout)
2046
    try:
2047
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2048
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2049
    except RetryTimeout:
2050
      result = None
2051
  else:
2052
    result = None
2053
    while result is None:
2054
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2055
  return result
2056

    
2057

    
2058
def UniqueSequence(seq):
2059
  """Returns a list with unique elements.
2060

2061
  Element order is preserved.
2062

2063
  @type seq: sequence
2064
  @param seq: the sequence with the source elements
2065
  @rtype: list
2066
  @return: list of unique elements from seq
2067

2068
  """
2069
  seen = set()
2070
  return [i for i in seq if i not in seen and not seen.add(i)]
2071

    
2072

    
2073
def NormalizeAndValidateMac(mac):
2074
  """Normalizes and check if a MAC address is valid.
2075

2076
  Checks whether the supplied MAC address is formally correct, only
2077
  accepts colon separated format. Normalize it to all lower.
2078

2079
  @type mac: str
2080
  @param mac: the MAC to be validated
2081
  @rtype: str
2082
  @return: returns the normalized and validated MAC.
2083

2084
  @raise errors.OpPrereqError: If the MAC isn't valid
2085

2086
  """
2087
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2088
  if not mac_check.match(mac):
2089
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2090
                               mac, errors.ECODE_INVAL)
2091

    
2092
  return mac.lower()
2093

    
2094

    
2095
def TestDelay(duration):
2096
  """Sleep for a fixed amount of time.
2097

2098
  @type duration: float
2099
  @param duration: the sleep duration
2100
  @rtype: boolean
2101
  @return: False for negative value, True otherwise
2102

2103
  """
2104
  if duration < 0:
2105
    return False, "Invalid sleep duration"
2106
  time.sleep(duration)
2107
  return True, None
2108

    
2109

    
2110
def _CloseFDNoErr(fd, retries=5):
2111
  """Close a file descriptor ignoring errors.
2112

2113
  @type fd: int
2114
  @param fd: the file descriptor
2115
  @type retries: int
2116
  @param retries: how many retries to make, in case we get any
2117
      other error than EBADF
2118

2119
  """
2120
  try:
2121
    os.close(fd)
2122
  except OSError, err:
2123
    if err.errno != errno.EBADF:
2124
      if retries > 0:
2125
        _CloseFDNoErr(fd, retries - 1)
2126
    # else either it's closed already or we're out of retries, so we
2127
    # ignore this and go on
2128

    
2129

    
2130
def CloseFDs(noclose_fds=None):
2131
  """Close file descriptors.
2132

2133
  This closes all file descriptors above 2 (i.e. except
2134
  stdin/out/err).
2135

2136
  @type noclose_fds: list or None
2137
  @param noclose_fds: if given, it denotes a list of file descriptor
2138
      that should not be closed
2139

2140
  """
2141
  # Default maximum for the number of available file descriptors.
2142
  if 'SC_OPEN_MAX' in os.sysconf_names:
2143
    try:
2144
      MAXFD = os.sysconf('SC_OPEN_MAX')
2145
      if MAXFD < 0:
2146
        MAXFD = 1024
2147
    except OSError:
2148
      MAXFD = 1024
2149
  else:
2150
    MAXFD = 1024
2151
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2152
  if (maxfd == resource.RLIM_INFINITY):
2153
    maxfd = MAXFD
2154

    
2155
  # Iterate through and close all file descriptors (except the standard ones)
2156
  for fd in range(3, maxfd):
2157
    if noclose_fds and fd in noclose_fds:
2158
      continue
2159
    _CloseFDNoErr(fd)
2160

    
2161

    
2162
def Mlockall():
2163
  """Lock current process' virtual address space into RAM.
2164

2165
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2166
  see mlock(2) for more details. This function requires ctypes module.
2167

2168
  """
2169
  if ctypes is None:
2170
    logging.warning("Cannot set memory lock, ctypes module not found")
2171
    return
2172

    
2173
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2174
  if libc is None:
2175
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2176
    return
2177

    
2178
  # Some older version of the ctypes module don't have built-in functionality
2179
  # to access the errno global variable, where function error codes are stored.
2180
  # By declaring this variable as a pointer to an integer we can then access
2181
  # its value correctly, should the mlockall call fail, in order to see what
2182
  # the actual error code was.
2183
  # pylint: disable-msg=W0212
2184
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2185

    
2186
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2187
    # pylint: disable-msg=W0212
2188
    logging.error("Cannot set memory lock: %s",
2189
                  os.strerror(libc.__errno_location().contents.value))
2190
    return
2191

    
2192
  logging.debug("Memory lock set")
2193

    
2194

    
2195
def Daemonize(logfile):
2196
  """Daemonize the current process.
2197

2198
  This detaches the current process from the controlling terminal and
2199
  runs it in the background as a daemon.
2200

2201
  @type logfile: str
2202
  @param logfile: the logfile to which we should redirect stdout/stderr
2203
  @rtype: int
2204
  @return: the value zero
2205

2206
  """
2207
  # pylint: disable-msg=W0212
2208
  # yes, we really want os._exit
2209
  UMASK = 077
2210
  WORKDIR = "/"
2211

    
2212
  # this might fail
2213
  pid = os.fork()
2214
  if (pid == 0):  # The first child.
2215
    os.setsid()
2216
    # this might fail
2217
    pid = os.fork() # Fork a second child.
2218
    if (pid == 0):  # The second child.
2219
      os.chdir(WORKDIR)
2220
      os.umask(UMASK)
2221
    else:
2222
      # exit() or _exit()?  See below.
2223
      os._exit(0) # Exit parent (the first child) of the second child.
2224
  else:
2225
    os._exit(0) # Exit parent of the first child.
2226

    
2227
  for fd in range(3):
2228
    _CloseFDNoErr(fd)
2229
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2230
  assert i == 0, "Can't close/reopen stdin"
2231
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2232
  assert i == 1, "Can't close/reopen stdout"
2233
  # Duplicate standard output to standard error.
2234
  os.dup2(1, 2)
2235
  return 0
2236

    
2237

    
2238
def DaemonPidFileName(name):
2239
  """Compute a ganeti pid file absolute path
2240

2241
  @type name: str
2242
  @param name: the daemon name
2243
  @rtype: str
2244
  @return: the full path to the pidfile corresponding to the given
2245
      daemon name
2246

2247
  """
2248
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2249

    
2250

    
2251
def EnsureDaemon(name):
2252
  """Check for and start daemon if not alive.
2253

2254
  """
2255
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2256
  if result.failed:
2257
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2258
                  name, result.fail_reason, result.output)
2259
    return False
2260

    
2261
  return True
2262

    
2263

    
2264
def StopDaemon(name):
2265
  """Stop daemon
2266

2267
  """
2268
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2269
  if result.failed:
2270
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2271
                  name, result.fail_reason, result.output)
2272
    return False
2273

    
2274
  return True
2275

    
2276

    
2277
def WritePidFile(name):
2278
  """Write the current process pidfile.
2279

2280
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2281

2282
  @type name: str
2283
  @param name: the daemon name to use
2284
  @raise errors.GenericError: if the pid file already exists and
2285
      points to a live process
2286

2287
  """
2288
  pid = os.getpid()
2289
  pidfilename = DaemonPidFileName(name)
2290
  if IsProcessAlive(ReadPidFile(pidfilename)):
2291
    raise errors.GenericError("%s contains a live process" % pidfilename)
2292

    
2293
  WriteFile(pidfilename, data="%d\n" % pid)
2294

    
2295

    
2296
def RemovePidFile(name):
2297
  """Remove the current process pidfile.
2298

2299
  Any errors are ignored.
2300

2301
  @type name: str
2302
  @param name: the daemon name used to derive the pidfile name
2303

2304
  """
2305
  pidfilename = DaemonPidFileName(name)
2306
  # TODO: we could check here that the file contains our pid
2307
  try:
2308
    RemoveFile(pidfilename)
2309
  except: # pylint: disable-msg=W0702
2310
    pass
2311

    
2312

    
2313
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2314
                waitpid=False):
2315
  """Kill a process given by its pid.
2316

2317
  @type pid: int
2318
  @param pid: The PID to terminate.
2319
  @type signal_: int
2320
  @param signal_: The signal to send, by default SIGTERM
2321
  @type timeout: int
2322
  @param timeout: The timeout after which, if the process is still alive,
2323
                  a SIGKILL will be sent. If not positive, no such checking
2324
                  will be done
2325
  @type waitpid: boolean
2326
  @param waitpid: If true, we should waitpid on this process after
2327
      sending signals, since it's our own child and otherwise it
2328
      would remain as zombie
2329

2330
  """
2331
  def _helper(pid, signal_, wait):
2332
    """Simple helper to encapsulate the kill/waitpid sequence"""
2333
    os.kill(pid, signal_)
2334
    if wait:
2335
      try:
2336
        os.waitpid(pid, os.WNOHANG)
2337
      except OSError:
2338
        pass
2339

    
2340
  if pid <= 0:
2341
    # kill with pid=0 == suicide
2342
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2343

    
2344
  if not IsProcessAlive(pid):
2345
    return
2346

    
2347
  _helper(pid, signal_, waitpid)
2348

    
2349
  if timeout <= 0:
2350
    return
2351

    
2352
  def _CheckProcess():
2353
    if not IsProcessAlive(pid):
2354
      return
2355

    
2356
    try:
2357
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2358
    except OSError:
2359
      raise RetryAgain()
2360

    
2361
    if result_pid > 0:
2362
      return
2363

    
2364
    raise RetryAgain()
2365

    
2366
  try:
2367
    # Wait up to $timeout seconds
2368
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2369
  except RetryTimeout:
2370
    pass
2371

    
2372
  if IsProcessAlive(pid):
2373
    # Kill process if it's still alive
2374
    _helper(pid, signal.SIGKILL, waitpid)
2375

    
2376

    
2377
def FindFile(name, search_path, test=os.path.exists):
2378
  """Look for a filesystem object in a given path.
2379

2380
  This is an abstract method to search for filesystem object (files,
2381
  dirs) under a given search path.
2382

2383
  @type name: str
2384
  @param name: the name to look for
2385
  @type search_path: str
2386
  @param search_path: location to start at
2387
  @type test: callable
2388
  @param test: a function taking one argument that should return True
2389
      if the a given object is valid; the default value is
2390
      os.path.exists, causing only existing files to be returned
2391
  @rtype: str or None
2392
  @return: full path to the object if found, None otherwise
2393

2394
  """
2395
  # validate the filename mask
2396
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2397
    logging.critical("Invalid value passed for external script name: '%s'",
2398
                     name)
2399
    return None
2400

    
2401
  for dir_name in search_path:
2402
    # FIXME: investigate switch to PathJoin
2403
    item_name = os.path.sep.join([dir_name, name])
2404
    # check the user test and that we're indeed resolving to the given
2405
    # basename
2406
    if test(item_name) and os.path.basename(item_name) == name:
2407
      return item_name
2408
  return None
2409

    
2410

    
2411
def CheckVolumeGroupSize(vglist, vgname, minsize):
2412
  """Checks if the volume group list is valid.
2413

2414
  The function will check if a given volume group is in the list of
2415
  volume groups and has a minimum size.
2416

2417
  @type vglist: dict
2418
  @param vglist: dictionary of volume group names and their size
2419
  @type vgname: str
2420
  @param vgname: the volume group we should check
2421
  @type minsize: int
2422
  @param minsize: the minimum size we accept
2423
  @rtype: None or str
2424
  @return: None for success, otherwise the error message
2425

2426
  """
2427
  vgsize = vglist.get(vgname, None)
2428
  if vgsize is None:
2429
    return "volume group '%s' missing" % vgname
2430
  elif vgsize < minsize:
2431
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2432
            (vgname, minsize, vgsize))
2433
  return None
2434

    
2435

    
2436
def SplitTime(value):
2437
  """Splits time as floating point number into a tuple.
2438

2439
  @param value: Time in seconds
2440
  @type value: int or float
2441
  @return: Tuple containing (seconds, microseconds)
2442

2443
  """
2444
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2445

    
2446
  assert 0 <= seconds, \
2447
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2448
  assert 0 <= microseconds <= 999999, \
2449
    "Microseconds must be 0-999999, but are %s" % microseconds
2450

    
2451
  return (int(seconds), int(microseconds))
2452

    
2453

    
2454
def MergeTime(timetuple):
2455
  """Merges a tuple into time as a floating point number.
2456

2457
  @param timetuple: Time as tuple, (seconds, microseconds)
2458
  @type timetuple: tuple
2459
  @return: Time as a floating point number expressed in seconds
2460

2461
  """
2462
  (seconds, microseconds) = timetuple
2463

    
2464
  assert 0 <= seconds, \
2465
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2466
  assert 0 <= microseconds <= 999999, \
2467
    "Microseconds must be 0-999999, but are %s" % microseconds
2468

    
2469
  return float(seconds) + (float(microseconds) * 0.000001)
2470

    
2471

    
2472
def GetDaemonPort(daemon_name):
2473
  """Get the daemon port for this cluster.
2474

2475
  Note that this routine does not read a ganeti-specific file, but
2476
  instead uses C{socket.getservbyname} to allow pre-customization of
2477
  this parameter outside of Ganeti.
2478

2479
  @type daemon_name: string
2480
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2481
  @rtype: int
2482

2483
  """
2484
  if daemon_name not in constants.DAEMONS_PORTS:
2485
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2486

    
2487
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2488
  try:
2489
    port = socket.getservbyname(daemon_name, proto)
2490
  except socket.error:
2491
    port = default_port
2492

    
2493
  return port
2494

    
2495

    
2496
class LogFileHandler(logging.FileHandler):
2497
  """Log handler that doesn't fallback to stderr.
2498

2499
  When an error occurs while writing on the logfile, logging.FileHandler tries
2500
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2501
  the logfile. This class avoids failures reporting errors to /dev/console.
2502

2503
  """
2504
  def __init__(self, filename, mode="a", encoding=None):
2505
    """Open the specified file and use it as the stream for logging.
2506

2507
    Also open /dev/console to report errors while logging.
2508

2509
    """
2510
    logging.FileHandler.__init__(self, filename, mode, encoding)
2511
    self.console = open(constants.DEV_CONSOLE, "a")
2512

    
2513
  def handleError(self, record): # pylint: disable-msg=C0103
2514
    """Handle errors which occur during an emit() call.
2515

2516
    Try to handle errors with FileHandler method, if it fails write to
2517
    /dev/console.
2518

2519
    """
2520
    try:
2521
      logging.FileHandler.handleError(self, record)
2522
    except Exception: # pylint: disable-msg=W0703
2523
      try:
2524
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2525
      except Exception: # pylint: disable-msg=W0703
2526
        # Log handler tried everything it could, now just give up
2527
        pass
2528

    
2529

    
2530
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2531
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2532
                 console_logging=False):
2533
  """Configures the logging module.
2534

2535
  @type logfile: str
2536
  @param logfile: the filename to which we should log
2537
  @type debug: integer
2538
  @param debug: if greater than zero, enable debug messages, otherwise
2539
      only those at C{INFO} and above level
2540
  @type stderr_logging: boolean
2541
  @param stderr_logging: whether we should also log to the standard error
2542
  @type program: str
2543
  @param program: the name under which we should log messages
2544
  @type multithreaded: boolean
2545
  @param multithreaded: if True, will add the thread name to the log file
2546
  @type syslog: string
2547
  @param syslog: one of 'no', 'yes', 'only':
2548
      - if no, syslog is not used
2549
      - if yes, syslog is used (in addition to file-logging)
2550
      - if only, only syslog is used
2551
  @type console_logging: boolean
2552
  @param console_logging: if True, will use a FileHandler which falls back to
2553
      the system console if logging fails
2554
  @raise EnvironmentError: if we can't open the log file and
2555
      syslog/stderr logging is disabled
2556

2557
  """
2558
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2559
  sft = program + "[%(process)d]:"
2560
  if multithreaded:
2561
    fmt += "/%(threadName)s"
2562
    sft += " (%(threadName)s)"
2563
  if debug:
2564
    fmt += " %(module)s:%(lineno)s"
2565
    # no debug info for syslog loggers
2566
  fmt += " %(levelname)s %(message)s"
2567
  # yes, we do want the textual level, as remote syslog will probably
2568
  # lose the error level, and it's easier to grep for it
2569
  sft += " %(levelname)s %(message)s"
2570
  formatter = logging.Formatter(fmt)
2571
  sys_fmt = logging.Formatter(sft)
2572

    
2573
  root_logger = logging.getLogger("")
2574
  root_logger.setLevel(logging.NOTSET)
2575

    
2576
  # Remove all previously setup handlers
2577
  for handler in root_logger.handlers:
2578
    handler.close()
2579
    root_logger.removeHandler(handler)
2580

    
2581
  if stderr_logging:
2582
    stderr_handler = logging.StreamHandler()
2583
    stderr_handler.setFormatter(formatter)
2584
    if debug:
2585
      stderr_handler.setLevel(logging.NOTSET)
2586
    else:
2587
      stderr_handler.setLevel(logging.CRITICAL)
2588
    root_logger.addHandler(stderr_handler)
2589

    
2590
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2591
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2592
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2593
                                                    facility)
2594
    syslog_handler.setFormatter(sys_fmt)
2595
    # Never enable debug over syslog
2596
    syslog_handler.setLevel(logging.INFO)
2597
    root_logger.addHandler(syslog_handler)
2598

    
2599
  if syslog != constants.SYSLOG_ONLY:
2600
    # this can fail, if the logging directories are not setup or we have
2601
    # a permisssion problem; in this case, it's best to log but ignore
2602
    # the error if stderr_logging is True, and if false we re-raise the
2603
    # exception since otherwise we could run but without any logs at all
2604
    try:
2605
      if console_logging:
2606
        logfile_handler = LogFileHandler(logfile)
2607
      else:
2608
        logfile_handler = logging.FileHandler(logfile)
2609
      logfile_handler.setFormatter(formatter)
2610
      if debug:
2611
        logfile_handler.setLevel(logging.DEBUG)
2612
      else:
2613
        logfile_handler.setLevel(logging.INFO)
2614
      root_logger.addHandler(logfile_handler)
2615
    except EnvironmentError:
2616
      if stderr_logging or syslog == constants.SYSLOG_YES:
2617
        logging.exception("Failed to enable logging to file '%s'", logfile)
2618
      else:
2619
        # we need to re-raise the exception
2620
        raise
2621

    
2622

    
2623
def IsNormAbsPath(path):
2624
  """Check whether a path is absolute and also normalized
2625

2626
  This avoids things like /dir/../../other/path to be valid.
2627

2628
  """
2629
  return os.path.normpath(path) == path and os.path.isabs(path)
2630

    
2631

    
2632
def PathJoin(*args):
2633
  """Safe-join a list of path components.
2634

2635
  Requirements:
2636
      - the first argument must be an absolute path
2637
      - no component in the path must have backtracking (e.g. /../),
2638
        since we check for normalization at the end
2639

2640
  @param args: the path components to be joined
2641
  @raise ValueError: for invalid paths
2642

2643
  """
2644
  # ensure we're having at least one path passed in
2645
  assert args
2646
  # ensure the first component is an absolute and normalized path name
2647
  root = args[0]
2648
  if not IsNormAbsPath(root):
2649
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2650
  result = os.path.join(*args)
2651
  # ensure that the whole path is normalized
2652
  if not IsNormAbsPath(result):
2653
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2654
  # check that we're still under the original prefix
2655
  prefix = os.path.commonprefix([root, result])
2656
  if prefix != root:
2657
    raise ValueError("Error: path joining resulted in different prefix"
2658
                     " (%s != %s)" % (prefix, root))
2659
  return result
2660

    
2661

    
2662
def TailFile(fname, lines=20):
2663
  """Return the last lines from a file.
2664

2665
  @note: this function will only read and parse the last 4KB of
2666
      the file; if the lines are very long, it could be that less
2667
      than the requested number of lines are returned
2668

2669
  @param fname: the file name
2670
  @type lines: int
2671
  @param lines: the (maximum) number of lines to return
2672

2673
  """
2674
  fd = open(fname, "r")
2675
  try:
2676
    fd.seek(0, 2)
2677
    pos = fd.tell()
2678
    pos = max(0, pos-4096)
2679
    fd.seek(pos, 0)
2680
    raw_data = fd.read()
2681
  finally:
2682
    fd.close()
2683

    
2684
  rows = raw_data.splitlines()
2685
  return rows[-lines:]
2686

    
2687

    
2688
def FormatTimestampWithTZ(secs):
2689
  """Formats a Unix timestamp with the local timezone.
2690

2691
  """
2692
  return time.strftime("%F %T %Z", time.gmtime(secs))
2693

    
2694

    
2695
def _ParseAsn1Generalizedtime(value):
2696
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2697

2698
  @type value: string
2699
  @param value: ASN1 GENERALIZEDTIME timestamp
2700

2701
  """
2702
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2703
  if m:
2704
    # We have an offset
2705
    asn1time = m.group(1)
2706
    hours = int(m.group(2))
2707
    minutes = int(m.group(3))
2708
    utcoffset = (60 * hours) + minutes
2709
  else:
2710
    if not value.endswith("Z"):
2711
      raise ValueError("Missing timezone")
2712
    asn1time = value[:-1]
2713
    utcoffset = 0
2714

    
2715
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2716

    
2717
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2718

    
2719
  return calendar.timegm(tt.utctimetuple())
2720

    
2721

    
2722
def GetX509CertValidity(cert):
2723
  """Returns the validity period of the certificate.
2724

2725
  @type cert: OpenSSL.crypto.X509
2726
  @param cert: X509 certificate object
2727

2728
  """
2729
  # The get_notBefore and get_notAfter functions are only supported in
2730
  # pyOpenSSL 0.7 and above.
2731
  try:
2732
    get_notbefore_fn = cert.get_notBefore
2733
  except AttributeError:
2734
    not_before = None
2735
  else:
2736
    not_before_asn1 = get_notbefore_fn()
2737

    
2738
    if not_before_asn1 is None:
2739
      not_before = None
2740
    else:
2741
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2742

    
2743
  try:
2744
    get_notafter_fn = cert.get_notAfter
2745
  except AttributeError:
2746
    not_after = None
2747
  else:
2748
    not_after_asn1 = get_notafter_fn()
2749

    
2750
    if not_after_asn1 is None:
2751
      not_after = None
2752
    else:
2753
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2754

    
2755
  return (not_before, not_after)
2756

    
2757

    
2758
def _VerifyCertificateInner(expired, not_before, not_after, now,
2759
                            warn_days, error_days):
2760
  """Verifies certificate validity.
2761

2762
  @type expired: bool
2763
  @param expired: Whether pyOpenSSL considers the certificate as expired
2764
  @type not_before: number or None
2765
  @param not_before: Unix timestamp before which certificate is not valid
2766
  @type not_after: number or None
2767
  @param not_after: Unix timestamp after which certificate is invalid
2768
  @type now: number
2769
  @param now: Current time as Unix timestamp
2770
  @type warn_days: number or None
2771
  @param warn_days: How many days before expiration a warning should be reported
2772
  @type error_days: number or None
2773
  @param error_days: How many days before expiration an error should be reported
2774

2775
  """
2776
  if expired:
2777
    msg = "Certificate is expired"
2778

    
2779
    if not_before is not None and not_after is not None:
2780
      msg += (" (valid from %s to %s)" %
2781
              (FormatTimestampWithTZ(not_before),
2782
               FormatTimestampWithTZ(not_after)))
2783
    elif not_before is not None:
2784
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2785
    elif not_after is not None:
2786
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2787

    
2788
    return (CERT_ERROR, msg)
2789

    
2790
  elif not_before is not None and not_before > now:
2791
    return (CERT_WARNING,
2792
            "Certificate not yet valid (valid from %s)" %
2793
            FormatTimestampWithTZ(not_before))
2794

    
2795
  elif not_after is not None:
2796
    remaining_days = int((not_after - now) / (24 * 3600))
2797

    
2798
    msg = "Certificate expires in about %d days" % remaining_days
2799

    
2800
    if error_days is not None and remaining_days <= error_days:
2801
      return (CERT_ERROR, msg)
2802

    
2803
    if warn_days is not None and remaining_days <= warn_days:
2804
      return (CERT_WARNING, msg)
2805

    
2806
  return (None, None)
2807

    
2808

    
2809
def VerifyX509Certificate(cert, warn_days, error_days):
2810
  """Verifies a certificate for LUVerifyCluster.
2811

2812
  @type cert: OpenSSL.crypto.X509
2813
  @param cert: X509 certificate object
2814
  @type warn_days: number or None
2815
  @param warn_days: How many days before expiration a warning should be reported
2816
  @type error_days: number or None
2817
  @param error_days: How many days before expiration an error should be reported
2818

2819
  """
2820
  # Depending on the pyOpenSSL version, this can just return (None, None)
2821
  (not_before, not_after) = GetX509CertValidity(cert)
2822

    
2823
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2824
                                 time.time(), warn_days, error_days)
2825

    
2826

    
2827
def SignX509Certificate(cert, key, salt):
2828
  """Sign a X509 certificate.
2829

2830
  An RFC822-like signature header is added in front of the certificate.
2831

2832
  @type cert: OpenSSL.crypto.X509
2833
  @param cert: X509 certificate object
2834
  @type key: string
2835
  @param key: Key for HMAC
2836
  @type salt: string
2837
  @param salt: Salt for HMAC
2838
  @rtype: string
2839
  @return: Serialized and signed certificate in PEM format
2840

2841
  """
2842
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2843
    raise errors.GenericError("Invalid salt: %r" % salt)
2844

    
2845
  # Dumping as PEM here ensures the certificate is in a sane format
2846
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2847

    
2848
  return ("%s: %s/%s\n\n%s" %
2849
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2850
           Sha1Hmac(key, cert_pem, salt=salt),
2851
           cert_pem))
2852

    
2853

    
2854
def _ExtractX509CertificateSignature(cert_pem):
2855
  """Helper function to extract signature from X509 certificate.
2856

2857
  """
2858
  # Extract signature from original PEM data
2859
  for line in cert_pem.splitlines():
2860
    if line.startswith("---"):
2861
      break
2862

    
2863
    m = X509_SIGNATURE.match(line.strip())
2864
    if m:
2865
      return (m.group("salt"), m.group("sign"))
2866

    
2867
  raise errors.GenericError("X509 certificate signature is missing")
2868

    
2869

    
2870
def LoadSignedX509Certificate(cert_pem, key):
2871
  """Verifies a signed X509 certificate.
2872

2873
  @type cert_pem: string
2874
  @param cert_pem: Certificate in PEM format and with signature header
2875
  @type key: string
2876
  @param key: Key for HMAC
2877
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2878
  @return: X509 certificate object and salt
2879

2880
  """
2881
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2882

    
2883
  # Load certificate
2884
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2885

    
2886
  # Dump again to ensure it's in a sane format
2887
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2888

    
2889
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2890
    raise errors.GenericError("X509 certificate signature is invalid")
2891

    
2892
  return (cert, salt)
2893

    
2894

    
2895
def Sha1Hmac(key, text, salt=None):
2896
  """Calculates the HMAC-SHA1 digest of a text.
2897

2898
  HMAC is defined in RFC2104.
2899

2900
  @type key: string
2901
  @param key: Secret key
2902
  @type text: string
2903

2904
  """
2905
  if salt:
2906
    salted_text = salt + text
2907
  else:
2908
    salted_text = text
2909

    
2910
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2911

    
2912

    
2913
def VerifySha1Hmac(key, text, digest, salt=None):
2914
  """Verifies the HMAC-SHA1 digest of a text.
2915

2916
  HMAC is defined in RFC2104.
2917

2918
  @type key: string
2919
  @param key: Secret key
2920
  @type text: string
2921
  @type digest: string
2922
  @param digest: Expected digest
2923
  @rtype: bool
2924
  @return: Whether HMAC-SHA1 digest matches
2925

2926
  """
2927
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2928

    
2929

    
2930
def SafeEncode(text):
2931
  """Return a 'safe' version of a source string.
2932

2933
  This function mangles the input string and returns a version that
2934
  should be safe to display/encode as ASCII. To this end, we first
2935
  convert it to ASCII using the 'backslashreplace' encoding which
2936
  should get rid of any non-ASCII chars, and then we process it
2937
  through a loop copied from the string repr sources in the python; we
2938
  don't use string_escape anymore since that escape single quotes and
2939
  backslashes too, and that is too much; and that escaping is not
2940
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2941

2942
  @type text: str or unicode
2943
  @param text: input data
2944
  @rtype: str
2945
  @return: a safe version of text
2946

2947
  """
2948
  if isinstance(text, unicode):
2949
    # only if unicode; if str already, we handle it below
2950
    text = text.encode('ascii', 'backslashreplace')
2951
  resu = ""
2952
  for char in text:
2953
    c = ord(char)
2954
    if char  == '\t':
2955
      resu += r'\t'
2956
    elif char == '\n':
2957
      resu += r'\n'
2958
    elif char == '\r':
2959
      resu += r'\'r'
2960
    elif c < 32 or c >= 127: # non-printable
2961
      resu += "\\x%02x" % (c & 0xff)
2962
    else:
2963
      resu += char
2964
  return resu
2965

    
2966

    
2967
def UnescapeAndSplit(text, sep=","):
2968
  """Split and unescape a string based on a given separator.
2969

2970
  This function splits a string based on a separator where the
2971
  separator itself can be escape in order to be an element of the
2972
  elements. The escaping rules are (assuming coma being the
2973
  separator):
2974
    - a plain , separates the elements
2975
    - a sequence \\\\, (double backslash plus comma) is handled as a
2976
      backslash plus a separator comma
2977
    - a sequence \, (backslash plus comma) is handled as a
2978
      non-separator comma
2979

2980
  @type text: string
2981
  @param text: the string to split
2982
  @type sep: string
2983
  @param text: the separator
2984
  @rtype: string
2985
  @return: a list of strings
2986

2987
  """
2988
  # we split the list by sep (with no escaping at this stage)
2989
  slist = text.split(sep)
2990
  # next, we revisit the elements and if any of them ended with an odd
2991
  # number of backslashes, then we join it with the next
2992
  rlist = []
2993
  while slist:
2994
    e1 = slist.pop(0)
2995
    if e1.endswith("\\"):
2996
      num_b = len(e1) - len(e1.rstrip("\\"))
2997
      if num_b % 2 == 1:
2998
        e2 = slist.pop(0)
2999
        # here the backslashes remain (all), and will be reduced in
3000
        # the next step
3001
        rlist.append(e1 + sep + e2)
3002
        continue
3003
    rlist.append(e1)
3004
  # finally, replace backslash-something with something
3005
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3006
  return rlist
3007

    
3008

    
3009
def CommaJoin(names):
3010
  """Nicely join a set of identifiers.
3011

3012
  @param names: set, list or tuple
3013
  @return: a string with the formatted results
3014

3015
  """
3016
  return ", ".join([str(val) for val in names])
3017

    
3018

    
3019
def BytesToMebibyte(value):
3020
  """Converts bytes to mebibytes.
3021

3022
  @type value: int
3023
  @param value: Value in bytes
3024
  @rtype: int
3025
  @return: Value in mebibytes
3026

3027
  """
3028
  return int(round(value / (1024.0 * 1024.0), 0))
3029

    
3030

    
3031
def CalculateDirectorySize(path):
3032
  """Calculates the size of a directory recursively.
3033

3034
  @type path: string
3035
  @param path: Path to directory
3036
  @rtype: int
3037
  @return: Size in mebibytes
3038

3039
  """
3040
  size = 0
3041

    
3042
  for (curpath, _, files) in os.walk(path):
3043
    for filename in files:
3044
      st = os.lstat(PathJoin(curpath, filename))
3045
      size += st.st_size
3046

    
3047
  return BytesToMebibyte(size)
3048

    
3049

    
3050
def GetFilesystemStats(path):
3051
  """Returns the total and free space on a filesystem.
3052

3053
  @type path: string
3054
  @param path: Path on filesystem to be examined
3055
  @rtype: int
3056
  @return: tuple of (Total space, Free space) in mebibytes
3057

3058
  """
3059
  st = os.statvfs(path)
3060

    
3061
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3062
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3063
  return (tsize, fsize)
3064

    
3065

    
3066
def RunInSeparateProcess(fn, *args):
3067
  """Runs a function in a separate process.
3068

3069
  Note: Only boolean return values are supported.
3070

3071
  @type fn: callable
3072
  @param fn: Function to be called
3073
  @rtype: bool
3074
  @return: Function's result
3075

3076
  """
3077
  pid = os.fork()
3078
  if pid == 0:
3079
    # Child process
3080
    try:
3081
      # In case the function uses temporary files
3082
      ResetTempfileModule()
3083

    
3084
      # Call function
3085
      result = int(bool(fn(*args)))
3086
      assert result in (0, 1)
3087
    except: # pylint: disable-msg=W0702
3088
      logging.exception("Error while calling function in separate process")
3089
      # 0 and 1 are reserved for the return value
3090
      result = 33
3091

    
3092
    os._exit(result) # pylint: disable-msg=W0212
3093

    
3094
  # Parent process
3095

    
3096
  # Avoid zombies and check exit code
3097
  (_, status) = os.waitpid(pid, 0)
3098

    
3099
  if os.WIFSIGNALED(status):
3100
    exitcode = None
3101
    signum = os.WTERMSIG(status)
3102
  else:
3103
    exitcode = os.WEXITSTATUS(status)
3104
    signum = None
3105

    
3106
  if not (exitcode in (0, 1) and signum is None):
3107
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3108
                              (exitcode, signum))
3109

    
3110
  return bool(exitcode)
3111

    
3112

    
3113
def IgnoreSignals(fn, *args, **kwargs):
3114
  """Tries to call a function ignoring failures due to EINTR.
3115

3116
  """
3117
  try:
3118
    return fn(*args, **kwargs)
3119
  except EnvironmentError, err:
3120
    if err.errno == errno.EINTR:
3121
      return None
3122
    else:
3123
      raise
3124
  except (select.error, socket.error), err:
3125
    # In python 2.6 and above select.error is an IOError, so it's handled
3126
    # above, in 2.5 and below it's not, and it's handled here.
3127
    if err.args and err.args[0] == errno.EINTR:
3128
      return None
3129
    else:
3130
      raise
3131

    
3132

    
3133
def LockedMethod(fn):
3134
  """Synchronized object access decorator.
3135

3136
  This decorator is intended to protect access to an object using the
3137
  object's own lock which is hardcoded to '_lock'.
3138

3139
  """
3140
  def _LockDebug(*args, **kwargs):
3141
    if debug_locks:
3142
      logging.debug(*args, **kwargs)
3143

    
3144
  def wrapper(self, *args, **kwargs):
3145
    # pylint: disable-msg=W0212
3146
    assert hasattr(self, '_lock')
3147
    lock = self._lock
3148
    _LockDebug("Waiting for %s", lock)
3149
    lock.acquire()
3150
    try:
3151
      _LockDebug("Acquired %s", lock)
3152
      result = fn(self, *args, **kwargs)
3153
    finally:
3154
      _LockDebug("Releasing %s", lock)
3155
      lock.release()
3156
      _LockDebug("Released %s", lock)
3157
    return result
3158
  return wrapper
3159

    
3160

    
3161
def LockFile(fd):
3162
  """Locks a file using POSIX locks.
3163

3164
  @type fd: int
3165
  @param fd: the file descriptor we need to lock
3166

3167
  """
3168
  try:
3169
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3170
  except IOError, err:
3171
    if err.errno == errno.EAGAIN:
3172
      raise errors.LockError("File already locked")
3173
    raise
3174

    
3175

    
3176
def FormatTime(val):
3177
  """Formats a time value.
3178

3179
  @type val: float or None
3180
  @param val: the timestamp as returned by time.time()
3181
  @return: a string value or N/A if we don't have a valid timestamp
3182

3183
  """
3184
  if val is None or not isinstance(val, (int, float)):
3185
    return "N/A"
3186
  # these two codes works on Linux, but they are not guaranteed on all
3187
  # platforms
3188
  return time.strftime("%F %T", time.localtime(val))
3189

    
3190

    
3191
def FormatSeconds(secs):
3192
  """Formats seconds for easier reading.
3193

3194
  @type secs: number
3195
  @param secs: Number of seconds
3196
  @rtype: string
3197
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3198

3199
  """
3200
  parts = []
3201

    
3202
  secs = round(secs, 0)
3203

    
3204
  if secs > 0:
3205
    # Negative values would be a bit tricky
3206
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3207
      (complete, secs) = divmod(secs, one)
3208
      if complete or parts:
3209
        parts.append("%d%s" % (complete, unit))
3210

    
3211
  parts.append("%ds" % secs)
3212

    
3213
  return " ".join(parts)
3214

    
3215

    
3216
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3217
  """Reads the watcher pause file.
3218

3219
  @type filename: string
3220
  @param filename: Path to watcher pause file
3221
  @type now: None, float or int
3222
  @param now: Current time as Unix timestamp
3223
  @type remove_after: int
3224
  @param remove_after: Remove watcher pause file after specified amount of
3225
    seconds past the pause end time
3226

3227
  """
3228
  if now is None:
3229
    now = time.time()
3230

    
3231
  try:
3232
    value = ReadFile(filename)
3233
  except IOError, err:
3234
    if err.errno != errno.ENOENT:
3235
      raise
3236
    value = None
3237

    
3238
  if value is not None:
3239
    try:
3240
      value = int(value)
3241
    except ValueError:
3242
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3243
                       " removing it"), filename)
3244
      RemoveFile(filename)
3245
      value = None
3246

    
3247
    if value is not None:
3248
      # Remove file if it's outdated
3249
      if now > (value + remove_after):
3250
        RemoveFile(filename)
3251
        value = None
3252

    
3253
      elif now > value:
3254
        value = None
3255

    
3256
  return value
3257

    
3258

    
3259
class RetryTimeout(Exception):
3260
  """Retry loop timed out.
3261

3262
  Any arguments which was passed by the retried function to RetryAgain will be
3263
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3264
  the RaiseInner helper method will reraise it.
3265

3266
  """
3267
  def RaiseInner(self):
3268
    if self.args and isinstance(self.args[0], Exception):
3269
      raise self.args[0]
3270
    else:
3271
      raise RetryTimeout(*self.args)
3272

    
3273

    
3274
class RetryAgain(Exception):
3275
  """Retry again.
3276

3277
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3278
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3279
  of the RetryTimeout() method can be used to reraise it.
3280

3281
  """
3282

    
3283

    
3284
class _RetryDelayCalculator(object):
3285
  """Calculator for increasing delays.
3286

3287
  """
3288
  __slots__ = [
3289
    "_factor",
3290
    "_limit",
3291
    "_next",
3292
    "_start",
3293
    ]
3294

    
3295
  def __init__(self, start, factor, limit):
3296
    """Initializes this class.
3297

3298
    @type start: float
3299
    @param start: Initial delay
3300
    @type factor: float
3301
    @param factor: Factor for delay increase
3302
    @type limit: float or None
3303
    @param limit: Upper limit for delay or None for no limit
3304

3305
    """
3306
    assert start > 0.0
3307
    assert factor >= 1.0
3308
    assert limit is None or limit >= 0.0
3309

    
3310
    self._start = start
3311
    self._factor = factor
3312
    self._limit = limit
3313

    
3314
    self._next = start
3315

    
3316
  def __call__(self):
3317
    """Returns current delay and calculates the next one.
3318

3319
    """
3320
    current = self._next
3321

    
3322
    # Update for next run
3323
    if self._limit is None or self._next < self._limit:
3324
      self._next = min(self._limit, self._next * self._factor)
3325

    
3326
    return current
3327

    
3328

    
3329
#: Special delay to specify whole remaining timeout
3330
RETRY_REMAINING_TIME = object()
3331

    
3332

    
3333
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3334
          _time_fn=time.time):
3335
  """Call a function repeatedly until it succeeds.
3336

3337
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3338
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3339
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3340

3341
  C{delay} can be one of the following:
3342
    - callable returning the delay length as a float
3343
    - Tuple of (start, factor, limit)
3344
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3345
      useful when overriding L{wait_fn} to wait for an external event)
3346
    - A static delay as a number (int or float)
3347

3348
  @type fn: callable
3349
  @param fn: Function to be called
3350
  @param delay: Either a callable (returning the delay), a tuple of (start,
3351
                factor, limit) (see L{_RetryDelayCalculator}),
3352
                L{RETRY_REMAINING_TIME} or a number (int or float)
3353
  @type timeout: float
3354
  @param timeout: Total timeout
3355
  @type wait_fn: callable
3356
  @param wait_fn: Waiting function
3357
  @return: Return value of function
3358

3359
  """
3360
  assert callable(fn)
3361
  assert callable(wait_fn)
3362
  assert callable(_time_fn)
3363

    
3364
  if args is None:
3365
    args = []
3366

    
3367
  end_time = _time_fn() + timeout
3368

    
3369
  if callable(delay):
3370
    # External function to calculate delay
3371
    calc_delay = delay
3372

    
3373
  elif isinstance(delay, (tuple, list)):
3374
    # Increasing delay with optional upper boundary
3375
    (start, factor, limit) = delay
3376
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3377

    
3378
  elif delay is RETRY_REMAINING_TIME:
3379
    # Always use the remaining time
3380
    calc_delay = None
3381

    
3382
  else:
3383
    # Static delay
3384
    calc_delay = lambda: delay
3385

    
3386
  assert calc_delay is None or callable(calc_delay)
3387

    
3388
  while True:
3389
    retry_args = []
3390
    try:
3391
      # pylint: disable-msg=W0142
3392
      return fn(*args)
3393
    except RetryAgain, err:
3394
      retry_args = err.args
3395
    except RetryTimeout:
3396
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3397
                                   " handle RetryTimeout")
3398

    
3399
    remaining_time = end_time - _time_fn()
3400

    
3401
    if remaining_time < 0.0:
3402
      # pylint: disable-msg=W0142
3403
      raise RetryTimeout(*retry_args)
3404

    
3405
    assert remaining_time >= 0.0
3406

    
3407
    if calc_delay is None:
3408
      wait_fn(remaining_time)
3409
    else:
3410
      current_delay = calc_delay()
3411
      if current_delay > 0.0:
3412
        wait_fn(current_delay)
3413

    
3414

    
3415
def GetClosedTempfile(*args, **kwargs):
3416
  """Creates a temporary file and returns its path.
3417

3418
  """
3419
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3420
  _CloseFDNoErr(fd)
3421
  return path
3422

    
3423

    
3424
def GenerateSelfSignedX509Cert(common_name, validity):
3425
  """Generates a self-signed X509 certificate.
3426

3427
  @type common_name: string
3428
  @param common_name: commonName value
3429
  @type validity: int
3430
  @param validity: Validity for certificate in seconds
3431

3432
  """
3433
  # Create private and public key
3434
  key = OpenSSL.crypto.PKey()
3435
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3436

    
3437
  # Create self-signed certificate
3438
  cert = OpenSSL.crypto.X509()
3439
  if common_name:
3440
    cert.get_subject().CN = common_name
3441
  cert.set_serial_number(1)
3442
  cert.gmtime_adj_notBefore(0)
3443
  cert.gmtime_adj_notAfter(validity)
3444
  cert.set_issuer(cert.get_subject())
3445
  cert.set_pubkey(key)
3446
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3447

    
3448
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3449
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3450

    
3451
  return (key_pem, cert_pem)
3452

    
3453

    
3454
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3455
  """Legacy function to generate self-signed X509 certificate.
3456

3457
  """
3458
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3459
                                                   validity * 24 * 60 * 60)
3460

    
3461
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3462

    
3463

    
3464
class FileLock(object):
3465
  """Utility class for file locks.
3466

3467
  """
3468
  def __init__(self, fd, filename):
3469
    """Constructor for FileLock.
3470

3471
    @type fd: file
3472
    @param fd: File object
3473
    @type filename: str
3474
    @param filename: Path of the file opened at I{fd}
3475

3476
    """
3477
    self.fd = fd
3478
    self.filename = filename
3479

    
3480
  @classmethod
3481
  def Open(cls, filename):
3482
    """Creates and opens a file to be used as a file-based lock.
3483

3484
    @type filename: string
3485
    @param filename: path to the file to be locked
3486

3487
    """
3488
    # Using "os.open" is necessary to allow both opening existing file
3489
    # read/write and creating if not existing. Vanilla "open" will truncate an
3490
    # existing file -or- allow creating if not existing.
3491
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3492
               filename)
3493

    
3494
  def __del__(self):
3495
    self.Close()
3496

    
3497
  def Close(self):
3498
    """Close the file and release the lock.
3499

3500
    """
3501
    if hasattr(self, "fd") and self.fd:
3502
      self.fd.close()
3503
      self.fd = None
3504

    
3505
  def _flock(self, flag, blocking, timeout, errmsg):
3506
    """Wrapper for fcntl.flock.
3507

3508
    @type flag: int
3509
    @param flag: operation flag
3510
    @type blocking: bool
3511
    @param blocking: whether the operation should be done in blocking mode.
3512
    @type timeout: None or float
3513
    @param timeout: for how long the operation should be retried (implies
3514
                    non-blocking mode).
3515
    @type errmsg: string
3516
    @param errmsg: error message in case operation fails.
3517

3518
    """
3519
    assert self.fd, "Lock was closed"
3520
    assert timeout is None or timeout >= 0, \
3521
      "If specified, timeout must be positive"
3522
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3523

    
3524
    # When a timeout is used, LOCK_NB must always be set
3525
    if not (timeout is None and blocking):
3526
      flag |= fcntl.LOCK_NB
3527

    
3528
    if timeout is None:
3529
      self._Lock(self.fd, flag, timeout)
3530
    else:
3531
      try:
3532
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3533
              args=(self.fd, flag, timeout))
3534
      except RetryTimeout:
3535
        raise errors.LockError(errmsg)
3536

    
3537
  @staticmethod
3538
  def _Lock(fd, flag, timeout):
3539
    try:
3540
      fcntl.flock(fd, flag)
3541
    except IOError, err:
3542
      if timeout is not None and err.errno == errno.EAGAIN:
3543
        raise RetryAgain()
3544

    
3545
      logging.exception("fcntl.flock failed")
3546
      raise
3547

    
3548
  def Exclusive(self, blocking=False, timeout=None):
3549
    """Locks the file in exclusive mode.
3550

3551
    @type blocking: boolean
3552
    @param blocking: whether to block and wait until we
3553
        can lock the file or return immediately
3554
    @type timeout: int or None
3555
    @param timeout: if not None, the duration to wait for the lock
3556
        (in blocking mode)
3557

3558
    """
3559
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3560
                "Failed to lock %s in exclusive mode" % self.filename)
3561

    
3562
  def Shared(self, blocking=False, timeout=None):
3563
    """Locks the file in shared mode.
3564

3565
    @type blocking: boolean
3566
    @param blocking: whether to block and wait until we
3567
        can lock the file or return immediately
3568
    @type timeout: int or None
3569
    @param timeout: if not None, the duration to wait for the lock
3570
        (in blocking mode)
3571

3572
    """
3573
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3574
                "Failed to lock %s in shared mode" % self.filename)
3575

    
3576
  def Unlock(self, blocking=True, timeout=None):
3577
    """Unlocks the file.
3578

3579
    According to C{flock(2)}, unlocking can also be a nonblocking
3580
    operation::
3581

3582
      To make a non-blocking request, include LOCK_NB with any of the above
3583
      operations.
3584

3585
    @type blocking: boolean
3586
    @param blocking: whether to block and wait until we
3587
        can lock the file or return immediately
3588
    @type timeout: int or None
3589
    @param timeout: if not None, the duration to wait for the lock
3590
        (in blocking mode)
3591

3592
    """
3593
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3594
                "Failed to unlock %s" % self.filename)
3595

    
3596

    
3597
class LineSplitter:
3598
  """Splits data chunks into lines separated by newline.
3599

3600
  Instances provide a file-like interface.
3601

3602
  """
3603
  def __init__(self, line_fn, *args):
3604
    """Initializes this class.
3605

3606
    @type line_fn: callable
3607
    @param line_fn: Function called for each line, first parameter is line
3608
    @param args: Extra arguments for L{line_fn}
3609

3610
    """
3611
    assert callable(line_fn)
3612

    
3613
    if args:
3614
      # Python 2.4 doesn't have functools.partial yet
3615
      self._line_fn = \
3616
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3617
    else:
3618
      self._line_fn = line_fn
3619

    
3620
    self._lines = collections.deque()
3621
    self._buffer = ""
3622

    
3623
  def write(self, data):
3624
    parts = (self._buffer + data).split("\n")
3625
    self._buffer = parts.pop()
3626
    self._lines.extend(parts)
3627

    
3628
  def flush(self):
3629
    while self._lines:
3630
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3631

    
3632
  def close(self):
3633
    self.flush()
3634
    if self._buffer:
3635
      self._line_fn(self._buffer)
3636

    
3637

    
3638
def SignalHandled(signums):
3639
  """Signal Handled decoration.
3640

3641
  This special decorator installs a signal handler and then calls the target
3642
  function. The function must accept a 'signal_handlers' keyword argument,
3643
  which will contain a dict indexed by signal number, with SignalHandler
3644
  objects as values.
3645

3646
  The decorator can be safely stacked with iself, to handle multiple signals
3647
  with different handlers.
3648

3649
  @type signums: list
3650
  @param signums: signals to intercept
3651

3652
  """
3653
  def wrap(fn):
3654
    def sig_function(*args, **kwargs):
3655
      assert 'signal_handlers' not in kwargs or \
3656
             kwargs['signal_handlers'] is None or \
3657
             isinstance(kwargs['signal_handlers'], dict), \
3658
             "Wrong signal_handlers parameter in original function call"
3659
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3660
        signal_handlers = kwargs['signal_handlers']
3661
      else:
3662
        signal_handlers = {}
3663
        kwargs['signal_handlers'] = signal_handlers
3664
      sighandler = SignalHandler(signums)
3665
      try:
3666
        for sig in signums:
3667
          signal_handlers[sig] = sighandler
3668
        return fn(*args, **kwargs)
3669
      finally:
3670
        sighandler.Reset()
3671
    return sig_function
3672
  return wrap
3673

    
3674

    
3675
class SignalWakeupFd(object):
3676
  try:
3677
    # This is only supported in Python 2.5 and above (some distributions
3678
    # backported it to Python 2.4)
3679
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3680
  except AttributeError:
3681
    # Not supported
3682
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3683
      return -1
3684
  else:
3685
    def _SetWakeupFd(self, fd):
3686
      return self._set_wakeup_fd_fn(fd)
3687

    
3688
  def __init__(self):
3689
    """Initializes this class.
3690

3691
    """
3692
    (read_fd, write_fd) = os.pipe()
3693

    
3694
    # Once these succeeded, the file descriptors will be closed automatically.
3695
    # Buffer size 0 is important, otherwise .read() with a specified length
3696
    # might buffer data and the file descriptors won't be marked readable.
3697
    self._read_fh = os.fdopen(read_fd, "r", 0)
3698
    self._write_fh = os.fdopen(write_fd, "w", 0)
3699

    
3700
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3701

    
3702
    # Utility functions
3703
    self.fileno = self._read_fh.fileno
3704
    self.read = self._read_fh.read
3705

    
3706
  def Reset(self):
3707
    """Restores the previous wakeup file descriptor.
3708

3709
    """
3710
    if hasattr(self, "_previous") and self._previous is not None:
3711
      self._SetWakeupFd(self._previous)
3712
      self._previous = None
3713

    
3714
  def Notify(self):
3715
    """Notifies the wakeup file descriptor.
3716

3717
    """
3718
    self._write_fh.write("\0")
3719

    
3720
  def __del__(self):
3721
    """Called before object deletion.
3722

3723
    """
3724
    self.Reset()
3725

    
3726

    
3727
class SignalHandler(object):
3728
  """Generic signal handler class.
3729

3730
  It automatically restores the original handler when deconstructed or
3731
  when L{Reset} is called. You can either pass your own handler
3732
  function in or query the L{called} attribute to detect whether the
3733
  signal was sent.
3734

3735
  @type signum: list
3736
  @ivar signum: the signals we handle
3737
  @type called: boolean
3738
  @ivar called: tracks whether any of the signals have been raised
3739

3740
  """
3741
  def __init__(self, signum, handler_fn=None, wakeup=None):
3742
    """Constructs a new SignalHandler instance.
3743

3744
    @type signum: int or list of ints
3745
    @param signum: Single signal number or set of signal numbers
3746
    @type handler_fn: callable
3747
    @param handler_fn: Signal handling function
3748

3749
    """
3750
    assert handler_fn is None or callable(handler_fn)
3751

    
3752
    self.signum = set(signum)
3753
    self.called = False
3754

    
3755
    self._handler_fn = handler_fn
3756
    self._wakeup = wakeup
3757

    
3758
    self._previous = {}
3759
    try:
3760
      for signum in self.signum:
3761
        # Setup handler
3762
        prev_handler = signal.signal(signum, self._HandleSignal)
3763
        try:
3764
          self._previous[signum] = prev_handler
3765
        except:
3766
          # Restore previous handler
3767
          signal.signal(signum, prev_handler)
3768
          raise
3769
    except:
3770
      # Reset all handlers
3771
      self.Reset()
3772
      # Here we have a race condition: a handler may have already been called,
3773
      # but there's not much we can do about it at this point.
3774
      raise
3775

    
3776
  def __del__(self):
3777
    self.Reset()
3778

    
3779
  def Reset(self):
3780
    """Restore previous handler.
3781

3782
    This will reset all the signals to their previous handlers.
3783

3784
    """
3785
    for signum, prev_handler in self._previous.items():
3786
      signal.signal(signum, prev_handler)
3787
      # If successful, remove from dict
3788
      del self._previous[signum]
3789

    
3790
  def Clear(self):
3791
    """Unsets the L{called} flag.
3792

3793
    This function can be used in case a signal may arrive several times.
3794

3795
    """
3796
    self.called = False
3797

    
3798
  def _HandleSignal(self, signum, frame):
3799
    """Actual signal handling function.
3800

3801
    """
3802
    # This is not nice and not absolutely atomic, but it appears to be the only
3803
    # solution in Python -- there are no atomic types.
3804
    self.called = True
3805

    
3806
    if self._wakeup:
3807
      # Notify whoever is interested in signals
3808
      self._wakeup.Notify()
3809

    
3810
    if self._handler_fn:
3811
      self._handler_fn(signum, frame)
3812

    
3813

    
3814
class FieldSet(object):
3815
  """A simple field set.
3816

3817
  Among the features are:
3818
    - checking if a string is among a list of static string or regex objects
3819
    - checking if a whole list of string matches
3820
    - returning the matching groups from a regex match
3821

3822
  Internally, all fields are held as regular expression objects.
3823

3824
  """
3825
  def __init__(self, *items):
3826
    self.items = [re.compile("^%s$" % value) for value in items]
3827

    
3828
  def Extend(self, other_set):
3829
    """Extend the field set with the items from another one"""
3830
    self.items.extend(other_set.items)
3831

    
3832
  def Matches(self, field):
3833
    """Checks if a field matches the current set
3834

3835
    @type field: str
3836
    @param field: the string to match
3837
    @return: either None or a regular expression match object
3838

3839
    """
3840
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3841
      return m
3842
    return None
3843

    
3844
  def NonMatching(self, items):
3845
    """Returns the list of fields not matching the current set
3846

3847
    @type items: list
3848
    @param items: the list of fields to check
3849
    @rtype: list
3850
    @return: list of non-matching fields
3851

3852
    """
3853
    return [val for val in items if not self.Matches(val)]