Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 0070a462

History | View | Annotate | Download (102.3 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52

    
53
from cStringIO import StringIO
54

    
55
try:
56
  # pylint: disable-msg=F0401
57
  import ctypes
58
except ImportError:
59
  ctypes = None
60

    
61
from ganeti import errors
62
from ganeti import constants
63
from ganeti import compat
64

    
65

    
66
_locksheld = []
67
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
68

    
69
debug_locks = False
70

    
71
#: when set to True, L{RunCmd} is disabled
72
no_fork = False
73

    
74
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
75

    
76
HEX_CHAR_RE = r"[a-zA-Z0-9]"
77
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
78
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
79
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
80
                             HEX_CHAR_RE, HEX_CHAR_RE),
81
                            re.S | re.I)
82

    
83
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
84

    
85
# Certificate verification results
86
(CERT_WARNING,
87
 CERT_ERROR) = range(1, 3)
88

    
89
# Flags for mlockall() (from bits/mman.h)
90
_MCL_CURRENT = 1
91
_MCL_FUTURE = 2
92

    
93

    
94
class RunResult(object):
95
  """Holds the result of running external programs.
96

97
  @type exit_code: int
98
  @ivar exit_code: the exit code of the program, or None (if the program
99
      didn't exit())
100
  @type signal: int or None
101
  @ivar signal: the signal that caused the program to finish, or None
102
      (if the program wasn't terminated by a signal)
103
  @type stdout: str
104
  @ivar stdout: the standard output of the program
105
  @type stderr: str
106
  @ivar stderr: the standard error of the program
107
  @type failed: boolean
108
  @ivar failed: True in case the program was
109
      terminated by a signal or exited with a non-zero exit code
110
  @ivar fail_reason: a string detailing the termination reason
111

112
  """
113
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
114
               "failed", "fail_reason", "cmd"]
115

    
116

    
117
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
118
    self.cmd = cmd
119
    self.exit_code = exit_code
120
    self.signal = signal_
121
    self.stdout = stdout
122
    self.stderr = stderr
123
    self.failed = (signal_ is not None or exit_code != 0)
124

    
125
    if self.signal is not None:
126
      self.fail_reason = "terminated by signal %s" % self.signal
127
    elif self.exit_code is not None:
128
      self.fail_reason = "exited with exit code %s" % self.exit_code
129
    else:
130
      self.fail_reason = "unable to determine termination reason"
131

    
132
    if self.failed:
133
      logging.debug("Command '%s' failed (%s); output: %s",
134
                    self.cmd, self.fail_reason, self.output)
135

    
136
  def _GetOutput(self):
137
    """Returns the combined stdout and stderr for easier usage.
138

139
    """
140
    return self.stdout + self.stderr
141

    
142
  output = property(_GetOutput, None, None, "Return full output")
143

    
144

    
145
def _BuildCmdEnvironment(env, reset):
146
  """Builds the environment for an external program.
147

148
  """
149
  if reset:
150
    cmd_env = {}
151
  else:
152
    cmd_env = os.environ.copy()
153
    cmd_env["LC_ALL"] = "C"
154

    
155
  if env is not None:
156
    cmd_env.update(env)
157

    
158
  return cmd_env
159

    
160

    
161
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False,
162
           interactive=False):
163
  """Execute a (shell) command.
164

165
  The command should not read from its standard input, as it will be
166
  closed.
167

168
  @type cmd: string or list
169
  @param cmd: Command to run
170
  @type env: dict
171
  @param env: Additional environment variables
172
  @type output: str
173
  @param output: if desired, the output of the command can be
174
      saved in a file instead of the RunResult instance; this
175
      parameter denotes the file name (if not None)
176
  @type cwd: string
177
  @param cwd: if specified, will be used as the working
178
      directory for the command; the default will be /
179
  @type reset_env: boolean
180
  @param reset_env: whether to reset or keep the default os environment
181
  @type interactive: boolean
182
  @param interactive: weather we pipe stdin, stdout and stderr
183
                      (default behaviour) or run the command interactive
184
  @rtype: L{RunResult}
185
  @return: RunResult instance
186
  @raise errors.ProgrammerError: if we call this when forks are disabled
187

188
  """
189
  if no_fork:
190
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
191

    
192
  if output and interactive:
193
    raise errors.ProgrammerError("Parameters 'output' and 'interactive' can"
194
                                 " not be provided at the same time")
195

    
196
  if isinstance(cmd, basestring):
197
    strcmd = cmd
198
    shell = True
199
  else:
200
    cmd = [str(val) for val in cmd]
201
    strcmd = ShellQuoteArgs(cmd)
202
    shell = False
203

    
204
  if output:
205
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
206
  else:
207
    logging.debug("RunCmd %s", strcmd)
208

    
209
  cmd_env = _BuildCmdEnvironment(env, reset_env)
210

    
211
  try:
212
    if output is None:
213
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd, interactive)
214
    else:
215
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
216
      out = err = ""
217
  except OSError, err:
218
    if err.errno == errno.ENOENT:
219
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
220
                               (strcmd, err))
221
    else:
222
      raise
223

    
224
  if status >= 0:
225
    exitcode = status
226
    signal_ = None
227
  else:
228
    exitcode = None
229
    signal_ = -status
230

    
231
  return RunResult(exitcode, signal_, out, err, strcmd)
232

    
233

    
234
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
235
                pidfile=None):
236
  """Start a daemon process after forking twice.
237

238
  @type cmd: string or list
239
  @param cmd: Command to run
240
  @type env: dict
241
  @param env: Additional environment variables
242
  @type cwd: string
243
  @param cwd: Working directory for the program
244
  @type output: string
245
  @param output: Path to file in which to save the output
246
  @type output_fd: int
247
  @param output_fd: File descriptor for output
248
  @type pidfile: string
249
  @param pidfile: Process ID file
250
  @rtype: int
251
  @return: Daemon process ID
252
  @raise errors.ProgrammerError: if we call this when forks are disabled
253

254
  """
255
  if no_fork:
256
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
257
                                 " disabled")
258

    
259
  if output and not (bool(output) ^ (output_fd is not None)):
260
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
261
                                 " specified")
262

    
263
  if isinstance(cmd, basestring):
264
    cmd = ["/bin/sh", "-c", cmd]
265

    
266
  strcmd = ShellQuoteArgs(cmd)
267

    
268
  if output:
269
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
270
  else:
271
    logging.debug("StartDaemon %s", strcmd)
272

    
273
  cmd_env = _BuildCmdEnvironment(env, False)
274

    
275
  # Create pipe for sending PID back
276
  (pidpipe_read, pidpipe_write) = os.pipe()
277
  try:
278
    try:
279
      # Create pipe for sending error messages
280
      (errpipe_read, errpipe_write) = os.pipe()
281
      try:
282
        try:
283
          # First fork
284
          pid = os.fork()
285
          if pid == 0:
286
            try:
287
              # Child process, won't return
288
              _StartDaemonChild(errpipe_read, errpipe_write,
289
                                pidpipe_read, pidpipe_write,
290
                                cmd, cmd_env, cwd,
291
                                output, output_fd, pidfile)
292
            finally:
293
              # Well, maybe child process failed
294
              os._exit(1) # pylint: disable-msg=W0212
295
        finally:
296
          _CloseFDNoErr(errpipe_write)
297

    
298
        # Wait for daemon to be started (or an error message to arrive) and read
299
        # up to 100 KB as an error message
300
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
301
      finally:
302
        _CloseFDNoErr(errpipe_read)
303
    finally:
304
      _CloseFDNoErr(pidpipe_write)
305

    
306
    # Read up to 128 bytes for PID
307
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
308
  finally:
309
    _CloseFDNoErr(pidpipe_read)
310

    
311
  # Try to avoid zombies by waiting for child process
312
  try:
313
    os.waitpid(pid, 0)
314
  except OSError:
315
    pass
316

    
317
  if errormsg:
318
    raise errors.OpExecError("Error when starting daemon process: %r" %
319
                             errormsg)
320

    
321
  try:
322
    return int(pidtext)
323
  except (ValueError, TypeError), err:
324
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
325
                             (pidtext, err))
326

    
327

    
328
def _StartDaemonChild(errpipe_read, errpipe_write,
329
                      pidpipe_read, pidpipe_write,
330
                      args, env, cwd,
331
                      output, fd_output, pidfile):
332
  """Child process for starting daemon.
333

334
  """
335
  try:
336
    # Close parent's side
337
    _CloseFDNoErr(errpipe_read)
338
    _CloseFDNoErr(pidpipe_read)
339

    
340
    # First child process
341
    os.chdir("/")
342
    os.umask(077)
343
    os.setsid()
344

    
345
    # And fork for the second time
346
    pid = os.fork()
347
    if pid != 0:
348
      # Exit first child process
349
      os._exit(0) # pylint: disable-msg=W0212
350

    
351
    # Make sure pipe is closed on execv* (and thereby notifies original process)
352
    SetCloseOnExecFlag(errpipe_write, True)
353

    
354
    # List of file descriptors to be left open
355
    noclose_fds = [errpipe_write]
356

    
357
    # Open PID file
358
    if pidfile:
359
      try:
360
        # TODO: Atomic replace with another locked file instead of writing into
361
        # it after creating
362
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
363

    
364
        # Lock the PID file (and fail if not possible to do so). Any code
365
        # wanting to send a signal to the daemon should try to lock the PID
366
        # file before reading it. If acquiring the lock succeeds, the daemon is
367
        # no longer running and the signal should not be sent.
368
        LockFile(fd_pidfile)
369

    
370
        os.write(fd_pidfile, "%d\n" % os.getpid())
371
      except Exception, err:
372
        raise Exception("Creating and locking PID file failed: %s" % err)
373

    
374
      # Keeping the file open to hold the lock
375
      noclose_fds.append(fd_pidfile)
376

    
377
      SetCloseOnExecFlag(fd_pidfile, False)
378
    else:
379
      fd_pidfile = None
380

    
381
    # Open /dev/null
382
    fd_devnull = os.open(os.devnull, os.O_RDWR)
383

    
384
    assert not output or (bool(output) ^ (fd_output is not None))
385

    
386
    if fd_output is not None:
387
      pass
388
    elif output:
389
      # Open output file
390
      try:
391
        # TODO: Implement flag to set append=yes/no
392
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
393
      except EnvironmentError, err:
394
        raise Exception("Opening output file failed: %s" % err)
395
    else:
396
      fd_output = fd_devnull
397

    
398
    # Redirect standard I/O
399
    os.dup2(fd_devnull, 0)
400
    os.dup2(fd_output, 1)
401
    os.dup2(fd_output, 2)
402

    
403
    # Send daemon PID to parent
404
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
405

    
406
    # Close all file descriptors except stdio and error message pipe
407
    CloseFDs(noclose_fds=noclose_fds)
408

    
409
    # Change working directory
410
    os.chdir(cwd)
411

    
412
    if env is None:
413
      os.execvp(args[0], args)
414
    else:
415
      os.execvpe(args[0], args, env)
416
  except: # pylint: disable-msg=W0702
417
    try:
418
      # Report errors to original process
419
      buf = str(sys.exc_info()[1])
420

    
421
      RetryOnSignal(os.write, errpipe_write, buf)
422
    except: # pylint: disable-msg=W0702
423
      # Ignore errors in error handling
424
      pass
425

    
426
  os._exit(1) # pylint: disable-msg=W0212
427

    
428

    
429
def _RunCmdPipe(cmd, env, via_shell, cwd, interactive):
430
  """Run a command and return its output.
431

432
  @type  cmd: string or list
433
  @param cmd: Command to run
434
  @type env: dict
435
  @param env: The environment to use
436
  @type via_shell: bool
437
  @param via_shell: if we should run via the shell
438
  @type cwd: string
439
  @param cwd: the working directory for the program
440
  @type interactive: boolean
441
  @param interactive: Run command interactive (without piping)
442
  @rtype: tuple
443
  @return: (out, err, status)
444

445
  """
446
  poller = select.poll()
447

    
448
  stderr = subprocess.PIPE
449
  stdout = subprocess.PIPE
450
  stdin = subprocess.PIPE
451

    
452
  if interactive:
453
    stderr = stdout = stdin = None
454

    
455
  child = subprocess.Popen(cmd, shell=via_shell,
456
                           stderr=stderr,
457
                           stdout=stdout,
458
                           stdin=stdin,
459
                           close_fds=True, env=env,
460
                           cwd=cwd)
461

    
462
  out = StringIO()
463
  err = StringIO()
464
  if not interactive:
465
    child.stdin.close()
466
    poller.register(child.stdout, select.POLLIN)
467
    poller.register(child.stderr, select.POLLIN)
468
    fdmap = {
469
      child.stdout.fileno(): (out, child.stdout),
470
      child.stderr.fileno(): (err, child.stderr),
471
      }
472
    for fd in fdmap:
473
      SetNonblockFlag(fd, True)
474

    
475
    while fdmap:
476
      pollresult = RetryOnSignal(poller.poll)
477

    
478
      for fd, event in pollresult:
479
        if event & select.POLLIN or event & select.POLLPRI:
480
          data = fdmap[fd][1].read()
481
          # no data from read signifies EOF (the same as POLLHUP)
482
          if not data:
483
            poller.unregister(fd)
484
            del fdmap[fd]
485
            continue
486
          fdmap[fd][0].write(data)
487
        if (event & select.POLLNVAL or event & select.POLLHUP or
488
            event & select.POLLERR):
489
          poller.unregister(fd)
490
          del fdmap[fd]
491

    
492
  out = out.getvalue()
493
  err = err.getvalue()
494

    
495
  status = child.wait()
496
  return out, err, status
497

    
498

    
499
def _RunCmdFile(cmd, env, via_shell, output, cwd):
500
  """Run a command and save its output to a file.
501

502
  @type  cmd: string or list
503
  @param cmd: Command to run
504
  @type env: dict
505
  @param env: The environment to use
506
  @type via_shell: bool
507
  @param via_shell: if we should run via the shell
508
  @type output: str
509
  @param output: the filename in which to save the output
510
  @type cwd: string
511
  @param cwd: the working directory for the program
512
  @rtype: int
513
  @return: the exit status
514

515
  """
516
  fh = open(output, "a")
517
  try:
518
    child = subprocess.Popen(cmd, shell=via_shell,
519
                             stderr=subprocess.STDOUT,
520
                             stdout=fh,
521
                             stdin=subprocess.PIPE,
522
                             close_fds=True, env=env,
523
                             cwd=cwd)
524

    
525
    child.stdin.close()
526
    status = child.wait()
527
  finally:
528
    fh.close()
529
  return status
530

    
531

    
532
def SetCloseOnExecFlag(fd, enable):
533
  """Sets or unsets the close-on-exec flag on a file descriptor.
534

535
  @type fd: int
536
  @param fd: File descriptor
537
  @type enable: bool
538
  @param enable: Whether to set or unset it.
539

540
  """
541
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
542

    
543
  if enable:
544
    flags |= fcntl.FD_CLOEXEC
545
  else:
546
    flags &= ~fcntl.FD_CLOEXEC
547

    
548
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
549

    
550

    
551
def SetNonblockFlag(fd, enable):
552
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
553

554
  @type fd: int
555
  @param fd: File descriptor
556
  @type enable: bool
557
  @param enable: Whether to set or unset it
558

559
  """
560
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
561

    
562
  if enable:
563
    flags |= os.O_NONBLOCK
564
  else:
565
    flags &= ~os.O_NONBLOCK
566

    
567
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
568

    
569

    
570
def RetryOnSignal(fn, *args, **kwargs):
571
  """Calls a function again if it failed due to EINTR.
572

573
  """
574
  while True:
575
    try:
576
      return fn(*args, **kwargs)
577
    except EnvironmentError, err:
578
      if err.errno != errno.EINTR:
579
        raise
580
    except (socket.error, select.error), err:
581
      # In python 2.6 and above select.error is an IOError, so it's handled
582
      # above, in 2.5 and below it's not, and it's handled here.
583
      if not (err.args and err.args[0] == errno.EINTR):
584
        raise
585

    
586

    
587
def RunParts(dir_name, env=None, reset_env=False):
588
  """Run Scripts or programs in a directory
589

590
  @type dir_name: string
591
  @param dir_name: absolute path to a directory
592
  @type env: dict
593
  @param env: The environment to use
594
  @type reset_env: boolean
595
  @param reset_env: whether to reset or keep the default os environment
596
  @rtype: list of tuples
597
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
598

599
  """
600
  rr = []
601

    
602
  try:
603
    dir_contents = ListVisibleFiles(dir_name)
604
  except OSError, err:
605
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
606
    return rr
607

    
608
  for relname in sorted(dir_contents):
609
    fname = PathJoin(dir_name, relname)
610
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
611
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
612
      rr.append((relname, constants.RUNPARTS_SKIP, None))
613
    else:
614
      try:
615
        result = RunCmd([fname], env=env, reset_env=reset_env)
616
      except Exception, err: # pylint: disable-msg=W0703
617
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
618
      else:
619
        rr.append((relname, constants.RUNPARTS_RUN, result))
620

    
621
  return rr
622

    
623

    
624
def RemoveFile(filename):
625
  """Remove a file ignoring some errors.
626

627
  Remove a file, ignoring non-existing ones or directories. Other
628
  errors are passed.
629

630
  @type filename: str
631
  @param filename: the file to be removed
632

633
  """
634
  try:
635
    os.unlink(filename)
636
  except OSError, err:
637
    if err.errno not in (errno.ENOENT, errno.EISDIR):
638
      raise
639

    
640

    
641
def RemoveDir(dirname):
642
  """Remove an empty directory.
643

644
  Remove a directory, ignoring non-existing ones.
645
  Other errors are passed. This includes the case,
646
  where the directory is not empty, so it can't be removed.
647

648
  @type dirname: str
649
  @param dirname: the empty directory to be removed
650

651
  """
652
  try:
653
    os.rmdir(dirname)
654
  except OSError, err:
655
    if err.errno != errno.ENOENT:
656
      raise
657

    
658

    
659
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
660
  """Renames a file.
661

662
  @type old: string
663
  @param old: Original path
664
  @type new: string
665
  @param new: New path
666
  @type mkdir: bool
667
  @param mkdir: Whether to create target directory if it doesn't exist
668
  @type mkdir_mode: int
669
  @param mkdir_mode: Mode for newly created directories
670

671
  """
672
  try:
673
    return os.rename(old, new)
674
  except OSError, err:
675
    # In at least one use case of this function, the job queue, directory
676
    # creation is very rare. Checking for the directory before renaming is not
677
    # as efficient.
678
    if mkdir and err.errno == errno.ENOENT:
679
      # Create directory and try again
680
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
681

    
682
      return os.rename(old, new)
683

    
684
    raise
685

    
686

    
687
def Makedirs(path, mode=0750):
688
  """Super-mkdir; create a leaf directory and all intermediate ones.
689

690
  This is a wrapper around C{os.makedirs} adding error handling not implemented
691
  before Python 2.5.
692

693
  """
694
  try:
695
    os.makedirs(path, mode)
696
  except OSError, err:
697
    # Ignore EEXIST. This is only handled in os.makedirs as included in
698
    # Python 2.5 and above.
699
    if err.errno != errno.EEXIST or not os.path.exists(path):
700
      raise
701

    
702

    
703
def ResetTempfileModule():
704
  """Resets the random name generator of the tempfile module.
705

706
  This function should be called after C{os.fork} in the child process to
707
  ensure it creates a newly seeded random generator. Otherwise it would
708
  generate the same random parts as the parent process. If several processes
709
  race for the creation of a temporary file, this could lead to one not getting
710
  a temporary name.
711

712
  """
713
  # pylint: disable-msg=W0212
714
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
715
    tempfile._once_lock.acquire()
716
    try:
717
      # Reset random name generator
718
      tempfile._name_sequence = None
719
    finally:
720
      tempfile._once_lock.release()
721
  else:
722
    logging.critical("The tempfile module misses at least one of the"
723
                     " '_once_lock' and '_name_sequence' attributes")
724

    
725

    
726
def _FingerprintFile(filename):
727
  """Compute the fingerprint of a file.
728

729
  If the file does not exist, a None will be returned
730
  instead.
731

732
  @type filename: str
733
  @param filename: the filename to checksum
734
  @rtype: str
735
  @return: the hex digest of the sha checksum of the contents
736
      of the file
737

738
  """
739
  if not (os.path.exists(filename) and os.path.isfile(filename)):
740
    return None
741

    
742
  f = open(filename)
743

    
744
  fp = compat.sha1_hash()
745
  while True:
746
    data = f.read(4096)
747
    if not data:
748
      break
749

    
750
    fp.update(data)
751

    
752
  return fp.hexdigest()
753

    
754

    
755
def FingerprintFiles(files):
756
  """Compute fingerprints for a list of files.
757

758
  @type files: list
759
  @param files: the list of filename to fingerprint
760
  @rtype: dict
761
  @return: a dictionary filename: fingerprint, holding only
762
      existing files
763

764
  """
765
  ret = {}
766

    
767
  for filename in files:
768
    cksum = _FingerprintFile(filename)
769
    if cksum:
770
      ret[filename] = cksum
771

    
772
  return ret
773

    
774

    
775
def ForceDictType(target, key_types, allowed_values=None):
776
  """Force the values of a dict to have certain types.
777

778
  @type target: dict
779
  @param target: the dict to update
780
  @type key_types: dict
781
  @param key_types: dict mapping target dict keys to types
782
                    in constants.ENFORCEABLE_TYPES
783
  @type allowed_values: list
784
  @keyword allowed_values: list of specially allowed values
785

786
  """
787
  if allowed_values is None:
788
    allowed_values = []
789

    
790
  if not isinstance(target, dict):
791
    msg = "Expected dictionary, got '%s'" % target
792
    raise errors.TypeEnforcementError(msg)
793

    
794
  for key in target:
795
    if key not in key_types:
796
      msg = "Unknown key '%s'" % key
797
      raise errors.TypeEnforcementError(msg)
798

    
799
    if target[key] in allowed_values:
800
      continue
801

    
802
    ktype = key_types[key]
803
    if ktype not in constants.ENFORCEABLE_TYPES:
804
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
805
      raise errors.ProgrammerError(msg)
806

    
807
    if ktype in (constants.VTYPE_STRING, constants.VTYPE_MAYBE_STRING):
808
      if target[key] is None and ktype == constants.VTYPE_MAYBE_STRING:
809
        pass
810
      elif not isinstance(target[key], basestring):
811
        if isinstance(target[key], bool) and not target[key]:
812
          target[key] = ''
813
        else:
814
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
815
          raise errors.TypeEnforcementError(msg)
816
    elif ktype == constants.VTYPE_BOOL:
817
      if isinstance(target[key], basestring) and target[key]:
818
        if target[key].lower() == constants.VALUE_FALSE:
819
          target[key] = False
820
        elif target[key].lower() == constants.VALUE_TRUE:
821
          target[key] = True
822
        else:
823
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
824
          raise errors.TypeEnforcementError(msg)
825
      elif target[key]:
826
        target[key] = True
827
      else:
828
        target[key] = False
829
    elif ktype == constants.VTYPE_SIZE:
830
      try:
831
        target[key] = ParseUnit(target[key])
832
      except errors.UnitParseError, err:
833
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
834
              (key, target[key], err)
835
        raise errors.TypeEnforcementError(msg)
836
    elif ktype == constants.VTYPE_INT:
837
      try:
838
        target[key] = int(target[key])
839
      except (ValueError, TypeError):
840
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
841
        raise errors.TypeEnforcementError(msg)
842

    
843

    
844
def _GetProcStatusPath(pid):
845
  """Returns the path for a PID's proc status file.
846

847
  @type pid: int
848
  @param pid: Process ID
849
  @rtype: string
850

851
  """
852
  return "/proc/%d/status" % pid
853

    
854

    
855
def IsProcessAlive(pid):
856
  """Check if a given pid exists on the system.
857

858
  @note: zombie status is not handled, so zombie processes
859
      will be returned as alive
860
  @type pid: int
861
  @param pid: the process ID to check
862
  @rtype: boolean
863
  @return: True if the process exists
864

865
  """
866
  def _TryStat(name):
867
    try:
868
      os.stat(name)
869
      return True
870
    except EnvironmentError, err:
871
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
872
        return False
873
      elif err.errno == errno.EINVAL:
874
        raise RetryAgain(err)
875
      raise
876

    
877
  assert isinstance(pid, int), "pid must be an integer"
878
  if pid <= 0:
879
    return False
880

    
881
  # /proc in a multiprocessor environment can have strange behaviors.
882
  # Retry the os.stat a few times until we get a good result.
883
  try:
884
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
885
                 args=[_GetProcStatusPath(pid)])
886
  except RetryTimeout, err:
887
    err.RaiseInner()
888

    
889

    
890
def _ParseSigsetT(sigset):
891
  """Parse a rendered sigset_t value.
892

893
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
894
  function.
895

896
  @type sigset: string
897
  @param sigset: Rendered signal set from /proc/$pid/status
898
  @rtype: set
899
  @return: Set of all enabled signal numbers
900

901
  """
902
  result = set()
903

    
904
  signum = 0
905
  for ch in reversed(sigset):
906
    chv = int(ch, 16)
907

    
908
    # The following could be done in a loop, but it's easier to read and
909
    # understand in the unrolled form
910
    if chv & 1:
911
      result.add(signum + 1)
912
    if chv & 2:
913
      result.add(signum + 2)
914
    if chv & 4:
915
      result.add(signum + 3)
916
    if chv & 8:
917
      result.add(signum + 4)
918

    
919
    signum += 4
920

    
921
  return result
922

    
923

    
924
def _GetProcStatusField(pstatus, field):
925
  """Retrieves a field from the contents of a proc status file.
926

927
  @type pstatus: string
928
  @param pstatus: Contents of /proc/$pid/status
929
  @type field: string
930
  @param field: Name of field whose value should be returned
931
  @rtype: string
932

933
  """
934
  for line in pstatus.splitlines():
935
    parts = line.split(":", 1)
936

    
937
    if len(parts) < 2 or parts[0] != field:
938
      continue
939

    
940
    return parts[1].strip()
941

    
942
  return None
943

    
944

    
945
def IsProcessHandlingSignal(pid, signum, status_path=None):
946
  """Checks whether a process is handling a signal.
947

948
  @type pid: int
949
  @param pid: Process ID
950
  @type signum: int
951
  @param signum: Signal number
952
  @rtype: bool
953

954
  """
955
  if status_path is None:
956
    status_path = _GetProcStatusPath(pid)
957

    
958
  try:
959
    proc_status = ReadFile(status_path)
960
  except EnvironmentError, err:
961
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
962
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
963
      return False
964
    raise
965

    
966
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
967
  if sigcgt is None:
968
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
969

    
970
  # Now check whether signal is handled
971
  return signum in _ParseSigsetT(sigcgt)
972

    
973

    
974
def ReadPidFile(pidfile):
975
  """Read a pid from a file.
976

977
  @type  pidfile: string
978
  @param pidfile: path to the file containing the pid
979
  @rtype: int
980
  @return: The process id, if the file exists and contains a valid PID,
981
           otherwise 0
982

983
  """
984
  try:
985
    raw_data = ReadOneLineFile(pidfile)
986
  except EnvironmentError, err:
987
    if err.errno != errno.ENOENT:
988
      logging.exception("Can't read pid file")
989
    return 0
990

    
991
  try:
992
    pid = int(raw_data)
993
  except (TypeError, ValueError), err:
994
    logging.info("Can't parse pid file contents", exc_info=True)
995
    return 0
996

    
997
  return pid
998

    
999

    
1000
def ReadLockedPidFile(path):
1001
  """Reads a locked PID file.
1002

1003
  This can be used together with L{StartDaemon}.
1004

1005
  @type path: string
1006
  @param path: Path to PID file
1007
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1008

1009
  """
1010
  try:
1011
    fd = os.open(path, os.O_RDONLY)
1012
  except EnvironmentError, err:
1013
    if err.errno == errno.ENOENT:
1014
      # PID file doesn't exist
1015
      return None
1016
    raise
1017

    
1018
  try:
1019
    try:
1020
      # Try to acquire lock
1021
      LockFile(fd)
1022
    except errors.LockError:
1023
      # Couldn't lock, daemon is running
1024
      return int(os.read(fd, 100))
1025
  finally:
1026
    os.close(fd)
1027

    
1028
  return None
1029

    
1030

    
1031
def MatchNameComponent(key, name_list, case_sensitive=True):
1032
  """Try to match a name against a list.
1033

1034
  This function will try to match a name like test1 against a list
1035
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1036
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1037
  not I{'test1.ex'}. A multiple match will be considered as no match
1038
  at all (e.g. I{'test1'} against C{['test1.example.com',
1039
  'test1.example.org']}), except when the key fully matches an entry
1040
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1041

1042
  @type key: str
1043
  @param key: the name to be searched
1044
  @type name_list: list
1045
  @param name_list: the list of strings against which to search the key
1046
  @type case_sensitive: boolean
1047
  @param case_sensitive: whether to provide a case-sensitive match
1048

1049
  @rtype: None or str
1050
  @return: None if there is no match I{or} if there are multiple matches,
1051
      otherwise the element from the list which matches
1052

1053
  """
1054
  if key in name_list:
1055
    return key
1056

    
1057
  re_flags = 0
1058
  if not case_sensitive:
1059
    re_flags |= re.IGNORECASE
1060
    key = key.upper()
1061
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1062
  names_filtered = []
1063
  string_matches = []
1064
  for name in name_list:
1065
    if mo.match(name) is not None:
1066
      names_filtered.append(name)
1067
      if not case_sensitive and key == name.upper():
1068
        string_matches.append(name)
1069

    
1070
  if len(string_matches) == 1:
1071
    return string_matches[0]
1072
  if len(names_filtered) == 1:
1073
    return names_filtered[0]
1074
  return None
1075

    
1076

    
1077
def ValidateServiceName(name):
1078
  """Validate the given service name.
1079

1080
  @type name: number or string
1081
  @param name: Service name or port specification
1082

1083
  """
1084
  try:
1085
    numport = int(name)
1086
  except (ValueError, TypeError):
1087
    # Non-numeric service name
1088
    valid = _VALID_SERVICE_NAME_RE.match(name)
1089
  else:
1090
    # Numeric port (protocols other than TCP or UDP might need adjustments
1091
    # here)
1092
    valid = (numport >= 0 and numport < (1 << 16))
1093

    
1094
  if not valid:
1095
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1096
                               errors.ECODE_INVAL)
1097

    
1098
  return name
1099

    
1100

    
1101
def ListVolumeGroups():
1102
  """List volume groups and their size
1103

1104
  @rtype: dict
1105
  @return:
1106
       Dictionary with keys volume name and values
1107
       the size of the volume
1108

1109
  """
1110
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1111
  result = RunCmd(command)
1112
  retval = {}
1113
  if result.failed:
1114
    return retval
1115

    
1116
  for line in result.stdout.splitlines():
1117
    try:
1118
      name, size = line.split()
1119
      size = int(float(size))
1120
    except (IndexError, ValueError), err:
1121
      logging.error("Invalid output from vgs (%s): %s", err, line)
1122
      continue
1123

    
1124
    retval[name] = size
1125

    
1126
  return retval
1127

    
1128

    
1129
def BridgeExists(bridge):
1130
  """Check whether the given bridge exists in the system
1131

1132
  @type bridge: str
1133
  @param bridge: the bridge name to check
1134
  @rtype: boolean
1135
  @return: True if it does
1136

1137
  """
1138
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1139

    
1140

    
1141
def NiceSort(name_list):
1142
  """Sort a list of strings based on digit and non-digit groupings.
1143

1144
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1145
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1146
  'a11']}.
1147

1148
  The sort algorithm breaks each name in groups of either only-digits
1149
  or no-digits. Only the first eight such groups are considered, and
1150
  after that we just use what's left of the string.
1151

1152
  @type name_list: list
1153
  @param name_list: the names to be sorted
1154
  @rtype: list
1155
  @return: a copy of the name list sorted with our algorithm
1156

1157
  """
1158
  _SORTER_BASE = "(\D+|\d+)"
1159
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1160
                                                  _SORTER_BASE, _SORTER_BASE,
1161
                                                  _SORTER_BASE, _SORTER_BASE,
1162
                                                  _SORTER_BASE, _SORTER_BASE)
1163
  _SORTER_RE = re.compile(_SORTER_FULL)
1164
  _SORTER_NODIGIT = re.compile("^\D*$")
1165
  def _TryInt(val):
1166
    """Attempts to convert a variable to integer."""
1167
    if val is None or _SORTER_NODIGIT.match(val):
1168
      return val
1169
    rval = int(val)
1170
    return rval
1171

    
1172
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1173
             for name in name_list]
1174
  to_sort.sort()
1175
  return [tup[1] for tup in to_sort]
1176

    
1177

    
1178
def TryConvert(fn, val):
1179
  """Try to convert a value ignoring errors.
1180

1181
  This function tries to apply function I{fn} to I{val}. If no
1182
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1183
  the result, else it will return the original value. Any other
1184
  exceptions are propagated to the caller.
1185

1186
  @type fn: callable
1187
  @param fn: function to apply to the value
1188
  @param val: the value to be converted
1189
  @return: The converted value if the conversion was successful,
1190
      otherwise the original value.
1191

1192
  """
1193
  try:
1194
    nv = fn(val)
1195
  except (ValueError, TypeError):
1196
    nv = val
1197
  return nv
1198

    
1199

    
1200
def IsValidShellParam(word):
1201
  """Verifies is the given word is safe from the shell's p.o.v.
1202

1203
  This means that we can pass this to a command via the shell and be
1204
  sure that it doesn't alter the command line and is passed as such to
1205
  the actual command.
1206

1207
  Note that we are overly restrictive here, in order to be on the safe
1208
  side.
1209

1210
  @type word: str
1211
  @param word: the word to check
1212
  @rtype: boolean
1213
  @return: True if the word is 'safe'
1214

1215
  """
1216
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1217

    
1218

    
1219
def BuildShellCmd(template, *args):
1220
  """Build a safe shell command line from the given arguments.
1221

1222
  This function will check all arguments in the args list so that they
1223
  are valid shell parameters (i.e. they don't contain shell
1224
  metacharacters). If everything is ok, it will return the result of
1225
  template % args.
1226

1227
  @type template: str
1228
  @param template: the string holding the template for the
1229
      string formatting
1230
  @rtype: str
1231
  @return: the expanded command line
1232

1233
  """
1234
  for word in args:
1235
    if not IsValidShellParam(word):
1236
      raise errors.ProgrammerError("Shell argument '%s' contains"
1237
                                   " invalid characters" % word)
1238
  return template % args
1239

    
1240

    
1241
def FormatUnit(value, units):
1242
  """Formats an incoming number of MiB with the appropriate unit.
1243

1244
  @type value: int
1245
  @param value: integer representing the value in MiB (1048576)
1246
  @type units: char
1247
  @param units: the type of formatting we should do:
1248
      - 'h' for automatic scaling
1249
      - 'm' for MiBs
1250
      - 'g' for GiBs
1251
      - 't' for TiBs
1252
  @rtype: str
1253
  @return: the formatted value (with suffix)
1254

1255
  """
1256
  if units not in ('m', 'g', 't', 'h'):
1257
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1258

    
1259
  suffix = ''
1260

    
1261
  if units == 'm' or (units == 'h' and value < 1024):
1262
    if units == 'h':
1263
      suffix = 'M'
1264
    return "%d%s" % (round(value, 0), suffix)
1265

    
1266
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1267
    if units == 'h':
1268
      suffix = 'G'
1269
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1270

    
1271
  else:
1272
    if units == 'h':
1273
      suffix = 'T'
1274
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1275

    
1276

    
1277
def ParseUnit(input_string):
1278
  """Tries to extract number and scale from the given string.
1279

1280
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1281
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1282
  is always an int in MiB.
1283

1284
  """
1285
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1286
  if not m:
1287
    raise errors.UnitParseError("Invalid format")
1288

    
1289
  value = float(m.groups()[0])
1290

    
1291
  unit = m.groups()[1]
1292
  if unit:
1293
    lcunit = unit.lower()
1294
  else:
1295
    lcunit = 'm'
1296

    
1297
  if lcunit in ('m', 'mb', 'mib'):
1298
    # Value already in MiB
1299
    pass
1300

    
1301
  elif lcunit in ('g', 'gb', 'gib'):
1302
    value *= 1024
1303

    
1304
  elif lcunit in ('t', 'tb', 'tib'):
1305
    value *= 1024 * 1024
1306

    
1307
  else:
1308
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1309

    
1310
  # Make sure we round up
1311
  if int(value) < value:
1312
    value += 1
1313

    
1314
  # Round up to the next multiple of 4
1315
  value = int(value)
1316
  if value % 4:
1317
    value += 4 - value % 4
1318

    
1319
  return value
1320

    
1321

    
1322
def ParseCpuMask(cpu_mask):
1323
  """Parse a CPU mask definition and return the list of CPU IDs.
1324

1325
  CPU mask format: comma-separated list of CPU IDs
1326
  or dash-separated ID ranges
1327
  Example: "0-2,5" -> "0,1,2,5"
1328

1329
  @type cpu_mask: str
1330
  @param cpu_mask: CPU mask definition
1331
  @rtype: list of int
1332
  @return: list of CPU IDs
1333

1334
  """
1335
  if not cpu_mask:
1336
    return []
1337
  cpu_list = []
1338
  for range_def in cpu_mask.split(","):
1339
    boundaries = range_def.split("-")
1340
    n_elements = len(boundaries)
1341
    if n_elements > 2:
1342
      raise errors.ParseError("Invalid CPU ID range definition"
1343
                              " (only one hyphen allowed): %s" % range_def)
1344
    try:
1345
      lower = int(boundaries[0])
1346
    except (ValueError, TypeError), err:
1347
      raise errors.ParseError("Invalid CPU ID value for lower boundary of"
1348
                              " CPU ID range: %s" % str(err))
1349
    try:
1350
      higher = int(boundaries[-1])
1351
    except (ValueError, TypeError), err:
1352
      raise errors.ParseError("Invalid CPU ID value for higher boundary of"
1353
                              " CPU ID range: %s" % str(err))
1354
    if lower > higher:
1355
      raise errors.ParseError("Invalid CPU ID range definition"
1356
                              " (%d > %d): %s" % (lower, higher, range_def))
1357
    cpu_list.extend(range(lower, higher + 1))
1358
  return cpu_list
1359

    
1360

    
1361
def AddAuthorizedKey(file_obj, key):
1362
  """Adds an SSH public key to an authorized_keys file.
1363

1364
  @type file_obj: str or file handle
1365
  @param file_obj: path to authorized_keys file
1366
  @type key: str
1367
  @param key: string containing key
1368

1369
  """
1370
  key_fields = key.split()
1371

    
1372
  if isinstance(file_obj, basestring):
1373
    f = open(file_obj, 'a+')
1374
  else:
1375
    f = file_obj
1376

    
1377
  try:
1378
    nl = True
1379
    for line in f:
1380
      # Ignore whitespace changes
1381
      if line.split() == key_fields:
1382
        break
1383
      nl = line.endswith('\n')
1384
    else:
1385
      if not nl:
1386
        f.write("\n")
1387
      f.write(key.rstrip('\r\n'))
1388
      f.write("\n")
1389
      f.flush()
1390
  finally:
1391
    f.close()
1392

    
1393

    
1394
def RemoveAuthorizedKey(file_name, key):
1395
  """Removes an SSH public key from an authorized_keys file.
1396

1397
  @type file_name: str
1398
  @param file_name: path to authorized_keys file
1399
  @type key: str
1400
  @param key: string containing key
1401

1402
  """
1403
  key_fields = key.split()
1404

    
1405
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1406
  try:
1407
    out = os.fdopen(fd, 'w')
1408
    try:
1409
      f = open(file_name, 'r')
1410
      try:
1411
        for line in f:
1412
          # Ignore whitespace changes while comparing lines
1413
          if line.split() != key_fields:
1414
            out.write(line)
1415

    
1416
        out.flush()
1417
        os.rename(tmpname, file_name)
1418
      finally:
1419
        f.close()
1420
    finally:
1421
      out.close()
1422
  except:
1423
    RemoveFile(tmpname)
1424
    raise
1425

    
1426

    
1427
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1428
  """Sets the name of an IP address and hostname in /etc/hosts.
1429

1430
  @type file_name: str
1431
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1432
  @type ip: str
1433
  @param ip: the IP address
1434
  @type hostname: str
1435
  @param hostname: the hostname to be added
1436
  @type aliases: list
1437
  @param aliases: the list of aliases to add for the hostname
1438

1439
  """
1440
  # Ensure aliases are unique
1441
  aliases = UniqueSequence([hostname] + aliases)[1:]
1442

    
1443
  def _WriteEtcHosts(fd):
1444
    # Duplicating file descriptor because os.fdopen's result will automatically
1445
    # close the descriptor, but we would still like to have its functionality.
1446
    out = os.fdopen(os.dup(fd), "w")
1447
    try:
1448
      for line in ReadFile(file_name).splitlines(True):
1449
        fields = line.split()
1450
        if fields and not fields[0].startswith("#") and ip == fields[0]:
1451
          continue
1452
        out.write(line)
1453

    
1454
      out.write("%s\t%s" % (ip, hostname))
1455
      if aliases:
1456
        out.write(" %s" % " ".join(aliases))
1457
      out.write("\n")
1458
      out.flush()
1459
    finally:
1460
      out.close()
1461

    
1462
  WriteFile(file_name, fn=_WriteEtcHosts, mode=0644)
1463

    
1464

    
1465
def AddHostToEtcHosts(hostname, ip):
1466
  """Wrapper around SetEtcHostsEntry.
1467

1468
  @type hostname: str
1469
  @param hostname: a hostname that will be resolved and added to
1470
      L{constants.ETC_HOSTS}
1471
  @type ip: str
1472
  @param ip: The ip address of the host
1473

1474
  """
1475
  SetEtcHostsEntry(constants.ETC_HOSTS, ip, hostname, [hostname.split(".")[0]])
1476

    
1477

    
1478
def RemoveEtcHostsEntry(file_name, hostname):
1479
  """Removes a hostname from /etc/hosts.
1480

1481
  IP addresses without names are removed from the file.
1482

1483
  @type file_name: str
1484
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1485
  @type hostname: str
1486
  @param hostname: the hostname to be removed
1487

1488
  """
1489
  def _WriteEtcHosts(fd):
1490
    # Duplicating file descriptor because os.fdopen's result will automatically
1491
    # close the descriptor, but we would still like to have its functionality.
1492
    out = os.fdopen(os.dup(fd), "w")
1493
    try:
1494
      for line in ReadFile(file_name).splitlines(True):
1495
        fields = line.split()
1496
        if len(fields) > 1 and not fields[0].startswith("#"):
1497
          names = fields[1:]
1498
          if hostname in names:
1499
            while hostname in names:
1500
              names.remove(hostname)
1501
            if names:
1502
              out.write("%s %s\n" % (fields[0], " ".join(names)))
1503
            continue
1504

    
1505
        out.write(line)
1506

    
1507
      out.flush()
1508
    finally:
1509
      out.close()
1510

    
1511
  WriteFile(file_name, fn=_WriteEtcHosts, mode=0644)
1512

    
1513

    
1514
def RemoveHostFromEtcHosts(hostname):
1515
  """Wrapper around RemoveEtcHostsEntry.
1516

1517
  @type hostname: str
1518
  @param hostname: hostname that will be resolved and its
1519
      full and shot name will be removed from
1520
      L{constants.ETC_HOSTS}
1521

1522
  """
1523
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname)
1524
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname.split(".")[0])
1525

    
1526

    
1527
def TimestampForFilename():
1528
  """Returns the current time formatted for filenames.
1529

1530
  The format doesn't contain colons as some shells and applications them as
1531
  separators.
1532

1533
  """
1534
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1535

    
1536

    
1537
def CreateBackup(file_name):
1538
  """Creates a backup of a file.
1539

1540
  @type file_name: str
1541
  @param file_name: file to be backed up
1542
  @rtype: str
1543
  @return: the path to the newly created backup
1544
  @raise errors.ProgrammerError: for invalid file names
1545

1546
  """
1547
  if not os.path.isfile(file_name):
1548
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1549
                                file_name)
1550

    
1551
  prefix = ("%s.backup-%s." %
1552
            (os.path.basename(file_name), TimestampForFilename()))
1553
  dir_name = os.path.dirname(file_name)
1554

    
1555
  fsrc = open(file_name, 'rb')
1556
  try:
1557
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1558
    fdst = os.fdopen(fd, 'wb')
1559
    try:
1560
      logging.debug("Backing up %s at %s", file_name, backup_name)
1561
      shutil.copyfileobj(fsrc, fdst)
1562
    finally:
1563
      fdst.close()
1564
  finally:
1565
    fsrc.close()
1566

    
1567
  return backup_name
1568

    
1569

    
1570
def ShellQuote(value):
1571
  """Quotes shell argument according to POSIX.
1572

1573
  @type value: str
1574
  @param value: the argument to be quoted
1575
  @rtype: str
1576
  @return: the quoted value
1577

1578
  """
1579
  if _re_shell_unquoted.match(value):
1580
    return value
1581
  else:
1582
    return "'%s'" % value.replace("'", "'\\''")
1583

    
1584

    
1585
def ShellQuoteArgs(args):
1586
  """Quotes a list of shell arguments.
1587

1588
  @type args: list
1589
  @param args: list of arguments to be quoted
1590
  @rtype: str
1591
  @return: the quoted arguments concatenated with spaces
1592

1593
  """
1594
  return ' '.join([ShellQuote(i) for i in args])
1595

    
1596

    
1597
class ShellWriter:
1598
  """Helper class to write scripts with indentation.
1599

1600
  """
1601
  INDENT_STR = "  "
1602

    
1603
  def __init__(self, fh):
1604
    """Initializes this class.
1605

1606
    """
1607
    self._fh = fh
1608
    self._indent = 0
1609

    
1610
  def IncIndent(self):
1611
    """Increase indentation level by 1.
1612

1613
    """
1614
    self._indent += 1
1615

    
1616
  def DecIndent(self):
1617
    """Decrease indentation level by 1.
1618

1619
    """
1620
    assert self._indent > 0
1621
    self._indent -= 1
1622

    
1623
  def Write(self, txt, *args):
1624
    """Write line to output file.
1625

1626
    """
1627
    assert self._indent >= 0
1628

    
1629
    self._fh.write(self._indent * self.INDENT_STR)
1630

    
1631
    if args:
1632
      self._fh.write(txt % args)
1633
    else:
1634
      self._fh.write(txt)
1635

    
1636
    self._fh.write("\n")
1637

    
1638

    
1639
def ListVisibleFiles(path):
1640
  """Returns a list of visible files in a directory.
1641

1642
  @type path: str
1643
  @param path: the directory to enumerate
1644
  @rtype: list
1645
  @return: the list of all files not starting with a dot
1646
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1647

1648
  """
1649
  if not IsNormAbsPath(path):
1650
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1651
                                 " absolute/normalized: '%s'" % path)
1652
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1653
  return files
1654

    
1655

    
1656
def GetHomeDir(user, default=None):
1657
  """Try to get the homedir of the given user.
1658

1659
  The user can be passed either as a string (denoting the name) or as
1660
  an integer (denoting the user id). If the user is not found, the
1661
  'default' argument is returned, which defaults to None.
1662

1663
  """
1664
  try:
1665
    if isinstance(user, basestring):
1666
      result = pwd.getpwnam(user)
1667
    elif isinstance(user, (int, long)):
1668
      result = pwd.getpwuid(user)
1669
    else:
1670
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1671
                                   type(user))
1672
  except KeyError:
1673
    return default
1674
  return result.pw_dir
1675

    
1676

    
1677
def NewUUID():
1678
  """Returns a random UUID.
1679

1680
  @note: This is a Linux-specific method as it uses the /proc
1681
      filesystem.
1682
  @rtype: str
1683

1684
  """
1685
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1686

    
1687

    
1688
def GenerateSecret(numbytes=20):
1689
  """Generates a random secret.
1690

1691
  This will generate a pseudo-random secret returning an hex string
1692
  (so that it can be used where an ASCII string is needed).
1693

1694
  @param numbytes: the number of bytes which will be represented by the returned
1695
      string (defaulting to 20, the length of a SHA1 hash)
1696
  @rtype: str
1697
  @return: an hex representation of the pseudo-random sequence
1698

1699
  """
1700
  return os.urandom(numbytes).encode('hex')
1701

    
1702

    
1703
def EnsureDirs(dirs):
1704
  """Make required directories, if they don't exist.
1705

1706
  @param dirs: list of tuples (dir_name, dir_mode)
1707
  @type dirs: list of (string, integer)
1708

1709
  """
1710
  for dir_name, dir_mode in dirs:
1711
    try:
1712
      os.mkdir(dir_name, dir_mode)
1713
    except EnvironmentError, err:
1714
      if err.errno != errno.EEXIST:
1715
        raise errors.GenericError("Cannot create needed directory"
1716
                                  " '%s': %s" % (dir_name, err))
1717
    try:
1718
      os.chmod(dir_name, dir_mode)
1719
    except EnvironmentError, err:
1720
      raise errors.GenericError("Cannot change directory permissions on"
1721
                                " '%s': %s" % (dir_name, err))
1722
    if not os.path.isdir(dir_name):
1723
      raise errors.GenericError("%s is not a directory" % dir_name)
1724

    
1725

    
1726
def ReadFile(file_name, size=-1):
1727
  """Reads a file.
1728

1729
  @type size: int
1730
  @param size: Read at most size bytes (if negative, entire file)
1731
  @rtype: str
1732
  @return: the (possibly partial) content of the file
1733

1734
  """
1735
  f = open(file_name, "r")
1736
  try:
1737
    return f.read(size)
1738
  finally:
1739
    f.close()
1740

    
1741

    
1742
def WriteFile(file_name, fn=None, data=None,
1743
              mode=None, uid=-1, gid=-1,
1744
              atime=None, mtime=None, close=True,
1745
              dry_run=False, backup=False,
1746
              prewrite=None, postwrite=None):
1747
  """(Over)write a file atomically.
1748

1749
  The file_name and either fn (a function taking one argument, the
1750
  file descriptor, and which should write the data to it) or data (the
1751
  contents of the file) must be passed. The other arguments are
1752
  optional and allow setting the file mode, owner and group, and the
1753
  mtime/atime of the file.
1754

1755
  If the function doesn't raise an exception, it has succeeded and the
1756
  target file has the new contents. If the function has raised an
1757
  exception, an existing target file should be unmodified and the
1758
  temporary file should be removed.
1759

1760
  @type file_name: str
1761
  @param file_name: the target filename
1762
  @type fn: callable
1763
  @param fn: content writing function, called with
1764
      file descriptor as parameter
1765
  @type data: str
1766
  @param data: contents of the file
1767
  @type mode: int
1768
  @param mode: file mode
1769
  @type uid: int
1770
  @param uid: the owner of the file
1771
  @type gid: int
1772
  @param gid: the group of the file
1773
  @type atime: int
1774
  @param atime: a custom access time to be set on the file
1775
  @type mtime: int
1776
  @param mtime: a custom modification time to be set on the file
1777
  @type close: boolean
1778
  @param close: whether to close file after writing it
1779
  @type prewrite: callable
1780
  @param prewrite: function to be called before writing content
1781
  @type postwrite: callable
1782
  @param postwrite: function to be called after writing content
1783

1784
  @rtype: None or int
1785
  @return: None if the 'close' parameter evaluates to True,
1786
      otherwise the file descriptor
1787

1788
  @raise errors.ProgrammerError: if any of the arguments are not valid
1789

1790
  """
1791
  if not os.path.isabs(file_name):
1792
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1793
                                 " absolute: '%s'" % file_name)
1794

    
1795
  if [fn, data].count(None) != 1:
1796
    raise errors.ProgrammerError("fn or data required")
1797

    
1798
  if [atime, mtime].count(None) == 1:
1799
    raise errors.ProgrammerError("Both atime and mtime must be either"
1800
                                 " set or None")
1801

    
1802
  if backup and not dry_run and os.path.isfile(file_name):
1803
    CreateBackup(file_name)
1804

    
1805
  dir_name, base_name = os.path.split(file_name)
1806
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1807
  do_remove = True
1808
  # here we need to make sure we remove the temp file, if any error
1809
  # leaves it in place
1810
  try:
1811
    if uid != -1 or gid != -1:
1812
      os.chown(new_name, uid, gid)
1813
    if mode:
1814
      os.chmod(new_name, mode)
1815
    if callable(prewrite):
1816
      prewrite(fd)
1817
    if data is not None:
1818
      os.write(fd, data)
1819
    else:
1820
      fn(fd)
1821
    if callable(postwrite):
1822
      postwrite(fd)
1823
    os.fsync(fd)
1824
    if atime is not None and mtime is not None:
1825
      os.utime(new_name, (atime, mtime))
1826
    if not dry_run:
1827
      os.rename(new_name, file_name)
1828
      do_remove = False
1829
  finally:
1830
    if close:
1831
      os.close(fd)
1832
      result = None
1833
    else:
1834
      result = fd
1835
    if do_remove:
1836
      RemoveFile(new_name)
1837

    
1838
  return result
1839

    
1840

    
1841
def ReadOneLineFile(file_name, strict=False):
1842
  """Return the first non-empty line from a file.
1843

1844
  @type strict: boolean
1845
  @param strict: if True, abort if the file has more than one
1846
      non-empty line
1847

1848
  """
1849
  file_lines = ReadFile(file_name).splitlines()
1850
  full_lines = filter(bool, file_lines)
1851
  if not file_lines or not full_lines:
1852
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1853
  elif strict and len(full_lines) > 1:
1854
    raise errors.GenericError("Too many lines in one-liner file %s" %
1855
                              file_name)
1856
  return full_lines[0]
1857

    
1858

    
1859
def FirstFree(seq, base=0):
1860
  """Returns the first non-existing integer from seq.
1861

1862
  The seq argument should be a sorted list of positive integers. The
1863
  first time the index of an element is smaller than the element
1864
  value, the index will be returned.
1865

1866
  The base argument is used to start at a different offset,
1867
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1868

1869
  Example: C{[0, 1, 3]} will return I{2}.
1870

1871
  @type seq: sequence
1872
  @param seq: the sequence to be analyzed.
1873
  @type base: int
1874
  @param base: use this value as the base index of the sequence
1875
  @rtype: int
1876
  @return: the first non-used index in the sequence
1877

1878
  """
1879
  for idx, elem in enumerate(seq):
1880
    assert elem >= base, "Passed element is higher than base offset"
1881
    if elem > idx + base:
1882
      # idx is not used
1883
      return idx + base
1884
  return None
1885

    
1886

    
1887
def SingleWaitForFdCondition(fdobj, event, timeout):
1888
  """Waits for a condition to occur on the socket.
1889

1890
  Immediately returns at the first interruption.
1891

1892
  @type fdobj: integer or object supporting a fileno() method
1893
  @param fdobj: entity to wait for events on
1894
  @type event: integer
1895
  @param event: ORed condition (see select module)
1896
  @type timeout: float or None
1897
  @param timeout: Timeout in seconds
1898
  @rtype: int or None
1899
  @return: None for timeout, otherwise occured conditions
1900

1901
  """
1902
  check = (event | select.POLLPRI |
1903
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1904

    
1905
  if timeout is not None:
1906
    # Poller object expects milliseconds
1907
    timeout *= 1000
1908

    
1909
  poller = select.poll()
1910
  poller.register(fdobj, event)
1911
  try:
1912
    # TODO: If the main thread receives a signal and we have no timeout, we
1913
    # could wait forever. This should check a global "quit" flag or something
1914
    # every so often.
1915
    io_events = poller.poll(timeout)
1916
  except select.error, err:
1917
    if err[0] != errno.EINTR:
1918
      raise
1919
    io_events = []
1920
  if io_events and io_events[0][1] & check:
1921
    return io_events[0][1]
1922
  else:
1923
    return None
1924

    
1925

    
1926
class FdConditionWaiterHelper(object):
1927
  """Retry helper for WaitForFdCondition.
1928

1929
  This class contains the retried and wait functions that make sure
1930
  WaitForFdCondition can continue waiting until the timeout is actually
1931
  expired.
1932

1933
  """
1934

    
1935
  def __init__(self, timeout):
1936
    self.timeout = timeout
1937

    
1938
  def Poll(self, fdobj, event):
1939
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1940
    if result is None:
1941
      raise RetryAgain()
1942
    else:
1943
      return result
1944

    
1945
  def UpdateTimeout(self, timeout):
1946
    self.timeout = timeout
1947

    
1948

    
1949
def WaitForFdCondition(fdobj, event, timeout):
1950
  """Waits for a condition to occur on the socket.
1951

1952
  Retries until the timeout is expired, even if interrupted.
1953

1954
  @type fdobj: integer or object supporting a fileno() method
1955
  @param fdobj: entity to wait for events on
1956
  @type event: integer
1957
  @param event: ORed condition (see select module)
1958
  @type timeout: float or None
1959
  @param timeout: Timeout in seconds
1960
  @rtype: int or None
1961
  @return: None for timeout, otherwise occured conditions
1962

1963
  """
1964
  if timeout is not None:
1965
    retrywaiter = FdConditionWaiterHelper(timeout)
1966
    try:
1967
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1968
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1969
    except RetryTimeout:
1970
      result = None
1971
  else:
1972
    result = None
1973
    while result is None:
1974
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1975
  return result
1976

    
1977

    
1978
def UniqueSequence(seq):
1979
  """Returns a list with unique elements.
1980

1981
  Element order is preserved.
1982

1983
  @type seq: sequence
1984
  @param seq: the sequence with the source elements
1985
  @rtype: list
1986
  @return: list of unique elements from seq
1987

1988
  """
1989
  seen = set()
1990
  return [i for i in seq if i not in seen and not seen.add(i)]
1991

    
1992

    
1993
def NormalizeAndValidateMac(mac):
1994
  """Normalizes and check if a MAC address is valid.
1995

1996
  Checks whether the supplied MAC address is formally correct, only
1997
  accepts colon separated format. Normalize it to all lower.
1998

1999
  @type mac: str
2000
  @param mac: the MAC to be validated
2001
  @rtype: str
2002
  @return: returns the normalized and validated MAC.
2003

2004
  @raise errors.OpPrereqError: If the MAC isn't valid
2005

2006
  """
2007
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2008
  if not mac_check.match(mac):
2009
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2010
                               mac, errors.ECODE_INVAL)
2011

    
2012
  return mac.lower()
2013

    
2014

    
2015
def TestDelay(duration):
2016
  """Sleep for a fixed amount of time.
2017

2018
  @type duration: float
2019
  @param duration: the sleep duration
2020
  @rtype: boolean
2021
  @return: False for negative value, True otherwise
2022

2023
  """
2024
  if duration < 0:
2025
    return False, "Invalid sleep duration"
2026
  time.sleep(duration)
2027
  return True, None
2028

    
2029

    
2030
def _CloseFDNoErr(fd, retries=5):
2031
  """Close a file descriptor ignoring errors.
2032

2033
  @type fd: int
2034
  @param fd: the file descriptor
2035
  @type retries: int
2036
  @param retries: how many retries to make, in case we get any
2037
      other error than EBADF
2038

2039
  """
2040
  try:
2041
    os.close(fd)
2042
  except OSError, err:
2043
    if err.errno != errno.EBADF:
2044
      if retries > 0:
2045
        _CloseFDNoErr(fd, retries - 1)
2046
    # else either it's closed already or we're out of retries, so we
2047
    # ignore this and go on
2048

    
2049

    
2050
def CloseFDs(noclose_fds=None):
2051
  """Close file descriptors.
2052

2053
  This closes all file descriptors above 2 (i.e. except
2054
  stdin/out/err).
2055

2056
  @type noclose_fds: list or None
2057
  @param noclose_fds: if given, it denotes a list of file descriptor
2058
      that should not be closed
2059

2060
  """
2061
  # Default maximum for the number of available file descriptors.
2062
  if 'SC_OPEN_MAX' in os.sysconf_names:
2063
    try:
2064
      MAXFD = os.sysconf('SC_OPEN_MAX')
2065
      if MAXFD < 0:
2066
        MAXFD = 1024
2067
    except OSError:
2068
      MAXFD = 1024
2069
  else:
2070
    MAXFD = 1024
2071
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2072
  if (maxfd == resource.RLIM_INFINITY):
2073
    maxfd = MAXFD
2074

    
2075
  # Iterate through and close all file descriptors (except the standard ones)
2076
  for fd in range(3, maxfd):
2077
    if noclose_fds and fd in noclose_fds:
2078
      continue
2079
    _CloseFDNoErr(fd)
2080

    
2081

    
2082
def Mlockall(_ctypes=ctypes):
2083
  """Lock current process' virtual address space into RAM.
2084

2085
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2086
  see mlock(2) for more details. This function requires ctypes module.
2087

2088
  @raises errors.NoCtypesError: if ctypes module is not found
2089

2090
  """
2091
  if _ctypes is None:
2092
    raise errors.NoCtypesError()
2093

    
2094
  libc = _ctypes.cdll.LoadLibrary("libc.so.6")
2095
  if libc is None:
2096
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2097
    return
2098

    
2099
  # Some older version of the ctypes module don't have built-in functionality
2100
  # to access the errno global variable, where function error codes are stored.
2101
  # By declaring this variable as a pointer to an integer we can then access
2102
  # its value correctly, should the mlockall call fail, in order to see what
2103
  # the actual error code was.
2104
  # pylint: disable-msg=W0212
2105
  libc.__errno_location.restype = _ctypes.POINTER(_ctypes.c_int)
2106

    
2107
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2108
    # pylint: disable-msg=W0212
2109
    logging.error("Cannot set memory lock: %s",
2110
                  os.strerror(libc.__errno_location().contents.value))
2111
    return
2112

    
2113
  logging.debug("Memory lock set")
2114

    
2115

    
2116
def Daemonize(logfile):
2117
  """Daemonize the current process.
2118

2119
  This detaches the current process from the controlling terminal and
2120
  runs it in the background as a daemon.
2121

2122
  @type logfile: str
2123
  @param logfile: the logfile to which we should redirect stdout/stderr
2124
  @rtype: int
2125
  @return: the value zero
2126

2127
  """
2128
  # pylint: disable-msg=W0212
2129
  # yes, we really want os._exit
2130
  UMASK = 077
2131
  WORKDIR = "/"
2132

    
2133
  # this might fail
2134
  pid = os.fork()
2135
  if (pid == 0):  # The first child.
2136
    os.setsid()
2137
    # this might fail
2138
    pid = os.fork() # Fork a second child.
2139
    if (pid == 0):  # The second child.
2140
      os.chdir(WORKDIR)
2141
      os.umask(UMASK)
2142
    else:
2143
      # exit() or _exit()?  See below.
2144
      os._exit(0) # Exit parent (the first child) of the second child.
2145
  else:
2146
    os._exit(0) # Exit parent of the first child.
2147

    
2148
  for fd in range(3):
2149
    _CloseFDNoErr(fd)
2150
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2151
  assert i == 0, "Can't close/reopen stdin"
2152
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2153
  assert i == 1, "Can't close/reopen stdout"
2154
  # Duplicate standard output to standard error.
2155
  os.dup2(1, 2)
2156
  return 0
2157

    
2158

    
2159
def DaemonPidFileName(name):
2160
  """Compute a ganeti pid file absolute path
2161

2162
  @type name: str
2163
  @param name: the daemon name
2164
  @rtype: str
2165
  @return: the full path to the pidfile corresponding to the given
2166
      daemon name
2167

2168
  """
2169
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2170

    
2171

    
2172
def EnsureDaemon(name):
2173
  """Check for and start daemon if not alive.
2174

2175
  """
2176
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2177
  if result.failed:
2178
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2179
                  name, result.fail_reason, result.output)
2180
    return False
2181

    
2182
  return True
2183

    
2184

    
2185
def StopDaemon(name):
2186
  """Stop daemon
2187

2188
  """
2189
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2190
  if result.failed:
2191
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2192
                  name, result.fail_reason, result.output)
2193
    return False
2194

    
2195
  return True
2196

    
2197

    
2198
def WritePidFile(name):
2199
  """Write the current process pidfile.
2200

2201
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2202

2203
  @type name: str
2204
  @param name: the daemon name to use
2205
  @raise errors.GenericError: if the pid file already exists and
2206
      points to a live process
2207

2208
  """
2209
  pid = os.getpid()
2210
  pidfilename = DaemonPidFileName(name)
2211
  if IsProcessAlive(ReadPidFile(pidfilename)):
2212
    raise errors.GenericError("%s contains a live process" % pidfilename)
2213

    
2214
  WriteFile(pidfilename, data="%d\n" % pid)
2215

    
2216

    
2217
def RemovePidFile(name):
2218
  """Remove the current process pidfile.
2219

2220
  Any errors are ignored.
2221

2222
  @type name: str
2223
  @param name: the daemon name used to derive the pidfile name
2224

2225
  """
2226
  pidfilename = DaemonPidFileName(name)
2227
  # TODO: we could check here that the file contains our pid
2228
  try:
2229
    RemoveFile(pidfilename)
2230
  except: # pylint: disable-msg=W0702
2231
    pass
2232

    
2233

    
2234
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2235
                waitpid=False):
2236
  """Kill a process given by its pid.
2237

2238
  @type pid: int
2239
  @param pid: The PID to terminate.
2240
  @type signal_: int
2241
  @param signal_: The signal to send, by default SIGTERM
2242
  @type timeout: int
2243
  @param timeout: The timeout after which, if the process is still alive,
2244
                  a SIGKILL will be sent. If not positive, no such checking
2245
                  will be done
2246
  @type waitpid: boolean
2247
  @param waitpid: If true, we should waitpid on this process after
2248
      sending signals, since it's our own child and otherwise it
2249
      would remain as zombie
2250

2251
  """
2252
  def _helper(pid, signal_, wait):
2253
    """Simple helper to encapsulate the kill/waitpid sequence"""
2254
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2255
      try:
2256
        os.waitpid(pid, os.WNOHANG)
2257
      except OSError:
2258
        pass
2259

    
2260
  if pid <= 0:
2261
    # kill with pid=0 == suicide
2262
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2263

    
2264
  if not IsProcessAlive(pid):
2265
    return
2266

    
2267
  _helper(pid, signal_, waitpid)
2268

    
2269
  if timeout <= 0:
2270
    return
2271

    
2272
  def _CheckProcess():
2273
    if not IsProcessAlive(pid):
2274
      return
2275

    
2276
    try:
2277
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2278
    except OSError:
2279
      raise RetryAgain()
2280

    
2281
    if result_pid > 0:
2282
      return
2283

    
2284
    raise RetryAgain()
2285

    
2286
  try:
2287
    # Wait up to $timeout seconds
2288
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2289
  except RetryTimeout:
2290
    pass
2291

    
2292
  if IsProcessAlive(pid):
2293
    # Kill process if it's still alive
2294
    _helper(pid, signal.SIGKILL, waitpid)
2295

    
2296

    
2297
def FindFile(name, search_path, test=os.path.exists):
2298
  """Look for a filesystem object in a given path.
2299

2300
  This is an abstract method to search for filesystem object (files,
2301
  dirs) under a given search path.
2302

2303
  @type name: str
2304
  @param name: the name to look for
2305
  @type search_path: str
2306
  @param search_path: location to start at
2307
  @type test: callable
2308
  @param test: a function taking one argument that should return True
2309
      if the a given object is valid; the default value is
2310
      os.path.exists, causing only existing files to be returned
2311
  @rtype: str or None
2312
  @return: full path to the object if found, None otherwise
2313

2314
  """
2315
  # validate the filename mask
2316
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2317
    logging.critical("Invalid value passed for external script name: '%s'",
2318
                     name)
2319
    return None
2320

    
2321
  for dir_name in search_path:
2322
    # FIXME: investigate switch to PathJoin
2323
    item_name = os.path.sep.join([dir_name, name])
2324
    # check the user test and that we're indeed resolving to the given
2325
    # basename
2326
    if test(item_name) and os.path.basename(item_name) == name:
2327
      return item_name
2328
  return None
2329

    
2330

    
2331
def CheckVolumeGroupSize(vglist, vgname, minsize):
2332
  """Checks if the volume group list is valid.
2333

2334
  The function will check if a given volume group is in the list of
2335
  volume groups and has a minimum size.
2336

2337
  @type vglist: dict
2338
  @param vglist: dictionary of volume group names and their size
2339
  @type vgname: str
2340
  @param vgname: the volume group we should check
2341
  @type minsize: int
2342
  @param minsize: the minimum size we accept
2343
  @rtype: None or str
2344
  @return: None for success, otherwise the error message
2345

2346
  """
2347
  vgsize = vglist.get(vgname, None)
2348
  if vgsize is None:
2349
    return "volume group '%s' missing" % vgname
2350
  elif vgsize < minsize:
2351
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2352
            (vgname, minsize, vgsize))
2353
  return None
2354

    
2355

    
2356
def SplitTime(value):
2357
  """Splits time as floating point number into a tuple.
2358

2359
  @param value: Time in seconds
2360
  @type value: int or float
2361
  @return: Tuple containing (seconds, microseconds)
2362

2363
  """
2364
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2365

    
2366
  assert 0 <= seconds, \
2367
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2368
  assert 0 <= microseconds <= 999999, \
2369
    "Microseconds must be 0-999999, but are %s" % microseconds
2370

    
2371
  return (int(seconds), int(microseconds))
2372

    
2373

    
2374
def MergeTime(timetuple):
2375
  """Merges a tuple into time as a floating point number.
2376

2377
  @param timetuple: Time as tuple, (seconds, microseconds)
2378
  @type timetuple: tuple
2379
  @return: Time as a floating point number expressed in seconds
2380

2381
  """
2382
  (seconds, microseconds) = timetuple
2383

    
2384
  assert 0 <= seconds, \
2385
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2386
  assert 0 <= microseconds <= 999999, \
2387
    "Microseconds must be 0-999999, but are %s" % microseconds
2388

    
2389
  return float(seconds) + (float(microseconds) * 0.000001)
2390

    
2391

    
2392
class LogFileHandler(logging.FileHandler):
2393
  """Log handler that doesn't fallback to stderr.
2394

2395
  When an error occurs while writing on the logfile, logging.FileHandler tries
2396
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2397
  the logfile. This class avoids failures reporting errors to /dev/console.
2398

2399
  """
2400
  def __init__(self, filename, mode="a", encoding=None):
2401
    """Open the specified file and use it as the stream for logging.
2402

2403
    Also open /dev/console to report errors while logging.
2404

2405
    """
2406
    logging.FileHandler.__init__(self, filename, mode, encoding)
2407
    self.console = open(constants.DEV_CONSOLE, "a")
2408

    
2409
  def handleError(self, record): # pylint: disable-msg=C0103
2410
    """Handle errors which occur during an emit() call.
2411

2412
    Try to handle errors with FileHandler method, if it fails write to
2413
    /dev/console.
2414

2415
    """
2416
    try:
2417
      logging.FileHandler.handleError(self, record)
2418
    except Exception: # pylint: disable-msg=W0703
2419
      try:
2420
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2421
      except Exception: # pylint: disable-msg=W0703
2422
        # Log handler tried everything it could, now just give up
2423
        pass
2424

    
2425

    
2426
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2427
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2428
                 console_logging=False):
2429
  """Configures the logging module.
2430

2431
  @type logfile: str
2432
  @param logfile: the filename to which we should log
2433
  @type debug: integer
2434
  @param debug: if greater than zero, enable debug messages, otherwise
2435
      only those at C{INFO} and above level
2436
  @type stderr_logging: boolean
2437
  @param stderr_logging: whether we should also log to the standard error
2438
  @type program: str
2439
  @param program: the name under which we should log messages
2440
  @type multithreaded: boolean
2441
  @param multithreaded: if True, will add the thread name to the log file
2442
  @type syslog: string
2443
  @param syslog: one of 'no', 'yes', 'only':
2444
      - if no, syslog is not used
2445
      - if yes, syslog is used (in addition to file-logging)
2446
      - if only, only syslog is used
2447
  @type console_logging: boolean
2448
  @param console_logging: if True, will use a FileHandler which falls back to
2449
      the system console if logging fails
2450
  @raise EnvironmentError: if we can't open the log file and
2451
      syslog/stderr logging is disabled
2452

2453
  """
2454
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2455
  sft = program + "[%(process)d]:"
2456
  if multithreaded:
2457
    fmt += "/%(threadName)s"
2458
    sft += " (%(threadName)s)"
2459
  if debug:
2460
    fmt += " %(module)s:%(lineno)s"
2461
    # no debug info for syslog loggers
2462
  fmt += " %(levelname)s %(message)s"
2463
  # yes, we do want the textual level, as remote syslog will probably
2464
  # lose the error level, and it's easier to grep for it
2465
  sft += " %(levelname)s %(message)s"
2466
  formatter = logging.Formatter(fmt)
2467
  sys_fmt = logging.Formatter(sft)
2468

    
2469
  root_logger = logging.getLogger("")
2470
  root_logger.setLevel(logging.NOTSET)
2471

    
2472
  # Remove all previously setup handlers
2473
  for handler in root_logger.handlers:
2474
    handler.close()
2475
    root_logger.removeHandler(handler)
2476

    
2477
  if stderr_logging:
2478
    stderr_handler = logging.StreamHandler()
2479
    stderr_handler.setFormatter(formatter)
2480
    if debug:
2481
      stderr_handler.setLevel(logging.NOTSET)
2482
    else:
2483
      stderr_handler.setLevel(logging.CRITICAL)
2484
    root_logger.addHandler(stderr_handler)
2485

    
2486
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2487
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2488
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2489
                                                    facility)
2490
    syslog_handler.setFormatter(sys_fmt)
2491
    # Never enable debug over syslog
2492
    syslog_handler.setLevel(logging.INFO)
2493
    root_logger.addHandler(syslog_handler)
2494

    
2495
  if syslog != constants.SYSLOG_ONLY:
2496
    # this can fail, if the logging directories are not setup or we have
2497
    # a permisssion problem; in this case, it's best to log but ignore
2498
    # the error if stderr_logging is True, and if false we re-raise the
2499
    # exception since otherwise we could run but without any logs at all
2500
    try:
2501
      if console_logging:
2502
        logfile_handler = LogFileHandler(logfile)
2503
      else:
2504
        logfile_handler = logging.FileHandler(logfile)
2505
      logfile_handler.setFormatter(formatter)
2506
      if debug:
2507
        logfile_handler.setLevel(logging.DEBUG)
2508
      else:
2509
        logfile_handler.setLevel(logging.INFO)
2510
      root_logger.addHandler(logfile_handler)
2511
    except EnvironmentError:
2512
      if stderr_logging or syslog == constants.SYSLOG_YES:
2513
        logging.exception("Failed to enable logging to file '%s'", logfile)
2514
      else:
2515
        # we need to re-raise the exception
2516
        raise
2517

    
2518

    
2519
def IsNormAbsPath(path):
2520
  """Check whether a path is absolute and also normalized
2521

2522
  This avoids things like /dir/../../other/path to be valid.
2523

2524
  """
2525
  return os.path.normpath(path) == path and os.path.isabs(path)
2526

    
2527

    
2528
def PathJoin(*args):
2529
  """Safe-join a list of path components.
2530

2531
  Requirements:
2532
      - the first argument must be an absolute path
2533
      - no component in the path must have backtracking (e.g. /../),
2534
        since we check for normalization at the end
2535

2536
  @param args: the path components to be joined
2537
  @raise ValueError: for invalid paths
2538

2539
  """
2540
  # ensure we're having at least one path passed in
2541
  assert args
2542
  # ensure the first component is an absolute and normalized path name
2543
  root = args[0]
2544
  if not IsNormAbsPath(root):
2545
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2546
  result = os.path.join(*args)
2547
  # ensure that the whole path is normalized
2548
  if not IsNormAbsPath(result):
2549
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2550
  # check that we're still under the original prefix
2551
  prefix = os.path.commonprefix([root, result])
2552
  if prefix != root:
2553
    raise ValueError("Error: path joining resulted in different prefix"
2554
                     " (%s != %s)" % (prefix, root))
2555
  return result
2556

    
2557

    
2558
def TailFile(fname, lines=20):
2559
  """Return the last lines from a file.
2560

2561
  @note: this function will only read and parse the last 4KB of
2562
      the file; if the lines are very long, it could be that less
2563
      than the requested number of lines are returned
2564

2565
  @param fname: the file name
2566
  @type lines: int
2567
  @param lines: the (maximum) number of lines to return
2568

2569
  """
2570
  fd = open(fname, "r")
2571
  try:
2572
    fd.seek(0, 2)
2573
    pos = fd.tell()
2574
    pos = max(0, pos-4096)
2575
    fd.seek(pos, 0)
2576
    raw_data = fd.read()
2577
  finally:
2578
    fd.close()
2579

    
2580
  rows = raw_data.splitlines()
2581
  return rows[-lines:]
2582

    
2583

    
2584
def FormatTimestampWithTZ(secs):
2585
  """Formats a Unix timestamp with the local timezone.
2586

2587
  """
2588
  return time.strftime("%F %T %Z", time.gmtime(secs))
2589

    
2590

    
2591
def _ParseAsn1Generalizedtime(value):
2592
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2593

2594
  @type value: string
2595
  @param value: ASN1 GENERALIZEDTIME timestamp
2596

2597
  """
2598
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2599
  if m:
2600
    # We have an offset
2601
    asn1time = m.group(1)
2602
    hours = int(m.group(2))
2603
    minutes = int(m.group(3))
2604
    utcoffset = (60 * hours) + minutes
2605
  else:
2606
    if not value.endswith("Z"):
2607
      raise ValueError("Missing timezone")
2608
    asn1time = value[:-1]
2609
    utcoffset = 0
2610

    
2611
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2612

    
2613
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2614

    
2615
  return calendar.timegm(tt.utctimetuple())
2616

    
2617

    
2618
def GetX509CertValidity(cert):
2619
  """Returns the validity period of the certificate.
2620

2621
  @type cert: OpenSSL.crypto.X509
2622
  @param cert: X509 certificate object
2623

2624
  """
2625
  # The get_notBefore and get_notAfter functions are only supported in
2626
  # pyOpenSSL 0.7 and above.
2627
  try:
2628
    get_notbefore_fn = cert.get_notBefore
2629
  except AttributeError:
2630
    not_before = None
2631
  else:
2632
    not_before_asn1 = get_notbefore_fn()
2633

    
2634
    if not_before_asn1 is None:
2635
      not_before = None
2636
    else:
2637
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2638

    
2639
  try:
2640
    get_notafter_fn = cert.get_notAfter
2641
  except AttributeError:
2642
    not_after = None
2643
  else:
2644
    not_after_asn1 = get_notafter_fn()
2645

    
2646
    if not_after_asn1 is None:
2647
      not_after = None
2648
    else:
2649
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2650

    
2651
  return (not_before, not_after)
2652

    
2653

    
2654
def _VerifyCertificateInner(expired, not_before, not_after, now,
2655
                            warn_days, error_days):
2656
  """Verifies certificate validity.
2657

2658
  @type expired: bool
2659
  @param expired: Whether pyOpenSSL considers the certificate as expired
2660
  @type not_before: number or None
2661
  @param not_before: Unix timestamp before which certificate is not valid
2662
  @type not_after: number or None
2663
  @param not_after: Unix timestamp after which certificate is invalid
2664
  @type now: number
2665
  @param now: Current time as Unix timestamp
2666
  @type warn_days: number or None
2667
  @param warn_days: How many days before expiration a warning should be reported
2668
  @type error_days: number or None
2669
  @param error_days: How many days before expiration an error should be reported
2670

2671
  """
2672
  if expired:
2673
    msg = "Certificate is expired"
2674

    
2675
    if not_before is not None and not_after is not None:
2676
      msg += (" (valid from %s to %s)" %
2677
              (FormatTimestampWithTZ(not_before),
2678
               FormatTimestampWithTZ(not_after)))
2679
    elif not_before is not None:
2680
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2681
    elif not_after is not None:
2682
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2683

    
2684
    return (CERT_ERROR, msg)
2685

    
2686
  elif not_before is not None and not_before > now:
2687
    return (CERT_WARNING,
2688
            "Certificate not yet valid (valid from %s)" %
2689
            FormatTimestampWithTZ(not_before))
2690

    
2691
  elif not_after is not None:
2692
    remaining_days = int((not_after - now) / (24 * 3600))
2693

    
2694
    msg = "Certificate expires in about %d days" % remaining_days
2695

    
2696
    if error_days is not None and remaining_days <= error_days:
2697
      return (CERT_ERROR, msg)
2698

    
2699
    if warn_days is not None and remaining_days <= warn_days:
2700
      return (CERT_WARNING, msg)
2701

    
2702
  return (None, None)
2703

    
2704

    
2705
def VerifyX509Certificate(cert, warn_days, error_days):
2706
  """Verifies a certificate for LUVerifyCluster.
2707

2708
  @type cert: OpenSSL.crypto.X509
2709
  @param cert: X509 certificate object
2710
  @type warn_days: number or None
2711
  @param warn_days: How many days before expiration a warning should be reported
2712
  @type error_days: number or None
2713
  @param error_days: How many days before expiration an error should be reported
2714

2715
  """
2716
  # Depending on the pyOpenSSL version, this can just return (None, None)
2717
  (not_before, not_after) = GetX509CertValidity(cert)
2718

    
2719
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2720
                                 time.time(), warn_days, error_days)
2721

    
2722

    
2723
def SignX509Certificate(cert, key, salt):
2724
  """Sign a X509 certificate.
2725

2726
  An RFC822-like signature header is added in front of the certificate.
2727

2728
  @type cert: OpenSSL.crypto.X509
2729
  @param cert: X509 certificate object
2730
  @type key: string
2731
  @param key: Key for HMAC
2732
  @type salt: string
2733
  @param salt: Salt for HMAC
2734
  @rtype: string
2735
  @return: Serialized and signed certificate in PEM format
2736

2737
  """
2738
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2739
    raise errors.GenericError("Invalid salt: %r" % salt)
2740

    
2741
  # Dumping as PEM here ensures the certificate is in a sane format
2742
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2743

    
2744
  return ("%s: %s/%s\n\n%s" %
2745
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2746
           Sha1Hmac(key, cert_pem, salt=salt),
2747
           cert_pem))
2748

    
2749

    
2750
def _ExtractX509CertificateSignature(cert_pem):
2751
  """Helper function to extract signature from X509 certificate.
2752

2753
  """
2754
  # Extract signature from original PEM data
2755
  for line in cert_pem.splitlines():
2756
    if line.startswith("---"):
2757
      break
2758

    
2759
    m = X509_SIGNATURE.match(line.strip())
2760
    if m:
2761
      return (m.group("salt"), m.group("sign"))
2762

    
2763
  raise errors.GenericError("X509 certificate signature is missing")
2764

    
2765

    
2766
def LoadSignedX509Certificate(cert_pem, key):
2767
  """Verifies a signed X509 certificate.
2768

2769
  @type cert_pem: string
2770
  @param cert_pem: Certificate in PEM format and with signature header
2771
  @type key: string
2772
  @param key: Key for HMAC
2773
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2774
  @return: X509 certificate object and salt
2775

2776
  """
2777
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2778

    
2779
  # Load certificate
2780
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2781

    
2782
  # Dump again to ensure it's in a sane format
2783
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2784

    
2785
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2786
    raise errors.GenericError("X509 certificate signature is invalid")
2787

    
2788
  return (cert, salt)
2789

    
2790

    
2791
def Sha1Hmac(key, text, salt=None):
2792
  """Calculates the HMAC-SHA1 digest of a text.
2793

2794
  HMAC is defined in RFC2104.
2795

2796
  @type key: string
2797
  @param key: Secret key
2798
  @type text: string
2799

2800
  """
2801
  if salt:
2802
    salted_text = salt + text
2803
  else:
2804
    salted_text = text
2805

    
2806
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2807

    
2808

    
2809
def VerifySha1Hmac(key, text, digest, salt=None):
2810
  """Verifies the HMAC-SHA1 digest of a text.
2811

2812
  HMAC is defined in RFC2104.
2813

2814
  @type key: string
2815
  @param key: Secret key
2816
  @type text: string
2817
  @type digest: string
2818
  @param digest: Expected digest
2819
  @rtype: bool
2820
  @return: Whether HMAC-SHA1 digest matches
2821

2822
  """
2823
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2824

    
2825

    
2826
def SafeEncode(text):
2827
  """Return a 'safe' version of a source string.
2828

2829
  This function mangles the input string and returns a version that
2830
  should be safe to display/encode as ASCII. To this end, we first
2831
  convert it to ASCII using the 'backslashreplace' encoding which
2832
  should get rid of any non-ASCII chars, and then we process it
2833
  through a loop copied from the string repr sources in the python; we
2834
  don't use string_escape anymore since that escape single quotes and
2835
  backslashes too, and that is too much; and that escaping is not
2836
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2837

2838
  @type text: str or unicode
2839
  @param text: input data
2840
  @rtype: str
2841
  @return: a safe version of text
2842

2843
  """
2844
  if isinstance(text, unicode):
2845
    # only if unicode; if str already, we handle it below
2846
    text = text.encode('ascii', 'backslashreplace')
2847
  resu = ""
2848
  for char in text:
2849
    c = ord(char)
2850
    if char  == '\t':
2851
      resu += r'\t'
2852
    elif char == '\n':
2853
      resu += r'\n'
2854
    elif char == '\r':
2855
      resu += r'\'r'
2856
    elif c < 32 or c >= 127: # non-printable
2857
      resu += "\\x%02x" % (c & 0xff)
2858
    else:
2859
      resu += char
2860
  return resu
2861

    
2862

    
2863
def UnescapeAndSplit(text, sep=","):
2864
  """Split and unescape a string based on a given separator.
2865

2866
  This function splits a string based on a separator where the
2867
  separator itself can be escape in order to be an element of the
2868
  elements. The escaping rules are (assuming coma being the
2869
  separator):
2870
    - a plain , separates the elements
2871
    - a sequence \\\\, (double backslash plus comma) is handled as a
2872
      backslash plus a separator comma
2873
    - a sequence \, (backslash plus comma) is handled as a
2874
      non-separator comma
2875

2876
  @type text: string
2877
  @param text: the string to split
2878
  @type sep: string
2879
  @param text: the separator
2880
  @rtype: string
2881
  @return: a list of strings
2882

2883
  """
2884
  # we split the list by sep (with no escaping at this stage)
2885
  slist = text.split(sep)
2886
  # next, we revisit the elements and if any of them ended with an odd
2887
  # number of backslashes, then we join it with the next
2888
  rlist = []
2889
  while slist:
2890
    e1 = slist.pop(0)
2891
    if e1.endswith("\\"):
2892
      num_b = len(e1) - len(e1.rstrip("\\"))
2893
      if num_b % 2 == 1:
2894
        e2 = slist.pop(0)
2895
        # here the backslashes remain (all), and will be reduced in
2896
        # the next step
2897
        rlist.append(e1 + sep + e2)
2898
        continue
2899
    rlist.append(e1)
2900
  # finally, replace backslash-something with something
2901
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2902
  return rlist
2903

    
2904

    
2905
def CommaJoin(names):
2906
  """Nicely join a set of identifiers.
2907

2908
  @param names: set, list or tuple
2909
  @return: a string with the formatted results
2910

2911
  """
2912
  return ", ".join([str(val) for val in names])
2913

    
2914

    
2915
def BytesToMebibyte(value):
2916
  """Converts bytes to mebibytes.
2917

2918
  @type value: int
2919
  @param value: Value in bytes
2920
  @rtype: int
2921
  @return: Value in mebibytes
2922

2923
  """
2924
  return int(round(value / (1024.0 * 1024.0), 0))
2925

    
2926

    
2927
def CalculateDirectorySize(path):
2928
  """Calculates the size of a directory recursively.
2929

2930
  @type path: string
2931
  @param path: Path to directory
2932
  @rtype: int
2933
  @return: Size in mebibytes
2934

2935
  """
2936
  size = 0
2937

    
2938
  for (curpath, _, files) in os.walk(path):
2939
    for filename in files:
2940
      st = os.lstat(PathJoin(curpath, filename))
2941
      size += st.st_size
2942

    
2943
  return BytesToMebibyte(size)
2944

    
2945

    
2946
def GetMounts(filename=constants.PROC_MOUNTS):
2947
  """Returns the list of mounted filesystems.
2948

2949
  This function is Linux-specific.
2950

2951
  @param filename: path of mounts file (/proc/mounts by default)
2952
  @rtype: list of tuples
2953
  @return: list of mount entries (device, mountpoint, fstype, options)
2954

2955
  """
2956
  # TODO(iustin): investigate non-Linux options (e.g. via mount output)
2957
  data = []
2958
  mountlines = ReadFile(filename).splitlines()
2959
  for line in mountlines:
2960
    device, mountpoint, fstype, options, _ = line.split(None, 4)
2961
    data.append((device, mountpoint, fstype, options))
2962

    
2963
  return data
2964

    
2965

    
2966
def GetFilesystemStats(path):
2967
  """Returns the total and free space on a filesystem.
2968

2969
  @type path: string
2970
  @param path: Path on filesystem to be examined
2971
  @rtype: int
2972
  @return: tuple of (Total space, Free space) in mebibytes
2973

2974
  """
2975
  st = os.statvfs(path)
2976

    
2977
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2978
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2979
  return (tsize, fsize)
2980

    
2981

    
2982
def RunInSeparateProcess(fn, *args):
2983
  """Runs a function in a separate process.
2984

2985
  Note: Only boolean return values are supported.
2986

2987
  @type fn: callable
2988
  @param fn: Function to be called
2989
  @rtype: bool
2990
  @return: Function's result
2991

2992
  """
2993
  pid = os.fork()
2994
  if pid == 0:
2995
    # Child process
2996
    try:
2997
      # In case the function uses temporary files
2998
      ResetTempfileModule()
2999

    
3000
      # Call function
3001
      result = int(bool(fn(*args)))
3002
      assert result in (0, 1)
3003
    except: # pylint: disable-msg=W0702
3004
      logging.exception("Error while calling function in separate process")
3005
      # 0 and 1 are reserved for the return value
3006
      result = 33
3007

    
3008
    os._exit(result) # pylint: disable-msg=W0212
3009

    
3010
  # Parent process
3011

    
3012
  # Avoid zombies and check exit code
3013
  (_, status) = os.waitpid(pid, 0)
3014

    
3015
  if os.WIFSIGNALED(status):
3016
    exitcode = None
3017
    signum = os.WTERMSIG(status)
3018
  else:
3019
    exitcode = os.WEXITSTATUS(status)
3020
    signum = None
3021

    
3022
  if not (exitcode in (0, 1) and signum is None):
3023
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3024
                              (exitcode, signum))
3025

    
3026
  return bool(exitcode)
3027

    
3028

    
3029
def IgnoreProcessNotFound(fn, *args, **kwargs):
3030
  """Ignores ESRCH when calling a process-related function.
3031

3032
  ESRCH is raised when a process is not found.
3033

3034
  @rtype: bool
3035
  @return: Whether process was found
3036

3037
  """
3038
  try:
3039
    fn(*args, **kwargs)
3040
  except EnvironmentError, err:
3041
    # Ignore ESRCH
3042
    if err.errno == errno.ESRCH:
3043
      return False
3044
    raise
3045

    
3046
  return True
3047

    
3048

    
3049
def IgnoreSignals(fn, *args, **kwargs):
3050
  """Tries to call a function ignoring failures due to EINTR.
3051

3052
  """
3053
  try:
3054
    return fn(*args, **kwargs)
3055
  except EnvironmentError, err:
3056
    if err.errno == errno.EINTR:
3057
      return None
3058
    else:
3059
      raise
3060
  except (select.error, socket.error), err:
3061
    # In python 2.6 and above select.error is an IOError, so it's handled
3062
    # above, in 2.5 and below it's not, and it's handled here.
3063
    if err.args and err.args[0] == errno.EINTR:
3064
      return None
3065
    else:
3066
      raise
3067

    
3068

    
3069
def LockFile(fd):
3070
  """Locks a file using POSIX locks.
3071

3072
  @type fd: int
3073
  @param fd: the file descriptor we need to lock
3074

3075
  """
3076
  try:
3077
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3078
  except IOError, err:
3079
    if err.errno == errno.EAGAIN:
3080
      raise errors.LockError("File already locked")
3081
    raise
3082

    
3083

    
3084
def FormatTime(val):
3085
  """Formats a time value.
3086

3087
  @type val: float or None
3088
  @param val: the timestamp as returned by time.time()
3089
  @return: a string value or N/A if we don't have a valid timestamp
3090

3091
  """
3092
  if val is None or not isinstance(val, (int, float)):
3093
    return "N/A"
3094
  # these two codes works on Linux, but they are not guaranteed on all
3095
  # platforms
3096
  return time.strftime("%F %T", time.localtime(val))
3097

    
3098

    
3099
def FormatSeconds(secs):
3100
  """Formats seconds for easier reading.
3101

3102
  @type secs: number
3103
  @param secs: Number of seconds
3104
  @rtype: string
3105
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3106

3107
  """
3108
  parts = []
3109

    
3110
  secs = round(secs, 0)
3111

    
3112
  if secs > 0:
3113
    # Negative values would be a bit tricky
3114
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3115
      (complete, secs) = divmod(secs, one)
3116
      if complete or parts:
3117
        parts.append("%d%s" % (complete, unit))
3118

    
3119
  parts.append("%ds" % secs)
3120

    
3121
  return " ".join(parts)
3122

    
3123

    
3124
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3125
  """Reads the watcher pause file.
3126

3127
  @type filename: string
3128
  @param filename: Path to watcher pause file
3129
  @type now: None, float or int
3130
  @param now: Current time as Unix timestamp
3131
  @type remove_after: int
3132
  @param remove_after: Remove watcher pause file after specified amount of
3133
    seconds past the pause end time
3134

3135
  """
3136
  if now is None:
3137
    now = time.time()
3138

    
3139
  try:
3140
    value = ReadFile(filename)
3141
  except IOError, err:
3142
    if err.errno != errno.ENOENT:
3143
      raise
3144
    value = None
3145

    
3146
  if value is not None:
3147
    try:
3148
      value = int(value)
3149
    except ValueError:
3150
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3151
                       " removing it"), filename)
3152
      RemoveFile(filename)
3153
      value = None
3154

    
3155
    if value is not None:
3156
      # Remove file if it's outdated
3157
      if now > (value + remove_after):
3158
        RemoveFile(filename)
3159
        value = None
3160

    
3161
      elif now > value:
3162
        value = None
3163

    
3164
  return value
3165

    
3166

    
3167
class RetryTimeout(Exception):
3168
  """Retry loop timed out.
3169

3170
  Any arguments which was passed by the retried function to RetryAgain will be
3171
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3172
  the RaiseInner helper method will reraise it.
3173

3174
  """
3175
  def RaiseInner(self):
3176
    if self.args and isinstance(self.args[0], Exception):
3177
      raise self.args[0]
3178
    else:
3179
      raise RetryTimeout(*self.args)
3180

    
3181

    
3182
class RetryAgain(Exception):
3183
  """Retry again.
3184

3185
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3186
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3187
  of the RetryTimeout() method can be used to reraise it.
3188

3189
  """
3190

    
3191

    
3192
class _RetryDelayCalculator(object):
3193
  """Calculator for increasing delays.
3194

3195
  """
3196
  __slots__ = [
3197
    "_factor",
3198
    "_limit",
3199
    "_next",
3200
    "_start",
3201
    ]
3202

    
3203
  def __init__(self, start, factor, limit):
3204
    """Initializes this class.
3205

3206
    @type start: float
3207
    @param start: Initial delay
3208
    @type factor: float
3209
    @param factor: Factor for delay increase
3210
    @type limit: float or None
3211
    @param limit: Upper limit for delay or None for no limit
3212

3213
    """
3214
    assert start > 0.0
3215
    assert factor >= 1.0
3216
    assert limit is None or limit >= 0.0
3217

    
3218
    self._start = start
3219
    self._factor = factor
3220
    self._limit = limit
3221

    
3222
    self._next = start
3223

    
3224
  def __call__(self):
3225
    """Returns current delay and calculates the next one.
3226

3227
    """
3228
    current = self._next
3229

    
3230
    # Update for next run
3231
    if self._limit is None or self._next < self._limit:
3232
      self._next = min(self._limit, self._next * self._factor)
3233

    
3234
    return current
3235

    
3236

    
3237
#: Special delay to specify whole remaining timeout
3238
RETRY_REMAINING_TIME = object()
3239

    
3240

    
3241
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3242
          _time_fn=time.time):
3243
  """Call a function repeatedly until it succeeds.
3244

3245
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3246
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3247
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3248

3249
  C{delay} can be one of the following:
3250
    - callable returning the delay length as a float
3251
    - Tuple of (start, factor, limit)
3252
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3253
      useful when overriding L{wait_fn} to wait for an external event)
3254
    - A static delay as a number (int or float)
3255

3256
  @type fn: callable
3257
  @param fn: Function to be called
3258
  @param delay: Either a callable (returning the delay), a tuple of (start,
3259
                factor, limit) (see L{_RetryDelayCalculator}),
3260
                L{RETRY_REMAINING_TIME} or a number (int or float)
3261
  @type timeout: float
3262
  @param timeout: Total timeout
3263
  @type wait_fn: callable
3264
  @param wait_fn: Waiting function
3265
  @return: Return value of function
3266

3267
  """
3268
  assert callable(fn)
3269
  assert callable(wait_fn)
3270
  assert callable(_time_fn)
3271

    
3272
  if args is None:
3273
    args = []
3274

    
3275
  end_time = _time_fn() + timeout
3276

    
3277
  if callable(delay):
3278
    # External function to calculate delay
3279
    calc_delay = delay
3280

    
3281
  elif isinstance(delay, (tuple, list)):
3282
    # Increasing delay with optional upper boundary
3283
    (start, factor, limit) = delay
3284
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3285

    
3286
  elif delay is RETRY_REMAINING_TIME:
3287
    # Always use the remaining time
3288
    calc_delay = None
3289

    
3290
  else:
3291
    # Static delay
3292
    calc_delay = lambda: delay
3293

    
3294
  assert calc_delay is None or callable(calc_delay)
3295

    
3296
  while True:
3297
    retry_args = []
3298
    try:
3299
      # pylint: disable-msg=W0142
3300
      return fn(*args)
3301
    except RetryAgain, err:
3302
      retry_args = err.args
3303
    except RetryTimeout:
3304
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3305
                                   " handle RetryTimeout")
3306

    
3307
    remaining_time = end_time - _time_fn()
3308

    
3309
    if remaining_time < 0.0:
3310
      # pylint: disable-msg=W0142
3311
      raise RetryTimeout(*retry_args)
3312

    
3313
    assert remaining_time >= 0.0
3314

    
3315
    if calc_delay is None:
3316
      wait_fn(remaining_time)
3317
    else:
3318
      current_delay = calc_delay()
3319
      if current_delay > 0.0:
3320
        wait_fn(current_delay)
3321

    
3322

    
3323
def GetClosedTempfile(*args, **kwargs):
3324
  """Creates a temporary file and returns its path.
3325

3326
  """
3327
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3328
  _CloseFDNoErr(fd)
3329
  return path
3330

    
3331

    
3332
def GenerateSelfSignedX509Cert(common_name, validity):
3333
  """Generates a self-signed X509 certificate.
3334

3335
  @type common_name: string
3336
  @param common_name: commonName value
3337
  @type validity: int
3338
  @param validity: Validity for certificate in seconds
3339

3340
  """
3341
  # Create private and public key
3342
  key = OpenSSL.crypto.PKey()
3343
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3344

    
3345
  # Create self-signed certificate
3346
  cert = OpenSSL.crypto.X509()
3347
  if common_name:
3348
    cert.get_subject().CN = common_name
3349
  cert.set_serial_number(1)
3350
  cert.gmtime_adj_notBefore(0)
3351
  cert.gmtime_adj_notAfter(validity)
3352
  cert.set_issuer(cert.get_subject())
3353
  cert.set_pubkey(key)
3354
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3355

    
3356
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3357
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3358

    
3359
  return (key_pem, cert_pem)
3360

    
3361

    
3362
def GenerateSelfSignedSslCert(filename, common_name=constants.X509_CERT_CN,
3363
                              validity=constants.X509_CERT_DEFAULT_VALIDITY):
3364
  """Legacy function to generate self-signed X509 certificate.
3365

3366
  @type filename: str
3367
  @param filename: path to write certificate to
3368
  @type common_name: string
3369
  @param common_name: commonName value
3370
  @type validity: int
3371
  @param validity: validity of certificate in number of days
3372

3373
  """
3374
  # TODO: Investigate using the cluster name instead of X505_CERT_CN for
3375
  # common_name, as cluster-renames are very seldom, and it'd be nice if RAPI
3376
  # and node daemon certificates have the proper Subject/Issuer.
3377
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(common_name,
3378
                                                   validity * 24 * 60 * 60)
3379

    
3380
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3381

    
3382

    
3383
class FileLock(object):
3384
  """Utility class for file locks.
3385

3386
  """
3387
  def __init__(self, fd, filename):
3388
    """Constructor for FileLock.
3389

3390
    @type fd: file
3391
    @param fd: File object
3392
    @type filename: str
3393
    @param filename: Path of the file opened at I{fd}
3394

3395
    """
3396
    self.fd = fd
3397
    self.filename = filename
3398

    
3399
  @classmethod
3400
  def Open(cls, filename):
3401
    """Creates and opens a file to be used as a file-based lock.
3402

3403
    @type filename: string
3404
    @param filename: path to the file to be locked
3405

3406
    """
3407
    # Using "os.open" is necessary to allow both opening existing file
3408
    # read/write and creating if not existing. Vanilla "open" will truncate an
3409
    # existing file -or- allow creating if not existing.
3410
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3411
               filename)
3412

    
3413
  def __del__(self):
3414
    self.Close()
3415

    
3416
  def Close(self):
3417
    """Close the file and release the lock.
3418

3419
    """
3420
    if hasattr(self, "fd") and self.fd:
3421
      self.fd.close()
3422
      self.fd = None
3423

    
3424
  def _flock(self, flag, blocking, timeout, errmsg):
3425
    """Wrapper for fcntl.flock.
3426

3427
    @type flag: int
3428
    @param flag: operation flag
3429
    @type blocking: bool
3430
    @param blocking: whether the operation should be done in blocking mode.
3431
    @type timeout: None or float
3432
    @param timeout: for how long the operation should be retried (implies
3433
                    non-blocking mode).
3434
    @type errmsg: string
3435
    @param errmsg: error message in case operation fails.
3436

3437
    """
3438
    assert self.fd, "Lock was closed"
3439
    assert timeout is None or timeout >= 0, \
3440
      "If specified, timeout must be positive"
3441
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3442

    
3443
    # When a timeout is used, LOCK_NB must always be set
3444
    if not (timeout is None and blocking):
3445
      flag |= fcntl.LOCK_NB
3446

    
3447
    if timeout is None:
3448
      self._Lock(self.fd, flag, timeout)
3449
    else:
3450
      try:
3451
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3452
              args=(self.fd, flag, timeout))
3453
      except RetryTimeout:
3454
        raise errors.LockError(errmsg)
3455

    
3456
  @staticmethod
3457
  def _Lock(fd, flag, timeout):
3458
    try:
3459
      fcntl.flock(fd, flag)
3460
    except IOError, err:
3461
      if timeout is not None and err.errno == errno.EAGAIN:
3462
        raise RetryAgain()
3463

    
3464
      logging.exception("fcntl.flock failed")
3465
      raise
3466

    
3467
  def Exclusive(self, blocking=False, timeout=None):
3468
    """Locks the file in exclusive mode.
3469

3470
    @type blocking: boolean
3471
    @param blocking: whether to block and wait until we
3472
        can lock the file or return immediately
3473
    @type timeout: int or None
3474
    @param timeout: if not None, the duration to wait for the lock
3475
        (in blocking mode)
3476

3477
    """
3478
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3479
                "Failed to lock %s in exclusive mode" % self.filename)
3480

    
3481
  def Shared(self, blocking=False, timeout=None):
3482
    """Locks the file in shared mode.
3483

3484
    @type blocking: boolean
3485
    @param blocking: whether to block and wait until we
3486
        can lock the file or return immediately
3487
    @type timeout: int or None
3488
    @param timeout: if not None, the duration to wait for the lock
3489
        (in blocking mode)
3490

3491
    """
3492
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3493
                "Failed to lock %s in shared mode" % self.filename)
3494

    
3495
  def Unlock(self, blocking=True, timeout=None):
3496
    """Unlocks the file.
3497

3498
    According to C{flock(2)}, unlocking can also be a nonblocking
3499
    operation::
3500

3501
      To make a non-blocking request, include LOCK_NB with any of the above
3502
      operations.
3503

3504
    @type blocking: boolean
3505
    @param blocking: whether to block and wait until we
3506
        can lock the file or return immediately
3507
    @type timeout: int or None
3508
    @param timeout: if not None, the duration to wait for the lock
3509
        (in blocking mode)
3510

3511
    """
3512
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3513
                "Failed to unlock %s" % self.filename)
3514

    
3515

    
3516
class LineSplitter:
3517
  """Splits data chunks into lines separated by newline.
3518

3519
  Instances provide a file-like interface.
3520

3521
  """
3522
  def __init__(self, line_fn, *args):
3523
    """Initializes this class.
3524

3525
    @type line_fn: callable
3526
    @param line_fn: Function called for each line, first parameter is line
3527
    @param args: Extra arguments for L{line_fn}
3528

3529
    """
3530
    assert callable(line_fn)
3531

    
3532
    if args:
3533
      # Python 2.4 doesn't have functools.partial yet
3534
      self._line_fn = \
3535
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3536
    else:
3537
      self._line_fn = line_fn
3538

    
3539
    self._lines = collections.deque()
3540
    self._buffer = ""
3541

    
3542
  def write(self, data):
3543
    parts = (self._buffer + data).split("\n")
3544
    self._buffer = parts.pop()
3545
    self._lines.extend(parts)
3546

    
3547
  def flush(self):
3548
    while self._lines:
3549
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3550

    
3551
  def close(self):
3552
    self.flush()
3553
    if self._buffer:
3554
      self._line_fn(self._buffer)
3555

    
3556

    
3557
def SignalHandled(signums):
3558
  """Signal Handled decoration.
3559

3560
  This special decorator installs a signal handler and then calls the target
3561
  function. The function must accept a 'signal_handlers' keyword argument,
3562
  which will contain a dict indexed by signal number, with SignalHandler
3563
  objects as values.
3564

3565
  The decorator can be safely stacked with iself, to handle multiple signals
3566
  with different handlers.
3567

3568
  @type signums: list
3569
  @param signums: signals to intercept
3570

3571
  """
3572
  def wrap(fn):
3573
    def sig_function(*args, **kwargs):
3574
      assert 'signal_handlers' not in kwargs or \
3575
             kwargs['signal_handlers'] is None or \
3576
             isinstance(kwargs['signal_handlers'], dict), \
3577
             "Wrong signal_handlers parameter in original function call"
3578
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3579
        signal_handlers = kwargs['signal_handlers']
3580
      else:
3581
        signal_handlers = {}
3582
        kwargs['signal_handlers'] = signal_handlers
3583
      sighandler = SignalHandler(signums)
3584
      try:
3585
        for sig in signums:
3586
          signal_handlers[sig] = sighandler
3587
        return fn(*args, **kwargs)
3588
      finally:
3589
        sighandler.Reset()
3590
    return sig_function
3591
  return wrap
3592

    
3593

    
3594
class SignalWakeupFd(object):
3595
  try:
3596
    # This is only supported in Python 2.5 and above (some distributions
3597
    # backported it to Python 2.4)
3598
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3599
  except AttributeError:
3600
    # Not supported
3601
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3602
      return -1
3603
  else:
3604
    def _SetWakeupFd(self, fd):
3605
      return self._set_wakeup_fd_fn(fd)
3606

    
3607
  def __init__(self):
3608
    """Initializes this class.
3609

3610
    """
3611
    (read_fd, write_fd) = os.pipe()
3612

    
3613
    # Once these succeeded, the file descriptors will be closed automatically.
3614
    # Buffer size 0 is important, otherwise .read() with a specified length
3615
    # might buffer data and the file descriptors won't be marked readable.
3616
    self._read_fh = os.fdopen(read_fd, "r", 0)
3617
    self._write_fh = os.fdopen(write_fd, "w", 0)
3618

    
3619
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3620

    
3621
    # Utility functions
3622
    self.fileno = self._read_fh.fileno
3623
    self.read = self._read_fh.read
3624

    
3625
  def Reset(self):
3626
    """Restores the previous wakeup file descriptor.
3627

3628
    """
3629
    if hasattr(self, "_previous") and self._previous is not None:
3630
      self._SetWakeupFd(self._previous)
3631
      self._previous = None
3632

    
3633
  def Notify(self):
3634
    """Notifies the wakeup file descriptor.
3635

3636
    """
3637
    self._write_fh.write("\0")
3638

    
3639
  def __del__(self):
3640
    """Called before object deletion.
3641

3642
    """
3643
    self.Reset()
3644

    
3645

    
3646
class SignalHandler(object):
3647
  """Generic signal handler class.
3648

3649
  It automatically restores the original handler when deconstructed or
3650
  when L{Reset} is called. You can either pass your own handler
3651
  function in or query the L{called} attribute to detect whether the
3652
  signal was sent.
3653

3654
  @type signum: list
3655
  @ivar signum: the signals we handle
3656
  @type called: boolean
3657
  @ivar called: tracks whether any of the signals have been raised
3658

3659
  """
3660
  def __init__(self, signum, handler_fn=None, wakeup=None):
3661
    """Constructs a new SignalHandler instance.
3662

3663
    @type signum: int or list of ints
3664
    @param signum: Single signal number or set of signal numbers
3665
    @type handler_fn: callable
3666
    @param handler_fn: Signal handling function
3667

3668
    """
3669
    assert handler_fn is None or callable(handler_fn)
3670

    
3671
    self.signum = set(signum)
3672
    self.called = False
3673

    
3674
    self._handler_fn = handler_fn
3675
    self._wakeup = wakeup
3676

    
3677
    self._previous = {}
3678
    try:
3679
      for signum in self.signum:
3680
        # Setup handler
3681
        prev_handler = signal.signal(signum, self._HandleSignal)
3682
        try:
3683
          self._previous[signum] = prev_handler
3684
        except:
3685
          # Restore previous handler
3686
          signal.signal(signum, prev_handler)
3687
          raise
3688
    except:
3689
      # Reset all handlers
3690
      self.Reset()
3691
      # Here we have a race condition: a handler may have already been called,
3692
      # but there's not much we can do about it at this point.
3693
      raise
3694

    
3695
  def __del__(self):
3696
    self.Reset()
3697

    
3698
  def Reset(self):
3699
    """Restore previous handler.
3700

3701
    This will reset all the signals to their previous handlers.
3702

3703
    """
3704
    for signum, prev_handler in self._previous.items():
3705
      signal.signal(signum, prev_handler)
3706
      # If successful, remove from dict
3707
      del self._previous[signum]
3708

    
3709
  def Clear(self):
3710
    """Unsets the L{called} flag.
3711

3712
    This function can be used in case a signal may arrive several times.
3713

3714
    """
3715
    self.called = False
3716

    
3717
  def _HandleSignal(self, signum, frame):
3718
    """Actual signal handling function.
3719

3720
    """
3721
    # This is not nice and not absolutely atomic, but it appears to be the only
3722
    # solution in Python -- there are no atomic types.
3723
    self.called = True
3724

    
3725
    if self._wakeup:
3726
      # Notify whoever is interested in signals
3727
      self._wakeup.Notify()
3728

    
3729
    if self._handler_fn:
3730
      self._handler_fn(signum, frame)
3731

    
3732

    
3733
class FieldSet(object):
3734
  """A simple field set.
3735

3736
  Among the features are:
3737
    - checking if a string is among a list of static string or regex objects
3738
    - checking if a whole list of string matches
3739
    - returning the matching groups from a regex match
3740

3741
  Internally, all fields are held as regular expression objects.
3742

3743
  """
3744
  def __init__(self, *items):
3745
    self.items = [re.compile("^%s$" % value) for value in items]
3746

    
3747
  def Extend(self, other_set):
3748
    """Extend the field set with the items from another one"""
3749
    self.items.extend(other_set.items)
3750

    
3751
  def Matches(self, field):
3752
    """Checks if a field matches the current set
3753

3754
    @type field: str
3755
    @param field: the string to match
3756
    @return: either None or a regular expression match object
3757

3758
    """
3759
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3760
      return m
3761
    return None
3762

    
3763
  def NonMatching(self, items):
3764
    """Returns the list of fields not matching the current set
3765

3766
    @type items: list
3767
    @param items: the list of fields to check
3768
    @rtype: list
3769
    @return: list of non-matching fields
3770

3771
    """
3772
    return [val for val in items if not self.Matches(val)]