Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 560cbec1

History | View | Annotate | Download (105 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  import ctypes
59
except ImportError:
60
  ctypes = None
61

    
62
from ganeti import errors
63
from ganeti import constants
64
from ganeti import compat
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
85
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
86
#
87
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
88
# pid_t to "int".
89
#
90
# IEEE Std 1003.1-2008:
91
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
92
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
93
_STRUCT_UCRED = "iII"
94
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
95

    
96
# Certificate verification results
97
(CERT_WARNING,
98
 CERT_ERROR) = range(1, 3)
99

    
100
# Flags for mlockall() (from bits/mman.h)
101
_MCL_CURRENT = 1
102
_MCL_FUTURE = 2
103

    
104

    
105
class RunResult(object):
106
  """Holds the result of running external programs.
107

108
  @type exit_code: int
109
  @ivar exit_code: the exit code of the program, or None (if the program
110
      didn't exit())
111
  @type signal: int or None
112
  @ivar signal: the signal that caused the program to finish, or None
113
      (if the program wasn't terminated by a signal)
114
  @type stdout: str
115
  @ivar stdout: the standard output of the program
116
  @type stderr: str
117
  @ivar stderr: the standard error of the program
118
  @type failed: boolean
119
  @ivar failed: True in case the program was
120
      terminated by a signal or exited with a non-zero exit code
121
  @ivar fail_reason: a string detailing the termination reason
122

123
  """
124
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
125
               "failed", "fail_reason", "cmd"]
126

    
127

    
128
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
129
    self.cmd = cmd
130
    self.exit_code = exit_code
131
    self.signal = signal_
132
    self.stdout = stdout
133
    self.stderr = stderr
134
    self.failed = (signal_ is not None or exit_code != 0)
135

    
136
    if self.signal is not None:
137
      self.fail_reason = "terminated by signal %s" % self.signal
138
    elif self.exit_code is not None:
139
      self.fail_reason = "exited with exit code %s" % self.exit_code
140
    else:
141
      self.fail_reason = "unable to determine termination reason"
142

    
143
    if self.failed:
144
      logging.debug("Command '%s' failed (%s); output: %s",
145
                    self.cmd, self.fail_reason, self.output)
146

    
147
  def _GetOutput(self):
148
    """Returns the combined stdout and stderr for easier usage.
149

150
    """
151
    return self.stdout + self.stderr
152

    
153
  output = property(_GetOutput, None, None, "Return full output")
154

    
155

    
156
def _BuildCmdEnvironment(env, reset):
157
  """Builds the environment for an external program.
158

159
  """
160
  if reset:
161
    cmd_env = {}
162
  else:
163
    cmd_env = os.environ.copy()
164
    cmd_env["LC_ALL"] = "C"
165

    
166
  if env is not None:
167
    cmd_env.update(env)
168

    
169
  return cmd_env
170

    
171

    
172
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
173
  """Execute a (shell) command.
174

175
  The command should not read from its standard input, as it will be
176
  closed.
177

178
  @type cmd: string or list
179
  @param cmd: Command to run
180
  @type env: dict
181
  @param env: Additional environment variables
182
  @type output: str
183
  @param output: if desired, the output of the command can be
184
      saved in a file instead of the RunResult instance; this
185
      parameter denotes the file name (if not None)
186
  @type cwd: string
187
  @param cwd: if specified, will be used as the working
188
      directory for the command; the default will be /
189
  @type reset_env: boolean
190
  @param reset_env: whether to reset or keep the default os environment
191
  @rtype: L{RunResult}
192
  @return: RunResult instance
193
  @raise errors.ProgrammerError: if we call this when forks are disabled
194

195
  """
196
  if no_fork:
197
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
198

    
199
  if isinstance(cmd, basestring):
200
    strcmd = cmd
201
    shell = True
202
  else:
203
    cmd = [str(val) for val in cmd]
204
    strcmd = ShellQuoteArgs(cmd)
205
    shell = False
206

    
207
  if output:
208
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
209
  else:
210
    logging.debug("RunCmd %s", strcmd)
211

    
212
  cmd_env = _BuildCmdEnvironment(env, reset_env)
213

    
214
  try:
215
    if output is None:
216
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
217
    else:
218
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
219
      out = err = ""
220
  except OSError, err:
221
    if err.errno == errno.ENOENT:
222
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
223
                               (strcmd, err))
224
    else:
225
      raise
226

    
227
  if status >= 0:
228
    exitcode = status
229
    signal_ = None
230
  else:
231
    exitcode = None
232
    signal_ = -status
233

    
234
  return RunResult(exitcode, signal_, out, err, strcmd)
235

    
236

    
237
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
238
                pidfile=None):
239
  """Start a daemon process after forking twice.
240

241
  @type cmd: string or list
242
  @param cmd: Command to run
243
  @type env: dict
244
  @param env: Additional environment variables
245
  @type cwd: string
246
  @param cwd: Working directory for the program
247
  @type output: string
248
  @param output: Path to file in which to save the output
249
  @type output_fd: int
250
  @param output_fd: File descriptor for output
251
  @type pidfile: string
252
  @param pidfile: Process ID file
253
  @rtype: int
254
  @return: Daemon process ID
255
  @raise errors.ProgrammerError: if we call this when forks are disabled
256

257
  """
258
  if no_fork:
259
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
260
                                 " disabled")
261

    
262
  if output and not (bool(output) ^ (output_fd is not None)):
263
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
264
                                 " specified")
265

    
266
  if isinstance(cmd, basestring):
267
    cmd = ["/bin/sh", "-c", cmd]
268

    
269
  strcmd = ShellQuoteArgs(cmd)
270

    
271
  if output:
272
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
273
  else:
274
    logging.debug("StartDaemon %s", strcmd)
275

    
276
  cmd_env = _BuildCmdEnvironment(env, False)
277

    
278
  # Create pipe for sending PID back
279
  (pidpipe_read, pidpipe_write) = os.pipe()
280
  try:
281
    try:
282
      # Create pipe for sending error messages
283
      (errpipe_read, errpipe_write) = os.pipe()
284
      try:
285
        try:
286
          # First fork
287
          pid = os.fork()
288
          if pid == 0:
289
            try:
290
              # Child process, won't return
291
              _StartDaemonChild(errpipe_read, errpipe_write,
292
                                pidpipe_read, pidpipe_write,
293
                                cmd, cmd_env, cwd,
294
                                output, output_fd, pidfile)
295
            finally:
296
              # Well, maybe child process failed
297
              os._exit(1) # pylint: disable-msg=W0212
298
        finally:
299
          _CloseFDNoErr(errpipe_write)
300

    
301
        # Wait for daemon to be started (or an error message to arrive) and read
302
        # up to 100 KB as an error message
303
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
304
      finally:
305
        _CloseFDNoErr(errpipe_read)
306
    finally:
307
      _CloseFDNoErr(pidpipe_write)
308

    
309
    # Read up to 128 bytes for PID
310
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
311
  finally:
312
    _CloseFDNoErr(pidpipe_read)
313

    
314
  # Try to avoid zombies by waiting for child process
315
  try:
316
    os.waitpid(pid, 0)
317
  except OSError:
318
    pass
319

    
320
  if errormsg:
321
    raise errors.OpExecError("Error when starting daemon process: %r" %
322
                             errormsg)
323

    
324
  try:
325
    return int(pidtext)
326
  except (ValueError, TypeError), err:
327
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
328
                             (pidtext, err))
329

    
330

    
331
def _StartDaemonChild(errpipe_read, errpipe_write,
332
                      pidpipe_read, pidpipe_write,
333
                      args, env, cwd,
334
                      output, fd_output, pidfile):
335
  """Child process for starting daemon.
336

337
  """
338
  try:
339
    # Close parent's side
340
    _CloseFDNoErr(errpipe_read)
341
    _CloseFDNoErr(pidpipe_read)
342

    
343
    # First child process
344
    os.chdir("/")
345
    os.umask(077)
346
    os.setsid()
347

    
348
    # And fork for the second time
349
    pid = os.fork()
350
    if pid != 0:
351
      # Exit first child process
352
      os._exit(0) # pylint: disable-msg=W0212
353

    
354
    # Make sure pipe is closed on execv* (and thereby notifies original process)
355
    SetCloseOnExecFlag(errpipe_write, True)
356

    
357
    # List of file descriptors to be left open
358
    noclose_fds = [errpipe_write]
359

    
360
    # Open PID file
361
    if pidfile:
362
      try:
363
        # TODO: Atomic replace with another locked file instead of writing into
364
        # it after creating
365
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
366

    
367
        # Lock the PID file (and fail if not possible to do so). Any code
368
        # wanting to send a signal to the daemon should try to lock the PID
369
        # file before reading it. If acquiring the lock succeeds, the daemon is
370
        # no longer running and the signal should not be sent.
371
        LockFile(fd_pidfile)
372

    
373
        os.write(fd_pidfile, "%d\n" % os.getpid())
374
      except Exception, err:
375
        raise Exception("Creating and locking PID file failed: %s" % err)
376

    
377
      # Keeping the file open to hold the lock
378
      noclose_fds.append(fd_pidfile)
379

    
380
      SetCloseOnExecFlag(fd_pidfile, False)
381
    else:
382
      fd_pidfile = None
383

    
384
    # Open /dev/null
385
    fd_devnull = os.open(os.devnull, os.O_RDWR)
386

    
387
    assert not output or (bool(output) ^ (fd_output is not None))
388

    
389
    if fd_output is not None:
390
      pass
391
    elif output:
392
      # Open output file
393
      try:
394
        # TODO: Implement flag to set append=yes/no
395
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
396
      except EnvironmentError, err:
397
        raise Exception("Opening output file failed: %s" % err)
398
    else:
399
      fd_output = fd_devnull
400

    
401
    # Redirect standard I/O
402
    os.dup2(fd_devnull, 0)
403
    os.dup2(fd_output, 1)
404
    os.dup2(fd_output, 2)
405

    
406
    # Send daemon PID to parent
407
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
408

    
409
    # Close all file descriptors except stdio and error message pipe
410
    CloseFDs(noclose_fds=noclose_fds)
411

    
412
    # Change working directory
413
    os.chdir(cwd)
414

    
415
    if env is None:
416
      os.execvp(args[0], args)
417
    else:
418
      os.execvpe(args[0], args, env)
419
  except: # pylint: disable-msg=W0702
420
    try:
421
      # Report errors to original process
422
      buf = str(sys.exc_info()[1])
423

    
424
      RetryOnSignal(os.write, errpipe_write, buf)
425
    except: # pylint: disable-msg=W0702
426
      # Ignore errors in error handling
427
      pass
428

    
429
  os._exit(1) # pylint: disable-msg=W0212
430

    
431

    
432
def _RunCmdPipe(cmd, env, via_shell, cwd):
433
  """Run a command and return its output.
434

435
  @type  cmd: string or list
436
  @param cmd: Command to run
437
  @type env: dict
438
  @param env: The environment to use
439
  @type via_shell: bool
440
  @param via_shell: if we should run via the shell
441
  @type cwd: string
442
  @param cwd: the working directory for the program
443
  @rtype: tuple
444
  @return: (out, err, status)
445

446
  """
447
  poller = select.poll()
448
  child = subprocess.Popen(cmd, shell=via_shell,
449
                           stderr=subprocess.PIPE,
450
                           stdout=subprocess.PIPE,
451
                           stdin=subprocess.PIPE,
452
                           close_fds=True, env=env,
453
                           cwd=cwd)
454

    
455
  child.stdin.close()
456
  poller.register(child.stdout, select.POLLIN)
457
  poller.register(child.stderr, select.POLLIN)
458
  out = StringIO()
459
  err = StringIO()
460
  fdmap = {
461
    child.stdout.fileno(): (out, child.stdout),
462
    child.stderr.fileno(): (err, child.stderr),
463
    }
464
  for fd in fdmap:
465
    SetNonblockFlag(fd, True)
466

    
467
  while fdmap:
468
    pollresult = RetryOnSignal(poller.poll)
469

    
470
    for fd, event in pollresult:
471
      if event & select.POLLIN or event & select.POLLPRI:
472
        data = fdmap[fd][1].read()
473
        # no data from read signifies EOF (the same as POLLHUP)
474
        if not data:
475
          poller.unregister(fd)
476
          del fdmap[fd]
477
          continue
478
        fdmap[fd][0].write(data)
479
      if (event & select.POLLNVAL or event & select.POLLHUP or
480
          event & select.POLLERR):
481
        poller.unregister(fd)
482
        del fdmap[fd]
483

    
484
  out = out.getvalue()
485
  err = err.getvalue()
486

    
487
  status = child.wait()
488
  return out, err, status
489

    
490

    
491
def _RunCmdFile(cmd, env, via_shell, output, cwd):
492
  """Run a command and save its output to a file.
493

494
  @type  cmd: string or list
495
  @param cmd: Command to run
496
  @type env: dict
497
  @param env: The environment to use
498
  @type via_shell: bool
499
  @param via_shell: if we should run via the shell
500
  @type output: str
501
  @param output: the filename in which to save the output
502
  @type cwd: string
503
  @param cwd: the working directory for the program
504
  @rtype: int
505
  @return: the exit status
506

507
  """
508
  fh = open(output, "a")
509
  try:
510
    child = subprocess.Popen(cmd, shell=via_shell,
511
                             stderr=subprocess.STDOUT,
512
                             stdout=fh,
513
                             stdin=subprocess.PIPE,
514
                             close_fds=True, env=env,
515
                             cwd=cwd)
516

    
517
    child.stdin.close()
518
    status = child.wait()
519
  finally:
520
    fh.close()
521
  return status
522

    
523

    
524
def SetCloseOnExecFlag(fd, enable):
525
  """Sets or unsets the close-on-exec flag on a file descriptor.
526

527
  @type fd: int
528
  @param fd: File descriptor
529
  @type enable: bool
530
  @param enable: Whether to set or unset it.
531

532
  """
533
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
534

    
535
  if enable:
536
    flags |= fcntl.FD_CLOEXEC
537
  else:
538
    flags &= ~fcntl.FD_CLOEXEC
539

    
540
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
541

    
542

    
543
def SetNonblockFlag(fd, enable):
544
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
545

546
  @type fd: int
547
  @param fd: File descriptor
548
  @type enable: bool
549
  @param enable: Whether to set or unset it
550

551
  """
552
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
553

    
554
  if enable:
555
    flags |= os.O_NONBLOCK
556
  else:
557
    flags &= ~os.O_NONBLOCK
558

    
559
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
560

    
561

    
562
def RetryOnSignal(fn, *args, **kwargs):
563
  """Calls a function again if it failed due to EINTR.
564

565
  """
566
  while True:
567
    try:
568
      return fn(*args, **kwargs)
569
    except EnvironmentError, err:
570
      if err.errno != errno.EINTR:
571
        raise
572
    except (socket.error, select.error), err:
573
      # In python 2.6 and above select.error is an IOError, so it's handled
574
      # above, in 2.5 and below it's not, and it's handled here.
575
      if not (err.args and err.args[0] == errno.EINTR):
576
        raise
577

    
578

    
579
def RunParts(dir_name, env=None, reset_env=False):
580
  """Run Scripts or programs in a directory
581

582
  @type dir_name: string
583
  @param dir_name: absolute path to a directory
584
  @type env: dict
585
  @param env: The environment to use
586
  @type reset_env: boolean
587
  @param reset_env: whether to reset or keep the default os environment
588
  @rtype: list of tuples
589
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
590

591
  """
592
  rr = []
593

    
594
  try:
595
    dir_contents = ListVisibleFiles(dir_name)
596
  except OSError, err:
597
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
598
    return rr
599

    
600
  for relname in sorted(dir_contents):
601
    fname = PathJoin(dir_name, relname)
602
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
603
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
604
      rr.append((relname, constants.RUNPARTS_SKIP, None))
605
    else:
606
      try:
607
        result = RunCmd([fname], env=env, reset_env=reset_env)
608
      except Exception, err: # pylint: disable-msg=W0703
609
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
610
      else:
611
        rr.append((relname, constants.RUNPARTS_RUN, result))
612

    
613
  return rr
614

    
615

    
616
def GetSocketCredentials(sock):
617
  """Returns the credentials of the foreign process connected to a socket.
618

619
  @param sock: Unix socket
620
  @rtype: tuple; (number, number, number)
621
  @return: The PID, UID and GID of the connected foreign process.
622

623
  """
624
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
625
                             _STRUCT_UCRED_SIZE)
626
  return struct.unpack(_STRUCT_UCRED, peercred)
627

    
628

    
629
def RemoveFile(filename):
630
  """Remove a file ignoring some errors.
631

632
  Remove a file, ignoring non-existing ones or directories. Other
633
  errors are passed.
634

635
  @type filename: str
636
  @param filename: the file to be removed
637

638
  """
639
  try:
640
    os.unlink(filename)
641
  except OSError, err:
642
    if err.errno not in (errno.ENOENT, errno.EISDIR):
643
      raise
644

    
645

    
646
def RemoveDir(dirname):
647
  """Remove an empty directory.
648

649
  Remove a directory, ignoring non-existing ones.
650
  Other errors are passed. This includes the case,
651
  where the directory is not empty, so it can't be removed.
652

653
  @type dirname: str
654
  @param dirname: the empty directory to be removed
655

656
  """
657
  try:
658
    os.rmdir(dirname)
659
  except OSError, err:
660
    if err.errno != errno.ENOENT:
661
      raise
662

    
663

    
664
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
665
  """Renames a file.
666

667
  @type old: string
668
  @param old: Original path
669
  @type new: string
670
  @param new: New path
671
  @type mkdir: bool
672
  @param mkdir: Whether to create target directory if it doesn't exist
673
  @type mkdir_mode: int
674
  @param mkdir_mode: Mode for newly created directories
675

676
  """
677
  try:
678
    return os.rename(old, new)
679
  except OSError, err:
680
    # In at least one use case of this function, the job queue, directory
681
    # creation is very rare. Checking for the directory before renaming is not
682
    # as efficient.
683
    if mkdir and err.errno == errno.ENOENT:
684
      # Create directory and try again
685
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
686

    
687
      return os.rename(old, new)
688

    
689
    raise
690

    
691

    
692
def Makedirs(path, mode=0750):
693
  """Super-mkdir; create a leaf directory and all intermediate ones.
694

695
  This is a wrapper around C{os.makedirs} adding error handling not implemented
696
  before Python 2.5.
697

698
  """
699
  try:
700
    os.makedirs(path, mode)
701
  except OSError, err:
702
    # Ignore EEXIST. This is only handled in os.makedirs as included in
703
    # Python 2.5 and above.
704
    if err.errno != errno.EEXIST or not os.path.exists(path):
705
      raise
706

    
707

    
708
def ResetTempfileModule():
709
  """Resets the random name generator of the tempfile module.
710

711
  This function should be called after C{os.fork} in the child process to
712
  ensure it creates a newly seeded random generator. Otherwise it would
713
  generate the same random parts as the parent process. If several processes
714
  race for the creation of a temporary file, this could lead to one not getting
715
  a temporary name.
716

717
  """
718
  # pylint: disable-msg=W0212
719
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
720
    tempfile._once_lock.acquire()
721
    try:
722
      # Reset random name generator
723
      tempfile._name_sequence = None
724
    finally:
725
      tempfile._once_lock.release()
726
  else:
727
    logging.critical("The tempfile module misses at least one of the"
728
                     " '_once_lock' and '_name_sequence' attributes")
729

    
730

    
731
def _FingerprintFile(filename):
732
  """Compute the fingerprint of a file.
733

734
  If the file does not exist, a None will be returned
735
  instead.
736

737
  @type filename: str
738
  @param filename: the filename to checksum
739
  @rtype: str
740
  @return: the hex digest of the sha checksum of the contents
741
      of the file
742

743
  """
744
  if not (os.path.exists(filename) and os.path.isfile(filename)):
745
    return None
746

    
747
  f = open(filename)
748

    
749
  fp = compat.sha1_hash()
750
  while True:
751
    data = f.read(4096)
752
    if not data:
753
      break
754

    
755
    fp.update(data)
756

    
757
  return fp.hexdigest()
758

    
759

    
760
def FingerprintFiles(files):
761
  """Compute fingerprints for a list of files.
762

763
  @type files: list
764
  @param files: the list of filename to fingerprint
765
  @rtype: dict
766
  @return: a dictionary filename: fingerprint, holding only
767
      existing files
768

769
  """
770
  ret = {}
771

    
772
  for filename in files:
773
    cksum = _FingerprintFile(filename)
774
    if cksum:
775
      ret[filename] = cksum
776

    
777
  return ret
778

    
779

    
780
def ForceDictType(target, key_types, allowed_values=None):
781
  """Force the values of a dict to have certain types.
782

783
  @type target: dict
784
  @param target: the dict to update
785
  @type key_types: dict
786
  @param key_types: dict mapping target dict keys to types
787
                    in constants.ENFORCEABLE_TYPES
788
  @type allowed_values: list
789
  @keyword allowed_values: list of specially allowed values
790

791
  """
792
  if allowed_values is None:
793
    allowed_values = []
794

    
795
  if not isinstance(target, dict):
796
    msg = "Expected dictionary, got '%s'" % target
797
    raise errors.TypeEnforcementError(msg)
798

    
799
  for key in target:
800
    if key not in key_types:
801
      msg = "Unknown key '%s'" % key
802
      raise errors.TypeEnforcementError(msg)
803

    
804
    if target[key] in allowed_values:
805
      continue
806

    
807
    ktype = key_types[key]
808
    if ktype not in constants.ENFORCEABLE_TYPES:
809
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
810
      raise errors.ProgrammerError(msg)
811

    
812
    if ktype == constants.VTYPE_STRING:
813
      if not isinstance(target[key], basestring):
814
        if isinstance(target[key], bool) and not target[key]:
815
          target[key] = ''
816
        else:
817
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
818
          raise errors.TypeEnforcementError(msg)
819
    elif ktype == constants.VTYPE_BOOL:
820
      if isinstance(target[key], basestring) and target[key]:
821
        if target[key].lower() == constants.VALUE_FALSE:
822
          target[key] = False
823
        elif target[key].lower() == constants.VALUE_TRUE:
824
          target[key] = True
825
        else:
826
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
827
          raise errors.TypeEnforcementError(msg)
828
      elif target[key]:
829
        target[key] = True
830
      else:
831
        target[key] = False
832
    elif ktype == constants.VTYPE_SIZE:
833
      try:
834
        target[key] = ParseUnit(target[key])
835
      except errors.UnitParseError, err:
836
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
837
              (key, target[key], err)
838
        raise errors.TypeEnforcementError(msg)
839
    elif ktype == constants.VTYPE_INT:
840
      try:
841
        target[key] = int(target[key])
842
      except (ValueError, TypeError):
843
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
844
        raise errors.TypeEnforcementError(msg)
845

    
846

    
847
def _GetProcStatusPath(pid):
848
  """Returns the path for a PID's proc status file.
849

850
  @type pid: int
851
  @param pid: Process ID
852
  @rtype: string
853

854
  """
855
  return "/proc/%d/status" % pid
856

    
857

    
858
def IsProcessAlive(pid):
859
  """Check if a given pid exists on the system.
860

861
  @note: zombie status is not handled, so zombie processes
862
      will be returned as alive
863
  @type pid: int
864
  @param pid: the process ID to check
865
  @rtype: boolean
866
  @return: True if the process exists
867

868
  """
869
  def _TryStat(name):
870
    try:
871
      os.stat(name)
872
      return True
873
    except EnvironmentError, err:
874
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
875
        return False
876
      elif err.errno == errno.EINVAL:
877
        raise RetryAgain(err)
878
      raise
879

    
880
  assert isinstance(pid, int), "pid must be an integer"
881
  if pid <= 0:
882
    return False
883

    
884
  # /proc in a multiprocessor environment can have strange behaviors.
885
  # Retry the os.stat a few times until we get a good result.
886
  try:
887
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
888
                 args=[_GetProcStatusPath(pid)])
889
  except RetryTimeout, err:
890
    err.RaiseInner()
891

    
892

    
893
def _ParseSigsetT(sigset):
894
  """Parse a rendered sigset_t value.
895

896
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
897
  function.
898

899
  @type sigset: string
900
  @param sigset: Rendered signal set from /proc/$pid/status
901
  @rtype: set
902
  @return: Set of all enabled signal numbers
903

904
  """
905
  result = set()
906

    
907
  signum = 0
908
  for ch in reversed(sigset):
909
    chv = int(ch, 16)
910

    
911
    # The following could be done in a loop, but it's easier to read and
912
    # understand in the unrolled form
913
    if chv & 1:
914
      result.add(signum + 1)
915
    if chv & 2:
916
      result.add(signum + 2)
917
    if chv & 4:
918
      result.add(signum + 3)
919
    if chv & 8:
920
      result.add(signum + 4)
921

    
922
    signum += 4
923

    
924
  return result
925

    
926

    
927
def _GetProcStatusField(pstatus, field):
928
  """Retrieves a field from the contents of a proc status file.
929

930
  @type pstatus: string
931
  @param pstatus: Contents of /proc/$pid/status
932
  @type field: string
933
  @param field: Name of field whose value should be returned
934
  @rtype: string
935

936
  """
937
  for line in pstatus.splitlines():
938
    parts = line.split(":", 1)
939

    
940
    if len(parts) < 2 or parts[0] != field:
941
      continue
942

    
943
    return parts[1].strip()
944

    
945
  return None
946

    
947

    
948
def IsProcessHandlingSignal(pid, signum, status_path=None):
949
  """Checks whether a process is handling a signal.
950

951
  @type pid: int
952
  @param pid: Process ID
953
  @type signum: int
954
  @param signum: Signal number
955
  @rtype: bool
956

957
  """
958
  if status_path is None:
959
    status_path = _GetProcStatusPath(pid)
960

    
961
  try:
962
    proc_status = ReadFile(status_path)
963
  except EnvironmentError, err:
964
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
965
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
966
      return False
967
    raise
968

    
969
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
970
  if sigcgt is None:
971
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
972

    
973
  # Now check whether signal is handled
974
  return signum in _ParseSigsetT(sigcgt)
975

    
976

    
977
def ReadPidFile(pidfile):
978
  """Read a pid from a file.
979

980
  @type  pidfile: string
981
  @param pidfile: path to the file containing the pid
982
  @rtype: int
983
  @return: The process id, if the file exists and contains a valid PID,
984
           otherwise 0
985

986
  """
987
  try:
988
    raw_data = ReadOneLineFile(pidfile)
989
  except EnvironmentError, err:
990
    if err.errno != errno.ENOENT:
991
      logging.exception("Can't read pid file")
992
    return 0
993

    
994
  try:
995
    pid = int(raw_data)
996
  except (TypeError, ValueError), err:
997
    logging.info("Can't parse pid file contents", exc_info=True)
998
    return 0
999

    
1000
  return pid
1001

    
1002

    
1003
def ReadLockedPidFile(path):
1004
  """Reads a locked PID file.
1005

1006
  This can be used together with L{StartDaemon}.
1007

1008
  @type path: string
1009
  @param path: Path to PID file
1010
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1011

1012
  """
1013
  try:
1014
    fd = os.open(path, os.O_RDONLY)
1015
  except EnvironmentError, err:
1016
    if err.errno == errno.ENOENT:
1017
      # PID file doesn't exist
1018
      return None
1019
    raise
1020

    
1021
  try:
1022
    try:
1023
      # Try to acquire lock
1024
      LockFile(fd)
1025
    except errors.LockError:
1026
      # Couldn't lock, daemon is running
1027
      return int(os.read(fd, 100))
1028
  finally:
1029
    os.close(fd)
1030

    
1031
  return None
1032

    
1033

    
1034
def MatchNameComponent(key, name_list, case_sensitive=True):
1035
  """Try to match a name against a list.
1036

1037
  This function will try to match a name like test1 against a list
1038
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1039
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1040
  not I{'test1.ex'}. A multiple match will be considered as no match
1041
  at all (e.g. I{'test1'} against C{['test1.example.com',
1042
  'test1.example.org']}), except when the key fully matches an entry
1043
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1044

1045
  @type key: str
1046
  @param key: the name to be searched
1047
  @type name_list: list
1048
  @param name_list: the list of strings against which to search the key
1049
  @type case_sensitive: boolean
1050
  @param case_sensitive: whether to provide a case-sensitive match
1051

1052
  @rtype: None or str
1053
  @return: None if there is no match I{or} if there are multiple matches,
1054
      otherwise the element from the list which matches
1055

1056
  """
1057
  if key in name_list:
1058
    return key
1059

    
1060
  re_flags = 0
1061
  if not case_sensitive:
1062
    re_flags |= re.IGNORECASE
1063
    key = key.upper()
1064
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1065
  names_filtered = []
1066
  string_matches = []
1067
  for name in name_list:
1068
    if mo.match(name) is not None:
1069
      names_filtered.append(name)
1070
      if not case_sensitive and key == name.upper():
1071
        string_matches.append(name)
1072

    
1073
  if len(string_matches) == 1:
1074
    return string_matches[0]
1075
  if len(names_filtered) == 1:
1076
    return names_filtered[0]
1077
  return None
1078

    
1079

    
1080
class HostInfo:
1081
  """Class implementing resolver and hostname functionality
1082

1083
  """
1084
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1085

    
1086
  def __init__(self, name=None):
1087
    """Initialize the host name object.
1088

1089
    If the name argument is not passed, it will use this system's
1090
    name.
1091

1092
    """
1093
    if name is None:
1094
      name = self.SysName()
1095

    
1096
    self.query = name
1097
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1098
    self.ip = self.ipaddrs[0]
1099

    
1100
  def ShortName(self):
1101
    """Returns the hostname without domain.
1102

1103
    """
1104
    return self.name.split('.')[0]
1105

    
1106
  @staticmethod
1107
  def SysName():
1108
    """Return the current system's name.
1109

1110
    This is simply a wrapper over C{socket.gethostname()}.
1111

1112
    """
1113
    return socket.gethostname()
1114

    
1115
  @staticmethod
1116
  def LookupHostname(hostname):
1117
    """Look up hostname
1118

1119
    @type hostname: str
1120
    @param hostname: hostname to look up
1121

1122
    @rtype: tuple
1123
    @return: a tuple (name, aliases, ipaddrs) as returned by
1124
        C{socket.gethostbyname_ex}
1125
    @raise errors.ResolverError: in case of errors in resolving
1126

1127
    """
1128
    try:
1129
      result = socket.gethostbyname_ex(hostname)
1130
    except socket.gaierror, err:
1131
      # hostname not found in DNS
1132
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1133

    
1134
    return result
1135

    
1136
  @classmethod
1137
  def NormalizeName(cls, hostname):
1138
    """Validate and normalize the given hostname.
1139

1140
    @attention: the validation is a bit more relaxed than the standards
1141
        require; most importantly, we allow underscores in names
1142
    @raise errors.OpPrereqError: when the name is not valid
1143

1144
    """
1145
    hostname = hostname.lower()
1146
    if (not cls._VALID_NAME_RE.match(hostname) or
1147
        # double-dots, meaning empty label
1148
        ".." in hostname or
1149
        # empty initial label
1150
        hostname.startswith(".")):
1151
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1152
                                 errors.ECODE_INVAL)
1153
    if hostname.endswith("."):
1154
      hostname = hostname.rstrip(".")
1155
    return hostname
1156

    
1157

    
1158
def GetHostInfo(name=None):
1159
  """Lookup host name and raise an OpPrereqError for failures"""
1160

    
1161
  try:
1162
    return HostInfo(name)
1163
  except errors.ResolverError, err:
1164
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1165
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1166

    
1167

    
1168
def ListVolumeGroups():
1169
  """List volume groups and their size
1170

1171
  @rtype: dict
1172
  @return:
1173
       Dictionary with keys volume name and values
1174
       the size of the volume
1175

1176
  """
1177
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1178
  result = RunCmd(command)
1179
  retval = {}
1180
  if result.failed:
1181
    return retval
1182

    
1183
  for line in result.stdout.splitlines():
1184
    try:
1185
      name, size = line.split()
1186
      size = int(float(size))
1187
    except (IndexError, ValueError), err:
1188
      logging.error("Invalid output from vgs (%s): %s", err, line)
1189
      continue
1190

    
1191
    retval[name] = size
1192

    
1193
  return retval
1194

    
1195

    
1196
def BridgeExists(bridge):
1197
  """Check whether the given bridge exists in the system
1198

1199
  @type bridge: str
1200
  @param bridge: the bridge name to check
1201
  @rtype: boolean
1202
  @return: True if it does
1203

1204
  """
1205
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1206

    
1207

    
1208
def NiceSort(name_list):
1209
  """Sort a list of strings based on digit and non-digit groupings.
1210

1211
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1212
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1213
  'a11']}.
1214

1215
  The sort algorithm breaks each name in groups of either only-digits
1216
  or no-digits. Only the first eight such groups are considered, and
1217
  after that we just use what's left of the string.
1218

1219
  @type name_list: list
1220
  @param name_list: the names to be sorted
1221
  @rtype: list
1222
  @return: a copy of the name list sorted with our algorithm
1223

1224
  """
1225
  _SORTER_BASE = "(\D+|\d+)"
1226
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1227
                                                  _SORTER_BASE, _SORTER_BASE,
1228
                                                  _SORTER_BASE, _SORTER_BASE,
1229
                                                  _SORTER_BASE, _SORTER_BASE)
1230
  _SORTER_RE = re.compile(_SORTER_FULL)
1231
  _SORTER_NODIGIT = re.compile("^\D*$")
1232
  def _TryInt(val):
1233
    """Attempts to convert a variable to integer."""
1234
    if val is None or _SORTER_NODIGIT.match(val):
1235
      return val
1236
    rval = int(val)
1237
    return rval
1238

    
1239
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1240
             for name in name_list]
1241
  to_sort.sort()
1242
  return [tup[1] for tup in to_sort]
1243

    
1244

    
1245
def TryConvert(fn, val):
1246
  """Try to convert a value ignoring errors.
1247

1248
  This function tries to apply function I{fn} to I{val}. If no
1249
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1250
  the result, else it will return the original value. Any other
1251
  exceptions are propagated to the caller.
1252

1253
  @type fn: callable
1254
  @param fn: function to apply to the value
1255
  @param val: the value to be converted
1256
  @return: The converted value if the conversion was successful,
1257
      otherwise the original value.
1258

1259
  """
1260
  try:
1261
    nv = fn(val)
1262
  except (ValueError, TypeError):
1263
    nv = val
1264
  return nv
1265

    
1266

    
1267
def IsValidIP(ip):
1268
  """Verifies the syntax of an IPv4 address.
1269

1270
  This function checks if the IPv4 address passes is valid or not based
1271
  on syntax (not IP range, class calculations, etc.).
1272

1273
  @type ip: str
1274
  @param ip: the address to be checked
1275
  @rtype: a regular expression match object
1276
  @return: a regular expression match object, or None if the
1277
      address is not valid
1278

1279
  """
1280
  unit = "(0|[1-9]\d{0,2})"
1281
  #TODO: convert and return only boolean
1282
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1283

    
1284

    
1285
def IsValidShellParam(word):
1286
  """Verifies is the given word is safe from the shell's p.o.v.
1287

1288
  This means that we can pass this to a command via the shell and be
1289
  sure that it doesn't alter the command line and is passed as such to
1290
  the actual command.
1291

1292
  Note that we are overly restrictive here, in order to be on the safe
1293
  side.
1294

1295
  @type word: str
1296
  @param word: the word to check
1297
  @rtype: boolean
1298
  @return: True if the word is 'safe'
1299

1300
  """
1301
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1302

    
1303

    
1304
def BuildShellCmd(template, *args):
1305
  """Build a safe shell command line from the given arguments.
1306

1307
  This function will check all arguments in the args list so that they
1308
  are valid shell parameters (i.e. they don't contain shell
1309
  metacharacters). If everything is ok, it will return the result of
1310
  template % args.
1311

1312
  @type template: str
1313
  @param template: the string holding the template for the
1314
      string formatting
1315
  @rtype: str
1316
  @return: the expanded command line
1317

1318
  """
1319
  for word in args:
1320
    if not IsValidShellParam(word):
1321
      raise errors.ProgrammerError("Shell argument '%s' contains"
1322
                                   " invalid characters" % word)
1323
  return template % args
1324

    
1325

    
1326
def FormatUnit(value, units):
1327
  """Formats an incoming number of MiB with the appropriate unit.
1328

1329
  @type value: int
1330
  @param value: integer representing the value in MiB (1048576)
1331
  @type units: char
1332
  @param units: the type of formatting we should do:
1333
      - 'h' for automatic scaling
1334
      - 'm' for MiBs
1335
      - 'g' for GiBs
1336
      - 't' for TiBs
1337
  @rtype: str
1338
  @return: the formatted value (with suffix)
1339

1340
  """
1341
  if units not in ('m', 'g', 't', 'h'):
1342
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1343

    
1344
  suffix = ''
1345

    
1346
  if units == 'm' or (units == 'h' and value < 1024):
1347
    if units == 'h':
1348
      suffix = 'M'
1349
    return "%d%s" % (round(value, 0), suffix)
1350

    
1351
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1352
    if units == 'h':
1353
      suffix = 'G'
1354
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1355

    
1356
  else:
1357
    if units == 'h':
1358
      suffix = 'T'
1359
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1360

    
1361

    
1362
def ParseUnit(input_string):
1363
  """Tries to extract number and scale from the given string.
1364

1365
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1366
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1367
  is always an int in MiB.
1368

1369
  """
1370
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1371
  if not m:
1372
    raise errors.UnitParseError("Invalid format")
1373

    
1374
  value = float(m.groups()[0])
1375

    
1376
  unit = m.groups()[1]
1377
  if unit:
1378
    lcunit = unit.lower()
1379
  else:
1380
    lcunit = 'm'
1381

    
1382
  if lcunit in ('m', 'mb', 'mib'):
1383
    # Value already in MiB
1384
    pass
1385

    
1386
  elif lcunit in ('g', 'gb', 'gib'):
1387
    value *= 1024
1388

    
1389
  elif lcunit in ('t', 'tb', 'tib'):
1390
    value *= 1024 * 1024
1391

    
1392
  else:
1393
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1394

    
1395
  # Make sure we round up
1396
  if int(value) < value:
1397
    value += 1
1398

    
1399
  # Round up to the next multiple of 4
1400
  value = int(value)
1401
  if value % 4:
1402
    value += 4 - value % 4
1403

    
1404
  return value
1405

    
1406

    
1407
def AddAuthorizedKey(file_name, key):
1408
  """Adds an SSH public key to an authorized_keys file.
1409

1410
  @type file_name: str
1411
  @param file_name: path to authorized_keys file
1412
  @type key: str
1413
  @param key: string containing key
1414

1415
  """
1416
  key_fields = key.split()
1417

    
1418
  f = open(file_name, 'a+')
1419
  try:
1420
    nl = True
1421
    for line in f:
1422
      # Ignore whitespace changes
1423
      if line.split() == key_fields:
1424
        break
1425
      nl = line.endswith('\n')
1426
    else:
1427
      if not nl:
1428
        f.write("\n")
1429
      f.write(key.rstrip('\r\n'))
1430
      f.write("\n")
1431
      f.flush()
1432
  finally:
1433
    f.close()
1434

    
1435

    
1436
def RemoveAuthorizedKey(file_name, key):
1437
  """Removes an SSH public key from an authorized_keys file.
1438

1439
  @type file_name: str
1440
  @param file_name: path to authorized_keys file
1441
  @type key: str
1442
  @param key: string containing key
1443

1444
  """
1445
  key_fields = key.split()
1446

    
1447
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1448
  try:
1449
    out = os.fdopen(fd, 'w')
1450
    try:
1451
      f = open(file_name, 'r')
1452
      try:
1453
        for line in f:
1454
          # Ignore whitespace changes while comparing lines
1455
          if line.split() != key_fields:
1456
            out.write(line)
1457

    
1458
        out.flush()
1459
        os.rename(tmpname, file_name)
1460
      finally:
1461
        f.close()
1462
    finally:
1463
      out.close()
1464
  except:
1465
    RemoveFile(tmpname)
1466
    raise
1467

    
1468

    
1469
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1470
  """Sets the name of an IP address and hostname in /etc/hosts.
1471

1472
  @type file_name: str
1473
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1474
  @type ip: str
1475
  @param ip: the IP address
1476
  @type hostname: str
1477
  @param hostname: the hostname to be added
1478
  @type aliases: list
1479
  @param aliases: the list of aliases to add for the hostname
1480

1481
  """
1482
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1483
  # Ensure aliases are unique
1484
  aliases = UniqueSequence([hostname] + aliases)[1:]
1485

    
1486
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1487
  try:
1488
    out = os.fdopen(fd, 'w')
1489
    try:
1490
      f = open(file_name, 'r')
1491
      try:
1492
        for line in f:
1493
          fields = line.split()
1494
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1495
            continue
1496
          out.write(line)
1497

    
1498
        out.write("%s\t%s" % (ip, hostname))
1499
        if aliases:
1500
          out.write(" %s" % ' '.join(aliases))
1501
        out.write('\n')
1502

    
1503
        out.flush()
1504
        os.fsync(out)
1505
        os.chmod(tmpname, 0644)
1506
        os.rename(tmpname, file_name)
1507
      finally:
1508
        f.close()
1509
    finally:
1510
      out.close()
1511
  except:
1512
    RemoveFile(tmpname)
1513
    raise
1514

    
1515

    
1516
def AddHostToEtcHosts(hostname):
1517
  """Wrapper around SetEtcHostsEntry.
1518

1519
  @type hostname: str
1520
  @param hostname: a hostname that will be resolved and added to
1521
      L{constants.ETC_HOSTS}
1522

1523
  """
1524
  hi = HostInfo(name=hostname)
1525
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1526

    
1527

    
1528
def RemoveEtcHostsEntry(file_name, hostname):
1529
  """Removes a hostname from /etc/hosts.
1530

1531
  IP addresses without names are removed from the file.
1532

1533
  @type file_name: str
1534
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1535
  @type hostname: str
1536
  @param hostname: the hostname to be removed
1537

1538
  """
1539
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1540
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1541
  try:
1542
    out = os.fdopen(fd, 'w')
1543
    try:
1544
      f = open(file_name, 'r')
1545
      try:
1546
        for line in f:
1547
          fields = line.split()
1548
          if len(fields) > 1 and not fields[0].startswith('#'):
1549
            names = fields[1:]
1550
            if hostname in names:
1551
              while hostname in names:
1552
                names.remove(hostname)
1553
              if names:
1554
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1555
              continue
1556

    
1557
          out.write(line)
1558

    
1559
        out.flush()
1560
        os.fsync(out)
1561
        os.chmod(tmpname, 0644)
1562
        os.rename(tmpname, file_name)
1563
      finally:
1564
        f.close()
1565
    finally:
1566
      out.close()
1567
  except:
1568
    RemoveFile(tmpname)
1569
    raise
1570

    
1571

    
1572
def RemoveHostFromEtcHosts(hostname):
1573
  """Wrapper around RemoveEtcHostsEntry.
1574

1575
  @type hostname: str
1576
  @param hostname: hostname that will be resolved and its
1577
      full and shot name will be removed from
1578
      L{constants.ETC_HOSTS}
1579

1580
  """
1581
  hi = HostInfo(name=hostname)
1582
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1583
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1584

    
1585

    
1586
def TimestampForFilename():
1587
  """Returns the current time formatted for filenames.
1588

1589
  The format doesn't contain colons as some shells and applications them as
1590
  separators.
1591

1592
  """
1593
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1594

    
1595

    
1596
def CreateBackup(file_name):
1597
  """Creates a backup of a file.
1598

1599
  @type file_name: str
1600
  @param file_name: file to be backed up
1601
  @rtype: str
1602
  @return: the path to the newly created backup
1603
  @raise errors.ProgrammerError: for invalid file names
1604

1605
  """
1606
  if not os.path.isfile(file_name):
1607
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1608
                                file_name)
1609

    
1610
  prefix = ("%s.backup-%s." %
1611
            (os.path.basename(file_name), TimestampForFilename()))
1612
  dir_name = os.path.dirname(file_name)
1613

    
1614
  fsrc = open(file_name, 'rb')
1615
  try:
1616
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1617
    fdst = os.fdopen(fd, 'wb')
1618
    try:
1619
      logging.debug("Backing up %s at %s", file_name, backup_name)
1620
      shutil.copyfileobj(fsrc, fdst)
1621
    finally:
1622
      fdst.close()
1623
  finally:
1624
    fsrc.close()
1625

    
1626
  return backup_name
1627

    
1628

    
1629
def ShellQuote(value):
1630
  """Quotes shell argument according to POSIX.
1631

1632
  @type value: str
1633
  @param value: the argument to be quoted
1634
  @rtype: str
1635
  @return: the quoted value
1636

1637
  """
1638
  if _re_shell_unquoted.match(value):
1639
    return value
1640
  else:
1641
    return "'%s'" % value.replace("'", "'\\''")
1642

    
1643

    
1644
def ShellQuoteArgs(args):
1645
  """Quotes a list of shell arguments.
1646

1647
  @type args: list
1648
  @param args: list of arguments to be quoted
1649
  @rtype: str
1650
  @return: the quoted arguments concatenated with spaces
1651

1652
  """
1653
  return ' '.join([ShellQuote(i) for i in args])
1654

    
1655

    
1656
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1657
  """Simple ping implementation using TCP connect(2).
1658

1659
  Check if the given IP is reachable by doing attempting a TCP connect
1660
  to it.
1661

1662
  @type target: str
1663
  @param target: the IP or hostname to ping
1664
  @type port: int
1665
  @param port: the port to connect to
1666
  @type timeout: int
1667
  @param timeout: the timeout on the connection attempt
1668
  @type live_port_needed: boolean
1669
  @param live_port_needed: whether a closed port will cause the
1670
      function to return failure, as if there was a timeout
1671
  @type source: str or None
1672
  @param source: if specified, will cause the connect to be made
1673
      from this specific source address; failures to bind other
1674
      than C{EADDRNOTAVAIL} will be ignored
1675

1676
  """
1677
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1678

    
1679
  success = False
1680

    
1681
  if source is not None:
1682
    try:
1683
      sock.bind((source, 0))
1684
    except socket.error, (errcode, _):
1685
      if errcode == errno.EADDRNOTAVAIL:
1686
        success = False
1687

    
1688
  sock.settimeout(timeout)
1689

    
1690
  try:
1691
    sock.connect((target, port))
1692
    sock.close()
1693
    success = True
1694
  except socket.timeout:
1695
    success = False
1696
  except socket.error, (errcode, _):
1697
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1698

    
1699
  return success
1700

    
1701

    
1702
def OwnIpAddress(address):
1703
  """Check if the current host has the the given IP address.
1704

1705
  Currently this is done by TCP-pinging the address from the loopback
1706
  address.
1707

1708
  @type address: string
1709
  @param address: the address to check
1710
  @rtype: bool
1711
  @return: True if we own the address
1712

1713
  """
1714
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1715
                 source=constants.LOCALHOST_IP_ADDRESS)
1716

    
1717

    
1718
def ListVisibleFiles(path, sort=True):
1719
  """Returns a list of visible files in a directory.
1720

1721
  @type path: str
1722
  @param path: the directory to enumerate
1723
  @type sort: boolean
1724
  @param sort: whether to provide a sorted output
1725
  @rtype: list
1726
  @return: the list of all files not starting with a dot
1727
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1728

1729
  """
1730
  if not IsNormAbsPath(path):
1731
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1732
                                 " absolute/normalized: '%s'" % path)
1733
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1734
  if sort:
1735
    files.sort()
1736
  return files
1737

    
1738

    
1739
def GetHomeDir(user, default=None):
1740
  """Try to get the homedir of the given user.
1741

1742
  The user can be passed either as a string (denoting the name) or as
1743
  an integer (denoting the user id). If the user is not found, the
1744
  'default' argument is returned, which defaults to None.
1745

1746
  """
1747
  try:
1748
    if isinstance(user, basestring):
1749
      result = pwd.getpwnam(user)
1750
    elif isinstance(user, (int, long)):
1751
      result = pwd.getpwuid(user)
1752
    else:
1753
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1754
                                   type(user))
1755
  except KeyError:
1756
    return default
1757
  return result.pw_dir
1758

    
1759

    
1760
def NewUUID():
1761
  """Returns a random UUID.
1762

1763
  @note: This is a Linux-specific method as it uses the /proc
1764
      filesystem.
1765
  @rtype: str
1766

1767
  """
1768
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1769

    
1770

    
1771
def GenerateSecret(numbytes=20):
1772
  """Generates a random secret.
1773

1774
  This will generate a pseudo-random secret returning an hex string
1775
  (so that it can be used where an ASCII string is needed).
1776

1777
  @param numbytes: the number of bytes which will be represented by the returned
1778
      string (defaulting to 20, the length of a SHA1 hash)
1779
  @rtype: str
1780
  @return: an hex representation of the pseudo-random sequence
1781

1782
  """
1783
  return os.urandom(numbytes).encode('hex')
1784

    
1785

    
1786
def EnsureDirs(dirs):
1787
  """Make required directories, if they don't exist.
1788

1789
  @param dirs: list of tuples (dir_name, dir_mode)
1790
  @type dirs: list of (string, integer)
1791

1792
  """
1793
  for dir_name, dir_mode in dirs:
1794
    try:
1795
      os.mkdir(dir_name, dir_mode)
1796
    except EnvironmentError, err:
1797
      if err.errno != errno.EEXIST:
1798
        raise errors.GenericError("Cannot create needed directory"
1799
                                  " '%s': %s" % (dir_name, err))
1800
    try:
1801
      os.chmod(dir_name, dir_mode)
1802
    except EnvironmentError, err:
1803
      raise errors.GenericError("Cannot change directory permissions on"
1804
                                " '%s': %s" % (dir_name, err))
1805
    if not os.path.isdir(dir_name):
1806
      raise errors.GenericError("%s is not a directory" % dir_name)
1807

    
1808

    
1809
def ReadFile(file_name, size=-1):
1810
  """Reads a file.
1811

1812
  @type size: int
1813
  @param size: Read at most size bytes (if negative, entire file)
1814
  @rtype: str
1815
  @return: the (possibly partial) content of the file
1816

1817
  """
1818
  f = open(file_name, "r")
1819
  try:
1820
    return f.read(size)
1821
  finally:
1822
    f.close()
1823

    
1824

    
1825
def WriteFile(file_name, fn=None, data=None,
1826
              mode=None, uid=-1, gid=-1,
1827
              atime=None, mtime=None, close=True,
1828
              dry_run=False, backup=False,
1829
              prewrite=None, postwrite=None):
1830
  """(Over)write a file atomically.
1831

1832
  The file_name and either fn (a function taking one argument, the
1833
  file descriptor, and which should write the data to it) or data (the
1834
  contents of the file) must be passed. The other arguments are
1835
  optional and allow setting the file mode, owner and group, and the
1836
  mtime/atime of the file.
1837

1838
  If the function doesn't raise an exception, it has succeeded and the
1839
  target file has the new contents. If the function has raised an
1840
  exception, an existing target file should be unmodified and the
1841
  temporary file should be removed.
1842

1843
  @type file_name: str
1844
  @param file_name: the target filename
1845
  @type fn: callable
1846
  @param fn: content writing function, called with
1847
      file descriptor as parameter
1848
  @type data: str
1849
  @param data: contents of the file
1850
  @type mode: int
1851
  @param mode: file mode
1852
  @type uid: int
1853
  @param uid: the owner of the file
1854
  @type gid: int
1855
  @param gid: the group of the file
1856
  @type atime: int
1857
  @param atime: a custom access time to be set on the file
1858
  @type mtime: int
1859
  @param mtime: a custom modification time to be set on the file
1860
  @type close: boolean
1861
  @param close: whether to close file after writing it
1862
  @type prewrite: callable
1863
  @param prewrite: function to be called before writing content
1864
  @type postwrite: callable
1865
  @param postwrite: function to be called after writing content
1866

1867
  @rtype: None or int
1868
  @return: None if the 'close' parameter evaluates to True,
1869
      otherwise the file descriptor
1870

1871
  @raise errors.ProgrammerError: if any of the arguments are not valid
1872

1873
  """
1874
  if not os.path.isabs(file_name):
1875
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1876
                                 " absolute: '%s'" % file_name)
1877

    
1878
  if [fn, data].count(None) != 1:
1879
    raise errors.ProgrammerError("fn or data required")
1880

    
1881
  if [atime, mtime].count(None) == 1:
1882
    raise errors.ProgrammerError("Both atime and mtime must be either"
1883
                                 " set or None")
1884

    
1885
  if backup and not dry_run and os.path.isfile(file_name):
1886
    CreateBackup(file_name)
1887

    
1888
  dir_name, base_name = os.path.split(file_name)
1889
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1890
  do_remove = True
1891
  # here we need to make sure we remove the temp file, if any error
1892
  # leaves it in place
1893
  try:
1894
    if uid != -1 or gid != -1:
1895
      os.chown(new_name, uid, gid)
1896
    if mode:
1897
      os.chmod(new_name, mode)
1898
    if callable(prewrite):
1899
      prewrite(fd)
1900
    if data is not None:
1901
      os.write(fd, data)
1902
    else:
1903
      fn(fd)
1904
    if callable(postwrite):
1905
      postwrite(fd)
1906
    os.fsync(fd)
1907
    if atime is not None and mtime is not None:
1908
      os.utime(new_name, (atime, mtime))
1909
    if not dry_run:
1910
      os.rename(new_name, file_name)
1911
      do_remove = False
1912
  finally:
1913
    if close:
1914
      os.close(fd)
1915
      result = None
1916
    else:
1917
      result = fd
1918
    if do_remove:
1919
      RemoveFile(new_name)
1920

    
1921
  return result
1922

    
1923

    
1924
def ReadOneLineFile(file_name, strict=False):
1925
  """Return the first non-empty line from a file.
1926

1927
  @type strict: boolean
1928
  @param strict: if True, abort if the file has more than one
1929
      non-empty line
1930

1931
  """
1932
  file_lines = ReadFile(file_name).splitlines()
1933
  full_lines = filter(bool, file_lines)
1934
  if not file_lines or not full_lines:
1935
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1936
  elif strict and len(full_lines) > 1:
1937
    raise errors.GenericError("Too many lines in one-liner file %s" %
1938
                              file_name)
1939
  return full_lines[0]
1940

    
1941

    
1942
def FirstFree(seq, base=0):
1943
  """Returns the first non-existing integer from seq.
1944

1945
  The seq argument should be a sorted list of positive integers. The
1946
  first time the index of an element is smaller than the element
1947
  value, the index will be returned.
1948

1949
  The base argument is used to start at a different offset,
1950
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1951

1952
  Example: C{[0, 1, 3]} will return I{2}.
1953

1954
  @type seq: sequence
1955
  @param seq: the sequence to be analyzed.
1956
  @type base: int
1957
  @param base: use this value as the base index of the sequence
1958
  @rtype: int
1959
  @return: the first non-used index in the sequence
1960

1961
  """
1962
  for idx, elem in enumerate(seq):
1963
    assert elem >= base, "Passed element is higher than base offset"
1964
    if elem > idx + base:
1965
      # idx is not used
1966
      return idx + base
1967
  return None
1968

    
1969

    
1970
def SingleWaitForFdCondition(fdobj, event, timeout):
1971
  """Waits for a condition to occur on the socket.
1972

1973
  Immediately returns at the first interruption.
1974

1975
  @type fdobj: integer or object supporting a fileno() method
1976
  @param fdobj: entity to wait for events on
1977
  @type event: integer
1978
  @param event: ORed condition (see select module)
1979
  @type timeout: float or None
1980
  @param timeout: Timeout in seconds
1981
  @rtype: int or None
1982
  @return: None for timeout, otherwise occured conditions
1983

1984
  """
1985
  check = (event | select.POLLPRI |
1986
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1987

    
1988
  if timeout is not None:
1989
    # Poller object expects milliseconds
1990
    timeout *= 1000
1991

    
1992
  poller = select.poll()
1993
  poller.register(fdobj, event)
1994
  try:
1995
    # TODO: If the main thread receives a signal and we have no timeout, we
1996
    # could wait forever. This should check a global "quit" flag or something
1997
    # every so often.
1998
    io_events = poller.poll(timeout)
1999
  except select.error, err:
2000
    if err[0] != errno.EINTR:
2001
      raise
2002
    io_events = []
2003
  if io_events and io_events[0][1] & check:
2004
    return io_events[0][1]
2005
  else:
2006
    return None
2007

    
2008

    
2009
class FdConditionWaiterHelper(object):
2010
  """Retry helper for WaitForFdCondition.
2011

2012
  This class contains the retried and wait functions that make sure
2013
  WaitForFdCondition can continue waiting until the timeout is actually
2014
  expired.
2015

2016
  """
2017

    
2018
  def __init__(self, timeout):
2019
    self.timeout = timeout
2020

    
2021
  def Poll(self, fdobj, event):
2022
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2023
    if result is None:
2024
      raise RetryAgain()
2025
    else:
2026
      return result
2027

    
2028
  def UpdateTimeout(self, timeout):
2029
    self.timeout = timeout
2030

    
2031

    
2032
def WaitForFdCondition(fdobj, event, timeout):
2033
  """Waits for a condition to occur on the socket.
2034

2035
  Retries until the timeout is expired, even if interrupted.
2036

2037
  @type fdobj: integer or object supporting a fileno() method
2038
  @param fdobj: entity to wait for events on
2039
  @type event: integer
2040
  @param event: ORed condition (see select module)
2041
  @type timeout: float or None
2042
  @param timeout: Timeout in seconds
2043
  @rtype: int or None
2044
  @return: None for timeout, otherwise occured conditions
2045

2046
  """
2047
  if timeout is not None:
2048
    retrywaiter = FdConditionWaiterHelper(timeout)
2049
    try:
2050
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2051
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2052
    except RetryTimeout:
2053
      result = None
2054
  else:
2055
    result = None
2056
    while result is None:
2057
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2058
  return result
2059

    
2060

    
2061
def UniqueSequence(seq):
2062
  """Returns a list with unique elements.
2063

2064
  Element order is preserved.
2065

2066
  @type seq: sequence
2067
  @param seq: the sequence with the source elements
2068
  @rtype: list
2069
  @return: list of unique elements from seq
2070

2071
  """
2072
  seen = set()
2073
  return [i for i in seq if i not in seen and not seen.add(i)]
2074

    
2075

    
2076
def NormalizeAndValidateMac(mac):
2077
  """Normalizes and check if a MAC address is valid.
2078

2079
  Checks whether the supplied MAC address is formally correct, only
2080
  accepts colon separated format. Normalize it to all lower.
2081

2082
  @type mac: str
2083
  @param mac: the MAC to be validated
2084
  @rtype: str
2085
  @return: returns the normalized and validated MAC.
2086

2087
  @raise errors.OpPrereqError: If the MAC isn't valid
2088

2089
  """
2090
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2091
  if not mac_check.match(mac):
2092
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2093
                               mac, errors.ECODE_INVAL)
2094

    
2095
  return mac.lower()
2096

    
2097

    
2098
def TestDelay(duration):
2099
  """Sleep for a fixed amount of time.
2100

2101
  @type duration: float
2102
  @param duration: the sleep duration
2103
  @rtype: boolean
2104
  @return: False for negative value, True otherwise
2105

2106
  """
2107
  if duration < 0:
2108
    return False, "Invalid sleep duration"
2109
  time.sleep(duration)
2110
  return True, None
2111

    
2112

    
2113
def _CloseFDNoErr(fd, retries=5):
2114
  """Close a file descriptor ignoring errors.
2115

2116
  @type fd: int
2117
  @param fd: the file descriptor
2118
  @type retries: int
2119
  @param retries: how many retries to make, in case we get any
2120
      other error than EBADF
2121

2122
  """
2123
  try:
2124
    os.close(fd)
2125
  except OSError, err:
2126
    if err.errno != errno.EBADF:
2127
      if retries > 0:
2128
        _CloseFDNoErr(fd, retries - 1)
2129
    # else either it's closed already or we're out of retries, so we
2130
    # ignore this and go on
2131

    
2132

    
2133
def CloseFDs(noclose_fds=None):
2134
  """Close file descriptors.
2135

2136
  This closes all file descriptors above 2 (i.e. except
2137
  stdin/out/err).
2138

2139
  @type noclose_fds: list or None
2140
  @param noclose_fds: if given, it denotes a list of file descriptor
2141
      that should not be closed
2142

2143
  """
2144
  # Default maximum for the number of available file descriptors.
2145
  if 'SC_OPEN_MAX' in os.sysconf_names:
2146
    try:
2147
      MAXFD = os.sysconf('SC_OPEN_MAX')
2148
      if MAXFD < 0:
2149
        MAXFD = 1024
2150
    except OSError:
2151
      MAXFD = 1024
2152
  else:
2153
    MAXFD = 1024
2154
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2155
  if (maxfd == resource.RLIM_INFINITY):
2156
    maxfd = MAXFD
2157

    
2158
  # Iterate through and close all file descriptors (except the standard ones)
2159
  for fd in range(3, maxfd):
2160
    if noclose_fds and fd in noclose_fds:
2161
      continue
2162
    _CloseFDNoErr(fd)
2163

    
2164

    
2165
def Mlockall():
2166
  """Lock current process' virtual address space into RAM.
2167

2168
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2169
  see mlock(2) for more details. This function requires ctypes module.
2170

2171
  """
2172
  if ctypes is None:
2173
    logging.warning("Cannot set memory lock, ctypes module not found")
2174
    return
2175

    
2176
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2177
  if libc is None:
2178
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2179
    return
2180

    
2181
  # Some older version of the ctypes module don't have built-in functionality
2182
  # to access the errno global variable, where function error codes are stored.
2183
  # By declaring this variable as a pointer to an integer we can then access
2184
  # its value correctly, should the mlockall call fail, in order to see what
2185
  # the actual error code was.
2186
  # pylint: disable-msg=W0212
2187
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2188

    
2189
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2190
    # pylint: disable-msg=W0212
2191
    logging.error("Cannot set memory lock: %s",
2192
                  os.strerror(libc.__errno_location().contents.value))
2193
    return
2194

    
2195
  logging.debug("Memory lock set")
2196

    
2197

    
2198
def Daemonize(logfile, run_uid, run_gid):
2199
  """Daemonize the current process.
2200

2201
  This detaches the current process from the controlling terminal and
2202
  runs it in the background as a daemon.
2203

2204
  @type logfile: str
2205
  @param logfile: the logfile to which we should redirect stdout/stderr
2206
  @type run_uid: int
2207
  @param run_uid: Run the child under this uid
2208
  @type run_gid: int
2209
  @param run_gid: Run the child under this gid
2210
  @rtype: int
2211
  @return: the value zero
2212

2213
  """
2214
  # pylint: disable-msg=W0212
2215
  # yes, we really want os._exit
2216
  UMASK = 077
2217
  WORKDIR = "/"
2218

    
2219
  # this might fail
2220
  pid = os.fork()
2221
  if (pid == 0):  # The first child.
2222
    os.setsid()
2223
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2224
    #        make sure to check for config permission and bail out when invoked
2225
    #        with wrong user.
2226
    os.setgid(run_gid)
2227
    os.setuid(run_uid)
2228
    # this might fail
2229
    pid = os.fork() # Fork a second child.
2230
    if (pid == 0):  # The second child.
2231
      os.chdir(WORKDIR)
2232
      os.umask(UMASK)
2233
    else:
2234
      # exit() or _exit()?  See below.
2235
      os._exit(0) # Exit parent (the first child) of the second child.
2236
  else:
2237
    os._exit(0) # Exit parent of the first child.
2238

    
2239
  for fd in range(3):
2240
    _CloseFDNoErr(fd)
2241
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2242
  assert i == 0, "Can't close/reopen stdin"
2243
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2244
  assert i == 1, "Can't close/reopen stdout"
2245
  # Duplicate standard output to standard error.
2246
  os.dup2(1, 2)
2247
  return 0
2248

    
2249

    
2250
def DaemonPidFileName(name):
2251
  """Compute a ganeti pid file absolute path
2252

2253
  @type name: str
2254
  @param name: the daemon name
2255
  @rtype: str
2256
  @return: the full path to the pidfile corresponding to the given
2257
      daemon name
2258

2259
  """
2260
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2261

    
2262

    
2263
def EnsureDaemon(name):
2264
  """Check for and start daemon if not alive.
2265

2266
  """
2267
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2268
  if result.failed:
2269
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2270
                  name, result.fail_reason, result.output)
2271
    return False
2272

    
2273
  return True
2274

    
2275

    
2276
def StopDaemon(name):
2277
  """Stop daemon
2278

2279
  """
2280
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2281
  if result.failed:
2282
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2283
                  name, result.fail_reason, result.output)
2284
    return False
2285

    
2286
  return True
2287

    
2288

    
2289
def WritePidFile(name):
2290
  """Write the current process pidfile.
2291

2292
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2293

2294
  @type name: str
2295
  @param name: the daemon name to use
2296
  @raise errors.GenericError: if the pid file already exists and
2297
      points to a live process
2298

2299
  """
2300
  pid = os.getpid()
2301
  pidfilename = DaemonPidFileName(name)
2302
  if IsProcessAlive(ReadPidFile(pidfilename)):
2303
    raise errors.GenericError("%s contains a live process" % pidfilename)
2304

    
2305
  WriteFile(pidfilename, data="%d\n" % pid)
2306

    
2307

    
2308
def RemovePidFile(name):
2309
  """Remove the current process pidfile.
2310

2311
  Any errors are ignored.
2312

2313
  @type name: str
2314
  @param name: the daemon name used to derive the pidfile name
2315

2316
  """
2317
  pidfilename = DaemonPidFileName(name)
2318
  # TODO: we could check here that the file contains our pid
2319
  try:
2320
    RemoveFile(pidfilename)
2321
  except: # pylint: disable-msg=W0702
2322
    pass
2323

    
2324

    
2325
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2326
                waitpid=False):
2327
  """Kill a process given by its pid.
2328

2329
  @type pid: int
2330
  @param pid: The PID to terminate.
2331
  @type signal_: int
2332
  @param signal_: The signal to send, by default SIGTERM
2333
  @type timeout: int
2334
  @param timeout: The timeout after which, if the process is still alive,
2335
                  a SIGKILL will be sent. If not positive, no such checking
2336
                  will be done
2337
  @type waitpid: boolean
2338
  @param waitpid: If true, we should waitpid on this process after
2339
      sending signals, since it's our own child and otherwise it
2340
      would remain as zombie
2341

2342
  """
2343
  def _helper(pid, signal_, wait):
2344
    """Simple helper to encapsulate the kill/waitpid sequence"""
2345
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2346
      try:
2347
        os.waitpid(pid, os.WNOHANG)
2348
      except OSError:
2349
        pass
2350

    
2351
  if pid <= 0:
2352
    # kill with pid=0 == suicide
2353
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2354

    
2355
  if not IsProcessAlive(pid):
2356
    return
2357

    
2358
  _helper(pid, signal_, waitpid)
2359

    
2360
  if timeout <= 0:
2361
    return
2362

    
2363
  def _CheckProcess():
2364
    if not IsProcessAlive(pid):
2365
      return
2366

    
2367
    try:
2368
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2369
    except OSError:
2370
      raise RetryAgain()
2371

    
2372
    if result_pid > 0:
2373
      return
2374

    
2375
    raise RetryAgain()
2376

    
2377
  try:
2378
    # Wait up to $timeout seconds
2379
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2380
  except RetryTimeout:
2381
    pass
2382

    
2383
  if IsProcessAlive(pid):
2384
    # Kill process if it's still alive
2385
    _helper(pid, signal.SIGKILL, waitpid)
2386

    
2387

    
2388
def FindFile(name, search_path, test=os.path.exists):
2389
  """Look for a filesystem object in a given path.
2390

2391
  This is an abstract method to search for filesystem object (files,
2392
  dirs) under a given search path.
2393

2394
  @type name: str
2395
  @param name: the name to look for
2396
  @type search_path: str
2397
  @param search_path: location to start at
2398
  @type test: callable
2399
  @param test: a function taking one argument that should return True
2400
      if the a given object is valid; the default value is
2401
      os.path.exists, causing only existing files to be returned
2402
  @rtype: str or None
2403
  @return: full path to the object if found, None otherwise
2404

2405
  """
2406
  # validate the filename mask
2407
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2408
    logging.critical("Invalid value passed for external script name: '%s'",
2409
                     name)
2410
    return None
2411

    
2412
  for dir_name in search_path:
2413
    # FIXME: investigate switch to PathJoin
2414
    item_name = os.path.sep.join([dir_name, name])
2415
    # check the user test and that we're indeed resolving to the given
2416
    # basename
2417
    if test(item_name) and os.path.basename(item_name) == name:
2418
      return item_name
2419
  return None
2420

    
2421

    
2422
def CheckVolumeGroupSize(vglist, vgname, minsize):
2423
  """Checks if the volume group list is valid.
2424

2425
  The function will check if a given volume group is in the list of
2426
  volume groups and has a minimum size.
2427

2428
  @type vglist: dict
2429
  @param vglist: dictionary of volume group names and their size
2430
  @type vgname: str
2431
  @param vgname: the volume group we should check
2432
  @type minsize: int
2433
  @param minsize: the minimum size we accept
2434
  @rtype: None or str
2435
  @return: None for success, otherwise the error message
2436

2437
  """
2438
  vgsize = vglist.get(vgname, None)
2439
  if vgsize is None:
2440
    return "volume group '%s' missing" % vgname
2441
  elif vgsize < minsize:
2442
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2443
            (vgname, minsize, vgsize))
2444
  return None
2445

    
2446

    
2447
def SplitTime(value):
2448
  """Splits time as floating point number into a tuple.
2449

2450
  @param value: Time in seconds
2451
  @type value: int or float
2452
  @return: Tuple containing (seconds, microseconds)
2453

2454
  """
2455
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2456

    
2457
  assert 0 <= seconds, \
2458
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2459
  assert 0 <= microseconds <= 999999, \
2460
    "Microseconds must be 0-999999, but are %s" % microseconds
2461

    
2462
  return (int(seconds), int(microseconds))
2463

    
2464

    
2465
def MergeTime(timetuple):
2466
  """Merges a tuple into time as a floating point number.
2467

2468
  @param timetuple: Time as tuple, (seconds, microseconds)
2469
  @type timetuple: tuple
2470
  @return: Time as a floating point number expressed in seconds
2471

2472
  """
2473
  (seconds, microseconds) = timetuple
2474

    
2475
  assert 0 <= seconds, \
2476
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2477
  assert 0 <= microseconds <= 999999, \
2478
    "Microseconds must be 0-999999, but are %s" % microseconds
2479

    
2480
  return float(seconds) + (float(microseconds) * 0.000001)
2481

    
2482

    
2483
def GetDaemonPort(daemon_name):
2484
  """Get the daemon port for this cluster.
2485

2486
  Note that this routine does not read a ganeti-specific file, but
2487
  instead uses C{socket.getservbyname} to allow pre-customization of
2488
  this parameter outside of Ganeti.
2489

2490
  @type daemon_name: string
2491
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2492
  @rtype: int
2493

2494
  """
2495
  if daemon_name not in constants.DAEMONS_PORTS:
2496
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2497

    
2498
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2499
  try:
2500
    port = socket.getservbyname(daemon_name, proto)
2501
  except socket.error:
2502
    port = default_port
2503

    
2504
  return port
2505

    
2506

    
2507
class LogFileHandler(logging.FileHandler):
2508
  """Log handler that doesn't fallback to stderr.
2509

2510
  When an error occurs while writing on the logfile, logging.FileHandler tries
2511
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2512
  the logfile. This class avoids failures reporting errors to /dev/console.
2513

2514
  """
2515
  def __init__(self, filename, mode="a", encoding=None):
2516
    """Open the specified file and use it as the stream for logging.
2517

2518
    Also open /dev/console to report errors while logging.
2519

2520
    """
2521
    logging.FileHandler.__init__(self, filename, mode, encoding)
2522
    self.console = open(constants.DEV_CONSOLE, "a")
2523

    
2524
  def handleError(self, record): # pylint: disable-msg=C0103
2525
    """Handle errors which occur during an emit() call.
2526

2527
    Try to handle errors with FileHandler method, if it fails write to
2528
    /dev/console.
2529

2530
    """
2531
    try:
2532
      logging.FileHandler.handleError(self, record)
2533
    except Exception: # pylint: disable-msg=W0703
2534
      try:
2535
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2536
      except Exception: # pylint: disable-msg=W0703
2537
        # Log handler tried everything it could, now just give up
2538
        pass
2539

    
2540

    
2541
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2542
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2543
                 console_logging=False):
2544
  """Configures the logging module.
2545

2546
  @type logfile: str
2547
  @param logfile: the filename to which we should log
2548
  @type debug: integer
2549
  @param debug: if greater than zero, enable debug messages, otherwise
2550
      only those at C{INFO} and above level
2551
  @type stderr_logging: boolean
2552
  @param stderr_logging: whether we should also log to the standard error
2553
  @type program: str
2554
  @param program: the name under which we should log messages
2555
  @type multithreaded: boolean
2556
  @param multithreaded: if True, will add the thread name to the log file
2557
  @type syslog: string
2558
  @param syslog: one of 'no', 'yes', 'only':
2559
      - if no, syslog is not used
2560
      - if yes, syslog is used (in addition to file-logging)
2561
      - if only, only syslog is used
2562
  @type console_logging: boolean
2563
  @param console_logging: if True, will use a FileHandler which falls back to
2564
      the system console if logging fails
2565
  @raise EnvironmentError: if we can't open the log file and
2566
      syslog/stderr logging is disabled
2567

2568
  """
2569
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2570
  sft = program + "[%(process)d]:"
2571
  if multithreaded:
2572
    fmt += "/%(threadName)s"
2573
    sft += " (%(threadName)s)"
2574
  if debug:
2575
    fmt += " %(module)s:%(lineno)s"
2576
    # no debug info for syslog loggers
2577
  fmt += " %(levelname)s %(message)s"
2578
  # yes, we do want the textual level, as remote syslog will probably
2579
  # lose the error level, and it's easier to grep for it
2580
  sft += " %(levelname)s %(message)s"
2581
  formatter = logging.Formatter(fmt)
2582
  sys_fmt = logging.Formatter(sft)
2583

    
2584
  root_logger = logging.getLogger("")
2585
  root_logger.setLevel(logging.NOTSET)
2586

    
2587
  # Remove all previously setup handlers
2588
  for handler in root_logger.handlers:
2589
    handler.close()
2590
    root_logger.removeHandler(handler)
2591

    
2592
  if stderr_logging:
2593
    stderr_handler = logging.StreamHandler()
2594
    stderr_handler.setFormatter(formatter)
2595
    if debug:
2596
      stderr_handler.setLevel(logging.NOTSET)
2597
    else:
2598
      stderr_handler.setLevel(logging.CRITICAL)
2599
    root_logger.addHandler(stderr_handler)
2600

    
2601
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2602
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2603
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2604
                                                    facility)
2605
    syslog_handler.setFormatter(sys_fmt)
2606
    # Never enable debug over syslog
2607
    syslog_handler.setLevel(logging.INFO)
2608
    root_logger.addHandler(syslog_handler)
2609

    
2610
  if syslog != constants.SYSLOG_ONLY:
2611
    # this can fail, if the logging directories are not setup or we have
2612
    # a permisssion problem; in this case, it's best to log but ignore
2613
    # the error if stderr_logging is True, and if false we re-raise the
2614
    # exception since otherwise we could run but without any logs at all
2615
    try:
2616
      if console_logging:
2617
        logfile_handler = LogFileHandler(logfile)
2618
      else:
2619
        logfile_handler = logging.FileHandler(logfile)
2620
      logfile_handler.setFormatter(formatter)
2621
      if debug:
2622
        logfile_handler.setLevel(logging.DEBUG)
2623
      else:
2624
        logfile_handler.setLevel(logging.INFO)
2625
      root_logger.addHandler(logfile_handler)
2626
    except EnvironmentError:
2627
      if stderr_logging or syslog == constants.SYSLOG_YES:
2628
        logging.exception("Failed to enable logging to file '%s'", logfile)
2629
      else:
2630
        # we need to re-raise the exception
2631
        raise
2632

    
2633

    
2634
def IsNormAbsPath(path):
2635
  """Check whether a path is absolute and also normalized
2636

2637
  This avoids things like /dir/../../other/path to be valid.
2638

2639
  """
2640
  return os.path.normpath(path) == path and os.path.isabs(path)
2641

    
2642

    
2643
def PathJoin(*args):
2644
  """Safe-join a list of path components.
2645

2646
  Requirements:
2647
      - the first argument must be an absolute path
2648
      - no component in the path must have backtracking (e.g. /../),
2649
        since we check for normalization at the end
2650

2651
  @param args: the path components to be joined
2652
  @raise ValueError: for invalid paths
2653

2654
  """
2655
  # ensure we're having at least one path passed in
2656
  assert args
2657
  # ensure the first component is an absolute and normalized path name
2658
  root = args[0]
2659
  if not IsNormAbsPath(root):
2660
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2661
  result = os.path.join(*args)
2662
  # ensure that the whole path is normalized
2663
  if not IsNormAbsPath(result):
2664
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2665
  # check that we're still under the original prefix
2666
  prefix = os.path.commonprefix([root, result])
2667
  if prefix != root:
2668
    raise ValueError("Error: path joining resulted in different prefix"
2669
                     " (%s != %s)" % (prefix, root))
2670
  return result
2671

    
2672

    
2673
def TailFile(fname, lines=20):
2674
  """Return the last lines from a file.
2675

2676
  @note: this function will only read and parse the last 4KB of
2677
      the file; if the lines are very long, it could be that less
2678
      than the requested number of lines are returned
2679

2680
  @param fname: the file name
2681
  @type lines: int
2682
  @param lines: the (maximum) number of lines to return
2683

2684
  """
2685
  fd = open(fname, "r")
2686
  try:
2687
    fd.seek(0, 2)
2688
    pos = fd.tell()
2689
    pos = max(0, pos-4096)
2690
    fd.seek(pos, 0)
2691
    raw_data = fd.read()
2692
  finally:
2693
    fd.close()
2694

    
2695
  rows = raw_data.splitlines()
2696
  return rows[-lines:]
2697

    
2698

    
2699
def FormatTimestampWithTZ(secs):
2700
  """Formats a Unix timestamp with the local timezone.
2701

2702
  """
2703
  return time.strftime("%F %T %Z", time.gmtime(secs))
2704

    
2705

    
2706
def _ParseAsn1Generalizedtime(value):
2707
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2708

2709
  @type value: string
2710
  @param value: ASN1 GENERALIZEDTIME timestamp
2711

2712
  """
2713
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2714
  if m:
2715
    # We have an offset
2716
    asn1time = m.group(1)
2717
    hours = int(m.group(2))
2718
    minutes = int(m.group(3))
2719
    utcoffset = (60 * hours) + minutes
2720
  else:
2721
    if not value.endswith("Z"):
2722
      raise ValueError("Missing timezone")
2723
    asn1time = value[:-1]
2724
    utcoffset = 0
2725

    
2726
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2727

    
2728
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2729

    
2730
  return calendar.timegm(tt.utctimetuple())
2731

    
2732

    
2733
def GetX509CertValidity(cert):
2734
  """Returns the validity period of the certificate.
2735

2736
  @type cert: OpenSSL.crypto.X509
2737
  @param cert: X509 certificate object
2738

2739
  """
2740
  # The get_notBefore and get_notAfter functions are only supported in
2741
  # pyOpenSSL 0.7 and above.
2742
  try:
2743
    get_notbefore_fn = cert.get_notBefore
2744
  except AttributeError:
2745
    not_before = None
2746
  else:
2747
    not_before_asn1 = get_notbefore_fn()
2748

    
2749
    if not_before_asn1 is None:
2750
      not_before = None
2751
    else:
2752
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2753

    
2754
  try:
2755
    get_notafter_fn = cert.get_notAfter
2756
  except AttributeError:
2757
    not_after = None
2758
  else:
2759
    not_after_asn1 = get_notafter_fn()
2760

    
2761
    if not_after_asn1 is None:
2762
      not_after = None
2763
    else:
2764
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2765

    
2766
  return (not_before, not_after)
2767

    
2768

    
2769
def _VerifyCertificateInner(expired, not_before, not_after, now,
2770
                            warn_days, error_days):
2771
  """Verifies certificate validity.
2772

2773
  @type expired: bool
2774
  @param expired: Whether pyOpenSSL considers the certificate as expired
2775
  @type not_before: number or None
2776
  @param not_before: Unix timestamp before which certificate is not valid
2777
  @type not_after: number or None
2778
  @param not_after: Unix timestamp after which certificate is invalid
2779
  @type now: number
2780
  @param now: Current time as Unix timestamp
2781
  @type warn_days: number or None
2782
  @param warn_days: How many days before expiration a warning should be reported
2783
  @type error_days: number or None
2784
  @param error_days: How many days before expiration an error should be reported
2785

2786
  """
2787
  if expired:
2788
    msg = "Certificate is expired"
2789

    
2790
    if not_before is not None and not_after is not None:
2791
      msg += (" (valid from %s to %s)" %
2792
              (FormatTimestampWithTZ(not_before),
2793
               FormatTimestampWithTZ(not_after)))
2794
    elif not_before is not None:
2795
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2796
    elif not_after is not None:
2797
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2798

    
2799
    return (CERT_ERROR, msg)
2800

    
2801
  elif not_before is not None and not_before > now:
2802
    return (CERT_WARNING,
2803
            "Certificate not yet valid (valid from %s)" %
2804
            FormatTimestampWithTZ(not_before))
2805

    
2806
  elif not_after is not None:
2807
    remaining_days = int((not_after - now) / (24 * 3600))
2808

    
2809
    msg = "Certificate expires in about %d days" % remaining_days
2810

    
2811
    if error_days is not None and remaining_days <= error_days:
2812
      return (CERT_ERROR, msg)
2813

    
2814
    if warn_days is not None and remaining_days <= warn_days:
2815
      return (CERT_WARNING, msg)
2816

    
2817
  return (None, None)
2818

    
2819

    
2820
def VerifyX509Certificate(cert, warn_days, error_days):
2821
  """Verifies a certificate for LUVerifyCluster.
2822

2823
  @type cert: OpenSSL.crypto.X509
2824
  @param cert: X509 certificate object
2825
  @type warn_days: number or None
2826
  @param warn_days: How many days before expiration a warning should be reported
2827
  @type error_days: number or None
2828
  @param error_days: How many days before expiration an error should be reported
2829

2830
  """
2831
  # Depending on the pyOpenSSL version, this can just return (None, None)
2832
  (not_before, not_after) = GetX509CertValidity(cert)
2833

    
2834
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2835
                                 time.time(), warn_days, error_days)
2836

    
2837

    
2838
def SignX509Certificate(cert, key, salt):
2839
  """Sign a X509 certificate.
2840

2841
  An RFC822-like signature header is added in front of the certificate.
2842

2843
  @type cert: OpenSSL.crypto.X509
2844
  @param cert: X509 certificate object
2845
  @type key: string
2846
  @param key: Key for HMAC
2847
  @type salt: string
2848
  @param salt: Salt for HMAC
2849
  @rtype: string
2850
  @return: Serialized and signed certificate in PEM format
2851

2852
  """
2853
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2854
    raise errors.GenericError("Invalid salt: %r" % salt)
2855

    
2856
  # Dumping as PEM here ensures the certificate is in a sane format
2857
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2858

    
2859
  return ("%s: %s/%s\n\n%s" %
2860
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2861
           Sha1Hmac(key, cert_pem, salt=salt),
2862
           cert_pem))
2863

    
2864

    
2865
def _ExtractX509CertificateSignature(cert_pem):
2866
  """Helper function to extract signature from X509 certificate.
2867

2868
  """
2869
  # Extract signature from original PEM data
2870
  for line in cert_pem.splitlines():
2871
    if line.startswith("---"):
2872
      break
2873

    
2874
    m = X509_SIGNATURE.match(line.strip())
2875
    if m:
2876
      return (m.group("salt"), m.group("sign"))
2877

    
2878
  raise errors.GenericError("X509 certificate signature is missing")
2879

    
2880

    
2881
def LoadSignedX509Certificate(cert_pem, key):
2882
  """Verifies a signed X509 certificate.
2883

2884
  @type cert_pem: string
2885
  @param cert_pem: Certificate in PEM format and with signature header
2886
  @type key: string
2887
  @param key: Key for HMAC
2888
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2889
  @return: X509 certificate object and salt
2890

2891
  """
2892
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2893

    
2894
  # Load certificate
2895
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2896

    
2897
  # Dump again to ensure it's in a sane format
2898
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2899

    
2900
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2901
    raise errors.GenericError("X509 certificate signature is invalid")
2902

    
2903
  return (cert, salt)
2904

    
2905

    
2906
def Sha1Hmac(key, text, salt=None):
2907
  """Calculates the HMAC-SHA1 digest of a text.
2908

2909
  HMAC is defined in RFC2104.
2910

2911
  @type key: string
2912
  @param key: Secret key
2913
  @type text: string
2914

2915
  """
2916
  if salt:
2917
    salted_text = salt + text
2918
  else:
2919
    salted_text = text
2920

    
2921
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2922

    
2923

    
2924
def VerifySha1Hmac(key, text, digest, salt=None):
2925
  """Verifies the HMAC-SHA1 digest of a text.
2926

2927
  HMAC is defined in RFC2104.
2928

2929
  @type key: string
2930
  @param key: Secret key
2931
  @type text: string
2932
  @type digest: string
2933
  @param digest: Expected digest
2934
  @rtype: bool
2935
  @return: Whether HMAC-SHA1 digest matches
2936

2937
  """
2938
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2939

    
2940

    
2941
def SafeEncode(text):
2942
  """Return a 'safe' version of a source string.
2943

2944
  This function mangles the input string and returns a version that
2945
  should be safe to display/encode as ASCII. To this end, we first
2946
  convert it to ASCII using the 'backslashreplace' encoding which
2947
  should get rid of any non-ASCII chars, and then we process it
2948
  through a loop copied from the string repr sources in the python; we
2949
  don't use string_escape anymore since that escape single quotes and
2950
  backslashes too, and that is too much; and that escaping is not
2951
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2952

2953
  @type text: str or unicode
2954
  @param text: input data
2955
  @rtype: str
2956
  @return: a safe version of text
2957

2958
  """
2959
  if isinstance(text, unicode):
2960
    # only if unicode; if str already, we handle it below
2961
    text = text.encode('ascii', 'backslashreplace')
2962
  resu = ""
2963
  for char in text:
2964
    c = ord(char)
2965
    if char  == '\t':
2966
      resu += r'\t'
2967
    elif char == '\n':
2968
      resu += r'\n'
2969
    elif char == '\r':
2970
      resu += r'\'r'
2971
    elif c < 32 or c >= 127: # non-printable
2972
      resu += "\\x%02x" % (c & 0xff)
2973
    else:
2974
      resu += char
2975
  return resu
2976

    
2977

    
2978
def UnescapeAndSplit(text, sep=","):
2979
  """Split and unescape a string based on a given separator.
2980

2981
  This function splits a string based on a separator where the
2982
  separator itself can be escape in order to be an element of the
2983
  elements. The escaping rules are (assuming coma being the
2984
  separator):
2985
    - a plain , separates the elements
2986
    - a sequence \\\\, (double backslash plus comma) is handled as a
2987
      backslash plus a separator comma
2988
    - a sequence \, (backslash plus comma) is handled as a
2989
      non-separator comma
2990

2991
  @type text: string
2992
  @param text: the string to split
2993
  @type sep: string
2994
  @param text: the separator
2995
  @rtype: string
2996
  @return: a list of strings
2997

2998
  """
2999
  # we split the list by sep (with no escaping at this stage)
3000
  slist = text.split(sep)
3001
  # next, we revisit the elements and if any of them ended with an odd
3002
  # number of backslashes, then we join it with the next
3003
  rlist = []
3004
  while slist:
3005
    e1 = slist.pop(0)
3006
    if e1.endswith("\\"):
3007
      num_b = len(e1) - len(e1.rstrip("\\"))
3008
      if num_b % 2 == 1:
3009
        e2 = slist.pop(0)
3010
        # here the backslashes remain (all), and will be reduced in
3011
        # the next step
3012
        rlist.append(e1 + sep + e2)
3013
        continue
3014
    rlist.append(e1)
3015
  # finally, replace backslash-something with something
3016
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3017
  return rlist
3018

    
3019

    
3020
def CommaJoin(names):
3021
  """Nicely join a set of identifiers.
3022

3023
  @param names: set, list or tuple
3024
  @return: a string with the formatted results
3025

3026
  """
3027
  return ", ".join([str(val) for val in names])
3028

    
3029

    
3030
def BytesToMebibyte(value):
3031
  """Converts bytes to mebibytes.
3032

3033
  @type value: int
3034
  @param value: Value in bytes
3035
  @rtype: int
3036
  @return: Value in mebibytes
3037

3038
  """
3039
  return int(round(value / (1024.0 * 1024.0), 0))
3040

    
3041

    
3042
def CalculateDirectorySize(path):
3043
  """Calculates the size of a directory recursively.
3044

3045
  @type path: string
3046
  @param path: Path to directory
3047
  @rtype: int
3048
  @return: Size in mebibytes
3049

3050
  """
3051
  size = 0
3052

    
3053
  for (curpath, _, files) in os.walk(path):
3054
    for filename in files:
3055
      st = os.lstat(PathJoin(curpath, filename))
3056
      size += st.st_size
3057

    
3058
  return BytesToMebibyte(size)
3059

    
3060

    
3061
def GetFilesystemStats(path):
3062
  """Returns the total and free space on a filesystem.
3063

3064
  @type path: string
3065
  @param path: Path on filesystem to be examined
3066
  @rtype: int
3067
  @return: tuple of (Total space, Free space) in mebibytes
3068

3069
  """
3070
  st = os.statvfs(path)
3071

    
3072
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3073
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3074
  return (tsize, fsize)
3075

    
3076

    
3077
def RunInSeparateProcess(fn, *args):
3078
  """Runs a function in a separate process.
3079

3080
  Note: Only boolean return values are supported.
3081

3082
  @type fn: callable
3083
  @param fn: Function to be called
3084
  @rtype: bool
3085
  @return: Function's result
3086

3087
  """
3088
  pid = os.fork()
3089
  if pid == 0:
3090
    # Child process
3091
    try:
3092
      # In case the function uses temporary files
3093
      ResetTempfileModule()
3094

    
3095
      # Call function
3096
      result = int(bool(fn(*args)))
3097
      assert result in (0, 1)
3098
    except: # pylint: disable-msg=W0702
3099
      logging.exception("Error while calling function in separate process")
3100
      # 0 and 1 are reserved for the return value
3101
      result = 33
3102

    
3103
    os._exit(result) # pylint: disable-msg=W0212
3104

    
3105
  # Parent process
3106

    
3107
  # Avoid zombies and check exit code
3108
  (_, status) = os.waitpid(pid, 0)
3109

    
3110
  if os.WIFSIGNALED(status):
3111
    exitcode = None
3112
    signum = os.WTERMSIG(status)
3113
  else:
3114
    exitcode = os.WEXITSTATUS(status)
3115
    signum = None
3116

    
3117
  if not (exitcode in (0, 1) and signum is None):
3118
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3119
                              (exitcode, signum))
3120

    
3121
  return bool(exitcode)
3122

    
3123

    
3124
def IgnoreProcessNotFound(fn, *args, **kwargs):
3125
  """Ignores ESRCH when calling a process-related function.
3126

3127
  ESRCH is raised when a process is not found.
3128

3129
  @rtype: bool
3130
  @return: Whether process was found
3131

3132
  """
3133
  try:
3134
    fn(*args, **kwargs)
3135
  except EnvironmentError, err:
3136
    # Ignore ESRCH
3137
    if err.errno == errno.ESRCH:
3138
      return False
3139
    raise
3140

    
3141
  return True
3142

    
3143

    
3144
def IgnoreSignals(fn, *args, **kwargs):
3145
  """Tries to call a function ignoring failures due to EINTR.
3146

3147
  """
3148
  try:
3149
    return fn(*args, **kwargs)
3150
  except EnvironmentError, err:
3151
    if err.errno == errno.EINTR:
3152
      return None
3153
    else:
3154
      raise
3155
  except (select.error, socket.error), err:
3156
    # In python 2.6 and above select.error is an IOError, so it's handled
3157
    # above, in 2.5 and below it's not, and it's handled here.
3158
    if err.args and err.args[0] == errno.EINTR:
3159
      return None
3160
    else:
3161
      raise
3162

    
3163

    
3164
def LockedMethod(fn):
3165
  """Synchronized object access decorator.
3166

3167
  This decorator is intended to protect access to an object using the
3168
  object's own lock which is hardcoded to '_lock'.
3169

3170
  """
3171
  def _LockDebug(*args, **kwargs):
3172
    if debug_locks:
3173
      logging.debug(*args, **kwargs)
3174

    
3175
  def wrapper(self, *args, **kwargs):
3176
    # pylint: disable-msg=W0212
3177
    assert hasattr(self, '_lock')
3178
    lock = self._lock
3179
    _LockDebug("Waiting for %s", lock)
3180
    lock.acquire()
3181
    try:
3182
      _LockDebug("Acquired %s", lock)
3183
      result = fn(self, *args, **kwargs)
3184
    finally:
3185
      _LockDebug("Releasing %s", lock)
3186
      lock.release()
3187
      _LockDebug("Released %s", lock)
3188
    return result
3189
  return wrapper
3190

    
3191

    
3192
def LockFile(fd):
3193
  """Locks a file using POSIX locks.
3194

3195
  @type fd: int
3196
  @param fd: the file descriptor we need to lock
3197

3198
  """
3199
  try:
3200
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3201
  except IOError, err:
3202
    if err.errno == errno.EAGAIN:
3203
      raise errors.LockError("File already locked")
3204
    raise
3205

    
3206

    
3207
def FormatTime(val):
3208
  """Formats a time value.
3209

3210
  @type val: float or None
3211
  @param val: the timestamp as returned by time.time()
3212
  @return: a string value or N/A if we don't have a valid timestamp
3213

3214
  """
3215
  if val is None or not isinstance(val, (int, float)):
3216
    return "N/A"
3217
  # these two codes works on Linux, but they are not guaranteed on all
3218
  # platforms
3219
  return time.strftime("%F %T", time.localtime(val))
3220

    
3221

    
3222
def FormatSeconds(secs):
3223
  """Formats seconds for easier reading.
3224

3225
  @type secs: number
3226
  @param secs: Number of seconds
3227
  @rtype: string
3228
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3229

3230
  """
3231
  parts = []
3232

    
3233
  secs = round(secs, 0)
3234

    
3235
  if secs > 0:
3236
    # Negative values would be a bit tricky
3237
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3238
      (complete, secs) = divmod(secs, one)
3239
      if complete or parts:
3240
        parts.append("%d%s" % (complete, unit))
3241

    
3242
  parts.append("%ds" % secs)
3243

    
3244
  return " ".join(parts)
3245

    
3246

    
3247
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3248
  """Reads the watcher pause file.
3249

3250
  @type filename: string
3251
  @param filename: Path to watcher pause file
3252
  @type now: None, float or int
3253
  @param now: Current time as Unix timestamp
3254
  @type remove_after: int
3255
  @param remove_after: Remove watcher pause file after specified amount of
3256
    seconds past the pause end time
3257

3258
  """
3259
  if now is None:
3260
    now = time.time()
3261

    
3262
  try:
3263
    value = ReadFile(filename)
3264
  except IOError, err:
3265
    if err.errno != errno.ENOENT:
3266
      raise
3267
    value = None
3268

    
3269
  if value is not None:
3270
    try:
3271
      value = int(value)
3272
    except ValueError:
3273
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3274
                       " removing it"), filename)
3275
      RemoveFile(filename)
3276
      value = None
3277

    
3278
    if value is not None:
3279
      # Remove file if it's outdated
3280
      if now > (value + remove_after):
3281
        RemoveFile(filename)
3282
        value = None
3283

    
3284
      elif now > value:
3285
        value = None
3286

    
3287
  return value
3288

    
3289

    
3290
class RetryTimeout(Exception):
3291
  """Retry loop timed out.
3292

3293
  Any arguments which was passed by the retried function to RetryAgain will be
3294
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3295
  the RaiseInner helper method will reraise it.
3296

3297
  """
3298
  def RaiseInner(self):
3299
    if self.args and isinstance(self.args[0], Exception):
3300
      raise self.args[0]
3301
    else:
3302
      raise RetryTimeout(*self.args)
3303

    
3304

    
3305
class RetryAgain(Exception):
3306
  """Retry again.
3307

3308
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3309
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3310
  of the RetryTimeout() method can be used to reraise it.
3311

3312
  """
3313

    
3314

    
3315
class _RetryDelayCalculator(object):
3316
  """Calculator for increasing delays.
3317

3318
  """
3319
  __slots__ = [
3320
    "_factor",
3321
    "_limit",
3322
    "_next",
3323
    "_start",
3324
    ]
3325

    
3326
  def __init__(self, start, factor, limit):
3327
    """Initializes this class.
3328

3329
    @type start: float
3330
    @param start: Initial delay
3331
    @type factor: float
3332
    @param factor: Factor for delay increase
3333
    @type limit: float or None
3334
    @param limit: Upper limit for delay or None for no limit
3335

3336
    """
3337
    assert start > 0.0
3338
    assert factor >= 1.0
3339
    assert limit is None or limit >= 0.0
3340

    
3341
    self._start = start
3342
    self._factor = factor
3343
    self._limit = limit
3344

    
3345
    self._next = start
3346

    
3347
  def __call__(self):
3348
    """Returns current delay and calculates the next one.
3349

3350
    """
3351
    current = self._next
3352

    
3353
    # Update for next run
3354
    if self._limit is None or self._next < self._limit:
3355
      self._next = min(self._limit, self._next * self._factor)
3356

    
3357
    return current
3358

    
3359

    
3360
#: Special delay to specify whole remaining timeout
3361
RETRY_REMAINING_TIME = object()
3362

    
3363

    
3364
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3365
          _time_fn=time.time):
3366
  """Call a function repeatedly until it succeeds.
3367

3368
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3369
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3370
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3371

3372
  C{delay} can be one of the following:
3373
    - callable returning the delay length as a float
3374
    - Tuple of (start, factor, limit)
3375
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3376
      useful when overriding L{wait_fn} to wait for an external event)
3377
    - A static delay as a number (int or float)
3378

3379
  @type fn: callable
3380
  @param fn: Function to be called
3381
  @param delay: Either a callable (returning the delay), a tuple of (start,
3382
                factor, limit) (see L{_RetryDelayCalculator}),
3383
                L{RETRY_REMAINING_TIME} or a number (int or float)
3384
  @type timeout: float
3385
  @param timeout: Total timeout
3386
  @type wait_fn: callable
3387
  @param wait_fn: Waiting function
3388
  @return: Return value of function
3389

3390
  """
3391
  assert callable(fn)
3392
  assert callable(wait_fn)
3393
  assert callable(_time_fn)
3394

    
3395
  if args is None:
3396
    args = []
3397

    
3398
  end_time = _time_fn() + timeout
3399

    
3400
  if callable(delay):
3401
    # External function to calculate delay
3402
    calc_delay = delay
3403

    
3404
  elif isinstance(delay, (tuple, list)):
3405
    # Increasing delay with optional upper boundary
3406
    (start, factor, limit) = delay
3407
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3408

    
3409
  elif delay is RETRY_REMAINING_TIME:
3410
    # Always use the remaining time
3411
    calc_delay = None
3412

    
3413
  else:
3414
    # Static delay
3415
    calc_delay = lambda: delay
3416

    
3417
  assert calc_delay is None or callable(calc_delay)
3418

    
3419
  while True:
3420
    retry_args = []
3421
    try:
3422
      # pylint: disable-msg=W0142
3423
      return fn(*args)
3424
    except RetryAgain, err:
3425
      retry_args = err.args
3426
    except RetryTimeout:
3427
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3428
                                   " handle RetryTimeout")
3429

    
3430
    remaining_time = end_time - _time_fn()
3431

    
3432
    if remaining_time < 0.0:
3433
      # pylint: disable-msg=W0142
3434
      raise RetryTimeout(*retry_args)
3435

    
3436
    assert remaining_time >= 0.0
3437

    
3438
    if calc_delay is None:
3439
      wait_fn(remaining_time)
3440
    else:
3441
      current_delay = calc_delay()
3442
      if current_delay > 0.0:
3443
        wait_fn(current_delay)
3444

    
3445

    
3446
def GetClosedTempfile(*args, **kwargs):
3447
  """Creates a temporary file and returns its path.
3448

3449
  """
3450
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3451
  _CloseFDNoErr(fd)
3452
  return path
3453

    
3454

    
3455
def GenerateSelfSignedX509Cert(common_name, validity):
3456
  """Generates a self-signed X509 certificate.
3457

3458
  @type common_name: string
3459
  @param common_name: commonName value
3460
  @type validity: int
3461
  @param validity: Validity for certificate in seconds
3462

3463
  """
3464
  # Create private and public key
3465
  key = OpenSSL.crypto.PKey()
3466
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3467

    
3468
  # Create self-signed certificate
3469
  cert = OpenSSL.crypto.X509()
3470
  if common_name:
3471
    cert.get_subject().CN = common_name
3472
  cert.set_serial_number(1)
3473
  cert.gmtime_adj_notBefore(0)
3474
  cert.gmtime_adj_notAfter(validity)
3475
  cert.set_issuer(cert.get_subject())
3476
  cert.set_pubkey(key)
3477
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3478

    
3479
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3480
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3481

    
3482
  return (key_pem, cert_pem)
3483

    
3484

    
3485
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3486
  """Legacy function to generate self-signed X509 certificate.
3487

3488
  """
3489
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3490
                                                   validity * 24 * 60 * 60)
3491

    
3492
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3493

    
3494

    
3495
class FileLock(object):
3496
  """Utility class for file locks.
3497

3498
  """
3499
  def __init__(self, fd, filename):
3500
    """Constructor for FileLock.
3501

3502
    @type fd: file
3503
    @param fd: File object
3504
    @type filename: str
3505
    @param filename: Path of the file opened at I{fd}
3506

3507
    """
3508
    self.fd = fd
3509
    self.filename = filename
3510

    
3511
  @classmethod
3512
  def Open(cls, filename):
3513
    """Creates and opens a file to be used as a file-based lock.
3514

3515
    @type filename: string
3516
    @param filename: path to the file to be locked
3517

3518
    """
3519
    # Using "os.open" is necessary to allow both opening existing file
3520
    # read/write and creating if not existing. Vanilla "open" will truncate an
3521
    # existing file -or- allow creating if not existing.
3522
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3523
               filename)
3524

    
3525
  def __del__(self):
3526
    self.Close()
3527

    
3528
  def Close(self):
3529
    """Close the file and release the lock.
3530

3531
    """
3532
    if hasattr(self, "fd") and self.fd:
3533
      self.fd.close()
3534
      self.fd = None
3535

    
3536
  def _flock(self, flag, blocking, timeout, errmsg):
3537
    """Wrapper for fcntl.flock.
3538

3539
    @type flag: int
3540
    @param flag: operation flag
3541
    @type blocking: bool
3542
    @param blocking: whether the operation should be done in blocking mode.
3543
    @type timeout: None or float
3544
    @param timeout: for how long the operation should be retried (implies
3545
                    non-blocking mode).
3546
    @type errmsg: string
3547
    @param errmsg: error message in case operation fails.
3548

3549
    """
3550
    assert self.fd, "Lock was closed"
3551
    assert timeout is None or timeout >= 0, \
3552
      "If specified, timeout must be positive"
3553
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3554

    
3555
    # When a timeout is used, LOCK_NB must always be set
3556
    if not (timeout is None and blocking):
3557
      flag |= fcntl.LOCK_NB
3558

    
3559
    if timeout is None:
3560
      self._Lock(self.fd, flag, timeout)
3561
    else:
3562
      try:
3563
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3564
              args=(self.fd, flag, timeout))
3565
      except RetryTimeout:
3566
        raise errors.LockError(errmsg)
3567

    
3568
  @staticmethod
3569
  def _Lock(fd, flag, timeout):
3570
    try:
3571
      fcntl.flock(fd, flag)
3572
    except IOError, err:
3573
      if timeout is not None and err.errno == errno.EAGAIN:
3574
        raise RetryAgain()
3575

    
3576
      logging.exception("fcntl.flock failed")
3577
      raise
3578

    
3579
  def Exclusive(self, blocking=False, timeout=None):
3580
    """Locks the file in exclusive mode.
3581

3582
    @type blocking: boolean
3583
    @param blocking: whether to block and wait until we
3584
        can lock the file or return immediately
3585
    @type timeout: int or None
3586
    @param timeout: if not None, the duration to wait for the lock
3587
        (in blocking mode)
3588

3589
    """
3590
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3591
                "Failed to lock %s in exclusive mode" % self.filename)
3592

    
3593
  def Shared(self, blocking=False, timeout=None):
3594
    """Locks the file in shared mode.
3595

3596
    @type blocking: boolean
3597
    @param blocking: whether to block and wait until we
3598
        can lock the file or return immediately
3599
    @type timeout: int or None
3600
    @param timeout: if not None, the duration to wait for the lock
3601
        (in blocking mode)
3602

3603
    """
3604
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3605
                "Failed to lock %s in shared mode" % self.filename)
3606

    
3607
  def Unlock(self, blocking=True, timeout=None):
3608
    """Unlocks the file.
3609

3610
    According to C{flock(2)}, unlocking can also be a nonblocking
3611
    operation::
3612

3613
      To make a non-blocking request, include LOCK_NB with any of the above
3614
      operations.
3615

3616
    @type blocking: boolean
3617
    @param blocking: whether to block and wait until we
3618
        can lock the file or return immediately
3619
    @type timeout: int or None
3620
    @param timeout: if not None, the duration to wait for the lock
3621
        (in blocking mode)
3622

3623
    """
3624
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3625
                "Failed to unlock %s" % self.filename)
3626

    
3627

    
3628
class LineSplitter:
3629
  """Splits data chunks into lines separated by newline.
3630

3631
  Instances provide a file-like interface.
3632

3633
  """
3634
  def __init__(self, line_fn, *args):
3635
    """Initializes this class.
3636

3637
    @type line_fn: callable
3638
    @param line_fn: Function called for each line, first parameter is line
3639
    @param args: Extra arguments for L{line_fn}
3640

3641
    """
3642
    assert callable(line_fn)
3643

    
3644
    if args:
3645
      # Python 2.4 doesn't have functools.partial yet
3646
      self._line_fn = \
3647
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3648
    else:
3649
      self._line_fn = line_fn
3650

    
3651
    self._lines = collections.deque()
3652
    self._buffer = ""
3653

    
3654
  def write(self, data):
3655
    parts = (self._buffer + data).split("\n")
3656
    self._buffer = parts.pop()
3657
    self._lines.extend(parts)
3658

    
3659
  def flush(self):
3660
    while self._lines:
3661
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3662

    
3663
  def close(self):
3664
    self.flush()
3665
    if self._buffer:
3666
      self._line_fn(self._buffer)
3667

    
3668

    
3669
def SignalHandled(signums):
3670
  """Signal Handled decoration.
3671

3672
  This special decorator installs a signal handler and then calls the target
3673
  function. The function must accept a 'signal_handlers' keyword argument,
3674
  which will contain a dict indexed by signal number, with SignalHandler
3675
  objects as values.
3676

3677
  The decorator can be safely stacked with iself, to handle multiple signals
3678
  with different handlers.
3679

3680
  @type signums: list
3681
  @param signums: signals to intercept
3682

3683
  """
3684
  def wrap(fn):
3685
    def sig_function(*args, **kwargs):
3686
      assert 'signal_handlers' not in kwargs or \
3687
             kwargs['signal_handlers'] is None or \
3688
             isinstance(kwargs['signal_handlers'], dict), \
3689
             "Wrong signal_handlers parameter in original function call"
3690
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3691
        signal_handlers = kwargs['signal_handlers']
3692
      else:
3693
        signal_handlers = {}
3694
        kwargs['signal_handlers'] = signal_handlers
3695
      sighandler = SignalHandler(signums)
3696
      try:
3697
        for sig in signums:
3698
          signal_handlers[sig] = sighandler
3699
        return fn(*args, **kwargs)
3700
      finally:
3701
        sighandler.Reset()
3702
    return sig_function
3703
  return wrap
3704

    
3705

    
3706
class SignalWakeupFd(object):
3707
  try:
3708
    # This is only supported in Python 2.5 and above (some distributions
3709
    # backported it to Python 2.4)
3710
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3711
  except AttributeError:
3712
    # Not supported
3713
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3714
      return -1
3715
  else:
3716
    def _SetWakeupFd(self, fd):
3717
      return self._set_wakeup_fd_fn(fd)
3718

    
3719
  def __init__(self):
3720
    """Initializes this class.
3721

3722
    """
3723
    (read_fd, write_fd) = os.pipe()
3724

    
3725
    # Once these succeeded, the file descriptors will be closed automatically.
3726
    # Buffer size 0 is important, otherwise .read() with a specified length
3727
    # might buffer data and the file descriptors won't be marked readable.
3728
    self._read_fh = os.fdopen(read_fd, "r", 0)
3729
    self._write_fh = os.fdopen(write_fd, "w", 0)
3730

    
3731
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3732

    
3733
    # Utility functions
3734
    self.fileno = self._read_fh.fileno
3735
    self.read = self._read_fh.read
3736

    
3737
  def Reset(self):
3738
    """Restores the previous wakeup file descriptor.
3739

3740
    """
3741
    if hasattr(self, "_previous") and self._previous is not None:
3742
      self._SetWakeupFd(self._previous)
3743
      self._previous = None
3744

    
3745
  def Notify(self):
3746
    """Notifies the wakeup file descriptor.
3747

3748
    """
3749
    self._write_fh.write("\0")
3750

    
3751
  def __del__(self):
3752
    """Called before object deletion.
3753

3754
    """
3755
    self.Reset()
3756

    
3757

    
3758
class SignalHandler(object):
3759
  """Generic signal handler class.
3760

3761
  It automatically restores the original handler when deconstructed or
3762
  when L{Reset} is called. You can either pass your own handler
3763
  function in or query the L{called} attribute to detect whether the
3764
  signal was sent.
3765

3766
  @type signum: list
3767
  @ivar signum: the signals we handle
3768
  @type called: boolean
3769
  @ivar called: tracks whether any of the signals have been raised
3770

3771
  """
3772
  def __init__(self, signum, handler_fn=None, wakeup=None):
3773
    """Constructs a new SignalHandler instance.
3774

3775
    @type signum: int or list of ints
3776
    @param signum: Single signal number or set of signal numbers
3777
    @type handler_fn: callable
3778
    @param handler_fn: Signal handling function
3779

3780
    """
3781
    assert handler_fn is None or callable(handler_fn)
3782

    
3783
    self.signum = set(signum)
3784
    self.called = False
3785

    
3786
    self._handler_fn = handler_fn
3787
    self._wakeup = wakeup
3788

    
3789
    self._previous = {}
3790
    try:
3791
      for signum in self.signum:
3792
        # Setup handler
3793
        prev_handler = signal.signal(signum, self._HandleSignal)
3794
        try:
3795
          self._previous[signum] = prev_handler
3796
        except:
3797
          # Restore previous handler
3798
          signal.signal(signum, prev_handler)
3799
          raise
3800
    except:
3801
      # Reset all handlers
3802
      self.Reset()
3803
      # Here we have a race condition: a handler may have already been called,
3804
      # but there's not much we can do about it at this point.
3805
      raise
3806

    
3807
  def __del__(self):
3808
    self.Reset()
3809

    
3810
  def Reset(self):
3811
    """Restore previous handler.
3812

3813
    This will reset all the signals to their previous handlers.
3814

3815
    """
3816
    for signum, prev_handler in self._previous.items():
3817
      signal.signal(signum, prev_handler)
3818
      # If successful, remove from dict
3819
      del self._previous[signum]
3820

    
3821
  def Clear(self):
3822
    """Unsets the L{called} flag.
3823

3824
    This function can be used in case a signal may arrive several times.
3825

3826
    """
3827
    self.called = False
3828

    
3829
  def _HandleSignal(self, signum, frame):
3830
    """Actual signal handling function.
3831

3832
    """
3833
    # This is not nice and not absolutely atomic, but it appears to be the only
3834
    # solution in Python -- there are no atomic types.
3835
    self.called = True
3836

    
3837
    if self._wakeup:
3838
      # Notify whoever is interested in signals
3839
      self._wakeup.Notify()
3840

    
3841
    if self._handler_fn:
3842
      self._handler_fn(signum, frame)
3843

    
3844

    
3845
class FieldSet(object):
3846
  """A simple field set.
3847

3848
  Among the features are:
3849
    - checking if a string is among a list of static string or regex objects
3850
    - checking if a whole list of string matches
3851
    - returning the matching groups from a regex match
3852

3853
  Internally, all fields are held as regular expression objects.
3854

3855
  """
3856
  def __init__(self, *items):
3857
    self.items = [re.compile("^%s$" % value) for value in items]
3858

    
3859
  def Extend(self, other_set):
3860
    """Extend the field set with the items from another one"""
3861
    self.items.extend(other_set.items)
3862

    
3863
  def Matches(self, field):
3864
    """Checks if a field matches the current set
3865

3866
    @type field: str
3867
    @param field: the string to match
3868
    @return: either None or a regular expression match object
3869

3870
    """
3871
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3872
      return m
3873
    return None
3874

    
3875
  def NonMatching(self, items):
3876
    """Returns the list of fields not matching the current set
3877

3878
    @type items: list
3879
    @param items: the list of fields to check
3880
    @rtype: list
3881
    @return: list of non-matching fields
3882

3883
    """
3884
    return [val for val in items if not self.Matches(val)]