Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ acd65a16

History | View | Annotate | Download (105.6 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  import ctypes
59
except ImportError:
60
  ctypes = None
61

    
62
from ganeti import errors
63
from ganeti import constants
64
from ganeti import compat
65

    
66

    
67
_locksheld = []
68
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
69

    
70
debug_locks = False
71

    
72
#: when set to True, L{RunCmd} is disabled
73
no_fork = False
74

    
75
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
76

    
77
HEX_CHAR_RE = r"[a-zA-Z0-9]"
78
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
79
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
80
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
81
                             HEX_CHAR_RE, HEX_CHAR_RE),
82
                            re.S | re.I)
83

    
84
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
85

    
86
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
87
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
88
#
89
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
90
# pid_t to "int".
91
#
92
# IEEE Std 1003.1-2008:
93
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
94
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
95
_STRUCT_UCRED = "iII"
96
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
97

    
98
# Certificate verification results
99
(CERT_WARNING,
100
 CERT_ERROR) = range(1, 3)
101

    
102
# Flags for mlockall() (from bits/mman.h)
103
_MCL_CURRENT = 1
104
_MCL_FUTURE = 2
105

    
106

    
107
class RunResult(object):
108
  """Holds the result of running external programs.
109

110
  @type exit_code: int
111
  @ivar exit_code: the exit code of the program, or None (if the program
112
      didn't exit())
113
  @type signal: int or None
114
  @ivar signal: the signal that caused the program to finish, or None
115
      (if the program wasn't terminated by a signal)
116
  @type stdout: str
117
  @ivar stdout: the standard output of the program
118
  @type stderr: str
119
  @ivar stderr: the standard error of the program
120
  @type failed: boolean
121
  @ivar failed: True in case the program was
122
      terminated by a signal or exited with a non-zero exit code
123
  @ivar fail_reason: a string detailing the termination reason
124

125
  """
126
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
127
               "failed", "fail_reason", "cmd"]
128

    
129

    
130
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
131
    self.cmd = cmd
132
    self.exit_code = exit_code
133
    self.signal = signal_
134
    self.stdout = stdout
135
    self.stderr = stderr
136
    self.failed = (signal_ is not None or exit_code != 0)
137

    
138
    if self.signal is not None:
139
      self.fail_reason = "terminated by signal %s" % self.signal
140
    elif self.exit_code is not None:
141
      self.fail_reason = "exited with exit code %s" % self.exit_code
142
    else:
143
      self.fail_reason = "unable to determine termination reason"
144

    
145
    if self.failed:
146
      logging.debug("Command '%s' failed (%s); output: %s",
147
                    self.cmd, self.fail_reason, self.output)
148

    
149
  def _GetOutput(self):
150
    """Returns the combined stdout and stderr for easier usage.
151

152
    """
153
    return self.stdout + self.stderr
154

    
155
  output = property(_GetOutput, None, None, "Return full output")
156

    
157

    
158
def _BuildCmdEnvironment(env, reset):
159
  """Builds the environment for an external program.
160

161
  """
162
  if reset:
163
    cmd_env = {}
164
  else:
165
    cmd_env = os.environ.copy()
166
    cmd_env["LC_ALL"] = "C"
167

    
168
  if env is not None:
169
    cmd_env.update(env)
170

    
171
  return cmd_env
172

    
173

    
174
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
175
  """Execute a (shell) command.
176

177
  The command should not read from its standard input, as it will be
178
  closed.
179

180
  @type cmd: string or list
181
  @param cmd: Command to run
182
  @type env: dict
183
  @param env: Additional environment variables
184
  @type output: str
185
  @param output: if desired, the output of the command can be
186
      saved in a file instead of the RunResult instance; this
187
      parameter denotes the file name (if not None)
188
  @type cwd: string
189
  @param cwd: if specified, will be used as the working
190
      directory for the command; the default will be /
191
  @type reset_env: boolean
192
  @param reset_env: whether to reset or keep the default os environment
193
  @rtype: L{RunResult}
194
  @return: RunResult instance
195
  @raise errors.ProgrammerError: if we call this when forks are disabled
196

197
  """
198
  if no_fork:
199
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
200

    
201
  if isinstance(cmd, basestring):
202
    strcmd = cmd
203
    shell = True
204
  else:
205
    cmd = [str(val) for val in cmd]
206
    strcmd = ShellQuoteArgs(cmd)
207
    shell = False
208

    
209
  if output:
210
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
211
  else:
212
    logging.debug("RunCmd %s", strcmd)
213

    
214
  cmd_env = _BuildCmdEnvironment(env, reset_env)
215

    
216
  try:
217
    if output is None:
218
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
219
    else:
220
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
221
      out = err = ""
222
  except OSError, err:
223
    if err.errno == errno.ENOENT:
224
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
225
                               (strcmd, err))
226
    else:
227
      raise
228

    
229
  if status >= 0:
230
    exitcode = status
231
    signal_ = None
232
  else:
233
    exitcode = None
234
    signal_ = -status
235

    
236
  return RunResult(exitcode, signal_, out, err, strcmd)
237

    
238

    
239
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
240
                pidfile=None):
241
  """Start a daemon process after forking twice.
242

243
  @type cmd: string or list
244
  @param cmd: Command to run
245
  @type env: dict
246
  @param env: Additional environment variables
247
  @type cwd: string
248
  @param cwd: Working directory for the program
249
  @type output: string
250
  @param output: Path to file in which to save the output
251
  @type output_fd: int
252
  @param output_fd: File descriptor for output
253
  @type pidfile: string
254
  @param pidfile: Process ID file
255
  @rtype: int
256
  @return: Daemon process ID
257
  @raise errors.ProgrammerError: if we call this when forks are disabled
258

259
  """
260
  if no_fork:
261
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
262
                                 " disabled")
263

    
264
  if output and not (bool(output) ^ (output_fd is not None)):
265
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
266
                                 " specified")
267

    
268
  if isinstance(cmd, basestring):
269
    cmd = ["/bin/sh", "-c", cmd]
270

    
271
  strcmd = ShellQuoteArgs(cmd)
272

    
273
  if output:
274
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
275
  else:
276
    logging.debug("StartDaemon %s", strcmd)
277

    
278
  cmd_env = _BuildCmdEnvironment(env, False)
279

    
280
  # Create pipe for sending PID back
281
  (pidpipe_read, pidpipe_write) = os.pipe()
282
  try:
283
    try:
284
      # Create pipe for sending error messages
285
      (errpipe_read, errpipe_write) = os.pipe()
286
      try:
287
        try:
288
          # First fork
289
          pid = os.fork()
290
          if pid == 0:
291
            try:
292
              # Child process, won't return
293
              _StartDaemonChild(errpipe_read, errpipe_write,
294
                                pidpipe_read, pidpipe_write,
295
                                cmd, cmd_env, cwd,
296
                                output, output_fd, pidfile)
297
            finally:
298
              # Well, maybe child process failed
299
              os._exit(1) # pylint: disable-msg=W0212
300
        finally:
301
          _CloseFDNoErr(errpipe_write)
302

    
303
        # Wait for daemon to be started (or an error message to arrive) and read
304
        # up to 100 KB as an error message
305
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
306
      finally:
307
        _CloseFDNoErr(errpipe_read)
308
    finally:
309
      _CloseFDNoErr(pidpipe_write)
310

    
311
    # Read up to 128 bytes for PID
312
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
313
  finally:
314
    _CloseFDNoErr(pidpipe_read)
315

    
316
  # Try to avoid zombies by waiting for child process
317
  try:
318
    os.waitpid(pid, 0)
319
  except OSError:
320
    pass
321

    
322
  if errormsg:
323
    raise errors.OpExecError("Error when starting daemon process: %r" %
324
                             errormsg)
325

    
326
  try:
327
    return int(pidtext)
328
  except (ValueError, TypeError), err:
329
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
330
                             (pidtext, err))
331

    
332

    
333
def _StartDaemonChild(errpipe_read, errpipe_write,
334
                      pidpipe_read, pidpipe_write,
335
                      args, env, cwd,
336
                      output, fd_output, pidfile):
337
  """Child process for starting daemon.
338

339
  """
340
  try:
341
    # Close parent's side
342
    _CloseFDNoErr(errpipe_read)
343
    _CloseFDNoErr(pidpipe_read)
344

    
345
    # First child process
346
    os.chdir("/")
347
    os.umask(077)
348
    os.setsid()
349

    
350
    # And fork for the second time
351
    pid = os.fork()
352
    if pid != 0:
353
      # Exit first child process
354
      os._exit(0) # pylint: disable-msg=W0212
355

    
356
    # Make sure pipe is closed on execv* (and thereby notifies original process)
357
    SetCloseOnExecFlag(errpipe_write, True)
358

    
359
    # List of file descriptors to be left open
360
    noclose_fds = [errpipe_write]
361

    
362
    # Open PID file
363
    if pidfile:
364
      try:
365
        # TODO: Atomic replace with another locked file instead of writing into
366
        # it after creating
367
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
368

    
369
        # Lock the PID file (and fail if not possible to do so). Any code
370
        # wanting to send a signal to the daemon should try to lock the PID
371
        # file before reading it. If acquiring the lock succeeds, the daemon is
372
        # no longer running and the signal should not be sent.
373
        LockFile(fd_pidfile)
374

    
375
        os.write(fd_pidfile, "%d\n" % os.getpid())
376
      except Exception, err:
377
        raise Exception("Creating and locking PID file failed: %s" % err)
378

    
379
      # Keeping the file open to hold the lock
380
      noclose_fds.append(fd_pidfile)
381

    
382
      SetCloseOnExecFlag(fd_pidfile, False)
383
    else:
384
      fd_pidfile = None
385

    
386
    # Open /dev/null
387
    fd_devnull = os.open(os.devnull, os.O_RDWR)
388

    
389
    assert not output or (bool(output) ^ (fd_output is not None))
390

    
391
    if fd_output is not None:
392
      pass
393
    elif output:
394
      # Open output file
395
      try:
396
        # TODO: Implement flag to set append=yes/no
397
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
398
      except EnvironmentError, err:
399
        raise Exception("Opening output file failed: %s" % err)
400
    else:
401
      fd_output = fd_devnull
402

    
403
    # Redirect standard I/O
404
    os.dup2(fd_devnull, 0)
405
    os.dup2(fd_output, 1)
406
    os.dup2(fd_output, 2)
407

    
408
    # Send daemon PID to parent
409
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
410

    
411
    # Close all file descriptors except stdio and error message pipe
412
    CloseFDs(noclose_fds=noclose_fds)
413

    
414
    # Change working directory
415
    os.chdir(cwd)
416

    
417
    if env is None:
418
      os.execvp(args[0], args)
419
    else:
420
      os.execvpe(args[0], args, env)
421
  except: # pylint: disable-msg=W0702
422
    try:
423
      # Report errors to original process
424
      buf = str(sys.exc_info()[1])
425

    
426
      RetryOnSignal(os.write, errpipe_write, buf)
427
    except: # pylint: disable-msg=W0702
428
      # Ignore errors in error handling
429
      pass
430

    
431
  os._exit(1) # pylint: disable-msg=W0212
432

    
433

    
434
def _RunCmdPipe(cmd, env, via_shell, cwd):
435
  """Run a command and return its output.
436

437
  @type  cmd: string or list
438
  @param cmd: Command to run
439
  @type env: dict
440
  @param env: The environment to use
441
  @type via_shell: bool
442
  @param via_shell: if we should run via the shell
443
  @type cwd: string
444
  @param cwd: the working directory for the program
445
  @rtype: tuple
446
  @return: (out, err, status)
447

448
  """
449
  poller = select.poll()
450
  child = subprocess.Popen(cmd, shell=via_shell,
451
                           stderr=subprocess.PIPE,
452
                           stdout=subprocess.PIPE,
453
                           stdin=subprocess.PIPE,
454
                           close_fds=True, env=env,
455
                           cwd=cwd)
456

    
457
  child.stdin.close()
458
  poller.register(child.stdout, select.POLLIN)
459
  poller.register(child.stderr, select.POLLIN)
460
  out = StringIO()
461
  err = StringIO()
462
  fdmap = {
463
    child.stdout.fileno(): (out, child.stdout),
464
    child.stderr.fileno(): (err, child.stderr),
465
    }
466
  for fd in fdmap:
467
    SetNonblockFlag(fd, True)
468

    
469
  while fdmap:
470
    pollresult = RetryOnSignal(poller.poll)
471

    
472
    for fd, event in pollresult:
473
      if event & select.POLLIN or event & select.POLLPRI:
474
        data = fdmap[fd][1].read()
475
        # no data from read signifies EOF (the same as POLLHUP)
476
        if not data:
477
          poller.unregister(fd)
478
          del fdmap[fd]
479
          continue
480
        fdmap[fd][0].write(data)
481
      if (event & select.POLLNVAL or event & select.POLLHUP or
482
          event & select.POLLERR):
483
        poller.unregister(fd)
484
        del fdmap[fd]
485

    
486
  out = out.getvalue()
487
  err = err.getvalue()
488

    
489
  status = child.wait()
490
  return out, err, status
491

    
492

    
493
def _RunCmdFile(cmd, env, via_shell, output, cwd):
494
  """Run a command and save its output to a file.
495

496
  @type  cmd: string or list
497
  @param cmd: Command to run
498
  @type env: dict
499
  @param env: The environment to use
500
  @type via_shell: bool
501
  @param via_shell: if we should run via the shell
502
  @type output: str
503
  @param output: the filename in which to save the output
504
  @type cwd: string
505
  @param cwd: the working directory for the program
506
  @rtype: int
507
  @return: the exit status
508

509
  """
510
  fh = open(output, "a")
511
  try:
512
    child = subprocess.Popen(cmd, shell=via_shell,
513
                             stderr=subprocess.STDOUT,
514
                             stdout=fh,
515
                             stdin=subprocess.PIPE,
516
                             close_fds=True, env=env,
517
                             cwd=cwd)
518

    
519
    child.stdin.close()
520
    status = child.wait()
521
  finally:
522
    fh.close()
523
  return status
524

    
525

    
526
def SetCloseOnExecFlag(fd, enable):
527
  """Sets or unsets the close-on-exec flag on a file descriptor.
528

529
  @type fd: int
530
  @param fd: File descriptor
531
  @type enable: bool
532
  @param enable: Whether to set or unset it.
533

534
  """
535
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
536

    
537
  if enable:
538
    flags |= fcntl.FD_CLOEXEC
539
  else:
540
    flags &= ~fcntl.FD_CLOEXEC
541

    
542
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
543

    
544

    
545
def SetNonblockFlag(fd, enable):
546
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
547

548
  @type fd: int
549
  @param fd: File descriptor
550
  @type enable: bool
551
  @param enable: Whether to set or unset it
552

553
  """
554
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
555

    
556
  if enable:
557
    flags |= os.O_NONBLOCK
558
  else:
559
    flags &= ~os.O_NONBLOCK
560

    
561
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
562

    
563

    
564
def RetryOnSignal(fn, *args, **kwargs):
565
  """Calls a function again if it failed due to EINTR.
566

567
  """
568
  while True:
569
    try:
570
      return fn(*args, **kwargs)
571
    except EnvironmentError, err:
572
      if err.errno != errno.EINTR:
573
        raise
574
    except (socket.error, select.error), err:
575
      # In python 2.6 and above select.error is an IOError, so it's handled
576
      # above, in 2.5 and below it's not, and it's handled here.
577
      if not (err.args and err.args[0] == errno.EINTR):
578
        raise
579

    
580

    
581
def RunParts(dir_name, env=None, reset_env=False):
582
  """Run Scripts or programs in a directory
583

584
  @type dir_name: string
585
  @param dir_name: absolute path to a directory
586
  @type env: dict
587
  @param env: The environment to use
588
  @type reset_env: boolean
589
  @param reset_env: whether to reset or keep the default os environment
590
  @rtype: list of tuples
591
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
592

593
  """
594
  rr = []
595

    
596
  try:
597
    dir_contents = ListVisibleFiles(dir_name)
598
  except OSError, err:
599
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
600
    return rr
601

    
602
  for relname in sorted(dir_contents):
603
    fname = PathJoin(dir_name, relname)
604
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
605
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
606
      rr.append((relname, constants.RUNPARTS_SKIP, None))
607
    else:
608
      try:
609
        result = RunCmd([fname], env=env, reset_env=reset_env)
610
      except Exception, err: # pylint: disable-msg=W0703
611
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
612
      else:
613
        rr.append((relname, constants.RUNPARTS_RUN, result))
614

    
615
  return rr
616

    
617

    
618
def GetSocketCredentials(sock):
619
  """Returns the credentials of the foreign process connected to a socket.
620

621
  @param sock: Unix socket
622
  @rtype: tuple; (number, number, number)
623
  @return: The PID, UID and GID of the connected foreign process.
624

625
  """
626
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
627
                             _STRUCT_UCRED_SIZE)
628
  return struct.unpack(_STRUCT_UCRED, peercred)
629

    
630

    
631
def RemoveFile(filename):
632
  """Remove a file ignoring some errors.
633

634
  Remove a file, ignoring non-existing ones or directories. Other
635
  errors are passed.
636

637
  @type filename: str
638
  @param filename: the file to be removed
639

640
  """
641
  try:
642
    os.unlink(filename)
643
  except OSError, err:
644
    if err.errno not in (errno.ENOENT, errno.EISDIR):
645
      raise
646

    
647

    
648
def RemoveDir(dirname):
649
  """Remove an empty directory.
650

651
  Remove a directory, ignoring non-existing ones.
652
  Other errors are passed. This includes the case,
653
  where the directory is not empty, so it can't be removed.
654

655
  @type dirname: str
656
  @param dirname: the empty directory to be removed
657

658
  """
659
  try:
660
    os.rmdir(dirname)
661
  except OSError, err:
662
    if err.errno != errno.ENOENT:
663
      raise
664

    
665

    
666
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
667
  """Renames a file.
668

669
  @type old: string
670
  @param old: Original path
671
  @type new: string
672
  @param new: New path
673
  @type mkdir: bool
674
  @param mkdir: Whether to create target directory if it doesn't exist
675
  @type mkdir_mode: int
676
  @param mkdir_mode: Mode for newly created directories
677

678
  """
679
  try:
680
    return os.rename(old, new)
681
  except OSError, err:
682
    # In at least one use case of this function, the job queue, directory
683
    # creation is very rare. Checking for the directory before renaming is not
684
    # as efficient.
685
    if mkdir and err.errno == errno.ENOENT:
686
      # Create directory and try again
687
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
688

    
689
      return os.rename(old, new)
690

    
691
    raise
692

    
693

    
694
def Makedirs(path, mode=0750):
695
  """Super-mkdir; create a leaf directory and all intermediate ones.
696

697
  This is a wrapper around C{os.makedirs} adding error handling not implemented
698
  before Python 2.5.
699

700
  """
701
  try:
702
    os.makedirs(path, mode)
703
  except OSError, err:
704
    # Ignore EEXIST. This is only handled in os.makedirs as included in
705
    # Python 2.5 and above.
706
    if err.errno != errno.EEXIST or not os.path.exists(path):
707
      raise
708

    
709

    
710
def ResetTempfileModule():
711
  """Resets the random name generator of the tempfile module.
712

713
  This function should be called after C{os.fork} in the child process to
714
  ensure it creates a newly seeded random generator. Otherwise it would
715
  generate the same random parts as the parent process. If several processes
716
  race for the creation of a temporary file, this could lead to one not getting
717
  a temporary name.
718

719
  """
720
  # pylint: disable-msg=W0212
721
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
722
    tempfile._once_lock.acquire()
723
    try:
724
      # Reset random name generator
725
      tempfile._name_sequence = None
726
    finally:
727
      tempfile._once_lock.release()
728
  else:
729
    logging.critical("The tempfile module misses at least one of the"
730
                     " '_once_lock' and '_name_sequence' attributes")
731

    
732

    
733
def _FingerprintFile(filename):
734
  """Compute the fingerprint of a file.
735

736
  If the file does not exist, a None will be returned
737
  instead.
738

739
  @type filename: str
740
  @param filename: the filename to checksum
741
  @rtype: str
742
  @return: the hex digest of the sha checksum of the contents
743
      of the file
744

745
  """
746
  if not (os.path.exists(filename) and os.path.isfile(filename)):
747
    return None
748

    
749
  f = open(filename)
750

    
751
  fp = compat.sha1_hash()
752
  while True:
753
    data = f.read(4096)
754
    if not data:
755
      break
756

    
757
    fp.update(data)
758

    
759
  return fp.hexdigest()
760

    
761

    
762
def FingerprintFiles(files):
763
  """Compute fingerprints for a list of files.
764

765
  @type files: list
766
  @param files: the list of filename to fingerprint
767
  @rtype: dict
768
  @return: a dictionary filename: fingerprint, holding only
769
      existing files
770

771
  """
772
  ret = {}
773

    
774
  for filename in files:
775
    cksum = _FingerprintFile(filename)
776
    if cksum:
777
      ret[filename] = cksum
778

    
779
  return ret
780

    
781

    
782
def ForceDictType(target, key_types, allowed_values=None):
783
  """Force the values of a dict to have certain types.
784

785
  @type target: dict
786
  @param target: the dict to update
787
  @type key_types: dict
788
  @param key_types: dict mapping target dict keys to types
789
                    in constants.ENFORCEABLE_TYPES
790
  @type allowed_values: list
791
  @keyword allowed_values: list of specially allowed values
792

793
  """
794
  if allowed_values is None:
795
    allowed_values = []
796

    
797
  if not isinstance(target, dict):
798
    msg = "Expected dictionary, got '%s'" % target
799
    raise errors.TypeEnforcementError(msg)
800

    
801
  for key in target:
802
    if key not in key_types:
803
      msg = "Unknown key '%s'" % key
804
      raise errors.TypeEnforcementError(msg)
805

    
806
    if target[key] in allowed_values:
807
      continue
808

    
809
    ktype = key_types[key]
810
    if ktype not in constants.ENFORCEABLE_TYPES:
811
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
812
      raise errors.ProgrammerError(msg)
813

    
814
    if ktype == constants.VTYPE_STRING:
815
      if not isinstance(target[key], basestring):
816
        if isinstance(target[key], bool) and not target[key]:
817
          target[key] = ''
818
        else:
819
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
820
          raise errors.TypeEnforcementError(msg)
821
    elif ktype == constants.VTYPE_BOOL:
822
      if isinstance(target[key], basestring) and target[key]:
823
        if target[key].lower() == constants.VALUE_FALSE:
824
          target[key] = False
825
        elif target[key].lower() == constants.VALUE_TRUE:
826
          target[key] = True
827
        else:
828
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
829
          raise errors.TypeEnforcementError(msg)
830
      elif target[key]:
831
        target[key] = True
832
      else:
833
        target[key] = False
834
    elif ktype == constants.VTYPE_SIZE:
835
      try:
836
        target[key] = ParseUnit(target[key])
837
      except errors.UnitParseError, err:
838
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
839
              (key, target[key], err)
840
        raise errors.TypeEnforcementError(msg)
841
    elif ktype == constants.VTYPE_INT:
842
      try:
843
        target[key] = int(target[key])
844
      except (ValueError, TypeError):
845
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
846
        raise errors.TypeEnforcementError(msg)
847

    
848

    
849
def _GetProcStatusPath(pid):
850
  """Returns the path for a PID's proc status file.
851

852
  @type pid: int
853
  @param pid: Process ID
854
  @rtype: string
855

856
  """
857
  return "/proc/%d/status" % pid
858

    
859

    
860
def IsProcessAlive(pid):
861
  """Check if a given pid exists on the system.
862

863
  @note: zombie status is not handled, so zombie processes
864
      will be returned as alive
865
  @type pid: int
866
  @param pid: the process ID to check
867
  @rtype: boolean
868
  @return: True if the process exists
869

870
  """
871
  def _TryStat(name):
872
    try:
873
      os.stat(name)
874
      return True
875
    except EnvironmentError, err:
876
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
877
        return False
878
      elif err.errno == errno.EINVAL:
879
        raise RetryAgain(err)
880
      raise
881

    
882
  assert isinstance(pid, int), "pid must be an integer"
883
  if pid <= 0:
884
    return False
885

    
886
  # /proc in a multiprocessor environment can have strange behaviors.
887
  # Retry the os.stat a few times until we get a good result.
888
  try:
889
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
890
                 args=[_GetProcStatusPath(pid)])
891
  except RetryTimeout, err:
892
    err.RaiseInner()
893

    
894

    
895
def _ParseSigsetT(sigset):
896
  """Parse a rendered sigset_t value.
897

898
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
899
  function.
900

901
  @type sigset: string
902
  @param sigset: Rendered signal set from /proc/$pid/status
903
  @rtype: set
904
  @return: Set of all enabled signal numbers
905

906
  """
907
  result = set()
908

    
909
  signum = 0
910
  for ch in reversed(sigset):
911
    chv = int(ch, 16)
912

    
913
    # The following could be done in a loop, but it's easier to read and
914
    # understand in the unrolled form
915
    if chv & 1:
916
      result.add(signum + 1)
917
    if chv & 2:
918
      result.add(signum + 2)
919
    if chv & 4:
920
      result.add(signum + 3)
921
    if chv & 8:
922
      result.add(signum + 4)
923

    
924
    signum += 4
925

    
926
  return result
927

    
928

    
929
def _GetProcStatusField(pstatus, field):
930
  """Retrieves a field from the contents of a proc status file.
931

932
  @type pstatus: string
933
  @param pstatus: Contents of /proc/$pid/status
934
  @type field: string
935
  @param field: Name of field whose value should be returned
936
  @rtype: string
937

938
  """
939
  for line in pstatus.splitlines():
940
    parts = line.split(":", 1)
941

    
942
    if len(parts) < 2 or parts[0] != field:
943
      continue
944

    
945
    return parts[1].strip()
946

    
947
  return None
948

    
949

    
950
def IsProcessHandlingSignal(pid, signum, status_path=None):
951
  """Checks whether a process is handling a signal.
952

953
  @type pid: int
954
  @param pid: Process ID
955
  @type signum: int
956
  @param signum: Signal number
957
  @rtype: bool
958

959
  """
960
  if status_path is None:
961
    status_path = _GetProcStatusPath(pid)
962

    
963
  try:
964
    proc_status = ReadFile(status_path)
965
  except EnvironmentError, err:
966
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
967
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
968
      return False
969
    raise
970

    
971
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
972
  if sigcgt is None:
973
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
974

    
975
  # Now check whether signal is handled
976
  return signum in _ParseSigsetT(sigcgt)
977

    
978

    
979
def ReadPidFile(pidfile):
980
  """Read a pid from a file.
981

982
  @type  pidfile: string
983
  @param pidfile: path to the file containing the pid
984
  @rtype: int
985
  @return: The process id, if the file exists and contains a valid PID,
986
           otherwise 0
987

988
  """
989
  try:
990
    raw_data = ReadOneLineFile(pidfile)
991
  except EnvironmentError, err:
992
    if err.errno != errno.ENOENT:
993
      logging.exception("Can't read pid file")
994
    return 0
995

    
996
  try:
997
    pid = int(raw_data)
998
  except (TypeError, ValueError), err:
999
    logging.info("Can't parse pid file contents", exc_info=True)
1000
    return 0
1001

    
1002
  return pid
1003

    
1004

    
1005
def ReadLockedPidFile(path):
1006
  """Reads a locked PID file.
1007

1008
  This can be used together with L{StartDaemon}.
1009

1010
  @type path: string
1011
  @param path: Path to PID file
1012
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1013

1014
  """
1015
  try:
1016
    fd = os.open(path, os.O_RDONLY)
1017
  except EnvironmentError, err:
1018
    if err.errno == errno.ENOENT:
1019
      # PID file doesn't exist
1020
      return None
1021
    raise
1022

    
1023
  try:
1024
    try:
1025
      # Try to acquire lock
1026
      LockFile(fd)
1027
    except errors.LockError:
1028
      # Couldn't lock, daemon is running
1029
      return int(os.read(fd, 100))
1030
  finally:
1031
    os.close(fd)
1032

    
1033
  return None
1034

    
1035

    
1036
def MatchNameComponent(key, name_list, case_sensitive=True):
1037
  """Try to match a name against a list.
1038

1039
  This function will try to match a name like test1 against a list
1040
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1041
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1042
  not I{'test1.ex'}. A multiple match will be considered as no match
1043
  at all (e.g. I{'test1'} against C{['test1.example.com',
1044
  'test1.example.org']}), except when the key fully matches an entry
1045
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1046

1047
  @type key: str
1048
  @param key: the name to be searched
1049
  @type name_list: list
1050
  @param name_list: the list of strings against which to search the key
1051
  @type case_sensitive: boolean
1052
  @param case_sensitive: whether to provide a case-sensitive match
1053

1054
  @rtype: None or str
1055
  @return: None if there is no match I{or} if there are multiple matches,
1056
      otherwise the element from the list which matches
1057

1058
  """
1059
  if key in name_list:
1060
    return key
1061

    
1062
  re_flags = 0
1063
  if not case_sensitive:
1064
    re_flags |= re.IGNORECASE
1065
    key = key.upper()
1066
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1067
  names_filtered = []
1068
  string_matches = []
1069
  for name in name_list:
1070
    if mo.match(name) is not None:
1071
      names_filtered.append(name)
1072
      if not case_sensitive and key == name.upper():
1073
        string_matches.append(name)
1074

    
1075
  if len(string_matches) == 1:
1076
    return string_matches[0]
1077
  if len(names_filtered) == 1:
1078
    return names_filtered[0]
1079
  return None
1080

    
1081

    
1082
class HostInfo:
1083
  """Class implementing resolver and hostname functionality
1084

1085
  """
1086
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1087

    
1088
  def __init__(self, name=None):
1089
    """Initialize the host name object.
1090

1091
    If the name argument is not passed, it will use this system's
1092
    name.
1093

1094
    """
1095
    if name is None:
1096
      name = self.SysName()
1097

    
1098
    self.query = name
1099
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1100
    self.ip = self.ipaddrs[0]
1101

    
1102
  def ShortName(self):
1103
    """Returns the hostname without domain.
1104

1105
    """
1106
    return self.name.split('.')[0]
1107

    
1108
  @staticmethod
1109
  def SysName():
1110
    """Return the current system's name.
1111

1112
    This is simply a wrapper over C{socket.gethostname()}.
1113

1114
    """
1115
    return socket.gethostname()
1116

    
1117
  @staticmethod
1118
  def LookupHostname(hostname):
1119
    """Look up hostname
1120

1121
    @type hostname: str
1122
    @param hostname: hostname to look up
1123

1124
    @rtype: tuple
1125
    @return: a tuple (name, aliases, ipaddrs) as returned by
1126
        C{socket.gethostbyname_ex}
1127
    @raise errors.ResolverError: in case of errors in resolving
1128

1129
    """
1130
    try:
1131
      result = socket.gethostbyname_ex(hostname)
1132
    except socket.gaierror, err:
1133
      # hostname not found in DNS
1134
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1135

    
1136
    return result
1137

    
1138
  @classmethod
1139
  def NormalizeName(cls, hostname):
1140
    """Validate and normalize the given hostname.
1141

1142
    @attention: the validation is a bit more relaxed than the standards
1143
        require; most importantly, we allow underscores in names
1144
    @raise errors.OpPrereqError: when the name is not valid
1145

1146
    """
1147
    hostname = hostname.lower()
1148
    if (not cls._VALID_NAME_RE.match(hostname) or
1149
        # double-dots, meaning empty label
1150
        ".." in hostname or
1151
        # empty initial label
1152
        hostname.startswith(".")):
1153
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1154
                                 errors.ECODE_INVAL)
1155
    if hostname.endswith("."):
1156
      hostname = hostname.rstrip(".")
1157
    return hostname
1158

    
1159

    
1160
def ValidateServiceName(name):
1161
  """Validate the given service name.
1162

1163
  @type name: number or string
1164
  @param name: Service name or port specification
1165

1166
  """
1167
  try:
1168
    numport = int(name)
1169
  except (ValueError, TypeError):
1170
    # Non-numeric service name
1171
    valid = _VALID_SERVICE_NAME_RE.match(name)
1172
  else:
1173
    # Numeric port (protocols other than TCP or UDP might need adjustments
1174
    # here)
1175
    valid = (numport >= 0 and numport < (1 << 16))
1176

    
1177
  if not valid:
1178
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1179
                               errors.ECODE_INVAL)
1180

    
1181
  return name
1182

    
1183

    
1184
def GetHostInfo(name=None):
1185
  """Lookup host name and raise an OpPrereqError for failures"""
1186

    
1187
  try:
1188
    return HostInfo(name)
1189
  except errors.ResolverError, err:
1190
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1191
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1192

    
1193

    
1194
def ListVolumeGroups():
1195
  """List volume groups and their size
1196

1197
  @rtype: dict
1198
  @return:
1199
       Dictionary with keys volume name and values
1200
       the size of the volume
1201

1202
  """
1203
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1204
  result = RunCmd(command)
1205
  retval = {}
1206
  if result.failed:
1207
    return retval
1208

    
1209
  for line in result.stdout.splitlines():
1210
    try:
1211
      name, size = line.split()
1212
      size = int(float(size))
1213
    except (IndexError, ValueError), err:
1214
      logging.error("Invalid output from vgs (%s): %s", err, line)
1215
      continue
1216

    
1217
    retval[name] = size
1218

    
1219
  return retval
1220

    
1221

    
1222
def BridgeExists(bridge):
1223
  """Check whether the given bridge exists in the system
1224

1225
  @type bridge: str
1226
  @param bridge: the bridge name to check
1227
  @rtype: boolean
1228
  @return: True if it does
1229

1230
  """
1231
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1232

    
1233

    
1234
def NiceSort(name_list):
1235
  """Sort a list of strings based on digit and non-digit groupings.
1236

1237
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1238
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1239
  'a11']}.
1240

1241
  The sort algorithm breaks each name in groups of either only-digits
1242
  or no-digits. Only the first eight such groups are considered, and
1243
  after that we just use what's left of the string.
1244

1245
  @type name_list: list
1246
  @param name_list: the names to be sorted
1247
  @rtype: list
1248
  @return: a copy of the name list sorted with our algorithm
1249

1250
  """
1251
  _SORTER_BASE = "(\D+|\d+)"
1252
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1253
                                                  _SORTER_BASE, _SORTER_BASE,
1254
                                                  _SORTER_BASE, _SORTER_BASE,
1255
                                                  _SORTER_BASE, _SORTER_BASE)
1256
  _SORTER_RE = re.compile(_SORTER_FULL)
1257
  _SORTER_NODIGIT = re.compile("^\D*$")
1258
  def _TryInt(val):
1259
    """Attempts to convert a variable to integer."""
1260
    if val is None or _SORTER_NODIGIT.match(val):
1261
      return val
1262
    rval = int(val)
1263
    return rval
1264

    
1265
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1266
             for name in name_list]
1267
  to_sort.sort()
1268
  return [tup[1] for tup in to_sort]
1269

    
1270

    
1271
def TryConvert(fn, val):
1272
  """Try to convert a value ignoring errors.
1273

1274
  This function tries to apply function I{fn} to I{val}. If no
1275
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1276
  the result, else it will return the original value. Any other
1277
  exceptions are propagated to the caller.
1278

1279
  @type fn: callable
1280
  @param fn: function to apply to the value
1281
  @param val: the value to be converted
1282
  @return: The converted value if the conversion was successful,
1283
      otherwise the original value.
1284

1285
  """
1286
  try:
1287
    nv = fn(val)
1288
  except (ValueError, TypeError):
1289
    nv = val
1290
  return nv
1291

    
1292

    
1293
def IsValidIP(ip):
1294
  """Verifies the syntax of an IPv4 address.
1295

1296
  This function checks if the IPv4 address passes is valid or not based
1297
  on syntax (not IP range, class calculations, etc.).
1298

1299
  @type ip: str
1300
  @param ip: the address to be checked
1301
  @rtype: a regular expression match object
1302
  @return: a regular expression match object, or None if the
1303
      address is not valid
1304

1305
  """
1306
  unit = "(0|[1-9]\d{0,2})"
1307
  #TODO: convert and return only boolean
1308
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
1309

    
1310

    
1311
def IsValidShellParam(word):
1312
  """Verifies is the given word is safe from the shell's p.o.v.
1313

1314
  This means that we can pass this to a command via the shell and be
1315
  sure that it doesn't alter the command line and is passed as such to
1316
  the actual command.
1317

1318
  Note that we are overly restrictive here, in order to be on the safe
1319
  side.
1320

1321
  @type word: str
1322
  @param word: the word to check
1323
  @rtype: boolean
1324
  @return: True if the word is 'safe'
1325

1326
  """
1327
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1328

    
1329

    
1330
def BuildShellCmd(template, *args):
1331
  """Build a safe shell command line from the given arguments.
1332

1333
  This function will check all arguments in the args list so that they
1334
  are valid shell parameters (i.e. they don't contain shell
1335
  metacharacters). If everything is ok, it will return the result of
1336
  template % args.
1337

1338
  @type template: str
1339
  @param template: the string holding the template for the
1340
      string formatting
1341
  @rtype: str
1342
  @return: the expanded command line
1343

1344
  """
1345
  for word in args:
1346
    if not IsValidShellParam(word):
1347
      raise errors.ProgrammerError("Shell argument '%s' contains"
1348
                                   " invalid characters" % word)
1349
  return template % args
1350

    
1351

    
1352
def FormatUnit(value, units):
1353
  """Formats an incoming number of MiB with the appropriate unit.
1354

1355
  @type value: int
1356
  @param value: integer representing the value in MiB (1048576)
1357
  @type units: char
1358
  @param units: the type of formatting we should do:
1359
      - 'h' for automatic scaling
1360
      - 'm' for MiBs
1361
      - 'g' for GiBs
1362
      - 't' for TiBs
1363
  @rtype: str
1364
  @return: the formatted value (with suffix)
1365

1366
  """
1367
  if units not in ('m', 'g', 't', 'h'):
1368
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1369

    
1370
  suffix = ''
1371

    
1372
  if units == 'm' or (units == 'h' and value < 1024):
1373
    if units == 'h':
1374
      suffix = 'M'
1375
    return "%d%s" % (round(value, 0), suffix)
1376

    
1377
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1378
    if units == 'h':
1379
      suffix = 'G'
1380
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1381

    
1382
  else:
1383
    if units == 'h':
1384
      suffix = 'T'
1385
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1386

    
1387

    
1388
def ParseUnit(input_string):
1389
  """Tries to extract number and scale from the given string.
1390

1391
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1392
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1393
  is always an int in MiB.
1394

1395
  """
1396
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1397
  if not m:
1398
    raise errors.UnitParseError("Invalid format")
1399

    
1400
  value = float(m.groups()[0])
1401

    
1402
  unit = m.groups()[1]
1403
  if unit:
1404
    lcunit = unit.lower()
1405
  else:
1406
    lcunit = 'm'
1407

    
1408
  if lcunit in ('m', 'mb', 'mib'):
1409
    # Value already in MiB
1410
    pass
1411

    
1412
  elif lcunit in ('g', 'gb', 'gib'):
1413
    value *= 1024
1414

    
1415
  elif lcunit in ('t', 'tb', 'tib'):
1416
    value *= 1024 * 1024
1417

    
1418
  else:
1419
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1420

    
1421
  # Make sure we round up
1422
  if int(value) < value:
1423
    value += 1
1424

    
1425
  # Round up to the next multiple of 4
1426
  value = int(value)
1427
  if value % 4:
1428
    value += 4 - value % 4
1429

    
1430
  return value
1431

    
1432

    
1433
def AddAuthorizedKey(file_name, key):
1434
  """Adds an SSH public key to an authorized_keys file.
1435

1436
  @type file_name: str
1437
  @param file_name: path to authorized_keys file
1438
  @type key: str
1439
  @param key: string containing key
1440

1441
  """
1442
  key_fields = key.split()
1443

    
1444
  f = open(file_name, 'a+')
1445
  try:
1446
    nl = True
1447
    for line in f:
1448
      # Ignore whitespace changes
1449
      if line.split() == key_fields:
1450
        break
1451
      nl = line.endswith('\n')
1452
    else:
1453
      if not nl:
1454
        f.write("\n")
1455
      f.write(key.rstrip('\r\n'))
1456
      f.write("\n")
1457
      f.flush()
1458
  finally:
1459
    f.close()
1460

    
1461

    
1462
def RemoveAuthorizedKey(file_name, key):
1463
  """Removes an SSH public key from an authorized_keys file.
1464

1465
  @type file_name: str
1466
  @param file_name: path to authorized_keys file
1467
  @type key: str
1468
  @param key: string containing key
1469

1470
  """
1471
  key_fields = key.split()
1472

    
1473
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1474
  try:
1475
    out = os.fdopen(fd, 'w')
1476
    try:
1477
      f = open(file_name, 'r')
1478
      try:
1479
        for line in f:
1480
          # Ignore whitespace changes while comparing lines
1481
          if line.split() != key_fields:
1482
            out.write(line)
1483

    
1484
        out.flush()
1485
        os.rename(tmpname, file_name)
1486
      finally:
1487
        f.close()
1488
    finally:
1489
      out.close()
1490
  except:
1491
    RemoveFile(tmpname)
1492
    raise
1493

    
1494

    
1495
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1496
  """Sets the name of an IP address and hostname in /etc/hosts.
1497

1498
  @type file_name: str
1499
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1500
  @type ip: str
1501
  @param ip: the IP address
1502
  @type hostname: str
1503
  @param hostname: the hostname to be added
1504
  @type aliases: list
1505
  @param aliases: the list of aliases to add for the hostname
1506

1507
  """
1508
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1509
  # Ensure aliases are unique
1510
  aliases = UniqueSequence([hostname] + aliases)[1:]
1511

    
1512
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1513
  try:
1514
    out = os.fdopen(fd, 'w')
1515
    try:
1516
      f = open(file_name, 'r')
1517
      try:
1518
        for line in f:
1519
          fields = line.split()
1520
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1521
            continue
1522
          out.write(line)
1523

    
1524
        out.write("%s\t%s" % (ip, hostname))
1525
        if aliases:
1526
          out.write(" %s" % ' '.join(aliases))
1527
        out.write('\n')
1528

    
1529
        out.flush()
1530
        os.fsync(out)
1531
        os.chmod(tmpname, 0644)
1532
        os.rename(tmpname, file_name)
1533
      finally:
1534
        f.close()
1535
    finally:
1536
      out.close()
1537
  except:
1538
    RemoveFile(tmpname)
1539
    raise
1540

    
1541

    
1542
def AddHostToEtcHosts(hostname):
1543
  """Wrapper around SetEtcHostsEntry.
1544

1545
  @type hostname: str
1546
  @param hostname: a hostname that will be resolved and added to
1547
      L{constants.ETC_HOSTS}
1548

1549
  """
1550
  hi = HostInfo(name=hostname)
1551
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1552

    
1553

    
1554
def RemoveEtcHostsEntry(file_name, hostname):
1555
  """Removes a hostname from /etc/hosts.
1556

1557
  IP addresses without names are removed from the file.
1558

1559
  @type file_name: str
1560
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1561
  @type hostname: str
1562
  @param hostname: the hostname to be removed
1563

1564
  """
1565
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1566
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1567
  try:
1568
    out = os.fdopen(fd, 'w')
1569
    try:
1570
      f = open(file_name, 'r')
1571
      try:
1572
        for line in f:
1573
          fields = line.split()
1574
          if len(fields) > 1 and not fields[0].startswith('#'):
1575
            names = fields[1:]
1576
            if hostname in names:
1577
              while hostname in names:
1578
                names.remove(hostname)
1579
              if names:
1580
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1581
              continue
1582

    
1583
          out.write(line)
1584

    
1585
        out.flush()
1586
        os.fsync(out)
1587
        os.chmod(tmpname, 0644)
1588
        os.rename(tmpname, file_name)
1589
      finally:
1590
        f.close()
1591
    finally:
1592
      out.close()
1593
  except:
1594
    RemoveFile(tmpname)
1595
    raise
1596

    
1597

    
1598
def RemoveHostFromEtcHosts(hostname):
1599
  """Wrapper around RemoveEtcHostsEntry.
1600

1601
  @type hostname: str
1602
  @param hostname: hostname that will be resolved and its
1603
      full and shot name will be removed from
1604
      L{constants.ETC_HOSTS}
1605

1606
  """
1607
  hi = HostInfo(name=hostname)
1608
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1609
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1610

    
1611

    
1612
def TimestampForFilename():
1613
  """Returns the current time formatted for filenames.
1614

1615
  The format doesn't contain colons as some shells and applications them as
1616
  separators.
1617

1618
  """
1619
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1620

    
1621

    
1622
def CreateBackup(file_name):
1623
  """Creates a backup of a file.
1624

1625
  @type file_name: str
1626
  @param file_name: file to be backed up
1627
  @rtype: str
1628
  @return: the path to the newly created backup
1629
  @raise errors.ProgrammerError: for invalid file names
1630

1631
  """
1632
  if not os.path.isfile(file_name):
1633
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1634
                                file_name)
1635

    
1636
  prefix = ("%s.backup-%s." %
1637
            (os.path.basename(file_name), TimestampForFilename()))
1638
  dir_name = os.path.dirname(file_name)
1639

    
1640
  fsrc = open(file_name, 'rb')
1641
  try:
1642
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1643
    fdst = os.fdopen(fd, 'wb')
1644
    try:
1645
      logging.debug("Backing up %s at %s", file_name, backup_name)
1646
      shutil.copyfileobj(fsrc, fdst)
1647
    finally:
1648
      fdst.close()
1649
  finally:
1650
    fsrc.close()
1651

    
1652
  return backup_name
1653

    
1654

    
1655
def ShellQuote(value):
1656
  """Quotes shell argument according to POSIX.
1657

1658
  @type value: str
1659
  @param value: the argument to be quoted
1660
  @rtype: str
1661
  @return: the quoted value
1662

1663
  """
1664
  if _re_shell_unquoted.match(value):
1665
    return value
1666
  else:
1667
    return "'%s'" % value.replace("'", "'\\''")
1668

    
1669

    
1670
def ShellQuoteArgs(args):
1671
  """Quotes a list of shell arguments.
1672

1673
  @type args: list
1674
  @param args: list of arguments to be quoted
1675
  @rtype: str
1676
  @return: the quoted arguments concatenated with spaces
1677

1678
  """
1679
  return ' '.join([ShellQuote(i) for i in args])
1680

    
1681

    
1682
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1683
  """Simple ping implementation using TCP connect(2).
1684

1685
  Check if the given IP is reachable by doing attempting a TCP connect
1686
  to it.
1687

1688
  @type target: str
1689
  @param target: the IP or hostname to ping
1690
  @type port: int
1691
  @param port: the port to connect to
1692
  @type timeout: int
1693
  @param timeout: the timeout on the connection attempt
1694
  @type live_port_needed: boolean
1695
  @param live_port_needed: whether a closed port will cause the
1696
      function to return failure, as if there was a timeout
1697
  @type source: str or None
1698
  @param source: if specified, will cause the connect to be made
1699
      from this specific source address; failures to bind other
1700
      than C{EADDRNOTAVAIL} will be ignored
1701

1702
  """
1703
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1704

    
1705
  success = False
1706

    
1707
  if source is not None:
1708
    try:
1709
      sock.bind((source, 0))
1710
    except socket.error, (errcode, _):
1711
      if errcode == errno.EADDRNOTAVAIL:
1712
        success = False
1713

    
1714
  sock.settimeout(timeout)
1715

    
1716
  try:
1717
    sock.connect((target, port))
1718
    sock.close()
1719
    success = True
1720
  except socket.timeout:
1721
    success = False
1722
  except socket.error, (errcode, _):
1723
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1724

    
1725
  return success
1726

    
1727

    
1728
def OwnIpAddress(address):
1729
  """Check if the current host has the the given IP address.
1730

1731
  Currently this is done by TCP-pinging the address from the loopback
1732
  address.
1733

1734
  @type address: string
1735
  @param address: the address to check
1736
  @rtype: bool
1737
  @return: True if we own the address
1738

1739
  """
1740
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1741
                 source=constants.LOCALHOST_IP_ADDRESS)
1742

    
1743

    
1744
def ListVisibleFiles(path, sort=True):
1745
  """Returns a list of visible files in a directory.
1746

1747
  @type path: str
1748
  @param path: the directory to enumerate
1749
  @type sort: boolean
1750
  @param sort: whether to provide a sorted output
1751
  @rtype: list
1752
  @return: the list of all files not starting with a dot
1753
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1754

1755
  """
1756
  if not IsNormAbsPath(path):
1757
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1758
                                 " absolute/normalized: '%s'" % path)
1759
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1760
  if sort:
1761
    files.sort()
1762
  return files
1763

    
1764

    
1765
def GetHomeDir(user, default=None):
1766
  """Try to get the homedir of the given user.
1767

1768
  The user can be passed either as a string (denoting the name) or as
1769
  an integer (denoting the user id). If the user is not found, the
1770
  'default' argument is returned, which defaults to None.
1771

1772
  """
1773
  try:
1774
    if isinstance(user, basestring):
1775
      result = pwd.getpwnam(user)
1776
    elif isinstance(user, (int, long)):
1777
      result = pwd.getpwuid(user)
1778
    else:
1779
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1780
                                   type(user))
1781
  except KeyError:
1782
    return default
1783
  return result.pw_dir
1784

    
1785

    
1786
def NewUUID():
1787
  """Returns a random UUID.
1788

1789
  @note: This is a Linux-specific method as it uses the /proc
1790
      filesystem.
1791
  @rtype: str
1792

1793
  """
1794
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1795

    
1796

    
1797
def GenerateSecret(numbytes=20):
1798
  """Generates a random secret.
1799

1800
  This will generate a pseudo-random secret returning an hex string
1801
  (so that it can be used where an ASCII string is needed).
1802

1803
  @param numbytes: the number of bytes which will be represented by the returned
1804
      string (defaulting to 20, the length of a SHA1 hash)
1805
  @rtype: str
1806
  @return: an hex representation of the pseudo-random sequence
1807

1808
  """
1809
  return os.urandom(numbytes).encode('hex')
1810

    
1811

    
1812
def EnsureDirs(dirs):
1813
  """Make required directories, if they don't exist.
1814

1815
  @param dirs: list of tuples (dir_name, dir_mode)
1816
  @type dirs: list of (string, integer)
1817

1818
  """
1819
  for dir_name, dir_mode in dirs:
1820
    try:
1821
      os.mkdir(dir_name, dir_mode)
1822
    except EnvironmentError, err:
1823
      if err.errno != errno.EEXIST:
1824
        raise errors.GenericError("Cannot create needed directory"
1825
                                  " '%s': %s" % (dir_name, err))
1826
    try:
1827
      os.chmod(dir_name, dir_mode)
1828
    except EnvironmentError, err:
1829
      raise errors.GenericError("Cannot change directory permissions on"
1830
                                " '%s': %s" % (dir_name, err))
1831
    if not os.path.isdir(dir_name):
1832
      raise errors.GenericError("%s is not a directory" % dir_name)
1833

    
1834

    
1835
def ReadFile(file_name, size=-1):
1836
  """Reads a file.
1837

1838
  @type size: int
1839
  @param size: Read at most size bytes (if negative, entire file)
1840
  @rtype: str
1841
  @return: the (possibly partial) content of the file
1842

1843
  """
1844
  f = open(file_name, "r")
1845
  try:
1846
    return f.read(size)
1847
  finally:
1848
    f.close()
1849

    
1850

    
1851
def WriteFile(file_name, fn=None, data=None,
1852
              mode=None, uid=-1, gid=-1,
1853
              atime=None, mtime=None, close=True,
1854
              dry_run=False, backup=False,
1855
              prewrite=None, postwrite=None):
1856
  """(Over)write a file atomically.
1857

1858
  The file_name and either fn (a function taking one argument, the
1859
  file descriptor, and which should write the data to it) or data (the
1860
  contents of the file) must be passed. The other arguments are
1861
  optional and allow setting the file mode, owner and group, and the
1862
  mtime/atime of the file.
1863

1864
  If the function doesn't raise an exception, it has succeeded and the
1865
  target file has the new contents. If the function has raised an
1866
  exception, an existing target file should be unmodified and the
1867
  temporary file should be removed.
1868

1869
  @type file_name: str
1870
  @param file_name: the target filename
1871
  @type fn: callable
1872
  @param fn: content writing function, called with
1873
      file descriptor as parameter
1874
  @type data: str
1875
  @param data: contents of the file
1876
  @type mode: int
1877
  @param mode: file mode
1878
  @type uid: int
1879
  @param uid: the owner of the file
1880
  @type gid: int
1881
  @param gid: the group of the file
1882
  @type atime: int
1883
  @param atime: a custom access time to be set on the file
1884
  @type mtime: int
1885
  @param mtime: a custom modification time to be set on the file
1886
  @type close: boolean
1887
  @param close: whether to close file after writing it
1888
  @type prewrite: callable
1889
  @param prewrite: function to be called before writing content
1890
  @type postwrite: callable
1891
  @param postwrite: function to be called after writing content
1892

1893
  @rtype: None or int
1894
  @return: None if the 'close' parameter evaluates to True,
1895
      otherwise the file descriptor
1896

1897
  @raise errors.ProgrammerError: if any of the arguments are not valid
1898

1899
  """
1900
  if not os.path.isabs(file_name):
1901
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1902
                                 " absolute: '%s'" % file_name)
1903

    
1904
  if [fn, data].count(None) != 1:
1905
    raise errors.ProgrammerError("fn or data required")
1906

    
1907
  if [atime, mtime].count(None) == 1:
1908
    raise errors.ProgrammerError("Both atime and mtime must be either"
1909
                                 " set or None")
1910

    
1911
  if backup and not dry_run and os.path.isfile(file_name):
1912
    CreateBackup(file_name)
1913

    
1914
  dir_name, base_name = os.path.split(file_name)
1915
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1916
  do_remove = True
1917
  # here we need to make sure we remove the temp file, if any error
1918
  # leaves it in place
1919
  try:
1920
    if uid != -1 or gid != -1:
1921
      os.chown(new_name, uid, gid)
1922
    if mode:
1923
      os.chmod(new_name, mode)
1924
    if callable(prewrite):
1925
      prewrite(fd)
1926
    if data is not None:
1927
      os.write(fd, data)
1928
    else:
1929
      fn(fd)
1930
    if callable(postwrite):
1931
      postwrite(fd)
1932
    os.fsync(fd)
1933
    if atime is not None and mtime is not None:
1934
      os.utime(new_name, (atime, mtime))
1935
    if not dry_run:
1936
      os.rename(new_name, file_name)
1937
      do_remove = False
1938
  finally:
1939
    if close:
1940
      os.close(fd)
1941
      result = None
1942
    else:
1943
      result = fd
1944
    if do_remove:
1945
      RemoveFile(new_name)
1946

    
1947
  return result
1948

    
1949

    
1950
def ReadOneLineFile(file_name, strict=False):
1951
  """Return the first non-empty line from a file.
1952

1953
  @type strict: boolean
1954
  @param strict: if True, abort if the file has more than one
1955
      non-empty line
1956

1957
  """
1958
  file_lines = ReadFile(file_name).splitlines()
1959
  full_lines = filter(bool, file_lines)
1960
  if not file_lines or not full_lines:
1961
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1962
  elif strict and len(full_lines) > 1:
1963
    raise errors.GenericError("Too many lines in one-liner file %s" %
1964
                              file_name)
1965
  return full_lines[0]
1966

    
1967

    
1968
def FirstFree(seq, base=0):
1969
  """Returns the first non-existing integer from seq.
1970

1971
  The seq argument should be a sorted list of positive integers. The
1972
  first time the index of an element is smaller than the element
1973
  value, the index will be returned.
1974

1975
  The base argument is used to start at a different offset,
1976
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1977

1978
  Example: C{[0, 1, 3]} will return I{2}.
1979

1980
  @type seq: sequence
1981
  @param seq: the sequence to be analyzed.
1982
  @type base: int
1983
  @param base: use this value as the base index of the sequence
1984
  @rtype: int
1985
  @return: the first non-used index in the sequence
1986

1987
  """
1988
  for idx, elem in enumerate(seq):
1989
    assert elem >= base, "Passed element is higher than base offset"
1990
    if elem > idx + base:
1991
      # idx is not used
1992
      return idx + base
1993
  return None
1994

    
1995

    
1996
def SingleWaitForFdCondition(fdobj, event, timeout):
1997
  """Waits for a condition to occur on the socket.
1998

1999
  Immediately returns at the first interruption.
2000

2001
  @type fdobj: integer or object supporting a fileno() method
2002
  @param fdobj: entity to wait for events on
2003
  @type event: integer
2004
  @param event: ORed condition (see select module)
2005
  @type timeout: float or None
2006
  @param timeout: Timeout in seconds
2007
  @rtype: int or None
2008
  @return: None for timeout, otherwise occured conditions
2009

2010
  """
2011
  check = (event | select.POLLPRI |
2012
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
2013

    
2014
  if timeout is not None:
2015
    # Poller object expects milliseconds
2016
    timeout *= 1000
2017

    
2018
  poller = select.poll()
2019
  poller.register(fdobj, event)
2020
  try:
2021
    # TODO: If the main thread receives a signal and we have no timeout, we
2022
    # could wait forever. This should check a global "quit" flag or something
2023
    # every so often.
2024
    io_events = poller.poll(timeout)
2025
  except select.error, err:
2026
    if err[0] != errno.EINTR:
2027
      raise
2028
    io_events = []
2029
  if io_events and io_events[0][1] & check:
2030
    return io_events[0][1]
2031
  else:
2032
    return None
2033

    
2034

    
2035
class FdConditionWaiterHelper(object):
2036
  """Retry helper for WaitForFdCondition.
2037

2038
  This class contains the retried and wait functions that make sure
2039
  WaitForFdCondition can continue waiting until the timeout is actually
2040
  expired.
2041

2042
  """
2043

    
2044
  def __init__(self, timeout):
2045
    self.timeout = timeout
2046

    
2047
  def Poll(self, fdobj, event):
2048
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2049
    if result is None:
2050
      raise RetryAgain()
2051
    else:
2052
      return result
2053

    
2054
  def UpdateTimeout(self, timeout):
2055
    self.timeout = timeout
2056

    
2057

    
2058
def WaitForFdCondition(fdobj, event, timeout):
2059
  """Waits for a condition to occur on the socket.
2060

2061
  Retries until the timeout is expired, even if interrupted.
2062

2063
  @type fdobj: integer or object supporting a fileno() method
2064
  @param fdobj: entity to wait for events on
2065
  @type event: integer
2066
  @param event: ORed condition (see select module)
2067
  @type timeout: float or None
2068
  @param timeout: Timeout in seconds
2069
  @rtype: int or None
2070
  @return: None for timeout, otherwise occured conditions
2071

2072
  """
2073
  if timeout is not None:
2074
    retrywaiter = FdConditionWaiterHelper(timeout)
2075
    try:
2076
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2077
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2078
    except RetryTimeout:
2079
      result = None
2080
  else:
2081
    result = None
2082
    while result is None:
2083
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2084
  return result
2085

    
2086

    
2087
def UniqueSequence(seq):
2088
  """Returns a list with unique elements.
2089

2090
  Element order is preserved.
2091

2092
  @type seq: sequence
2093
  @param seq: the sequence with the source elements
2094
  @rtype: list
2095
  @return: list of unique elements from seq
2096

2097
  """
2098
  seen = set()
2099
  return [i for i in seq if i not in seen and not seen.add(i)]
2100

    
2101

    
2102
def NormalizeAndValidateMac(mac):
2103
  """Normalizes and check if a MAC address is valid.
2104

2105
  Checks whether the supplied MAC address is formally correct, only
2106
  accepts colon separated format. Normalize it to all lower.
2107

2108
  @type mac: str
2109
  @param mac: the MAC to be validated
2110
  @rtype: str
2111
  @return: returns the normalized and validated MAC.
2112

2113
  @raise errors.OpPrereqError: If the MAC isn't valid
2114

2115
  """
2116
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2117
  if not mac_check.match(mac):
2118
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2119
                               mac, errors.ECODE_INVAL)
2120

    
2121
  return mac.lower()
2122

    
2123

    
2124
def TestDelay(duration):
2125
  """Sleep for a fixed amount of time.
2126

2127
  @type duration: float
2128
  @param duration: the sleep duration
2129
  @rtype: boolean
2130
  @return: False for negative value, True otherwise
2131

2132
  """
2133
  if duration < 0:
2134
    return False, "Invalid sleep duration"
2135
  time.sleep(duration)
2136
  return True, None
2137

    
2138

    
2139
def _CloseFDNoErr(fd, retries=5):
2140
  """Close a file descriptor ignoring errors.
2141

2142
  @type fd: int
2143
  @param fd: the file descriptor
2144
  @type retries: int
2145
  @param retries: how many retries to make, in case we get any
2146
      other error than EBADF
2147

2148
  """
2149
  try:
2150
    os.close(fd)
2151
  except OSError, err:
2152
    if err.errno != errno.EBADF:
2153
      if retries > 0:
2154
        _CloseFDNoErr(fd, retries - 1)
2155
    # else either it's closed already or we're out of retries, so we
2156
    # ignore this and go on
2157

    
2158

    
2159
def CloseFDs(noclose_fds=None):
2160
  """Close file descriptors.
2161

2162
  This closes all file descriptors above 2 (i.e. except
2163
  stdin/out/err).
2164

2165
  @type noclose_fds: list or None
2166
  @param noclose_fds: if given, it denotes a list of file descriptor
2167
      that should not be closed
2168

2169
  """
2170
  # Default maximum for the number of available file descriptors.
2171
  if 'SC_OPEN_MAX' in os.sysconf_names:
2172
    try:
2173
      MAXFD = os.sysconf('SC_OPEN_MAX')
2174
      if MAXFD < 0:
2175
        MAXFD = 1024
2176
    except OSError:
2177
      MAXFD = 1024
2178
  else:
2179
    MAXFD = 1024
2180
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2181
  if (maxfd == resource.RLIM_INFINITY):
2182
    maxfd = MAXFD
2183

    
2184
  # Iterate through and close all file descriptors (except the standard ones)
2185
  for fd in range(3, maxfd):
2186
    if noclose_fds and fd in noclose_fds:
2187
      continue
2188
    _CloseFDNoErr(fd)
2189

    
2190

    
2191
def Mlockall():
2192
  """Lock current process' virtual address space into RAM.
2193

2194
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2195
  see mlock(2) for more details. This function requires ctypes module.
2196

2197
  """
2198
  if ctypes is None:
2199
    logging.warning("Cannot set memory lock, ctypes module not found")
2200
    return
2201

    
2202
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2203
  if libc is None:
2204
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2205
    return
2206

    
2207
  # Some older version of the ctypes module don't have built-in functionality
2208
  # to access the errno global variable, where function error codes are stored.
2209
  # By declaring this variable as a pointer to an integer we can then access
2210
  # its value correctly, should the mlockall call fail, in order to see what
2211
  # the actual error code was.
2212
  # pylint: disable-msg=W0212
2213
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2214

    
2215
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2216
    # pylint: disable-msg=W0212
2217
    logging.error("Cannot set memory lock: %s",
2218
                  os.strerror(libc.__errno_location().contents.value))
2219
    return
2220

    
2221
  logging.debug("Memory lock set")
2222

    
2223

    
2224
def Daemonize(logfile, run_uid, run_gid):
2225
  """Daemonize the current process.
2226

2227
  This detaches the current process from the controlling terminal and
2228
  runs it in the background as a daemon.
2229

2230
  @type logfile: str
2231
  @param logfile: the logfile to which we should redirect stdout/stderr
2232
  @type run_uid: int
2233
  @param run_uid: Run the child under this uid
2234
  @type run_gid: int
2235
  @param run_gid: Run the child under this gid
2236
  @rtype: int
2237
  @return: the value zero
2238

2239
  """
2240
  # pylint: disable-msg=W0212
2241
  # yes, we really want os._exit
2242
  UMASK = 077
2243
  WORKDIR = "/"
2244

    
2245
  # this might fail
2246
  pid = os.fork()
2247
  if (pid == 0):  # The first child.
2248
    os.setsid()
2249
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2250
    #        make sure to check for config permission and bail out when invoked
2251
    #        with wrong user.
2252
    os.setgid(run_gid)
2253
    os.setuid(run_uid)
2254
    # this might fail
2255
    pid = os.fork() # Fork a second child.
2256
    if (pid == 0):  # The second child.
2257
      os.chdir(WORKDIR)
2258
      os.umask(UMASK)
2259
    else:
2260
      # exit() or _exit()?  See below.
2261
      os._exit(0) # Exit parent (the first child) of the second child.
2262
  else:
2263
    os._exit(0) # Exit parent of the first child.
2264

    
2265
  for fd in range(3):
2266
    _CloseFDNoErr(fd)
2267
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2268
  assert i == 0, "Can't close/reopen stdin"
2269
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2270
  assert i == 1, "Can't close/reopen stdout"
2271
  # Duplicate standard output to standard error.
2272
  os.dup2(1, 2)
2273
  return 0
2274

    
2275

    
2276
def DaemonPidFileName(name):
2277
  """Compute a ganeti pid file absolute path
2278

2279
  @type name: str
2280
  @param name: the daemon name
2281
  @rtype: str
2282
  @return: the full path to the pidfile corresponding to the given
2283
      daemon name
2284

2285
  """
2286
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2287

    
2288

    
2289
def EnsureDaemon(name):
2290
  """Check for and start daemon if not alive.
2291

2292
  """
2293
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2294
  if result.failed:
2295
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2296
                  name, result.fail_reason, result.output)
2297
    return False
2298

    
2299
  return True
2300

    
2301

    
2302
def StopDaemon(name):
2303
  """Stop daemon
2304

2305
  """
2306
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2307
  if result.failed:
2308
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2309
                  name, result.fail_reason, result.output)
2310
    return False
2311

    
2312
  return True
2313

    
2314

    
2315
def WritePidFile(name):
2316
  """Write the current process pidfile.
2317

2318
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2319

2320
  @type name: str
2321
  @param name: the daemon name to use
2322
  @raise errors.GenericError: if the pid file already exists and
2323
      points to a live process
2324

2325
  """
2326
  pid = os.getpid()
2327
  pidfilename = DaemonPidFileName(name)
2328
  if IsProcessAlive(ReadPidFile(pidfilename)):
2329
    raise errors.GenericError("%s contains a live process" % pidfilename)
2330

    
2331
  WriteFile(pidfilename, data="%d\n" % pid)
2332

    
2333

    
2334
def RemovePidFile(name):
2335
  """Remove the current process pidfile.
2336

2337
  Any errors are ignored.
2338

2339
  @type name: str
2340
  @param name: the daemon name used to derive the pidfile name
2341

2342
  """
2343
  pidfilename = DaemonPidFileName(name)
2344
  # TODO: we could check here that the file contains our pid
2345
  try:
2346
    RemoveFile(pidfilename)
2347
  except: # pylint: disable-msg=W0702
2348
    pass
2349

    
2350

    
2351
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2352
                waitpid=False):
2353
  """Kill a process given by its pid.
2354

2355
  @type pid: int
2356
  @param pid: The PID to terminate.
2357
  @type signal_: int
2358
  @param signal_: The signal to send, by default SIGTERM
2359
  @type timeout: int
2360
  @param timeout: The timeout after which, if the process is still alive,
2361
                  a SIGKILL will be sent. If not positive, no such checking
2362
                  will be done
2363
  @type waitpid: boolean
2364
  @param waitpid: If true, we should waitpid on this process after
2365
      sending signals, since it's our own child and otherwise it
2366
      would remain as zombie
2367

2368
  """
2369
  def _helper(pid, signal_, wait):
2370
    """Simple helper to encapsulate the kill/waitpid sequence"""
2371
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2372
      try:
2373
        os.waitpid(pid, os.WNOHANG)
2374
      except OSError:
2375
        pass
2376

    
2377
  if pid <= 0:
2378
    # kill with pid=0 == suicide
2379
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2380

    
2381
  if not IsProcessAlive(pid):
2382
    return
2383

    
2384
  _helper(pid, signal_, waitpid)
2385

    
2386
  if timeout <= 0:
2387
    return
2388

    
2389
  def _CheckProcess():
2390
    if not IsProcessAlive(pid):
2391
      return
2392

    
2393
    try:
2394
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2395
    except OSError:
2396
      raise RetryAgain()
2397

    
2398
    if result_pid > 0:
2399
      return
2400

    
2401
    raise RetryAgain()
2402

    
2403
  try:
2404
    # Wait up to $timeout seconds
2405
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2406
  except RetryTimeout:
2407
    pass
2408

    
2409
  if IsProcessAlive(pid):
2410
    # Kill process if it's still alive
2411
    _helper(pid, signal.SIGKILL, waitpid)
2412

    
2413

    
2414
def FindFile(name, search_path, test=os.path.exists):
2415
  """Look for a filesystem object in a given path.
2416

2417
  This is an abstract method to search for filesystem object (files,
2418
  dirs) under a given search path.
2419

2420
  @type name: str
2421
  @param name: the name to look for
2422
  @type search_path: str
2423
  @param search_path: location to start at
2424
  @type test: callable
2425
  @param test: a function taking one argument that should return True
2426
      if the a given object is valid; the default value is
2427
      os.path.exists, causing only existing files to be returned
2428
  @rtype: str or None
2429
  @return: full path to the object if found, None otherwise
2430

2431
  """
2432
  # validate the filename mask
2433
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2434
    logging.critical("Invalid value passed for external script name: '%s'",
2435
                     name)
2436
    return None
2437

    
2438
  for dir_name in search_path:
2439
    # FIXME: investigate switch to PathJoin
2440
    item_name = os.path.sep.join([dir_name, name])
2441
    # check the user test and that we're indeed resolving to the given
2442
    # basename
2443
    if test(item_name) and os.path.basename(item_name) == name:
2444
      return item_name
2445
  return None
2446

    
2447

    
2448
def CheckVolumeGroupSize(vglist, vgname, minsize):
2449
  """Checks if the volume group list is valid.
2450

2451
  The function will check if a given volume group is in the list of
2452
  volume groups and has a minimum size.
2453

2454
  @type vglist: dict
2455
  @param vglist: dictionary of volume group names and their size
2456
  @type vgname: str
2457
  @param vgname: the volume group we should check
2458
  @type minsize: int
2459
  @param minsize: the minimum size we accept
2460
  @rtype: None or str
2461
  @return: None for success, otherwise the error message
2462

2463
  """
2464
  vgsize = vglist.get(vgname, None)
2465
  if vgsize is None:
2466
    return "volume group '%s' missing" % vgname
2467
  elif vgsize < minsize:
2468
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2469
            (vgname, minsize, vgsize))
2470
  return None
2471

    
2472

    
2473
def SplitTime(value):
2474
  """Splits time as floating point number into a tuple.
2475

2476
  @param value: Time in seconds
2477
  @type value: int or float
2478
  @return: Tuple containing (seconds, microseconds)
2479

2480
  """
2481
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2482

    
2483
  assert 0 <= seconds, \
2484
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2485
  assert 0 <= microseconds <= 999999, \
2486
    "Microseconds must be 0-999999, but are %s" % microseconds
2487

    
2488
  return (int(seconds), int(microseconds))
2489

    
2490

    
2491
def MergeTime(timetuple):
2492
  """Merges a tuple into time as a floating point number.
2493

2494
  @param timetuple: Time as tuple, (seconds, microseconds)
2495
  @type timetuple: tuple
2496
  @return: Time as a floating point number expressed in seconds
2497

2498
  """
2499
  (seconds, microseconds) = timetuple
2500

    
2501
  assert 0 <= seconds, \
2502
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2503
  assert 0 <= microseconds <= 999999, \
2504
    "Microseconds must be 0-999999, but are %s" % microseconds
2505

    
2506
  return float(seconds) + (float(microseconds) * 0.000001)
2507

    
2508

    
2509
def GetDaemonPort(daemon_name):
2510
  """Get the daemon port for this cluster.
2511

2512
  Note that this routine does not read a ganeti-specific file, but
2513
  instead uses C{socket.getservbyname} to allow pre-customization of
2514
  this parameter outside of Ganeti.
2515

2516
  @type daemon_name: string
2517
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2518
  @rtype: int
2519

2520
  """
2521
  if daemon_name not in constants.DAEMONS_PORTS:
2522
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2523

    
2524
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2525
  try:
2526
    port = socket.getservbyname(daemon_name, proto)
2527
  except socket.error:
2528
    port = default_port
2529

    
2530
  return port
2531

    
2532

    
2533
class LogFileHandler(logging.FileHandler):
2534
  """Log handler that doesn't fallback to stderr.
2535

2536
  When an error occurs while writing on the logfile, logging.FileHandler tries
2537
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2538
  the logfile. This class avoids failures reporting errors to /dev/console.
2539

2540
  """
2541
  def __init__(self, filename, mode="a", encoding=None):
2542
    """Open the specified file and use it as the stream for logging.
2543

2544
    Also open /dev/console to report errors while logging.
2545

2546
    """
2547
    logging.FileHandler.__init__(self, filename, mode, encoding)
2548
    self.console = open(constants.DEV_CONSOLE, "a")
2549

    
2550
  def handleError(self, record): # pylint: disable-msg=C0103
2551
    """Handle errors which occur during an emit() call.
2552

2553
    Try to handle errors with FileHandler method, if it fails write to
2554
    /dev/console.
2555

2556
    """
2557
    try:
2558
      logging.FileHandler.handleError(self, record)
2559
    except Exception: # pylint: disable-msg=W0703
2560
      try:
2561
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2562
      except Exception: # pylint: disable-msg=W0703
2563
        # Log handler tried everything it could, now just give up
2564
        pass
2565

    
2566

    
2567
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2568
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2569
                 console_logging=False):
2570
  """Configures the logging module.
2571

2572
  @type logfile: str
2573
  @param logfile: the filename to which we should log
2574
  @type debug: integer
2575
  @param debug: if greater than zero, enable debug messages, otherwise
2576
      only those at C{INFO} and above level
2577
  @type stderr_logging: boolean
2578
  @param stderr_logging: whether we should also log to the standard error
2579
  @type program: str
2580
  @param program: the name under which we should log messages
2581
  @type multithreaded: boolean
2582
  @param multithreaded: if True, will add the thread name to the log file
2583
  @type syslog: string
2584
  @param syslog: one of 'no', 'yes', 'only':
2585
      - if no, syslog is not used
2586
      - if yes, syslog is used (in addition to file-logging)
2587
      - if only, only syslog is used
2588
  @type console_logging: boolean
2589
  @param console_logging: if True, will use a FileHandler which falls back to
2590
      the system console if logging fails
2591
  @raise EnvironmentError: if we can't open the log file and
2592
      syslog/stderr logging is disabled
2593

2594
  """
2595
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2596
  sft = program + "[%(process)d]:"
2597
  if multithreaded:
2598
    fmt += "/%(threadName)s"
2599
    sft += " (%(threadName)s)"
2600
  if debug:
2601
    fmt += " %(module)s:%(lineno)s"
2602
    # no debug info for syslog loggers
2603
  fmt += " %(levelname)s %(message)s"
2604
  # yes, we do want the textual level, as remote syslog will probably
2605
  # lose the error level, and it's easier to grep for it
2606
  sft += " %(levelname)s %(message)s"
2607
  formatter = logging.Formatter(fmt)
2608
  sys_fmt = logging.Formatter(sft)
2609

    
2610
  root_logger = logging.getLogger("")
2611
  root_logger.setLevel(logging.NOTSET)
2612

    
2613
  # Remove all previously setup handlers
2614
  for handler in root_logger.handlers:
2615
    handler.close()
2616
    root_logger.removeHandler(handler)
2617

    
2618
  if stderr_logging:
2619
    stderr_handler = logging.StreamHandler()
2620
    stderr_handler.setFormatter(formatter)
2621
    if debug:
2622
      stderr_handler.setLevel(logging.NOTSET)
2623
    else:
2624
      stderr_handler.setLevel(logging.CRITICAL)
2625
    root_logger.addHandler(stderr_handler)
2626

    
2627
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2628
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2629
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2630
                                                    facility)
2631
    syslog_handler.setFormatter(sys_fmt)
2632
    # Never enable debug over syslog
2633
    syslog_handler.setLevel(logging.INFO)
2634
    root_logger.addHandler(syslog_handler)
2635

    
2636
  if syslog != constants.SYSLOG_ONLY:
2637
    # this can fail, if the logging directories are not setup or we have
2638
    # a permisssion problem; in this case, it's best to log but ignore
2639
    # the error if stderr_logging is True, and if false we re-raise the
2640
    # exception since otherwise we could run but without any logs at all
2641
    try:
2642
      if console_logging:
2643
        logfile_handler = LogFileHandler(logfile)
2644
      else:
2645
        logfile_handler = logging.FileHandler(logfile)
2646
      logfile_handler.setFormatter(formatter)
2647
      if debug:
2648
        logfile_handler.setLevel(logging.DEBUG)
2649
      else:
2650
        logfile_handler.setLevel(logging.INFO)
2651
      root_logger.addHandler(logfile_handler)
2652
    except EnvironmentError:
2653
      if stderr_logging or syslog == constants.SYSLOG_YES:
2654
        logging.exception("Failed to enable logging to file '%s'", logfile)
2655
      else:
2656
        # we need to re-raise the exception
2657
        raise
2658

    
2659

    
2660
def IsNormAbsPath(path):
2661
  """Check whether a path is absolute and also normalized
2662

2663
  This avoids things like /dir/../../other/path to be valid.
2664

2665
  """
2666
  return os.path.normpath(path) == path and os.path.isabs(path)
2667

    
2668

    
2669
def PathJoin(*args):
2670
  """Safe-join a list of path components.
2671

2672
  Requirements:
2673
      - the first argument must be an absolute path
2674
      - no component in the path must have backtracking (e.g. /../),
2675
        since we check for normalization at the end
2676

2677
  @param args: the path components to be joined
2678
  @raise ValueError: for invalid paths
2679

2680
  """
2681
  # ensure we're having at least one path passed in
2682
  assert args
2683
  # ensure the first component is an absolute and normalized path name
2684
  root = args[0]
2685
  if not IsNormAbsPath(root):
2686
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2687
  result = os.path.join(*args)
2688
  # ensure that the whole path is normalized
2689
  if not IsNormAbsPath(result):
2690
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2691
  # check that we're still under the original prefix
2692
  prefix = os.path.commonprefix([root, result])
2693
  if prefix != root:
2694
    raise ValueError("Error: path joining resulted in different prefix"
2695
                     " (%s != %s)" % (prefix, root))
2696
  return result
2697

    
2698

    
2699
def TailFile(fname, lines=20):
2700
  """Return the last lines from a file.
2701

2702
  @note: this function will only read and parse the last 4KB of
2703
      the file; if the lines are very long, it could be that less
2704
      than the requested number of lines are returned
2705

2706
  @param fname: the file name
2707
  @type lines: int
2708
  @param lines: the (maximum) number of lines to return
2709

2710
  """
2711
  fd = open(fname, "r")
2712
  try:
2713
    fd.seek(0, 2)
2714
    pos = fd.tell()
2715
    pos = max(0, pos-4096)
2716
    fd.seek(pos, 0)
2717
    raw_data = fd.read()
2718
  finally:
2719
    fd.close()
2720

    
2721
  rows = raw_data.splitlines()
2722
  return rows[-lines:]
2723

    
2724

    
2725
def FormatTimestampWithTZ(secs):
2726
  """Formats a Unix timestamp with the local timezone.
2727

2728
  """
2729
  return time.strftime("%F %T %Z", time.gmtime(secs))
2730

    
2731

    
2732
def _ParseAsn1Generalizedtime(value):
2733
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2734

2735
  @type value: string
2736
  @param value: ASN1 GENERALIZEDTIME timestamp
2737

2738
  """
2739
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2740
  if m:
2741
    # We have an offset
2742
    asn1time = m.group(1)
2743
    hours = int(m.group(2))
2744
    minutes = int(m.group(3))
2745
    utcoffset = (60 * hours) + minutes
2746
  else:
2747
    if not value.endswith("Z"):
2748
      raise ValueError("Missing timezone")
2749
    asn1time = value[:-1]
2750
    utcoffset = 0
2751

    
2752
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2753

    
2754
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2755

    
2756
  return calendar.timegm(tt.utctimetuple())
2757

    
2758

    
2759
def GetX509CertValidity(cert):
2760
  """Returns the validity period of the certificate.
2761

2762
  @type cert: OpenSSL.crypto.X509
2763
  @param cert: X509 certificate object
2764

2765
  """
2766
  # The get_notBefore and get_notAfter functions are only supported in
2767
  # pyOpenSSL 0.7 and above.
2768
  try:
2769
    get_notbefore_fn = cert.get_notBefore
2770
  except AttributeError:
2771
    not_before = None
2772
  else:
2773
    not_before_asn1 = get_notbefore_fn()
2774

    
2775
    if not_before_asn1 is None:
2776
      not_before = None
2777
    else:
2778
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2779

    
2780
  try:
2781
    get_notafter_fn = cert.get_notAfter
2782
  except AttributeError:
2783
    not_after = None
2784
  else:
2785
    not_after_asn1 = get_notafter_fn()
2786

    
2787
    if not_after_asn1 is None:
2788
      not_after = None
2789
    else:
2790
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2791

    
2792
  return (not_before, not_after)
2793

    
2794

    
2795
def _VerifyCertificateInner(expired, not_before, not_after, now,
2796
                            warn_days, error_days):
2797
  """Verifies certificate validity.
2798

2799
  @type expired: bool
2800
  @param expired: Whether pyOpenSSL considers the certificate as expired
2801
  @type not_before: number or None
2802
  @param not_before: Unix timestamp before which certificate is not valid
2803
  @type not_after: number or None
2804
  @param not_after: Unix timestamp after which certificate is invalid
2805
  @type now: number
2806
  @param now: Current time as Unix timestamp
2807
  @type warn_days: number or None
2808
  @param warn_days: How many days before expiration a warning should be reported
2809
  @type error_days: number or None
2810
  @param error_days: How many days before expiration an error should be reported
2811

2812
  """
2813
  if expired:
2814
    msg = "Certificate is expired"
2815

    
2816
    if not_before is not None and not_after is not None:
2817
      msg += (" (valid from %s to %s)" %
2818
              (FormatTimestampWithTZ(not_before),
2819
               FormatTimestampWithTZ(not_after)))
2820
    elif not_before is not None:
2821
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2822
    elif not_after is not None:
2823
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2824

    
2825
    return (CERT_ERROR, msg)
2826

    
2827
  elif not_before is not None and not_before > now:
2828
    return (CERT_WARNING,
2829
            "Certificate not yet valid (valid from %s)" %
2830
            FormatTimestampWithTZ(not_before))
2831

    
2832
  elif not_after is not None:
2833
    remaining_days = int((not_after - now) / (24 * 3600))
2834

    
2835
    msg = "Certificate expires in about %d days" % remaining_days
2836

    
2837
    if error_days is not None and remaining_days <= error_days:
2838
      return (CERT_ERROR, msg)
2839

    
2840
    if warn_days is not None and remaining_days <= warn_days:
2841
      return (CERT_WARNING, msg)
2842

    
2843
  return (None, None)
2844

    
2845

    
2846
def VerifyX509Certificate(cert, warn_days, error_days):
2847
  """Verifies a certificate for LUVerifyCluster.
2848

2849
  @type cert: OpenSSL.crypto.X509
2850
  @param cert: X509 certificate object
2851
  @type warn_days: number or None
2852
  @param warn_days: How many days before expiration a warning should be reported
2853
  @type error_days: number or None
2854
  @param error_days: How many days before expiration an error should be reported
2855

2856
  """
2857
  # Depending on the pyOpenSSL version, this can just return (None, None)
2858
  (not_before, not_after) = GetX509CertValidity(cert)
2859

    
2860
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2861
                                 time.time(), warn_days, error_days)
2862

    
2863

    
2864
def SignX509Certificate(cert, key, salt):
2865
  """Sign a X509 certificate.
2866

2867
  An RFC822-like signature header is added in front of the certificate.
2868

2869
  @type cert: OpenSSL.crypto.X509
2870
  @param cert: X509 certificate object
2871
  @type key: string
2872
  @param key: Key for HMAC
2873
  @type salt: string
2874
  @param salt: Salt for HMAC
2875
  @rtype: string
2876
  @return: Serialized and signed certificate in PEM format
2877

2878
  """
2879
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2880
    raise errors.GenericError("Invalid salt: %r" % salt)
2881

    
2882
  # Dumping as PEM here ensures the certificate is in a sane format
2883
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2884

    
2885
  return ("%s: %s/%s\n\n%s" %
2886
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2887
           Sha1Hmac(key, cert_pem, salt=salt),
2888
           cert_pem))
2889

    
2890

    
2891
def _ExtractX509CertificateSignature(cert_pem):
2892
  """Helper function to extract signature from X509 certificate.
2893

2894
  """
2895
  # Extract signature from original PEM data
2896
  for line in cert_pem.splitlines():
2897
    if line.startswith("---"):
2898
      break
2899

    
2900
    m = X509_SIGNATURE.match(line.strip())
2901
    if m:
2902
      return (m.group("salt"), m.group("sign"))
2903

    
2904
  raise errors.GenericError("X509 certificate signature is missing")
2905

    
2906

    
2907
def LoadSignedX509Certificate(cert_pem, key):
2908
  """Verifies a signed X509 certificate.
2909

2910
  @type cert_pem: string
2911
  @param cert_pem: Certificate in PEM format and with signature header
2912
  @type key: string
2913
  @param key: Key for HMAC
2914
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2915
  @return: X509 certificate object and salt
2916

2917
  """
2918
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2919

    
2920
  # Load certificate
2921
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2922

    
2923
  # Dump again to ensure it's in a sane format
2924
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2925

    
2926
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2927
    raise errors.GenericError("X509 certificate signature is invalid")
2928

    
2929
  return (cert, salt)
2930

    
2931

    
2932
def Sha1Hmac(key, text, salt=None):
2933
  """Calculates the HMAC-SHA1 digest of a text.
2934

2935
  HMAC is defined in RFC2104.
2936

2937
  @type key: string
2938
  @param key: Secret key
2939
  @type text: string
2940

2941
  """
2942
  if salt:
2943
    salted_text = salt + text
2944
  else:
2945
    salted_text = text
2946

    
2947
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
2948

    
2949

    
2950
def VerifySha1Hmac(key, text, digest, salt=None):
2951
  """Verifies the HMAC-SHA1 digest of a text.
2952

2953
  HMAC is defined in RFC2104.
2954

2955
  @type key: string
2956
  @param key: Secret key
2957
  @type text: string
2958
  @type digest: string
2959
  @param digest: Expected digest
2960
  @rtype: bool
2961
  @return: Whether HMAC-SHA1 digest matches
2962

2963
  """
2964
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
2965

    
2966

    
2967
def SafeEncode(text):
2968
  """Return a 'safe' version of a source string.
2969

2970
  This function mangles the input string and returns a version that
2971
  should be safe to display/encode as ASCII. To this end, we first
2972
  convert it to ASCII using the 'backslashreplace' encoding which
2973
  should get rid of any non-ASCII chars, and then we process it
2974
  through a loop copied from the string repr sources in the python; we
2975
  don't use string_escape anymore since that escape single quotes and
2976
  backslashes too, and that is too much; and that escaping is not
2977
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2978

2979
  @type text: str or unicode
2980
  @param text: input data
2981
  @rtype: str
2982
  @return: a safe version of text
2983

2984
  """
2985
  if isinstance(text, unicode):
2986
    # only if unicode; if str already, we handle it below
2987
    text = text.encode('ascii', 'backslashreplace')
2988
  resu = ""
2989
  for char in text:
2990
    c = ord(char)
2991
    if char  == '\t':
2992
      resu += r'\t'
2993
    elif char == '\n':
2994
      resu += r'\n'
2995
    elif char == '\r':
2996
      resu += r'\'r'
2997
    elif c < 32 or c >= 127: # non-printable
2998
      resu += "\\x%02x" % (c & 0xff)
2999
    else:
3000
      resu += char
3001
  return resu
3002

    
3003

    
3004
def UnescapeAndSplit(text, sep=","):
3005
  """Split and unescape a string based on a given separator.
3006

3007
  This function splits a string based on a separator where the
3008
  separator itself can be escape in order to be an element of the
3009
  elements. The escaping rules are (assuming coma being the
3010
  separator):
3011
    - a plain , separates the elements
3012
    - a sequence \\\\, (double backslash plus comma) is handled as a
3013
      backslash plus a separator comma
3014
    - a sequence \, (backslash plus comma) is handled as a
3015
      non-separator comma
3016

3017
  @type text: string
3018
  @param text: the string to split
3019
  @type sep: string
3020
  @param text: the separator
3021
  @rtype: string
3022
  @return: a list of strings
3023

3024
  """
3025
  # we split the list by sep (with no escaping at this stage)
3026
  slist = text.split(sep)
3027
  # next, we revisit the elements and if any of them ended with an odd
3028
  # number of backslashes, then we join it with the next
3029
  rlist = []
3030
  while slist:
3031
    e1 = slist.pop(0)
3032
    if e1.endswith("\\"):
3033
      num_b = len(e1) - len(e1.rstrip("\\"))
3034
      if num_b % 2 == 1:
3035
        e2 = slist.pop(0)
3036
        # here the backslashes remain (all), and will be reduced in
3037
        # the next step
3038
        rlist.append(e1 + sep + e2)
3039
        continue
3040
    rlist.append(e1)
3041
  # finally, replace backslash-something with something
3042
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3043
  return rlist
3044

    
3045

    
3046
def CommaJoin(names):
3047
  """Nicely join a set of identifiers.
3048

3049
  @param names: set, list or tuple
3050
  @return: a string with the formatted results
3051

3052
  """
3053
  return ", ".join([str(val) for val in names])
3054

    
3055

    
3056
def BytesToMebibyte(value):
3057
  """Converts bytes to mebibytes.
3058

3059
  @type value: int
3060
  @param value: Value in bytes
3061
  @rtype: int
3062
  @return: Value in mebibytes
3063

3064
  """
3065
  return int(round(value / (1024.0 * 1024.0), 0))
3066

    
3067

    
3068
def CalculateDirectorySize(path):
3069
  """Calculates the size of a directory recursively.
3070

3071
  @type path: string
3072
  @param path: Path to directory
3073
  @rtype: int
3074
  @return: Size in mebibytes
3075

3076
  """
3077
  size = 0
3078

    
3079
  for (curpath, _, files) in os.walk(path):
3080
    for filename in files:
3081
      st = os.lstat(PathJoin(curpath, filename))
3082
      size += st.st_size
3083

    
3084
  return BytesToMebibyte(size)
3085

    
3086

    
3087
def GetFilesystemStats(path):
3088
  """Returns the total and free space on a filesystem.
3089

3090
  @type path: string
3091
  @param path: Path on filesystem to be examined
3092
  @rtype: int
3093
  @return: tuple of (Total space, Free space) in mebibytes
3094

3095
  """
3096
  st = os.statvfs(path)
3097

    
3098
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3099
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3100
  return (tsize, fsize)
3101

    
3102

    
3103
def RunInSeparateProcess(fn, *args):
3104
  """Runs a function in a separate process.
3105

3106
  Note: Only boolean return values are supported.
3107

3108
  @type fn: callable
3109
  @param fn: Function to be called
3110
  @rtype: bool
3111
  @return: Function's result
3112

3113
  """
3114
  pid = os.fork()
3115
  if pid == 0:
3116
    # Child process
3117
    try:
3118
      # In case the function uses temporary files
3119
      ResetTempfileModule()
3120

    
3121
      # Call function
3122
      result = int(bool(fn(*args)))
3123
      assert result in (0, 1)
3124
    except: # pylint: disable-msg=W0702
3125
      logging.exception("Error while calling function in separate process")
3126
      # 0 and 1 are reserved for the return value
3127
      result = 33
3128

    
3129
    os._exit(result) # pylint: disable-msg=W0212
3130

    
3131
  # Parent process
3132

    
3133
  # Avoid zombies and check exit code
3134
  (_, status) = os.waitpid(pid, 0)
3135

    
3136
  if os.WIFSIGNALED(status):
3137
    exitcode = None
3138
    signum = os.WTERMSIG(status)
3139
  else:
3140
    exitcode = os.WEXITSTATUS(status)
3141
    signum = None
3142

    
3143
  if not (exitcode in (0, 1) and signum is None):
3144
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3145
                              (exitcode, signum))
3146

    
3147
  return bool(exitcode)
3148

    
3149

    
3150
def IgnoreProcessNotFound(fn, *args, **kwargs):
3151
  """Ignores ESRCH when calling a process-related function.
3152

3153
  ESRCH is raised when a process is not found.
3154

3155
  @rtype: bool
3156
  @return: Whether process was found
3157

3158
  """
3159
  try:
3160
    fn(*args, **kwargs)
3161
  except EnvironmentError, err:
3162
    # Ignore ESRCH
3163
    if err.errno == errno.ESRCH:
3164
      return False
3165
    raise
3166

    
3167
  return True
3168

    
3169

    
3170
def IgnoreSignals(fn, *args, **kwargs):
3171
  """Tries to call a function ignoring failures due to EINTR.
3172

3173
  """
3174
  try:
3175
    return fn(*args, **kwargs)
3176
  except EnvironmentError, err:
3177
    if err.errno == errno.EINTR:
3178
      return None
3179
    else:
3180
      raise
3181
  except (select.error, socket.error), err:
3182
    # In python 2.6 and above select.error is an IOError, so it's handled
3183
    # above, in 2.5 and below it's not, and it's handled here.
3184
    if err.args and err.args[0] == errno.EINTR:
3185
      return None
3186
    else:
3187
      raise
3188

    
3189

    
3190
def LockedMethod(fn):
3191
  """Synchronized object access decorator.
3192

3193
  This decorator is intended to protect access to an object using the
3194
  object's own lock which is hardcoded to '_lock'.
3195

3196
  """
3197
  def _LockDebug(*args, **kwargs):
3198
    if debug_locks:
3199
      logging.debug(*args, **kwargs)
3200

    
3201
  def wrapper(self, *args, **kwargs):
3202
    # pylint: disable-msg=W0212
3203
    assert hasattr(self, '_lock')
3204
    lock = self._lock
3205
    _LockDebug("Waiting for %s", lock)
3206
    lock.acquire()
3207
    try:
3208
      _LockDebug("Acquired %s", lock)
3209
      result = fn(self, *args, **kwargs)
3210
    finally:
3211
      _LockDebug("Releasing %s", lock)
3212
      lock.release()
3213
      _LockDebug("Released %s", lock)
3214
    return result
3215
  return wrapper
3216

    
3217

    
3218
def LockFile(fd):
3219
  """Locks a file using POSIX locks.
3220

3221
  @type fd: int
3222
  @param fd: the file descriptor we need to lock
3223

3224
  """
3225
  try:
3226
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3227
  except IOError, err:
3228
    if err.errno == errno.EAGAIN:
3229
      raise errors.LockError("File already locked")
3230
    raise
3231

    
3232

    
3233
def FormatTime(val):
3234
  """Formats a time value.
3235

3236
  @type val: float or None
3237
  @param val: the timestamp as returned by time.time()
3238
  @return: a string value or N/A if we don't have a valid timestamp
3239

3240
  """
3241
  if val is None or not isinstance(val, (int, float)):
3242
    return "N/A"
3243
  # these two codes works on Linux, but they are not guaranteed on all
3244
  # platforms
3245
  return time.strftime("%F %T", time.localtime(val))
3246

    
3247

    
3248
def FormatSeconds(secs):
3249
  """Formats seconds for easier reading.
3250

3251
  @type secs: number
3252
  @param secs: Number of seconds
3253
  @rtype: string
3254
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3255

3256
  """
3257
  parts = []
3258

    
3259
  secs = round(secs, 0)
3260

    
3261
  if secs > 0:
3262
    # Negative values would be a bit tricky
3263
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3264
      (complete, secs) = divmod(secs, one)
3265
      if complete or parts:
3266
        parts.append("%d%s" % (complete, unit))
3267

    
3268
  parts.append("%ds" % secs)
3269

    
3270
  return " ".join(parts)
3271

    
3272

    
3273
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3274
  """Reads the watcher pause file.
3275

3276
  @type filename: string
3277
  @param filename: Path to watcher pause file
3278
  @type now: None, float or int
3279
  @param now: Current time as Unix timestamp
3280
  @type remove_after: int
3281
  @param remove_after: Remove watcher pause file after specified amount of
3282
    seconds past the pause end time
3283

3284
  """
3285
  if now is None:
3286
    now = time.time()
3287

    
3288
  try:
3289
    value = ReadFile(filename)
3290
  except IOError, err:
3291
    if err.errno != errno.ENOENT:
3292
      raise
3293
    value = None
3294

    
3295
  if value is not None:
3296
    try:
3297
      value = int(value)
3298
    except ValueError:
3299
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3300
                       " removing it"), filename)
3301
      RemoveFile(filename)
3302
      value = None
3303

    
3304
    if value is not None:
3305
      # Remove file if it's outdated
3306
      if now > (value + remove_after):
3307
        RemoveFile(filename)
3308
        value = None
3309

    
3310
      elif now > value:
3311
        value = None
3312

    
3313
  return value
3314

    
3315

    
3316
class RetryTimeout(Exception):
3317
  """Retry loop timed out.
3318

3319
  Any arguments which was passed by the retried function to RetryAgain will be
3320
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3321
  the RaiseInner helper method will reraise it.
3322

3323
  """
3324
  def RaiseInner(self):
3325
    if self.args and isinstance(self.args[0], Exception):
3326
      raise self.args[0]
3327
    else:
3328
      raise RetryTimeout(*self.args)
3329

    
3330

    
3331
class RetryAgain(Exception):
3332
  """Retry again.
3333

3334
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3335
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3336
  of the RetryTimeout() method can be used to reraise it.
3337

3338
  """
3339

    
3340

    
3341
class _RetryDelayCalculator(object):
3342
  """Calculator for increasing delays.
3343

3344
  """
3345
  __slots__ = [
3346
    "_factor",
3347
    "_limit",
3348
    "_next",
3349
    "_start",
3350
    ]
3351

    
3352
  def __init__(self, start, factor, limit):
3353
    """Initializes this class.
3354

3355
    @type start: float
3356
    @param start: Initial delay
3357
    @type factor: float
3358
    @param factor: Factor for delay increase
3359
    @type limit: float or None
3360
    @param limit: Upper limit for delay or None for no limit
3361

3362
    """
3363
    assert start > 0.0
3364
    assert factor >= 1.0
3365
    assert limit is None or limit >= 0.0
3366

    
3367
    self._start = start
3368
    self._factor = factor
3369
    self._limit = limit
3370

    
3371
    self._next = start
3372

    
3373
  def __call__(self):
3374
    """Returns current delay and calculates the next one.
3375

3376
    """
3377
    current = self._next
3378

    
3379
    # Update for next run
3380
    if self._limit is None or self._next < self._limit:
3381
      self._next = min(self._limit, self._next * self._factor)
3382

    
3383
    return current
3384

    
3385

    
3386
#: Special delay to specify whole remaining timeout
3387
RETRY_REMAINING_TIME = object()
3388

    
3389

    
3390
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3391
          _time_fn=time.time):
3392
  """Call a function repeatedly until it succeeds.
3393

3394
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3395
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3396
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3397

3398
  C{delay} can be one of the following:
3399
    - callable returning the delay length as a float
3400
    - Tuple of (start, factor, limit)
3401
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3402
      useful when overriding L{wait_fn} to wait for an external event)
3403
    - A static delay as a number (int or float)
3404

3405
  @type fn: callable
3406
  @param fn: Function to be called
3407
  @param delay: Either a callable (returning the delay), a tuple of (start,
3408
                factor, limit) (see L{_RetryDelayCalculator}),
3409
                L{RETRY_REMAINING_TIME} or a number (int or float)
3410
  @type timeout: float
3411
  @param timeout: Total timeout
3412
  @type wait_fn: callable
3413
  @param wait_fn: Waiting function
3414
  @return: Return value of function
3415

3416
  """
3417
  assert callable(fn)
3418
  assert callable(wait_fn)
3419
  assert callable(_time_fn)
3420

    
3421
  if args is None:
3422
    args = []
3423

    
3424
  end_time = _time_fn() + timeout
3425

    
3426
  if callable(delay):
3427
    # External function to calculate delay
3428
    calc_delay = delay
3429

    
3430
  elif isinstance(delay, (tuple, list)):
3431
    # Increasing delay with optional upper boundary
3432
    (start, factor, limit) = delay
3433
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3434

    
3435
  elif delay is RETRY_REMAINING_TIME:
3436
    # Always use the remaining time
3437
    calc_delay = None
3438

    
3439
  else:
3440
    # Static delay
3441
    calc_delay = lambda: delay
3442

    
3443
  assert calc_delay is None or callable(calc_delay)
3444

    
3445
  while True:
3446
    retry_args = []
3447
    try:
3448
      # pylint: disable-msg=W0142
3449
      return fn(*args)
3450
    except RetryAgain, err:
3451
      retry_args = err.args
3452
    except RetryTimeout:
3453
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3454
                                   " handle RetryTimeout")
3455

    
3456
    remaining_time = end_time - _time_fn()
3457

    
3458
    if remaining_time < 0.0:
3459
      # pylint: disable-msg=W0142
3460
      raise RetryTimeout(*retry_args)
3461

    
3462
    assert remaining_time >= 0.0
3463

    
3464
    if calc_delay is None:
3465
      wait_fn(remaining_time)
3466
    else:
3467
      current_delay = calc_delay()
3468
      if current_delay > 0.0:
3469
        wait_fn(current_delay)
3470

    
3471

    
3472
def GetClosedTempfile(*args, **kwargs):
3473
  """Creates a temporary file and returns its path.
3474

3475
  """
3476
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3477
  _CloseFDNoErr(fd)
3478
  return path
3479

    
3480

    
3481
def GenerateSelfSignedX509Cert(common_name, validity):
3482
  """Generates a self-signed X509 certificate.
3483

3484
  @type common_name: string
3485
  @param common_name: commonName value
3486
  @type validity: int
3487
  @param validity: Validity for certificate in seconds
3488

3489
  """
3490
  # Create private and public key
3491
  key = OpenSSL.crypto.PKey()
3492
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3493

    
3494
  # Create self-signed certificate
3495
  cert = OpenSSL.crypto.X509()
3496
  if common_name:
3497
    cert.get_subject().CN = common_name
3498
  cert.set_serial_number(1)
3499
  cert.gmtime_adj_notBefore(0)
3500
  cert.gmtime_adj_notAfter(validity)
3501
  cert.set_issuer(cert.get_subject())
3502
  cert.set_pubkey(key)
3503
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3504

    
3505
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3506
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3507

    
3508
  return (key_pem, cert_pem)
3509

    
3510

    
3511
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3512
  """Legacy function to generate self-signed X509 certificate.
3513

3514
  """
3515
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3516
                                                   validity * 24 * 60 * 60)
3517

    
3518
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3519

    
3520

    
3521
class FileLock(object):
3522
  """Utility class for file locks.
3523

3524
  """
3525
  def __init__(self, fd, filename):
3526
    """Constructor for FileLock.
3527

3528
    @type fd: file
3529
    @param fd: File object
3530
    @type filename: str
3531
    @param filename: Path of the file opened at I{fd}
3532

3533
    """
3534
    self.fd = fd
3535
    self.filename = filename
3536

    
3537
  @classmethod
3538
  def Open(cls, filename):
3539
    """Creates and opens a file to be used as a file-based lock.
3540

3541
    @type filename: string
3542
    @param filename: path to the file to be locked
3543

3544
    """
3545
    # Using "os.open" is necessary to allow both opening existing file
3546
    # read/write and creating if not existing. Vanilla "open" will truncate an
3547
    # existing file -or- allow creating if not existing.
3548
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3549
               filename)
3550

    
3551
  def __del__(self):
3552
    self.Close()
3553

    
3554
  def Close(self):
3555
    """Close the file and release the lock.
3556

3557
    """
3558
    if hasattr(self, "fd") and self.fd:
3559
      self.fd.close()
3560
      self.fd = None
3561

    
3562
  def _flock(self, flag, blocking, timeout, errmsg):
3563
    """Wrapper for fcntl.flock.
3564

3565
    @type flag: int
3566
    @param flag: operation flag
3567
    @type blocking: bool
3568
    @param blocking: whether the operation should be done in blocking mode.
3569
    @type timeout: None or float
3570
    @param timeout: for how long the operation should be retried (implies
3571
                    non-blocking mode).
3572
    @type errmsg: string
3573
    @param errmsg: error message in case operation fails.
3574

3575
    """
3576
    assert self.fd, "Lock was closed"
3577
    assert timeout is None or timeout >= 0, \
3578
      "If specified, timeout must be positive"
3579
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3580

    
3581
    # When a timeout is used, LOCK_NB must always be set
3582
    if not (timeout is None and blocking):
3583
      flag |= fcntl.LOCK_NB
3584

    
3585
    if timeout is None:
3586
      self._Lock(self.fd, flag, timeout)
3587
    else:
3588
      try:
3589
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3590
              args=(self.fd, flag, timeout))
3591
      except RetryTimeout:
3592
        raise errors.LockError(errmsg)
3593

    
3594
  @staticmethod
3595
  def _Lock(fd, flag, timeout):
3596
    try:
3597
      fcntl.flock(fd, flag)
3598
    except IOError, err:
3599
      if timeout is not None and err.errno == errno.EAGAIN:
3600
        raise RetryAgain()
3601

    
3602
      logging.exception("fcntl.flock failed")
3603
      raise
3604

    
3605
  def Exclusive(self, blocking=False, timeout=None):
3606
    """Locks the file in exclusive mode.
3607

3608
    @type blocking: boolean
3609
    @param blocking: whether to block and wait until we
3610
        can lock the file or return immediately
3611
    @type timeout: int or None
3612
    @param timeout: if not None, the duration to wait for the lock
3613
        (in blocking mode)
3614

3615
    """
3616
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3617
                "Failed to lock %s in exclusive mode" % self.filename)
3618

    
3619
  def Shared(self, blocking=False, timeout=None):
3620
    """Locks the file in shared mode.
3621

3622
    @type blocking: boolean
3623
    @param blocking: whether to block and wait until we
3624
        can lock the file or return immediately
3625
    @type timeout: int or None
3626
    @param timeout: if not None, the duration to wait for the lock
3627
        (in blocking mode)
3628

3629
    """
3630
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3631
                "Failed to lock %s in shared mode" % self.filename)
3632

    
3633
  def Unlock(self, blocking=True, timeout=None):
3634
    """Unlocks the file.
3635

3636
    According to C{flock(2)}, unlocking can also be a nonblocking
3637
    operation::
3638

3639
      To make a non-blocking request, include LOCK_NB with any of the above
3640
      operations.
3641

3642
    @type blocking: boolean
3643
    @param blocking: whether to block and wait until we
3644
        can lock the file or return immediately
3645
    @type timeout: int or None
3646
    @param timeout: if not None, the duration to wait for the lock
3647
        (in blocking mode)
3648

3649
    """
3650
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3651
                "Failed to unlock %s" % self.filename)
3652

    
3653

    
3654
class LineSplitter:
3655
  """Splits data chunks into lines separated by newline.
3656

3657
  Instances provide a file-like interface.
3658

3659
  """
3660
  def __init__(self, line_fn, *args):
3661
    """Initializes this class.
3662

3663
    @type line_fn: callable
3664
    @param line_fn: Function called for each line, first parameter is line
3665
    @param args: Extra arguments for L{line_fn}
3666

3667
    """
3668
    assert callable(line_fn)
3669

    
3670
    if args:
3671
      # Python 2.4 doesn't have functools.partial yet
3672
      self._line_fn = \
3673
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3674
    else:
3675
      self._line_fn = line_fn
3676

    
3677
    self._lines = collections.deque()
3678
    self._buffer = ""
3679

    
3680
  def write(self, data):
3681
    parts = (self._buffer + data).split("\n")
3682
    self._buffer = parts.pop()
3683
    self._lines.extend(parts)
3684

    
3685
  def flush(self):
3686
    while self._lines:
3687
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3688

    
3689
  def close(self):
3690
    self.flush()
3691
    if self._buffer:
3692
      self._line_fn(self._buffer)
3693

    
3694

    
3695
def SignalHandled(signums):
3696
  """Signal Handled decoration.
3697

3698
  This special decorator installs a signal handler and then calls the target
3699
  function. The function must accept a 'signal_handlers' keyword argument,
3700
  which will contain a dict indexed by signal number, with SignalHandler
3701
  objects as values.
3702

3703
  The decorator can be safely stacked with iself, to handle multiple signals
3704
  with different handlers.
3705

3706
  @type signums: list
3707
  @param signums: signals to intercept
3708

3709
  """
3710
  def wrap(fn):
3711
    def sig_function(*args, **kwargs):
3712
      assert 'signal_handlers' not in kwargs or \
3713
             kwargs['signal_handlers'] is None or \
3714
             isinstance(kwargs['signal_handlers'], dict), \
3715
             "Wrong signal_handlers parameter in original function call"
3716
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3717
        signal_handlers = kwargs['signal_handlers']
3718
      else:
3719
        signal_handlers = {}
3720
        kwargs['signal_handlers'] = signal_handlers
3721
      sighandler = SignalHandler(signums)
3722
      try:
3723
        for sig in signums:
3724
          signal_handlers[sig] = sighandler
3725
        return fn(*args, **kwargs)
3726
      finally:
3727
        sighandler.Reset()
3728
    return sig_function
3729
  return wrap
3730

    
3731

    
3732
class SignalWakeupFd(object):
3733
  try:
3734
    # This is only supported in Python 2.5 and above (some distributions
3735
    # backported it to Python 2.4)
3736
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3737
  except AttributeError:
3738
    # Not supported
3739
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3740
      return -1
3741
  else:
3742
    def _SetWakeupFd(self, fd):
3743
      return self._set_wakeup_fd_fn(fd)
3744

    
3745
  def __init__(self):
3746
    """Initializes this class.
3747

3748
    """
3749
    (read_fd, write_fd) = os.pipe()
3750

    
3751
    # Once these succeeded, the file descriptors will be closed automatically.
3752
    # Buffer size 0 is important, otherwise .read() with a specified length
3753
    # might buffer data and the file descriptors won't be marked readable.
3754
    self._read_fh = os.fdopen(read_fd, "r", 0)
3755
    self._write_fh = os.fdopen(write_fd, "w", 0)
3756

    
3757
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3758

    
3759
    # Utility functions
3760
    self.fileno = self._read_fh.fileno
3761
    self.read = self._read_fh.read
3762

    
3763
  def Reset(self):
3764
    """Restores the previous wakeup file descriptor.
3765

3766
    """
3767
    if hasattr(self, "_previous") and self._previous is not None:
3768
      self._SetWakeupFd(self._previous)
3769
      self._previous = None
3770

    
3771
  def Notify(self):
3772
    """Notifies the wakeup file descriptor.
3773

3774
    """
3775
    self._write_fh.write("\0")
3776

    
3777
  def __del__(self):
3778
    """Called before object deletion.
3779

3780
    """
3781
    self.Reset()
3782

    
3783

    
3784
class SignalHandler(object):
3785
  """Generic signal handler class.
3786

3787
  It automatically restores the original handler when deconstructed or
3788
  when L{Reset} is called. You can either pass your own handler
3789
  function in or query the L{called} attribute to detect whether the
3790
  signal was sent.
3791

3792
  @type signum: list
3793
  @ivar signum: the signals we handle
3794
  @type called: boolean
3795
  @ivar called: tracks whether any of the signals have been raised
3796

3797
  """
3798
  def __init__(self, signum, handler_fn=None, wakeup=None):
3799
    """Constructs a new SignalHandler instance.
3800

3801
    @type signum: int or list of ints
3802
    @param signum: Single signal number or set of signal numbers
3803
    @type handler_fn: callable
3804
    @param handler_fn: Signal handling function
3805

3806
    """
3807
    assert handler_fn is None or callable(handler_fn)
3808

    
3809
    self.signum = set(signum)
3810
    self.called = False
3811

    
3812
    self._handler_fn = handler_fn
3813
    self._wakeup = wakeup
3814

    
3815
    self._previous = {}
3816
    try:
3817
      for signum in self.signum:
3818
        # Setup handler
3819
        prev_handler = signal.signal(signum, self._HandleSignal)
3820
        try:
3821
          self._previous[signum] = prev_handler
3822
        except:
3823
          # Restore previous handler
3824
          signal.signal(signum, prev_handler)
3825
          raise
3826
    except:
3827
      # Reset all handlers
3828
      self.Reset()
3829
      # Here we have a race condition: a handler may have already been called,
3830
      # but there's not much we can do about it at this point.
3831
      raise
3832

    
3833
  def __del__(self):
3834
    self.Reset()
3835

    
3836
  def Reset(self):
3837
    """Restore previous handler.
3838

3839
    This will reset all the signals to their previous handlers.
3840

3841
    """
3842
    for signum, prev_handler in self._previous.items():
3843
      signal.signal(signum, prev_handler)
3844
      # If successful, remove from dict
3845
      del self._previous[signum]
3846

    
3847
  def Clear(self):
3848
    """Unsets the L{called} flag.
3849

3850
    This function can be used in case a signal may arrive several times.
3851

3852
    """
3853
    self.called = False
3854

    
3855
  def _HandleSignal(self, signum, frame):
3856
    """Actual signal handling function.
3857

3858
    """
3859
    # This is not nice and not absolutely atomic, but it appears to be the only
3860
    # solution in Python -- there are no atomic types.
3861
    self.called = True
3862

    
3863
    if self._wakeup:
3864
      # Notify whoever is interested in signals
3865
      self._wakeup.Notify()
3866

    
3867
    if self._handler_fn:
3868
      self._handler_fn(signum, frame)
3869

    
3870

    
3871
class FieldSet(object):
3872
  """A simple field set.
3873

3874
  Among the features are:
3875
    - checking if a string is among a list of static string or regex objects
3876
    - checking if a whole list of string matches
3877
    - returning the matching groups from a regex match
3878

3879
  Internally, all fields are held as regular expression objects.
3880

3881
  """
3882
  def __init__(self, *items):
3883
    self.items = [re.compile("^%s$" % value) for value in items]
3884

    
3885
  def Extend(self, other_set):
3886
    """Extend the field set with the items from another one"""
3887
    self.items.extend(other_set.items)
3888

    
3889
  def Matches(self, field):
3890
    """Checks if a field matches the current set
3891

3892
    @type field: str
3893
    @param field: the string to match
3894
    @return: either None or a regular expression match object
3895

3896
    """
3897
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3898
      return m
3899
    return None
3900

    
3901
  def NonMatching(self, items):
3902
    """Returns the list of fields not matching the current set
3903

3904
    @type items: list
3905
    @param items: the list of fields to check
3906
    @rtype: list
3907
    @return: list of non-matching fields
3908

3909
    """
3910
    return [val for val in items if not self.Matches(val)]