Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 56e4c4a0

History | View | Annotate | Download (106.3 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  # pylint: disable-msg=F0401
59
  import ctypes
60
except ImportError:
61
  ctypes = None
62

    
63
from ganeti import errors
64
from ganeti import constants
65
from ganeti import compat
66

    
67

    
68
_locksheld = []
69
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
70

    
71
debug_locks = False
72

    
73
#: when set to True, L{RunCmd} is disabled
74
no_fork = False
75

    
76
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
77

    
78
HEX_CHAR_RE = r"[a-zA-Z0-9]"
79
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
80
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
81
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
82
                             HEX_CHAR_RE, HEX_CHAR_RE),
83
                            re.S | re.I)
84

    
85
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
86

    
87
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
88
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
89
#
90
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
91
# pid_t to "int".
92
#
93
# IEEE Std 1003.1-2008:
94
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
95
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
96
_STRUCT_UCRED = "iII"
97
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
98

    
99
# Certificate verification results
100
(CERT_WARNING,
101
 CERT_ERROR) = range(1, 3)
102

    
103
# Flags for mlockall() (from bits/mman.h)
104
_MCL_CURRENT = 1
105
_MCL_FUTURE = 2
106

    
107

    
108
class RunResult(object):
109
  """Holds the result of running external programs.
110

111
  @type exit_code: int
112
  @ivar exit_code: the exit code of the program, or None (if the program
113
      didn't exit())
114
  @type signal: int or None
115
  @ivar signal: the signal that caused the program to finish, or None
116
      (if the program wasn't terminated by a signal)
117
  @type stdout: str
118
  @ivar stdout: the standard output of the program
119
  @type stderr: str
120
  @ivar stderr: the standard error of the program
121
  @type failed: boolean
122
  @ivar failed: True in case the program was
123
      terminated by a signal or exited with a non-zero exit code
124
  @ivar fail_reason: a string detailing the termination reason
125

126
  """
127
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
128
               "failed", "fail_reason", "cmd"]
129

    
130

    
131
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
132
    self.cmd = cmd
133
    self.exit_code = exit_code
134
    self.signal = signal_
135
    self.stdout = stdout
136
    self.stderr = stderr
137
    self.failed = (signal_ is not None or exit_code != 0)
138

    
139
    if self.signal is not None:
140
      self.fail_reason = "terminated by signal %s" % self.signal
141
    elif self.exit_code is not None:
142
      self.fail_reason = "exited with exit code %s" % self.exit_code
143
    else:
144
      self.fail_reason = "unable to determine termination reason"
145

    
146
    if self.failed:
147
      logging.debug("Command '%s' failed (%s); output: %s",
148
                    self.cmd, self.fail_reason, self.output)
149

    
150
  def _GetOutput(self):
151
    """Returns the combined stdout and stderr for easier usage.
152

153
    """
154
    return self.stdout + self.stderr
155

    
156
  output = property(_GetOutput, None, None, "Return full output")
157

    
158

    
159
def _BuildCmdEnvironment(env, reset):
160
  """Builds the environment for an external program.
161

162
  """
163
  if reset:
164
    cmd_env = {}
165
  else:
166
    cmd_env = os.environ.copy()
167
    cmd_env["LC_ALL"] = "C"
168

    
169
  if env is not None:
170
    cmd_env.update(env)
171

    
172
  return cmd_env
173

    
174

    
175
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
176
  """Execute a (shell) command.
177

178
  The command should not read from its standard input, as it will be
179
  closed.
180

181
  @type cmd: string or list
182
  @param cmd: Command to run
183
  @type env: dict
184
  @param env: Additional environment variables
185
  @type output: str
186
  @param output: if desired, the output of the command can be
187
      saved in a file instead of the RunResult instance; this
188
      parameter denotes the file name (if not None)
189
  @type cwd: string
190
  @param cwd: if specified, will be used as the working
191
      directory for the command; the default will be /
192
  @type reset_env: boolean
193
  @param reset_env: whether to reset or keep the default os environment
194
  @rtype: L{RunResult}
195
  @return: RunResult instance
196
  @raise errors.ProgrammerError: if we call this when forks are disabled
197

198
  """
199
  if no_fork:
200
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
201

    
202
  if isinstance(cmd, basestring):
203
    strcmd = cmd
204
    shell = True
205
  else:
206
    cmd = [str(val) for val in cmd]
207
    strcmd = ShellQuoteArgs(cmd)
208
    shell = False
209

    
210
  if output:
211
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
212
  else:
213
    logging.debug("RunCmd %s", strcmd)
214

    
215
  cmd_env = _BuildCmdEnvironment(env, reset_env)
216

    
217
  try:
218
    if output is None:
219
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
220
    else:
221
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
222
      out = err = ""
223
  except OSError, err:
224
    if err.errno == errno.ENOENT:
225
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
226
                               (strcmd, err))
227
    else:
228
      raise
229

    
230
  if status >= 0:
231
    exitcode = status
232
    signal_ = None
233
  else:
234
    exitcode = None
235
    signal_ = -status
236

    
237
  return RunResult(exitcode, signal_, out, err, strcmd)
238

    
239

    
240
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
241
                pidfile=None):
242
  """Start a daemon process after forking twice.
243

244
  @type cmd: string or list
245
  @param cmd: Command to run
246
  @type env: dict
247
  @param env: Additional environment variables
248
  @type cwd: string
249
  @param cwd: Working directory for the program
250
  @type output: string
251
  @param output: Path to file in which to save the output
252
  @type output_fd: int
253
  @param output_fd: File descriptor for output
254
  @type pidfile: string
255
  @param pidfile: Process ID file
256
  @rtype: int
257
  @return: Daemon process ID
258
  @raise errors.ProgrammerError: if we call this when forks are disabled
259

260
  """
261
  if no_fork:
262
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
263
                                 " disabled")
264

    
265
  if output and not (bool(output) ^ (output_fd is not None)):
266
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
267
                                 " specified")
268

    
269
  if isinstance(cmd, basestring):
270
    cmd = ["/bin/sh", "-c", cmd]
271

    
272
  strcmd = ShellQuoteArgs(cmd)
273

    
274
  if output:
275
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
276
  else:
277
    logging.debug("StartDaemon %s", strcmd)
278

    
279
  cmd_env = _BuildCmdEnvironment(env, False)
280

    
281
  # Create pipe for sending PID back
282
  (pidpipe_read, pidpipe_write) = os.pipe()
283
  try:
284
    try:
285
      # Create pipe for sending error messages
286
      (errpipe_read, errpipe_write) = os.pipe()
287
      try:
288
        try:
289
          # First fork
290
          pid = os.fork()
291
          if pid == 0:
292
            try:
293
              # Child process, won't return
294
              _StartDaemonChild(errpipe_read, errpipe_write,
295
                                pidpipe_read, pidpipe_write,
296
                                cmd, cmd_env, cwd,
297
                                output, output_fd, pidfile)
298
            finally:
299
              # Well, maybe child process failed
300
              os._exit(1) # pylint: disable-msg=W0212
301
        finally:
302
          _CloseFDNoErr(errpipe_write)
303

    
304
        # Wait for daemon to be started (or an error message to arrive) and read
305
        # up to 100 KB as an error message
306
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
307
      finally:
308
        _CloseFDNoErr(errpipe_read)
309
    finally:
310
      _CloseFDNoErr(pidpipe_write)
311

    
312
    # Read up to 128 bytes for PID
313
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
314
  finally:
315
    _CloseFDNoErr(pidpipe_read)
316

    
317
  # Try to avoid zombies by waiting for child process
318
  try:
319
    os.waitpid(pid, 0)
320
  except OSError:
321
    pass
322

    
323
  if errormsg:
324
    raise errors.OpExecError("Error when starting daemon process: %r" %
325
                             errormsg)
326

    
327
  try:
328
    return int(pidtext)
329
  except (ValueError, TypeError), err:
330
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
331
                             (pidtext, err))
332

    
333

    
334
def _StartDaemonChild(errpipe_read, errpipe_write,
335
                      pidpipe_read, pidpipe_write,
336
                      args, env, cwd,
337
                      output, fd_output, pidfile):
338
  """Child process for starting daemon.
339

340
  """
341
  try:
342
    # Close parent's side
343
    _CloseFDNoErr(errpipe_read)
344
    _CloseFDNoErr(pidpipe_read)
345

    
346
    # First child process
347
    os.chdir("/")
348
    os.umask(077)
349
    os.setsid()
350

    
351
    # And fork for the second time
352
    pid = os.fork()
353
    if pid != 0:
354
      # Exit first child process
355
      os._exit(0) # pylint: disable-msg=W0212
356

    
357
    # Make sure pipe is closed on execv* (and thereby notifies original process)
358
    SetCloseOnExecFlag(errpipe_write, True)
359

    
360
    # List of file descriptors to be left open
361
    noclose_fds = [errpipe_write]
362

    
363
    # Open PID file
364
    if pidfile:
365
      try:
366
        # TODO: Atomic replace with another locked file instead of writing into
367
        # it after creating
368
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
369

    
370
        # Lock the PID file (and fail if not possible to do so). Any code
371
        # wanting to send a signal to the daemon should try to lock the PID
372
        # file before reading it. If acquiring the lock succeeds, the daemon is
373
        # no longer running and the signal should not be sent.
374
        LockFile(fd_pidfile)
375

    
376
        os.write(fd_pidfile, "%d\n" % os.getpid())
377
      except Exception, err:
378
        raise Exception("Creating and locking PID file failed: %s" % err)
379

    
380
      # Keeping the file open to hold the lock
381
      noclose_fds.append(fd_pidfile)
382

    
383
      SetCloseOnExecFlag(fd_pidfile, False)
384
    else:
385
      fd_pidfile = None
386

    
387
    # Open /dev/null
388
    fd_devnull = os.open(os.devnull, os.O_RDWR)
389

    
390
    assert not output or (bool(output) ^ (fd_output is not None))
391

    
392
    if fd_output is not None:
393
      pass
394
    elif output:
395
      # Open output file
396
      try:
397
        # TODO: Implement flag to set append=yes/no
398
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
399
      except EnvironmentError, err:
400
        raise Exception("Opening output file failed: %s" % err)
401
    else:
402
      fd_output = fd_devnull
403

    
404
    # Redirect standard I/O
405
    os.dup2(fd_devnull, 0)
406
    os.dup2(fd_output, 1)
407
    os.dup2(fd_output, 2)
408

    
409
    # Send daemon PID to parent
410
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
411

    
412
    # Close all file descriptors except stdio and error message pipe
413
    CloseFDs(noclose_fds=noclose_fds)
414

    
415
    # Change working directory
416
    os.chdir(cwd)
417

    
418
    if env is None:
419
      os.execvp(args[0], args)
420
    else:
421
      os.execvpe(args[0], args, env)
422
  except: # pylint: disable-msg=W0702
423
    try:
424
      # Report errors to original process
425
      buf = str(sys.exc_info()[1])
426

    
427
      RetryOnSignal(os.write, errpipe_write, buf)
428
    except: # pylint: disable-msg=W0702
429
      # Ignore errors in error handling
430
      pass
431

    
432
  os._exit(1) # pylint: disable-msg=W0212
433

    
434

    
435
def _RunCmdPipe(cmd, env, via_shell, cwd):
436
  """Run a command and return its output.
437

438
  @type  cmd: string or list
439
  @param cmd: Command to run
440
  @type env: dict
441
  @param env: The environment to use
442
  @type via_shell: bool
443
  @param via_shell: if we should run via the shell
444
  @type cwd: string
445
  @param cwd: the working directory for the program
446
  @rtype: tuple
447
  @return: (out, err, status)
448

449
  """
450
  poller = select.poll()
451
  child = subprocess.Popen(cmd, shell=via_shell,
452
                           stderr=subprocess.PIPE,
453
                           stdout=subprocess.PIPE,
454
                           stdin=subprocess.PIPE,
455
                           close_fds=True, env=env,
456
                           cwd=cwd)
457

    
458
  child.stdin.close()
459
  poller.register(child.stdout, select.POLLIN)
460
  poller.register(child.stderr, select.POLLIN)
461
  out = StringIO()
462
  err = StringIO()
463
  fdmap = {
464
    child.stdout.fileno(): (out, child.stdout),
465
    child.stderr.fileno(): (err, child.stderr),
466
    }
467
  for fd in fdmap:
468
    SetNonblockFlag(fd, True)
469

    
470
  while fdmap:
471
    pollresult = RetryOnSignal(poller.poll)
472

    
473
    for fd, event in pollresult:
474
      if event & select.POLLIN or event & select.POLLPRI:
475
        data = fdmap[fd][1].read()
476
        # no data from read signifies EOF (the same as POLLHUP)
477
        if not data:
478
          poller.unregister(fd)
479
          del fdmap[fd]
480
          continue
481
        fdmap[fd][0].write(data)
482
      if (event & select.POLLNVAL or event & select.POLLHUP or
483
          event & select.POLLERR):
484
        poller.unregister(fd)
485
        del fdmap[fd]
486

    
487
  out = out.getvalue()
488
  err = err.getvalue()
489

    
490
  status = child.wait()
491
  return out, err, status
492

    
493

    
494
def _RunCmdFile(cmd, env, via_shell, output, cwd):
495
  """Run a command and save its output to a file.
496

497
  @type  cmd: string or list
498
  @param cmd: Command to run
499
  @type env: dict
500
  @param env: The environment to use
501
  @type via_shell: bool
502
  @param via_shell: if we should run via the shell
503
  @type output: str
504
  @param output: the filename in which to save the output
505
  @type cwd: string
506
  @param cwd: the working directory for the program
507
  @rtype: int
508
  @return: the exit status
509

510
  """
511
  fh = open(output, "a")
512
  try:
513
    child = subprocess.Popen(cmd, shell=via_shell,
514
                             stderr=subprocess.STDOUT,
515
                             stdout=fh,
516
                             stdin=subprocess.PIPE,
517
                             close_fds=True, env=env,
518
                             cwd=cwd)
519

    
520
    child.stdin.close()
521
    status = child.wait()
522
  finally:
523
    fh.close()
524
  return status
525

    
526

    
527
def SetCloseOnExecFlag(fd, enable):
528
  """Sets or unsets the close-on-exec flag on a file descriptor.
529

530
  @type fd: int
531
  @param fd: File descriptor
532
  @type enable: bool
533
  @param enable: Whether to set or unset it.
534

535
  """
536
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
537

    
538
  if enable:
539
    flags |= fcntl.FD_CLOEXEC
540
  else:
541
    flags &= ~fcntl.FD_CLOEXEC
542

    
543
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
544

    
545

    
546
def SetNonblockFlag(fd, enable):
547
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
548

549
  @type fd: int
550
  @param fd: File descriptor
551
  @type enable: bool
552
  @param enable: Whether to set or unset it
553

554
  """
555
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
556

    
557
  if enable:
558
    flags |= os.O_NONBLOCK
559
  else:
560
    flags &= ~os.O_NONBLOCK
561

    
562
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
563

    
564

    
565
def RetryOnSignal(fn, *args, **kwargs):
566
  """Calls a function again if it failed due to EINTR.
567

568
  """
569
  while True:
570
    try:
571
      return fn(*args, **kwargs)
572
    except EnvironmentError, err:
573
      if err.errno != errno.EINTR:
574
        raise
575
    except (socket.error, select.error), err:
576
      # In python 2.6 and above select.error is an IOError, so it's handled
577
      # above, in 2.5 and below it's not, and it's handled here.
578
      if not (err.args and err.args[0] == errno.EINTR):
579
        raise
580

    
581

    
582
def RunParts(dir_name, env=None, reset_env=False):
583
  """Run Scripts or programs in a directory
584

585
  @type dir_name: string
586
  @param dir_name: absolute path to a directory
587
  @type env: dict
588
  @param env: The environment to use
589
  @type reset_env: boolean
590
  @param reset_env: whether to reset or keep the default os environment
591
  @rtype: list of tuples
592
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
593

594
  """
595
  rr = []
596

    
597
  try:
598
    dir_contents = ListVisibleFiles(dir_name)
599
  except OSError, err:
600
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
601
    return rr
602

    
603
  for relname in sorted(dir_contents):
604
    fname = PathJoin(dir_name, relname)
605
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
606
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
607
      rr.append((relname, constants.RUNPARTS_SKIP, None))
608
    else:
609
      try:
610
        result = RunCmd([fname], env=env, reset_env=reset_env)
611
      except Exception, err: # pylint: disable-msg=W0703
612
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
613
      else:
614
        rr.append((relname, constants.RUNPARTS_RUN, result))
615

    
616
  return rr
617

    
618

    
619
def GetSocketCredentials(sock):
620
  """Returns the credentials of the foreign process connected to a socket.
621

622
  @param sock: Unix socket
623
  @rtype: tuple; (number, number, number)
624
  @return: The PID, UID and GID of the connected foreign process.
625

626
  """
627
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
628
                             _STRUCT_UCRED_SIZE)
629
  return struct.unpack(_STRUCT_UCRED, peercred)
630

    
631

    
632
def RemoveFile(filename):
633
  """Remove a file ignoring some errors.
634

635
  Remove a file, ignoring non-existing ones or directories. Other
636
  errors are passed.
637

638
  @type filename: str
639
  @param filename: the file to be removed
640

641
  """
642
  try:
643
    os.unlink(filename)
644
  except OSError, err:
645
    if err.errno not in (errno.ENOENT, errno.EISDIR):
646
      raise
647

    
648

    
649
def RemoveDir(dirname):
650
  """Remove an empty directory.
651

652
  Remove a directory, ignoring non-existing ones.
653
  Other errors are passed. This includes the case,
654
  where the directory is not empty, so it can't be removed.
655

656
  @type dirname: str
657
  @param dirname: the empty directory to be removed
658

659
  """
660
  try:
661
    os.rmdir(dirname)
662
  except OSError, err:
663
    if err.errno != errno.ENOENT:
664
      raise
665

    
666

    
667
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
668
  """Renames a file.
669

670
  @type old: string
671
  @param old: Original path
672
  @type new: string
673
  @param new: New path
674
  @type mkdir: bool
675
  @param mkdir: Whether to create target directory if it doesn't exist
676
  @type mkdir_mode: int
677
  @param mkdir_mode: Mode for newly created directories
678

679
  """
680
  try:
681
    return os.rename(old, new)
682
  except OSError, err:
683
    # In at least one use case of this function, the job queue, directory
684
    # creation is very rare. Checking for the directory before renaming is not
685
    # as efficient.
686
    if mkdir and err.errno == errno.ENOENT:
687
      # Create directory and try again
688
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
689

    
690
      return os.rename(old, new)
691

    
692
    raise
693

    
694

    
695
def Makedirs(path, mode=0750):
696
  """Super-mkdir; create a leaf directory and all intermediate ones.
697

698
  This is a wrapper around C{os.makedirs} adding error handling not implemented
699
  before Python 2.5.
700

701
  """
702
  try:
703
    os.makedirs(path, mode)
704
  except OSError, err:
705
    # Ignore EEXIST. This is only handled in os.makedirs as included in
706
    # Python 2.5 and above.
707
    if err.errno != errno.EEXIST or not os.path.exists(path):
708
      raise
709

    
710

    
711
def ResetTempfileModule():
712
  """Resets the random name generator of the tempfile module.
713

714
  This function should be called after C{os.fork} in the child process to
715
  ensure it creates a newly seeded random generator. Otherwise it would
716
  generate the same random parts as the parent process. If several processes
717
  race for the creation of a temporary file, this could lead to one not getting
718
  a temporary name.
719

720
  """
721
  # pylint: disable-msg=W0212
722
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
723
    tempfile._once_lock.acquire()
724
    try:
725
      # Reset random name generator
726
      tempfile._name_sequence = None
727
    finally:
728
      tempfile._once_lock.release()
729
  else:
730
    logging.critical("The tempfile module misses at least one of the"
731
                     " '_once_lock' and '_name_sequence' attributes")
732

    
733

    
734
def _FingerprintFile(filename):
735
  """Compute the fingerprint of a file.
736

737
  If the file does not exist, a None will be returned
738
  instead.
739

740
  @type filename: str
741
  @param filename: the filename to checksum
742
  @rtype: str
743
  @return: the hex digest of the sha checksum of the contents
744
      of the file
745

746
  """
747
  if not (os.path.exists(filename) and os.path.isfile(filename)):
748
    return None
749

    
750
  f = open(filename)
751

    
752
  fp = compat.sha1_hash()
753
  while True:
754
    data = f.read(4096)
755
    if not data:
756
      break
757

    
758
    fp.update(data)
759

    
760
  return fp.hexdigest()
761

    
762

    
763
def FingerprintFiles(files):
764
  """Compute fingerprints for a list of files.
765

766
  @type files: list
767
  @param files: the list of filename to fingerprint
768
  @rtype: dict
769
  @return: a dictionary filename: fingerprint, holding only
770
      existing files
771

772
  """
773
  ret = {}
774

    
775
  for filename in files:
776
    cksum = _FingerprintFile(filename)
777
    if cksum:
778
      ret[filename] = cksum
779

    
780
  return ret
781

    
782

    
783
def ForceDictType(target, key_types, allowed_values=None):
784
  """Force the values of a dict to have certain types.
785

786
  @type target: dict
787
  @param target: the dict to update
788
  @type key_types: dict
789
  @param key_types: dict mapping target dict keys to types
790
                    in constants.ENFORCEABLE_TYPES
791
  @type allowed_values: list
792
  @keyword allowed_values: list of specially allowed values
793

794
  """
795
  if allowed_values is None:
796
    allowed_values = []
797

    
798
  if not isinstance(target, dict):
799
    msg = "Expected dictionary, got '%s'" % target
800
    raise errors.TypeEnforcementError(msg)
801

    
802
  for key in target:
803
    if key not in key_types:
804
      msg = "Unknown key '%s'" % key
805
      raise errors.TypeEnforcementError(msg)
806

    
807
    if target[key] in allowed_values:
808
      continue
809

    
810
    ktype = key_types[key]
811
    if ktype not in constants.ENFORCEABLE_TYPES:
812
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
813
      raise errors.ProgrammerError(msg)
814

    
815
    if ktype == constants.VTYPE_STRING:
816
      if not isinstance(target[key], basestring):
817
        if isinstance(target[key], bool) and not target[key]:
818
          target[key] = ''
819
        else:
820
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
821
          raise errors.TypeEnforcementError(msg)
822
    elif ktype == constants.VTYPE_BOOL:
823
      if isinstance(target[key], basestring) and target[key]:
824
        if target[key].lower() == constants.VALUE_FALSE:
825
          target[key] = False
826
        elif target[key].lower() == constants.VALUE_TRUE:
827
          target[key] = True
828
        else:
829
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
830
          raise errors.TypeEnforcementError(msg)
831
      elif target[key]:
832
        target[key] = True
833
      else:
834
        target[key] = False
835
    elif ktype == constants.VTYPE_SIZE:
836
      try:
837
        target[key] = ParseUnit(target[key])
838
      except errors.UnitParseError, err:
839
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
840
              (key, target[key], err)
841
        raise errors.TypeEnforcementError(msg)
842
    elif ktype == constants.VTYPE_INT:
843
      try:
844
        target[key] = int(target[key])
845
      except (ValueError, TypeError):
846
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
847
        raise errors.TypeEnforcementError(msg)
848

    
849

    
850
def _GetProcStatusPath(pid):
851
  """Returns the path for a PID's proc status file.
852

853
  @type pid: int
854
  @param pid: Process ID
855
  @rtype: string
856

857
  """
858
  return "/proc/%d/status" % pid
859

    
860

    
861
def IsProcessAlive(pid):
862
  """Check if a given pid exists on the system.
863

864
  @note: zombie status is not handled, so zombie processes
865
      will be returned as alive
866
  @type pid: int
867
  @param pid: the process ID to check
868
  @rtype: boolean
869
  @return: True if the process exists
870

871
  """
872
  def _TryStat(name):
873
    try:
874
      os.stat(name)
875
      return True
876
    except EnvironmentError, err:
877
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
878
        return False
879
      elif err.errno == errno.EINVAL:
880
        raise RetryAgain(err)
881
      raise
882

    
883
  assert isinstance(pid, int), "pid must be an integer"
884
  if pid <= 0:
885
    return False
886

    
887
  # /proc in a multiprocessor environment can have strange behaviors.
888
  # Retry the os.stat a few times until we get a good result.
889
  try:
890
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
891
                 args=[_GetProcStatusPath(pid)])
892
  except RetryTimeout, err:
893
    err.RaiseInner()
894

    
895

    
896
def _ParseSigsetT(sigset):
897
  """Parse a rendered sigset_t value.
898

899
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
900
  function.
901

902
  @type sigset: string
903
  @param sigset: Rendered signal set from /proc/$pid/status
904
  @rtype: set
905
  @return: Set of all enabled signal numbers
906

907
  """
908
  result = set()
909

    
910
  signum = 0
911
  for ch in reversed(sigset):
912
    chv = int(ch, 16)
913

    
914
    # The following could be done in a loop, but it's easier to read and
915
    # understand in the unrolled form
916
    if chv & 1:
917
      result.add(signum + 1)
918
    if chv & 2:
919
      result.add(signum + 2)
920
    if chv & 4:
921
      result.add(signum + 3)
922
    if chv & 8:
923
      result.add(signum + 4)
924

    
925
    signum += 4
926

    
927
  return result
928

    
929

    
930
def _GetProcStatusField(pstatus, field):
931
  """Retrieves a field from the contents of a proc status file.
932

933
  @type pstatus: string
934
  @param pstatus: Contents of /proc/$pid/status
935
  @type field: string
936
  @param field: Name of field whose value should be returned
937
  @rtype: string
938

939
  """
940
  for line in pstatus.splitlines():
941
    parts = line.split(":", 1)
942

    
943
    if len(parts) < 2 or parts[0] != field:
944
      continue
945

    
946
    return parts[1].strip()
947

    
948
  return None
949

    
950

    
951
def IsProcessHandlingSignal(pid, signum, status_path=None):
952
  """Checks whether a process is handling a signal.
953

954
  @type pid: int
955
  @param pid: Process ID
956
  @type signum: int
957
  @param signum: Signal number
958
  @rtype: bool
959

960
  """
961
  if status_path is None:
962
    status_path = _GetProcStatusPath(pid)
963

    
964
  try:
965
    proc_status = ReadFile(status_path)
966
  except EnvironmentError, err:
967
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
968
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
969
      return False
970
    raise
971

    
972
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
973
  if sigcgt is None:
974
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
975

    
976
  # Now check whether signal is handled
977
  return signum in _ParseSigsetT(sigcgt)
978

    
979

    
980
def ReadPidFile(pidfile):
981
  """Read a pid from a file.
982

983
  @type  pidfile: string
984
  @param pidfile: path to the file containing the pid
985
  @rtype: int
986
  @return: The process id, if the file exists and contains a valid PID,
987
           otherwise 0
988

989
  """
990
  try:
991
    raw_data = ReadOneLineFile(pidfile)
992
  except EnvironmentError, err:
993
    if err.errno != errno.ENOENT:
994
      logging.exception("Can't read pid file")
995
    return 0
996

    
997
  try:
998
    pid = int(raw_data)
999
  except (TypeError, ValueError), err:
1000
    logging.info("Can't parse pid file contents", exc_info=True)
1001
    return 0
1002

    
1003
  return pid
1004

    
1005

    
1006
def ReadLockedPidFile(path):
1007
  """Reads a locked PID file.
1008

1009
  This can be used together with L{StartDaemon}.
1010

1011
  @type path: string
1012
  @param path: Path to PID file
1013
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1014

1015
  """
1016
  try:
1017
    fd = os.open(path, os.O_RDONLY)
1018
  except EnvironmentError, err:
1019
    if err.errno == errno.ENOENT:
1020
      # PID file doesn't exist
1021
      return None
1022
    raise
1023

    
1024
  try:
1025
    try:
1026
      # Try to acquire lock
1027
      LockFile(fd)
1028
    except errors.LockError:
1029
      # Couldn't lock, daemon is running
1030
      return int(os.read(fd, 100))
1031
  finally:
1032
    os.close(fd)
1033

    
1034
  return None
1035

    
1036

    
1037
def MatchNameComponent(key, name_list, case_sensitive=True):
1038
  """Try to match a name against a list.
1039

1040
  This function will try to match a name like test1 against a list
1041
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1042
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1043
  not I{'test1.ex'}. A multiple match will be considered as no match
1044
  at all (e.g. I{'test1'} against C{['test1.example.com',
1045
  'test1.example.org']}), except when the key fully matches an entry
1046
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1047

1048
  @type key: str
1049
  @param key: the name to be searched
1050
  @type name_list: list
1051
  @param name_list: the list of strings against which to search the key
1052
  @type case_sensitive: boolean
1053
  @param case_sensitive: whether to provide a case-sensitive match
1054

1055
  @rtype: None or str
1056
  @return: None if there is no match I{or} if there are multiple matches,
1057
      otherwise the element from the list which matches
1058

1059
  """
1060
  if key in name_list:
1061
    return key
1062

    
1063
  re_flags = 0
1064
  if not case_sensitive:
1065
    re_flags |= re.IGNORECASE
1066
    key = key.upper()
1067
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1068
  names_filtered = []
1069
  string_matches = []
1070
  for name in name_list:
1071
    if mo.match(name) is not None:
1072
      names_filtered.append(name)
1073
      if not case_sensitive and key == name.upper():
1074
        string_matches.append(name)
1075

    
1076
  if len(string_matches) == 1:
1077
    return string_matches[0]
1078
  if len(names_filtered) == 1:
1079
    return names_filtered[0]
1080
  return None
1081

    
1082

    
1083
class HostInfo:
1084
  """Class implementing resolver and hostname functionality
1085

1086
  """
1087
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1088

    
1089
  def __init__(self, name=None):
1090
    """Initialize the host name object.
1091

1092
    If the name argument is not passed, it will use this system's
1093
    name.
1094

1095
    """
1096
    if name is None:
1097
      name = self.SysName()
1098

    
1099
    self.query = name
1100
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1101
    self.ip = self.ipaddrs[0]
1102

    
1103
  def ShortName(self):
1104
    """Returns the hostname without domain.
1105

1106
    """
1107
    return self.name.split('.')[0]
1108

    
1109
  @staticmethod
1110
  def SysName():
1111
    """Return the current system's name.
1112

1113
    This is simply a wrapper over C{socket.gethostname()}.
1114

1115
    """
1116
    return socket.gethostname()
1117

    
1118
  @staticmethod
1119
  def LookupHostname(hostname):
1120
    """Look up hostname
1121

1122
    @type hostname: str
1123
    @param hostname: hostname to look up
1124

1125
    @rtype: tuple
1126
    @return: a tuple (name, aliases, ipaddrs) as returned by
1127
        C{socket.gethostbyname_ex}
1128
    @raise errors.ResolverError: in case of errors in resolving
1129

1130
    """
1131
    try:
1132
      result = socket.gethostbyname_ex(hostname)
1133
    except (socket.gaierror, socket.herror, socket.error), err:
1134
      # hostname not found in DNS, or other socket exception in the
1135
      # (code, description format)
1136
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1137

    
1138
    return result
1139

    
1140
  @classmethod
1141
  def NormalizeName(cls, hostname):
1142
    """Validate and normalize the given hostname.
1143

1144
    @attention: the validation is a bit more relaxed than the standards
1145
        require; most importantly, we allow underscores in names
1146
    @raise errors.OpPrereqError: when the name is not valid
1147

1148
    """
1149
    hostname = hostname.lower()
1150
    if (not cls._VALID_NAME_RE.match(hostname) or
1151
        # double-dots, meaning empty label
1152
        ".." in hostname or
1153
        # empty initial label
1154
        hostname.startswith(".")):
1155
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1156
                                 errors.ECODE_INVAL)
1157
    if hostname.endswith("."):
1158
      hostname = hostname.rstrip(".")
1159
    return hostname
1160

    
1161

    
1162
def ValidateServiceName(name):
1163
  """Validate the given service name.
1164

1165
  @type name: number or string
1166
  @param name: Service name or port specification
1167

1168
  """
1169
  try:
1170
    numport = int(name)
1171
  except (ValueError, TypeError):
1172
    # Non-numeric service name
1173
    valid = _VALID_SERVICE_NAME_RE.match(name)
1174
  else:
1175
    # Numeric port (protocols other than TCP or UDP might need adjustments
1176
    # here)
1177
    valid = (numport >= 0 and numport < (1 << 16))
1178

    
1179
  if not valid:
1180
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1181
                               errors.ECODE_INVAL)
1182

    
1183
  return name
1184

    
1185

    
1186
def GetHostInfo(name=None):
1187
  """Lookup host name and raise an OpPrereqError for failures"""
1188

    
1189
  try:
1190
    return HostInfo(name)
1191
  except errors.ResolverError, err:
1192
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1193
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1194

    
1195

    
1196
def ListVolumeGroups():
1197
  """List volume groups and their size
1198

1199
  @rtype: dict
1200
  @return:
1201
       Dictionary with keys volume name and values
1202
       the size of the volume
1203

1204
  """
1205
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1206
  result = RunCmd(command)
1207
  retval = {}
1208
  if result.failed:
1209
    return retval
1210

    
1211
  for line in result.stdout.splitlines():
1212
    try:
1213
      name, size = line.split()
1214
      size = int(float(size))
1215
    except (IndexError, ValueError), err:
1216
      logging.error("Invalid output from vgs (%s): %s", err, line)
1217
      continue
1218

    
1219
    retval[name] = size
1220

    
1221
  return retval
1222

    
1223

    
1224
def BridgeExists(bridge):
1225
  """Check whether the given bridge exists in the system
1226

1227
  @type bridge: str
1228
  @param bridge: the bridge name to check
1229
  @rtype: boolean
1230
  @return: True if it does
1231

1232
  """
1233
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1234

    
1235

    
1236
def NiceSort(name_list):
1237
  """Sort a list of strings based on digit and non-digit groupings.
1238

1239
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1240
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1241
  'a11']}.
1242

1243
  The sort algorithm breaks each name in groups of either only-digits
1244
  or no-digits. Only the first eight such groups are considered, and
1245
  after that we just use what's left of the string.
1246

1247
  @type name_list: list
1248
  @param name_list: the names to be sorted
1249
  @rtype: list
1250
  @return: a copy of the name list sorted with our algorithm
1251

1252
  """
1253
  _SORTER_BASE = "(\D+|\d+)"
1254
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1255
                                                  _SORTER_BASE, _SORTER_BASE,
1256
                                                  _SORTER_BASE, _SORTER_BASE,
1257
                                                  _SORTER_BASE, _SORTER_BASE)
1258
  _SORTER_RE = re.compile(_SORTER_FULL)
1259
  _SORTER_NODIGIT = re.compile("^\D*$")
1260
  def _TryInt(val):
1261
    """Attempts to convert a variable to integer."""
1262
    if val is None or _SORTER_NODIGIT.match(val):
1263
      return val
1264
    rval = int(val)
1265
    return rval
1266

    
1267
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1268
             for name in name_list]
1269
  to_sort.sort()
1270
  return [tup[1] for tup in to_sort]
1271

    
1272

    
1273
def TryConvert(fn, val):
1274
  """Try to convert a value ignoring errors.
1275

1276
  This function tries to apply function I{fn} to I{val}. If no
1277
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1278
  the result, else it will return the original value. Any other
1279
  exceptions are propagated to the caller.
1280

1281
  @type fn: callable
1282
  @param fn: function to apply to the value
1283
  @param val: the value to be converted
1284
  @return: The converted value if the conversion was successful,
1285
      otherwise the original value.
1286

1287
  """
1288
  try:
1289
    nv = fn(val)
1290
  except (ValueError, TypeError):
1291
    nv = val
1292
  return nv
1293

    
1294

    
1295
def _GenericIsValidIP(family, ip):
1296
  """Generic internal version of ip validation.
1297

1298
  @type family: int
1299
  @param family: socket.AF_INET | socket.AF_INET6
1300
  @type ip: str
1301
  @param ip: the address to be checked
1302
  @rtype: boolean
1303
  @return: True if ip is valid, False otherwise
1304

1305
  """
1306
  try:
1307
    socket.inet_pton(family, ip)
1308
    return True
1309
  except socket.error:
1310
    return False
1311

    
1312

    
1313
def IsValidIP4(ip):
1314
  """Verifies an IPv4 address.
1315

1316
  This function checks if the given address is a valid IPv4 address.
1317

1318
  @type ip: str
1319
  @param ip: the address to be checked
1320
  @rtype: boolean
1321
  @return: True if ip is valid, False otherwise
1322

1323
  """
1324
  return _GenericIsValidIP(socket.AF_INET, ip)
1325

    
1326

    
1327
def IsValidIP6(ip):
1328
  """Verifies an IPv6 address.
1329

1330
  This function checks if the given address is a valid IPv6 address.
1331

1332
  @type ip: str
1333
  @param ip: the address to be checked
1334
  @rtype: boolean
1335
  @return: True if ip is valid, False otherwise
1336

1337
  """
1338
  return _GenericIsValidIP(socket.AF_INET6, ip)
1339

    
1340

    
1341
def IsValidIP(ip):
1342
  """Verifies an IP address.
1343

1344
  This function checks if the given IP address (both IPv4 and IPv6) is valid.
1345

1346
  @type ip: str
1347
  @param ip: the address to be checked
1348
  @rtype: boolean
1349
  @return: True if ip is valid, False otherwise
1350

1351
  """
1352
  return IsValidIP4(ip) or IsValidIP6(ip)
1353

    
1354

    
1355
def GetAddressFamily(ip):
1356
  """Get the address family of the given address.
1357

1358
  @type ip: str
1359
  @param ip: ip address whose family will be returned
1360
  @rtype: int
1361
  @return: socket.AF_INET or socket.AF_INET6
1362
  @raise errors.GenericError: for invalid addresses
1363

1364
  """
1365
  if IsValidIP6(ip):
1366
    return socket.AF_INET6
1367
  elif IsValidIP4(ip):
1368
    return socket.AF_INET
1369
  else:
1370
    raise errors.GenericError("Address %s not valid" % ip)
1371

    
1372

    
1373
def IsValidShellParam(word):
1374
  """Verifies is the given word is safe from the shell's p.o.v.
1375

1376
  This means that we can pass this to a command via the shell and be
1377
  sure that it doesn't alter the command line and is passed as such to
1378
  the actual command.
1379

1380
  Note that we are overly restrictive here, in order to be on the safe
1381
  side.
1382

1383
  @type word: str
1384
  @param word: the word to check
1385
  @rtype: boolean
1386
  @return: True if the word is 'safe'
1387

1388
  """
1389
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1390

    
1391

    
1392
def BuildShellCmd(template, *args):
1393
  """Build a safe shell command line from the given arguments.
1394

1395
  This function will check all arguments in the args list so that they
1396
  are valid shell parameters (i.e. they don't contain shell
1397
  metacharacters). If everything is ok, it will return the result of
1398
  template % args.
1399

1400
  @type template: str
1401
  @param template: the string holding the template for the
1402
      string formatting
1403
  @rtype: str
1404
  @return: the expanded command line
1405

1406
  """
1407
  for word in args:
1408
    if not IsValidShellParam(word):
1409
      raise errors.ProgrammerError("Shell argument '%s' contains"
1410
                                   " invalid characters" % word)
1411
  return template % args
1412

    
1413

    
1414
def FormatUnit(value, units):
1415
  """Formats an incoming number of MiB with the appropriate unit.
1416

1417
  @type value: int
1418
  @param value: integer representing the value in MiB (1048576)
1419
  @type units: char
1420
  @param units: the type of formatting we should do:
1421
      - 'h' for automatic scaling
1422
      - 'm' for MiBs
1423
      - 'g' for GiBs
1424
      - 't' for TiBs
1425
  @rtype: str
1426
  @return: the formatted value (with suffix)
1427

1428
  """
1429
  if units not in ('m', 'g', 't', 'h'):
1430
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1431

    
1432
  suffix = ''
1433

    
1434
  if units == 'm' or (units == 'h' and value < 1024):
1435
    if units == 'h':
1436
      suffix = 'M'
1437
    return "%d%s" % (round(value, 0), suffix)
1438

    
1439
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1440
    if units == 'h':
1441
      suffix = 'G'
1442
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1443

    
1444
  else:
1445
    if units == 'h':
1446
      suffix = 'T'
1447
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1448

    
1449

    
1450
def ParseUnit(input_string):
1451
  """Tries to extract number and scale from the given string.
1452

1453
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1454
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1455
  is always an int in MiB.
1456

1457
  """
1458
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1459
  if not m:
1460
    raise errors.UnitParseError("Invalid format")
1461

    
1462
  value = float(m.groups()[0])
1463

    
1464
  unit = m.groups()[1]
1465
  if unit:
1466
    lcunit = unit.lower()
1467
  else:
1468
    lcunit = 'm'
1469

    
1470
  if lcunit in ('m', 'mb', 'mib'):
1471
    # Value already in MiB
1472
    pass
1473

    
1474
  elif lcunit in ('g', 'gb', 'gib'):
1475
    value *= 1024
1476

    
1477
  elif lcunit in ('t', 'tb', 'tib'):
1478
    value *= 1024 * 1024
1479

    
1480
  else:
1481
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1482

    
1483
  # Make sure we round up
1484
  if int(value) < value:
1485
    value += 1
1486

    
1487
  # Round up to the next multiple of 4
1488
  value = int(value)
1489
  if value % 4:
1490
    value += 4 - value % 4
1491

    
1492
  return value
1493

    
1494

    
1495
def AddAuthorizedKey(file_name, key):
1496
  """Adds an SSH public key to an authorized_keys file.
1497

1498
  @type file_name: str
1499
  @param file_name: path to authorized_keys file
1500
  @type key: str
1501
  @param key: string containing key
1502

1503
  """
1504
  key_fields = key.split()
1505

    
1506
  f = open(file_name, 'a+')
1507
  try:
1508
    nl = True
1509
    for line in f:
1510
      # Ignore whitespace changes
1511
      if line.split() == key_fields:
1512
        break
1513
      nl = line.endswith('\n')
1514
    else:
1515
      if not nl:
1516
        f.write("\n")
1517
      f.write(key.rstrip('\r\n'))
1518
      f.write("\n")
1519
      f.flush()
1520
  finally:
1521
    f.close()
1522

    
1523

    
1524
def RemoveAuthorizedKey(file_name, key):
1525
  """Removes an SSH public key from an authorized_keys file.
1526

1527
  @type file_name: str
1528
  @param file_name: path to authorized_keys file
1529
  @type key: str
1530
  @param key: string containing key
1531

1532
  """
1533
  key_fields = key.split()
1534

    
1535
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1536
  try:
1537
    out = os.fdopen(fd, 'w')
1538
    try:
1539
      f = open(file_name, 'r')
1540
      try:
1541
        for line in f:
1542
          # Ignore whitespace changes while comparing lines
1543
          if line.split() != key_fields:
1544
            out.write(line)
1545

    
1546
        out.flush()
1547
        os.rename(tmpname, file_name)
1548
      finally:
1549
        f.close()
1550
    finally:
1551
      out.close()
1552
  except:
1553
    RemoveFile(tmpname)
1554
    raise
1555

    
1556

    
1557
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1558
  """Sets the name of an IP address and hostname in /etc/hosts.
1559

1560
  @type file_name: str
1561
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1562
  @type ip: str
1563
  @param ip: the IP address
1564
  @type hostname: str
1565
  @param hostname: the hostname to be added
1566
  @type aliases: list
1567
  @param aliases: the list of aliases to add for the hostname
1568

1569
  """
1570
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1571
  # Ensure aliases are unique
1572
  aliases = UniqueSequence([hostname] + aliases)[1:]
1573

    
1574
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1575
  try:
1576
    out = os.fdopen(fd, 'w')
1577
    try:
1578
      f = open(file_name, 'r')
1579
      try:
1580
        for line in f:
1581
          fields = line.split()
1582
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1583
            continue
1584
          out.write(line)
1585

    
1586
        out.write("%s\t%s" % (ip, hostname))
1587
        if aliases:
1588
          out.write(" %s" % ' '.join(aliases))
1589
        out.write('\n')
1590

    
1591
        out.flush()
1592
        os.fsync(out)
1593
        os.chmod(tmpname, 0644)
1594
        os.rename(tmpname, file_name)
1595
      finally:
1596
        f.close()
1597
    finally:
1598
      out.close()
1599
  except:
1600
    RemoveFile(tmpname)
1601
    raise
1602

    
1603

    
1604
def AddHostToEtcHosts(hostname):
1605
  """Wrapper around SetEtcHostsEntry.
1606

1607
  @type hostname: str
1608
  @param hostname: a hostname that will be resolved and added to
1609
      L{constants.ETC_HOSTS}
1610

1611
  """
1612
  hi = HostInfo(name=hostname)
1613
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1614

    
1615

    
1616
def RemoveEtcHostsEntry(file_name, hostname):
1617
  """Removes a hostname from /etc/hosts.
1618

1619
  IP addresses without names are removed from the file.
1620

1621
  @type file_name: str
1622
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1623
  @type hostname: str
1624
  @param hostname: the hostname to be removed
1625

1626
  """
1627
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1628
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1629
  try:
1630
    out = os.fdopen(fd, 'w')
1631
    try:
1632
      f = open(file_name, 'r')
1633
      try:
1634
        for line in f:
1635
          fields = line.split()
1636
          if len(fields) > 1 and not fields[0].startswith('#'):
1637
            names = fields[1:]
1638
            if hostname in names:
1639
              while hostname in names:
1640
                names.remove(hostname)
1641
              if names:
1642
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1643
              continue
1644

    
1645
          out.write(line)
1646

    
1647
        out.flush()
1648
        os.fsync(out)
1649
        os.chmod(tmpname, 0644)
1650
        os.rename(tmpname, file_name)
1651
      finally:
1652
        f.close()
1653
    finally:
1654
      out.close()
1655
  except:
1656
    RemoveFile(tmpname)
1657
    raise
1658

    
1659

    
1660
def RemoveHostFromEtcHosts(hostname):
1661
  """Wrapper around RemoveEtcHostsEntry.
1662

1663
  @type hostname: str
1664
  @param hostname: hostname that will be resolved and its
1665
      full and shot name will be removed from
1666
      L{constants.ETC_HOSTS}
1667

1668
  """
1669
  hi = HostInfo(name=hostname)
1670
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1671
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1672

    
1673

    
1674
def TimestampForFilename():
1675
  """Returns the current time formatted for filenames.
1676

1677
  The format doesn't contain colons as some shells and applications them as
1678
  separators.
1679

1680
  """
1681
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1682

    
1683

    
1684
def CreateBackup(file_name):
1685
  """Creates a backup of a file.
1686

1687
  @type file_name: str
1688
  @param file_name: file to be backed up
1689
  @rtype: str
1690
  @return: the path to the newly created backup
1691
  @raise errors.ProgrammerError: for invalid file names
1692

1693
  """
1694
  if not os.path.isfile(file_name):
1695
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1696
                                file_name)
1697

    
1698
  prefix = ("%s.backup-%s." %
1699
            (os.path.basename(file_name), TimestampForFilename()))
1700
  dir_name = os.path.dirname(file_name)
1701

    
1702
  fsrc = open(file_name, 'rb')
1703
  try:
1704
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1705
    fdst = os.fdopen(fd, 'wb')
1706
    try:
1707
      logging.debug("Backing up %s at %s", file_name, backup_name)
1708
      shutil.copyfileobj(fsrc, fdst)
1709
    finally:
1710
      fdst.close()
1711
  finally:
1712
    fsrc.close()
1713

    
1714
  return backup_name
1715

    
1716

    
1717
def ShellQuote(value):
1718
  """Quotes shell argument according to POSIX.
1719

1720
  @type value: str
1721
  @param value: the argument to be quoted
1722
  @rtype: str
1723
  @return: the quoted value
1724

1725
  """
1726
  if _re_shell_unquoted.match(value):
1727
    return value
1728
  else:
1729
    return "'%s'" % value.replace("'", "'\\''")
1730

    
1731

    
1732
def ShellQuoteArgs(args):
1733
  """Quotes a list of shell arguments.
1734

1735
  @type args: list
1736
  @param args: list of arguments to be quoted
1737
  @rtype: str
1738
  @return: the quoted arguments concatenated with spaces
1739

1740
  """
1741
  return ' '.join([ShellQuote(i) for i in args])
1742

    
1743

    
1744
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1745
  """Simple ping implementation using TCP connect(2).
1746

1747
  Check if the given IP is reachable by doing attempting a TCP connect
1748
  to it.
1749

1750
  @type target: str
1751
  @param target: the IP or hostname to ping
1752
  @type port: int
1753
  @param port: the port to connect to
1754
  @type timeout: int
1755
  @param timeout: the timeout on the connection attempt
1756
  @type live_port_needed: boolean
1757
  @param live_port_needed: whether a closed port will cause the
1758
      function to return failure, as if there was a timeout
1759
  @type source: str or None
1760
  @param source: if specified, will cause the connect to be made
1761
      from this specific source address; failures to bind other
1762
      than C{EADDRNOTAVAIL} will be ignored
1763

1764
  """
1765
  try:
1766
    family = GetAddressFamily(target)
1767
  except errors.GenericError:
1768
    return False
1769

    
1770
  sock = socket.socket(family, socket.SOCK_STREAM)
1771
  success = False
1772

    
1773
  if source is not None:
1774
    try:
1775
      sock.bind((source, 0))
1776
    except socket.error, (errcode, _):
1777
      if errcode == errno.EADDRNOTAVAIL:
1778
        success = False
1779

    
1780
  sock.settimeout(timeout)
1781

    
1782
  try:
1783
    sock.connect((target, port))
1784
    sock.close()
1785
    success = True
1786
  except socket.timeout:
1787
    success = False
1788
  except socket.error, (errcode, _):
1789
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1790

    
1791
  return success
1792

    
1793

    
1794
def OwnIpAddress(address):
1795
  """Check if the current host has the the given IP address.
1796

1797
  This is done by trying to bind the given address. We return True if we
1798
  succeed or false if a socket.error is raised.
1799

1800
  @type address: string
1801
  @param address: the address to check
1802
  @rtype: bool
1803
  @return: True if we own the address
1804

1805
  """
1806
  family = GetAddressFamily(address)
1807
  s = socket.socket(family, socket.SOCK_DGRAM)
1808
  success = False
1809
  try:
1810
    s.bind((address, 0))
1811
    success = True
1812
  except socket.error:
1813
    success = False
1814
  finally:
1815
    s.close()
1816
  return success
1817

    
1818

    
1819
def ListVisibleFiles(path):
1820
  """Returns a list of visible files in a directory.
1821

1822
  @type path: str
1823
  @param path: the directory to enumerate
1824
  @rtype: list
1825
  @return: the list of all files not starting with a dot
1826
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1827

1828
  """
1829
  if not IsNormAbsPath(path):
1830
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1831
                                 " absolute/normalized: '%s'" % path)
1832
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1833
  return files
1834

    
1835

    
1836
def GetHomeDir(user, default=None):
1837
  """Try to get the homedir of the given user.
1838

1839
  The user can be passed either as a string (denoting the name) or as
1840
  an integer (denoting the user id). If the user is not found, the
1841
  'default' argument is returned, which defaults to None.
1842

1843
  """
1844
  try:
1845
    if isinstance(user, basestring):
1846
      result = pwd.getpwnam(user)
1847
    elif isinstance(user, (int, long)):
1848
      result = pwd.getpwuid(user)
1849
    else:
1850
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1851
                                   type(user))
1852
  except KeyError:
1853
    return default
1854
  return result.pw_dir
1855

    
1856

    
1857
def NewUUID():
1858
  """Returns a random UUID.
1859

1860
  @note: This is a Linux-specific method as it uses the /proc
1861
      filesystem.
1862
  @rtype: str
1863

1864
  """
1865
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1866

    
1867

    
1868
def GenerateSecret(numbytes=20):
1869
  """Generates a random secret.
1870

1871
  This will generate a pseudo-random secret returning an hex string
1872
  (so that it can be used where an ASCII string is needed).
1873

1874
  @param numbytes: the number of bytes which will be represented by the returned
1875
      string (defaulting to 20, the length of a SHA1 hash)
1876
  @rtype: str
1877
  @return: an hex representation of the pseudo-random sequence
1878

1879
  """
1880
  return os.urandom(numbytes).encode('hex')
1881

    
1882

    
1883
def EnsureDirs(dirs):
1884
  """Make required directories, if they don't exist.
1885

1886
  @param dirs: list of tuples (dir_name, dir_mode)
1887
  @type dirs: list of (string, integer)
1888

1889
  """
1890
  for dir_name, dir_mode in dirs:
1891
    try:
1892
      os.mkdir(dir_name, dir_mode)
1893
    except EnvironmentError, err:
1894
      if err.errno != errno.EEXIST:
1895
        raise errors.GenericError("Cannot create needed directory"
1896
                                  " '%s': %s" % (dir_name, err))
1897
    try:
1898
      os.chmod(dir_name, dir_mode)
1899
    except EnvironmentError, err:
1900
      raise errors.GenericError("Cannot change directory permissions on"
1901
                                " '%s': %s" % (dir_name, err))
1902
    if not os.path.isdir(dir_name):
1903
      raise errors.GenericError("%s is not a directory" % dir_name)
1904

    
1905

    
1906
def ReadFile(file_name, size=-1):
1907
  """Reads a file.
1908

1909
  @type size: int
1910
  @param size: Read at most size bytes (if negative, entire file)
1911
  @rtype: str
1912
  @return: the (possibly partial) content of the file
1913

1914
  """
1915
  f = open(file_name, "r")
1916
  try:
1917
    return f.read(size)
1918
  finally:
1919
    f.close()
1920

    
1921

    
1922
def WriteFile(file_name, fn=None, data=None,
1923
              mode=None, uid=-1, gid=-1,
1924
              atime=None, mtime=None, close=True,
1925
              dry_run=False, backup=False,
1926
              prewrite=None, postwrite=None):
1927
  """(Over)write a file atomically.
1928

1929
  The file_name and either fn (a function taking one argument, the
1930
  file descriptor, and which should write the data to it) or data (the
1931
  contents of the file) must be passed. The other arguments are
1932
  optional and allow setting the file mode, owner and group, and the
1933
  mtime/atime of the file.
1934

1935
  If the function doesn't raise an exception, it has succeeded and the
1936
  target file has the new contents. If the function has raised an
1937
  exception, an existing target file should be unmodified and the
1938
  temporary file should be removed.
1939

1940
  @type file_name: str
1941
  @param file_name: the target filename
1942
  @type fn: callable
1943
  @param fn: content writing function, called with
1944
      file descriptor as parameter
1945
  @type data: str
1946
  @param data: contents of the file
1947
  @type mode: int
1948
  @param mode: file mode
1949
  @type uid: int
1950
  @param uid: the owner of the file
1951
  @type gid: int
1952
  @param gid: the group of the file
1953
  @type atime: int
1954
  @param atime: a custom access time to be set on the file
1955
  @type mtime: int
1956
  @param mtime: a custom modification time to be set on the file
1957
  @type close: boolean
1958
  @param close: whether to close file after writing it
1959
  @type prewrite: callable
1960
  @param prewrite: function to be called before writing content
1961
  @type postwrite: callable
1962
  @param postwrite: function to be called after writing content
1963

1964
  @rtype: None or int
1965
  @return: None if the 'close' parameter evaluates to True,
1966
      otherwise the file descriptor
1967

1968
  @raise errors.ProgrammerError: if any of the arguments are not valid
1969

1970
  """
1971
  if not os.path.isabs(file_name):
1972
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1973
                                 " absolute: '%s'" % file_name)
1974

    
1975
  if [fn, data].count(None) != 1:
1976
    raise errors.ProgrammerError("fn or data required")
1977

    
1978
  if [atime, mtime].count(None) == 1:
1979
    raise errors.ProgrammerError("Both atime and mtime must be either"
1980
                                 " set or None")
1981

    
1982
  if backup and not dry_run and os.path.isfile(file_name):
1983
    CreateBackup(file_name)
1984

    
1985
  dir_name, base_name = os.path.split(file_name)
1986
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1987
  do_remove = True
1988
  # here we need to make sure we remove the temp file, if any error
1989
  # leaves it in place
1990
  try:
1991
    if uid != -1 or gid != -1:
1992
      os.chown(new_name, uid, gid)
1993
    if mode:
1994
      os.chmod(new_name, mode)
1995
    if callable(prewrite):
1996
      prewrite(fd)
1997
    if data is not None:
1998
      os.write(fd, data)
1999
    else:
2000
      fn(fd)
2001
    if callable(postwrite):
2002
      postwrite(fd)
2003
    os.fsync(fd)
2004
    if atime is not None and mtime is not None:
2005
      os.utime(new_name, (atime, mtime))
2006
    if not dry_run:
2007
      os.rename(new_name, file_name)
2008
      do_remove = False
2009
  finally:
2010
    if close:
2011
      os.close(fd)
2012
      result = None
2013
    else:
2014
      result = fd
2015
    if do_remove:
2016
      RemoveFile(new_name)
2017

    
2018
  return result
2019

    
2020

    
2021
def ReadOneLineFile(file_name, strict=False):
2022
  """Return the first non-empty line from a file.
2023

2024
  @type strict: boolean
2025
  @param strict: if True, abort if the file has more than one
2026
      non-empty line
2027

2028
  """
2029
  file_lines = ReadFile(file_name).splitlines()
2030
  full_lines = filter(bool, file_lines)
2031
  if not file_lines or not full_lines:
2032
    raise errors.GenericError("No data in one-liner file %s" % file_name)
2033
  elif strict and len(full_lines) > 1:
2034
    raise errors.GenericError("Too many lines in one-liner file %s" %
2035
                              file_name)
2036
  return full_lines[0]
2037

    
2038

    
2039
def FirstFree(seq, base=0):
2040
  """Returns the first non-existing integer from seq.
2041

2042
  The seq argument should be a sorted list of positive integers. The
2043
  first time the index of an element is smaller than the element
2044
  value, the index will be returned.
2045

2046
  The base argument is used to start at a different offset,
2047
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
2048

2049
  Example: C{[0, 1, 3]} will return I{2}.
2050

2051
  @type seq: sequence
2052
  @param seq: the sequence to be analyzed.
2053
  @type base: int
2054
  @param base: use this value as the base index of the sequence
2055
  @rtype: int
2056
  @return: the first non-used index in the sequence
2057

2058
  """
2059
  for idx, elem in enumerate(seq):
2060
    assert elem >= base, "Passed element is higher than base offset"
2061
    if elem > idx + base:
2062
      # idx is not used
2063
      return idx + base
2064
  return None
2065

    
2066

    
2067
def SingleWaitForFdCondition(fdobj, event, timeout):
2068
  """Waits for a condition to occur on the socket.
2069

2070
  Immediately returns at the first interruption.
2071

2072
  @type fdobj: integer or object supporting a fileno() method
2073
  @param fdobj: entity to wait for events on
2074
  @type event: integer
2075
  @param event: ORed condition (see select module)
2076
  @type timeout: float or None
2077
  @param timeout: Timeout in seconds
2078
  @rtype: int or None
2079
  @return: None for timeout, otherwise occured conditions
2080

2081
  """
2082
  check = (event | select.POLLPRI |
2083
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
2084

    
2085
  if timeout is not None:
2086
    # Poller object expects milliseconds
2087
    timeout *= 1000
2088

    
2089
  poller = select.poll()
2090
  poller.register(fdobj, event)
2091
  try:
2092
    # TODO: If the main thread receives a signal and we have no timeout, we
2093
    # could wait forever. This should check a global "quit" flag or something
2094
    # every so often.
2095
    io_events = poller.poll(timeout)
2096
  except select.error, err:
2097
    if err[0] != errno.EINTR:
2098
      raise
2099
    io_events = []
2100
  if io_events and io_events[0][1] & check:
2101
    return io_events[0][1]
2102
  else:
2103
    return None
2104

    
2105

    
2106
class FdConditionWaiterHelper(object):
2107
  """Retry helper for WaitForFdCondition.
2108

2109
  This class contains the retried and wait functions that make sure
2110
  WaitForFdCondition can continue waiting until the timeout is actually
2111
  expired.
2112

2113
  """
2114

    
2115
  def __init__(self, timeout):
2116
    self.timeout = timeout
2117

    
2118
  def Poll(self, fdobj, event):
2119
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2120
    if result is None:
2121
      raise RetryAgain()
2122
    else:
2123
      return result
2124

    
2125
  def UpdateTimeout(self, timeout):
2126
    self.timeout = timeout
2127

    
2128

    
2129
def WaitForFdCondition(fdobj, event, timeout):
2130
  """Waits for a condition to occur on the socket.
2131

2132
  Retries until the timeout is expired, even if interrupted.
2133

2134
  @type fdobj: integer or object supporting a fileno() method
2135
  @param fdobj: entity to wait for events on
2136
  @type event: integer
2137
  @param event: ORed condition (see select module)
2138
  @type timeout: float or None
2139
  @param timeout: Timeout in seconds
2140
  @rtype: int or None
2141
  @return: None for timeout, otherwise occured conditions
2142

2143
  """
2144
  if timeout is not None:
2145
    retrywaiter = FdConditionWaiterHelper(timeout)
2146
    try:
2147
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2148
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2149
    except RetryTimeout:
2150
      result = None
2151
  else:
2152
    result = None
2153
    while result is None:
2154
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2155
  return result
2156

    
2157

    
2158
def UniqueSequence(seq):
2159
  """Returns a list with unique elements.
2160

2161
  Element order is preserved.
2162

2163
  @type seq: sequence
2164
  @param seq: the sequence with the source elements
2165
  @rtype: list
2166
  @return: list of unique elements from seq
2167

2168
  """
2169
  seen = set()
2170
  return [i for i in seq if i not in seen and not seen.add(i)]
2171

    
2172

    
2173
def NormalizeAndValidateMac(mac):
2174
  """Normalizes and check if a MAC address is valid.
2175

2176
  Checks whether the supplied MAC address is formally correct, only
2177
  accepts colon separated format. Normalize it to all lower.
2178

2179
  @type mac: str
2180
  @param mac: the MAC to be validated
2181
  @rtype: str
2182
  @return: returns the normalized and validated MAC.
2183

2184
  @raise errors.OpPrereqError: If the MAC isn't valid
2185

2186
  """
2187
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2188
  if not mac_check.match(mac):
2189
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2190
                               mac, errors.ECODE_INVAL)
2191

    
2192
  return mac.lower()
2193

    
2194

    
2195
def TestDelay(duration):
2196
  """Sleep for a fixed amount of time.
2197

2198
  @type duration: float
2199
  @param duration: the sleep duration
2200
  @rtype: boolean
2201
  @return: False for negative value, True otherwise
2202

2203
  """
2204
  if duration < 0:
2205
    return False, "Invalid sleep duration"
2206
  time.sleep(duration)
2207
  return True, None
2208

    
2209

    
2210
def _CloseFDNoErr(fd, retries=5):
2211
  """Close a file descriptor ignoring errors.
2212

2213
  @type fd: int
2214
  @param fd: the file descriptor
2215
  @type retries: int
2216
  @param retries: how many retries to make, in case we get any
2217
      other error than EBADF
2218

2219
  """
2220
  try:
2221
    os.close(fd)
2222
  except OSError, err:
2223
    if err.errno != errno.EBADF:
2224
      if retries > 0:
2225
        _CloseFDNoErr(fd, retries - 1)
2226
    # else either it's closed already or we're out of retries, so we
2227
    # ignore this and go on
2228

    
2229

    
2230
def CloseFDs(noclose_fds=None):
2231
  """Close file descriptors.
2232

2233
  This closes all file descriptors above 2 (i.e. except
2234
  stdin/out/err).
2235

2236
  @type noclose_fds: list or None
2237
  @param noclose_fds: if given, it denotes a list of file descriptor
2238
      that should not be closed
2239

2240
  """
2241
  # Default maximum for the number of available file descriptors.
2242
  if 'SC_OPEN_MAX' in os.sysconf_names:
2243
    try:
2244
      MAXFD = os.sysconf('SC_OPEN_MAX')
2245
      if MAXFD < 0:
2246
        MAXFD = 1024
2247
    except OSError:
2248
      MAXFD = 1024
2249
  else:
2250
    MAXFD = 1024
2251
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2252
  if (maxfd == resource.RLIM_INFINITY):
2253
    maxfd = MAXFD
2254

    
2255
  # Iterate through and close all file descriptors (except the standard ones)
2256
  for fd in range(3, maxfd):
2257
    if noclose_fds and fd in noclose_fds:
2258
      continue
2259
    _CloseFDNoErr(fd)
2260

    
2261

    
2262
def Mlockall():
2263
  """Lock current process' virtual address space into RAM.
2264

2265
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2266
  see mlock(2) for more details. This function requires ctypes module.
2267

2268
  """
2269
  if ctypes is None:
2270
    logging.warning("Cannot set memory lock, ctypes module not found")
2271
    return
2272

    
2273
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2274
  if libc is None:
2275
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2276
    return
2277

    
2278
  # Some older version of the ctypes module don't have built-in functionality
2279
  # to access the errno global variable, where function error codes are stored.
2280
  # By declaring this variable as a pointer to an integer we can then access
2281
  # its value correctly, should the mlockall call fail, in order to see what
2282
  # the actual error code was.
2283
  # pylint: disable-msg=W0212
2284
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2285

    
2286
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2287
    # pylint: disable-msg=W0212
2288
    logging.error("Cannot set memory lock: %s",
2289
                  os.strerror(libc.__errno_location().contents.value))
2290
    return
2291

    
2292
  logging.debug("Memory lock set")
2293

    
2294

    
2295
def Daemonize(logfile, run_uid, run_gid):
2296
  """Daemonize the current process.
2297

2298
  This detaches the current process from the controlling terminal and
2299
  runs it in the background as a daemon.
2300

2301
  @type logfile: str
2302
  @param logfile: the logfile to which we should redirect stdout/stderr
2303
  @type run_uid: int
2304
  @param run_uid: Run the child under this uid
2305
  @type run_gid: int
2306
  @param run_gid: Run the child under this gid
2307
  @rtype: int
2308
  @return: the value zero
2309

2310
  """
2311
  # pylint: disable-msg=W0212
2312
  # yes, we really want os._exit
2313
  UMASK = 077
2314
  WORKDIR = "/"
2315

    
2316
  # this might fail
2317
  pid = os.fork()
2318
  if (pid == 0):  # The first child.
2319
    os.setsid()
2320
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2321
    #        make sure to check for config permission and bail out when invoked
2322
    #        with wrong user.
2323
    os.setgid(run_gid)
2324
    os.setuid(run_uid)
2325
    # this might fail
2326
    pid = os.fork() # Fork a second child.
2327
    if (pid == 0):  # The second child.
2328
      os.chdir(WORKDIR)
2329
      os.umask(UMASK)
2330
    else:
2331
      # exit() or _exit()?  See below.
2332
      os._exit(0) # Exit parent (the first child) of the second child.
2333
  else:
2334
    os._exit(0) # Exit parent of the first child.
2335

    
2336
  for fd in range(3):
2337
    _CloseFDNoErr(fd)
2338
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2339
  assert i == 0, "Can't close/reopen stdin"
2340
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2341
  assert i == 1, "Can't close/reopen stdout"
2342
  # Duplicate standard output to standard error.
2343
  os.dup2(1, 2)
2344
  return 0
2345

    
2346

    
2347
def DaemonPidFileName(name):
2348
  """Compute a ganeti pid file absolute path
2349

2350
  @type name: str
2351
  @param name: the daemon name
2352
  @rtype: str
2353
  @return: the full path to the pidfile corresponding to the given
2354
      daemon name
2355

2356
  """
2357
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2358

    
2359

    
2360
def EnsureDaemon(name):
2361
  """Check for and start daemon if not alive.
2362

2363
  """
2364
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2365
  if result.failed:
2366
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2367
                  name, result.fail_reason, result.output)
2368
    return False
2369

    
2370
  return True
2371

    
2372

    
2373
def StopDaemon(name):
2374
  """Stop daemon
2375

2376
  """
2377
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2378
  if result.failed:
2379
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2380
                  name, result.fail_reason, result.output)
2381
    return False
2382

    
2383
  return True
2384

    
2385

    
2386
def WritePidFile(name):
2387
  """Write the current process pidfile.
2388

2389
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2390

2391
  @type name: str
2392
  @param name: the daemon name to use
2393
  @raise errors.GenericError: if the pid file already exists and
2394
      points to a live process
2395

2396
  """
2397
  pid = os.getpid()
2398
  pidfilename = DaemonPidFileName(name)
2399
  if IsProcessAlive(ReadPidFile(pidfilename)):
2400
    raise errors.GenericError("%s contains a live process" % pidfilename)
2401

    
2402
  WriteFile(pidfilename, data="%d\n" % pid)
2403

    
2404

    
2405
def RemovePidFile(name):
2406
  """Remove the current process pidfile.
2407

2408
  Any errors are ignored.
2409

2410
  @type name: str
2411
  @param name: the daemon name used to derive the pidfile name
2412

2413
  """
2414
  pidfilename = DaemonPidFileName(name)
2415
  # TODO: we could check here that the file contains our pid
2416
  try:
2417
    RemoveFile(pidfilename)
2418
  except: # pylint: disable-msg=W0702
2419
    pass
2420

    
2421

    
2422
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2423
                waitpid=False):
2424
  """Kill a process given by its pid.
2425

2426
  @type pid: int
2427
  @param pid: The PID to terminate.
2428
  @type signal_: int
2429
  @param signal_: The signal to send, by default SIGTERM
2430
  @type timeout: int
2431
  @param timeout: The timeout after which, if the process is still alive,
2432
                  a SIGKILL will be sent. If not positive, no such checking
2433
                  will be done
2434
  @type waitpid: boolean
2435
  @param waitpid: If true, we should waitpid on this process after
2436
      sending signals, since it's our own child and otherwise it
2437
      would remain as zombie
2438

2439
  """
2440
  def _helper(pid, signal_, wait):
2441
    """Simple helper to encapsulate the kill/waitpid sequence"""
2442
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2443
      try:
2444
        os.waitpid(pid, os.WNOHANG)
2445
      except OSError:
2446
        pass
2447

    
2448
  if pid <= 0:
2449
    # kill with pid=0 == suicide
2450
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2451

    
2452
  if not IsProcessAlive(pid):
2453
    return
2454

    
2455
  _helper(pid, signal_, waitpid)
2456

    
2457
  if timeout <= 0:
2458
    return
2459

    
2460
  def _CheckProcess():
2461
    if not IsProcessAlive(pid):
2462
      return
2463

    
2464
    try:
2465
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2466
    except OSError:
2467
      raise RetryAgain()
2468

    
2469
    if result_pid > 0:
2470
      return
2471

    
2472
    raise RetryAgain()
2473

    
2474
  try:
2475
    # Wait up to $timeout seconds
2476
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2477
  except RetryTimeout:
2478
    pass
2479

    
2480
  if IsProcessAlive(pid):
2481
    # Kill process if it's still alive
2482
    _helper(pid, signal.SIGKILL, waitpid)
2483

    
2484

    
2485
def FindFile(name, search_path, test=os.path.exists):
2486
  """Look for a filesystem object in a given path.
2487

2488
  This is an abstract method to search for filesystem object (files,
2489
  dirs) under a given search path.
2490

2491
  @type name: str
2492
  @param name: the name to look for
2493
  @type search_path: str
2494
  @param search_path: location to start at
2495
  @type test: callable
2496
  @param test: a function taking one argument that should return True
2497
      if the a given object is valid; the default value is
2498
      os.path.exists, causing only existing files to be returned
2499
  @rtype: str or None
2500
  @return: full path to the object if found, None otherwise
2501

2502
  """
2503
  # validate the filename mask
2504
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2505
    logging.critical("Invalid value passed for external script name: '%s'",
2506
                     name)
2507
    return None
2508

    
2509
  for dir_name in search_path:
2510
    # FIXME: investigate switch to PathJoin
2511
    item_name = os.path.sep.join([dir_name, name])
2512
    # check the user test and that we're indeed resolving to the given
2513
    # basename
2514
    if test(item_name) and os.path.basename(item_name) == name:
2515
      return item_name
2516
  return None
2517

    
2518

    
2519
def CheckVolumeGroupSize(vglist, vgname, minsize):
2520
  """Checks if the volume group list is valid.
2521

2522
  The function will check if a given volume group is in the list of
2523
  volume groups and has a minimum size.
2524

2525
  @type vglist: dict
2526
  @param vglist: dictionary of volume group names and their size
2527
  @type vgname: str
2528
  @param vgname: the volume group we should check
2529
  @type minsize: int
2530
  @param minsize: the minimum size we accept
2531
  @rtype: None or str
2532
  @return: None for success, otherwise the error message
2533

2534
  """
2535
  vgsize = vglist.get(vgname, None)
2536
  if vgsize is None:
2537
    return "volume group '%s' missing" % vgname
2538
  elif vgsize < minsize:
2539
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2540
            (vgname, minsize, vgsize))
2541
  return None
2542

    
2543

    
2544
def SplitTime(value):
2545
  """Splits time as floating point number into a tuple.
2546

2547
  @param value: Time in seconds
2548
  @type value: int or float
2549
  @return: Tuple containing (seconds, microseconds)
2550

2551
  """
2552
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2553

    
2554
  assert 0 <= seconds, \
2555
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2556
  assert 0 <= microseconds <= 999999, \
2557
    "Microseconds must be 0-999999, but are %s" % microseconds
2558

    
2559
  return (int(seconds), int(microseconds))
2560

    
2561

    
2562
def MergeTime(timetuple):
2563
  """Merges a tuple into time as a floating point number.
2564

2565
  @param timetuple: Time as tuple, (seconds, microseconds)
2566
  @type timetuple: tuple
2567
  @return: Time as a floating point number expressed in seconds
2568

2569
  """
2570
  (seconds, microseconds) = timetuple
2571

    
2572
  assert 0 <= seconds, \
2573
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2574
  assert 0 <= microseconds <= 999999, \
2575
    "Microseconds must be 0-999999, but are %s" % microseconds
2576

    
2577
  return float(seconds) + (float(microseconds) * 0.000001)
2578

    
2579

    
2580
def GetDaemonPort(daemon_name):
2581
  """Get the daemon port for this cluster.
2582

2583
  Note that this routine does not read a ganeti-specific file, but
2584
  instead uses C{socket.getservbyname} to allow pre-customization of
2585
  this parameter outside of Ganeti.
2586

2587
  @type daemon_name: string
2588
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2589
  @rtype: int
2590

2591
  """
2592
  if daemon_name not in constants.DAEMONS_PORTS:
2593
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2594

    
2595
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2596
  try:
2597
    port = socket.getservbyname(daemon_name, proto)
2598
  except socket.error:
2599
    port = default_port
2600

    
2601
  return port
2602

    
2603

    
2604
class LogFileHandler(logging.FileHandler):
2605
  """Log handler that doesn't fallback to stderr.
2606

2607
  When an error occurs while writing on the logfile, logging.FileHandler tries
2608
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2609
  the logfile. This class avoids failures reporting errors to /dev/console.
2610

2611
  """
2612
  def __init__(self, filename, mode="a", encoding=None):
2613
    """Open the specified file and use it as the stream for logging.
2614

2615
    Also open /dev/console to report errors while logging.
2616

2617
    """
2618
    logging.FileHandler.__init__(self, filename, mode, encoding)
2619
    self.console = open(constants.DEV_CONSOLE, "a")
2620

    
2621
  def handleError(self, record): # pylint: disable-msg=C0103
2622
    """Handle errors which occur during an emit() call.
2623

2624
    Try to handle errors with FileHandler method, if it fails write to
2625
    /dev/console.
2626

2627
    """
2628
    try:
2629
      logging.FileHandler.handleError(self, record)
2630
    except Exception: # pylint: disable-msg=W0703
2631
      try:
2632
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2633
      except Exception: # pylint: disable-msg=W0703
2634
        # Log handler tried everything it could, now just give up
2635
        pass
2636

    
2637

    
2638
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2639
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2640
                 console_logging=False):
2641
  """Configures the logging module.
2642

2643
  @type logfile: str
2644
  @param logfile: the filename to which we should log
2645
  @type debug: integer
2646
  @param debug: if greater than zero, enable debug messages, otherwise
2647
      only those at C{INFO} and above level
2648
  @type stderr_logging: boolean
2649
  @param stderr_logging: whether we should also log to the standard error
2650
  @type program: str
2651
  @param program: the name under which we should log messages
2652
  @type multithreaded: boolean
2653
  @param multithreaded: if True, will add the thread name to the log file
2654
  @type syslog: string
2655
  @param syslog: one of 'no', 'yes', 'only':
2656
      - if no, syslog is not used
2657
      - if yes, syslog is used (in addition to file-logging)
2658
      - if only, only syslog is used
2659
  @type console_logging: boolean
2660
  @param console_logging: if True, will use a FileHandler which falls back to
2661
      the system console if logging fails
2662
  @raise EnvironmentError: if we can't open the log file and
2663
      syslog/stderr logging is disabled
2664

2665
  """
2666
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2667
  sft = program + "[%(process)d]:"
2668
  if multithreaded:
2669
    fmt += "/%(threadName)s"
2670
    sft += " (%(threadName)s)"
2671
  if debug:
2672
    fmt += " %(module)s:%(lineno)s"
2673
    # no debug info for syslog loggers
2674
  fmt += " %(levelname)s %(message)s"
2675
  # yes, we do want the textual level, as remote syslog will probably
2676
  # lose the error level, and it's easier to grep for it
2677
  sft += " %(levelname)s %(message)s"
2678
  formatter = logging.Formatter(fmt)
2679
  sys_fmt = logging.Formatter(sft)
2680

    
2681
  root_logger = logging.getLogger("")
2682
  root_logger.setLevel(logging.NOTSET)
2683

    
2684
  # Remove all previously setup handlers
2685
  for handler in root_logger.handlers:
2686
    handler.close()
2687
    root_logger.removeHandler(handler)
2688

    
2689
  if stderr_logging:
2690
    stderr_handler = logging.StreamHandler()
2691
    stderr_handler.setFormatter(formatter)
2692
    if debug:
2693
      stderr_handler.setLevel(logging.NOTSET)
2694
    else:
2695
      stderr_handler.setLevel(logging.CRITICAL)
2696
    root_logger.addHandler(stderr_handler)
2697

    
2698
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2699
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2700
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2701
                                                    facility)
2702
    syslog_handler.setFormatter(sys_fmt)
2703
    # Never enable debug over syslog
2704
    syslog_handler.setLevel(logging.INFO)
2705
    root_logger.addHandler(syslog_handler)
2706

    
2707
  if syslog != constants.SYSLOG_ONLY:
2708
    # this can fail, if the logging directories are not setup or we have
2709
    # a permisssion problem; in this case, it's best to log but ignore
2710
    # the error if stderr_logging is True, and if false we re-raise the
2711
    # exception since otherwise we could run but without any logs at all
2712
    try:
2713
      if console_logging:
2714
        logfile_handler = LogFileHandler(logfile)
2715
      else:
2716
        logfile_handler = logging.FileHandler(logfile)
2717
      logfile_handler.setFormatter(formatter)
2718
      if debug:
2719
        logfile_handler.setLevel(logging.DEBUG)
2720
      else:
2721
        logfile_handler.setLevel(logging.INFO)
2722
      root_logger.addHandler(logfile_handler)
2723
    except EnvironmentError:
2724
      if stderr_logging or syslog == constants.SYSLOG_YES:
2725
        logging.exception("Failed to enable logging to file '%s'", logfile)
2726
      else:
2727
        # we need to re-raise the exception
2728
        raise
2729

    
2730

    
2731
def IsNormAbsPath(path):
2732
  """Check whether a path is absolute and also normalized
2733

2734
  This avoids things like /dir/../../other/path to be valid.
2735

2736
  """
2737
  return os.path.normpath(path) == path and os.path.isabs(path)
2738

    
2739

    
2740
def PathJoin(*args):
2741
  """Safe-join a list of path components.
2742

2743
  Requirements:
2744
      - the first argument must be an absolute path
2745
      - no component in the path must have backtracking (e.g. /../),
2746
        since we check for normalization at the end
2747

2748
  @param args: the path components to be joined
2749
  @raise ValueError: for invalid paths
2750

2751
  """
2752
  # ensure we're having at least one path passed in
2753
  assert args
2754
  # ensure the first component is an absolute and normalized path name
2755
  root = args[0]
2756
  if not IsNormAbsPath(root):
2757
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2758
  result = os.path.join(*args)
2759
  # ensure that the whole path is normalized
2760
  if not IsNormAbsPath(result):
2761
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2762
  # check that we're still under the original prefix
2763
  prefix = os.path.commonprefix([root, result])
2764
  if prefix != root:
2765
    raise ValueError("Error: path joining resulted in different prefix"
2766
                     " (%s != %s)" % (prefix, root))
2767
  return result
2768

    
2769

    
2770
def TailFile(fname, lines=20):
2771
  """Return the last lines from a file.
2772

2773
  @note: this function will only read and parse the last 4KB of
2774
      the file; if the lines are very long, it could be that less
2775
      than the requested number of lines are returned
2776

2777
  @param fname: the file name
2778
  @type lines: int
2779
  @param lines: the (maximum) number of lines to return
2780

2781
  """
2782
  fd = open(fname, "r")
2783
  try:
2784
    fd.seek(0, 2)
2785
    pos = fd.tell()
2786
    pos = max(0, pos-4096)
2787
    fd.seek(pos, 0)
2788
    raw_data = fd.read()
2789
  finally:
2790
    fd.close()
2791

    
2792
  rows = raw_data.splitlines()
2793
  return rows[-lines:]
2794

    
2795

    
2796
def FormatTimestampWithTZ(secs):
2797
  """Formats a Unix timestamp with the local timezone.
2798

2799
  """
2800
  return time.strftime("%F %T %Z", time.gmtime(secs))
2801

    
2802

    
2803
def _ParseAsn1Generalizedtime(value):
2804
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2805

2806
  @type value: string
2807
  @param value: ASN1 GENERALIZEDTIME timestamp
2808

2809
  """
2810
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2811
  if m:
2812
    # We have an offset
2813
    asn1time = m.group(1)
2814
    hours = int(m.group(2))
2815
    minutes = int(m.group(3))
2816
    utcoffset = (60 * hours) + minutes
2817
  else:
2818
    if not value.endswith("Z"):
2819
      raise ValueError("Missing timezone")
2820
    asn1time = value[:-1]
2821
    utcoffset = 0
2822

    
2823
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2824

    
2825
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2826

    
2827
  return calendar.timegm(tt.utctimetuple())
2828

    
2829

    
2830
def GetX509CertValidity(cert):
2831
  """Returns the validity period of the certificate.
2832

2833
  @type cert: OpenSSL.crypto.X509
2834
  @param cert: X509 certificate object
2835

2836
  """
2837
  # The get_notBefore and get_notAfter functions are only supported in
2838
  # pyOpenSSL 0.7 and above.
2839
  try:
2840
    get_notbefore_fn = cert.get_notBefore
2841
  except AttributeError:
2842
    not_before = None
2843
  else:
2844
    not_before_asn1 = get_notbefore_fn()
2845

    
2846
    if not_before_asn1 is None:
2847
      not_before = None
2848
    else:
2849
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2850

    
2851
  try:
2852
    get_notafter_fn = cert.get_notAfter
2853
  except AttributeError:
2854
    not_after = None
2855
  else:
2856
    not_after_asn1 = get_notafter_fn()
2857

    
2858
    if not_after_asn1 is None:
2859
      not_after = None
2860
    else:
2861
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2862

    
2863
  return (not_before, not_after)
2864

    
2865

    
2866
def _VerifyCertificateInner(expired, not_before, not_after, now,
2867
                            warn_days, error_days):
2868
  """Verifies certificate validity.
2869

2870
  @type expired: bool
2871
  @param expired: Whether pyOpenSSL considers the certificate as expired
2872
  @type not_before: number or None
2873
  @param not_before: Unix timestamp before which certificate is not valid
2874
  @type not_after: number or None
2875
  @param not_after: Unix timestamp after which certificate is invalid
2876
  @type now: number
2877
  @param now: Current time as Unix timestamp
2878
  @type warn_days: number or None
2879
  @param warn_days: How many days before expiration a warning should be reported
2880
  @type error_days: number or None
2881
  @param error_days: How many days before expiration an error should be reported
2882

2883
  """
2884
  if expired:
2885
    msg = "Certificate is expired"
2886

    
2887
    if not_before is not None and not_after is not None:
2888
      msg += (" (valid from %s to %s)" %
2889
              (FormatTimestampWithTZ(not_before),
2890
               FormatTimestampWithTZ(not_after)))
2891
    elif not_before is not None:
2892
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2893
    elif not_after is not None:
2894
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2895

    
2896
    return (CERT_ERROR, msg)
2897

    
2898
  elif not_before is not None and not_before > now:
2899
    return (CERT_WARNING,
2900
            "Certificate not yet valid (valid from %s)" %
2901
            FormatTimestampWithTZ(not_before))
2902

    
2903
  elif not_after is not None:
2904
    remaining_days = int((not_after - now) / (24 * 3600))
2905

    
2906
    msg = "Certificate expires in about %d days" % remaining_days
2907

    
2908
    if error_days is not None and remaining_days <= error_days:
2909
      return (CERT_ERROR, msg)
2910

    
2911
    if warn_days is not None and remaining_days <= warn_days:
2912
      return (CERT_WARNING, msg)
2913

    
2914
  return (None, None)
2915

    
2916

    
2917
def VerifyX509Certificate(cert, warn_days, error_days):
2918
  """Verifies a certificate for LUVerifyCluster.
2919

2920
  @type cert: OpenSSL.crypto.X509
2921
  @param cert: X509 certificate object
2922
  @type warn_days: number or None
2923
  @param warn_days: How many days before expiration a warning should be reported
2924
  @type error_days: number or None
2925
  @param error_days: How many days before expiration an error should be reported
2926

2927
  """
2928
  # Depending on the pyOpenSSL version, this can just return (None, None)
2929
  (not_before, not_after) = GetX509CertValidity(cert)
2930

    
2931
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2932
                                 time.time(), warn_days, error_days)
2933

    
2934

    
2935
def SignX509Certificate(cert, key, salt):
2936
  """Sign a X509 certificate.
2937

2938
  An RFC822-like signature header is added in front of the certificate.
2939

2940
  @type cert: OpenSSL.crypto.X509
2941
  @param cert: X509 certificate object
2942
  @type key: string
2943
  @param key: Key for HMAC
2944
  @type salt: string
2945
  @param salt: Salt for HMAC
2946
  @rtype: string
2947
  @return: Serialized and signed certificate in PEM format
2948

2949
  """
2950
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2951
    raise errors.GenericError("Invalid salt: %r" % salt)
2952

    
2953
  # Dumping as PEM here ensures the certificate is in a sane format
2954
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2955

    
2956
  return ("%s: %s/%s\n\n%s" %
2957
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2958
           Sha1Hmac(key, cert_pem, salt=salt),
2959
           cert_pem))
2960

    
2961

    
2962
def _ExtractX509CertificateSignature(cert_pem):
2963
  """Helper function to extract signature from X509 certificate.
2964

2965
  """
2966
  # Extract signature from original PEM data
2967
  for line in cert_pem.splitlines():
2968
    if line.startswith("---"):
2969
      break
2970

    
2971
    m = X509_SIGNATURE.match(line.strip())
2972
    if m:
2973
      return (m.group("salt"), m.group("sign"))
2974

    
2975
  raise errors.GenericError("X509 certificate signature is missing")
2976

    
2977

    
2978
def LoadSignedX509Certificate(cert_pem, key):
2979
  """Verifies a signed X509 certificate.
2980

2981
  @type cert_pem: string
2982
  @param cert_pem: Certificate in PEM format and with signature header
2983
  @type key: string
2984
  @param key: Key for HMAC
2985
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2986
  @return: X509 certificate object and salt
2987

2988
  """
2989
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2990

    
2991
  # Load certificate
2992
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2993

    
2994
  # Dump again to ensure it's in a sane format
2995
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2996

    
2997
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2998
    raise errors.GenericError("X509 certificate signature is invalid")
2999

    
3000
  return (cert, salt)
3001

    
3002

    
3003
def Sha1Hmac(key, text, salt=None):
3004
  """Calculates the HMAC-SHA1 digest of a text.
3005

3006
  HMAC is defined in RFC2104.
3007

3008
  @type key: string
3009
  @param key: Secret key
3010
  @type text: string
3011

3012
  """
3013
  if salt:
3014
    salted_text = salt + text
3015
  else:
3016
    salted_text = text
3017

    
3018
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
3019

    
3020

    
3021
def VerifySha1Hmac(key, text, digest, salt=None):
3022
  """Verifies the HMAC-SHA1 digest of a text.
3023

3024
  HMAC is defined in RFC2104.
3025

3026
  @type key: string
3027
  @param key: Secret key
3028
  @type text: string
3029
  @type digest: string
3030
  @param digest: Expected digest
3031
  @rtype: bool
3032
  @return: Whether HMAC-SHA1 digest matches
3033

3034
  """
3035
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
3036

    
3037

    
3038
def SafeEncode(text):
3039
  """Return a 'safe' version of a source string.
3040

3041
  This function mangles the input string and returns a version that
3042
  should be safe to display/encode as ASCII. To this end, we first
3043
  convert it to ASCII using the 'backslashreplace' encoding which
3044
  should get rid of any non-ASCII chars, and then we process it
3045
  through a loop copied from the string repr sources in the python; we
3046
  don't use string_escape anymore since that escape single quotes and
3047
  backslashes too, and that is too much; and that escaping is not
3048
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
3049

3050
  @type text: str or unicode
3051
  @param text: input data
3052
  @rtype: str
3053
  @return: a safe version of text
3054

3055
  """
3056
  if isinstance(text, unicode):
3057
    # only if unicode; if str already, we handle it below
3058
    text = text.encode('ascii', 'backslashreplace')
3059
  resu = ""
3060
  for char in text:
3061
    c = ord(char)
3062
    if char  == '\t':
3063
      resu += r'\t'
3064
    elif char == '\n':
3065
      resu += r'\n'
3066
    elif char == '\r':
3067
      resu += r'\'r'
3068
    elif c < 32 or c >= 127: # non-printable
3069
      resu += "\\x%02x" % (c & 0xff)
3070
    else:
3071
      resu += char
3072
  return resu
3073

    
3074

    
3075
def UnescapeAndSplit(text, sep=","):
3076
  """Split and unescape a string based on a given separator.
3077

3078
  This function splits a string based on a separator where the
3079
  separator itself can be escape in order to be an element of the
3080
  elements. The escaping rules are (assuming coma being the
3081
  separator):
3082
    - a plain , separates the elements
3083
    - a sequence \\\\, (double backslash plus comma) is handled as a
3084
      backslash plus a separator comma
3085
    - a sequence \, (backslash plus comma) is handled as a
3086
      non-separator comma
3087

3088
  @type text: string
3089
  @param text: the string to split
3090
  @type sep: string
3091
  @param text: the separator
3092
  @rtype: string
3093
  @return: a list of strings
3094

3095
  """
3096
  # we split the list by sep (with no escaping at this stage)
3097
  slist = text.split(sep)
3098
  # next, we revisit the elements and if any of them ended with an odd
3099
  # number of backslashes, then we join it with the next
3100
  rlist = []
3101
  while slist:
3102
    e1 = slist.pop(0)
3103
    if e1.endswith("\\"):
3104
      num_b = len(e1) - len(e1.rstrip("\\"))
3105
      if num_b % 2 == 1:
3106
        e2 = slist.pop(0)
3107
        # here the backslashes remain (all), and will be reduced in
3108
        # the next step
3109
        rlist.append(e1 + sep + e2)
3110
        continue
3111
    rlist.append(e1)
3112
  # finally, replace backslash-something with something
3113
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3114
  return rlist
3115

    
3116

    
3117
def CommaJoin(names):
3118
  """Nicely join a set of identifiers.
3119

3120
  @param names: set, list or tuple
3121
  @return: a string with the formatted results
3122

3123
  """
3124
  return ", ".join([str(val) for val in names])
3125

    
3126

    
3127
def BytesToMebibyte(value):
3128
  """Converts bytes to mebibytes.
3129

3130
  @type value: int
3131
  @param value: Value in bytes
3132
  @rtype: int
3133
  @return: Value in mebibytes
3134

3135
  """
3136
  return int(round(value / (1024.0 * 1024.0), 0))
3137

    
3138

    
3139
def CalculateDirectorySize(path):
3140
  """Calculates the size of a directory recursively.
3141

3142
  @type path: string
3143
  @param path: Path to directory
3144
  @rtype: int
3145
  @return: Size in mebibytes
3146

3147
  """
3148
  size = 0
3149

    
3150
  for (curpath, _, files) in os.walk(path):
3151
    for filename in files:
3152
      st = os.lstat(PathJoin(curpath, filename))
3153
      size += st.st_size
3154

    
3155
  return BytesToMebibyte(size)
3156

    
3157

    
3158
def GetFilesystemStats(path):
3159
  """Returns the total and free space on a filesystem.
3160

3161
  @type path: string
3162
  @param path: Path on filesystem to be examined
3163
  @rtype: int
3164
  @return: tuple of (Total space, Free space) in mebibytes
3165

3166
  """
3167
  st = os.statvfs(path)
3168

    
3169
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3170
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3171
  return (tsize, fsize)
3172

    
3173

    
3174
def RunInSeparateProcess(fn, *args):
3175
  """Runs a function in a separate process.
3176

3177
  Note: Only boolean return values are supported.
3178

3179
  @type fn: callable
3180
  @param fn: Function to be called
3181
  @rtype: bool
3182
  @return: Function's result
3183

3184
  """
3185
  pid = os.fork()
3186
  if pid == 0:
3187
    # Child process
3188
    try:
3189
      # In case the function uses temporary files
3190
      ResetTempfileModule()
3191

    
3192
      # Call function
3193
      result = int(bool(fn(*args)))
3194
      assert result in (0, 1)
3195
    except: # pylint: disable-msg=W0702
3196
      logging.exception("Error while calling function in separate process")
3197
      # 0 and 1 are reserved for the return value
3198
      result = 33
3199

    
3200
    os._exit(result) # pylint: disable-msg=W0212
3201

    
3202
  # Parent process
3203

    
3204
  # Avoid zombies and check exit code
3205
  (_, status) = os.waitpid(pid, 0)
3206

    
3207
  if os.WIFSIGNALED(status):
3208
    exitcode = None
3209
    signum = os.WTERMSIG(status)
3210
  else:
3211
    exitcode = os.WEXITSTATUS(status)
3212
    signum = None
3213

    
3214
  if not (exitcode in (0, 1) and signum is None):
3215
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3216
                              (exitcode, signum))
3217

    
3218
  return bool(exitcode)
3219

    
3220

    
3221
def IgnoreProcessNotFound(fn, *args, **kwargs):
3222
  """Ignores ESRCH when calling a process-related function.
3223

3224
  ESRCH is raised when a process is not found.
3225

3226
  @rtype: bool
3227
  @return: Whether process was found
3228

3229
  """
3230
  try:
3231
    fn(*args, **kwargs)
3232
  except EnvironmentError, err:
3233
    # Ignore ESRCH
3234
    if err.errno == errno.ESRCH:
3235
      return False
3236
    raise
3237

    
3238
  return True
3239

    
3240

    
3241
def IgnoreSignals(fn, *args, **kwargs):
3242
  """Tries to call a function ignoring failures due to EINTR.
3243

3244
  """
3245
  try:
3246
    return fn(*args, **kwargs)
3247
  except EnvironmentError, err:
3248
    if err.errno == errno.EINTR:
3249
      return None
3250
    else:
3251
      raise
3252
  except (select.error, socket.error), err:
3253
    # In python 2.6 and above select.error is an IOError, so it's handled
3254
    # above, in 2.5 and below it's not, and it's handled here.
3255
    if err.args and err.args[0] == errno.EINTR:
3256
      return None
3257
    else:
3258
      raise
3259

    
3260

    
3261
def LockFile(fd):
3262
  """Locks a file using POSIX locks.
3263

3264
  @type fd: int
3265
  @param fd: the file descriptor we need to lock
3266

3267
  """
3268
  try:
3269
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3270
  except IOError, err:
3271
    if err.errno == errno.EAGAIN:
3272
      raise errors.LockError("File already locked")
3273
    raise
3274

    
3275

    
3276
def FormatTime(val):
3277
  """Formats a time value.
3278

3279
  @type val: float or None
3280
  @param val: the timestamp as returned by time.time()
3281
  @return: a string value or N/A if we don't have a valid timestamp
3282

3283
  """
3284
  if val is None or not isinstance(val, (int, float)):
3285
    return "N/A"
3286
  # these two codes works on Linux, but they are not guaranteed on all
3287
  # platforms
3288
  return time.strftime("%F %T", time.localtime(val))
3289

    
3290

    
3291
def FormatSeconds(secs):
3292
  """Formats seconds for easier reading.
3293

3294
  @type secs: number
3295
  @param secs: Number of seconds
3296
  @rtype: string
3297
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3298

3299
  """
3300
  parts = []
3301

    
3302
  secs = round(secs, 0)
3303

    
3304
  if secs > 0:
3305
    # Negative values would be a bit tricky
3306
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3307
      (complete, secs) = divmod(secs, one)
3308
      if complete or parts:
3309
        parts.append("%d%s" % (complete, unit))
3310

    
3311
  parts.append("%ds" % secs)
3312

    
3313
  return " ".join(parts)
3314

    
3315

    
3316
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3317
  """Reads the watcher pause file.
3318

3319
  @type filename: string
3320
  @param filename: Path to watcher pause file
3321
  @type now: None, float or int
3322
  @param now: Current time as Unix timestamp
3323
  @type remove_after: int
3324
  @param remove_after: Remove watcher pause file after specified amount of
3325
    seconds past the pause end time
3326

3327
  """
3328
  if now is None:
3329
    now = time.time()
3330

    
3331
  try:
3332
    value = ReadFile(filename)
3333
  except IOError, err:
3334
    if err.errno != errno.ENOENT:
3335
      raise
3336
    value = None
3337

    
3338
  if value is not None:
3339
    try:
3340
      value = int(value)
3341
    except ValueError:
3342
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3343
                       " removing it"), filename)
3344
      RemoveFile(filename)
3345
      value = None
3346

    
3347
    if value is not None:
3348
      # Remove file if it's outdated
3349
      if now > (value + remove_after):
3350
        RemoveFile(filename)
3351
        value = None
3352

    
3353
      elif now > value:
3354
        value = None
3355

    
3356
  return value
3357

    
3358

    
3359
class RetryTimeout(Exception):
3360
  """Retry loop timed out.
3361

3362
  Any arguments which was passed by the retried function to RetryAgain will be
3363
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3364
  the RaiseInner helper method will reraise it.
3365

3366
  """
3367
  def RaiseInner(self):
3368
    if self.args and isinstance(self.args[0], Exception):
3369
      raise self.args[0]
3370
    else:
3371
      raise RetryTimeout(*self.args)
3372

    
3373

    
3374
class RetryAgain(Exception):
3375
  """Retry again.
3376

3377
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3378
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3379
  of the RetryTimeout() method can be used to reraise it.
3380

3381
  """
3382

    
3383

    
3384
class _RetryDelayCalculator(object):
3385
  """Calculator for increasing delays.
3386

3387
  """
3388
  __slots__ = [
3389
    "_factor",
3390
    "_limit",
3391
    "_next",
3392
    "_start",
3393
    ]
3394

    
3395
  def __init__(self, start, factor, limit):
3396
    """Initializes this class.
3397

3398
    @type start: float
3399
    @param start: Initial delay
3400
    @type factor: float
3401
    @param factor: Factor for delay increase
3402
    @type limit: float or None
3403
    @param limit: Upper limit for delay or None for no limit
3404

3405
    """
3406
    assert start > 0.0
3407
    assert factor >= 1.0
3408
    assert limit is None or limit >= 0.0
3409

    
3410
    self._start = start
3411
    self._factor = factor
3412
    self._limit = limit
3413

    
3414
    self._next = start
3415

    
3416
  def __call__(self):
3417
    """Returns current delay and calculates the next one.
3418

3419
    """
3420
    current = self._next
3421

    
3422
    # Update for next run
3423
    if self._limit is None or self._next < self._limit:
3424
      self._next = min(self._limit, self._next * self._factor)
3425

    
3426
    return current
3427

    
3428

    
3429
#: Special delay to specify whole remaining timeout
3430
RETRY_REMAINING_TIME = object()
3431

    
3432

    
3433
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3434
          _time_fn=time.time):
3435
  """Call a function repeatedly until it succeeds.
3436

3437
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3438
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3439
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3440

3441
  C{delay} can be one of the following:
3442
    - callable returning the delay length as a float
3443
    - Tuple of (start, factor, limit)
3444
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3445
      useful when overriding L{wait_fn} to wait for an external event)
3446
    - A static delay as a number (int or float)
3447

3448
  @type fn: callable
3449
  @param fn: Function to be called
3450
  @param delay: Either a callable (returning the delay), a tuple of (start,
3451
                factor, limit) (see L{_RetryDelayCalculator}),
3452
                L{RETRY_REMAINING_TIME} or a number (int or float)
3453
  @type timeout: float
3454
  @param timeout: Total timeout
3455
  @type wait_fn: callable
3456
  @param wait_fn: Waiting function
3457
  @return: Return value of function
3458

3459
  """
3460
  assert callable(fn)
3461
  assert callable(wait_fn)
3462
  assert callable(_time_fn)
3463

    
3464
  if args is None:
3465
    args = []
3466

    
3467
  end_time = _time_fn() + timeout
3468

    
3469
  if callable(delay):
3470
    # External function to calculate delay
3471
    calc_delay = delay
3472

    
3473
  elif isinstance(delay, (tuple, list)):
3474
    # Increasing delay with optional upper boundary
3475
    (start, factor, limit) = delay
3476
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3477

    
3478
  elif delay is RETRY_REMAINING_TIME:
3479
    # Always use the remaining time
3480
    calc_delay = None
3481

    
3482
  else:
3483
    # Static delay
3484
    calc_delay = lambda: delay
3485

    
3486
  assert calc_delay is None or callable(calc_delay)
3487

    
3488
  while True:
3489
    retry_args = []
3490
    try:
3491
      # pylint: disable-msg=W0142
3492
      return fn(*args)
3493
    except RetryAgain, err:
3494
      retry_args = err.args
3495
    except RetryTimeout:
3496
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3497
                                   " handle RetryTimeout")
3498

    
3499
    remaining_time = end_time - _time_fn()
3500

    
3501
    if remaining_time < 0.0:
3502
      # pylint: disable-msg=W0142
3503
      raise RetryTimeout(*retry_args)
3504

    
3505
    assert remaining_time >= 0.0
3506

    
3507
    if calc_delay is None:
3508
      wait_fn(remaining_time)
3509
    else:
3510
      current_delay = calc_delay()
3511
      if current_delay > 0.0:
3512
        wait_fn(current_delay)
3513

    
3514

    
3515
def GetClosedTempfile(*args, **kwargs):
3516
  """Creates a temporary file and returns its path.
3517

3518
  """
3519
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3520
  _CloseFDNoErr(fd)
3521
  return path
3522

    
3523

    
3524
def GenerateSelfSignedX509Cert(common_name, validity):
3525
  """Generates a self-signed X509 certificate.
3526

3527
  @type common_name: string
3528
  @param common_name: commonName value
3529
  @type validity: int
3530
  @param validity: Validity for certificate in seconds
3531

3532
  """
3533
  # Create private and public key
3534
  key = OpenSSL.crypto.PKey()
3535
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3536

    
3537
  # Create self-signed certificate
3538
  cert = OpenSSL.crypto.X509()
3539
  if common_name:
3540
    cert.get_subject().CN = common_name
3541
  cert.set_serial_number(1)
3542
  cert.gmtime_adj_notBefore(0)
3543
  cert.gmtime_adj_notAfter(validity)
3544
  cert.set_issuer(cert.get_subject())
3545
  cert.set_pubkey(key)
3546
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3547

    
3548
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3549
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3550

    
3551
  return (key_pem, cert_pem)
3552

    
3553

    
3554
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3555
  """Legacy function to generate self-signed X509 certificate.
3556

3557
  """
3558
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3559
                                                   validity * 24 * 60 * 60)
3560

    
3561
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3562

    
3563

    
3564
class FileLock(object):
3565
  """Utility class for file locks.
3566

3567
  """
3568
  def __init__(self, fd, filename):
3569
    """Constructor for FileLock.
3570

3571
    @type fd: file
3572
    @param fd: File object
3573
    @type filename: str
3574
    @param filename: Path of the file opened at I{fd}
3575

3576
    """
3577
    self.fd = fd
3578
    self.filename = filename
3579

    
3580
  @classmethod
3581
  def Open(cls, filename):
3582
    """Creates and opens a file to be used as a file-based lock.
3583

3584
    @type filename: string
3585
    @param filename: path to the file to be locked
3586

3587
    """
3588
    # Using "os.open" is necessary to allow both opening existing file
3589
    # read/write and creating if not existing. Vanilla "open" will truncate an
3590
    # existing file -or- allow creating if not existing.
3591
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3592
               filename)
3593

    
3594
  def __del__(self):
3595
    self.Close()
3596

    
3597
  def Close(self):
3598
    """Close the file and release the lock.
3599

3600
    """
3601
    if hasattr(self, "fd") and self.fd:
3602
      self.fd.close()
3603
      self.fd = None
3604

    
3605
  def _flock(self, flag, blocking, timeout, errmsg):
3606
    """Wrapper for fcntl.flock.
3607

3608
    @type flag: int
3609
    @param flag: operation flag
3610
    @type blocking: bool
3611
    @param blocking: whether the operation should be done in blocking mode.
3612
    @type timeout: None or float
3613
    @param timeout: for how long the operation should be retried (implies
3614
                    non-blocking mode).
3615
    @type errmsg: string
3616
    @param errmsg: error message in case operation fails.
3617

3618
    """
3619
    assert self.fd, "Lock was closed"
3620
    assert timeout is None or timeout >= 0, \
3621
      "If specified, timeout must be positive"
3622
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3623

    
3624
    # When a timeout is used, LOCK_NB must always be set
3625
    if not (timeout is None and blocking):
3626
      flag |= fcntl.LOCK_NB
3627

    
3628
    if timeout is None:
3629
      self._Lock(self.fd, flag, timeout)
3630
    else:
3631
      try:
3632
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3633
              args=(self.fd, flag, timeout))
3634
      except RetryTimeout:
3635
        raise errors.LockError(errmsg)
3636

    
3637
  @staticmethod
3638
  def _Lock(fd, flag, timeout):
3639
    try:
3640
      fcntl.flock(fd, flag)
3641
    except IOError, err:
3642
      if timeout is not None and err.errno == errno.EAGAIN:
3643
        raise RetryAgain()
3644

    
3645
      logging.exception("fcntl.flock failed")
3646
      raise
3647

    
3648
  def Exclusive(self, blocking=False, timeout=None):
3649
    """Locks the file in exclusive mode.
3650

3651
    @type blocking: boolean
3652
    @param blocking: whether to block and wait until we
3653
        can lock the file or return immediately
3654
    @type timeout: int or None
3655
    @param timeout: if not None, the duration to wait for the lock
3656
        (in blocking mode)
3657

3658
    """
3659
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3660
                "Failed to lock %s in exclusive mode" % self.filename)
3661

    
3662
  def Shared(self, blocking=False, timeout=None):
3663
    """Locks the file in shared mode.
3664

3665
    @type blocking: boolean
3666
    @param blocking: whether to block and wait until we
3667
        can lock the file or return immediately
3668
    @type timeout: int or None
3669
    @param timeout: if not None, the duration to wait for the lock
3670
        (in blocking mode)
3671

3672
    """
3673
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3674
                "Failed to lock %s in shared mode" % self.filename)
3675

    
3676
  def Unlock(self, blocking=True, timeout=None):
3677
    """Unlocks the file.
3678

3679
    According to C{flock(2)}, unlocking can also be a nonblocking
3680
    operation::
3681

3682
      To make a non-blocking request, include LOCK_NB with any of the above
3683
      operations.
3684

3685
    @type blocking: boolean
3686
    @param blocking: whether to block and wait until we
3687
        can lock the file or return immediately
3688
    @type timeout: int or None
3689
    @param timeout: if not None, the duration to wait for the lock
3690
        (in blocking mode)
3691

3692
    """
3693
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3694
                "Failed to unlock %s" % self.filename)
3695

    
3696

    
3697
class LineSplitter:
3698
  """Splits data chunks into lines separated by newline.
3699

3700
  Instances provide a file-like interface.
3701

3702
  """
3703
  def __init__(self, line_fn, *args):
3704
    """Initializes this class.
3705

3706
    @type line_fn: callable
3707
    @param line_fn: Function called for each line, first parameter is line
3708
    @param args: Extra arguments for L{line_fn}
3709

3710
    """
3711
    assert callable(line_fn)
3712

    
3713
    if args:
3714
      # Python 2.4 doesn't have functools.partial yet
3715
      self._line_fn = \
3716
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3717
    else:
3718
      self._line_fn = line_fn
3719

    
3720
    self._lines = collections.deque()
3721
    self._buffer = ""
3722

    
3723
  def write(self, data):
3724
    parts = (self._buffer + data).split("\n")
3725
    self._buffer = parts.pop()
3726
    self._lines.extend(parts)
3727

    
3728
  def flush(self):
3729
    while self._lines:
3730
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3731

    
3732
  def close(self):
3733
    self.flush()
3734
    if self._buffer:
3735
      self._line_fn(self._buffer)
3736

    
3737

    
3738
def SignalHandled(signums):
3739
  """Signal Handled decoration.
3740

3741
  This special decorator installs a signal handler and then calls the target
3742
  function. The function must accept a 'signal_handlers' keyword argument,
3743
  which will contain a dict indexed by signal number, with SignalHandler
3744
  objects as values.
3745

3746
  The decorator can be safely stacked with iself, to handle multiple signals
3747
  with different handlers.
3748

3749
  @type signums: list
3750
  @param signums: signals to intercept
3751

3752
  """
3753
  def wrap(fn):
3754
    def sig_function(*args, **kwargs):
3755
      assert 'signal_handlers' not in kwargs or \
3756
             kwargs['signal_handlers'] is None or \
3757
             isinstance(kwargs['signal_handlers'], dict), \
3758
             "Wrong signal_handlers parameter in original function call"
3759
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3760
        signal_handlers = kwargs['signal_handlers']
3761
      else:
3762
        signal_handlers = {}
3763
        kwargs['signal_handlers'] = signal_handlers
3764
      sighandler = SignalHandler(signums)
3765
      try:
3766
        for sig in signums:
3767
          signal_handlers[sig] = sighandler
3768
        return fn(*args, **kwargs)
3769
      finally:
3770
        sighandler.Reset()
3771
    return sig_function
3772
  return wrap
3773

    
3774

    
3775
class SignalWakeupFd(object):
3776
  try:
3777
    # This is only supported in Python 2.5 and above (some distributions
3778
    # backported it to Python 2.4)
3779
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3780
  except AttributeError:
3781
    # Not supported
3782
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3783
      return -1
3784
  else:
3785
    def _SetWakeupFd(self, fd):
3786
      return self._set_wakeup_fd_fn(fd)
3787

    
3788
  def __init__(self):
3789
    """Initializes this class.
3790

3791
    """
3792
    (read_fd, write_fd) = os.pipe()
3793

    
3794
    # Once these succeeded, the file descriptors will be closed automatically.
3795
    # Buffer size 0 is important, otherwise .read() with a specified length
3796
    # might buffer data and the file descriptors won't be marked readable.
3797
    self._read_fh = os.fdopen(read_fd, "r", 0)
3798
    self._write_fh = os.fdopen(write_fd, "w", 0)
3799

    
3800
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3801

    
3802
    # Utility functions
3803
    self.fileno = self._read_fh.fileno
3804
    self.read = self._read_fh.read
3805

    
3806
  def Reset(self):
3807
    """Restores the previous wakeup file descriptor.
3808

3809
    """
3810
    if hasattr(self, "_previous") and self._previous is not None:
3811
      self._SetWakeupFd(self._previous)
3812
      self._previous = None
3813

    
3814
  def Notify(self):
3815
    """Notifies the wakeup file descriptor.
3816

3817
    """
3818
    self._write_fh.write("\0")
3819

    
3820
  def __del__(self):
3821
    """Called before object deletion.
3822

3823
    """
3824
    self.Reset()
3825

    
3826

    
3827
class SignalHandler(object):
3828
  """Generic signal handler class.
3829

3830
  It automatically restores the original handler when deconstructed or
3831
  when L{Reset} is called. You can either pass your own handler
3832
  function in or query the L{called} attribute to detect whether the
3833
  signal was sent.
3834

3835
  @type signum: list
3836
  @ivar signum: the signals we handle
3837
  @type called: boolean
3838
  @ivar called: tracks whether any of the signals have been raised
3839

3840
  """
3841
  def __init__(self, signum, handler_fn=None, wakeup=None):
3842
    """Constructs a new SignalHandler instance.
3843

3844
    @type signum: int or list of ints
3845
    @param signum: Single signal number or set of signal numbers
3846
    @type handler_fn: callable
3847
    @param handler_fn: Signal handling function
3848

3849
    """
3850
    assert handler_fn is None or callable(handler_fn)
3851

    
3852
    self.signum = set(signum)
3853
    self.called = False
3854

    
3855
    self._handler_fn = handler_fn
3856
    self._wakeup = wakeup
3857

    
3858
    self._previous = {}
3859
    try:
3860
      for signum in self.signum:
3861
        # Setup handler
3862
        prev_handler = signal.signal(signum, self._HandleSignal)
3863
        try:
3864
          self._previous[signum] = prev_handler
3865
        except:
3866
          # Restore previous handler
3867
          signal.signal(signum, prev_handler)
3868
          raise
3869
    except:
3870
      # Reset all handlers
3871
      self.Reset()
3872
      # Here we have a race condition: a handler may have already been called,
3873
      # but there's not much we can do about it at this point.
3874
      raise
3875

    
3876
  def __del__(self):
3877
    self.Reset()
3878

    
3879
  def Reset(self):
3880
    """Restore previous handler.
3881

3882
    This will reset all the signals to their previous handlers.
3883

3884
    """
3885
    for signum, prev_handler in self._previous.items():
3886
      signal.signal(signum, prev_handler)
3887
      # If successful, remove from dict
3888
      del self._previous[signum]
3889

    
3890
  def Clear(self):
3891
    """Unsets the L{called} flag.
3892

3893
    This function can be used in case a signal may arrive several times.
3894

3895
    """
3896
    self.called = False
3897

    
3898
  def _HandleSignal(self, signum, frame):
3899
    """Actual signal handling function.
3900

3901
    """
3902
    # This is not nice and not absolutely atomic, but it appears to be the only
3903
    # solution in Python -- there are no atomic types.
3904
    self.called = True
3905

    
3906
    if self._wakeup:
3907
      # Notify whoever is interested in signals
3908
      self._wakeup.Notify()
3909

    
3910
    if self._handler_fn:
3911
      self._handler_fn(signum, frame)
3912

    
3913

    
3914
class FieldSet(object):
3915
  """A simple field set.
3916

3917
  Among the features are:
3918
    - checking if a string is among a list of static string or regex objects
3919
    - checking if a whole list of string matches
3920
    - returning the matching groups from a regex match
3921

3922
  Internally, all fields are held as regular expression objects.
3923

3924
  """
3925
  def __init__(self, *items):
3926
    self.items = [re.compile("^%s$" % value) for value in items]
3927

    
3928
  def Extend(self, other_set):
3929
    """Extend the field set with the items from another one"""
3930
    self.items.extend(other_set.items)
3931

    
3932
  def Matches(self, field):
3933
    """Checks if a field matches the current set
3934

3935
    @type field: str
3936
    @param field: the string to match
3937
    @return: either None or a regular expression match object
3938

3939
    """
3940
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3941
      return m
3942
    return None
3943

    
3944
  def NonMatching(self, items):
3945
    """Returns the list of fields not matching the current set
3946

3947
    @type items: list
3948
    @param items: the list of fields to check
3949
    @rtype: list
3950
    @return: list of non-matching fields
3951

3952
    """
3953
    return [val for val in items if not self.Matches(val)]