Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ b5b8309d

History | View | Annotate | Download (105.5 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  import ctypes
59
except ImportError:
60
  ctypes = None
61

    
62
from ganeti import errors
63
from ganeti import constants
64
from ganeti import compat
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
85

    
86
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
87
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
88
#
89
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
90
# pid_t to "int".
91
#
92
# IEEE Std 1003.1-2008:
93
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
94
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
95
_STRUCT_UCRED = "iII"
96
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
97

    
98
# Certificate verification results
99
(CERT_WARNING,
100
 CERT_ERROR) = range(1, 3)
101

    
102
# Flags for mlockall() (from bits/mman.h)
103
_MCL_CURRENT = 1
104
_MCL_FUTURE = 2
105

    
106

    
107
class RunResult(object):
108
  """Holds the result of running external programs.
109

110
  @type exit_code: int
111
  @ivar exit_code: the exit code of the program, or None (if the program
112
      didn't exit())
113
  @type signal: int or None
114
  @ivar signal: the signal that caused the program to finish, or None
115
      (if the program wasn't terminated by a signal)
116
  @type stdout: str
117
  @ivar stdout: the standard output of the program
118
  @type stderr: str
119
  @ivar stderr: the standard error of the program
120
  @type failed: boolean
121
  @ivar failed: True in case the program was
122
      terminated by a signal or exited with a non-zero exit code
123
  @ivar fail_reason: a string detailing the termination reason
124

125
  """
126
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
127
               "failed", "fail_reason", "cmd"]
128

    
129

    
130
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
131
    self.cmd = cmd
132
    self.exit_code = exit_code
133
    self.signal = signal_
134
    self.stdout = stdout
135
    self.stderr = stderr
136
    self.failed = (signal_ is not None or exit_code != 0)
137

    
138
    if self.signal is not None:
139
      self.fail_reason = "terminated by signal %s" % self.signal
140
    elif self.exit_code is not None:
141
      self.fail_reason = "exited with exit code %s" % self.exit_code
142
    else:
143
      self.fail_reason = "unable to determine termination reason"
144

    
145
    if self.failed:
146
      logging.debug("Command '%s' failed (%s); output: %s",
147
                    self.cmd, self.fail_reason, self.output)
148

    
149
  def _GetOutput(self):
150
    """Returns the combined stdout and stderr for easier usage.
151

152
    """
153
    return self.stdout + self.stderr
154

    
155
  output = property(_GetOutput, None, None, "Return full output")
156

    
157

    
158
def _BuildCmdEnvironment(env, reset):
159
  """Builds the environment for an external program.
160

161
  """
162
  if reset:
163
    cmd_env = {}
164
  else:
165
    cmd_env = os.environ.copy()
166
    cmd_env["LC_ALL"] = "C"
167

    
168
  if env is not None:
169
    cmd_env.update(env)
170

    
171
  return cmd_env
172

    
173

    
174
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
175
  """Execute a (shell) command.
176

177
  The command should not read from its standard input, as it will be
178
  closed.
179

180
  @type cmd: string or list
181
  @param cmd: Command to run
182
  @type env: dict
183
  @param env: Additional environment variables
184
  @type output: str
185
  @param output: if desired, the output of the command can be
186
      saved in a file instead of the RunResult instance; this
187
      parameter denotes the file name (if not None)
188
  @type cwd: string
189
  @param cwd: if specified, will be used as the working
190
      directory for the command; the default will be /
191
  @type reset_env: boolean
192
  @param reset_env: whether to reset or keep the default os environment
193
  @rtype: L{RunResult}
194
  @return: RunResult instance
195
  @raise errors.ProgrammerError: if we call this when forks are disabled
196

197
  """
198
  if no_fork:
199
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
200

    
201
  if isinstance(cmd, basestring):
202
    strcmd = cmd
203
    shell = True
204
  else:
205
    cmd = [str(val) for val in cmd]
206
    strcmd = ShellQuoteArgs(cmd)
207
    shell = False
208

    
209
  if output:
210
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
211
  else:
212
    logging.debug("RunCmd %s", strcmd)
213

    
214
  cmd_env = _BuildCmdEnvironment(env, reset_env)
215

    
216
  try:
217
    if output is None:
218
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
219
    else:
220
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
221
      out = err = ""
222
  except OSError, err:
223
    if err.errno == errno.ENOENT:
224
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
225
                               (strcmd, err))
226
    else:
227
      raise
228

    
229
  if status >= 0:
230
    exitcode = status
231
    signal_ = None
232
  else:
233
    exitcode = None
234
    signal_ = -status
235

    
236
  return RunResult(exitcode, signal_, out, err, strcmd)
237

    
238

    
239
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
240
                pidfile=None):
241
  """Start a daemon process after forking twice.
242

243
  @type cmd: string or list
244
  @param cmd: Command to run
245
  @type env: dict
246
  @param env: Additional environment variables
247
  @type cwd: string
248
  @param cwd: Working directory for the program
249
  @type output: string
250
  @param output: Path to file in which to save the output
251
  @type output_fd: int
252
  @param output_fd: File descriptor for output
253
  @type pidfile: string
254
  @param pidfile: Process ID file
255
  @rtype: int
256
  @return: Daemon process ID
257
  @raise errors.ProgrammerError: if we call this when forks are disabled
258

259
  """
260
  if no_fork:
261
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
262
                                 " disabled")
263

    
264
  if output and not (bool(output) ^ (output_fd is not None)):
265
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
266
                                 " specified")
267

    
268
  if isinstance(cmd, basestring):
269
    cmd = ["/bin/sh", "-c", cmd]
270

    
271
  strcmd = ShellQuoteArgs(cmd)
272

    
273
  if output:
274
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
275
  else:
276
    logging.debug("StartDaemon %s", strcmd)
277

    
278
  cmd_env = _BuildCmdEnvironment(env, False)
279

    
280
  # Create pipe for sending PID back
281
  (pidpipe_read, pidpipe_write) = os.pipe()
282
  try:
283
    try:
284
      # Create pipe for sending error messages
285
      (errpipe_read, errpipe_write) = os.pipe()
286
      try:
287
        try:
288
          # First fork
289
          pid = os.fork()
290
          if pid == 0:
291
            try:
292
              # Child process, won't return
293
              _StartDaemonChild(errpipe_read, errpipe_write,
294
                                pidpipe_read, pidpipe_write,
295
                                cmd, cmd_env, cwd,
296
                                output, output_fd, pidfile)
297
            finally:
298
              # Well, maybe child process failed
299
              os._exit(1) # pylint: disable-msg=W0212
300
        finally:
301
          _CloseFDNoErr(errpipe_write)
302

    
303
        # Wait for daemon to be started (or an error message to arrive) and read
304
        # up to 100 KB as an error message
305
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
306
      finally:
307
        _CloseFDNoErr(errpipe_read)
308
    finally:
309
      _CloseFDNoErr(pidpipe_write)
310

    
311
    # Read up to 128 bytes for PID
312
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
313
  finally:
314
    _CloseFDNoErr(pidpipe_read)
315

    
316
  # Try to avoid zombies by waiting for child process
317
  try:
318
    os.waitpid(pid, 0)
319
  except OSError:
320
    pass
321

    
322
  if errormsg:
323
    raise errors.OpExecError("Error when starting daemon process: %r" %
324
                             errormsg)
325

    
326
  try:
327
    return int(pidtext)
328
  except (ValueError, TypeError), err:
329
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
330
                             (pidtext, err))
331

    
332

    
333
def _StartDaemonChild(errpipe_read, errpipe_write,
334
                      pidpipe_read, pidpipe_write,
335
                      args, env, cwd,
336
                      output, fd_output, pidfile):
337
  """Child process for starting daemon.
338

339
  """
340
  try:
341
    # Close parent's side
342
    _CloseFDNoErr(errpipe_read)
343
    _CloseFDNoErr(pidpipe_read)
344

    
345
    # First child process
346
    os.chdir("/")
347
    os.umask(077)
348
    os.setsid()
349

    
350
    # And fork for the second time
351
    pid = os.fork()
352
    if pid != 0:
353
      # Exit first child process
354
      os._exit(0) # pylint: disable-msg=W0212
355

    
356
    # Make sure pipe is closed on execv* (and thereby notifies original process)
357
    SetCloseOnExecFlag(errpipe_write, True)
358

    
359
    # List of file descriptors to be left open
360
    noclose_fds = [errpipe_write]
361

    
362
    # Open PID file
363
    if pidfile:
364
      try:
365
        # TODO: Atomic replace with another locked file instead of writing into
366
        # it after creating
367
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
368

    
369
        # Lock the PID file (and fail if not possible to do so). Any code
370
        # wanting to send a signal to the daemon should try to lock the PID
371
        # file before reading it. If acquiring the lock succeeds, the daemon is
372
        # no longer running and the signal should not be sent.
373
        LockFile(fd_pidfile)
374

    
375
        os.write(fd_pidfile, "%d\n" % os.getpid())
376
      except Exception, err:
377
        raise Exception("Creating and locking PID file failed: %s" % err)
378

    
379
      # Keeping the file open to hold the lock
380
      noclose_fds.append(fd_pidfile)
381

    
382
      SetCloseOnExecFlag(fd_pidfile, False)
383
    else:
384
      fd_pidfile = None
385

    
386
    # Open /dev/null
387
    fd_devnull = os.open(os.devnull, os.O_RDWR)
388

    
389
    assert not output or (bool(output) ^ (fd_output is not None))
390

    
391
    if fd_output is not None:
392
      pass
393
    elif output:
394
      # Open output file
395
      try:
396
        # TODO: Implement flag to set append=yes/no
397
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
398
      except EnvironmentError, err:
399
        raise Exception("Opening output file failed: %s" % err)
400
    else:
401
      fd_output = fd_devnull
402

    
403
    # Redirect standard I/O
404
    os.dup2(fd_devnull, 0)
405
    os.dup2(fd_output, 1)
406
    os.dup2(fd_output, 2)
407

    
408
    # Send daemon PID to parent
409
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
410

    
411
    # Close all file descriptors except stdio and error message pipe
412
    CloseFDs(noclose_fds=noclose_fds)
413

    
414
    # Change working directory
415
    os.chdir(cwd)
416

    
417
    if env is None:
418
      os.execvp(args[0], args)
419
    else:
420
      os.execvpe(args[0], args, env)
421
  except: # pylint: disable-msg=W0702
422
    try:
423
      # Report errors to original process
424
      buf = str(sys.exc_info()[1])
425

    
426
      RetryOnSignal(os.write, errpipe_write, buf)
427
    except: # pylint: disable-msg=W0702
428
      # Ignore errors in error handling
429
      pass
430

    
431
  os._exit(1) # pylint: disable-msg=W0212
432

    
433

    
434
def _RunCmdPipe(cmd, env, via_shell, cwd):
435
  """Run a command and return its output.
436

437
  @type  cmd: string or list
438
  @param cmd: Command to run
439
  @type env: dict
440
  @param env: The environment to use
441
  @type via_shell: bool
442
  @param via_shell: if we should run via the shell
443
  @type cwd: string
444
  @param cwd: the working directory for the program
445
  @rtype: tuple
446
  @return: (out, err, status)
447

448
  """
449
  poller = select.poll()
450
  child = subprocess.Popen(cmd, shell=via_shell,
451
                           stderr=subprocess.PIPE,
452
                           stdout=subprocess.PIPE,
453
                           stdin=subprocess.PIPE,
454
                           close_fds=True, env=env,
455
                           cwd=cwd)
456

    
457
  child.stdin.close()
458
  poller.register(child.stdout, select.POLLIN)
459
  poller.register(child.stderr, select.POLLIN)
460
  out = StringIO()
461
  err = StringIO()
462
  fdmap = {
463
    child.stdout.fileno(): (out, child.stdout),
464
    child.stderr.fileno(): (err, child.stderr),
465
    }
466
  for fd in fdmap:
467
    SetNonblockFlag(fd, True)
468

    
469
  while fdmap:
470
    pollresult = RetryOnSignal(poller.poll)
471

    
472
    for fd, event in pollresult:
473
      if event & select.POLLIN or event & select.POLLPRI:
474
        data = fdmap[fd][1].read()
475
        # no data from read signifies EOF (the same as POLLHUP)
476
        if not data:
477
          poller.unregister(fd)
478
          del fdmap[fd]
479
          continue
480
        fdmap[fd][0].write(data)
481
      if (event & select.POLLNVAL or event & select.POLLHUP or
482
          event & select.POLLERR):
483
        poller.unregister(fd)
484
        del fdmap[fd]
485

    
486
  out = out.getvalue()
487
  err = err.getvalue()
488

    
489
  status = child.wait()
490
  return out, err, status
491

    
492

    
493
def _RunCmdFile(cmd, env, via_shell, output, cwd):
494
  """Run a command and save its output to a file.
495

496
  @type  cmd: string or list
497
  @param cmd: Command to run
498
  @type env: dict
499
  @param env: The environment to use
500
  @type via_shell: bool
501
  @param via_shell: if we should run via the shell
502
  @type output: str
503
  @param output: the filename in which to save the output
504
  @type cwd: string
505
  @param cwd: the working directory for the program
506
  @rtype: int
507
  @return: the exit status
508

509
  """
510
  fh = open(output, "a")
511
  try:
512
    child = subprocess.Popen(cmd, shell=via_shell,
513
                             stderr=subprocess.STDOUT,
514
                             stdout=fh,
515
                             stdin=subprocess.PIPE,
516
                             close_fds=True, env=env,
517
                             cwd=cwd)
518

    
519
    child.stdin.close()
520
    status = child.wait()
521
  finally:
522
    fh.close()
523
  return status
524

    
525

    
526
def SetCloseOnExecFlag(fd, enable):
527
  """Sets or unsets the close-on-exec flag on a file descriptor.
528

529
  @type fd: int
530
  @param fd: File descriptor
531
  @type enable: bool
532
  @param enable: Whether to set or unset it.
533

534
  """
535
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
536

    
537
  if enable:
538
    flags |= fcntl.FD_CLOEXEC
539
  else:
540
    flags &= ~fcntl.FD_CLOEXEC
541

    
542
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
543

    
544

    
545
def SetNonblockFlag(fd, enable):
546
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
547

548
  @type fd: int
549
  @param fd: File descriptor
550
  @type enable: bool
551
  @param enable: Whether to set or unset it
552

553
  """
554
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
555

    
556
  if enable:
557
    flags |= os.O_NONBLOCK
558
  else:
559
    flags &= ~os.O_NONBLOCK
560

    
561
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
562

    
563

    
564
def RetryOnSignal(fn, *args, **kwargs):
565
  """Calls a function again if it failed due to EINTR.
566

567
  """
568
  while True:
569
    try:
570
      return fn(*args, **kwargs)
571
    except EnvironmentError, err:
572
      if err.errno != errno.EINTR:
573
        raise
574
    except (socket.error, select.error), err:
575
      # In python 2.6 and above select.error is an IOError, so it's handled
576
      # above, in 2.5 and below it's not, and it's handled here.
577
      if not (err.args and err.args[0] == errno.EINTR):
578
        raise
579

    
580

    
581
def RunParts(dir_name, env=None, reset_env=False):
582
  """Run Scripts or programs in a directory
583

584
  @type dir_name: string
585
  @param dir_name: absolute path to a directory
586
  @type env: dict
587
  @param env: The environment to use
588
  @type reset_env: boolean
589
  @param reset_env: whether to reset or keep the default os environment
590
  @rtype: list of tuples
591
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
592

593
  """
594
  rr = []
595

    
596
  try:
597
    dir_contents = ListVisibleFiles(dir_name)
598
  except OSError, err:
599
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
600
    return rr
601

    
602
  for relname in sorted(dir_contents):
603
    fname = PathJoin(dir_name, relname)
604
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
605
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
606
      rr.append((relname, constants.RUNPARTS_SKIP, None))
607
    else:
608
      try:
609
        result = RunCmd([fname], env=env, reset_env=reset_env)
610
      except Exception, err: # pylint: disable-msg=W0703
611
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
612
      else:
613
        rr.append((relname, constants.RUNPARTS_RUN, result))
614

    
615
  return rr
616

    
617

    
618
def GetSocketCredentials(sock):
619
  """Returns the credentials of the foreign process connected to a socket.
620

621
  @param sock: Unix socket
622
  @rtype: tuple; (number, number, number)
623
  @return: The PID, UID and GID of the connected foreign process.
624

625
  """
626
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
627
                             _STRUCT_UCRED_SIZE)
628
  return struct.unpack(_STRUCT_UCRED, peercred)
629

    
630

    
631
def RemoveFile(filename):
632
  """Remove a file ignoring some errors.
633

634
  Remove a file, ignoring non-existing ones or directories. Other
635
  errors are passed.
636

637
  @type filename: str
638
  @param filename: the file to be removed
639

640
  """
641
  try:
642
    os.unlink(filename)
643
  except OSError, err:
644
    if err.errno not in (errno.ENOENT, errno.EISDIR):
645
      raise
646

    
647

    
648
def RemoveDir(dirname):
649
  """Remove an empty directory.
650

651
  Remove a directory, ignoring non-existing ones.
652
  Other errors are passed. This includes the case,
653
  where the directory is not empty, so it can't be removed.
654

655
  @type dirname: str
656
  @param dirname: the empty directory to be removed
657

658
  """
659
  try:
660
    os.rmdir(dirname)
661
  except OSError, err:
662
    if err.errno != errno.ENOENT:
663
      raise
664

    
665

    
666
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
667
  """Renames a file.
668

669
  @type old: string
670
  @param old: Original path
671
  @type new: string
672
  @param new: New path
673
  @type mkdir: bool
674
  @param mkdir: Whether to create target directory if it doesn't exist
675
  @type mkdir_mode: int
676
  @param mkdir_mode: Mode for newly created directories
677

678
  """
679
  try:
680
    return os.rename(old, new)
681
  except OSError, err:
682
    # In at least one use case of this function, the job queue, directory
683
    # creation is very rare. Checking for the directory before renaming is not
684
    # as efficient.
685
    if mkdir and err.errno == errno.ENOENT:
686
      # Create directory and try again
687
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
688

    
689
      return os.rename(old, new)
690

    
691
    raise
692

    
693

    
694
def Makedirs(path, mode=0750):
695
  """Super-mkdir; create a leaf directory and all intermediate ones.
696

697
  This is a wrapper around C{os.makedirs} adding error handling not implemented
698
  before Python 2.5.
699

700
  """
701
  try:
702
    os.makedirs(path, mode)
703
  except OSError, err:
704
    # Ignore EEXIST. This is only handled in os.makedirs as included in
705
    # Python 2.5 and above.
706
    if err.errno != errno.EEXIST or not os.path.exists(path):
707
      raise
708

    
709

    
710
def ResetTempfileModule():
711
  """Resets the random name generator of the tempfile module.
712

713
  This function should be called after C{os.fork} in the child process to
714
  ensure it creates a newly seeded random generator. Otherwise it would
715
  generate the same random parts as the parent process. If several processes
716
  race for the creation of a temporary file, this could lead to one not getting
717
  a temporary name.
718

719
  """
720
  # pylint: disable-msg=W0212
721
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
722
    tempfile._once_lock.acquire()
723
    try:
724
      # Reset random name generator
725
      tempfile._name_sequence = None
726
    finally:
727
      tempfile._once_lock.release()
728
  else:
729
    logging.critical("The tempfile module misses at least one of the"
730
                     " '_once_lock' and '_name_sequence' attributes")
731

    
732

    
733
def _FingerprintFile(filename):
734
  """Compute the fingerprint of a file.
735

736
  If the file does not exist, a None will be returned
737
  instead.
738

739
  @type filename: str
740
  @param filename: the filename to checksum
741
  @rtype: str
742
  @return: the hex digest of the sha checksum of the contents
743
      of the file
744

745
  """
746
  if not (os.path.exists(filename) and os.path.isfile(filename)):
747
    return None
748

    
749
  f = open(filename)
750

    
751
  fp = compat.sha1_hash()
752
  while True:
753
    data = f.read(4096)
754
    if not data:
755
      break
756

    
757
    fp.update(data)
758

    
759
  return fp.hexdigest()
760

    
761

    
762
def FingerprintFiles(files):
763
  """Compute fingerprints for a list of files.
764

765
  @type files: list
766
  @param files: the list of filename to fingerprint
767
  @rtype: dict
768
  @return: a dictionary filename: fingerprint, holding only
769
      existing files
770

771
  """
772
  ret = {}
773

    
774
  for filename in files:
775
    cksum = _FingerprintFile(filename)
776
    if cksum:
777
      ret[filename] = cksum
778

    
779
  return ret
780

    
781

    
782
def ForceDictType(target, key_types, allowed_values=None):
783
  """Force the values of a dict to have certain types.
784

785
  @type target: dict
786
  @param target: the dict to update
787
  @type key_types: dict
788
  @param key_types: dict mapping target dict keys to types
789
                    in constants.ENFORCEABLE_TYPES
790
  @type allowed_values: list
791
  @keyword allowed_values: list of specially allowed values
792

793
  """
794
  if allowed_values is None:
795
    allowed_values = []
796

    
797
  if not isinstance(target, dict):
798
    msg = "Expected dictionary, got '%s'" % target
799
    raise errors.TypeEnforcementError(msg)
800

    
801
  for key in target:
802
    if key not in key_types:
803
      msg = "Unknown key '%s'" % key
804
      raise errors.TypeEnforcementError(msg)
805

    
806
    if target[key] in allowed_values:
807
      continue
808

    
809
    ktype = key_types[key]
810
    if ktype not in constants.ENFORCEABLE_TYPES:
811
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
812
      raise errors.ProgrammerError(msg)
813

    
814
    if ktype == constants.VTYPE_STRING:
815
      if not isinstance(target[key], basestring):
816
        if isinstance(target[key], bool) and not target[key]:
817
          target[key] = ''
818
        else:
819
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
820
          raise errors.TypeEnforcementError(msg)
821
    elif ktype == constants.VTYPE_BOOL:
822
      if isinstance(target[key], basestring) and target[key]:
823
        if target[key].lower() == constants.VALUE_FALSE:
824
          target[key] = False
825
        elif target[key].lower() == constants.VALUE_TRUE:
826
          target[key] = True
827
        else:
828
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
829
          raise errors.TypeEnforcementError(msg)
830
      elif target[key]:
831
        target[key] = True
832
      else:
833
        target[key] = False
834
    elif ktype == constants.VTYPE_SIZE:
835
      try:
836
        target[key] = ParseUnit(target[key])
837
      except errors.UnitParseError, err:
838
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
839
              (key, target[key], err)
840
        raise errors.TypeEnforcementError(msg)
841
    elif ktype == constants.VTYPE_INT:
842
      try:
843
        target[key] = int(target[key])
844
      except (ValueError, TypeError):
845
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
846
        raise errors.TypeEnforcementError(msg)
847

    
848

    
849
def _GetProcStatusPath(pid):
850
  """Returns the path for a PID's proc status file.
851

852
  @type pid: int
853
  @param pid: Process ID
854
  @rtype: string
855

856
  """
857
  return "/proc/%d/status" % pid
858

    
859

    
860
def IsProcessAlive(pid):
861
  """Check if a given pid exists on the system.
862

863
  @note: zombie status is not handled, so zombie processes
864
      will be returned as alive
865
  @type pid: int
866
  @param pid: the process ID to check
867
  @rtype: boolean
868
  @return: True if the process exists
869

870
  """
871
  def _TryStat(name):
872
    try:
873
      os.stat(name)
874
      return True
875
    except EnvironmentError, err:
876
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
877
        return False
878
      elif err.errno == errno.EINVAL:
879
        raise RetryAgain(err)
880
      raise
881

    
882
  assert isinstance(pid, int), "pid must be an integer"
883
  if pid <= 0:
884
    return False
885

    
886
  # /proc in a multiprocessor environment can have strange behaviors.
887
  # Retry the os.stat a few times until we get a good result.
888
  try:
889
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
890
                 args=[_GetProcStatusPath(pid)])
891
  except RetryTimeout, err:
892
    err.RaiseInner()
893

    
894

    
895
def _ParseSigsetT(sigset):
896
  """Parse a rendered sigset_t value.
897

898
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
899
  function.
900

901
  @type sigset: string
902
  @param sigset: Rendered signal set from /proc/$pid/status
903
  @rtype: set
904
  @return: Set of all enabled signal numbers
905

906
  """
907
  result = set()
908

    
909
  signum = 0
910
  for ch in reversed(sigset):
911
    chv = int(ch, 16)
912

    
913
    # The following could be done in a loop, but it's easier to read and
914
    # understand in the unrolled form
915
    if chv & 1:
916
      result.add(signum + 1)
917
    if chv & 2:
918
      result.add(signum + 2)
919
    if chv & 4:
920
      result.add(signum + 3)
921
    if chv & 8:
922
      result.add(signum + 4)
923

    
924
    signum += 4
925

    
926
  return result
927

    
928

    
929
def _GetProcStatusField(pstatus, field):
930
  """Retrieves a field from the contents of a proc status file.
931

932
  @type pstatus: string
933
  @param pstatus: Contents of /proc/$pid/status
934
  @type field: string
935
  @param field: Name of field whose value should be returned
936
  @rtype: string
937

938
  """
939
  for line in pstatus.splitlines():
940
    parts = line.split(":", 1)
941

    
942
    if len(parts) < 2 or parts[0] != field:
943
      continue
944

    
945
    return parts[1].strip()
946

    
947
  return None
948

    
949

    
950
def IsProcessHandlingSignal(pid, signum, status_path=None):
951
  """Checks whether a process is handling a signal.
952

953
  @type pid: int
954
  @param pid: Process ID
955
  @type signum: int
956
  @param signum: Signal number
957
  @rtype: bool
958

959
  """
960
  if status_path is None:
961
    status_path = _GetProcStatusPath(pid)
962

    
963
  try:
964
    proc_status = ReadFile(status_path)
965
  except EnvironmentError, err:
966
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
967
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
968
      return False
969
    raise
970

    
971
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
972
  if sigcgt is None:
973
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
974

    
975
  # Now check whether signal is handled
976
  return signum in _ParseSigsetT(sigcgt)
977

    
978

    
979
def ReadPidFile(pidfile):
980
  """Read a pid from a file.
981

982
  @type  pidfile: string
983
  @param pidfile: path to the file containing the pid
984
  @rtype: int
985
  @return: The process id, if the file exists and contains a valid PID,
986
           otherwise 0
987

988
  """
989
  try:
990
    raw_data = ReadOneLineFile(pidfile)
991
  except EnvironmentError, err:
992
    if err.errno != errno.ENOENT:
993
      logging.exception("Can't read pid file")
994
    return 0
995

    
996
  try:
997
    pid = int(raw_data)
998
  except (TypeError, ValueError), err:
999
    logging.info("Can't parse pid file contents", exc_info=True)
1000
    return 0
1001

    
1002
  return pid
1003

    
1004

    
1005
def ReadLockedPidFile(path):
1006
  """Reads a locked PID file.
1007

1008
  This can be used together with L{StartDaemon}.
1009

1010
  @type path: string
1011
  @param path: Path to PID file
1012
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1013

1014
  """
1015
  try:
1016
    fd = os.open(path, os.O_RDONLY)
1017
  except EnvironmentError, err:
1018
    if err.errno == errno.ENOENT:
1019
      # PID file doesn't exist
1020
      return None
1021
    raise
1022

    
1023
  try:
1024
    try:
1025
      # Try to acquire lock
1026
      LockFile(fd)
1027
    except errors.LockError:
1028
      # Couldn't lock, daemon is running
1029
      return int(os.read(fd, 100))
1030
  finally:
1031
    os.close(fd)
1032

    
1033
  return None
1034

    
1035

    
1036
def MatchNameComponent(key, name_list, case_sensitive=True):
1037
  """Try to match a name against a list.
1038

1039
  This function will try to match a name like test1 against a list
1040
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1041
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1042
  not I{'test1.ex'}. A multiple match will be considered as no match
1043
  at all (e.g. I{'test1'} against C{['test1.example.com',
1044
  'test1.example.org']}), except when the key fully matches an entry
1045
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1046

1047
  @type key: str
1048
  @param key: the name to be searched
1049
  @type name_list: list
1050
  @param name_list: the list of strings against which to search the key
1051
  @type case_sensitive: boolean
1052
  @param case_sensitive: whether to provide a case-sensitive match
1053

1054
  @rtype: None or str
1055
  @return: None if there is no match I{or} if there are multiple matches,
1056
      otherwise the element from the list which matches
1057

1058
  """
1059
  if key in name_list:
1060
    return key
1061

    
1062
  re_flags = 0
1063
  if not case_sensitive:
1064
    re_flags |= re.IGNORECASE
1065
    key = key.upper()
1066
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1067
  names_filtered = []
1068
  string_matches = []
1069
  for name in name_list:
1070
    if mo.match(name) is not None:
1071
      names_filtered.append(name)
1072
      if not case_sensitive and key == name.upper():
1073
        string_matches.append(name)
1074

    
1075
  if len(string_matches) == 1:
1076
    return string_matches[0]
1077
  if len(names_filtered) == 1:
1078
    return names_filtered[0]
1079
  return None
1080

    
1081

    
1082
class HostInfo:
1083
  """Class implementing resolver and hostname functionality
1084

1085
  """
1086
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1087

    
1088
  def __init__(self, name=None):
1089
    """Initialize the host name object.
1090

1091
    If the name argument is not passed, it will use this system's
1092
    name.
1093

1094
    """
1095
    if name is None:
1096
      name = self.SysName()
1097

    
1098
    self.query = name
1099
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1100
    self.ip = self.ipaddrs[0]
1101

    
1102
  def ShortName(self):
1103
    """Returns the hostname without domain.
1104

1105
    """
1106
    return self.name.split('.')[0]
1107

    
1108
  @staticmethod
1109
  def SysName():
1110
    """Return the current system's name.
1111

1112
    This is simply a wrapper over C{socket.gethostname()}.
1113

1114
    """
1115
    return socket.gethostname()
1116

    
1117
  @staticmethod
1118
  def LookupHostname(hostname):
1119
    """Look up hostname
1120

1121
    @type hostname: str
1122
    @param hostname: hostname to look up
1123

1124
    @rtype: tuple
1125
    @return: a tuple (name, aliases, ipaddrs) as returned by
1126
        C{socket.gethostbyname_ex}
1127
    @raise errors.ResolverError: in case of errors in resolving
1128

1129
    """
1130
    try:
1131
      result = socket.gethostbyname_ex(hostname)
1132
    except socket.gaierror, err:
1133
      # hostname not found in DNS
1134
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1135

    
1136
    return result
1137

    
1138
  @classmethod
1139
  def NormalizeName(cls, hostname):
1140
    """Validate and normalize the given hostname.
1141

1142
    @attention: the validation is a bit more relaxed than the standards
1143
        require; most importantly, we allow underscores in names
1144
    @raise errors.OpPrereqError: when the name is not valid
1145

1146
    """
1147
    hostname = hostname.lower()
1148
    if (not cls._VALID_NAME_RE.match(hostname) or
1149
        # double-dots, meaning empty label
1150
        ".." in hostname or
1151
        # empty initial label
1152
        hostname.startswith(".")):
1153
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1154
                                 errors.ECODE_INVAL)
1155
    if hostname.endswith("."):
1156
      hostname = hostname.rstrip(".")
1157
    return hostname
1158

    
1159

    
1160
def ValidateServiceName(name):
1161
  """Validate the given service name.
1162

1163
  @type name: number or string
1164
  @param name: Service name or port specification
1165

1166
  """
1167
  try:
1168
    numport = int(name)
1169
  except (ValueError, TypeError):
1170
    # Non-numeric service name
1171
    valid = _VALID_SERVICE_NAME_RE.match(name)
1172
  else:
1173
    # Numeric port (protocols other than TCP or UDP might need adjustments
1174
    # here)
1175
    valid = (numport >= 0 and numport < (1 << 16))
1176

    
1177
  if not valid:
1178
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1179
                               errors.ECODE_INVAL)
1180

    
1181
  return name
1182

    
1183

    
1184
def GetHostInfo(name=None):
1185
  """Lookup host name and raise an OpPrereqError for failures"""
1186

    
1187
  try:
1188
    return HostInfo(name)
1189
  except errors.ResolverError, err:
1190
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1191
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1192

    
1193

    
1194
def ListVolumeGroups():
1195
  """List volume groups and their size
1196

1197
  @rtype: dict
1198
  @return:
1199
       Dictionary with keys volume name and values
1200
       the size of the volume
1201

1202
  """
1203
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1204
  result = RunCmd(command)
1205
  retval = {}
1206
  if result.failed:
1207
    return retval
1208

    
1209
  for line in result.stdout.splitlines():
1210
    try:
1211
      name, size = line.split()
1212
      size = int(float(size))
1213
    except (IndexError, ValueError), err:
1214
      logging.error("Invalid output from vgs (%s): %s", err, line)
1215
      continue
1216

    
1217
    retval[name] = size
1218

    
1219
  return retval
1220

    
1221

    
1222
def BridgeExists(bridge):
1223
  """Check whether the given bridge exists in the system
1224

1225
  @type bridge: str
1226
  @param bridge: the bridge name to check
1227
  @rtype: boolean
1228
  @return: True if it does
1229

1230
  """
1231
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1232

    
1233

    
1234
def NiceSort(name_list):
1235
  """Sort a list of strings based on digit and non-digit groupings.
1236

1237
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1238
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1239
  'a11']}.
1240

1241
  The sort algorithm breaks each name in groups of either only-digits
1242
  or no-digits. Only the first eight such groups are considered, and
1243
  after that we just use what's left of the string.
1244

1245
  @type name_list: list
1246
  @param name_list: the names to be sorted
1247
  @rtype: list
1248
  @return: a copy of the name list sorted with our algorithm
1249

1250
  """
1251
  _SORTER_BASE = "(\D+|\d+)"
1252
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1253
                                                  _SORTER_BASE, _SORTER_BASE,
1254
                                                  _SORTER_BASE, _SORTER_BASE,
1255
                                                  _SORTER_BASE, _SORTER_BASE)
1256
  _SORTER_RE = re.compile(_SORTER_FULL)
1257
  _SORTER_NODIGIT = re.compile("^\D*$")
1258
  def _TryInt(val):
1259
    """Attempts to convert a variable to integer."""
1260
    if val is None or _SORTER_NODIGIT.match(val):
1261
      return val
1262
    rval = int(val)
1263
    return rval
1264

    
1265
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1266
             for name in name_list]
1267
  to_sort.sort()
1268
  return [tup[1] for tup in to_sort]
1269

    
1270

    
1271
def TryConvert(fn, val):
1272
  """Try to convert a value ignoring errors.
1273

1274
  This function tries to apply function I{fn} to I{val}. If no
1275
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1276
  the result, else it will return the original value. Any other
1277
  exceptions are propagated to the caller.
1278

1279
  @type fn: callable
1280
  @param fn: function to apply to the value
1281
  @param val: the value to be converted
1282
  @return: The converted value if the conversion was successful,
1283
      otherwise the original value.
1284

1285
  """
1286
  try:
1287
    nv = fn(val)
1288
  except (ValueError, TypeError):
1289
    nv = val
1290
  return nv
1291

    
1292

    
1293
def IsValidIP(ip):
1294
  """Verifies the syntax of an IPv4 address.
1295

1296
  This function checks if the IPv4 address passes is valid or not based
1297
  on syntax (not IP range, class calculations, etc.).
1298

1299
  @type ip: str
1300
  @param ip: the address to be checked
1301
  @rtype: a regular expression match object
1302
  @return: a regular expression match object, or None if the
1303
      address is not valid
1304

1305
  """
1306
  unit = "(0|[1-9]\d{0,2})"
1307
  #TODO: convert and return only boolean
1308
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1309

    
1310

    
1311
def IsValidShellParam(word):
1312
  """Verifies is the given word is safe from the shell's p.o.v.
1313

1314
  This means that we can pass this to a command via the shell and be
1315
  sure that it doesn't alter the command line and is passed as such to
1316
  the actual command.
1317

1318
  Note that we are overly restrictive here, in order to be on the safe
1319
  side.
1320

1321
  @type word: str
1322
  @param word: the word to check
1323
  @rtype: boolean
1324
  @return: True if the word is 'safe'
1325

1326
  """
1327
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1328

    
1329

    
1330
def BuildShellCmd(template, *args):
1331
  """Build a safe shell command line from the given arguments.
1332

1333
  This function will check all arguments in the args list so that they
1334
  are valid shell parameters (i.e. they don't contain shell
1335
  metacharacters). If everything is ok, it will return the result of
1336
  template % args.
1337

1338
  @type template: str
1339
  @param template: the string holding the template for the
1340
      string formatting
1341
  @rtype: str
1342
  @return: the expanded command line
1343

1344
  """
1345
  for word in args:
1346
    if not IsValidShellParam(word):
1347
      raise errors.ProgrammerError("Shell argument '%s' contains"
1348
                                   " invalid characters" % word)
1349
  return template % args
1350

    
1351

    
1352
def FormatUnit(value, units):
1353
  """Formats an incoming number of MiB with the appropriate unit.
1354

1355
  @type value: int
1356
  @param value: integer representing the value in MiB (1048576)
1357
  @type units: char
1358
  @param units: the type of formatting we should do:
1359
      - 'h' for automatic scaling
1360
      - 'm' for MiBs
1361
      - 'g' for GiBs
1362
      - 't' for TiBs
1363
  @rtype: str
1364
  @return: the formatted value (with suffix)
1365

1366
  """
1367
  if units not in ('m', 'g', 't', 'h'):
1368
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1369

    
1370
  suffix = ''
1371

    
1372
  if units == 'm' or (units == 'h' and value < 1024):
1373
    if units == 'h':
1374
      suffix = 'M'
1375
    return "%d%s" % (round(value, 0), suffix)
1376

    
1377
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1378
    if units == 'h':
1379
      suffix = 'G'
1380
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1381

    
1382
  else:
1383
    if units == 'h':
1384
      suffix = 'T'
1385
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1386

    
1387

    
1388
def ParseUnit(input_string):
1389
  """Tries to extract number and scale from the given string.
1390

1391
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1392
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1393
  is always an int in MiB.
1394

1395
  """
1396
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1397
  if not m:
1398
    raise errors.UnitParseError("Invalid format")
1399

    
1400
  value = float(m.groups()[0])
1401

    
1402
  unit = m.groups()[1]
1403
  if unit:
1404
    lcunit = unit.lower()
1405
  else:
1406
    lcunit = 'm'
1407

    
1408
  if lcunit in ('m', 'mb', 'mib'):
1409
    # Value already in MiB
1410
    pass
1411

    
1412
  elif lcunit in ('g', 'gb', 'gib'):
1413
    value *= 1024
1414

    
1415
  elif lcunit in ('t', 'tb', 'tib'):
1416
    value *= 1024 * 1024
1417

    
1418
  else:
1419
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1420

    
1421
  # Make sure we round up
1422
  if int(value) < value:
1423
    value += 1
1424

    
1425
  # Round up to the next multiple of 4
1426
  value = int(value)
1427
  if value % 4:
1428
    value += 4 - value % 4
1429

    
1430
  return value
1431

    
1432

    
1433
def AddAuthorizedKey(file_name, key):
1434
  """Adds an SSH public key to an authorized_keys file.
1435

1436
  @type file_name: str
1437
  @param file_name: path to authorized_keys file
1438
  @type key: str
1439
  @param key: string containing key
1440

1441
  """
1442
  key_fields = key.split()
1443

    
1444
  f = open(file_name, 'a+')
1445
  try:
1446
    nl = True
1447
    for line in f:
1448
      # Ignore whitespace changes
1449
      if line.split() == key_fields:
1450
        break
1451
      nl = line.endswith('\n')
1452
    else:
1453
      if not nl:
1454
        f.write("\n")
1455
      f.write(key.rstrip('\r\n'))
1456
      f.write("\n")
1457
      f.flush()
1458
  finally:
1459
    f.close()
1460

    
1461

    
1462
def RemoveAuthorizedKey(file_name, key):
1463
  """Removes an SSH public key from an authorized_keys file.
1464

1465
  @type file_name: str
1466
  @param file_name: path to authorized_keys file
1467
  @type key: str
1468
  @param key: string containing key
1469

1470
  """
1471
  key_fields = key.split()
1472

    
1473
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1474
  try:
1475
    out = os.fdopen(fd, 'w')
1476
    try:
1477
      f = open(file_name, 'r')
1478
      try:
1479
        for line in f:
1480
          # Ignore whitespace changes while comparing lines
1481
          if line.split() != key_fields:
1482
            out.write(line)
1483

    
1484
        out.flush()
1485
        os.rename(tmpname, file_name)
1486
      finally:
1487
        f.close()
1488
    finally:
1489
      out.close()
1490
  except:
1491
    RemoveFile(tmpname)
1492
    raise
1493

    
1494

    
1495
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1496
  """Sets the name of an IP address and hostname in /etc/hosts.
1497

1498
  @type file_name: str
1499
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1500
  @type ip: str
1501
  @param ip: the IP address
1502
  @type hostname: str
1503
  @param hostname: the hostname to be added
1504
  @type aliases: list
1505
  @param aliases: the list of aliases to add for the hostname
1506

1507
  """
1508
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1509
  # Ensure aliases are unique
1510
  aliases = UniqueSequence([hostname] + aliases)[1:]
1511

    
1512
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1513
  try:
1514
    out = os.fdopen(fd, 'w')
1515
    try:
1516
      f = open(file_name, 'r')
1517
      try:
1518
        for line in f:
1519
          fields = line.split()
1520
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1521
            continue
1522
          out.write(line)
1523

    
1524
        out.write("%s\t%s" % (ip, hostname))
1525
        if aliases:
1526
          out.write(" %s" % ' '.join(aliases))
1527
        out.write('\n')
1528

    
1529
        out.flush()
1530
        os.fsync(out)
1531
        os.chmod(tmpname, 0644)
1532
        os.rename(tmpname, file_name)
1533
      finally:
1534
        f.close()
1535
    finally:
1536
      out.close()
1537
  except:
1538
    RemoveFile(tmpname)
1539
    raise
1540

    
1541

    
1542
def AddHostToEtcHosts(hostname):
1543
  """Wrapper around SetEtcHostsEntry.
1544

1545
  @type hostname: str
1546
  @param hostname: a hostname that will be resolved and added to
1547
      L{constants.ETC_HOSTS}
1548

1549
  """
1550
  hi = HostInfo(name=hostname)
1551
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1552

    
1553

    
1554
def RemoveEtcHostsEntry(file_name, hostname):
1555
  """Removes a hostname from /etc/hosts.
1556

1557
  IP addresses without names are removed from the file.
1558

1559
  @type file_name: str
1560
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1561
  @type hostname: str
1562
  @param hostname: the hostname to be removed
1563

1564
  """
1565
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1566
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1567
  try:
1568
    out = os.fdopen(fd, 'w')
1569
    try:
1570
      f = open(file_name, 'r')
1571
      try:
1572
        for line in f:
1573
          fields = line.split()
1574
          if len(fields) > 1 and not fields[0].startswith('#'):
1575
            names = fields[1:]
1576
            if hostname in names:
1577
              while hostname in names:
1578
                names.remove(hostname)
1579
              if names:
1580
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1581
              continue
1582

    
1583
          out.write(line)
1584

    
1585
        out.flush()
1586
        os.fsync(out)
1587
        os.chmod(tmpname, 0644)
1588
        os.rename(tmpname, file_name)
1589
      finally:
1590
        f.close()
1591
    finally:
1592
      out.close()
1593
  except:
1594
    RemoveFile(tmpname)
1595
    raise
1596

    
1597

    
1598
def RemoveHostFromEtcHosts(hostname):
1599
  """Wrapper around RemoveEtcHostsEntry.
1600

1601
  @type hostname: str
1602
  @param hostname: hostname that will be resolved and its
1603
      full and shot name will be removed from
1604
      L{constants.ETC_HOSTS}
1605

1606
  """
1607
  hi = HostInfo(name=hostname)
1608
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1609
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1610

    
1611

    
1612
def TimestampForFilename():
1613
  """Returns the current time formatted for filenames.
1614

1615
  The format doesn't contain colons as some shells and applications them as
1616
  separators.
1617

1618
  """
1619
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1620

    
1621

    
1622
def CreateBackup(file_name):
1623
  """Creates a backup of a file.
1624

1625
  @type file_name: str
1626
  @param file_name: file to be backed up
1627
  @rtype: str
1628
  @return: the path to the newly created backup
1629
  @raise errors.ProgrammerError: for invalid file names
1630

1631
  """
1632
  if not os.path.isfile(file_name):
1633
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1634
                                file_name)
1635

    
1636
  prefix = ("%s.backup-%s." %
1637
            (os.path.basename(file_name), TimestampForFilename()))
1638
  dir_name = os.path.dirname(file_name)
1639

    
1640
  fsrc = open(file_name, 'rb')
1641
  try:
1642
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1643
    fdst = os.fdopen(fd, 'wb')
1644
    try:
1645
      logging.debug("Backing up %s at %s", file_name, backup_name)
1646
      shutil.copyfileobj(fsrc, fdst)
1647
    finally:
1648
      fdst.close()
1649
  finally:
1650
    fsrc.close()
1651

    
1652
  return backup_name
1653

    
1654

    
1655
def ShellQuote(value):
1656
  """Quotes shell argument according to POSIX.
1657

1658
  @type value: str
1659
  @param value: the argument to be quoted
1660
  @rtype: str
1661
  @return: the quoted value
1662

1663
  """
1664
  if _re_shell_unquoted.match(value):
1665
    return value
1666
  else:
1667
    return "'%s'" % value.replace("'", "'\\''")
1668

    
1669

    
1670
def ShellQuoteArgs(args):
1671
  """Quotes a list of shell arguments.
1672

1673
  @type args: list
1674
  @param args: list of arguments to be quoted
1675
  @rtype: str
1676
  @return: the quoted arguments concatenated with spaces
1677

1678
  """
1679
  return ' '.join([ShellQuote(i) for i in args])
1680

    
1681

    
1682
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1683
  """Simple ping implementation using TCP connect(2).
1684

1685
  Check if the given IP is reachable by doing attempting a TCP connect
1686
  to it.
1687

1688
  @type target: str
1689
  @param target: the IP or hostname to ping
1690
  @type port: int
1691
  @param port: the port to connect to
1692
  @type timeout: int
1693
  @param timeout: the timeout on the connection attempt
1694
  @type live_port_needed: boolean
1695
  @param live_port_needed: whether a closed port will cause the
1696
      function to return failure, as if there was a timeout
1697
  @type source: str or None
1698
  @param source: if specified, will cause the connect to be made
1699
      from this specific source address; failures to bind other
1700
      than C{EADDRNOTAVAIL} will be ignored
1701

1702
  """
1703
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1704

    
1705
  success = False
1706

    
1707
  if source is not None:
1708
    try:
1709
      sock.bind((source, 0))
1710
    except socket.error, (errcode, _):
1711
      if errcode == errno.EADDRNOTAVAIL:
1712
        success = False
1713

    
1714
  sock.settimeout(timeout)
1715

    
1716
  try:
1717
    sock.connect((target, port))
1718
    sock.close()
1719
    success = True
1720
  except socket.timeout:
1721
    success = False
1722
  except socket.error, (errcode, _):
1723
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1724

    
1725
  return success
1726

    
1727

    
1728
def OwnIpAddress(address):
1729
  """Check if the current host has the the given IP address.
1730

1731
  Currently this is done by TCP-pinging the address from the loopback
1732
  address.
1733

1734
  @type address: string
1735
  @param address: the address to check
1736
  @rtype: bool
1737
  @return: True if we own the address
1738

1739
  """
1740
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1741
                 source=constants.LOCALHOST_IP_ADDRESS)
1742

    
1743

    
1744
def ListVisibleFiles(path):
1745
  """Returns a list of visible files in a directory.
1746

1747
  @type path: str
1748
  @param path: the directory to enumerate
1749
  @rtype: list
1750
  @return: the list of all files not starting with a dot
1751
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1752

1753
  """
1754
  if not IsNormAbsPath(path):
1755
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1756
                                 " absolute/normalized: '%s'" % path)
1757
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1758
  return files
1759

    
1760

    
1761
def GetHomeDir(user, default=None):
1762
  """Try to get the homedir of the given user.
1763

1764
  The user can be passed either as a string (denoting the name) or as
1765
  an integer (denoting the user id). If the user is not found, the
1766
  'default' argument is returned, which defaults to None.
1767

1768
  """
1769
  try:
1770
    if isinstance(user, basestring):
1771
      result = pwd.getpwnam(user)
1772
    elif isinstance(user, (int, long)):
1773
      result = pwd.getpwuid(user)
1774
    else:
1775
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1776
                                   type(user))
1777
  except KeyError:
1778
    return default
1779
  return result.pw_dir
1780

    
1781

    
1782
def NewUUID():
1783
  """Returns a random UUID.
1784

1785
  @note: This is a Linux-specific method as it uses the /proc
1786
      filesystem.
1787
  @rtype: str
1788

1789
  """
1790
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1791

    
1792

    
1793
def GenerateSecret(numbytes=20):
1794
  """Generates a random secret.
1795

1796
  This will generate a pseudo-random secret returning an hex string
1797
  (so that it can be used where an ASCII string is needed).
1798

1799
  @param numbytes: the number of bytes which will be represented by the returned
1800
      string (defaulting to 20, the length of a SHA1 hash)
1801
  @rtype: str
1802
  @return: an hex representation of the pseudo-random sequence
1803

1804
  """
1805
  return os.urandom(numbytes).encode('hex')
1806

    
1807

    
1808
def EnsureDirs(dirs):
1809
  """Make required directories, if they don't exist.
1810

1811
  @param dirs: list of tuples (dir_name, dir_mode)
1812
  @type dirs: list of (string, integer)
1813

1814
  """
1815
  for dir_name, dir_mode in dirs:
1816
    try:
1817
      os.mkdir(dir_name, dir_mode)
1818
    except EnvironmentError, err:
1819
      if err.errno != errno.EEXIST:
1820
        raise errors.GenericError("Cannot create needed directory"
1821
                                  " '%s': %s" % (dir_name, err))
1822
    try:
1823
      os.chmod(dir_name, dir_mode)
1824
    except EnvironmentError, err:
1825
      raise errors.GenericError("Cannot change directory permissions on"
1826
                                " '%s': %s" % (dir_name, err))
1827
    if not os.path.isdir(dir_name):
1828
      raise errors.GenericError("%s is not a directory" % dir_name)
1829

    
1830

    
1831
def ReadFile(file_name, size=-1):
1832
  """Reads a file.
1833

1834
  @type size: int
1835
  @param size: Read at most size bytes (if negative, entire file)
1836
  @rtype: str
1837
  @return: the (possibly partial) content of the file
1838

1839
  """
1840
  f = open(file_name, "r")
1841
  try:
1842
    return f.read(size)
1843
  finally:
1844
    f.close()
1845

    
1846

    
1847
def WriteFile(file_name, fn=None, data=None,
1848
              mode=None, uid=-1, gid=-1,
1849
              atime=None, mtime=None, close=True,
1850
              dry_run=False, backup=False,
1851
              prewrite=None, postwrite=None):
1852
  """(Over)write a file atomically.
1853

1854
  The file_name and either fn (a function taking one argument, the
1855
  file descriptor, and which should write the data to it) or data (the
1856
  contents of the file) must be passed. The other arguments are
1857
  optional and allow setting the file mode, owner and group, and the
1858
  mtime/atime of the file.
1859

1860
  If the function doesn't raise an exception, it has succeeded and the
1861
  target file has the new contents. If the function has raised an
1862
  exception, an existing target file should be unmodified and the
1863
  temporary file should be removed.
1864

1865
  @type file_name: str
1866
  @param file_name: the target filename
1867
  @type fn: callable
1868
  @param fn: content writing function, called with
1869
      file descriptor as parameter
1870
  @type data: str
1871
  @param data: contents of the file
1872
  @type mode: int
1873
  @param mode: file mode
1874
  @type uid: int
1875
  @param uid: the owner of the file
1876
  @type gid: int
1877
  @param gid: the group of the file
1878
  @type atime: int
1879
  @param atime: a custom access time to be set on the file
1880
  @type mtime: int
1881
  @param mtime: a custom modification time to be set on the file
1882
  @type close: boolean
1883
  @param close: whether to close file after writing it
1884
  @type prewrite: callable
1885
  @param prewrite: function to be called before writing content
1886
  @type postwrite: callable
1887
  @param postwrite: function to be called after writing content
1888

1889
  @rtype: None or int
1890
  @return: None if the 'close' parameter evaluates to True,
1891
      otherwise the file descriptor
1892

1893
  @raise errors.ProgrammerError: if any of the arguments are not valid
1894

1895
  """
1896
  if not os.path.isabs(file_name):
1897
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1898
                                 " absolute: '%s'" % file_name)
1899

    
1900
  if [fn, data].count(None) != 1:
1901
    raise errors.ProgrammerError("fn or data required")
1902

    
1903
  if [atime, mtime].count(None) == 1:
1904
    raise errors.ProgrammerError("Both atime and mtime must be either"
1905
                                 " set or None")
1906

    
1907
  if backup and not dry_run and os.path.isfile(file_name):
1908
    CreateBackup(file_name)
1909

    
1910
  dir_name, base_name = os.path.split(file_name)
1911
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1912
  do_remove = True
1913
  # here we need to make sure we remove the temp file, if any error
1914
  # leaves it in place
1915
  try:
1916
    if uid != -1 or gid != -1:
1917
      os.chown(new_name, uid, gid)
1918
    if mode:
1919
      os.chmod(new_name, mode)
1920
    if callable(prewrite):
1921
      prewrite(fd)
1922
    if data is not None:
1923
      os.write(fd, data)
1924
    else:
1925
      fn(fd)
1926
    if callable(postwrite):
1927
      postwrite(fd)
1928
    os.fsync(fd)
1929
    if atime is not None and mtime is not None:
1930
      os.utime(new_name, (atime, mtime))
1931
    if not dry_run:
1932
      os.rename(new_name, file_name)
1933
      do_remove = False
1934
  finally:
1935
    if close:
1936
      os.close(fd)
1937
      result = None
1938
    else:
1939
      result = fd
1940
    if do_remove:
1941
      RemoveFile(new_name)
1942

    
1943
  return result
1944

    
1945

    
1946
def ReadOneLineFile(file_name, strict=False):
1947
  """Return the first non-empty line from a file.
1948

1949
  @type strict: boolean
1950
  @param strict: if True, abort if the file has more than one
1951
      non-empty line
1952

1953
  """
1954
  file_lines = ReadFile(file_name).splitlines()
1955
  full_lines = filter(bool, file_lines)
1956
  if not file_lines or not full_lines:
1957
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1958
  elif strict and len(full_lines) > 1:
1959
    raise errors.GenericError("Too many lines in one-liner file %s" %
1960
                              file_name)
1961
  return full_lines[0]
1962

    
1963

    
1964
def FirstFree(seq, base=0):
1965
  """Returns the first non-existing integer from seq.
1966

1967
  The seq argument should be a sorted list of positive integers. The
1968
  first time the index of an element is smaller than the element
1969
  value, the index will be returned.
1970

1971
  The base argument is used to start at a different offset,
1972
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1973

1974
  Example: C{[0, 1, 3]} will return I{2}.
1975

1976
  @type seq: sequence
1977
  @param seq: the sequence to be analyzed.
1978
  @type base: int
1979
  @param base: use this value as the base index of the sequence
1980
  @rtype: int
1981
  @return: the first non-used index in the sequence
1982

1983
  """
1984
  for idx, elem in enumerate(seq):
1985
    assert elem >= base, "Passed element is higher than base offset"
1986
    if elem > idx + base:
1987
      # idx is not used
1988
      return idx + base
1989
  return None
1990

    
1991

    
1992
def SingleWaitForFdCondition(fdobj, event, timeout):
1993
  """Waits for a condition to occur on the socket.
1994

1995
  Immediately returns at the first interruption.
1996

1997
  @type fdobj: integer or object supporting a fileno() method
1998
  @param fdobj: entity to wait for events on
1999
  @type event: integer
2000
  @param event: ORed condition (see select module)
2001
  @type timeout: float or None
2002
  @param timeout: Timeout in seconds
2003
  @rtype: int or None
2004
  @return: None for timeout, otherwise occured conditions
2005

2006
  """
2007
  check = (event | select.POLLPRI |
2008
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
2009

    
2010
  if timeout is not None:
2011
    # Poller object expects milliseconds
2012
    timeout *= 1000
2013

    
2014
  poller = select.poll()
2015
  poller.register(fdobj, event)
2016
  try:
2017
    # TODO: If the main thread receives a signal and we have no timeout, we
2018
    # could wait forever. This should check a global "quit" flag or something
2019
    # every so often.
2020
    io_events = poller.poll(timeout)
2021
  except select.error, err:
2022
    if err[0] != errno.EINTR:
2023
      raise
2024
    io_events = []
2025
  if io_events and io_events[0][1] & check:
2026
    return io_events[0][1]
2027
  else:
2028
    return None
2029

    
2030

    
2031
class FdConditionWaiterHelper(object):
2032
  """Retry helper for WaitForFdCondition.
2033

2034
  This class contains the retried and wait functions that make sure
2035
  WaitForFdCondition can continue waiting until the timeout is actually
2036
  expired.
2037

2038
  """
2039

    
2040
  def __init__(self, timeout):
2041
    self.timeout = timeout
2042

    
2043
  def Poll(self, fdobj, event):
2044
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2045
    if result is None:
2046
      raise RetryAgain()
2047
    else:
2048
      return result
2049

    
2050
  def UpdateTimeout(self, timeout):
2051
    self.timeout = timeout
2052

    
2053

    
2054
def WaitForFdCondition(fdobj, event, timeout):
2055
  """Waits for a condition to occur on the socket.
2056

2057
  Retries until the timeout is expired, even if interrupted.
2058

2059
  @type fdobj: integer or object supporting a fileno() method
2060
  @param fdobj: entity to wait for events on
2061
  @type event: integer
2062
  @param event: ORed condition (see select module)
2063
  @type timeout: float or None
2064
  @param timeout: Timeout in seconds
2065
  @rtype: int or None
2066
  @return: None for timeout, otherwise occured conditions
2067

2068
  """
2069
  if timeout is not None:
2070
    retrywaiter = FdConditionWaiterHelper(timeout)
2071
    try:
2072
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2073
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2074
    except RetryTimeout:
2075
      result = None
2076
  else:
2077
    result = None
2078
    while result is None:
2079
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2080
  return result
2081

    
2082

    
2083
def UniqueSequence(seq):
2084
  """Returns a list with unique elements.
2085

2086
  Element order is preserved.
2087

2088
  @type seq: sequence
2089
  @param seq: the sequence with the source elements
2090
  @rtype: list
2091
  @return: list of unique elements from seq
2092

2093
  """
2094
  seen = set()
2095
  return [i for i in seq if i not in seen and not seen.add(i)]
2096

    
2097

    
2098
def NormalizeAndValidateMac(mac):
2099
  """Normalizes and check if a MAC address is valid.
2100

2101
  Checks whether the supplied MAC address is formally correct, only
2102
  accepts colon separated format. Normalize it to all lower.
2103

2104
  @type mac: str
2105
  @param mac: the MAC to be validated
2106
  @rtype: str
2107
  @return: returns the normalized and validated MAC.
2108

2109
  @raise errors.OpPrereqError: If the MAC isn't valid
2110

2111
  """
2112
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2113
  if not mac_check.match(mac):
2114
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2115
                               mac, errors.ECODE_INVAL)
2116

    
2117
  return mac.lower()
2118

    
2119

    
2120
def TestDelay(duration):
2121
  """Sleep for a fixed amount of time.
2122

2123
  @type duration: float
2124
  @param duration: the sleep duration
2125
  @rtype: boolean
2126
  @return: False for negative value, True otherwise
2127

2128
  """
2129
  if duration < 0:
2130
    return False, "Invalid sleep duration"
2131
  time.sleep(duration)
2132
  return True, None
2133

    
2134

    
2135
def _CloseFDNoErr(fd, retries=5):
2136
  """Close a file descriptor ignoring errors.
2137

2138
  @type fd: int
2139
  @param fd: the file descriptor
2140
  @type retries: int
2141
  @param retries: how many retries to make, in case we get any
2142
      other error than EBADF
2143

2144
  """
2145
  try:
2146
    os.close(fd)
2147
  except OSError, err:
2148
    if err.errno != errno.EBADF:
2149
      if retries > 0:
2150
        _CloseFDNoErr(fd, retries - 1)
2151
    # else either it's closed already or we're out of retries, so we
2152
    # ignore this and go on
2153

    
2154

    
2155
def CloseFDs(noclose_fds=None):
2156
  """Close file descriptors.
2157

2158
  This closes all file descriptors above 2 (i.e. except
2159
  stdin/out/err).
2160

2161
  @type noclose_fds: list or None
2162
  @param noclose_fds: if given, it denotes a list of file descriptor
2163
      that should not be closed
2164

2165
  """
2166
  # Default maximum for the number of available file descriptors.
2167
  if 'SC_OPEN_MAX' in os.sysconf_names:
2168
    try:
2169
      MAXFD = os.sysconf('SC_OPEN_MAX')
2170
      if MAXFD < 0:
2171
        MAXFD = 1024
2172
    except OSError:
2173
      MAXFD = 1024
2174
  else:
2175
    MAXFD = 1024
2176
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2177
  if (maxfd == resource.RLIM_INFINITY):
2178
    maxfd = MAXFD
2179

    
2180
  # Iterate through and close all file descriptors (except the standard ones)
2181
  for fd in range(3, maxfd):
2182
    if noclose_fds and fd in noclose_fds:
2183
      continue
2184
    _CloseFDNoErr(fd)
2185

    
2186

    
2187
def Mlockall():
2188
  """Lock current process' virtual address space into RAM.
2189

2190
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2191
  see mlock(2) for more details. This function requires ctypes module.
2192

2193
  """
2194
  if ctypes is None:
2195
    logging.warning("Cannot set memory lock, ctypes module not found")
2196
    return
2197

    
2198
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2199
  if libc is None:
2200
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2201
    return
2202

    
2203
  # Some older version of the ctypes module don't have built-in functionality
2204
  # to access the errno global variable, where function error codes are stored.
2205
  # By declaring this variable as a pointer to an integer we can then access
2206
  # its value correctly, should the mlockall call fail, in order to see what
2207
  # the actual error code was.
2208
  # pylint: disable-msg=W0212
2209
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2210

    
2211
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2212
    # pylint: disable-msg=W0212
2213
    logging.error("Cannot set memory lock: %s",
2214
                  os.strerror(libc.__errno_location().contents.value))
2215
    return
2216

    
2217
  logging.debug("Memory lock set")
2218

    
2219

    
2220
def Daemonize(logfile, run_uid, run_gid):
2221
  """Daemonize the current process.
2222

2223
  This detaches the current process from the controlling terminal and
2224
  runs it in the background as a daemon.
2225

2226
  @type logfile: str
2227
  @param logfile: the logfile to which we should redirect stdout/stderr
2228
  @type run_uid: int
2229
  @param run_uid: Run the child under this uid
2230
  @type run_gid: int
2231
  @param run_gid: Run the child under this gid
2232
  @rtype: int
2233
  @return: the value zero
2234

2235
  """
2236
  # pylint: disable-msg=W0212
2237
  # yes, we really want os._exit
2238
  UMASK = 077
2239
  WORKDIR = "/"
2240

    
2241
  # this might fail
2242
  pid = os.fork()
2243
  if (pid == 0):  # The first child.
2244
    os.setsid()
2245
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2246
    #        make sure to check for config permission and bail out when invoked
2247
    #        with wrong user.
2248
    os.setgid(run_gid)
2249
    os.setuid(run_uid)
2250
    # this might fail
2251
    pid = os.fork() # Fork a second child.
2252
    if (pid == 0):  # The second child.
2253
      os.chdir(WORKDIR)
2254
      os.umask(UMASK)
2255
    else:
2256
      # exit() or _exit()?  See below.
2257
      os._exit(0) # Exit parent (the first child) of the second child.
2258
  else:
2259
    os._exit(0) # Exit parent of the first child.
2260

    
2261
  for fd in range(3):
2262
    _CloseFDNoErr(fd)
2263
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2264
  assert i == 0, "Can't close/reopen stdin"
2265
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2266
  assert i == 1, "Can't close/reopen stdout"
2267
  # Duplicate standard output to standard error.
2268
  os.dup2(1, 2)
2269
  return 0
2270

    
2271

    
2272
def DaemonPidFileName(name):
2273
  """Compute a ganeti pid file absolute path
2274

2275
  @type name: str
2276
  @param name: the daemon name
2277
  @rtype: str
2278
  @return: the full path to the pidfile corresponding to the given
2279
      daemon name
2280

2281
  """
2282
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2283

    
2284

    
2285
def EnsureDaemon(name):
2286
  """Check for and start daemon if not alive.
2287

2288
  """
2289
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2290
  if result.failed:
2291
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2292
                  name, result.fail_reason, result.output)
2293
    return False
2294

    
2295
  return True
2296

    
2297

    
2298
def StopDaemon(name):
2299
  """Stop daemon
2300

2301
  """
2302
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2303
  if result.failed:
2304
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2305
                  name, result.fail_reason, result.output)
2306
    return False
2307

    
2308
  return True
2309

    
2310

    
2311
def WritePidFile(name):
2312
  """Write the current process pidfile.
2313

2314
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2315

2316
  @type name: str
2317
  @param name: the daemon name to use
2318
  @raise errors.GenericError: if the pid file already exists and
2319
      points to a live process
2320

2321
  """
2322
  pid = os.getpid()
2323
  pidfilename = DaemonPidFileName(name)
2324
  if IsProcessAlive(ReadPidFile(pidfilename)):
2325
    raise errors.GenericError("%s contains a live process" % pidfilename)
2326

    
2327
  WriteFile(pidfilename, data="%d\n" % pid)
2328

    
2329

    
2330
def RemovePidFile(name):
2331
  """Remove the current process pidfile.
2332

2333
  Any errors are ignored.
2334

2335
  @type name: str
2336
  @param name: the daemon name used to derive the pidfile name
2337

2338
  """
2339
  pidfilename = DaemonPidFileName(name)
2340
  # TODO: we could check here that the file contains our pid
2341
  try:
2342
    RemoveFile(pidfilename)
2343
  except: # pylint: disable-msg=W0702
2344
    pass
2345

    
2346

    
2347
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2348
                waitpid=False):
2349
  """Kill a process given by its pid.
2350

2351
  @type pid: int
2352
  @param pid: The PID to terminate.
2353
  @type signal_: int
2354
  @param signal_: The signal to send, by default SIGTERM
2355
  @type timeout: int
2356
  @param timeout: The timeout after which, if the process is still alive,
2357
                  a SIGKILL will be sent. If not positive, no such checking
2358
                  will be done
2359
  @type waitpid: boolean
2360
  @param waitpid: If true, we should waitpid on this process after
2361
      sending signals, since it's our own child and otherwise it
2362
      would remain as zombie
2363

2364
  """
2365
  def _helper(pid, signal_, wait):
2366
    """Simple helper to encapsulate the kill/waitpid sequence"""
2367
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2368
      try:
2369
        os.waitpid(pid, os.WNOHANG)
2370
      except OSError:
2371
        pass
2372

    
2373
  if pid <= 0:
2374
    # kill with pid=0 == suicide
2375
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2376

    
2377
  if not IsProcessAlive(pid):
2378
    return
2379

    
2380
  _helper(pid, signal_, waitpid)
2381

    
2382
  if timeout <= 0:
2383
    return
2384

    
2385
  def _CheckProcess():
2386
    if not IsProcessAlive(pid):
2387
      return
2388

    
2389
    try:
2390
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2391
    except OSError:
2392
      raise RetryAgain()
2393

    
2394
    if result_pid > 0:
2395
      return
2396

    
2397
    raise RetryAgain()
2398

    
2399
  try:
2400
    # Wait up to $timeout seconds
2401
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2402
  except RetryTimeout:
2403
    pass
2404

    
2405
  if IsProcessAlive(pid):
2406
    # Kill process if it's still alive
2407
    _helper(pid, signal.SIGKILL, waitpid)
2408

    
2409

    
2410
def FindFile(name, search_path, test=os.path.exists):
2411
  """Look for a filesystem object in a given path.
2412

2413
  This is an abstract method to search for filesystem object (files,
2414
  dirs) under a given search path.
2415

2416
  @type name: str
2417
  @param name: the name to look for
2418
  @type search_path: str
2419
  @param search_path: location to start at
2420
  @type test: callable
2421
  @param test: a function taking one argument that should return True
2422
      if the a given object is valid; the default value is
2423
      os.path.exists, causing only existing files to be returned
2424
  @rtype: str or None
2425
  @return: full path to the object if found, None otherwise
2426

2427
  """
2428
  # validate the filename mask
2429
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2430
    logging.critical("Invalid value passed for external script name: '%s'",
2431
                     name)
2432
    return None
2433

    
2434
  for dir_name in search_path:
2435
    # FIXME: investigate switch to PathJoin
2436
    item_name = os.path.sep.join([dir_name, name])
2437
    # check the user test and that we're indeed resolving to the given
2438
    # basename
2439
    if test(item_name) and os.path.basename(item_name) == name:
2440
      return item_name
2441
  return None
2442

    
2443

    
2444
def CheckVolumeGroupSize(vglist, vgname, minsize):
2445
  """Checks if the volume group list is valid.
2446

2447
  The function will check if a given volume group is in the list of
2448
  volume groups and has a minimum size.
2449

2450
  @type vglist: dict
2451
  @param vglist: dictionary of volume group names and their size
2452
  @type vgname: str
2453
  @param vgname: the volume group we should check
2454
  @type minsize: int
2455
  @param minsize: the minimum size we accept
2456
  @rtype: None or str
2457
  @return: None for success, otherwise the error message
2458

2459
  """
2460
  vgsize = vglist.get(vgname, None)
2461
  if vgsize is None:
2462
    return "volume group '%s' missing" % vgname
2463
  elif vgsize < minsize:
2464
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2465
            (vgname, minsize, vgsize))
2466
  return None
2467

    
2468

    
2469
def SplitTime(value):
2470
  """Splits time as floating point number into a tuple.
2471

2472
  @param value: Time in seconds
2473
  @type value: int or float
2474
  @return: Tuple containing (seconds, microseconds)
2475

2476
  """
2477
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2478

    
2479
  assert 0 <= seconds, \
2480
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2481
  assert 0 <= microseconds <= 999999, \
2482
    "Microseconds must be 0-999999, but are %s" % microseconds
2483

    
2484
  return (int(seconds), int(microseconds))
2485

    
2486

    
2487
def MergeTime(timetuple):
2488
  """Merges a tuple into time as a floating point number.
2489

2490
  @param timetuple: Time as tuple, (seconds, microseconds)
2491
  @type timetuple: tuple
2492
  @return: Time as a floating point number expressed in seconds
2493

2494
  """
2495
  (seconds, microseconds) = timetuple
2496

    
2497
  assert 0 <= seconds, \
2498
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2499
  assert 0 <= microseconds <= 999999, \
2500
    "Microseconds must be 0-999999, but are %s" % microseconds
2501

    
2502
  return float(seconds) + (float(microseconds) * 0.000001)
2503

    
2504

    
2505
def GetDaemonPort(daemon_name):
2506
  """Get the daemon port for this cluster.
2507

2508
  Note that this routine does not read a ganeti-specific file, but
2509
  instead uses C{socket.getservbyname} to allow pre-customization of
2510
  this parameter outside of Ganeti.
2511

2512
  @type daemon_name: string
2513
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2514
  @rtype: int
2515

2516
  """
2517
  if daemon_name not in constants.DAEMONS_PORTS:
2518
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2519

    
2520
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2521
  try:
2522
    port = socket.getservbyname(daemon_name, proto)
2523
  except socket.error:
2524
    port = default_port
2525

    
2526
  return port
2527

    
2528

    
2529
class LogFileHandler(logging.FileHandler):
2530
  """Log handler that doesn't fallback to stderr.
2531

2532
  When an error occurs while writing on the logfile, logging.FileHandler tries
2533
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2534
  the logfile. This class avoids failures reporting errors to /dev/console.
2535

2536
  """
2537
  def __init__(self, filename, mode="a", encoding=None):
2538
    """Open the specified file and use it as the stream for logging.
2539

2540
    Also open /dev/console to report errors while logging.
2541

2542
    """
2543
    logging.FileHandler.__init__(self, filename, mode, encoding)
2544
    self.console = open(constants.DEV_CONSOLE, "a")
2545

    
2546
  def handleError(self, record): # pylint: disable-msg=C0103
2547
    """Handle errors which occur during an emit() call.
2548

2549
    Try to handle errors with FileHandler method, if it fails write to
2550
    /dev/console.
2551

2552
    """
2553
    try:
2554
      logging.FileHandler.handleError(self, record)
2555
    except Exception: # pylint: disable-msg=W0703
2556
      try:
2557
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2558
      except Exception: # pylint: disable-msg=W0703
2559
        # Log handler tried everything it could, now just give up
2560
        pass
2561

    
2562

    
2563
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2564
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2565
                 console_logging=False):
2566
  """Configures the logging module.
2567

2568
  @type logfile: str
2569
  @param logfile: the filename to which we should log
2570
  @type debug: integer
2571
  @param debug: if greater than zero, enable debug messages, otherwise
2572
      only those at C{INFO} and above level
2573
  @type stderr_logging: boolean
2574
  @param stderr_logging: whether we should also log to the standard error
2575
  @type program: str
2576
  @param program: the name under which we should log messages
2577
  @type multithreaded: boolean
2578
  @param multithreaded: if True, will add the thread name to the log file
2579
  @type syslog: string
2580
  @param syslog: one of 'no', 'yes', 'only':
2581
      - if no, syslog is not used
2582
      - if yes, syslog is used (in addition to file-logging)
2583
      - if only, only syslog is used
2584
  @type console_logging: boolean
2585
  @param console_logging: if True, will use a FileHandler which falls back to
2586
      the system console if logging fails
2587
  @raise EnvironmentError: if we can't open the log file and
2588
      syslog/stderr logging is disabled
2589

2590
  """
2591
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2592
  sft = program + "[%(process)d]:"
2593
  if multithreaded:
2594
    fmt += "/%(threadName)s"
2595
    sft += " (%(threadName)s)"
2596
  if debug:
2597
    fmt += " %(module)s:%(lineno)s"
2598
    # no debug info for syslog loggers
2599
  fmt += " %(levelname)s %(message)s"
2600
  # yes, we do want the textual level, as remote syslog will probably
2601
  # lose the error level, and it's easier to grep for it
2602
  sft += " %(levelname)s %(message)s"
2603
  formatter = logging.Formatter(fmt)
2604
  sys_fmt = logging.Formatter(sft)
2605

    
2606
  root_logger = logging.getLogger("")
2607
  root_logger.setLevel(logging.NOTSET)
2608

    
2609
  # Remove all previously setup handlers
2610
  for handler in root_logger.handlers:
2611
    handler.close()
2612
    root_logger.removeHandler(handler)
2613

    
2614
  if stderr_logging:
2615
    stderr_handler = logging.StreamHandler()
2616
    stderr_handler.setFormatter(formatter)
2617
    if debug:
2618
      stderr_handler.setLevel(logging.NOTSET)
2619
    else:
2620
      stderr_handler.setLevel(logging.CRITICAL)
2621
    root_logger.addHandler(stderr_handler)
2622

    
2623
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2624
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2625
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2626
                                                    facility)
2627
    syslog_handler.setFormatter(sys_fmt)
2628
    # Never enable debug over syslog
2629
    syslog_handler.setLevel(logging.INFO)
2630
    root_logger.addHandler(syslog_handler)
2631

    
2632
  if syslog != constants.SYSLOG_ONLY:
2633
    # this can fail, if the logging directories are not setup or we have
2634
    # a permisssion problem; in this case, it's best to log but ignore
2635
    # the error if stderr_logging is True, and if false we re-raise the
2636
    # exception since otherwise we could run but without any logs at all
2637
    try:
2638
      if console_logging:
2639
        logfile_handler = LogFileHandler(logfile)
2640
      else:
2641
        logfile_handler = logging.FileHandler(logfile)
2642
      logfile_handler.setFormatter(formatter)
2643
      if debug:
2644
        logfile_handler.setLevel(logging.DEBUG)
2645
      else:
2646
        logfile_handler.setLevel(logging.INFO)
2647
      root_logger.addHandler(logfile_handler)
2648
    except EnvironmentError:
2649
      if stderr_logging or syslog == constants.SYSLOG_YES:
2650
        logging.exception("Failed to enable logging to file '%s'", logfile)
2651
      else:
2652
        # we need to re-raise the exception
2653
        raise
2654

    
2655

    
2656
def IsNormAbsPath(path):
2657
  """Check whether a path is absolute and also normalized
2658

2659
  This avoids things like /dir/../../other/path to be valid.
2660

2661
  """
2662
  return os.path.normpath(path) == path and os.path.isabs(path)
2663

    
2664

    
2665
def PathJoin(*args):
2666
  """Safe-join a list of path components.
2667

2668
  Requirements:
2669
      - the first argument must be an absolute path
2670
      - no component in the path must have backtracking (e.g. /../),
2671
        since we check for normalization at the end
2672

2673
  @param args: the path components to be joined
2674
  @raise ValueError: for invalid paths
2675

2676
  """
2677
  # ensure we're having at least one path passed in
2678
  assert args
2679
  # ensure the first component is an absolute and normalized path name
2680
  root = args[0]
2681
  if not IsNormAbsPath(root):
2682
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2683
  result = os.path.join(*args)
2684
  # ensure that the whole path is normalized
2685
  if not IsNormAbsPath(result):
2686
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2687
  # check that we're still under the original prefix
2688
  prefix = os.path.commonprefix([root, result])
2689
  if prefix != root:
2690
    raise ValueError("Error: path joining resulted in different prefix"
2691
                     " (%s != %s)" % (prefix, root))
2692
  return result
2693

    
2694

    
2695
def TailFile(fname, lines=20):
2696
  """Return the last lines from a file.
2697

2698
  @note: this function will only read and parse the last 4KB of
2699
      the file; if the lines are very long, it could be that less
2700
      than the requested number of lines are returned
2701

2702
  @param fname: the file name
2703
  @type lines: int
2704
  @param lines: the (maximum) number of lines to return
2705

2706
  """
2707
  fd = open(fname, "r")
2708
  try:
2709
    fd.seek(0, 2)
2710
    pos = fd.tell()
2711
    pos = max(0, pos-4096)
2712
    fd.seek(pos, 0)
2713
    raw_data = fd.read()
2714
  finally:
2715
    fd.close()
2716

    
2717
  rows = raw_data.splitlines()
2718
  return rows[-lines:]
2719

    
2720

    
2721
def FormatTimestampWithTZ(secs):
2722
  """Formats a Unix timestamp with the local timezone.
2723

2724
  """
2725
  return time.strftime("%F %T %Z", time.gmtime(secs))
2726

    
2727

    
2728
def _ParseAsn1Generalizedtime(value):
2729
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2730

2731
  @type value: string
2732
  @param value: ASN1 GENERALIZEDTIME timestamp
2733

2734
  """
2735
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2736
  if m:
2737
    # We have an offset
2738
    asn1time = m.group(1)
2739
    hours = int(m.group(2))
2740
    minutes = int(m.group(3))
2741
    utcoffset = (60 * hours) + minutes
2742
  else:
2743
    if not value.endswith("Z"):
2744
      raise ValueError("Missing timezone")
2745
    asn1time = value[:-1]
2746
    utcoffset = 0
2747

    
2748
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2749

    
2750
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2751

    
2752
  return calendar.timegm(tt.utctimetuple())
2753

    
2754

    
2755
def GetX509CertValidity(cert):
2756
  """Returns the validity period of the certificate.
2757

2758
  @type cert: OpenSSL.crypto.X509
2759
  @param cert: X509 certificate object
2760

2761
  """
2762
  # The get_notBefore and get_notAfter functions are only supported in
2763
  # pyOpenSSL 0.7 and above.
2764
  try:
2765
    get_notbefore_fn = cert.get_notBefore
2766
  except AttributeError:
2767
    not_before = None
2768
  else:
2769
    not_before_asn1 = get_notbefore_fn()
2770

    
2771
    if not_before_asn1 is None:
2772
      not_before = None
2773
    else:
2774
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2775

    
2776
  try:
2777
    get_notafter_fn = cert.get_notAfter
2778
  except AttributeError:
2779
    not_after = None
2780
  else:
2781
    not_after_asn1 = get_notafter_fn()
2782

    
2783
    if not_after_asn1 is None:
2784
      not_after = None
2785
    else:
2786
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2787

    
2788
  return (not_before, not_after)
2789

    
2790

    
2791
def _VerifyCertificateInner(expired, not_before, not_after, now,
2792
                            warn_days, error_days):
2793
  """Verifies certificate validity.
2794

2795
  @type expired: bool
2796
  @param expired: Whether pyOpenSSL considers the certificate as expired
2797
  @type not_before: number or None
2798
  @param not_before: Unix timestamp before which certificate is not valid
2799
  @type not_after: number or None
2800
  @param not_after: Unix timestamp after which certificate is invalid
2801
  @type now: number
2802
  @param now: Current time as Unix timestamp
2803
  @type warn_days: number or None
2804
  @param warn_days: How many days before expiration a warning should be reported
2805
  @type error_days: number or None
2806
  @param error_days: How many days before expiration an error should be reported
2807

2808
  """
2809
  if expired:
2810
    msg = "Certificate is expired"
2811

    
2812
    if not_before is not None and not_after is not None:
2813
      msg += (" (valid from %s to %s)" %
2814
              (FormatTimestampWithTZ(not_before),
2815
               FormatTimestampWithTZ(not_after)))
2816
    elif not_before is not None:
2817
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2818
    elif not_after is not None:
2819
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2820

    
2821
    return (CERT_ERROR, msg)
2822

    
2823
  elif not_before is not None and not_before > now:
2824
    return (CERT_WARNING,
2825
            "Certificate not yet valid (valid from %s)" %
2826
            FormatTimestampWithTZ(not_before))
2827

    
2828
  elif not_after is not None:
2829
    remaining_days = int((not_after - now) / (24 * 3600))
2830

    
2831
    msg = "Certificate expires in about %d days" % remaining_days
2832

    
2833
    if error_days is not None and remaining_days <= error_days:
2834
      return (CERT_ERROR, msg)
2835

    
2836
    if warn_days is not None and remaining_days <= warn_days:
2837
      return (CERT_WARNING, msg)
2838

    
2839
  return (None, None)
2840

    
2841

    
2842
def VerifyX509Certificate(cert, warn_days, error_days):
2843
  """Verifies a certificate for LUVerifyCluster.
2844

2845
  @type cert: OpenSSL.crypto.X509
2846
  @param cert: X509 certificate object
2847
  @type warn_days: number or None
2848
  @param warn_days: How many days before expiration a warning should be reported
2849
  @type error_days: number or None
2850
  @param error_days: How many days before expiration an error should be reported
2851

2852
  """
2853
  # Depending on the pyOpenSSL version, this can just return (None, None)
2854
  (not_before, not_after) = GetX509CertValidity(cert)
2855

    
2856
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2857
                                 time.time(), warn_days, error_days)
2858

    
2859

    
2860
def SignX509Certificate(cert, key, salt):
2861
  """Sign a X509 certificate.
2862

2863
  An RFC822-like signature header is added in front of the certificate.
2864

2865
  @type cert: OpenSSL.crypto.X509
2866
  @param cert: X509 certificate object
2867
  @type key: string
2868
  @param key: Key for HMAC
2869
  @type salt: string
2870
  @param salt: Salt for HMAC
2871
  @rtype: string
2872
  @return: Serialized and signed certificate in PEM format
2873

2874
  """
2875
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2876
    raise errors.GenericError("Invalid salt: %r" % salt)
2877

    
2878
  # Dumping as PEM here ensures the certificate is in a sane format
2879
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2880

    
2881
  return ("%s: %s/%s\n\n%s" %
2882
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2883
           Sha1Hmac(key, cert_pem, salt=salt),
2884
           cert_pem))
2885

    
2886

    
2887
def _ExtractX509CertificateSignature(cert_pem):
2888
  """Helper function to extract signature from X509 certificate.
2889

2890
  """
2891
  # Extract signature from original PEM data
2892
  for line in cert_pem.splitlines():
2893
    if line.startswith("---"):
2894
      break
2895

    
2896
    m = X509_SIGNATURE.match(line.strip())
2897
    if m:
2898
      return (m.group("salt"), m.group("sign"))
2899

    
2900
  raise errors.GenericError("X509 certificate signature is missing")
2901

    
2902

    
2903
def LoadSignedX509Certificate(cert_pem, key):
2904
  """Verifies a signed X509 certificate.
2905

2906
  @type cert_pem: string
2907
  @param cert_pem: Certificate in PEM format and with signature header
2908
  @type key: string
2909
  @param key: Key for HMAC
2910
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2911
  @return: X509 certificate object and salt
2912

2913
  """
2914
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2915

    
2916
  # Load certificate
2917
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2918

    
2919
  # Dump again to ensure it's in a sane format
2920
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2921

    
2922
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2923
    raise errors.GenericError("X509 certificate signature is invalid")
2924

    
2925
  return (cert, salt)
2926

    
2927

    
2928
def Sha1Hmac(key, text, salt=None):
2929
  """Calculates the HMAC-SHA1 digest of a text.
2930

2931
  HMAC is defined in RFC2104.
2932

2933
  @type key: string
2934
  @param key: Secret key
2935
  @type text: string
2936

2937
  """
2938
  if salt:
2939
    salted_text = salt + text
2940
  else:
2941
    salted_text = text
2942

    
2943
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2944

    
2945

    
2946
def VerifySha1Hmac(key, text, digest, salt=None):
2947
  """Verifies the HMAC-SHA1 digest of a text.
2948

2949
  HMAC is defined in RFC2104.
2950

2951
  @type key: string
2952
  @param key: Secret key
2953
  @type text: string
2954
  @type digest: string
2955
  @param digest: Expected digest
2956
  @rtype: bool
2957
  @return: Whether HMAC-SHA1 digest matches
2958

2959
  """
2960
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2961

    
2962

    
2963
def SafeEncode(text):
2964
  """Return a 'safe' version of a source string.
2965

2966
  This function mangles the input string and returns a version that
2967
  should be safe to display/encode as ASCII. To this end, we first
2968
  convert it to ASCII using the 'backslashreplace' encoding which
2969
  should get rid of any non-ASCII chars, and then we process it
2970
  through a loop copied from the string repr sources in the python; we
2971
  don't use string_escape anymore since that escape single quotes and
2972
  backslashes too, and that is too much; and that escaping is not
2973
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2974

2975
  @type text: str or unicode
2976
  @param text: input data
2977
  @rtype: str
2978
  @return: a safe version of text
2979

2980
  """
2981
  if isinstance(text, unicode):
2982
    # only if unicode; if str already, we handle it below
2983
    text = text.encode('ascii', 'backslashreplace')
2984
  resu = ""
2985
  for char in text:
2986
    c = ord(char)
2987
    if char  == '\t':
2988
      resu += r'\t'
2989
    elif char == '\n':
2990
      resu += r'\n'
2991
    elif char == '\r':
2992
      resu += r'\'r'
2993
    elif c < 32 or c >= 127: # non-printable
2994
      resu += "\\x%02x" % (c & 0xff)
2995
    else:
2996
      resu += char
2997
  return resu
2998

    
2999

    
3000
def UnescapeAndSplit(text, sep=","):
3001
  """Split and unescape a string based on a given separator.
3002

3003
  This function splits a string based on a separator where the
3004
  separator itself can be escape in order to be an element of the
3005
  elements. The escaping rules are (assuming coma being the
3006
  separator):
3007
    - a plain , separates the elements
3008
    - a sequence \\\\, (double backslash plus comma) is handled as a
3009
      backslash plus a separator comma
3010
    - a sequence \, (backslash plus comma) is handled as a
3011
      non-separator comma
3012

3013
  @type text: string
3014
  @param text: the string to split
3015
  @type sep: string
3016
  @param text: the separator
3017
  @rtype: string
3018
  @return: a list of strings
3019

3020
  """
3021
  # we split the list by sep (with no escaping at this stage)
3022
  slist = text.split(sep)
3023
  # next, we revisit the elements and if any of them ended with an odd
3024
  # number of backslashes, then we join it with the next
3025
  rlist = []
3026
  while slist:
3027
    e1 = slist.pop(0)
3028
    if e1.endswith("\\"):
3029
      num_b = len(e1) - len(e1.rstrip("\\"))
3030
      if num_b % 2 == 1:
3031
        e2 = slist.pop(0)
3032
        # here the backslashes remain (all), and will be reduced in
3033
        # the next step
3034
        rlist.append(e1 + sep + e2)
3035
        continue
3036
    rlist.append(e1)
3037
  # finally, replace backslash-something with something
3038
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3039
  return rlist
3040

    
3041

    
3042
def CommaJoin(names):
3043
  """Nicely join a set of identifiers.
3044

3045
  @param names: set, list or tuple
3046
  @return: a string with the formatted results
3047

3048
  """
3049
  return ", ".join([str(val) for val in names])
3050

    
3051

    
3052
def BytesToMebibyte(value):
3053
  """Converts bytes to mebibytes.
3054

3055
  @type value: int
3056
  @param value: Value in bytes
3057
  @rtype: int
3058
  @return: Value in mebibytes
3059

3060
  """
3061
  return int(round(value / (1024.0 * 1024.0), 0))
3062

    
3063

    
3064
def CalculateDirectorySize(path):
3065
  """Calculates the size of a directory recursively.
3066

3067
  @type path: string
3068
  @param path: Path to directory
3069
  @rtype: int
3070
  @return: Size in mebibytes
3071

3072
  """
3073
  size = 0
3074

    
3075
  for (curpath, _, files) in os.walk(path):
3076
    for filename in files:
3077
      st = os.lstat(PathJoin(curpath, filename))
3078
      size += st.st_size
3079

    
3080
  return BytesToMebibyte(size)
3081

    
3082

    
3083
def GetFilesystemStats(path):
3084
  """Returns the total and free space on a filesystem.
3085

3086
  @type path: string
3087
  @param path: Path on filesystem to be examined
3088
  @rtype: int
3089
  @return: tuple of (Total space, Free space) in mebibytes
3090

3091
  """
3092
  st = os.statvfs(path)
3093

    
3094
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3095
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3096
  return (tsize, fsize)
3097

    
3098

    
3099
def RunInSeparateProcess(fn, *args):
3100
  """Runs a function in a separate process.
3101

3102
  Note: Only boolean return values are supported.
3103

3104
  @type fn: callable
3105
  @param fn: Function to be called
3106
  @rtype: bool
3107
  @return: Function's result
3108

3109
  """
3110
  pid = os.fork()
3111
  if pid == 0:
3112
    # Child process
3113
    try:
3114
      # In case the function uses temporary files
3115
      ResetTempfileModule()
3116

    
3117
      # Call function
3118
      result = int(bool(fn(*args)))
3119
      assert result in (0, 1)
3120
    except: # pylint: disable-msg=W0702
3121
      logging.exception("Error while calling function in separate process")
3122
      # 0 and 1 are reserved for the return value
3123
      result = 33
3124

    
3125
    os._exit(result) # pylint: disable-msg=W0212
3126

    
3127
  # Parent process
3128

    
3129
  # Avoid zombies and check exit code
3130
  (_, status) = os.waitpid(pid, 0)
3131

    
3132
  if os.WIFSIGNALED(status):
3133
    exitcode = None
3134
    signum = os.WTERMSIG(status)
3135
  else:
3136
    exitcode = os.WEXITSTATUS(status)
3137
    signum = None
3138

    
3139
  if not (exitcode in (0, 1) and signum is None):
3140
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3141
                              (exitcode, signum))
3142

    
3143
  return bool(exitcode)
3144

    
3145

    
3146
def IgnoreProcessNotFound(fn, *args, **kwargs):
3147
  """Ignores ESRCH when calling a process-related function.
3148

3149
  ESRCH is raised when a process is not found.
3150

3151
  @rtype: bool
3152
  @return: Whether process was found
3153

3154
  """
3155
  try:
3156
    fn(*args, **kwargs)
3157
  except EnvironmentError, err:
3158
    # Ignore ESRCH
3159
    if err.errno == errno.ESRCH:
3160
      return False
3161
    raise
3162

    
3163
  return True
3164

    
3165

    
3166
def IgnoreSignals(fn, *args, **kwargs):
3167
  """Tries to call a function ignoring failures due to EINTR.
3168

3169
  """
3170
  try:
3171
    return fn(*args, **kwargs)
3172
  except EnvironmentError, err:
3173
    if err.errno == errno.EINTR:
3174
      return None
3175
    else:
3176
      raise
3177
  except (select.error, socket.error), err:
3178
    # In python 2.6 and above select.error is an IOError, so it's handled
3179
    # above, in 2.5 and below it's not, and it's handled here.
3180
    if err.args and err.args[0] == errno.EINTR:
3181
      return None
3182
    else:
3183
      raise
3184

    
3185

    
3186
def LockedMethod(fn):
3187
  """Synchronized object access decorator.
3188

3189
  This decorator is intended to protect access to an object using the
3190
  object's own lock which is hardcoded to '_lock'.
3191

3192
  """
3193
  def _LockDebug(*args, **kwargs):
3194
    if debug_locks:
3195
      logging.debug(*args, **kwargs)
3196

    
3197
  def wrapper(self, *args, **kwargs):
3198
    # pylint: disable-msg=W0212
3199
    assert hasattr(self, '_lock')
3200
    lock = self._lock
3201
    _LockDebug("Waiting for %s", lock)
3202
    lock.acquire()
3203
    try:
3204
      _LockDebug("Acquired %s", lock)
3205
      result = fn(self, *args, **kwargs)
3206
    finally:
3207
      _LockDebug("Releasing %s", lock)
3208
      lock.release()
3209
      _LockDebug("Released %s", lock)
3210
    return result
3211
  return wrapper
3212

    
3213

    
3214
def LockFile(fd):
3215
  """Locks a file using POSIX locks.
3216

3217
  @type fd: int
3218
  @param fd: the file descriptor we need to lock
3219

3220
  """
3221
  try:
3222
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3223
  except IOError, err:
3224
    if err.errno == errno.EAGAIN:
3225
      raise errors.LockError("File already locked")
3226
    raise
3227

    
3228

    
3229
def FormatTime(val):
3230
  """Formats a time value.
3231

3232
  @type val: float or None
3233
  @param val: the timestamp as returned by time.time()
3234
  @return: a string value or N/A if we don't have a valid timestamp
3235

3236
  """
3237
  if val is None or not isinstance(val, (int, float)):
3238
    return "N/A"
3239
  # these two codes works on Linux, but they are not guaranteed on all
3240
  # platforms
3241
  return time.strftime("%F %T", time.localtime(val))
3242

    
3243

    
3244
def FormatSeconds(secs):
3245
  """Formats seconds for easier reading.
3246

3247
  @type secs: number
3248
  @param secs: Number of seconds
3249
  @rtype: string
3250
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3251

3252
  """
3253
  parts = []
3254

    
3255
  secs = round(secs, 0)
3256

    
3257
  if secs > 0:
3258
    # Negative values would be a bit tricky
3259
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3260
      (complete, secs) = divmod(secs, one)
3261
      if complete or parts:
3262
        parts.append("%d%s" % (complete, unit))
3263

    
3264
  parts.append("%ds" % secs)
3265

    
3266
  return " ".join(parts)
3267

    
3268

    
3269
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3270
  """Reads the watcher pause file.
3271

3272
  @type filename: string
3273
  @param filename: Path to watcher pause file
3274
  @type now: None, float or int
3275
  @param now: Current time as Unix timestamp
3276
  @type remove_after: int
3277
  @param remove_after: Remove watcher pause file after specified amount of
3278
    seconds past the pause end time
3279

3280
  """
3281
  if now is None:
3282
    now = time.time()
3283

    
3284
  try:
3285
    value = ReadFile(filename)
3286
  except IOError, err:
3287
    if err.errno != errno.ENOENT:
3288
      raise
3289
    value = None
3290

    
3291
  if value is not None:
3292
    try:
3293
      value = int(value)
3294
    except ValueError:
3295
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3296
                       " removing it"), filename)
3297
      RemoveFile(filename)
3298
      value = None
3299

    
3300
    if value is not None:
3301
      # Remove file if it's outdated
3302
      if now > (value + remove_after):
3303
        RemoveFile(filename)
3304
        value = None
3305

    
3306
      elif now > value:
3307
        value = None
3308

    
3309
  return value
3310

    
3311

    
3312
class RetryTimeout(Exception):
3313
  """Retry loop timed out.
3314

3315
  Any arguments which was passed by the retried function to RetryAgain will be
3316
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3317
  the RaiseInner helper method will reraise it.
3318

3319
  """
3320
  def RaiseInner(self):
3321
    if self.args and isinstance(self.args[0], Exception):
3322
      raise self.args[0]
3323
    else:
3324
      raise RetryTimeout(*self.args)
3325

    
3326

    
3327
class RetryAgain(Exception):
3328
  """Retry again.
3329

3330
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3331
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3332
  of the RetryTimeout() method can be used to reraise it.
3333

3334
  """
3335

    
3336

    
3337
class _RetryDelayCalculator(object):
3338
  """Calculator for increasing delays.
3339

3340
  """
3341
  __slots__ = [
3342
    "_factor",
3343
    "_limit",
3344
    "_next",
3345
    "_start",
3346
    ]
3347

    
3348
  def __init__(self, start, factor, limit):
3349
    """Initializes this class.
3350

3351
    @type start: float
3352
    @param start: Initial delay
3353
    @type factor: float
3354
    @param factor: Factor for delay increase
3355
    @type limit: float or None
3356
    @param limit: Upper limit for delay or None for no limit
3357

3358
    """
3359
    assert start > 0.0
3360
    assert factor >= 1.0
3361
    assert limit is None or limit >= 0.0
3362

    
3363
    self._start = start
3364
    self._factor = factor
3365
    self._limit = limit
3366

    
3367
    self._next = start
3368

    
3369
  def __call__(self):
3370
    """Returns current delay and calculates the next one.
3371

3372
    """
3373
    current = self._next
3374

    
3375
    # Update for next run
3376
    if self._limit is None or self._next < self._limit:
3377
      self._next = min(self._limit, self._next * self._factor)
3378

    
3379
    return current
3380

    
3381

    
3382
#: Special delay to specify whole remaining timeout
3383
RETRY_REMAINING_TIME = object()
3384

    
3385

    
3386
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3387
          _time_fn=time.time):
3388
  """Call a function repeatedly until it succeeds.
3389

3390
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3391
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3392
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3393

3394
  C{delay} can be one of the following:
3395
    - callable returning the delay length as a float
3396
    - Tuple of (start, factor, limit)
3397
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3398
      useful when overriding L{wait_fn} to wait for an external event)
3399
    - A static delay as a number (int or float)
3400

3401
  @type fn: callable
3402
  @param fn: Function to be called
3403
  @param delay: Either a callable (returning the delay), a tuple of (start,
3404
                factor, limit) (see L{_RetryDelayCalculator}),
3405
                L{RETRY_REMAINING_TIME} or a number (int or float)
3406
  @type timeout: float
3407
  @param timeout: Total timeout
3408
  @type wait_fn: callable
3409
  @param wait_fn: Waiting function
3410
  @return: Return value of function
3411

3412
  """
3413
  assert callable(fn)
3414
  assert callable(wait_fn)
3415
  assert callable(_time_fn)
3416

    
3417
  if args is None:
3418
    args = []
3419

    
3420
  end_time = _time_fn() + timeout
3421

    
3422
  if callable(delay):
3423
    # External function to calculate delay
3424
    calc_delay = delay
3425

    
3426
  elif isinstance(delay, (tuple, list)):
3427
    # Increasing delay with optional upper boundary
3428
    (start, factor, limit) = delay
3429
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3430

    
3431
  elif delay is RETRY_REMAINING_TIME:
3432
    # Always use the remaining time
3433
    calc_delay = None
3434

    
3435
  else:
3436
    # Static delay
3437
    calc_delay = lambda: delay
3438

    
3439
  assert calc_delay is None or callable(calc_delay)
3440

    
3441
  while True:
3442
    retry_args = []
3443
    try:
3444
      # pylint: disable-msg=W0142
3445
      return fn(*args)
3446
    except RetryAgain, err:
3447
      retry_args = err.args
3448
    except RetryTimeout:
3449
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3450
                                   " handle RetryTimeout")
3451

    
3452
    remaining_time = end_time - _time_fn()
3453

    
3454
    if remaining_time < 0.0:
3455
      # pylint: disable-msg=W0142
3456
      raise RetryTimeout(*retry_args)
3457

    
3458
    assert remaining_time >= 0.0
3459

    
3460
    if calc_delay is None:
3461
      wait_fn(remaining_time)
3462
    else:
3463
      current_delay = calc_delay()
3464
      if current_delay > 0.0:
3465
        wait_fn(current_delay)
3466

    
3467

    
3468
def GetClosedTempfile(*args, **kwargs):
3469
  """Creates a temporary file and returns its path.
3470

3471
  """
3472
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3473
  _CloseFDNoErr(fd)
3474
  return path
3475

    
3476

    
3477
def GenerateSelfSignedX509Cert(common_name, validity):
3478
  """Generates a self-signed X509 certificate.
3479

3480
  @type common_name: string
3481
  @param common_name: commonName value
3482
  @type validity: int
3483
  @param validity: Validity for certificate in seconds
3484

3485
  """
3486
  # Create private and public key
3487
  key = OpenSSL.crypto.PKey()
3488
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3489

    
3490
  # Create self-signed certificate
3491
  cert = OpenSSL.crypto.X509()
3492
  if common_name:
3493
    cert.get_subject().CN = common_name
3494
  cert.set_serial_number(1)
3495
  cert.gmtime_adj_notBefore(0)
3496
  cert.gmtime_adj_notAfter(validity)
3497
  cert.set_issuer(cert.get_subject())
3498
  cert.set_pubkey(key)
3499
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3500

    
3501
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3502
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3503

    
3504
  return (key_pem, cert_pem)
3505

    
3506

    
3507
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3508
  """Legacy function to generate self-signed X509 certificate.
3509

3510
  """
3511
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3512
                                                   validity * 24 * 60 * 60)
3513

    
3514
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3515

    
3516

    
3517
class FileLock(object):
3518
  """Utility class for file locks.
3519

3520
  """
3521
  def __init__(self, fd, filename):
3522
    """Constructor for FileLock.
3523

3524
    @type fd: file
3525
    @param fd: File object
3526
    @type filename: str
3527
    @param filename: Path of the file opened at I{fd}
3528

3529
    """
3530
    self.fd = fd
3531
    self.filename = filename
3532

    
3533
  @classmethod
3534
  def Open(cls, filename):
3535
    """Creates and opens a file to be used as a file-based lock.
3536

3537
    @type filename: string
3538
    @param filename: path to the file to be locked
3539

3540
    """
3541
    # Using "os.open" is necessary to allow both opening existing file
3542
    # read/write and creating if not existing. Vanilla "open" will truncate an
3543
    # existing file -or- allow creating if not existing.
3544
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3545
               filename)
3546

    
3547
  def __del__(self):
3548
    self.Close()
3549

    
3550
  def Close(self):
3551
    """Close the file and release the lock.
3552

3553
    """
3554
    if hasattr(self, "fd") and self.fd:
3555
      self.fd.close()
3556
      self.fd = None
3557

    
3558
  def _flock(self, flag, blocking, timeout, errmsg):
3559
    """Wrapper for fcntl.flock.
3560

3561
    @type flag: int
3562
    @param flag: operation flag
3563
    @type blocking: bool
3564
    @param blocking: whether the operation should be done in blocking mode.
3565
    @type timeout: None or float
3566
    @param timeout: for how long the operation should be retried (implies
3567
                    non-blocking mode).
3568
    @type errmsg: string
3569
    @param errmsg: error message in case operation fails.
3570

3571
    """
3572
    assert self.fd, "Lock was closed"
3573
    assert timeout is None or timeout >= 0, \
3574
      "If specified, timeout must be positive"
3575
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3576

    
3577
    # When a timeout is used, LOCK_NB must always be set
3578
    if not (timeout is None and blocking):
3579
      flag |= fcntl.LOCK_NB
3580

    
3581
    if timeout is None:
3582
      self._Lock(self.fd, flag, timeout)
3583
    else:
3584
      try:
3585
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3586
              args=(self.fd, flag, timeout))
3587
      except RetryTimeout:
3588
        raise errors.LockError(errmsg)
3589

    
3590
  @staticmethod
3591
  def _Lock(fd, flag, timeout):
3592
    try:
3593
      fcntl.flock(fd, flag)
3594
    except IOError, err:
3595
      if timeout is not None and err.errno == errno.EAGAIN:
3596
        raise RetryAgain()
3597

    
3598
      logging.exception("fcntl.flock failed")
3599
      raise
3600

    
3601
  def Exclusive(self, blocking=False, timeout=None):
3602
    """Locks the file in exclusive mode.
3603

3604
    @type blocking: boolean
3605
    @param blocking: whether to block and wait until we
3606
        can lock the file or return immediately
3607
    @type timeout: int or None
3608
    @param timeout: if not None, the duration to wait for the lock
3609
        (in blocking mode)
3610

3611
    """
3612
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3613
                "Failed to lock %s in exclusive mode" % self.filename)
3614

    
3615
  def Shared(self, blocking=False, timeout=None):
3616
    """Locks the file in shared mode.
3617

3618
    @type blocking: boolean
3619
    @param blocking: whether to block and wait until we
3620
        can lock the file or return immediately
3621
    @type timeout: int or None
3622
    @param timeout: if not None, the duration to wait for the lock
3623
        (in blocking mode)
3624

3625
    """
3626
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3627
                "Failed to lock %s in shared mode" % self.filename)
3628

    
3629
  def Unlock(self, blocking=True, timeout=None):
3630
    """Unlocks the file.
3631

3632
    According to C{flock(2)}, unlocking can also be a nonblocking
3633
    operation::
3634

3635
      To make a non-blocking request, include LOCK_NB with any of the above
3636
      operations.
3637

3638
    @type blocking: boolean
3639
    @param blocking: whether to block and wait until we
3640
        can lock the file or return immediately
3641
    @type timeout: int or None
3642
    @param timeout: if not None, the duration to wait for the lock
3643
        (in blocking mode)
3644

3645
    """
3646
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3647
                "Failed to unlock %s" % self.filename)
3648

    
3649

    
3650
class LineSplitter:
3651
  """Splits data chunks into lines separated by newline.
3652

3653
  Instances provide a file-like interface.
3654

3655
  """
3656
  def __init__(self, line_fn, *args):
3657
    """Initializes this class.
3658

3659
    @type line_fn: callable
3660
    @param line_fn: Function called for each line, first parameter is line
3661
    @param args: Extra arguments for L{line_fn}
3662

3663
    """
3664
    assert callable(line_fn)
3665

    
3666
    if args:
3667
      # Python 2.4 doesn't have functools.partial yet
3668
      self._line_fn = \
3669
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3670
    else:
3671
      self._line_fn = line_fn
3672

    
3673
    self._lines = collections.deque()
3674
    self._buffer = ""
3675

    
3676
  def write(self, data):
3677
    parts = (self._buffer + data).split("\n")
3678
    self._buffer = parts.pop()
3679
    self._lines.extend(parts)
3680

    
3681
  def flush(self):
3682
    while self._lines:
3683
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3684

    
3685
  def close(self):
3686
    self.flush()
3687
    if self._buffer:
3688
      self._line_fn(self._buffer)
3689

    
3690

    
3691
def SignalHandled(signums):
3692
  """Signal Handled decoration.
3693

3694
  This special decorator installs a signal handler and then calls the target
3695
  function. The function must accept a 'signal_handlers' keyword argument,
3696
  which will contain a dict indexed by signal number, with SignalHandler
3697
  objects as values.
3698

3699
  The decorator can be safely stacked with iself, to handle multiple signals
3700
  with different handlers.
3701

3702
  @type signums: list
3703
  @param signums: signals to intercept
3704

3705
  """
3706
  def wrap(fn):
3707
    def sig_function(*args, **kwargs):
3708
      assert 'signal_handlers' not in kwargs or \
3709
             kwargs['signal_handlers'] is None or \
3710
             isinstance(kwargs['signal_handlers'], dict), \
3711
             "Wrong signal_handlers parameter in original function call"
3712
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3713
        signal_handlers = kwargs['signal_handlers']
3714
      else:
3715
        signal_handlers = {}
3716
        kwargs['signal_handlers'] = signal_handlers
3717
      sighandler = SignalHandler(signums)
3718
      try:
3719
        for sig in signums:
3720
          signal_handlers[sig] = sighandler
3721
        return fn(*args, **kwargs)
3722
      finally:
3723
        sighandler.Reset()
3724
    return sig_function
3725
  return wrap
3726

    
3727

    
3728
class SignalWakeupFd(object):
3729
  try:
3730
    # This is only supported in Python 2.5 and above (some distributions
3731
    # backported it to Python 2.4)
3732
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3733
  except AttributeError:
3734
    # Not supported
3735
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3736
      return -1
3737
  else:
3738
    def _SetWakeupFd(self, fd):
3739
      return self._set_wakeup_fd_fn(fd)
3740

    
3741
  def __init__(self):
3742
    """Initializes this class.
3743

3744
    """
3745
    (read_fd, write_fd) = os.pipe()
3746

    
3747
    # Once these succeeded, the file descriptors will be closed automatically.
3748
    # Buffer size 0 is important, otherwise .read() with a specified length
3749
    # might buffer data and the file descriptors won't be marked readable.
3750
    self._read_fh = os.fdopen(read_fd, "r", 0)
3751
    self._write_fh = os.fdopen(write_fd, "w", 0)
3752

    
3753
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3754

    
3755
    # Utility functions
3756
    self.fileno = self._read_fh.fileno
3757
    self.read = self._read_fh.read
3758

    
3759
  def Reset(self):
3760
    """Restores the previous wakeup file descriptor.
3761

3762
    """
3763
    if hasattr(self, "_previous") and self._previous is not None:
3764
      self._SetWakeupFd(self._previous)
3765
      self._previous = None
3766

    
3767
  def Notify(self):
3768
    """Notifies the wakeup file descriptor.
3769

3770
    """
3771
    self._write_fh.write("\0")
3772

    
3773
  def __del__(self):
3774
    """Called before object deletion.
3775

3776
    """
3777
    self.Reset()
3778

    
3779

    
3780
class SignalHandler(object):
3781
  """Generic signal handler class.
3782

3783
  It automatically restores the original handler when deconstructed or
3784
  when L{Reset} is called. You can either pass your own handler
3785
  function in or query the L{called} attribute to detect whether the
3786
  signal was sent.
3787

3788
  @type signum: list
3789
  @ivar signum: the signals we handle
3790
  @type called: boolean
3791
  @ivar called: tracks whether any of the signals have been raised
3792

3793
  """
3794
  def __init__(self, signum, handler_fn=None, wakeup=None):
3795
    """Constructs a new SignalHandler instance.
3796

3797
    @type signum: int or list of ints
3798
    @param signum: Single signal number or set of signal numbers
3799
    @type handler_fn: callable
3800
    @param handler_fn: Signal handling function
3801

3802
    """
3803
    assert handler_fn is None or callable(handler_fn)
3804

    
3805
    self.signum = set(signum)
3806
    self.called = False
3807

    
3808
    self._handler_fn = handler_fn
3809
    self._wakeup = wakeup
3810

    
3811
    self._previous = {}
3812
    try:
3813
      for signum in self.signum:
3814
        # Setup handler
3815
        prev_handler = signal.signal(signum, self._HandleSignal)
3816
        try:
3817
          self._previous[signum] = prev_handler
3818
        except:
3819
          # Restore previous handler
3820
          signal.signal(signum, prev_handler)
3821
          raise
3822
    except:
3823
      # Reset all handlers
3824
      self.Reset()
3825
      # Here we have a race condition: a handler may have already been called,
3826
      # but there's not much we can do about it at this point.
3827
      raise
3828

    
3829
  def __del__(self):
3830
    self.Reset()
3831

    
3832
  def Reset(self):
3833
    """Restore previous handler.
3834

3835
    This will reset all the signals to their previous handlers.
3836

3837
    """
3838
    for signum, prev_handler in self._previous.items():
3839
      signal.signal(signum, prev_handler)
3840
      # If successful, remove from dict
3841
      del self._previous[signum]
3842

    
3843
  def Clear(self):
3844
    """Unsets the L{called} flag.
3845

3846
    This function can be used in case a signal may arrive several times.
3847

3848
    """
3849
    self.called = False
3850

    
3851
  def _HandleSignal(self, signum, frame):
3852
    """Actual signal handling function.
3853

3854
    """
3855
    # This is not nice and not absolutely atomic, but it appears to be the only
3856
    # solution in Python -- there are no atomic types.
3857
    self.called = True
3858

    
3859
    if self._wakeup:
3860
      # Notify whoever is interested in signals
3861
      self._wakeup.Notify()
3862

    
3863
    if self._handler_fn:
3864
      self._handler_fn(signum, frame)
3865

    
3866

    
3867
class FieldSet(object):
3868
  """A simple field set.
3869

3870
  Among the features are:
3871
    - checking if a string is among a list of static string or regex objects
3872
    - checking if a whole list of string matches
3873
    - returning the matching groups from a regex match
3874

3875
  Internally, all fields are held as regular expression objects.
3876

3877
  """
3878
  def __init__(self, *items):
3879
    self.items = [re.compile("^%s$" % value) for value in items]
3880

    
3881
  def Extend(self, other_set):
3882
    """Extend the field set with the items from another one"""
3883
    self.items.extend(other_set.items)
3884

    
3885
  def Matches(self, field):
3886
    """Checks if a field matches the current set
3887

3888
    @type field: str
3889
    @param field: the string to match
3890
    @return: either None or a regular expression match object
3891

3892
    """
3893
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3894
      return m
3895
    return None
3896

    
3897
  def NonMatching(self, items):
3898
    """Returns the list of fields not matching the current set
3899

3900
    @type items: list
3901
    @param items: the list of fields to check
3902
    @rtype: list
3903
    @return: list of non-matching fields
3904

3905
    """
3906
    return [val for val in items if not self.Matches(val)]