Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ ac492887

History | View | Annotate | Download (106.3 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import logging.handlers
46
import signal
47
import OpenSSL
48
import datetime
49
import calendar
50
import hmac
51
import collections
52
import struct
53
import IN
54

    
55
from cStringIO import StringIO
56

    
57
try:
58
  # pylint: disable-msg=F0401
59
  import ctypes
60
except ImportError:
61
  ctypes = None
62

    
63
from ganeti import errors
64
from ganeti import constants
65
from ganeti import compat
66

    
67

    
68
_locksheld = []
69
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
70

    
71
debug_locks = False
72

    
73
#: when set to True, L{RunCmd} is disabled
74
no_fork = False
75

    
76
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
77

    
78
HEX_CHAR_RE = r"[a-zA-Z0-9]"
79
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
80
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
81
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
82
                             HEX_CHAR_RE, HEX_CHAR_RE),
83
                            re.S | re.I)
84

    
85
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
86

    
87
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
88
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
89
#
90
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
91
# pid_t to "int".
92
#
93
# IEEE Std 1003.1-2008:
94
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
95
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
96
_STRUCT_UCRED = "iII"
97
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
98

    
99
# Certificate verification results
100
(CERT_WARNING,
101
 CERT_ERROR) = range(1, 3)
102

    
103
# Flags for mlockall() (from bits/mman.h)
104
_MCL_CURRENT = 1
105
_MCL_FUTURE = 2
106

    
107

    
108
class RunResult(object):
109
  """Holds the result of running external programs.
110

111
  @type exit_code: int
112
  @ivar exit_code: the exit code of the program, or None (if the program
113
      didn't exit())
114
  @type signal: int or None
115
  @ivar signal: the signal that caused the program to finish, or None
116
      (if the program wasn't terminated by a signal)
117
  @type stdout: str
118
  @ivar stdout: the standard output of the program
119
  @type stderr: str
120
  @ivar stderr: the standard error of the program
121
  @type failed: boolean
122
  @ivar failed: True in case the program was
123
      terminated by a signal or exited with a non-zero exit code
124
  @ivar fail_reason: a string detailing the termination reason
125

126
  """
127
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
128
               "failed", "fail_reason", "cmd"]
129

    
130

    
131
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
132
    self.cmd = cmd
133
    self.exit_code = exit_code
134
    self.signal = signal_
135
    self.stdout = stdout
136
    self.stderr = stderr
137
    self.failed = (signal_ is not None or exit_code != 0)
138

    
139
    if self.signal is not None:
140
      self.fail_reason = "terminated by signal %s" % self.signal
141
    elif self.exit_code is not None:
142
      self.fail_reason = "exited with exit code %s" % self.exit_code
143
    else:
144
      self.fail_reason = "unable to determine termination reason"
145

    
146
    if self.failed:
147
      logging.debug("Command '%s' failed (%s); output: %s",
148
                    self.cmd, self.fail_reason, self.output)
149

    
150
  def _GetOutput(self):
151
    """Returns the combined stdout and stderr for easier usage.
152

153
    """
154
    return self.stdout + self.stderr
155

    
156
  output = property(_GetOutput, None, None, "Return full output")
157

    
158

    
159
def _BuildCmdEnvironment(env, reset):
160
  """Builds the environment for an external program.
161

162
  """
163
  if reset:
164
    cmd_env = {}
165
  else:
166
    cmd_env = os.environ.copy()
167
    cmd_env["LC_ALL"] = "C"
168

    
169
  if env is not None:
170
    cmd_env.update(env)
171

    
172
  return cmd_env
173

    
174

    
175
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False):
176
  """Execute a (shell) command.
177

178
  The command should not read from its standard input, as it will be
179
  closed.
180

181
  @type cmd: string or list
182
  @param cmd: Command to run
183
  @type env: dict
184
  @param env: Additional environment variables
185
  @type output: str
186
  @param output: if desired, the output of the command can be
187
      saved in a file instead of the RunResult instance; this
188
      parameter denotes the file name (if not None)
189
  @type cwd: string
190
  @param cwd: if specified, will be used as the working
191
      directory for the command; the default will be /
192
  @type reset_env: boolean
193
  @param reset_env: whether to reset or keep the default os environment
194
  @rtype: L{RunResult}
195
  @return: RunResult instance
196
  @raise errors.ProgrammerError: if we call this when forks are disabled
197

198
  """
199
  if no_fork:
200
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
201

    
202
  if isinstance(cmd, basestring):
203
    strcmd = cmd
204
    shell = True
205
  else:
206
    cmd = [str(val) for val in cmd]
207
    strcmd = ShellQuoteArgs(cmd)
208
    shell = False
209

    
210
  if output:
211
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
212
  else:
213
    logging.debug("RunCmd %s", strcmd)
214

    
215
  cmd_env = _BuildCmdEnvironment(env, reset_env)
216

    
217
  try:
218
    if output is None:
219
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
220
    else:
221
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
222
      out = err = ""
223
  except OSError, err:
224
    if err.errno == errno.ENOENT:
225
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
226
                               (strcmd, err))
227
    else:
228
      raise
229

    
230
  if status >= 0:
231
    exitcode = status
232
    signal_ = None
233
  else:
234
    exitcode = None
235
    signal_ = -status
236

    
237
  return RunResult(exitcode, signal_, out, err, strcmd)
238

    
239

    
240
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
241
                pidfile=None):
242
  """Start a daemon process after forking twice.
243

244
  @type cmd: string or list
245
  @param cmd: Command to run
246
  @type env: dict
247
  @param env: Additional environment variables
248
  @type cwd: string
249
  @param cwd: Working directory for the program
250
  @type output: string
251
  @param output: Path to file in which to save the output
252
  @type output_fd: int
253
  @param output_fd: File descriptor for output
254
  @type pidfile: string
255
  @param pidfile: Process ID file
256
  @rtype: int
257
  @return: Daemon process ID
258
  @raise errors.ProgrammerError: if we call this when forks are disabled
259

260
  """
261
  if no_fork:
262
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
263
                                 " disabled")
264

    
265
  if output and not (bool(output) ^ (output_fd is not None)):
266
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
267
                                 " specified")
268

    
269
  if isinstance(cmd, basestring):
270
    cmd = ["/bin/sh", "-c", cmd]
271

    
272
  strcmd = ShellQuoteArgs(cmd)
273

    
274
  if output:
275
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
276
  else:
277
    logging.debug("StartDaemon %s", strcmd)
278

    
279
  cmd_env = _BuildCmdEnvironment(env, False)
280

    
281
  # Create pipe for sending PID back
282
  (pidpipe_read, pidpipe_write) = os.pipe()
283
  try:
284
    try:
285
      # Create pipe for sending error messages
286
      (errpipe_read, errpipe_write) = os.pipe()
287
      try:
288
        try:
289
          # First fork
290
          pid = os.fork()
291
          if pid == 0:
292
            try:
293
              # Child process, won't return
294
              _StartDaemonChild(errpipe_read, errpipe_write,
295
                                pidpipe_read, pidpipe_write,
296
                                cmd, cmd_env, cwd,
297
                                output, output_fd, pidfile)
298
            finally:
299
              # Well, maybe child process failed
300
              os._exit(1) # pylint: disable-msg=W0212
301
        finally:
302
          _CloseFDNoErr(errpipe_write)
303

    
304
        # Wait for daemon to be started (or an error message to arrive) and read
305
        # up to 100 KB as an error message
306
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
307
      finally:
308
        _CloseFDNoErr(errpipe_read)
309
    finally:
310
      _CloseFDNoErr(pidpipe_write)
311

    
312
    # Read up to 128 bytes for PID
313
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
314
  finally:
315
    _CloseFDNoErr(pidpipe_read)
316

    
317
  # Try to avoid zombies by waiting for child process
318
  try:
319
    os.waitpid(pid, 0)
320
  except OSError:
321
    pass
322

    
323
  if errormsg:
324
    raise errors.OpExecError("Error when starting daemon process: %r" %
325
                             errormsg)
326

    
327
  try:
328
    return int(pidtext)
329
  except (ValueError, TypeError), err:
330
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
331
                             (pidtext, err))
332

    
333

    
334
def _StartDaemonChild(errpipe_read, errpipe_write,
335
                      pidpipe_read, pidpipe_write,
336
                      args, env, cwd,
337
                      output, fd_output, pidfile):
338
  """Child process for starting daemon.
339

340
  """
341
  try:
342
    # Close parent's side
343
    _CloseFDNoErr(errpipe_read)
344
    _CloseFDNoErr(pidpipe_read)
345

    
346
    # First child process
347
    os.chdir("/")
348
    os.umask(077)
349
    os.setsid()
350

    
351
    # And fork for the second time
352
    pid = os.fork()
353
    if pid != 0:
354
      # Exit first child process
355
      os._exit(0) # pylint: disable-msg=W0212
356

    
357
    # Make sure pipe is closed on execv* (and thereby notifies original process)
358
    SetCloseOnExecFlag(errpipe_write, True)
359

    
360
    # List of file descriptors to be left open
361
    noclose_fds = [errpipe_write]
362

    
363
    # Open PID file
364
    if pidfile:
365
      try:
366
        # TODO: Atomic replace with another locked file instead of writing into
367
        # it after creating
368
        fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
369

    
370
        # Lock the PID file (and fail if not possible to do so). Any code
371
        # wanting to send a signal to the daemon should try to lock the PID
372
        # file before reading it. If acquiring the lock succeeds, the daemon is
373
        # no longer running and the signal should not be sent.
374
        LockFile(fd_pidfile)
375

    
376
        os.write(fd_pidfile, "%d\n" % os.getpid())
377
      except Exception, err:
378
        raise Exception("Creating and locking PID file failed: %s" % err)
379

    
380
      # Keeping the file open to hold the lock
381
      noclose_fds.append(fd_pidfile)
382

    
383
      SetCloseOnExecFlag(fd_pidfile, False)
384
    else:
385
      fd_pidfile = None
386

    
387
    # Open /dev/null
388
    fd_devnull = os.open(os.devnull, os.O_RDWR)
389

    
390
    assert not output or (bool(output) ^ (fd_output is not None))
391

    
392
    if fd_output is not None:
393
      pass
394
    elif output:
395
      # Open output file
396
      try:
397
        # TODO: Implement flag to set append=yes/no
398
        fd_output = os.open(output, os.O_WRONLY | os.O_CREAT, 0600)
399
      except EnvironmentError, err:
400
        raise Exception("Opening output file failed: %s" % err)
401
    else:
402
      fd_output = fd_devnull
403

    
404
    # Redirect standard I/O
405
    os.dup2(fd_devnull, 0)
406
    os.dup2(fd_output, 1)
407
    os.dup2(fd_output, 2)
408

    
409
    # Send daemon PID to parent
410
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
411

    
412
    # Close all file descriptors except stdio and error message pipe
413
    CloseFDs(noclose_fds=noclose_fds)
414

    
415
    # Change working directory
416
    os.chdir(cwd)
417

    
418
    if env is None:
419
      os.execvp(args[0], args)
420
    else:
421
      os.execvpe(args[0], args, env)
422
  except: # pylint: disable-msg=W0702
423
    try:
424
      # Report errors to original process
425
      buf = str(sys.exc_info()[1])
426

    
427
      RetryOnSignal(os.write, errpipe_write, buf)
428
    except: # pylint: disable-msg=W0702
429
      # Ignore errors in error handling
430
      pass
431

    
432
  os._exit(1) # pylint: disable-msg=W0212
433

    
434

    
435
def _RunCmdPipe(cmd, env, via_shell, cwd):
436
  """Run a command and return its output.
437

438
  @type  cmd: string or list
439
  @param cmd: Command to run
440
  @type env: dict
441
  @param env: The environment to use
442
  @type via_shell: bool
443
  @param via_shell: if we should run via the shell
444
  @type cwd: string
445
  @param cwd: the working directory for the program
446
  @rtype: tuple
447
  @return: (out, err, status)
448

449
  """
450
  poller = select.poll()
451
  child = subprocess.Popen(cmd, shell=via_shell,
452
                           stderr=subprocess.PIPE,
453
                           stdout=subprocess.PIPE,
454
                           stdin=subprocess.PIPE,
455
                           close_fds=True, env=env,
456
                           cwd=cwd)
457

    
458
  child.stdin.close()
459
  poller.register(child.stdout, select.POLLIN)
460
  poller.register(child.stderr, select.POLLIN)
461
  out = StringIO()
462
  err = StringIO()
463
  fdmap = {
464
    child.stdout.fileno(): (out, child.stdout),
465
    child.stderr.fileno(): (err, child.stderr),
466
    }
467
  for fd in fdmap:
468
    SetNonblockFlag(fd, True)
469

    
470
  while fdmap:
471
    pollresult = RetryOnSignal(poller.poll)
472

    
473
    for fd, event in pollresult:
474
      if event & select.POLLIN or event & select.POLLPRI:
475
        data = fdmap[fd][1].read()
476
        # no data from read signifies EOF (the same as POLLHUP)
477
        if not data:
478
          poller.unregister(fd)
479
          del fdmap[fd]
480
          continue
481
        fdmap[fd][0].write(data)
482
      if (event & select.POLLNVAL or event & select.POLLHUP or
483
          event & select.POLLERR):
484
        poller.unregister(fd)
485
        del fdmap[fd]
486

    
487
  out = out.getvalue()
488
  err = err.getvalue()
489

    
490
  status = child.wait()
491
  return out, err, status
492

    
493

    
494
def _RunCmdFile(cmd, env, via_shell, output, cwd):
495
  """Run a command and save its output to a file.
496

497
  @type  cmd: string or list
498
  @param cmd: Command to run
499
  @type env: dict
500
  @param env: The environment to use
501
  @type via_shell: bool
502
  @param via_shell: if we should run via the shell
503
  @type output: str
504
  @param output: the filename in which to save the output
505
  @type cwd: string
506
  @param cwd: the working directory for the program
507
  @rtype: int
508
  @return: the exit status
509

510
  """
511
  fh = open(output, "a")
512
  try:
513
    child = subprocess.Popen(cmd, shell=via_shell,
514
                             stderr=subprocess.STDOUT,
515
                             stdout=fh,
516
                             stdin=subprocess.PIPE,
517
                             close_fds=True, env=env,
518
                             cwd=cwd)
519

    
520
    child.stdin.close()
521
    status = child.wait()
522
  finally:
523
    fh.close()
524
  return status
525

    
526

    
527
def SetCloseOnExecFlag(fd, enable):
528
  """Sets or unsets the close-on-exec flag on a file descriptor.
529

530
  @type fd: int
531
  @param fd: File descriptor
532
  @type enable: bool
533
  @param enable: Whether to set or unset it.
534

535
  """
536
  flags = fcntl.fcntl(fd, fcntl.F_GETFD)
537

    
538
  if enable:
539
    flags |= fcntl.FD_CLOEXEC
540
  else:
541
    flags &= ~fcntl.FD_CLOEXEC
542

    
543
  fcntl.fcntl(fd, fcntl.F_SETFD, flags)
544

    
545

    
546
def SetNonblockFlag(fd, enable):
547
  """Sets or unsets the O_NONBLOCK flag on on a file descriptor.
548

549
  @type fd: int
550
  @param fd: File descriptor
551
  @type enable: bool
552
  @param enable: Whether to set or unset it
553

554
  """
555
  flags = fcntl.fcntl(fd, fcntl.F_GETFL)
556

    
557
  if enable:
558
    flags |= os.O_NONBLOCK
559
  else:
560
    flags &= ~os.O_NONBLOCK
561

    
562
  fcntl.fcntl(fd, fcntl.F_SETFL, flags)
563

    
564

    
565
def RetryOnSignal(fn, *args, **kwargs):
566
  """Calls a function again if it failed due to EINTR.
567

568
  """
569
  while True:
570
    try:
571
      return fn(*args, **kwargs)
572
    except EnvironmentError, err:
573
      if err.errno != errno.EINTR:
574
        raise
575
    except (socket.error, select.error), err:
576
      # In python 2.6 and above select.error is an IOError, so it's handled
577
      # above, in 2.5 and below it's not, and it's handled here.
578
      if not (err.args and err.args[0] == errno.EINTR):
579
        raise
580

    
581

    
582
def RunParts(dir_name, env=None, reset_env=False):
583
  """Run Scripts or programs in a directory
584

585
  @type dir_name: string
586
  @param dir_name: absolute path to a directory
587
  @type env: dict
588
  @param env: The environment to use
589
  @type reset_env: boolean
590
  @param reset_env: whether to reset or keep the default os environment
591
  @rtype: list of tuples
592
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
593

594
  """
595
  rr = []
596

    
597
  try:
598
    dir_contents = ListVisibleFiles(dir_name)
599
  except OSError, err:
600
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
601
    return rr
602

    
603
  for relname in sorted(dir_contents):
604
    fname = PathJoin(dir_name, relname)
605
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
606
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
607
      rr.append((relname, constants.RUNPARTS_SKIP, None))
608
    else:
609
      try:
610
        result = RunCmd([fname], env=env, reset_env=reset_env)
611
      except Exception, err: # pylint: disable-msg=W0703
612
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
613
      else:
614
        rr.append((relname, constants.RUNPARTS_RUN, result))
615

    
616
  return rr
617

    
618

    
619
def GetSocketCredentials(sock):
620
  """Returns the credentials of the foreign process connected to a socket.
621

622
  @param sock: Unix socket
623
  @rtype: tuple; (number, number, number)
624
  @return: The PID, UID and GID of the connected foreign process.
625

626
  """
627
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
628
                             _STRUCT_UCRED_SIZE)
629
  return struct.unpack(_STRUCT_UCRED, peercred)
630

    
631

    
632
def RemoveFile(filename):
633
  """Remove a file ignoring some errors.
634

635
  Remove a file, ignoring non-existing ones or directories. Other
636
  errors are passed.
637

638
  @type filename: str
639
  @param filename: the file to be removed
640

641
  """
642
  try:
643
    os.unlink(filename)
644
  except OSError, err:
645
    if err.errno not in (errno.ENOENT, errno.EISDIR):
646
      raise
647

    
648

    
649
def RemoveDir(dirname):
650
  """Remove an empty directory.
651

652
  Remove a directory, ignoring non-existing ones.
653
  Other errors are passed. This includes the case,
654
  where the directory is not empty, so it can't be removed.
655

656
  @type dirname: str
657
  @param dirname: the empty directory to be removed
658

659
  """
660
  try:
661
    os.rmdir(dirname)
662
  except OSError, err:
663
    if err.errno != errno.ENOENT:
664
      raise
665

    
666

    
667
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
668
  """Renames a file.
669

670
  @type old: string
671
  @param old: Original path
672
  @type new: string
673
  @param new: New path
674
  @type mkdir: bool
675
  @param mkdir: Whether to create target directory if it doesn't exist
676
  @type mkdir_mode: int
677
  @param mkdir_mode: Mode for newly created directories
678

679
  """
680
  try:
681
    return os.rename(old, new)
682
  except OSError, err:
683
    # In at least one use case of this function, the job queue, directory
684
    # creation is very rare. Checking for the directory before renaming is not
685
    # as efficient.
686
    if mkdir and err.errno == errno.ENOENT:
687
      # Create directory and try again
688
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
689

    
690
      return os.rename(old, new)
691

    
692
    raise
693

    
694

    
695
def Makedirs(path, mode=0750):
696
  """Super-mkdir; create a leaf directory and all intermediate ones.
697

698
  This is a wrapper around C{os.makedirs} adding error handling not implemented
699
  before Python 2.5.
700

701
  """
702
  try:
703
    os.makedirs(path, mode)
704
  except OSError, err:
705
    # Ignore EEXIST. This is only handled in os.makedirs as included in
706
    # Python 2.5 and above.
707
    if err.errno != errno.EEXIST or not os.path.exists(path):
708
      raise
709

    
710

    
711
def ResetTempfileModule():
712
  """Resets the random name generator of the tempfile module.
713

714
  This function should be called after C{os.fork} in the child process to
715
  ensure it creates a newly seeded random generator. Otherwise it would
716
  generate the same random parts as the parent process. If several processes
717
  race for the creation of a temporary file, this could lead to one not getting
718
  a temporary name.
719

720
  """
721
  # pylint: disable-msg=W0212
722
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
723
    tempfile._once_lock.acquire()
724
    try:
725
      # Reset random name generator
726
      tempfile._name_sequence = None
727
    finally:
728
      tempfile._once_lock.release()
729
  else:
730
    logging.critical("The tempfile module misses at least one of the"
731
                     " '_once_lock' and '_name_sequence' attributes")
732

    
733

    
734
def _FingerprintFile(filename):
735
  """Compute the fingerprint of a file.
736

737
  If the file does not exist, a None will be returned
738
  instead.
739

740
  @type filename: str
741
  @param filename: the filename to checksum
742
  @rtype: str
743
  @return: the hex digest of the sha checksum of the contents
744
      of the file
745

746
  """
747
  if not (os.path.exists(filename) and os.path.isfile(filename)):
748
    return None
749

    
750
  f = open(filename)
751

    
752
  fp = compat.sha1_hash()
753
  while True:
754
    data = f.read(4096)
755
    if not data:
756
      break
757

    
758
    fp.update(data)
759

    
760
  return fp.hexdigest()
761

    
762

    
763
def FingerprintFiles(files):
764
  """Compute fingerprints for a list of files.
765

766
  @type files: list
767
  @param files: the list of filename to fingerprint
768
  @rtype: dict
769
  @return: a dictionary filename: fingerprint, holding only
770
      existing files
771

772
  """
773
  ret = {}
774

    
775
  for filename in files:
776
    cksum = _FingerprintFile(filename)
777
    if cksum:
778
      ret[filename] = cksum
779

    
780
  return ret
781

    
782

    
783
def ForceDictType(target, key_types, allowed_values=None):
784
  """Force the values of a dict to have certain types.
785

786
  @type target: dict
787
  @param target: the dict to update
788
  @type key_types: dict
789
  @param key_types: dict mapping target dict keys to types
790
                    in constants.ENFORCEABLE_TYPES
791
  @type allowed_values: list
792
  @keyword allowed_values: list of specially allowed values
793

794
  """
795
  if allowed_values is None:
796
    allowed_values = []
797

    
798
  if not isinstance(target, dict):
799
    msg = "Expected dictionary, got '%s'" % target
800
    raise errors.TypeEnforcementError(msg)
801

    
802
  for key in target:
803
    if key not in key_types:
804
      msg = "Unknown key '%s'" % key
805
      raise errors.TypeEnforcementError(msg)
806

    
807
    if target[key] in allowed_values:
808
      continue
809

    
810
    ktype = key_types[key]
811
    if ktype not in constants.ENFORCEABLE_TYPES:
812
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
813
      raise errors.ProgrammerError(msg)
814

    
815
    if ktype == constants.VTYPE_STRING:
816
      if not isinstance(target[key], basestring):
817
        if isinstance(target[key], bool) and not target[key]:
818
          target[key] = ''
819
        else:
820
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
821
          raise errors.TypeEnforcementError(msg)
822
    elif ktype == constants.VTYPE_BOOL:
823
      if isinstance(target[key], basestring) and target[key]:
824
        if target[key].lower() == constants.VALUE_FALSE:
825
          target[key] = False
826
        elif target[key].lower() == constants.VALUE_TRUE:
827
          target[key] = True
828
        else:
829
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
830
          raise errors.TypeEnforcementError(msg)
831
      elif target[key]:
832
        target[key] = True
833
      else:
834
        target[key] = False
835
    elif ktype == constants.VTYPE_SIZE:
836
      try:
837
        target[key] = ParseUnit(target[key])
838
      except errors.UnitParseError, err:
839
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
840
              (key, target[key], err)
841
        raise errors.TypeEnforcementError(msg)
842
    elif ktype == constants.VTYPE_INT:
843
      try:
844
        target[key] = int(target[key])
845
      except (ValueError, TypeError):
846
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
847
        raise errors.TypeEnforcementError(msg)
848

    
849

    
850
def _GetProcStatusPath(pid):
851
  """Returns the path for a PID's proc status file.
852

853
  @type pid: int
854
  @param pid: Process ID
855
  @rtype: string
856

857
  """
858
  return "/proc/%d/status" % pid
859

    
860

    
861
def IsProcessAlive(pid):
862
  """Check if a given pid exists on the system.
863

864
  @note: zombie status is not handled, so zombie processes
865
      will be returned as alive
866
  @type pid: int
867
  @param pid: the process ID to check
868
  @rtype: boolean
869
  @return: True if the process exists
870

871
  """
872
  def _TryStat(name):
873
    try:
874
      os.stat(name)
875
      return True
876
    except EnvironmentError, err:
877
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
878
        return False
879
      elif err.errno == errno.EINVAL:
880
        raise RetryAgain(err)
881
      raise
882

    
883
  assert isinstance(pid, int), "pid must be an integer"
884
  if pid <= 0:
885
    return False
886

    
887
  # /proc in a multiprocessor environment can have strange behaviors.
888
  # Retry the os.stat a few times until we get a good result.
889
  try:
890
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
891
                 args=[_GetProcStatusPath(pid)])
892
  except RetryTimeout, err:
893
    err.RaiseInner()
894

    
895

    
896
def _ParseSigsetT(sigset):
897
  """Parse a rendered sigset_t value.
898

899
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
900
  function.
901

902
  @type sigset: string
903
  @param sigset: Rendered signal set from /proc/$pid/status
904
  @rtype: set
905
  @return: Set of all enabled signal numbers
906

907
  """
908
  result = set()
909

    
910
  signum = 0
911
  for ch in reversed(sigset):
912
    chv = int(ch, 16)
913

    
914
    # The following could be done in a loop, but it's easier to read and
915
    # understand in the unrolled form
916
    if chv & 1:
917
      result.add(signum + 1)
918
    if chv & 2:
919
      result.add(signum + 2)
920
    if chv & 4:
921
      result.add(signum + 3)
922
    if chv & 8:
923
      result.add(signum + 4)
924

    
925
    signum += 4
926

    
927
  return result
928

    
929

    
930
def _GetProcStatusField(pstatus, field):
931
  """Retrieves a field from the contents of a proc status file.
932

933
  @type pstatus: string
934
  @param pstatus: Contents of /proc/$pid/status
935
  @type field: string
936
  @param field: Name of field whose value should be returned
937
  @rtype: string
938

939
  """
940
  for line in pstatus.splitlines():
941
    parts = line.split(":", 1)
942

    
943
    if len(parts) < 2 or parts[0] != field:
944
      continue
945

    
946
    return parts[1].strip()
947

    
948
  return None
949

    
950

    
951
def IsProcessHandlingSignal(pid, signum, status_path=None):
952
  """Checks whether a process is handling a signal.
953

954
  @type pid: int
955
  @param pid: Process ID
956
  @type signum: int
957
  @param signum: Signal number
958
  @rtype: bool
959

960
  """
961
  if status_path is None:
962
    status_path = _GetProcStatusPath(pid)
963

    
964
  try:
965
    proc_status = ReadFile(status_path)
966
  except EnvironmentError, err:
967
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
968
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
969
      return False
970
    raise
971

    
972
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
973
  if sigcgt is None:
974
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
975

    
976
  # Now check whether signal is handled
977
  return signum in _ParseSigsetT(sigcgt)
978

    
979

    
980
def ReadPidFile(pidfile):
981
  """Read a pid from a file.
982

983
  @type  pidfile: string
984
  @param pidfile: path to the file containing the pid
985
  @rtype: int
986
  @return: The process id, if the file exists and contains a valid PID,
987
           otherwise 0
988

989
  """
990
  try:
991
    raw_data = ReadOneLineFile(pidfile)
992
  except EnvironmentError, err:
993
    if err.errno != errno.ENOENT:
994
      logging.exception("Can't read pid file")
995
    return 0
996

    
997
  try:
998
    pid = int(raw_data)
999
  except (TypeError, ValueError), err:
1000
    logging.info("Can't parse pid file contents", exc_info=True)
1001
    return 0
1002

    
1003
  return pid
1004

    
1005

    
1006
def ReadLockedPidFile(path):
1007
  """Reads a locked PID file.
1008

1009
  This can be used together with L{StartDaemon}.
1010

1011
  @type path: string
1012
  @param path: Path to PID file
1013
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1014

1015
  """
1016
  try:
1017
    fd = os.open(path, os.O_RDONLY)
1018
  except EnvironmentError, err:
1019
    if err.errno == errno.ENOENT:
1020
      # PID file doesn't exist
1021
      return None
1022
    raise
1023

    
1024
  try:
1025
    try:
1026
      # Try to acquire lock
1027
      LockFile(fd)
1028
    except errors.LockError:
1029
      # Couldn't lock, daemon is running
1030
      return int(os.read(fd, 100))
1031
  finally:
1032
    os.close(fd)
1033

    
1034
  return None
1035

    
1036

    
1037
def MatchNameComponent(key, name_list, case_sensitive=True):
1038
  """Try to match a name against a list.
1039

1040
  This function will try to match a name like test1 against a list
1041
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
1042
  this list, I{'test1'} as well as I{'test1.example'} will match, but
1043
  not I{'test1.ex'}. A multiple match will be considered as no match
1044
  at all (e.g. I{'test1'} against C{['test1.example.com',
1045
  'test1.example.org']}), except when the key fully matches an entry
1046
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
1047

1048
  @type key: str
1049
  @param key: the name to be searched
1050
  @type name_list: list
1051
  @param name_list: the list of strings against which to search the key
1052
  @type case_sensitive: boolean
1053
  @param case_sensitive: whether to provide a case-sensitive match
1054

1055
  @rtype: None or str
1056
  @return: None if there is no match I{or} if there are multiple matches,
1057
      otherwise the element from the list which matches
1058

1059
  """
1060
  if key in name_list:
1061
    return key
1062

    
1063
  re_flags = 0
1064
  if not case_sensitive:
1065
    re_flags |= re.IGNORECASE
1066
    key = key.upper()
1067
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
1068
  names_filtered = []
1069
  string_matches = []
1070
  for name in name_list:
1071
    if mo.match(name) is not None:
1072
      names_filtered.append(name)
1073
      if not case_sensitive and key == name.upper():
1074
        string_matches.append(name)
1075

    
1076
  if len(string_matches) == 1:
1077
    return string_matches[0]
1078
  if len(names_filtered) == 1:
1079
    return names_filtered[0]
1080
  return None
1081

    
1082

    
1083
class HostInfo:
1084
  """Class implementing resolver and hostname functionality
1085

1086
  """
1087
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
1088

    
1089
  def __init__(self, name=None):
1090
    """Initialize the host name object.
1091

1092
    If the name argument is not passed, it will use this system's
1093
    name.
1094

1095
    """
1096
    if name is None:
1097
      name = self.SysName()
1098

    
1099
    self.query = name
1100
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
1101
    self.ip = self.ipaddrs[0]
1102

    
1103
  def ShortName(self):
1104
    """Returns the hostname without domain.
1105

1106
    """
1107
    return self.name.split('.')[0]
1108

    
1109
  @staticmethod
1110
  def SysName():
1111
    """Return the current system's name.
1112

1113
    This is simply a wrapper over C{socket.gethostname()}.
1114

1115
    """
1116
    return socket.gethostname()
1117

    
1118
  @staticmethod
1119
  def LookupHostname(hostname):
1120
    """Look up hostname
1121

1122
    @type hostname: str
1123
    @param hostname: hostname to look up
1124

1125
    @rtype: tuple
1126
    @return: a tuple (name, aliases, ipaddrs) as returned by
1127
        C{socket.gethostbyname_ex}
1128
    @raise errors.ResolverError: in case of errors in resolving
1129

1130
    """
1131
    try:
1132
      result = socket.gethostbyname_ex(hostname)
1133
    except (socket.gaierror, socket.herror, socket.error), err:
1134
      # hostname not found in DNS, or other socket exception in the
1135
      # (code, description format)
1136
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
1137

    
1138
    return result
1139

    
1140
  @classmethod
1141
  def NormalizeName(cls, hostname):
1142
    """Validate and normalize the given hostname.
1143

1144
    @attention: the validation is a bit more relaxed than the standards
1145
        require; most importantly, we allow underscores in names
1146
    @raise errors.OpPrereqError: when the name is not valid
1147

1148
    """
1149
    hostname = hostname.lower()
1150
    if (not cls._VALID_NAME_RE.match(hostname) or
1151
        # double-dots, meaning empty label
1152
        ".." in hostname or
1153
        # empty initial label
1154
        hostname.startswith(".")):
1155
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
1156
                                 errors.ECODE_INVAL)
1157
    if hostname.endswith("."):
1158
      hostname = hostname.rstrip(".")
1159
    return hostname
1160

    
1161

    
1162
def ValidateServiceName(name):
1163
  """Validate the given service name.
1164

1165
  @type name: number or string
1166
  @param name: Service name or port specification
1167

1168
  """
1169
  try:
1170
    numport = int(name)
1171
  except (ValueError, TypeError):
1172
    # Non-numeric service name
1173
    valid = _VALID_SERVICE_NAME_RE.match(name)
1174
  else:
1175
    # Numeric port (protocols other than TCP or UDP might need adjustments
1176
    # here)
1177
    valid = (numport >= 0 and numport < (1 << 16))
1178

    
1179
  if not valid:
1180
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1181
                               errors.ECODE_INVAL)
1182

    
1183
  return name
1184

    
1185

    
1186
def GetHostInfo(name=None):
1187
  """Lookup host name and raise an OpPrereqError for failures"""
1188

    
1189
  try:
1190
    return HostInfo(name)
1191
  except errors.ResolverError, err:
1192
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
1193
                               (err[0], err[2]), errors.ECODE_RESOLVER)
1194

    
1195

    
1196
def ListVolumeGroups():
1197
  """List volume groups and their size
1198

1199
  @rtype: dict
1200
  @return:
1201
       Dictionary with keys volume name and values
1202
       the size of the volume
1203

1204
  """
1205
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1206
  result = RunCmd(command)
1207
  retval = {}
1208
  if result.failed:
1209
    return retval
1210

    
1211
  for line in result.stdout.splitlines():
1212
    try:
1213
      name, size = line.split()
1214
      size = int(float(size))
1215
    except (IndexError, ValueError), err:
1216
      logging.error("Invalid output from vgs (%s): %s", err, line)
1217
      continue
1218

    
1219
    retval[name] = size
1220

    
1221
  return retval
1222

    
1223

    
1224
def BridgeExists(bridge):
1225
  """Check whether the given bridge exists in the system
1226

1227
  @type bridge: str
1228
  @param bridge: the bridge name to check
1229
  @rtype: boolean
1230
  @return: True if it does
1231

1232
  """
1233
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1234

    
1235

    
1236
def NiceSort(name_list):
1237
  """Sort a list of strings based on digit and non-digit groupings.
1238

1239
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
1240
  will sort the list in the logical order C{['a1', 'a2', 'a10',
1241
  'a11']}.
1242

1243
  The sort algorithm breaks each name in groups of either only-digits
1244
  or no-digits. Only the first eight such groups are considered, and
1245
  after that we just use what's left of the string.
1246

1247
  @type name_list: list
1248
  @param name_list: the names to be sorted
1249
  @rtype: list
1250
  @return: a copy of the name list sorted with our algorithm
1251

1252
  """
1253
  _SORTER_BASE = "(\D+|\d+)"
1254
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
1255
                                                  _SORTER_BASE, _SORTER_BASE,
1256
                                                  _SORTER_BASE, _SORTER_BASE,
1257
                                                  _SORTER_BASE, _SORTER_BASE)
1258
  _SORTER_RE = re.compile(_SORTER_FULL)
1259
  _SORTER_NODIGIT = re.compile("^\D*$")
1260
  def _TryInt(val):
1261
    """Attempts to convert a variable to integer."""
1262
    if val is None or _SORTER_NODIGIT.match(val):
1263
      return val
1264
    rval = int(val)
1265
    return rval
1266

    
1267
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
1268
             for name in name_list]
1269
  to_sort.sort()
1270
  return [tup[1] for tup in to_sort]
1271

    
1272

    
1273
def TryConvert(fn, val):
1274
  """Try to convert a value ignoring errors.
1275

1276
  This function tries to apply function I{fn} to I{val}. If no
1277
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1278
  the result, else it will return the original value. Any other
1279
  exceptions are propagated to the caller.
1280

1281
  @type fn: callable
1282
  @param fn: function to apply to the value
1283
  @param val: the value to be converted
1284
  @return: The converted value if the conversion was successful,
1285
      otherwise the original value.
1286

1287
  """
1288
  try:
1289
    nv = fn(val)
1290
  except (ValueError, TypeError):
1291
    nv = val
1292
  return nv
1293

    
1294

    
1295
def _GenericIsValidIP(family, ip):
1296
  """Generic internal version of ip validation.
1297

1298
  @type family: int
1299
  @param family: socket.AF_INET | socket.AF_INET6
1300
  @type ip: str
1301
  @param ip: the address to be checked
1302
  @rtype: boolean
1303
  @return: True if ip is valid, False otherwise
1304

1305
  """
1306
  try:
1307
    socket.inet_pton(family, ip)
1308
    return True
1309
  except socket.error:
1310
    return False
1311

    
1312

    
1313
def IsValidIP4(ip):
1314
  """Verifies an IPv4 address.
1315

1316
  This function checks if the given address is a valid IPv4 address.
1317

1318
  @type ip: str
1319
  @param ip: the address to be checked
1320
  @rtype: boolean
1321
  @return: True if ip is valid, False otherwise
1322

1323
  """
1324
  return _GenericIsValidIP(socket.AF_INET, ip)
1325

    
1326

    
1327
def IsValidIP6(ip):
1328
  """Verifies an IPv6 address.
1329

1330
  This function checks if the given address is a valid IPv6 address.
1331

1332
  @type ip: str
1333
  @param ip: the address to be checked
1334
  @rtype: boolean
1335
  @return: True if ip is valid, False otherwise
1336

1337
  """
1338
  return _GenericIsValidIP(socket.AF_INET6, ip)
1339

    
1340

    
1341
def IsValidIP(ip):
1342
  """Verifies an IP address.
1343

1344
  This function checks if the given IP address (both IPv4 and IPv6) is valid.
1345

1346
  @type ip: str
1347
  @param ip: the address to be checked
1348
  @rtype: boolean
1349
  @return: True if ip is valid, False otherwise
1350

1351
  """
1352
  return IsValidIP4(ip) or IsValidIP6(ip)
1353

    
1354

    
1355
def GetAddressFamily(ip):
1356
  """Get the address family of the given address.
1357

1358
  @type ip: str
1359
  @param ip: ip address whose family will be returned
1360
  @rtype: int
1361
  @return: socket.AF_INET or socket.AF_INET6
1362
  @raise errors.GenericError: for invalid addresses
1363

1364
  """
1365
  if IsValidIP6(ip):
1366
    return socket.AF_INET6
1367
  elif IsValidIP4(ip):
1368
    return socket.AF_INET
1369
  else:
1370
    raise errors.GenericError("Address %s not valid" % ip)
1371

    
1372

    
1373
def IsValidShellParam(word):
1374
  """Verifies is the given word is safe from the shell's p.o.v.
1375

1376
  This means that we can pass this to a command via the shell and be
1377
  sure that it doesn't alter the command line and is passed as such to
1378
  the actual command.
1379

1380
  Note that we are overly restrictive here, in order to be on the safe
1381
  side.
1382

1383
  @type word: str
1384
  @param word: the word to check
1385
  @rtype: boolean
1386
  @return: True if the word is 'safe'
1387

1388
  """
1389
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
1390

    
1391

    
1392
def BuildShellCmd(template, *args):
1393
  """Build a safe shell command line from the given arguments.
1394

1395
  This function will check all arguments in the args list so that they
1396
  are valid shell parameters (i.e. they don't contain shell
1397
  metacharacters). If everything is ok, it will return the result of
1398
  template % args.
1399

1400
  @type template: str
1401
  @param template: the string holding the template for the
1402
      string formatting
1403
  @rtype: str
1404
  @return: the expanded command line
1405

1406
  """
1407
  for word in args:
1408
    if not IsValidShellParam(word):
1409
      raise errors.ProgrammerError("Shell argument '%s' contains"
1410
                                   " invalid characters" % word)
1411
  return template % args
1412

    
1413

    
1414
def FormatUnit(value, units):
1415
  """Formats an incoming number of MiB with the appropriate unit.
1416

1417
  @type value: int
1418
  @param value: integer representing the value in MiB (1048576)
1419
  @type units: char
1420
  @param units: the type of formatting we should do:
1421
      - 'h' for automatic scaling
1422
      - 'm' for MiBs
1423
      - 'g' for GiBs
1424
      - 't' for TiBs
1425
  @rtype: str
1426
  @return: the formatted value (with suffix)
1427

1428
  """
1429
  if units not in ('m', 'g', 't', 'h'):
1430
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
1431

    
1432
  suffix = ''
1433

    
1434
  if units == 'm' or (units == 'h' and value < 1024):
1435
    if units == 'h':
1436
      suffix = 'M'
1437
    return "%d%s" % (round(value, 0), suffix)
1438

    
1439
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
1440
    if units == 'h':
1441
      suffix = 'G'
1442
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
1443

    
1444
  else:
1445
    if units == 'h':
1446
      suffix = 'T'
1447
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
1448

    
1449

    
1450
def ParseUnit(input_string):
1451
  """Tries to extract number and scale from the given string.
1452

1453
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
1454
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
1455
  is always an int in MiB.
1456

1457
  """
1458
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
1459
  if not m:
1460
    raise errors.UnitParseError("Invalid format")
1461

    
1462
  value = float(m.groups()[0])
1463

    
1464
  unit = m.groups()[1]
1465
  if unit:
1466
    lcunit = unit.lower()
1467
  else:
1468
    lcunit = 'm'
1469

    
1470
  if lcunit in ('m', 'mb', 'mib'):
1471
    # Value already in MiB
1472
    pass
1473

    
1474
  elif lcunit in ('g', 'gb', 'gib'):
1475
    value *= 1024
1476

    
1477
  elif lcunit in ('t', 'tb', 'tib'):
1478
    value *= 1024 * 1024
1479

    
1480
  else:
1481
    raise errors.UnitParseError("Unknown unit: %s" % unit)
1482

    
1483
  # Make sure we round up
1484
  if int(value) < value:
1485
    value += 1
1486

    
1487
  # Round up to the next multiple of 4
1488
  value = int(value)
1489
  if value % 4:
1490
    value += 4 - value % 4
1491

    
1492
  return value
1493

    
1494

    
1495
def AddAuthorizedKey(file_name, key):
1496
  """Adds an SSH public key to an authorized_keys file.
1497

1498
  @type file_name: str
1499
  @param file_name: path to authorized_keys file
1500
  @type key: str
1501
  @param key: string containing key
1502

1503
  """
1504
  key_fields = key.split()
1505

    
1506
  f = open(file_name, 'a+')
1507
  try:
1508
    nl = True
1509
    for line in f:
1510
      # Ignore whitespace changes
1511
      if line.split() == key_fields:
1512
        break
1513
      nl = line.endswith('\n')
1514
    else:
1515
      if not nl:
1516
        f.write("\n")
1517
      f.write(key.rstrip('\r\n'))
1518
      f.write("\n")
1519
      f.flush()
1520
  finally:
1521
    f.close()
1522

    
1523

    
1524
def RemoveAuthorizedKey(file_name, key):
1525
  """Removes an SSH public key from an authorized_keys file.
1526

1527
  @type file_name: str
1528
  @param file_name: path to authorized_keys file
1529
  @type key: str
1530
  @param key: string containing key
1531

1532
  """
1533
  key_fields = key.split()
1534

    
1535
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1536
  try:
1537
    out = os.fdopen(fd, 'w')
1538
    try:
1539
      f = open(file_name, 'r')
1540
      try:
1541
        for line in f:
1542
          # Ignore whitespace changes while comparing lines
1543
          if line.split() != key_fields:
1544
            out.write(line)
1545

    
1546
        out.flush()
1547
        os.rename(tmpname, file_name)
1548
      finally:
1549
        f.close()
1550
    finally:
1551
      out.close()
1552
  except:
1553
    RemoveFile(tmpname)
1554
    raise
1555

    
1556

    
1557
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1558
  """Sets the name of an IP address and hostname in /etc/hosts.
1559

1560
  @type file_name: str
1561
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1562
  @type ip: str
1563
  @param ip: the IP address
1564
  @type hostname: str
1565
  @param hostname: the hostname to be added
1566
  @type aliases: list
1567
  @param aliases: the list of aliases to add for the hostname
1568

1569
  """
1570
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1571
  # Ensure aliases are unique
1572
  aliases = UniqueSequence([hostname] + aliases)[1:]
1573

    
1574
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1575
  try:
1576
    out = os.fdopen(fd, 'w')
1577
    try:
1578
      f = open(file_name, 'r')
1579
      try:
1580
        for line in f:
1581
          fields = line.split()
1582
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1583
            continue
1584
          out.write(line)
1585

    
1586
        out.write("%s\t%s" % (ip, hostname))
1587
        if aliases:
1588
          out.write(" %s" % ' '.join(aliases))
1589
        out.write('\n')
1590

    
1591
        out.flush()
1592
        os.fsync(out)
1593
        os.chmod(tmpname, 0644)
1594
        os.rename(tmpname, file_name)
1595
      finally:
1596
        f.close()
1597
    finally:
1598
      out.close()
1599
  except:
1600
    RemoveFile(tmpname)
1601
    raise
1602

    
1603

    
1604
def AddHostToEtcHosts(hostname):
1605
  """Wrapper around SetEtcHostsEntry.
1606

1607
  @type hostname: str
1608
  @param hostname: a hostname that will be resolved and added to
1609
      L{constants.ETC_HOSTS}
1610

1611
  """
1612
  hi = HostInfo(name=hostname)
1613
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1614

    
1615

    
1616
def RemoveEtcHostsEntry(file_name, hostname):
1617
  """Removes a hostname from /etc/hosts.
1618

1619
  IP addresses without names are removed from the file.
1620

1621
  @type file_name: str
1622
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1623
  @type hostname: str
1624
  @param hostname: the hostname to be removed
1625

1626
  """
1627
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1628
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1629
  try:
1630
    out = os.fdopen(fd, 'w')
1631
    try:
1632
      f = open(file_name, 'r')
1633
      try:
1634
        for line in f:
1635
          fields = line.split()
1636
          if len(fields) > 1 and not fields[0].startswith('#'):
1637
            names = fields[1:]
1638
            if hostname in names:
1639
              while hostname in names:
1640
                names.remove(hostname)
1641
              if names:
1642
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1643
              continue
1644

    
1645
          out.write(line)
1646

    
1647
        out.flush()
1648
        os.fsync(out)
1649
        os.chmod(tmpname, 0644)
1650
        os.rename(tmpname, file_name)
1651
      finally:
1652
        f.close()
1653
    finally:
1654
      out.close()
1655
  except:
1656
    RemoveFile(tmpname)
1657
    raise
1658

    
1659

    
1660
def RemoveHostFromEtcHosts(hostname):
1661
  """Wrapper around RemoveEtcHostsEntry.
1662

1663
  @type hostname: str
1664
  @param hostname: hostname that will be resolved and its
1665
      full and shot name will be removed from
1666
      L{constants.ETC_HOSTS}
1667

1668
  """
1669
  hi = HostInfo(name=hostname)
1670
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1671
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1672

    
1673

    
1674
def TimestampForFilename():
1675
  """Returns the current time formatted for filenames.
1676

1677
  The format doesn't contain colons as some shells and applications them as
1678
  separators.
1679

1680
  """
1681
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1682

    
1683

    
1684
def CreateBackup(file_name):
1685
  """Creates a backup of a file.
1686

1687
  @type file_name: str
1688
  @param file_name: file to be backed up
1689
  @rtype: str
1690
  @return: the path to the newly created backup
1691
  @raise errors.ProgrammerError: for invalid file names
1692

1693
  """
1694
  if not os.path.isfile(file_name):
1695
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1696
                                file_name)
1697

    
1698
  prefix = ("%s.backup-%s." %
1699
            (os.path.basename(file_name), TimestampForFilename()))
1700
  dir_name = os.path.dirname(file_name)
1701

    
1702
  fsrc = open(file_name, 'rb')
1703
  try:
1704
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1705
    fdst = os.fdopen(fd, 'wb')
1706
    try:
1707
      logging.debug("Backing up %s at %s", file_name, backup_name)
1708
      shutil.copyfileobj(fsrc, fdst)
1709
    finally:
1710
      fdst.close()
1711
  finally:
1712
    fsrc.close()
1713

    
1714
  return backup_name
1715

    
1716

    
1717
def ShellQuote(value):
1718
  """Quotes shell argument according to POSIX.
1719

1720
  @type value: str
1721
  @param value: the argument to be quoted
1722
  @rtype: str
1723
  @return: the quoted value
1724

1725
  """
1726
  if _re_shell_unquoted.match(value):
1727
    return value
1728
  else:
1729
    return "'%s'" % value.replace("'", "'\\''")
1730

    
1731

    
1732
def ShellQuoteArgs(args):
1733
  """Quotes a list of shell arguments.
1734

1735
  @type args: list
1736
  @param args: list of arguments to be quoted
1737
  @rtype: str
1738
  @return: the quoted arguments concatenated with spaces
1739

1740
  """
1741
  return ' '.join([ShellQuote(i) for i in args])
1742

    
1743

    
1744
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1745
  """Simple ping implementation using TCP connect(2).
1746

1747
  Check if the given IP is reachable by doing attempting a TCP connect
1748
  to it.
1749

1750
  @type target: str
1751
  @param target: the IP or hostname to ping
1752
  @type port: int
1753
  @param port: the port to connect to
1754
  @type timeout: int
1755
  @param timeout: the timeout on the connection attempt
1756
  @type live_port_needed: boolean
1757
  @param live_port_needed: whether a closed port will cause the
1758
      function to return failure, as if there was a timeout
1759
  @type source: str or None
1760
  @param source: if specified, will cause the connect to be made
1761
      from this specific source address; failures to bind other
1762
      than C{EADDRNOTAVAIL} will be ignored
1763

1764
  """
1765
  try:
1766
    family = GetAddressFamily(target)
1767
  except errors.GenericError:
1768
    return False
1769

    
1770
  sock = socket.socket(family, socket.SOCK_STREAM)
1771
  success = False
1772

    
1773
  if source is not None:
1774
    try:
1775
      sock.bind((source, 0))
1776
    except socket.error, (errcode, _):
1777
      if errcode == errno.EADDRNOTAVAIL:
1778
        success = False
1779

    
1780
  sock.settimeout(timeout)
1781

    
1782
  try:
1783
    sock.connect((target, port))
1784
    sock.close()
1785
    success = True
1786
  except socket.timeout:
1787
    success = False
1788
  except socket.error, (errcode, _):
1789
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1790

    
1791
  return success
1792

    
1793

    
1794
def OwnIpAddress(address):
1795
  """Check if the current host has the the given IP address.
1796

1797
  This is done by trying to bind the given address. We return True if we
1798
  succeed or false if a socket.error is raised.
1799

1800
  @type address: string
1801
  @param address: the address to check
1802
  @rtype: bool
1803
  @return: True if we own the address
1804

1805
  """
1806
  family = GetAddressFamily(address)
1807
  s = socket.socket(family, socket.SOCK_DGRAM)
1808
  success = False
1809
  try:
1810
    try:
1811
      s.bind((address, 0))
1812
      success = True
1813
    except socket.error:
1814
      success = False
1815
  finally:
1816
    s.close()
1817
  return success
1818

    
1819

    
1820
def ListVisibleFiles(path):
1821
  """Returns a list of visible files in a directory.
1822

1823
  @type path: str
1824
  @param path: the directory to enumerate
1825
  @rtype: list
1826
  @return: the list of all files not starting with a dot
1827
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1828

1829
  """
1830
  if not IsNormAbsPath(path):
1831
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1832
                                 " absolute/normalized: '%s'" % path)
1833
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1834
  return files
1835

    
1836

    
1837
def GetHomeDir(user, default=None):
1838
  """Try to get the homedir of the given user.
1839

1840
  The user can be passed either as a string (denoting the name) or as
1841
  an integer (denoting the user id). If the user is not found, the
1842
  'default' argument is returned, which defaults to None.
1843

1844
  """
1845
  try:
1846
    if isinstance(user, basestring):
1847
      result = pwd.getpwnam(user)
1848
    elif isinstance(user, (int, long)):
1849
      result = pwd.getpwuid(user)
1850
    else:
1851
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1852
                                   type(user))
1853
  except KeyError:
1854
    return default
1855
  return result.pw_dir
1856

    
1857

    
1858
def NewUUID():
1859
  """Returns a random UUID.
1860

1861
  @note: This is a Linux-specific method as it uses the /proc
1862
      filesystem.
1863
  @rtype: str
1864

1865
  """
1866
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1867

    
1868

    
1869
def GenerateSecret(numbytes=20):
1870
  """Generates a random secret.
1871

1872
  This will generate a pseudo-random secret returning an hex string
1873
  (so that it can be used where an ASCII string is needed).
1874

1875
  @param numbytes: the number of bytes which will be represented by the returned
1876
      string (defaulting to 20, the length of a SHA1 hash)
1877
  @rtype: str
1878
  @return: an hex representation of the pseudo-random sequence
1879

1880
  """
1881
  return os.urandom(numbytes).encode('hex')
1882

    
1883

    
1884
def EnsureDirs(dirs):
1885
  """Make required directories, if they don't exist.
1886

1887
  @param dirs: list of tuples (dir_name, dir_mode)
1888
  @type dirs: list of (string, integer)
1889

1890
  """
1891
  for dir_name, dir_mode in dirs:
1892
    try:
1893
      os.mkdir(dir_name, dir_mode)
1894
    except EnvironmentError, err:
1895
      if err.errno != errno.EEXIST:
1896
        raise errors.GenericError("Cannot create needed directory"
1897
                                  " '%s': %s" % (dir_name, err))
1898
    try:
1899
      os.chmod(dir_name, dir_mode)
1900
    except EnvironmentError, err:
1901
      raise errors.GenericError("Cannot change directory permissions on"
1902
                                " '%s': %s" % (dir_name, err))
1903
    if not os.path.isdir(dir_name):
1904
      raise errors.GenericError("%s is not a directory" % dir_name)
1905

    
1906

    
1907
def ReadFile(file_name, size=-1):
1908
  """Reads a file.
1909

1910
  @type size: int
1911
  @param size: Read at most size bytes (if negative, entire file)
1912
  @rtype: str
1913
  @return: the (possibly partial) content of the file
1914

1915
  """
1916
  f = open(file_name, "r")
1917
  try:
1918
    return f.read(size)
1919
  finally:
1920
    f.close()
1921

    
1922

    
1923
def WriteFile(file_name, fn=None, data=None,
1924
              mode=None, uid=-1, gid=-1,
1925
              atime=None, mtime=None, close=True,
1926
              dry_run=False, backup=False,
1927
              prewrite=None, postwrite=None):
1928
  """(Over)write a file atomically.
1929

1930
  The file_name and either fn (a function taking one argument, the
1931
  file descriptor, and which should write the data to it) or data (the
1932
  contents of the file) must be passed. The other arguments are
1933
  optional and allow setting the file mode, owner and group, and the
1934
  mtime/atime of the file.
1935

1936
  If the function doesn't raise an exception, it has succeeded and the
1937
  target file has the new contents. If the function has raised an
1938
  exception, an existing target file should be unmodified and the
1939
  temporary file should be removed.
1940

1941
  @type file_name: str
1942
  @param file_name: the target filename
1943
  @type fn: callable
1944
  @param fn: content writing function, called with
1945
      file descriptor as parameter
1946
  @type data: str
1947
  @param data: contents of the file
1948
  @type mode: int
1949
  @param mode: file mode
1950
  @type uid: int
1951
  @param uid: the owner of the file
1952
  @type gid: int
1953
  @param gid: the group of the file
1954
  @type atime: int
1955
  @param atime: a custom access time to be set on the file
1956
  @type mtime: int
1957
  @param mtime: a custom modification time to be set on the file
1958
  @type close: boolean
1959
  @param close: whether to close file after writing it
1960
  @type prewrite: callable
1961
  @param prewrite: function to be called before writing content
1962
  @type postwrite: callable
1963
  @param postwrite: function to be called after writing content
1964

1965
  @rtype: None or int
1966
  @return: None if the 'close' parameter evaluates to True,
1967
      otherwise the file descriptor
1968

1969
  @raise errors.ProgrammerError: if any of the arguments are not valid
1970

1971
  """
1972
  if not os.path.isabs(file_name):
1973
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1974
                                 " absolute: '%s'" % file_name)
1975

    
1976
  if [fn, data].count(None) != 1:
1977
    raise errors.ProgrammerError("fn or data required")
1978

    
1979
  if [atime, mtime].count(None) == 1:
1980
    raise errors.ProgrammerError("Both atime and mtime must be either"
1981
                                 " set or None")
1982

    
1983
  if backup and not dry_run and os.path.isfile(file_name):
1984
    CreateBackup(file_name)
1985

    
1986
  dir_name, base_name = os.path.split(file_name)
1987
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1988
  do_remove = True
1989
  # here we need to make sure we remove the temp file, if any error
1990
  # leaves it in place
1991
  try:
1992
    if uid != -1 or gid != -1:
1993
      os.chown(new_name, uid, gid)
1994
    if mode:
1995
      os.chmod(new_name, mode)
1996
    if callable(prewrite):
1997
      prewrite(fd)
1998
    if data is not None:
1999
      os.write(fd, data)
2000
    else:
2001
      fn(fd)
2002
    if callable(postwrite):
2003
      postwrite(fd)
2004
    os.fsync(fd)
2005
    if atime is not None and mtime is not None:
2006
      os.utime(new_name, (atime, mtime))
2007
    if not dry_run:
2008
      os.rename(new_name, file_name)
2009
      do_remove = False
2010
  finally:
2011
    if close:
2012
      os.close(fd)
2013
      result = None
2014
    else:
2015
      result = fd
2016
    if do_remove:
2017
      RemoveFile(new_name)
2018

    
2019
  return result
2020

    
2021

    
2022
def ReadOneLineFile(file_name, strict=False):
2023
  """Return the first non-empty line from a file.
2024

2025
  @type strict: boolean
2026
  @param strict: if True, abort if the file has more than one
2027
      non-empty line
2028

2029
  """
2030
  file_lines = ReadFile(file_name).splitlines()
2031
  full_lines = filter(bool, file_lines)
2032
  if not file_lines or not full_lines:
2033
    raise errors.GenericError("No data in one-liner file %s" % file_name)
2034
  elif strict and len(full_lines) > 1:
2035
    raise errors.GenericError("Too many lines in one-liner file %s" %
2036
                              file_name)
2037
  return full_lines[0]
2038

    
2039

    
2040
def FirstFree(seq, base=0):
2041
  """Returns the first non-existing integer from seq.
2042

2043
  The seq argument should be a sorted list of positive integers. The
2044
  first time the index of an element is smaller than the element
2045
  value, the index will be returned.
2046

2047
  The base argument is used to start at a different offset,
2048
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
2049

2050
  Example: C{[0, 1, 3]} will return I{2}.
2051

2052
  @type seq: sequence
2053
  @param seq: the sequence to be analyzed.
2054
  @type base: int
2055
  @param base: use this value as the base index of the sequence
2056
  @rtype: int
2057
  @return: the first non-used index in the sequence
2058

2059
  """
2060
  for idx, elem in enumerate(seq):
2061
    assert elem >= base, "Passed element is higher than base offset"
2062
    if elem > idx + base:
2063
      # idx is not used
2064
      return idx + base
2065
  return None
2066

    
2067

    
2068
def SingleWaitForFdCondition(fdobj, event, timeout):
2069
  """Waits for a condition to occur on the socket.
2070

2071
  Immediately returns at the first interruption.
2072

2073
  @type fdobj: integer or object supporting a fileno() method
2074
  @param fdobj: entity to wait for events on
2075
  @type event: integer
2076
  @param event: ORed condition (see select module)
2077
  @type timeout: float or None
2078
  @param timeout: Timeout in seconds
2079
  @rtype: int or None
2080
  @return: None for timeout, otherwise occured conditions
2081

2082
  """
2083
  check = (event | select.POLLPRI |
2084
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
2085

    
2086
  if timeout is not None:
2087
    # Poller object expects milliseconds
2088
    timeout *= 1000
2089

    
2090
  poller = select.poll()
2091
  poller.register(fdobj, event)
2092
  try:
2093
    # TODO: If the main thread receives a signal and we have no timeout, we
2094
    # could wait forever. This should check a global "quit" flag or something
2095
    # every so often.
2096
    io_events = poller.poll(timeout)
2097
  except select.error, err:
2098
    if err[0] != errno.EINTR:
2099
      raise
2100
    io_events = []
2101
  if io_events and io_events[0][1] & check:
2102
    return io_events[0][1]
2103
  else:
2104
    return None
2105

    
2106

    
2107
class FdConditionWaiterHelper(object):
2108
  """Retry helper for WaitForFdCondition.
2109

2110
  This class contains the retried and wait functions that make sure
2111
  WaitForFdCondition can continue waiting until the timeout is actually
2112
  expired.
2113

2114
  """
2115

    
2116
  def __init__(self, timeout):
2117
    self.timeout = timeout
2118

    
2119
  def Poll(self, fdobj, event):
2120
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
2121
    if result is None:
2122
      raise RetryAgain()
2123
    else:
2124
      return result
2125

    
2126
  def UpdateTimeout(self, timeout):
2127
    self.timeout = timeout
2128

    
2129

    
2130
def WaitForFdCondition(fdobj, event, timeout):
2131
  """Waits for a condition to occur on the socket.
2132

2133
  Retries until the timeout is expired, even if interrupted.
2134

2135
  @type fdobj: integer or object supporting a fileno() method
2136
  @param fdobj: entity to wait for events on
2137
  @type event: integer
2138
  @param event: ORed condition (see select module)
2139
  @type timeout: float or None
2140
  @param timeout: Timeout in seconds
2141
  @rtype: int or None
2142
  @return: None for timeout, otherwise occured conditions
2143

2144
  """
2145
  if timeout is not None:
2146
    retrywaiter = FdConditionWaiterHelper(timeout)
2147
    try:
2148
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
2149
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
2150
    except RetryTimeout:
2151
      result = None
2152
  else:
2153
    result = None
2154
    while result is None:
2155
      result = SingleWaitForFdCondition(fdobj, event, timeout)
2156
  return result
2157

    
2158

    
2159
def UniqueSequence(seq):
2160
  """Returns a list with unique elements.
2161

2162
  Element order is preserved.
2163

2164
  @type seq: sequence
2165
  @param seq: the sequence with the source elements
2166
  @rtype: list
2167
  @return: list of unique elements from seq
2168

2169
  """
2170
  seen = set()
2171
  return [i for i in seq if i not in seen and not seen.add(i)]
2172

    
2173

    
2174
def NormalizeAndValidateMac(mac):
2175
  """Normalizes and check if a MAC address is valid.
2176

2177
  Checks whether the supplied MAC address is formally correct, only
2178
  accepts colon separated format. Normalize it to all lower.
2179

2180
  @type mac: str
2181
  @param mac: the MAC to be validated
2182
  @rtype: str
2183
  @return: returns the normalized and validated MAC.
2184

2185
  @raise errors.OpPrereqError: If the MAC isn't valid
2186

2187
  """
2188
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
2189
  if not mac_check.match(mac):
2190
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
2191
                               mac, errors.ECODE_INVAL)
2192

    
2193
  return mac.lower()
2194

    
2195

    
2196
def TestDelay(duration):
2197
  """Sleep for a fixed amount of time.
2198

2199
  @type duration: float
2200
  @param duration: the sleep duration
2201
  @rtype: boolean
2202
  @return: False for negative value, True otherwise
2203

2204
  """
2205
  if duration < 0:
2206
    return False, "Invalid sleep duration"
2207
  time.sleep(duration)
2208
  return True, None
2209

    
2210

    
2211
def _CloseFDNoErr(fd, retries=5):
2212
  """Close a file descriptor ignoring errors.
2213

2214
  @type fd: int
2215
  @param fd: the file descriptor
2216
  @type retries: int
2217
  @param retries: how many retries to make, in case we get any
2218
      other error than EBADF
2219

2220
  """
2221
  try:
2222
    os.close(fd)
2223
  except OSError, err:
2224
    if err.errno != errno.EBADF:
2225
      if retries > 0:
2226
        _CloseFDNoErr(fd, retries - 1)
2227
    # else either it's closed already or we're out of retries, so we
2228
    # ignore this and go on
2229

    
2230

    
2231
def CloseFDs(noclose_fds=None):
2232
  """Close file descriptors.
2233

2234
  This closes all file descriptors above 2 (i.e. except
2235
  stdin/out/err).
2236

2237
  @type noclose_fds: list or None
2238
  @param noclose_fds: if given, it denotes a list of file descriptor
2239
      that should not be closed
2240

2241
  """
2242
  # Default maximum for the number of available file descriptors.
2243
  if 'SC_OPEN_MAX' in os.sysconf_names:
2244
    try:
2245
      MAXFD = os.sysconf('SC_OPEN_MAX')
2246
      if MAXFD < 0:
2247
        MAXFD = 1024
2248
    except OSError:
2249
      MAXFD = 1024
2250
  else:
2251
    MAXFD = 1024
2252
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
2253
  if (maxfd == resource.RLIM_INFINITY):
2254
    maxfd = MAXFD
2255

    
2256
  # Iterate through and close all file descriptors (except the standard ones)
2257
  for fd in range(3, maxfd):
2258
    if noclose_fds and fd in noclose_fds:
2259
      continue
2260
    _CloseFDNoErr(fd)
2261

    
2262

    
2263
def Mlockall():
2264
  """Lock current process' virtual address space into RAM.
2265

2266
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
2267
  see mlock(2) for more details. This function requires ctypes module.
2268

2269
  """
2270
  if ctypes is None:
2271
    logging.warning("Cannot set memory lock, ctypes module not found")
2272
    return
2273

    
2274
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
2275
  if libc is None:
2276
    logging.error("Cannot set memory lock, ctypes cannot load libc")
2277
    return
2278

    
2279
  # Some older version of the ctypes module don't have built-in functionality
2280
  # to access the errno global variable, where function error codes are stored.
2281
  # By declaring this variable as a pointer to an integer we can then access
2282
  # its value correctly, should the mlockall call fail, in order to see what
2283
  # the actual error code was.
2284
  # pylint: disable-msg=W0212
2285
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
2286

    
2287
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
2288
    # pylint: disable-msg=W0212
2289
    logging.error("Cannot set memory lock: %s",
2290
                  os.strerror(libc.__errno_location().contents.value))
2291
    return
2292

    
2293
  logging.debug("Memory lock set")
2294

    
2295

    
2296
def Daemonize(logfile, run_uid, run_gid):
2297
  """Daemonize the current process.
2298

2299
  This detaches the current process from the controlling terminal and
2300
  runs it in the background as a daemon.
2301

2302
  @type logfile: str
2303
  @param logfile: the logfile to which we should redirect stdout/stderr
2304
  @type run_uid: int
2305
  @param run_uid: Run the child under this uid
2306
  @type run_gid: int
2307
  @param run_gid: Run the child under this gid
2308
  @rtype: int
2309
  @return: the value zero
2310

2311
  """
2312
  # pylint: disable-msg=W0212
2313
  # yes, we really want os._exit
2314
  UMASK = 077
2315
  WORKDIR = "/"
2316

    
2317
  # this might fail
2318
  pid = os.fork()
2319
  if (pid == 0):  # The first child.
2320
    os.setsid()
2321
    # FIXME: When removing again and moving to start-stop-daemon privilege drop
2322
    #        make sure to check for config permission and bail out when invoked
2323
    #        with wrong user.
2324
    os.setgid(run_gid)
2325
    os.setuid(run_uid)
2326
    # this might fail
2327
    pid = os.fork() # Fork a second child.
2328
    if (pid == 0):  # The second child.
2329
      os.chdir(WORKDIR)
2330
      os.umask(UMASK)
2331
    else:
2332
      # exit() or _exit()?  See below.
2333
      os._exit(0) # Exit parent (the first child) of the second child.
2334
  else:
2335
    os._exit(0) # Exit parent of the first child.
2336

    
2337
  for fd in range(3):
2338
    _CloseFDNoErr(fd)
2339
  i = os.open("/dev/null", os.O_RDONLY) # stdin
2340
  assert i == 0, "Can't close/reopen stdin"
2341
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
2342
  assert i == 1, "Can't close/reopen stdout"
2343
  # Duplicate standard output to standard error.
2344
  os.dup2(1, 2)
2345
  return 0
2346

    
2347

    
2348
def DaemonPidFileName(name):
2349
  """Compute a ganeti pid file absolute path
2350

2351
  @type name: str
2352
  @param name: the daemon name
2353
  @rtype: str
2354
  @return: the full path to the pidfile corresponding to the given
2355
      daemon name
2356

2357
  """
2358
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
2359

    
2360

    
2361
def EnsureDaemon(name):
2362
  """Check for and start daemon if not alive.
2363

2364
  """
2365
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
2366
  if result.failed:
2367
    logging.error("Can't start daemon '%s', failure %s, output: %s",
2368
                  name, result.fail_reason, result.output)
2369
    return False
2370

    
2371
  return True
2372

    
2373

    
2374
def StopDaemon(name):
2375
  """Stop daemon
2376

2377
  """
2378
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
2379
  if result.failed:
2380
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
2381
                  name, result.fail_reason, result.output)
2382
    return False
2383

    
2384
  return True
2385

    
2386

    
2387
def WritePidFile(name):
2388
  """Write the current process pidfile.
2389

2390
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
2391

2392
  @type name: str
2393
  @param name: the daemon name to use
2394
  @raise errors.GenericError: if the pid file already exists and
2395
      points to a live process
2396

2397
  """
2398
  pid = os.getpid()
2399
  pidfilename = DaemonPidFileName(name)
2400
  if IsProcessAlive(ReadPidFile(pidfilename)):
2401
    raise errors.GenericError("%s contains a live process" % pidfilename)
2402

    
2403
  WriteFile(pidfilename, data="%d\n" % pid)
2404

    
2405

    
2406
def RemovePidFile(name):
2407
  """Remove the current process pidfile.
2408

2409
  Any errors are ignored.
2410

2411
  @type name: str
2412
  @param name: the daemon name used to derive the pidfile name
2413

2414
  """
2415
  pidfilename = DaemonPidFileName(name)
2416
  # TODO: we could check here that the file contains our pid
2417
  try:
2418
    RemoveFile(pidfilename)
2419
  except: # pylint: disable-msg=W0702
2420
    pass
2421

    
2422

    
2423
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
2424
                waitpid=False):
2425
  """Kill a process given by its pid.
2426

2427
  @type pid: int
2428
  @param pid: The PID to terminate.
2429
  @type signal_: int
2430
  @param signal_: The signal to send, by default SIGTERM
2431
  @type timeout: int
2432
  @param timeout: The timeout after which, if the process is still alive,
2433
                  a SIGKILL will be sent. If not positive, no such checking
2434
                  will be done
2435
  @type waitpid: boolean
2436
  @param waitpid: If true, we should waitpid on this process after
2437
      sending signals, since it's our own child and otherwise it
2438
      would remain as zombie
2439

2440
  """
2441
  def _helper(pid, signal_, wait):
2442
    """Simple helper to encapsulate the kill/waitpid sequence"""
2443
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2444
      try:
2445
        os.waitpid(pid, os.WNOHANG)
2446
      except OSError:
2447
        pass
2448

    
2449
  if pid <= 0:
2450
    # kill with pid=0 == suicide
2451
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2452

    
2453
  if not IsProcessAlive(pid):
2454
    return
2455

    
2456
  _helper(pid, signal_, waitpid)
2457

    
2458
  if timeout <= 0:
2459
    return
2460

    
2461
  def _CheckProcess():
2462
    if not IsProcessAlive(pid):
2463
      return
2464

    
2465
    try:
2466
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2467
    except OSError:
2468
      raise RetryAgain()
2469

    
2470
    if result_pid > 0:
2471
      return
2472

    
2473
    raise RetryAgain()
2474

    
2475
  try:
2476
    # Wait up to $timeout seconds
2477
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2478
  except RetryTimeout:
2479
    pass
2480

    
2481
  if IsProcessAlive(pid):
2482
    # Kill process if it's still alive
2483
    _helper(pid, signal.SIGKILL, waitpid)
2484

    
2485

    
2486
def FindFile(name, search_path, test=os.path.exists):
2487
  """Look for a filesystem object in a given path.
2488

2489
  This is an abstract method to search for filesystem object (files,
2490
  dirs) under a given search path.
2491

2492
  @type name: str
2493
  @param name: the name to look for
2494
  @type search_path: str
2495
  @param search_path: location to start at
2496
  @type test: callable
2497
  @param test: a function taking one argument that should return True
2498
      if the a given object is valid; the default value is
2499
      os.path.exists, causing only existing files to be returned
2500
  @rtype: str or None
2501
  @return: full path to the object if found, None otherwise
2502

2503
  """
2504
  # validate the filename mask
2505
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2506
    logging.critical("Invalid value passed for external script name: '%s'",
2507
                     name)
2508
    return None
2509

    
2510
  for dir_name in search_path:
2511
    # FIXME: investigate switch to PathJoin
2512
    item_name = os.path.sep.join([dir_name, name])
2513
    # check the user test and that we're indeed resolving to the given
2514
    # basename
2515
    if test(item_name) and os.path.basename(item_name) == name:
2516
      return item_name
2517
  return None
2518

    
2519

    
2520
def CheckVolumeGroupSize(vglist, vgname, minsize):
2521
  """Checks if the volume group list is valid.
2522

2523
  The function will check if a given volume group is in the list of
2524
  volume groups and has a minimum size.
2525

2526
  @type vglist: dict
2527
  @param vglist: dictionary of volume group names and their size
2528
  @type vgname: str
2529
  @param vgname: the volume group we should check
2530
  @type minsize: int
2531
  @param minsize: the minimum size we accept
2532
  @rtype: None or str
2533
  @return: None for success, otherwise the error message
2534

2535
  """
2536
  vgsize = vglist.get(vgname, None)
2537
  if vgsize is None:
2538
    return "volume group '%s' missing" % vgname
2539
  elif vgsize < minsize:
2540
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2541
            (vgname, minsize, vgsize))
2542
  return None
2543

    
2544

    
2545
def SplitTime(value):
2546
  """Splits time as floating point number into a tuple.
2547

2548
  @param value: Time in seconds
2549
  @type value: int or float
2550
  @return: Tuple containing (seconds, microseconds)
2551

2552
  """
2553
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2554

    
2555
  assert 0 <= seconds, \
2556
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2557
  assert 0 <= microseconds <= 999999, \
2558
    "Microseconds must be 0-999999, but are %s" % microseconds
2559

    
2560
  return (int(seconds), int(microseconds))
2561

    
2562

    
2563
def MergeTime(timetuple):
2564
  """Merges a tuple into time as a floating point number.
2565

2566
  @param timetuple: Time as tuple, (seconds, microseconds)
2567
  @type timetuple: tuple
2568
  @return: Time as a floating point number expressed in seconds
2569

2570
  """
2571
  (seconds, microseconds) = timetuple
2572

    
2573
  assert 0 <= seconds, \
2574
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2575
  assert 0 <= microseconds <= 999999, \
2576
    "Microseconds must be 0-999999, but are %s" % microseconds
2577

    
2578
  return float(seconds) + (float(microseconds) * 0.000001)
2579

    
2580

    
2581
def GetDaemonPort(daemon_name):
2582
  """Get the daemon port for this cluster.
2583

2584
  Note that this routine does not read a ganeti-specific file, but
2585
  instead uses C{socket.getservbyname} to allow pre-customization of
2586
  this parameter outside of Ganeti.
2587

2588
  @type daemon_name: string
2589
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2590
  @rtype: int
2591

2592
  """
2593
  if daemon_name not in constants.DAEMONS_PORTS:
2594
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2595

    
2596
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2597
  try:
2598
    port = socket.getservbyname(daemon_name, proto)
2599
  except socket.error:
2600
    port = default_port
2601

    
2602
  return port
2603

    
2604

    
2605
class LogFileHandler(logging.FileHandler):
2606
  """Log handler that doesn't fallback to stderr.
2607

2608
  When an error occurs while writing on the logfile, logging.FileHandler tries
2609
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2610
  the logfile. This class avoids failures reporting errors to /dev/console.
2611

2612
  """
2613
  def __init__(self, filename, mode="a", encoding=None):
2614
    """Open the specified file and use it as the stream for logging.
2615

2616
    Also open /dev/console to report errors while logging.
2617

2618
    """
2619
    logging.FileHandler.__init__(self, filename, mode, encoding)
2620
    self.console = open(constants.DEV_CONSOLE, "a")
2621

    
2622
  def handleError(self, record): # pylint: disable-msg=C0103
2623
    """Handle errors which occur during an emit() call.
2624

2625
    Try to handle errors with FileHandler method, if it fails write to
2626
    /dev/console.
2627

2628
    """
2629
    try:
2630
      logging.FileHandler.handleError(self, record)
2631
    except Exception: # pylint: disable-msg=W0703
2632
      try:
2633
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2634
      except Exception: # pylint: disable-msg=W0703
2635
        # Log handler tried everything it could, now just give up
2636
        pass
2637

    
2638

    
2639
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2640
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2641
                 console_logging=False):
2642
  """Configures the logging module.
2643

2644
  @type logfile: str
2645
  @param logfile: the filename to which we should log
2646
  @type debug: integer
2647
  @param debug: if greater than zero, enable debug messages, otherwise
2648
      only those at C{INFO} and above level
2649
  @type stderr_logging: boolean
2650
  @param stderr_logging: whether we should also log to the standard error
2651
  @type program: str
2652
  @param program: the name under which we should log messages
2653
  @type multithreaded: boolean
2654
  @param multithreaded: if True, will add the thread name to the log file
2655
  @type syslog: string
2656
  @param syslog: one of 'no', 'yes', 'only':
2657
      - if no, syslog is not used
2658
      - if yes, syslog is used (in addition to file-logging)
2659
      - if only, only syslog is used
2660
  @type console_logging: boolean
2661
  @param console_logging: if True, will use a FileHandler which falls back to
2662
      the system console if logging fails
2663
  @raise EnvironmentError: if we can't open the log file and
2664
      syslog/stderr logging is disabled
2665

2666
  """
2667
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2668
  sft = program + "[%(process)d]:"
2669
  if multithreaded:
2670
    fmt += "/%(threadName)s"
2671
    sft += " (%(threadName)s)"
2672
  if debug:
2673
    fmt += " %(module)s:%(lineno)s"
2674
    # no debug info for syslog loggers
2675
  fmt += " %(levelname)s %(message)s"
2676
  # yes, we do want the textual level, as remote syslog will probably
2677
  # lose the error level, and it's easier to grep for it
2678
  sft += " %(levelname)s %(message)s"
2679
  formatter = logging.Formatter(fmt)
2680
  sys_fmt = logging.Formatter(sft)
2681

    
2682
  root_logger = logging.getLogger("")
2683
  root_logger.setLevel(logging.NOTSET)
2684

    
2685
  # Remove all previously setup handlers
2686
  for handler in root_logger.handlers:
2687
    handler.close()
2688
    root_logger.removeHandler(handler)
2689

    
2690
  if stderr_logging:
2691
    stderr_handler = logging.StreamHandler()
2692
    stderr_handler.setFormatter(formatter)
2693
    if debug:
2694
      stderr_handler.setLevel(logging.NOTSET)
2695
    else:
2696
      stderr_handler.setLevel(logging.CRITICAL)
2697
    root_logger.addHandler(stderr_handler)
2698

    
2699
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2700
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2701
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2702
                                                    facility)
2703
    syslog_handler.setFormatter(sys_fmt)
2704
    # Never enable debug over syslog
2705
    syslog_handler.setLevel(logging.INFO)
2706
    root_logger.addHandler(syslog_handler)
2707

    
2708
  if syslog != constants.SYSLOG_ONLY:
2709
    # this can fail, if the logging directories are not setup or we have
2710
    # a permisssion problem; in this case, it's best to log but ignore
2711
    # the error if stderr_logging is True, and if false we re-raise the
2712
    # exception since otherwise we could run but without any logs at all
2713
    try:
2714
      if console_logging:
2715
        logfile_handler = LogFileHandler(logfile)
2716
      else:
2717
        logfile_handler = logging.FileHandler(logfile)
2718
      logfile_handler.setFormatter(formatter)
2719
      if debug:
2720
        logfile_handler.setLevel(logging.DEBUG)
2721
      else:
2722
        logfile_handler.setLevel(logging.INFO)
2723
      root_logger.addHandler(logfile_handler)
2724
    except EnvironmentError:
2725
      if stderr_logging or syslog == constants.SYSLOG_YES:
2726
        logging.exception("Failed to enable logging to file '%s'", logfile)
2727
      else:
2728
        # we need to re-raise the exception
2729
        raise
2730

    
2731

    
2732
def IsNormAbsPath(path):
2733
  """Check whether a path is absolute and also normalized
2734

2735
  This avoids things like /dir/../../other/path to be valid.
2736

2737
  """
2738
  return os.path.normpath(path) == path and os.path.isabs(path)
2739

    
2740

    
2741
def PathJoin(*args):
2742
  """Safe-join a list of path components.
2743

2744
  Requirements:
2745
      - the first argument must be an absolute path
2746
      - no component in the path must have backtracking (e.g. /../),
2747
        since we check for normalization at the end
2748

2749
  @param args: the path components to be joined
2750
  @raise ValueError: for invalid paths
2751

2752
  """
2753
  # ensure we're having at least one path passed in
2754
  assert args
2755
  # ensure the first component is an absolute and normalized path name
2756
  root = args[0]
2757
  if not IsNormAbsPath(root):
2758
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2759
  result = os.path.join(*args)
2760
  # ensure that the whole path is normalized
2761
  if not IsNormAbsPath(result):
2762
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2763
  # check that we're still under the original prefix
2764
  prefix = os.path.commonprefix([root, result])
2765
  if prefix != root:
2766
    raise ValueError("Error: path joining resulted in different prefix"
2767
                     " (%s != %s)" % (prefix, root))
2768
  return result
2769

    
2770

    
2771
def TailFile(fname, lines=20):
2772
  """Return the last lines from a file.
2773

2774
  @note: this function will only read and parse the last 4KB of
2775
      the file; if the lines are very long, it could be that less
2776
      than the requested number of lines are returned
2777

2778
  @param fname: the file name
2779
  @type lines: int
2780
  @param lines: the (maximum) number of lines to return
2781

2782
  """
2783
  fd = open(fname, "r")
2784
  try:
2785
    fd.seek(0, 2)
2786
    pos = fd.tell()
2787
    pos = max(0, pos-4096)
2788
    fd.seek(pos, 0)
2789
    raw_data = fd.read()
2790
  finally:
2791
    fd.close()
2792

    
2793
  rows = raw_data.splitlines()
2794
  return rows[-lines:]
2795

    
2796

    
2797
def FormatTimestampWithTZ(secs):
2798
  """Formats a Unix timestamp with the local timezone.
2799

2800
  """
2801
  return time.strftime("%F %T %Z", time.gmtime(secs))
2802

    
2803

    
2804
def _ParseAsn1Generalizedtime(value):
2805
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2806

2807
  @type value: string
2808
  @param value: ASN1 GENERALIZEDTIME timestamp
2809

2810
  """
2811
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2812
  if m:
2813
    # We have an offset
2814
    asn1time = m.group(1)
2815
    hours = int(m.group(2))
2816
    minutes = int(m.group(3))
2817
    utcoffset = (60 * hours) + minutes
2818
  else:
2819
    if not value.endswith("Z"):
2820
      raise ValueError("Missing timezone")
2821
    asn1time = value[:-1]
2822
    utcoffset = 0
2823

    
2824
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2825

    
2826
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2827

    
2828
  return calendar.timegm(tt.utctimetuple())
2829

    
2830

    
2831
def GetX509CertValidity(cert):
2832
  """Returns the validity period of the certificate.
2833

2834
  @type cert: OpenSSL.crypto.X509
2835
  @param cert: X509 certificate object
2836

2837
  """
2838
  # The get_notBefore and get_notAfter functions are only supported in
2839
  # pyOpenSSL 0.7 and above.
2840
  try:
2841
    get_notbefore_fn = cert.get_notBefore
2842
  except AttributeError:
2843
    not_before = None
2844
  else:
2845
    not_before_asn1 = get_notbefore_fn()
2846

    
2847
    if not_before_asn1 is None:
2848
      not_before = None
2849
    else:
2850
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2851

    
2852
  try:
2853
    get_notafter_fn = cert.get_notAfter
2854
  except AttributeError:
2855
    not_after = None
2856
  else:
2857
    not_after_asn1 = get_notafter_fn()
2858

    
2859
    if not_after_asn1 is None:
2860
      not_after = None
2861
    else:
2862
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2863

    
2864
  return (not_before, not_after)
2865

    
2866

    
2867
def _VerifyCertificateInner(expired, not_before, not_after, now,
2868
                            warn_days, error_days):
2869
  """Verifies certificate validity.
2870

2871
  @type expired: bool
2872
  @param expired: Whether pyOpenSSL considers the certificate as expired
2873
  @type not_before: number or None
2874
  @param not_before: Unix timestamp before which certificate is not valid
2875
  @type not_after: number or None
2876
  @param not_after: Unix timestamp after which certificate is invalid
2877
  @type now: number
2878
  @param now: Current time as Unix timestamp
2879
  @type warn_days: number or None
2880
  @param warn_days: How many days before expiration a warning should be reported
2881
  @type error_days: number or None
2882
  @param error_days: How many days before expiration an error should be reported
2883

2884
  """
2885
  if expired:
2886
    msg = "Certificate is expired"
2887

    
2888
    if not_before is not None and not_after is not None:
2889
      msg += (" (valid from %s to %s)" %
2890
              (FormatTimestampWithTZ(not_before),
2891
               FormatTimestampWithTZ(not_after)))
2892
    elif not_before is not None:
2893
      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
2894
    elif not_after is not None:
2895
      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
2896

    
2897
    return (CERT_ERROR, msg)
2898

    
2899
  elif not_before is not None and not_before > now:
2900
    return (CERT_WARNING,
2901
            "Certificate not yet valid (valid from %s)" %
2902
            FormatTimestampWithTZ(not_before))
2903

    
2904
  elif not_after is not None:
2905
    remaining_days = int((not_after - now) / (24 * 3600))
2906

    
2907
    msg = "Certificate expires in about %d days" % remaining_days
2908

    
2909
    if error_days is not None and remaining_days <= error_days:
2910
      return (CERT_ERROR, msg)
2911

    
2912
    if warn_days is not None and remaining_days <= warn_days:
2913
      return (CERT_WARNING, msg)
2914

    
2915
  return (None, None)
2916

    
2917

    
2918
def VerifyX509Certificate(cert, warn_days, error_days):
2919
  """Verifies a certificate for LUVerifyCluster.
2920

2921
  @type cert: OpenSSL.crypto.X509
2922
  @param cert: X509 certificate object
2923
  @type warn_days: number or None
2924
  @param warn_days: How many days before expiration a warning should be reported
2925
  @type error_days: number or None
2926
  @param error_days: How many days before expiration an error should be reported
2927

2928
  """
2929
  # Depending on the pyOpenSSL version, this can just return (None, None)
2930
  (not_before, not_after) = GetX509CertValidity(cert)
2931

    
2932
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2933
                                 time.time(), warn_days, error_days)
2934

    
2935

    
2936
def SignX509Certificate(cert, key, salt):
2937
  """Sign a X509 certificate.
2938

2939
  An RFC822-like signature header is added in front of the certificate.
2940

2941
  @type cert: OpenSSL.crypto.X509
2942
  @param cert: X509 certificate object
2943
  @type key: string
2944
  @param key: Key for HMAC
2945
  @type salt: string
2946
  @param salt: Salt for HMAC
2947
  @rtype: string
2948
  @return: Serialized and signed certificate in PEM format
2949

2950
  """
2951
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2952
    raise errors.GenericError("Invalid salt: %r" % salt)
2953

    
2954
  # Dumping as PEM here ensures the certificate is in a sane format
2955
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2956

    
2957
  return ("%s: %s/%s\n\n%s" %
2958
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2959
           Sha1Hmac(key, cert_pem, salt=salt),
2960
           cert_pem))
2961

    
2962

    
2963
def _ExtractX509CertificateSignature(cert_pem):
2964
  """Helper function to extract signature from X509 certificate.
2965

2966
  """
2967
  # Extract signature from original PEM data
2968
  for line in cert_pem.splitlines():
2969
    if line.startswith("---"):
2970
      break
2971

    
2972
    m = X509_SIGNATURE.match(line.strip())
2973
    if m:
2974
      return (m.group("salt"), m.group("sign"))
2975

    
2976
  raise errors.GenericError("X509 certificate signature is missing")
2977

    
2978

    
2979
def LoadSignedX509Certificate(cert_pem, key):
2980
  """Verifies a signed X509 certificate.
2981

2982
  @type cert_pem: string
2983
  @param cert_pem: Certificate in PEM format and with signature header
2984
  @type key: string
2985
  @param key: Key for HMAC
2986
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2987
  @return: X509 certificate object and salt
2988

2989
  """
2990
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2991

    
2992
  # Load certificate
2993
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2994

    
2995
  # Dump again to ensure it's in a sane format
2996
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2997

    
2998
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2999
    raise errors.GenericError("X509 certificate signature is invalid")
3000

    
3001
  return (cert, salt)
3002

    
3003

    
3004
def Sha1Hmac(key, text, salt=None):
3005
  """Calculates the HMAC-SHA1 digest of a text.
3006

3007
  HMAC is defined in RFC2104.
3008

3009
  @type key: string
3010
  @param key: Secret key
3011
  @type text: string
3012

3013
  """
3014
  if salt:
3015
    salted_text = salt + text
3016
  else:
3017
    salted_text = text
3018

    
3019
  return hmac.new(key, salted_text, compat.sha1).hexdigest()
3020

    
3021

    
3022
def VerifySha1Hmac(key, text, digest, salt=None):
3023
  """Verifies the HMAC-SHA1 digest of a text.
3024

3025
  HMAC is defined in RFC2104.
3026

3027
  @type key: string
3028
  @param key: Secret key
3029
  @type text: string
3030
  @type digest: string
3031
  @param digest: Expected digest
3032
  @rtype: bool
3033
  @return: Whether HMAC-SHA1 digest matches
3034

3035
  """
3036
  return digest.lower() == Sha1Hmac(key, text, salt=salt).lower()
3037

    
3038

    
3039
def SafeEncode(text):
3040
  """Return a 'safe' version of a source string.
3041

3042
  This function mangles the input string and returns a version that
3043
  should be safe to display/encode as ASCII. To this end, we first
3044
  convert it to ASCII using the 'backslashreplace' encoding which
3045
  should get rid of any non-ASCII chars, and then we process it
3046
  through a loop copied from the string repr sources in the python; we
3047
  don't use string_escape anymore since that escape single quotes and
3048
  backslashes too, and that is too much; and that escaping is not
3049
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
3050

3051
  @type text: str or unicode
3052
  @param text: input data
3053
  @rtype: str
3054
  @return: a safe version of text
3055

3056
  """
3057
  if isinstance(text, unicode):
3058
    # only if unicode; if str already, we handle it below
3059
    text = text.encode('ascii', 'backslashreplace')
3060
  resu = ""
3061
  for char in text:
3062
    c = ord(char)
3063
    if char  == '\t':
3064
      resu += r'\t'
3065
    elif char == '\n':
3066
      resu += r'\n'
3067
    elif char == '\r':
3068
      resu += r'\'r'
3069
    elif c < 32 or c >= 127: # non-printable
3070
      resu += "\\x%02x" % (c & 0xff)
3071
    else:
3072
      resu += char
3073
  return resu
3074

    
3075

    
3076
def UnescapeAndSplit(text, sep=","):
3077
  """Split and unescape a string based on a given separator.
3078

3079
  This function splits a string based on a separator where the
3080
  separator itself can be escape in order to be an element of the
3081
  elements. The escaping rules are (assuming coma being the
3082
  separator):
3083
    - a plain , separates the elements
3084
    - a sequence \\\\, (double backslash plus comma) is handled as a
3085
      backslash plus a separator comma
3086
    - a sequence \, (backslash plus comma) is handled as a
3087
      non-separator comma
3088

3089
  @type text: string
3090
  @param text: the string to split
3091
  @type sep: string
3092
  @param text: the separator
3093
  @rtype: string
3094
  @return: a list of strings
3095

3096
  """
3097
  # we split the list by sep (with no escaping at this stage)
3098
  slist = text.split(sep)
3099
  # next, we revisit the elements and if any of them ended with an odd
3100
  # number of backslashes, then we join it with the next
3101
  rlist = []
3102
  while slist:
3103
    e1 = slist.pop(0)
3104
    if e1.endswith("\\"):
3105
      num_b = len(e1) - len(e1.rstrip("\\"))
3106
      if num_b % 2 == 1:
3107
        e2 = slist.pop(0)
3108
        # here the backslashes remain (all), and will be reduced in
3109
        # the next step
3110
        rlist.append(e1 + sep + e2)
3111
        continue
3112
    rlist.append(e1)
3113
  # finally, replace backslash-something with something
3114
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
3115
  return rlist
3116

    
3117

    
3118
def CommaJoin(names):
3119
  """Nicely join a set of identifiers.
3120

3121
  @param names: set, list or tuple
3122
  @return: a string with the formatted results
3123

3124
  """
3125
  return ", ".join([str(val) for val in names])
3126

    
3127

    
3128
def BytesToMebibyte(value):
3129
  """Converts bytes to mebibytes.
3130

3131
  @type value: int
3132
  @param value: Value in bytes
3133
  @rtype: int
3134
  @return: Value in mebibytes
3135

3136
  """
3137
  return int(round(value / (1024.0 * 1024.0), 0))
3138

    
3139

    
3140
def CalculateDirectorySize(path):
3141
  """Calculates the size of a directory recursively.
3142

3143
  @type path: string
3144
  @param path: Path to directory
3145
  @rtype: int
3146
  @return: Size in mebibytes
3147

3148
  """
3149
  size = 0
3150

    
3151
  for (curpath, _, files) in os.walk(path):
3152
    for filename in files:
3153
      st = os.lstat(PathJoin(curpath, filename))
3154
      size += st.st_size
3155

    
3156
  return BytesToMebibyte(size)
3157

    
3158

    
3159
def GetFilesystemStats(path):
3160
  """Returns the total and free space on a filesystem.
3161

3162
  @type path: string
3163
  @param path: Path on filesystem to be examined
3164
  @rtype: int
3165
  @return: tuple of (Total space, Free space) in mebibytes
3166

3167
  """
3168
  st = os.statvfs(path)
3169

    
3170
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
3171
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
3172
  return (tsize, fsize)
3173

    
3174

    
3175
def RunInSeparateProcess(fn, *args):
3176
  """Runs a function in a separate process.
3177

3178
  Note: Only boolean return values are supported.
3179

3180
  @type fn: callable
3181
  @param fn: Function to be called
3182
  @rtype: bool
3183
  @return: Function's result
3184

3185
  """
3186
  pid = os.fork()
3187
  if pid == 0:
3188
    # Child process
3189
    try:
3190
      # In case the function uses temporary files
3191
      ResetTempfileModule()
3192

    
3193
      # Call function
3194
      result = int(bool(fn(*args)))
3195
      assert result in (0, 1)
3196
    except: # pylint: disable-msg=W0702
3197
      logging.exception("Error while calling function in separate process")
3198
      # 0 and 1 are reserved for the return value
3199
      result = 33
3200

    
3201
    os._exit(result) # pylint: disable-msg=W0212
3202

    
3203
  # Parent process
3204

    
3205
  # Avoid zombies and check exit code
3206
  (_, status) = os.waitpid(pid, 0)
3207

    
3208
  if os.WIFSIGNALED(status):
3209
    exitcode = None
3210
    signum = os.WTERMSIG(status)
3211
  else:
3212
    exitcode = os.WEXITSTATUS(status)
3213
    signum = None
3214

    
3215
  if not (exitcode in (0, 1) and signum is None):
3216
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
3217
                              (exitcode, signum))
3218

    
3219
  return bool(exitcode)
3220

    
3221

    
3222
def IgnoreProcessNotFound(fn, *args, **kwargs):
3223
  """Ignores ESRCH when calling a process-related function.
3224

3225
  ESRCH is raised when a process is not found.
3226

3227
  @rtype: bool
3228
  @return: Whether process was found
3229

3230
  """
3231
  try:
3232
    fn(*args, **kwargs)
3233
  except EnvironmentError, err:
3234
    # Ignore ESRCH
3235
    if err.errno == errno.ESRCH:
3236
      return False
3237
    raise
3238

    
3239
  return True
3240

    
3241

    
3242
def IgnoreSignals(fn, *args, **kwargs):
3243
  """Tries to call a function ignoring failures due to EINTR.
3244

3245
  """
3246
  try:
3247
    return fn(*args, **kwargs)
3248
  except EnvironmentError, err:
3249
    if err.errno == errno.EINTR:
3250
      return None
3251
    else:
3252
      raise
3253
  except (select.error, socket.error), err:
3254
    # In python 2.6 and above select.error is an IOError, so it's handled
3255
    # above, in 2.5 and below it's not, and it's handled here.
3256
    if err.args and err.args[0] == errno.EINTR:
3257
      return None
3258
    else:
3259
      raise
3260

    
3261

    
3262
def LockFile(fd):
3263
  """Locks a file using POSIX locks.
3264

3265
  @type fd: int
3266
  @param fd: the file descriptor we need to lock
3267

3268
  """
3269
  try:
3270
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
3271
  except IOError, err:
3272
    if err.errno == errno.EAGAIN:
3273
      raise errors.LockError("File already locked")
3274
    raise
3275

    
3276

    
3277
def FormatTime(val):
3278
  """Formats a time value.
3279

3280
  @type val: float or None
3281
  @param val: the timestamp as returned by time.time()
3282
  @return: a string value or N/A if we don't have a valid timestamp
3283

3284
  """
3285
  if val is None or not isinstance(val, (int, float)):
3286
    return "N/A"
3287
  # these two codes works on Linux, but they are not guaranteed on all
3288
  # platforms
3289
  return time.strftime("%F %T", time.localtime(val))
3290

    
3291

    
3292
def FormatSeconds(secs):
3293
  """Formats seconds for easier reading.
3294

3295
  @type secs: number
3296
  @param secs: Number of seconds
3297
  @rtype: string
3298
  @return: Formatted seconds (e.g. "2d 9h 19m 49s")
3299

3300
  """
3301
  parts = []
3302

    
3303
  secs = round(secs, 0)
3304

    
3305
  if secs > 0:
3306
    # Negative values would be a bit tricky
3307
    for unit, one in [("d", 24 * 60 * 60), ("h", 60 * 60), ("m", 60)]:
3308
      (complete, secs) = divmod(secs, one)
3309
      if complete or parts:
3310
        parts.append("%d%s" % (complete, unit))
3311

    
3312
  parts.append("%ds" % secs)
3313

    
3314
  return " ".join(parts)
3315

    
3316

    
3317
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
3318
  """Reads the watcher pause file.
3319

3320
  @type filename: string
3321
  @param filename: Path to watcher pause file
3322
  @type now: None, float or int
3323
  @param now: Current time as Unix timestamp
3324
  @type remove_after: int
3325
  @param remove_after: Remove watcher pause file after specified amount of
3326
    seconds past the pause end time
3327

3328
  """
3329
  if now is None:
3330
    now = time.time()
3331

    
3332
  try:
3333
    value = ReadFile(filename)
3334
  except IOError, err:
3335
    if err.errno != errno.ENOENT:
3336
      raise
3337
    value = None
3338

    
3339
  if value is not None:
3340
    try:
3341
      value = int(value)
3342
    except ValueError:
3343
      logging.warning(("Watcher pause file (%s) contains invalid value,"
3344
                       " removing it"), filename)
3345
      RemoveFile(filename)
3346
      value = None
3347

    
3348
    if value is not None:
3349
      # Remove file if it's outdated
3350
      if now > (value + remove_after):
3351
        RemoveFile(filename)
3352
        value = None
3353

    
3354
      elif now > value:
3355
        value = None
3356

    
3357
  return value
3358

    
3359

    
3360
class RetryTimeout(Exception):
3361
  """Retry loop timed out.
3362

3363
  Any arguments which was passed by the retried function to RetryAgain will be
3364
  preserved in RetryTimeout, if it is raised. If such argument was an exception
3365
  the RaiseInner helper method will reraise it.
3366

3367
  """
3368
  def RaiseInner(self):
3369
    if self.args and isinstance(self.args[0], Exception):
3370
      raise self.args[0]
3371
    else:
3372
      raise RetryTimeout(*self.args)
3373

    
3374

    
3375
class RetryAgain(Exception):
3376
  """Retry again.
3377

3378
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
3379
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
3380
  of the RetryTimeout() method can be used to reraise it.
3381

3382
  """
3383

    
3384

    
3385
class _RetryDelayCalculator(object):
3386
  """Calculator for increasing delays.
3387

3388
  """
3389
  __slots__ = [
3390
    "_factor",
3391
    "_limit",
3392
    "_next",
3393
    "_start",
3394
    ]
3395

    
3396
  def __init__(self, start, factor, limit):
3397
    """Initializes this class.
3398

3399
    @type start: float
3400
    @param start: Initial delay
3401
    @type factor: float
3402
    @param factor: Factor for delay increase
3403
    @type limit: float or None
3404
    @param limit: Upper limit for delay or None for no limit
3405

3406
    """
3407
    assert start > 0.0
3408
    assert factor >= 1.0
3409
    assert limit is None or limit >= 0.0
3410

    
3411
    self._start = start
3412
    self._factor = factor
3413
    self._limit = limit
3414

    
3415
    self._next = start
3416

    
3417
  def __call__(self):
3418
    """Returns current delay and calculates the next one.
3419

3420
    """
3421
    current = self._next
3422

    
3423
    # Update for next run
3424
    if self._limit is None or self._next < self._limit:
3425
      self._next = min(self._limit, self._next * self._factor)
3426

    
3427
    return current
3428

    
3429

    
3430
#: Special delay to specify whole remaining timeout
3431
RETRY_REMAINING_TIME = object()
3432

    
3433

    
3434
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
3435
          _time_fn=time.time):
3436
  """Call a function repeatedly until it succeeds.
3437

3438
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
3439
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
3440
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
3441

3442
  C{delay} can be one of the following:
3443
    - callable returning the delay length as a float
3444
    - Tuple of (start, factor, limit)
3445
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
3446
      useful when overriding L{wait_fn} to wait for an external event)
3447
    - A static delay as a number (int or float)
3448

3449
  @type fn: callable
3450
  @param fn: Function to be called
3451
  @param delay: Either a callable (returning the delay), a tuple of (start,
3452
                factor, limit) (see L{_RetryDelayCalculator}),
3453
                L{RETRY_REMAINING_TIME} or a number (int or float)
3454
  @type timeout: float
3455
  @param timeout: Total timeout
3456
  @type wait_fn: callable
3457
  @param wait_fn: Waiting function
3458
  @return: Return value of function
3459

3460
  """
3461
  assert callable(fn)
3462
  assert callable(wait_fn)
3463
  assert callable(_time_fn)
3464

    
3465
  if args is None:
3466
    args = []
3467

    
3468
  end_time = _time_fn() + timeout
3469

    
3470
  if callable(delay):
3471
    # External function to calculate delay
3472
    calc_delay = delay
3473

    
3474
  elif isinstance(delay, (tuple, list)):
3475
    # Increasing delay with optional upper boundary
3476
    (start, factor, limit) = delay
3477
    calc_delay = _RetryDelayCalculator(start, factor, limit)
3478

    
3479
  elif delay is RETRY_REMAINING_TIME:
3480
    # Always use the remaining time
3481
    calc_delay = None
3482

    
3483
  else:
3484
    # Static delay
3485
    calc_delay = lambda: delay
3486

    
3487
  assert calc_delay is None or callable(calc_delay)
3488

    
3489
  while True:
3490
    retry_args = []
3491
    try:
3492
      # pylint: disable-msg=W0142
3493
      return fn(*args)
3494
    except RetryAgain, err:
3495
      retry_args = err.args
3496
    except RetryTimeout:
3497
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
3498
                                   " handle RetryTimeout")
3499

    
3500
    remaining_time = end_time - _time_fn()
3501

    
3502
    if remaining_time < 0.0:
3503
      # pylint: disable-msg=W0142
3504
      raise RetryTimeout(*retry_args)
3505

    
3506
    assert remaining_time >= 0.0
3507

    
3508
    if calc_delay is None:
3509
      wait_fn(remaining_time)
3510
    else:
3511
      current_delay = calc_delay()
3512
      if current_delay > 0.0:
3513
        wait_fn(current_delay)
3514

    
3515

    
3516
def GetClosedTempfile(*args, **kwargs):
3517
  """Creates a temporary file and returns its path.
3518

3519
  """
3520
  (fd, path) = tempfile.mkstemp(*args, **kwargs)
3521
  _CloseFDNoErr(fd)
3522
  return path
3523

    
3524

    
3525
def GenerateSelfSignedX509Cert(common_name, validity):
3526
  """Generates a self-signed X509 certificate.
3527

3528
  @type common_name: string
3529
  @param common_name: commonName value
3530
  @type validity: int
3531
  @param validity: Validity for certificate in seconds
3532

3533
  """
3534
  # Create private and public key
3535
  key = OpenSSL.crypto.PKey()
3536
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
3537

    
3538
  # Create self-signed certificate
3539
  cert = OpenSSL.crypto.X509()
3540
  if common_name:
3541
    cert.get_subject().CN = common_name
3542
  cert.set_serial_number(1)
3543
  cert.gmtime_adj_notBefore(0)
3544
  cert.gmtime_adj_notAfter(validity)
3545
  cert.set_issuer(cert.get_subject())
3546
  cert.set_pubkey(key)
3547
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
3548

    
3549
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
3550
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
3551

    
3552
  return (key_pem, cert_pem)
3553

    
3554

    
3555
def GenerateSelfSignedSslCert(filename, validity=(5 * 365)):
3556
  """Legacy function to generate self-signed X509 certificate.
3557

3558
  """
3559
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(None,
3560
                                                   validity * 24 * 60 * 60)
3561

    
3562
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
3563

    
3564

    
3565
class FileLock(object):
3566
  """Utility class for file locks.
3567

3568
  """
3569
  def __init__(self, fd, filename):
3570
    """Constructor for FileLock.
3571

3572
    @type fd: file
3573
    @param fd: File object
3574
    @type filename: str
3575
    @param filename: Path of the file opened at I{fd}
3576

3577
    """
3578
    self.fd = fd
3579
    self.filename = filename
3580

    
3581
  @classmethod
3582
  def Open(cls, filename):
3583
    """Creates and opens a file to be used as a file-based lock.
3584

3585
    @type filename: string
3586
    @param filename: path to the file to be locked
3587

3588
    """
3589
    # Using "os.open" is necessary to allow both opening existing file
3590
    # read/write and creating if not existing. Vanilla "open" will truncate an
3591
    # existing file -or- allow creating if not existing.
3592
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
3593
               filename)
3594

    
3595
  def __del__(self):
3596
    self.Close()
3597

    
3598
  def Close(self):
3599
    """Close the file and release the lock.
3600

3601
    """
3602
    if hasattr(self, "fd") and self.fd:
3603
      self.fd.close()
3604
      self.fd = None
3605

    
3606
  def _flock(self, flag, blocking, timeout, errmsg):
3607
    """Wrapper for fcntl.flock.
3608

3609
    @type flag: int
3610
    @param flag: operation flag
3611
    @type blocking: bool
3612
    @param blocking: whether the operation should be done in blocking mode.
3613
    @type timeout: None or float
3614
    @param timeout: for how long the operation should be retried (implies
3615
                    non-blocking mode).
3616
    @type errmsg: string
3617
    @param errmsg: error message in case operation fails.
3618

3619
    """
3620
    assert self.fd, "Lock was closed"
3621
    assert timeout is None or timeout >= 0, \
3622
      "If specified, timeout must be positive"
3623
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
3624

    
3625
    # When a timeout is used, LOCK_NB must always be set
3626
    if not (timeout is None and blocking):
3627
      flag |= fcntl.LOCK_NB
3628

    
3629
    if timeout is None:
3630
      self._Lock(self.fd, flag, timeout)
3631
    else:
3632
      try:
3633
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
3634
              args=(self.fd, flag, timeout))
3635
      except RetryTimeout:
3636
        raise errors.LockError(errmsg)
3637

    
3638
  @staticmethod
3639
  def _Lock(fd, flag, timeout):
3640
    try:
3641
      fcntl.flock(fd, flag)
3642
    except IOError, err:
3643
      if timeout is not None and err.errno == errno.EAGAIN:
3644
        raise RetryAgain()
3645

    
3646
      logging.exception("fcntl.flock failed")
3647
      raise
3648

    
3649
  def Exclusive(self, blocking=False, timeout=None):
3650
    """Locks the file in exclusive mode.
3651

3652
    @type blocking: boolean
3653
    @param blocking: whether to block and wait until we
3654
        can lock the file or return immediately
3655
    @type timeout: int or None
3656
    @param timeout: if not None, the duration to wait for the lock
3657
        (in blocking mode)
3658

3659
    """
3660
    self._flock(fcntl.LOCK_EX, blocking, timeout,
3661
                "Failed to lock %s in exclusive mode" % self.filename)
3662

    
3663
  def Shared(self, blocking=False, timeout=None):
3664
    """Locks the file in shared mode.
3665

3666
    @type blocking: boolean
3667
    @param blocking: whether to block and wait until we
3668
        can lock the file or return immediately
3669
    @type timeout: int or None
3670
    @param timeout: if not None, the duration to wait for the lock
3671
        (in blocking mode)
3672

3673
    """
3674
    self._flock(fcntl.LOCK_SH, blocking, timeout,
3675
                "Failed to lock %s in shared mode" % self.filename)
3676

    
3677
  def Unlock(self, blocking=True, timeout=None):
3678
    """Unlocks the file.
3679

3680
    According to C{flock(2)}, unlocking can also be a nonblocking
3681
    operation::
3682

3683
      To make a non-blocking request, include LOCK_NB with any of the above
3684
      operations.
3685

3686
    @type blocking: boolean
3687
    @param blocking: whether to block and wait until we
3688
        can lock the file or return immediately
3689
    @type timeout: int or None
3690
    @param timeout: if not None, the duration to wait for the lock
3691
        (in blocking mode)
3692

3693
    """
3694
    self._flock(fcntl.LOCK_UN, blocking, timeout,
3695
                "Failed to unlock %s" % self.filename)
3696

    
3697

    
3698
class LineSplitter:
3699
  """Splits data chunks into lines separated by newline.
3700

3701
  Instances provide a file-like interface.
3702

3703
  """
3704
  def __init__(self, line_fn, *args):
3705
    """Initializes this class.
3706

3707
    @type line_fn: callable
3708
    @param line_fn: Function called for each line, first parameter is line
3709
    @param args: Extra arguments for L{line_fn}
3710

3711
    """
3712
    assert callable(line_fn)
3713

    
3714
    if args:
3715
      # Python 2.4 doesn't have functools.partial yet
3716
      self._line_fn = \
3717
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
3718
    else:
3719
      self._line_fn = line_fn
3720

    
3721
    self._lines = collections.deque()
3722
    self._buffer = ""
3723

    
3724
  def write(self, data):
3725
    parts = (self._buffer + data).split("\n")
3726
    self._buffer = parts.pop()
3727
    self._lines.extend(parts)
3728

    
3729
  def flush(self):
3730
    while self._lines:
3731
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
3732

    
3733
  def close(self):
3734
    self.flush()
3735
    if self._buffer:
3736
      self._line_fn(self._buffer)
3737

    
3738

    
3739
def SignalHandled(signums):
3740
  """Signal Handled decoration.
3741

3742
  This special decorator installs a signal handler and then calls the target
3743
  function. The function must accept a 'signal_handlers' keyword argument,
3744
  which will contain a dict indexed by signal number, with SignalHandler
3745
  objects as values.
3746

3747
  The decorator can be safely stacked with iself, to handle multiple signals
3748
  with different handlers.
3749

3750
  @type signums: list
3751
  @param signums: signals to intercept
3752

3753
  """
3754
  def wrap(fn):
3755
    def sig_function(*args, **kwargs):
3756
      assert 'signal_handlers' not in kwargs or \
3757
             kwargs['signal_handlers'] is None or \
3758
             isinstance(kwargs['signal_handlers'], dict), \
3759
             "Wrong signal_handlers parameter in original function call"
3760
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
3761
        signal_handlers = kwargs['signal_handlers']
3762
      else:
3763
        signal_handlers = {}
3764
        kwargs['signal_handlers'] = signal_handlers
3765
      sighandler = SignalHandler(signums)
3766
      try:
3767
        for sig in signums:
3768
          signal_handlers[sig] = sighandler
3769
        return fn(*args, **kwargs)
3770
      finally:
3771
        sighandler.Reset()
3772
    return sig_function
3773
  return wrap
3774

    
3775

    
3776
class SignalWakeupFd(object):
3777
  try:
3778
    # This is only supported in Python 2.5 and above (some distributions
3779
    # backported it to Python 2.4)
3780
    _set_wakeup_fd_fn = signal.set_wakeup_fd
3781
  except AttributeError:
3782
    # Not supported
3783
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
3784
      return -1
3785
  else:
3786
    def _SetWakeupFd(self, fd):
3787
      return self._set_wakeup_fd_fn(fd)
3788

    
3789
  def __init__(self):
3790
    """Initializes this class.
3791

3792
    """
3793
    (read_fd, write_fd) = os.pipe()
3794

    
3795
    # Once these succeeded, the file descriptors will be closed automatically.
3796
    # Buffer size 0 is important, otherwise .read() with a specified length
3797
    # might buffer data and the file descriptors won't be marked readable.
3798
    self._read_fh = os.fdopen(read_fd, "r", 0)
3799
    self._write_fh = os.fdopen(write_fd, "w", 0)
3800

    
3801
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
3802

    
3803
    # Utility functions
3804
    self.fileno = self._read_fh.fileno
3805
    self.read = self._read_fh.read
3806

    
3807
  def Reset(self):
3808
    """Restores the previous wakeup file descriptor.
3809

3810
    """
3811
    if hasattr(self, "_previous") and self._previous is not None:
3812
      self._SetWakeupFd(self._previous)
3813
      self._previous = None
3814

    
3815
  def Notify(self):
3816
    """Notifies the wakeup file descriptor.
3817

3818
    """
3819
    self._write_fh.write("\0")
3820

    
3821
  def __del__(self):
3822
    """Called before object deletion.
3823

3824
    """
3825
    self.Reset()
3826

    
3827

    
3828
class SignalHandler(object):
3829
  """Generic signal handler class.
3830

3831
  It automatically restores the original handler when deconstructed or
3832
  when L{Reset} is called. You can either pass your own handler
3833
  function in or query the L{called} attribute to detect whether the
3834
  signal was sent.
3835

3836
  @type signum: list
3837
  @ivar signum: the signals we handle
3838
  @type called: boolean
3839
  @ivar called: tracks whether any of the signals have been raised
3840

3841
  """
3842
  def __init__(self, signum, handler_fn=None, wakeup=None):
3843
    """Constructs a new SignalHandler instance.
3844

3845
    @type signum: int or list of ints
3846
    @param signum: Single signal number or set of signal numbers
3847
    @type handler_fn: callable
3848
    @param handler_fn: Signal handling function
3849

3850
    """
3851
    assert handler_fn is None or callable(handler_fn)
3852

    
3853
    self.signum = set(signum)
3854
    self.called = False
3855

    
3856
    self._handler_fn = handler_fn
3857
    self._wakeup = wakeup
3858

    
3859
    self._previous = {}
3860
    try:
3861
      for signum in self.signum:
3862
        # Setup handler
3863
        prev_handler = signal.signal(signum, self._HandleSignal)
3864
        try:
3865
          self._previous[signum] = prev_handler
3866
        except:
3867
          # Restore previous handler
3868
          signal.signal(signum, prev_handler)
3869
          raise
3870
    except:
3871
      # Reset all handlers
3872
      self.Reset()
3873
      # Here we have a race condition: a handler may have already been called,
3874
      # but there's not much we can do about it at this point.
3875
      raise
3876

    
3877
  def __del__(self):
3878
    self.Reset()
3879

    
3880
  def Reset(self):
3881
    """Restore previous handler.
3882

3883
    This will reset all the signals to their previous handlers.
3884

3885
    """
3886
    for signum, prev_handler in self._previous.items():
3887
      signal.signal(signum, prev_handler)
3888
      # If successful, remove from dict
3889
      del self._previous[signum]
3890

    
3891
  def Clear(self):
3892
    """Unsets the L{called} flag.
3893

3894
    This function can be used in case a signal may arrive several times.
3895

3896
    """
3897
    self.called = False
3898

    
3899
  def _HandleSignal(self, signum, frame):
3900
    """Actual signal handling function.
3901

3902
    """
3903
    # This is not nice and not absolutely atomic, but it appears to be the only
3904
    # solution in Python -- there are no atomic types.
3905
    self.called = True
3906

    
3907
    if self._wakeup:
3908
      # Notify whoever is interested in signals
3909
      self._wakeup.Notify()
3910

    
3911
    if self._handler_fn:
3912
      self._handler_fn(signum, frame)
3913

    
3914

    
3915
class FieldSet(object):
3916
  """A simple field set.
3917

3918
  Among the features are:
3919
    - checking if a string is among a list of static string or regex objects
3920
    - checking if a whole list of string matches
3921
    - returning the matching groups from a regex match
3922

3923
  Internally, all fields are held as regular expression objects.
3924

3925
  """
3926
  def __init__(self, *items):
3927
    self.items = [re.compile("^%s$" % value) for value in items]
3928

    
3929
  def Extend(self, other_set):
3930
    """Extend the field set with the items from another one"""
3931
    self.items.extend(other_set.items)
3932

    
3933
  def Matches(self, field):
3934
    """Checks if a field matches the current set
3935

3936
    @type field: str
3937
    @param field: the string to match
3938
    @return: either None or a regular expression match object
3939

3940
    """
3941
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3942
      return m
3943
    return None
3944

    
3945
  def NonMatching(self, items):
3946
    """Returns the list of fields not matching the current set
3947

3948
    @type items: list
3949
    @param items: the list of fields to check
3950
    @rtype: list
3951
    @return: list of non-matching fields
3952

3953
    """
3954
    return [val for val in items if not self.Matches(val)]