Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 68857643

History | View | Annotate | Download (86.1 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51

    
52
from cStringIO import StringIO
53

    
54
try:
55
  from hashlib import sha1
56
except ImportError:
57
  import sha as sha1
58

    
59
from ganeti import errors
60
from ganeti import constants
61

    
62

    
63
_locksheld = []
64
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
65

    
66
debug_locks = False
67

    
68
#: when set to True, L{RunCmd} is disabled
69
no_fork = False
70

    
71
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
72

    
73
HEX_CHAR_RE = r"[a-zA-Z0-9]"
74
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
75
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
76
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
77
                             HEX_CHAR_RE, HEX_CHAR_RE),
78
                            re.S | re.I)
79

    
80

    
81
class RunResult(object):
82
  """Holds the result of running external programs.
83

84
  @type exit_code: int
85
  @ivar exit_code: the exit code of the program, or None (if the program
86
      didn't exit())
87
  @type signal: int or None
88
  @ivar signal: the signal that caused the program to finish, or None
89
      (if the program wasn't terminated by a signal)
90
  @type stdout: str
91
  @ivar stdout: the standard output of the program
92
  @type stderr: str
93
  @ivar stderr: the standard error of the program
94
  @type failed: boolean
95
  @ivar failed: True in case the program was
96
      terminated by a signal or exited with a non-zero exit code
97
  @ivar fail_reason: a string detailing the termination reason
98

99
  """
100
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
101
               "failed", "fail_reason", "cmd"]
102

    
103

    
104
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
105
    self.cmd = cmd
106
    self.exit_code = exit_code
107
    self.signal = signal_
108
    self.stdout = stdout
109
    self.stderr = stderr
110
    self.failed = (signal_ is not None or exit_code != 0)
111

    
112
    if self.signal is not None:
113
      self.fail_reason = "terminated by signal %s" % self.signal
114
    elif self.exit_code is not None:
115
      self.fail_reason = "exited with exit code %s" % self.exit_code
116
    else:
117
      self.fail_reason = "unable to determine termination reason"
118

    
119
    if self.failed:
120
      logging.debug("Command '%s' failed (%s); output: %s",
121
                    self.cmd, self.fail_reason, self.output)
122

    
123
  def _GetOutput(self):
124
    """Returns the combined stdout and stderr for easier usage.
125

126
    """
127
    return self.stdout + self.stderr
128

    
129
  output = property(_GetOutput, None, None, "Return full output")
130

    
131

    
132
def _BuildCmdEnvironment(env, reset):
133
  """Builds the environment for an external program.
134

135
  """
136
  if reset:
137
    cmd_env = {}
138
  else:
139
    cmd_env = os.environ.copy()
140
    cmd_env["LC_ALL"] = "C"
141

    
142
  if env is not None:
143
    cmd_env.update(env)
144

    
145
  return cmd_env
146

    
147

    
148
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
149
  """Execute a (shell) command.
150

151
  The command should not read from its standard input, as it will be
152
  closed.
153

154
  @type cmd: string or list
155
  @param cmd: Command to run
156
  @type env: dict
157
  @param env: Additional environment variables
158
  @type output: str
159
  @param output: if desired, the output of the command can be
160
      saved in a file instead of the RunResult instance; this
161
      parameter denotes the file name (if not None)
162
  @type cwd: string
163
  @param cwd: if specified, will be used as the working
164
      directory for the command; the default will be /
165
  @type reset_env: boolean
166
  @param reset_env: whether to reset or keep the default os environment
167
  @rtype: L{RunResult}
168
  @return: RunResult instance
169
  @raise errors.ProgrammerError: if we call this when forks are disabled
170

171
  """
172
  if no_fork:
173
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
174

    
175
  if isinstance(cmd, basestring):
176
    strcmd = cmd
177
    shell = True
178
  else:
179
    cmd = [str(val) for val in cmd]
180
    strcmd = ShellQuoteArgs(cmd)
181
    shell = False
182

    
183
  if output:
184
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
185
  else:
186
    logging.debug("RunCmd %s", strcmd)
187

    
188
  cmd_env = _BuildCmdEnvironment(env, reset_env)
189

    
190
  try:
191
    if output is None:
192
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
193
    else:
194
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
195
      out = err = ""
196
  except OSError, err:
197
    if err.errno == errno.ENOENT:
198
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
199
                               (strcmd, err))
200
    else:
201
      raise
202

    
203
  if status >= 0:
204
    exitcode = status
205
    signal_ = None
206
  else:
207
    exitcode = None
208
    signal_ = -status
209

    
210
  return RunResult(exitcode, signal_, out, err, strcmd)
211

    
212

    
213
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
214
                pidfile=None):
215
  """Start a daemon process after forking twice.
216

217
  @type cmd: string or list
218
  @param cmd: Command to run
219
  @type env: dict
220
  @param env: Additional environment variables
221
  @type cwd: string
222
  @param cwd: Working directory for the program
223
  @type output: string
224
  @param output: Path to file in which to save the output
225
  @type output_fd: int
226
  @param output_fd: File descriptor for output
227
  @type pidfile: string
228
  @param pidfile: Process ID file
229
  @rtype: int
230
  @return: Daemon process ID
231
  @raise errors.ProgrammerError: if we call this when forks are disabled
232

233
  """
234
  if no_fork:
235
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
236
                                 " disabled")
237

    
238
  if output and not (bool(output) ^ (output_fd is not None)):
239
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
240
                                 " specified")
241

    
242
  if isinstance(cmd, basestring):
243
    cmd = ["/bin/sh", "-c", cmd]
244

    
245
  strcmd = ShellQuoteArgs(cmd)
246

    
247
  if output:
248
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
249
  else:
250
    logging.debug("StartDaemon %s", strcmd)
251

    
252
  cmd_env = _BuildCmdEnvironment(env, False)
253

    
254
  # Create pipe for sending PID back
255
  (pidpipe_read, pidpipe_write) = os.pipe()
256
  try:
257
    try:
258
      # Create pipe for sending error messages
259
      (errpipe_read, errpipe_write) = os.pipe()
260
      try:
261
        try:
262
          # First fork
263
          pid = os.fork()
264
          if pid == 0:
265
            try:
266
              # Child process, won't return
267
              _StartDaemonChild(errpipe_read, errpipe_write,
268
                                pidpipe_read, pidpipe_write,
269
                                cmd, cmd_env, cwd,
270
                                output, output_fd, pidfile)
271
            finally:
272
              # Well, maybe child process failed
273
              os._exit(1) # pylint: disable-msg=W0212
274
        finally:
275
          _CloseFDNoErr(errpipe_write)
276

    
277
        # Wait for daemon to be started (or an error message to arrive) and read
278
        # up to 100 KB as an error message
279
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
280
      finally:
281
        _CloseFDNoErr(errpipe_read)
282
    finally:
283
      _CloseFDNoErr(pidpipe_write)
284

    
285
    # Read up to 128 bytes for PID
286
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
287
  finally:
288
    _CloseFDNoErr(pidpipe_read)
289

    
290
  # Try to avoid zombies by waiting for child process
291
  try:
292
    os.waitpid(pid, 0)
293
  except OSError:
294
    pass
295

    
296
  if errormsg:
297
    raise errors.OpExecError("Error when starting daemon process: %r" %
298
                             errormsg)
299

    
300
  try:
301
    return int(pidtext)
302
  except (ValueError, TypeError), err:
303
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
304
                             (pidtext, err))
305

    
306

    
307
def _StartDaemonChild(errpipe_read, errpipe_write,
308
                      pidpipe_read, pidpipe_write,
309
                      args, env, cwd,
310
                      output, fd_output, pidfile):
311
  """Child process for starting daemon.
312

313
  """
314
  try:
315
    # Close parent's side
316
    _CloseFDNoErr(errpipe_read)
317
    _CloseFDNoErr(pidpipe_read)
318

    
319
    # First child process
320
    os.chdir("/")
321
    os.umask(077)
322
    os.setsid()
323

    
324
    # And fork for the second time
325
    pid = os.fork()
326
    if pid != 0:
327
      # Exit first child process
328
      os._exit(0) # pylint: disable-msg=W0212
329

    
330
    # Make sure pipe is closed on execv* (and thereby notifies original process)
331
    SetCloseOnExecFlag(errpipe_write, True)
332

    
333
    # List of file descriptors to be left open
334
    noclose_fds = [errpipe_write]
335

    
336
    # Open PID file
337
    if pidfile:
338
      try:
339
        # TODO: Atomic replace with another locked file instead of writing into
340
        # it after creating
341
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
342

    
343
        # Lock the PID file (and fail if not possible to do so). Any code
344
        # wanting to send a signal to the daemon should try to lock the PID
345
        # file before reading it. If acquiring the lock succeeds, the daemon is
346
        # no longer running and the signal should not be sent.
347
        LockFile(fd_pidfile)
348

    
349
        os.write(fd_pidfile, "%d\n" % os.getpid())
350
      except Exception, err:
351
        raise Exception("Creating and locking PID file failed: %s" % err)
352

    
353
      # Keeping the file open to hold the lock
354
      noclose_fds.append(fd_pidfile)
355

    
356
      SetCloseOnExecFlag(fd_pidfile, False)
357
    else:
358
      fd_pidfile = None
359

    
360
    # Open /dev/null
361
    fd_devnull = os.open(os.devnull, os.O_RDWR)
362

    
363
    assert not output or (bool(output) ^ (fd_output is not None))
364

    
365
    if fd_output is not None:
366
      pass
367
    elif output:
368
      # Open output file
369
      try:
370
        # TODO: Implement flag to set append=yes/no
371
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
372
      except EnvironmentError, err:
373
        raise Exception("Opening output file failed: %s" % err)
374
    else:
375
      fd_output = fd_devnull
376

    
377
    # Redirect standard I/O
378
    os.dup2(fd_devnull, 0)
379
    os.dup2(fd_output, 1)
380
    os.dup2(fd_output, 2)
381

    
382
    # Send daemon PID to parent
383
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
384

    
385
    # Close all file descriptors except stdio and error message pipe
386
    CloseFDs(noclose_fds=noclose_fds)
387

    
388
    # Change working directory
389
    os.chdir(cwd)
390

    
391
    if env is None:
392
      os.execvp(args[0], args)
393
    else:
394
      os.execvpe(args[0], args, env)
395
  except: # pylint: disable-msg=W0702
396
    try:
397
      # Report errors to original process
398
      buf = str(sys.exc_info()[1])
399

    
400
      RetryOnSignal(os.write, errpipe_write, buf)
401
    except: # pylint: disable-msg=W0702
402
      # Ignore errors in error handling
403
      pass
404

    
405
  os._exit(1) # pylint: disable-msg=W0212
406

    
407

    
408
def _RunCmdPipe(cmd, env, via_shell, cwd):
409
  """Run a command and return its output.
410

411
  @type  cmd: string or list
412
  @param cmd: Command to run
413
  @type env: dict
414
  @param env: The environment to use
415
  @type via_shell: bool
416
  @param via_shell: if we should run via the shell
417
  @type cwd: string
418
  @param cwd: the working directory for the program
419
  @rtype: tuple
420
  @return: (out, err, status)
421

422
  """
423
  poller = select.poll()
424
  child = subprocess.Popen(cmd, shell=via_shell,
425
                           stderr=subprocess.PIPE,
426
                           stdout=subprocess.PIPE,
427
                           stdin=subprocess.PIPE,
428
                           close_fds=True, env=env,
429
                           cwd=cwd)
430

    
431
  child.stdin.close()
432
  poller.register(child.stdout, select.POLLIN)
433
  poller.register(child.stderr, select.POLLIN)
434
  out = StringIO()
435
  err = StringIO()
436
  fdmap = {
437
    child.stdout.fileno(): (out, child.stdout),
438
    child.stderr.fileno(): (err, child.stderr),
439
    }
440
  for fd in fdmap:
441
    SetNonblockFlag(fd, True)
442

    
443
  while fdmap:
444
    pollresult = RetryOnSignal(poller.poll)
445

    
446
    for fd, event in pollresult:
447
      if event & select.POLLIN or event & select.POLLPRI:
448
        data = fdmap[fd][1].read()
449
        # no data from read signifies EOF (the same as POLLHUP)
450
        if not data:
451
          poller.unregister(fd)
452
          del fdmap[fd]
453
          continue
454
        fdmap[fd][0].write(data)
455
      if (event & select.POLLNVAL or event & select.POLLHUP or
456
          event & select.POLLERR):
457
        poller.unregister(fd)
458
        del fdmap[fd]
459

    
460
  out = out.getvalue()
461
  err = err.getvalue()
462

    
463
  status = child.wait()
464
  return out, err, status
465

    
466

    
467
def _RunCmdFile(cmd, env, via_shell, output, cwd):
468
  """Run a command and save its output to a file.
469

470
  @type  cmd: string or list
471
  @param cmd: Command to run
472
  @type env: dict
473
  @param env: The environment to use
474
  @type via_shell: bool
475
  @param via_shell: if we should run via the shell
476
  @type output: str
477
  @param output: the filename in which to save the output
478
  @type cwd: string
479
  @param cwd: the working directory for the program
480
  @rtype: int
481
  @return: the exit status
482

483
  """
484
  fh = open(output, "a")
485
  try:
486
    child = subprocess.Popen(cmd, shell=via_shell,
487
                             stderr=subprocess.STDOUT,
488
                             stdout=fh,
489
                             stdin=subprocess.PIPE,
490
                             close_fds=True, env=env,
491
                             cwd=cwd)
492

    
493
    child.stdin.close()
494
    status = child.wait()
495
  finally:
496
    fh.close()
497
  return status
498

    
499

    
500
def SetCloseOnExecFlag(fd, enable):
501
  """Sets or unsets the close-on-exec flag on a file descriptor.
502

503
  @type fd: int
504
  @param fd: File descriptor
505
  @type enable: bool
506
  @param enable: Whether to set or unset it.
507

508
  """
509
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
510

    
511
  if enable:
512
    flags |= fcntl.FD_CLOEXEC
513
  else:
514
    flags &= ~fcntl.FD_CLOEXEC
515

    
516
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
517

    
518

    
519
def SetNonblockFlag(fd, enable):
520
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
521

522
  @type fd: int
523
  @param fd: File descriptor
524
  @type enable: bool
525
  @param enable: Whether to set or unset it
526

527
  """
528
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
529

    
530
  if enable:
531
    flags |= os.O_NONBLOCK
532
  else:
533
    flags &= ~os.O_NONBLOCK
534

    
535
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
536

    
537

    
538
def RetryOnSignal(fn, *args, **kwargs):
539
  """Calls a function again if it failed due to EINTR.
540

541
  """
542
  while True:
543
    try:
544
      return fn(*args, **kwargs)
545
    except EnvironmentError, err:
546
      if err.errno != errno.EINTR:
547
        raise
548
    except select.error, err:
549
      if not (err.args and err.args[0] == errno.EINTR):
550
        raise
551

    
552

    
553
def RunParts(dir_name, env=None, reset_env=False):
554
  """Run Scripts or programs in a directory
555

556
  @type dir_name: string
557
  @param dir_name: absolute path to a directory
558
  @type env: dict
559
  @param env: The environment to use
560
  @type reset_env: boolean
561
  @param reset_env: whether to reset or keep the default os environment
562
  @rtype: list of tuples
563
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
564

565
  """
566
  rr = []
567

    
568
  try:
569
    dir_contents = ListVisibleFiles(dir_name)
570
  except OSError, err:
571
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
572
    return rr
573

    
574
  for relname in sorted(dir_contents):
575
    fname = PathJoin(dir_name, relname)
576
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
577
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
578
      rr.append((relname, constants.RUNPARTS_SKIP, None))
579
    else:
580
      try:
581
        result = RunCmd([fname], env=env, reset_env=reset_env)
582
      except Exception, err: # pylint: disable-msg=W0703
583
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
584
      else:
585
        rr.append((relname, constants.RUNPARTS_RUN, result))
586

    
587
  return rr
588

    
589

    
590
def RemoveFile(filename):
591
  """Remove a file ignoring some errors.
592

593
  Remove a file, ignoring non-existing ones or directories. Other
594
  errors are passed.
595

596
  @type filename: str
597
  @param filename: the file to be removed
598

599
  """
600
  try:
601
    os.unlink(filename)
602
  except OSError, err:
603
    if err.errno not in (errno.ENOENT, errno.EISDIR):
604
      raise
605

    
606

    
607
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
608
  """Renames a file.
609

610
  @type old: string
611
  @param old: Original path
612
  @type new: string
613
  @param new: New path
614
  @type mkdir: bool
615
  @param mkdir: Whether to create target directory if it doesn't exist
616
  @type mkdir_mode: int
617
  @param mkdir_mode: Mode for newly created directories
618

619
  """
620
  try:
621
    return os.rename(old, new)
622
  except OSError, err:
623
    # In at least one use case of this function, the job queue, directory
624
    # creation is very rare. Checking for the directory before renaming is not
625
    # as efficient.
626
    if mkdir and err.errno == errno.ENOENT:
627
      # Create directory and try again
628
      dirname = os.path.dirname(new)
629
      try:
630
        os.makedirs(dirname, mode=mkdir_mode)
631
      except OSError, err:
632
        # Ignore EEXIST. This is only handled in os.makedirs as included in
633
        # Python 2.5 and above.
634
        if err.errno != errno.EEXIST or not os.path.exists(dirname):
635
          raise
636

    
637
      return os.rename(old, new)
638

    
639
    raise
640

    
641

    
642
def ResetTempfileModule():
643
  """Resets the random name generator of the tempfile module.
644

645
  This function should be called after C{os.fork} in the child process to
646
  ensure it creates a newly seeded random generator. Otherwise it would
647
  generate the same random parts as the parent process. If several processes
648
  race for the creation of a temporary file, this could lead to one not getting
649
  a temporary name.
650

651
  """
652
  # pylint: disable-msg=W0212
653
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
654
    tempfile._once_lock.acquire()
655
    try:
656
      # Reset random name generator
657
      tempfile._name_sequence = None
658
    finally:
659
      tempfile._once_lock.release()
660
  else:
661
    logging.critical("The tempfile module misses at least one of the"
662
                     " '_once_lock' and '_name_sequence' attributes")
663

    
664

    
665
def _FingerprintFile(filename):
666
  """Compute the fingerprint of a file.
667

668
  If the file does not exist, a None will be returned
669
  instead.
670

671
  @type filename: str
672
  @param filename: the filename to checksum
673
  @rtype: str
674
  @return: the hex digest of the sha checksum of the contents
675
      of the file
676

677
  """
678
  if not (os.path.exists(filename) and os.path.isfile(filename)):
679
    return None
680

    
681
  f = open(filename)
682

    
683
  if callable(sha1):
684
    fp = sha1()
685
  else:
686
    fp = sha1.new()
687
  while True:
688
    data = f.read(4096)
689
    if not data:
690
      break
691

    
692
    fp.update(data)
693

    
694
  return fp.hexdigest()
695

    
696

    
697
def FingerprintFiles(files):
698
  """Compute fingerprints for a list of files.
699

700
  @type files: list
701
  @param files: the list of filename to fingerprint
702
  @rtype: dict
703
  @return: a dictionary filename: fingerprint, holding only
704
      existing files
705

706
  """
707
  ret = {}
708

    
709
  for filename in files:
710
    cksum = _FingerprintFile(filename)
711
    if cksum:
712
      ret[filename] = cksum
713

    
714
  return ret
715

    
716

    
717
def ForceDictType(target, key_types, allowed_values=None):
718
  """Force the values of a dict to have certain types.
719

720
  @type target: dict
721
  @param target: the dict to update
722
  @type key_types: dict
723
  @param key_types: dict mapping target dict keys to types
724
                    in constants.ENFORCEABLE_TYPES
725
  @type allowed_values: list
726
  @keyword allowed_values: list of specially allowed values
727

728
  """
729
  if allowed_values is None:
730
    allowed_values = []
731

    
732
  if not isinstance(target, dict):
733
    msg = "Expected dictionary, got '%s'" % target
734
    raise errors.TypeEnforcementError(msg)
735

    
736
  for key in target:
737
    if key not in key_types:
738
      msg = "Unknown key '%s'" % key
739
      raise errors.TypeEnforcementError(msg)
740

    
741
    if target[key] in allowed_values:
742
      continue
743

    
744
    ktype = key_types[key]
745
    if ktype not in constants.ENFORCEABLE_TYPES:
746
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
747
      raise errors.ProgrammerError(msg)
748

    
749
    if ktype == constants.VTYPE_STRING:
750
      if not isinstance(target[key], basestring):
751
        if isinstance(target[key], bool) and not target[key]:
752
          target[key] = ''
753
        else:
754
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
755
          raise errors.TypeEnforcementError(msg)
756
    elif ktype == constants.VTYPE_BOOL:
757
      if isinstance(target[key], basestring) and target[key]:
758
        if target[key].lower() == constants.VALUE_FALSE:
759
          target[key] = False
760
        elif target[key].lower() == constants.VALUE_TRUE:
761
          target[key] = True
762
        else:
763
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
764
          raise errors.TypeEnforcementError(msg)
765
      elif target[key]:
766
        target[key] = True
767
      else:
768
        target[key] = False
769
    elif ktype == constants.VTYPE_SIZE:
770
      try:
771
        target[key] = ParseUnit(target[key])
772
      except errors.UnitParseError, err:
773
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
774
              (key, target[key], err)
775
        raise errors.TypeEnforcementError(msg)
776
    elif ktype == constants.VTYPE_INT:
777
      try:
778
        target[key] = int(target[key])
779
      except (ValueError, TypeError):
780
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
781
        raise errors.TypeEnforcementError(msg)
782

    
783

    
784
def IsProcessAlive(pid):
785
  """Check if a given pid exists on the system.
786

787
  @note: zombie status is not handled, so zombie processes
788
      will be returned as alive
789
  @type pid: int
790
  @param pid: the process ID to check
791
  @rtype: boolean
792
  @return: True if the process exists
793

794
  """
795
  if pid <= 0:
796
    return False
797

    
798
  try:
799
    os.stat("/proc/%d/status" % pid)
800
    return True
801
  except EnvironmentError, err:
802
    if err.errno in (errno.ENOENT, errno.ENOTDIR):
803
      return False
804
    raise
805

    
806

    
807
def ReadPidFile(pidfile):
808
  """Read a pid from a file.
809

810
  @type  pidfile: string
811
  @param pidfile: path to the file containing the pid
812
  @rtype: int
813
  @return: The process id, if the file exists and contains a valid PID,
814
           otherwise 0
815

816
  """
817
  try:
818
    raw_data = ReadFile(pidfile)
819
  except EnvironmentError, err:
820
    if err.errno != errno.ENOENT:
821
      logging.exception("Can't read pid file")
822
    return 0
823

    
824
  try:
825
    pid = int(raw_data)
826
  except (TypeError, ValueError), err:
827
    logging.info("Can't parse pid file contents", exc_info=True)
828
    return 0
829

    
830
  return pid
831

    
832

    
833
def MatchNameComponent(key, name_list, case_sensitive=True):
834
  """Try to match a name against a list.
835

836
  This function will try to match a name like test1 against a list
837
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
838
  this list, I{'test1'} as well as I{'test1.example'} will match, but
839
  not I{'test1.ex'}. A multiple match will be considered as no match
840
  at all (e.g. I{'test1'} against C{['test1.example.com',
841
  'test1.example.org']}), except when the key fully matches an entry
842
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
843

844
  @type key: str
845
  @param key: the name to be searched
846
  @type name_list: list
847
  @param name_list: the list of strings against which to search the key
848
  @type case_sensitive: boolean
849
  @param case_sensitive: whether to provide a case-sensitive match
850

851
  @rtype: None or str
852
  @return: None if there is no match I{or} if there are multiple matches,
853
      otherwise the element from the list which matches
854

855
  """
856
  if key in name_list:
857
    return key
858

    
859
  re_flags = 0
860
  if not case_sensitive:
861
    re_flags |= re.IGNORECASE
862
    key = key.upper()
863
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
864
  names_filtered = []
865
  string_matches = []
866
  for name in name_list:
867
    if mo.match(name) is not None:
868
      names_filtered.append(name)
869
      if not case_sensitive and key == name.upper():
870
        string_matches.append(name)
871

    
872
  if len(string_matches) == 1:
873
    return string_matches[0]
874
  if len(names_filtered) == 1:
875
    return names_filtered[0]
876
  return None
877

    
878

    
879
class HostInfo:
880
  """Class implementing resolver and hostname functionality
881

882
  """
883
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
884

    
885
  def __init__(self, name=None):
886
    """Initialize the host name object.
887

888
    If the name argument is not passed, it will use this system's
889
    name.
890

891
    """
892
    if name is None:
893
      name = self.SysName()
894

    
895
    self.query = name
896
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
897
    self.ip = self.ipaddrs[0]
898

    
899
  def ShortName(self):
900
    """Returns the hostname without domain.
901

902
    """
903
    return self.name.split('.')[0]
904

    
905
  @staticmethod
906
  def SysName():
907
    """Return the current system's name.
908

909
    This is simply a wrapper over C{socket.gethostname()}.
910

911
    """
912
    return socket.gethostname()
913

    
914
  @staticmethod
915
  def LookupHostname(hostname):
916
    """Look up hostname
917

918
    @type hostname: str
919
    @param hostname: hostname to look up
920

921
    @rtype: tuple
922
    @return: a tuple (name, aliases, ipaddrs) as returned by
923
        C{socket.gethostbyname_ex}
924
    @raise errors.ResolverError: in case of errors in resolving
925

926
    """
927
    try:
928
      result = socket.gethostbyname_ex(hostname)
929
    except socket.gaierror, err:
930
      # hostname not found in DNS
931
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
932

    
933
    return result
934

    
935
  @classmethod
936
  def NormalizeName(cls, hostname):
937
    """Validate and normalize the given hostname.
938

939
    @attention: the validation is a bit more relaxed than the standards
940
        require; most importantly, we allow underscores in names
941
    @raise errors.OpPrereqError: when the name is not valid
942

943
    """
944
    hostname = hostname.lower()
945
    if (not cls._VALID_NAME_RE.match(hostname) or
946
        # double-dots, meaning empty label
947
        ".." in hostname or
948
        # empty initial label
949
        hostname.startswith(".")):
950
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
951
                                 errors.ECODE_INVAL)
952
    if hostname.endswith("."):
953
      hostname = hostname.rstrip(".")
954
    return hostname
955

    
956

    
957
def GetHostInfo(name=None):
958
  """Lookup host name and raise an OpPrereqError for failures"""
959

    
960
  try:
961
    return HostInfo(name)
962
  except errors.ResolverError, err:
963
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
964
                               (err[0], err[2]), errors.ECODE_RESOLVER)
965

    
966

    
967
def ListVolumeGroups():
968
  """List volume groups and their size
969

970
  @rtype: dict
971
  @return:
972
       Dictionary with keys volume name and values
973
       the size of the volume
974

975
  """
976
  command = "vgs --noheadings --units m --nosuffix -o name,size"
977
  result = RunCmd(command)
978
  retval = {}
979
  if result.failed:
980
    return retval
981

    
982
  for line in result.stdout.splitlines():
983
    try:
984
      name, size = line.split()
985
      size = int(float(size))
986
    except (IndexError, ValueError), err:
987
      logging.error("Invalid output from vgs (%s): %s", err, line)
988
      continue
989

    
990
    retval[name] = size
991

    
992
  return retval
993

    
994

    
995
def BridgeExists(bridge):
996
  """Check whether the given bridge exists in the system
997

998
  @type bridge: str
999
  @param bridge: the bridge name to check
1000
  @rtype: boolean
1001
  @return: True if it does
1002

1003
  """
1004
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1005

    
1006

    
1007
def NiceSort(name_list):
1008
  """Sort a list of strings based on digit and non-digit groupings.
1009

1010
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1011
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1012
  'a11']}.
1013

1014
  The sort algorithm breaks each name in groups of either only-digits
1015
  or no-digits. Only the first eight such groups are considered, and
1016
  after that we just use what's left of the string.
1017

1018
  @type name_list: list
1019
  @param name_list: the names to be sorted
1020
  @rtype: list
1021
  @return: a copy of the name list sorted with our algorithm
1022

1023
  """
1024
  _SORTER_BASE = "(\D+|\d+)"
1025
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1026
                                                  _SORTER_BASE, _SORTER_BASE,
1027
                                                  _SORTER_BASE, _SORTER_BASE,
1028
                                                  _SORTER_BASE, _SORTER_BASE)
1029
  _SORTER_RE = re.compile(_SORTER_FULL)
1030
  _SORTER_NODIGIT = re.compile("^\D*$")
1031
  def _TryInt(val):
1032
    """Attempts to convert a variable to integer."""
1033
    if val is None or _SORTER_NODIGIT.match(val):
1034
      return val
1035
    rval = int(val)
1036
    return rval
1037

    
1038
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1039
             for name in name_list]
1040
  to_sort.sort()
1041
  return [tup[1] for tup in to_sort]
1042

    
1043

    
1044
def TryConvert(fn, val):
1045
  """Try to convert a value ignoring errors.
1046

1047
  This function tries to apply function I{fn} to I{val}. If no
1048
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1049
  the result, else it will return the original value. Any other
1050
  exceptions are propagated to the caller.
1051

1052
  @type fn: callable
1053
  @param fn: function to apply to the value
1054
  @param val: the value to be converted
1055
  @return: The converted value if the conversion was successful,
1056
      otherwise the original value.
1057

1058
  """
1059
  try:
1060
    nv = fn(val)
1061
  except (ValueError, TypeError):
1062
    nv = val
1063
  return nv
1064

    
1065

    
1066
def IsValidIP(ip):
1067
  """Verifies the syntax of an IPv4 address.
1068

1069
  This function checks if the IPv4 address passes is valid or not based
1070
  on syntax (not IP range, class calculations, etc.).
1071

1072
  @type ip: str
1073
  @param ip: the address to be checked
1074
  @rtype: a regular expression match object
1075
  @return: a regular expression match object, or None if the
1076
      address is not valid
1077

1078
  """
1079
  unit = "(0|[1-9]\d{0,2})"
1080
  #TODO: convert and return only boolean
1081
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1082

    
1083

    
1084
def IsValidShellParam(word):
1085
  """Verifies is the given word is safe from the shell's p.o.v.
1086

1087
  This means that we can pass this to a command via the shell and be
1088
  sure that it doesn't alter the command line and is passed as such to
1089
  the actual command.
1090

1091
  Note that we are overly restrictive here, in order to be on the safe
1092
  side.
1093

1094
  @type word: str
1095
  @param word: the word to check
1096
  @rtype: boolean
1097
  @return: True if the word is 'safe'
1098

1099
  """
1100
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1101

    
1102

    
1103
def BuildShellCmd(template, *args):
1104
  """Build a safe shell command line from the given arguments.
1105

1106
  This function will check all arguments in the args list so that they
1107
  are valid shell parameters (i.e. they don't contain shell
1108
  metacharacters). If everything is ok, it will return the result of
1109
  template % args.
1110

1111
  @type template: str
1112
  @param template: the string holding the template for the
1113
      string formatting
1114
  @rtype: str
1115
  @return: the expanded command line
1116

1117
  """
1118
  for word in args:
1119
    if not IsValidShellParam(word):
1120
      raise errors.ProgrammerError("Shell argument '%s' contains"
1121
                                   " invalid characters" % word)
1122
  return template % args
1123

    
1124

    
1125
def FormatUnit(value, units):
1126
  """Formats an incoming number of MiB with the appropriate unit.
1127

1128
  @type value: int
1129
  @param value: integer representing the value in MiB (1048576)
1130
  @type units: char
1131
  @param units: the type of formatting we should do:
1132
      - 'h' for automatic scaling
1133
      - 'm' for MiBs
1134
      - 'g' for GiBs
1135
      - 't' for TiBs
1136
  @rtype: str
1137
  @return: the formatted value (with suffix)
1138

1139
  """
1140
  if units not in ('m', 'g', 't', 'h'):
1141
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1142

    
1143
  suffix = ''
1144

    
1145
  if units == 'm' or (units == 'h' and value < 1024):
1146
    if units == 'h':
1147
      suffix = 'M'
1148
    return "%d%s" % (round(value, 0), suffix)
1149

    
1150
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1151
    if units == 'h':
1152
      suffix = 'G'
1153
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1154

    
1155
  else:
1156
    if units == 'h':
1157
      suffix = 'T'
1158
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1159

    
1160

    
1161
def ParseUnit(input_string):
1162
  """Tries to extract number and scale from the given string.
1163

1164
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1165
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1166
  is always an int in MiB.
1167

1168
  """
1169
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1170
  if not m:
1171
    raise errors.UnitParseError("Invalid format")
1172

    
1173
  value = float(m.groups()[0])
1174

    
1175
  unit = m.groups()[1]
1176
  if unit:
1177
    lcunit = unit.lower()
1178
  else:
1179
    lcunit = 'm'
1180

    
1181
  if lcunit in ('m', 'mb', 'mib'):
1182
    # Value already in MiB
1183
    pass
1184

    
1185
  elif lcunit in ('g', 'gb', 'gib'):
1186
    value *= 1024
1187

    
1188
  elif lcunit in ('t', 'tb', 'tib'):
1189
    value *= 1024 * 1024
1190

    
1191
  else:
1192
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1193

    
1194
  # Make sure we round up
1195
  if int(value) < value:
1196
    value += 1
1197

    
1198
  # Round up to the next multiple of 4
1199
  value = int(value)
1200
  if value % 4:
1201
    value += 4 - value % 4
1202

    
1203
  return value
1204

    
1205

    
1206
def AddAuthorizedKey(file_name, key):
1207
  """Adds an SSH public key to an authorized_keys file.
1208

1209
  @type file_name: str
1210
  @param file_name: path to authorized_keys file
1211
  @type key: str
1212
  @param key: string containing key
1213

1214
  """
1215
  key_fields = key.split()
1216

    
1217
  f = open(file_name, 'a+')
1218
  try:
1219
    nl = True
1220
    for line in f:
1221
      # Ignore whitespace changes
1222
      if line.split() == key_fields:
1223
        break
1224
      nl = line.endswith('\n')
1225
    else:
1226
      if not nl:
1227
        f.write("\n")
1228
      f.write(key.rstrip('\r\n'))
1229
      f.write("\n")
1230
      f.flush()
1231
  finally:
1232
    f.close()
1233

    
1234

    
1235
def RemoveAuthorizedKey(file_name, key):
1236
  """Removes an SSH public key from an authorized_keys file.
1237

1238
  @type file_name: str
1239
  @param file_name: path to authorized_keys file
1240
  @type key: str
1241
  @param key: string containing key
1242

1243
  """
1244
  key_fields = key.split()
1245

    
1246
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1247
  try:
1248
    out = os.fdopen(fd, 'w')
1249
    try:
1250
      f = open(file_name, 'r')
1251
      try:
1252
        for line in f:
1253
          # Ignore whitespace changes while comparing lines
1254
          if line.split() != key_fields:
1255
            out.write(line)
1256

    
1257
        out.flush()
1258
        os.rename(tmpname, file_name)
1259
      finally:
1260
        f.close()
1261
    finally:
1262
      out.close()
1263
  except:
1264
    RemoveFile(tmpname)
1265
    raise
1266

    
1267

    
1268
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1269
  """Sets the name of an IP address and hostname in /etc/hosts.
1270

1271
  @type file_name: str
1272
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1273
  @type ip: str
1274
  @param ip: the IP address
1275
  @type hostname: str
1276
  @param hostname: the hostname to be added
1277
  @type aliases: list
1278
  @param aliases: the list of aliases to add for the hostname
1279

1280
  """
1281
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1282
  # Ensure aliases are unique
1283
  aliases = UniqueSequence([hostname] + aliases)[1:]
1284

    
1285
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1286
  try:
1287
    out = os.fdopen(fd, 'w')
1288
    try:
1289
      f = open(file_name, 'r')
1290
      try:
1291
        for line in f:
1292
          fields = line.split()
1293
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1294
            continue
1295
          out.write(line)
1296

    
1297
        out.write("%s\t%s" % (ip, hostname))
1298
        if aliases:
1299
          out.write(" %s" % ' '.join(aliases))
1300
        out.write('\n')
1301

    
1302
        out.flush()
1303
        os.fsync(out)
1304
        os.chmod(tmpname, 0644)
1305
        os.rename(tmpname, file_name)
1306
      finally:
1307
        f.close()
1308
    finally:
1309
      out.close()
1310
  except:
1311
    RemoveFile(tmpname)
1312
    raise
1313

    
1314

    
1315
def AddHostToEtcHosts(hostname):
1316
  """Wrapper around SetEtcHostsEntry.
1317

1318
  @type hostname: str
1319
  @param hostname: a hostname that will be resolved and added to
1320
      L{constants.ETC_HOSTS}
1321

1322
  """
1323
  hi = HostInfo(name=hostname)
1324
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1325

    
1326

    
1327
def RemoveEtcHostsEntry(file_name, hostname):
1328
  """Removes a hostname from /etc/hosts.
1329

1330
  IP addresses without names are removed from the file.
1331

1332
  @type file_name: str
1333
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1334
  @type hostname: str
1335
  @param hostname: the hostname to be removed
1336

1337
  """
1338
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1339
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1340
  try:
1341
    out = os.fdopen(fd, 'w')
1342
    try:
1343
      f = open(file_name, 'r')
1344
      try:
1345
        for line in f:
1346
          fields = line.split()
1347
          if len(fields) > 1 and not fields[0].startswith('#'):
1348
            names = fields[1:]
1349
            if hostname in names:
1350
              while hostname in names:
1351
                names.remove(hostname)
1352
              if names:
1353
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1354
              continue
1355

    
1356
          out.write(line)
1357

    
1358
        out.flush()
1359
        os.fsync(out)
1360
        os.chmod(tmpname, 0644)
1361
        os.rename(tmpname, file_name)
1362
      finally:
1363
        f.close()
1364
    finally:
1365
      out.close()
1366
  except:
1367
    RemoveFile(tmpname)
1368
    raise
1369

    
1370

    
1371
def RemoveHostFromEtcHosts(hostname):
1372
  """Wrapper around RemoveEtcHostsEntry.
1373

1374
  @type hostname: str
1375
  @param hostname: hostname that will be resolved and its
1376
      full and shot name will be removed from
1377
      L{constants.ETC_HOSTS}
1378

1379
  """
1380
  hi = HostInfo(name=hostname)
1381
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1382
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1383

    
1384

    
1385
def TimestampForFilename():
1386
  """Returns the current time formatted for filenames.
1387

1388
  The format doesn't contain colons as some shells and applications them as
1389
  separators.
1390

1391
  """
1392
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1393

    
1394

    
1395
def CreateBackup(file_name):
1396
  """Creates a backup of a file.
1397

1398
  @type file_name: str
1399
  @param file_name: file to be backed up
1400
  @rtype: str
1401
  @return: the path to the newly created backup
1402
  @raise errors.ProgrammerError: for invalid file names
1403

1404
  """
1405
  if not os.path.isfile(file_name):
1406
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1407
                                file_name)
1408

    
1409
  prefix = ("%s.backup-%s." %
1410
            (os.path.basename(file_name), TimestampForFilename()))
1411
  dir_name = os.path.dirname(file_name)
1412

    
1413
  fsrc = open(file_name, 'rb')
1414
  try:
1415
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1416
    fdst = os.fdopen(fd, 'wb')
1417
    try:
1418
      logging.debug("Backing up %s at %s", file_name, backup_name)
1419
      shutil.copyfileobj(fsrc, fdst)
1420
    finally:
1421
      fdst.close()
1422
  finally:
1423
    fsrc.close()
1424

    
1425
  return backup_name
1426

    
1427

    
1428
def ShellQuote(value):
1429
  """Quotes shell argument according to POSIX.
1430

1431
  @type value: str
1432
  @param value: the argument to be quoted
1433
  @rtype: str
1434
  @return: the quoted value
1435

1436
  """
1437
  if _re_shell_unquoted.match(value):
1438
    return value
1439
  else:
1440
    return "'%s'" % value.replace("'", "'\\''")
1441

    
1442

    
1443
def ShellQuoteArgs(args):
1444
  """Quotes a list of shell arguments.
1445

1446
  @type args: list
1447
  @param args: list of arguments to be quoted
1448
  @rtype: str
1449
  @return: the quoted arguments concatenated with spaces
1450

1451
  """
1452
  return ' '.join([ShellQuote(i) for i in args])
1453

    
1454

    
1455
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1456
  """Simple ping implementation using TCP connect(2).
1457

1458
  Check if the given IP is reachable by doing attempting a TCP connect
1459
  to it.
1460

1461
  @type target: str
1462
  @param target: the IP or hostname to ping
1463
  @type port: int
1464
  @param port: the port to connect to
1465
  @type timeout: int
1466
  @param timeout: the timeout on the connection attempt
1467
  @type live_port_needed: boolean
1468
  @param live_port_needed: whether a closed port will cause the
1469
      function to return failure, as if there was a timeout
1470
  @type source: str or None
1471
  @param source: if specified, will cause the connect to be made
1472
      from this specific source address; failures to bind other
1473
      than C{EADDRNOTAVAIL} will be ignored
1474

1475
  """
1476
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1477

    
1478
  success = False
1479

    
1480
  if source is not None:
1481
    try:
1482
      sock.bind((source, 0))
1483
    except socket.error, (errcode, _):
1484
      if errcode == errno.EADDRNOTAVAIL:
1485
        success = False
1486

    
1487
  sock.settimeout(timeout)
1488

    
1489
  try:
1490
    sock.connect((target, port))
1491
    sock.close()
1492
    success = True
1493
  except socket.timeout:
1494
    success = False
1495
  except socket.error, (errcode, _):
1496
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1497

    
1498
  return success
1499

    
1500

    
1501
def OwnIpAddress(address):
1502
  """Check if the current host has the the given IP address.
1503

1504
  Currently this is done by TCP-pinging the address from the loopback
1505
  address.
1506

1507
  @type address: string
1508
  @param address: the address to check
1509
  @rtype: bool
1510
  @return: True if we own the address
1511

1512
  """
1513
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1514
                 source=constants.LOCALHOST_IP_ADDRESS)
1515

    
1516

    
1517
def ListVisibleFiles(path):
1518
  """Returns a list of visible files in a directory.
1519

1520
  @type path: str
1521
  @param path: the directory to enumerate
1522
  @rtype: list
1523
  @return: the list of all files not starting with a dot
1524
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1525

1526
  """
1527
  if not IsNormAbsPath(path):
1528
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1529
                                 " absolute/normalized: '%s'" % path)
1530
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1531
  files.sort()
1532
  return files
1533

    
1534

    
1535
def GetHomeDir(user, default=None):
1536
  """Try to get the homedir of the given user.
1537

1538
  The user can be passed either as a string (denoting the name) or as
1539
  an integer (denoting the user id). If the user is not found, the
1540
  'default' argument is returned, which defaults to None.
1541

1542
  """
1543
  try:
1544
    if isinstance(user, basestring):
1545
      result = pwd.getpwnam(user)
1546
    elif isinstance(user, (int, long)):
1547
      result = pwd.getpwuid(user)
1548
    else:
1549
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1550
                                   type(user))
1551
  except KeyError:
1552
    return default
1553
  return result.pw_dir
1554

    
1555

    
1556
def NewUUID():
1557
  """Returns a random UUID.
1558

1559
  @note: This is a Linux-specific method as it uses the /proc
1560
      filesystem.
1561
  @rtype: str
1562

1563
  """
1564
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1565

    
1566

    
1567
def GenerateSecret(numbytes=20):
1568
  """Generates a random secret.
1569

1570
  This will generate a pseudo-random secret returning an hex string
1571
  (so that it can be used where an ASCII string is needed).
1572

1573
  @param numbytes: the number of bytes which will be represented by the returned
1574
      string (defaulting to 20, the length of a SHA1 hash)
1575
  @rtype: str
1576
  @return: an hex representation of the pseudo-random sequence
1577

1578
  """
1579
  return os.urandom(numbytes).encode('hex')
1580

    
1581

    
1582
def EnsureDirs(dirs):
1583
  """Make required directories, if they don't exist.
1584

1585
  @param dirs: list of tuples (dir_name, dir_mode)
1586
  @type dirs: list of (string, integer)
1587

1588
  """
1589
  for dir_name, dir_mode in dirs:
1590
    try:
1591
      os.mkdir(dir_name, dir_mode)
1592
    except EnvironmentError, err:
1593
      if err.errno != errno.EEXIST:
1594
        raise errors.GenericError("Cannot create needed directory"
1595
                                  " '%s': %s" % (dir_name, err))
1596
    if not os.path.isdir(dir_name):
1597
      raise errors.GenericError("%s is not a directory" % dir_name)
1598

    
1599

    
1600
def ReadFile(file_name, size=-1):
1601
  """Reads a file.
1602

1603
  @type size: int
1604
  @param size: Read at most size bytes (if negative, entire file)
1605
  @rtype: str
1606
  @return: the (possibly partial) content of the file
1607

1608
  """
1609
  f = open(file_name, "r")
1610
  try:
1611
    return f.read(size)
1612
  finally:
1613
    f.close()
1614

    
1615

    
1616
def WriteFile(file_name, fn=None, data=None,
1617
              mode=None, uid=-1, gid=-1,
1618
              atime=None, mtime=None, close=True,
1619
              dry_run=False, backup=False,
1620
              prewrite=None, postwrite=None):
1621
  """(Over)write a file atomically.
1622

1623
  The file_name and either fn (a function taking one argument, the
1624
  file descriptor, and which should write the data to it) or data (the
1625
  contents of the file) must be passed. The other arguments are
1626
  optional and allow setting the file mode, owner and group, and the
1627
  mtime/atime of the file.
1628

1629
  If the function doesn't raise an exception, it has succeeded and the
1630
  target file has the new contents. If the function has raised an
1631
  exception, an existing target file should be unmodified and the
1632
  temporary file should be removed.
1633

1634
  @type file_name: str
1635
  @param file_name: the target filename
1636
  @type fn: callable
1637
  @param fn: content writing function, called with
1638
      file descriptor as parameter
1639
  @type data: str
1640
  @param data: contents of the file
1641
  @type mode: int
1642
  @param mode: file mode
1643
  @type uid: int
1644
  @param uid: the owner of the file
1645
  @type gid: int
1646
  @param gid: the group of the file
1647
  @type atime: int
1648
  @param atime: a custom access time to be set on the file
1649
  @type mtime: int
1650
  @param mtime: a custom modification time to be set on the file
1651
  @type close: boolean
1652
  @param close: whether to close file after writing it
1653
  @type prewrite: callable
1654
  @param prewrite: function to be called before writing content
1655
  @type postwrite: callable
1656
  @param postwrite: function to be called after writing content
1657

1658
  @rtype: None or int
1659
  @return: None if the 'close' parameter evaluates to True,
1660
      otherwise the file descriptor
1661

1662
  @raise errors.ProgrammerError: if any of the arguments are not valid
1663

1664
  """
1665
  if not os.path.isabs(file_name):
1666
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1667
                                 " absolute: '%s'" % file_name)
1668

    
1669
  if [fn, data].count(None) != 1:
1670
    raise errors.ProgrammerError("fn or data required")
1671

    
1672
  if [atime, mtime].count(None) == 1:
1673
    raise errors.ProgrammerError("Both atime and mtime must be either"
1674
                                 " set or None")
1675

    
1676
  if backup and not dry_run and os.path.isfile(file_name):
1677
    CreateBackup(file_name)
1678

    
1679
  dir_name, base_name = os.path.split(file_name)
1680
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1681
  do_remove = True
1682
  # here we need to make sure we remove the temp file, if any error
1683
  # leaves it in place
1684
  try:
1685
    if uid != -1 or gid != -1:
1686
      os.chown(new_name, uid, gid)
1687
    if mode:
1688
      os.chmod(new_name, mode)
1689
    if callable(prewrite):
1690
      prewrite(fd)
1691
    if data is not None:
1692
      os.write(fd, data)
1693
    else:
1694
      fn(fd)
1695
    if callable(postwrite):
1696
      postwrite(fd)
1697
    os.fsync(fd)
1698
    if atime is not None and mtime is not None:
1699
      os.utime(new_name, (atime, mtime))
1700
    if not dry_run:
1701
      os.rename(new_name, file_name)
1702
      do_remove = False
1703
  finally:
1704
    if close:
1705
      os.close(fd)
1706
      result = None
1707
    else:
1708
      result = fd
1709
    if do_remove:
1710
      RemoveFile(new_name)
1711

    
1712
  return result
1713

    
1714

    
1715
def FirstFree(seq, base=0):
1716
  """Returns the first non-existing integer from seq.
1717

1718
  The seq argument should be a sorted list of positive integers. The
1719
  first time the index of an element is smaller than the element
1720
  value, the index will be returned.
1721

1722
  The base argument is used to start at a different offset,
1723
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1724

1725
  Example: C{[0, 1, 3]} will return I{2}.
1726

1727
  @type seq: sequence
1728
  @param seq: the sequence to be analyzed.
1729
  @type base: int
1730
  @param base: use this value as the base index of the sequence
1731
  @rtype: int
1732
  @return: the first non-used index in the sequence
1733

1734
  """
1735
  for idx, elem in enumerate(seq):
1736
    assert elem >= base, "Passed element is higher than base offset"
1737
    if elem > idx + base:
1738
      # idx is not used
1739
      return idx + base
1740
  return None
1741

    
1742

    
1743
def all(seq, pred=bool): # pylint: disable-msg=W0622
1744
  "Returns True if pred(x) is True for every element in the iterable"
1745
  for _ in itertools.ifilterfalse(pred, seq):
1746
    return False
1747
  return True
1748

    
1749

    
1750
def any(seq, pred=bool): # pylint: disable-msg=W0622
1751
  "Returns True if pred(x) is True for at least one element in the iterable"
1752
  for _ in itertools.ifilter(pred, seq):
1753
    return True
1754
  return False
1755

    
1756

    
1757
def partition(seq, pred=bool): # # pylint: disable-msg=W0622
1758
  "Partition a list in two, based on the given predicate"
1759
  return (list(itertools.ifilter(pred, seq)),
1760
          list(itertools.ifilterfalse(pred, seq)))
1761

    
1762

    
1763
def UniqueSequence(seq):
1764
  """Returns a list with unique elements.
1765

1766
  Element order is preserved.
1767

1768
  @type seq: sequence
1769
  @param seq: the sequence with the source elements
1770
  @rtype: list
1771
  @return: list of unique elements from seq
1772

1773
  """
1774
  seen = set()
1775
  return [i for i in seq if i not in seen and not seen.add(i)]
1776

    
1777

    
1778
def NormalizeAndValidateMac(mac):
1779
  """Normalizes and check if a MAC address is valid.
1780

1781
  Checks whether the supplied MAC address is formally correct, only
1782
  accepts colon separated format. Normalize it to all lower.
1783

1784
  @type mac: str
1785
  @param mac: the MAC to be validated
1786
  @rtype: str
1787
  @return: returns the normalized and validated MAC.
1788

1789
  @raise errors.OpPrereqError: If the MAC isn't valid
1790

1791
  """
1792
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
1793
  if not mac_check.match(mac):
1794
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
1795
                               mac, errors.ECODE_INVAL)
1796

    
1797
  return mac.lower()
1798

    
1799

    
1800
def TestDelay(duration):
1801
  """Sleep for a fixed amount of time.
1802

1803
  @type duration: float
1804
  @param duration: the sleep duration
1805
  @rtype: boolean
1806
  @return: False for negative value, True otherwise
1807

1808
  """
1809
  if duration < 0:
1810
    return False, "Invalid sleep duration"
1811
  time.sleep(duration)
1812
  return True, None
1813

    
1814

    
1815
def _CloseFDNoErr(fd, retries=5):
1816
  """Close a file descriptor ignoring errors.
1817

1818
  @type fd: int
1819
  @param fd: the file descriptor
1820
  @type retries: int
1821
  @param retries: how many retries to make, in case we get any
1822
      other error than EBADF
1823

1824
  """
1825
  try:
1826
    os.close(fd)
1827
  except OSError, err:
1828
    if err.errno != errno.EBADF:
1829
      if retries > 0:
1830
        _CloseFDNoErr(fd, retries - 1)
1831
    # else either it's closed already or we're out of retries, so we
1832
    # ignore this and go on
1833

    
1834

    
1835
def CloseFDs(noclose_fds=None):
1836
  """Close file descriptors.
1837

1838
  This closes all file descriptors above 2 (i.e. except
1839
  stdin/out/err).
1840

1841
  @type noclose_fds: list or None
1842
  @param noclose_fds: if given, it denotes a list of file descriptor
1843
      that should not be closed
1844

1845
  """
1846
  # Default maximum for the number of available file descriptors.
1847
  if 'SC_OPEN_MAX' in os.sysconf_names:
1848
    try:
1849
      MAXFD = os.sysconf('SC_OPEN_MAX')
1850
      if MAXFD < 0:
1851
        MAXFD = 1024
1852
    except OSError:
1853
      MAXFD = 1024
1854
  else:
1855
    MAXFD = 1024
1856
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
1857
  if (maxfd == resource.RLIM_INFINITY):
1858
    maxfd = MAXFD
1859

    
1860
  # Iterate through and close all file descriptors (except the standard ones)
1861
  for fd in range(3, maxfd):
1862
    if noclose_fds and fd in noclose_fds:
1863
      continue
1864
    _CloseFDNoErr(fd)
1865

    
1866

    
1867
def Daemonize(logfile):
1868
  """Daemonize the current process.
1869

1870
  This detaches the current process from the controlling terminal and
1871
  runs it in the background as a daemon.
1872

1873
  @type logfile: str
1874
  @param logfile: the logfile to which we should redirect stdout/stderr
1875
  @rtype: int
1876
  @return: the value zero
1877

1878
  """
1879
  # pylint: disable-msg=W0212
1880
  # yes, we really want os._exit
1881
  UMASK = 077
1882
  WORKDIR = "/"
1883

    
1884
  # this might fail
1885
  pid = os.fork()
1886
  if (pid == 0):  # The first child.
1887
    os.setsid()
1888
    # this might fail
1889
    pid = os.fork() # Fork a second child.
1890
    if (pid == 0):  # The second child.
1891
      os.chdir(WORKDIR)
1892
      os.umask(UMASK)
1893
    else:
1894
      # exit() or _exit()?  See below.
1895
      os._exit(0) # Exit parent (the first child) of the second child.
1896
  else:
1897
    os._exit(0) # Exit parent of the first child.
1898

    
1899
  for fd in range(3):
1900
    _CloseFDNoErr(fd)
1901
  i = os.open("/dev/null", os.O_RDONLY) # stdin
1902
  assert i == 0, "Can't close/reopen stdin"
1903
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
1904
  assert i == 1, "Can't close/reopen stdout"
1905
  # Duplicate standard output to standard error.
1906
  os.dup2(1, 2)
1907
  return 0
1908

    
1909

    
1910
def DaemonPidFileName(name):
1911
  """Compute a ganeti pid file absolute path
1912

1913
  @type name: str
1914
  @param name: the daemon name
1915
  @rtype: str
1916
  @return: the full path to the pidfile corresponding to the given
1917
      daemon name
1918

1919
  """
1920
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
1921

    
1922

    
1923
def EnsureDaemon(name):
1924
  """Check for and start daemon if not alive.
1925

1926
  """
1927
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
1928
  if result.failed:
1929
    logging.error("Can't start daemon '%s', failure %s, output: %s",
1930
                  name, result.fail_reason, result.output)
1931
    return False
1932

    
1933
  return True
1934

    
1935

    
1936
def WritePidFile(name):
1937
  """Write the current process pidfile.
1938

1939
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
1940

1941
  @type name: str
1942
  @param name: the daemon name to use
1943
  @raise errors.GenericError: if the pid file already exists and
1944
      points to a live process
1945

1946
  """
1947
  pid = os.getpid()
1948
  pidfilename = DaemonPidFileName(name)
1949
  if IsProcessAlive(ReadPidFile(pidfilename)):
1950
    raise errors.GenericError("%s contains a live process" % pidfilename)
1951

    
1952
  WriteFile(pidfilename, data="%d\n" % pid)
1953

    
1954

    
1955
def RemovePidFile(name):
1956
  """Remove the current process pidfile.
1957

1958
  Any errors are ignored.
1959

1960
  @type name: str
1961
  @param name: the daemon name used to derive the pidfile name
1962

1963
  """
1964
  pidfilename = DaemonPidFileName(name)
1965
  # TODO: we could check here that the file contains our pid
1966
  try:
1967
    RemoveFile(pidfilename)
1968
  except: # pylint: disable-msg=W0702
1969
    pass
1970

    
1971

    
1972
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
1973
                waitpid=False):
1974
  """Kill a process given by its pid.
1975

1976
  @type pid: int
1977
  @param pid: The PID to terminate.
1978
  @type signal_: int
1979
  @param signal_: The signal to send, by default SIGTERM
1980
  @type timeout: int
1981
  @param timeout: The timeout after which, if the process is still alive,
1982
                  a SIGKILL will be sent. If not positive, no such checking
1983
                  will be done
1984
  @type waitpid: boolean
1985
  @param waitpid: If true, we should waitpid on this process after
1986
      sending signals, since it's our own child and otherwise it
1987
      would remain as zombie
1988

1989
  """
1990
  def _helper(pid, signal_, wait):
1991
    """Simple helper to encapsulate the kill/waitpid sequence"""
1992
    os.kill(pid, signal_)
1993
    if wait:
1994
      try:
1995
        os.waitpid(pid, os.WNOHANG)
1996
      except OSError:
1997
        pass
1998

    
1999
  if pid <= 0:
2000
    # kill with pid=0 == suicide
2001
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2002

    
2003
  if not IsProcessAlive(pid):
2004
    return
2005

    
2006
  _helper(pid, signal_, waitpid)
2007

    
2008
  if timeout <= 0:
2009
    return
2010

    
2011
  def _CheckProcess():
2012
    if not IsProcessAlive(pid):
2013
      return
2014

    
2015
    try:
2016
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2017
    except OSError:
2018
      raise RetryAgain()
2019

    
2020
    if result_pid > 0:
2021
      return
2022

    
2023
    raise RetryAgain()
2024

    
2025
  try:
2026
    # Wait up to $timeout seconds
2027
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2028
  except RetryTimeout:
2029
    pass
2030

    
2031
  if IsProcessAlive(pid):
2032
    # Kill process if it's still alive
2033
    _helper(pid, signal.SIGKILL, waitpid)
2034

    
2035

    
2036
def FindFile(name, search_path, test=os.path.exists):
2037
  """Look for a filesystem object in a given path.
2038

2039
  This is an abstract method to search for filesystem object (files,
2040
  dirs) under a given search path.
2041

2042
  @type name: str
2043
  @param name: the name to look for
2044
  @type search_path: str
2045
  @param search_path: location to start at
2046
  @type test: callable
2047
  @param test: a function taking one argument that should return True
2048
      if the a given object is valid; the default value is
2049
      os.path.exists, causing only existing files to be returned
2050
  @rtype: str or None
2051
  @return: full path to the object if found, None otherwise
2052

2053
  """
2054
  # validate the filename mask
2055
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2056
    logging.critical("Invalid value passed for external script name: '%s'",
2057
                     name)
2058
    return None
2059

    
2060
  for dir_name in search_path:
2061
    # FIXME: investigate switch to PathJoin
2062
    item_name = os.path.sep.join([dir_name, name])
2063
    # check the user test and that we're indeed resolving to the given
2064
    # basename
2065
    if test(item_name) and os.path.basename(item_name) == name:
2066
      return item_name
2067
  return None
2068

    
2069

    
2070
def CheckVolumeGroupSize(vglist, vgname, minsize):
2071
  """Checks if the volume group list is valid.
2072

2073
  The function will check if a given volume group is in the list of
2074
  volume groups and has a minimum size.
2075

2076
  @type vglist: dict
2077
  @param vglist: dictionary of volume group names and their size
2078
  @type vgname: str
2079
  @param vgname: the volume group we should check
2080
  @type minsize: int
2081
  @param minsize: the minimum size we accept
2082
  @rtype: None or str
2083
  @return: None for success, otherwise the error message
2084

2085
  """
2086
  vgsize = vglist.get(vgname, None)
2087
  if vgsize is None:
2088
    return "volume group '%s' missing" % vgname
2089
  elif vgsize < minsize:
2090
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2091
            (vgname, minsize, vgsize))
2092
  return None
2093

    
2094

    
2095
def SplitTime(value):
2096
  """Splits time as floating point number into a tuple.
2097

2098
  @param value: Time in seconds
2099
  @type value: int or float
2100
  @return: Tuple containing (seconds, microseconds)
2101

2102
  """
2103
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2104

    
2105
  assert 0 <= seconds, \
2106
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2107
  assert 0 <= microseconds <= 999999, \
2108
    "Microseconds must be 0-999999, but are %s" % microseconds
2109

    
2110
  return (int(seconds), int(microseconds))
2111

    
2112

    
2113
def MergeTime(timetuple):
2114
  """Merges a tuple into time as a floating point number.
2115

2116
  @param timetuple: Time as tuple, (seconds, microseconds)
2117
  @type timetuple: tuple
2118
  @return: Time as a floating point number expressed in seconds
2119

2120
  """
2121
  (seconds, microseconds) = timetuple
2122

    
2123
  assert 0 <= seconds, \
2124
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2125
  assert 0 <= microseconds <= 999999, \
2126
    "Microseconds must be 0-999999, but are %s" % microseconds
2127

    
2128
  return float(seconds) + (float(microseconds) * 0.000001)
2129

    
2130

    
2131
def GetDaemonPort(daemon_name):
2132
  """Get the daemon port for this cluster.
2133

2134
  Note that this routine does not read a ganeti-specific file, but
2135
  instead uses C{socket.getservbyname} to allow pre-customization of
2136
  this parameter outside of Ganeti.
2137

2138
  @type daemon_name: string
2139
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2140
  @rtype: int
2141

2142
  """
2143
  if daemon_name not in constants.DAEMONS_PORTS:
2144
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2145

    
2146
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2147
  try:
2148
    port = socket.getservbyname(daemon_name, proto)
2149
  except socket.error:
2150
    port = default_port
2151

    
2152
  return port
2153

    
2154

    
2155
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2156
                 multithreaded=False, syslog=constants.SYSLOG_USAGE):
2157
  """Configures the logging module.
2158

2159
  @type logfile: str
2160
  @param logfile: the filename to which we should log
2161
  @type debug: integer
2162
  @param debug: if greater than zero, enable debug messages, otherwise
2163
      only those at C{INFO} and above level
2164
  @type stderr_logging: boolean
2165
  @param stderr_logging: whether we should also log to the standard error
2166
  @type program: str
2167
  @param program: the name under which we should log messages
2168
  @type multithreaded: boolean
2169
  @param multithreaded: if True, will add the thread name to the log file
2170
  @type syslog: string
2171
  @param syslog: one of 'no', 'yes', 'only':
2172
      - if no, syslog is not used
2173
      - if yes, syslog is used (in addition to file-logging)
2174
      - if only, only syslog is used
2175
  @raise EnvironmentError: if we can't open the log file and
2176
      syslog/stderr logging is disabled
2177

2178
  """
2179
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2180
  sft = program + "[%(process)d]:"
2181
  if multithreaded:
2182
    fmt += "/%(threadName)s"
2183
    sft += " (%(threadName)s)"
2184
  if debug:
2185
    fmt += " %(module)s:%(lineno)s"
2186
    # no debug info for syslog loggers
2187
  fmt += " %(levelname)s %(message)s"
2188
  # yes, we do want the textual level, as remote syslog will probably
2189
  # lose the error level, and it's easier to grep for it
2190
  sft += " %(levelname)s %(message)s"
2191
  formatter = logging.Formatter(fmt)
2192
  sys_fmt = logging.Formatter(sft)
2193

    
2194
  root_logger = logging.getLogger("")
2195
  root_logger.setLevel(logging.NOTSET)
2196

    
2197
  # Remove all previously setup handlers
2198
  for handler in root_logger.handlers:
2199
    handler.close()
2200
    root_logger.removeHandler(handler)
2201

    
2202
  if stderr_logging:
2203
    stderr_handler = logging.StreamHandler()
2204
    stderr_handler.setFormatter(formatter)
2205
    if debug:
2206
      stderr_handler.setLevel(logging.NOTSET)
2207
    else:
2208
      stderr_handler.setLevel(logging.CRITICAL)
2209
    root_logger.addHandler(stderr_handler)
2210

    
2211
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2212
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2213
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2214
                                                    facility)
2215
    syslog_handler.setFormatter(sys_fmt)
2216
    # Never enable debug over syslog
2217
    syslog_handler.setLevel(logging.INFO)
2218
    root_logger.addHandler(syslog_handler)
2219

    
2220
  if syslog != constants.SYSLOG_ONLY:
2221
    # this can fail, if the logging directories are not setup or we have
2222
    # a permisssion problem; in this case, it's best to log but ignore
2223
    # the error if stderr_logging is True, and if false we re-raise the
2224
    # exception since otherwise we could run but without any logs at all
2225
    try:
2226
      logfile_handler = logging.FileHandler(logfile)
2227
      logfile_handler.setFormatter(formatter)
2228
      if debug:
2229
        logfile_handler.setLevel(logging.DEBUG)
2230
      else:
2231
        logfile_handler.setLevel(logging.INFO)
2232
      root_logger.addHandler(logfile_handler)
2233
    except EnvironmentError:
2234
      if stderr_logging or syslog == constants.SYSLOG_YES:
2235
        logging.exception("Failed to enable logging to file '%s'", logfile)
2236
      else:
2237
        # we need to re-raise the exception
2238
        raise
2239

    
2240

    
2241
def IsNormAbsPath(path):
2242
  """Check whether a path is absolute and also normalized
2243

2244
  This avoids things like /dir/../../other/path to be valid.
2245

2246
  """
2247
  return os.path.normpath(path) == path and os.path.isabs(path)
2248

    
2249

    
2250
def PathJoin(*args):
2251
  """Safe-join a list of path components.
2252

2253
  Requirements:
2254
      - the first argument must be an absolute path
2255
      - no component in the path must have backtracking (e.g. /../),
2256
        since we check for normalization at the end
2257

2258
  @param args: the path components to be joined
2259
  @raise ValueError: for invalid paths
2260

2261
  """
2262
  # ensure we're having at least one path passed in
2263
  assert args
2264
  # ensure the first component is an absolute and normalized path name
2265
  root = args[0]
2266
  if not IsNormAbsPath(root):
2267
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2268
  result = os.path.join(*args)
2269
  # ensure that the whole path is normalized
2270
  if not IsNormAbsPath(result):
2271
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2272
  # check that we're still under the original prefix
2273
  prefix = os.path.commonprefix([root, result])
2274
  if prefix != root:
2275
    raise ValueError("Error: path joining resulted in different prefix"
2276
                     " (%s != %s)" % (prefix, root))
2277
  return result
2278

    
2279

    
2280
def TailFile(fname, lines=20):
2281
  """Return the last lines from a file.
2282

2283
  @note: this function will only read and parse the last 4KB of
2284
      the file; if the lines are very long, it could be that less
2285
      than the requested number of lines are returned
2286

2287
  @param fname: the file name
2288
  @type lines: int
2289
  @param lines: the (maximum) number of lines to return
2290

2291
  """
2292
  fd = open(fname, "r")
2293
  try:
2294
    fd.seek(0, 2)
2295
    pos = fd.tell()
2296
    pos = max(0, pos-4096)
2297
    fd.seek(pos, 0)
2298
    raw_data = fd.read()
2299
  finally:
2300
    fd.close()
2301

    
2302
  rows = raw_data.splitlines()
2303
  return rows[-lines:]
2304

    
2305

    
2306
def _ParseAsn1Generalizedtime(value):
2307
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2308

2309
  @type value: string
2310
  @param value: ASN1 GENERALIZEDTIME timestamp
2311

2312
  """
2313
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2314
  if m:
2315
    # We have an offset
2316
    asn1time = m.group(1)
2317
    hours = int(m.group(2))
2318
    minutes = int(m.group(3))
2319
    utcoffset = (60 * hours) + minutes
2320
  else:
2321
    if not value.endswith("Z"):
2322
      raise ValueError("Missing timezone")
2323
    asn1time = value[:-1]
2324
    utcoffset = 0
2325

    
2326
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2327

    
2328
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2329

    
2330
  return calendar.timegm(tt.utctimetuple())
2331

    
2332

    
2333
def GetX509CertValidity(cert):
2334
  """Returns the validity period of the certificate.
2335

2336
  @type cert: OpenSSL.crypto.X509
2337
  @param cert: X509 certificate object
2338

2339
  """
2340
  # The get_notBefore and get_notAfter functions are only supported in
2341
  # pyOpenSSL 0.7 and above.
2342
  try:
2343
    get_notbefore_fn = cert.get_notBefore
2344
  except AttributeError:
2345
    not_before = None
2346
  else:
2347
    not_before_asn1 = get_notbefore_fn()
2348

    
2349
    if not_before_asn1 is None:
2350
      not_before = None
2351
    else:
2352
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2353

    
2354
  try:
2355
    get_notafter_fn = cert.get_notAfter
2356
  except AttributeError:
2357
    not_after = None
2358
  else:
2359
    not_after_asn1 = get_notafter_fn()
2360

    
2361
    if not_after_asn1 is None:
2362
      not_after = None
2363
    else:
2364
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2365

    
2366
  return (not_before, not_after)
2367

    
2368

    
2369
def SignX509Certificate(cert, key, salt):
2370
  """Sign a X509 certificate.
2371

2372
  An RFC822-like signature header is added in front of the certificate.
2373

2374
  @type cert: OpenSSL.crypto.X509
2375
  @param cert: X509 certificate object
2376
  @type key: string
2377
  @param key: Key for HMAC
2378
  @type salt: string
2379
  @param salt: Salt for HMAC
2380
  @rtype: string
2381
  @return: Serialized and signed certificate in PEM format
2382

2383
  """
2384
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2385
    raise errors.GenericError("Invalid salt: %r" % salt)
2386

    
2387
  # Dumping as PEM here ensures the certificate is in a sane format
2388
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2389

    
2390
  return ("%s: %s/%s\n\n%s" %
2391
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2392
           hmac.new(key, salt + cert_pem, sha1).hexdigest(),
2393
           cert_pem))
2394

    
2395

    
2396
def _ExtractX509CertificateSignature(cert_pem):
2397
  """Helper function to extract signature from X509 certificate.
2398

2399
  """
2400
  # Extract signature from original PEM data
2401
  for line in cert_pem.splitlines():
2402
    if line.startswith("---"):
2403
      break
2404

    
2405
    m = X509_SIGNATURE.match(line.strip())
2406
    if m:
2407
      return (m.group("salt"), m.group("sign"))
2408

    
2409
  raise errors.GenericError("X509 certificate signature is missing")
2410

    
2411

    
2412
def LoadSignedX509Certificate(cert_pem, key):
2413
  """Verifies a signed X509 certificate.
2414

2415
  @type cert_pem: string
2416
  @param cert_pem: Certificate in PEM format and with signature header
2417
  @type key: string
2418
  @param key: Key for HMAC
2419
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2420
  @return: X509 certificate object and salt
2421

2422
  """
2423
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2424

    
2425
  # Load certificate
2426
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2427

    
2428
  # Dump again to ensure it's in a sane format
2429
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2430

    
2431
  if signature != hmac.new(key, salt + sane_pem, sha1).hexdigest():
2432
    raise errors.GenericError("X509 certificate signature is invalid")
2433

    
2434
  return (cert, salt)
2435

    
2436

    
2437
def SafeEncode(text):
2438
  """Return a 'safe' version of a source string.
2439

2440
  This function mangles the input string and returns a version that
2441
  should be safe to display/encode as ASCII. To this end, we first
2442
  convert it to ASCII using the 'backslashreplace' encoding which
2443
  should get rid of any non-ASCII chars, and then we process it
2444
  through a loop copied from the string repr sources in the python; we
2445
  don't use string_escape anymore since that escape single quotes and
2446
  backslashes too, and that is too much; and that escaping is not
2447
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2448

2449
  @type text: str or unicode
2450
  @param text: input data
2451
  @rtype: str
2452
  @return: a safe version of text
2453

2454
  """
2455
  if isinstance(text, unicode):
2456
    # only if unicode; if str already, we handle it below
2457
    text = text.encode('ascii', 'backslashreplace')
2458
  resu = ""
2459
  for char in text:
2460
    c = ord(char)
2461
    if char  == '\t':
2462
      resu += r'\t'
2463
    elif char == '\n':
2464
      resu += r'\n'
2465
    elif char == '\r':
2466
      resu += r'\'r'
2467
    elif c < 32 or c >= 127: # non-printable
2468
      resu += "\\x%02x" % (c & 0xff)
2469
    else:
2470
      resu += char
2471
  return resu
2472

    
2473

    
2474
def UnescapeAndSplit(text, sep=","):
2475
  """Split and unescape a string based on a given separator.
2476

2477
  This function splits a string based on a separator where the
2478
  separator itself can be escape in order to be an element of the
2479
  elements. The escaping rules are (assuming coma being the
2480
  separator):
2481
    - a plain , separates the elements
2482
    - a sequence \\\\, (double backslash plus comma) is handled as a
2483
      backslash plus a separator comma
2484
    - a sequence \, (backslash plus comma) is handled as a
2485
      non-separator comma
2486

2487
  @type text: string
2488
  @param text: the string to split
2489
  @type sep: string
2490
  @param text: the separator
2491
  @rtype: string
2492
  @return: a list of strings
2493

2494
  """
2495
  # we split the list by sep (with no escaping at this stage)
2496
  slist = text.split(sep)
2497
  # next, we revisit the elements and if any of them ended with an odd
2498
  # number of backslashes, then we join it with the next
2499
  rlist = []
2500
  while slist:
2501
    e1 = slist.pop(0)
2502
    if e1.endswith("\\"):
2503
      num_b = len(e1) - len(e1.rstrip("\\"))
2504
      if num_b % 2 == 1:
2505
        e2 = slist.pop(0)
2506
        # here the backslashes remain (all), and will be reduced in
2507
        # the next step
2508
        rlist.append(e1 + sep + e2)
2509
        continue
2510
    rlist.append(e1)
2511
  # finally, replace backslash-something with something
2512
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2513
  return rlist
2514

    
2515

    
2516
def CommaJoin(names):
2517
  """Nicely join a set of identifiers.
2518

2519
  @param names: set, list or tuple
2520
  @return: a string with the formatted results
2521

2522
  """
2523
  return ", ".join([str(val) for val in names])
2524

    
2525

    
2526
def BytesToMebibyte(value):
2527
  """Converts bytes to mebibytes.
2528

2529
  @type value: int
2530
  @param value: Value in bytes
2531
  @rtype: int
2532
  @return: Value in mebibytes
2533

2534
  """
2535
  return int(round(value / (1024.0 * 1024.0), 0))
2536

    
2537

    
2538
def CalculateDirectorySize(path):
2539
  """Calculates the size of a directory recursively.
2540

2541
  @type path: string
2542
  @param path: Path to directory
2543
  @rtype: int
2544
  @return: Size in mebibytes
2545

2546
  """
2547
  size = 0
2548

    
2549
  for (curpath, _, files) in os.walk(path):
2550
    for filename in files:
2551
      st = os.lstat(PathJoin(curpath, filename))
2552
      size += st.st_size
2553

    
2554
  return BytesToMebibyte(size)
2555

    
2556

    
2557
def GetFilesystemStats(path):
2558
  """Returns the total and free space on a filesystem.
2559

2560
  @type path: string
2561
  @param path: Path on filesystem to be examined
2562
  @rtype: int
2563
  @return: tuple of (Total space, Free space) in mebibytes
2564

2565
  """
2566
  st = os.statvfs(path)
2567

    
2568
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2569
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2570
  return (tsize, fsize)
2571

    
2572

    
2573
def RunInSeparateProcess(fn, *args):
2574
  """Runs a function in a separate process.
2575

2576
  Note: Only boolean return values are supported.
2577

2578
  @type fn: callable
2579
  @param fn: Function to be called
2580
  @rtype: bool
2581
  @return: Function's result
2582

2583
  """
2584
  pid = os.fork()
2585
  if pid == 0:
2586
    # Child process
2587
    try:
2588
      # In case the function uses temporary files
2589
      ResetTempfileModule()
2590

    
2591
      # Call function
2592
      result = int(bool(fn(*args)))
2593
      assert result in (0, 1)
2594
    except: # pylint: disable-msg=W0702
2595
      logging.exception("Error while calling function in separate process")
2596
      # 0 and 1 are reserved for the return value
2597
      result = 33
2598

    
2599
    os._exit(result) # pylint: disable-msg=W0212
2600

    
2601
  # Parent process
2602

    
2603
  # Avoid zombies and check exit code
2604
  (_, status) = os.waitpid(pid, 0)
2605

    
2606
  if os.WIFSIGNALED(status):
2607
    exitcode = None
2608
    signum = os.WTERMSIG(status)
2609
  else:
2610
    exitcode = os.WEXITSTATUS(status)
2611
    signum = None
2612

    
2613
  if not (exitcode in (0, 1) and signum is None):
2614
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
2615
                              (exitcode, signum))
2616

    
2617
  return bool(exitcode)
2618

    
2619

    
2620
def LockedMethod(fn):
2621
  """Synchronized object access decorator.
2622

2623
  This decorator is intended to protect access to an object using the
2624
  object's own lock which is hardcoded to '_lock'.
2625

2626
  """
2627
  def _LockDebug(*args, **kwargs):
2628
    if debug_locks:
2629
      logging.debug(*args, **kwargs)
2630

    
2631
  def wrapper(self, *args, **kwargs):
2632
    # pylint: disable-msg=W0212
2633
    assert hasattr(self, '_lock')
2634
    lock = self._lock
2635
    _LockDebug("Waiting for %s", lock)
2636
    lock.acquire()
2637
    try:
2638
      _LockDebug("Acquired %s", lock)
2639
      result = fn(self, *args, **kwargs)
2640
    finally:
2641
      _LockDebug("Releasing %s", lock)
2642
      lock.release()
2643
      _LockDebug("Released %s", lock)
2644
    return result
2645
  return wrapper
2646

    
2647

    
2648
def LockFile(fd):
2649
  """Locks a file using POSIX locks.
2650

2651
  @type fd: int
2652
  @param fd: the file descriptor we need to lock
2653

2654
  """
2655
  try:
2656
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
2657
  except IOError, err:
2658
    if err.errno == errno.EAGAIN:
2659
      raise errors.LockError("File already locked")
2660
    raise
2661

    
2662

    
2663
def FormatTime(val):
2664
  """Formats a time value.
2665

2666
  @type val: float or None
2667
  @param val: the timestamp as returned by time.time()
2668
  @return: a string value or N/A if we don't have a valid timestamp
2669

2670
  """
2671
  if val is None or not isinstance(val, (int, float)):
2672
    return "N/A"
2673
  # these two codes works on Linux, but they are not guaranteed on all
2674
  # platforms
2675
  return time.strftime("%F %T", time.localtime(val))
2676

    
2677

    
2678
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
2679
  """Reads the watcher pause file.
2680

2681
  @type filename: string
2682
  @param filename: Path to watcher pause file
2683
  @type now: None, float or int
2684
  @param now: Current time as Unix timestamp
2685
  @type remove_after: int
2686
  @param remove_after: Remove watcher pause file after specified amount of
2687
    seconds past the pause end time
2688

2689
  """
2690
  if now is None:
2691
    now = time.time()
2692

    
2693
  try:
2694
    value = ReadFile(filename)
2695
  except IOError, err:
2696
    if err.errno != errno.ENOENT:
2697
      raise
2698
    value = None
2699

    
2700
  if value is not None:
2701
    try:
2702
      value = int(value)
2703
    except ValueError:
2704
      logging.warning(("Watcher pause file (%s) contains invalid value,"
2705
                       " removing it"), filename)
2706
      RemoveFile(filename)
2707
      value = None
2708

    
2709
    if value is not None:
2710
      # Remove file if it's outdated
2711
      if now > (value + remove_after):
2712
        RemoveFile(filename)
2713
        value = None
2714

    
2715
      elif now > value:
2716
        value = None
2717

    
2718
  return value
2719

    
2720

    
2721
class RetryTimeout(Exception):
2722
  """Retry loop timed out.
2723

2724
  """
2725

    
2726

    
2727
class RetryAgain(Exception):
2728
  """Retry again.
2729

2730
  """
2731

    
2732

    
2733
class _RetryDelayCalculator(object):
2734
  """Calculator for increasing delays.
2735

2736
  """
2737
  __slots__ = [
2738
    "_factor",
2739
    "_limit",
2740
    "_next",
2741
    "_start",
2742
    ]
2743

    
2744
  def __init__(self, start, factor, limit):
2745
    """Initializes this class.
2746

2747
    @type start: float
2748
    @param start: Initial delay
2749
    @type factor: float
2750
    @param factor: Factor for delay increase
2751
    @type limit: float or None
2752
    @param limit: Upper limit for delay or None for no limit
2753

2754
    """
2755
    assert start > 0.0
2756
    assert factor >= 1.0
2757
    assert limit is None or limit >= 0.0
2758

    
2759
    self._start = start
2760
    self._factor = factor
2761
    self._limit = limit
2762

    
2763
    self._next = start
2764

    
2765
  def __call__(self):
2766
    """Returns current delay and calculates the next one.
2767

2768
    """
2769
    current = self._next
2770

    
2771
    # Update for next run
2772
    if self._limit is None or self._next < self._limit:
2773
      self._next = min(self._limit, self._next * self._factor)
2774

    
2775
    return current
2776

    
2777

    
2778
#: Special delay to specify whole remaining timeout
2779
RETRY_REMAINING_TIME = object()
2780

    
2781

    
2782
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
2783
          _time_fn=time.time):
2784
  """Call a function repeatedly until it succeeds.
2785

2786
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
2787
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
2788
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
2789

2790
  C{delay} can be one of the following:
2791
    - callable returning the delay length as a float
2792
    - Tuple of (start, factor, limit)
2793
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
2794
      useful when overriding L{wait_fn} to wait for an external event)
2795
    - A static delay as a number (int or float)
2796

2797
  @type fn: callable
2798
  @param fn: Function to be called
2799
  @param delay: Either a callable (returning the delay), a tuple of (start,
2800
                factor, limit) (see L{_RetryDelayCalculator}),
2801
                L{RETRY_REMAINING_TIME} or a number (int or float)
2802
  @type timeout: float
2803
  @param timeout: Total timeout
2804
  @type wait_fn: callable
2805
  @param wait_fn: Waiting function
2806
  @return: Return value of function
2807

2808
  """
2809
  assert callable(fn)
2810
  assert callable(wait_fn)
2811
  assert callable(_time_fn)
2812

    
2813
  if args is None:
2814
    args = []
2815

    
2816
  end_time = _time_fn() + timeout
2817

    
2818
  if callable(delay):
2819
    # External function to calculate delay
2820
    calc_delay = delay
2821

    
2822
  elif isinstance(delay, (tuple, list)):
2823
    # Increasing delay with optional upper boundary
2824
    (start, factor, limit) = delay
2825
    calc_delay = _RetryDelayCalculator(start, factor, limit)
2826

    
2827
  elif delay is RETRY_REMAINING_TIME:
2828
    # Always use the remaining time
2829
    calc_delay = None
2830

    
2831
  else:
2832
    # Static delay
2833
    calc_delay = lambda: delay
2834

    
2835
  assert calc_delay is None or callable(calc_delay)
2836

    
2837
  while True:
2838
    try:
2839
      # pylint: disable-msg=W0142
2840
      return fn(*args)
2841
    except RetryAgain:
2842
      pass
2843

    
2844
    remaining_time = end_time - _time_fn()
2845

    
2846
    if remaining_time < 0.0:
2847
      raise RetryTimeout()
2848

    
2849
    assert remaining_time >= 0.0
2850

    
2851
    if calc_delay is None:
2852
      wait_fn(remaining_time)
2853
    else:
2854
      current_delay = calc_delay()
2855
      if current_delay > 0.0:
2856
        wait_fn(current_delay)
2857

    
2858

    
2859
def GetClosedTempfile(*args, **kwargs):
2860
  """Creates a temporary file and returns its path.
2861

2862
  """
2863
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
2864
  _CloseFDNoErr(fd)
2865
  return path
2866

    
2867

    
2868
def GenerateSelfSignedX509Cert(common_name, validity):
2869
  """Generates a self-signed X509 certificate.
2870

2871
  @type common_name: string
2872
  @param common_name: commonName value
2873
  @type validity: int
2874
  @param validity: Validity for certificate in seconds
2875

2876
  """
2877
  # Create private and public key
2878
  key = OpenSSL.crypto.PKey()
2879
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
2880

    
2881
  # Create self-signed certificate
2882
  cert = OpenSSL.crypto.X509()
2883
  if common_name:
2884
    cert.get_subject().CN = common_name
2885
  cert.set_serial_number(1)
2886
  cert.gmtime_adj_notBefore(0)
2887
  cert.gmtime_adj_notAfter(validity)
2888
  cert.set_issuer(cert.get_subject())
2889
  cert.set_pubkey(key)
2890
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
2891

    
2892
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
2893
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2894

    
2895
  return (key_pem, cert_pem)
2896

    
2897

    
2898
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
2899
  """Legacy function to generate self-signed X509 certificate.
2900

2901
  """
2902
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
2903
                                                   validity * 24 * 60 * 60)
2904

    
2905
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
2906

    
2907

    
2908
class FileLock(object):
2909
  """Utility class for file locks.
2910

2911
  """
2912
  def __init__(self, fd, filename):
2913
    """Constructor for FileLock.
2914

2915
    @type fd: file
2916
    @param fd: File object
2917
    @type filename: str
2918
    @param filename: Path of the file opened at I{fd}
2919

2920
    """
2921
    self.fd = fd
2922
    self.filename = filename
2923

    
2924
  @classmethod
2925
  def Open(cls, filename):
2926
    """Creates and opens a file to be used as a file-based lock.
2927

2928
    @type filename: string
2929
    @param filename: path to the file to be locked
2930

2931
    """
2932
    # Using "os.open" is necessary to allow both opening existing file
2933
    # read/write and creating if not existing. Vanilla "open" will truncate an
2934
    # existing file -or- allow creating if not existing.
2935
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
2936
               filename)
2937

    
2938
  def __del__(self):
2939
    self.Close()
2940

    
2941
  def Close(self):
2942
    """Close the file and release the lock.
2943

2944
    """
2945
    if hasattr(self, "fd") and self.fd:
2946
      self.fd.close()
2947
      self.fd = None
2948

    
2949
  def _flock(self, flag, blocking, timeout, errmsg):
2950
    """Wrapper for fcntl.flock.
2951

2952
    @type flag: int
2953
    @param flag: operation flag
2954
    @type blocking: bool
2955
    @param blocking: whether the operation should be done in blocking mode.
2956
    @type timeout: None or float
2957
    @param timeout: for how long the operation should be retried (implies
2958
                    non-blocking mode).
2959
    @type errmsg: string
2960
    @param errmsg: error message in case operation fails.
2961

2962
    """
2963
    assert self.fd, "Lock was closed"
2964
    assert timeout is None or timeout >= 0, \
2965
      "If specified, timeout must be positive"
2966
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
2967

    
2968
    # When a timeout is used, LOCK_NB must always be set
2969
    if not (timeout is None and blocking):
2970
      flag |= fcntl.LOCK_NB
2971

    
2972
    if timeout is None:
2973
      self._Lock(self.fd, flag, timeout)
2974
    else:
2975
      try:
2976
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
2977
              args=(self.fd, flag, timeout))
2978
      except RetryTimeout:
2979
        raise errors.LockError(errmsg)
2980

    
2981
  @staticmethod
2982
  def _Lock(fd, flag, timeout):
2983
    try:
2984
      fcntl.flock(fd, flag)
2985
    except IOError, err:
2986
      if timeout is not None and err.errno == errno.EAGAIN:
2987
        raise RetryAgain()
2988

    
2989
      logging.exception("fcntl.flock failed")
2990
      raise
2991

    
2992
  def Exclusive(self, blocking=False, timeout=None):
2993
    """Locks the file in exclusive mode.
2994

2995
    @type blocking: boolean
2996
    @param blocking: whether to block and wait until we
2997
        can lock the file or return immediately
2998
    @type timeout: int or None
2999
    @param timeout: if not None, the duration to wait for the lock
3000
        (in blocking mode)
3001

3002
    """
3003
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3004
                "Failed to lock %s in exclusive mode" % self.filename)
3005

    
3006
  def Shared(self, blocking=False, timeout=None):
3007
    """Locks the file in shared mode.
3008

3009
    @type blocking: boolean
3010
    @param blocking: whether to block and wait until we
3011
        can lock the file or return immediately
3012
    @type timeout: int or None
3013
    @param timeout: if not None, the duration to wait for the lock
3014
        (in blocking mode)
3015

3016
    """
3017
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3018
                "Failed to lock %s in shared mode" % self.filename)
3019

    
3020
  def Unlock(self, blocking=True, timeout=None):
3021
    """Unlocks the file.
3022

3023
    According to C{flock(2)}, unlocking can also be a nonblocking
3024
    operation::
3025

3026
      To make a non-blocking request, include LOCK_NB with any of the above
3027
      operations.
3028

3029
    @type blocking: boolean
3030
    @param blocking: whether to block and wait until we
3031
        can lock the file or return immediately
3032
    @type timeout: int or None
3033
    @param timeout: if not None, the duration to wait for the lock
3034
        (in blocking mode)
3035

3036
    """
3037
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3038
                "Failed to unlock %s" % self.filename)
3039

    
3040

    
3041
def SignalHandled(signums):
3042
  """Signal Handled decoration.
3043

3044
  This special decorator installs a signal handler and then calls the target
3045
  function. The function must accept a 'signal_handlers' keyword argument,
3046
  which will contain a dict indexed by signal number, with SignalHandler
3047
  objects as values.
3048

3049
  The decorator can be safely stacked with iself, to handle multiple signals
3050
  with different handlers.
3051

3052
  @type signums: list
3053
  @param signums: signals to intercept
3054

3055
  """
3056
  def wrap(fn):
3057
    def sig_function(*args, **kwargs):
3058
      assert 'signal_handlers' not in kwargs or \
3059
             kwargs['signal_handlers'] is None or \
3060
             isinstance(kwargs['signal_handlers'], dict), \
3061
             "Wrong signal_handlers parameter in original function call"
3062
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3063
        signal_handlers = kwargs['signal_handlers']
3064
      else:
3065
        signal_handlers = {}
3066
        kwargs['signal_handlers'] = signal_handlers
3067
      sighandler = SignalHandler(signums)
3068
      try:
3069
        for sig in signums:
3070
          signal_handlers[sig] = sighandler
3071
        return fn(*args, **kwargs)
3072
      finally:
3073
        sighandler.Reset()
3074
    return sig_function
3075
  return wrap
3076

    
3077

    
3078
class SignalHandler(object):
3079
  """Generic signal handler class.
3080

3081
  It automatically restores the original handler when deconstructed or
3082
  when L{Reset} is called. You can either pass your own handler
3083
  function in or query the L{called} attribute to detect whether the
3084
  signal was sent.
3085

3086
  @type signum: list
3087
  @ivar signum: the signals we handle
3088
  @type called: boolean
3089
  @ivar called: tracks whether any of the signals have been raised
3090

3091
  """
3092
  def __init__(self, signum, handler_fn=None):
3093
    """Constructs a new SignalHandler instance.
3094

3095
    @type signum: int or list of ints
3096
    @param signum: Single signal number or set of signal numbers
3097
    @type handler_fn: callable
3098
    @param handler_fn: Signal handling function
3099

3100
    """
3101
    assert handler_fn is None or callable(handler_fn)
3102

    
3103
    self.signum = set(signum)
3104
    self.called = False
3105

    
3106
    self._handler_fn = handler_fn
3107

    
3108
    self._previous = {}
3109
    try:
3110
      for signum in self.signum:
3111
        # Setup handler
3112
        prev_handler = signal.signal(signum, self._HandleSignal)
3113
        try:
3114
          self._previous[signum] = prev_handler
3115
        except:
3116
          # Restore previous handler
3117
          signal.signal(signum, prev_handler)
3118
          raise
3119
    except:
3120
      # Reset all handlers
3121
      self.Reset()
3122
      # Here we have a race condition: a handler may have already been called,
3123
      # but there's not much we can do about it at this point.
3124
      raise
3125

    
3126
  def __del__(self):
3127
    self.Reset()
3128

    
3129
  def Reset(self):
3130
    """Restore previous handler.
3131

3132
    This will reset all the signals to their previous handlers.
3133

3134
    """
3135
    for signum, prev_handler in self._previous.items():
3136
      signal.signal(signum, prev_handler)
3137
      # If successful, remove from dict
3138
      del self._previous[signum]
3139

    
3140
  def Clear(self):
3141
    """Unsets the L{called} flag.
3142

3143
    This function can be used in case a signal may arrive several times.
3144

3145
    """
3146
    self.called = False
3147

    
3148
  def _HandleSignal(self, signum, frame):
3149
    """Actual signal handling function.
3150

3151
    """
3152
    # This is not nice and not absolutely atomic, but it appears to be the only
3153
    # solution in Python -- there are no atomic types.
3154
    self.called = True
3155

    
3156
    if self._handler_fn:
3157
      self._handler_fn(signum, frame)
3158

    
3159

    
3160
class FieldSet(object):
3161
  """A simple field set.
3162

3163
  Among the features are:
3164
    - checking if a string is among a list of static string or regex objects
3165
    - checking if a whole list of string matches
3166
    - returning the matching groups from a regex match
3167

3168
  Internally, all fields are held as regular expression objects.
3169

3170
  """
3171
  def __init__(self, *items):
3172
    self.items = [re.compile("^%s$" % value) for value in items]
3173

    
3174
  def Extend(self, other_set):
3175
    """Extend the field set with the items from another one"""
3176
    self.items.extend(other_set.items)
3177

    
3178
  def Matches(self, field):
3179
    """Checks if a field matches the current set
3180

3181
    @type field: str
3182
    @param field: the string to match
3183
    @return: either None or a regular expression match object
3184

3185
    """
3186
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3187
      return m
3188
    return None
3189

    
3190
  def NonMatching(self, items):
3191
    """Returns the list of fields not matching the current set
3192

3193
    @type items: list
3194
    @param items: the list of fields to check
3195
    @rtype: list
3196
    @return: list of non-matching fields
3197

3198
    """
3199
    return [val for val in items if not self.Matches(val)]