Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ b43dcc5a

History | View | Annotate | Download (102.1 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52

    
53
from cStringIO import StringIO
54

    
55
try:
56
  # pylint: disable-msg=F0401
57
  import ctypes
58
except ImportError:
59
  ctypes = None
60

    
61
from ganeti import errors
62
from ganeti import constants
63
from ganeti import compat
64

    
65

    
66
_locksheld = []
67
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
68

    
69
debug_locks = False
70

    
71
#: when set to True, L{RunCmd} is disabled
72
no_fork = False
73

    
74
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
75

    
76
HEX_CHAR_RE = r"[a-zA-Z0-9]"
77
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
78
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
79
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
80
                             HEX_CHAR_RE, HEX_CHAR_RE),
81
                            re.S | re.I)
82

    
83
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
84

    
85
# Certificate verification results
86
(CERT_WARNING,
87
 CERT_ERROR) = range(1, 3)
88

    
89
# Flags for mlockall() (from bits/mman.h)
90
_MCL_CURRENT = 1
91
_MCL_FUTURE = 2
92

    
93

    
94
class RunResult(object):
95
  """Holds the result of running external programs.
96

97
  @type exit_code: int
98
  @ivar exit_code: the exit code of the program, or None (if the program
99
      didn't exit())
100
  @type signal: int or None
101
  @ivar signal: the signal that caused the program to finish, or None
102
      (if the program wasn't terminated by a signal)
103
  @type stdout: str
104
  @ivar stdout: the standard output of the program
105
  @type stderr: str
106
  @ivar stderr: the standard error of the program
107
  @type failed: boolean
108
  @ivar failed: True in case the program was
109
      terminated by a signal or exited with a non-zero exit code
110
  @ivar fail_reason: a string detailing the termination reason
111

112
  """
113
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
114
               "failed", "fail_reason", "cmd"]
115

    
116

    
117
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
118
    self.cmd = cmd
119
    self.exit_code = exit_code
120
    self.signal = signal_
121
    self.stdout = stdout
122
    self.stderr = stderr
123
    self.failed = (signal_ is not None or exit_code != 0)
124

    
125
    if self.signal is not None:
126
      self.fail_reason = "terminated by signal %s" % self.signal
127
    elif self.exit_code is not None:
128
      self.fail_reason = "exited with exit code %s" % self.exit_code
129
    else:
130
      self.fail_reason = "unable to determine termination reason"
131

    
132
    if self.failed:
133
      logging.debug("Command '%s' failed (%s); output: %s",
134
                    self.cmd, self.fail_reason, self.output)
135

    
136
  def _GetOutput(self):
137
    """Returns the combined stdout and stderr for easier usage.
138

139
    """
140
    return self.stdout + self.stderr
141

    
142
  output = property(_GetOutput, None, None, "Return full output")
143

    
144

    
145
def _BuildCmdEnvironment(env, reset):
146
  """Builds the environment for an external program.
147

148
  """
149
  if reset:
150
    cmd_env = {}
151
  else:
152
    cmd_env = os.environ.copy()
153
    cmd_env["LC_ALL"] = "C"
154

    
155
  if env is not None:
156
    cmd_env.update(env)
157

    
158
  return cmd_env
159

    
160

    
161
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
162
  """Execute a (shell) command.
163

164
  The command should not read from its standard input, as it will be
165
  closed.
166

167
  @type cmd: string or list
168
  @param cmd: Command to run
169
  @type env: dict
170
  @param env: Additional environment variables
171
  @type output: str
172
  @param output: if desired, the output of the command can be
173
      saved in a file instead of the RunResult instance; this
174
      parameter denotes the file name (if not None)
175
  @type cwd: string
176
  @param cwd: if specified, will be used as the working
177
      directory for the command; the default will be /
178
  @type reset_env: boolean
179
  @param reset_env: whether to reset or keep the default os environment
180
  @rtype: L{RunResult}
181
  @return: RunResult instance
182
  @raise errors.ProgrammerError: if we call this when forks are disabled
183

184
  """
185
  if no_fork:
186
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
187

    
188
  if isinstance(cmd, basestring):
189
    strcmd = cmd
190
    shell = True
191
  else:
192
    cmd = [str(val) for val in cmd]
193
    strcmd = ShellQuoteArgs(cmd)
194
    shell = False
195

    
196
  if output:
197
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
198
  else:
199
    logging.debug("RunCmd %s", strcmd)
200

    
201
  cmd_env = _BuildCmdEnvironment(env, reset_env)
202

    
203
  try:
204
    if output is None:
205
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
206
    else:
207
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
208
      out = err = ""
209
  except OSError, err:
210
    if err.errno == errno.ENOENT:
211
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
212
                               (strcmd, err))
213
    else:
214
      raise
215

    
216
  if status >= 0:
217
    exitcode = status
218
    signal_ = None
219
  else:
220
    exitcode = None
221
    signal_ = -status
222

    
223
  return RunResult(exitcode, signal_, out, err, strcmd)
224

    
225

    
226
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
227
                pidfile=None):
228
  """Start a daemon process after forking twice.
229

230
  @type cmd: string or list
231
  @param cmd: Command to run
232
  @type env: dict
233
  @param env: Additional environment variables
234
  @type cwd: string
235
  @param cwd: Working directory for the program
236
  @type output: string
237
  @param output: Path to file in which to save the output
238
  @type output_fd: int
239
  @param output_fd: File descriptor for output
240
  @type pidfile: string
241
  @param pidfile: Process ID file
242
  @rtype: int
243
  @return: Daemon process ID
244
  @raise errors.ProgrammerError: if we call this when forks are disabled
245

246
  """
247
  if no_fork:
248
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
249
                                 " disabled")
250

    
251
  if output and not (bool(output) ^ (output_fd is not None)):
252
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
253
                                 " specified")
254

    
255
  if isinstance(cmd, basestring):
256
    cmd = ["/bin/sh", "-c", cmd]
257

    
258
  strcmd = ShellQuoteArgs(cmd)
259

    
260
  if output:
261
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
262
  else:
263
    logging.debug("StartDaemon %s", strcmd)
264

    
265
  cmd_env = _BuildCmdEnvironment(env, False)
266

    
267
  # Create pipe for sending PID back
268
  (pidpipe_read, pidpipe_write) = os.pipe()
269
  try:
270
    try:
271
      # Create pipe for sending error messages
272
      (errpipe_read, errpipe_write) = os.pipe()
273
      try:
274
        try:
275
          # First fork
276
          pid = os.fork()
277
          if pid == 0:
278
            try:
279
              # Child process, won't return
280
              _StartDaemonChild(errpipe_read, errpipe_write,
281
                                pidpipe_read, pidpipe_write,
282
                                cmd, cmd_env, cwd,
283
                                output, output_fd, pidfile)
284
            finally:
285
              # Well, maybe child process failed
286
              os._exit(1) # pylint: disable-msg=W0212
287
        finally:
288
          _CloseFDNoErr(errpipe_write)
289

    
290
        # Wait for daemon to be started (or an error message to arrive) and read
291
        # up to 100 KB as an error message
292
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
293
      finally:
294
        _CloseFDNoErr(errpipe_read)
295
    finally:
296
      _CloseFDNoErr(pidpipe_write)
297

    
298
    # Read up to 128 bytes for PID
299
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
300
  finally:
301
    _CloseFDNoErr(pidpipe_read)
302

    
303
  # Try to avoid zombies by waiting for child process
304
  try:
305
    os.waitpid(pid, 0)
306
  except OSError:
307
    pass
308

    
309
  if errormsg:
310
    raise errors.OpExecError("Error when starting daemon process: %r" %
311
                             errormsg)
312

    
313
  try:
314
    return int(pidtext)
315
  except (ValueError, TypeError), err:
316
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
317
                             (pidtext, err))
318

    
319

    
320
def _StartDaemonChild(errpipe_read, errpipe_write,
321
                      pidpipe_read, pidpipe_write,
322
                      args, env, cwd,
323
                      output, fd_output, pidfile):
324
  """Child process for starting daemon.
325

326
  """
327
  try:
328
    # Close parent's side
329
    _CloseFDNoErr(errpipe_read)
330
    _CloseFDNoErr(pidpipe_read)
331

    
332
    # First child process
333
    os.chdir("/")
334
    os.umask(077)
335
    os.setsid()
336

    
337
    # And fork for the second time
338
    pid = os.fork()
339
    if pid != 0:
340
      # Exit first child process
341
      os._exit(0) # pylint: disable-msg=W0212
342

    
343
    # Make sure pipe is closed on execv* (and thereby notifies original process)
344
    SetCloseOnExecFlag(errpipe_write, True)
345

    
346
    # List of file descriptors to be left open
347
    noclose_fds = [errpipe_write]
348

    
349
    # Open PID file
350
    if pidfile:
351
      try:
352
        # TODO: Atomic replace with another locked file instead of writing into
353
        # it after creating
354
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
355

    
356
        # Lock the PID file (and fail if not possible to do so). Any code
357
        # wanting to send a signal to the daemon should try to lock the PID
358
        # file before reading it. If acquiring the lock succeeds, the daemon is
359
        # no longer running and the signal should not be sent.
360
        LockFile(fd_pidfile)
361

    
362
        os.write(fd_pidfile, "%d\n" % os.getpid())
363
      except Exception, err:
364
        raise Exception("Creating and locking PID file failed: %s" % err)
365

    
366
      # Keeping the file open to hold the lock
367
      noclose_fds.append(fd_pidfile)
368

    
369
      SetCloseOnExecFlag(fd_pidfile, False)
370
    else:
371
      fd_pidfile = None
372

    
373
    # Open /dev/null
374
    fd_devnull = os.open(os.devnull, os.O_RDWR)
375

    
376
    assert not output or (bool(output) ^ (fd_output is not None))
377

    
378
    if fd_output is not None:
379
      pass
380
    elif output:
381
      # Open output file
382
      try:
383
        # TODO: Implement flag to set append=yes/no
384
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
385
      except EnvironmentError, err:
386
        raise Exception("Opening output file failed: %s" % err)
387
    else:
388
      fd_output = fd_devnull
389

    
390
    # Redirect standard I/O
391
    os.dup2(fd_devnull, 0)
392
    os.dup2(fd_output, 1)
393
    os.dup2(fd_output, 2)
394

    
395
    # Send daemon PID to parent
396
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
397

    
398
    # Close all file descriptors except stdio and error message pipe
399
    CloseFDs(noclose_fds=noclose_fds)
400

    
401
    # Change working directory
402
    os.chdir(cwd)
403

    
404
    if env is None:
405
      os.execvp(args[0], args)
406
    else:
407
      os.execvpe(args[0], args, env)
408
  except: # pylint: disable-msg=W0702
409
    try:
410
      # Report errors to original process
411
      buf = str(sys.exc_info()[1])
412

    
413
      RetryOnSignal(os.write, errpipe_write, buf)
414
    except: # pylint: disable-msg=W0702
415
      # Ignore errors in error handling
416
      pass
417

    
418
  os._exit(1) # pylint: disable-msg=W0212
419

    
420

    
421
def _RunCmdPipe(cmd, env, via_shell, cwd):
422
  """Run a command and return its output.
423

424
  @type  cmd: string or list
425
  @param cmd: Command to run
426
  @type env: dict
427
  @param env: The environment to use
428
  @type via_shell: bool
429
  @param via_shell: if we should run via the shell
430
  @type cwd: string
431
  @param cwd: the working directory for the program
432
  @rtype: tuple
433
  @return: (out, err, status)
434

435
  """
436
  poller = select.poll()
437
  child = subprocess.Popen(cmd, shell=via_shell,
438
                           stderr=subprocess.PIPE,
439
                           stdout=subprocess.PIPE,
440
                           stdin=subprocess.PIPE,
441
                           close_fds=True, env=env,
442
                           cwd=cwd)
443

    
444
  child.stdin.close()
445
  poller.register(child.stdout, select.POLLIN)
446
  poller.register(child.stderr, select.POLLIN)
447
  out = StringIO()
448
  err = StringIO()
449
  fdmap = {
450
    child.stdout.fileno(): (out, child.stdout),
451
    child.stderr.fileno(): (err, child.stderr),
452
    }
453
  for fd in fdmap:
454
    SetNonblockFlag(fd, True)
455

    
456
  while fdmap:
457
    pollresult = RetryOnSignal(poller.poll)
458

    
459
    for fd, event in pollresult:
460
      if event & select.POLLIN or event & select.POLLPRI:
461
        data = fdmap[fd][1].read()
462
        # no data from read signifies EOF (the same as POLLHUP)
463
        if not data:
464
          poller.unregister(fd)
465
          del fdmap[fd]
466
          continue
467
        fdmap[fd][0].write(data)
468
      if (event & select.POLLNVAL or event & select.POLLHUP or
469
          event & select.POLLERR):
470
        poller.unregister(fd)
471
        del fdmap[fd]
472

    
473
  out = out.getvalue()
474
  err = err.getvalue()
475

    
476
  status = child.wait()
477
  return out, err, status
478

    
479

    
480
def _RunCmdFile(cmd, env, via_shell, output, cwd):
481
  """Run a command and save its output to a file.
482

483
  @type  cmd: string or list
484
  @param cmd: Command to run
485
  @type env: dict
486
  @param env: The environment to use
487
  @type via_shell: bool
488
  @param via_shell: if we should run via the shell
489
  @type output: str
490
  @param output: the filename in which to save the output
491
  @type cwd: string
492
  @param cwd: the working directory for the program
493
  @rtype: int
494
  @return: the exit status
495

496
  """
497
  fh = open(output, "a")
498
  try:
499
    child = subprocess.Popen(cmd, shell=via_shell,
500
                             stderr=subprocess.STDOUT,
501
                             stdout=fh,
502
                             stdin=subprocess.PIPE,
503
                             close_fds=True, env=env,
504
                             cwd=cwd)
505

    
506
    child.stdin.close()
507
    status = child.wait()
508
  finally:
509
    fh.close()
510
  return status
511

    
512

    
513
def SetCloseOnExecFlag(fd, enable):
514
  """Sets or unsets the close-on-exec flag on a file descriptor.
515

516
  @type fd: int
517
  @param fd: File descriptor
518
  @type enable: bool
519
  @param enable: Whether to set or unset it.
520

521
  """
522
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
523

    
524
  if enable:
525
    flags |= fcntl.FD_CLOEXEC
526
  else:
527
    flags &= ~fcntl.FD_CLOEXEC
528

    
529
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
530

    
531

    
532
def SetNonblockFlag(fd, enable):
533
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
534

535
  @type fd: int
536
  @param fd: File descriptor
537
  @type enable: bool
538
  @param enable: Whether to set or unset it
539

540
  """
541
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
542

    
543
  if enable:
544
    flags |= os.O_NONBLOCK
545
  else:
546
    flags &= ~os.O_NONBLOCK
547

    
548
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
549

    
550

    
551
def RetryOnSignal(fn, *args, **kwargs):
552
  """Calls a function again if it failed due to EINTR.
553

554
  """
555
  while True:
556
    try:
557
      return fn(*args, **kwargs)
558
    except EnvironmentError, err:
559
      if err.errno != errno.EINTR:
560
        raise
561
    except (socket.error, select.error), err:
562
      # In python 2.6 and above select.error is an IOError, so it's handled
563
      # above, in 2.5 and below it's not, and it's handled here.
564
      if not (err.args and err.args[0] == errno.EINTR):
565
        raise
566

    
567

    
568
def RunParts(dir_name, env=None, reset_env=False):
569
  """Run Scripts or programs in a directory
570

571
  @type dir_name: string
572
  @param dir_name: absolute path to a directory
573
  @type env: dict
574
  @param env: The environment to use
575
  @type reset_env: boolean
576
  @param reset_env: whether to reset or keep the default os environment
577
  @rtype: list of tuples
578
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
579

580
  """
581
  rr = []
582

    
583
  try:
584
    dir_contents = ListVisibleFiles(dir_name)
585
  except OSError, err:
586
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
587
    return rr
588

    
589
  for relname in sorted(dir_contents):
590
    fname = PathJoin(dir_name, relname)
591
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
592
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
593
      rr.append((relname, constants.RUNPARTS_SKIP, None))
594
    else:
595
      try:
596
        result = RunCmd([fname], env=env, reset_env=reset_env)
597
      except Exception, err: # pylint: disable-msg=W0703
598
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
599
      else:
600
        rr.append((relname, constants.RUNPARTS_RUN, result))
601

    
602
  return rr
603

    
604

    
605
def RemoveFile(filename):
606
  """Remove a file ignoring some errors.
607

608
  Remove a file, ignoring non-existing ones or directories. Other
609
  errors are passed.
610

611
  @type filename: str
612
  @param filename: the file to be removed
613

614
  """
615
  try:
616
    os.unlink(filename)
617
  except OSError, err:
618
    if err.errno not in (errno.ENOENT, errno.EISDIR):
619
      raise
620

    
621

    
622
def RemoveDir(dirname):
623
  """Remove an empty directory.
624

625
  Remove a directory, ignoring non-existing ones.
626
  Other errors are passed. This includes the case,
627
  where the directory is not empty, so it can't be removed.
628

629
  @type dirname: str
630
  @param dirname: the empty directory to be removed
631

632
  """
633
  try:
634
    os.rmdir(dirname)
635
  except OSError, err:
636
    if err.errno != errno.ENOENT:
637
      raise
638

    
639

    
640
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
641
  """Renames a file.
642

643
  @type old: string
644
  @param old: Original path
645
  @type new: string
646
  @param new: New path
647
  @type mkdir: bool
648
  @param mkdir: Whether to create target directory if it doesn't exist
649
  @type mkdir_mode: int
650
  @param mkdir_mode: Mode for newly created directories
651

652
  """
653
  try:
654
    return os.rename(old, new)
655
  except OSError, err:
656
    # In at least one use case of this function, the job queue, directory
657
    # creation is very rare. Checking for the directory before renaming is not
658
    # as efficient.
659
    if mkdir and err.errno == errno.ENOENT:
660
      # Create directory and try again
661
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
662

    
663
      return os.rename(old, new)
664

    
665
    raise
666

    
667

    
668
def Makedirs(path, mode=0750):
669
  """Super-mkdir; create a leaf directory and all intermediate ones.
670

671
  This is a wrapper around C{os.makedirs} adding error handling not implemented
672
  before Python 2.5.
673

674
  """
675
  try:
676
    os.makedirs(path, mode)
677
  except OSError, err:
678
    # Ignore EEXIST. This is only handled in os.makedirs as included in
679
    # Python 2.5 and above.
680
    if err.errno != errno.EEXIST or not os.path.exists(path):
681
      raise
682

    
683

    
684
def ResetTempfileModule():
685
  """Resets the random name generator of the tempfile module.
686

687
  This function should be called after C{os.fork} in the child process to
688
  ensure it creates a newly seeded random generator. Otherwise it would
689
  generate the same random parts as the parent process. If several processes
690
  race for the creation of a temporary file, this could lead to one not getting
691
  a temporary name.
692

693
  """
694
  # pylint: disable-msg=W0212
695
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
696
    tempfile._once_lock.acquire()
697
    try:
698
      # Reset random name generator
699
      tempfile._name_sequence = None
700
    finally:
701
      tempfile._once_lock.release()
702
  else:
703
    logging.critical("The tempfile module misses at least one of the"
704
                     " '_once_lock' and '_name_sequence' attributes")
705

    
706

    
707
def _FingerprintFile(filename):
708
  """Compute the fingerprint of a file.
709

710
  If the file does not exist, a None will be returned
711
  instead.
712

713
  @type filename: str
714
  @param filename: the filename to checksum
715
  @rtype: str
716
  @return: the hex digest of the sha checksum of the contents
717
      of the file
718

719
  """
720
  if not (os.path.exists(filename) and os.path.isfile(filename)):
721
    return None
722

    
723
  f = open(filename)
724

    
725
  fp = compat.sha1_hash()
726
  while True:
727
    data = f.read(4096)
728
    if not data:
729
      break
730

    
731
    fp.update(data)
732

    
733
  return fp.hexdigest()
734

    
735

    
736
def FingerprintFiles(files):
737
  """Compute fingerprints for a list of files.
738

739
  @type files: list
740
  @param files: the list of filename to fingerprint
741
  @rtype: dict
742
  @return: a dictionary filename: fingerprint, holding only
743
      existing files
744

745
  """
746
  ret = {}
747

    
748
  for filename in files:
749
    cksum = _FingerprintFile(filename)
750
    if cksum:
751
      ret[filename] = cksum
752

    
753
  return ret
754

    
755

    
756
def ForceDictType(target, key_types, allowed_values=None):
757
  """Force the values of a dict to have certain types.
758

759
  @type target: dict
760
  @param target: the dict to update
761
  @type key_types: dict
762
  @param key_types: dict mapping target dict keys to types
763
                    in constants.ENFORCEABLE_TYPES
764
  @type allowed_values: list
765
  @keyword allowed_values: list of specially allowed values
766

767
  """
768
  if allowed_values is None:
769
    allowed_values = []
770

    
771
  if not isinstance(target, dict):
772
    msg = "Expected dictionary, got '%s'" % target
773
    raise errors.TypeEnforcementError(msg)
774

    
775
  for key in target:
776
    if key not in key_types:
777
      msg = "Unknown key '%s'" % key
778
      raise errors.TypeEnforcementError(msg)
779

    
780
    if target[key] in allowed_values:
781
      continue
782

    
783
    ktype = key_types[key]
784
    if ktype not in constants.ENFORCEABLE_TYPES:
785
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
786
      raise errors.ProgrammerError(msg)
787

    
788
    if ktype in (constants.VTYPE_STRING, constants.VTYPE_MAYBE_STRING):
789
      if target[key] is None and ktype == constants.VTYPE_MAYBE_STRING:
790
        pass
791
      elif not isinstance(target[key], basestring):
792
        if isinstance(target[key], bool) and not target[key]:
793
          target[key] = ''
794
        else:
795
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
796
          raise errors.TypeEnforcementError(msg)
797
    elif ktype == constants.VTYPE_BOOL:
798
      if isinstance(target[key], basestring) and target[key]:
799
        if target[key].lower() == constants.VALUE_FALSE:
800
          target[key] = False
801
        elif target[key].lower() == constants.VALUE_TRUE:
802
          target[key] = True
803
        else:
804
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
805
          raise errors.TypeEnforcementError(msg)
806
      elif target[key]:
807
        target[key] = True
808
      else:
809
        target[key] = False
810
    elif ktype == constants.VTYPE_SIZE:
811
      try:
812
        target[key] = ParseUnit(target[key])
813
      except errors.UnitParseError, err:
814
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
815
              (key, target[key], err)
816
        raise errors.TypeEnforcementError(msg)
817
    elif ktype == constants.VTYPE_INT:
818
      try:
819
        target[key] = int(target[key])
820
      except (ValueError, TypeError):
821
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
822
        raise errors.TypeEnforcementError(msg)
823

    
824

    
825
def _GetProcStatusPath(pid):
826
  """Returns the path for a PID's proc status file.
827

828
  @type pid: int
829
  @param pid: Process ID
830
  @rtype: string
831

832
  """
833
  return "/proc/%d/status" % pid
834

    
835

    
836
def IsProcessAlive(pid):
837
  """Check if a given pid exists on the system.
838

839
  @note: zombie status is not handled, so zombie processes
840
      will be returned as alive
841
  @type pid: int
842
  @param pid: the process ID to check
843
  @rtype: boolean
844
  @return: True if the process exists
845

846
  """
847
  def _TryStat(name):
848
    try:
849
      os.stat(name)
850
      return True
851
    except EnvironmentError, err:
852
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
853
        return False
854
      elif err.errno == errno.EINVAL:
855
        raise RetryAgain(err)
856
      raise
857

    
858
  assert isinstance(pid, int), "pid must be an integer"
859
  if pid <= 0:
860
    return False
861

    
862
  # /proc in a multiprocessor environment can have strange behaviors.
863
  # Retry the os.stat a few times until we get a good result.
864
  try:
865
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
866
                 args=[_GetProcStatusPath(pid)])
867
  except RetryTimeout, err:
868
    err.RaiseInner()
869

    
870

    
871
def _ParseSigsetT(sigset):
872
  """Parse a rendered sigset_t value.
873

874
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
875
  function.
876

877
  @type sigset: string
878
  @param sigset: Rendered signal set from /proc/$pid/status
879
  @rtype: set
880
  @return: Set of all enabled signal numbers
881

882
  """
883
  result = set()
884

    
885
  signum = 0
886
  for ch in reversed(sigset):
887
    chv = int(ch, 16)
888

    
889
    # The following could be done in a loop, but it's easier to read and
890
    # understand in the unrolled form
891
    if chv & 1:
892
      result.add(signum + 1)
893
    if chv & 2:
894
      result.add(signum + 2)
895
    if chv & 4:
896
      result.add(signum + 3)
897
    if chv & 8:
898
      result.add(signum + 4)
899

    
900
    signum += 4
901

    
902
  return result
903

    
904

    
905
def _GetProcStatusField(pstatus, field):
906
  """Retrieves a field from the contents of a proc status file.
907

908
  @type pstatus: string
909
  @param pstatus: Contents of /proc/$pid/status
910
  @type field: string
911
  @param field: Name of field whose value should be returned
912
  @rtype: string
913

914
  """
915
  for line in pstatus.splitlines():
916
    parts = line.split(":", 1)
917

    
918
    if len(parts) < 2 or parts[0] != field:
919
      continue
920

    
921
    return parts[1].strip()
922

    
923
  return None
924

    
925

    
926
def IsProcessHandlingSignal(pid, signum, status_path=None):
927
  """Checks whether a process is handling a signal.
928

929
  @type pid: int
930
  @param pid: Process ID
931
  @type signum: int
932
  @param signum: Signal number
933
  @rtype: bool
934

935
  """
936
  if status_path is None:
937
    status_path = _GetProcStatusPath(pid)
938

    
939
  try:
940
    proc_status = ReadFile(status_path)
941
  except EnvironmentError, err:
942
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
943
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
944
      return False
945
    raise
946

    
947
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
948
  if sigcgt is None:
949
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
950

    
951
  # Now check whether signal is handled
952
  return signum in _ParseSigsetT(sigcgt)
953

    
954

    
955
def ReadPidFile(pidfile):
956
  """Read a pid from a file.
957

958
  @type  pidfile: string
959
  @param pidfile: path to the file containing the pid
960
  @rtype: int
961
  @return: The process id, if the file exists and contains a valid PID,
962
           otherwise 0
963

964
  """
965
  try:
966
    raw_data = ReadOneLineFile(pidfile)
967
  except EnvironmentError, err:
968
    if err.errno != errno.ENOENT:
969
      logging.exception("Can't read pid file")
970
    return 0
971

    
972
  try:
973
    pid = int(raw_data)
974
  except (TypeError, ValueError), err:
975
    logging.info("Can't parse pid file contents", exc_info=True)
976
    return 0
977

    
978
  return pid
979

    
980

    
981
def ReadLockedPidFile(path):
982
  """Reads a locked PID file.
983

984
  This can be used together with L{StartDaemon}.
985

986
  @type path: string
987
  @param path: Path to PID file
988
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
989

990
  """
991
  try:
992
    fd = os.open(path, os.O_RDONLY)
993
  except EnvironmentError, err:
994
    if err.errno == errno.ENOENT:
995
      # PID file doesn't exist
996
      return None
997
    raise
998

    
999
  try:
1000
    try:
1001
      # Try to acquire lock
1002
      LockFile(fd)
1003
    except errors.LockError:
1004
      # Couldn't lock, daemon is running
1005
      return int(os.read(fd, 100))
1006
  finally:
1007
    os.close(fd)
1008

    
1009
  return None
1010

    
1011

    
1012
def MatchNameComponent(key, name_list, case_sensitive=True):
1013
  """Try to match a name against a list.
1014

1015
  This function will try to match a name like test1 against a list
1016
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1017
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1018
  not I{'test1.ex'}. A multiple match will be considered as no match
1019
  at all (e.g. I{'test1'} against C{['test1.example.com',
1020
  'test1.example.org']}), except when the key fully matches an entry
1021
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1022

1023
  @type key: str
1024
  @param key: the name to be searched
1025
  @type name_list: list
1026
  @param name_list: the list of strings against which to search the key
1027
  @type case_sensitive: boolean
1028
  @param case_sensitive: whether to provide a case-sensitive match
1029

1030
  @rtype: None or str
1031
  @return: None if there is no match I{or} if there are multiple matches,
1032
      otherwise the element from the list which matches
1033

1034
  """
1035
  if key in name_list:
1036
    return key
1037

    
1038
  re_flags = 0
1039
  if not case_sensitive:
1040
    re_flags |= re.IGNORECASE
1041
    key = key.upper()
1042
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1043
  names_filtered = []
1044
  string_matches = []
1045
  for name in name_list:
1046
    if mo.match(name) is not None:
1047
      names_filtered.append(name)
1048
      if not case_sensitive and key == name.upper():
1049
        string_matches.append(name)
1050

    
1051
  if len(string_matches) == 1:
1052
    return string_matches[0]
1053
  if len(names_filtered) == 1:
1054
    return names_filtered[0]
1055
  return None
1056

    
1057

    
1058
def ValidateServiceName(name):
1059
  """Validate the given service name.
1060

1061
  @type name: number or string
1062
  @param name: Service name or port specification
1063

1064
  """
1065
  try:
1066
    numport = int(name)
1067
  except (ValueError, TypeError):
1068
    # Non-numeric service name
1069
    valid = _VALID_SERVICE_NAME_RE.match(name)
1070
  else:
1071
    # Numeric port (protocols other than TCP or UDP might need adjustments
1072
    # here)
1073
    valid = (numport >= 0 and numport < (1 << 16))
1074

    
1075
  if not valid:
1076
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1077
                               errors.ECODE_INVAL)
1078

    
1079
  return name
1080

    
1081

    
1082
def ListVolumeGroups():
1083
  """List volume groups and their size
1084

1085
  @rtype: dict
1086
  @return:
1087
       Dictionary with keys volume name and values
1088
       the size of the volume
1089

1090
  """
1091
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1092
  result = RunCmd(command)
1093
  retval = {}
1094
  if result.failed:
1095
    return retval
1096

    
1097
  for line in result.stdout.splitlines():
1098
    try:
1099
      name, size = line.split()
1100
      size = int(float(size))
1101
    except (IndexError, ValueError), err:
1102
      logging.error("Invalid output from vgs (%s): %s", err, line)
1103
      continue
1104

    
1105
    retval[name] = size
1106

    
1107
  return retval
1108

    
1109

    
1110
def BridgeExists(bridge):
1111
  """Check whether the given bridge exists in the system
1112

1113
  @type bridge: str
1114
  @param bridge: the bridge name to check
1115
  @rtype: boolean
1116
  @return: True if it does
1117

1118
  """
1119
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1120

    
1121

    
1122
def NiceSort(name_list):
1123
  """Sort a list of strings based on digit and non-digit groupings.
1124

1125
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1126
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1127
  'a11']}.
1128

1129
  The sort algorithm breaks each name in groups of either only-digits
1130
  or no-digits. Only the first eight such groups are considered, and
1131
  after that we just use what's left of the string.
1132

1133
  @type name_list: list
1134
  @param name_list: the names to be sorted
1135
  @rtype: list
1136
  @return: a copy of the name list sorted with our algorithm
1137

1138
  """
1139
  _SORTER_BASE = "(\D+|\d+)"
1140
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1141
                                                  _SORTER_BASE, _SORTER_BASE,
1142
                                                  _SORTER_BASE, _SORTER_BASE,
1143
                                                  _SORTER_BASE, _SORTER_BASE)
1144
  _SORTER_RE = re.compile(_SORTER_FULL)
1145
  _SORTER_NODIGIT = re.compile("^\D*$")
1146
  def _TryInt(val):
1147
    """Attempts to convert a variable to integer."""
1148
    if val is None or _SORTER_NODIGIT.match(val):
1149
      return val
1150
    rval = int(val)
1151
    return rval
1152

    
1153
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1154
             for name in name_list]
1155
  to_sort.sort()
1156
  return [tup[1] for tup in to_sort]
1157

    
1158

    
1159
def TryConvert(fn, val):
1160
  """Try to convert a value ignoring errors.
1161

1162
  This function tries to apply function I{fn} to I{val}. If no
1163
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1164
  the result, else it will return the original value. Any other
1165
  exceptions are propagated to the caller.
1166

1167
  @type fn: callable
1168
  @param fn: function to apply to the value
1169
  @param val: the value to be converted
1170
  @return: The converted value if the conversion was successful,
1171
      otherwise the original value.
1172

1173
  """
1174
  try:
1175
    nv = fn(val)
1176
  except (ValueError, TypeError):
1177
    nv = val
1178
  return nv
1179

    
1180

    
1181
def IsValidShellParam(word):
1182
  """Verifies is the given word is safe from the shell's p.o.v.
1183

1184
  This means that we can pass this to a command via the shell and be
1185
  sure that it doesn't alter the command line and is passed as such to
1186
  the actual command.
1187

1188
  Note that we are overly restrictive here, in order to be on the safe
1189
  side.
1190

1191
  @type word: str
1192
  @param word: the word to check
1193
  @rtype: boolean
1194
  @return: True if the word is 'safe'
1195

1196
  """
1197
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1198

    
1199

    
1200
def BuildShellCmd(template, *args):
1201
  """Build a safe shell command line from the given arguments.
1202

1203
  This function will check all arguments in the args list so that they
1204
  are valid shell parameters (i.e. they don't contain shell
1205
  metacharacters). If everything is ok, it will return the result of
1206
  template % args.
1207

1208
  @type template: str
1209
  @param template: the string holding the template for the
1210
      string formatting
1211
  @rtype: str
1212
  @return: the expanded command line
1213

1214
  """
1215
  for word in args:
1216
    if not IsValidShellParam(word):
1217
      raise errors.ProgrammerError("Shell argument '%s' contains"
1218
                                   " invalid characters" % word)
1219
  return template % args
1220

    
1221

    
1222
def FormatUnit(value, units):
1223
  """Formats an incoming number of MiB with the appropriate unit.
1224

1225
  @type value: int
1226
  @param value: integer representing the value in MiB (1048576)
1227
  @type units: char
1228
  @param units: the type of formatting we should do:
1229
      - 'h' for automatic scaling
1230
      - 'm' for MiBs
1231
      - 'g' for GiBs
1232
      - 't' for TiBs
1233
  @rtype: str
1234
  @return: the formatted value (with suffix)
1235

1236
  """
1237
  if units not in ('m', 'g', 't', 'h'):
1238
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1239

    
1240
  suffix = ''
1241

    
1242
  if units == 'm' or (units == 'h' and value < 1024):
1243
    if units == 'h':
1244
      suffix = 'M'
1245
    return "%d%s" % (round(value, 0), suffix)
1246

    
1247
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1248
    if units == 'h':
1249
      suffix = 'G'
1250
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1251

    
1252
  else:
1253
    if units == 'h':
1254
      suffix = 'T'
1255
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1256

    
1257

    
1258
def ParseUnit(input_string):
1259
  """Tries to extract number and scale from the given string.
1260

1261
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1262
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1263
  is always an int in MiB.
1264

1265
  """
1266
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1267
  if not m:
1268
    raise errors.UnitParseError("Invalid format")
1269

    
1270
  value = float(m.groups()[0])
1271

    
1272
  unit = m.groups()[1]
1273
  if unit:
1274
    lcunit = unit.lower()
1275
  else:
1276
    lcunit = 'm'
1277

    
1278
  if lcunit in ('m', 'mb', 'mib'):
1279
    # Value already in MiB
1280
    pass
1281

    
1282
  elif lcunit in ('g', 'gb', 'gib'):
1283
    value *= 1024
1284

    
1285
  elif lcunit in ('t', 'tb', 'tib'):
1286
    value *= 1024 * 1024
1287

    
1288
  else:
1289
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1290

    
1291
  # Make sure we round up
1292
  if int(value) < value:
1293
    value += 1
1294

    
1295
  # Round up to the next multiple of 4
1296
  value = int(value)
1297
  if value % 4:
1298
    value += 4 - value % 4
1299

    
1300
  return value
1301

    
1302

    
1303
def ParseCpuMask(cpu_mask):
1304
  """Parse a CPU mask definition and return the list of CPU IDs.
1305

1306
  CPU mask format: comma-separated list of CPU IDs
1307
  or dash-separated ID ranges
1308
  Example: "0-2,5" -> "0,1,2,5"
1309

1310
  @type cpu_mask: str
1311
  @param cpu_mask: CPU mask definition
1312
  @rtype: list of int
1313
  @return: list of CPU IDs
1314

1315
  """
1316
  if not cpu_mask:
1317
    return []
1318
  cpu_list = []
1319
  for range_def in cpu_mask.split(","):
1320
    boundaries = range_def.split("-")
1321
    n_elements = len(boundaries)
1322
    if n_elements > 2:
1323
      raise errors.ParseError("Invalid CPU ID range definition"
1324
                              " (only one hyphen allowed): %s" % range_def)
1325
    try:
1326
      lower = int(boundaries[0])
1327
    except (ValueError, TypeError), err:
1328
      raise errors.ParseError("Invalid CPU ID value for lower boundary of"
1329
                              " CPU ID range: %s" % str(err))
1330
    try:
1331
      higher = int(boundaries[-1])
1332
    except (ValueError, TypeError), err:
1333
      raise errors.ParseError("Invalid CPU ID value for higher boundary of"
1334
                              " CPU ID range: %s" % str(err))
1335
    if lower > higher:
1336
      raise errors.ParseError("Invalid CPU ID range definition"
1337
                              " (%d > %d): %s" % (lower, higher, range_def))
1338
    cpu_list.extend(range(lower, higher + 1))
1339
  return cpu_list
1340

    
1341

    
1342
def AddAuthorizedKey(file_obj, key):
1343
  """Adds an SSH public key to an authorized_keys file.
1344

1345
  @type file_obj: str or file handle
1346
  @param file_obj: path to authorized_keys file
1347
  @type key: str
1348
  @param key: string containing key
1349

1350
  """
1351
  key_fields = key.split()
1352

    
1353
  if isinstance(file_obj, basestring):
1354
    f = open(file_obj, 'a+')
1355
  else:
1356
    f = file_obj
1357

    
1358
  try:
1359
    nl = True
1360
    for line in f:
1361
      # Ignore whitespace changes
1362
      if line.split() == key_fields:
1363
        break
1364
      nl = line.endswith('\n')
1365
    else:
1366
      if not nl:
1367
        f.write("\n")
1368
      f.write(key.rstrip('\r\n'))
1369
      f.write("\n")
1370
      f.flush()
1371
  finally:
1372
    f.close()
1373

    
1374

    
1375
def RemoveAuthorizedKey(file_name, key):
1376
  """Removes an SSH public key from an authorized_keys file.
1377

1378
  @type file_name: str
1379
  @param file_name: path to authorized_keys file
1380
  @type key: str
1381
  @param key: string containing key
1382

1383
  """
1384
  key_fields = key.split()
1385

    
1386
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1387
  try:
1388
    out = os.fdopen(fd, 'w')
1389
    try:
1390
      f = open(file_name, 'r')
1391
      try:
1392
        for line in f:
1393
          # Ignore whitespace changes while comparing lines
1394
          if line.split() != key_fields:
1395
            out.write(line)
1396

    
1397
        out.flush()
1398
        os.rename(tmpname, file_name)
1399
      finally:
1400
        f.close()
1401
    finally:
1402
      out.close()
1403
  except:
1404
    RemoveFile(tmpname)
1405
    raise
1406

    
1407

    
1408
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1409
  """Sets the name of an IP address and hostname in /etc/hosts.
1410

1411
  @type file_name: str
1412
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1413
  @type ip: str
1414
  @param ip: the IP address
1415
  @type hostname: str
1416
  @param hostname: the hostname to be added
1417
  @type aliases: list
1418
  @param aliases: the list of aliases to add for the hostname
1419

1420
  """
1421
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1422
  # Ensure aliases are unique
1423
  aliases = UniqueSequence([hostname] + aliases)[1:]
1424

    
1425
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1426
  try:
1427
    out = os.fdopen(fd, 'w')
1428
    try:
1429
      f = open(file_name, 'r')
1430
      try:
1431
        for line in f:
1432
          fields = line.split()
1433
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1434
            continue
1435
          out.write(line)
1436

    
1437
        out.write("%s\t%s" % (ip, hostname))
1438
        if aliases:
1439
          out.write(" %s" % ' '.join(aliases))
1440
        out.write('\n')
1441

    
1442
        out.flush()
1443
        os.fsync(out)
1444
        os.chmod(tmpname, 0644)
1445
        os.rename(tmpname, file_name)
1446
      finally:
1447
        f.close()
1448
    finally:
1449
      out.close()
1450
  except:
1451
    RemoveFile(tmpname)
1452
    raise
1453

    
1454

    
1455
def AddHostToEtcHosts(hostname):
1456
  """Wrapper around SetEtcHostsEntry.
1457

1458
  @type hostname: str
1459
  @param hostname: a hostname that will be resolved and added to
1460
      L{constants.ETC_HOSTS}
1461

1462
  """
1463
  SetEtcHostsEntry(constants.ETC_HOSTS, hostname.ip, hostname.name,
1464
                   [hostname.name.split(".")[0]])
1465

    
1466

    
1467
def RemoveEtcHostsEntry(file_name, hostname):
1468
  """Removes a hostname from /etc/hosts.
1469

1470
  IP addresses without names are removed from the file.
1471

1472
  @type file_name: str
1473
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1474
  @type hostname: str
1475
  @param hostname: the hostname to be removed
1476

1477
  """
1478
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1479
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1480
  try:
1481
    out = os.fdopen(fd, 'w')
1482
    try:
1483
      f = open(file_name, 'r')
1484
      try:
1485
        for line in f:
1486
          fields = line.split()
1487
          if len(fields) > 1 and not fields[0].startswith('#'):
1488
            names = fields[1:]
1489
            if hostname in names:
1490
              while hostname in names:
1491
                names.remove(hostname)
1492
              if names:
1493
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1494
              continue
1495

    
1496
          out.write(line)
1497

    
1498
        out.flush()
1499
        os.fsync(out)
1500
        os.chmod(tmpname, 0644)
1501
        os.rename(tmpname, file_name)
1502
      finally:
1503
        f.close()
1504
    finally:
1505
      out.close()
1506
  except:
1507
    RemoveFile(tmpname)
1508
    raise
1509

    
1510

    
1511
def RemoveHostFromEtcHosts(hostname):
1512
  """Wrapper around RemoveEtcHostsEntry.
1513

1514
  @type hostname: str
1515
  @param hostname: hostname that will be resolved and its
1516
      full and shot name will be removed from
1517
      L{constants.ETC_HOSTS}
1518

1519
  """
1520
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname)
1521
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname.split(".")[0])
1522

    
1523

    
1524
def TimestampForFilename():
1525
  """Returns the current time formatted for filenames.
1526

1527
  The format doesn't contain colons as some shells and applications them as
1528
  separators.
1529

1530
  """
1531
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1532

    
1533

    
1534
def CreateBackup(file_name):
1535
  """Creates a backup of a file.
1536

1537
  @type file_name: str
1538
  @param file_name: file to be backed up
1539
  @rtype: str
1540
  @return: the path to the newly created backup
1541
  @raise errors.ProgrammerError: for invalid file names
1542

1543
  """
1544
  if not os.path.isfile(file_name):
1545
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1546
                                file_name)
1547

    
1548
  prefix = ("%s.backup-%s." %
1549
            (os.path.basename(file_name), TimestampForFilename()))
1550
  dir_name = os.path.dirname(file_name)
1551

    
1552
  fsrc = open(file_name, 'rb')
1553
  try:
1554
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1555
    fdst = os.fdopen(fd, 'wb')
1556
    try:
1557
      logging.debug("Backing up %s at %s", file_name, backup_name)
1558
      shutil.copyfileobj(fsrc, fdst)
1559
    finally:
1560
      fdst.close()
1561
  finally:
1562
    fsrc.close()
1563

    
1564
  return backup_name
1565

    
1566

    
1567
def ShellQuote(value):
1568
  """Quotes shell argument according to POSIX.
1569

1570
  @type value: str
1571
  @param value: the argument to be quoted
1572
  @rtype: str
1573
  @return: the quoted value
1574

1575
  """
1576
  if _re_shell_unquoted.match(value):
1577
    return value
1578
  else:
1579
    return "'%s'" % value.replace("'", "'\\''")
1580

    
1581

    
1582
def ShellQuoteArgs(args):
1583
  """Quotes a list of shell arguments.
1584

1585
  @type args: list
1586
  @param args: list of arguments to be quoted
1587
  @rtype: str
1588
  @return: the quoted arguments concatenated with spaces
1589

1590
  """
1591
  return ' '.join([ShellQuote(i) for i in args])
1592

    
1593

    
1594
class ShellWriter:
1595
  """Helper class to write scripts with indentation.
1596

1597
  """
1598
  INDENT_STR = "  "
1599

    
1600
  def __init__(self, fh):
1601
    """Initializes this class.
1602

1603
    """
1604
    self._fh = fh
1605
    self._indent = 0
1606

    
1607
  def IncIndent(self):
1608
    """Increase indentation level by 1.
1609

1610
    """
1611
    self._indent += 1
1612

    
1613
  def DecIndent(self):
1614
    """Decrease indentation level by 1.
1615

1616
    """
1617
    assert self._indent > 0
1618
    self._indent -= 1
1619

    
1620
  def Write(self, txt, *args):
1621
    """Write line to output file.
1622

1623
    """
1624
    assert self._indent >= 0
1625

    
1626
    self._fh.write(self._indent * self.INDENT_STR)
1627

    
1628
    if args:
1629
      self._fh.write(txt % args)
1630
    else:
1631
      self._fh.write(txt)
1632

    
1633
    self._fh.write("\n")
1634

    
1635

    
1636
def ListVisibleFiles(path):
1637
  """Returns a list of visible files in a directory.
1638

1639
  @type path: str
1640
  @param path: the directory to enumerate
1641
  @rtype: list
1642
  @return: the list of all files not starting with a dot
1643
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1644

1645
  """
1646
  if not IsNormAbsPath(path):
1647
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1648
                                 " absolute/normalized: '%s'" % path)
1649
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1650
  return files
1651

    
1652

    
1653
def GetHomeDir(user, default=None):
1654
  """Try to get the homedir of the given user.
1655

1656
  The user can be passed either as a string (denoting the name) or as
1657
  an integer (denoting the user id). If the user is not found, the
1658
  'default' argument is returned, which defaults to None.
1659

1660
  """
1661
  try:
1662
    if isinstance(user, basestring):
1663
      result = pwd.getpwnam(user)
1664
    elif isinstance(user, (int, long)):
1665
      result = pwd.getpwuid(user)
1666
    else:
1667
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1668
                                   type(user))
1669
  except KeyError:
1670
    return default
1671
  return result.pw_dir
1672

    
1673

    
1674
def NewUUID():
1675
  """Returns a random UUID.
1676

1677
  @note: This is a Linux-specific method as it uses the /proc
1678
      filesystem.
1679
  @rtype: str
1680

1681
  """
1682
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1683

    
1684

    
1685
def GenerateSecret(numbytes=20):
1686
  """Generates a random secret.
1687

1688
  This will generate a pseudo-random secret returning an hex string
1689
  (so that it can be used where an ASCII string is needed).
1690

1691
  @param numbytes: the number of bytes which will be represented by the returned
1692
      string (defaulting to 20, the length of a SHA1 hash)
1693
  @rtype: str
1694
  @return: an hex representation of the pseudo-random sequence
1695

1696
  """
1697
  return os.urandom(numbytes).encode('hex')
1698

    
1699

    
1700
def EnsureDirs(dirs):
1701
  """Make required directories, if they don't exist.
1702

1703
  @param dirs: list of tuples (dir_name, dir_mode)
1704
  @type dirs: list of (string, integer)
1705

1706
  """
1707
  for dir_name, dir_mode in dirs:
1708
    try:
1709
      os.mkdir(dir_name, dir_mode)
1710
    except EnvironmentError, err:
1711
      if err.errno != errno.EEXIST:
1712
        raise errors.GenericError("Cannot create needed directory"
1713
                                  " '%s': %s" % (dir_name, err))
1714
    try:
1715
      os.chmod(dir_name, dir_mode)
1716
    except EnvironmentError, err:
1717
      raise errors.GenericError("Cannot change directory permissions on"
1718
                                " '%s': %s" % (dir_name, err))
1719
    if not os.path.isdir(dir_name):
1720
      raise errors.GenericError("%s is not a directory" % dir_name)
1721

    
1722

    
1723
def ReadFile(file_name, size=-1):
1724
  """Reads a file.
1725

1726
  @type size: int
1727
  @param size: Read at most size bytes (if negative, entire file)
1728
  @rtype: str
1729
  @return: the (possibly partial) content of the file
1730

1731
  """
1732
  f = open(file_name, "r")
1733
  try:
1734
    return f.read(size)
1735
  finally:
1736
    f.close()
1737

    
1738

    
1739
def WriteFile(file_name, fn=None, data=None,
1740
              mode=None, uid=-1, gid=-1,
1741
              atime=None, mtime=None, close=True,
1742
              dry_run=False, backup=False,
1743
              prewrite=None, postwrite=None):
1744
  """(Over)write a file atomically.
1745

1746
  The file_name and either fn (a function taking one argument, the
1747
  file descriptor, and which should write the data to it) or data (the
1748
  contents of the file) must be passed. The other arguments are
1749
  optional and allow setting the file mode, owner and group, and the
1750
  mtime/atime of the file.
1751

1752
  If the function doesn't raise an exception, it has succeeded and the
1753
  target file has the new contents. If the function has raised an
1754
  exception, an existing target file should be unmodified and the
1755
  temporary file should be removed.
1756

1757
  @type file_name: str
1758
  @param file_name: the target filename
1759
  @type fn: callable
1760
  @param fn: content writing function, called with
1761
      file descriptor as parameter
1762
  @type data: str
1763
  @param data: contents of the file
1764
  @type mode: int
1765
  @param mode: file mode
1766
  @type uid: int
1767
  @param uid: the owner of the file
1768
  @type gid: int
1769
  @param gid: the group of the file
1770
  @type atime: int
1771
  @param atime: a custom access time to be set on the file
1772
  @type mtime: int
1773
  @param mtime: a custom modification time to be set on the file
1774
  @type close: boolean
1775
  @param close: whether to close file after writing it
1776
  @type prewrite: callable
1777
  @param prewrite: function to be called before writing content
1778
  @type postwrite: callable
1779
  @param postwrite: function to be called after writing content
1780

1781
  @rtype: None or int
1782
  @return: None if the 'close' parameter evaluates to True,
1783
      otherwise the file descriptor
1784

1785
  @raise errors.ProgrammerError: if any of the arguments are not valid
1786

1787
  """
1788
  if not os.path.isabs(file_name):
1789
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1790
                                 " absolute: '%s'" % file_name)
1791

    
1792
  if [fn, data].count(None) != 1:
1793
    raise errors.ProgrammerError("fn or data required")
1794

    
1795
  if [atime, mtime].count(None) == 1:
1796
    raise errors.ProgrammerError("Both atime and mtime must be either"
1797
                                 " set or None")
1798

    
1799
  if backup and not dry_run and os.path.isfile(file_name):
1800
    CreateBackup(file_name)
1801

    
1802
  dir_name, base_name = os.path.split(file_name)
1803
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1804
  do_remove = True
1805
  # here we need to make sure we remove the temp file, if any error
1806
  # leaves it in place
1807
  try:
1808
    if uid != -1 or gid != -1:
1809
      os.chown(new_name, uid, gid)
1810
    if mode:
1811
      os.chmod(new_name, mode)
1812
    if callable(prewrite):
1813
      prewrite(fd)
1814
    if data is not None:
1815
      os.write(fd, data)
1816
    else:
1817
      fn(fd)
1818
    if callable(postwrite):
1819
      postwrite(fd)
1820
    os.fsync(fd)
1821
    if atime is not None and mtime is not None:
1822
      os.utime(new_name, (atime, mtime))
1823
    if not dry_run:
1824
      os.rename(new_name, file_name)
1825
      do_remove = False
1826
  finally:
1827
    if close:
1828
      os.close(fd)
1829
      result = None
1830
    else:
1831
      result = fd
1832
    if do_remove:
1833
      RemoveFile(new_name)
1834

    
1835
  return result
1836

    
1837

    
1838
def ReadOneLineFile(file_name, strict=False):
1839
  """Return the first non-empty line from a file.
1840

1841
  @type strict: boolean
1842
  @param strict: if True, abort if the file has more than one
1843
      non-empty line
1844

1845
  """
1846
  file_lines = ReadFile(file_name).splitlines()
1847
  full_lines = filter(bool, file_lines)
1848
  if not file_lines or not full_lines:
1849
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1850
  elif strict and len(full_lines) > 1:
1851
    raise errors.GenericError("Too many lines in one-liner file %s" %
1852
                              file_name)
1853
  return full_lines[0]
1854

    
1855

    
1856
def FirstFree(seq, base=0):
1857
  """Returns the first non-existing integer from seq.
1858

1859
  The seq argument should be a sorted list of positive integers. The
1860
  first time the index of an element is smaller than the element
1861
  value, the index will be returned.
1862

1863
  The base argument is used to start at a different offset,
1864
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1865

1866
  Example: C{[0, 1, 3]} will return I{2}.
1867

1868
  @type seq: sequence
1869
  @param seq: the sequence to be analyzed.
1870
  @type base: int
1871
  @param base: use this value as the base index of the sequence
1872
  @rtype: int
1873
  @return: the first non-used index in the sequence
1874

1875
  """
1876
  for idx, elem in enumerate(seq):
1877
    assert elem >= base, "Passed element is higher than base offset"
1878
    if elem > idx + base:
1879
      # idx is not used
1880
      return idx + base
1881
  return None
1882

    
1883

    
1884
def SingleWaitForFdCondition(fdobj, event, timeout):
1885
  """Waits for a condition to occur on the socket.
1886

1887
  Immediately returns at the first interruption.
1888

1889
  @type fdobj: integer or object supporting a fileno() method
1890
  @param fdobj: entity to wait for events on
1891
  @type event: integer
1892
  @param event: ORed condition (see select module)
1893
  @type timeout: float or None
1894
  @param timeout: Timeout in seconds
1895
  @rtype: int or None
1896
  @return: None for timeout, otherwise occured conditions
1897

1898
  """
1899
  check = (event | select.POLLPRI |
1900
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1901

    
1902
  if timeout is not None:
1903
    # Poller object expects milliseconds
1904
    timeout *= 1000
1905

    
1906
  poller = select.poll()
1907
  poller.register(fdobj, event)
1908
  try:
1909
    # TODO: If the main thread receives a signal and we have no timeout, we
1910
    # could wait forever. This should check a global "quit" flag or something
1911
    # every so often.
1912
    io_events = poller.poll(timeout)
1913
  except select.error, err:
1914
    if err[0] != errno.EINTR:
1915
      raise
1916
    io_events = []
1917
  if io_events and io_events[0][1] & check:
1918
    return io_events[0][1]
1919
  else:
1920
    return None
1921

    
1922

    
1923
class FdConditionWaiterHelper(object):
1924
  """Retry helper for WaitForFdCondition.
1925

1926
  This class contains the retried and wait functions that make sure
1927
  WaitForFdCondition can continue waiting until the timeout is actually
1928
  expired.
1929

1930
  """
1931

    
1932
  def __init__(self, timeout):
1933
    self.timeout = timeout
1934

    
1935
  def Poll(self, fdobj, event):
1936
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1937
    if result is None:
1938
      raise RetryAgain()
1939
    else:
1940
      return result
1941

    
1942
  def UpdateTimeout(self, timeout):
1943
    self.timeout = timeout
1944

    
1945

    
1946
def WaitForFdCondition(fdobj, event, timeout):
1947
  """Waits for a condition to occur on the socket.
1948

1949
  Retries until the timeout is expired, even if interrupted.
1950

1951
  @type fdobj: integer or object supporting a fileno() method
1952
  @param fdobj: entity to wait for events on
1953
  @type event: integer
1954
  @param event: ORed condition (see select module)
1955
  @type timeout: float or None
1956
  @param timeout: Timeout in seconds
1957
  @rtype: int or None
1958
  @return: None for timeout, otherwise occured conditions
1959

1960
  """
1961
  if timeout is not None:
1962
    retrywaiter = FdConditionWaiterHelper(timeout)
1963
    try:
1964
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1965
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1966
    except RetryTimeout:
1967
      result = None
1968
  else:
1969
    result = None
1970
    while result is None:
1971
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1972
  return result
1973

    
1974

    
1975
def UniqueSequence(seq):
1976
  """Returns a list with unique elements.
1977

1978
  Element order is preserved.
1979

1980
  @type seq: sequence
1981
  @param seq: the sequence with the source elements
1982
  @rtype: list
1983
  @return: list of unique elements from seq
1984

1985
  """
1986
  seen = set()
1987
  return [i for i in seq if i not in seen and not seen.add(i)]
1988

    
1989

    
1990
def NormalizeAndValidateMac(mac):
1991
  """Normalizes and check if a MAC address is valid.
1992

1993
  Checks whether the supplied MAC address is formally correct, only
1994
  accepts colon separated format. Normalize it to all lower.
1995

1996
  @type mac: str
1997
  @param mac: the MAC to be validated
1998
  @rtype: str
1999
  @return: returns the normalized and validated MAC.
2000

2001
  @raise errors.OpPrereqError: If the MAC isn't valid
2002

2003
  """
2004
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2005
  if not mac_check.match(mac):
2006
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2007
                               mac, errors.ECODE_INVAL)
2008

    
2009
  return mac.lower()
2010

    
2011

    
2012
def TestDelay(duration):
2013
  """Sleep for a fixed amount of time.
2014

2015
  @type duration: float
2016
  @param duration: the sleep duration
2017
  @rtype: boolean
2018
  @return: False for negative value, True otherwise
2019

2020
  """
2021
  if duration < 0:
2022
    return False, "Invalid sleep duration"
2023
  time.sleep(duration)
2024
  return True, None
2025

    
2026

    
2027
def _CloseFDNoErr(fd, retries=5):
2028
  """Close a file descriptor ignoring errors.
2029

2030
  @type fd: int
2031
  @param fd: the file descriptor
2032
  @type retries: int
2033
  @param retries: how many retries to make, in case we get any
2034
      other error than EBADF
2035

2036
  """
2037
  try:
2038
    os.close(fd)
2039
  except OSError, err:
2040
    if err.errno != errno.EBADF:
2041
      if retries > 0:
2042
        _CloseFDNoErr(fd, retries - 1)
2043
    # else either it's closed already or we're out of retries, so we
2044
    # ignore this and go on
2045

    
2046

    
2047
def CloseFDs(noclose_fds=None):
2048
  """Close file descriptors.
2049

2050
  This closes all file descriptors above 2 (i.e. except
2051
  stdin/out/err).
2052

2053
  @type noclose_fds: list or None
2054
  @param noclose_fds: if given, it denotes a list of file descriptor
2055
      that should not be closed
2056

2057
  """
2058
  # Default maximum for the number of available file descriptors.
2059
  if 'SC_OPEN_MAX' in os.sysconf_names:
2060
    try:
2061
      MAXFD = os.sysconf('SC_OPEN_MAX')
2062
      if MAXFD < 0:
2063
        MAXFD = 1024
2064
    except OSError:
2065
      MAXFD = 1024
2066
  else:
2067
    MAXFD = 1024
2068
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2069
  if (maxfd == resource.RLIM_INFINITY):
2070
    maxfd = MAXFD
2071

    
2072
  # Iterate through and close all file descriptors (except the standard ones)
2073
  for fd in range(3, maxfd):
2074
    if noclose_fds and fd in noclose_fds:
2075
      continue
2076
    _CloseFDNoErr(fd)
2077

    
2078

    
2079
def Mlockall(_ctypes=ctypes):
2080
  """Lock current process' virtual address space into RAM.
2081

2082
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2083
  see mlock(2) for more details. This function requires ctypes module.
2084

2085
  @raises errors.NoCtypesError: if ctypes module is not found
2086

2087
  """
2088
  if _ctypes is None:
2089
    raise errors.NoCtypesError()
2090

    
2091
  libc = _ctypes.cdll.LoadLibrary("libc.so.6")
2092
  if libc is None:
2093
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2094
    return
2095

    
2096
  # Some older version of the ctypes module don't have built-in functionality
2097
  # to access the errno global variable, where function error codes are stored.
2098
  # By declaring this variable as a pointer to an integer we can then access
2099
  # its value correctly, should the mlockall call fail, in order to see what
2100
  # the actual error code was.
2101
  # pylint: disable-msg=W0212
2102
  libc.__errno_location.restype = _ctypes.POINTER(_ctypes.c_int)
2103

    
2104
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2105
    # pylint: disable-msg=W0212
2106
    logging.error("Cannot set memory lock: %s",
2107
                  os.strerror(libc.__errno_location().contents.value))
2108
    return
2109

    
2110
  logging.debug("Memory lock set")
2111

    
2112

    
2113
def Daemonize(logfile, run_uid, run_gid):
2114
  """Daemonize the current process.
2115

2116
  This detaches the current process from the controlling terminal and
2117
  runs it in the background as a daemon.
2118

2119
  @type logfile: str
2120
  @param logfile: the logfile to which we should redirect stdout/stderr
2121
  @type run_uid: int
2122
  @param run_uid: Run the child under this uid
2123
  @type run_gid: int
2124
  @param run_gid: Run the child under this gid
2125
  @rtype: int
2126
  @return: the value zero
2127

2128
  """
2129
  # pylint: disable-msg=W0212
2130
  # yes, we really want os._exit
2131
  UMASK = 077
2132
  WORKDIR = "/"
2133

    
2134
  # this might fail
2135
  pid = os.fork()
2136
  if (pid == 0):  # The first child.
2137
    os.setsid()
2138
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2139
    #        make sure to check for config permission and bail out when invoked
2140
    #        with wrong user.
2141
    os.setgid(run_gid)
2142
    os.setuid(run_uid)
2143
    # this might fail
2144
    pid = os.fork() # Fork a second child.
2145
    if (pid == 0):  # The second child.
2146
      os.chdir(WORKDIR)
2147
      os.umask(UMASK)
2148
    else:
2149
      # exit() or _exit()?  See below.
2150
      os._exit(0) # Exit parent (the first child) of the second child.
2151
  else:
2152
    os._exit(0) # Exit parent of the first child.
2153

    
2154
  for fd in range(3):
2155
    _CloseFDNoErr(fd)
2156
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2157
  assert i == 0, "Can't close/reopen stdin"
2158
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2159
  assert i == 1, "Can't close/reopen stdout"
2160
  # Duplicate standard output to standard error.
2161
  os.dup2(1, 2)
2162
  return 0
2163

    
2164

    
2165
def DaemonPidFileName(name):
2166
  """Compute a ganeti pid file absolute path
2167

2168
  @type name: str
2169
  @param name: the daemon name
2170
  @rtype: str
2171
  @return: the full path to the pidfile corresponding to the given
2172
      daemon name
2173

2174
  """
2175
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2176

    
2177

    
2178
def EnsureDaemon(name):
2179
  """Check for and start daemon if not alive.
2180

2181
  """
2182
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2183
  if result.failed:
2184
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2185
                  name, result.fail_reason, result.output)
2186
    return False
2187

    
2188
  return True
2189

    
2190

    
2191
def StopDaemon(name):
2192
  """Stop daemon
2193

2194
  """
2195
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2196
  if result.failed:
2197
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2198
                  name, result.fail_reason, result.output)
2199
    return False
2200

    
2201
  return True
2202

    
2203

    
2204
def WritePidFile(name):
2205
  """Write the current process pidfile.
2206

2207
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2208

2209
  @type name: str
2210
  @param name: the daemon name to use
2211
  @raise errors.GenericError: if the pid file already exists and
2212
      points to a live process
2213

2214
  """
2215
  pid = os.getpid()
2216
  pidfilename = DaemonPidFileName(name)
2217
  if IsProcessAlive(ReadPidFile(pidfilename)):
2218
    raise errors.GenericError("%s contains a live process" % pidfilename)
2219

    
2220
  WriteFile(pidfilename, data="%d\n" % pid)
2221

    
2222

    
2223
def RemovePidFile(name):
2224
  """Remove the current process pidfile.
2225

2226
  Any errors are ignored.
2227

2228
  @type name: str
2229
  @param name: the daemon name used to derive the pidfile name
2230

2231
  """
2232
  pidfilename = DaemonPidFileName(name)
2233
  # TODO: we could check here that the file contains our pid
2234
  try:
2235
    RemoveFile(pidfilename)
2236
  except: # pylint: disable-msg=W0702
2237
    pass
2238

    
2239

    
2240
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2241
                waitpid=False):
2242
  """Kill a process given by its pid.
2243

2244
  @type pid: int
2245
  @param pid: The PID to terminate.
2246
  @type signal_: int
2247
  @param signal_: The signal to send, by default SIGTERM
2248
  @type timeout: int
2249
  @param timeout: The timeout after which, if the process is still alive,
2250
                  a SIGKILL will be sent. If not positive, no such checking
2251
                  will be done
2252
  @type waitpid: boolean
2253
  @param waitpid: If true, we should waitpid on this process after
2254
      sending signals, since it's our own child and otherwise it
2255
      would remain as zombie
2256

2257
  """
2258
  def _helper(pid, signal_, wait):
2259
    """Simple helper to encapsulate the kill/waitpid sequence"""
2260
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2261
      try:
2262
        os.waitpid(pid, os.WNOHANG)
2263
      except OSError:
2264
        pass
2265

    
2266
  if pid <= 0:
2267
    # kill with pid=0 == suicide
2268
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2269

    
2270
  if not IsProcessAlive(pid):
2271
    return
2272

    
2273
  _helper(pid, signal_, waitpid)
2274

    
2275
  if timeout <= 0:
2276
    return
2277

    
2278
  def _CheckProcess():
2279
    if not IsProcessAlive(pid):
2280
      return
2281

    
2282
    try:
2283
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2284
    except OSError:
2285
      raise RetryAgain()
2286

    
2287
    if result_pid > 0:
2288
      return
2289

    
2290
    raise RetryAgain()
2291

    
2292
  try:
2293
    # Wait up to $timeout seconds
2294
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2295
  except RetryTimeout:
2296
    pass
2297

    
2298
  if IsProcessAlive(pid):
2299
    # Kill process if it's still alive
2300
    _helper(pid, signal.SIGKILL, waitpid)
2301

    
2302

    
2303
def FindFile(name, search_path, test=os.path.exists):
2304
  """Look for a filesystem object in a given path.
2305

2306
  This is an abstract method to search for filesystem object (files,
2307
  dirs) under a given search path.
2308

2309
  @type name: str
2310
  @param name: the name to look for
2311
  @type search_path: str
2312
  @param search_path: location to start at
2313
  @type test: callable
2314
  @param test: a function taking one argument that should return True
2315
      if the a given object is valid; the default value is
2316
      os.path.exists, causing only existing files to be returned
2317
  @rtype: str or None
2318
  @return: full path to the object if found, None otherwise
2319

2320
  """
2321
  # validate the filename mask
2322
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2323
    logging.critical("Invalid value passed for external script name: '%s'",
2324
                     name)
2325
    return None
2326

    
2327
  for dir_name in search_path:
2328
    # FIXME: investigate switch to PathJoin
2329
    item_name = os.path.sep.join([dir_name, name])
2330
    # check the user test and that we're indeed resolving to the given
2331
    # basename
2332
    if test(item_name) and os.path.basename(item_name) == name:
2333
      return item_name
2334
  return None
2335

    
2336

    
2337
def CheckVolumeGroupSize(vglist, vgname, minsize):
2338
  """Checks if the volume group list is valid.
2339

2340
  The function will check if a given volume group is in the list of
2341
  volume groups and has a minimum size.
2342

2343
  @type vglist: dict
2344
  @param vglist: dictionary of volume group names and their size
2345
  @type vgname: str
2346
  @param vgname: the volume group we should check
2347
  @type minsize: int
2348
  @param minsize: the minimum size we accept
2349
  @rtype: None or str
2350
  @return: None for success, otherwise the error message
2351

2352
  """
2353
  vgsize = vglist.get(vgname, None)
2354
  if vgsize is None:
2355
    return "volume group '%s' missing" % vgname
2356
  elif vgsize < minsize:
2357
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2358
            (vgname, minsize, vgsize))
2359
  return None
2360

    
2361

    
2362
def SplitTime(value):
2363
  """Splits time as floating point number into a tuple.
2364

2365
  @param value: Time in seconds
2366
  @type value: int or float
2367
  @return: Tuple containing (seconds, microseconds)
2368

2369
  """
2370
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2371

    
2372
  assert 0 <= seconds, \
2373
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2374
  assert 0 <= microseconds <= 999999, \
2375
    "Microseconds must be 0-999999, but are %s" % microseconds
2376

    
2377
  return (int(seconds), int(microseconds))
2378

    
2379

    
2380
def MergeTime(timetuple):
2381
  """Merges a tuple into time as a floating point number.
2382

2383
  @param timetuple: Time as tuple, (seconds, microseconds)
2384
  @type timetuple: tuple
2385
  @return: Time as a floating point number expressed in seconds
2386

2387
  """
2388
  (seconds, microseconds) = timetuple
2389

    
2390
  assert 0 <= seconds, \
2391
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2392
  assert 0 <= microseconds <= 999999, \
2393
    "Microseconds must be 0-999999, but are %s" % microseconds
2394

    
2395
  return float(seconds) + (float(microseconds) * 0.000001)
2396

    
2397

    
2398
class LogFileHandler(logging.FileHandler):
2399
  """Log handler that doesn't fallback to stderr.
2400

2401
  When an error occurs while writing on the logfile, logging.FileHandler tries
2402
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2403
  the logfile. This class avoids failures reporting errors to /dev/console.
2404

2405
  """
2406
  def __init__(self, filename, mode="a", encoding=None):
2407
    """Open the specified file and use it as the stream for logging.
2408

2409
    Also open /dev/console to report errors while logging.
2410

2411
    """
2412
    logging.FileHandler.__init__(self, filename, mode, encoding)
2413
    self.console = open(constants.DEV_CONSOLE, "a")
2414

    
2415
  def handleError(self, record): # pylint: disable-msg=C0103
2416
    """Handle errors which occur during an emit() call.
2417

2418
    Try to handle errors with FileHandler method, if it fails write to
2419
    /dev/console.
2420

2421
    """
2422
    try:
2423
      logging.FileHandler.handleError(self, record)
2424
    except Exception: # pylint: disable-msg=W0703
2425
      try:
2426
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2427
      except Exception: # pylint: disable-msg=W0703
2428
        # Log handler tried everything it could, now just give up
2429
        pass
2430

    
2431

    
2432
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2433
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2434
                 console_logging=False):
2435
  """Configures the logging module.
2436

2437
  @type logfile: str
2438
  @param logfile: the filename to which we should log
2439
  @type debug: integer
2440
  @param debug: if greater than zero, enable debug messages, otherwise
2441
      only those at C{INFO} and above level
2442
  @type stderr_logging: boolean
2443
  @param stderr_logging: whether we should also log to the standard error
2444
  @type program: str
2445
  @param program: the name under which we should log messages
2446
  @type multithreaded: boolean
2447
  @param multithreaded: if True, will add the thread name to the log file
2448
  @type syslog: string
2449
  @param syslog: one of 'no', 'yes', 'only':
2450
      - if no, syslog is not used
2451
      - if yes, syslog is used (in addition to file-logging)
2452
      - if only, only syslog is used
2453
  @type console_logging: boolean
2454
  @param console_logging: if True, will use a FileHandler which falls back to
2455
      the system console if logging fails
2456
  @raise EnvironmentError: if we can't open the log file and
2457
      syslog/stderr logging is disabled
2458

2459
  """
2460
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2461
  sft = program + "[%(process)d]:"
2462
  if multithreaded:
2463
    fmt += "/%(threadName)s"
2464
    sft += " (%(threadName)s)"
2465
  if debug:
2466
    fmt += " %(module)s:%(lineno)s"
2467
    # no debug info for syslog loggers
2468
  fmt += " %(levelname)s %(message)s"
2469
  # yes, we do want the textual level, as remote syslog will probably
2470
  # lose the error level, and it's easier to grep for it
2471
  sft += " %(levelname)s %(message)s"
2472
  formatter = logging.Formatter(fmt)
2473
  sys_fmt = logging.Formatter(sft)
2474

    
2475
  root_logger = logging.getLogger("")
2476
  root_logger.setLevel(logging.NOTSET)
2477

    
2478
  # Remove all previously setup handlers
2479
  for handler in root_logger.handlers:
2480
    handler.close()
2481
    root_logger.removeHandler(handler)
2482

    
2483
  if stderr_logging:
2484
    stderr_handler = logging.StreamHandler()
2485
    stderr_handler.setFormatter(formatter)
2486
    if debug:
2487
      stderr_handler.setLevel(logging.NOTSET)
2488
    else:
2489
      stderr_handler.setLevel(logging.CRITICAL)
2490
    root_logger.addHandler(stderr_handler)
2491

    
2492
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2493
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2494
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2495
                                                    facility)
2496
    syslog_handler.setFormatter(sys_fmt)
2497
    # Never enable debug over syslog
2498
    syslog_handler.setLevel(logging.INFO)
2499
    root_logger.addHandler(syslog_handler)
2500

    
2501
  if syslog != constants.SYSLOG_ONLY:
2502
    # this can fail, if the logging directories are not setup or we have
2503
    # a permisssion problem; in this case, it's best to log but ignore
2504
    # the error if stderr_logging is True, and if false we re-raise the
2505
    # exception since otherwise we could run but without any logs at all
2506
    try:
2507
      if console_logging:
2508
        logfile_handler = LogFileHandler(logfile)
2509
      else:
2510
        logfile_handler = logging.FileHandler(logfile)
2511
      logfile_handler.setFormatter(formatter)
2512
      if debug:
2513
        logfile_handler.setLevel(logging.DEBUG)
2514
      else:
2515
        logfile_handler.setLevel(logging.INFO)
2516
      root_logger.addHandler(logfile_handler)
2517
    except EnvironmentError:
2518
      if stderr_logging or syslog == constants.SYSLOG_YES:
2519
        logging.exception("Failed to enable logging to file '%s'", logfile)
2520
      else:
2521
        # we need to re-raise the exception
2522
        raise
2523

    
2524

    
2525
def IsNormAbsPath(path):
2526
  """Check whether a path is absolute and also normalized
2527

2528
  This avoids things like /dir/../../other/path to be valid.
2529

2530
  """
2531
  return os.path.normpath(path) == path and os.path.isabs(path)
2532

    
2533

    
2534
def PathJoin(*args):
2535
  """Safe-join a list of path components.
2536

2537
  Requirements:
2538
      - the first argument must be an absolute path
2539
      - no component in the path must have backtracking (e.g. /../),
2540
        since we check for normalization at the end
2541

2542
  @param args: the path components to be joined
2543
  @raise ValueError: for invalid paths
2544

2545
  """
2546
  # ensure we're having at least one path passed in
2547
  assert args
2548
  # ensure the first component is an absolute and normalized path name
2549
  root = args[0]
2550
  if not IsNormAbsPath(root):
2551
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2552
  result = os.path.join(*args)
2553
  # ensure that the whole path is normalized
2554
  if not IsNormAbsPath(result):
2555
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2556
  # check that we're still under the original prefix
2557
  prefix = os.path.commonprefix([root, result])
2558
  if prefix != root:
2559
    raise ValueError("Error: path joining resulted in different prefix"
2560
                     " (%s != %s)" % (prefix, root))
2561
  return result
2562

    
2563

    
2564
def TailFile(fname, lines=20):
2565
  """Return the last lines from a file.
2566

2567
  @note: this function will only read and parse the last 4KB of
2568
      the file; if the lines are very long, it could be that less
2569
      than the requested number of lines are returned
2570

2571
  @param fname: the file name
2572
  @type lines: int
2573
  @param lines: the (maximum) number of lines to return
2574

2575
  """
2576
  fd = open(fname, "r")
2577
  try:
2578
    fd.seek(0, 2)
2579
    pos = fd.tell()
2580
    pos = max(0, pos-4096)
2581
    fd.seek(pos, 0)
2582
    raw_data = fd.read()
2583
  finally:
2584
    fd.close()
2585

    
2586
  rows = raw_data.splitlines()
2587
  return rows[-lines:]
2588

    
2589

    
2590
def FormatTimestampWithTZ(secs):
2591
  """Formats a Unix timestamp with the local timezone.
2592

2593
  """
2594
  return time.strftime("%F %T %Z", time.gmtime(secs))
2595

    
2596

    
2597
def _ParseAsn1Generalizedtime(value):
2598
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2599

2600
  @type value: string
2601
  @param value: ASN1 GENERALIZEDTIME timestamp
2602

2603
  """
2604
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2605
  if m:
2606
    # We have an offset
2607
    asn1time = m.group(1)
2608
    hours = int(m.group(2))
2609
    minutes = int(m.group(3))
2610
    utcoffset = (60 * hours) + minutes
2611
  else:
2612
    if not value.endswith("Z"):
2613
      raise ValueError("Missing timezone")
2614
    asn1time = value[:-1]
2615
    utcoffset = 0
2616

    
2617
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2618

    
2619
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2620

    
2621
  return calendar.timegm(tt.utctimetuple())
2622

    
2623

    
2624
def GetX509CertValidity(cert):
2625
  """Returns the validity period of the certificate.
2626

2627
  @type cert: OpenSSL.crypto.X509
2628
  @param cert: X509 certificate object
2629

2630
  """
2631
  # The get_notBefore and get_notAfter functions are only supported in
2632
  # pyOpenSSL 0.7 and above.
2633
  try:
2634
    get_notbefore_fn = cert.get_notBefore
2635
  except AttributeError:
2636
    not_before = None
2637
  else:
2638
    not_before_asn1 = get_notbefore_fn()
2639

    
2640
    if not_before_asn1 is None:
2641
      not_before = None
2642
    else:
2643
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2644

    
2645
  try:
2646
    get_notafter_fn = cert.get_notAfter
2647
  except AttributeError:
2648
    not_after = None
2649
  else:
2650
    not_after_asn1 = get_notafter_fn()
2651

    
2652
    if not_after_asn1 is None:
2653
      not_after = None
2654
    else:
2655
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2656

    
2657
  return (not_before, not_after)
2658

    
2659

    
2660
def _VerifyCertificateInner(expired, not_before, not_after, now,
2661
                            warn_days, error_days):
2662
  """Verifies certificate validity.
2663

2664
  @type expired: bool
2665
  @param expired: Whether pyOpenSSL considers the certificate as expired
2666
  @type not_before: number or None
2667
  @param not_before: Unix timestamp before which certificate is not valid
2668
  @type not_after: number or None
2669
  @param not_after: Unix timestamp after which certificate is invalid
2670
  @type now: number
2671
  @param now: Current time as Unix timestamp
2672
  @type warn_days: number or None
2673
  @param warn_days: How many days before expiration a warning should be reported
2674
  @type error_days: number or None
2675
  @param error_days: How many days before expiration an error should be reported
2676

2677
  """
2678
  if expired:
2679
    msg = "Certificate is expired"
2680

    
2681
    if not_before is not None and not_after is not None:
2682
      msg += (" (valid from %s to %s)" %
2683
              (FormatTimestampWithTZ(not_before),
2684
               FormatTimestampWithTZ(not_after)))
2685
    elif not_before is not None:
2686
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2687
    elif not_after is not None:
2688
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2689

    
2690
    return (CERT_ERROR, msg)
2691

    
2692
  elif not_before is not None and not_before > now:
2693
    return (CERT_WARNING,
2694
            "Certificate not yet valid (valid from %s)" %
2695
            FormatTimestampWithTZ(not_before))
2696

    
2697
  elif not_after is not None:
2698
    remaining_days = int((not_after - now) / (24 * 3600))
2699

    
2700
    msg = "Certificate expires in about %d days" % remaining_days
2701

    
2702
    if error_days is not None and remaining_days <= error_days:
2703
      return (CERT_ERROR, msg)
2704

    
2705
    if warn_days is not None and remaining_days <= warn_days:
2706
      return (CERT_WARNING, msg)
2707

    
2708
  return (None, None)
2709

    
2710

    
2711
def VerifyX509Certificate(cert, warn_days, error_days):
2712
  """Verifies a certificate for LUVerifyCluster.
2713

2714
  @type cert: OpenSSL.crypto.X509
2715
  @param cert: X509 certificate object
2716
  @type warn_days: number or None
2717
  @param warn_days: How many days before expiration a warning should be reported
2718
  @type error_days: number or None
2719
  @param error_days: How many days before expiration an error should be reported
2720

2721
  """
2722
  # Depending on the pyOpenSSL version, this can just return (None, None)
2723
  (not_before, not_after) = GetX509CertValidity(cert)
2724

    
2725
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2726
                                 time.time(), warn_days, error_days)
2727

    
2728

    
2729
def SignX509Certificate(cert, key, salt):
2730
  """Sign a X509 certificate.
2731

2732
  An RFC822-like signature header is added in front of the certificate.
2733

2734
  @type cert: OpenSSL.crypto.X509
2735
  @param cert: X509 certificate object
2736
  @type key: string
2737
  @param key: Key for HMAC
2738
  @type salt: string
2739
  @param salt: Salt for HMAC
2740
  @rtype: string
2741
  @return: Serialized and signed certificate in PEM format
2742

2743
  """
2744
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2745
    raise errors.GenericError("Invalid salt: %r" % salt)
2746

    
2747
  # Dumping as PEM here ensures the certificate is in a sane format
2748
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2749

    
2750
  return ("%s: %s/%s\n\n%s" %
2751
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2752
           Sha1Hmac(key, cert_pem, salt=salt),
2753
           cert_pem))
2754

    
2755

    
2756
def _ExtractX509CertificateSignature(cert_pem):
2757
  """Helper function to extract signature from X509 certificate.
2758

2759
  """
2760
  # Extract signature from original PEM data
2761
  for line in cert_pem.splitlines():
2762
    if line.startswith("---"):
2763
      break
2764

    
2765
    m = X509_SIGNATURE.match(line.strip())
2766
    if m:
2767
      return (m.group("salt"), m.group("sign"))
2768

    
2769
  raise errors.GenericError("X509 certificate signature is missing")
2770

    
2771

    
2772
def LoadSignedX509Certificate(cert_pem, key):
2773
  """Verifies a signed X509 certificate.
2774

2775
  @type cert_pem: string
2776
  @param cert_pem: Certificate in PEM format and with signature header
2777
  @type key: string
2778
  @param key: Key for HMAC
2779
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2780
  @return: X509 certificate object and salt
2781

2782
  """
2783
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2784

    
2785
  # Load certificate
2786
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2787

    
2788
  # Dump again to ensure it's in a sane format
2789
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2790

    
2791
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2792
    raise errors.GenericError("X509 certificate signature is invalid")
2793

    
2794
  return (cert, salt)
2795

    
2796

    
2797
def Sha1Hmac(key, text, salt=None):
2798
  """Calculates the HMAC-SHA1 digest of a text.
2799

2800
  HMAC is defined in RFC2104.
2801

2802
  @type key: string
2803
  @param key: Secret key
2804
  @type text: string
2805

2806
  """
2807
  if salt:
2808
    salted_text = salt + text
2809
  else:
2810
    salted_text = text
2811

    
2812
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2813

    
2814

    
2815
def VerifySha1Hmac(key, text, digest, salt=None):
2816
  """Verifies the HMAC-SHA1 digest of a text.
2817

2818
  HMAC is defined in RFC2104.
2819

2820
  @type key: string
2821
  @param key: Secret key
2822
  @type text: string
2823
  @type digest: string
2824
  @param digest: Expected digest
2825
  @rtype: bool
2826
  @return: Whether HMAC-SHA1 digest matches
2827

2828
  """
2829
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2830

    
2831

    
2832
def SafeEncode(text):
2833
  """Return a 'safe' version of a source string.
2834

2835
  This function mangles the input string and returns a version that
2836
  should be safe to display/encode as ASCII. To this end, we first
2837
  convert it to ASCII using the 'backslashreplace' encoding which
2838
  should get rid of any non-ASCII chars, and then we process it
2839
  through a loop copied from the string repr sources in the python; we
2840
  don't use string_escape anymore since that escape single quotes and
2841
  backslashes too, and that is too much; and that escaping is not
2842
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2843

2844
  @type text: str or unicode
2845
  @param text: input data
2846
  @rtype: str
2847
  @return: a safe version of text
2848

2849
  """
2850
  if isinstance(text, unicode):
2851
    # only if unicode; if str already, we handle it below
2852
    text = text.encode('ascii', 'backslashreplace')
2853
  resu = ""
2854
  for char in text:
2855
    c = ord(char)
2856
    if char  == '\t':
2857
      resu += r'\t'
2858
    elif char == '\n':
2859
      resu += r'\n'
2860
    elif char == '\r':
2861
      resu += r'\'r'
2862
    elif c < 32 or c >= 127: # non-printable
2863
      resu += "\\x%02x" % (c & 0xff)
2864
    else:
2865
      resu += char
2866
  return resu
2867

    
2868

    
2869
def UnescapeAndSplit(text, sep=","):
2870
  """Split and unescape a string based on a given separator.
2871

2872
  This function splits a string based on a separator where the
2873
  separator itself can be escape in order to be an element of the
2874
  elements. The escaping rules are (assuming coma being the
2875
  separator):
2876
    - a plain , separates the elements
2877
    - a sequence \\\\, (double backslash plus comma) is handled as a
2878
      backslash plus a separator comma
2879
    - a sequence \, (backslash plus comma) is handled as a
2880
      non-separator comma
2881

2882
  @type text: string
2883
  @param text: the string to split
2884
  @type sep: string
2885
  @param text: the separator
2886
  @rtype: string
2887
  @return: a list of strings
2888

2889
  """
2890
  # we split the list by sep (with no escaping at this stage)
2891
  slist = text.split(sep)
2892
  # next, we revisit the elements and if any of them ended with an odd
2893
  # number of backslashes, then we join it with the next
2894
  rlist = []
2895
  while slist:
2896
    e1 = slist.pop(0)
2897
    if e1.endswith("\\"):
2898
      num_b = len(e1) - len(e1.rstrip("\\"))
2899
      if num_b % 2 == 1:
2900
        e2 = slist.pop(0)
2901
        # here the backslashes remain (all), and will be reduced in
2902
        # the next step
2903
        rlist.append(e1 + sep + e2)
2904
        continue
2905
    rlist.append(e1)
2906
  # finally, replace backslash-something with something
2907
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2908
  return rlist
2909

    
2910

    
2911
def CommaJoin(names):
2912
  """Nicely join a set of identifiers.
2913

2914
  @param names: set, list or tuple
2915
  @return: a string with the formatted results
2916

2917
  """
2918
  return ", ".join([str(val) for val in names])
2919

    
2920

    
2921
def BytesToMebibyte(value):
2922
  """Converts bytes to mebibytes.
2923

2924
  @type value: int
2925
  @param value: Value in bytes
2926
  @rtype: int
2927
  @return: Value in mebibytes
2928

2929
  """
2930
  return int(round(value / (1024.0 * 1024.0), 0))
2931

    
2932

    
2933
def CalculateDirectorySize(path):
2934
  """Calculates the size of a directory recursively.
2935

2936
  @type path: string
2937
  @param path: Path to directory
2938
  @rtype: int
2939
  @return: Size in mebibytes
2940

2941
  """
2942
  size = 0
2943

    
2944
  for (curpath, _, files) in os.walk(path):
2945
    for filename in files:
2946
      st = os.lstat(PathJoin(curpath, filename))
2947
      size += st.st_size
2948

    
2949
  return BytesToMebibyte(size)
2950

    
2951

    
2952
def GetMounts(filename=constants.PROC_MOUNTS):
2953
  """Returns the list of mounted filesystems.
2954

2955
  This function is Linux-specific.
2956

2957
  @param filename: path of mounts file (/proc/mounts by default)
2958
  @rtype: list of tuples
2959
  @return: list of mount entries (device, mountpoint, fstype, options)
2960

2961
  """
2962
  # TODO(iustin): investigate non-Linux options (e.g. via mount output)
2963
  data = []
2964
  mountlines = ReadFile(filename).splitlines()
2965
  for line in mountlines:
2966
    device, mountpoint, fstype, options, _ = line.split(None, 4)
2967
    data.append((device, mountpoint, fstype, options))
2968

    
2969
  return data
2970

    
2971

    
2972
def GetFilesystemStats(path):
2973
  """Returns the total and free space on a filesystem.
2974

2975
  @type path: string
2976
  @param path: Path on filesystem to be examined
2977
  @rtype: int
2978
  @return: tuple of (Total space, Free space) in mebibytes
2979

2980
  """
2981
  st = os.statvfs(path)
2982

    
2983
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2984
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2985
  return (tsize, fsize)
2986

    
2987

    
2988
def RunInSeparateProcess(fn, *args):
2989
  """Runs a function in a separate process.
2990

2991
  Note: Only boolean return values are supported.
2992

2993
  @type fn: callable
2994
  @param fn: Function to be called
2995
  @rtype: bool
2996
  @return: Function's result
2997

2998
  """
2999
  pid = os.fork()
3000
  if pid == 0:
3001
    # Child process
3002
    try:
3003
      # In case the function uses temporary files
3004
      ResetTempfileModule()
3005

    
3006
      # Call function
3007
      result = int(bool(fn(*args)))
3008
      assert result in (0, 1)
3009
    except: # pylint: disable-msg=W0702
3010
      logging.exception("Error while calling function in separate process")
3011
      # 0 and 1 are reserved for the return value
3012
      result = 33
3013

    
3014
    os._exit(result) # pylint: disable-msg=W0212
3015

    
3016
  # Parent process
3017

    
3018
  # Avoid zombies and check exit code
3019
  (_, status) = os.waitpid(pid, 0)
3020

    
3021
  if os.WIFSIGNALED(status):
3022
    exitcode = None
3023
    signum = os.WTERMSIG(status)
3024
  else:
3025
    exitcode = os.WEXITSTATUS(status)
3026
    signum = None
3027

    
3028
  if not (exitcode in (0, 1) and signum is None):
3029
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3030
                              (exitcode, signum))
3031

    
3032
  return bool(exitcode)
3033

    
3034

    
3035
def IgnoreProcessNotFound(fn, *args, **kwargs):
3036
  """Ignores ESRCH when calling a process-related function.
3037

3038
  ESRCH is raised when a process is not found.
3039

3040
  @rtype: bool
3041
  @return: Whether process was found
3042

3043
  """
3044
  try:
3045
    fn(*args, **kwargs)
3046
  except EnvironmentError, err:
3047
    # Ignore ESRCH
3048
    if err.errno == errno.ESRCH:
3049
      return False
3050
    raise
3051

    
3052
  return True
3053

    
3054

    
3055
def IgnoreSignals(fn, *args, **kwargs):
3056
  """Tries to call a function ignoring failures due to EINTR.
3057

3058
  """
3059
  try:
3060
    return fn(*args, **kwargs)
3061
  except EnvironmentError, err:
3062
    if err.errno == errno.EINTR:
3063
      return None
3064
    else:
3065
      raise
3066
  except (select.error, socket.error), err:
3067
    # In python 2.6 and above select.error is an IOError, so it's handled
3068
    # above, in 2.5 and below it's not, and it's handled here.
3069
    if err.args and err.args[0] == errno.EINTR:
3070
      return None
3071
    else:
3072
      raise
3073

    
3074

    
3075
def LockFile(fd):
3076
  """Locks a file using POSIX locks.
3077

3078
  @type fd: int
3079
  @param fd: the file descriptor we need to lock
3080

3081
  """
3082
  try:
3083
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3084
  except IOError, err:
3085
    if err.errno == errno.EAGAIN:
3086
      raise errors.LockError("File already locked")
3087
    raise
3088

    
3089

    
3090
def FormatTime(val):
3091
  """Formats a time value.
3092

3093
  @type val: float or None
3094
  @param val: the timestamp as returned by time.time()
3095
  @return: a string value or N/A if we don't have a valid timestamp
3096

3097
  """
3098
  if val is None or not isinstance(val, (int, float)):
3099
    return "N/A"
3100
  # these two codes works on Linux, but they are not guaranteed on all
3101
  # platforms
3102
  return time.strftime("%F %T", time.localtime(val))
3103

    
3104

    
3105
def FormatSeconds(secs):
3106
  """Formats seconds for easier reading.
3107

3108
  @type secs: number
3109
  @param secs: Number of seconds
3110
  @rtype: string
3111
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3112

3113
  """
3114
  parts = []
3115

    
3116
  secs = round(secs, 0)
3117

    
3118
  if secs > 0:
3119
    # Negative values would be a bit tricky
3120
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3121
      (complete, secs) = divmod(secs, one)
3122
      if complete or parts:
3123
        parts.append("%d%s" % (complete, unit))
3124

    
3125
  parts.append("%ds" % secs)
3126

    
3127
  return " ".join(parts)
3128

    
3129

    
3130
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3131
  """Reads the watcher pause file.
3132

3133
  @type filename: string
3134
  @param filename: Path to watcher pause file
3135
  @type now: None, float or int
3136
  @param now: Current time as Unix timestamp
3137
  @type remove_after: int
3138
  @param remove_after: Remove watcher pause file after specified amount of
3139
    seconds past the pause end time
3140

3141
  """
3142
  if now is None:
3143
    now = time.time()
3144

    
3145
  try:
3146
    value = ReadFile(filename)
3147
  except IOError, err:
3148
    if err.errno != errno.ENOENT:
3149
      raise
3150
    value = None
3151

    
3152
  if value is not None:
3153
    try:
3154
      value = int(value)
3155
    except ValueError:
3156
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3157
                       " removing it"), filename)
3158
      RemoveFile(filename)
3159
      value = None
3160

    
3161
    if value is not None:
3162
      # Remove file if it's outdated
3163
      if now > (value + remove_after):
3164
        RemoveFile(filename)
3165
        value = None
3166

    
3167
      elif now > value:
3168
        value = None
3169

    
3170
  return value
3171

    
3172

    
3173
class RetryTimeout(Exception):
3174
  """Retry loop timed out.
3175

3176
  Any arguments which was passed by the retried function to RetryAgain will be
3177
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3178
  the RaiseInner helper method will reraise it.
3179

3180
  """
3181
  def RaiseInner(self):
3182
    if self.args and isinstance(self.args[0], Exception):
3183
      raise self.args[0]
3184
    else:
3185
      raise RetryTimeout(*self.args)
3186

    
3187

    
3188
class RetryAgain(Exception):
3189
  """Retry again.
3190

3191
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3192
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3193
  of the RetryTimeout() method can be used to reraise it.
3194

3195
  """
3196

    
3197

    
3198
class _RetryDelayCalculator(object):
3199
  """Calculator for increasing delays.
3200

3201
  """
3202
  __slots__ = [
3203
    "_factor",
3204
    "_limit",
3205
    "_next",
3206
    "_start",
3207
    ]
3208

    
3209
  def __init__(self, start, factor, limit):
3210
    """Initializes this class.
3211

3212
    @type start: float
3213
    @param start: Initial delay
3214
    @type factor: float
3215
    @param factor: Factor for delay increase
3216
    @type limit: float or None
3217
    @param limit: Upper limit for delay or None for no limit
3218

3219
    """
3220
    assert start > 0.0
3221
    assert factor >= 1.0
3222
    assert limit is None or limit >= 0.0
3223

    
3224
    self._start = start
3225
    self._factor = factor
3226
    self._limit = limit
3227

    
3228
    self._next = start
3229

    
3230
  def __call__(self):
3231
    """Returns current delay and calculates the next one.
3232

3233
    """
3234
    current = self._next
3235

    
3236
    # Update for next run
3237
    if self._limit is None or self._next < self._limit:
3238
      self._next = min(self._limit, self._next * self._factor)
3239

    
3240
    return current
3241

    
3242

    
3243
#: Special delay to specify whole remaining timeout
3244
RETRY_REMAINING_TIME = object()
3245

    
3246

    
3247
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3248
          _time_fn=time.time):
3249
  """Call a function repeatedly until it succeeds.
3250

3251
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3252
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3253
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3254

3255
  C{delay} can be one of the following:
3256
    - callable returning the delay length as a float
3257
    - Tuple of (start, factor, limit)
3258
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3259
      useful when overriding L{wait_fn} to wait for an external event)
3260
    - A static delay as a number (int or float)
3261

3262
  @type fn: callable
3263
  @param fn: Function to be called
3264
  @param delay: Either a callable (returning the delay), a tuple of (start,
3265
                factor, limit) (see L{_RetryDelayCalculator}),
3266
                L{RETRY_REMAINING_TIME} or a number (int or float)
3267
  @type timeout: float
3268
  @param timeout: Total timeout
3269
  @type wait_fn: callable
3270
  @param wait_fn: Waiting function
3271
  @return: Return value of function
3272

3273
  """
3274
  assert callable(fn)
3275
  assert callable(wait_fn)
3276
  assert callable(_time_fn)
3277

    
3278
  if args is None:
3279
    args = []
3280

    
3281
  end_time = _time_fn() + timeout
3282

    
3283
  if callable(delay):
3284
    # External function to calculate delay
3285
    calc_delay = delay
3286

    
3287
  elif isinstance(delay, (tuple, list)):
3288
    # Increasing delay with optional upper boundary
3289
    (start, factor, limit) = delay
3290
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3291

    
3292
  elif delay is RETRY_REMAINING_TIME:
3293
    # Always use the remaining time
3294
    calc_delay = None
3295

    
3296
  else:
3297
    # Static delay
3298
    calc_delay = lambda: delay
3299

    
3300
  assert calc_delay is None or callable(calc_delay)
3301

    
3302
  while True:
3303
    retry_args = []
3304
    try:
3305
      # pylint: disable-msg=W0142
3306
      return fn(*args)
3307
    except RetryAgain, err:
3308
      retry_args = err.args
3309
    except RetryTimeout:
3310
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3311
                                   " handle RetryTimeout")
3312

    
3313
    remaining_time = end_time - _time_fn()
3314

    
3315
    if remaining_time < 0.0:
3316
      # pylint: disable-msg=W0142
3317
      raise RetryTimeout(*retry_args)
3318

    
3319
    assert remaining_time >= 0.0
3320

    
3321
    if calc_delay is None:
3322
      wait_fn(remaining_time)
3323
    else:
3324
      current_delay = calc_delay()
3325
      if current_delay > 0.0:
3326
        wait_fn(current_delay)
3327

    
3328

    
3329
def GetClosedTempfile(*args, **kwargs):
3330
  """Creates a temporary file and returns its path.
3331

3332
  """
3333
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3334
  _CloseFDNoErr(fd)
3335
  return path
3336

    
3337

    
3338
def GenerateSelfSignedX509Cert(common_name, validity):
3339
  """Generates a self-signed X509 certificate.
3340

3341
  @type common_name: string
3342
  @param common_name: commonName value
3343
  @type validity: int
3344
  @param validity: Validity for certificate in seconds
3345

3346
  """
3347
  # Create private and public key
3348
  key = OpenSSL.crypto.PKey()
3349
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3350

    
3351
  # Create self-signed certificate
3352
  cert = OpenSSL.crypto.X509()
3353
  if common_name:
3354
    cert.get_subject().CN = common_name
3355
  cert.set_serial_number(1)
3356
  cert.gmtime_adj_notBefore(0)
3357
  cert.gmtime_adj_notAfter(validity)
3358
  cert.set_issuer(cert.get_subject())
3359
  cert.set_pubkey(key)
3360
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3361

    
3362
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3363
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3364

    
3365
  return (key_pem, cert_pem)
3366

    
3367

    
3368
def GenerateSelfSignedSslCert(filename, common_name=constants.X509_CERT_CN,
3369
                              validity=constants.X509_CERT_DEFAULT_VALIDITY):
3370
  """Legacy function to generate self-signed X509 certificate.
3371

3372
  @type filename: str
3373
  @param filename: path to write certificate to
3374
  @type common_name: string
3375
  @param common_name: commonName value
3376
  @type validity: int
3377
  @param validity: validity of certificate in number of days
3378

3379
  """
3380
  # TODO: Investigate using the cluster name instead of X505_CERT_CN for
3381
  # common_name, as cluster-renames are very seldom, and it'd be nice if RAPI
3382
  # and node daemon certificates have the proper Subject/Issuer.
3383
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(common_name,
3384
                                                   validity * 24 * 60 * 60)
3385

    
3386
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3387

    
3388

    
3389
class FileLock(object):
3390
  """Utility class for file locks.
3391

3392
  """
3393
  def __init__(self, fd, filename):
3394
    """Constructor for FileLock.
3395

3396
    @type fd: file
3397
    @param fd: File object
3398
    @type filename: str
3399
    @param filename: Path of the file opened at I{fd}
3400

3401
    """
3402
    self.fd = fd
3403
    self.filename = filename
3404

    
3405
  @classmethod
3406
  def Open(cls, filename):
3407
    """Creates and opens a file to be used as a file-based lock.
3408

3409
    @type filename: string
3410
    @param filename: path to the file to be locked
3411

3412
    """
3413
    # Using "os.open" is necessary to allow both opening existing file
3414
    # read/write and creating if not existing. Vanilla "open" will truncate an
3415
    # existing file -or- allow creating if not existing.
3416
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3417
               filename)
3418

    
3419
  def __del__(self):
3420
    self.Close()
3421

    
3422
  def Close(self):
3423
    """Close the file and release the lock.
3424

3425
    """
3426
    if hasattr(self, "fd") and self.fd:
3427
      self.fd.close()
3428
      self.fd = None
3429

    
3430
  def _flock(self, flag, blocking, timeout, errmsg):
3431
    """Wrapper for fcntl.flock.
3432

3433
    @type flag: int
3434
    @param flag: operation flag
3435
    @type blocking: bool
3436
    @param blocking: whether the operation should be done in blocking mode.
3437
    @type timeout: None or float
3438
    @param timeout: for how long the operation should be retried (implies
3439
                    non-blocking mode).
3440
    @type errmsg: string
3441
    @param errmsg: error message in case operation fails.
3442

3443
    """
3444
    assert self.fd, "Lock was closed"
3445
    assert timeout is None or timeout >= 0, \
3446
      "If specified, timeout must be positive"
3447
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3448

    
3449
    # When a timeout is used, LOCK_NB must always be set
3450
    if not (timeout is None and blocking):
3451
      flag |= fcntl.LOCK_NB
3452

    
3453
    if timeout is None:
3454
      self._Lock(self.fd, flag, timeout)
3455
    else:
3456
      try:
3457
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3458
              args=(self.fd, flag, timeout))
3459
      except RetryTimeout:
3460
        raise errors.LockError(errmsg)
3461

    
3462
  @staticmethod
3463
  def _Lock(fd, flag, timeout):
3464
    try:
3465
      fcntl.flock(fd, flag)
3466
    except IOError, err:
3467
      if timeout is not None and err.errno == errno.EAGAIN:
3468
        raise RetryAgain()
3469

    
3470
      logging.exception("fcntl.flock failed")
3471
      raise
3472

    
3473
  def Exclusive(self, blocking=False, timeout=None):
3474
    """Locks the file in exclusive mode.
3475

3476
    @type blocking: boolean
3477
    @param blocking: whether to block and wait until we
3478
        can lock the file or return immediately
3479
    @type timeout: int or None
3480
    @param timeout: if not None, the duration to wait for the lock
3481
        (in blocking mode)
3482

3483
    """
3484
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3485
                "Failed to lock %s in exclusive mode" % self.filename)
3486

    
3487
  def Shared(self, blocking=False, timeout=None):
3488
    """Locks the file in shared mode.
3489

3490
    @type blocking: boolean
3491
    @param blocking: whether to block and wait until we
3492
        can lock the file or return immediately
3493
    @type timeout: int or None
3494
    @param timeout: if not None, the duration to wait for the lock
3495
        (in blocking mode)
3496

3497
    """
3498
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3499
                "Failed to lock %s in shared mode" % self.filename)
3500

    
3501
  def Unlock(self, blocking=True, timeout=None):
3502
    """Unlocks the file.
3503

3504
    According to C{flock(2)}, unlocking can also be a nonblocking
3505
    operation::
3506

3507
      To make a non-blocking request, include LOCK_NB with any of the above
3508
      operations.
3509

3510
    @type blocking: boolean
3511
    @param blocking: whether to block and wait until we
3512
        can lock the file or return immediately
3513
    @type timeout: int or None
3514
    @param timeout: if not None, the duration to wait for the lock
3515
        (in blocking mode)
3516

3517
    """
3518
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3519
                "Failed to unlock %s" % self.filename)
3520

    
3521

    
3522
class LineSplitter:
3523
  """Splits data chunks into lines separated by newline.
3524

3525
  Instances provide a file-like interface.
3526

3527
  """
3528
  def __init__(self, line_fn, *args):
3529
    """Initializes this class.
3530

3531
    @type line_fn: callable
3532
    @param line_fn: Function called for each line, first parameter is line
3533
    @param args: Extra arguments for L{line_fn}
3534

3535
    """
3536
    assert callable(line_fn)
3537

    
3538
    if args:
3539
      # Python 2.4 doesn't have functools.partial yet
3540
      self._line_fn = \
3541
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3542
    else:
3543
      self._line_fn = line_fn
3544

    
3545
    self._lines = collections.deque()
3546
    self._buffer = ""
3547

    
3548
  def write(self, data):
3549
    parts = (self._buffer + data).split("\n")
3550
    self._buffer = parts.pop()
3551
    self._lines.extend(parts)
3552

    
3553
  def flush(self):
3554
    while self._lines:
3555
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3556

    
3557
  def close(self):
3558
    self.flush()
3559
    if self._buffer:
3560
      self._line_fn(self._buffer)
3561

    
3562

    
3563
def SignalHandled(signums):
3564
  """Signal Handled decoration.
3565

3566
  This special decorator installs a signal handler and then calls the target
3567
  function. The function must accept a 'signal_handlers' keyword argument,
3568
  which will contain a dict indexed by signal number, with SignalHandler
3569
  objects as values.
3570

3571
  The decorator can be safely stacked with iself, to handle multiple signals
3572
  with different handlers.
3573

3574
  @type signums: list
3575
  @param signums: signals to intercept
3576

3577
  """
3578
  def wrap(fn):
3579
    def sig_function(*args, **kwargs):
3580
      assert 'signal_handlers' not in kwargs or \
3581
             kwargs['signal_handlers'] is None or \
3582
             isinstance(kwargs['signal_handlers'], dict), \
3583
             "Wrong signal_handlers parameter in original function call"
3584
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3585
        signal_handlers = kwargs['signal_handlers']
3586
      else:
3587
        signal_handlers = {}
3588
        kwargs['signal_handlers'] = signal_handlers
3589
      sighandler = SignalHandler(signums)
3590
      try:
3591
        for sig in signums:
3592
          signal_handlers[sig] = sighandler
3593
        return fn(*args, **kwargs)
3594
      finally:
3595
        sighandler.Reset()
3596
    return sig_function
3597
  return wrap
3598

    
3599

    
3600
class SignalWakeupFd(object):
3601
  try:
3602
    # This is only supported in Python 2.5 and above (some distributions
3603
    # backported it to Python 2.4)
3604
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3605
  except AttributeError:
3606
    # Not supported
3607
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3608
      return -1
3609
  else:
3610
    def _SetWakeupFd(self, fd):
3611
      return self._set_wakeup_fd_fn(fd)
3612

    
3613
  def __init__(self):
3614
    """Initializes this class.
3615

3616
    """
3617
    (read_fd, write_fd) = os.pipe()
3618

    
3619
    # Once these succeeded, the file descriptors will be closed automatically.
3620
    # Buffer size 0 is important, otherwise .read() with a specified length
3621
    # might buffer data and the file descriptors won't be marked readable.
3622
    self._read_fh = os.fdopen(read_fd, "r", 0)
3623
    self._write_fh = os.fdopen(write_fd, "w", 0)
3624

    
3625
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3626

    
3627
    # Utility functions
3628
    self.fileno = self._read_fh.fileno
3629
    self.read = self._read_fh.read
3630

    
3631
  def Reset(self):
3632
    """Restores the previous wakeup file descriptor.
3633

3634
    """
3635
    if hasattr(self, "_previous") and self._previous is not None:
3636
      self._SetWakeupFd(self._previous)
3637
      self._previous = None
3638

    
3639
  def Notify(self):
3640
    """Notifies the wakeup file descriptor.
3641

3642
    """
3643
    self._write_fh.write("\0")
3644

    
3645
  def __del__(self):
3646
    """Called before object deletion.
3647

3648
    """
3649
    self.Reset()
3650

    
3651

    
3652
class SignalHandler(object):
3653
  """Generic signal handler class.
3654

3655
  It automatically restores the original handler when deconstructed or
3656
  when L{Reset} is called. You can either pass your own handler
3657
  function in or query the L{called} attribute to detect whether the
3658
  signal was sent.
3659

3660
  @type signum: list
3661
  @ivar signum: the signals we handle
3662
  @type called: boolean
3663
  @ivar called: tracks whether any of the signals have been raised
3664

3665
  """
3666
  def __init__(self, signum, handler_fn=None, wakeup=None):
3667
    """Constructs a new SignalHandler instance.
3668

3669
    @type signum: int or list of ints
3670
    @param signum: Single signal number or set of signal numbers
3671
    @type handler_fn: callable
3672
    @param handler_fn: Signal handling function
3673

3674
    """
3675
    assert handler_fn is None or callable(handler_fn)
3676

    
3677
    self.signum = set(signum)
3678
    self.called = False
3679

    
3680
    self._handler_fn = handler_fn
3681
    self._wakeup = wakeup
3682

    
3683
    self._previous = {}
3684
    try:
3685
      for signum in self.signum:
3686
        # Setup handler
3687
        prev_handler = signal.signal(signum, self._HandleSignal)
3688
        try:
3689
          self._previous[signum] = prev_handler
3690
        except:
3691
          # Restore previous handler
3692
          signal.signal(signum, prev_handler)
3693
          raise
3694
    except:
3695
      # Reset all handlers
3696
      self.Reset()
3697
      # Here we have a race condition: a handler may have already been called,
3698
      # but there's not much we can do about it at this point.
3699
      raise
3700

    
3701
  def __del__(self):
3702
    self.Reset()
3703

    
3704
  def Reset(self):
3705
    """Restore previous handler.
3706

3707
    This will reset all the signals to their previous handlers.
3708

3709
    """
3710
    for signum, prev_handler in self._previous.items():
3711
      signal.signal(signum, prev_handler)
3712
      # If successful, remove from dict
3713
      del self._previous[signum]
3714

    
3715
  def Clear(self):
3716
    """Unsets the L{called} flag.
3717

3718
    This function can be used in case a signal may arrive several times.
3719

3720
    """
3721
    self.called = False
3722

    
3723
  def _HandleSignal(self, signum, frame):
3724
    """Actual signal handling function.
3725

3726
    """
3727
    # This is not nice and not absolutely atomic, but it appears to be the only
3728
    # solution in Python -- there are no atomic types.
3729
    self.called = True
3730

    
3731
    if self._wakeup:
3732
      # Notify whoever is interested in signals
3733
      self._wakeup.Notify()
3734

    
3735
    if self._handler_fn:
3736
      self._handler_fn(signum, frame)
3737

    
3738

    
3739
class FieldSet(object):
3740
  """A simple field set.
3741

3742
  Among the features are:
3743
    - checking if a string is among a list of static string or regex objects
3744
    - checking if a whole list of string matches
3745
    - returning the matching groups from a regex match
3746

3747
  Internally, all fields are held as regular expression objects.
3748

3749
  """
3750
  def __init__(self, *items):
3751
    self.items = [re.compile("^%s$" % value) for value in items]
3752

    
3753
  def Extend(self, other_set):
3754
    """Extend the field set with the items from another one"""
3755
    self.items.extend(other_set.items)
3756

    
3757
  def Matches(self, field):
3758
    """Checks if a field matches the current set
3759

3760
    @type field: str
3761
    @param field: the string to match
3762
    @return: either None or a regular expression match object
3763

3764
    """
3765
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3766
      return m
3767
    return None
3768

    
3769
  def NonMatching(self, items):
3770
    """Returns the list of fields not matching the current set
3771

3772
    @type items: list
3773
    @param items: the list of fields to check
3774
    @rtype: list
3775
    @return: list of non-matching fields
3776

3777
    """
3778
    return [val for val in items if not self.Matches(val)]