Statistics
| Branch: | Tag: | Revision:

root / lib / utils / __init__.py @ 9d1b963f

History | View | Annotate | Download (79.7 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import sys
32
import time
33
import subprocess
34
import re
35
import socket
36
import tempfile
37
import shutil
38
import errno
39
import pwd
40
import itertools
41
import select
42
import fcntl
43
import resource
44
import logging
45
import signal
46
import OpenSSL
47
import datetime
48
import calendar
49

    
50
from cStringIO import StringIO
51

    
52
from ganeti import errors
53
from ganeti import constants
54
from ganeti import compat
55

    
56
from ganeti.utils.algo import * # pylint: disable-msg=W0401
57
from ganeti.utils.retry import * # pylint: disable-msg=W0401
58
from ganeti.utils.text import * # pylint: disable-msg=W0401
59
from ganeti.utils.mlock import * # pylint: disable-msg=W0401
60
from ganeti.utils.log import * # pylint: disable-msg=W0401
61
from ganeti.utils.hash import * # pylint: disable-msg=W0401
62
from ganeti.utils.wrapper import * # pylint: disable-msg=W0401
63
from ganeti.utils.filelock import * # pylint: disable-msg=W0401
64

    
65

    
66
#: when set to True, L{RunCmd} is disabled
67
_no_fork = False
68

    
69
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
70

    
71
HEX_CHAR_RE = r"[a-zA-Z0-9]"
72
VALID_X509_SIGNATURE_SALT = re.compile("^%s+$" % HEX_CHAR_RE, re.S)
73
X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
74
                            (re.escape(constants.X509_CERT_SIGNATURE_HEADER),
75
                             HEX_CHAR_RE, HEX_CHAR_RE),
76
                            re.S | re.I)
77

    
78
_VALID_SERVICE_NAME_RE = re.compile("^[-_.a-zA-Z0-9]{1,128}$")
79

    
80
UUID_RE = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
81
                     '[a-f0-9]{4}-[a-f0-9]{12}$')
82

    
83
# Certificate verification results
84
(CERT_WARNING,
85
 CERT_ERROR) = range(1, 3)
86

    
87
(_TIMEOUT_NONE,
88
 _TIMEOUT_TERM,
89
 _TIMEOUT_KILL) = range(3)
90

    
91
#: Shell param checker regexp
92
_SHELLPARAM_REGEX = re.compile(r"^[-a-zA-Z0-9._+/:%@]+$")
93

    
94
#: ASN1 time regexp
95
_ASN1_TIME_REGEX = re.compile(r"^(\d+)([-+]\d\d)(\d\d)$")
96

    
97

    
98
def DisableFork():
99
  """Disables the use of fork(2).
100

101
  """
102
  global _no_fork # pylint: disable-msg=W0603
103

    
104
  _no_fork = True
105

    
106

    
107
class RunResult(object):
108
  """Holds the result of running external programs.
109

110
  @type exit_code: int
111
  @ivar exit_code: the exit code of the program, or None (if the program
112
      didn't exit())
113
  @type signal: int or None
114
  @ivar signal: the signal that caused the program to finish, or None
115
      (if the program wasn't terminated by a signal)
116
  @type stdout: str
117
  @ivar stdout: the standard output of the program
118
  @type stderr: str
119
  @ivar stderr: the standard error of the program
120
  @type failed: boolean
121
  @ivar failed: True in case the program was
122
      terminated by a signal or exited with a non-zero exit code
123
  @ivar fail_reason: a string detailing the termination reason
124

125
  """
126
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
127
               "failed", "fail_reason", "cmd"]
128

    
129

    
130
  def __init__(self, exit_code, signal_, stdout, stderr, cmd, timeout_action,
131
               timeout):
132
    self.cmd = cmd
133
    self.exit_code = exit_code
134
    self.signal = signal_
135
    self.stdout = stdout
136
    self.stderr = stderr
137
    self.failed = (signal_ is not None or exit_code != 0)
138

    
139
    fail_msgs = []
140
    if self.signal is not None:
141
      fail_msgs.append("terminated by signal %s" % self.signal)
142
    elif self.exit_code is not None:
143
      fail_msgs.append("exited with exit code %s" % self.exit_code)
144
    else:
145
      fail_msgs.append("unable to determine termination reason")
146

    
147
    if timeout_action == _TIMEOUT_TERM:
148
      fail_msgs.append("terminated after timeout of %.2f seconds" % timeout)
149
    elif timeout_action == _TIMEOUT_KILL:
150
      fail_msgs.append(("force termination after timeout of %.2f seconds"
151
                        " and linger for another %.2f seconds") %
152
                       (timeout, constants.CHILD_LINGER_TIMEOUT))
153

    
154
    if fail_msgs and self.failed:
155
      self.fail_reason = CommaJoin(fail_msgs)
156

    
157
    if self.failed:
158
      logging.debug("Command '%s' failed (%s); output: %s",
159
                    self.cmd, self.fail_reason, self.output)
160

    
161
  def _GetOutput(self):
162
    """Returns the combined stdout and stderr for easier usage.
163

164
    """
165
    return self.stdout + self.stderr
166

    
167
  output = property(_GetOutput, None, None, "Return full output")
168

    
169

    
170
def _BuildCmdEnvironment(env, reset):
171
  """Builds the environment for an external program.
172

173
  """
174
  if reset:
175
    cmd_env = {}
176
  else:
177
    cmd_env = os.environ.copy()
178
    cmd_env["LC_ALL"] = "C"
179

    
180
  if env is not None:
181
    cmd_env.update(env)
182

    
183
  return cmd_env
184

    
185

    
186
def RunCmd(cmd, env=None, output=None, cwd="/", reset_env=False,
187
           interactive=False, timeout=None):
188
  """Execute a (shell) command.
189

190
  The command should not read from its standard input, as it will be
191
  closed.
192

193
  @type cmd: string or list
194
  @param cmd: Command to run
195
  @type env: dict
196
  @param env: Additional environment variables
197
  @type output: str
198
  @param output: if desired, the output of the command can be
199
      saved in a file instead of the RunResult instance; this
200
      parameter denotes the file name (if not None)
201
  @type cwd: string
202
  @param cwd: if specified, will be used as the working
203
      directory for the command; the default will be /
204
  @type reset_env: boolean
205
  @param reset_env: whether to reset or keep the default os environment
206
  @type interactive: boolean
207
  @param interactive: weather we pipe stdin, stdout and stderr
208
                      (default behaviour) or run the command interactive
209
  @type timeout: int
210
  @param timeout: If not None, timeout in seconds until child process gets
211
                  killed
212
  @rtype: L{RunResult}
213
  @return: RunResult instance
214
  @raise errors.ProgrammerError: if we call this when forks are disabled
215

216
  """
217
  if _no_fork:
218
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
219

    
220
  if output and interactive:
221
    raise errors.ProgrammerError("Parameters 'output' and 'interactive' can"
222
                                 " not be provided at the same time")
223

    
224
  if isinstance(cmd, basestring):
225
    strcmd = cmd
226
    shell = True
227
  else:
228
    cmd = [str(val) for val in cmd]
229
    strcmd = ShellQuoteArgs(cmd)
230
    shell = False
231

    
232
  if output:
233
    logging.debug("RunCmd %s, output file '%s'", strcmd, output)
234
  else:
235
    logging.debug("RunCmd %s", strcmd)
236

    
237
  cmd_env = _BuildCmdEnvironment(env, reset_env)
238

    
239
  try:
240
    if output is None:
241
      out, err, status, timeout_action = _RunCmdPipe(cmd, cmd_env, shell, cwd,
242
                                                     interactive, timeout)
243
    else:
244
      timeout_action = _TIMEOUT_NONE
245
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
246
      out = err = ""
247
  except OSError, err:
248
    if err.errno == errno.ENOENT:
249
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
250
                               (strcmd, err))
251
    else:
252
      raise
253

    
254
  if status >= 0:
255
    exitcode = status
256
    signal_ = None
257
  else:
258
    exitcode = None
259
    signal_ = -status
260

    
261
  return RunResult(exitcode, signal_, out, err, strcmd, timeout_action, timeout)
262

    
263

    
264
def SetupDaemonEnv(cwd="/", umask=077):
265
  """Setup a daemon's environment.
266

267
  This should be called between the first and second fork, due to
268
  setsid usage.
269

270
  @param cwd: the directory to which to chdir
271
  @param umask: the umask to setup
272

273
  """
274
  os.chdir(cwd)
275
  os.umask(umask)
276
  os.setsid()
277

    
278

    
279
def SetupDaemonFDs(output_file, output_fd):
280
  """Setups up a daemon's file descriptors.
281

282
  @param output_file: if not None, the file to which to redirect
283
      stdout/stderr
284
  @param output_fd: if not None, the file descriptor for stdout/stderr
285

286
  """
287
  # check that at most one is defined
288
  assert [output_file, output_fd].count(None) >= 1
289

    
290
  # Open /dev/null (read-only, only for stdin)
291
  devnull_fd = os.open(os.devnull, os.O_RDONLY)
292

    
293
  if output_fd is not None:
294
    pass
295
  elif output_file is not None:
296
    # Open output file
297
    try:
298
      output_fd = os.open(output_file,
299
                          os.O_WRONLY | os.O_CREAT | os.O_APPEND, 0600)
300
    except EnvironmentError, err:
301
      raise Exception("Opening output file failed: %s" % err)
302
  else:
303
    output_fd = os.open(os.devnull, os.O_WRONLY)
304

    
305
  # Redirect standard I/O
306
  os.dup2(devnull_fd, 0)
307
  os.dup2(output_fd, 1)
308
  os.dup2(output_fd, 2)
309

    
310

    
311
def StartDaemon(cmd, env=None, cwd="/", output=None, output_fd=None,
312
                pidfile=None):
313
  """Start a daemon process after forking twice.
314

315
  @type cmd: string or list
316
  @param cmd: Command to run
317
  @type env: dict
318
  @param env: Additional environment variables
319
  @type cwd: string
320
  @param cwd: Working directory for the program
321
  @type output: string
322
  @param output: Path to file in which to save the output
323
  @type output_fd: int
324
  @param output_fd: File descriptor for output
325
  @type pidfile: string
326
  @param pidfile: Process ID file
327
  @rtype: int
328
  @return: Daemon process ID
329
  @raise errors.ProgrammerError: if we call this when forks are disabled
330

331
  """
332
  if _no_fork:
333
    raise errors.ProgrammerError("utils.StartDaemon() called with fork()"
334
                                 " disabled")
335

    
336
  if output and not (bool(output) ^ (output_fd is not None)):
337
    raise errors.ProgrammerError("Only one of 'output' and 'output_fd' can be"
338
                                 " specified")
339

    
340
  if isinstance(cmd, basestring):
341
    cmd = ["/bin/sh", "-c", cmd]
342

    
343
  strcmd = ShellQuoteArgs(cmd)
344

    
345
  if output:
346
    logging.debug("StartDaemon %s, output file '%s'", strcmd, output)
347
  else:
348
    logging.debug("StartDaemon %s", strcmd)
349

    
350
  cmd_env = _BuildCmdEnvironment(env, False)
351

    
352
  # Create pipe for sending PID back
353
  (pidpipe_read, pidpipe_write) = os.pipe()
354
  try:
355
    try:
356
      # Create pipe for sending error messages
357
      (errpipe_read, errpipe_write) = os.pipe()
358
      try:
359
        try:
360
          # First fork
361
          pid = os.fork()
362
          if pid == 0:
363
            try:
364
              # Child process, won't return
365
              _StartDaemonChild(errpipe_read, errpipe_write,
366
                                pidpipe_read, pidpipe_write,
367
                                cmd, cmd_env, cwd,
368
                                output, output_fd, pidfile)
369
            finally:
370
              # Well, maybe child process failed
371
              os._exit(1) # pylint: disable-msg=W0212
372
        finally:
373
          CloseFdNoError(errpipe_write)
374

    
375
        # Wait for daemon to be started (or an error message to
376
        # arrive) and read up to 100 KB as an error message
377
        errormsg = RetryOnSignal(os.read, errpipe_read, 100 * 1024)
378
      finally:
379
        CloseFdNoError(errpipe_read)
380
    finally:
381
      CloseFdNoError(pidpipe_write)
382

    
383
    # Read up to 128 bytes for PID
384
    pidtext = RetryOnSignal(os.read, pidpipe_read, 128)
385
  finally:
386
    CloseFdNoError(pidpipe_read)
387

    
388
  # Try to avoid zombies by waiting for child process
389
  try:
390
    os.waitpid(pid, 0)
391
  except OSError:
392
    pass
393

    
394
  if errormsg:
395
    raise errors.OpExecError("Error when starting daemon process: %r" %
396
                             errormsg)
397

    
398
  try:
399
    return int(pidtext)
400
  except (ValueError, TypeError), err:
401
    raise errors.OpExecError("Error while trying to parse PID %r: %s" %
402
                             (pidtext, err))
403

    
404

    
405
def _StartDaemonChild(errpipe_read, errpipe_write,
406
                      pidpipe_read, pidpipe_write,
407
                      args, env, cwd,
408
                      output, fd_output, pidfile):
409
  """Child process for starting daemon.
410

411
  """
412
  try:
413
    # Close parent's side
414
    CloseFdNoError(errpipe_read)
415
    CloseFdNoError(pidpipe_read)
416

    
417
    # First child process
418
    SetupDaemonEnv()
419

    
420
    # And fork for the second time
421
    pid = os.fork()
422
    if pid != 0:
423
      # Exit first child process
424
      os._exit(0) # pylint: disable-msg=W0212
425

    
426
    # Make sure pipe is closed on execv* (and thereby notifies
427
    # original process)
428
    SetCloseOnExecFlag(errpipe_write, True)
429

    
430
    # List of file descriptors to be left open
431
    noclose_fds = [errpipe_write]
432

    
433
    # Open PID file
434
    if pidfile:
435
      fd_pidfile = WritePidFile(pidfile)
436

    
437
      # Keeping the file open to hold the lock
438
      noclose_fds.append(fd_pidfile)
439

    
440
      SetCloseOnExecFlag(fd_pidfile, False)
441
    else:
442
      fd_pidfile = None
443

    
444
    SetupDaemonFDs(output, fd_output)
445

    
446
    # Send daemon PID to parent
447
    RetryOnSignal(os.write, pidpipe_write, str(os.getpid()))
448

    
449
    # Close all file descriptors except stdio and error message pipe
450
    CloseFDs(noclose_fds=noclose_fds)
451

    
452
    # Change working directory
453
    os.chdir(cwd)
454

    
455
    if env is None:
456
      os.execvp(args[0], args)
457
    else:
458
      os.execvpe(args[0], args, env)
459
  except: # pylint: disable-msg=W0702
460
    try:
461
      # Report errors to original process
462
      WriteErrorToFD(errpipe_write, str(sys.exc_info()[1]))
463
    except: # pylint: disable-msg=W0702
464
      # Ignore errors in error handling
465
      pass
466

    
467
  os._exit(1) # pylint: disable-msg=W0212
468

    
469

    
470
def WriteErrorToFD(fd, err):
471
  """Possibly write an error message to a fd.
472

473
  @type fd: None or int (file descriptor)
474
  @param fd: if not None, the error will be written to this fd
475
  @param err: string, the error message
476

477
  """
478
  if fd is None:
479
    return
480

    
481
  if not err:
482
    err = "<unknown error>"
483

    
484
  RetryOnSignal(os.write, fd, err)
485

    
486

    
487
def _CheckIfAlive(child):
488
  """Raises L{RetryAgain} if child is still alive.
489

490
  @raises RetryAgain: If child is still alive
491

492
  """
493
  if child.poll() is None:
494
    raise RetryAgain()
495

    
496

    
497
def _WaitForProcess(child, timeout):
498
  """Waits for the child to terminate or until we reach timeout.
499

500
  """
501
  try:
502
    Retry(_CheckIfAlive, (1.0, 1.2, 5.0), max(0, timeout), args=[child])
503
  except RetryTimeout:
504
    pass
505

    
506

    
507
def _RunCmdPipe(cmd, env, via_shell, cwd, interactive, timeout,
508
                _linger_timeout=constants.CHILD_LINGER_TIMEOUT):
509
  """Run a command and return its output.
510

511
  @type  cmd: string or list
512
  @param cmd: Command to run
513
  @type env: dict
514
  @param env: The environment to use
515
  @type via_shell: bool
516
  @param via_shell: if we should run via the shell
517
  @type cwd: string
518
  @param cwd: the working directory for the program
519
  @type interactive: boolean
520
  @param interactive: Run command interactive (without piping)
521
  @type timeout: int
522
  @param timeout: Timeout after the programm gets terminated
523
  @rtype: tuple
524
  @return: (out, err, status)
525

526
  """
527
  poller = select.poll()
528

    
529
  stderr = subprocess.PIPE
530
  stdout = subprocess.PIPE
531
  stdin = subprocess.PIPE
532

    
533
  if interactive:
534
    stderr = stdout = stdin = None
535

    
536
  child = subprocess.Popen(cmd, shell=via_shell,
537
                           stderr=stderr,
538
                           stdout=stdout,
539
                           stdin=stdin,
540
                           close_fds=True, env=env,
541
                           cwd=cwd)
542

    
543
  out = StringIO()
544
  err = StringIO()
545

    
546
  linger_timeout = None
547

    
548
  if timeout is None:
549
    poll_timeout = None
550
  else:
551
    poll_timeout = RunningTimeout(timeout, True).Remaining
552

    
553
  msg_timeout = ("Command %s (%d) run into execution timeout, terminating" %
554
                 (cmd, child.pid))
555
  msg_linger = ("Command %s (%d) run into linger timeout, killing" %
556
                (cmd, child.pid))
557

    
558
  timeout_action = _TIMEOUT_NONE
559

    
560
  if not interactive:
561
    child.stdin.close()
562
    poller.register(child.stdout, select.POLLIN)
563
    poller.register(child.stderr, select.POLLIN)
564
    fdmap = {
565
      child.stdout.fileno(): (out, child.stdout),
566
      child.stderr.fileno(): (err, child.stderr),
567
      }
568
    for fd in fdmap:
569
      SetNonblockFlag(fd, True)
570

    
571
    while fdmap:
572
      if poll_timeout:
573
        pt = poll_timeout() * 1000
574
        if pt < 0:
575
          if linger_timeout is None:
576
            logging.warning(msg_timeout)
577
            if child.poll() is None:
578
              timeout_action = _TIMEOUT_TERM
579
              IgnoreProcessNotFound(os.kill, child.pid, signal.SIGTERM)
580
            linger_timeout = RunningTimeout(_linger_timeout, True).Remaining
581
          pt = linger_timeout() * 1000
582
          if pt < 0:
583
            break
584
      else:
585
        pt = None
586

    
587
      pollresult = RetryOnSignal(poller.poll, pt)
588

    
589
      for fd, event in pollresult:
590
        if event & select.POLLIN or event & select.POLLPRI:
591
          data = fdmap[fd][1].read()
592
          # no data from read signifies EOF (the same as POLLHUP)
593
          if not data:
594
            poller.unregister(fd)
595
            del fdmap[fd]
596
            continue
597
          fdmap[fd][0].write(data)
598
        if (event & select.POLLNVAL or event & select.POLLHUP or
599
            event & select.POLLERR):
600
          poller.unregister(fd)
601
          del fdmap[fd]
602

    
603
  if timeout is not None:
604
    assert callable(poll_timeout)
605

    
606
    # We have no I/O left but it might still run
607
    if child.poll() is None:
608
      _WaitForProcess(child, poll_timeout())
609

    
610
    # Terminate if still alive after timeout
611
    if child.poll() is None:
612
      if linger_timeout is None:
613
        logging.warning(msg_timeout)
614
        timeout_action = _TIMEOUT_TERM
615
        IgnoreProcessNotFound(os.kill, child.pid, signal.SIGTERM)
616
        lt = _linger_timeout
617
      else:
618
        lt = linger_timeout()
619
      _WaitForProcess(child, lt)
620

    
621
    # Okay, still alive after timeout and linger timeout? Kill it!
622
    if child.poll() is None:
623
      timeout_action = _TIMEOUT_KILL
624
      logging.warning(msg_linger)
625
      IgnoreProcessNotFound(os.kill, child.pid, signal.SIGKILL)
626

    
627
  out = out.getvalue()
628
  err = err.getvalue()
629

    
630
  status = child.wait()
631
  return out, err, status, timeout_action
632

    
633

    
634
def _RunCmdFile(cmd, env, via_shell, output, cwd):
635
  """Run a command and save its output to a file.
636

637
  @type  cmd: string or list
638
  @param cmd: Command to run
639
  @type env: dict
640
  @param env: The environment to use
641
  @type via_shell: bool
642
  @param via_shell: if we should run via the shell
643
  @type output: str
644
  @param output: the filename in which to save the output
645
  @type cwd: string
646
  @param cwd: the working directory for the program
647
  @rtype: int
648
  @return: the exit status
649

650
  """
651
  fh = open(output, "a")
652
  try:
653
    child = subprocess.Popen(cmd, shell=via_shell,
654
                             stderr=subprocess.STDOUT,
655
                             stdout=fh,
656
                             stdin=subprocess.PIPE,
657
                             close_fds=True, env=env,
658
                             cwd=cwd)
659

    
660
    child.stdin.close()
661
    status = child.wait()
662
  finally:
663
    fh.close()
664
  return status
665

    
666

    
667
def RunParts(dir_name, env=None, reset_env=False):
668
  """Run Scripts or programs in a directory
669

670
  @type dir_name: string
671
  @param dir_name: absolute path to a directory
672
  @type env: dict
673
  @param env: The environment to use
674
  @type reset_env: boolean
675
  @param reset_env: whether to reset or keep the default os environment
676
  @rtype: list of tuples
677
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
678

679
  """
680
  rr = []
681

    
682
  try:
683
    dir_contents = ListVisibleFiles(dir_name)
684
  except OSError, err:
685
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
686
    return rr
687

    
688
  for relname in sorted(dir_contents):
689
    fname = PathJoin(dir_name, relname)
690
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
691
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
692
      rr.append((relname, constants.RUNPARTS_SKIP, None))
693
    else:
694
      try:
695
        result = RunCmd([fname], env=env, reset_env=reset_env)
696
      except Exception, err: # pylint: disable-msg=W0703
697
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
698
      else:
699
        rr.append((relname, constants.RUNPARTS_RUN, result))
700

    
701
  return rr
702

    
703

    
704
def RemoveFile(filename):
705
  """Remove a file ignoring some errors.
706

707
  Remove a file, ignoring non-existing ones or directories. Other
708
  errors are passed.
709

710
  @type filename: str
711
  @param filename: the file to be removed
712

713
  """
714
  try:
715
    os.unlink(filename)
716
  except OSError, err:
717
    if err.errno not in (errno.ENOENT, errno.EISDIR):
718
      raise
719

    
720

    
721
def RemoveDir(dirname):
722
  """Remove an empty directory.
723

724
  Remove a directory, ignoring non-existing ones.
725
  Other errors are passed. This includes the case,
726
  where the directory is not empty, so it can't be removed.
727

728
  @type dirname: str
729
  @param dirname: the empty directory to be removed
730

731
  """
732
  try:
733
    os.rmdir(dirname)
734
  except OSError, err:
735
    if err.errno != errno.ENOENT:
736
      raise
737

    
738

    
739
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
740
  """Renames a file.
741

742
  @type old: string
743
  @param old: Original path
744
  @type new: string
745
  @param new: New path
746
  @type mkdir: bool
747
  @param mkdir: Whether to create target directory if it doesn't exist
748
  @type mkdir_mode: int
749
  @param mkdir_mode: Mode for newly created directories
750

751
  """
752
  try:
753
    return os.rename(old, new)
754
  except OSError, err:
755
    # In at least one use case of this function, the job queue, directory
756
    # creation is very rare. Checking for the directory before renaming is not
757
    # as efficient.
758
    if mkdir and err.errno == errno.ENOENT:
759
      # Create directory and try again
760
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
761

    
762
      return os.rename(old, new)
763

    
764
    raise
765

    
766

    
767
def Makedirs(path, mode=0750):
768
  """Super-mkdir; create a leaf directory and all intermediate ones.
769

770
  This is a wrapper around C{os.makedirs} adding error handling not implemented
771
  before Python 2.5.
772

773
  """
774
  try:
775
    os.makedirs(path, mode)
776
  except OSError, err:
777
    # Ignore EEXIST. This is only handled in os.makedirs as included in
778
    # Python 2.5 and above.
779
    if err.errno != errno.EEXIST or not os.path.exists(path):
780
      raise
781

    
782

    
783
def ResetTempfileModule():
784
  """Resets the random name generator of the tempfile module.
785

786
  This function should be called after C{os.fork} in the child process to
787
  ensure it creates a newly seeded random generator. Otherwise it would
788
  generate the same random parts as the parent process. If several processes
789
  race for the creation of a temporary file, this could lead to one not getting
790
  a temporary name.
791

792
  """
793
  # pylint: disable-msg=W0212
794
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
795
    tempfile._once_lock.acquire()
796
    try:
797
      # Reset random name generator
798
      tempfile._name_sequence = None
799
    finally:
800
      tempfile._once_lock.release()
801
  else:
802
    logging.critical("The tempfile module misses at least one of the"
803
                     " '_once_lock' and '_name_sequence' attributes")
804

    
805

    
806
def ForceDictType(target, key_types, allowed_values=None):
807
  """Force the values of a dict to have certain types.
808

809
  @type target: dict
810
  @param target: the dict to update
811
  @type key_types: dict
812
  @param key_types: dict mapping target dict keys to types
813
                    in constants.ENFORCEABLE_TYPES
814
  @type allowed_values: list
815
  @keyword allowed_values: list of specially allowed values
816

817
  """
818
  if allowed_values is None:
819
    allowed_values = []
820

    
821
  if not isinstance(target, dict):
822
    msg = "Expected dictionary, got '%s'" % target
823
    raise errors.TypeEnforcementError(msg)
824

    
825
  for key in target:
826
    if key not in key_types:
827
      msg = "Unknown key '%s'" % key
828
      raise errors.TypeEnforcementError(msg)
829

    
830
    if target[key] in allowed_values:
831
      continue
832

    
833
    ktype = key_types[key]
834
    if ktype not in constants.ENFORCEABLE_TYPES:
835
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
836
      raise errors.ProgrammerError(msg)
837

    
838
    if ktype in (constants.VTYPE_STRING, constants.VTYPE_MAYBE_STRING):
839
      if target[key] is None and ktype == constants.VTYPE_MAYBE_STRING:
840
        pass
841
      elif not isinstance(target[key], basestring):
842
        if isinstance(target[key], bool) and not target[key]:
843
          target[key] = ''
844
        else:
845
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
846
          raise errors.TypeEnforcementError(msg)
847
    elif ktype == constants.VTYPE_BOOL:
848
      if isinstance(target[key], basestring) and target[key]:
849
        if target[key].lower() == constants.VALUE_FALSE:
850
          target[key] = False
851
        elif target[key].lower() == constants.VALUE_TRUE:
852
          target[key] = True
853
        else:
854
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
855
          raise errors.TypeEnforcementError(msg)
856
      elif target[key]:
857
        target[key] = True
858
      else:
859
        target[key] = False
860
    elif ktype == constants.VTYPE_SIZE:
861
      try:
862
        target[key] = ParseUnit(target[key])
863
      except errors.UnitParseError, err:
864
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
865
              (key, target[key], err)
866
        raise errors.TypeEnforcementError(msg)
867
    elif ktype == constants.VTYPE_INT:
868
      try:
869
        target[key] = int(target[key])
870
      except (ValueError, TypeError):
871
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
872
        raise errors.TypeEnforcementError(msg)
873

    
874

    
875
def _GetProcStatusPath(pid):
876
  """Returns the path for a PID's proc status file.
877

878
  @type pid: int
879
  @param pid: Process ID
880
  @rtype: string
881

882
  """
883
  return "/proc/%d/status" % pid
884

    
885

    
886
def IsProcessAlive(pid):
887
  """Check if a given pid exists on the system.
888

889
  @note: zombie status is not handled, so zombie processes
890
      will be returned as alive
891
  @type pid: int
892
  @param pid: the process ID to check
893
  @rtype: boolean
894
  @return: True if the process exists
895

896
  """
897
  def _TryStat(name):
898
    try:
899
      os.stat(name)
900
      return True
901
    except EnvironmentError, err:
902
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
903
        return False
904
      elif err.errno == errno.EINVAL:
905
        raise RetryAgain(err)
906
      raise
907

    
908
  assert isinstance(pid, int), "pid must be an integer"
909
  if pid <= 0:
910
    return False
911

    
912
  # /proc in a multiprocessor environment can have strange behaviors.
913
  # Retry the os.stat a few times until we get a good result.
914
  try:
915
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5,
916
                 args=[_GetProcStatusPath(pid)])
917
  except RetryTimeout, err:
918
    err.RaiseInner()
919

    
920

    
921
def _ParseSigsetT(sigset):
922
  """Parse a rendered sigset_t value.
923

924
  This is the opposite of the Linux kernel's fs/proc/array.c:render_sigset_t
925
  function.
926

927
  @type sigset: string
928
  @param sigset: Rendered signal set from /proc/$pid/status
929
  @rtype: set
930
  @return: Set of all enabled signal numbers
931

932
  """
933
  result = set()
934

    
935
  signum = 0
936
  for ch in reversed(sigset):
937
    chv = int(ch, 16)
938

    
939
    # The following could be done in a loop, but it's easier to read and
940
    # understand in the unrolled form
941
    if chv & 1:
942
      result.add(signum + 1)
943
    if chv & 2:
944
      result.add(signum + 2)
945
    if chv & 4:
946
      result.add(signum + 3)
947
    if chv & 8:
948
      result.add(signum + 4)
949

    
950
    signum += 4
951

    
952
  return result
953

    
954

    
955
def _GetProcStatusField(pstatus, field):
956
  """Retrieves a field from the contents of a proc status file.
957

958
  @type pstatus: string
959
  @param pstatus: Contents of /proc/$pid/status
960
  @type field: string
961
  @param field: Name of field whose value should be returned
962
  @rtype: string
963

964
  """
965
  for line in pstatus.splitlines():
966
    parts = line.split(":", 1)
967

    
968
    if len(parts) < 2 or parts[0] != field:
969
      continue
970

    
971
    return parts[1].strip()
972

    
973
  return None
974

    
975

    
976
def IsProcessHandlingSignal(pid, signum, status_path=None):
977
  """Checks whether a process is handling a signal.
978

979
  @type pid: int
980
  @param pid: Process ID
981
  @type signum: int
982
  @param signum: Signal number
983
  @rtype: bool
984

985
  """
986
  if status_path is None:
987
    status_path = _GetProcStatusPath(pid)
988

    
989
  try:
990
    proc_status = ReadFile(status_path)
991
  except EnvironmentError, err:
992
    # In at least one case, reading /proc/$pid/status failed with ESRCH.
993
    if err.errno in (errno.ENOENT, errno.ENOTDIR, errno.EINVAL, errno.ESRCH):
994
      return False
995
    raise
996

    
997
  sigcgt = _GetProcStatusField(proc_status, "SigCgt")
998
  if sigcgt is None:
999
    raise RuntimeError("%s is missing 'SigCgt' field" % status_path)
1000

    
1001
  # Now check whether signal is handled
1002
  return signum in _ParseSigsetT(sigcgt)
1003

    
1004

    
1005
def ReadPidFile(pidfile):
1006
  """Read a pid from a file.
1007

1008
  @type  pidfile: string
1009
  @param pidfile: path to the file containing the pid
1010
  @rtype: int
1011
  @return: The process id, if the file exists and contains a valid PID,
1012
           otherwise 0
1013

1014
  """
1015
  try:
1016
    raw_data = ReadOneLineFile(pidfile)
1017
  except EnvironmentError, err:
1018
    if err.errno != errno.ENOENT:
1019
      logging.exception("Can't read pid file")
1020
    return 0
1021

    
1022
  try:
1023
    pid = int(raw_data)
1024
  except (TypeError, ValueError), err:
1025
    logging.info("Can't parse pid file contents", exc_info=True)
1026
    return 0
1027

    
1028
  return pid
1029

    
1030

    
1031
def ReadLockedPidFile(path):
1032
  """Reads a locked PID file.
1033

1034
  This can be used together with L{StartDaemon}.
1035

1036
  @type path: string
1037
  @param path: Path to PID file
1038
  @return: PID as integer or, if file was unlocked or couldn't be opened, None
1039

1040
  """
1041
  try:
1042
    fd = os.open(path, os.O_RDONLY)
1043
  except EnvironmentError, err:
1044
    if err.errno == errno.ENOENT:
1045
      # PID file doesn't exist
1046
      return None
1047
    raise
1048

    
1049
  try:
1050
    try:
1051
      # Try to acquire lock
1052
      LockFile(fd)
1053
    except errors.LockError:
1054
      # Couldn't lock, daemon is running
1055
      return int(os.read(fd, 100))
1056
  finally:
1057
    os.close(fd)
1058

    
1059
  return None
1060

    
1061

    
1062
def ValidateServiceName(name):
1063
  """Validate the given service name.
1064

1065
  @type name: number or string
1066
  @param name: Service name or port specification
1067

1068
  """
1069
  try:
1070
    numport = int(name)
1071
  except (ValueError, TypeError):
1072
    # Non-numeric service name
1073
    valid = _VALID_SERVICE_NAME_RE.match(name)
1074
  else:
1075
    # Numeric port (protocols other than TCP or UDP might need adjustments
1076
    # here)
1077
    valid = (numport >= 0 and numport < (1 << 16))
1078

    
1079
  if not valid:
1080
    raise errors.OpPrereqError("Invalid service name '%s'" % name,
1081
                               errors.ECODE_INVAL)
1082

    
1083
  return name
1084

    
1085

    
1086
def ListVolumeGroups():
1087
  """List volume groups and their size
1088

1089
  @rtype: dict
1090
  @return:
1091
       Dictionary with keys volume name and values
1092
       the size of the volume
1093

1094
  """
1095
  command = "vgs --noheadings --units m --nosuffix -o name,size"
1096
  result = RunCmd(command)
1097
  retval = {}
1098
  if result.failed:
1099
    return retval
1100

    
1101
  for line in result.stdout.splitlines():
1102
    try:
1103
      name, size = line.split()
1104
      size = int(float(size))
1105
    except (IndexError, ValueError), err:
1106
      logging.error("Invalid output from vgs (%s): %s", err, line)
1107
      continue
1108

    
1109
    retval[name] = size
1110

    
1111
  return retval
1112

    
1113

    
1114
def BridgeExists(bridge):
1115
  """Check whether the given bridge exists in the system
1116

1117
  @type bridge: str
1118
  @param bridge: the bridge name to check
1119
  @rtype: boolean
1120
  @return: True if it does
1121

1122
  """
1123
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
1124

    
1125

    
1126
def TryConvert(fn, val):
1127
  """Try to convert a value ignoring errors.
1128

1129
  This function tries to apply function I{fn} to I{val}. If no
1130
  C{ValueError} or C{TypeError} exceptions are raised, it will return
1131
  the result, else it will return the original value. Any other
1132
  exceptions are propagated to the caller.
1133

1134
  @type fn: callable
1135
  @param fn: function to apply to the value
1136
  @param val: the value to be converted
1137
  @return: The converted value if the conversion was successful,
1138
      otherwise the original value.
1139

1140
  """
1141
  try:
1142
    nv = fn(val)
1143
  except (ValueError, TypeError):
1144
    nv = val
1145
  return nv
1146

    
1147

    
1148
def IsValidShellParam(word):
1149
  """Verifies is the given word is safe from the shell's p.o.v.
1150

1151
  This means that we can pass this to a command via the shell and be
1152
  sure that it doesn't alter the command line and is passed as such to
1153
  the actual command.
1154

1155
  Note that we are overly restrictive here, in order to be on the safe
1156
  side.
1157

1158
  @type word: str
1159
  @param word: the word to check
1160
  @rtype: boolean
1161
  @return: True if the word is 'safe'
1162

1163
  """
1164
  return bool(_SHELLPARAM_REGEX.match(word))
1165

    
1166

    
1167
def BuildShellCmd(template, *args):
1168
  """Build a safe shell command line from the given arguments.
1169

1170
  This function will check all arguments in the args list so that they
1171
  are valid shell parameters (i.e. they don't contain shell
1172
  metacharacters). If everything is ok, it will return the result of
1173
  template % args.
1174

1175
  @type template: str
1176
  @param template: the string holding the template for the
1177
      string formatting
1178
  @rtype: str
1179
  @return: the expanded command line
1180

1181
  """
1182
  for word in args:
1183
    if not IsValidShellParam(word):
1184
      raise errors.ProgrammerError("Shell argument '%s' contains"
1185
                                   " invalid characters" % word)
1186
  return template % args
1187

    
1188

    
1189
def ParseCpuMask(cpu_mask):
1190
  """Parse a CPU mask definition and return the list of CPU IDs.
1191

1192
  CPU mask format: comma-separated list of CPU IDs
1193
  or dash-separated ID ranges
1194
  Example: "0-2,5" -> "0,1,2,5"
1195

1196
  @type cpu_mask: str
1197
  @param cpu_mask: CPU mask definition
1198
  @rtype: list of int
1199
  @return: list of CPU IDs
1200

1201
  """
1202
  if not cpu_mask:
1203
    return []
1204
  cpu_list = []
1205
  for range_def in cpu_mask.split(","):
1206
    boundaries = range_def.split("-")
1207
    n_elements = len(boundaries)
1208
    if n_elements > 2:
1209
      raise errors.ParseError("Invalid CPU ID range definition"
1210
                              " (only one hyphen allowed): %s" % range_def)
1211
    try:
1212
      lower = int(boundaries[0])
1213
    except (ValueError, TypeError), err:
1214
      raise errors.ParseError("Invalid CPU ID value for lower boundary of"
1215
                              " CPU ID range: %s" % str(err))
1216
    try:
1217
      higher = int(boundaries[-1])
1218
    except (ValueError, TypeError), err:
1219
      raise errors.ParseError("Invalid CPU ID value for higher boundary of"
1220
                              " CPU ID range: %s" % str(err))
1221
    if lower > higher:
1222
      raise errors.ParseError("Invalid CPU ID range definition"
1223
                              " (%d > %d): %s" % (lower, higher, range_def))
1224
    cpu_list.extend(range(lower, higher + 1))
1225
  return cpu_list
1226

    
1227

    
1228
def AddAuthorizedKey(file_obj, key):
1229
  """Adds an SSH public key to an authorized_keys file.
1230

1231
  @type file_obj: str or file handle
1232
  @param file_obj: path to authorized_keys file
1233
  @type key: str
1234
  @param key: string containing key
1235

1236
  """
1237
  key_fields = key.split()
1238

    
1239
  if isinstance(file_obj, basestring):
1240
    f = open(file_obj, 'a+')
1241
  else:
1242
    f = file_obj
1243

    
1244
  try:
1245
    nl = True
1246
    for line in f:
1247
      # Ignore whitespace changes
1248
      if line.split() == key_fields:
1249
        break
1250
      nl = line.endswith('\n')
1251
    else:
1252
      if not nl:
1253
        f.write("\n")
1254
      f.write(key.rstrip('\r\n'))
1255
      f.write("\n")
1256
      f.flush()
1257
  finally:
1258
    f.close()
1259

    
1260

    
1261
def RemoveAuthorizedKey(file_name, key):
1262
  """Removes an SSH public key from an authorized_keys file.
1263

1264
  @type file_name: str
1265
  @param file_name: path to authorized_keys file
1266
  @type key: str
1267
  @param key: string containing key
1268

1269
  """
1270
  key_fields = key.split()
1271

    
1272
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1273
  try:
1274
    out = os.fdopen(fd, 'w')
1275
    try:
1276
      f = open(file_name, 'r')
1277
      try:
1278
        for line in f:
1279
          # Ignore whitespace changes while comparing lines
1280
          if line.split() != key_fields:
1281
            out.write(line)
1282

    
1283
        out.flush()
1284
        os.rename(tmpname, file_name)
1285
      finally:
1286
        f.close()
1287
    finally:
1288
      out.close()
1289
  except:
1290
    RemoveFile(tmpname)
1291
    raise
1292

    
1293

    
1294
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1295
  """Sets the name of an IP address and hostname in /etc/hosts.
1296

1297
  @type file_name: str
1298
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1299
  @type ip: str
1300
  @param ip: the IP address
1301
  @type hostname: str
1302
  @param hostname: the hostname to be added
1303
  @type aliases: list
1304
  @param aliases: the list of aliases to add for the hostname
1305

1306
  """
1307
  # Ensure aliases are unique
1308
  aliases = UniqueSequence([hostname] + aliases)[1:]
1309

    
1310
  def _WriteEtcHosts(fd):
1311
    # Duplicating file descriptor because os.fdopen's result will automatically
1312
    # close the descriptor, but we would still like to have its functionality.
1313
    out = os.fdopen(os.dup(fd), "w")
1314
    try:
1315
      for line in ReadFile(file_name).splitlines(True):
1316
        fields = line.split()
1317
        if fields and not fields[0].startswith("#") and ip == fields[0]:
1318
          continue
1319
        out.write(line)
1320

    
1321
      out.write("%s\t%s" % (ip, hostname))
1322
      if aliases:
1323
        out.write(" %s" % " ".join(aliases))
1324
      out.write("\n")
1325
      out.flush()
1326
    finally:
1327
      out.close()
1328

    
1329
  WriteFile(file_name, fn=_WriteEtcHosts, mode=0644)
1330

    
1331

    
1332
def AddHostToEtcHosts(hostname, ip):
1333
  """Wrapper around SetEtcHostsEntry.
1334

1335
  @type hostname: str
1336
  @param hostname: a hostname that will be resolved and added to
1337
      L{constants.ETC_HOSTS}
1338
  @type ip: str
1339
  @param ip: The ip address of the host
1340

1341
  """
1342
  SetEtcHostsEntry(constants.ETC_HOSTS, ip, hostname, [hostname.split(".")[0]])
1343

    
1344

    
1345
def RemoveEtcHostsEntry(file_name, hostname):
1346
  """Removes a hostname from /etc/hosts.
1347

1348
  IP addresses without names are removed from the file.
1349

1350
  @type file_name: str
1351
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1352
  @type hostname: str
1353
  @param hostname: the hostname to be removed
1354

1355
  """
1356
  def _WriteEtcHosts(fd):
1357
    # Duplicating file descriptor because os.fdopen's result will automatically
1358
    # close the descriptor, but we would still like to have its functionality.
1359
    out = os.fdopen(os.dup(fd), "w")
1360
    try:
1361
      for line in ReadFile(file_name).splitlines(True):
1362
        fields = line.split()
1363
        if len(fields) > 1 and not fields[0].startswith("#"):
1364
          names = fields[1:]
1365
          if hostname in names:
1366
            while hostname in names:
1367
              names.remove(hostname)
1368
            if names:
1369
              out.write("%s %s\n" % (fields[0], " ".join(names)))
1370
            continue
1371

    
1372
        out.write(line)
1373

    
1374
      out.flush()
1375
    finally:
1376
      out.close()
1377

    
1378
  WriteFile(file_name, fn=_WriteEtcHosts, mode=0644)
1379

    
1380

    
1381
def RemoveHostFromEtcHosts(hostname):
1382
  """Wrapper around RemoveEtcHostsEntry.
1383

1384
  @type hostname: str
1385
  @param hostname: hostname that will be resolved and its
1386
      full and shot name will be removed from
1387
      L{constants.ETC_HOSTS}
1388

1389
  """
1390
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname)
1391
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hostname.split(".")[0])
1392

    
1393

    
1394
def TimestampForFilename():
1395
  """Returns the current time formatted for filenames.
1396

1397
  The format doesn't contain colons as some shells and applications treat them
1398
  as separators. Uses the local timezone.
1399

1400
  """
1401
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1402

    
1403

    
1404
def CreateBackup(file_name):
1405
  """Creates a backup of a file.
1406

1407
  @type file_name: str
1408
  @param file_name: file to be backed up
1409
  @rtype: str
1410
  @return: the path to the newly created backup
1411
  @raise errors.ProgrammerError: for invalid file names
1412

1413
  """
1414
  if not os.path.isfile(file_name):
1415
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1416
                                file_name)
1417

    
1418
  prefix = ("%s.backup-%s." %
1419
            (os.path.basename(file_name), TimestampForFilename()))
1420
  dir_name = os.path.dirname(file_name)
1421

    
1422
  fsrc = open(file_name, 'rb')
1423
  try:
1424
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1425
    fdst = os.fdopen(fd, 'wb')
1426
    try:
1427
      logging.debug("Backing up %s at %s", file_name, backup_name)
1428
      shutil.copyfileobj(fsrc, fdst)
1429
    finally:
1430
      fdst.close()
1431
  finally:
1432
    fsrc.close()
1433

    
1434
  return backup_name
1435

    
1436

    
1437
def ListVisibleFiles(path):
1438
  """Returns a list of visible files in a directory.
1439

1440
  @type path: str
1441
  @param path: the directory to enumerate
1442
  @rtype: list
1443
  @return: the list of all files not starting with a dot
1444
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1445

1446
  """
1447
  if not IsNormAbsPath(path):
1448
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1449
                                 " absolute/normalized: '%s'" % path)
1450
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1451
  return files
1452

    
1453

    
1454
def GetHomeDir(user, default=None):
1455
  """Try to get the homedir of the given user.
1456

1457
  The user can be passed either as a string (denoting the name) or as
1458
  an integer (denoting the user id). If the user is not found, the
1459
  'default' argument is returned, which defaults to None.
1460

1461
  """
1462
  try:
1463
    if isinstance(user, basestring):
1464
      result = pwd.getpwnam(user)
1465
    elif isinstance(user, (int, long)):
1466
      result = pwd.getpwuid(user)
1467
    else:
1468
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1469
                                   type(user))
1470
  except KeyError:
1471
    return default
1472
  return result.pw_dir
1473

    
1474

    
1475
def NewUUID():
1476
  """Returns a random UUID.
1477

1478
  @note: This is a Linux-specific method as it uses the /proc
1479
      filesystem.
1480
  @rtype: str
1481

1482
  """
1483
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1484

    
1485

    
1486
def EnsureDirs(dirs):
1487
  """Make required directories, if they don't exist.
1488

1489
  @param dirs: list of tuples (dir_name, dir_mode)
1490
  @type dirs: list of (string, integer)
1491

1492
  """
1493
  for dir_name, dir_mode in dirs:
1494
    try:
1495
      os.mkdir(dir_name, dir_mode)
1496
    except EnvironmentError, err:
1497
      if err.errno != errno.EEXIST:
1498
        raise errors.GenericError("Cannot create needed directory"
1499
                                  " '%s': %s" % (dir_name, err))
1500
    try:
1501
      os.chmod(dir_name, dir_mode)
1502
    except EnvironmentError, err:
1503
      raise errors.GenericError("Cannot change directory permissions on"
1504
                                " '%s': %s" % (dir_name, err))
1505
    if not os.path.isdir(dir_name):
1506
      raise errors.GenericError("%s is not a directory" % dir_name)
1507

    
1508

    
1509
def ReadFile(file_name, size=-1):
1510
  """Reads a file.
1511

1512
  @type size: int
1513
  @param size: Read at most size bytes (if negative, entire file)
1514
  @rtype: str
1515
  @return: the (possibly partial) content of the file
1516

1517
  """
1518
  f = open(file_name, "r")
1519
  try:
1520
    return f.read(size)
1521
  finally:
1522
    f.close()
1523

    
1524

    
1525
def WriteFile(file_name, fn=None, data=None,
1526
              mode=None, uid=-1, gid=-1,
1527
              atime=None, mtime=None, close=True,
1528
              dry_run=False, backup=False,
1529
              prewrite=None, postwrite=None):
1530
  """(Over)write a file atomically.
1531

1532
  The file_name and either fn (a function taking one argument, the
1533
  file descriptor, and which should write the data to it) or data (the
1534
  contents of the file) must be passed. The other arguments are
1535
  optional and allow setting the file mode, owner and group, and the
1536
  mtime/atime of the file.
1537

1538
  If the function doesn't raise an exception, it has succeeded and the
1539
  target file has the new contents. If the function has raised an
1540
  exception, an existing target file should be unmodified and the
1541
  temporary file should be removed.
1542

1543
  @type file_name: str
1544
  @param file_name: the target filename
1545
  @type fn: callable
1546
  @param fn: content writing function, called with
1547
      file descriptor as parameter
1548
  @type data: str
1549
  @param data: contents of the file
1550
  @type mode: int
1551
  @param mode: file mode
1552
  @type uid: int
1553
  @param uid: the owner of the file
1554
  @type gid: int
1555
  @param gid: the group of the file
1556
  @type atime: int
1557
  @param atime: a custom access time to be set on the file
1558
  @type mtime: int
1559
  @param mtime: a custom modification time to be set on the file
1560
  @type close: boolean
1561
  @param close: whether to close file after writing it
1562
  @type prewrite: callable
1563
  @param prewrite: function to be called before writing content
1564
  @type postwrite: callable
1565
  @param postwrite: function to be called after writing content
1566

1567
  @rtype: None or int
1568
  @return: None if the 'close' parameter evaluates to True,
1569
      otherwise the file descriptor
1570

1571
  @raise errors.ProgrammerError: if any of the arguments are not valid
1572

1573
  """
1574
  if not os.path.isabs(file_name):
1575
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1576
                                 " absolute: '%s'" % file_name)
1577

    
1578
  if [fn, data].count(None) != 1:
1579
    raise errors.ProgrammerError("fn or data required")
1580

    
1581
  if [atime, mtime].count(None) == 1:
1582
    raise errors.ProgrammerError("Both atime and mtime must be either"
1583
                                 " set or None")
1584

    
1585
  if backup and not dry_run and os.path.isfile(file_name):
1586
    CreateBackup(file_name)
1587

    
1588
  dir_name, base_name = os.path.split(file_name)
1589
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1590
  do_remove = True
1591
  # here we need to make sure we remove the temp file, if any error
1592
  # leaves it in place
1593
  try:
1594
    if uid != -1 or gid != -1:
1595
      os.chown(new_name, uid, gid)
1596
    if mode:
1597
      os.chmod(new_name, mode)
1598
    if callable(prewrite):
1599
      prewrite(fd)
1600
    if data is not None:
1601
      os.write(fd, data)
1602
    else:
1603
      fn(fd)
1604
    if callable(postwrite):
1605
      postwrite(fd)
1606
    os.fsync(fd)
1607
    if atime is not None and mtime is not None:
1608
      os.utime(new_name, (atime, mtime))
1609
    if not dry_run:
1610
      os.rename(new_name, file_name)
1611
      do_remove = False
1612
  finally:
1613
    if close:
1614
      os.close(fd)
1615
      result = None
1616
    else:
1617
      result = fd
1618
    if do_remove:
1619
      RemoveFile(new_name)
1620

    
1621
  return result
1622

    
1623

    
1624
def GetFileID(path=None, fd=None):
1625
  """Returns the file 'id', i.e. the dev/inode and mtime information.
1626

1627
  Either the path to the file or the fd must be given.
1628

1629
  @param path: the file path
1630
  @param fd: a file descriptor
1631
  @return: a tuple of (device number, inode number, mtime)
1632

1633
  """
1634
  if [path, fd].count(None) != 1:
1635
    raise errors.ProgrammerError("One and only one of fd/path must be given")
1636

    
1637
  if fd is None:
1638
    st = os.stat(path)
1639
  else:
1640
    st = os.fstat(fd)
1641

    
1642
  return (st.st_dev, st.st_ino, st.st_mtime)
1643

    
1644

    
1645
def VerifyFileID(fi_disk, fi_ours):
1646
  """Verifies that two file IDs are matching.
1647

1648
  Differences in the inode/device are not accepted, but and older
1649
  timestamp for fi_disk is accepted.
1650

1651
  @param fi_disk: tuple (dev, inode, mtime) representing the actual
1652
      file data
1653
  @param fi_ours: tuple (dev, inode, mtime) representing the last
1654
      written file data
1655
  @rtype: boolean
1656

1657
  """
1658
  (d1, i1, m1) = fi_disk
1659
  (d2, i2, m2) = fi_ours
1660

    
1661
  return (d1, i1) == (d2, i2) and m1 <= m2
1662

    
1663

    
1664
def SafeWriteFile(file_name, file_id, **kwargs):
1665
  """Wraper over L{WriteFile} that locks the target file.
1666

1667
  By keeping the target file locked during WriteFile, we ensure that
1668
  cooperating writers will safely serialise access to the file.
1669

1670
  @type file_name: str
1671
  @param file_name: the target filename
1672
  @type file_id: tuple
1673
  @param file_id: a result from L{GetFileID}
1674

1675
  """
1676
  fd = os.open(file_name, os.O_RDONLY | os.O_CREAT)
1677
  try:
1678
    LockFile(fd)
1679
    if file_id is not None:
1680
      disk_id = GetFileID(fd=fd)
1681
      if not VerifyFileID(disk_id, file_id):
1682
        raise errors.LockError("Cannot overwrite file %s, it has been modified"
1683
                               " since last written" % file_name)
1684
    return WriteFile(file_name, **kwargs)
1685
  finally:
1686
    os.close(fd)
1687

    
1688

    
1689
def ReadOneLineFile(file_name, strict=False):
1690
  """Return the first non-empty line from a file.
1691

1692
  @type strict: boolean
1693
  @param strict: if True, abort if the file has more than one
1694
      non-empty line
1695

1696
  """
1697
  file_lines = ReadFile(file_name).splitlines()
1698
  full_lines = filter(bool, file_lines)
1699
  if not file_lines or not full_lines:
1700
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1701
  elif strict and len(full_lines) > 1:
1702
    raise errors.GenericError("Too many lines in one-liner file %s" %
1703
                              file_name)
1704
  return full_lines[0]
1705

    
1706

    
1707
def FirstFree(seq, base=0):
1708
  """Returns the first non-existing integer from seq.
1709

1710
  The seq argument should be a sorted list of positive integers. The
1711
  first time the index of an element is smaller than the element
1712
  value, the index will be returned.
1713

1714
  The base argument is used to start at a different offset,
1715
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1716

1717
  Example: C{[0, 1, 3]} will return I{2}.
1718

1719
  @type seq: sequence
1720
  @param seq: the sequence to be analyzed.
1721
  @type base: int
1722
  @param base: use this value as the base index of the sequence
1723
  @rtype: int
1724
  @return: the first non-used index in the sequence
1725

1726
  """
1727
  for idx, elem in enumerate(seq):
1728
    assert elem >= base, "Passed element is higher than base offset"
1729
    if elem > idx + base:
1730
      # idx is not used
1731
      return idx + base
1732
  return None
1733

    
1734

    
1735
def SingleWaitForFdCondition(fdobj, event, timeout):
1736
  """Waits for a condition to occur on the socket.
1737

1738
  Immediately returns at the first interruption.
1739

1740
  @type fdobj: integer or object supporting a fileno() method
1741
  @param fdobj: entity to wait for events on
1742
  @type event: integer
1743
  @param event: ORed condition (see select module)
1744
  @type timeout: float or None
1745
  @param timeout: Timeout in seconds
1746
  @rtype: int or None
1747
  @return: None for timeout, otherwise occured conditions
1748

1749
  """
1750
  check = (event | select.POLLPRI |
1751
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1752

    
1753
  if timeout is not None:
1754
    # Poller object expects milliseconds
1755
    timeout *= 1000
1756

    
1757
  poller = select.poll()
1758
  poller.register(fdobj, event)
1759
  try:
1760
    # TODO: If the main thread receives a signal and we have no timeout, we
1761
    # could wait forever. This should check a global "quit" flag or something
1762
    # every so often.
1763
    io_events = poller.poll(timeout)
1764
  except select.error, err:
1765
    if err[0] != errno.EINTR:
1766
      raise
1767
    io_events = []
1768
  if io_events and io_events[0][1] & check:
1769
    return io_events[0][1]
1770
  else:
1771
    return None
1772

    
1773

    
1774
class FdConditionWaiterHelper(object):
1775
  """Retry helper for WaitForFdCondition.
1776

1777
  This class contains the retried and wait functions that make sure
1778
  WaitForFdCondition can continue waiting until the timeout is actually
1779
  expired.
1780

1781
  """
1782

    
1783
  def __init__(self, timeout):
1784
    self.timeout = timeout
1785

    
1786
  def Poll(self, fdobj, event):
1787
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1788
    if result is None:
1789
      raise RetryAgain()
1790
    else:
1791
      return result
1792

    
1793
  def UpdateTimeout(self, timeout):
1794
    self.timeout = timeout
1795

    
1796

    
1797
def WaitForFdCondition(fdobj, event, timeout):
1798
  """Waits for a condition to occur on the socket.
1799

1800
  Retries until the timeout is expired, even if interrupted.
1801

1802
  @type fdobj: integer or object supporting a fileno() method
1803
  @param fdobj: entity to wait for events on
1804
  @type event: integer
1805
  @param event: ORed condition (see select module)
1806
  @type timeout: float or None
1807
  @param timeout: Timeout in seconds
1808
  @rtype: int or None
1809
  @return: None for timeout, otherwise occured conditions
1810

1811
  """
1812
  if timeout is not None:
1813
    retrywaiter = FdConditionWaiterHelper(timeout)
1814
    try:
1815
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1816
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1817
    except RetryTimeout:
1818
      result = None
1819
  else:
1820
    result = None
1821
    while result is None:
1822
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1823
  return result
1824

    
1825

    
1826
def CloseFDs(noclose_fds=None):
1827
  """Close file descriptors.
1828

1829
  This closes all file descriptors above 2 (i.e. except
1830
  stdin/out/err).
1831

1832
  @type noclose_fds: list or None
1833
  @param noclose_fds: if given, it denotes a list of file descriptor
1834
      that should not be closed
1835

1836
  """
1837
  # Default maximum for the number of available file descriptors.
1838
  if 'SC_OPEN_MAX' in os.sysconf_names:
1839
    try:
1840
      MAXFD = os.sysconf('SC_OPEN_MAX')
1841
      if MAXFD < 0:
1842
        MAXFD = 1024
1843
    except OSError:
1844
      MAXFD = 1024
1845
  else:
1846
    MAXFD = 1024
1847
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
1848
  if (maxfd == resource.RLIM_INFINITY):
1849
    maxfd = MAXFD
1850

    
1851
  # Iterate through and close all file descriptors (except the standard ones)
1852
  for fd in range(3, maxfd):
1853
    if noclose_fds and fd in noclose_fds:
1854
      continue
1855
    CloseFdNoError(fd)
1856

    
1857

    
1858
def Daemonize(logfile):
1859
  """Daemonize the current process.
1860

1861
  This detaches the current process from the controlling terminal and
1862
  runs it in the background as a daemon.
1863

1864
  @type logfile: str
1865
  @param logfile: the logfile to which we should redirect stdout/stderr
1866
  @rtype: int
1867
  @return: the value zero
1868

1869
  """
1870
  # pylint: disable-msg=W0212
1871
  # yes, we really want os._exit
1872

    
1873
  # TODO: do another attempt to merge Daemonize and StartDaemon, or at
1874
  # least abstract the pipe functionality between them
1875

    
1876
  # Create pipe for sending error messages
1877
  (rpipe, wpipe) = os.pipe()
1878

    
1879
  # this might fail
1880
  pid = os.fork()
1881
  if (pid == 0):  # The first child.
1882
    SetupDaemonEnv()
1883

    
1884
    # this might fail
1885
    pid = os.fork() # Fork a second child.
1886
    if (pid == 0):  # The second child.
1887
      CloseFdNoError(rpipe)
1888
    else:
1889
      # exit() or _exit()?  See below.
1890
      os._exit(0) # Exit parent (the first child) of the second child.
1891
  else:
1892
    CloseFdNoError(wpipe)
1893
    # Wait for daemon to be started (or an error message to
1894
    # arrive) and read up to 100 KB as an error message
1895
    errormsg = RetryOnSignal(os.read, rpipe, 100 * 1024)
1896
    if errormsg:
1897
      sys.stderr.write("Error when starting daemon process: %r\n" % errormsg)
1898
      rcode = 1
1899
    else:
1900
      rcode = 0
1901
    os._exit(rcode) # Exit parent of the first child.
1902

    
1903
  SetupDaemonFDs(logfile, None)
1904
  return wpipe
1905

    
1906

    
1907
def DaemonPidFileName(name):
1908
  """Compute a ganeti pid file absolute path
1909

1910
  @type name: str
1911
  @param name: the daemon name
1912
  @rtype: str
1913
  @return: the full path to the pidfile corresponding to the given
1914
      daemon name
1915

1916
  """
1917
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
1918

    
1919

    
1920
def EnsureDaemon(name):
1921
  """Check for and start daemon if not alive.
1922

1923
  """
1924
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
1925
  if result.failed:
1926
    logging.error("Can't start daemon '%s', failure %s, output: %s",
1927
                  name, result.fail_reason, result.output)
1928
    return False
1929

    
1930
  return True
1931

    
1932

    
1933
def StopDaemon(name):
1934
  """Stop daemon
1935

1936
  """
1937
  result = RunCmd([constants.DAEMON_UTIL, "stop", name])
1938
  if result.failed:
1939
    logging.error("Can't stop daemon '%s', failure %s, output: %s",
1940
                  name, result.fail_reason, result.output)
1941
    return False
1942

    
1943
  return True
1944

    
1945

    
1946
def WritePidFile(pidfile):
1947
  """Write the current process pidfile.
1948

1949
  @type pidfile: string
1950
  @param pidfile: the path to the file to be written
1951
  @raise errors.LockError: if the pid file already exists and
1952
      points to a live process
1953
  @rtype: int
1954
  @return: the file descriptor of the lock file; do not close this unless
1955
      you want to unlock the pid file
1956

1957
  """
1958
  # We don't rename nor truncate the file to not drop locks under
1959
  # existing processes
1960
  fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
1961

    
1962
  # Lock the PID file (and fail if not possible to do so). Any code
1963
  # wanting to send a signal to the daemon should try to lock the PID
1964
  # file before reading it. If acquiring the lock succeeds, the daemon is
1965
  # no longer running and the signal should not be sent.
1966
  LockFile(fd_pidfile)
1967

    
1968
  os.write(fd_pidfile, "%d\n" % os.getpid())
1969

    
1970
  return fd_pidfile
1971

    
1972

    
1973
def RemovePidFile(pidfile):
1974
  """Remove the current process pidfile.
1975

1976
  Any errors are ignored.
1977

1978
  @type pidfile: string
1979
  @param pidfile: Path to the file to be removed
1980

1981
  """
1982
  # TODO: we could check here that the file contains our pid
1983
  try:
1984
    RemoveFile(pidfile)
1985
  except Exception: # pylint: disable-msg=W0703
1986
    pass
1987

    
1988

    
1989
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
1990
                waitpid=False):
1991
  """Kill a process given by its pid.
1992

1993
  @type pid: int
1994
  @param pid: The PID to terminate.
1995
  @type signal_: int
1996
  @param signal_: The signal to send, by default SIGTERM
1997
  @type timeout: int
1998
  @param timeout: The timeout after which, if the process is still alive,
1999
                  a SIGKILL will be sent. If not positive, no such checking
2000
                  will be done
2001
  @type waitpid: boolean
2002
  @param waitpid: If true, we should waitpid on this process after
2003
      sending signals, since it's our own child and otherwise it
2004
      would remain as zombie
2005

2006
  """
2007
  def _helper(pid, signal_, wait):
2008
    """Simple helper to encapsulate the kill/waitpid sequence"""
2009
    if IgnoreProcessNotFound(os.kill, pid, signal_) and wait:
2010
      try:
2011
        os.waitpid(pid, os.WNOHANG)
2012
      except OSError:
2013
        pass
2014

    
2015
  if pid <= 0:
2016
    # kill with pid=0 == suicide
2017
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
2018

    
2019
  if not IsProcessAlive(pid):
2020
    return
2021

    
2022
  _helper(pid, signal_, waitpid)
2023

    
2024
  if timeout <= 0:
2025
    return
2026

    
2027
  def _CheckProcess():
2028
    if not IsProcessAlive(pid):
2029
      return
2030

    
2031
    try:
2032
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
2033
    except OSError:
2034
      raise RetryAgain()
2035

    
2036
    if result_pid > 0:
2037
      return
2038

    
2039
    raise RetryAgain()
2040

    
2041
  try:
2042
    # Wait up to $timeout seconds
2043
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
2044
  except RetryTimeout:
2045
    pass
2046

    
2047
  if IsProcessAlive(pid):
2048
    # Kill process if it's still alive
2049
    _helper(pid, signal.SIGKILL, waitpid)
2050

    
2051

    
2052
def FindFile(name, search_path, test=os.path.exists):
2053
  """Look for a filesystem object in a given path.
2054

2055
  This is an abstract method to search for filesystem object (files,
2056
  dirs) under a given search path.
2057

2058
  @type name: str
2059
  @param name: the name to look for
2060
  @type search_path: str
2061
  @param search_path: location to start at
2062
  @type test: callable
2063
  @param test: a function taking one argument that should return True
2064
      if the a given object is valid; the default value is
2065
      os.path.exists, causing only existing files to be returned
2066
  @rtype: str or None
2067
  @return: full path to the object if found, None otherwise
2068

2069
  """
2070
  # validate the filename mask
2071
  if constants.EXT_PLUGIN_MASK.match(name) is None:
2072
    logging.critical("Invalid value passed for external script name: '%s'",
2073
                     name)
2074
    return None
2075

    
2076
  for dir_name in search_path:
2077
    # FIXME: investigate switch to PathJoin
2078
    item_name = os.path.sep.join([dir_name, name])
2079
    # check the user test and that we're indeed resolving to the given
2080
    # basename
2081
    if test(item_name) and os.path.basename(item_name) == name:
2082
      return item_name
2083
  return None
2084

    
2085

    
2086
def CheckVolumeGroupSize(vglist, vgname, minsize):
2087
  """Checks if the volume group list is valid.
2088

2089
  The function will check if a given volume group is in the list of
2090
  volume groups and has a minimum size.
2091

2092
  @type vglist: dict
2093
  @param vglist: dictionary of volume group names and their size
2094
  @type vgname: str
2095
  @param vgname: the volume group we should check
2096
  @type minsize: int
2097
  @param minsize: the minimum size we accept
2098
  @rtype: None or str
2099
  @return: None for success, otherwise the error message
2100

2101
  """
2102
  vgsize = vglist.get(vgname, None)
2103
  if vgsize is None:
2104
    return "volume group '%s' missing" % vgname
2105
  elif vgsize < minsize:
2106
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2107
            (vgname, minsize, vgsize))
2108
  return None
2109

    
2110

    
2111
def SplitTime(value):
2112
  """Splits time as floating point number into a tuple.
2113

2114
  @param value: Time in seconds
2115
  @type value: int or float
2116
  @return: Tuple containing (seconds, microseconds)
2117

2118
  """
2119
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2120

    
2121
  assert 0 <= seconds, \
2122
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2123
  assert 0 <= microseconds <= 999999, \
2124
    "Microseconds must be 0-999999, but are %s" % microseconds
2125

    
2126
  return (int(seconds), int(microseconds))
2127

    
2128

    
2129
def MergeTime(timetuple):
2130
  """Merges a tuple into time as a floating point number.
2131

2132
  @param timetuple: Time as tuple, (seconds, microseconds)
2133
  @type timetuple: tuple
2134
  @return: Time as a floating point number expressed in seconds
2135

2136
  """
2137
  (seconds, microseconds) = timetuple
2138

    
2139
  assert 0 <= seconds, \
2140
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2141
  assert 0 <= microseconds <= 999999, \
2142
    "Microseconds must be 0-999999, but are %s" % microseconds
2143

    
2144
  return float(seconds) + (float(microseconds) * 0.000001)
2145

    
2146

    
2147
def IsNormAbsPath(path):
2148
  """Check whether a path is absolute and also normalized
2149

2150
  This avoids things like /dir/../../other/path to be valid.
2151

2152
  """
2153
  return os.path.normpath(path) == path and os.path.isabs(path)
2154

    
2155

    
2156
def PathJoin(*args):
2157
  """Safe-join a list of path components.
2158

2159
  Requirements:
2160
      - the first argument must be an absolute path
2161
      - no component in the path must have backtracking (e.g. /../),
2162
        since we check for normalization at the end
2163

2164
  @param args: the path components to be joined
2165
  @raise ValueError: for invalid paths
2166

2167
  """
2168
  # ensure we're having at least one path passed in
2169
  assert args
2170
  # ensure the first component is an absolute and normalized path name
2171
  root = args[0]
2172
  if not IsNormAbsPath(root):
2173
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2174
  result = os.path.join(*args)
2175
  # ensure that the whole path is normalized
2176
  if not IsNormAbsPath(result):
2177
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2178
  # check that we're still under the original prefix
2179
  prefix = os.path.commonprefix([root, result])
2180
  if prefix != root:
2181
    raise ValueError("Error: path joining resulted in different prefix"
2182
                     " (%s != %s)" % (prefix, root))
2183
  return result
2184

    
2185

    
2186
def TailFile(fname, lines=20):
2187
  """Return the last lines from a file.
2188

2189
  @note: this function will only read and parse the last 4KB of
2190
      the file; if the lines are very long, it could be that less
2191
      than the requested number of lines are returned
2192

2193
  @param fname: the file name
2194
  @type lines: int
2195
  @param lines: the (maximum) number of lines to return
2196

2197
  """
2198
  fd = open(fname, "r")
2199
  try:
2200
    fd.seek(0, 2)
2201
    pos = fd.tell()
2202
    pos = max(0, pos-4096)
2203
    fd.seek(pos, 0)
2204
    raw_data = fd.read()
2205
  finally:
2206
    fd.close()
2207

    
2208
  rows = raw_data.splitlines()
2209
  return rows[-lines:]
2210

    
2211

    
2212
def _ParseAsn1Generalizedtime(value):
2213
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2214

2215
  @type value: string
2216
  @param value: ASN1 GENERALIZEDTIME timestamp
2217
  @return: Seconds since the Epoch (1970-01-01 00:00:00 UTC)
2218

2219
  """
2220
  m = _ASN1_TIME_REGEX.match(value)
2221
  if m:
2222
    # We have an offset
2223
    asn1time = m.group(1)
2224
    hours = int(m.group(2))
2225
    minutes = int(m.group(3))
2226
    utcoffset = (60 * hours) + minutes
2227
  else:
2228
    if not value.endswith("Z"):
2229
      raise ValueError("Missing timezone")
2230
    asn1time = value[:-1]
2231
    utcoffset = 0
2232

    
2233
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2234

    
2235
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2236

    
2237
  return calendar.timegm(tt.utctimetuple())
2238

    
2239

    
2240
def GetX509CertValidity(cert):
2241
  """Returns the validity period of the certificate.
2242

2243
  @type cert: OpenSSL.crypto.X509
2244
  @param cert: X509 certificate object
2245

2246
  """
2247
  # The get_notBefore and get_notAfter functions are only supported in
2248
  # pyOpenSSL 0.7 and above.
2249
  try:
2250
    get_notbefore_fn = cert.get_notBefore
2251
  except AttributeError:
2252
    not_before = None
2253
  else:
2254
    not_before_asn1 = get_notbefore_fn()
2255

    
2256
    if not_before_asn1 is None:
2257
      not_before = None
2258
    else:
2259
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2260

    
2261
  try:
2262
    get_notafter_fn = cert.get_notAfter
2263
  except AttributeError:
2264
    not_after = None
2265
  else:
2266
    not_after_asn1 = get_notafter_fn()
2267

    
2268
    if not_after_asn1 is None:
2269
      not_after = None
2270
    else:
2271
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2272

    
2273
  return (not_before, not_after)
2274

    
2275

    
2276
def _VerifyCertificateInner(expired, not_before, not_after, now,
2277
                            warn_days, error_days):
2278
  """Verifies certificate validity.
2279

2280
  @type expired: bool
2281
  @param expired: Whether pyOpenSSL considers the certificate as expired
2282
  @type not_before: number or None
2283
  @param not_before: Unix timestamp before which certificate is not valid
2284
  @type not_after: number or None
2285
  @param not_after: Unix timestamp after which certificate is invalid
2286
  @type now: number
2287
  @param now: Current time as Unix timestamp
2288
  @type warn_days: number or None
2289
  @param warn_days: How many days before expiration a warning should be reported
2290
  @type error_days: number or None
2291
  @param error_days: How many days before expiration an error should be reported
2292

2293
  """
2294
  if expired:
2295
    msg = "Certificate is expired"
2296

    
2297
    if not_before is not None and not_after is not None:
2298
      msg += (" (valid from %s to %s)" %
2299
              (FormatTime(not_before), FormatTime(not_after)))
2300
    elif not_before is not None:
2301
      msg += " (valid from %s)" % FormatTime(not_before)
2302
    elif not_after is not None:
2303
      msg += " (valid until %s)" % FormatTime(not_after)
2304

    
2305
    return (CERT_ERROR, msg)
2306

    
2307
  elif not_before is not None and not_before > now:
2308
    return (CERT_WARNING,
2309
            "Certificate not yet valid (valid from %s)" %
2310
            FormatTime(not_before))
2311

    
2312
  elif not_after is not None:
2313
    remaining_days = int((not_after - now) / (24 * 3600))
2314

    
2315
    msg = "Certificate expires in about %d days" % remaining_days
2316

    
2317
    if error_days is not None and remaining_days <= error_days:
2318
      return (CERT_ERROR, msg)
2319

    
2320
    if warn_days is not None and remaining_days <= warn_days:
2321
      return (CERT_WARNING, msg)
2322

    
2323
  return (None, None)
2324

    
2325

    
2326
def VerifyX509Certificate(cert, warn_days, error_days):
2327
  """Verifies a certificate for LUVerifyCluster.
2328

2329
  @type cert: OpenSSL.crypto.X509
2330
  @param cert: X509 certificate object
2331
  @type warn_days: number or None
2332
  @param warn_days: How many days before expiration a warning should be reported
2333
  @type error_days: number or None
2334
  @param error_days: How many days before expiration an error should be reported
2335

2336
  """
2337
  # Depending on the pyOpenSSL version, this can just return (None, None)
2338
  (not_before, not_after) = GetX509CertValidity(cert)
2339

    
2340
  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
2341
                                 time.time(), warn_days, error_days)
2342

    
2343

    
2344
def SignX509Certificate(cert, key, salt):
2345
  """Sign a X509 certificate.
2346

2347
  An RFC822-like signature header is added in front of the certificate.
2348

2349
  @type cert: OpenSSL.crypto.X509
2350
  @param cert: X509 certificate object
2351
  @type key: string
2352
  @param key: Key for HMAC
2353
  @type salt: string
2354
  @param salt: Salt for HMAC
2355
  @rtype: string
2356
  @return: Serialized and signed certificate in PEM format
2357

2358
  """
2359
  if not VALID_X509_SIGNATURE_SALT.match(salt):
2360
    raise errors.GenericError("Invalid salt: %r" % salt)
2361

    
2362
  # Dumping as PEM here ensures the certificate is in a sane format
2363
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2364

    
2365
  return ("%s: %s/%s\n\n%s" %
2366
          (constants.X509_CERT_SIGNATURE_HEADER, salt,
2367
           Sha1Hmac(key, cert_pem, salt=salt),
2368
           cert_pem))
2369

    
2370

    
2371
def _ExtractX509CertificateSignature(cert_pem):
2372
  """Helper function to extract signature from X509 certificate.
2373

2374
  """
2375
  # Extract signature from original PEM data
2376
  for line in cert_pem.splitlines():
2377
    if line.startswith("---"):
2378
      break
2379

    
2380
    m = X509_SIGNATURE.match(line.strip())
2381
    if m:
2382
      return (m.group("salt"), m.group("sign"))
2383

    
2384
  raise errors.GenericError("X509 certificate signature is missing")
2385

    
2386

    
2387
def LoadSignedX509Certificate(cert_pem, key):
2388
  """Verifies a signed X509 certificate.
2389

2390
  @type cert_pem: string
2391
  @param cert_pem: Certificate in PEM format and with signature header
2392
  @type key: string
2393
  @param key: Key for HMAC
2394
  @rtype: tuple; (OpenSSL.crypto.X509, string)
2395
  @return: X509 certificate object and salt
2396

2397
  """
2398
  (salt, signature) = _ExtractX509CertificateSignature(cert_pem)
2399

    
2400
  # Load certificate
2401
  cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
2402

    
2403
  # Dump again to ensure it's in a sane format
2404
  sane_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2405

    
2406
  if not VerifySha1Hmac(key, sane_pem, signature, salt=salt):
2407
    raise errors.GenericError("X509 certificate signature is invalid")
2408

    
2409
  return (cert, salt)
2410

    
2411

    
2412
def FindMatch(data, name):
2413
  """Tries to find an item in a dictionary matching a name.
2414

2415
  Callers have to ensure the data names aren't contradictory (e.g. a regexp
2416
  that matches a string). If the name isn't a direct key, all regular
2417
  expression objects in the dictionary are matched against it.
2418

2419
  @type data: dict
2420
  @param data: Dictionary containing data
2421
  @type name: string
2422
  @param name: Name to look for
2423
  @rtype: tuple; (value in dictionary, matched groups as list)
2424

2425
  """
2426
  if name in data:
2427
    return (data[name], [])
2428

    
2429
  for key, value in data.items():
2430
    # Regex objects
2431
    if hasattr(key, "match"):
2432
      m = key.match(name)
2433
      if m:
2434
        return (value, list(m.groups()))
2435

    
2436
  return None
2437

    
2438

    
2439
def BytesToMebibyte(value):
2440
  """Converts bytes to mebibytes.
2441

2442
  @type value: int
2443
  @param value: Value in bytes
2444
  @rtype: int
2445
  @return: Value in mebibytes
2446

2447
  """
2448
  return int(round(value / (1024.0 * 1024.0), 0))
2449

    
2450

    
2451
def CalculateDirectorySize(path):
2452
  """Calculates the size of a directory recursively.
2453

2454
  @type path: string
2455
  @param path: Path to directory
2456
  @rtype: int
2457
  @return: Size in mebibytes
2458

2459
  """
2460
  size = 0
2461

    
2462
  for (curpath, _, files) in os.walk(path):
2463
    for filename in files:
2464
      st = os.lstat(PathJoin(curpath, filename))
2465
      size += st.st_size
2466

    
2467
  return BytesToMebibyte(size)
2468

    
2469

    
2470
def GetMounts(filename=constants.PROC_MOUNTS):
2471
  """Returns the list of mounted filesystems.
2472

2473
  This function is Linux-specific.
2474

2475
  @param filename: path of mounts file (/proc/mounts by default)
2476
  @rtype: list of tuples
2477
  @return: list of mount entries (device, mountpoint, fstype, options)
2478

2479
  """
2480
  # TODO(iustin): investigate non-Linux options (e.g. via mount output)
2481
  data = []
2482
  mountlines = ReadFile(filename).splitlines()
2483
  for line in mountlines:
2484
    device, mountpoint, fstype, options, _ = line.split(None, 4)
2485
    data.append((device, mountpoint, fstype, options))
2486

    
2487
  return data
2488

    
2489

    
2490
def GetFilesystemStats(path):
2491
  """Returns the total and free space on a filesystem.
2492

2493
  @type path: string
2494
  @param path: Path on filesystem to be examined
2495
  @rtype: int
2496
  @return: tuple of (Total space, Free space) in mebibytes
2497

2498
  """
2499
  st = os.statvfs(path)
2500

    
2501
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2502
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2503
  return (tsize, fsize)
2504

    
2505

    
2506
def RunInSeparateProcess(fn, *args):
2507
  """Runs a function in a separate process.
2508

2509
  Note: Only boolean return values are supported.
2510

2511
  @type fn: callable
2512
  @param fn: Function to be called
2513
  @rtype: bool
2514
  @return: Function's result
2515

2516
  """
2517
  pid = os.fork()
2518
  if pid == 0:
2519
    # Child process
2520
    try:
2521
      # In case the function uses temporary files
2522
      ResetTempfileModule()
2523

    
2524
      # Call function
2525
      result = int(bool(fn(*args)))
2526
      assert result in (0, 1)
2527
    except: # pylint: disable-msg=W0702
2528
      logging.exception("Error while calling function in separate process")
2529
      # 0 and 1 are reserved for the return value
2530
      result = 33
2531

    
2532
    os._exit(result) # pylint: disable-msg=W0212
2533

    
2534
  # Parent process
2535

    
2536
  # Avoid zombies and check exit code
2537
  (_, status) = os.waitpid(pid, 0)
2538

    
2539
  if os.WIFSIGNALED(status):
2540
    exitcode = None
2541
    signum = os.WTERMSIG(status)
2542
  else:
2543
    exitcode = os.WEXITSTATUS(status)
2544
    signum = None
2545

    
2546
  if not (exitcode in (0, 1) and signum is None):
2547
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
2548
                              (exitcode, signum))
2549

    
2550
  return bool(exitcode)
2551

    
2552

    
2553
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
2554
  """Reads the watcher pause file.
2555

2556
  @type filename: string
2557
  @param filename: Path to watcher pause file
2558
  @type now: None, float or int
2559
  @param now: Current time as Unix timestamp
2560
  @type remove_after: int
2561
  @param remove_after: Remove watcher pause file after specified amount of
2562
    seconds past the pause end time
2563

2564
  """
2565
  if now is None:
2566
    now = time.time()
2567

    
2568
  try:
2569
    value = ReadFile(filename)
2570
  except IOError, err:
2571
    if err.errno != errno.ENOENT:
2572
      raise
2573
    value = None
2574

    
2575
  if value is not None:
2576
    try:
2577
      value = int(value)
2578
    except ValueError:
2579
      logging.warning(("Watcher pause file (%s) contains invalid value,"
2580
                       " removing it"), filename)
2581
      RemoveFile(filename)
2582
      value = None
2583

    
2584
    if value is not None:
2585
      # Remove file if it's outdated
2586
      if now > (value + remove_after):
2587
        RemoveFile(filename)
2588
        value = None
2589

    
2590
      elif now > value:
2591
        value = None
2592

    
2593
  return value
2594

    
2595

    
2596
def GenerateSelfSignedX509Cert(common_name, validity):
2597
  """Generates a self-signed X509 certificate.
2598

2599
  @type common_name: string
2600
  @param common_name: commonName value
2601
  @type validity: int
2602
  @param validity: Validity for certificate in seconds
2603

2604
  """
2605
  # Create private and public key
2606
  key = OpenSSL.crypto.PKey()
2607
  key.generate_key(OpenSSL.crypto.TYPE_RSA, constants.RSA_KEY_BITS)
2608

    
2609
  # Create self-signed certificate
2610
  cert = OpenSSL.crypto.X509()
2611
  if common_name:
2612
    cert.get_subject().CN = common_name
2613
  cert.set_serial_number(1)
2614
  cert.gmtime_adj_notBefore(0)
2615
  cert.gmtime_adj_notAfter(validity)
2616
  cert.set_issuer(cert.get_subject())
2617
  cert.set_pubkey(key)
2618
  cert.sign(key, constants.X509_CERT_SIGN_DIGEST)
2619

    
2620
  key_pem = OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)
2621
  cert_pem = OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
2622

    
2623
  return (key_pem, cert_pem)
2624

    
2625

    
2626
def GenerateSelfSignedSslCert(filename, common_name=constants.X509_CERT_CN,
2627
                              validity=constants.X509_CERT_DEFAULT_VALIDITY):
2628
  """Legacy function to generate self-signed X509 certificate.
2629

2630
  @type filename: str
2631
  @param filename: path to write certificate to
2632
  @type common_name: string
2633
  @param common_name: commonName value
2634
  @type validity: int
2635
  @param validity: validity of certificate in number of days
2636

2637
  """
2638
  # TODO: Investigate using the cluster name instead of X505_CERT_CN for
2639
  # common_name, as cluster-renames are very seldom, and it'd be nice if RAPI
2640
  # and node daemon certificates have the proper Subject/Issuer.
2641
  (key_pem, cert_pem) = GenerateSelfSignedX509Cert(common_name,
2642
                                                   validity * 24 * 60 * 60)
2643

    
2644
  WriteFile(filename, mode=0400, data=key_pem + cert_pem)
2645

    
2646

    
2647
def SignalHandled(signums):
2648
  """Signal Handled decoration.
2649

2650
  This special decorator installs a signal handler and then calls the target
2651
  function. The function must accept a 'signal_handlers' keyword argument,
2652
  which will contain a dict indexed by signal number, with SignalHandler
2653
  objects as values.
2654

2655
  The decorator can be safely stacked with iself, to handle multiple signals
2656
  with different handlers.
2657

2658
  @type signums: list
2659
  @param signums: signals to intercept
2660

2661
  """
2662
  def wrap(fn):
2663
    def sig_function(*args, **kwargs):
2664
      assert 'signal_handlers' not in kwargs or \
2665
             kwargs['signal_handlers'] is None or \
2666
             isinstance(kwargs['signal_handlers'], dict), \
2667
             "Wrong signal_handlers parameter in original function call"
2668
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
2669
        signal_handlers = kwargs['signal_handlers']
2670
      else:
2671
        signal_handlers = {}
2672
        kwargs['signal_handlers'] = signal_handlers
2673
      sighandler = SignalHandler(signums)
2674
      try:
2675
        for sig in signums:
2676
          signal_handlers[sig] = sighandler
2677
        return fn(*args, **kwargs)
2678
      finally:
2679
        sighandler.Reset()
2680
    return sig_function
2681
  return wrap
2682

    
2683

    
2684
class SignalWakeupFd(object):
2685
  try:
2686
    # This is only supported in Python 2.5 and above (some distributions
2687
    # backported it to Python 2.4)
2688
    _set_wakeup_fd_fn = signal.set_wakeup_fd
2689
  except AttributeError:
2690
    # Not supported
2691
    def _SetWakeupFd(self, _): # pylint: disable-msg=R0201
2692
      return -1
2693
  else:
2694
    def _SetWakeupFd(self, fd):
2695
      return self._set_wakeup_fd_fn(fd)
2696

    
2697
  def __init__(self):
2698
    """Initializes this class.
2699

2700
    """
2701
    (read_fd, write_fd) = os.pipe()
2702

    
2703
    # Once these succeeded, the file descriptors will be closed automatically.
2704
    # Buffer size 0 is important, otherwise .read() with a specified length
2705
    # might buffer data and the file descriptors won't be marked readable.
2706
    self._read_fh = os.fdopen(read_fd, "r", 0)
2707
    self._write_fh = os.fdopen(write_fd, "w", 0)
2708

    
2709
    self._previous = self._SetWakeupFd(self._write_fh.fileno())
2710

    
2711
    # Utility functions
2712
    self.fileno = self._read_fh.fileno
2713
    self.read = self._read_fh.read
2714

    
2715
  def Reset(self):
2716
    """Restores the previous wakeup file descriptor.
2717

2718
    """
2719
    if hasattr(self, "_previous") and self._previous is not None:
2720
      self._SetWakeupFd(self._previous)
2721
      self._previous = None
2722

    
2723
  def Notify(self):
2724
    """Notifies the wakeup file descriptor.
2725

2726
    """
2727
    self._write_fh.write("\0")
2728

    
2729
  def __del__(self):
2730
    """Called before object deletion.
2731

2732
    """
2733
    self.Reset()
2734

    
2735

    
2736
class SignalHandler(object):
2737
  """Generic signal handler class.
2738

2739
  It automatically restores the original handler when deconstructed or
2740
  when L{Reset} is called. You can either pass your own handler
2741
  function in or query the L{called} attribute to detect whether the
2742
  signal was sent.
2743

2744
  @type signum: list
2745
  @ivar signum: the signals we handle
2746
  @type called: boolean
2747
  @ivar called: tracks whether any of the signals have been raised
2748

2749
  """
2750
  def __init__(self, signum, handler_fn=None, wakeup=None):
2751
    """Constructs a new SignalHandler instance.
2752

2753
    @type signum: int or list of ints
2754
    @param signum: Single signal number or set of signal numbers
2755
    @type handler_fn: callable
2756
    @param handler_fn: Signal handling function
2757

2758
    """
2759
    assert handler_fn is None or callable(handler_fn)
2760

    
2761
    self.signum = set(signum)
2762
    self.called = False
2763

    
2764
    self._handler_fn = handler_fn
2765
    self._wakeup = wakeup
2766

    
2767
    self._previous = {}
2768
    try:
2769
      for signum in self.signum:
2770
        # Setup handler
2771
        prev_handler = signal.signal(signum, self._HandleSignal)
2772
        try:
2773
          self._previous[signum] = prev_handler
2774
        except:
2775
          # Restore previous handler
2776
          signal.signal(signum, prev_handler)
2777
          raise
2778
    except:
2779
      # Reset all handlers
2780
      self.Reset()
2781
      # Here we have a race condition: a handler may have already been called,
2782
      # but there's not much we can do about it at this point.
2783
      raise
2784

    
2785
  def __del__(self):
2786
    self.Reset()
2787

    
2788
  def Reset(self):
2789
    """Restore previous handler.
2790

2791
    This will reset all the signals to their previous handlers.
2792

2793
    """
2794
    for signum, prev_handler in self._previous.items():
2795
      signal.signal(signum, prev_handler)
2796
      # If successful, remove from dict
2797
      del self._previous[signum]
2798

    
2799
  def Clear(self):
2800
    """Unsets the L{called} flag.
2801

2802
    This function can be used in case a signal may arrive several times.
2803

2804
    """
2805
    self.called = False
2806

    
2807
  def _HandleSignal(self, signum, frame):
2808
    """Actual signal handling function.
2809

2810
    """
2811
    # This is not nice and not absolutely atomic, but it appears to be the only
2812
    # solution in Python -- there are no atomic types.
2813
    self.called = True
2814

    
2815
    if self._wakeup:
2816
      # Notify whoever is interested in signals
2817
      self._wakeup.Notify()
2818

    
2819
    if self._handler_fn:
2820
      self._handler_fn(signum, frame)
2821

    
2822

    
2823
class FieldSet(object):
2824
  """A simple field set.
2825

2826
  Among the features are:
2827
    - checking if a string is among a list of static string or regex objects
2828
    - checking if a whole list of string matches
2829
    - returning the matching groups from a regex match
2830

2831
  Internally, all fields are held as regular expression objects.
2832

2833
  """
2834
  def __init__(self, *items):
2835
    self.items = [re.compile("^%s$" % value) for value in items]
2836

    
2837
  def Extend(self, other_set):
2838
    """Extend the field set with the items from another one"""
2839
    self.items.extend(other_set.items)
2840

    
2841
  def Matches(self, field):
2842
    """Checks if a field matches the current set
2843

2844
    @type field: str
2845
    @param field: the string to match
2846
    @return: either None or a regular expression match object
2847

2848
    """
2849
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
2850
      return m
2851
    return None
2852

    
2853
  def NonMatching(self, items):
2854
    """Returns the list of fields not matching the current set
2855

2856
    @type items: list
2857
    @param items: the list of fields to check
2858
    @rtype: list
2859
    @return: list of non-matching fields
2860

2861
    """
2862
    return [val for val in items if not self.Matches(val)]
2863

    
2864

    
2865
class RunningTimeout(object):
2866
  """Class to calculate remaining timeout when doing several operations.
2867

2868
  """
2869
  __slots__ = [
2870
    "_allow_negative",
2871
    "_start_time",
2872
    "_time_fn",
2873
    "_timeout",
2874
    ]
2875

    
2876
  def __init__(self, timeout, allow_negative, _time_fn=time.time):
2877
    """Initializes this class.
2878

2879
    @type timeout: float
2880
    @param timeout: Timeout duration
2881
    @type allow_negative: bool
2882
    @param allow_negative: Whether to return values below zero
2883
    @param _time_fn: Time function for unittests
2884

2885
    """
2886
    object.__init__(self)
2887

    
2888
    if timeout is not None and timeout < 0.0:
2889
      raise ValueError("Timeout must not be negative")
2890

    
2891
    self._timeout = timeout
2892
    self._allow_negative = allow_negative
2893
    self._time_fn = _time_fn
2894

    
2895
    self._start_time = None
2896

    
2897
  def Remaining(self):
2898
    """Returns the remaining timeout.
2899

2900
    """
2901
    if self._timeout is None:
2902
      return None
2903

    
2904
    # Get start time on first calculation
2905
    if self._start_time is None:
2906
      self._start_time = self._time_fn()
2907

    
2908
    # Calculate remaining time
2909
    remaining_timeout = self._start_time + self._timeout - self._time_fn()
2910

    
2911
    if not self._allow_negative:
2912
      # Ensure timeout is always >= 0
2913
      return max(0.0, remaining_timeout)
2914

    
2915
    return remaining_timeout