Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 2632795d

History | View | Annotate | Download (106.1 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  # pylint: disable-msg=F0401
59
  import ctypes
60
except ImportError:
61
  ctypes = None
62

    
63
from ganeti import errors
64
from ganeti import constants
65
from ganeti import compat
66

    
67

    
68
_locksheld = []
69
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
70

    
71
debug_locks = False
72

    
73
#: when set to True, L{RunCmd} is disabled
74
no_fork = False
75

    
76
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
77

    
78
HEX_CHAR_RE = r"[a-zA-Z0-9]"
79
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
80
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
81
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
82
                             HEX_CHAR_RE, HEX_CHAR_RE),
83
                            re.S | re.I)
84

    
85
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
86

    
87
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
88
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
89
#
90
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
91
# pid_t to "int".
92
#
93
# IEEE Std 1003.1-2008:
94
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
95
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
96
_STRUCT_UCRED = "iII"
97
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
98

    
99
# Certificate verification results
100
(CERT_WARNING,
101
 CERT_ERROR) = range(1, 3)
102

    
103
# Flags for mlockall() (from bits/mman.h)
104
_MCL_CURRENT = 1
105
_MCL_FUTURE = 2
106

    
107

    
108
class RunResult(object):
109
  """Holds the result of running external programs.
110

111
  @type exit_code: int
112
  @ivar exit_code: the exit code of the program, or None (if the program
113
      didn't exit())
114
  @type signal: int or None
115
  @ivar signal: the signal that caused the program to finish, or None
116
      (if the program wasn't terminated by a signal)
117
  @type stdout: str
118
  @ivar stdout: the standard output of the program
119
  @type stderr: str
120
  @ivar stderr: the standard error of the program
121
  @type failed: boolean
122
  @ivar failed: True in case the program was
123
      terminated by a signal or exited with a non-zero exit code
124
  @ivar fail_reason: a string detailing the termination reason
125

126
  """
127
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
128
               "failed", "fail_reason", "cmd"]
129

    
130

    
131
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
132
    self.cmd = cmd
133
    self.exit_code = exit_code
134
    self.signal = signal_
135
    self.stdout = stdout
136
    self.stderr = stderr
137
    self.failed = (signal_ is not None or exit_code != 0)
138

    
139
    if self.signal is not None:
140
      self.fail_reason = "terminated by signal %s" % self.signal
141
    elif self.exit_code is not None:
142
      self.fail_reason = "exited with exit code %s" % self.exit_code
143
    else:
144
      self.fail_reason = "unable to determine termination reason"
145

    
146
    if self.failed:
147
      logging.debug("Command '%s' failed (%s); output: %s",
148
                    self.cmd, self.fail_reason, self.output)
149

    
150
  def _GetOutput(self):
151
    """Returns the combined stdout and stderr for easier usage.
152

153
    """
154
    return self.stdout + self.stderr
155

    
156
  output = property(_GetOutput, None, None, "Return full output")
157

    
158

    
159
def _BuildCmdEnvironment(env, reset):
160
  """Builds the environment for an external program.
161

162
  """
163
  if reset:
164
    cmd_env = {}
165
  else:
166
    cmd_env = os.environ.copy()
167
    cmd_env["LC_ALL"] = "C"
168

    
169
  if env is not None:
170
    cmd_env.update(env)
171

    
172
  return cmd_env
173

    
174

    
175
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
176
  """Execute a (shell) command.
177

178
  The command should not read from its standard input, as it will be
179
  closed.
180

181
  @type cmd: string or list
182
  @param cmd: Command to run
183
  @type env: dict
184
  @param env: Additional environment variables
185
  @type output: str
186
  @param output: if desired, the output of the command can be
187
      saved in a file instead of the RunResult instance; this
188
      parameter denotes the file name (if not None)
189
  @type cwd: string
190
  @param cwd: if specified, will be used as the working
191
      directory for the command; the default will be /
192
  @type reset_env: boolean
193
  @param reset_env: whether to reset or keep the default os environment
194
  @rtype: L{RunResult}
195
  @return: RunResult instance
196
  @raise errors.ProgrammerError: if we call this when forks are disabled
197

198
  """
199
  if no_fork:
200
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
201

    
202
  if isinstance(cmd, basestring):
203
    strcmd = cmd
204
    shell = True
205
  else:
206
    cmd = [str(val) for val in cmd]
207
    strcmd = ShellQuoteArgs(cmd)
208
    shell = False
209

    
210
  if output:
211
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
212
  else:
213
    logging.debug("RunCmd %s", strcmd)
214

    
215
  cmd_env = _BuildCmdEnvironment(env, reset_env)
216

    
217
  try:
218
    if output is None:
219
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
220
    else:
221
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
222
      out = err = ""
223
  except OSError, err:
224
    if err.errno == errno.ENOENT:
225
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
226
                               (strcmd, err))
227
    else:
228
      raise
229

    
230
  if status >= 0:
231
    exitcode = status
232
    signal_ = None
233
  else:
234
    exitcode = None
235
    signal_ = -status
236

    
237
  return RunResult(exitcode, signal_, out, err, strcmd)
238

    
239

    
240
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
241
                pidfile=None):
242
  """Start a daemon process after forking twice.
243

244
  @type cmd: string or list
245
  @param cmd: Command to run
246
  @type env: dict
247
  @param env: Additional environment variables
248
  @type cwd: string
249
  @param cwd: Working directory for the program
250
  @type output: string
251
  @param output: Path to file in which to save the output
252
  @type output_fd: int
253
  @param output_fd: File descriptor for output
254
  @type pidfile: string
255
  @param pidfile: Process ID file
256
  @rtype: int
257
  @return: Daemon process ID
258
  @raise errors.ProgrammerError: if we call this when forks are disabled
259

260
  """
261
  if no_fork:
262
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
263
                                 " disabled")
264

    
265
  if output and not (bool(output) ^ (output_fd is not None)):
266
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
267
                                 " specified")
268

    
269
  if isinstance(cmd, basestring):
270
    cmd = ["/bin/sh", "-c", cmd]
271

    
272
  strcmd = ShellQuoteArgs(cmd)
273

    
274
  if output:
275
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
276
  else:
277
    logging.debug("StartDaemon %s", strcmd)
278

    
279
  cmd_env = _BuildCmdEnvironment(env, False)
280

    
281
  # Create pipe for sending PID back
282
  (pidpipe_read, pidpipe_write) = os.pipe()
283
  try:
284
    try:
285
      # Create pipe for sending error messages
286
      (errpipe_read, errpipe_write) = os.pipe()
287
      try:
288
        try:
289
          # First fork
290
          pid = os.fork()
291
          if pid == 0:
292
            try:
293
              # Child process, won't return
294
              _StartDaemonChild(errpipe_read, errpipe_write,
295
                                pidpipe_read, pidpipe_write,
296
                                cmd, cmd_env, cwd,
297
                                output, output_fd, pidfile)
298
            finally:
299
              # Well, maybe child process failed
300
              os._exit(1) # pylint: disable-msg=W0212
301
        finally:
302
          _CloseFDNoErr(errpipe_write)
303

    
304
        # Wait for daemon to be started (or an error message to arrive) and read
305
        # up to 100 KB as an error message
306
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
307
      finally:
308
        _CloseFDNoErr(errpipe_read)
309
    finally:
310
      _CloseFDNoErr(pidpipe_write)
311

    
312
    # Read up to 128 bytes for PID
313
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
314
  finally:
315
    _CloseFDNoErr(pidpipe_read)
316

    
317
  # Try to avoid zombies by waiting for child process
318
  try:
319
    os.waitpid(pid, 0)
320
  except OSError:
321
    pass
322

    
323
  if errormsg:
324
    raise errors.OpExecError("Error when starting daemon process: %r" %
325
                             errormsg)
326

    
327
  try:
328
    return int(pidtext)
329
  except (ValueError, TypeError), err:
330
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
331
                             (pidtext, err))
332

    
333

    
334
def _StartDaemonChild(errpipe_read, errpipe_write,
335
                      pidpipe_read, pidpipe_write,
336
                      args, env, cwd,
337
                      output, fd_output, pidfile):
338
  """Child process for starting daemon.
339

340
  """
341
  try:
342
    # Close parent's side
343
    _CloseFDNoErr(errpipe_read)
344
    _CloseFDNoErr(pidpipe_read)
345

    
346
    # First child process
347
    os.chdir("/")
348
    os.umask(077)
349
    os.setsid()
350

    
351
    # And fork for the second time
352
    pid = os.fork()
353
    if pid != 0:
354
      # Exit first child process
355
      os._exit(0) # pylint: disable-msg=W0212
356

    
357
    # Make sure pipe is closed on execv* (and thereby notifies original process)
358
    SetCloseOnExecFlag(errpipe_write, True)
359

    
360
    # List of file descriptors to be left open
361
    noclose_fds = [errpipe_write]
362

    
363
    # Open PID file
364
    if pidfile:
365
      try:
366
        # TODO: Atomic replace with another locked file instead of writing into
367
        # it after creating
368
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
369

    
370
        # Lock the PID file (and fail if not possible to do so). Any code
371
        # wanting to send a signal to the daemon should try to lock the PID
372
        # file before reading it. If acquiring the lock succeeds, the daemon is
373
        # no longer running and the signal should not be sent.
374
        LockFile(fd_pidfile)
375

    
376
        os.write(fd_pidfile, "%d\n" % os.getpid())
377
      except Exception, err:
378
        raise Exception("Creating and locking PID file failed: %s" % err)
379

    
380
      # Keeping the file open to hold the lock
381
      noclose_fds.append(fd_pidfile)
382

    
383
      SetCloseOnExecFlag(fd_pidfile, False)
384
    else:
385
      fd_pidfile = None
386

    
387
    # Open /dev/null
388
    fd_devnull = os.open(os.devnull, os.O_RDWR)
389

    
390
    assert not output or (bool(output) ^ (fd_output is not None))
391

    
392
    if fd_output is not None:
393
      pass
394
    elif output:
395
      # Open output file
396
      try:
397
        # TODO: Implement flag to set append=yes/no
398
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
399
      except EnvironmentError, err:
400
        raise Exception("Opening output file failed: %s" % err)
401
    else:
402
      fd_output = fd_devnull
403

    
404
    # Redirect standard I/O
405
    os.dup2(fd_devnull, 0)
406
    os.dup2(fd_output, 1)
407
    os.dup2(fd_output, 2)
408

    
409
    # Send daemon PID to parent
410
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
411

    
412
    # Close all file descriptors except stdio and error message pipe
413
    CloseFDs(noclose_fds=noclose_fds)
414

    
415
    # Change working directory
416
    os.chdir(cwd)
417

    
418
    if env is None:
419
      os.execvp(args[0], args)
420
    else:
421
      os.execvpe(args[0], args, env)
422
  except: # pylint: disable-msg=W0702
423
    try:
424
      # Report errors to original process
425
      buf = str(sys.exc_info()[1])
426

    
427
      RetryOnSignal(os.write, errpipe_write, buf)
428
    except: # pylint: disable-msg=W0702
429
      # Ignore errors in error handling
430
      pass
431

    
432
  os._exit(1) # pylint: disable-msg=W0212
433

    
434

    
435
def _RunCmdPipe(cmd, env, via_shell, cwd):
436
  """Run a command and return its output.
437

438
  @type  cmd: string or list
439
  @param cmd: Command to run
440
  @type env: dict
441
  @param env: The environment to use
442
  @type via_shell: bool
443
  @param via_shell: if we should run via the shell
444
  @type cwd: string
445
  @param cwd: the working directory for the program
446
  @rtype: tuple
447
  @return: (out, err, status)
448

449
  """
450
  poller = select.poll()
451
  child = subprocess.Popen(cmd, shell=via_shell,
452
                           stderr=subprocess.PIPE,
453
                           stdout=subprocess.PIPE,
454
                           stdin=subprocess.PIPE,
455
                           close_fds=True, env=env,
456
                           cwd=cwd)
457

    
458
  child.stdin.close()
459
  poller.register(child.stdout, select.POLLIN)
460
  poller.register(child.stderr, select.POLLIN)
461
  out = StringIO()
462
  err = StringIO()
463
  fdmap = {
464
    child.stdout.fileno(): (out, child.stdout),
465
    child.stderr.fileno(): (err, child.stderr),
466
    }
467
  for fd in fdmap:
468
    SetNonblockFlag(fd, True)
469

    
470
  while fdmap:
471
    pollresult = RetryOnSignal(poller.poll)
472

    
473
    for fd, event in pollresult:
474
      if event & select.POLLIN or event & select.POLLPRI:
475
        data = fdmap[fd][1].read()
476
        # no data from read signifies EOF (the same as POLLHUP)
477
        if not data:
478
          poller.unregister(fd)
479
          del fdmap[fd]
480
          continue
481
        fdmap[fd][0].write(data)
482
      if (event & select.POLLNVAL or event & select.POLLHUP or
483
          event & select.POLLERR):
484
        poller.unregister(fd)
485
        del fdmap[fd]
486

    
487
  out = out.getvalue()
488
  err = err.getvalue()
489

    
490
  status = child.wait()
491
  return out, err, status
492

    
493

    
494
def _RunCmdFile(cmd, env, via_shell, output, cwd):
495
  """Run a command and save its output to a file.
496

497
  @type  cmd: string or list
498
  @param cmd: Command to run
499
  @type env: dict
500
  @param env: The environment to use
501
  @type via_shell: bool
502
  @param via_shell: if we should run via the shell
503
  @type output: str
504
  @param output: the filename in which to save the output
505
  @type cwd: string
506
  @param cwd: the working directory for the program
507
  @rtype: int
508
  @return: the exit status
509

510
  """
511
  fh = open(output, "a")
512
  try:
513
    child = subprocess.Popen(cmd, shell=via_shell,
514
                             stderr=subprocess.STDOUT,
515
                             stdout=fh,
516
                             stdin=subprocess.PIPE,
517
                             close_fds=True, env=env,
518
                             cwd=cwd)
519

    
520
    child.stdin.close()
521
    status = child.wait()
522
  finally:
523
    fh.close()
524
  return status
525

    
526

    
527
def SetCloseOnExecFlag(fd, enable):
528
  """Sets or unsets the close-on-exec flag on a file descriptor.
529

530
  @type fd: int
531
  @param fd: File descriptor
532
  @type enable: bool
533
  @param enable: Whether to set or unset it.
534

535
  """
536
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
537

    
538
  if enable:
539
    flags |= fcntl.FD_CLOEXEC
540
  else:
541
    flags &= ~fcntl.FD_CLOEXEC
542

    
543
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
544

    
545

    
546
def SetNonblockFlag(fd, enable):
547
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
548

549
  @type fd: int
550
  @param fd: File descriptor
551
  @type enable: bool
552
  @param enable: Whether to set or unset it
553

554
  """
555
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
556

    
557
  if enable:
558
    flags |= os.O_NONBLOCK
559
  else:
560
    flags &= ~os.O_NONBLOCK
561

    
562
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
563

    
564

    
565
def RetryOnSignal(fn, *args, **kwargs):
566
  """Calls a function again if it failed due to EINTR.
567

568
  """
569
  while True:
570
    try:
571
      return fn(*args, **kwargs)
572
    except EnvironmentError, err:
573
      if err.errno != errno.EINTR:
574
        raise
575
    except (socket.error, select.error), err:
576
      # In python 2.6 and above select.error is an IOError, so it's handled
577
      # above, in 2.5 and below it's not, and it's handled here.
578
      if not (err.args and err.args[0] == errno.EINTR):
579
        raise
580

    
581

    
582
def RunParts(dir_name, env=None, reset_env=False):
583
  """Run Scripts or programs in a directory
584

585
  @type dir_name: string
586
  @param dir_name: absolute path to a directory
587
  @type env: dict
588
  @param env: The environment to use
589
  @type reset_env: boolean
590
  @param reset_env: whether to reset or keep the default os environment
591
  @rtype: list of tuples
592
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
593

594
  """
595
  rr = []
596

    
597
  try:
598
    dir_contents = ListVisibleFiles(dir_name)
599
  except OSError, err:
600
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
601
    return rr
602

    
603
  for relname in sorted(dir_contents):
604
    fname = PathJoin(dir_name, relname)
605
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
606
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
607
      rr.append((relname, constants.RUNPARTS_SKIP, None))
608
    else:
609
      try:
610
        result = RunCmd([fname], env=env, reset_env=reset_env)
611
      except Exception, err: # pylint: disable-msg=W0703
612
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
613
      else:
614
        rr.append((relname, constants.RUNPARTS_RUN, result))
615

    
616
  return rr
617

    
618

    
619
def GetSocketCredentials(sock):
620
  """Returns the credentials of the foreign process connected to a socket.
621

622
  @param sock: Unix socket
623
  @rtype: tuple; (number, number, number)
624
  @return: The PID, UID and GID of the connected foreign process.
625

626
  """
627
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
628
                             _STRUCT_UCRED_SIZE)
629
  return struct.unpack(_STRUCT_UCRED, peercred)
630

    
631

    
632
def RemoveFile(filename):
633
  """Remove a file ignoring some errors.
634

635
  Remove a file, ignoring non-existing ones or directories. Other
636
  errors are passed.
637

638
  @type filename: str
639
  @param filename: the file to be removed
640

641
  """
642
  try:
643
    os.unlink(filename)
644
  except OSError, err:
645
    if err.errno not in (errno.ENOENT, errno.EISDIR):
646
      raise
647

    
648

    
649
def RemoveDir(dirname):
650
  """Remove an empty directory.
651

652
  Remove a directory, ignoring non-existing ones.
653
  Other errors are passed. This includes the case,
654
  where the directory is not empty, so it can't be removed.
655

656
  @type dirname: str
657
  @param dirname: the empty directory to be removed
658

659
  """
660
  try:
661
    os.rmdir(dirname)
662
  except OSError, err:
663
    if err.errno != errno.ENOENT:
664
      raise
665

    
666

    
667
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
668
  """Renames a file.
669

670
  @type old: string
671
  @param old: Original path
672
  @type new: string
673
  @param new: New path
674
  @type mkdir: bool
675
  @param mkdir: Whether to create target directory if it doesn't exist
676
  @type mkdir_mode: int
677
  @param mkdir_mode: Mode for newly created directories
678

679
  """
680
  try:
681
    return os.rename(old, new)
682
  except OSError, err:
683
    # In at least one use case of this function, the job queue, directory
684
    # creation is very rare. Checking for the directory before renaming is not
685
    # as efficient.
686
    if mkdir and err.errno == errno.ENOENT:
687
      # Create directory and try again
688
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
689

    
690
      return os.rename(old, new)
691

    
692
    raise
693

    
694

    
695
def Makedirs(path, mode=0750):
696
  """Super-mkdir; create a leaf directory and all intermediate ones.
697

698
  This is a wrapper around C{os.makedirs} adding error handling not implemented
699
  before Python 2.5.
700

701
  """
702
  try:
703
    os.makedirs(path, mode)
704
  except OSError, err:
705
    # Ignore EEXIST. This is only handled in os.makedirs as included in
706
    # Python 2.5 and above.
707
    if err.errno != errno.EEXIST or not os.path.exists(path):
708
      raise
709

    
710

    
711
def ResetTempfileModule():
712
  """Resets the random name generator of the tempfile module.
713

714
  This function should be called after C{os.fork} in the child process to
715
  ensure it creates a newly seeded random generator. Otherwise it would
716
  generate the same random parts as the parent process. If several processes
717
  race for the creation of a temporary file, this could lead to one not getting
718
  a temporary name.
719

720
  """
721
  # pylint: disable-msg=W0212
722
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
723
    tempfile._once_lock.acquire()
724
    try:
725
      # Reset random name generator
726
      tempfile._name_sequence = None
727
    finally:
728
      tempfile._once_lock.release()
729
  else:
730
    logging.critical("The tempfile module misses at least one of the"
731
                     " '_once_lock' and '_name_sequence' attributes")
732

    
733

    
734
def _FingerprintFile(filename):
735
  """Compute the fingerprint of a file.
736

737
  If the file does not exist, a None will be returned
738
  instead.
739

740
  @type filename: str
741
  @param filename: the filename to checksum
742
  @rtype: str
743
  @return: the hex digest of the sha checksum of the contents
744
      of the file
745

746
  """
747
  if not (os.path.exists(filename) and os.path.isfile(filename)):
748
    return None
749

    
750
  f = open(filename)
751

    
752
  fp = compat.sha1_hash()
753
  while True:
754
    data = f.read(4096)
755
    if not data:
756
      break
757

    
758
    fp.update(data)
759

    
760
  return fp.hexdigest()
761

    
762

    
763
def FingerprintFiles(files):
764
  """Compute fingerprints for a list of files.
765

766
  @type files: list
767
  @param files: the list of filename to fingerprint
768
  @rtype: dict
769
  @return: a dictionary filename: fingerprint, holding only
770
      existing files
771

772
  """
773
  ret = {}
774

    
775
  for filename in files:
776
    cksum = _FingerprintFile(filename)
777
    if cksum:
778
      ret[filename] = cksum
779

    
780
  return ret
781

    
782

    
783
def ForceDictType(target, key_types, allowed_values=None):
784
  """Force the values of a dict to have certain types.
785

786
  @type target: dict
787
  @param target: the dict to update
788
  @type key_types: dict
789
  @param key_types: dict mapping target dict keys to types
790
                    in constants.ENFORCEABLE_TYPES
791
  @type allowed_values: list
792
  @keyword allowed_values: list of specially allowed values
793

794
  """
795
  if allowed_values is None:
796
    allowed_values = []
797

    
798
  if not isinstance(target, dict):
799
    msg = "Expected dictionary, got '%s'" % target
800
    raise errors.TypeEnforcementError(msg)
801

    
802
  for key in target:
803
    if key not in key_types:
804
      msg = "Unknown key '%s'" % key
805
      raise errors.TypeEnforcementError(msg)
806

    
807
    if target[key] in allowed_values:
808
      continue
809

    
810
    ktype = key_types[key]
811
    if ktype not in constants.ENFORCEABLE_TYPES:
812
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
813
      raise errors.ProgrammerError(msg)
814

    
815
    if ktype == constants.VTYPE_STRING:
816
      if not isinstance(target[key], basestring):
817
        if isinstance(target[key], bool) and not target[key]:
818
          target[key] = ''
819
        else:
820
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
821
          raise errors.TypeEnforcementError(msg)
822
    elif ktype == constants.VTYPE_BOOL:
823
      if isinstance(target[key], basestring) and target[key]:
824
        if target[key].lower() == constants.VALUE_FALSE:
825
          target[key] = False
826
        elif target[key].lower() == constants.VALUE_TRUE:
827
          target[key] = True
828
        else:
829
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
830
          raise errors.TypeEnforcementError(msg)
831
      elif target[key]:
832
        target[key] = True
833
      else:
834
        target[key] = False
835
    elif ktype == constants.VTYPE_SIZE:
836
      try:
837
        target[key] = ParseUnit(target[key])
838
      except errors.UnitParseError, err:
839
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
840
              (key, target[key], err)
841
        raise errors.TypeEnforcementError(msg)
842
    elif ktype == constants.VTYPE_INT:
843
      try:
844
        target[key] = int(target[key])
845
      except (ValueError, TypeError):
846
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
847
        raise errors.TypeEnforcementError(msg)
848

    
849

    
850
def _GetProcStatusPath(pid):
851
  """Returns the path for a PID's proc status file.
852

853
  @type pid: int
854
  @param pid: Process ID
855
  @rtype: string
856

857
  """
858
  return "/proc/%d/status" % pid
859

    
860

    
861
def IsProcessAlive(pid):
862
  """Check if a given pid exists on the system.
863

864
  @note: zombie status is not handled, so zombie processes
865
      will be returned as alive
866
  @type pid: int
867
  @param pid: the process ID to check
868
  @rtype: boolean
869
  @return: True if the process exists
870

871
  """
872
  def _TryStat(name):
873
    try:
874
      os.stat(name)
875
      return True
876
    except EnvironmentError, err:
877
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
878
        return False
879
      elif err.errno == errno.EINVAL:
880
        raise RetryAgain(err)
881
      raise
882

    
883
  assert isinstance(pid, int), "pid must be an integer"
884
  if pid <= 0:
885
    return False
886

    
887
  # /proc in a multiprocessor environment can have strange behaviors.
888
  # Retry the os.stat a few times until we get a good result.
889
  try:
890
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
891
                 args=[_GetProcStatusPath(pid)])
892
  except RetryTimeout, err:
893
    err.RaiseInner()
894

    
895

    
896
def _ParseSigsetT(sigset):
897
  """Parse a rendered sigset_t value.
898

899
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
900
  function.
901

902
  @type sigset: string
903
  @param sigset: Rendered signal set from /proc/$pid/status
904
  @rtype: set
905
  @return: Set of all enabled signal numbers
906

907
  """
908
  result = set()
909

    
910
  signum = 0
911
  for ch in reversed(sigset):
912
    chv = int(ch, 16)
913

    
914
    # The following could be done in a loop, but it's easier to read and
915
    # understand in the unrolled form
916
    if chv & 1:
917
      result.add(signum + 1)
918
    if chv & 2:
919
      result.add(signum + 2)
920
    if chv & 4:
921
      result.add(signum + 3)
922
    if chv & 8:
923
      result.add(signum + 4)
924

    
925
    signum += 4
926

    
927
  return result
928

    
929

    
930
def _GetProcStatusField(pstatus, field):
931
  """Retrieves a field from the contents of a proc status file.
932

933
  @type pstatus: string
934
  @param pstatus: Contents of /proc/$pid/status
935
  @type field: string
936
  @param field: Name of field whose value should be returned
937
  @rtype: string
938

939
  """
940
  for line in pstatus.splitlines():
941
    parts = line.split(":", 1)
942

    
943
    if len(parts) < 2 or parts[0] != field:
944
      continue
945

    
946
    return parts[1].strip()
947

    
948
  return None
949

    
950

    
951
def IsProcessHandlingSignal(pid, signum, status_path=None):
952
  """Checks whether a process is handling a signal.
953

954
  @type pid: int
955
  @param pid: Process ID
956
  @type signum: int
957
  @param signum: Signal number
958
  @rtype: bool
959

960
  """
961
  if status_path is None:
962
    status_path = _GetProcStatusPath(pid)
963

    
964
  try:
965
    proc_status = ReadFile(status_path)
966
  except EnvironmentError, err:
967
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
968
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
969
      return False
970
    raise
971

    
972
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
973
  if sigcgt is None:
974
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
975

    
976
  # Now check whether signal is handled
977
  return signum in _ParseSigsetT(sigcgt)
978

    
979

    
980
def ReadPidFile(pidfile):
981
  """Read a pid from a file.
982

983
  @type  pidfile: string
984
  @param pidfile: path to the file containing the pid
985
  @rtype: int
986
  @return: The process id, if the file exists and contains a valid PID,
987
           otherwise 0
988

989
  """
990
  try:
991
    raw_data = ReadOneLineFile(pidfile)
992
  except EnvironmentError, err:
993
    if err.errno != errno.ENOENT:
994
      logging.exception("Can't read pid file")
995
    return 0
996

    
997
  try:
998
    pid = int(raw_data)
999
  except (TypeError, ValueError), err:
1000
    logging.info("Can't parse pid file contents", exc_info=True)
1001
    return 0
1002

    
1003
  return pid
1004

    
1005

    
1006
def ReadLockedPidFile(path):
1007
  """Reads a locked PID file.
1008

1009
  This can be used together with L{StartDaemon}.
1010

1011
  @type path: string
1012
  @param path: Path to PID file
1013
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1014

1015
  """
1016
  try:
1017
    fd = os.open(path, os.O_RDONLY)
1018
  except EnvironmentError, err:
1019
    if err.errno == errno.ENOENT:
1020
      # PID file doesn't exist
1021
      return None
1022
    raise
1023

    
1024
  try:
1025
    try:
1026
      # Try to acquire lock
1027
      LockFile(fd)
1028
    except errors.LockError:
1029
      # Couldn't lock, daemon is running
1030
      return int(os.read(fd, 100))
1031
  finally:
1032
    os.close(fd)
1033

    
1034
  return None
1035

    
1036

    
1037
def MatchNameComponent(key, name_list, case_sensitive=True):
1038
  """Try to match a name against a list.
1039

1040
  This function will try to match a name like test1 against a list
1041
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1042
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1043
  not I{'test1.ex'}. A multiple match will be considered as no match
1044
  at all (e.g. I{'test1'} against C{['test1.example.com',
1045
  'test1.example.org']}), except when the key fully matches an entry
1046
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1047

1048
  @type key: str
1049
  @param key: the name to be searched
1050
  @type name_list: list
1051
  @param name_list: the list of strings against which to search the key
1052
  @type case_sensitive: boolean
1053
  @param case_sensitive: whether to provide a case-sensitive match
1054

1055
  @rtype: None or str
1056
  @return: None if there is no match I{or} if there are multiple matches,
1057
      otherwise the element from the list which matches
1058

1059
  """
1060
  if key in name_list:
1061
    return key
1062

    
1063
  re_flags = 0
1064
  if not case_sensitive:
1065
    re_flags |= re.IGNORECASE
1066
    key = key.upper()
1067
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1068
  names_filtered = []
1069
  string_matches = []
1070
  for name in name_list:
1071
    if mo.match(name) is not None:
1072
      names_filtered.append(name)
1073
      if not case_sensitive and key == name.upper():
1074
        string_matches.append(name)
1075

    
1076
  if len(string_matches) == 1:
1077
    return string_matches[0]
1078
  if len(names_filtered) == 1:
1079
    return names_filtered[0]
1080
  return None
1081

    
1082

    
1083
class HostInfo:
1084
  """Class implementing resolver and hostname functionality
1085

1086
  """
1087
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1088

    
1089
  def __init__(self, name=None):
1090
    """Initialize the host name object.
1091

1092
    If the name argument is not passed, it will use this system's
1093
    name.
1094

1095
    """
1096
    if name is None:
1097
      name = self.SysName()
1098

    
1099
    self.query = name
1100
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1101
    self.ip = self.ipaddrs[0]
1102

    
1103
  def ShortName(self):
1104
    """Returns the hostname without domain.
1105

1106
    """
1107
    return self.name.split('.')[0]
1108

    
1109
  @staticmethod
1110
  def SysName():
1111
    """Return the current system's name.
1112

1113
    This is simply a wrapper over C{socket.gethostname()}.
1114

1115
    """
1116
    return socket.gethostname()
1117

    
1118
  @staticmethod
1119
  def LookupHostname(hostname):
1120
    """Look up hostname
1121

1122
    @type hostname: str
1123
    @param hostname: hostname to look up
1124

1125
    @rtype: tuple
1126
    @return: a tuple (name, aliases, ipaddrs) as returned by
1127
        C{socket.gethostbyname_ex}
1128
    @raise errors.ResolverError: in case of errors in resolving
1129

1130
    """
1131
    try:
1132
      result = socket.gethostbyname_ex(hostname)
1133
    except (socket.gaierror, socket.herror, socket.error), err:
1134
      # hostname not found in DNS, or other socket exception in the
1135
      # (code, description format)
1136
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1137

    
1138
    return result
1139

    
1140
  @classmethod
1141
  def NormalizeName(cls, hostname):
1142
    """Validate and normalize the given hostname.
1143

1144
    @attention: the validation is a bit more relaxed than the standards
1145
        require; most importantly, we allow underscores in names
1146
    @raise errors.OpPrereqError: when the name is not valid
1147

1148
    """
1149
    hostname = hostname.lower()
1150
    if (not cls._VALID_NAME_RE.match(hostname) or
1151
        # double-dots, meaning empty label
1152
        ".." in hostname or
1153
        # empty initial label
1154
        hostname.startswith(".")):
1155
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1156
                                 errors.ECODE_INVAL)
1157
    if hostname.endswith("."):
1158
      hostname = hostname.rstrip(".")
1159
    return hostname
1160

    
1161

    
1162
def ValidateServiceName(name):
1163
  """Validate the given service name.
1164

1165
  @type name: number or string
1166
  @param name: Service name or port specification
1167

1168
  """
1169
  try:
1170
    numport = int(name)
1171
  except (ValueError, TypeError):
1172
    # Non-numeric service name
1173
    valid = _VALID_SERVICE_NAME_RE.match(name)
1174
  else:
1175
    # Numeric port (protocols other than TCP or UDP might need adjustments
1176
    # here)
1177
    valid = (numport >= 0 and numport < (1 << 16))
1178

    
1179
  if not valid:
1180
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1181
                               errors.ECODE_INVAL)
1182

    
1183
  return name
1184

    
1185

    
1186
def GetHostInfo(name=None):
1187
  """Lookup host name and raise an OpPrereqError for failures"""
1188

    
1189
  try:
1190
    return HostInfo(name)
1191
  except errors.ResolverError, err:
1192
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1193
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1194

    
1195

    
1196
def ListVolumeGroups():
1197
  """List volume groups and their size
1198

1199
  @rtype: dict
1200
  @return:
1201
       Dictionary with keys volume name and values
1202
       the size of the volume
1203

1204
  """
1205
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1206
  result = RunCmd(command)
1207
  retval = {}
1208
  if result.failed:
1209
    return retval
1210

    
1211
  for line in result.stdout.splitlines():
1212
    try:
1213
      name, size = line.split()
1214
      size = int(float(size))
1215
    except (IndexError, ValueError), err:
1216
      logging.error("Invalid output from vgs (%s): %s", err, line)
1217
      continue
1218

    
1219
    retval[name] = size
1220

    
1221
  return retval
1222

    
1223

    
1224
def BridgeExists(bridge):
1225
  """Check whether the given bridge exists in the system
1226

1227
  @type bridge: str
1228
  @param bridge: the bridge name to check
1229
  @rtype: boolean
1230
  @return: True if it does
1231

1232
  """
1233
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1234

    
1235

    
1236
def NiceSort(name_list):
1237
  """Sort a list of strings based on digit and non-digit groupings.
1238

1239
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1240
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1241
  'a11']}.
1242

1243
  The sort algorithm breaks each name in groups of either only-digits
1244
  or no-digits. Only the first eight such groups are considered, and
1245
  after that we just use what's left of the string.
1246

1247
  @type name_list: list
1248
  @param name_list: the names to be sorted
1249
  @rtype: list
1250
  @return: a copy of the name list sorted with our algorithm
1251

1252
  """
1253
  _SORTER_BASE = "(\D+|\d+)"
1254
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1255
                                                  _SORTER_BASE, _SORTER_BASE,
1256
                                                  _SORTER_BASE, _SORTER_BASE,
1257
                                                  _SORTER_BASE, _SORTER_BASE)
1258
  _SORTER_RE = re.compile(_SORTER_FULL)
1259
  _SORTER_NODIGIT = re.compile("^\D*$")
1260
  def _TryInt(val):
1261
    """Attempts to convert a variable to integer."""
1262
    if val is None or _SORTER_NODIGIT.match(val):
1263
      return val
1264
    rval = int(val)
1265
    return rval
1266

    
1267
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1268
             for name in name_list]
1269
  to_sort.sort()
1270
  return [tup[1] for tup in to_sort]
1271

    
1272

    
1273
def TryConvert(fn, val):
1274
  """Try to convert a value ignoring errors.
1275

1276
  This function tries to apply function I{fn} to I{val}. If no
1277
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1278
  the result, else it will return the original value. Any other
1279
  exceptions are propagated to the caller.
1280

1281
  @type fn: callable
1282
  @param fn: function to apply to the value
1283
  @param val: the value to be converted
1284
  @return: The converted value if the conversion was successful,
1285
      otherwise the original value.
1286

1287
  """
1288
  try:
1289
    nv = fn(val)
1290
  except (ValueError, TypeError):
1291
    nv = val
1292
  return nv
1293

    
1294

    
1295
def _GenericIsValidIP(family, ip):
1296
  """Generic internal version of ip validation.
1297

1298
  @type family: int
1299
  @param family: socket.AF_INET | socket.AF_INET6
1300
  @type ip: str
1301
  @param ip: the address to be checked
1302
  @rtype: boolean
1303
  @return: True if ip is valid, False otherwise
1304

1305
  """
1306
  try:
1307
    socket.inet_pton(family, ip)
1308
    return True
1309
  except socket.error:
1310
    return False
1311

    
1312

    
1313
def IsValidIP4(ip):
1314
  """Verifies an IPv4 address.
1315

1316
  This function checks if the given address is a valid IPv4 address.
1317

1318
  @type ip: str
1319
  @param ip: the address to be checked
1320
  @rtype: boolean
1321
  @return: True if ip is valid, False otherwise
1322

1323
  """
1324
  return _GenericIsValidIP(socket.AF_INET, ip)
1325

    
1326

    
1327
def IsValidIP6(ip):
1328
  """Verifies an IPv6 address.
1329

1330
  This function checks if the given address is a valid IPv6 address.
1331

1332
  @type ip: str
1333
  @param ip: the address to be checked
1334
  @rtype: boolean
1335
  @return: True if ip is valid, False otherwise
1336

1337
  """
1338
  return _GenericIsValidIP(socket.AF_INET6, ip)
1339

    
1340

    
1341
def IsValidIP(ip):
1342
  """Verifies an IP address.
1343

1344
  This function checks if the given IP address (both IPv4 and IPv6) is valid.
1345

1346
  @type ip: str
1347
  @param ip: the address to be checked
1348
  @rtype: boolean
1349
  @return: True if ip is valid, False otherwise
1350

1351
  """
1352
  return IsValidIP4(ip) or IsValidIP6(ip)
1353

    
1354

    
1355
def GetAddressFamily(ip):
1356
  """Get the address family of the given address.
1357

1358
  @type ip: str
1359
  @param ip: ip address whose family will be returned
1360
  @rtype: int
1361
  @return: socket.AF_INET or socket.AF_INET6
1362
  @raise errors.GenericError: for invalid addresses
1363

1364
  """
1365
  if IsValidIP6(ip):
1366
    return socket.AF_INET6
1367
  elif IsValidIP4(ip):
1368
    return socket.AF_INET
1369
  else:
1370
    raise errors.GenericError("Address %s not valid" % ip)
1371

    
1372

    
1373
def IsValidShellParam(word):
1374
  """Verifies is the given word is safe from the shell's p.o.v.
1375

1376
  This means that we can pass this to a command via the shell and be
1377
  sure that it doesn't alter the command line and is passed as such to
1378
  the actual command.
1379

1380
  Note that we are overly restrictive here, in order to be on the safe
1381
  side.
1382

1383
  @type word: str
1384
  @param word: the word to check
1385
  @rtype: boolean
1386
  @return: True if the word is 'safe'
1387

1388
  """
1389
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1390

    
1391

    
1392
def BuildShellCmd(template, *args):
1393
  """Build a safe shell command line from the given arguments.
1394

1395
  This function will check all arguments in the args list so that they
1396
  are valid shell parameters (i.e. they don't contain shell
1397
  metacharacters). If everything is ok, it will return the result of
1398
  template % args.
1399

1400
  @type template: str
1401
  @param template: the string holding the template for the
1402
      string formatting
1403
  @rtype: str
1404
  @return: the expanded command line
1405

1406
  """
1407
  for word in args:
1408
    if not IsValidShellParam(word):
1409
      raise errors.ProgrammerError("Shell argument '%s' contains"
1410
                                   " invalid characters" % word)
1411
  return template % args
1412

    
1413

    
1414
def FormatUnit(value, units):
1415
  """Formats an incoming number of MiB with the appropriate unit.
1416

1417
  @type value: int
1418
  @param value: integer representing the value in MiB (1048576)
1419
  @type units: char
1420
  @param units: the type of formatting we should do:
1421
      - 'h' for automatic scaling
1422
      - 'm' for MiBs
1423
      - 'g' for GiBs
1424
      - 't' for TiBs
1425
  @rtype: str
1426
  @return: the formatted value (with suffix)
1427

1428
  """
1429
  if units not in ('m', 'g', 't', 'h'):
1430
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1431

    
1432
  suffix = ''
1433

    
1434
  if units == 'm' or (units == 'h' and value < 1024):
1435
    if units == 'h':
1436
      suffix = 'M'
1437
    return "%d%s" % (round(value, 0), suffix)
1438

    
1439
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1440
    if units == 'h':
1441
      suffix = 'G'
1442
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1443

    
1444
  else:
1445
    if units == 'h':
1446
      suffix = 'T'
1447
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1448

    
1449

    
1450
def ParseUnit(input_string):
1451
  """Tries to extract number and scale from the given string.
1452

1453
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1454
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1455
  is always an int in MiB.
1456

1457
  """
1458
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1459
  if not m:
1460
    raise errors.UnitParseError("Invalid format")
1461

    
1462
  value = float(m.groups()[0])
1463

    
1464
  unit = m.groups()[1]
1465
  if unit:
1466
    lcunit = unit.lower()
1467
  else:
1468
    lcunit = 'm'
1469

    
1470
  if lcunit in ('m', 'mb', 'mib'):
1471
    # Value already in MiB
1472
    pass
1473

    
1474
  elif lcunit in ('g', 'gb', 'gib'):
1475
    value *= 1024
1476

    
1477
  elif lcunit in ('t', 'tb', 'tib'):
1478
    value *= 1024 * 1024
1479

    
1480
  else:
1481
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1482

    
1483
  # Make sure we round up
1484
  if int(value) < value:
1485
    value += 1
1486

    
1487
  # Round up to the next multiple of 4
1488
  value = int(value)
1489
  if value % 4:
1490
    value += 4 - value % 4
1491

    
1492
  return value
1493

    
1494

    
1495
def AddAuthorizedKey(file_name, key):
1496
  """Adds an SSH public key to an authorized_keys file.
1497

1498
  @type file_name: str
1499
  @param file_name: path to authorized_keys file
1500
  @type key: str
1501
  @param key: string containing key
1502

1503
  """
1504
  key_fields = key.split()
1505

    
1506
  f = open(file_name, 'a+')
1507
  try:
1508
    nl = True
1509
    for line in f:
1510
      # Ignore whitespace changes
1511
      if line.split() == key_fields:
1512
        break
1513
      nl = line.endswith('\n')
1514
    else:
1515
      if not nl:
1516
        f.write("\n")
1517
      f.write(key.rstrip('\r\n'))
1518
      f.write("\n")
1519
      f.flush()
1520
  finally:
1521
    f.close()
1522

    
1523

    
1524
def RemoveAuthorizedKey(file_name, key):
1525
  """Removes an SSH public key from an authorized_keys file.
1526

1527
  @type file_name: str
1528
  @param file_name: path to authorized_keys file
1529
  @type key: str
1530
  @param key: string containing key
1531

1532
  """
1533
  key_fields = key.split()
1534

    
1535
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1536
  try:
1537
    out = os.fdopen(fd, 'w')
1538
    try:
1539
      f = open(file_name, 'r')
1540
      try:
1541
        for line in f:
1542
          # Ignore whitespace changes while comparing lines
1543
          if line.split() != key_fields:
1544
            out.write(line)
1545

    
1546
        out.flush()
1547
        os.rename(tmpname, file_name)
1548
      finally:
1549
        f.close()
1550
    finally:
1551
      out.close()
1552
  except:
1553
    RemoveFile(tmpname)
1554
    raise
1555

    
1556

    
1557
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1558
  """Sets the name of an IP address and hostname in /etc/hosts.
1559

1560
  @type file_name: str
1561
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1562
  @type ip: str
1563
  @param ip: the IP address
1564
  @type hostname: str
1565
  @param hostname: the hostname to be added
1566
  @type aliases: list
1567
  @param aliases: the list of aliases to add for the hostname
1568

1569
  """
1570
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1571
  # Ensure aliases are unique
1572
  aliases = UniqueSequence([hostname] + aliases)[1:]
1573

    
1574
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1575
  try:
1576
    out = os.fdopen(fd, 'w')
1577
    try:
1578
      f = open(file_name, 'r')
1579
      try:
1580
        for line in f:
1581
          fields = line.split()
1582
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1583
            continue
1584
          out.write(line)
1585

    
1586
        out.write("%s\t%s" % (ip, hostname))
1587
        if aliases:
1588
          out.write(" %s" % ' '.join(aliases))
1589
        out.write('\n')
1590

    
1591
        out.flush()
1592
        os.fsync(out)
1593
        os.chmod(tmpname, 0644)
1594
        os.rename(tmpname, file_name)
1595
      finally:
1596
        f.close()
1597
    finally:
1598
      out.close()
1599
  except:
1600
    RemoveFile(tmpname)
1601
    raise
1602

    
1603

    
1604
def AddHostToEtcHosts(hostname):
1605
  """Wrapper around SetEtcHostsEntry.
1606

1607
  @type hostname: str
1608
  @param hostname: a hostname that will be resolved and added to
1609
      L{constants.ETC_HOSTS}
1610

1611
  """
1612
  hi = HostInfo(name=hostname)
1613
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1614

    
1615

    
1616
def RemoveEtcHostsEntry(file_name, hostname):
1617
  """Removes a hostname from /etc/hosts.
1618

1619
  IP addresses without names are removed from the file.
1620

1621
  @type file_name: str
1622
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1623
  @type hostname: str
1624
  @param hostname: the hostname to be removed
1625

1626
  """
1627
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1628
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1629
  try:
1630
    out = os.fdopen(fd, 'w')
1631
    try:
1632
      f = open(file_name, 'r')
1633
      try:
1634
        for line in f:
1635
          fields = line.split()
1636
          if len(fields) > 1 and not fields[0].startswith('#'):
1637
            names = fields[1:]
1638
            if hostname in names:
1639
              while hostname in names:
1640
                names.remove(hostname)
1641
              if names:
1642
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1643
              continue
1644

    
1645
          out.write(line)
1646

    
1647
        out.flush()
1648
        os.fsync(out)
1649
        os.chmod(tmpname, 0644)
1650
        os.rename(tmpname, file_name)
1651
      finally:
1652
        f.close()
1653
    finally:
1654
      out.close()
1655
  except:
1656
    RemoveFile(tmpname)
1657
    raise
1658

    
1659

    
1660
def RemoveHostFromEtcHosts(hostname):
1661
  """Wrapper around RemoveEtcHostsEntry.
1662

1663
  @type hostname: str
1664
  @param hostname: hostname that will be resolved and its
1665
      full and shot name will be removed from
1666
      L{constants.ETC_HOSTS}
1667

1668
  """
1669
  hi = HostInfo(name=hostname)
1670
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1671
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1672

    
1673

    
1674
def TimestampForFilename():
1675
  """Returns the current time formatted for filenames.
1676

1677
  The format doesn't contain colons as some shells and applications them as
1678
  separators.
1679

1680
  """
1681
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1682

    
1683

    
1684
def CreateBackup(file_name):
1685
  """Creates a backup of a file.
1686

1687
  @type file_name: str
1688
  @param file_name: file to be backed up
1689
  @rtype: str
1690
  @return: the path to the newly created backup
1691
  @raise errors.ProgrammerError: for invalid file names
1692

1693
  """
1694
  if not os.path.isfile(file_name):
1695
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1696
                                file_name)
1697

    
1698
  prefix = ("%s.backup-%s." %
1699
            (os.path.basename(file_name), TimestampForFilename()))
1700
  dir_name = os.path.dirname(file_name)
1701

    
1702
  fsrc = open(file_name, 'rb')
1703
  try:
1704
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1705
    fdst = os.fdopen(fd, 'wb')
1706
    try:
1707
      logging.debug("Backing up %s at %s", file_name, backup_name)
1708
      shutil.copyfileobj(fsrc, fdst)
1709
    finally:
1710
      fdst.close()
1711
  finally:
1712
    fsrc.close()
1713

    
1714
  return backup_name
1715

    
1716

    
1717
def ShellQuote(value):
1718
  """Quotes shell argument according to POSIX.
1719

1720
  @type value: str
1721
  @param value: the argument to be quoted
1722
  @rtype: str
1723
  @return: the quoted value
1724

1725
  """
1726
  if _re_shell_unquoted.match(value):
1727
    return value
1728
  else:
1729
    return "'%s'" % value.replace("'", "'\\''")
1730

    
1731

    
1732
def ShellQuoteArgs(args):
1733
  """Quotes a list of shell arguments.
1734

1735
  @type args: list
1736
  @param args: list of arguments to be quoted
1737
  @rtype: str
1738
  @return: the quoted arguments concatenated with spaces
1739

1740
  """
1741
  return ' '.join([ShellQuote(i) for i in args])
1742

    
1743

    
1744
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1745
  """Simple ping implementation using TCP connect(2).
1746

1747
  Check if the given IP is reachable by doing attempting a TCP connect
1748
  to it.
1749

1750
  @type target: str
1751
  @param target: the IP or hostname to ping
1752
  @type port: int
1753
  @param port: the port to connect to
1754
  @type timeout: int
1755
  @param timeout: the timeout on the connection attempt
1756
  @type live_port_needed: boolean
1757
  @param live_port_needed: whether a closed port will cause the
1758
      function to return failure, as if there was a timeout
1759
  @type source: str or None
1760
  @param source: if specified, will cause the connect to be made
1761
      from this specific source address; failures to bind other
1762
      than C{EADDRNOTAVAIL} will be ignored
1763

1764
  """
1765
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1766

    
1767
  success = False
1768

    
1769
  if source is not None:
1770
    try:
1771
      sock.bind((source, 0))
1772
    except socket.error, (errcode, _):
1773
      if errcode == errno.EADDRNOTAVAIL:
1774
        success = False
1775

    
1776
  sock.settimeout(timeout)
1777

    
1778
  try:
1779
    sock.connect((target, port))
1780
    sock.close()
1781
    success = True
1782
  except socket.timeout:
1783
    success = False
1784
  except socket.error, (errcode, _):
1785
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1786

    
1787
  return success
1788

    
1789

    
1790
def OwnIpAddress(address):
1791
  """Check if the current host has the the given IP address.
1792

1793
  Currently this is done by TCP-pinging the address from the loopback
1794
  address.
1795

1796
  @type address: string
1797
  @param address: the address to check
1798
  @rtype: bool
1799
  @return: True if we own the address
1800

1801
  """
1802
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1803
                 source=constants.IP4_ADDRESS_LOCALHOST)
1804

    
1805

    
1806
def ListVisibleFiles(path):
1807
  """Returns a list of visible files in a directory.
1808

1809
  @type path: str
1810
  @param path: the directory to enumerate
1811
  @rtype: list
1812
  @return: the list of all files not starting with a dot
1813
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1814

1815
  """
1816
  if not IsNormAbsPath(path):
1817
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1818
                                 " absolute/normalized: '%s'" % path)
1819
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1820
  return files
1821

    
1822

    
1823
def GetHomeDir(user, default=None):
1824
  """Try to get the homedir of the given user.
1825

1826
  The user can be passed either as a string (denoting the name) or as
1827
  an integer (denoting the user id). If the user is not found, the
1828
  'default' argument is returned, which defaults to None.
1829

1830
  """
1831
  try:
1832
    if isinstance(user, basestring):
1833
      result = pwd.getpwnam(user)
1834
    elif isinstance(user, (int, long)):
1835
      result = pwd.getpwuid(user)
1836
    else:
1837
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1838
                                   type(user))
1839
  except KeyError:
1840
    return default
1841
  return result.pw_dir
1842

    
1843

    
1844
def NewUUID():
1845
  """Returns a random UUID.
1846

1847
  @note: This is a Linux-specific method as it uses the /proc
1848
      filesystem.
1849
  @rtype: str
1850

1851
  """
1852
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1853

    
1854

    
1855
def GenerateSecret(numbytes=20):
1856
  """Generates a random secret.
1857

1858
  This will generate a pseudo-random secret returning an hex string
1859
  (so that it can be used where an ASCII string is needed).
1860

1861
  @param numbytes: the number of bytes which will be represented by the returned
1862
      string (defaulting to 20, the length of a SHA1 hash)
1863
  @rtype: str
1864
  @return: an hex representation of the pseudo-random sequence
1865

1866
  """
1867
  return os.urandom(numbytes).encode('hex')
1868

    
1869

    
1870
def EnsureDirs(dirs):
1871
  """Make required directories, if they don't exist.
1872

1873
  @param dirs: list of tuples (dir_name, dir_mode)
1874
  @type dirs: list of (string, integer)
1875

1876
  """
1877
  for dir_name, dir_mode in dirs:
1878
    try:
1879
      os.mkdir(dir_name, dir_mode)
1880
    except EnvironmentError, err:
1881
      if err.errno != errno.EEXIST:
1882
        raise errors.GenericError("Cannot create needed directory"
1883
                                  " '%s': %s" % (dir_name, err))
1884
    try:
1885
      os.chmod(dir_name, dir_mode)
1886
    except EnvironmentError, err:
1887
      raise errors.GenericError("Cannot change directory permissions on"
1888
                                " '%s': %s" % (dir_name, err))
1889
    if not os.path.isdir(dir_name):
1890
      raise errors.GenericError("%s is not a directory" % dir_name)
1891

    
1892

    
1893
def ReadFile(file_name, size=-1):
1894
  """Reads a file.
1895

1896
  @type size: int
1897
  @param size: Read at most size bytes (if negative, entire file)
1898
  @rtype: str
1899
  @return: the (possibly partial) content of the file
1900

1901
  """
1902
  f = open(file_name, "r")
1903
  try:
1904
    return f.read(size)
1905
  finally:
1906
    f.close()
1907

    
1908

    
1909
def WriteFile(file_name, fn=None, data=None,
1910
              mode=None, uid=-1, gid=-1,
1911
              atime=None, mtime=None, close=True,
1912
              dry_run=False, backup=False,
1913
              prewrite=None, postwrite=None):
1914
  """(Over)write a file atomically.
1915

1916
  The file_name and either fn (a function taking one argument, the
1917
  file descriptor, and which should write the data to it) or data (the
1918
  contents of the file) must be passed. The other arguments are
1919
  optional and allow setting the file mode, owner and group, and the
1920
  mtime/atime of the file.
1921

1922
  If the function doesn't raise an exception, it has succeeded and the
1923
  target file has the new contents. If the function has raised an
1924
  exception, an existing target file should be unmodified and the
1925
  temporary file should be removed.
1926

1927
  @type file_name: str
1928
  @param file_name: the target filename
1929
  @type fn: callable
1930
  @param fn: content writing function, called with
1931
      file descriptor as parameter
1932
  @type data: str
1933
  @param data: contents of the file
1934
  @type mode: int
1935
  @param mode: file mode
1936
  @type uid: int
1937
  @param uid: the owner of the file
1938
  @type gid: int
1939
  @param gid: the group of the file
1940
  @type atime: int
1941
  @param atime: a custom access time to be set on the file
1942
  @type mtime: int
1943
  @param mtime: a custom modification time to be set on the file
1944
  @type close: boolean
1945
  @param close: whether to close file after writing it
1946
  @type prewrite: callable
1947
  @param prewrite: function to be called before writing content
1948
  @type postwrite: callable
1949
  @param postwrite: function to be called after writing content
1950

1951
  @rtype: None or int
1952
  @return: None if the 'close' parameter evaluates to True,
1953
      otherwise the file descriptor
1954

1955
  @raise errors.ProgrammerError: if any of the arguments are not valid
1956

1957
  """
1958
  if not os.path.isabs(file_name):
1959
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1960
                                 " absolute: '%s'" % file_name)
1961

    
1962
  if [fn, data].count(None) != 1:
1963
    raise errors.ProgrammerError("fn or data required")
1964

    
1965
  if [atime, mtime].count(None) == 1:
1966
    raise errors.ProgrammerError("Both atime and mtime must be either"
1967
                                 " set or None")
1968

    
1969
  if backup and not dry_run and os.path.isfile(file_name):
1970
    CreateBackup(file_name)
1971

    
1972
  dir_name, base_name = os.path.split(file_name)
1973
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1974
  do_remove = True
1975
  # here we need to make sure we remove the temp file, if any error
1976
  # leaves it in place
1977
  try:
1978
    if uid != -1 or gid != -1:
1979
      os.chown(new_name, uid, gid)
1980
    if mode:
1981
      os.chmod(new_name, mode)
1982
    if callable(prewrite):
1983
      prewrite(fd)
1984
    if data is not None:
1985
      os.write(fd, data)
1986
    else:
1987
      fn(fd)
1988
    if callable(postwrite):
1989
      postwrite(fd)
1990
    os.fsync(fd)
1991
    if atime is not None and mtime is not None:
1992
      os.utime(new_name, (atime, mtime))
1993
    if not dry_run:
1994
      os.rename(new_name, file_name)
1995
      do_remove = False
1996
  finally:
1997
    if close:
1998
      os.close(fd)
1999
      result = None
2000
    else:
2001
      result = fd
2002
    if do_remove:
2003
      RemoveFile(new_name)
2004

    
2005
  return result
2006

    
2007

    
2008
def ReadOneLineFile(file_name, strict=False):
2009
  """Return the first non-empty line from a file.
2010

2011
  @type strict: boolean
2012
  @param strict: if True, abort if the file has more than one
2013
      non-empty line
2014

2015
  """
2016
  file_lines = ReadFile(file_name).splitlines()
2017
  full_lines = filter(bool, file_lines)
2018
  if not file_lines or not full_lines:
2019
    raise errors.GenericError("No data in one-liner file %s" % file_name)
2020
  elif strict and len(full_lines) > 1:
2021
    raise errors.GenericError("Too many lines in one-liner file %s" %
2022
                              file_name)
2023
  return full_lines[0]
2024

    
2025

    
2026
def FirstFree(seq, base=0):
2027
  """Returns the first non-existing integer from seq.
2028

2029
  The seq argument should be a sorted list of positive integers. The
2030
  first time the index of an element is smaller than the element
2031
  value, the index will be returned.
2032

2033
  The base argument is used to start at a different offset,
2034
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
2035

2036
  Example: C{[0, 1, 3]} will return I{2}.
2037

2038
  @type seq: sequence
2039
  @param seq: the sequence to be analyzed.
2040
  @type base: int
2041
  @param base: use this value as the base index of the sequence
2042
  @rtype: int
2043
  @return: the first non-used index in the sequence
2044

2045
  """
2046
  for idx, elem in enumerate(seq):
2047
    assert elem >= base, "Passed element is higher than base offset"
2048
    if elem > idx + base:
2049
      # idx is not used
2050
      return idx + base
2051
  return None
2052

    
2053

    
2054
def SingleWaitForFdCondition(fdobj, event, timeout):
2055
  """Waits for a condition to occur on the socket.
2056

2057
  Immediately returns at the first interruption.
2058

2059
  @type fdobj: integer or object supporting a fileno() method
2060
  @param fdobj: entity to wait for events on
2061
  @type event: integer
2062
  @param event: ORed condition (see select module)
2063
  @type timeout: float or None
2064
  @param timeout: Timeout in seconds
2065
  @rtype: int or None
2066
  @return: None for timeout, otherwise occured conditions
2067

2068
  """
2069
  check = (event | select.POLLPRI |
2070
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
2071

    
2072
  if timeout is not None:
2073
    # Poller object expects milliseconds
2074
    timeout *= 1000
2075

    
2076
  poller = select.poll()
2077
  poller.register(fdobj, event)
2078
  try:
2079
    # TODO: If the main thread receives a signal and we have no timeout, we
2080
    # could wait forever. This should check a global "quit" flag or something
2081
    # every so often.
2082
    io_events = poller.poll(timeout)
2083
  except select.error, err:
2084
    if err[0] != errno.EINTR:
2085
      raise
2086
    io_events = []
2087
  if io_events and io_events[0][1] & check:
2088
    return io_events[0][1]
2089
  else:
2090
    return None
2091

    
2092

    
2093
class FdConditionWaiterHelper(object):
2094
  """Retry helper for WaitForFdCondition.
2095

2096
  This class contains the retried and wait functions that make sure
2097
  WaitForFdCondition can continue waiting until the timeout is actually
2098
  expired.
2099

2100
  """
2101

    
2102
  def __init__(self, timeout):
2103
    self.timeout = timeout
2104

    
2105
  def Poll(self, fdobj, event):
2106
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2107
    if result is None:
2108
      raise RetryAgain()
2109
    else:
2110
      return result
2111

    
2112
  def UpdateTimeout(self, timeout):
2113
    self.timeout = timeout
2114

    
2115

    
2116
def WaitForFdCondition(fdobj, event, timeout):
2117
  """Waits for a condition to occur on the socket.
2118

2119
  Retries until the timeout is expired, even if interrupted.
2120

2121
  @type fdobj: integer or object supporting a fileno() method
2122
  @param fdobj: entity to wait for events on
2123
  @type event: integer
2124
  @param event: ORed condition (see select module)
2125
  @type timeout: float or None
2126
  @param timeout: Timeout in seconds
2127
  @rtype: int or None
2128
  @return: None for timeout, otherwise occured conditions
2129

2130
  """
2131
  if timeout is not None:
2132
    retrywaiter = FdConditionWaiterHelper(timeout)
2133
    try:
2134
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2135
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2136
    except RetryTimeout:
2137
      result = None
2138
  else:
2139
    result = None
2140
    while result is None:
2141
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2142
  return result
2143

    
2144

    
2145
def UniqueSequence(seq):
2146
  """Returns a list with unique elements.
2147

2148
  Element order is preserved.
2149

2150
  @type seq: sequence
2151
  @param seq: the sequence with the source elements
2152
  @rtype: list
2153
  @return: list of unique elements from seq
2154

2155
  """
2156
  seen = set()
2157
  return [i for i in seq if i not in seen and not seen.add(i)]
2158

    
2159

    
2160
def NormalizeAndValidateMac(mac):
2161
  """Normalizes and check if a MAC address is valid.
2162

2163
  Checks whether the supplied MAC address is formally correct, only
2164
  accepts colon separated format. Normalize it to all lower.
2165

2166
  @type mac: str
2167
  @param mac: the MAC to be validated
2168
  @rtype: str
2169
  @return: returns the normalized and validated MAC.
2170

2171
  @raise errors.OpPrereqError: If the MAC isn't valid
2172

2173
  """
2174
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2175
  if not mac_check.match(mac):
2176
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2177
                               mac, errors.ECODE_INVAL)
2178

    
2179
  return mac.lower()
2180

    
2181

    
2182
def TestDelay(duration):
2183
  """Sleep for a fixed amount of time.
2184

2185
  @type duration: float
2186
  @param duration: the sleep duration
2187
  @rtype: boolean
2188
  @return: False for negative value, True otherwise
2189

2190
  """
2191
  if duration < 0:
2192
    return False, "Invalid sleep duration"
2193
  time.sleep(duration)
2194
  return True, None
2195

    
2196

    
2197
def _CloseFDNoErr(fd, retries=5):
2198
  """Close a file descriptor ignoring errors.
2199

2200
  @type fd: int
2201
  @param fd: the file descriptor
2202
  @type retries: int
2203
  @param retries: how many retries to make, in case we get any
2204
      other error than EBADF
2205

2206
  """
2207
  try:
2208
    os.close(fd)
2209
  except OSError, err:
2210
    if err.errno != errno.EBADF:
2211
      if retries > 0:
2212
        _CloseFDNoErr(fd, retries - 1)
2213
    # else either it's closed already or we're out of retries, so we
2214
    # ignore this and go on
2215

    
2216

    
2217
def CloseFDs(noclose_fds=None):
2218
  """Close file descriptors.
2219

2220
  This closes all file descriptors above 2 (i.e. except
2221
  stdin/out/err).
2222

2223
  @type noclose_fds: list or None
2224
  @param noclose_fds: if given, it denotes a list of file descriptor
2225
      that should not be closed
2226

2227
  """
2228
  # Default maximum for the number of available file descriptors.
2229
  if 'SC_OPEN_MAX' in os.sysconf_names:
2230
    try:
2231
      MAXFD = os.sysconf('SC_OPEN_MAX')
2232
      if MAXFD < 0:
2233
        MAXFD = 1024
2234
    except OSError:
2235
      MAXFD = 1024
2236
  else:
2237
    MAXFD = 1024
2238
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2239
  if (maxfd == resource.RLIM_INFINITY):
2240
    maxfd = MAXFD
2241

    
2242
  # Iterate through and close all file descriptors (except the standard ones)
2243
  for fd in range(3, maxfd):
2244
    if noclose_fds and fd in noclose_fds:
2245
      continue
2246
    _CloseFDNoErr(fd)
2247

    
2248

    
2249
def Mlockall():
2250
  """Lock current process' virtual address space into RAM.
2251

2252
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2253
  see mlock(2) for more details. This function requires ctypes module.
2254

2255
  """
2256
  if ctypes is None:
2257
    logging.warning("Cannot set memory lock, ctypes module not found")
2258
    return
2259

    
2260
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2261
  if libc is None:
2262
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2263
    return
2264

    
2265
  # Some older version of the ctypes module don't have built-in functionality
2266
  # to access the errno global variable, where function error codes are stored.
2267
  # By declaring this variable as a pointer to an integer we can then access
2268
  # its value correctly, should the mlockall call fail, in order to see what
2269
  # the actual error code was.
2270
  # pylint: disable-msg=W0212
2271
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2272

    
2273
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2274
    # pylint: disable-msg=W0212
2275
    logging.error("Cannot set memory lock: %s",
2276
                  os.strerror(libc.__errno_location().contents.value))
2277
    return
2278

    
2279
  logging.debug("Memory lock set")
2280

    
2281

    
2282
def Daemonize(logfile, run_uid, run_gid):
2283
  """Daemonize the current process.
2284

2285
  This detaches the current process from the controlling terminal and
2286
  runs it in the background as a daemon.
2287

2288
  @type logfile: str
2289
  @param logfile: the logfile to which we should redirect stdout/stderr
2290
  @type run_uid: int
2291
  @param run_uid: Run the child under this uid
2292
  @type run_gid: int
2293
  @param run_gid: Run the child under this gid
2294
  @rtype: int
2295
  @return: the value zero
2296

2297
  """
2298
  # pylint: disable-msg=W0212
2299
  # yes, we really want os._exit
2300
  UMASK = 077
2301
  WORKDIR = "/"
2302

    
2303
  # this might fail
2304
  pid = os.fork()
2305
  if (pid == 0):  # The first child.
2306
    os.setsid()
2307
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2308
    #        make sure to check for config permission and bail out when invoked
2309
    #        with wrong user.
2310
    os.setgid(run_gid)
2311
    os.setuid(run_uid)
2312
    # this might fail
2313
    pid = os.fork() # Fork a second child.
2314
    if (pid == 0):  # The second child.
2315
      os.chdir(WORKDIR)
2316
      os.umask(UMASK)
2317
    else:
2318
      # exit() or _exit()?  See below.
2319
      os._exit(0) # Exit parent (the first child) of the second child.
2320
  else:
2321
    os._exit(0) # Exit parent of the first child.
2322

    
2323
  for fd in range(3):
2324
    _CloseFDNoErr(fd)
2325
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2326
  assert i == 0, "Can't close/reopen stdin"
2327
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2328
  assert i == 1, "Can't close/reopen stdout"
2329
  # Duplicate standard output to standard error.
2330
  os.dup2(1, 2)
2331
  return 0
2332

    
2333

    
2334
def DaemonPidFileName(name):
2335
  """Compute a ganeti pid file absolute path
2336

2337
  @type name: str
2338
  @param name: the daemon name
2339
  @rtype: str
2340
  @return: the full path to the pidfile corresponding to the given
2341
      daemon name
2342

2343
  """
2344
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2345

    
2346

    
2347
def EnsureDaemon(name):
2348
  """Check for and start daemon if not alive.
2349

2350
  """
2351
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2352
  if result.failed:
2353
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2354
                  name, result.fail_reason, result.output)
2355
    return False
2356

    
2357
  return True
2358

    
2359

    
2360
def StopDaemon(name):
2361
  """Stop daemon
2362

2363
  """
2364
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2365
  if result.failed:
2366
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2367
                  name, result.fail_reason, result.output)
2368
    return False
2369

    
2370
  return True
2371

    
2372

    
2373
def WritePidFile(name):
2374
  """Write the current process pidfile.
2375

2376
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2377

2378
  @type name: str
2379
  @param name: the daemon name to use
2380
  @raise errors.GenericError: if the pid file already exists and
2381
      points to a live process
2382

2383
  """
2384
  pid = os.getpid()
2385
  pidfilename = DaemonPidFileName(name)
2386
  if IsProcessAlive(ReadPidFile(pidfilename)):
2387
    raise errors.GenericError("%s contains a live process" % pidfilename)
2388

    
2389
  WriteFile(pidfilename, data="%d\n" % pid)
2390

    
2391

    
2392
def RemovePidFile(name):
2393
  """Remove the current process pidfile.
2394

2395
  Any errors are ignored.
2396

2397
  @type name: str
2398
  @param name: the daemon name used to derive the pidfile name
2399

2400
  """
2401
  pidfilename = DaemonPidFileName(name)
2402
  # TODO: we could check here that the file contains our pid
2403
  try:
2404
    RemoveFile(pidfilename)
2405
  except: # pylint: disable-msg=W0702
2406
    pass
2407

    
2408

    
2409
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2410
                waitpid=False):
2411
  """Kill a process given by its pid.
2412

2413
  @type pid: int
2414
  @param pid: The PID to terminate.
2415
  @type signal_: int
2416
  @param signal_: The signal to send, by default SIGTERM
2417
  @type timeout: int
2418
  @param timeout: The timeout after which, if the process is still alive,
2419
                  a SIGKILL will be sent. If not positive, no such checking
2420
                  will be done
2421
  @type waitpid: boolean
2422
  @param waitpid: If true, we should waitpid on this process after
2423
      sending signals, since it's our own child and otherwise it
2424
      would remain as zombie
2425

2426
  """
2427
  def _helper(pid, signal_, wait):
2428
    """Simple helper to encapsulate the kill/waitpid sequence"""
2429
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2430
      try:
2431
        os.waitpid(pid, os.WNOHANG)
2432
      except OSError:
2433
        pass
2434

    
2435
  if pid <= 0:
2436
    # kill with pid=0 == suicide
2437
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2438

    
2439
  if not IsProcessAlive(pid):
2440
    return
2441

    
2442
  _helper(pid, signal_, waitpid)
2443

    
2444
  if timeout <= 0:
2445
    return
2446

    
2447
  def _CheckProcess():
2448
    if not IsProcessAlive(pid):
2449
      return
2450

    
2451
    try:
2452
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2453
    except OSError:
2454
      raise RetryAgain()
2455

    
2456
    if result_pid > 0:
2457
      return
2458

    
2459
    raise RetryAgain()
2460

    
2461
  try:
2462
    # Wait up to $timeout seconds
2463
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2464
  except RetryTimeout:
2465
    pass
2466

    
2467
  if IsProcessAlive(pid):
2468
    # Kill process if it's still alive
2469
    _helper(pid, signal.SIGKILL, waitpid)
2470

    
2471

    
2472
def FindFile(name, search_path, test=os.path.exists):
2473
  """Look for a filesystem object in a given path.
2474

2475
  This is an abstract method to search for filesystem object (files,
2476
  dirs) under a given search path.
2477

2478
  @type name: str
2479
  @param name: the name to look for
2480
  @type search_path: str
2481
  @param search_path: location to start at
2482
  @type test: callable
2483
  @param test: a function taking one argument that should return True
2484
      if the a given object is valid; the default value is
2485
      os.path.exists, causing only existing files to be returned
2486
  @rtype: str or None
2487
  @return: full path to the object if found, None otherwise
2488

2489
  """
2490
  # validate the filename mask
2491
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2492
    logging.critical("Invalid value passed for external script name: '%s'",
2493
                     name)
2494
    return None
2495

    
2496
  for dir_name in search_path:
2497
    # FIXME: investigate switch to PathJoin
2498
    item_name = os.path.sep.join([dir_name, name])
2499
    # check the user test and that we're indeed resolving to the given
2500
    # basename
2501
    if test(item_name) and os.path.basename(item_name) == name:
2502
      return item_name
2503
  return None
2504

    
2505

    
2506
def CheckVolumeGroupSize(vglist, vgname, minsize):
2507
  """Checks if the volume group list is valid.
2508

2509
  The function will check if a given volume group is in the list of
2510
  volume groups and has a minimum size.
2511

2512
  @type vglist: dict
2513
  @param vglist: dictionary of volume group names and their size
2514
  @type vgname: str
2515
  @param vgname: the volume group we should check
2516
  @type minsize: int
2517
  @param minsize: the minimum size we accept
2518
  @rtype: None or str
2519
  @return: None for success, otherwise the error message
2520

2521
  """
2522
  vgsize = vglist.get(vgname, None)
2523
  if vgsize is None:
2524
    return "volume group '%s' missing" % vgname
2525
  elif vgsize < minsize:
2526
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2527
            (vgname, minsize, vgsize))
2528
  return None
2529

    
2530

    
2531
def SplitTime(value):
2532
  """Splits time as floating point number into a tuple.
2533

2534
  @param value: Time in seconds
2535
  @type value: int or float
2536
  @return: Tuple containing (seconds, microseconds)
2537

2538
  """
2539
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2540

    
2541
  assert 0 <= seconds, \
2542
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2543
  assert 0 <= microseconds <= 999999, \
2544
    "Microseconds must be 0-999999, but are %s" % microseconds
2545

    
2546
  return (int(seconds), int(microseconds))
2547

    
2548

    
2549
def MergeTime(timetuple):
2550
  """Merges a tuple into time as a floating point number.
2551

2552
  @param timetuple: Time as tuple, (seconds, microseconds)
2553
  @type timetuple: tuple
2554
  @return: Time as a floating point number expressed in seconds
2555

2556
  """
2557
  (seconds, microseconds) = timetuple
2558

    
2559
  assert 0 <= seconds, \
2560
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2561
  assert 0 <= microseconds <= 999999, \
2562
    "Microseconds must be 0-999999, but are %s" % microseconds
2563

    
2564
  return float(seconds) + (float(microseconds) * 0.000001)
2565

    
2566

    
2567
def GetDaemonPort(daemon_name):
2568
  """Get the daemon port for this cluster.
2569

2570
  Note that this routine does not read a ganeti-specific file, but
2571
  instead uses C{socket.getservbyname} to allow pre-customization of
2572
  this parameter outside of Ganeti.
2573

2574
  @type daemon_name: string
2575
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2576
  @rtype: int
2577

2578
  """
2579
  if daemon_name not in constants.DAEMONS_PORTS:
2580
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2581

    
2582
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2583
  try:
2584
    port = socket.getservbyname(daemon_name, proto)
2585
  except socket.error:
2586
    port = default_port
2587

    
2588
  return port
2589

    
2590

    
2591
class LogFileHandler(logging.FileHandler):
2592
  """Log handler that doesn't fallback to stderr.
2593

2594
  When an error occurs while writing on the logfile, logging.FileHandler tries
2595
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2596
  the logfile. This class avoids failures reporting errors to /dev/console.
2597

2598
  """
2599
  def __init__(self, filename, mode="a", encoding=None):
2600
    """Open the specified file and use it as the stream for logging.
2601

2602
    Also open /dev/console to report errors while logging.
2603

2604
    """
2605
    logging.FileHandler.__init__(self, filename, mode, encoding)
2606
    self.console = open(constants.DEV_CONSOLE, "a")
2607

    
2608
  def handleError(self, record): # pylint: disable-msg=C0103
2609
    """Handle errors which occur during an emit() call.
2610

2611
    Try to handle errors with FileHandler method, if it fails write to
2612
    /dev/console.
2613

2614
    """
2615
    try:
2616
      logging.FileHandler.handleError(self, record)
2617
    except Exception: # pylint: disable-msg=W0703
2618
      try:
2619
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2620
      except Exception: # pylint: disable-msg=W0703
2621
        # Log handler tried everything it could, now just give up
2622
        pass
2623

    
2624

    
2625
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2626
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2627
                 console_logging=False):
2628
  """Configures the logging module.
2629

2630
  @type logfile: str
2631
  @param logfile: the filename to which we should log
2632
  @type debug: integer
2633
  @param debug: if greater than zero, enable debug messages, otherwise
2634
      only those at C{INFO} and above level
2635
  @type stderr_logging: boolean
2636
  @param stderr_logging: whether we should also log to the standard error
2637
  @type program: str
2638
  @param program: the name under which we should log messages
2639
  @type multithreaded: boolean
2640
  @param multithreaded: if True, will add the thread name to the log file
2641
  @type syslog: string
2642
  @param syslog: one of 'no', 'yes', 'only':
2643
      - if no, syslog is not used
2644
      - if yes, syslog is used (in addition to file-logging)
2645
      - if only, only syslog is used
2646
  @type console_logging: boolean
2647
  @param console_logging: if True, will use a FileHandler which falls back to
2648
      the system console if logging fails
2649
  @raise EnvironmentError: if we can't open the log file and
2650
      syslog/stderr logging is disabled
2651

2652
  """
2653
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2654
  sft = program + "[%(process)d]:"
2655
  if multithreaded:
2656
    fmt += "/%(threadName)s"
2657
    sft += " (%(threadName)s)"
2658
  if debug:
2659
    fmt += " %(module)s:%(lineno)s"
2660
    # no debug info for syslog loggers
2661
  fmt += " %(levelname)s %(message)s"
2662
  # yes, we do want the textual level, as remote syslog will probably
2663
  # lose the error level, and it's easier to grep for it
2664
  sft += " %(levelname)s %(message)s"
2665
  formatter = logging.Formatter(fmt)
2666
  sys_fmt = logging.Formatter(sft)
2667

    
2668
  root_logger = logging.getLogger("")
2669
  root_logger.setLevel(logging.NOTSET)
2670

    
2671
  # Remove all previously setup handlers
2672
  for handler in root_logger.handlers:
2673
    handler.close()
2674
    root_logger.removeHandler(handler)
2675

    
2676
  if stderr_logging:
2677
    stderr_handler = logging.StreamHandler()
2678
    stderr_handler.setFormatter(formatter)
2679
    if debug:
2680
      stderr_handler.setLevel(logging.NOTSET)
2681
    else:
2682
      stderr_handler.setLevel(logging.CRITICAL)
2683
    root_logger.addHandler(stderr_handler)
2684

    
2685
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2686
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2687
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2688
                                                    facility)
2689
    syslog_handler.setFormatter(sys_fmt)
2690
    # Never enable debug over syslog
2691
    syslog_handler.setLevel(logging.INFO)
2692
    root_logger.addHandler(syslog_handler)
2693

    
2694
  if syslog != constants.SYSLOG_ONLY:
2695
    # this can fail, if the logging directories are not setup or we have
2696
    # a permisssion problem; in this case, it's best to log but ignore
2697
    # the error if stderr_logging is True, and if false we re-raise the
2698
    # exception since otherwise we could run but without any logs at all
2699
    try:
2700
      if console_logging:
2701
        logfile_handler = LogFileHandler(logfile)
2702
      else:
2703
        logfile_handler = logging.FileHandler(logfile)
2704
      logfile_handler.setFormatter(formatter)
2705
      if debug:
2706
        logfile_handler.setLevel(logging.DEBUG)
2707
      else:
2708
        logfile_handler.setLevel(logging.INFO)
2709
      root_logger.addHandler(logfile_handler)
2710
    except EnvironmentError:
2711
      if stderr_logging or syslog == constants.SYSLOG_YES:
2712
        logging.exception("Failed to enable logging to file '%s'", logfile)
2713
      else:
2714
        # we need to re-raise the exception
2715
        raise
2716

    
2717

    
2718
def IsNormAbsPath(path):
2719
  """Check whether a path is absolute and also normalized
2720

2721
  This avoids things like /dir/../../other/path to be valid.
2722

2723
  """
2724
  return os.path.normpath(path) == path and os.path.isabs(path)
2725

    
2726

    
2727
def PathJoin(*args):
2728
  """Safe-join a list of path components.
2729

2730
  Requirements:
2731
      - the first argument must be an absolute path
2732
      - no component in the path must have backtracking (e.g. /../),
2733
        since we check for normalization at the end
2734

2735
  @param args: the path components to be joined
2736
  @raise ValueError: for invalid paths
2737

2738
  """
2739
  # ensure we're having at least one path passed in
2740
  assert args
2741
  # ensure the first component is an absolute and normalized path name
2742
  root = args[0]
2743
  if not IsNormAbsPath(root):
2744
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2745
  result = os.path.join(*args)
2746
  # ensure that the whole path is normalized
2747
  if not IsNormAbsPath(result):
2748
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2749
  # check that we're still under the original prefix
2750
  prefix = os.path.commonprefix([root, result])
2751
  if prefix != root:
2752
    raise ValueError("Error: path joining resulted in different prefix"
2753
                     " (%s != %s)" % (prefix, root))
2754
  return result
2755

    
2756

    
2757
def TailFile(fname, lines=20):
2758
  """Return the last lines from a file.
2759

2760
  @note: this function will only read and parse the last 4KB of
2761
      the file; if the lines are very long, it could be that less
2762
      than the requested number of lines are returned
2763

2764
  @param fname: the file name
2765
  @type lines: int
2766
  @param lines: the (maximum) number of lines to return
2767

2768
  """
2769
  fd = open(fname, "r")
2770
  try:
2771
    fd.seek(0, 2)
2772
    pos = fd.tell()
2773
    pos = max(0, pos-4096)
2774
    fd.seek(pos, 0)
2775
    raw_data = fd.read()
2776
  finally:
2777
    fd.close()
2778

    
2779
  rows = raw_data.splitlines()
2780
  return rows[-lines:]
2781

    
2782

    
2783
def FormatTimestampWithTZ(secs):
2784
  """Formats a Unix timestamp with the local timezone.
2785

2786
  """
2787
  return time.strftime("%F %T %Z", time.gmtime(secs))
2788

    
2789

    
2790
def _ParseAsn1Generalizedtime(value):
2791
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2792

2793
  @type value: string
2794
  @param value: ASN1 GENERALIZEDTIME timestamp
2795

2796
  """
2797
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2798
  if m:
2799
    # We have an offset
2800
    asn1time = m.group(1)
2801
    hours = int(m.group(2))
2802
    minutes = int(m.group(3))
2803
    utcoffset = (60 * hours) + minutes
2804
  else:
2805
    if not value.endswith("Z"):
2806
      raise ValueError("Missing timezone")
2807
    asn1time = value[:-1]
2808
    utcoffset = 0
2809

    
2810
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2811

    
2812
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2813

    
2814
  return calendar.timegm(tt.utctimetuple())
2815

    
2816

    
2817
def GetX509CertValidity(cert):
2818
  """Returns the validity period of the certificate.
2819

2820
  @type cert: OpenSSL.crypto.X509
2821
  @param cert: X509 certificate object
2822

2823
  """
2824
  # The get_notBefore and get_notAfter functions are only supported in
2825
  # pyOpenSSL 0.7 and above.
2826
  try:
2827
    get_notbefore_fn = cert.get_notBefore
2828
  except AttributeError:
2829
    not_before = None
2830
  else:
2831
    not_before_asn1 = get_notbefore_fn()
2832

    
2833
    if not_before_asn1 is None:
2834
      not_before = None
2835
    else:
2836
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2837

    
2838
  try:
2839
    get_notafter_fn = cert.get_notAfter
2840
  except AttributeError:
2841
    not_after = None
2842
  else:
2843
    not_after_asn1 = get_notafter_fn()
2844

    
2845
    if not_after_asn1 is None:
2846
      not_after = None
2847
    else:
2848
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2849

    
2850
  return (not_before, not_after)
2851

    
2852

    
2853
def _VerifyCertificateInner(expired, not_before, not_after, now,
2854
                            warn_days, error_days):
2855
  """Verifies certificate validity.
2856

2857
  @type expired: bool
2858
  @param expired: Whether pyOpenSSL considers the certificate as expired
2859
  @type not_before: number or None
2860
  @param not_before: Unix timestamp before which certificate is not valid
2861
  @type not_after: number or None
2862
  @param not_after: Unix timestamp after which certificate is invalid
2863
  @type now: number
2864
  @param now: Current time as Unix timestamp
2865
  @type warn_days: number or None
2866
  @param warn_days: How many days before expiration a warning should be reported
2867
  @type error_days: number or None
2868
  @param error_days: How many days before expiration an error should be reported
2869

2870
  """
2871
  if expired:
2872
    msg = "Certificate is expired"
2873

    
2874
    if not_before is not None and not_after is not None:
2875
      msg += (" (valid from %s to %s)" %
2876
              (FormatTimestampWithTZ(not_before),
2877
               FormatTimestampWithTZ(not_after)))
2878
    elif not_before is not None:
2879
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2880
    elif not_after is not None:
2881
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2882

    
2883
    return (CERT_ERROR, msg)
2884

    
2885
  elif not_before is not None and not_before > now:
2886
    return (CERT_WARNING,
2887
            "Certificate not yet valid (valid from %s)" %
2888
            FormatTimestampWithTZ(not_before))
2889

    
2890
  elif not_after is not None:
2891
    remaining_days = int((not_after - now) / (24 * 3600))
2892

    
2893
    msg = "Certificate expires in about %d days" % remaining_days
2894

    
2895
    if error_days is not None and remaining_days <= error_days:
2896
      return (CERT_ERROR, msg)
2897

    
2898
    if warn_days is not None and remaining_days <= warn_days:
2899
      return (CERT_WARNING, msg)
2900

    
2901
  return (None, None)
2902

    
2903

    
2904
def VerifyX509Certificate(cert, warn_days, error_days):
2905
  """Verifies a certificate for LUVerifyCluster.
2906

2907
  @type cert: OpenSSL.crypto.X509
2908
  @param cert: X509 certificate object
2909
  @type warn_days: number or None
2910
  @param warn_days: How many days before expiration a warning should be reported
2911
  @type error_days: number or None
2912
  @param error_days: How many days before expiration an error should be reported
2913

2914
  """
2915
  # Depending on the pyOpenSSL version, this can just return (None, None)
2916
  (not_before, not_after) = GetX509CertValidity(cert)
2917

    
2918
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2919
                                 time.time(), warn_days, error_days)
2920

    
2921

    
2922
def SignX509Certificate(cert, key, salt):
2923
  """Sign a X509 certificate.
2924

2925
  An RFC822-like signature header is added in front of the certificate.
2926

2927
  @type cert: OpenSSL.crypto.X509
2928
  @param cert: X509 certificate object
2929
  @type key: string
2930
  @param key: Key for HMAC
2931
  @type salt: string
2932
  @param salt: Salt for HMAC
2933
  @rtype: string
2934
  @return: Serialized and signed certificate in PEM format
2935

2936
  """
2937
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2938
    raise errors.GenericError("Invalid salt: %r" % salt)
2939

    
2940
  # Dumping as PEM here ensures the certificate is in a sane format
2941
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2942

    
2943
  return ("%s: %s/%s\n\n%s" %
2944
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2945
           Sha1Hmac(key, cert_pem, salt=salt),
2946
           cert_pem))
2947

    
2948

    
2949
def _ExtractX509CertificateSignature(cert_pem):
2950
  """Helper function to extract signature from X509 certificate.
2951

2952
  """
2953
  # Extract signature from original PEM data
2954
  for line in cert_pem.splitlines():
2955
    if line.startswith("---"):
2956
      break
2957

    
2958
    m = X509_SIGNATURE.match(line.strip())
2959
    if m:
2960
      return (m.group("salt"), m.group("sign"))
2961

    
2962
  raise errors.GenericError("X509 certificate signature is missing")
2963

    
2964

    
2965
def LoadSignedX509Certificate(cert_pem, key):
2966
  """Verifies a signed X509 certificate.
2967

2968
  @type cert_pem: string
2969
  @param cert_pem: Certificate in PEM format and with signature header
2970
  @type key: string
2971
  @param key: Key for HMAC
2972
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2973
  @return: X509 certificate object and salt
2974

2975
  """
2976
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2977

    
2978
  # Load certificate
2979
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2980

    
2981
  # Dump again to ensure it's in a sane format
2982
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2983

    
2984
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2985
    raise errors.GenericError("X509 certificate signature is invalid")
2986

    
2987
  return (cert, salt)
2988

    
2989

    
2990
def Sha1Hmac(key, text, salt=None):
2991
  """Calculates the HMAC-SHA1 digest of a text.
2992

2993
  HMAC is defined in RFC2104.
2994

2995
  @type key: string
2996
  @param key: Secret key
2997
  @type text: string
2998

2999
  """
3000
  if salt:
3001
    salted_text = salt + text
3002
  else:
3003
    salted_text = text
3004

    
3005
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
3006

    
3007

    
3008
def VerifySha1Hmac(key, text, digest, salt=None):
3009
  """Verifies the HMAC-SHA1 digest of a text.
3010

3011
  HMAC is defined in RFC2104.
3012

3013
  @type key: string
3014
  @param key: Secret key
3015
  @type text: string
3016
  @type digest: string
3017
  @param digest: Expected digest
3018
  @rtype: bool
3019
  @return: Whether HMAC-SHA1 digest matches
3020

3021
  """
3022
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
3023

    
3024

    
3025
def SafeEncode(text):
3026
  """Return a 'safe' version of a source string.
3027

3028
  This function mangles the input string and returns a version that
3029
  should be safe to display/encode as ASCII. To this end, we first
3030
  convert it to ASCII using the 'backslashreplace' encoding which
3031
  should get rid of any non-ASCII chars, and then we process it
3032
  through a loop copied from the string repr sources in the python; we
3033
  don't use string_escape anymore since that escape single quotes and
3034
  backslashes too, and that is too much; and that escaping is not
3035
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
3036

3037
  @type text: str or unicode
3038
  @param text: input data
3039
  @rtype: str
3040
  @return: a safe version of text
3041

3042
  """
3043
  if isinstance(text, unicode):
3044
    # only if unicode; if str already, we handle it below
3045
    text = text.encode('ascii', 'backslashreplace')
3046
  resu = ""
3047
  for char in text:
3048
    c = ord(char)
3049
    if char  == '\t':
3050
      resu += r'\t'
3051
    elif char == '\n':
3052
      resu += r'\n'
3053
    elif char == '\r':
3054
      resu += r'\'r'
3055
    elif c < 32 or c >= 127: # non-printable
3056
      resu += "\\x%02x" % (c & 0xff)
3057
    else:
3058
      resu += char
3059
  return resu
3060

    
3061

    
3062
def UnescapeAndSplit(text, sep=","):
3063
  """Split and unescape a string based on a given separator.
3064

3065
  This function splits a string based on a separator where the
3066
  separator itself can be escape in order to be an element of the
3067
  elements. The escaping rules are (assuming coma being the
3068
  separator):
3069
    - a plain , separates the elements
3070
    - a sequence \\\\, (double backslash plus comma) is handled as a
3071
      backslash plus a separator comma
3072
    - a sequence \, (backslash plus comma) is handled as a
3073
      non-separator comma
3074

3075
  @type text: string
3076
  @param text: the string to split
3077
  @type sep: string
3078
  @param text: the separator
3079
  @rtype: string
3080
  @return: a list of strings
3081

3082
  """
3083
  # we split the list by sep (with no escaping at this stage)
3084
  slist = text.split(sep)
3085
  # next, we revisit the elements and if any of them ended with an odd
3086
  # number of backslashes, then we join it with the next
3087
  rlist = []
3088
  while slist:
3089
    e1 = slist.pop(0)
3090
    if e1.endswith("\\"):
3091
      num_b = len(e1) - len(e1.rstrip("\\"))
3092
      if num_b % 2 == 1:
3093
        e2 = slist.pop(0)
3094
        # here the backslashes remain (all), and will be reduced in
3095
        # the next step
3096
        rlist.append(e1 + sep + e2)
3097
        continue
3098
    rlist.append(e1)
3099
  # finally, replace backslash-something with something
3100
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3101
  return rlist
3102

    
3103

    
3104
def CommaJoin(names):
3105
  """Nicely join a set of identifiers.
3106

3107
  @param names: set, list or tuple
3108
  @return: a string with the formatted results
3109

3110
  """
3111
  return ", ".join([str(val) for val in names])
3112

    
3113

    
3114
def BytesToMebibyte(value):
3115
  """Converts bytes to mebibytes.
3116

3117
  @type value: int
3118
  @param value: Value in bytes
3119
  @rtype: int
3120
  @return: Value in mebibytes
3121

3122
  """
3123
  return int(round(value / (1024.0 * 1024.0), 0))
3124

    
3125

    
3126
def CalculateDirectorySize(path):
3127
  """Calculates the size of a directory recursively.
3128

3129
  @type path: string
3130
  @param path: Path to directory
3131
  @rtype: int
3132
  @return: Size in mebibytes
3133

3134
  """
3135
  size = 0
3136

    
3137
  for (curpath, _, files) in os.walk(path):
3138
    for filename in files:
3139
      st = os.lstat(PathJoin(curpath, filename))
3140
      size += st.st_size
3141

    
3142
  return BytesToMebibyte(size)
3143

    
3144

    
3145
def GetFilesystemStats(path):
3146
  """Returns the total and free space on a filesystem.
3147

3148
  @type path: string
3149
  @param path: Path on filesystem to be examined
3150
  @rtype: int
3151
  @return: tuple of (Total space, Free space) in mebibytes
3152

3153
  """
3154
  st = os.statvfs(path)
3155

    
3156
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3157
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3158
  return (tsize, fsize)
3159

    
3160

    
3161
def RunInSeparateProcess(fn, *args):
3162
  """Runs a function in a separate process.
3163

3164
  Note: Only boolean return values are supported.
3165

3166
  @type fn: callable
3167
  @param fn: Function to be called
3168
  @rtype: bool
3169
  @return: Function's result
3170

3171
  """
3172
  pid = os.fork()
3173
  if pid == 0:
3174
    # Child process
3175
    try:
3176
      # In case the function uses temporary files
3177
      ResetTempfileModule()
3178

    
3179
      # Call function
3180
      result = int(bool(fn(*args)))
3181
      assert result in (0, 1)
3182
    except: # pylint: disable-msg=W0702
3183
      logging.exception("Error while calling function in separate process")
3184
      # 0 and 1 are reserved for the return value
3185
      result = 33
3186

    
3187
    os._exit(result) # pylint: disable-msg=W0212
3188

    
3189
  # Parent process
3190

    
3191
  # Avoid zombies and check exit code
3192
  (_, status) = os.waitpid(pid, 0)
3193

    
3194
  if os.WIFSIGNALED(status):
3195
    exitcode = None
3196
    signum = os.WTERMSIG(status)
3197
  else:
3198
    exitcode = os.WEXITSTATUS(status)
3199
    signum = None
3200

    
3201
  if not (exitcode in (0, 1) and signum is None):
3202
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3203
                              (exitcode, signum))
3204

    
3205
  return bool(exitcode)
3206

    
3207

    
3208
def IgnoreProcessNotFound(fn, *args, **kwargs):
3209
  """Ignores ESRCH when calling a process-related function.
3210

3211
  ESRCH is raised when a process is not found.
3212

3213
  @rtype: bool
3214
  @return: Whether process was found
3215

3216
  """
3217
  try:
3218
    fn(*args, **kwargs)
3219
  except EnvironmentError, err:
3220
    # Ignore ESRCH
3221
    if err.errno == errno.ESRCH:
3222
      return False
3223
    raise
3224

    
3225
  return True
3226

    
3227

    
3228
def IgnoreSignals(fn, *args, **kwargs):
3229
  """Tries to call a function ignoring failures due to EINTR.
3230

3231
  """
3232
  try:
3233
    return fn(*args, **kwargs)
3234
  except EnvironmentError, err:
3235
    if err.errno == errno.EINTR:
3236
      return None
3237
    else:
3238
      raise
3239
  except (select.error, socket.error), err:
3240
    # In python 2.6 and above select.error is an IOError, so it's handled
3241
    # above, in 2.5 and below it's not, and it's handled here.
3242
    if err.args and err.args[0] == errno.EINTR:
3243
      return None
3244
    else:
3245
      raise
3246

    
3247

    
3248
def LockFile(fd):
3249
  """Locks a file using POSIX locks.
3250

3251
  @type fd: int
3252
  @param fd: the file descriptor we need to lock
3253

3254
  """
3255
  try:
3256
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3257
  except IOError, err:
3258
    if err.errno == errno.EAGAIN:
3259
      raise errors.LockError("File already locked")
3260
    raise
3261

    
3262

    
3263
def FormatTime(val):
3264
  """Formats a time value.
3265

3266
  @type val: float or None
3267
  @param val: the timestamp as returned by time.time()
3268
  @return: a string value or N/A if we don't have a valid timestamp
3269

3270
  """
3271
  if val is None or not isinstance(val, (int, float)):
3272
    return "N/A"
3273
  # these two codes works on Linux, but they are not guaranteed on all
3274
  # platforms
3275
  return time.strftime("%F %T", time.localtime(val))
3276

    
3277

    
3278
def FormatSeconds(secs):
3279
  """Formats seconds for easier reading.
3280

3281
  @type secs: number
3282
  @param secs: Number of seconds
3283
  @rtype: string
3284
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3285

3286
  """
3287
  parts = []
3288

    
3289
  secs = round(secs, 0)
3290

    
3291
  if secs > 0:
3292
    # Negative values would be a bit tricky
3293
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3294
      (complete, secs) = divmod(secs, one)
3295
      if complete or parts:
3296
        parts.append("%d%s" % (complete, unit))
3297

    
3298
  parts.append("%ds" % secs)
3299

    
3300
  return " ".join(parts)
3301

    
3302

    
3303
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3304
  """Reads the watcher pause file.
3305

3306
  @type filename: string
3307
  @param filename: Path to watcher pause file
3308
  @type now: None, float or int
3309
  @param now: Current time as Unix timestamp
3310
  @type remove_after: int
3311
  @param remove_after: Remove watcher pause file after specified amount of
3312
    seconds past the pause end time
3313

3314
  """
3315
  if now is None:
3316
    now = time.time()
3317

    
3318
  try:
3319
    value = ReadFile(filename)
3320
  except IOError, err:
3321
    if err.errno != errno.ENOENT:
3322
      raise
3323
    value = None
3324

    
3325
  if value is not None:
3326
    try:
3327
      value = int(value)
3328
    except ValueError:
3329
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3330
                       " removing it"), filename)
3331
      RemoveFile(filename)
3332
      value = None
3333

    
3334
    if value is not None:
3335
      # Remove file if it's outdated
3336
      if now > (value + remove_after):
3337
        RemoveFile(filename)
3338
        value = None
3339

    
3340
      elif now > value:
3341
        value = None
3342

    
3343
  return value
3344

    
3345

    
3346
class RetryTimeout(Exception):
3347
  """Retry loop timed out.
3348

3349
  Any arguments which was passed by the retried function to RetryAgain will be
3350
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3351
  the RaiseInner helper method will reraise it.
3352

3353
  """
3354
  def RaiseInner(self):
3355
    if self.args and isinstance(self.args[0], Exception):
3356
      raise self.args[0]
3357
    else:
3358
      raise RetryTimeout(*self.args)
3359

    
3360

    
3361
class RetryAgain(Exception):
3362
  """Retry again.
3363

3364
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3365
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3366
  of the RetryTimeout() method can be used to reraise it.
3367

3368
  """
3369

    
3370

    
3371
class _RetryDelayCalculator(object):
3372
  """Calculator for increasing delays.
3373

3374
  """
3375
  __slots__ = [
3376
    "_factor",
3377
    "_limit",
3378
    "_next",
3379
    "_start",
3380
    ]
3381

    
3382
  def __init__(self, start, factor, limit):
3383
    """Initializes this class.
3384

3385
    @type start: float
3386
    @param start: Initial delay
3387
    @type factor: float
3388
    @param factor: Factor for delay increase
3389
    @type limit: float or None
3390
    @param limit: Upper limit for delay or None for no limit
3391

3392
    """
3393
    assert start > 0.0
3394
    assert factor >= 1.0
3395
    assert limit is None or limit >= 0.0
3396

    
3397
    self._start = start
3398
    self._factor = factor
3399
    self._limit = limit
3400

    
3401
    self._next = start
3402

    
3403
  def __call__(self):
3404
    """Returns current delay and calculates the next one.
3405

3406
    """
3407
    current = self._next
3408

    
3409
    # Update for next run
3410
    if self._limit is None or self._next < self._limit:
3411
      self._next = min(self._limit, self._next * self._factor)
3412

    
3413
    return current
3414

    
3415

    
3416
#: Special delay to specify whole remaining timeout
3417
RETRY_REMAINING_TIME = object()
3418

    
3419

    
3420
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3421
          _time_fn=time.time):
3422
  """Call a function repeatedly until it succeeds.
3423

3424
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3425
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3426
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3427

3428
  C{delay} can be one of the following:
3429
    - callable returning the delay length as a float
3430
    - Tuple of (start, factor, limit)
3431
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3432
      useful when overriding L{wait_fn} to wait for an external event)
3433
    - A static delay as a number (int or float)
3434

3435
  @type fn: callable
3436
  @param fn: Function to be called
3437
  @param delay: Either a callable (returning the delay), a tuple of (start,
3438
                factor, limit) (see L{_RetryDelayCalculator}),
3439
                L{RETRY_REMAINING_TIME} or a number (int or float)
3440
  @type timeout: float
3441
  @param timeout: Total timeout
3442
  @type wait_fn: callable
3443
  @param wait_fn: Waiting function
3444
  @return: Return value of function
3445

3446
  """
3447
  assert callable(fn)
3448
  assert callable(wait_fn)
3449
  assert callable(_time_fn)
3450

    
3451
  if args is None:
3452
    args = []
3453

    
3454
  end_time = _time_fn() + timeout
3455

    
3456
  if callable(delay):
3457
    # External function to calculate delay
3458
    calc_delay = delay
3459

    
3460
  elif isinstance(delay, (tuple, list)):
3461
    # Increasing delay with optional upper boundary
3462
    (start, factor, limit) = delay
3463
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3464

    
3465
  elif delay is RETRY_REMAINING_TIME:
3466
    # Always use the remaining time
3467
    calc_delay = None
3468

    
3469
  else:
3470
    # Static delay
3471
    calc_delay = lambda: delay
3472

    
3473
  assert calc_delay is None or callable(calc_delay)
3474

    
3475
  while True:
3476
    retry_args = []
3477
    try:
3478
      # pylint: disable-msg=W0142
3479
      return fn(*args)
3480
    except RetryAgain, err:
3481
      retry_args = err.args
3482
    except RetryTimeout:
3483
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3484
                                   " handle RetryTimeout")
3485

    
3486
    remaining_time = end_time - _time_fn()
3487

    
3488
    if remaining_time < 0.0:
3489
      # pylint: disable-msg=W0142
3490
      raise RetryTimeout(*retry_args)
3491

    
3492
    assert remaining_time >= 0.0
3493

    
3494
    if calc_delay is None:
3495
      wait_fn(remaining_time)
3496
    else:
3497
      current_delay = calc_delay()
3498
      if current_delay > 0.0:
3499
        wait_fn(current_delay)
3500

    
3501

    
3502
def GetClosedTempfile(*args, **kwargs):
3503
  """Creates a temporary file and returns its path.
3504

3505
  """
3506
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3507
  _CloseFDNoErr(fd)
3508
  return path
3509

    
3510

    
3511
def GenerateSelfSignedX509Cert(common_name, validity):
3512
  """Generates a self-signed X509 certificate.
3513

3514
  @type common_name: string
3515
  @param common_name: commonName value
3516
  @type validity: int
3517
  @param validity: Validity for certificate in seconds
3518

3519
  """
3520
  # Create private and public key
3521
  key = OpenSSL.crypto.PKey()
3522
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3523

    
3524
  # Create self-signed certificate
3525
  cert = OpenSSL.crypto.X509()
3526
  if common_name:
3527
    cert.get_subject().CN = common_name
3528
  cert.set_serial_number(1)
3529
  cert.gmtime_adj_notBefore(0)
3530
  cert.gmtime_adj_notAfter(validity)
3531
  cert.set_issuer(cert.get_subject())
3532
  cert.set_pubkey(key)
3533
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3534

    
3535
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3536
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3537

    
3538
  return (key_pem, cert_pem)
3539

    
3540

    
3541
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3542
  """Legacy function to generate self-signed X509 certificate.
3543

3544
  """
3545
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3546
                                                   validity * 24 * 60 * 60)
3547

    
3548
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3549

    
3550

    
3551
class FileLock(object):
3552
  """Utility class for file locks.
3553

3554
  """
3555
  def __init__(self, fd, filename):
3556
    """Constructor for FileLock.
3557

3558
    @type fd: file
3559
    @param fd: File object
3560
    @type filename: str
3561
    @param filename: Path of the file opened at I{fd}
3562

3563
    """
3564
    self.fd = fd
3565
    self.filename = filename
3566

    
3567
  @classmethod
3568
  def Open(cls, filename):
3569
    """Creates and opens a file to be used as a file-based lock.
3570

3571
    @type filename: string
3572
    @param filename: path to the file to be locked
3573

3574
    """
3575
    # Using "os.open" is necessary to allow both opening existing file
3576
    # read/write and creating if not existing. Vanilla "open" will truncate an
3577
    # existing file -or- allow creating if not existing.
3578
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3579
               filename)
3580

    
3581
  def __del__(self):
3582
    self.Close()
3583

    
3584
  def Close(self):
3585
    """Close the file and release the lock.
3586

3587
    """
3588
    if hasattr(self, "fd") and self.fd:
3589
      self.fd.close()
3590
      self.fd = None
3591

    
3592
  def _flock(self, flag, blocking, timeout, errmsg):
3593
    """Wrapper for fcntl.flock.
3594

3595
    @type flag: int
3596
    @param flag: operation flag
3597
    @type blocking: bool
3598
    @param blocking: whether the operation should be done in blocking mode.
3599
    @type timeout: None or float
3600
    @param timeout: for how long the operation should be retried (implies
3601
                    non-blocking mode).
3602
    @type errmsg: string
3603
    @param errmsg: error message in case operation fails.
3604

3605
    """
3606
    assert self.fd, "Lock was closed"
3607
    assert timeout is None or timeout >= 0, \
3608
      "If specified, timeout must be positive"
3609
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3610

    
3611
    # When a timeout is used, LOCK_NB must always be set
3612
    if not (timeout is None and blocking):
3613
      flag |= fcntl.LOCK_NB
3614

    
3615
    if timeout is None:
3616
      self._Lock(self.fd, flag, timeout)
3617
    else:
3618
      try:
3619
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3620
              args=(self.fd, flag, timeout))
3621
      except RetryTimeout:
3622
        raise errors.LockError(errmsg)
3623

    
3624
  @staticmethod
3625
  def _Lock(fd, flag, timeout):
3626
    try:
3627
      fcntl.flock(fd, flag)
3628
    except IOError, err:
3629
      if timeout is not None and err.errno == errno.EAGAIN:
3630
        raise RetryAgain()
3631

    
3632
      logging.exception("fcntl.flock failed")
3633
      raise
3634

    
3635
  def Exclusive(self, blocking=False, timeout=None):
3636
    """Locks the file in exclusive mode.
3637

3638
    @type blocking: boolean
3639
    @param blocking: whether to block and wait until we
3640
        can lock the file or return immediately
3641
    @type timeout: int or None
3642
    @param timeout: if not None, the duration to wait for the lock
3643
        (in blocking mode)
3644

3645
    """
3646
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3647
                "Failed to lock %s in exclusive mode" % self.filename)
3648

    
3649
  def Shared(self, blocking=False, timeout=None):
3650
    """Locks the file in shared mode.
3651

3652
    @type blocking: boolean
3653
    @param blocking: whether to block and wait until we
3654
        can lock the file or return immediately
3655
    @type timeout: int or None
3656
    @param timeout: if not None, the duration to wait for the lock
3657
        (in blocking mode)
3658

3659
    """
3660
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3661
                "Failed to lock %s in shared mode" % self.filename)
3662

    
3663
  def Unlock(self, blocking=True, timeout=None):
3664
    """Unlocks the file.
3665

3666
    According to C{flock(2)}, unlocking can also be a nonblocking
3667
    operation::
3668

3669
      To make a non-blocking request, include LOCK_NB with any of the above
3670
      operations.
3671

3672
    @type blocking: boolean
3673
    @param blocking: whether to block and wait until we
3674
        can lock the file or return immediately
3675
    @type timeout: int or None
3676
    @param timeout: if not None, the duration to wait for the lock
3677
        (in blocking mode)
3678

3679
    """
3680
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3681
                "Failed to unlock %s" % self.filename)
3682

    
3683

    
3684
class LineSplitter:
3685
  """Splits data chunks into lines separated by newline.
3686

3687
  Instances provide a file-like interface.
3688

3689
  """
3690
  def __init__(self, line_fn, *args):
3691
    """Initializes this class.
3692

3693
    @type line_fn: callable
3694
    @param line_fn: Function called for each line, first parameter is line
3695
    @param args: Extra arguments for L{line_fn}
3696

3697
    """
3698
    assert callable(line_fn)
3699

    
3700
    if args:
3701
      # Python 2.4 doesn't have functools.partial yet
3702
      self._line_fn = \
3703
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3704
    else:
3705
      self._line_fn = line_fn
3706

    
3707
    self._lines = collections.deque()
3708
    self._buffer = ""
3709

    
3710
  def write(self, data):
3711
    parts = (self._buffer + data).split("\n")
3712
    self._buffer = parts.pop()
3713
    self._lines.extend(parts)
3714

    
3715
  def flush(self):
3716
    while self._lines:
3717
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3718

    
3719
  def close(self):
3720
    self.flush()
3721
    if self._buffer:
3722
      self._line_fn(self._buffer)
3723

    
3724

    
3725
def SignalHandled(signums):
3726
  """Signal Handled decoration.
3727

3728
  This special decorator installs a signal handler and then calls the target
3729
  function. The function must accept a 'signal_handlers' keyword argument,
3730
  which will contain a dict indexed by signal number, with SignalHandler
3731
  objects as values.
3732

3733
  The decorator can be safely stacked with iself, to handle multiple signals
3734
  with different handlers.
3735

3736
  @type signums: list
3737
  @param signums: signals to intercept
3738

3739
  """
3740
  def wrap(fn):
3741
    def sig_function(*args, **kwargs):
3742
      assert 'signal_handlers' not in kwargs or \
3743
             kwargs['signal_handlers'] is None or \
3744
             isinstance(kwargs['signal_handlers'], dict), \
3745
             "Wrong signal_handlers parameter in original function call"
3746
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3747
        signal_handlers = kwargs['signal_handlers']
3748
      else:
3749
        signal_handlers = {}
3750
        kwargs['signal_handlers'] = signal_handlers
3751
      sighandler = SignalHandler(signums)
3752
      try:
3753
        for sig in signums:
3754
          signal_handlers[sig] = sighandler
3755
        return fn(*args, **kwargs)
3756
      finally:
3757
        sighandler.Reset()
3758
    return sig_function
3759
  return wrap
3760

    
3761

    
3762
class SignalWakeupFd(object):
3763
  try:
3764
    # This is only supported in Python 2.5 and above (some distributions
3765
    # backported it to Python 2.4)
3766
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3767
  except AttributeError:
3768
    # Not supported
3769
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3770
      return -1
3771
  else:
3772
    def _SetWakeupFd(self, fd):
3773
      return self._set_wakeup_fd_fn(fd)
3774

    
3775
  def __init__(self):
3776
    """Initializes this class.
3777

3778
    """
3779
    (read_fd, write_fd) = os.pipe()
3780

    
3781
    # Once these succeeded, the file descriptors will be closed automatically.
3782
    # Buffer size 0 is important, otherwise .read() with a specified length
3783
    # might buffer data and the file descriptors won't be marked readable.
3784
    self._read_fh = os.fdopen(read_fd, "r", 0)
3785
    self._write_fh = os.fdopen(write_fd, "w", 0)
3786

    
3787
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3788

    
3789
    # Utility functions
3790
    self.fileno = self._read_fh.fileno
3791
    self.read = self._read_fh.read
3792

    
3793
  def Reset(self):
3794
    """Restores the previous wakeup file descriptor.
3795

3796
    """
3797
    if hasattr(self, "_previous") and self._previous is not None:
3798
      self._SetWakeupFd(self._previous)
3799
      self._previous = None
3800

    
3801
  def Notify(self):
3802
    """Notifies the wakeup file descriptor.
3803

3804
    """
3805
    self._write_fh.write("\0")
3806

    
3807
  def __del__(self):
3808
    """Called before object deletion.
3809

3810
    """
3811
    self.Reset()
3812

    
3813

    
3814
class SignalHandler(object):
3815
  """Generic signal handler class.
3816

3817
  It automatically restores the original handler when deconstructed or
3818
  when L{Reset} is called. You can either pass your own handler
3819
  function in or query the L{called} attribute to detect whether the
3820
  signal was sent.
3821

3822
  @type signum: list
3823
  @ivar signum: the signals we handle
3824
  @type called: boolean
3825
  @ivar called: tracks whether any of the signals have been raised
3826

3827
  """
3828
  def __init__(self, signum, handler_fn=None, wakeup=None):
3829
    """Constructs a new SignalHandler instance.
3830

3831
    @type signum: int or list of ints
3832
    @param signum: Single signal number or set of signal numbers
3833
    @type handler_fn: callable
3834
    @param handler_fn: Signal handling function
3835

3836
    """
3837
    assert handler_fn is None or callable(handler_fn)
3838

    
3839
    self.signum = set(signum)
3840
    self.called = False
3841

    
3842
    self._handler_fn = handler_fn
3843
    self._wakeup = wakeup
3844

    
3845
    self._previous = {}
3846
    try:
3847
      for signum in self.signum:
3848
        # Setup handler
3849
        prev_handler = signal.signal(signum, self._HandleSignal)
3850
        try:
3851
          self._previous[signum] = prev_handler
3852
        except:
3853
          # Restore previous handler
3854
          signal.signal(signum, prev_handler)
3855
          raise
3856
    except:
3857
      # Reset all handlers
3858
      self.Reset()
3859
      # Here we have a race condition: a handler may have already been called,
3860
      # but there's not much we can do about it at this point.
3861
      raise
3862

    
3863
  def __del__(self):
3864
    self.Reset()
3865

    
3866
  def Reset(self):
3867
    """Restore previous handler.
3868

3869
    This will reset all the signals to their previous handlers.
3870

3871
    """
3872
    for signum, prev_handler in self._previous.items():
3873
      signal.signal(signum, prev_handler)
3874
      # If successful, remove from dict
3875
      del self._previous[signum]
3876

    
3877
  def Clear(self):
3878
    """Unsets the L{called} flag.
3879

3880
    This function can be used in case a signal may arrive several times.
3881

3882
    """
3883
    self.called = False
3884

    
3885
  def _HandleSignal(self, signum, frame):
3886
    """Actual signal handling function.
3887

3888
    """
3889
    # This is not nice and not absolutely atomic, but it appears to be the only
3890
    # solution in Python -- there are no atomic types.
3891
    self.called = True
3892

    
3893
    if self._wakeup:
3894
      # Notify whoever is interested in signals
3895
      self._wakeup.Notify()
3896

    
3897
    if self._handler_fn:
3898
      self._handler_fn(signum, frame)
3899

    
3900

    
3901
class FieldSet(object):
3902
  """A simple field set.
3903

3904
  Among the features are:
3905
    - checking if a string is among a list of static string or regex objects
3906
    - checking if a whole list of string matches
3907
    - returning the matching groups from a regex match
3908

3909
  Internally, all fields are held as regular expression objects.
3910

3911
  """
3912
  def __init__(self, *items):
3913
    self.items = [re.compile("^%s$" % value) for value in items]
3914

    
3915
  def Extend(self, other_set):
3916
    """Extend the field set with the items from another one"""
3917
    self.items.extend(other_set.items)
3918

    
3919
  def Matches(self, field):
3920
    """Checks if a field matches the current set
3921

3922
    @type field: str
3923
    @param field: the string to match
3924
    @return: either None or a regular expression match object
3925

3926
    """
3927
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3928
      return m
3929
    return None
3930

    
3931
  def NonMatching(self, items):
3932
    """Returns the list of fields not matching the current set
3933

3934
    @type items: list
3935
    @param items: the list of fields to check
3936
    @rtype: list
3937
    @return: list of non-matching fields
3938

3939
    """
3940
    return [val for val in items if not self.Matches(val)]