Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 59525e1f

History | View | Annotate | Download (102.2 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52

    
53
from cStringIO import StringIO
54

    
55
try:
56
  # pylint: disable-msg=F0401
57
  import ctypes
58
except ImportError:
59
  ctypes = None
60

    
61
from ganeti import errors
62
from ganeti import constants
63
from ganeti import compat
64
from ganeti import netutils
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
85

    
86
# Certificate verification results
87
(CERT_WARNING,
88
 CERT_ERROR) = range(1, 3)
89

    
90
# Flags for mlockall() (from bits/mman.h)
91
_MCL_CURRENT = 1
92
_MCL_FUTURE = 2
93

    
94

    
95
class RunResult(object):
96
  """Holds the result of running external programs.
97

98
  @type exit_code: int
99
  @ivar exit_code: the exit code of the program, or None (if the program
100
      didn't exit())
101
  @type signal: int or None
102
  @ivar signal: the signal that caused the program to finish, or None
103
      (if the program wasn't terminated by a signal)
104
  @type stdout: str
105
  @ivar stdout: the standard output of the program
106
  @type stderr: str
107
  @ivar stderr: the standard error of the program
108
  @type failed: boolean
109
  @ivar failed: True in case the program was
110
      terminated by a signal or exited with a non-zero exit code
111
  @ivar fail_reason: a string detailing the termination reason
112

113
  """
114
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
115
               "failed", "fail_reason", "cmd"]
116

    
117

    
118
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
119
    self.cmd = cmd
120
    self.exit_code = exit_code
121
    self.signal = signal_
122
    self.stdout = stdout
123
    self.stderr = stderr
124
    self.failed = (signal_ is not None or exit_code != 0)
125

    
126
    if self.signal is not None:
127
      self.fail_reason = "terminated by signal %s" % self.signal
128
    elif self.exit_code is not None:
129
      self.fail_reason = "exited with exit code %s" % self.exit_code
130
    else:
131
      self.fail_reason = "unable to determine termination reason"
132

    
133
    if self.failed:
134
      logging.debug("Command '%s' failed (%s); output: %s",
135
                    self.cmd, self.fail_reason, self.output)
136

    
137
  def _GetOutput(self):
138
    """Returns the combined stdout and stderr for easier usage.
139

140
    """
141
    return self.stdout + self.stderr
142

    
143
  output = property(_GetOutput, None, None, "Return full output")
144

    
145

    
146
def _BuildCmdEnvironment(env, reset):
147
  """Builds the environment for an external program.
148

149
  """
150
  if reset:
151
    cmd_env = {}
152
  else:
153
    cmd_env = os.environ.copy()
154
    cmd_env["LC_ALL"] = "C"
155

    
156
  if env is not None:
157
    cmd_env.update(env)
158

    
159
  return cmd_env
160

    
161

    
162
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
163
  """Execute a (shell) command.
164

165
  The command should not read from its standard input, as it will be
166
  closed.
167

168
  @type cmd: string or list
169
  @param cmd: Command to run
170
  @type env: dict
171
  @param env: Additional environment variables
172
  @type output: str
173
  @param output: if desired, the output of the command can be
174
      saved in a file instead of the RunResult instance; this
175
      parameter denotes the file name (if not None)
176
  @type cwd: string
177
  @param cwd: if specified, will be used as the working
178
      directory for the command; the default will be /
179
  @type reset_env: boolean
180
  @param reset_env: whether to reset or keep the default os environment
181
  @rtype: L{RunResult}
182
  @return: RunResult instance
183
  @raise errors.ProgrammerError: if we call this when forks are disabled
184

185
  """
186
  if no_fork:
187
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
188

    
189
  if isinstance(cmd, basestring):
190
    strcmd = cmd
191
    shell = True
192
  else:
193
    cmd = [str(val) for val in cmd]
194
    strcmd = ShellQuoteArgs(cmd)
195
    shell = False
196

    
197
  if output:
198
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
199
  else:
200
    logging.debug("RunCmd %s", strcmd)
201

    
202
  cmd_env = _BuildCmdEnvironment(env, reset_env)
203

    
204
  try:
205
    if output is None:
206
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
207
    else:
208
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
209
      out = err = ""
210
  except OSError, err:
211
    if err.errno == errno.ENOENT:
212
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
213
                               (strcmd, err))
214
    else:
215
      raise
216

    
217
  if status >= 0:
218
    exitcode = status
219
    signal_ = None
220
  else:
221
    exitcode = None
222
    signal_ = -status
223

    
224
  return RunResult(exitcode, signal_, out, err, strcmd)
225

    
226

    
227
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
228
                pidfile=None):
229
  """Start a daemon process after forking twice.
230

231
  @type cmd: string or list
232
  @param cmd: Command to run
233
  @type env: dict
234
  @param env: Additional environment variables
235
  @type cwd: string
236
  @param cwd: Working directory for the program
237
  @type output: string
238
  @param output: Path to file in which to save the output
239
  @type output_fd: int
240
  @param output_fd: File descriptor for output
241
  @type pidfile: string
242
  @param pidfile: Process ID file
243
  @rtype: int
244
  @return: Daemon process ID
245
  @raise errors.ProgrammerError: if we call this when forks are disabled
246

247
  """
248
  if no_fork:
249
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
250
                                 " disabled")
251

    
252
  if output and not (bool(output) ^ (output_fd is not None)):
253
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
254
                                 " specified")
255

    
256
  if isinstance(cmd, basestring):
257
    cmd = ["/bin/sh", "-c", cmd]
258

    
259
  strcmd = ShellQuoteArgs(cmd)
260

    
261
  if output:
262
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
263
  else:
264
    logging.debug("StartDaemon %s", strcmd)
265

    
266
  cmd_env = _BuildCmdEnvironment(env, False)
267

    
268
  # Create pipe for sending PID back
269
  (pidpipe_read, pidpipe_write) = os.pipe()
270
  try:
271
    try:
272
      # Create pipe for sending error messages
273
      (errpipe_read, errpipe_write) = os.pipe()
274
      try:
275
        try:
276
          # First fork
277
          pid = os.fork()
278
          if pid == 0:
279
            try:
280
              # Child process, won't return
281
              _StartDaemonChild(errpipe_read, errpipe_write,
282
                                pidpipe_read, pidpipe_write,
283
                                cmd, cmd_env, cwd,
284
                                output, output_fd, pidfile)
285
            finally:
286
              # Well, maybe child process failed
287
              os._exit(1) # pylint: disable-msg=W0212
288
        finally:
289
          _CloseFDNoErr(errpipe_write)
290

    
291
        # Wait for daemon to be started (or an error message to arrive) and read
292
        # up to 100 KB as an error message
293
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
294
      finally:
295
        _CloseFDNoErr(errpipe_read)
296
    finally:
297
      _CloseFDNoErr(pidpipe_write)
298

    
299
    # Read up to 128 bytes for PID
300
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
301
  finally:
302
    _CloseFDNoErr(pidpipe_read)
303

    
304
  # Try to avoid zombies by waiting for child process
305
  try:
306
    os.waitpid(pid, 0)
307
  except OSError:
308
    pass
309

    
310
  if errormsg:
311
    raise errors.OpExecError("Error when starting daemon process: %r" %
312
                             errormsg)
313

    
314
  try:
315
    return int(pidtext)
316
  except (ValueError, TypeError), err:
317
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
318
                             (pidtext, err))
319

    
320

    
321
def _StartDaemonChild(errpipe_read, errpipe_write,
322
                      pidpipe_read, pidpipe_write,
323
                      args, env, cwd,
324
                      output, fd_output, pidfile):
325
  """Child process for starting daemon.
326

327
  """
328
  try:
329
    # Close parent's side
330
    _CloseFDNoErr(errpipe_read)
331
    _CloseFDNoErr(pidpipe_read)
332

    
333
    # First child process
334
    os.chdir("/")
335
    os.umask(077)
336
    os.setsid()
337

    
338
    # And fork for the second time
339
    pid = os.fork()
340
    if pid != 0:
341
      # Exit first child process
342
      os._exit(0) # pylint: disable-msg=W0212
343

    
344
    # Make sure pipe is closed on execv* (and thereby notifies original process)
345
    SetCloseOnExecFlag(errpipe_write, True)
346

    
347
    # List of file descriptors to be left open
348
    noclose_fds = [errpipe_write]
349

    
350
    # Open PID file
351
    if pidfile:
352
      try:
353
        # TODO: Atomic replace with another locked file instead of writing into
354
        # it after creating
355
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
356

    
357
        # Lock the PID file (and fail if not possible to do so). Any code
358
        # wanting to send a signal to the daemon should try to lock the PID
359
        # file before reading it. If acquiring the lock succeeds, the daemon is
360
        # no longer running and the signal should not be sent.
361
        LockFile(fd_pidfile)
362

    
363
        os.write(fd_pidfile, "%d\n" % os.getpid())
364
      except Exception, err:
365
        raise Exception("Creating and locking PID file failed: %s" % err)
366

    
367
      # Keeping the file open to hold the lock
368
      noclose_fds.append(fd_pidfile)
369

    
370
      SetCloseOnExecFlag(fd_pidfile, False)
371
    else:
372
      fd_pidfile = None
373

    
374
    # Open /dev/null
375
    fd_devnull = os.open(os.devnull, os.O_RDWR)
376

    
377
    assert not output or (bool(output) ^ (fd_output is not None))
378

    
379
    if fd_output is not None:
380
      pass
381
    elif output:
382
      # Open output file
383
      try:
384
        # TODO: Implement flag to set append=yes/no
385
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
386
      except EnvironmentError, err:
387
        raise Exception("Opening output file failed: %s" % err)
388
    else:
389
      fd_output = fd_devnull
390

    
391
    # Redirect standard I/O
392
    os.dup2(fd_devnull, 0)
393
    os.dup2(fd_output, 1)
394
    os.dup2(fd_output, 2)
395

    
396
    # Send daemon PID to parent
397
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
398

    
399
    # Close all file descriptors except stdio and error message pipe
400
    CloseFDs(noclose_fds=noclose_fds)
401

    
402
    # Change working directory
403
    os.chdir(cwd)
404

    
405
    if env is None:
406
      os.execvp(args[0], args)
407
    else:
408
      os.execvpe(args[0], args, env)
409
  except: # pylint: disable-msg=W0702
410
    try:
411
      # Report errors to original process
412
      buf = str(sys.exc_info()[1])
413

    
414
      RetryOnSignal(os.write, errpipe_write, buf)
415
    except: # pylint: disable-msg=W0702
416
      # Ignore errors in error handling
417
      pass
418

    
419
  os._exit(1) # pylint: disable-msg=W0212
420

    
421

    
422
def _RunCmdPipe(cmd, env, via_shell, cwd):
423
  """Run a command and return its output.
424

425
  @type  cmd: string or list
426
  @param cmd: Command to run
427
  @type env: dict
428
  @param env: The environment to use
429
  @type via_shell: bool
430
  @param via_shell: if we should run via the shell
431
  @type cwd: string
432
  @param cwd: the working directory for the program
433
  @rtype: tuple
434
  @return: (out, err, status)
435

436
  """
437
  poller = select.poll()
438
  child = subprocess.Popen(cmd, shell=via_shell,
439
                           stderr=subprocess.PIPE,
440
                           stdout=subprocess.PIPE,
441
                           stdin=subprocess.PIPE,
442
                           close_fds=True, env=env,
443
                           cwd=cwd)
444

    
445
  child.stdin.close()
446
  poller.register(child.stdout, select.POLLIN)
447
  poller.register(child.stderr, select.POLLIN)
448
  out = StringIO()
449
  err = StringIO()
450
  fdmap = {
451
    child.stdout.fileno(): (out, child.stdout),
452
    child.stderr.fileno(): (err, child.stderr),
453
    }
454
  for fd in fdmap:
455
    SetNonblockFlag(fd, True)
456

    
457
  while fdmap:
458
    pollresult = RetryOnSignal(poller.poll)
459

    
460
    for fd, event in pollresult:
461
      if event & select.POLLIN or event & select.POLLPRI:
462
        data = fdmap[fd][1].read()
463
        # no data from read signifies EOF (the same as POLLHUP)
464
        if not data:
465
          poller.unregister(fd)
466
          del fdmap[fd]
467
          continue
468
        fdmap[fd][0].write(data)
469
      if (event & select.POLLNVAL or event & select.POLLHUP or
470
          event & select.POLLERR):
471
        poller.unregister(fd)
472
        del fdmap[fd]
473

    
474
  out = out.getvalue()
475
  err = err.getvalue()
476

    
477
  status = child.wait()
478
  return out, err, status
479

    
480

    
481
def _RunCmdFile(cmd, env, via_shell, output, cwd):
482
  """Run a command and save its output to a file.
483

484
  @type  cmd: string or list
485
  @param cmd: Command to run
486
  @type env: dict
487
  @param env: The environment to use
488
  @type via_shell: bool
489
  @param via_shell: if we should run via the shell
490
  @type output: str
491
  @param output: the filename in which to save the output
492
  @type cwd: string
493
  @param cwd: the working directory for the program
494
  @rtype: int
495
  @return: the exit status
496

497
  """
498
  fh = open(output, "a")
499
  try:
500
    child = subprocess.Popen(cmd, shell=via_shell,
501
                             stderr=subprocess.STDOUT,
502
                             stdout=fh,
503
                             stdin=subprocess.PIPE,
504
                             close_fds=True, env=env,
505
                             cwd=cwd)
506

    
507
    child.stdin.close()
508
    status = child.wait()
509
  finally:
510
    fh.close()
511
  return status
512

    
513

    
514
def SetCloseOnExecFlag(fd, enable):
515
  """Sets or unsets the close-on-exec flag on a file descriptor.
516

517
  @type fd: int
518
  @param fd: File descriptor
519
  @type enable: bool
520
  @param enable: Whether to set or unset it.
521

522
  """
523
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
524

    
525
  if enable:
526
    flags |= fcntl.FD_CLOEXEC
527
  else:
528
    flags &= ~fcntl.FD_CLOEXEC
529

    
530
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
531

    
532

    
533
def SetNonblockFlag(fd, enable):
534
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
535

536
  @type fd: int
537
  @param fd: File descriptor
538
  @type enable: bool
539
  @param enable: Whether to set or unset it
540

541
  """
542
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
543

    
544
  if enable:
545
    flags |= os.O_NONBLOCK
546
  else:
547
    flags &= ~os.O_NONBLOCK
548

    
549
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
550

    
551

    
552
def RetryOnSignal(fn, *args, **kwargs):
553
  """Calls a function again if it failed due to EINTR.
554

555
  """
556
  while True:
557
    try:
558
      return fn(*args, **kwargs)
559
    except EnvironmentError, err:
560
      if err.errno != errno.EINTR:
561
        raise
562
    except (socket.error, select.error), err:
563
      # In python 2.6 and above select.error is an IOError, so it's handled
564
      # above, in 2.5 and below it's not, and it's handled here.
565
      if not (err.args and err.args[0] == errno.EINTR):
566
        raise
567

    
568

    
569
def RunParts(dir_name, env=None, reset_env=False):
570
  """Run Scripts or programs in a directory
571

572
  @type dir_name: string
573
  @param dir_name: absolute path to a directory
574
  @type env: dict
575
  @param env: The environment to use
576
  @type reset_env: boolean
577
  @param reset_env: whether to reset or keep the default os environment
578
  @rtype: list of tuples
579
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
580

581
  """
582
  rr = []
583

    
584
  try:
585
    dir_contents = ListVisibleFiles(dir_name)
586
  except OSError, err:
587
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
588
    return rr
589

    
590
  for relname in sorted(dir_contents):
591
    fname = PathJoin(dir_name, relname)
592
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
593
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
594
      rr.append((relname, constants.RUNPARTS_SKIP, None))
595
    else:
596
      try:
597
        result = RunCmd([fname], env=env, reset_env=reset_env)
598
      except Exception, err: # pylint: disable-msg=W0703
599
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
600
      else:
601
        rr.append((relname, constants.RUNPARTS_RUN, result))
602

    
603
  return rr
604

    
605

    
606
def RemoveFile(filename):
607
  """Remove a file ignoring some errors.
608

609
  Remove a file, ignoring non-existing ones or directories. Other
610
  errors are passed.
611

612
  @type filename: str
613
  @param filename: the file to be removed
614

615
  """
616
  try:
617
    os.unlink(filename)
618
  except OSError, err:
619
    if err.errno not in (errno.ENOENT, errno.EISDIR):
620
      raise
621

    
622

    
623
def RemoveDir(dirname):
624
  """Remove an empty directory.
625

626
  Remove a directory, ignoring non-existing ones.
627
  Other errors are passed. This includes the case,
628
  where the directory is not empty, so it can't be removed.
629

630
  @type dirname: str
631
  @param dirname: the empty directory to be removed
632

633
  """
634
  try:
635
    os.rmdir(dirname)
636
  except OSError, err:
637
    if err.errno != errno.ENOENT:
638
      raise
639

    
640

    
641
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
642
  """Renames a file.
643

644
  @type old: string
645
  @param old: Original path
646
  @type new: string
647
  @param new: New path
648
  @type mkdir: bool
649
  @param mkdir: Whether to create target directory if it doesn't exist
650
  @type mkdir_mode: int
651
  @param mkdir_mode: Mode for newly created directories
652

653
  """
654
  try:
655
    return os.rename(old, new)
656
  except OSError, err:
657
    # In at least one use case of this function, the job queue, directory
658
    # creation is very rare. Checking for the directory before renaming is not
659
    # as efficient.
660
    if mkdir and err.errno == errno.ENOENT:
661
      # Create directory and try again
662
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
663

    
664
      return os.rename(old, new)
665

    
666
    raise
667

    
668

    
669
def Makedirs(path, mode=0750):
670
  """Super-mkdir; create a leaf directory and all intermediate ones.
671

672
  This is a wrapper around C{os.makedirs} adding error handling not implemented
673
  before Python 2.5.
674

675
  """
676
  try:
677
    os.makedirs(path, mode)
678
  except OSError, err:
679
    # Ignore EEXIST. This is only handled in os.makedirs as included in
680
    # Python 2.5 and above.
681
    if err.errno != errno.EEXIST or not os.path.exists(path):
682
      raise
683

    
684

    
685
def ResetTempfileModule():
686
  """Resets the random name generator of the tempfile module.
687

688
  This function should be called after C{os.fork} in the child process to
689
  ensure it creates a newly seeded random generator. Otherwise it would
690
  generate the same random parts as the parent process. If several processes
691
  race for the creation of a temporary file, this could lead to one not getting
692
  a temporary name.
693

694
  """
695
  # pylint: disable-msg=W0212
696
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
697
    tempfile._once_lock.acquire()
698
    try:
699
      # Reset random name generator
700
      tempfile._name_sequence = None
701
    finally:
702
      tempfile._once_lock.release()
703
  else:
704
    logging.critical("The tempfile module misses at least one of the"
705
                     " '_once_lock' and '_name_sequence' attributes")
706

    
707

    
708
def _FingerprintFile(filename):
709
  """Compute the fingerprint of a file.
710

711
  If the file does not exist, a None will be returned
712
  instead.
713

714
  @type filename: str
715
  @param filename: the filename to checksum
716
  @rtype: str
717
  @return: the hex digest of the sha checksum of the contents
718
      of the file
719

720
  """
721
  if not (os.path.exists(filename) and os.path.isfile(filename)):
722
    return None
723

    
724
  f = open(filename)
725

    
726
  fp = compat.sha1_hash()
727
  while True:
728
    data = f.read(4096)
729
    if not data:
730
      break
731

    
732
    fp.update(data)
733

    
734
  return fp.hexdigest()
735

    
736

    
737
def FingerprintFiles(files):
738
  """Compute fingerprints for a list of files.
739

740
  @type files: list
741
  @param files: the list of filename to fingerprint
742
  @rtype: dict
743
  @return: a dictionary filename: fingerprint, holding only
744
      existing files
745

746
  """
747
  ret = {}
748

    
749
  for filename in files:
750
    cksum = _FingerprintFile(filename)
751
    if cksum:
752
      ret[filename] = cksum
753

    
754
  return ret
755

    
756

    
757
def ForceDictType(target, key_types, allowed_values=None):
758
  """Force the values of a dict to have certain types.
759

760
  @type target: dict
761
  @param target: the dict to update
762
  @type key_types: dict
763
  @param key_types: dict mapping target dict keys to types
764
                    in constants.ENFORCEABLE_TYPES
765
  @type allowed_values: list
766
  @keyword allowed_values: list of specially allowed values
767

768
  """
769
  if allowed_values is None:
770
    allowed_values = []
771

    
772
  if not isinstance(target, dict):
773
    msg = "Expected dictionary, got '%s'" % target
774
    raise errors.TypeEnforcementError(msg)
775

    
776
  for key in target:
777
    if key not in key_types:
778
      msg = "Unknown key '%s'" % key
779
      raise errors.TypeEnforcementError(msg)
780

    
781
    if target[key] in allowed_values:
782
      continue
783

    
784
    ktype = key_types[key]
785
    if ktype not in constants.ENFORCEABLE_TYPES:
786
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
787
      raise errors.ProgrammerError(msg)
788

    
789
    if ktype in (constants.VTYPE_STRING, constants.VTYPE_MAYBE_STRING):
790
      if target[key] is None and ktype == constants.VTYPE_MAYBE_STRING:
791
        pass
792
      elif not isinstance(target[key], basestring):
793
        if isinstance(target[key], bool) and not target[key]:
794
          target[key] = ''
795
        else:
796
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
797
          raise errors.TypeEnforcementError(msg)
798
    elif ktype == constants.VTYPE_BOOL:
799
      if isinstance(target[key], basestring) and target[key]:
800
        if target[key].lower() == constants.VALUE_FALSE:
801
          target[key] = False
802
        elif target[key].lower() == constants.VALUE_TRUE:
803
          target[key] = True
804
        else:
805
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
806
          raise errors.TypeEnforcementError(msg)
807
      elif target[key]:
808
        target[key] = True
809
      else:
810
        target[key] = False
811
    elif ktype == constants.VTYPE_SIZE:
812
      try:
813
        target[key] = ParseUnit(target[key])
814
      except errors.UnitParseError, err:
815
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
816
              (key, target[key], err)
817
        raise errors.TypeEnforcementError(msg)
818
    elif ktype == constants.VTYPE_INT:
819
      try:
820
        target[key] = int(target[key])
821
      except (ValueError, TypeError):
822
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
823
        raise errors.TypeEnforcementError(msg)
824

    
825

    
826
def _GetProcStatusPath(pid):
827
  """Returns the path for a PID's proc status file.
828

829
  @type pid: int
830
  @param pid: Process ID
831
  @rtype: string
832

833
  """
834
  return "/proc/%d/status" % pid
835

    
836

    
837
def IsProcessAlive(pid):
838
  """Check if a given pid exists on the system.
839

840
  @note: zombie status is not handled, so zombie processes
841
      will be returned as alive
842
  @type pid: int
843
  @param pid: the process ID to check
844
  @rtype: boolean
845
  @return: True if the process exists
846

847
  """
848
  def _TryStat(name):
849
    try:
850
      os.stat(name)
851
      return True
852
    except EnvironmentError, err:
853
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
854
        return False
855
      elif err.errno == errno.EINVAL:
856
        raise RetryAgain(err)
857
      raise
858

    
859
  assert isinstance(pid, int), "pid must be an integer"
860
  if pid <= 0:
861
    return False
862

    
863
  # /proc in a multiprocessor environment can have strange behaviors.
864
  # Retry the os.stat a few times until we get a good result.
865
  try:
866
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
867
                 args=[_GetProcStatusPath(pid)])
868
  except RetryTimeout, err:
869
    err.RaiseInner()
870

    
871

    
872
def _ParseSigsetT(sigset):
873
  """Parse a rendered sigset_t value.
874

875
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
876
  function.
877

878
  @type sigset: string
879
  @param sigset: Rendered signal set from /proc/$pid/status
880
  @rtype: set
881
  @return: Set of all enabled signal numbers
882

883
  """
884
  result = set()
885

    
886
  signum = 0
887
  for ch in reversed(sigset):
888
    chv = int(ch, 16)
889

    
890
    # The following could be done in a loop, but it's easier to read and
891
    # understand in the unrolled form
892
    if chv & 1:
893
      result.add(signum + 1)
894
    if chv & 2:
895
      result.add(signum + 2)
896
    if chv & 4:
897
      result.add(signum + 3)
898
    if chv & 8:
899
      result.add(signum + 4)
900

    
901
    signum += 4
902

    
903
  return result
904

    
905

    
906
def _GetProcStatusField(pstatus, field):
907
  """Retrieves a field from the contents of a proc status file.
908

909
  @type pstatus: string
910
  @param pstatus: Contents of /proc/$pid/status
911
  @type field: string
912
  @param field: Name of field whose value should be returned
913
  @rtype: string
914

915
  """
916
  for line in pstatus.splitlines():
917
    parts = line.split(":", 1)
918

    
919
    if len(parts) < 2 or parts[0] != field:
920
      continue
921

    
922
    return parts[1].strip()
923

    
924
  return None
925

    
926

    
927
def IsProcessHandlingSignal(pid, signum, status_path=None):
928
  """Checks whether a process is handling a signal.
929

930
  @type pid: int
931
  @param pid: Process ID
932
  @type signum: int
933
  @param signum: Signal number
934
  @rtype: bool
935

936
  """
937
  if status_path is None:
938
    status_path = _GetProcStatusPath(pid)
939

    
940
  try:
941
    proc_status = ReadFile(status_path)
942
  except EnvironmentError, err:
943
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
944
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
945
      return False
946
    raise
947

    
948
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
949
  if sigcgt is None:
950
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
951

    
952
  # Now check whether signal is handled
953
  return signum in _ParseSigsetT(sigcgt)
954

    
955

    
956
def ReadPidFile(pidfile):
957
  """Read a pid from a file.
958

959
  @type  pidfile: string
960
  @param pidfile: path to the file containing the pid
961
  @rtype: int
962
  @return: The process id, if the file exists and contains a valid PID,
963
           otherwise 0
964

965
  """
966
  try:
967
    raw_data = ReadOneLineFile(pidfile)
968
  except EnvironmentError, err:
969
    if err.errno != errno.ENOENT:
970
      logging.exception("Can't read pid file")
971
    return 0
972

    
973
  try:
974
    pid = int(raw_data)
975
  except (TypeError, ValueError), err:
976
    logging.info("Can't parse pid file contents", exc_info=True)
977
    return 0
978

    
979
  return pid
980

    
981

    
982
def ReadLockedPidFile(path):
983
  """Reads a locked PID file.
984

985
  This can be used together with L{StartDaemon}.
986

987
  @type path: string
988
  @param path: Path to PID file
989
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
990

991
  """
992
  try:
993
    fd = os.open(path, os.O_RDONLY)
994
  except EnvironmentError, err:
995
    if err.errno == errno.ENOENT:
996
      # PID file doesn't exist
997
      return None
998
    raise
999

    
1000
  try:
1001
    try:
1002
      # Try to acquire lock
1003
      LockFile(fd)
1004
    except errors.LockError:
1005
      # Couldn't lock, daemon is running
1006
      return int(os.read(fd, 100))
1007
  finally:
1008
    os.close(fd)
1009

    
1010
  return None
1011

    
1012

    
1013
def MatchNameComponent(key, name_list, case_sensitive=True):
1014
  """Try to match a name against a list.
1015

1016
  This function will try to match a name like test1 against a list
1017
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1018
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1019
  not I{'test1.ex'}. A multiple match will be considered as no match
1020
  at all (e.g. I{'test1'} against C{['test1.example.com',
1021
  'test1.example.org']}), except when the key fully matches an entry
1022
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1023

1024
  @type key: str
1025
  @param key: the name to be searched
1026
  @type name_list: list
1027
  @param name_list: the list of strings against which to search the key
1028
  @type case_sensitive: boolean
1029
  @param case_sensitive: whether to provide a case-sensitive match
1030

1031
  @rtype: None or str
1032
  @return: None if there is no match I{or} if there are multiple matches,
1033
      otherwise the element from the list which matches
1034

1035
  """
1036
  if key in name_list:
1037
    return key
1038

    
1039
  re_flags = 0
1040
  if not case_sensitive:
1041
    re_flags |= re.IGNORECASE
1042
    key = key.upper()
1043
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1044
  names_filtered = []
1045
  string_matches = []
1046
  for name in name_list:
1047
    if mo.match(name) is not None:
1048
      names_filtered.append(name)
1049
      if not case_sensitive and key == name.upper():
1050
        string_matches.append(name)
1051

    
1052
  if len(string_matches) == 1:
1053
    return string_matches[0]
1054
  if len(names_filtered) == 1:
1055
    return names_filtered[0]
1056
  return None
1057

    
1058

    
1059
def ValidateServiceName(name):
1060
  """Validate the given service name.
1061

1062
  @type name: number or string
1063
  @param name: Service name or port specification
1064

1065
  """
1066
  try:
1067
    numport = int(name)
1068
  except (ValueError, TypeError):
1069
    # Non-numeric service name
1070
    valid = _VALID_SERVICE_NAME_RE.match(name)
1071
  else:
1072
    # Numeric port (protocols other than TCP or UDP might need adjustments
1073
    # here)
1074
    valid = (numport >= 0 and numport < (1 << 16))
1075

    
1076
  if not valid:
1077
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1078
                               errors.ECODE_INVAL)
1079

    
1080
  return name
1081

    
1082

    
1083
def ListVolumeGroups():
1084
  """List volume groups and their size
1085

1086
  @rtype: dict
1087
  @return:
1088
       Dictionary with keys volume name and values
1089
       the size of the volume
1090

1091
  """
1092
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1093
  result = RunCmd(command)
1094
  retval = {}
1095
  if result.failed:
1096
    return retval
1097

    
1098
  for line in result.stdout.splitlines():
1099
    try:
1100
      name, size = line.split()
1101
      size = int(float(size))
1102
    except (IndexError, ValueError), err:
1103
      logging.error("Invalid output from vgs (%s): %s", err, line)
1104
      continue
1105

    
1106
    retval[name] = size
1107

    
1108
  return retval
1109

    
1110

    
1111
def BridgeExists(bridge):
1112
  """Check whether the given bridge exists in the system
1113

1114
  @type bridge: str
1115
  @param bridge: the bridge name to check
1116
  @rtype: boolean
1117
  @return: True if it does
1118

1119
  """
1120
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1121

    
1122

    
1123
def NiceSort(name_list):
1124
  """Sort a list of strings based on digit and non-digit groupings.
1125

1126
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1127
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1128
  'a11']}.
1129

1130
  The sort algorithm breaks each name in groups of either only-digits
1131
  or no-digits. Only the first eight such groups are considered, and
1132
  after that we just use what's left of the string.
1133

1134
  @type name_list: list
1135
  @param name_list: the names to be sorted
1136
  @rtype: list
1137
  @return: a copy of the name list sorted with our algorithm
1138

1139
  """
1140
  _SORTER_BASE = "(\D+|\d+)"
1141
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1142
                                                  _SORTER_BASE, _SORTER_BASE,
1143
                                                  _SORTER_BASE, _SORTER_BASE,
1144
                                                  _SORTER_BASE, _SORTER_BASE)
1145
  _SORTER_RE = re.compile(_SORTER_FULL)
1146
  _SORTER_NODIGIT = re.compile("^\D*$")
1147
  def _TryInt(val):
1148
    """Attempts to convert a variable to integer."""
1149
    if val is None or _SORTER_NODIGIT.match(val):
1150
      return val
1151
    rval = int(val)
1152
    return rval
1153

    
1154
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1155
             for name in name_list]
1156
  to_sort.sort()
1157
  return [tup[1] for tup in to_sort]
1158

    
1159

    
1160
def TryConvert(fn, val):
1161
  """Try to convert a value ignoring errors.
1162

1163
  This function tries to apply function I{fn} to I{val}. If no
1164
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1165
  the result, else it will return the original value. Any other
1166
  exceptions are propagated to the caller.
1167

1168
  @type fn: callable
1169
  @param fn: function to apply to the value
1170
  @param val: the value to be converted
1171
  @return: The converted value if the conversion was successful,
1172
      otherwise the original value.
1173

1174
  """
1175
  try:
1176
    nv = fn(val)
1177
  except (ValueError, TypeError):
1178
    nv = val
1179
  return nv
1180

    
1181

    
1182
def IsValidShellParam(word):
1183
  """Verifies is the given word is safe from the shell's p.o.v.
1184

1185
  This means that we can pass this to a command via the shell and be
1186
  sure that it doesn't alter the command line and is passed as such to
1187
  the actual command.
1188

1189
  Note that we are overly restrictive here, in order to be on the safe
1190
  side.
1191

1192
  @type word: str
1193
  @param word: the word to check
1194
  @rtype: boolean
1195
  @return: True if the word is 'safe'
1196

1197
  """
1198
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1199

    
1200

    
1201
def BuildShellCmd(template, *args):
1202
  """Build a safe shell command line from the given arguments.
1203

1204
  This function will check all arguments in the args list so that they
1205
  are valid shell parameters (i.e. they don't contain shell
1206
  metacharacters). If everything is ok, it will return the result of
1207
  template % args.
1208

1209
  @type template: str
1210
  @param template: the string holding the template for the
1211
      string formatting
1212
  @rtype: str
1213
  @return: the expanded command line
1214

1215
  """
1216
  for word in args:
1217
    if not IsValidShellParam(word):
1218
      raise errors.ProgrammerError("Shell argument '%s' contains"
1219
                                   " invalid characters" % word)
1220
  return template % args
1221

    
1222

    
1223
def FormatUnit(value, units):
1224
  """Formats an incoming number of MiB with the appropriate unit.
1225

1226
  @type value: int
1227
  @param value: integer representing the value in MiB (1048576)
1228
  @type units: char
1229
  @param units: the type of formatting we should do:
1230
      - 'h' for automatic scaling
1231
      - 'm' for MiBs
1232
      - 'g' for GiBs
1233
      - 't' for TiBs
1234
  @rtype: str
1235
  @return: the formatted value (with suffix)
1236

1237
  """
1238
  if units not in ('m', 'g', 't', 'h'):
1239
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1240

    
1241
  suffix = ''
1242

    
1243
  if units == 'm' or (units == 'h' and value < 1024):
1244
    if units == 'h':
1245
      suffix = 'M'
1246
    return "%d%s" % (round(value, 0), suffix)
1247

    
1248
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1249
    if units == 'h':
1250
      suffix = 'G'
1251
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1252

    
1253
  else:
1254
    if units == 'h':
1255
      suffix = 'T'
1256
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1257

    
1258

    
1259
def ParseUnit(input_string):
1260
  """Tries to extract number and scale from the given string.
1261

1262
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1263
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1264
  is always an int in MiB.
1265

1266
  """
1267
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1268
  if not m:
1269
    raise errors.UnitParseError("Invalid format")
1270

    
1271
  value = float(m.groups()[0])
1272

    
1273
  unit = m.groups()[1]
1274
  if unit:
1275
    lcunit = unit.lower()
1276
  else:
1277
    lcunit = 'm'
1278

    
1279
  if lcunit in ('m', 'mb', 'mib'):
1280
    # Value already in MiB
1281
    pass
1282

    
1283
  elif lcunit in ('g', 'gb', 'gib'):
1284
    value *= 1024
1285

    
1286
  elif lcunit in ('t', 'tb', 'tib'):
1287
    value *= 1024 * 1024
1288

    
1289
  else:
1290
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1291

    
1292
  # Make sure we round up
1293
  if int(value) < value:
1294
    value += 1
1295

    
1296
  # Round up to the next multiple of 4
1297
  value = int(value)
1298
  if value % 4:
1299
    value += 4 - value % 4
1300

    
1301
  return value
1302

    
1303

    
1304
def ParseCpuMask(cpu_mask):
1305
  """Parse a CPU mask definition and return the list of CPU IDs.
1306

1307
  CPU mask format: comma-separated list of CPU IDs
1308
  or dash-separated ID ranges
1309
  Example: "0-2,5" -> "0,1,2,5"
1310

1311
  @type cpu_mask: str
1312
  @param cpu_mask: CPU mask definition
1313
  @rtype: list of int
1314
  @return: list of CPU IDs
1315

1316
  """
1317
  if not cpu_mask:
1318
    return []
1319
  cpu_list = []
1320
  for range_def in cpu_mask.split(","):
1321
    boundaries = range_def.split("-")
1322
    n_elements = len(boundaries)
1323
    if n_elements > 2:
1324
      raise errors.ParseError("Invalid CPU ID range definition"
1325
                              " (only one hyphen allowed): %s" % range_def)
1326
    try:
1327
      lower = int(boundaries[0])
1328
    except (ValueError, TypeError), err:
1329
      raise errors.ParseError("Invalid CPU ID value for lower boundary of"
1330
                              " CPU ID range: %s" % str(err))
1331
    try:
1332
      higher = int(boundaries[-1])
1333
    except (ValueError, TypeError), err:
1334
      raise errors.ParseError("Invalid CPU ID value for higher boundary of"
1335
                              " CPU ID range: %s" % str(err))
1336
    if lower > higher:
1337
      raise errors.ParseError("Invalid CPU ID range definition"
1338
                              " (%d > %d): %s" % (lower, higher, range_def))
1339
    cpu_list.extend(range(lower, higher + 1))
1340
  return cpu_list
1341

    
1342

    
1343
def AddAuthorizedKey(file_obj, key):
1344
  """Adds an SSH public key to an authorized_keys file.
1345

1346
  @type file_obj: str or file handle
1347
  @param file_obj: path to authorized_keys file
1348
  @type key: str
1349
  @param key: string containing key
1350

1351
  """
1352
  key_fields = key.split()
1353

    
1354
  if isinstance(file_obj, basestring):
1355
    f = open(file_obj, 'a+')
1356
  else:
1357
    f = file_obj
1358

    
1359
  try:
1360
    nl = True
1361
    for line in f:
1362
      # Ignore whitespace changes
1363
      if line.split() == key_fields:
1364
        break
1365
      nl = line.endswith('\n')
1366
    else:
1367
      if not nl:
1368
        f.write("\n")
1369
      f.write(key.rstrip('\r\n'))
1370
      f.write("\n")
1371
      f.flush()
1372
  finally:
1373
    f.close()
1374

    
1375

    
1376
def RemoveAuthorizedKey(file_name, key):
1377
  """Removes an SSH public key from an authorized_keys file.
1378

1379
  @type file_name: str
1380
  @param file_name: path to authorized_keys file
1381
  @type key: str
1382
  @param key: string containing key
1383

1384
  """
1385
  key_fields = key.split()
1386

    
1387
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1388
  try:
1389
    out = os.fdopen(fd, 'w')
1390
    try:
1391
      f = open(file_name, 'r')
1392
      try:
1393
        for line in f:
1394
          # Ignore whitespace changes while comparing lines
1395
          if line.split() != key_fields:
1396
            out.write(line)
1397

    
1398
        out.flush()
1399
        os.rename(tmpname, file_name)
1400
      finally:
1401
        f.close()
1402
    finally:
1403
      out.close()
1404
  except:
1405
    RemoveFile(tmpname)
1406
    raise
1407

    
1408

    
1409
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1410
  """Sets the name of an IP address and hostname in /etc/hosts.
1411

1412
  @type file_name: str
1413
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1414
  @type ip: str
1415
  @param ip: the IP address
1416
  @type hostname: str
1417
  @param hostname: the hostname to be added
1418
  @type aliases: list
1419
  @param aliases: the list of aliases to add for the hostname
1420

1421
  """
1422
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1423
  # Ensure aliases are unique
1424
  aliases = UniqueSequence([hostname] + aliases)[1:]
1425

    
1426
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1427
  try:
1428
    out = os.fdopen(fd, 'w')
1429
    try:
1430
      f = open(file_name, 'r')
1431
      try:
1432
        for line in f:
1433
          fields = line.split()
1434
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1435
            continue
1436
          out.write(line)
1437

    
1438
        out.write("%s\t%s" % (ip, hostname))
1439
        if aliases:
1440
          out.write(" %s" % ' '.join(aliases))
1441
        out.write('\n')
1442

    
1443
        out.flush()
1444
        os.fsync(out)
1445
        os.chmod(tmpname, 0644)
1446
        os.rename(tmpname, file_name)
1447
      finally:
1448
        f.close()
1449
    finally:
1450
      out.close()
1451
  except:
1452
    RemoveFile(tmpname)
1453
    raise
1454

    
1455

    
1456
def AddHostToEtcHosts(hostname):
1457
  """Wrapper around SetEtcHostsEntry.
1458

1459
  @type hostname: str
1460
  @param hostname: a hostname that will be resolved and added to
1461
      L{constants.ETC_HOSTS}
1462

1463
  """
1464
  hi = netutils.HostInfo(name=hostname)
1465
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1466

    
1467

    
1468
def RemoveEtcHostsEntry(file_name, hostname):
1469
  """Removes a hostname from /etc/hosts.
1470

1471
  IP addresses without names are removed from the file.
1472

1473
  @type file_name: str
1474
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1475
  @type hostname: str
1476
  @param hostname: the hostname to be removed
1477

1478
  """
1479
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1480
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1481
  try:
1482
    out = os.fdopen(fd, 'w')
1483
    try:
1484
      f = open(file_name, 'r')
1485
      try:
1486
        for line in f:
1487
          fields = line.split()
1488
          if len(fields) > 1 and not fields[0].startswith('#'):
1489
            names = fields[1:]
1490
            if hostname in names:
1491
              while hostname in names:
1492
                names.remove(hostname)
1493
              if names:
1494
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1495
              continue
1496

    
1497
          out.write(line)
1498

    
1499
        out.flush()
1500
        os.fsync(out)
1501
        os.chmod(tmpname, 0644)
1502
        os.rename(tmpname, file_name)
1503
      finally:
1504
        f.close()
1505
    finally:
1506
      out.close()
1507
  except:
1508
    RemoveFile(tmpname)
1509
    raise
1510

    
1511

    
1512
def RemoveHostFromEtcHosts(hostname):
1513
  """Wrapper around RemoveEtcHostsEntry.
1514

1515
  @type hostname: str
1516
  @param hostname: hostname that will be resolved and its
1517
      full and shot name will be removed from
1518
      L{constants.ETC_HOSTS}
1519

1520
  """
1521
  hi = netutils.HostInfo(name=hostname)
1522
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1523
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1524

    
1525

    
1526
def TimestampForFilename():
1527
  """Returns the current time formatted for filenames.
1528

1529
  The format doesn't contain colons as some shells and applications them as
1530
  separators.
1531

1532
  """
1533
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1534

    
1535

    
1536
def CreateBackup(file_name):
1537
  """Creates a backup of a file.
1538

1539
  @type file_name: str
1540
  @param file_name: file to be backed up
1541
  @rtype: str
1542
  @return: the path to the newly created backup
1543
  @raise errors.ProgrammerError: for invalid file names
1544

1545
  """
1546
  if not os.path.isfile(file_name):
1547
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1548
                                file_name)
1549

    
1550
  prefix = ("%s.backup-%s." %
1551
            (os.path.basename(file_name), TimestampForFilename()))
1552
  dir_name = os.path.dirname(file_name)
1553

    
1554
  fsrc = open(file_name, 'rb')
1555
  try:
1556
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1557
    fdst = os.fdopen(fd, 'wb')
1558
    try:
1559
      logging.debug("Backing up %s at %s", file_name, backup_name)
1560
      shutil.copyfileobj(fsrc, fdst)
1561
    finally:
1562
      fdst.close()
1563
  finally:
1564
    fsrc.close()
1565

    
1566
  return backup_name
1567

    
1568

    
1569
def ShellQuote(value):
1570
  """Quotes shell argument according to POSIX.
1571

1572
  @type value: str
1573
  @param value: the argument to be quoted
1574
  @rtype: str
1575
  @return: the quoted value
1576

1577
  """
1578
  if _re_shell_unquoted.match(value):
1579
    return value
1580
  else:
1581
    return "'%s'" % value.replace("'", "'\\''")
1582

    
1583

    
1584
def ShellQuoteArgs(args):
1585
  """Quotes a list of shell arguments.
1586

1587
  @type args: list
1588
  @param args: list of arguments to be quoted
1589
  @rtype: str
1590
  @return: the quoted arguments concatenated with spaces
1591

1592
  """
1593
  return ' '.join([ShellQuote(i) for i in args])
1594

    
1595

    
1596
class ShellWriter:
1597
  """Helper class to write scripts with indentation.
1598

1599
  """
1600
  INDENT_STR = "  "
1601

    
1602
  def __init__(self, fh):
1603
    """Initializes this class.
1604

1605
    """
1606
    self._fh = fh
1607
    self._indent = 0
1608

    
1609
  def IncIndent(self):
1610
    """Increase indentation level by 1.
1611

1612
    """
1613
    self._indent += 1
1614

    
1615
  def DecIndent(self):
1616
    """Decrease indentation level by 1.
1617

1618
    """
1619
    assert self._indent > 0
1620
    self._indent -= 1
1621

    
1622
  def Write(self, txt, *args):
1623
    """Write line to output file.
1624

1625
    """
1626
    assert self._indent >= 0
1627

    
1628
    self._fh.write(self._indent * self.INDENT_STR)
1629

    
1630
    if args:
1631
      self._fh.write(txt % args)
1632
    else:
1633
      self._fh.write(txt)
1634

    
1635
    self._fh.write("\n")
1636

    
1637

    
1638
def ListVisibleFiles(path):
1639
  """Returns a list of visible files in a directory.
1640

1641
  @type path: str
1642
  @param path: the directory to enumerate
1643
  @rtype: list
1644
  @return: the list of all files not starting with a dot
1645
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1646

1647
  """
1648
  if not IsNormAbsPath(path):
1649
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1650
                                 " absolute/normalized: '%s'" % path)
1651
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1652
  return files
1653

    
1654

    
1655
def GetHomeDir(user, default=None):
1656
  """Try to get the homedir of the given user.
1657

1658
  The user can be passed either as a string (denoting the name) or as
1659
  an integer (denoting the user id). If the user is not found, the
1660
  'default' argument is returned, which defaults to None.
1661

1662
  """
1663
  try:
1664
    if isinstance(user, basestring):
1665
      result = pwd.getpwnam(user)
1666
    elif isinstance(user, (int, long)):
1667
      result = pwd.getpwuid(user)
1668
    else:
1669
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1670
                                   type(user))
1671
  except KeyError:
1672
    return default
1673
  return result.pw_dir
1674

    
1675

    
1676
def NewUUID():
1677
  """Returns a random UUID.
1678

1679
  @note: This is a Linux-specific method as it uses the /proc
1680
      filesystem.
1681
  @rtype: str
1682

1683
  """
1684
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1685

    
1686

    
1687
def GenerateSecret(numbytes=20):
1688
  """Generates a random secret.
1689

1690
  This will generate a pseudo-random secret returning an hex string
1691
  (so that it can be used where an ASCII string is needed).
1692

1693
  @param numbytes: the number of bytes which will be represented by the returned
1694
      string (defaulting to 20, the length of a SHA1 hash)
1695
  @rtype: str
1696
  @return: an hex representation of the pseudo-random sequence
1697

1698
  """
1699
  return os.urandom(numbytes).encode('hex')
1700

    
1701

    
1702
def EnsureDirs(dirs):
1703
  """Make required directories, if they don't exist.
1704

1705
  @param dirs: list of tuples (dir_name, dir_mode)
1706
  @type dirs: list of (string, integer)
1707

1708
  """
1709
  for dir_name, dir_mode in dirs:
1710
    try:
1711
      os.mkdir(dir_name, dir_mode)
1712
    except EnvironmentError, err:
1713
      if err.errno != errno.EEXIST:
1714
        raise errors.GenericError("Cannot create needed directory"
1715
                                  " '%s': %s" % (dir_name, err))
1716
    try:
1717
      os.chmod(dir_name, dir_mode)
1718
    except EnvironmentError, err:
1719
      raise errors.GenericError("Cannot change directory permissions on"
1720
                                " '%s': %s" % (dir_name, err))
1721
    if not os.path.isdir(dir_name):
1722
      raise errors.GenericError("%s is not a directory" % dir_name)
1723

    
1724

    
1725
def ReadFile(file_name, size=-1):
1726
  """Reads a file.
1727

1728
  @type size: int
1729
  @param size: Read at most size bytes (if negative, entire file)
1730
  @rtype: str
1731
  @return: the (possibly partial) content of the file
1732

1733
  """
1734
  f = open(file_name, "r")
1735
  try:
1736
    return f.read(size)
1737
  finally:
1738
    f.close()
1739

    
1740

    
1741
def WriteFile(file_name, fn=None, data=None,
1742
              mode=None, uid=-1, gid=-1,
1743
              atime=None, mtime=None, close=True,
1744
              dry_run=False, backup=False,
1745
              prewrite=None, postwrite=None):
1746
  """(Over)write a file atomically.
1747

1748
  The file_name and either fn (a function taking one argument, the
1749
  file descriptor, and which should write the data to it) or data (the
1750
  contents of the file) must be passed. The other arguments are
1751
  optional and allow setting the file mode, owner and group, and the
1752
  mtime/atime of the file.
1753

1754
  If the function doesn't raise an exception, it has succeeded and the
1755
  target file has the new contents. If the function has raised an
1756
  exception, an existing target file should be unmodified and the
1757
  temporary file should be removed.
1758

1759
  @type file_name: str
1760
  @param file_name: the target filename
1761
  @type fn: callable
1762
  @param fn: content writing function, called with
1763
      file descriptor as parameter
1764
  @type data: str
1765
  @param data: contents of the file
1766
  @type mode: int
1767
  @param mode: file mode
1768
  @type uid: int
1769
  @param uid: the owner of the file
1770
  @type gid: int
1771
  @param gid: the group of the file
1772
  @type atime: int
1773
  @param atime: a custom access time to be set on the file
1774
  @type mtime: int
1775
  @param mtime: a custom modification time to be set on the file
1776
  @type close: boolean
1777
  @param close: whether to close file after writing it
1778
  @type prewrite: callable
1779
  @param prewrite: function to be called before writing content
1780
  @type postwrite: callable
1781
  @param postwrite: function to be called after writing content
1782

1783
  @rtype: None or int
1784
  @return: None if the 'close' parameter evaluates to True,
1785
      otherwise the file descriptor
1786

1787
  @raise errors.ProgrammerError: if any of the arguments are not valid
1788

1789
  """
1790
  if not os.path.isabs(file_name):
1791
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1792
                                 " absolute: '%s'" % file_name)
1793

    
1794
  if [fn, data].count(None) != 1:
1795
    raise errors.ProgrammerError("fn or data required")
1796

    
1797
  if [atime, mtime].count(None) == 1:
1798
    raise errors.ProgrammerError("Both atime and mtime must be either"
1799
                                 " set or None")
1800

    
1801
  if backup and not dry_run and os.path.isfile(file_name):
1802
    CreateBackup(file_name)
1803

    
1804
  dir_name, base_name = os.path.split(file_name)
1805
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1806
  do_remove = True
1807
  # here we need to make sure we remove the temp file, if any error
1808
  # leaves it in place
1809
  try:
1810
    if uid != -1 or gid != -1:
1811
      os.chown(new_name, uid, gid)
1812
    if mode:
1813
      os.chmod(new_name, mode)
1814
    if callable(prewrite):
1815
      prewrite(fd)
1816
    if data is not None:
1817
      os.write(fd, data)
1818
    else:
1819
      fn(fd)
1820
    if callable(postwrite):
1821
      postwrite(fd)
1822
    os.fsync(fd)
1823
    if atime is not None and mtime is not None:
1824
      os.utime(new_name, (atime, mtime))
1825
    if not dry_run:
1826
      os.rename(new_name, file_name)
1827
      do_remove = False
1828
  finally:
1829
    if close:
1830
      os.close(fd)
1831
      result = None
1832
    else:
1833
      result = fd
1834
    if do_remove:
1835
      RemoveFile(new_name)
1836

    
1837
  return result
1838

    
1839

    
1840
def ReadOneLineFile(file_name, strict=False):
1841
  """Return the first non-empty line from a file.
1842

1843
  @type strict: boolean
1844
  @param strict: if True, abort if the file has more than one
1845
      non-empty line
1846

1847
  """
1848
  file_lines = ReadFile(file_name).splitlines()
1849
  full_lines = filter(bool, file_lines)
1850
  if not file_lines or not full_lines:
1851
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1852
  elif strict and len(full_lines) > 1:
1853
    raise errors.GenericError("Too many lines in one-liner file %s" %
1854
                              file_name)
1855
  return full_lines[0]
1856

    
1857

    
1858
def FirstFree(seq, base=0):
1859
  """Returns the first non-existing integer from seq.
1860

1861
  The seq argument should be a sorted list of positive integers. The
1862
  first time the index of an element is smaller than the element
1863
  value, the index will be returned.
1864

1865
  The base argument is used to start at a different offset,
1866
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1867

1868
  Example: C{[0, 1, 3]} will return I{2}.
1869

1870
  @type seq: sequence
1871
  @param seq: the sequence to be analyzed.
1872
  @type base: int
1873
  @param base: use this value as the base index of the sequence
1874
  @rtype: int
1875
  @return: the first non-used index in the sequence
1876

1877
  """
1878
  for idx, elem in enumerate(seq):
1879
    assert elem >= base, "Passed element is higher than base offset"
1880
    if elem > idx + base:
1881
      # idx is not used
1882
      return idx + base
1883
  return None
1884

    
1885

    
1886
def SingleWaitForFdCondition(fdobj, event, timeout):
1887
  """Waits for a condition to occur on the socket.
1888

1889
  Immediately returns at the first interruption.
1890

1891
  @type fdobj: integer or object supporting a fileno() method
1892
  @param fdobj: entity to wait for events on
1893
  @type event: integer
1894
  @param event: ORed condition (see select module)
1895
  @type timeout: float or None
1896
  @param timeout: Timeout in seconds
1897
  @rtype: int or None
1898
  @return: None for timeout, otherwise occured conditions
1899

1900
  """
1901
  check = (event | select.POLLPRI |
1902
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1903

    
1904
  if timeout is not None:
1905
    # Poller object expects milliseconds
1906
    timeout *= 1000
1907

    
1908
  poller = select.poll()
1909
  poller.register(fdobj, event)
1910
  try:
1911
    # TODO: If the main thread receives a signal and we have no timeout, we
1912
    # could wait forever. This should check a global "quit" flag or something
1913
    # every so often.
1914
    io_events = poller.poll(timeout)
1915
  except select.error, err:
1916
    if err[0] != errno.EINTR:
1917
      raise
1918
    io_events = []
1919
  if io_events and io_events[0][1] & check:
1920
    return io_events[0][1]
1921
  else:
1922
    return None
1923

    
1924

    
1925
class FdConditionWaiterHelper(object):
1926
  """Retry helper for WaitForFdCondition.
1927

1928
  This class contains the retried and wait functions that make sure
1929
  WaitForFdCondition can continue waiting until the timeout is actually
1930
  expired.
1931

1932
  """
1933

    
1934
  def __init__(self, timeout):
1935
    self.timeout = timeout
1936

    
1937
  def Poll(self, fdobj, event):
1938
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1939
    if result is None:
1940
      raise RetryAgain()
1941
    else:
1942
      return result
1943

    
1944
  def UpdateTimeout(self, timeout):
1945
    self.timeout = timeout
1946

    
1947

    
1948
def WaitForFdCondition(fdobj, event, timeout):
1949
  """Waits for a condition to occur on the socket.
1950

1951
  Retries until the timeout is expired, even if interrupted.
1952

1953
  @type fdobj: integer or object supporting a fileno() method
1954
  @param fdobj: entity to wait for events on
1955
  @type event: integer
1956
  @param event: ORed condition (see select module)
1957
  @type timeout: float or None
1958
  @param timeout: Timeout in seconds
1959
  @rtype: int or None
1960
  @return: None for timeout, otherwise occured conditions
1961

1962
  """
1963
  if timeout is not None:
1964
    retrywaiter = FdConditionWaiterHelper(timeout)
1965
    try:
1966
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1967
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1968
    except RetryTimeout:
1969
      result = None
1970
  else:
1971
    result = None
1972
    while result is None:
1973
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1974
  return result
1975

    
1976

    
1977
def UniqueSequence(seq):
1978
  """Returns a list with unique elements.
1979

1980
  Element order is preserved.
1981

1982
  @type seq: sequence
1983
  @param seq: the sequence with the source elements
1984
  @rtype: list
1985
  @return: list of unique elements from seq
1986

1987
  """
1988
  seen = set()
1989
  return [i for i in seq if i not in seen and not seen.add(i)]
1990

    
1991

    
1992
def NormalizeAndValidateMac(mac):
1993
  """Normalizes and check if a MAC address is valid.
1994

1995
  Checks whether the supplied MAC address is formally correct, only
1996
  accepts colon separated format. Normalize it to all lower.
1997

1998
  @type mac: str
1999
  @param mac: the MAC to be validated
2000
  @rtype: str
2001
  @return: returns the normalized and validated MAC.
2002

2003
  @raise errors.OpPrereqError: If the MAC isn't valid
2004

2005
  """
2006
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2007
  if not mac_check.match(mac):
2008
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2009
                               mac, errors.ECODE_INVAL)
2010

    
2011
  return mac.lower()
2012

    
2013

    
2014
def TestDelay(duration):
2015
  """Sleep for a fixed amount of time.
2016

2017
  @type duration: float
2018
  @param duration: the sleep duration
2019
  @rtype: boolean
2020
  @return: False for negative value, True otherwise
2021

2022
  """
2023
  if duration < 0:
2024
    return False, "Invalid sleep duration"
2025
  time.sleep(duration)
2026
  return True, None
2027

    
2028

    
2029
def _CloseFDNoErr(fd, retries=5):
2030
  """Close a file descriptor ignoring errors.
2031

2032
  @type fd: int
2033
  @param fd: the file descriptor
2034
  @type retries: int
2035
  @param retries: how many retries to make, in case we get any
2036
      other error than EBADF
2037

2038
  """
2039
  try:
2040
    os.close(fd)
2041
  except OSError, err:
2042
    if err.errno != errno.EBADF:
2043
      if retries > 0:
2044
        _CloseFDNoErr(fd, retries - 1)
2045
    # else either it's closed already or we're out of retries, so we
2046
    # ignore this and go on
2047

    
2048

    
2049
def CloseFDs(noclose_fds=None):
2050
  """Close file descriptors.
2051

2052
  This closes all file descriptors above 2 (i.e. except
2053
  stdin/out/err).
2054

2055
  @type noclose_fds: list or None
2056
  @param noclose_fds: if given, it denotes a list of file descriptor
2057
      that should not be closed
2058

2059
  """
2060
  # Default maximum for the number of available file descriptors.
2061
  if 'SC_OPEN_MAX' in os.sysconf_names:
2062
    try:
2063
      MAXFD = os.sysconf('SC_OPEN_MAX')
2064
      if MAXFD < 0:
2065
        MAXFD = 1024
2066
    except OSError:
2067
      MAXFD = 1024
2068
  else:
2069
    MAXFD = 1024
2070
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2071
  if (maxfd == resource.RLIM_INFINITY):
2072
    maxfd = MAXFD
2073

    
2074
  # Iterate through and close all file descriptors (except the standard ones)
2075
  for fd in range(3, maxfd):
2076
    if noclose_fds and fd in noclose_fds:
2077
      continue
2078
    _CloseFDNoErr(fd)
2079

    
2080

    
2081
def Mlockall(_ctypes=ctypes):
2082
  """Lock current process' virtual address space into RAM.
2083

2084
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2085
  see mlock(2) for more details. This function requires ctypes module.
2086

2087
  @raises errors.NoCtypesError: if ctypes module is not found
2088

2089
  """
2090
  if _ctypes is None:
2091
    raise errors.NoCtypesError()
2092

    
2093
  libc = _ctypes.cdll.LoadLibrary("libc.so.6")
2094
  if libc is None:
2095
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2096
    return
2097

    
2098
  # Some older version of the ctypes module don't have built-in functionality
2099
  # to access the errno global variable, where function error codes are stored.
2100
  # By declaring this variable as a pointer to an integer we can then access
2101
  # its value correctly, should the mlockall call fail, in order to see what
2102
  # the actual error code was.
2103
  # pylint: disable-msg=W0212
2104
  libc.__errno_location.restype = _ctypes.POINTER(_ctypes.c_int)
2105

    
2106
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2107
    # pylint: disable-msg=W0212
2108
    logging.error("Cannot set memory lock: %s",
2109
                  os.strerror(libc.__errno_location().contents.value))
2110
    return
2111

    
2112
  logging.debug("Memory lock set")
2113

    
2114

    
2115
def Daemonize(logfile, run_uid, run_gid):
2116
  """Daemonize the current process.
2117

2118
  This detaches the current process from the controlling terminal and
2119
  runs it in the background as a daemon.
2120

2121
  @type logfile: str
2122
  @param logfile: the logfile to which we should redirect stdout/stderr
2123
  @type run_uid: int
2124
  @param run_uid: Run the child under this uid
2125
  @type run_gid: int
2126
  @param run_gid: Run the child under this gid
2127
  @rtype: int
2128
  @return: the value zero
2129

2130
  """
2131
  # pylint: disable-msg=W0212
2132
  # yes, we really want os._exit
2133
  UMASK = 077
2134
  WORKDIR = "/"
2135

    
2136
  # this might fail
2137
  pid = os.fork()
2138
  if (pid == 0):  # The first child.
2139
    os.setsid()
2140
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2141
    #        make sure to check for config permission and bail out when invoked
2142
    #        with wrong user.
2143
    os.setgid(run_gid)
2144
    os.setuid(run_uid)
2145
    # this might fail
2146
    pid = os.fork() # Fork a second child.
2147
    if (pid == 0):  # The second child.
2148
      os.chdir(WORKDIR)
2149
      os.umask(UMASK)
2150
    else:
2151
      # exit() or _exit()?  See below.
2152
      os._exit(0) # Exit parent (the first child) of the second child.
2153
  else:
2154
    os._exit(0) # Exit parent of the first child.
2155

    
2156
  for fd in range(3):
2157
    _CloseFDNoErr(fd)
2158
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2159
  assert i == 0, "Can't close/reopen stdin"
2160
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2161
  assert i == 1, "Can't close/reopen stdout"
2162
  # Duplicate standard output to standard error.
2163
  os.dup2(1, 2)
2164
  return 0
2165

    
2166

    
2167
def DaemonPidFileName(name):
2168
  """Compute a ganeti pid file absolute path
2169

2170
  @type name: str
2171
  @param name: the daemon name
2172
  @rtype: str
2173
  @return: the full path to the pidfile corresponding to the given
2174
      daemon name
2175

2176
  """
2177
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2178

    
2179

    
2180
def EnsureDaemon(name):
2181
  """Check for and start daemon if not alive.
2182

2183
  """
2184
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2185
  if result.failed:
2186
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2187
                  name, result.fail_reason, result.output)
2188
    return False
2189

    
2190
  return True
2191

    
2192

    
2193
def StopDaemon(name):
2194
  """Stop daemon
2195

2196
  """
2197
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2198
  if result.failed:
2199
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2200
                  name, result.fail_reason, result.output)
2201
    return False
2202

    
2203
  return True
2204

    
2205

    
2206
def WritePidFile(name):
2207
  """Write the current process pidfile.
2208

2209
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2210

2211
  @type name: str
2212
  @param name: the daemon name to use
2213
  @raise errors.GenericError: if the pid file already exists and
2214
      points to a live process
2215

2216
  """
2217
  pid = os.getpid()
2218
  pidfilename = DaemonPidFileName(name)
2219
  if IsProcessAlive(ReadPidFile(pidfilename)):
2220
    raise errors.GenericError("%s contains a live process" % pidfilename)
2221

    
2222
  WriteFile(pidfilename, data="%d\n" % pid)
2223

    
2224

    
2225
def RemovePidFile(name):
2226
  """Remove the current process pidfile.
2227

2228
  Any errors are ignored.
2229

2230
  @type name: str
2231
  @param name: the daemon name used to derive the pidfile name
2232

2233
  """
2234
  pidfilename = DaemonPidFileName(name)
2235
  # TODO: we could check here that the file contains our pid
2236
  try:
2237
    RemoveFile(pidfilename)
2238
  except: # pylint: disable-msg=W0702
2239
    pass
2240

    
2241

    
2242
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2243
                waitpid=False):
2244
  """Kill a process given by its pid.
2245

2246
  @type pid: int
2247
  @param pid: The PID to terminate.
2248
  @type signal_: int
2249
  @param signal_: The signal to send, by default SIGTERM
2250
  @type timeout: int
2251
  @param timeout: The timeout after which, if the process is still alive,
2252
                  a SIGKILL will be sent. If not positive, no such checking
2253
                  will be done
2254
  @type waitpid: boolean
2255
  @param waitpid: If true, we should waitpid on this process after
2256
      sending signals, since it's our own child and otherwise it
2257
      would remain as zombie
2258

2259
  """
2260
  def _helper(pid, signal_, wait):
2261
    """Simple helper to encapsulate the kill/waitpid sequence"""
2262
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2263
      try:
2264
        os.waitpid(pid, os.WNOHANG)
2265
      except OSError:
2266
        pass
2267

    
2268
  if pid <= 0:
2269
    # kill with pid=0 == suicide
2270
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2271

    
2272
  if not IsProcessAlive(pid):
2273
    return
2274

    
2275
  _helper(pid, signal_, waitpid)
2276

    
2277
  if timeout <= 0:
2278
    return
2279

    
2280
  def _CheckProcess():
2281
    if not IsProcessAlive(pid):
2282
      return
2283

    
2284
    try:
2285
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2286
    except OSError:
2287
      raise RetryAgain()
2288

    
2289
    if result_pid > 0:
2290
      return
2291

    
2292
    raise RetryAgain()
2293

    
2294
  try:
2295
    # Wait up to $timeout seconds
2296
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2297
  except RetryTimeout:
2298
    pass
2299

    
2300
  if IsProcessAlive(pid):
2301
    # Kill process if it's still alive
2302
    _helper(pid, signal.SIGKILL, waitpid)
2303

    
2304

    
2305
def FindFile(name, search_path, test=os.path.exists):
2306
  """Look for a filesystem object in a given path.
2307

2308
  This is an abstract method to search for filesystem object (files,
2309
  dirs) under a given search path.
2310

2311
  @type name: str
2312
  @param name: the name to look for
2313
  @type search_path: str
2314
  @param search_path: location to start at
2315
  @type test: callable
2316
  @param test: a function taking one argument that should return True
2317
      if the a given object is valid; the default value is
2318
      os.path.exists, causing only existing files to be returned
2319
  @rtype: str or None
2320
  @return: full path to the object if found, None otherwise
2321

2322
  """
2323
  # validate the filename mask
2324
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2325
    logging.critical("Invalid value passed for external script name: '%s'",
2326
                     name)
2327
    return None
2328

    
2329
  for dir_name in search_path:
2330
    # FIXME: investigate switch to PathJoin
2331
    item_name = os.path.sep.join([dir_name, name])
2332
    # check the user test and that we're indeed resolving to the given
2333
    # basename
2334
    if test(item_name) and os.path.basename(item_name) == name:
2335
      return item_name
2336
  return None
2337

    
2338

    
2339
def CheckVolumeGroupSize(vglist, vgname, minsize):
2340
  """Checks if the volume group list is valid.
2341

2342
  The function will check if a given volume group is in the list of
2343
  volume groups and has a minimum size.
2344

2345
  @type vglist: dict
2346
  @param vglist: dictionary of volume group names and their size
2347
  @type vgname: str
2348
  @param vgname: the volume group we should check
2349
  @type minsize: int
2350
  @param minsize: the minimum size we accept
2351
  @rtype: None or str
2352
  @return: None for success, otherwise the error message
2353

2354
  """
2355
  vgsize = vglist.get(vgname, None)
2356
  if vgsize is None:
2357
    return "volume group '%s' missing" % vgname
2358
  elif vgsize < minsize:
2359
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2360
            (vgname, minsize, vgsize))
2361
  return None
2362

    
2363

    
2364
def SplitTime(value):
2365
  """Splits time as floating point number into a tuple.
2366

2367
  @param value: Time in seconds
2368
  @type value: int or float
2369
  @return: Tuple containing (seconds, microseconds)
2370

2371
  """
2372
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2373

    
2374
  assert 0 <= seconds, \
2375
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2376
  assert 0 <= microseconds <= 999999, \
2377
    "Microseconds must be 0-999999, but are %s" % microseconds
2378

    
2379
  return (int(seconds), int(microseconds))
2380

    
2381

    
2382
def MergeTime(timetuple):
2383
  """Merges a tuple into time as a floating point number.
2384

2385
  @param timetuple: Time as tuple, (seconds, microseconds)
2386
  @type timetuple: tuple
2387
  @return: Time as a floating point number expressed in seconds
2388

2389
  """
2390
  (seconds, microseconds) = timetuple
2391

    
2392
  assert 0 <= seconds, \
2393
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2394
  assert 0 <= microseconds <= 999999, \
2395
    "Microseconds must be 0-999999, but are %s" % microseconds
2396

    
2397
  return float(seconds) + (float(microseconds) * 0.000001)
2398

    
2399

    
2400
class LogFileHandler(logging.FileHandler):
2401
  """Log handler that doesn't fallback to stderr.
2402

2403
  When an error occurs while writing on the logfile, logging.FileHandler tries
2404
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2405
  the logfile. This class avoids failures reporting errors to /dev/console.
2406

2407
  """
2408
  def __init__(self, filename, mode="a", encoding=None):
2409
    """Open the specified file and use it as the stream for logging.
2410

2411
    Also open /dev/console to report errors while logging.
2412

2413
    """
2414
    logging.FileHandler.__init__(self, filename, mode, encoding)
2415
    self.console = open(constants.DEV_CONSOLE, "a")
2416

    
2417
  def handleError(self, record): # pylint: disable-msg=C0103
2418
    """Handle errors which occur during an emit() call.
2419

2420
    Try to handle errors with FileHandler method, if it fails write to
2421
    /dev/console.
2422

2423
    """
2424
    try:
2425
      logging.FileHandler.handleError(self, record)
2426
    except Exception: # pylint: disable-msg=W0703
2427
      try:
2428
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2429
      except Exception: # pylint: disable-msg=W0703
2430
        # Log handler tried everything it could, now just give up
2431
        pass
2432

    
2433

    
2434
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2435
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2436
                 console_logging=False):
2437
  """Configures the logging module.
2438

2439
  @type logfile: str
2440
  @param logfile: the filename to which we should log
2441
  @type debug: integer
2442
  @param debug: if greater than zero, enable debug messages, otherwise
2443
      only those at C{INFO} and above level
2444
  @type stderr_logging: boolean
2445
  @param stderr_logging: whether we should also log to the standard error
2446
  @type program: str
2447
  @param program: the name under which we should log messages
2448
  @type multithreaded: boolean
2449
  @param multithreaded: if True, will add the thread name to the log file
2450
  @type syslog: string
2451
  @param syslog: one of 'no', 'yes', 'only':
2452
      - if no, syslog is not used
2453
      - if yes, syslog is used (in addition to file-logging)
2454
      - if only, only syslog is used
2455
  @type console_logging: boolean
2456
  @param console_logging: if True, will use a FileHandler which falls back to
2457
      the system console if logging fails
2458
  @raise EnvironmentError: if we can't open the log file and
2459
      syslog/stderr logging is disabled
2460

2461
  """
2462
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2463
  sft = program + "[%(process)d]:"
2464
  if multithreaded:
2465
    fmt += "/%(threadName)s"
2466
    sft += " (%(threadName)s)"
2467
  if debug:
2468
    fmt += " %(module)s:%(lineno)s"
2469
    # no debug info for syslog loggers
2470
  fmt += " %(levelname)s %(message)s"
2471
  # yes, we do want the textual level, as remote syslog will probably
2472
  # lose the error level, and it's easier to grep for it
2473
  sft += " %(levelname)s %(message)s"
2474
  formatter = logging.Formatter(fmt)
2475
  sys_fmt = logging.Formatter(sft)
2476

    
2477
  root_logger = logging.getLogger("")
2478
  root_logger.setLevel(logging.NOTSET)
2479

    
2480
  # Remove all previously setup handlers
2481
  for handler in root_logger.handlers:
2482
    handler.close()
2483
    root_logger.removeHandler(handler)
2484

    
2485
  if stderr_logging:
2486
    stderr_handler = logging.StreamHandler()
2487
    stderr_handler.setFormatter(formatter)
2488
    if debug:
2489
      stderr_handler.setLevel(logging.NOTSET)
2490
    else:
2491
      stderr_handler.setLevel(logging.CRITICAL)
2492
    root_logger.addHandler(stderr_handler)
2493

    
2494
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2495
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2496
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2497
                                                    facility)
2498
    syslog_handler.setFormatter(sys_fmt)
2499
    # Never enable debug over syslog
2500
    syslog_handler.setLevel(logging.INFO)
2501
    root_logger.addHandler(syslog_handler)
2502

    
2503
  if syslog != constants.SYSLOG_ONLY:
2504
    # this can fail, if the logging directories are not setup or we have
2505
    # a permisssion problem; in this case, it's best to log but ignore
2506
    # the error if stderr_logging is True, and if false we re-raise the
2507
    # exception since otherwise we could run but without any logs at all
2508
    try:
2509
      if console_logging:
2510
        logfile_handler = LogFileHandler(logfile)
2511
      else:
2512
        logfile_handler = logging.FileHandler(logfile)
2513
      logfile_handler.setFormatter(formatter)
2514
      if debug:
2515
        logfile_handler.setLevel(logging.DEBUG)
2516
      else:
2517
        logfile_handler.setLevel(logging.INFO)
2518
      root_logger.addHandler(logfile_handler)
2519
    except EnvironmentError:
2520
      if stderr_logging or syslog == constants.SYSLOG_YES:
2521
        logging.exception("Failed to enable logging to file '%s'", logfile)
2522
      else:
2523
        # we need to re-raise the exception
2524
        raise
2525

    
2526

    
2527
def IsNormAbsPath(path):
2528
  """Check whether a path is absolute and also normalized
2529

2530
  This avoids things like /dir/../../other/path to be valid.
2531

2532
  """
2533
  return os.path.normpath(path) == path and os.path.isabs(path)
2534

    
2535

    
2536
def PathJoin(*args):
2537
  """Safe-join a list of path components.
2538

2539
  Requirements:
2540
      - the first argument must be an absolute path
2541
      - no component in the path must have backtracking (e.g. /../),
2542
        since we check for normalization at the end
2543

2544
  @param args: the path components to be joined
2545
  @raise ValueError: for invalid paths
2546

2547
  """
2548
  # ensure we're having at least one path passed in
2549
  assert args
2550
  # ensure the first component is an absolute and normalized path name
2551
  root = args[0]
2552
  if not IsNormAbsPath(root):
2553
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2554
  result = os.path.join(*args)
2555
  # ensure that the whole path is normalized
2556
  if not IsNormAbsPath(result):
2557
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2558
  # check that we're still under the original prefix
2559
  prefix = os.path.commonprefix([root, result])
2560
  if prefix != root:
2561
    raise ValueError("Error: path joining resulted in different prefix"
2562
                     " (%s != %s)" % (prefix, root))
2563
  return result
2564

    
2565

    
2566
def TailFile(fname, lines=20):
2567
  """Return the last lines from a file.
2568

2569
  @note: this function will only read and parse the last 4KB of
2570
      the file; if the lines are very long, it could be that less
2571
      than the requested number of lines are returned
2572

2573
  @param fname: the file name
2574
  @type lines: int
2575
  @param lines: the (maximum) number of lines to return
2576

2577
  """
2578
  fd = open(fname, "r")
2579
  try:
2580
    fd.seek(0, 2)
2581
    pos = fd.tell()
2582
    pos = max(0, pos-4096)
2583
    fd.seek(pos, 0)
2584
    raw_data = fd.read()
2585
  finally:
2586
    fd.close()
2587

    
2588
  rows = raw_data.splitlines()
2589
  return rows[-lines:]
2590

    
2591

    
2592
def FormatTimestampWithTZ(secs):
2593
  """Formats a Unix timestamp with the local timezone.
2594

2595
  """
2596
  return time.strftime("%F %T %Z", time.gmtime(secs))
2597

    
2598

    
2599
def _ParseAsn1Generalizedtime(value):
2600
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2601

2602
  @type value: string
2603
  @param value: ASN1 GENERALIZEDTIME timestamp
2604

2605
  """
2606
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2607
  if m:
2608
    # We have an offset
2609
    asn1time = m.group(1)
2610
    hours = int(m.group(2))
2611
    minutes = int(m.group(3))
2612
    utcoffset = (60 * hours) + minutes
2613
  else:
2614
    if not value.endswith("Z"):
2615
      raise ValueError("Missing timezone")
2616
    asn1time = value[:-1]
2617
    utcoffset = 0
2618

    
2619
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2620

    
2621
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2622

    
2623
  return calendar.timegm(tt.utctimetuple())
2624

    
2625

    
2626
def GetX509CertValidity(cert):
2627
  """Returns the validity period of the certificate.
2628

2629
  @type cert: OpenSSL.crypto.X509
2630
  @param cert: X509 certificate object
2631

2632
  """
2633
  # The get_notBefore and get_notAfter functions are only supported in
2634
  # pyOpenSSL 0.7 and above.
2635
  try:
2636
    get_notbefore_fn = cert.get_notBefore
2637
  except AttributeError:
2638
    not_before = None
2639
  else:
2640
    not_before_asn1 = get_notbefore_fn()
2641

    
2642
    if not_before_asn1 is None:
2643
      not_before = None
2644
    else:
2645
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2646

    
2647
  try:
2648
    get_notafter_fn = cert.get_notAfter
2649
  except AttributeError:
2650
    not_after = None
2651
  else:
2652
    not_after_asn1 = get_notafter_fn()
2653

    
2654
    if not_after_asn1 is None:
2655
      not_after = None
2656
    else:
2657
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2658

    
2659
  return (not_before, not_after)
2660

    
2661

    
2662
def _VerifyCertificateInner(expired, not_before, not_after, now,
2663
                            warn_days, error_days):
2664
  """Verifies certificate validity.
2665

2666
  @type expired: bool
2667
  @param expired: Whether pyOpenSSL considers the certificate as expired
2668
  @type not_before: number or None
2669
  @param not_before: Unix timestamp before which certificate is not valid
2670
  @type not_after: number or None
2671
  @param not_after: Unix timestamp after which certificate is invalid
2672
  @type now: number
2673
  @param now: Current time as Unix timestamp
2674
  @type warn_days: number or None
2675
  @param warn_days: How many days before expiration a warning should be reported
2676
  @type error_days: number or None
2677
  @param error_days: How many days before expiration an error should be reported
2678

2679
  """
2680
  if expired:
2681
    msg = "Certificate is expired"
2682

    
2683
    if not_before is not None and not_after is not None:
2684
      msg += (" (valid from %s to %s)" %
2685
              (FormatTimestampWithTZ(not_before),
2686
               FormatTimestampWithTZ(not_after)))
2687
    elif not_before is not None:
2688
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2689
    elif not_after is not None:
2690
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2691

    
2692
    return (CERT_ERROR, msg)
2693

    
2694
  elif not_before is not None and not_before > now:
2695
    return (CERT_WARNING,
2696
            "Certificate not yet valid (valid from %s)" %
2697
            FormatTimestampWithTZ(not_before))
2698

    
2699
  elif not_after is not None:
2700
    remaining_days = int((not_after - now) / (24 * 3600))
2701

    
2702
    msg = "Certificate expires in about %d days" % remaining_days
2703

    
2704
    if error_days is not None and remaining_days <= error_days:
2705
      return (CERT_ERROR, msg)
2706

    
2707
    if warn_days is not None and remaining_days <= warn_days:
2708
      return (CERT_WARNING, msg)
2709

    
2710
  return (None, None)
2711

    
2712

    
2713
def VerifyX509Certificate(cert, warn_days, error_days):
2714
  """Verifies a certificate for LUVerifyCluster.
2715

2716
  @type cert: OpenSSL.crypto.X509
2717
  @param cert: X509 certificate object
2718
  @type warn_days: number or None
2719
  @param warn_days: How many days before expiration a warning should be reported
2720
  @type error_days: number or None
2721
  @param error_days: How many days before expiration an error should be reported
2722

2723
  """
2724
  # Depending on the pyOpenSSL version, this can just return (None, None)
2725
  (not_before, not_after) = GetX509CertValidity(cert)
2726

    
2727
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2728
                                 time.time(), warn_days, error_days)
2729

    
2730

    
2731
def SignX509Certificate(cert, key, salt):
2732
  """Sign a X509 certificate.
2733

2734
  An RFC822-like signature header is added in front of the certificate.
2735

2736
  @type cert: OpenSSL.crypto.X509
2737
  @param cert: X509 certificate object
2738
  @type key: string
2739
  @param key: Key for HMAC
2740
  @type salt: string
2741
  @param salt: Salt for HMAC
2742
  @rtype: string
2743
  @return: Serialized and signed certificate in PEM format
2744

2745
  """
2746
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2747
    raise errors.GenericError("Invalid salt: %r" % salt)
2748

    
2749
  # Dumping as PEM here ensures the certificate is in a sane format
2750
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2751

    
2752
  return ("%s: %s/%s\n\n%s" %
2753
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2754
           Sha1Hmac(key, cert_pem, salt=salt),
2755
           cert_pem))
2756

    
2757

    
2758
def _ExtractX509CertificateSignature(cert_pem):
2759
  """Helper function to extract signature from X509 certificate.
2760

2761
  """
2762
  # Extract signature from original PEM data
2763
  for line in cert_pem.splitlines():
2764
    if line.startswith("---"):
2765
      break
2766

    
2767
    m = X509_SIGNATURE.match(line.strip())
2768
    if m:
2769
      return (m.group("salt"), m.group("sign"))
2770

    
2771
  raise errors.GenericError("X509 certificate signature is missing")
2772

    
2773

    
2774
def LoadSignedX509Certificate(cert_pem, key):
2775
  """Verifies a signed X509 certificate.
2776

2777
  @type cert_pem: string
2778
  @param cert_pem: Certificate in PEM format and with signature header
2779
  @type key: string
2780
  @param key: Key for HMAC
2781
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2782
  @return: X509 certificate object and salt
2783

2784
  """
2785
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2786

    
2787
  # Load certificate
2788
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2789

    
2790
  # Dump again to ensure it's in a sane format
2791
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2792

    
2793
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2794
    raise errors.GenericError("X509 certificate signature is invalid")
2795

    
2796
  return (cert, salt)
2797

    
2798

    
2799
def Sha1Hmac(key, text, salt=None):
2800
  """Calculates the HMAC-SHA1 digest of a text.
2801

2802
  HMAC is defined in RFC2104.
2803

2804
  @type key: string
2805
  @param key: Secret key
2806
  @type text: string
2807

2808
  """
2809
  if salt:
2810
    salted_text = salt + text
2811
  else:
2812
    salted_text = text
2813

    
2814
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2815

    
2816

    
2817
def VerifySha1Hmac(key, text, digest, salt=None):
2818
  """Verifies the HMAC-SHA1 digest of a text.
2819

2820
  HMAC is defined in RFC2104.
2821

2822
  @type key: string
2823
  @param key: Secret key
2824
  @type text: string
2825
  @type digest: string
2826
  @param digest: Expected digest
2827
  @rtype: bool
2828
  @return: Whether HMAC-SHA1 digest matches
2829

2830
  """
2831
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2832

    
2833

    
2834
def SafeEncode(text):
2835
  """Return a 'safe' version of a source string.
2836

2837
  This function mangles the input string and returns a version that
2838
  should be safe to display/encode as ASCII. To this end, we first
2839
  convert it to ASCII using the 'backslashreplace' encoding which
2840
  should get rid of any non-ASCII chars, and then we process it
2841
  through a loop copied from the string repr sources in the python; we
2842
  don't use string_escape anymore since that escape single quotes and
2843
  backslashes too, and that is too much; and that escaping is not
2844
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2845

2846
  @type text: str or unicode
2847
  @param text: input data
2848
  @rtype: str
2849
  @return: a safe version of text
2850

2851
  """
2852
  if isinstance(text, unicode):
2853
    # only if unicode; if str already, we handle it below
2854
    text = text.encode('ascii', 'backslashreplace')
2855
  resu = ""
2856
  for char in text:
2857
    c = ord(char)
2858
    if char  == '\t':
2859
      resu += r'\t'
2860
    elif char == '\n':
2861
      resu += r'\n'
2862
    elif char == '\r':
2863
      resu += r'\'r'
2864
    elif c < 32 or c >= 127: # non-printable
2865
      resu += "\\x%02x" % (c & 0xff)
2866
    else:
2867
      resu += char
2868
  return resu
2869

    
2870

    
2871
def UnescapeAndSplit(text, sep=","):
2872
  """Split and unescape a string based on a given separator.
2873

2874
  This function splits a string based on a separator where the
2875
  separator itself can be escape in order to be an element of the
2876
  elements. The escaping rules are (assuming coma being the
2877
  separator):
2878
    - a plain , separates the elements
2879
    - a sequence \\\\, (double backslash plus comma) is handled as a
2880
      backslash plus a separator comma
2881
    - a sequence \, (backslash plus comma) is handled as a
2882
      non-separator comma
2883

2884
  @type text: string
2885
  @param text: the string to split
2886
  @type sep: string
2887
  @param text: the separator
2888
  @rtype: string
2889
  @return: a list of strings
2890

2891
  """
2892
  # we split the list by sep (with no escaping at this stage)
2893
  slist = text.split(sep)
2894
  # next, we revisit the elements and if any of them ended with an odd
2895
  # number of backslashes, then we join it with the next
2896
  rlist = []
2897
  while slist:
2898
    e1 = slist.pop(0)
2899
    if e1.endswith("\\"):
2900
      num_b = len(e1) - len(e1.rstrip("\\"))
2901
      if num_b % 2 == 1:
2902
        e2 = slist.pop(0)
2903
        # here the backslashes remain (all), and will be reduced in
2904
        # the next step
2905
        rlist.append(e1 + sep + e2)
2906
        continue
2907
    rlist.append(e1)
2908
  # finally, replace backslash-something with something
2909
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2910
  return rlist
2911

    
2912

    
2913
def CommaJoin(names):
2914
  """Nicely join a set of identifiers.
2915

2916
  @param names: set, list or tuple
2917
  @return: a string with the formatted results
2918

2919
  """
2920
  return ", ".join([str(val) for val in names])
2921

    
2922

    
2923
def BytesToMebibyte(value):
2924
  """Converts bytes to mebibytes.
2925

2926
  @type value: int
2927
  @param value: Value in bytes
2928
  @rtype: int
2929
  @return: Value in mebibytes
2930

2931
  """
2932
  return int(round(value / (1024.0 * 1024.0), 0))
2933

    
2934

    
2935
def CalculateDirectorySize(path):
2936
  """Calculates the size of a directory recursively.
2937

2938
  @type path: string
2939
  @param path: Path to directory
2940
  @rtype: int
2941
  @return: Size in mebibytes
2942

2943
  """
2944
  size = 0
2945

    
2946
  for (curpath, _, files) in os.walk(path):
2947
    for filename in files:
2948
      st = os.lstat(PathJoin(curpath, filename))
2949
      size += st.st_size
2950

    
2951
  return BytesToMebibyte(size)
2952

    
2953

    
2954
def GetMounts(filename=constants.PROC_MOUNTS):
2955
  """Returns the list of mounted filesystems.
2956

2957
  This function is Linux-specific.
2958

2959
  @param filename: path of mounts file (/proc/mounts by default)
2960
  @rtype: list of tuples
2961
  @return: list of mount entries (device, mountpoint, fstype, options)
2962

2963
  """
2964
  # TODO(iustin): investigate non-Linux options (e.g. via mount output)
2965
  data = []
2966
  mountlines = ReadFile(filename).splitlines()
2967
  for line in mountlines:
2968
    device, mountpoint, fstype, options, _ = line.split(None, 4)
2969
    data.append((device, mountpoint, fstype, options))
2970

    
2971
  return data
2972

    
2973

    
2974
def GetFilesystemStats(path):
2975
  """Returns the total and free space on a filesystem.
2976

2977
  @type path: string
2978
  @param path: Path on filesystem to be examined
2979
  @rtype: int
2980
  @return: tuple of (Total space, Free space) in mebibytes
2981

2982
  """
2983
  st = os.statvfs(path)
2984

    
2985
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2986
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2987
  return (tsize, fsize)
2988

    
2989

    
2990
def RunInSeparateProcess(fn, *args):
2991
  """Runs a function in a separate process.
2992

2993
  Note: Only boolean return values are supported.
2994

2995
  @type fn: callable
2996
  @param fn: Function to be called
2997
  @rtype: bool
2998
  @return: Function's result
2999

3000
  """
3001
  pid = os.fork()
3002
  if pid == 0:
3003
    # Child process
3004
    try:
3005
      # In case the function uses temporary files
3006
      ResetTempfileModule()
3007

    
3008
      # Call function
3009
      result = int(bool(fn(*args)))
3010
      assert result in (0, 1)
3011
    except: # pylint: disable-msg=W0702
3012
      logging.exception("Error while calling function in separate process")
3013
      # 0 and 1 are reserved for the return value
3014
      result = 33
3015

    
3016
    os._exit(result) # pylint: disable-msg=W0212
3017

    
3018
  # Parent process
3019

    
3020
  # Avoid zombies and check exit code
3021
  (_, status) = os.waitpid(pid, 0)
3022

    
3023
  if os.WIFSIGNALED(status):
3024
    exitcode = None
3025
    signum = os.WTERMSIG(status)
3026
  else:
3027
    exitcode = os.WEXITSTATUS(status)
3028
    signum = None
3029

    
3030
  if not (exitcode in (0, 1) and signum is None):
3031
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3032
                              (exitcode, signum))
3033

    
3034
  return bool(exitcode)
3035

    
3036

    
3037
def IgnoreProcessNotFound(fn, *args, **kwargs):
3038
  """Ignores ESRCH when calling a process-related function.
3039

3040
  ESRCH is raised when a process is not found.
3041

3042
  @rtype: bool
3043
  @return: Whether process was found
3044

3045
  """
3046
  try:
3047
    fn(*args, **kwargs)
3048
  except EnvironmentError, err:
3049
    # Ignore ESRCH
3050
    if err.errno == errno.ESRCH:
3051
      return False
3052
    raise
3053

    
3054
  return True
3055

    
3056

    
3057
def IgnoreSignals(fn, *args, **kwargs):
3058
  """Tries to call a function ignoring failures due to EINTR.
3059

3060
  """
3061
  try:
3062
    return fn(*args, **kwargs)
3063
  except EnvironmentError, err:
3064
    if err.errno == errno.EINTR:
3065
      return None
3066
    else:
3067
      raise
3068
  except (select.error, socket.error), err:
3069
    # In python 2.6 and above select.error is an IOError, so it's handled
3070
    # above, in 2.5 and below it's not, and it's handled here.
3071
    if err.args and err.args[0] == errno.EINTR:
3072
      return None
3073
    else:
3074
      raise
3075

    
3076

    
3077
def LockFile(fd):
3078
  """Locks a file using POSIX locks.
3079

3080
  @type fd: int
3081
  @param fd: the file descriptor we need to lock
3082

3083
  """
3084
  try:
3085
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3086
  except IOError, err:
3087
    if err.errno == errno.EAGAIN:
3088
      raise errors.LockError("File already locked")
3089
    raise
3090

    
3091

    
3092
def FormatTime(val):
3093
  """Formats a time value.
3094

3095
  @type val: float or None
3096
  @param val: the timestamp as returned by time.time()
3097
  @return: a string value or N/A if we don't have a valid timestamp
3098

3099
  """
3100
  if val is None or not isinstance(val, (int, float)):
3101
    return "N/A"
3102
  # these two codes works on Linux, but they are not guaranteed on all
3103
  # platforms
3104
  return time.strftime("%F %T", time.localtime(val))
3105

    
3106

    
3107
def FormatSeconds(secs):
3108
  """Formats seconds for easier reading.
3109

3110
  @type secs: number
3111
  @param secs: Number of seconds
3112
  @rtype: string
3113
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3114

3115
  """
3116
  parts = []
3117

    
3118
  secs = round(secs, 0)
3119

    
3120
  if secs > 0:
3121
    # Negative values would be a bit tricky
3122
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3123
      (complete, secs) = divmod(secs, one)
3124
      if complete or parts:
3125
        parts.append("%d%s" % (complete, unit))
3126

    
3127
  parts.append("%ds" % secs)
3128

    
3129
  return " ".join(parts)
3130

    
3131

    
3132
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3133
  """Reads the watcher pause file.
3134

3135
  @type filename: string
3136
  @param filename: Path to watcher pause file
3137
  @type now: None, float or int
3138
  @param now: Current time as Unix timestamp
3139
  @type remove_after: int
3140
  @param remove_after: Remove watcher pause file after specified amount of
3141
    seconds past the pause end time
3142

3143
  """
3144
  if now is None:
3145
    now = time.time()
3146

    
3147
  try:
3148
    value = ReadFile(filename)
3149
  except IOError, err:
3150
    if err.errno != errno.ENOENT:
3151
      raise
3152
    value = None
3153

    
3154
  if value is not None:
3155
    try:
3156
      value = int(value)
3157
    except ValueError:
3158
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3159
                       " removing it"), filename)
3160
      RemoveFile(filename)
3161
      value = None
3162

    
3163
    if value is not None:
3164
      # Remove file if it's outdated
3165
      if now > (value + remove_after):
3166
        RemoveFile(filename)
3167
        value = None
3168

    
3169
      elif now > value:
3170
        value = None
3171

    
3172
  return value
3173

    
3174

    
3175
class RetryTimeout(Exception):
3176
  """Retry loop timed out.
3177

3178
  Any arguments which was passed by the retried function to RetryAgain will be
3179
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3180
  the RaiseInner helper method will reraise it.
3181

3182
  """
3183
  def RaiseInner(self):
3184
    if self.args and isinstance(self.args[0], Exception):
3185
      raise self.args[0]
3186
    else:
3187
      raise RetryTimeout(*self.args)
3188

    
3189

    
3190
class RetryAgain(Exception):
3191
  """Retry again.
3192

3193
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3194
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3195
  of the RetryTimeout() method can be used to reraise it.
3196

3197
  """
3198

    
3199

    
3200
class _RetryDelayCalculator(object):
3201
  """Calculator for increasing delays.
3202

3203
  """
3204
  __slots__ = [
3205
    "_factor",
3206
    "_limit",
3207
    "_next",
3208
    "_start",
3209
    ]
3210

    
3211
  def __init__(self, start, factor, limit):
3212
    """Initializes this class.
3213

3214
    @type start: float
3215
    @param start: Initial delay
3216
    @type factor: float
3217
    @param factor: Factor for delay increase
3218
    @type limit: float or None
3219
    @param limit: Upper limit for delay or None for no limit
3220

3221
    """
3222
    assert start > 0.0
3223
    assert factor >= 1.0
3224
    assert limit is None or limit >= 0.0
3225

    
3226
    self._start = start
3227
    self._factor = factor
3228
    self._limit = limit
3229

    
3230
    self._next = start
3231

    
3232
  def __call__(self):
3233
    """Returns current delay and calculates the next one.
3234

3235
    """
3236
    current = self._next
3237

    
3238
    # Update for next run
3239
    if self._limit is None or self._next < self._limit:
3240
      self._next = min(self._limit, self._next * self._factor)
3241

    
3242
    return current
3243

    
3244

    
3245
#: Special delay to specify whole remaining timeout
3246
RETRY_REMAINING_TIME = object()
3247

    
3248

    
3249
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3250
          _time_fn=time.time):
3251
  """Call a function repeatedly until it succeeds.
3252

3253
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3254
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3255
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3256

3257
  C{delay} can be one of the following:
3258
    - callable returning the delay length as a float
3259
    - Tuple of (start, factor, limit)
3260
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3261
      useful when overriding L{wait_fn} to wait for an external event)
3262
    - A static delay as a number (int or float)
3263

3264
  @type fn: callable
3265
  @param fn: Function to be called
3266
  @param delay: Either a callable (returning the delay), a tuple of (start,
3267
                factor, limit) (see L{_RetryDelayCalculator}),
3268
                L{RETRY_REMAINING_TIME} or a number (int or float)
3269
  @type timeout: float
3270
  @param timeout: Total timeout
3271
  @type wait_fn: callable
3272
  @param wait_fn: Waiting function
3273
  @return: Return value of function
3274

3275
  """
3276
  assert callable(fn)
3277
  assert callable(wait_fn)
3278
  assert callable(_time_fn)
3279

    
3280
  if args is None:
3281
    args = []
3282

    
3283
  end_time = _time_fn() + timeout
3284

    
3285
  if callable(delay):
3286
    # External function to calculate delay
3287
    calc_delay = delay
3288

    
3289
  elif isinstance(delay, (tuple, list)):
3290
    # Increasing delay with optional upper boundary
3291
    (start, factor, limit) = delay
3292
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3293

    
3294
  elif delay is RETRY_REMAINING_TIME:
3295
    # Always use the remaining time
3296
    calc_delay = None
3297

    
3298
  else:
3299
    # Static delay
3300
    calc_delay = lambda: delay
3301

    
3302
  assert calc_delay is None or callable(calc_delay)
3303

    
3304
  while True:
3305
    retry_args = []
3306
    try:
3307
      # pylint: disable-msg=W0142
3308
      return fn(*args)
3309
    except RetryAgain, err:
3310
      retry_args = err.args
3311
    except RetryTimeout:
3312
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3313
                                   " handle RetryTimeout")
3314

    
3315
    remaining_time = end_time - _time_fn()
3316

    
3317
    if remaining_time < 0.0:
3318
      # pylint: disable-msg=W0142
3319
      raise RetryTimeout(*retry_args)
3320

    
3321
    assert remaining_time >= 0.0
3322

    
3323
    if calc_delay is None:
3324
      wait_fn(remaining_time)
3325
    else:
3326
      current_delay = calc_delay()
3327
      if current_delay > 0.0:
3328
        wait_fn(current_delay)
3329

    
3330

    
3331
def GetClosedTempfile(*args, **kwargs):
3332
  """Creates a temporary file and returns its path.
3333

3334
  """
3335
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3336
  _CloseFDNoErr(fd)
3337
  return path
3338

    
3339

    
3340
def GenerateSelfSignedX509Cert(common_name, validity):
3341
  """Generates a self-signed X509 certificate.
3342

3343
  @type common_name: string
3344
  @param common_name: commonName value
3345
  @type validity: int
3346
  @param validity: Validity for certificate in seconds
3347

3348
  """
3349
  # Create private and public key
3350
  key = OpenSSL.crypto.PKey()
3351
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3352

    
3353
  # Create self-signed certificate
3354
  cert = OpenSSL.crypto.X509()
3355
  if common_name:
3356
    cert.get_subject().CN = common_name
3357
  cert.set_serial_number(1)
3358
  cert.gmtime_adj_notBefore(0)
3359
  cert.gmtime_adj_notAfter(validity)
3360
  cert.set_issuer(cert.get_subject())
3361
  cert.set_pubkey(key)
3362
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3363

    
3364
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3365
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3366

    
3367
  return (key_pem, cert_pem)
3368

    
3369

    
3370
def GenerateSelfSignedSslCert(filename, common_name=constants.X509_CERT_CN,
3371
                              validity=constants.X509_CERT_DEFAULT_VALIDITY):
3372
  """Legacy function to generate self-signed X509 certificate.
3373

3374
  @type filename: str
3375
  @param filename: path to write certificate to
3376
  @type common_name: string
3377
  @param common_name: commonName value
3378
  @type validity: int
3379
  @param validity: validity of certificate in number of days
3380

3381
  """
3382
  # TODO: Investigate using the cluster name instead of X505_CERT_CN for
3383
  # common_name, as cluster-renames are very seldom, and it'd be nice if RAPI
3384
  # and node daemon certificates have the proper Subject/Issuer.
3385
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(common_name,
3386
                                                   validity * 24 * 60 * 60)
3387

    
3388
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3389

    
3390

    
3391
class FileLock(object):
3392
  """Utility class for file locks.
3393

3394
  """
3395
  def __init__(self, fd, filename):
3396
    """Constructor for FileLock.
3397

3398
    @type fd: file
3399
    @param fd: File object
3400
    @type filename: str
3401
    @param filename: Path of the file opened at I{fd}
3402

3403
    """
3404
    self.fd = fd
3405
    self.filename = filename
3406

    
3407
  @classmethod
3408
  def Open(cls, filename):
3409
    """Creates and opens a file to be used as a file-based lock.
3410

3411
    @type filename: string
3412
    @param filename: path to the file to be locked
3413

3414
    """
3415
    # Using "os.open" is necessary to allow both opening existing file
3416
    # read/write and creating if not existing. Vanilla "open" will truncate an
3417
    # existing file -or- allow creating if not existing.
3418
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3419
               filename)
3420

    
3421
  def __del__(self):
3422
    self.Close()
3423

    
3424
  def Close(self):
3425
    """Close the file and release the lock.
3426

3427
    """
3428
    if hasattr(self, "fd") and self.fd:
3429
      self.fd.close()
3430
      self.fd = None
3431

    
3432
  def _flock(self, flag, blocking, timeout, errmsg):
3433
    """Wrapper for fcntl.flock.
3434

3435
    @type flag: int
3436
    @param flag: operation flag
3437
    @type blocking: bool
3438
    @param blocking: whether the operation should be done in blocking mode.
3439
    @type timeout: None or float
3440
    @param timeout: for how long the operation should be retried (implies
3441
                    non-blocking mode).
3442
    @type errmsg: string
3443
    @param errmsg: error message in case operation fails.
3444

3445
    """
3446
    assert self.fd, "Lock was closed"
3447
    assert timeout is None or timeout >= 0, \
3448
      "If specified, timeout must be positive"
3449
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3450

    
3451
    # When a timeout is used, LOCK_NB must always be set
3452
    if not (timeout is None and blocking):
3453
      flag |= fcntl.LOCK_NB
3454

    
3455
    if timeout is None:
3456
      self._Lock(self.fd, flag, timeout)
3457
    else:
3458
      try:
3459
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3460
              args=(self.fd, flag, timeout))
3461
      except RetryTimeout:
3462
        raise errors.LockError(errmsg)
3463

    
3464
  @staticmethod
3465
  def _Lock(fd, flag, timeout):
3466
    try:
3467
      fcntl.flock(fd, flag)
3468
    except IOError, err:
3469
      if timeout is not None and err.errno == errno.EAGAIN:
3470
        raise RetryAgain()
3471

    
3472
      logging.exception("fcntl.flock failed")
3473
      raise
3474

    
3475
  def Exclusive(self, blocking=False, timeout=None):
3476
    """Locks the file in exclusive mode.
3477

3478
    @type blocking: boolean
3479
    @param blocking: whether to block and wait until we
3480
        can lock the file or return immediately
3481
    @type timeout: int or None
3482
    @param timeout: if not None, the duration to wait for the lock
3483
        (in blocking mode)
3484

3485
    """
3486
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3487
                "Failed to lock %s in exclusive mode" % self.filename)
3488

    
3489
  def Shared(self, blocking=False, timeout=None):
3490
    """Locks the file in shared mode.
3491

3492
    @type blocking: boolean
3493
    @param blocking: whether to block and wait until we
3494
        can lock the file or return immediately
3495
    @type timeout: int or None
3496
    @param timeout: if not None, the duration to wait for the lock
3497
        (in blocking mode)
3498

3499
    """
3500
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3501
                "Failed to lock %s in shared mode" % self.filename)
3502

    
3503
  def Unlock(self, blocking=True, timeout=None):
3504
    """Unlocks the file.
3505

3506
    According to C{flock(2)}, unlocking can also be a nonblocking
3507
    operation::
3508

3509
      To make a non-blocking request, include LOCK_NB with any of the above
3510
      operations.
3511

3512
    @type blocking: boolean
3513
    @param blocking: whether to block and wait until we
3514
        can lock the file or return immediately
3515
    @type timeout: int or None
3516
    @param timeout: if not None, the duration to wait for the lock
3517
        (in blocking mode)
3518

3519
    """
3520
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3521
                "Failed to unlock %s" % self.filename)
3522

    
3523

    
3524
class LineSplitter:
3525
  """Splits data chunks into lines separated by newline.
3526

3527
  Instances provide a file-like interface.
3528

3529
  """
3530
  def __init__(self, line_fn, *args):
3531
    """Initializes this class.
3532

3533
    @type line_fn: callable
3534
    @param line_fn: Function called for each line, first parameter is line
3535
    @param args: Extra arguments for L{line_fn}
3536

3537
    """
3538
    assert callable(line_fn)
3539

    
3540
    if args:
3541
      # Python 2.4 doesn't have functools.partial yet
3542
      self._line_fn = \
3543
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3544
    else:
3545
      self._line_fn = line_fn
3546

    
3547
    self._lines = collections.deque()
3548
    self._buffer = ""
3549

    
3550
  def write(self, data):
3551
    parts = (self._buffer + data).split("\n")
3552
    self._buffer = parts.pop()
3553
    self._lines.extend(parts)
3554

    
3555
  def flush(self):
3556
    while self._lines:
3557
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3558

    
3559
  def close(self):
3560
    self.flush()
3561
    if self._buffer:
3562
      self._line_fn(self._buffer)
3563

    
3564

    
3565
def SignalHandled(signums):
3566
  """Signal Handled decoration.
3567

3568
  This special decorator installs a signal handler and then calls the target
3569
  function. The function must accept a 'signal_handlers' keyword argument,
3570
  which will contain a dict indexed by signal number, with SignalHandler
3571
  objects as values.
3572

3573
  The decorator can be safely stacked with iself, to handle multiple signals
3574
  with different handlers.
3575

3576
  @type signums: list
3577
  @param signums: signals to intercept
3578

3579
  """
3580
  def wrap(fn):
3581
    def sig_function(*args, **kwargs):
3582
      assert 'signal_handlers' not in kwargs or \
3583
             kwargs['signal_handlers'] is None or \
3584
             isinstance(kwargs['signal_handlers'], dict), \
3585
             "Wrong signal_handlers parameter in original function call"
3586
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3587
        signal_handlers = kwargs['signal_handlers']
3588
      else:
3589
        signal_handlers = {}
3590
        kwargs['signal_handlers'] = signal_handlers
3591
      sighandler = SignalHandler(signums)
3592
      try:
3593
        for sig in signums:
3594
          signal_handlers[sig] = sighandler
3595
        return fn(*args, **kwargs)
3596
      finally:
3597
        sighandler.Reset()
3598
    return sig_function
3599
  return wrap
3600

    
3601

    
3602
class SignalWakeupFd(object):
3603
  try:
3604
    # This is only supported in Python 2.5 and above (some distributions
3605
    # backported it to Python 2.4)
3606
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3607
  except AttributeError:
3608
    # Not supported
3609
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3610
      return -1
3611
  else:
3612
    def _SetWakeupFd(self, fd):
3613
      return self._set_wakeup_fd_fn(fd)
3614

    
3615
  def __init__(self):
3616
    """Initializes this class.
3617

3618
    """
3619
    (read_fd, write_fd) = os.pipe()
3620

    
3621
    # Once these succeeded, the file descriptors will be closed automatically.
3622
    # Buffer size 0 is important, otherwise .read() with a specified length
3623
    # might buffer data and the file descriptors won't be marked readable.
3624
    self._read_fh = os.fdopen(read_fd, "r", 0)
3625
    self._write_fh = os.fdopen(write_fd, "w", 0)
3626

    
3627
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3628

    
3629
    # Utility functions
3630
    self.fileno = self._read_fh.fileno
3631
    self.read = self._read_fh.read
3632

    
3633
  def Reset(self):
3634
    """Restores the previous wakeup file descriptor.
3635

3636
    """
3637
    if hasattr(self, "_previous") and self._previous is not None:
3638
      self._SetWakeupFd(self._previous)
3639
      self._previous = None
3640

    
3641
  def Notify(self):
3642
    """Notifies the wakeup file descriptor.
3643

3644
    """
3645
    self._write_fh.write("\0")
3646

    
3647
  def __del__(self):
3648
    """Called before object deletion.
3649

3650
    """
3651
    self.Reset()
3652

    
3653

    
3654
class SignalHandler(object):
3655
  """Generic signal handler class.
3656

3657
  It automatically restores the original handler when deconstructed or
3658
  when L{Reset} is called. You can either pass your own handler
3659
  function in or query the L{called} attribute to detect whether the
3660
  signal was sent.
3661

3662
  @type signum: list
3663
  @ivar signum: the signals we handle
3664
  @type called: boolean
3665
  @ivar called: tracks whether any of the signals have been raised
3666

3667
  """
3668
  def __init__(self, signum, handler_fn=None, wakeup=None):
3669
    """Constructs a new SignalHandler instance.
3670

3671
    @type signum: int or list of ints
3672
    @param signum: Single signal number or set of signal numbers
3673
    @type handler_fn: callable
3674
    @param handler_fn: Signal handling function
3675

3676
    """
3677
    assert handler_fn is None or callable(handler_fn)
3678

    
3679
    self.signum = set(signum)
3680
    self.called = False
3681

    
3682
    self._handler_fn = handler_fn
3683
    self._wakeup = wakeup
3684

    
3685
    self._previous = {}
3686
    try:
3687
      for signum in self.signum:
3688
        # Setup handler
3689
        prev_handler = signal.signal(signum, self._HandleSignal)
3690
        try:
3691
          self._previous[signum] = prev_handler
3692
        except:
3693
          # Restore previous handler
3694
          signal.signal(signum, prev_handler)
3695
          raise
3696
    except:
3697
      # Reset all handlers
3698
      self.Reset()
3699
      # Here we have a race condition: a handler may have already been called,
3700
      # but there's not much we can do about it at this point.
3701
      raise
3702

    
3703
  def __del__(self):
3704
    self.Reset()
3705

    
3706
  def Reset(self):
3707
    """Restore previous handler.
3708

3709
    This will reset all the signals to their previous handlers.
3710

3711
    """
3712
    for signum, prev_handler in self._previous.items():
3713
      signal.signal(signum, prev_handler)
3714
      # If successful, remove from dict
3715
      del self._previous[signum]
3716

    
3717
  def Clear(self):
3718
    """Unsets the L{called} flag.
3719

3720
    This function can be used in case a signal may arrive several times.
3721

3722
    """
3723
    self.called = False
3724

    
3725
  def _HandleSignal(self, signum, frame):
3726
    """Actual signal handling function.
3727

3728
    """
3729
    # This is not nice and not absolutely atomic, but it appears to be the only
3730
    # solution in Python -- there are no atomic types.
3731
    self.called = True
3732

    
3733
    if self._wakeup:
3734
      # Notify whoever is interested in signals
3735
      self._wakeup.Notify()
3736

    
3737
    if self._handler_fn:
3738
      self._handler_fn(signum, frame)
3739

    
3740

    
3741
class FieldSet(object):
3742
  """A simple field set.
3743

3744
  Among the features are:
3745
    - checking if a string is among a list of static string or regex objects
3746
    - checking if a whole list of string matches
3747
    - returning the matching groups from a regex match
3748

3749
  Internally, all fields are held as regular expression objects.
3750

3751
  """
3752
  def __init__(self, *items):
3753
    self.items = [re.compile("^%s$" % value) for value in items]
3754

    
3755
  def Extend(self, other_set):
3756
    """Extend the field set with the items from another one"""
3757
    self.items.extend(other_set.items)
3758

    
3759
  def Matches(self, field):
3760
    """Checks if a field matches the current set
3761

3762
    @type field: str
3763
    @param field: the string to match
3764
    @return: either None or a regular expression match object
3765

3766
    """
3767
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3768
      return m
3769
    return None
3770

    
3771
  def NonMatching(self, items):
3772
    """Returns the list of fields not matching the current set
3773

3774
    @type items: list
3775
    @param items: the list of fields to check
3776
    @rtype: list
3777
    @return: list of non-matching fields
3778

3779
    """
3780
    return [val for val in items if not self.Matches(val)]