Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 743b53d4

History | View | Annotate | Download (104.5 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  import ctypes
59
except ImportError:
60
  ctypes = None
61

    
62
from ganeti import errors
63
from ganeti import constants
64
from ganeti import compat
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
85
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
86
#
87
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
88
# pid_t to "int".
89
#
90
# IEEE Std 1003.1-2008:
91
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
92
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
93
_STRUCT_UCRED = "iII"
94
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
95

    
96
# Certificate verification results
97
(CERT_WARNING,
98
 CERT_ERROR) = range(1, 3)
99

    
100
# Flags for mlockall() (from bits/mman.h)
101
_MCL_CURRENT = 1
102
_MCL_FUTURE = 2
103

    
104

    
105
class RunResult(object):
106
  """Holds the result of running external programs.
107

108
  @type exit_code: int
109
  @ivar exit_code: the exit code of the program, or None (if the program
110
      didn't exit())
111
  @type signal: int or None
112
  @ivar signal: the signal that caused the program to finish, or None
113
      (if the program wasn't terminated by a signal)
114
  @type stdout: str
115
  @ivar stdout: the standard output of the program
116
  @type stderr: str
117
  @ivar stderr: the standard error of the program
118
  @type failed: boolean
119
  @ivar failed: True in case the program was
120
      terminated by a signal or exited with a non-zero exit code
121
  @ivar fail_reason: a string detailing the termination reason
122

123
  """
124
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
125
               "failed", "fail_reason", "cmd"]
126

    
127

    
128
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
129
    self.cmd = cmd
130
    self.exit_code = exit_code
131
    self.signal = signal_
132
    self.stdout = stdout
133
    self.stderr = stderr
134
    self.failed = (signal_ is not None or exit_code != 0)
135

    
136
    if self.signal is not None:
137
      self.fail_reason = "terminated by signal %s" % self.signal
138
    elif self.exit_code is not None:
139
      self.fail_reason = "exited with exit code %s" % self.exit_code
140
    else:
141
      self.fail_reason = "unable to determine termination reason"
142

    
143
    if self.failed:
144
      logging.debug("Command '%s' failed (%s); output: %s",
145
                    self.cmd, self.fail_reason, self.output)
146

    
147
  def _GetOutput(self):
148
    """Returns the combined stdout and stderr for easier usage.
149

150
    """
151
    return self.stdout + self.stderr
152

    
153
  output = property(_GetOutput, None, None, "Return full output")
154

    
155

    
156
def _BuildCmdEnvironment(env, reset):
157
  """Builds the environment for an external program.
158

159
  """
160
  if reset:
161
    cmd_env = {}
162
  else:
163
    cmd_env = os.environ.copy()
164
    cmd_env["LC_ALL"] = "C"
165

    
166
  if env is not None:
167
    cmd_env.update(env)
168

    
169
  return cmd_env
170

    
171

    
172
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
173
  """Execute a (shell) command.
174

175
  The command should not read from its standard input, as it will be
176
  closed.
177

178
  @type cmd: string or list
179
  @param cmd: Command to run
180
  @type env: dict
181
  @param env: Additional environment variables
182
  @type output: str
183
  @param output: if desired, the output of the command can be
184
      saved in a file instead of the RunResult instance; this
185
      parameter denotes the file name (if not None)
186
  @type cwd: string
187
  @param cwd: if specified, will be used as the working
188
      directory for the command; the default will be /
189
  @type reset_env: boolean
190
  @param reset_env: whether to reset or keep the default os environment
191
  @rtype: L{RunResult}
192
  @return: RunResult instance
193
  @raise errors.ProgrammerError: if we call this when forks are disabled
194

195
  """
196
  if no_fork:
197
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
198

    
199
  if isinstance(cmd, basestring):
200
    strcmd = cmd
201
    shell = True
202
  else:
203
    cmd = [str(val) for val in cmd]
204
    strcmd = ShellQuoteArgs(cmd)
205
    shell = False
206

    
207
  if output:
208
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
209
  else:
210
    logging.debug("RunCmd %s", strcmd)
211

    
212
  cmd_env = _BuildCmdEnvironment(env, reset_env)
213

    
214
  try:
215
    if output is None:
216
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
217
    else:
218
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
219
      out = err = ""
220
  except OSError, err:
221
    if err.errno == errno.ENOENT:
222
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
223
                               (strcmd, err))
224
    else:
225
      raise
226

    
227
  if status >= 0:
228
    exitcode = status
229
    signal_ = None
230
  else:
231
    exitcode = None
232
    signal_ = -status
233

    
234
  return RunResult(exitcode, signal_, out, err, strcmd)
235

    
236

    
237
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
238
                pidfile=None):
239
  """Start a daemon process after forking twice.
240

241
  @type cmd: string or list
242
  @param cmd: Command to run
243
  @type env: dict
244
  @param env: Additional environment variables
245
  @type cwd: string
246
  @param cwd: Working directory for the program
247
  @type output: string
248
  @param output: Path to file in which to save the output
249
  @type output_fd: int
250
  @param output_fd: File descriptor for output
251
  @type pidfile: string
252
  @param pidfile: Process ID file
253
  @rtype: int
254
  @return: Daemon process ID
255
  @raise errors.ProgrammerError: if we call this when forks are disabled
256

257
  """
258
  if no_fork:
259
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
260
                                 " disabled")
261

    
262
  if output and not (bool(output) ^ (output_fd is not None)):
263
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
264
                                 " specified")
265

    
266
  if isinstance(cmd, basestring):
267
    cmd = ["/bin/sh", "-c", cmd]
268

    
269
  strcmd = ShellQuoteArgs(cmd)
270

    
271
  if output:
272
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
273
  else:
274
    logging.debug("StartDaemon %s", strcmd)
275

    
276
  cmd_env = _BuildCmdEnvironment(env, False)
277

    
278
  # Create pipe for sending PID back
279
  (pidpipe_read, pidpipe_write) = os.pipe()
280
  try:
281
    try:
282
      # Create pipe for sending error messages
283
      (errpipe_read, errpipe_write) = os.pipe()
284
      try:
285
        try:
286
          # First fork
287
          pid = os.fork()
288
          if pid == 0:
289
            try:
290
              # Child process, won't return
291
              _StartDaemonChild(errpipe_read, errpipe_write,
292
                                pidpipe_read, pidpipe_write,
293
                                cmd, cmd_env, cwd,
294
                                output, output_fd, pidfile)
295
            finally:
296
              # Well, maybe child process failed
297
              os._exit(1) # pylint: disable-msg=W0212
298
        finally:
299
          _CloseFDNoErr(errpipe_write)
300

    
301
        # Wait for daemon to be started (or an error message to arrive) and read
302
        # up to 100 KB as an error message
303
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
304
      finally:
305
        _CloseFDNoErr(errpipe_read)
306
    finally:
307
      _CloseFDNoErr(pidpipe_write)
308

    
309
    # Read up to 128 bytes for PID
310
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
311
  finally:
312
    _CloseFDNoErr(pidpipe_read)
313

    
314
  # Try to avoid zombies by waiting for child process
315
  try:
316
    os.waitpid(pid, 0)
317
  except OSError:
318
    pass
319

    
320
  if errormsg:
321
    raise errors.OpExecError("Error when starting daemon process: %r" %
322
                             errormsg)
323

    
324
  try:
325
    return int(pidtext)
326
  except (ValueError, TypeError), err:
327
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
328
                             (pidtext, err))
329

    
330

    
331
def _StartDaemonChild(errpipe_read, errpipe_write,
332
                      pidpipe_read, pidpipe_write,
333
                      args, env, cwd,
334
                      output, fd_output, pidfile):
335
  """Child process for starting daemon.
336

337
  """
338
  try:
339
    # Close parent's side
340
    _CloseFDNoErr(errpipe_read)
341
    _CloseFDNoErr(pidpipe_read)
342

    
343
    # First child process
344
    os.chdir("/")
345
    os.umask(077)
346
    os.setsid()
347

    
348
    # And fork for the second time
349
    pid = os.fork()
350
    if pid != 0:
351
      # Exit first child process
352
      os._exit(0) # pylint: disable-msg=W0212
353

    
354
    # Make sure pipe is closed on execv* (and thereby notifies original process)
355
    SetCloseOnExecFlag(errpipe_write, True)
356

    
357
    # List of file descriptors to be left open
358
    noclose_fds = [errpipe_write]
359

    
360
    # Open PID file
361
    if pidfile:
362
      try:
363
        # TODO: Atomic replace with another locked file instead of writing into
364
        # it after creating
365
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
366

    
367
        # Lock the PID file (and fail if not possible to do so). Any code
368
        # wanting to send a signal to the daemon should try to lock the PID
369
        # file before reading it. If acquiring the lock succeeds, the daemon is
370
        # no longer running and the signal should not be sent.
371
        LockFile(fd_pidfile)
372

    
373
        os.write(fd_pidfile, "%d\n" % os.getpid())
374
      except Exception, err:
375
        raise Exception("Creating and locking PID file failed: %s" % err)
376

    
377
      # Keeping the file open to hold the lock
378
      noclose_fds.append(fd_pidfile)
379

    
380
      SetCloseOnExecFlag(fd_pidfile, False)
381
    else:
382
      fd_pidfile = None
383

    
384
    # Open /dev/null
385
    fd_devnull = os.open(os.devnull, os.O_RDWR)
386

    
387
    assert not output or (bool(output) ^ (fd_output is not None))
388

    
389
    if fd_output is not None:
390
      pass
391
    elif output:
392
      # Open output file
393
      try:
394
        # TODO: Implement flag to set append=yes/no
395
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
396
      except EnvironmentError, err:
397
        raise Exception("Opening output file failed: %s" % err)
398
    else:
399
      fd_output = fd_devnull
400

    
401
    # Redirect standard I/O
402
    os.dup2(fd_devnull, 0)
403
    os.dup2(fd_output, 1)
404
    os.dup2(fd_output, 2)
405

    
406
    # Send daemon PID to parent
407
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
408

    
409
    # Close all file descriptors except stdio and error message pipe
410
    CloseFDs(noclose_fds=noclose_fds)
411

    
412
    # Change working directory
413
    os.chdir(cwd)
414

    
415
    if env is None:
416
      os.execvp(args[0], args)
417
    else:
418
      os.execvpe(args[0], args, env)
419
  except: # pylint: disable-msg=W0702
420
    try:
421
      # Report errors to original process
422
      buf = str(sys.exc_info()[1])
423

    
424
      RetryOnSignal(os.write, errpipe_write, buf)
425
    except: # pylint: disable-msg=W0702
426
      # Ignore errors in error handling
427
      pass
428

    
429
  os._exit(1) # pylint: disable-msg=W0212
430

    
431

    
432
def _RunCmdPipe(cmd, env, via_shell, cwd):
433
  """Run a command and return its output.
434

435
  @type  cmd: string or list
436
  @param cmd: Command to run
437
  @type env: dict
438
  @param env: The environment to use
439
  @type via_shell: bool
440
  @param via_shell: if we should run via the shell
441
  @type cwd: string
442
  @param cwd: the working directory for the program
443
  @rtype: tuple
444
  @return: (out, err, status)
445

446
  """
447
  poller = select.poll()
448
  child = subprocess.Popen(cmd, shell=via_shell,
449
                           stderr=subprocess.PIPE,
450
                           stdout=subprocess.PIPE,
451
                           stdin=subprocess.PIPE,
452
                           close_fds=True, env=env,
453
                           cwd=cwd)
454

    
455
  child.stdin.close()
456
  poller.register(child.stdout, select.POLLIN)
457
  poller.register(child.stderr, select.POLLIN)
458
  out = StringIO()
459
  err = StringIO()
460
  fdmap = {
461
    child.stdout.fileno(): (out, child.stdout),
462
    child.stderr.fileno(): (err, child.stderr),
463
    }
464
  for fd in fdmap:
465
    SetNonblockFlag(fd, True)
466

    
467
  while fdmap:
468
    pollresult = RetryOnSignal(poller.poll)
469

    
470
    for fd, event in pollresult:
471
      if event & select.POLLIN or event & select.POLLPRI:
472
        data = fdmap[fd][1].read()
473
        # no data from read signifies EOF (the same as POLLHUP)
474
        if not data:
475
          poller.unregister(fd)
476
          del fdmap[fd]
477
          continue
478
        fdmap[fd][0].write(data)
479
      if (event & select.POLLNVAL or event & select.POLLHUP or
480
          event & select.POLLERR):
481
        poller.unregister(fd)
482
        del fdmap[fd]
483

    
484
  out = out.getvalue()
485
  err = err.getvalue()
486

    
487
  status = child.wait()
488
  return out, err, status
489

    
490

    
491
def _RunCmdFile(cmd, env, via_shell, output, cwd):
492
  """Run a command and save its output to a file.
493

494
  @type  cmd: string or list
495
  @param cmd: Command to run
496
  @type env: dict
497
  @param env: The environment to use
498
  @type via_shell: bool
499
  @param via_shell: if we should run via the shell
500
  @type output: str
501
  @param output: the filename in which to save the output
502
  @type cwd: string
503
  @param cwd: the working directory for the program
504
  @rtype: int
505
  @return: the exit status
506

507
  """
508
  fh = open(output, "a")
509
  try:
510
    child = subprocess.Popen(cmd, shell=via_shell,
511
                             stderr=subprocess.STDOUT,
512
                             stdout=fh,
513
                             stdin=subprocess.PIPE,
514
                             close_fds=True, env=env,
515
                             cwd=cwd)
516

    
517
    child.stdin.close()
518
    status = child.wait()
519
  finally:
520
    fh.close()
521
  return status
522

    
523

    
524
def SetCloseOnExecFlag(fd, enable):
525
  """Sets or unsets the close-on-exec flag on a file descriptor.
526

527
  @type fd: int
528
  @param fd: File descriptor
529
  @type enable: bool
530
  @param enable: Whether to set or unset it.
531

532
  """
533
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
534

    
535
  if enable:
536
    flags |= fcntl.FD_CLOEXEC
537
  else:
538
    flags &= ~fcntl.FD_CLOEXEC
539

    
540
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
541

    
542

    
543
def SetNonblockFlag(fd, enable):
544
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
545

546
  @type fd: int
547
  @param fd: File descriptor
548
  @type enable: bool
549
  @param enable: Whether to set or unset it
550

551
  """
552
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
553

    
554
  if enable:
555
    flags |= os.O_NONBLOCK
556
  else:
557
    flags &= ~os.O_NONBLOCK
558

    
559
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
560

    
561

    
562
def RetryOnSignal(fn, *args, **kwargs):
563
  """Calls a function again if it failed due to EINTR.
564

565
  """
566
  while True:
567
    try:
568
      return fn(*args, **kwargs)
569
    except EnvironmentError, err:
570
      if err.errno != errno.EINTR:
571
        raise
572
    except (socket.error, select.error), err:
573
      # In python 2.6 and above select.error is an IOError, so it's handled
574
      # above, in 2.5 and below it's not, and it's handled here.
575
      if not (err.args and err.args[0] == errno.EINTR):
576
        raise
577

    
578

    
579
def RunParts(dir_name, env=None, reset_env=False):
580
  """Run Scripts or programs in a directory
581

582
  @type dir_name: string
583
  @param dir_name: absolute path to a directory
584
  @type env: dict
585
  @param env: The environment to use
586
  @type reset_env: boolean
587
  @param reset_env: whether to reset or keep the default os environment
588
  @rtype: list of tuples
589
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
590

591
  """
592
  rr = []
593

    
594
  try:
595
    dir_contents = ListVisibleFiles(dir_name)
596
  except OSError, err:
597
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
598
    return rr
599

    
600
  for relname in sorted(dir_contents):
601
    fname = PathJoin(dir_name, relname)
602
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
603
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
604
      rr.append((relname, constants.RUNPARTS_SKIP, None))
605
    else:
606
      try:
607
        result = RunCmd([fname], env=env, reset_env=reset_env)
608
      except Exception, err: # pylint: disable-msg=W0703
609
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
610
      else:
611
        rr.append((relname, constants.RUNPARTS_RUN, result))
612

    
613
  return rr
614

    
615

    
616
def GetSocketCredentials(sock):
617
  """Returns the credentials of the foreign process connected to a socket.
618

619
  @param sock: Unix socket
620
  @rtype: tuple; (number, number, number)
621
  @return: The PID, UID and GID of the connected foreign process.
622

623
  """
624
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
625
                             _STRUCT_UCRED_SIZE)
626
  return struct.unpack(_STRUCT_UCRED, peercred)
627

    
628

    
629
def RemoveFile(filename):
630
  """Remove a file ignoring some errors.
631

632
  Remove a file, ignoring non-existing ones or directories. Other
633
  errors are passed.
634

635
  @type filename: str
636
  @param filename: the file to be removed
637

638
  """
639
  try:
640
    os.unlink(filename)
641
  except OSError, err:
642
    if err.errno not in (errno.ENOENT, errno.EISDIR):
643
      raise
644

    
645

    
646
def RemoveDir(dirname):
647
  """Remove an empty directory.
648

649
  Remove a directory, ignoring non-existing ones.
650
  Other errors are passed. This includes the case,
651
  where the directory is not empty, so it can't be removed.
652

653
  @type dirname: str
654
  @param dirname: the empty directory to be removed
655

656
  """
657
  try:
658
    os.rmdir(dirname)
659
  except OSError, err:
660
    if err.errno != errno.ENOENT:
661
      raise
662

    
663

    
664
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
665
  """Renames a file.
666

667
  @type old: string
668
  @param old: Original path
669
  @type new: string
670
  @param new: New path
671
  @type mkdir: bool
672
  @param mkdir: Whether to create target directory if it doesn't exist
673
  @type mkdir_mode: int
674
  @param mkdir_mode: Mode for newly created directories
675

676
  """
677
  try:
678
    return os.rename(old, new)
679
  except OSError, err:
680
    # In at least one use case of this function, the job queue, directory
681
    # creation is very rare. Checking for the directory before renaming is not
682
    # as efficient.
683
    if mkdir and err.errno == errno.ENOENT:
684
      # Create directory and try again
685
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
686

    
687
      return os.rename(old, new)
688

    
689
    raise
690

    
691

    
692
def Makedirs(path, mode=0750):
693
  """Super-mkdir; create a leaf directory and all intermediate ones.
694

695
  This is a wrapper around C{os.makedirs} adding error handling not implemented
696
  before Python 2.5.
697

698
  """
699
  try:
700
    os.makedirs(path, mode)
701
  except OSError, err:
702
    # Ignore EEXIST. This is only handled in os.makedirs as included in
703
    # Python 2.5 and above.
704
    if err.errno != errno.EEXIST or not os.path.exists(path):
705
      raise
706

    
707

    
708
def ResetTempfileModule():
709
  """Resets the random name generator of the tempfile module.
710

711
  This function should be called after C{os.fork} in the child process to
712
  ensure it creates a newly seeded random generator. Otherwise it would
713
  generate the same random parts as the parent process. If several processes
714
  race for the creation of a temporary file, this could lead to one not getting
715
  a temporary name.
716

717
  """
718
  # pylint: disable-msg=W0212
719
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
720
    tempfile._once_lock.acquire()
721
    try:
722
      # Reset random name generator
723
      tempfile._name_sequence = None
724
    finally:
725
      tempfile._once_lock.release()
726
  else:
727
    logging.critical("The tempfile module misses at least one of the"
728
                     " '_once_lock' and '_name_sequence' attributes")
729

    
730

    
731
def _FingerprintFile(filename):
732
  """Compute the fingerprint of a file.
733

734
  If the file does not exist, a None will be returned
735
  instead.
736

737
  @type filename: str
738
  @param filename: the filename to checksum
739
  @rtype: str
740
  @return: the hex digest of the sha checksum of the contents
741
      of the file
742

743
  """
744
  if not (os.path.exists(filename) and os.path.isfile(filename)):
745
    return None
746

    
747
  f = open(filename)
748

    
749
  fp = compat.sha1_hash()
750
  while True:
751
    data = f.read(4096)
752
    if not data:
753
      break
754

    
755
    fp.update(data)
756

    
757
  return fp.hexdigest()
758

    
759

    
760
def FingerprintFiles(files):
761
  """Compute fingerprints for a list of files.
762

763
  @type files: list
764
  @param files: the list of filename to fingerprint
765
  @rtype: dict
766
  @return: a dictionary filename: fingerprint, holding only
767
      existing files
768

769
  """
770
  ret = {}
771

    
772
  for filename in files:
773
    cksum = _FingerprintFile(filename)
774
    if cksum:
775
      ret[filename] = cksum
776

    
777
  return ret
778

    
779

    
780
def ForceDictType(target, key_types, allowed_values=None):
781
  """Force the values of a dict to have certain types.
782

783
  @type target: dict
784
  @param target: the dict to update
785
  @type key_types: dict
786
  @param key_types: dict mapping target dict keys to types
787
                    in constants.ENFORCEABLE_TYPES
788
  @type allowed_values: list
789
  @keyword allowed_values: list of specially allowed values
790

791
  """
792
  if allowed_values is None:
793
    allowed_values = []
794

    
795
  if not isinstance(target, dict):
796
    msg = "Expected dictionary, got '%s'" % target
797
    raise errors.TypeEnforcementError(msg)
798

    
799
  for key in target:
800
    if key not in key_types:
801
      msg = "Unknown key '%s'" % key
802
      raise errors.TypeEnforcementError(msg)
803

    
804
    if target[key] in allowed_values:
805
      continue
806

    
807
    ktype = key_types[key]
808
    if ktype not in constants.ENFORCEABLE_TYPES:
809
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
810
      raise errors.ProgrammerError(msg)
811

    
812
    if ktype == constants.VTYPE_STRING:
813
      if not isinstance(target[key], basestring):
814
        if isinstance(target[key], bool) and not target[key]:
815
          target[key] = ''
816
        else:
817
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
818
          raise errors.TypeEnforcementError(msg)
819
    elif ktype == constants.VTYPE_BOOL:
820
      if isinstance(target[key], basestring) and target[key]:
821
        if target[key].lower() == constants.VALUE_FALSE:
822
          target[key] = False
823
        elif target[key].lower() == constants.VALUE_TRUE:
824
          target[key] = True
825
        else:
826
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
827
          raise errors.TypeEnforcementError(msg)
828
      elif target[key]:
829
        target[key] = True
830
      else:
831
        target[key] = False
832
    elif ktype == constants.VTYPE_SIZE:
833
      try:
834
        target[key] = ParseUnit(target[key])
835
      except errors.UnitParseError, err:
836
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
837
              (key, target[key], err)
838
        raise errors.TypeEnforcementError(msg)
839
    elif ktype == constants.VTYPE_INT:
840
      try:
841
        target[key] = int(target[key])
842
      except (ValueError, TypeError):
843
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
844
        raise errors.TypeEnforcementError(msg)
845

    
846

    
847
def _GetProcStatusPath(pid):
848
  """Returns the path for a PID's proc status file.
849

850
  @type pid: int
851
  @param pid: Process ID
852
  @rtype: string
853

854
  """
855
  return "/proc/%d/status" % pid
856

    
857

    
858
def IsProcessAlive(pid):
859
  """Check if a given pid exists on the system.
860

861
  @note: zombie status is not handled, so zombie processes
862
      will be returned as alive
863
  @type pid: int
864
  @param pid: the process ID to check
865
  @rtype: boolean
866
  @return: True if the process exists
867

868
  """
869
  def _TryStat(name):
870
    try:
871
      os.stat(name)
872
      return True
873
    except EnvironmentError, err:
874
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
875
        return False
876
      elif err.errno == errno.EINVAL:
877
        raise RetryAgain(err)
878
      raise
879

    
880
  assert isinstance(pid, int), "pid must be an integer"
881
  if pid <= 0:
882
    return False
883

    
884
  # /proc in a multiprocessor environment can have strange behaviors.
885
  # Retry the os.stat a few times until we get a good result.
886
  try:
887
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
888
                 args=[_GetProcStatusPath(pid)])
889
  except RetryTimeout, err:
890
    err.RaiseInner()
891

    
892

    
893
def _ParseSigsetT(sigset):
894
  """Parse a rendered sigset_t value.
895

896
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
897
  function.
898

899
  @type sigset: string
900
  @param sigset: Rendered signal set from /proc/$pid/status
901
  @rtype: set
902
  @return: Set of all enabled signal numbers
903

904
  """
905
  result = set()
906

    
907
  signum = 0
908
  for ch in reversed(sigset):
909
    chv = int(ch, 16)
910

    
911
    # The following could be done in a loop, but it's easier to read and
912
    # understand in the unrolled form
913
    if chv & 1:
914
      result.add(signum + 1)
915
    if chv & 2:
916
      result.add(signum + 2)
917
    if chv & 4:
918
      result.add(signum + 3)
919
    if chv & 8:
920
      result.add(signum + 4)
921

    
922
    signum += 4
923

    
924
  return result
925

    
926

    
927
def _GetProcStatusField(pstatus, field):
928
  """Retrieves a field from the contents of a proc status file.
929

930
  @type pstatus: string
931
  @param pstatus: Contents of /proc/$pid/status
932
  @type field: string
933
  @param field: Name of field whose value should be returned
934
  @rtype: string
935

936
  """
937
  for line in pstatus.splitlines():
938
    parts = line.split(":", 1)
939

    
940
    if len(parts) < 2 or parts[0] != field:
941
      continue
942

    
943
    return parts[1].strip()
944

    
945
  return None
946

    
947

    
948
def IsProcessHandlingSignal(pid, signum, status_path=None):
949
  """Checks whether a process is handling a signal.
950

951
  @type pid: int
952
  @param pid: Process ID
953
  @type signum: int
954
  @param signum: Signal number
955
  @rtype: bool
956

957
  """
958
  if status_path is None:
959
    status_path = _GetProcStatusPath(pid)
960

    
961
  try:
962
    proc_status = ReadFile(status_path)
963
  except EnvironmentError, err:
964
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
965
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
966
      return False
967
    raise
968

    
969
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
970
  if sigcgt is None:
971
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
972

    
973
  # Now check whether signal is handled
974
  return signum in _ParseSigsetT(sigcgt)
975

    
976

    
977
def ReadPidFile(pidfile):
978
  """Read a pid from a file.
979

980
  @type  pidfile: string
981
  @param pidfile: path to the file containing the pid
982
  @rtype: int
983
  @return: The process id, if the file exists and contains a valid PID,
984
           otherwise 0
985

986
  """
987
  try:
988
    raw_data = ReadOneLineFile(pidfile)
989
  except EnvironmentError, err:
990
    if err.errno != errno.ENOENT:
991
      logging.exception("Can't read pid file")
992
    return 0
993

    
994
  try:
995
    pid = int(raw_data)
996
  except (TypeError, ValueError), err:
997
    logging.info("Can't parse pid file contents", exc_info=True)
998
    return 0
999

    
1000
  return pid
1001

    
1002

    
1003
def ReadLockedPidFile(path):
1004
  """Reads a locked PID file.
1005

1006
  This can be used together with L{StartDaemon}.
1007

1008
  @type path: string
1009
  @param path: Path to PID file
1010
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1011

1012
  """
1013
  try:
1014
    fd = os.open(path, os.O_RDONLY)
1015
  except EnvironmentError, err:
1016
    if err.errno == errno.ENOENT:
1017
      # PID file doesn't exist
1018
      return None
1019
    raise
1020

    
1021
  try:
1022
    try:
1023
      # Try to acquire lock
1024
      LockFile(fd)
1025
    except errors.LockError:
1026
      # Couldn't lock, daemon is running
1027
      return int(os.read(fd, 100))
1028
  finally:
1029
    os.close(fd)
1030

    
1031
  return None
1032

    
1033

    
1034
def MatchNameComponent(key, name_list, case_sensitive=True):
1035
  """Try to match a name against a list.
1036

1037
  This function will try to match a name like test1 against a list
1038
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1039
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1040
  not I{'test1.ex'}. A multiple match will be considered as no match
1041
  at all (e.g. I{'test1'} against C{['test1.example.com',
1042
  'test1.example.org']}), except when the key fully matches an entry
1043
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1044

1045
  @type key: str
1046
  @param key: the name to be searched
1047
  @type name_list: list
1048
  @param name_list: the list of strings against which to search the key
1049
  @type case_sensitive: boolean
1050
  @param case_sensitive: whether to provide a case-sensitive match
1051

1052
  @rtype: None or str
1053
  @return: None if there is no match I{or} if there are multiple matches,
1054
      otherwise the element from the list which matches
1055

1056
  """
1057
  if key in name_list:
1058
    return key
1059

    
1060
  re_flags = 0
1061
  if not case_sensitive:
1062
    re_flags |= re.IGNORECASE
1063
    key = key.upper()
1064
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1065
  names_filtered = []
1066
  string_matches = []
1067
  for name in name_list:
1068
    if mo.match(name) is not None:
1069
      names_filtered.append(name)
1070
      if not case_sensitive and key == name.upper():
1071
        string_matches.append(name)
1072

    
1073
  if len(string_matches) == 1:
1074
    return string_matches[0]
1075
  if len(names_filtered) == 1:
1076
    return names_filtered[0]
1077
  return None
1078

    
1079

    
1080
class HostInfo:
1081
  """Class implementing resolver and hostname functionality
1082

1083
  """
1084
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1085

    
1086
  def __init__(self, name=None):
1087
    """Initialize the host name object.
1088

1089
    If the name argument is not passed, it will use this system's
1090
    name.
1091

1092
    """
1093
    if name is None:
1094
      name = self.SysName()
1095

    
1096
    self.query = name
1097
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1098
    self.ip = self.ipaddrs[0]
1099

    
1100
  def ShortName(self):
1101
    """Returns the hostname without domain.
1102

1103
    """
1104
    return self.name.split('.')[0]
1105

    
1106
  @staticmethod
1107
  def SysName():
1108
    """Return the current system's name.
1109

1110
    This is simply a wrapper over C{socket.gethostname()}.
1111

1112
    """
1113
    return socket.gethostname()
1114

    
1115
  @staticmethod
1116
  def LookupHostname(hostname):
1117
    """Look up hostname
1118

1119
    @type hostname: str
1120
    @param hostname: hostname to look up
1121

1122
    @rtype: tuple
1123
    @return: a tuple (name, aliases, ipaddrs) as returned by
1124
        C{socket.gethostbyname_ex}
1125
    @raise errors.ResolverError: in case of errors in resolving
1126

1127
    """
1128
    try:
1129
      result = socket.gethostbyname_ex(hostname)
1130
    except socket.gaierror, err:
1131
      # hostname not found in DNS
1132
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1133

    
1134
    return result
1135

    
1136
  @classmethod
1137
  def NormalizeName(cls, hostname):
1138
    """Validate and normalize the given hostname.
1139

1140
    @attention: the validation is a bit more relaxed than the standards
1141
        require; most importantly, we allow underscores in names
1142
    @raise errors.OpPrereqError: when the name is not valid
1143

1144
    """
1145
    hostname = hostname.lower()
1146
    if (not cls._VALID_NAME_RE.match(hostname) or
1147
        # double-dots, meaning empty label
1148
        ".." in hostname or
1149
        # empty initial label
1150
        hostname.startswith(".")):
1151
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1152
                                 errors.ECODE_INVAL)
1153
    if hostname.endswith("."):
1154
      hostname = hostname.rstrip(".")
1155
    return hostname
1156

    
1157

    
1158
def GetHostInfo(name=None):
1159
  """Lookup host name and raise an OpPrereqError for failures"""
1160

    
1161
  try:
1162
    return HostInfo(name)
1163
  except errors.ResolverError, err:
1164
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1165
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1166

    
1167

    
1168
def ListVolumeGroups():
1169
  """List volume groups and their size
1170

1171
  @rtype: dict
1172
  @return:
1173
       Dictionary with keys volume name and values
1174
       the size of the volume
1175

1176
  """
1177
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1178
  result = RunCmd(command)
1179
  retval = {}
1180
  if result.failed:
1181
    return retval
1182

    
1183
  for line in result.stdout.splitlines():
1184
    try:
1185
      name, size = line.split()
1186
      size = int(float(size))
1187
    except (IndexError, ValueError), err:
1188
      logging.error("Invalid output from vgs (%s): %s", err, line)
1189
      continue
1190

    
1191
    retval[name] = size
1192

    
1193
  return retval
1194

    
1195

    
1196
def BridgeExists(bridge):
1197
  """Check whether the given bridge exists in the system
1198

1199
  @type bridge: str
1200
  @param bridge: the bridge name to check
1201
  @rtype: boolean
1202
  @return: True if it does
1203

1204
  """
1205
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1206

    
1207

    
1208
def NiceSort(name_list):
1209
  """Sort a list of strings based on digit and non-digit groupings.
1210

1211
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1212
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1213
  'a11']}.
1214

1215
  The sort algorithm breaks each name in groups of either only-digits
1216
  or no-digits. Only the first eight such groups are considered, and
1217
  after that we just use what's left of the string.
1218

1219
  @type name_list: list
1220
  @param name_list: the names to be sorted
1221
  @rtype: list
1222
  @return: a copy of the name list sorted with our algorithm
1223

1224
  """
1225
  _SORTER_BASE = "(\D+|\d+)"
1226
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1227
                                                  _SORTER_BASE, _SORTER_BASE,
1228
                                                  _SORTER_BASE, _SORTER_BASE,
1229
                                                  _SORTER_BASE, _SORTER_BASE)
1230
  _SORTER_RE = re.compile(_SORTER_FULL)
1231
  _SORTER_NODIGIT = re.compile("^\D*$")
1232
  def _TryInt(val):
1233
    """Attempts to convert a variable to integer."""
1234
    if val is None or _SORTER_NODIGIT.match(val):
1235
      return val
1236
    rval = int(val)
1237
    return rval
1238

    
1239
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1240
             for name in name_list]
1241
  to_sort.sort()
1242
  return [tup[1] for tup in to_sort]
1243

    
1244

    
1245
def TryConvert(fn, val):
1246
  """Try to convert a value ignoring errors.
1247

1248
  This function tries to apply function I{fn} to I{val}. If no
1249
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1250
  the result, else it will return the original value. Any other
1251
  exceptions are propagated to the caller.
1252

1253
  @type fn: callable
1254
  @param fn: function to apply to the value
1255
  @param val: the value to be converted
1256
  @return: The converted value if the conversion was successful,
1257
      otherwise the original value.
1258

1259
  """
1260
  try:
1261
    nv = fn(val)
1262
  except (ValueError, TypeError):
1263
    nv = val
1264
  return nv
1265

    
1266

    
1267
def IsValidIP(ip):
1268
  """Verifies the syntax of an IPv4 address.
1269

1270
  This function checks if the IPv4 address passes is valid or not based
1271
  on syntax (not IP range, class calculations, etc.).
1272

1273
  @type ip: str
1274
  @param ip: the address to be checked
1275
  @rtype: a regular expression match object
1276
  @return: a regular expression match object, or None if the
1277
      address is not valid
1278

1279
  """
1280
  unit = "(0|[1-9]\d{0,2})"
1281
  #TODO: convert and return only boolean
1282
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1283

    
1284

    
1285
def IsValidShellParam(word):
1286
  """Verifies is the given word is safe from the shell's p.o.v.
1287

1288
  This means that we can pass this to a command via the shell and be
1289
  sure that it doesn't alter the command line and is passed as such to
1290
  the actual command.
1291

1292
  Note that we are overly restrictive here, in order to be on the safe
1293
  side.
1294

1295
  @type word: str
1296
  @param word: the word to check
1297
  @rtype: boolean
1298
  @return: True if the word is 'safe'
1299

1300
  """
1301
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1302

    
1303

    
1304
def BuildShellCmd(template, *args):
1305
  """Build a safe shell command line from the given arguments.
1306

1307
  This function will check all arguments in the args list so that they
1308
  are valid shell parameters (i.e. they don't contain shell
1309
  metacharacters). If everything is ok, it will return the result of
1310
  template % args.
1311

1312
  @type template: str
1313
  @param template: the string holding the template for the
1314
      string formatting
1315
  @rtype: str
1316
  @return: the expanded command line
1317

1318
  """
1319
  for word in args:
1320
    if not IsValidShellParam(word):
1321
      raise errors.ProgrammerError("Shell argument '%s' contains"
1322
                                   " invalid characters" % word)
1323
  return template % args
1324

    
1325

    
1326
def FormatUnit(value, units):
1327
  """Formats an incoming number of MiB with the appropriate unit.
1328

1329
  @type value: int
1330
  @param value: integer representing the value in MiB (1048576)
1331
  @type units: char
1332
  @param units: the type of formatting we should do:
1333
      - 'h' for automatic scaling
1334
      - 'm' for MiBs
1335
      - 'g' for GiBs
1336
      - 't' for TiBs
1337
  @rtype: str
1338
  @return: the formatted value (with suffix)
1339

1340
  """
1341
  if units not in ('m', 'g', 't', 'h'):
1342
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1343

    
1344
  suffix = ''
1345

    
1346
  if units == 'm' or (units == 'h' and value < 1024):
1347
    if units == 'h':
1348
      suffix = 'M'
1349
    return "%d%s" % (round(value, 0), suffix)
1350

    
1351
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1352
    if units == 'h':
1353
      suffix = 'G'
1354
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1355

    
1356
  else:
1357
    if units == 'h':
1358
      suffix = 'T'
1359
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1360

    
1361

    
1362
def ParseUnit(input_string):
1363
  """Tries to extract number and scale from the given string.
1364

1365
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1366
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1367
  is always an int in MiB.
1368

1369
  """
1370
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1371
  if not m:
1372
    raise errors.UnitParseError("Invalid format")
1373

    
1374
  value = float(m.groups()[0])
1375

    
1376
  unit = m.groups()[1]
1377
  if unit:
1378
    lcunit = unit.lower()
1379
  else:
1380
    lcunit = 'm'
1381

    
1382
  if lcunit in ('m', 'mb', 'mib'):
1383
    # Value already in MiB
1384
    pass
1385

    
1386
  elif lcunit in ('g', 'gb', 'gib'):
1387
    value *= 1024
1388

    
1389
  elif lcunit in ('t', 'tb', 'tib'):
1390
    value *= 1024 * 1024
1391

    
1392
  else:
1393
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1394

    
1395
  # Make sure we round up
1396
  if int(value) < value:
1397
    value += 1
1398

    
1399
  # Round up to the next multiple of 4
1400
  value = int(value)
1401
  if value % 4:
1402
    value += 4 - value % 4
1403

    
1404
  return value
1405

    
1406

    
1407
def AddAuthorizedKey(file_name, key):
1408
  """Adds an SSH public key to an authorized_keys file.
1409

1410
  @type file_name: str
1411
  @param file_name: path to authorized_keys file
1412
  @type key: str
1413
  @param key: string containing key
1414

1415
  """
1416
  key_fields = key.split()
1417

    
1418
  f = open(file_name, 'a+')
1419
  try:
1420
    nl = True
1421
    for line in f:
1422
      # Ignore whitespace changes
1423
      if line.split() == key_fields:
1424
        break
1425
      nl = line.endswith('\n')
1426
    else:
1427
      if not nl:
1428
        f.write("\n")
1429
      f.write(key.rstrip('\r\n'))
1430
      f.write("\n")
1431
      f.flush()
1432
  finally:
1433
    f.close()
1434

    
1435

    
1436
def RemoveAuthorizedKey(file_name, key):
1437
  """Removes an SSH public key from an authorized_keys file.
1438

1439
  @type file_name: str
1440
  @param file_name: path to authorized_keys file
1441
  @type key: str
1442
  @param key: string containing key
1443

1444
  """
1445
  key_fields = key.split()
1446

    
1447
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1448
  try:
1449
    out = os.fdopen(fd, 'w')
1450
    try:
1451
      f = open(file_name, 'r')
1452
      try:
1453
        for line in f:
1454
          # Ignore whitespace changes while comparing lines
1455
          if line.split() != key_fields:
1456
            out.write(line)
1457

    
1458
        out.flush()
1459
        os.rename(tmpname, file_name)
1460
      finally:
1461
        f.close()
1462
    finally:
1463
      out.close()
1464
  except:
1465
    RemoveFile(tmpname)
1466
    raise
1467

    
1468

    
1469
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1470
  """Sets the name of an IP address and hostname in /etc/hosts.
1471

1472
  @type file_name: str
1473
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1474
  @type ip: str
1475
  @param ip: the IP address
1476
  @type hostname: str
1477
  @param hostname: the hostname to be added
1478
  @type aliases: list
1479
  @param aliases: the list of aliases to add for the hostname
1480

1481
  """
1482
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1483
  # Ensure aliases are unique
1484
  aliases = UniqueSequence([hostname] + aliases)[1:]
1485

    
1486
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1487
  try:
1488
    out = os.fdopen(fd, 'w')
1489
    try:
1490
      f = open(file_name, 'r')
1491
      try:
1492
        for line in f:
1493
          fields = line.split()
1494
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1495
            continue
1496
          out.write(line)
1497

    
1498
        out.write("%s\t%s" % (ip, hostname))
1499
        if aliases:
1500
          out.write(" %s" % ' '.join(aliases))
1501
        out.write('\n')
1502

    
1503
        out.flush()
1504
        os.fsync(out)
1505
        os.chmod(tmpname, 0644)
1506
        os.rename(tmpname, file_name)
1507
      finally:
1508
        f.close()
1509
    finally:
1510
      out.close()
1511
  except:
1512
    RemoveFile(tmpname)
1513
    raise
1514

    
1515

    
1516
def AddHostToEtcHosts(hostname):
1517
  """Wrapper around SetEtcHostsEntry.
1518

1519
  @type hostname: str
1520
  @param hostname: a hostname that will be resolved and added to
1521
      L{constants.ETC_HOSTS}
1522

1523
  """
1524
  hi = HostInfo(name=hostname)
1525
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1526

    
1527

    
1528
def RemoveEtcHostsEntry(file_name, hostname):
1529
  """Removes a hostname from /etc/hosts.
1530

1531
  IP addresses without names are removed from the file.
1532

1533
  @type file_name: str
1534
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1535
  @type hostname: str
1536
  @param hostname: the hostname to be removed
1537

1538
  """
1539
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1540
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1541
  try:
1542
    out = os.fdopen(fd, 'w')
1543
    try:
1544
      f = open(file_name, 'r')
1545
      try:
1546
        for line in f:
1547
          fields = line.split()
1548
          if len(fields) > 1 and not fields[0].startswith('#'):
1549
            names = fields[1:]
1550
            if hostname in names:
1551
              while hostname in names:
1552
                names.remove(hostname)
1553
              if names:
1554
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1555
              continue
1556

    
1557
          out.write(line)
1558

    
1559
        out.flush()
1560
        os.fsync(out)
1561
        os.chmod(tmpname, 0644)
1562
        os.rename(tmpname, file_name)
1563
      finally:
1564
        f.close()
1565
    finally:
1566
      out.close()
1567
  except:
1568
    RemoveFile(tmpname)
1569
    raise
1570

    
1571

    
1572
def RemoveHostFromEtcHosts(hostname):
1573
  """Wrapper around RemoveEtcHostsEntry.
1574

1575
  @type hostname: str
1576
  @param hostname: hostname that will be resolved and its
1577
      full and shot name will be removed from
1578
      L{constants.ETC_HOSTS}
1579

1580
  """
1581
  hi = HostInfo(name=hostname)
1582
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1583
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1584

    
1585

    
1586
def TimestampForFilename():
1587
  """Returns the current time formatted for filenames.
1588

1589
  The format doesn't contain colons as some shells and applications them as
1590
  separators.
1591

1592
  """
1593
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1594

    
1595

    
1596
def CreateBackup(file_name):
1597
  """Creates a backup of a file.
1598

1599
  @type file_name: str
1600
  @param file_name: file to be backed up
1601
  @rtype: str
1602
  @return: the path to the newly created backup
1603
  @raise errors.ProgrammerError: for invalid file names
1604

1605
  """
1606
  if not os.path.isfile(file_name):
1607
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1608
                                file_name)
1609

    
1610
  prefix = ("%s.backup-%s." %
1611
            (os.path.basename(file_name), TimestampForFilename()))
1612
  dir_name = os.path.dirname(file_name)
1613

    
1614
  fsrc = open(file_name, 'rb')
1615
  try:
1616
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1617
    fdst = os.fdopen(fd, 'wb')
1618
    try:
1619
      logging.debug("Backing up %s at %s", file_name, backup_name)
1620
      shutil.copyfileobj(fsrc, fdst)
1621
    finally:
1622
      fdst.close()
1623
  finally:
1624
    fsrc.close()
1625

    
1626
  return backup_name
1627

    
1628

    
1629
def ShellQuote(value):
1630
  """Quotes shell argument according to POSIX.
1631

1632
  @type value: str
1633
  @param value: the argument to be quoted
1634
  @rtype: str
1635
  @return: the quoted value
1636

1637
  """
1638
  if _re_shell_unquoted.match(value):
1639
    return value
1640
  else:
1641
    return "'%s'" % value.replace("'", "'\\''")
1642

    
1643

    
1644
def ShellQuoteArgs(args):
1645
  """Quotes a list of shell arguments.
1646

1647
  @type args: list
1648
  @param args: list of arguments to be quoted
1649
  @rtype: str
1650
  @return: the quoted arguments concatenated with spaces
1651

1652
  """
1653
  return ' '.join([ShellQuote(i) for i in args])
1654

    
1655

    
1656
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1657
  """Simple ping implementation using TCP connect(2).
1658

1659
  Check if the given IP is reachable by doing attempting a TCP connect
1660
  to it.
1661

1662
  @type target: str
1663
  @param target: the IP or hostname to ping
1664
  @type port: int
1665
  @param port: the port to connect to
1666
  @type timeout: int
1667
  @param timeout: the timeout on the connection attempt
1668
  @type live_port_needed: boolean
1669
  @param live_port_needed: whether a closed port will cause the
1670
      function to return failure, as if there was a timeout
1671
  @type source: str or None
1672
  @param source: if specified, will cause the connect to be made
1673
      from this specific source address; failures to bind other
1674
      than C{EADDRNOTAVAIL} will be ignored
1675

1676
  """
1677
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1678

    
1679
  success = False
1680

    
1681
  if source is not None:
1682
    try:
1683
      sock.bind((source, 0))
1684
    except socket.error, (errcode, _):
1685
      if errcode == errno.EADDRNOTAVAIL:
1686
        success = False
1687

    
1688
  sock.settimeout(timeout)
1689

    
1690
  try:
1691
    sock.connect((target, port))
1692
    sock.close()
1693
    success = True
1694
  except socket.timeout:
1695
    success = False
1696
  except socket.error, (errcode, _):
1697
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1698

    
1699
  return success
1700

    
1701

    
1702
def OwnIpAddress(address):
1703
  """Check if the current host has the the given IP address.
1704

1705
  Currently this is done by TCP-pinging the address from the loopback
1706
  address.
1707

1708
  @type address: string
1709
  @param address: the address to check
1710
  @rtype: bool
1711
  @return: True if we own the address
1712

1713
  """
1714
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1715
                 source=constants.LOCALHOST_IP_ADDRESS)
1716

    
1717

    
1718
def ListVisibleFiles(path):
1719
  """Returns a list of visible files in a directory.
1720

1721
  @type path: str
1722
  @param path: the directory to enumerate
1723
  @rtype: list
1724
  @return: the list of all files not starting with a dot
1725
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1726

1727
  """
1728
  if not IsNormAbsPath(path):
1729
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1730
                                 " absolute/normalized: '%s'" % path)
1731
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1732
  files.sort()
1733
  return files
1734

    
1735

    
1736
def GetHomeDir(user, default=None):
1737
  """Try to get the homedir of the given user.
1738

1739
  The user can be passed either as a string (denoting the name) or as
1740
  an integer (denoting the user id). If the user is not found, the
1741
  'default' argument is returned, which defaults to None.
1742

1743
  """
1744
  try:
1745
    if isinstance(user, basestring):
1746
      result = pwd.getpwnam(user)
1747
    elif isinstance(user, (int, long)):
1748
      result = pwd.getpwuid(user)
1749
    else:
1750
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1751
                                   type(user))
1752
  except KeyError:
1753
    return default
1754
  return result.pw_dir
1755

    
1756

    
1757
def NewUUID():
1758
  """Returns a random UUID.
1759

1760
  @note: This is a Linux-specific method as it uses the /proc
1761
      filesystem.
1762
  @rtype: str
1763

1764
  """
1765
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1766

    
1767

    
1768
def GenerateSecret(numbytes=20):
1769
  """Generates a random secret.
1770

1771
  This will generate a pseudo-random secret returning an hex string
1772
  (so that it can be used where an ASCII string is needed).
1773

1774
  @param numbytes: the number of bytes which will be represented by the returned
1775
      string (defaulting to 20, the length of a SHA1 hash)
1776
  @rtype: str
1777
  @return: an hex representation of the pseudo-random sequence
1778

1779
  """
1780
  return os.urandom(numbytes).encode('hex')
1781

    
1782

    
1783
def EnsureDirs(dirs):
1784
  """Make required directories, if they don't exist.
1785

1786
  @param dirs: list of tuples (dir_name, dir_mode)
1787
  @type dirs: list of (string, integer)
1788

1789
  """
1790
  for dir_name, dir_mode in dirs:
1791
    try:
1792
      os.mkdir(dir_name, dir_mode)
1793
    except EnvironmentError, err:
1794
      if err.errno != errno.EEXIST:
1795
        raise errors.GenericError("Cannot create needed directory"
1796
                                  " '%s': %s" % (dir_name, err))
1797
    try:
1798
      os.chmod(dir_name, dir_mode)
1799
    except EnvironmentError, err:
1800
      raise errors.GenericError("Cannot change directory permissions on"
1801
                                " '%s': %s" % (dir_name, err))
1802
    if not os.path.isdir(dir_name):
1803
      raise errors.GenericError("%s is not a directory" % dir_name)
1804

    
1805

    
1806
def ReadFile(file_name, size=-1):
1807
  """Reads a file.
1808

1809
  @type size: int
1810
  @param size: Read at most size bytes (if negative, entire file)
1811
  @rtype: str
1812
  @return: the (possibly partial) content of the file
1813

1814
  """
1815
  f = open(file_name, "r")
1816
  try:
1817
    return f.read(size)
1818
  finally:
1819
    f.close()
1820

    
1821

    
1822
def WriteFile(file_name, fn=None, data=None,
1823
              mode=None, uid=-1, gid=-1,
1824
              atime=None, mtime=None, close=True,
1825
              dry_run=False, backup=False,
1826
              prewrite=None, postwrite=None):
1827
  """(Over)write a file atomically.
1828

1829
  The file_name and either fn (a function taking one argument, the
1830
  file descriptor, and which should write the data to it) or data (the
1831
  contents of the file) must be passed. The other arguments are
1832
  optional and allow setting the file mode, owner and group, and the
1833
  mtime/atime of the file.
1834

1835
  If the function doesn't raise an exception, it has succeeded and the
1836
  target file has the new contents. If the function has raised an
1837
  exception, an existing target file should be unmodified and the
1838
  temporary file should be removed.
1839

1840
  @type file_name: str
1841
  @param file_name: the target filename
1842
  @type fn: callable
1843
  @param fn: content writing function, called with
1844
      file descriptor as parameter
1845
  @type data: str
1846
  @param data: contents of the file
1847
  @type mode: int
1848
  @param mode: file mode
1849
  @type uid: int
1850
  @param uid: the owner of the file
1851
  @type gid: int
1852
  @param gid: the group of the file
1853
  @type atime: int
1854
  @param atime: a custom access time to be set on the file
1855
  @type mtime: int
1856
  @param mtime: a custom modification time to be set on the file
1857
  @type close: boolean
1858
  @param close: whether to close file after writing it
1859
  @type prewrite: callable
1860
  @param prewrite: function to be called before writing content
1861
  @type postwrite: callable
1862
  @param postwrite: function to be called after writing content
1863

1864
  @rtype: None or int
1865
  @return: None if the 'close' parameter evaluates to True,
1866
      otherwise the file descriptor
1867

1868
  @raise errors.ProgrammerError: if any of the arguments are not valid
1869

1870
  """
1871
  if not os.path.isabs(file_name):
1872
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1873
                                 " absolute: '%s'" % file_name)
1874

    
1875
  if [fn, data].count(None) != 1:
1876
    raise errors.ProgrammerError("fn or data required")
1877

    
1878
  if [atime, mtime].count(None) == 1:
1879
    raise errors.ProgrammerError("Both atime and mtime must be either"
1880
                                 " set or None")
1881

    
1882
  if backup and not dry_run and os.path.isfile(file_name):
1883
    CreateBackup(file_name)
1884

    
1885
  dir_name, base_name = os.path.split(file_name)
1886
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1887
  do_remove = True
1888
  # here we need to make sure we remove the temp file, if any error
1889
  # leaves it in place
1890
  try:
1891
    if uid != -1 or gid != -1:
1892
      os.chown(new_name, uid, gid)
1893
    if mode:
1894
      os.chmod(new_name, mode)
1895
    if callable(prewrite):
1896
      prewrite(fd)
1897
    if data is not None:
1898
      os.write(fd, data)
1899
    else:
1900
      fn(fd)
1901
    if callable(postwrite):
1902
      postwrite(fd)
1903
    os.fsync(fd)
1904
    if atime is not None and mtime is not None:
1905
      os.utime(new_name, (atime, mtime))
1906
    if not dry_run:
1907
      os.rename(new_name, file_name)
1908
      do_remove = False
1909
  finally:
1910
    if close:
1911
      os.close(fd)
1912
      result = None
1913
    else:
1914
      result = fd
1915
    if do_remove:
1916
      RemoveFile(new_name)
1917

    
1918
  return result
1919

    
1920

    
1921
def ReadOneLineFile(file_name, strict=False):
1922
  """Return the first non-empty line from a file.
1923

1924
  @type strict: boolean
1925
  @param strict: if True, abort if the file has more than one
1926
      non-empty line
1927

1928
  """
1929
  file_lines = ReadFile(file_name).splitlines()
1930
  full_lines = filter(bool, file_lines)
1931
  if not file_lines or not full_lines:
1932
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1933
  elif strict and len(full_lines) > 1:
1934
    raise errors.GenericError("Too many lines in one-liner file %s" %
1935
                              file_name)
1936
  return full_lines[0]
1937

    
1938

    
1939
def FirstFree(seq, base=0):
1940
  """Returns the first non-existing integer from seq.
1941

1942
  The seq argument should be a sorted list of positive integers. The
1943
  first time the index of an element is smaller than the element
1944
  value, the index will be returned.
1945

1946
  The base argument is used to start at a different offset,
1947
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1948

1949
  Example: C{[0, 1, 3]} will return I{2}.
1950

1951
  @type seq: sequence
1952
  @param seq: the sequence to be analyzed.
1953
  @type base: int
1954
  @param base: use this value as the base index of the sequence
1955
  @rtype: int
1956
  @return: the first non-used index in the sequence
1957

1958
  """
1959
  for idx, elem in enumerate(seq):
1960
    assert elem >= base, "Passed element is higher than base offset"
1961
    if elem > idx + base:
1962
      # idx is not used
1963
      return idx + base
1964
  return None
1965

    
1966

    
1967
def SingleWaitForFdCondition(fdobj, event, timeout):
1968
  """Waits for a condition to occur on the socket.
1969

1970
  Immediately returns at the first interruption.
1971

1972
  @type fdobj: integer or object supporting a fileno() method
1973
  @param fdobj: entity to wait for events on
1974
  @type event: integer
1975
  @param event: ORed condition (see select module)
1976
  @type timeout: float or None
1977
  @param timeout: Timeout in seconds
1978
  @rtype: int or None
1979
  @return: None for timeout, otherwise occured conditions
1980

1981
  """
1982
  check = (event | select.POLLPRI |
1983
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1984

    
1985
  if timeout is not None:
1986
    # Poller object expects milliseconds
1987
    timeout *= 1000
1988

    
1989
  poller = select.poll()
1990
  poller.register(fdobj, event)
1991
  try:
1992
    # TODO: If the main thread receives a signal and we have no timeout, we
1993
    # could wait forever. This should check a global "quit" flag or something
1994
    # every so often.
1995
    io_events = poller.poll(timeout)
1996
  except select.error, err:
1997
    if err[0] != errno.EINTR:
1998
      raise
1999
    io_events = []
2000
  if io_events and io_events[0][1] & check:
2001
    return io_events[0][1]
2002
  else:
2003
    return None
2004

    
2005

    
2006
class FdConditionWaiterHelper(object):
2007
  """Retry helper for WaitForFdCondition.
2008

2009
  This class contains the retried and wait functions that make sure
2010
  WaitForFdCondition can continue waiting until the timeout is actually
2011
  expired.
2012

2013
  """
2014

    
2015
  def __init__(self, timeout):
2016
    self.timeout = timeout
2017

    
2018
  def Poll(self, fdobj, event):
2019
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2020
    if result is None:
2021
      raise RetryAgain()
2022
    else:
2023
      return result
2024

    
2025
  def UpdateTimeout(self, timeout):
2026
    self.timeout = timeout
2027

    
2028

    
2029
def WaitForFdCondition(fdobj, event, timeout):
2030
  """Waits for a condition to occur on the socket.
2031

2032
  Retries until the timeout is expired, even if interrupted.
2033

2034
  @type fdobj: integer or object supporting a fileno() method
2035
  @param fdobj: entity to wait for events on
2036
  @type event: integer
2037
  @param event: ORed condition (see select module)
2038
  @type timeout: float or None
2039
  @param timeout: Timeout in seconds
2040
  @rtype: int or None
2041
  @return: None for timeout, otherwise occured conditions
2042

2043
  """
2044
  if timeout is not None:
2045
    retrywaiter = FdConditionWaiterHelper(timeout)
2046
    try:
2047
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2048
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2049
    except RetryTimeout:
2050
      result = None
2051
  else:
2052
    result = None
2053
    while result is None:
2054
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2055
  return result
2056

    
2057

    
2058
def UniqueSequence(seq):
2059
  """Returns a list with unique elements.
2060

2061
  Element order is preserved.
2062

2063
  @type seq: sequence
2064
  @param seq: the sequence with the source elements
2065
  @rtype: list
2066
  @return: list of unique elements from seq
2067

2068
  """
2069
  seen = set()
2070
  return [i for i in seq if i not in seen and not seen.add(i)]
2071

    
2072

    
2073
def NormalizeAndValidateMac(mac):
2074
  """Normalizes and check if a MAC address is valid.
2075

2076
  Checks whether the supplied MAC address is formally correct, only
2077
  accepts colon separated format. Normalize it to all lower.
2078

2079
  @type mac: str
2080
  @param mac: the MAC to be validated
2081
  @rtype: str
2082
  @return: returns the normalized and validated MAC.
2083

2084
  @raise errors.OpPrereqError: If the MAC isn't valid
2085

2086
  """
2087
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2088
  if not mac_check.match(mac):
2089
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2090
                               mac, errors.ECODE_INVAL)
2091

    
2092
  return mac.lower()
2093

    
2094

    
2095
def TestDelay(duration):
2096
  """Sleep for a fixed amount of time.
2097

2098
  @type duration: float
2099
  @param duration: the sleep duration
2100
  @rtype: boolean
2101
  @return: False for negative value, True otherwise
2102

2103
  """
2104
  if duration < 0:
2105
    return False, "Invalid sleep duration"
2106
  time.sleep(duration)
2107
  return True, None
2108

    
2109

    
2110
def _CloseFDNoErr(fd, retries=5):
2111
  """Close a file descriptor ignoring errors.
2112

2113
  @type fd: int
2114
  @param fd: the file descriptor
2115
  @type retries: int
2116
  @param retries: how many retries to make, in case we get any
2117
      other error than EBADF
2118

2119
  """
2120
  try:
2121
    os.close(fd)
2122
  except OSError, err:
2123
    if err.errno != errno.EBADF:
2124
      if retries > 0:
2125
        _CloseFDNoErr(fd, retries - 1)
2126
    # else either it's closed already or we're out of retries, so we
2127
    # ignore this and go on
2128

    
2129

    
2130
def CloseFDs(noclose_fds=None):
2131
  """Close file descriptors.
2132

2133
  This closes all file descriptors above 2 (i.e. except
2134
  stdin/out/err).
2135

2136
  @type noclose_fds: list or None
2137
  @param noclose_fds: if given, it denotes a list of file descriptor
2138
      that should not be closed
2139

2140
  """
2141
  # Default maximum for the number of available file descriptors.
2142
  if 'SC_OPEN_MAX' in os.sysconf_names:
2143
    try:
2144
      MAXFD = os.sysconf('SC_OPEN_MAX')
2145
      if MAXFD < 0:
2146
        MAXFD = 1024
2147
    except OSError:
2148
      MAXFD = 1024
2149
  else:
2150
    MAXFD = 1024
2151
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2152
  if (maxfd == resource.RLIM_INFINITY):
2153
    maxfd = MAXFD
2154

    
2155
  # Iterate through and close all file descriptors (except the standard ones)
2156
  for fd in range(3, maxfd):
2157
    if noclose_fds and fd in noclose_fds:
2158
      continue
2159
    _CloseFDNoErr(fd)
2160

    
2161

    
2162
def Mlockall():
2163
  """Lock current process' virtual address space into RAM.
2164

2165
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2166
  see mlock(2) for more details. This function requires ctypes module.
2167

2168
  """
2169
  if ctypes is None:
2170
    logging.warning("Cannot set memory lock, ctypes module not found")
2171
    return
2172

    
2173
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2174
  if libc is None:
2175
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2176
    return
2177

    
2178
  # Some older version of the ctypes module don't have built-in functionality
2179
  # to access the errno global variable, where function error codes are stored.
2180
  # By declaring this variable as a pointer to an integer we can then access
2181
  # its value correctly, should the mlockall call fail, in order to see what
2182
  # the actual error code was.
2183
  # pylint: disable-msg=W0212
2184
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2185

    
2186
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2187
    # pylint: disable-msg=W0212
2188
    logging.error("Cannot set memory lock: %s",
2189
                  os.strerror(libc.__errno_location().contents.value))
2190
    return
2191

    
2192
  logging.debug("Memory lock set")
2193

    
2194

    
2195
def Daemonize(logfile, run_uid, run_gid):
2196
  """Daemonize the current process.
2197

2198
  This detaches the current process from the controlling terminal and
2199
  runs it in the background as a daemon.
2200

2201
  @type logfile: str
2202
  @param logfile: the logfile to which we should redirect stdout/stderr
2203
  @type run_uid: int
2204
  @param run_uid: Run the child under this uid
2205
  @type run_gid: int
2206
  @param run_gid: Run the child under this gid
2207
  @rtype: int
2208
  @return: the value zero
2209

2210
  """
2211
  # pylint: disable-msg=W0212
2212
  # yes, we really want os._exit
2213
  UMASK = 077
2214
  WORKDIR = "/"
2215

    
2216
  # this might fail
2217
  pid = os.fork()
2218
  if (pid == 0):  # The first child.
2219
    os.setsid()
2220
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2221
    #        make sure to check for config permission and bail out when invoked
2222
    #        with wrong user.
2223
    os.setgid(run_gid)
2224
    os.setuid(run_uid)
2225
    # this might fail
2226
    pid = os.fork() # Fork a second child.
2227
    if (pid == 0):  # The second child.
2228
      os.chdir(WORKDIR)
2229
      os.umask(UMASK)
2230
    else:
2231
      # exit() or _exit()?  See below.
2232
      os._exit(0) # Exit parent (the first child) of the second child.
2233
  else:
2234
    os._exit(0) # Exit parent of the first child.
2235

    
2236
  for fd in range(3):
2237
    _CloseFDNoErr(fd)
2238
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2239
  assert i == 0, "Can't close/reopen stdin"
2240
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2241
  assert i == 1, "Can't close/reopen stdout"
2242
  # Duplicate standard output to standard error.
2243
  os.dup2(1, 2)
2244
  return 0
2245

    
2246

    
2247
def DaemonPidFileName(name):
2248
  """Compute a ganeti pid file absolute path
2249

2250
  @type name: str
2251
  @param name: the daemon name
2252
  @rtype: str
2253
  @return: the full path to the pidfile corresponding to the given
2254
      daemon name
2255

2256
  """
2257
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2258

    
2259

    
2260
def EnsureDaemon(name):
2261
  """Check for and start daemon if not alive.
2262

2263
  """
2264
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2265
  if result.failed:
2266
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2267
                  name, result.fail_reason, result.output)
2268
    return False
2269

    
2270
  return True
2271

    
2272

    
2273
def StopDaemon(name):
2274
  """Stop daemon
2275

2276
  """
2277
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2278
  if result.failed:
2279
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2280
                  name, result.fail_reason, result.output)
2281
    return False
2282

    
2283
  return True
2284

    
2285

    
2286
def WritePidFile(name):
2287
  """Write the current process pidfile.
2288

2289
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2290

2291
  @type name: str
2292
  @param name: the daemon name to use
2293
  @raise errors.GenericError: if the pid file already exists and
2294
      points to a live process
2295

2296
  """
2297
  pid = os.getpid()
2298
  pidfilename = DaemonPidFileName(name)
2299
  if IsProcessAlive(ReadPidFile(pidfilename)):
2300
    raise errors.GenericError("%s contains a live process" % pidfilename)
2301

    
2302
  WriteFile(pidfilename, data="%d\n" % pid)
2303

    
2304

    
2305
def RemovePidFile(name):
2306
  """Remove the current process pidfile.
2307

2308
  Any errors are ignored.
2309

2310
  @type name: str
2311
  @param name: the daemon name used to derive the pidfile name
2312

2313
  """
2314
  pidfilename = DaemonPidFileName(name)
2315
  # TODO: we could check here that the file contains our pid
2316
  try:
2317
    RemoveFile(pidfilename)
2318
  except: # pylint: disable-msg=W0702
2319
    pass
2320

    
2321

    
2322
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2323
                waitpid=False):
2324
  """Kill a process given by its pid.
2325

2326
  @type pid: int
2327
  @param pid: The PID to terminate.
2328
  @type signal_: int
2329
  @param signal_: The signal to send, by default SIGTERM
2330
  @type timeout: int
2331
  @param timeout: The timeout after which, if the process is still alive,
2332
                  a SIGKILL will be sent. If not positive, no such checking
2333
                  will be done
2334
  @type waitpid: boolean
2335
  @param waitpid: If true, we should waitpid on this process after
2336
      sending signals, since it's our own child and otherwise it
2337
      would remain as zombie
2338

2339
  """
2340
  def _helper(pid, signal_, wait):
2341
    """Simple helper to encapsulate the kill/waitpid sequence"""
2342
    os.kill(pid, signal_)
2343
    if wait:
2344
      try:
2345
        os.waitpid(pid, os.WNOHANG)
2346
      except OSError:
2347
        pass
2348

    
2349
  if pid <= 0:
2350
    # kill with pid=0 == suicide
2351
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2352

    
2353
  if not IsProcessAlive(pid):
2354
    return
2355

    
2356
  _helper(pid, signal_, waitpid)
2357

    
2358
  if timeout <= 0:
2359
    return
2360

    
2361
  def _CheckProcess():
2362
    if not IsProcessAlive(pid):
2363
      return
2364

    
2365
    try:
2366
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2367
    except OSError:
2368
      raise RetryAgain()
2369

    
2370
    if result_pid > 0:
2371
      return
2372

    
2373
    raise RetryAgain()
2374

    
2375
  try:
2376
    # Wait up to $timeout seconds
2377
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2378
  except RetryTimeout:
2379
    pass
2380

    
2381
  if IsProcessAlive(pid):
2382
    # Kill process if it's still alive
2383
    _helper(pid, signal.SIGKILL, waitpid)
2384

    
2385

    
2386
def FindFile(name, search_path, test=os.path.exists):
2387
  """Look for a filesystem object in a given path.
2388

2389
  This is an abstract method to search for filesystem object (files,
2390
  dirs) under a given search path.
2391

2392
  @type name: str
2393
  @param name: the name to look for
2394
  @type search_path: str
2395
  @param search_path: location to start at
2396
  @type test: callable
2397
  @param test: a function taking one argument that should return True
2398
      if the a given object is valid; the default value is
2399
      os.path.exists, causing only existing files to be returned
2400
  @rtype: str or None
2401
  @return: full path to the object if found, None otherwise
2402

2403
  """
2404
  # validate the filename mask
2405
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2406
    logging.critical("Invalid value passed for external script name: '%s'",
2407
                     name)
2408
    return None
2409

    
2410
  for dir_name in search_path:
2411
    # FIXME: investigate switch to PathJoin
2412
    item_name = os.path.sep.join([dir_name, name])
2413
    # check the user test and that we're indeed resolving to the given
2414
    # basename
2415
    if test(item_name) and os.path.basename(item_name) == name:
2416
      return item_name
2417
  return None
2418

    
2419

    
2420
def CheckVolumeGroupSize(vglist, vgname, minsize):
2421
  """Checks if the volume group list is valid.
2422

2423
  The function will check if a given volume group is in the list of
2424
  volume groups and has a minimum size.
2425

2426
  @type vglist: dict
2427
  @param vglist: dictionary of volume group names and their size
2428
  @type vgname: str
2429
  @param vgname: the volume group we should check
2430
  @type minsize: int
2431
  @param minsize: the minimum size we accept
2432
  @rtype: None or str
2433
  @return: None for success, otherwise the error message
2434

2435
  """
2436
  vgsize = vglist.get(vgname, None)
2437
  if vgsize is None:
2438
    return "volume group '%s' missing" % vgname
2439
  elif vgsize < minsize:
2440
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2441
            (vgname, minsize, vgsize))
2442
  return None
2443

    
2444

    
2445
def SplitTime(value):
2446
  """Splits time as floating point number into a tuple.
2447

2448
  @param value: Time in seconds
2449
  @type value: int or float
2450
  @return: Tuple containing (seconds, microseconds)
2451

2452
  """
2453
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2454

    
2455
  assert 0 <= seconds, \
2456
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2457
  assert 0 <= microseconds <= 999999, \
2458
    "Microseconds must be 0-999999, but are %s" % microseconds
2459

    
2460
  return (int(seconds), int(microseconds))
2461

    
2462

    
2463
def MergeTime(timetuple):
2464
  """Merges a tuple into time as a floating point number.
2465

2466
  @param timetuple: Time as tuple, (seconds, microseconds)
2467
  @type timetuple: tuple
2468
  @return: Time as a floating point number expressed in seconds
2469

2470
  """
2471
  (seconds, microseconds) = timetuple
2472

    
2473
  assert 0 <= seconds, \
2474
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2475
  assert 0 <= microseconds <= 999999, \
2476
    "Microseconds must be 0-999999, but are %s" % microseconds
2477

    
2478
  return float(seconds) + (float(microseconds) * 0.000001)
2479

    
2480

    
2481
def GetDaemonPort(daemon_name):
2482
  """Get the daemon port for this cluster.
2483

2484
  Note that this routine does not read a ganeti-specific file, but
2485
  instead uses C{socket.getservbyname} to allow pre-customization of
2486
  this parameter outside of Ganeti.
2487

2488
  @type daemon_name: string
2489
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2490
  @rtype: int
2491

2492
  """
2493
  if daemon_name not in constants.DAEMONS_PORTS:
2494
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2495

    
2496
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2497
  try:
2498
    port = socket.getservbyname(daemon_name, proto)
2499
  except socket.error:
2500
    port = default_port
2501

    
2502
  return port
2503

    
2504

    
2505
class LogFileHandler(logging.FileHandler):
2506
  """Log handler that doesn't fallback to stderr.
2507

2508
  When an error occurs while writing on the logfile, logging.FileHandler tries
2509
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2510
  the logfile. This class avoids failures reporting errors to /dev/console.
2511

2512
  """
2513
  def __init__(self, filename, mode="a", encoding=None):
2514
    """Open the specified file and use it as the stream for logging.
2515

2516
    Also open /dev/console to report errors while logging.
2517

2518
    """
2519
    logging.FileHandler.__init__(self, filename, mode, encoding)
2520
    self.console = open(constants.DEV_CONSOLE, "a")
2521

    
2522
  def handleError(self, record): # pylint: disable-msg=C0103
2523
    """Handle errors which occur during an emit() call.
2524

2525
    Try to handle errors with FileHandler method, if it fails write to
2526
    /dev/console.
2527

2528
    """
2529
    try:
2530
      logging.FileHandler.handleError(self, record)
2531
    except Exception: # pylint: disable-msg=W0703
2532
      try:
2533
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2534
      except Exception: # pylint: disable-msg=W0703
2535
        # Log handler tried everything it could, now just give up
2536
        pass
2537

    
2538

    
2539
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2540
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2541
                 console_logging=False):
2542
  """Configures the logging module.
2543

2544
  @type logfile: str
2545
  @param logfile: the filename to which we should log
2546
  @type debug: integer
2547
  @param debug: if greater than zero, enable debug messages, otherwise
2548
      only those at C{INFO} and above level
2549
  @type stderr_logging: boolean
2550
  @param stderr_logging: whether we should also log to the standard error
2551
  @type program: str
2552
  @param program: the name under which we should log messages
2553
  @type multithreaded: boolean
2554
  @param multithreaded: if True, will add the thread name to the log file
2555
  @type syslog: string
2556
  @param syslog: one of 'no', 'yes', 'only':
2557
      - if no, syslog is not used
2558
      - if yes, syslog is used (in addition to file-logging)
2559
      - if only, only syslog is used
2560
  @type console_logging: boolean
2561
  @param console_logging: if True, will use a FileHandler which falls back to
2562
      the system console if logging fails
2563
  @raise EnvironmentError: if we can't open the log file and
2564
      syslog/stderr logging is disabled
2565

2566
  """
2567
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2568
  sft = program + "[%(process)d]:"
2569
  if multithreaded:
2570
    fmt += "/%(threadName)s"
2571
    sft += " (%(threadName)s)"
2572
  if debug:
2573
    fmt += " %(module)s:%(lineno)s"
2574
    # no debug info for syslog loggers
2575
  fmt += " %(levelname)s %(message)s"
2576
  # yes, we do want the textual level, as remote syslog will probably
2577
  # lose the error level, and it's easier to grep for it
2578
  sft += " %(levelname)s %(message)s"
2579
  formatter = logging.Formatter(fmt)
2580
  sys_fmt = logging.Formatter(sft)
2581

    
2582
  root_logger = logging.getLogger("")
2583
  root_logger.setLevel(logging.NOTSET)
2584

    
2585
  # Remove all previously setup handlers
2586
  for handler in root_logger.handlers:
2587
    handler.close()
2588
    root_logger.removeHandler(handler)
2589

    
2590
  if stderr_logging:
2591
    stderr_handler = logging.StreamHandler()
2592
    stderr_handler.setFormatter(formatter)
2593
    if debug:
2594
      stderr_handler.setLevel(logging.NOTSET)
2595
    else:
2596
      stderr_handler.setLevel(logging.CRITICAL)
2597
    root_logger.addHandler(stderr_handler)
2598

    
2599
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2600
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2601
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2602
                                                    facility)
2603
    syslog_handler.setFormatter(sys_fmt)
2604
    # Never enable debug over syslog
2605
    syslog_handler.setLevel(logging.INFO)
2606
    root_logger.addHandler(syslog_handler)
2607

    
2608
  if syslog != constants.SYSLOG_ONLY:
2609
    # this can fail, if the logging directories are not setup or we have
2610
    # a permisssion problem; in this case, it's best to log but ignore
2611
    # the error if stderr_logging is True, and if false we re-raise the
2612
    # exception since otherwise we could run but without any logs at all
2613
    try:
2614
      if console_logging:
2615
        logfile_handler = LogFileHandler(logfile)
2616
      else:
2617
        logfile_handler = logging.FileHandler(logfile)
2618
      logfile_handler.setFormatter(formatter)
2619
      if debug:
2620
        logfile_handler.setLevel(logging.DEBUG)
2621
      else:
2622
        logfile_handler.setLevel(logging.INFO)
2623
      root_logger.addHandler(logfile_handler)
2624
    except EnvironmentError:
2625
      if stderr_logging or syslog == constants.SYSLOG_YES:
2626
        logging.exception("Failed to enable logging to file '%s'", logfile)
2627
      else:
2628
        # we need to re-raise the exception
2629
        raise
2630

    
2631

    
2632
def IsNormAbsPath(path):
2633
  """Check whether a path is absolute and also normalized
2634

2635
  This avoids things like /dir/../../other/path to be valid.
2636

2637
  """
2638
  return os.path.normpath(path) == path and os.path.isabs(path)
2639

    
2640

    
2641
def PathJoin(*args):
2642
  """Safe-join a list of path components.
2643

2644
  Requirements:
2645
      - the first argument must be an absolute path
2646
      - no component in the path must have backtracking (e.g. /../),
2647
        since we check for normalization at the end
2648

2649
  @param args: the path components to be joined
2650
  @raise ValueError: for invalid paths
2651

2652
  """
2653
  # ensure we're having at least one path passed in
2654
  assert args
2655
  # ensure the first component is an absolute and normalized path name
2656
  root = args[0]
2657
  if not IsNormAbsPath(root):
2658
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2659
  result = os.path.join(*args)
2660
  # ensure that the whole path is normalized
2661
  if not IsNormAbsPath(result):
2662
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2663
  # check that we're still under the original prefix
2664
  prefix = os.path.commonprefix([root, result])
2665
  if prefix != root:
2666
    raise ValueError("Error: path joining resulted in different prefix"
2667
                     " (%s != %s)" % (prefix, root))
2668
  return result
2669

    
2670

    
2671
def TailFile(fname, lines=20):
2672
  """Return the last lines from a file.
2673

2674
  @note: this function will only read and parse the last 4KB of
2675
      the file; if the lines are very long, it could be that less
2676
      than the requested number of lines are returned
2677

2678
  @param fname: the file name
2679
  @type lines: int
2680
  @param lines: the (maximum) number of lines to return
2681

2682
  """
2683
  fd = open(fname, "r")
2684
  try:
2685
    fd.seek(0, 2)
2686
    pos = fd.tell()
2687
    pos = max(0, pos-4096)
2688
    fd.seek(pos, 0)
2689
    raw_data = fd.read()
2690
  finally:
2691
    fd.close()
2692

    
2693
  rows = raw_data.splitlines()
2694
  return rows[-lines:]
2695

    
2696

    
2697
def FormatTimestampWithTZ(secs):
2698
  """Formats a Unix timestamp with the local timezone.
2699

2700
  """
2701
  return time.strftime("%F %T %Z", time.gmtime(secs))
2702

    
2703

    
2704
def _ParseAsn1Generalizedtime(value):
2705
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2706

2707
  @type value: string
2708
  @param value: ASN1 GENERALIZEDTIME timestamp
2709

2710
  """
2711
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2712
  if m:
2713
    # We have an offset
2714
    asn1time = m.group(1)
2715
    hours = int(m.group(2))
2716
    minutes = int(m.group(3))
2717
    utcoffset = (60 * hours) + minutes
2718
  else:
2719
    if not value.endswith("Z"):
2720
      raise ValueError("Missing timezone")
2721
    asn1time = value[:-1]
2722
    utcoffset = 0
2723

    
2724
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2725

    
2726
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2727

    
2728
  return calendar.timegm(tt.utctimetuple())
2729

    
2730

    
2731
def GetX509CertValidity(cert):
2732
  """Returns the validity period of the certificate.
2733

2734
  @type cert: OpenSSL.crypto.X509
2735
  @param cert: X509 certificate object
2736

2737
  """
2738
  # The get_notBefore and get_notAfter functions are only supported in
2739
  # pyOpenSSL 0.7 and above.
2740
  try:
2741
    get_notbefore_fn = cert.get_notBefore
2742
  except AttributeError:
2743
    not_before = None
2744
  else:
2745
    not_before_asn1 = get_notbefore_fn()
2746

    
2747
    if not_before_asn1 is None:
2748
      not_before = None
2749
    else:
2750
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2751

    
2752
  try:
2753
    get_notafter_fn = cert.get_notAfter
2754
  except AttributeError:
2755
    not_after = None
2756
  else:
2757
    not_after_asn1 = get_notafter_fn()
2758

    
2759
    if not_after_asn1 is None:
2760
      not_after = None
2761
    else:
2762
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2763

    
2764
  return (not_before, not_after)
2765

    
2766

    
2767
def _VerifyCertificateInner(expired, not_before, not_after, now,
2768
                            warn_days, error_days):
2769
  """Verifies certificate validity.
2770

2771
  @type expired: bool
2772
  @param expired: Whether pyOpenSSL considers the certificate as expired
2773
  @type not_before: number or None
2774
  @param not_before: Unix timestamp before which certificate is not valid
2775
  @type not_after: number or None
2776
  @param not_after: Unix timestamp after which certificate is invalid
2777
  @type now: number
2778
  @param now: Current time as Unix timestamp
2779
  @type warn_days: number or None
2780
  @param warn_days: How many days before expiration a warning should be reported
2781
  @type error_days: number or None
2782
  @param error_days: How many days before expiration an error should be reported
2783

2784
  """
2785
  if expired:
2786
    msg = "Certificate is expired"
2787

    
2788
    if not_before is not None and not_after is not None:
2789
      msg += (" (valid from %s to %s)" %
2790
              (FormatTimestampWithTZ(not_before),
2791
               FormatTimestampWithTZ(not_after)))
2792
    elif not_before is not None:
2793
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2794
    elif not_after is not None:
2795
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2796

    
2797
    return (CERT_ERROR, msg)
2798

    
2799
  elif not_before is not None and not_before > now:
2800
    return (CERT_WARNING,
2801
            "Certificate not yet valid (valid from %s)" %
2802
            FormatTimestampWithTZ(not_before))
2803

    
2804
  elif not_after is not None:
2805
    remaining_days = int((not_after - now) / (24 * 3600))
2806

    
2807
    msg = "Certificate expires in about %d days" % remaining_days
2808

    
2809
    if error_days is not None and remaining_days <= error_days:
2810
      return (CERT_ERROR, msg)
2811

    
2812
    if warn_days is not None and remaining_days <= warn_days:
2813
      return (CERT_WARNING, msg)
2814

    
2815
  return (None, None)
2816

    
2817

    
2818
def VerifyX509Certificate(cert, warn_days, error_days):
2819
  """Verifies a certificate for LUVerifyCluster.
2820

2821
  @type cert: OpenSSL.crypto.X509
2822
  @param cert: X509 certificate object
2823
  @type warn_days: number or None
2824
  @param warn_days: How many days before expiration a warning should be reported
2825
  @type error_days: number or None
2826
  @param error_days: How many days before expiration an error should be reported
2827

2828
  """
2829
  # Depending on the pyOpenSSL version, this can just return (None, None)
2830
  (not_before, not_after) = GetX509CertValidity(cert)
2831

    
2832
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2833
                                 time.time(), warn_days, error_days)
2834

    
2835

    
2836
def SignX509Certificate(cert, key, salt):
2837
  """Sign a X509 certificate.
2838

2839
  An RFC822-like signature header is added in front of the certificate.
2840

2841
  @type cert: OpenSSL.crypto.X509
2842
  @param cert: X509 certificate object
2843
  @type key: string
2844
  @param key: Key for HMAC
2845
  @type salt: string
2846
  @param salt: Salt for HMAC
2847
  @rtype: string
2848
  @return: Serialized and signed certificate in PEM format
2849

2850
  """
2851
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2852
    raise errors.GenericError("Invalid salt: %r" % salt)
2853

    
2854
  # Dumping as PEM here ensures the certificate is in a sane format
2855
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2856

    
2857
  return ("%s: %s/%s\n\n%s" %
2858
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2859
           Sha1Hmac(key, cert_pem, salt=salt),
2860
           cert_pem))
2861

    
2862

    
2863
def _ExtractX509CertificateSignature(cert_pem):
2864
  """Helper function to extract signature from X509 certificate.
2865

2866
  """
2867
  # Extract signature from original PEM data
2868
  for line in cert_pem.splitlines():
2869
    if line.startswith("---"):
2870
      break
2871

    
2872
    m = X509_SIGNATURE.match(line.strip())
2873
    if m:
2874
      return (m.group("salt"), m.group("sign"))
2875

    
2876
  raise errors.GenericError("X509 certificate signature is missing")
2877

    
2878

    
2879
def LoadSignedX509Certificate(cert_pem, key):
2880
  """Verifies a signed X509 certificate.
2881

2882
  @type cert_pem: string
2883
  @param cert_pem: Certificate in PEM format and with signature header
2884
  @type key: string
2885
  @param key: Key for HMAC
2886
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2887
  @return: X509 certificate object and salt
2888

2889
  """
2890
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2891

    
2892
  # Load certificate
2893
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2894

    
2895
  # Dump again to ensure it's in a sane format
2896
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2897

    
2898
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2899
    raise errors.GenericError("X509 certificate signature is invalid")
2900

    
2901
  return (cert, salt)
2902

    
2903

    
2904
def Sha1Hmac(key, text, salt=None):
2905
  """Calculates the HMAC-SHA1 digest of a text.
2906

2907
  HMAC is defined in RFC2104.
2908

2909
  @type key: string
2910
  @param key: Secret key
2911
  @type text: string
2912

2913
  """
2914
  if salt:
2915
    salted_text = salt + text
2916
  else:
2917
    salted_text = text
2918

    
2919
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2920

    
2921

    
2922
def VerifySha1Hmac(key, text, digest, salt=None):
2923
  """Verifies the HMAC-SHA1 digest of a text.
2924

2925
  HMAC is defined in RFC2104.
2926

2927
  @type key: string
2928
  @param key: Secret key
2929
  @type text: string
2930
  @type digest: string
2931
  @param digest: Expected digest
2932
  @rtype: bool
2933
  @return: Whether HMAC-SHA1 digest matches
2934

2935
  """
2936
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2937

    
2938

    
2939
def SafeEncode(text):
2940
  """Return a 'safe' version of a source string.
2941

2942
  This function mangles the input string and returns a version that
2943
  should be safe to display/encode as ASCII. To this end, we first
2944
  convert it to ASCII using the 'backslashreplace' encoding which
2945
  should get rid of any non-ASCII chars, and then we process it
2946
  through a loop copied from the string repr sources in the python; we
2947
  don't use string_escape anymore since that escape single quotes and
2948
  backslashes too, and that is too much; and that escaping is not
2949
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2950

2951
  @type text: str or unicode
2952
  @param text: input data
2953
  @rtype: str
2954
  @return: a safe version of text
2955

2956
  """
2957
  if isinstance(text, unicode):
2958
    # only if unicode; if str already, we handle it below
2959
    text = text.encode('ascii', 'backslashreplace')
2960
  resu = ""
2961
  for char in text:
2962
    c = ord(char)
2963
    if char  == '\t':
2964
      resu += r'\t'
2965
    elif char == '\n':
2966
      resu += r'\n'
2967
    elif char == '\r':
2968
      resu += r'\'r'
2969
    elif c < 32 or c >= 127: # non-printable
2970
      resu += "\\x%02x" % (c & 0xff)
2971
    else:
2972
      resu += char
2973
  return resu
2974

    
2975

    
2976
def UnescapeAndSplit(text, sep=","):
2977
  """Split and unescape a string based on a given separator.
2978

2979
  This function splits a string based on a separator where the
2980
  separator itself can be escape in order to be an element of the
2981
  elements. The escaping rules are (assuming coma being the
2982
  separator):
2983
    - a plain , separates the elements
2984
    - a sequence \\\\, (double backslash plus comma) is handled as a
2985
      backslash plus a separator comma
2986
    - a sequence \, (backslash plus comma) is handled as a
2987
      non-separator comma
2988

2989
  @type text: string
2990
  @param text: the string to split
2991
  @type sep: string
2992
  @param text: the separator
2993
  @rtype: string
2994
  @return: a list of strings
2995

2996
  """
2997
  # we split the list by sep (with no escaping at this stage)
2998
  slist = text.split(sep)
2999
  # next, we revisit the elements and if any of them ended with an odd
3000
  # number of backslashes, then we join it with the next
3001
  rlist = []
3002
  while slist:
3003
    e1 = slist.pop(0)
3004
    if e1.endswith("\\"):
3005
      num_b = len(e1) - len(e1.rstrip("\\"))
3006
      if num_b % 2 == 1:
3007
        e2 = slist.pop(0)
3008
        # here the backslashes remain (all), and will be reduced in
3009
        # the next step
3010
        rlist.append(e1 + sep + e2)
3011
        continue
3012
    rlist.append(e1)
3013
  # finally, replace backslash-something with something
3014
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3015
  return rlist
3016

    
3017

    
3018
def CommaJoin(names):
3019
  """Nicely join a set of identifiers.
3020

3021
  @param names: set, list or tuple
3022
  @return: a string with the formatted results
3023

3024
  """
3025
  return ", ".join([str(val) for val in names])
3026

    
3027

    
3028
def BytesToMebibyte(value):
3029
  """Converts bytes to mebibytes.
3030

3031
  @type value: int
3032
  @param value: Value in bytes
3033
  @rtype: int
3034
  @return: Value in mebibytes
3035

3036
  """
3037
  return int(round(value / (1024.0 * 1024.0), 0))
3038

    
3039

    
3040
def CalculateDirectorySize(path):
3041
  """Calculates the size of a directory recursively.
3042

3043
  @type path: string
3044
  @param path: Path to directory
3045
  @rtype: int
3046
  @return: Size in mebibytes
3047

3048
  """
3049
  size = 0
3050

    
3051
  for (curpath, _, files) in os.walk(path):
3052
    for filename in files:
3053
      st = os.lstat(PathJoin(curpath, filename))
3054
      size += st.st_size
3055

    
3056
  return BytesToMebibyte(size)
3057

    
3058

    
3059
def GetFilesystemStats(path):
3060
  """Returns the total and free space on a filesystem.
3061

3062
  @type path: string
3063
  @param path: Path on filesystem to be examined
3064
  @rtype: int
3065
  @return: tuple of (Total space, Free space) in mebibytes
3066

3067
  """
3068
  st = os.statvfs(path)
3069

    
3070
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3071
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3072
  return (tsize, fsize)
3073

    
3074

    
3075
def RunInSeparateProcess(fn, *args):
3076
  """Runs a function in a separate process.
3077

3078
  Note: Only boolean return values are supported.
3079

3080
  @type fn: callable
3081
  @param fn: Function to be called
3082
  @rtype: bool
3083
  @return: Function's result
3084

3085
  """
3086
  pid = os.fork()
3087
  if pid == 0:
3088
    # Child process
3089
    try:
3090
      # In case the function uses temporary files
3091
      ResetTempfileModule()
3092

    
3093
      # Call function
3094
      result = int(bool(fn(*args)))
3095
      assert result in (0, 1)
3096
    except: # pylint: disable-msg=W0702
3097
      logging.exception("Error while calling function in separate process")
3098
      # 0 and 1 are reserved for the return value
3099
      result = 33
3100

    
3101
    os._exit(result) # pylint: disable-msg=W0212
3102

    
3103
  # Parent process
3104

    
3105
  # Avoid zombies and check exit code
3106
  (_, status) = os.waitpid(pid, 0)
3107

    
3108
  if os.WIFSIGNALED(status):
3109
    exitcode = None
3110
    signum = os.WTERMSIG(status)
3111
  else:
3112
    exitcode = os.WEXITSTATUS(status)
3113
    signum = None
3114

    
3115
  if not (exitcode in (0, 1) and signum is None):
3116
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3117
                              (exitcode, signum))
3118

    
3119
  return bool(exitcode)
3120

    
3121

    
3122
def IgnoreSignals(fn, *args, **kwargs):
3123
  """Tries to call a function ignoring failures due to EINTR.
3124

3125
  """
3126
  try:
3127
    return fn(*args, **kwargs)
3128
  except EnvironmentError, err:
3129
    if err.errno == errno.EINTR:
3130
      return None
3131
    else:
3132
      raise
3133
  except (select.error, socket.error), err:
3134
    # In python 2.6 and above select.error is an IOError, so it's handled
3135
    # above, in 2.5 and below it's not, and it's handled here.
3136
    if err.args and err.args[0] == errno.EINTR:
3137
      return None
3138
    else:
3139
      raise
3140

    
3141

    
3142
def LockedMethod(fn):
3143
  """Synchronized object access decorator.
3144

3145
  This decorator is intended to protect access to an object using the
3146
  object's own lock which is hardcoded to '_lock'.
3147

3148
  """
3149
  def _LockDebug(*args, **kwargs):
3150
    if debug_locks:
3151
      logging.debug(*args, **kwargs)
3152

    
3153
  def wrapper(self, *args, **kwargs):
3154
    # pylint: disable-msg=W0212
3155
    assert hasattr(self, '_lock')
3156
    lock = self._lock
3157
    _LockDebug("Waiting for %s", lock)
3158
    lock.acquire()
3159
    try:
3160
      _LockDebug("Acquired %s", lock)
3161
      result = fn(self, *args, **kwargs)
3162
    finally:
3163
      _LockDebug("Releasing %s", lock)
3164
      lock.release()
3165
      _LockDebug("Released %s", lock)
3166
    return result
3167
  return wrapper
3168

    
3169

    
3170
def LockFile(fd):
3171
  """Locks a file using POSIX locks.
3172

3173
  @type fd: int
3174
  @param fd: the file descriptor we need to lock
3175

3176
  """
3177
  try:
3178
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3179
  except IOError, err:
3180
    if err.errno == errno.EAGAIN:
3181
      raise errors.LockError("File already locked")
3182
    raise
3183

    
3184

    
3185
def FormatTime(val):
3186
  """Formats a time value.
3187

3188
  @type val: float or None
3189
  @param val: the timestamp as returned by time.time()
3190
  @return: a string value or N/A if we don't have a valid timestamp
3191

3192
  """
3193
  if val is None or not isinstance(val, (int, float)):
3194
    return "N/A"
3195
  # these two codes works on Linux, but they are not guaranteed on all
3196
  # platforms
3197
  return time.strftime("%F %T", time.localtime(val))
3198

    
3199

    
3200
def FormatSeconds(secs):
3201
  """Formats seconds for easier reading.
3202

3203
  @type secs: number
3204
  @param secs: Number of seconds
3205
  @rtype: string
3206
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3207

3208
  """
3209
  parts = []
3210

    
3211
  secs = round(secs, 0)
3212

    
3213
  if secs > 0:
3214
    # Negative values would be a bit tricky
3215
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3216
      (complete, secs) = divmod(secs, one)
3217
      if complete or parts:
3218
        parts.append("%d%s" % (complete, unit))
3219

    
3220
  parts.append("%ds" % secs)
3221

    
3222
  return " ".join(parts)
3223

    
3224

    
3225
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3226
  """Reads the watcher pause file.
3227

3228
  @type filename: string
3229
  @param filename: Path to watcher pause file
3230
  @type now: None, float or int
3231
  @param now: Current time as Unix timestamp
3232
  @type remove_after: int
3233
  @param remove_after: Remove watcher pause file after specified amount of
3234
    seconds past the pause end time
3235

3236
  """
3237
  if now is None:
3238
    now = time.time()
3239

    
3240
  try:
3241
    value = ReadFile(filename)
3242
  except IOError, err:
3243
    if err.errno != errno.ENOENT:
3244
      raise
3245
    value = None
3246

    
3247
  if value is not None:
3248
    try:
3249
      value = int(value)
3250
    except ValueError:
3251
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3252
                       " removing it"), filename)
3253
      RemoveFile(filename)
3254
      value = None
3255

    
3256
    if value is not None:
3257
      # Remove file if it's outdated
3258
      if now > (value + remove_after):
3259
        RemoveFile(filename)
3260
        value = None
3261

    
3262
      elif now > value:
3263
        value = None
3264

    
3265
  return value
3266

    
3267

    
3268
class RetryTimeout(Exception):
3269
  """Retry loop timed out.
3270

3271
  Any arguments which was passed by the retried function to RetryAgain will be
3272
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3273
  the RaiseInner helper method will reraise it.
3274

3275
  """
3276
  def RaiseInner(self):
3277
    if self.args and isinstance(self.args[0], Exception):
3278
      raise self.args[0]
3279
    else:
3280
      raise RetryTimeout(*self.args)
3281

    
3282

    
3283
class RetryAgain(Exception):
3284
  """Retry again.
3285

3286
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3287
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3288
  of the RetryTimeout() method can be used to reraise it.
3289

3290
  """
3291

    
3292

    
3293
class _RetryDelayCalculator(object):
3294
  """Calculator for increasing delays.
3295

3296
  """
3297
  __slots__ = [
3298
    "_factor",
3299
    "_limit",
3300
    "_next",
3301
    "_start",
3302
    ]
3303

    
3304
  def __init__(self, start, factor, limit):
3305
    """Initializes this class.
3306

3307
    @type start: float
3308
    @param start: Initial delay
3309
    @type factor: float
3310
    @param factor: Factor for delay increase
3311
    @type limit: float or None
3312
    @param limit: Upper limit for delay or None for no limit
3313

3314
    """
3315
    assert start > 0.0
3316
    assert factor >= 1.0
3317
    assert limit is None or limit >= 0.0
3318

    
3319
    self._start = start
3320
    self._factor = factor
3321
    self._limit = limit
3322

    
3323
    self._next = start
3324

    
3325
  def __call__(self):
3326
    """Returns current delay and calculates the next one.
3327

3328
    """
3329
    current = self._next
3330

    
3331
    # Update for next run
3332
    if self._limit is None or self._next < self._limit:
3333
      self._next = min(self._limit, self._next * self._factor)
3334

    
3335
    return current
3336

    
3337

    
3338
#: Special delay to specify whole remaining timeout
3339
RETRY_REMAINING_TIME = object()
3340

    
3341

    
3342
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3343
          _time_fn=time.time):
3344
  """Call a function repeatedly until it succeeds.
3345

3346
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3347
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3348
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3349

3350
  C{delay} can be one of the following:
3351
    - callable returning the delay length as a float
3352
    - Tuple of (start, factor, limit)
3353
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3354
      useful when overriding L{wait_fn} to wait for an external event)
3355
    - A static delay as a number (int or float)
3356

3357
  @type fn: callable
3358
  @param fn: Function to be called
3359
  @param delay: Either a callable (returning the delay), a tuple of (start,
3360
                factor, limit) (see L{_RetryDelayCalculator}),
3361
                L{RETRY_REMAINING_TIME} or a number (int or float)
3362
  @type timeout: float
3363
  @param timeout: Total timeout
3364
  @type wait_fn: callable
3365
  @param wait_fn: Waiting function
3366
  @return: Return value of function
3367

3368
  """
3369
  assert callable(fn)
3370
  assert callable(wait_fn)
3371
  assert callable(_time_fn)
3372

    
3373
  if args is None:
3374
    args = []
3375

    
3376
  end_time = _time_fn() + timeout
3377

    
3378
  if callable(delay):
3379
    # External function to calculate delay
3380
    calc_delay = delay
3381

    
3382
  elif isinstance(delay, (tuple, list)):
3383
    # Increasing delay with optional upper boundary
3384
    (start, factor, limit) = delay
3385
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3386

    
3387
  elif delay is RETRY_REMAINING_TIME:
3388
    # Always use the remaining time
3389
    calc_delay = None
3390

    
3391
  else:
3392
    # Static delay
3393
    calc_delay = lambda: delay
3394

    
3395
  assert calc_delay is None or callable(calc_delay)
3396

    
3397
  while True:
3398
    retry_args = []
3399
    try:
3400
      # pylint: disable-msg=W0142
3401
      return fn(*args)
3402
    except RetryAgain, err:
3403
      retry_args = err.args
3404
    except RetryTimeout:
3405
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3406
                                   " handle RetryTimeout")
3407

    
3408
    remaining_time = end_time - _time_fn()
3409

    
3410
    if remaining_time < 0.0:
3411
      # pylint: disable-msg=W0142
3412
      raise RetryTimeout(*retry_args)
3413

    
3414
    assert remaining_time >= 0.0
3415

    
3416
    if calc_delay is None:
3417
      wait_fn(remaining_time)
3418
    else:
3419
      current_delay = calc_delay()
3420
      if current_delay > 0.0:
3421
        wait_fn(current_delay)
3422

    
3423

    
3424
def GetClosedTempfile(*args, **kwargs):
3425
  """Creates a temporary file and returns its path.
3426

3427
  """
3428
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3429
  _CloseFDNoErr(fd)
3430
  return path
3431

    
3432

    
3433
def GenerateSelfSignedX509Cert(common_name, validity):
3434
  """Generates a self-signed X509 certificate.
3435

3436
  @type common_name: string
3437
  @param common_name: commonName value
3438
  @type validity: int
3439
  @param validity: Validity for certificate in seconds
3440

3441
  """
3442
  # Create private and public key
3443
  key = OpenSSL.crypto.PKey()
3444
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3445

    
3446
  # Create self-signed certificate
3447
  cert = OpenSSL.crypto.X509()
3448
  if common_name:
3449
    cert.get_subject().CN = common_name
3450
  cert.set_serial_number(1)
3451
  cert.gmtime_adj_notBefore(0)
3452
  cert.gmtime_adj_notAfter(validity)
3453
  cert.set_issuer(cert.get_subject())
3454
  cert.set_pubkey(key)
3455
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3456

    
3457
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3458
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3459

    
3460
  return (key_pem, cert_pem)
3461

    
3462

    
3463
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3464
  """Legacy function to generate self-signed X509 certificate.
3465

3466
  """
3467
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3468
                                                   validity * 24 * 60 * 60)
3469

    
3470
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3471

    
3472

    
3473
class FileLock(object):
3474
  """Utility class for file locks.
3475

3476
  """
3477
  def __init__(self, fd, filename):
3478
    """Constructor for FileLock.
3479

3480
    @type fd: file
3481
    @param fd: File object
3482
    @type filename: str
3483
    @param filename: Path of the file opened at I{fd}
3484

3485
    """
3486
    self.fd = fd
3487
    self.filename = filename
3488

    
3489
  @classmethod
3490
  def Open(cls, filename):
3491
    """Creates and opens a file to be used as a file-based lock.
3492

3493
    @type filename: string
3494
    @param filename: path to the file to be locked
3495

3496
    """
3497
    # Using "os.open" is necessary to allow both opening existing file
3498
    # read/write and creating if not existing. Vanilla "open" will truncate an
3499
    # existing file -or- allow creating if not existing.
3500
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3501
               filename)
3502

    
3503
  def __del__(self):
3504
    self.Close()
3505

    
3506
  def Close(self):
3507
    """Close the file and release the lock.
3508

3509
    """
3510
    if hasattr(self, "fd") and self.fd:
3511
      self.fd.close()
3512
      self.fd = None
3513

    
3514
  def _flock(self, flag, blocking, timeout, errmsg):
3515
    """Wrapper for fcntl.flock.
3516

3517
    @type flag: int
3518
    @param flag: operation flag
3519
    @type blocking: bool
3520
    @param blocking: whether the operation should be done in blocking mode.
3521
    @type timeout: None or float
3522
    @param timeout: for how long the operation should be retried (implies
3523
                    non-blocking mode).
3524
    @type errmsg: string
3525
    @param errmsg: error message in case operation fails.
3526

3527
    """
3528
    assert self.fd, "Lock was closed"
3529
    assert timeout is None or timeout >= 0, \
3530
      "If specified, timeout must be positive"
3531
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3532

    
3533
    # When a timeout is used, LOCK_NB must always be set
3534
    if not (timeout is None and blocking):
3535
      flag |= fcntl.LOCK_NB
3536

    
3537
    if timeout is None:
3538
      self._Lock(self.fd, flag, timeout)
3539
    else:
3540
      try:
3541
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3542
              args=(self.fd, flag, timeout))
3543
      except RetryTimeout:
3544
        raise errors.LockError(errmsg)
3545

    
3546
  @staticmethod
3547
  def _Lock(fd, flag, timeout):
3548
    try:
3549
      fcntl.flock(fd, flag)
3550
    except IOError, err:
3551
      if timeout is not None and err.errno == errno.EAGAIN:
3552
        raise RetryAgain()
3553

    
3554
      logging.exception("fcntl.flock failed")
3555
      raise
3556

    
3557
  def Exclusive(self, blocking=False, timeout=None):
3558
    """Locks the file in exclusive mode.
3559

3560
    @type blocking: boolean
3561
    @param blocking: whether to block and wait until we
3562
        can lock the file or return immediately
3563
    @type timeout: int or None
3564
    @param timeout: if not None, the duration to wait for the lock
3565
        (in blocking mode)
3566

3567
    """
3568
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3569
                "Failed to lock %s in exclusive mode" % self.filename)
3570

    
3571
  def Shared(self, blocking=False, timeout=None):
3572
    """Locks the file in shared mode.
3573

3574
    @type blocking: boolean
3575
    @param blocking: whether to block and wait until we
3576
        can lock the file or return immediately
3577
    @type timeout: int or None
3578
    @param timeout: if not None, the duration to wait for the lock
3579
        (in blocking mode)
3580

3581
    """
3582
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3583
                "Failed to lock %s in shared mode" % self.filename)
3584

    
3585
  def Unlock(self, blocking=True, timeout=None):
3586
    """Unlocks the file.
3587

3588
    According to C{flock(2)}, unlocking can also be a nonblocking
3589
    operation::
3590

3591
      To make a non-blocking request, include LOCK_NB with any of the above
3592
      operations.
3593

3594
    @type blocking: boolean
3595
    @param blocking: whether to block and wait until we
3596
        can lock the file or return immediately
3597
    @type timeout: int or None
3598
    @param timeout: if not None, the duration to wait for the lock
3599
        (in blocking mode)
3600

3601
    """
3602
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3603
                "Failed to unlock %s" % self.filename)
3604

    
3605

    
3606
class LineSplitter:
3607
  """Splits data chunks into lines separated by newline.
3608

3609
  Instances provide a file-like interface.
3610

3611
  """
3612
  def __init__(self, line_fn, *args):
3613
    """Initializes this class.
3614

3615
    @type line_fn: callable
3616
    @param line_fn: Function called for each line, first parameter is line
3617
    @param args: Extra arguments for L{line_fn}
3618

3619
    """
3620
    assert callable(line_fn)
3621

    
3622
    if args:
3623
      # Python 2.4 doesn't have functools.partial yet
3624
      self._line_fn = \
3625
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3626
    else:
3627
      self._line_fn = line_fn
3628

    
3629
    self._lines = collections.deque()
3630
    self._buffer = ""
3631

    
3632
  def write(self, data):
3633
    parts = (self._buffer + data).split("\n")
3634
    self._buffer = parts.pop()
3635
    self._lines.extend(parts)
3636

    
3637
  def flush(self):
3638
    while self._lines:
3639
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3640

    
3641
  def close(self):
3642
    self.flush()
3643
    if self._buffer:
3644
      self._line_fn(self._buffer)
3645

    
3646

    
3647
def SignalHandled(signums):
3648
  """Signal Handled decoration.
3649

3650
  This special decorator installs a signal handler and then calls the target
3651
  function. The function must accept a 'signal_handlers' keyword argument,
3652
  which will contain a dict indexed by signal number, with SignalHandler
3653
  objects as values.
3654

3655
  The decorator can be safely stacked with iself, to handle multiple signals
3656
  with different handlers.
3657

3658
  @type signums: list
3659
  @param signums: signals to intercept
3660

3661
  """
3662
  def wrap(fn):
3663
    def sig_function(*args, **kwargs):
3664
      assert 'signal_handlers' not in kwargs or \
3665
             kwargs['signal_handlers'] is None or \
3666
             isinstance(kwargs['signal_handlers'], dict), \
3667
             "Wrong signal_handlers parameter in original function call"
3668
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3669
        signal_handlers = kwargs['signal_handlers']
3670
      else:
3671
        signal_handlers = {}
3672
        kwargs['signal_handlers'] = signal_handlers
3673
      sighandler = SignalHandler(signums)
3674
      try:
3675
        for sig in signums:
3676
          signal_handlers[sig] = sighandler
3677
        return fn(*args, **kwargs)
3678
      finally:
3679
        sighandler.Reset()
3680
    return sig_function
3681
  return wrap
3682

    
3683

    
3684
class SignalWakeupFd(object):
3685
  try:
3686
    # This is only supported in Python 2.5 and above (some distributions
3687
    # backported it to Python 2.4)
3688
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3689
  except AttributeError:
3690
    # Not supported
3691
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3692
      return -1
3693
  else:
3694
    def _SetWakeupFd(self, fd):
3695
      return self._set_wakeup_fd_fn(fd)
3696

    
3697
  def __init__(self):
3698
    """Initializes this class.
3699

3700
    """
3701
    (read_fd, write_fd) = os.pipe()
3702

    
3703
    # Once these succeeded, the file descriptors will be closed automatically.
3704
    # Buffer size 0 is important, otherwise .read() with a specified length
3705
    # might buffer data and the file descriptors won't be marked readable.
3706
    self._read_fh = os.fdopen(read_fd, "r", 0)
3707
    self._write_fh = os.fdopen(write_fd, "w", 0)
3708

    
3709
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3710

    
3711
    # Utility functions
3712
    self.fileno = self._read_fh.fileno
3713
    self.read = self._read_fh.read
3714

    
3715
  def Reset(self):
3716
    """Restores the previous wakeup file descriptor.
3717

3718
    """
3719
    if hasattr(self, "_previous") and self._previous is not None:
3720
      self._SetWakeupFd(self._previous)
3721
      self._previous = None
3722

    
3723
  def Notify(self):
3724
    """Notifies the wakeup file descriptor.
3725

3726
    """
3727
    self._write_fh.write("\0")
3728

    
3729
  def __del__(self):
3730
    """Called before object deletion.
3731

3732
    """
3733
    self.Reset()
3734

    
3735

    
3736
class SignalHandler(object):
3737
  """Generic signal handler class.
3738

3739
  It automatically restores the original handler when deconstructed or
3740
  when L{Reset} is called. You can either pass your own handler
3741
  function in or query the L{called} attribute to detect whether the
3742
  signal was sent.
3743

3744
  @type signum: list
3745
  @ivar signum: the signals we handle
3746
  @type called: boolean
3747
  @ivar called: tracks whether any of the signals have been raised
3748

3749
  """
3750
  def __init__(self, signum, handler_fn=None, wakeup=None):
3751
    """Constructs a new SignalHandler instance.
3752

3753
    @type signum: int or list of ints
3754
    @param signum: Single signal number or set of signal numbers
3755
    @type handler_fn: callable
3756
    @param handler_fn: Signal handling function
3757

3758
    """
3759
    assert handler_fn is None or callable(handler_fn)
3760

    
3761
    self.signum = set(signum)
3762
    self.called = False
3763

    
3764
    self._handler_fn = handler_fn
3765
    self._wakeup = wakeup
3766

    
3767
    self._previous = {}
3768
    try:
3769
      for signum in self.signum:
3770
        # Setup handler
3771
        prev_handler = signal.signal(signum, self._HandleSignal)
3772
        try:
3773
          self._previous[signum] = prev_handler
3774
        except:
3775
          # Restore previous handler
3776
          signal.signal(signum, prev_handler)
3777
          raise
3778
    except:
3779
      # Reset all handlers
3780
      self.Reset()
3781
      # Here we have a race condition: a handler may have already been called,
3782
      # but there's not much we can do about it at this point.
3783
      raise
3784

    
3785
  def __del__(self):
3786
    self.Reset()
3787

    
3788
  def Reset(self):
3789
    """Restore previous handler.
3790

3791
    This will reset all the signals to their previous handlers.
3792

3793
    """
3794
    for signum, prev_handler in self._previous.items():
3795
      signal.signal(signum, prev_handler)
3796
      # If successful, remove from dict
3797
      del self._previous[signum]
3798

    
3799
  def Clear(self):
3800
    """Unsets the L{called} flag.
3801

3802
    This function can be used in case a signal may arrive several times.
3803

3804
    """
3805
    self.called = False
3806

    
3807
  def _HandleSignal(self, signum, frame):
3808
    """Actual signal handling function.
3809

3810
    """
3811
    # This is not nice and not absolutely atomic, but it appears to be the only
3812
    # solution in Python -- there are no atomic types.
3813
    self.called = True
3814

    
3815
    if self._wakeup:
3816
      # Notify whoever is interested in signals
3817
      self._wakeup.Notify()
3818

    
3819
    if self._handler_fn:
3820
      self._handler_fn(signum, frame)
3821

    
3822

    
3823
class FieldSet(object):
3824
  """A simple field set.
3825

3826
  Among the features are:
3827
    - checking if a string is among a list of static string or regex objects
3828
    - checking if a whole list of string matches
3829
    - returning the matching groups from a regex match
3830

3831
  Internally, all fields are held as regular expression objects.
3832

3833
  """
3834
  def __init__(self, *items):
3835
    self.items = [re.compile("^%s$" % value) for value in items]
3836

    
3837
  def Extend(self, other_set):
3838
    """Extend the field set with the items from another one"""
3839
    self.items.extend(other_set.items)
3840

    
3841
  def Matches(self, field):
3842
    """Checks if a field matches the current set
3843

3844
    @type field: str
3845
    @param field: the string to match
3846
    @return: either None or a regular expression match object
3847

3848
    """
3849
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3850
      return m
3851
    return None
3852

    
3853
  def NonMatching(self, items):
3854
    """Returns the list of fields not matching the current set
3855

3856
    @type items: list
3857
    @param items: the list of fields to check
3858
    @rtype: list
3859
    @return: list of non-matching fields
3860

3861
    """
3862
    return [val for val in items if not self.Matches(val)]