Statistics
| Branch: | Tag: | Revision:

root / lib / utils.py @ 6ed0bbce

History | View | Annotate | Download (83.8 kB)

1
#
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Ganeti utility module.
23

24
This module holds functions that can be used in both daemons (all) and
25
the command line scripts.
26

27
"""
28

    
29

    
30
import os
31
import time
32
import subprocess
33
import re
34
import socket
35
import tempfile
36
import shutil
37
import errno
38
import pwd
39
import itertools
40
import select
41
import fcntl
42
import resource
43
import logging
44
import logging.handlers
45
import signal
46
import datetime
47
import calendar
48
import collections
49
import struct
50
import IN
51

    
52
from cStringIO import StringIO
53

    
54
try:
55
  from hashlib import sha1
56
except ImportError:
57
  import sha
58
  sha1 = sha.new
59

    
60
try:
61
  import ctypes
62
except ImportError:
63
  ctypes = None
64

    
65
from ganeti import errors
66
from ganeti import constants
67

    
68

    
69
_locksheld = []
70
_re_shell_unquoted = re.compile('^[-.,=:/_+@A-Za-z0-9]+$')
71

    
72
debug_locks = False
73

    
74
#: when set to True, L{RunCmd} is disabled
75
no_fork = False
76

    
77
_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
78

    
79
# Structure definition for getsockopt(SOL_SOCKET, SO_PEERCRED, ...):
80
# struct ucred { pid_t pid; uid_t uid; gid_t gid; };
81
#
82
# The GNU C Library defines gid_t and uid_t to be "unsigned int" and
83
# pid_t to "int".
84
#
85
# IEEE Std 1003.1-2008:
86
# "nlink_t, uid_t, gid_t, and id_t shall be integer types"
87
# "blksize_t, pid_t, and ssize_t shall be signed integer types"
88
_STRUCT_UCRED = "iII"
89
_STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
90

    
91
# Flags for mlockall() (from bits/mman.h)
92
_MCL_CURRENT = 1
93
_MCL_FUTURE = 2
94

    
95

    
96
class RunResult(object):
97
  """Holds the result of running external programs.
98

99
  @type exit_code: int
100
  @ivar exit_code: the exit code of the program, or None (if the program
101
      didn't exit())
102
  @type signal: int or None
103
  @ivar signal: the signal that caused the program to finish, or None
104
      (if the program wasn't terminated by a signal)
105
  @type stdout: str
106
  @ivar stdout: the standard output of the program
107
  @type stderr: str
108
  @ivar stderr: the standard error of the program
109
  @type failed: boolean
110
  @ivar failed: True in case the program was
111
      terminated by a signal or exited with a non-zero exit code
112
  @ivar fail_reason: a string detailing the termination reason
113

114
  """
115
  __slots__ = ["exit_code", "signal", "stdout", "stderr",
116
               "failed", "fail_reason", "cmd"]
117

    
118

    
119
  def __init__(self, exit_code, signal_, stdout, stderr, cmd):
120
    self.cmd = cmd
121
    self.exit_code = exit_code
122
    self.signal = signal_
123
    self.stdout = stdout
124
    self.stderr = stderr
125
    self.failed = (signal_ is not None or exit_code != 0)
126

    
127
    if self.signal is not None:
128
      self.fail_reason = "terminated by signal %s" % self.signal
129
    elif self.exit_code is not None:
130
      self.fail_reason = "exited with exit code %s" % self.exit_code
131
    else:
132
      self.fail_reason = "unable to determine termination reason"
133

    
134
    if self.failed:
135
      logging.debug("Command '%s' failed (%s); output: %s",
136
                    self.cmd, self.fail_reason, self.output)
137

    
138
  def _GetOutput(self):
139
    """Returns the combined stdout and stderr for easier usage.
140

141
    """
142
    return self.stdout + self.stderr
143

    
144
  output = property(_GetOutput, None, None, "Return full output")
145

    
146

    
147
def RunCmd(cmd, env=None, output=None, cwd='/', reset_env=False):
148
  """Execute a (shell) command.
149

150
  The command should not read from its standard input, as it will be
151
  closed.
152

153
  @type cmd: string or list
154
  @param cmd: Command to run
155
  @type env: dict
156
  @param env: Additional environment
157
  @type output: str
158
  @param output: if desired, the output of the command can be
159
      saved in a file instead of the RunResult instance; this
160
      parameter denotes the file name (if not None)
161
  @type cwd: string
162
  @param cwd: if specified, will be used as the working
163
      directory for the command; the default will be /
164
  @type reset_env: boolean
165
  @param reset_env: whether to reset or keep the default os environment
166
  @rtype: L{RunResult}
167
  @return: RunResult instance
168
  @raise errors.ProgrammerError: if we call this when forks are disabled
169

170
  """
171
  if no_fork:
172
    raise errors.ProgrammerError("utils.RunCmd() called with fork() disabled")
173

    
174
  if isinstance(cmd, list):
175
    cmd = [str(val) for val in cmd]
176
    strcmd = " ".join(cmd)
177
    shell = False
178
  else:
179
    strcmd = cmd
180
    shell = True
181
  logging.debug("RunCmd '%s'", strcmd)
182

    
183
  if not reset_env:
184
    cmd_env = os.environ.copy()
185
    cmd_env["LC_ALL"] = "C"
186
  else:
187
    cmd_env = {}
188

    
189
  if env is not None:
190
    cmd_env.update(env)
191

    
192
  try:
193
    if output is None:
194
      out, err, status = _RunCmdPipe(cmd, cmd_env, shell, cwd)
195
    else:
196
      status = _RunCmdFile(cmd, cmd_env, shell, output, cwd)
197
      out = err = ""
198
  except OSError, err:
199
    if err.errno == errno.ENOENT:
200
      raise errors.OpExecError("Can't execute '%s': not found (%s)" %
201
                               (strcmd, err))
202
    else:
203
      raise
204

    
205
  if status >= 0:
206
    exitcode = status
207
    signal_ = None
208
  else:
209
    exitcode = None
210
    signal_ = -status
211

    
212
  return RunResult(exitcode, signal_, out, err, strcmd)
213

    
214

    
215
def _RunCmdPipe(cmd, env, via_shell, cwd):
216
  """Run a command and return its output.
217

218
  @type  cmd: string or list
219
  @param cmd: Command to run
220
  @type env: dict
221
  @param env: The environment to use
222
  @type via_shell: bool
223
  @param via_shell: if we should run via the shell
224
  @type cwd: string
225
  @param cwd: the working directory for the program
226
  @rtype: tuple
227
  @return: (out, err, status)
228

229
  """
230
  poller = select.poll()
231
  child = subprocess.Popen(cmd, shell=via_shell,
232
                           stderr=subprocess.PIPE,
233
                           stdout=subprocess.PIPE,
234
                           stdin=subprocess.PIPE,
235
                           close_fds=True, env=env,
236
                           cwd=cwd)
237

    
238
  child.stdin.close()
239
  poller.register(child.stdout, select.POLLIN)
240
  poller.register(child.stderr, select.POLLIN)
241
  out = StringIO()
242
  err = StringIO()
243
  fdmap = {
244
    child.stdout.fileno(): (out, child.stdout),
245
    child.stderr.fileno(): (err, child.stderr),
246
    }
247
  for fd in fdmap:
248
    status = fcntl.fcntl(fd, fcntl.F_GETFL)
249
    fcntl.fcntl(fd, fcntl.F_SETFL, status | os.O_NONBLOCK)
250

    
251
  while fdmap:
252
    try:
253
      pollresult = poller.poll()
254
    except EnvironmentError, eerr:
255
      if eerr.errno == errno.EINTR:
256
        continue
257
      raise
258
    except select.error, serr:
259
      if serr[0] == errno.EINTR:
260
        continue
261
      raise
262

    
263
    for fd, event in pollresult:
264
      if event & select.POLLIN or event & select.POLLPRI:
265
        data = fdmap[fd][1].read()
266
        # no data from read signifies EOF (the same as POLLHUP)
267
        if not data:
268
          poller.unregister(fd)
269
          del fdmap[fd]
270
          continue
271
        fdmap[fd][0].write(data)
272
      if (event & select.POLLNVAL or event & select.POLLHUP or
273
          event & select.POLLERR):
274
        poller.unregister(fd)
275
        del fdmap[fd]
276

    
277
  out = out.getvalue()
278
  err = err.getvalue()
279

    
280
  status = child.wait()
281
  return out, err, status
282

    
283

    
284
def _RunCmdFile(cmd, env, via_shell, output, cwd):
285
  """Run a command and save its output to a file.
286

287
  @type  cmd: string or list
288
  @param cmd: Command to run
289
  @type env: dict
290
  @param env: The environment to use
291
  @type via_shell: bool
292
  @param via_shell: if we should run via the shell
293
  @type output: str
294
  @param output: the filename in which to save the output
295
  @type cwd: string
296
  @param cwd: the working directory for the program
297
  @rtype: int
298
  @return: the exit status
299

300
  """
301
  fh = open(output, "a")
302
  try:
303
    child = subprocess.Popen(cmd, shell=via_shell,
304
                             stderr=subprocess.STDOUT,
305
                             stdout=fh,
306
                             stdin=subprocess.PIPE,
307
                             close_fds=True, env=env,
308
                             cwd=cwd)
309

    
310
    child.stdin.close()
311
    status = child.wait()
312
  finally:
313
    fh.close()
314
  return status
315

    
316

    
317
def RunParts(dir_name, env=None, reset_env=False):
318
  """Run Scripts or programs in a directory
319

320
  @type dir_name: string
321
  @param dir_name: absolute path to a directory
322
  @type env: dict
323
  @param env: The environment to use
324
  @type reset_env: boolean
325
  @param reset_env: whether to reset or keep the default os environment
326
  @rtype: list of tuples
327
  @return: list of (name, (one of RUNDIR_STATUS), RunResult)
328

329
  """
330
  rr = []
331

    
332
  try:
333
    dir_contents = ListVisibleFiles(dir_name)
334
  except OSError, err:
335
    logging.warning("RunParts: skipping %s (cannot list: %s)", dir_name, err)
336
    return rr
337

    
338
  for relname in sorted(dir_contents):
339
    fname = PathJoin(dir_name, relname)
340
    if not (os.path.isfile(fname) and os.access(fname, os.X_OK) and
341
            constants.EXT_PLUGIN_MASK.match(relname) is not None):
342
      rr.append((relname, constants.RUNPARTS_SKIP, None))
343
    else:
344
      try:
345
        result = RunCmd([fname], env=env, reset_env=reset_env)
346
      except Exception, err: # pylint: disable-msg=W0703
347
        rr.append((relname, constants.RUNPARTS_ERR, str(err)))
348
      else:
349
        rr.append((relname, constants.RUNPARTS_RUN, result))
350

    
351
  return rr
352

    
353

    
354
def GetSocketCredentials(sock):
355
  """Returns the credentials of the foreign process connected to a socket.
356

357
  @param sock: Unix socket
358
  @rtype: tuple; (number, number, number)
359
  @return: The PID, UID and GID of the connected foreign process.
360

361
  """
362
  peercred = sock.getsockopt(socket.SOL_SOCKET, IN.SO_PEERCRED,
363
                             _STRUCT_UCRED_SIZE)
364
  return struct.unpack(_STRUCT_UCRED, peercred)
365

    
366

    
367
def RemoveFile(filename):
368
  """Remove a file ignoring some errors.
369

370
  Remove a file, ignoring non-existing ones or directories. Other
371
  errors are passed.
372

373
  @type filename: str
374
  @param filename: the file to be removed
375

376
  """
377
  try:
378
    os.unlink(filename)
379
  except OSError, err:
380
    if err.errno not in (errno.ENOENT, errno.EISDIR):
381
      raise
382

    
383

    
384
def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
385
  """Renames a file.
386

387
  @type old: string
388
  @param old: Original path
389
  @type new: string
390
  @param new: New path
391
  @type mkdir: bool
392
  @param mkdir: Whether to create target directory if it doesn't exist
393
  @type mkdir_mode: int
394
  @param mkdir_mode: Mode for newly created directories
395

396
  """
397
  try:
398
    return os.rename(old, new)
399
  except OSError, err:
400
    # In at least one use case of this function, the job queue, directory
401
    # creation is very rare. Checking for the directory before renaming is not
402
    # as efficient.
403
    if mkdir and err.errno == errno.ENOENT:
404
      # Create directory and try again
405
      Makedirs(os.path.dirname(new), mode=mkdir_mode)
406

    
407
      return os.rename(old, new)
408

    
409
    raise
410

    
411

    
412
def Makedirs(path, mode=0750):
413
  """Super-mkdir; create a leaf directory and all intermediate ones.
414

415
  This is a wrapper around C{os.makedirs} adding error handling not implemented
416
  before Python 2.5.
417

418
  """
419
  try:
420
    os.makedirs(path, mode)
421
  except OSError, err:
422
    # Ignore EEXIST. This is only handled in os.makedirs as included in
423
    # Python 2.5 and above.
424
    if err.errno != errno.EEXIST or not os.path.exists(path):
425
      raise
426

    
427

    
428
def ResetTempfileModule():
429
  """Resets the random name generator of the tempfile module.
430

431
  This function should be called after C{os.fork} in the child process to
432
  ensure it creates a newly seeded random generator. Otherwise it would
433
  generate the same random parts as the parent process. If several processes
434
  race for the creation of a temporary file, this could lead to one not getting
435
  a temporary name.
436

437
  """
438
  # pylint: disable-msg=W0212
439
  if hasattr(tempfile, "_once_lock") and hasattr(tempfile, "_name_sequence"):
440
    tempfile._once_lock.acquire()
441
    try:
442
      # Reset random name generator
443
      tempfile._name_sequence = None
444
    finally:
445
      tempfile._once_lock.release()
446
  else:
447
    logging.critical("The tempfile module misses at least one of the"
448
                     " '_once_lock' and '_name_sequence' attributes")
449

    
450

    
451
def _FingerprintFile(filename):
452
  """Compute the fingerprint of a file.
453

454
  If the file does not exist, a None will be returned
455
  instead.
456

457
  @type filename: str
458
  @param filename: the filename to checksum
459
  @rtype: str
460
  @return: the hex digest of the sha checksum of the contents
461
      of the file
462

463
  """
464
  if not (os.path.exists(filename) and os.path.isfile(filename)):
465
    return None
466

    
467
  f = open(filename)
468

    
469
  fp = sha1()
470
  while True:
471
    data = f.read(4096)
472
    if not data:
473
      break
474

    
475
    fp.update(data)
476

    
477
  return fp.hexdigest()
478

    
479

    
480
def FingerprintFiles(files):
481
  """Compute fingerprints for a list of files.
482

483
  @type files: list
484
  @param files: the list of filename to fingerprint
485
  @rtype: dict
486
  @return: a dictionary filename: fingerprint, holding only
487
      existing files
488

489
  """
490
  ret = {}
491

    
492
  for filename in files:
493
    cksum = _FingerprintFile(filename)
494
    if cksum:
495
      ret[filename] = cksum
496

    
497
  return ret
498

    
499

    
500
def ForceDictType(target, key_types, allowed_values=None):
501
  """Force the values of a dict to have certain types.
502

503
  @type target: dict
504
  @param target: the dict to update
505
  @type key_types: dict
506
  @param key_types: dict mapping target dict keys to types
507
                    in constants.ENFORCEABLE_TYPES
508
  @type allowed_values: list
509
  @keyword allowed_values: list of specially allowed values
510

511
  """
512
  if allowed_values is None:
513
    allowed_values = []
514

    
515
  if not isinstance(target, dict):
516
    msg = "Expected dictionary, got '%s'" % target
517
    raise errors.TypeEnforcementError(msg)
518

    
519
  for key in target:
520
    if key not in key_types:
521
      msg = "Unknown key '%s'" % key
522
      raise errors.TypeEnforcementError(msg)
523

    
524
    if target[key] in allowed_values:
525
      continue
526

    
527
    ktype = key_types[key]
528
    if ktype not in constants.ENFORCEABLE_TYPES:
529
      msg = "'%s' has non-enforceable type %s" % (key, ktype)
530
      raise errors.ProgrammerError(msg)
531

    
532
    if ktype == constants.VTYPE_STRING:
533
      if not isinstance(target[key], basestring):
534
        if isinstance(target[key], bool) and not target[key]:
535
          target[key] = ''
536
        else:
537
          msg = "'%s' (value %s) is not a valid string" % (key, target[key])
538
          raise errors.TypeEnforcementError(msg)
539
    elif ktype == constants.VTYPE_BOOL:
540
      if isinstance(target[key], basestring) and target[key]:
541
        if target[key].lower() == constants.VALUE_FALSE:
542
          target[key] = False
543
        elif target[key].lower() == constants.VALUE_TRUE:
544
          target[key] = True
545
        else:
546
          msg = "'%s' (value %s) is not a valid boolean" % (key, target[key])
547
          raise errors.TypeEnforcementError(msg)
548
      elif target[key]:
549
        target[key] = True
550
      else:
551
        target[key] = False
552
    elif ktype == constants.VTYPE_SIZE:
553
      try:
554
        target[key] = ParseUnit(target[key])
555
      except errors.UnitParseError, err:
556
        msg = "'%s' (value %s) is not a valid size. error: %s" % \
557
              (key, target[key], err)
558
        raise errors.TypeEnforcementError(msg)
559
    elif ktype == constants.VTYPE_INT:
560
      try:
561
        target[key] = int(target[key])
562
      except (ValueError, TypeError):
563
        msg = "'%s' (value %s) is not a valid integer" % (key, target[key])
564
        raise errors.TypeEnforcementError(msg)
565

    
566

    
567
def IsProcessAlive(pid):
568
  """Check if a given pid exists on the system.
569

570
  @note: zombie status is not handled, so zombie processes
571
      will be returned as alive
572
  @type pid: int
573
  @param pid: the process ID to check
574
  @rtype: boolean
575
  @return: True if the process exists
576

577
  """
578
  def _TryStat(name):
579
    try:
580
      os.stat(name)
581
      return True
582
    except EnvironmentError, err:
583
      if err.errno in (errno.ENOENT, errno.ENOTDIR):
584
        return False
585
      elif err.errno == errno.EINVAL:
586
        raise RetryAgain(err)
587
      raise
588

    
589
  assert isinstance(pid, int), "pid must be an integer"
590
  if pid <= 0:
591
    return False
592

    
593
  proc_entry = "/proc/%d/status" % pid
594
  # /proc in a multiprocessor environment can have strange behaviors.
595
  # Retry the os.stat a few times until we get a good result.
596
  try:
597
    return Retry(_TryStat, (0.01, 1.5, 0.1), 0.5, args=[proc_entry])
598
  except RetryTimeout, err:
599
    err.RaiseInner()
600

    
601

    
602
def ReadPidFile(pidfile):
603
  """Read a pid from a file.
604

605
  @type  pidfile: string
606
  @param pidfile: path to the file containing the pid
607
  @rtype: int
608
  @return: The process id, if the file exists and contains a valid PID,
609
           otherwise 0
610

611
  """
612
  try:
613
    raw_data = ReadOneLineFile(pidfile)
614
  except EnvironmentError, err:
615
    if err.errno != errno.ENOENT:
616
      logging.exception("Can't read pid file")
617
    return 0
618

    
619
  try:
620
    pid = int(raw_data)
621
  except (TypeError, ValueError), err:
622
    logging.info("Can't parse pid file contents", exc_info=True)
623
    return 0
624

    
625
  return pid
626

    
627

    
628
def MatchNameComponent(key, name_list, case_sensitive=True):
629
  """Try to match a name against a list.
630

631
  This function will try to match a name like test1 against a list
632
  like C{['test1.example.com', 'test2.example.com', ...]}. Against
633
  this list, I{'test1'} as well as I{'test1.example'} will match, but
634
  not I{'test1.ex'}. A multiple match will be considered as no match
635
  at all (e.g. I{'test1'} against C{['test1.example.com',
636
  'test1.example.org']}), except when the key fully matches an entry
637
  (e.g. I{'test1'} against C{['test1', 'test1.example.com']}).
638

639
  @type key: str
640
  @param key: the name to be searched
641
  @type name_list: list
642
  @param name_list: the list of strings against which to search the key
643
  @type case_sensitive: boolean
644
  @param case_sensitive: whether to provide a case-sensitive match
645

646
  @rtype: None or str
647
  @return: None if there is no match I{or} if there are multiple matches,
648
      otherwise the element from the list which matches
649

650
  """
651
  if key in name_list:
652
    return key
653

    
654
  re_flags = 0
655
  if not case_sensitive:
656
    re_flags |= re.IGNORECASE
657
    key = key.upper()
658
  mo = re.compile("^%s(\..*)?$" % re.escape(key), re_flags)
659
  names_filtered = []
660
  string_matches = []
661
  for name in name_list:
662
    if mo.match(name) is not None:
663
      names_filtered.append(name)
664
      if not case_sensitive and key == name.upper():
665
        string_matches.append(name)
666

    
667
  if len(string_matches) == 1:
668
    return string_matches[0]
669
  if len(names_filtered) == 1:
670
    return names_filtered[0]
671
  return None
672

    
673

    
674
class HostInfo:
675
  """Class implementing resolver and hostname functionality
676

677
  """
678
  _VALID_NAME_RE = re.compile("^[a-z0-9._-]{1,255}$")
679

    
680
  def __init__(self, name=None):
681
    """Initialize the host name object.
682

683
    If the name argument is not passed, it will use this system's
684
    name.
685

686
    """
687
    if name is None:
688
      name = self.SysName()
689

    
690
    self.query = name
691
    self.name, self.aliases, self.ipaddrs = self.LookupHostname(name)
692
    self.ip = self.ipaddrs[0]
693

    
694
  def ShortName(self):
695
    """Returns the hostname without domain.
696

697
    """
698
    return self.name.split('.')[0]
699

    
700
  @staticmethod
701
  def SysName():
702
    """Return the current system's name.
703

704
    This is simply a wrapper over C{socket.gethostname()}.
705

706
    """
707
    return socket.gethostname()
708

    
709
  @staticmethod
710
  def LookupHostname(hostname):
711
    """Look up hostname
712

713
    @type hostname: str
714
    @param hostname: hostname to look up
715

716
    @rtype: tuple
717
    @return: a tuple (name, aliases, ipaddrs) as returned by
718
        C{socket.gethostbyname_ex}
719
    @raise errors.ResolverError: in case of errors in resolving
720

721
    """
722
    try:
723
      result = socket.gethostbyname_ex(hostname)
724
    except socket.gaierror, err:
725
      # hostname not found in DNS
726
      raise errors.ResolverError(hostname, err.args[0], err.args[1])
727

    
728
    return result
729

    
730
  @classmethod
731
  def NormalizeName(cls, hostname):
732
    """Validate and normalize the given hostname.
733

734
    @attention: the validation is a bit more relaxed than the standards
735
        require; most importantly, we allow underscores in names
736
    @raise errors.OpPrereqError: when the name is not valid
737

738
    """
739
    hostname = hostname.lower()
740
    if (not cls._VALID_NAME_RE.match(hostname) or
741
        # double-dots, meaning empty label
742
        ".." in hostname or
743
        # empty initial label
744
        hostname.startswith(".")):
745
      raise errors.OpPrereqError("Invalid hostname '%s'" % hostname,
746
                                 errors.ECODE_INVAL)
747
    if hostname.endswith("."):
748
      hostname = hostname.rstrip(".")
749
    return hostname
750

    
751

    
752
def GetHostInfo(name=None):
753
  """Lookup host name and raise an OpPrereqError for failures"""
754

    
755
  try:
756
    return HostInfo(name)
757
  except errors.ResolverError, err:
758
    raise errors.OpPrereqError("The given name (%s) does not resolve: %s" %
759
                               (err[0], err[2]), errors.ECODE_RESOLVER)
760

    
761

    
762
def ListVolumeGroups():
763
  """List volume groups and their size
764

765
  @rtype: dict
766
  @return:
767
       Dictionary with keys volume name and values
768
       the size of the volume
769

770
  """
771
  command = "vgs --noheadings --units m --nosuffix -o name,size"
772
  result = RunCmd(command)
773
  retval = {}
774
  if result.failed:
775
    return retval
776

    
777
  for line in result.stdout.splitlines():
778
    try:
779
      name, size = line.split()
780
      size = int(float(size))
781
    except (IndexError, ValueError), err:
782
      logging.error("Invalid output from vgs (%s): %s", err, line)
783
      continue
784

    
785
    retval[name] = size
786

    
787
  return retval
788

    
789

    
790
def BridgeExists(bridge):
791
  """Check whether the given bridge exists in the system
792

793
  @type bridge: str
794
  @param bridge: the bridge name to check
795
  @rtype: boolean
796
  @return: True if it does
797

798
  """
799
  return os.path.isdir("/sys/class/net/%s/bridge" % bridge)
800

    
801

    
802
def NiceSort(name_list):
803
  """Sort a list of strings based on digit and non-digit groupings.
804

805
  Given a list of names C{['a1', 'a10', 'a11', 'a2']} this function
806
  will sort the list in the logical order C{['a1', 'a2', 'a10',
807
  'a11']}.
808

809
  The sort algorithm breaks each name in groups of either only-digits
810
  or no-digits. Only the first eight such groups are considered, and
811
  after that we just use what's left of the string.
812

813
  @type name_list: list
814
  @param name_list: the names to be sorted
815
  @rtype: list
816
  @return: a copy of the name list sorted with our algorithm
817

818
  """
819
  _SORTER_BASE = "(\D+|\d+)"
820
  _SORTER_FULL = "^%s%s?%s?%s?%s?%s?%s?%s?.*$" % (_SORTER_BASE, _SORTER_BASE,
821
                                                  _SORTER_BASE, _SORTER_BASE,
822
                                                  _SORTER_BASE, _SORTER_BASE,
823
                                                  _SORTER_BASE, _SORTER_BASE)
824
  _SORTER_RE = re.compile(_SORTER_FULL)
825
  _SORTER_NODIGIT = re.compile("^\D*$")
826
  def _TryInt(val):
827
    """Attempts to convert a variable to integer."""
828
    if val is None or _SORTER_NODIGIT.match(val):
829
      return val
830
    rval = int(val)
831
    return rval
832

    
833
  to_sort = [([_TryInt(grp) for grp in _SORTER_RE.match(name).groups()], name)
834
             for name in name_list]
835
  to_sort.sort()
836
  return [tup[1] for tup in to_sort]
837

    
838

    
839
def TryConvert(fn, val):
840
  """Try to convert a value ignoring errors.
841

842
  This function tries to apply function I{fn} to I{val}. If no
843
  C{ValueError} or C{TypeError} exceptions are raised, it will return
844
  the result, else it will return the original value. Any other
845
  exceptions are propagated to the caller.
846

847
  @type fn: callable
848
  @param fn: function to apply to the value
849
  @param val: the value to be converted
850
  @return: The converted value if the conversion was successful,
851
      otherwise the original value.
852

853
  """
854
  try:
855
    nv = fn(val)
856
  except (ValueError, TypeError):
857
    nv = val
858
  return nv
859

    
860

    
861
def IsValidIP(ip):
862
  """Verifies the syntax of an IPv4 address.
863

864
  This function checks if the IPv4 address passes is valid or not based
865
  on syntax (not IP range, class calculations, etc.).
866

867
  @type ip: str
868
  @param ip: the address to be checked
869
  @rtype: a regular expression match object
870
  @return: a regular expression match object, or None if the
871
      address is not valid
872

873
  """
874
  unit = "(0|[1-9]\d{0,2})"
875
  #TODO: convert and return only boolean
876
  return re.match("^%s\.%s\.%s\.%s$" % (unit, unit, unit, unit), ip)
877

    
878

    
879
def IsValidShellParam(word):
880
  """Verifies is the given word is safe from the shell's p.o.v.
881

882
  This means that we can pass this to a command via the shell and be
883
  sure that it doesn't alter the command line and is passed as such to
884
  the actual command.
885

886
  Note that we are overly restrictive here, in order to be on the safe
887
  side.
888

889
  @type word: str
890
  @param word: the word to check
891
  @rtype: boolean
892
  @return: True if the word is 'safe'
893

894
  """
895
  return bool(re.match("^[-a-zA-Z0-9._+/:%@]+$", word))
896

    
897

    
898
def BuildShellCmd(template, *args):
899
  """Build a safe shell command line from the given arguments.
900

901
  This function will check all arguments in the args list so that they
902
  are valid shell parameters (i.e. they don't contain shell
903
  metacharacters). If everything is ok, it will return the result of
904
  template % args.
905

906
  @type template: str
907
  @param template: the string holding the template for the
908
      string formatting
909
  @rtype: str
910
  @return: the expanded command line
911

912
  """
913
  for word in args:
914
    if not IsValidShellParam(word):
915
      raise errors.ProgrammerError("Shell argument '%s' contains"
916
                                   " invalid characters" % word)
917
  return template % args
918

    
919

    
920
def FormatUnit(value, units):
921
  """Formats an incoming number of MiB with the appropriate unit.
922

923
  @type value: int
924
  @param value: integer representing the value in MiB (1048576)
925
  @type units: char
926
  @param units: the type of formatting we should do:
927
      - 'h' for automatic scaling
928
      - 'm' for MiBs
929
      - 'g' for GiBs
930
      - 't' for TiBs
931
  @rtype: str
932
  @return: the formatted value (with suffix)
933

934
  """
935
  if units not in ('m', 'g', 't', 'h'):
936
    raise errors.ProgrammerError("Invalid unit specified '%s'" % str(units))
937

    
938
  suffix = ''
939

    
940
  if units == 'm' or (units == 'h' and value < 1024):
941
    if units == 'h':
942
      suffix = 'M'
943
    return "%d%s" % (round(value, 0), suffix)
944

    
945
  elif units == 'g' or (units == 'h' and value < (1024 * 1024)):
946
    if units == 'h':
947
      suffix = 'G'
948
    return "%0.1f%s" % (round(float(value) / 1024, 1), suffix)
949

    
950
  else:
951
    if units == 'h':
952
      suffix = 'T'
953
    return "%0.1f%s" % (round(float(value) / 1024 / 1024, 1), suffix)
954

    
955

    
956
def ParseUnit(input_string):
957
  """Tries to extract number and scale from the given string.
958

959
  Input must be in the format C{NUMBER+ [DOT NUMBER+] SPACE*
960
  [UNIT]}. If no unit is specified, it defaults to MiB. Return value
961
  is always an int in MiB.
962

963
  """
964
  m = re.match('^([.\d]+)\s*([a-zA-Z]+)?$', str(input_string))
965
  if not m:
966
    raise errors.UnitParseError("Invalid format")
967

    
968
  value = float(m.groups()[0])
969

    
970
  unit = m.groups()[1]
971
  if unit:
972
    lcunit = unit.lower()
973
  else:
974
    lcunit = 'm'
975

    
976
  if lcunit in ('m', 'mb', 'mib'):
977
    # Value already in MiB
978
    pass
979

    
980
  elif lcunit in ('g', 'gb', 'gib'):
981
    value *= 1024
982

    
983
  elif lcunit in ('t', 'tb', 'tib'):
984
    value *= 1024 * 1024
985

    
986
  else:
987
    raise errors.UnitParseError("Unknown unit: %s" % unit)
988

    
989
  # Make sure we round up
990
  if int(value) < value:
991
    value += 1
992

    
993
  # Round up to the next multiple of 4
994
  value = int(value)
995
  if value % 4:
996
    value += 4 - value % 4
997

    
998
  return value
999

    
1000

    
1001
def AddAuthorizedKey(file_name, key):
1002
  """Adds an SSH public key to an authorized_keys file.
1003

1004
  @type file_name: str
1005
  @param file_name: path to authorized_keys file
1006
  @type key: str
1007
  @param key: string containing key
1008

1009
  """
1010
  key_fields = key.split()
1011

    
1012
  f = open(file_name, 'a+')
1013
  try:
1014
    nl = True
1015
    for line in f:
1016
      # Ignore whitespace changes
1017
      if line.split() == key_fields:
1018
        break
1019
      nl = line.endswith('\n')
1020
    else:
1021
      if not nl:
1022
        f.write("\n")
1023
      f.write(key.rstrip('\r\n'))
1024
      f.write("\n")
1025
      f.flush()
1026
  finally:
1027
    f.close()
1028

    
1029

    
1030
def RemoveAuthorizedKey(file_name, key):
1031
  """Removes an SSH public key from an authorized_keys file.
1032

1033
  @type file_name: str
1034
  @param file_name: path to authorized_keys file
1035
  @type key: str
1036
  @param key: string containing key
1037

1038
  """
1039
  key_fields = key.split()
1040

    
1041
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1042
  try:
1043
    out = os.fdopen(fd, 'w')
1044
    try:
1045
      f = open(file_name, 'r')
1046
      try:
1047
        for line in f:
1048
          # Ignore whitespace changes while comparing lines
1049
          if line.split() != key_fields:
1050
            out.write(line)
1051

    
1052
        out.flush()
1053
        os.rename(tmpname, file_name)
1054
      finally:
1055
        f.close()
1056
    finally:
1057
      out.close()
1058
  except:
1059
    RemoveFile(tmpname)
1060
    raise
1061

    
1062

    
1063
def SetEtcHostsEntry(file_name, ip, hostname, aliases):
1064
  """Sets the name of an IP address and hostname in /etc/hosts.
1065

1066
  @type file_name: str
1067
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1068
  @type ip: str
1069
  @param ip: the IP address
1070
  @type hostname: str
1071
  @param hostname: the hostname to be added
1072
  @type aliases: list
1073
  @param aliases: the list of aliases to add for the hostname
1074

1075
  """
1076
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1077
  # Ensure aliases are unique
1078
  aliases = UniqueSequence([hostname] + aliases)[1:]
1079

    
1080
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1081
  try:
1082
    out = os.fdopen(fd, 'w')
1083
    try:
1084
      f = open(file_name, 'r')
1085
      try:
1086
        for line in f:
1087
          fields = line.split()
1088
          if fields and not fields[0].startswith('#') and ip == fields[0]:
1089
            continue
1090
          out.write(line)
1091

    
1092
        out.write("%s\t%s" % (ip, hostname))
1093
        if aliases:
1094
          out.write(" %s" % ' '.join(aliases))
1095
        out.write('\n')
1096

    
1097
        out.flush()
1098
        os.fsync(out)
1099
        os.chmod(tmpname, 0644)
1100
        os.rename(tmpname, file_name)
1101
      finally:
1102
        f.close()
1103
    finally:
1104
      out.close()
1105
  except:
1106
    RemoveFile(tmpname)
1107
    raise
1108

    
1109

    
1110
def AddHostToEtcHosts(hostname):
1111
  """Wrapper around SetEtcHostsEntry.
1112

1113
  @type hostname: str
1114
  @param hostname: a hostname that will be resolved and added to
1115
      L{constants.ETC_HOSTS}
1116

1117
  """
1118
  hi = HostInfo(name=hostname)
1119
  SetEtcHostsEntry(constants.ETC_HOSTS, hi.ip, hi.name, [hi.ShortName()])
1120

    
1121

    
1122
def RemoveEtcHostsEntry(file_name, hostname):
1123
  """Removes a hostname from /etc/hosts.
1124

1125
  IP addresses without names are removed from the file.
1126

1127
  @type file_name: str
1128
  @param file_name: path to the file to modify (usually C{/etc/hosts})
1129
  @type hostname: str
1130
  @param hostname: the hostname to be removed
1131

1132
  """
1133
  # FIXME: use WriteFile + fn rather than duplicating its efforts
1134
  fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
1135
  try:
1136
    out = os.fdopen(fd, 'w')
1137
    try:
1138
      f = open(file_name, 'r')
1139
      try:
1140
        for line in f:
1141
          fields = line.split()
1142
          if len(fields) > 1 and not fields[0].startswith('#'):
1143
            names = fields[1:]
1144
            if hostname in names:
1145
              while hostname in names:
1146
                names.remove(hostname)
1147
              if names:
1148
                out.write("%s %s\n" % (fields[0], ' '.join(names)))
1149
              continue
1150

    
1151
          out.write(line)
1152

    
1153
        out.flush()
1154
        os.fsync(out)
1155
        os.chmod(tmpname, 0644)
1156
        os.rename(tmpname, file_name)
1157
      finally:
1158
        f.close()
1159
    finally:
1160
      out.close()
1161
  except:
1162
    RemoveFile(tmpname)
1163
    raise
1164

    
1165

    
1166
def RemoveHostFromEtcHosts(hostname):
1167
  """Wrapper around RemoveEtcHostsEntry.
1168

1169
  @type hostname: str
1170
  @param hostname: hostname that will be resolved and its
1171
      full and shot name will be removed from
1172
      L{constants.ETC_HOSTS}
1173

1174
  """
1175
  hi = HostInfo(name=hostname)
1176
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.name)
1177
  RemoveEtcHostsEntry(constants.ETC_HOSTS, hi.ShortName())
1178

    
1179

    
1180
def TimestampForFilename():
1181
  """Returns the current time formatted for filenames.
1182

1183
  The format doesn't contain colons as some shells and applications them as
1184
  separators.
1185

1186
  """
1187
  return time.strftime("%Y-%m-%d_%H_%M_%S")
1188

    
1189

    
1190
def CreateBackup(file_name):
1191
  """Creates a backup of a file.
1192

1193
  @type file_name: str
1194
  @param file_name: file to be backed up
1195
  @rtype: str
1196
  @return: the path to the newly created backup
1197
  @raise errors.ProgrammerError: for invalid file names
1198

1199
  """
1200
  if not os.path.isfile(file_name):
1201
    raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
1202
                                file_name)
1203

    
1204
  prefix = ("%s.backup-%s." %
1205
            (os.path.basename(file_name), TimestampForFilename()))
1206
  dir_name = os.path.dirname(file_name)
1207

    
1208
  fsrc = open(file_name, 'rb')
1209
  try:
1210
    (fd, backup_name) = tempfile.mkstemp(prefix=prefix, dir=dir_name)
1211
    fdst = os.fdopen(fd, 'wb')
1212
    try:
1213
      logging.debug("Backing up %s at %s", file_name, backup_name)
1214
      shutil.copyfileobj(fsrc, fdst)
1215
    finally:
1216
      fdst.close()
1217
  finally:
1218
    fsrc.close()
1219

    
1220
  return backup_name
1221

    
1222

    
1223
def ShellQuote(value):
1224
  """Quotes shell argument according to POSIX.
1225

1226
  @type value: str
1227
  @param value: the argument to be quoted
1228
  @rtype: str
1229
  @return: the quoted value
1230

1231
  """
1232
  if _re_shell_unquoted.match(value):
1233
    return value
1234
  else:
1235
    return "'%s'" % value.replace("'", "'\\''")
1236

    
1237

    
1238
def ShellQuoteArgs(args):
1239
  """Quotes a list of shell arguments.
1240

1241
  @type args: list
1242
  @param args: list of arguments to be quoted
1243
  @rtype: str
1244
  @return: the quoted arguments concatenated with spaces
1245

1246
  """
1247
  return ' '.join([ShellQuote(i) for i in args])
1248

    
1249

    
1250
def TcpPing(target, port, timeout=10, live_port_needed=False, source=None):
1251
  """Simple ping implementation using TCP connect(2).
1252

1253
  Check if the given IP is reachable by doing attempting a TCP connect
1254
  to it.
1255

1256
  @type target: str
1257
  @param target: the IP or hostname to ping
1258
  @type port: int
1259
  @param port: the port to connect to
1260
  @type timeout: int
1261
  @param timeout: the timeout on the connection attempt
1262
  @type live_port_needed: boolean
1263
  @param live_port_needed: whether a closed port will cause the
1264
      function to return failure, as if there was a timeout
1265
  @type source: str or None
1266
  @param source: if specified, will cause the connect to be made
1267
      from this specific source address; failures to bind other
1268
      than C{EADDRNOTAVAIL} will be ignored
1269

1270
  """
1271
  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1272

    
1273
  success = False
1274

    
1275
  if source is not None:
1276
    try:
1277
      sock.bind((source, 0))
1278
    except socket.error, (errcode, _):
1279
      if errcode == errno.EADDRNOTAVAIL:
1280
        success = False
1281

    
1282
  sock.settimeout(timeout)
1283

    
1284
  try:
1285
    sock.connect((target, port))
1286
    sock.close()
1287
    success = True
1288
  except socket.timeout:
1289
    success = False
1290
  except socket.error, (errcode, _):
1291
    success = (not live_port_needed) and (errcode == errno.ECONNREFUSED)
1292

    
1293
  return success
1294

    
1295

    
1296
def OwnIpAddress(address):
1297
  """Check if the current host has the the given IP address.
1298

1299
  Currently this is done by TCP-pinging the address from the loopback
1300
  address.
1301

1302
  @type address: string
1303
  @param address: the address to check
1304
  @rtype: bool
1305
  @return: True if we own the address
1306

1307
  """
1308
  return TcpPing(address, constants.DEFAULT_NODED_PORT,
1309
                 source=constants.LOCALHOST_IP_ADDRESS)
1310

    
1311

    
1312
def ListVisibleFiles(path):
1313
  """Returns a list of visible files in a directory.
1314

1315
  @type path: str
1316
  @param path: the directory to enumerate
1317
  @rtype: list
1318
  @return: the list of all files not starting with a dot
1319
  @raise ProgrammerError: if L{path} is not an absolue and normalized path
1320

1321
  """
1322
  if not IsNormAbsPath(path):
1323
    raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
1324
                                 " absolute/normalized: '%s'" % path)
1325
  files = [i for i in os.listdir(path) if not i.startswith(".")]
1326
  files.sort()
1327
  return files
1328

    
1329

    
1330
def GetHomeDir(user, default=None):
1331
  """Try to get the homedir of the given user.
1332

1333
  The user can be passed either as a string (denoting the name) or as
1334
  an integer (denoting the user id). If the user is not found, the
1335
  'default' argument is returned, which defaults to None.
1336

1337
  """
1338
  try:
1339
    if isinstance(user, basestring):
1340
      result = pwd.getpwnam(user)
1341
    elif isinstance(user, (int, long)):
1342
      result = pwd.getpwuid(user)
1343
    else:
1344
      raise errors.ProgrammerError("Invalid type passed to GetHomeDir (%s)" %
1345
                                   type(user))
1346
  except KeyError:
1347
    return default
1348
  return result.pw_dir
1349

    
1350

    
1351
def NewUUID():
1352
  """Returns a random UUID.
1353

1354
  @note: This is a Linux-specific method as it uses the /proc
1355
      filesystem.
1356
  @rtype: str
1357

1358
  """
1359
  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
1360

    
1361

    
1362
def GenerateSecret(numbytes=20):
1363
  """Generates a random secret.
1364

1365
  This will generate a pseudo-random secret returning an hex string
1366
  (so that it can be used where an ASCII string is needed).
1367

1368
  @param numbytes: the number of bytes which will be represented by the returned
1369
      string (defaulting to 20, the length of a SHA1 hash)
1370
  @rtype: str
1371
  @return: an hex representation of the pseudo-random sequence
1372

1373
  """
1374
  return os.urandom(numbytes).encode('hex')
1375

    
1376

    
1377
def EnsureDirs(dirs):
1378
  """Make required directories, if they don't exist.
1379

1380
  @param dirs: list of tuples (dir_name, dir_mode)
1381
  @type dirs: list of (string, integer)
1382

1383
  """
1384
  for dir_name, dir_mode in dirs:
1385
    try:
1386
      os.mkdir(dir_name, dir_mode)
1387
    except EnvironmentError, err:
1388
      if err.errno != errno.EEXIST:
1389
        raise errors.GenericError("Cannot create needed directory"
1390
                                  " '%s': %s" % (dir_name, err))
1391
    if not os.path.isdir(dir_name):
1392
      raise errors.GenericError("%s is not a directory" % dir_name)
1393

    
1394

    
1395
def ReadFile(file_name, size=-1):
1396
  """Reads a file.
1397

1398
  @type size: int
1399
  @param size: Read at most size bytes (if negative, entire file)
1400
  @rtype: str
1401
  @return: the (possibly partial) content of the file
1402

1403
  """
1404
  f = open(file_name, "r")
1405
  try:
1406
    return f.read(size)
1407
  finally:
1408
    f.close()
1409

    
1410

    
1411
def WriteFile(file_name, fn=None, data=None,
1412
              mode=None, uid=-1, gid=-1,
1413
              atime=None, mtime=None, close=True,
1414
              dry_run=False, backup=False,
1415
              prewrite=None, postwrite=None):
1416
  """(Over)write a file atomically.
1417

1418
  The file_name and either fn (a function taking one argument, the
1419
  file descriptor, and which should write the data to it) or data (the
1420
  contents of the file) must be passed. The other arguments are
1421
  optional and allow setting the file mode, owner and group, and the
1422
  mtime/atime of the file.
1423

1424
  If the function doesn't raise an exception, it has succeeded and the
1425
  target file has the new contents. If the function has raised an
1426
  exception, an existing target file should be unmodified and the
1427
  temporary file should be removed.
1428

1429
  @type file_name: str
1430
  @param file_name: the target filename
1431
  @type fn: callable
1432
  @param fn: content writing function, called with
1433
      file descriptor as parameter
1434
  @type data: str
1435
  @param data: contents of the file
1436
  @type mode: int
1437
  @param mode: file mode
1438
  @type uid: int
1439
  @param uid: the owner of the file
1440
  @type gid: int
1441
  @param gid: the group of the file
1442
  @type atime: int
1443
  @param atime: a custom access time to be set on the file
1444
  @type mtime: int
1445
  @param mtime: a custom modification time to be set on the file
1446
  @type close: boolean
1447
  @param close: whether to close file after writing it
1448
  @type prewrite: callable
1449
  @param prewrite: function to be called before writing content
1450
  @type postwrite: callable
1451
  @param postwrite: function to be called after writing content
1452

1453
  @rtype: None or int
1454
  @return: None if the 'close' parameter evaluates to True,
1455
      otherwise the file descriptor
1456

1457
  @raise errors.ProgrammerError: if any of the arguments are not valid
1458

1459
  """
1460
  if not os.path.isabs(file_name):
1461
    raise errors.ProgrammerError("Path passed to WriteFile is not"
1462
                                 " absolute: '%s'" % file_name)
1463

    
1464
  if [fn, data].count(None) != 1:
1465
    raise errors.ProgrammerError("fn or data required")
1466

    
1467
  if [atime, mtime].count(None) == 1:
1468
    raise errors.ProgrammerError("Both atime and mtime must be either"
1469
                                 " set or None")
1470

    
1471
  if backup and not dry_run and os.path.isfile(file_name):
1472
    CreateBackup(file_name)
1473

    
1474
  dir_name, base_name = os.path.split(file_name)
1475
  fd, new_name = tempfile.mkstemp('.new', base_name, dir_name)
1476
  do_remove = True
1477
  # here we need to make sure we remove the temp file, if any error
1478
  # leaves it in place
1479
  try:
1480
    if uid != -1 or gid != -1:
1481
      os.chown(new_name, uid, gid)
1482
    if mode:
1483
      os.chmod(new_name, mode)
1484
    if callable(prewrite):
1485
      prewrite(fd)
1486
    if data is not None:
1487
      os.write(fd, data)
1488
    else:
1489
      fn(fd)
1490
    if callable(postwrite):
1491
      postwrite(fd)
1492
    os.fsync(fd)
1493
    if atime is not None and mtime is not None:
1494
      os.utime(new_name, (atime, mtime))
1495
    if not dry_run:
1496
      os.rename(new_name, file_name)
1497
      do_remove = False
1498
  finally:
1499
    if close:
1500
      os.close(fd)
1501
      result = None
1502
    else:
1503
      result = fd
1504
    if do_remove:
1505
      RemoveFile(new_name)
1506

    
1507
  return result
1508

    
1509

    
1510
def ReadOneLineFile(file_name, strict=False):
1511
  """Return the first non-empty line from a file.
1512

1513
  @type strict: boolean
1514
  @param strict: if True, abort if the file has more than one
1515
      non-empty line
1516

1517
  """
1518
  file_lines = ReadFile(file_name).splitlines()
1519
  full_lines = filter(bool, file_lines)
1520
  if not file_lines or not full_lines:
1521
    raise errors.GenericError("No data in one-liner file %s" % file_name)
1522
  elif strict and len(full_lines) > 1:
1523
    raise errors.GenericError("Too many lines in one-liner file %s" %
1524
                              file_name)
1525
  return full_lines[0]
1526

    
1527

    
1528
def FirstFree(seq, base=0):
1529
  """Returns the first non-existing integer from seq.
1530

1531
  The seq argument should be a sorted list of positive integers. The
1532
  first time the index of an element is smaller than the element
1533
  value, the index will be returned.
1534

1535
  The base argument is used to start at a different offset,
1536
  i.e. C{[3, 4, 6]} with I{offset=3} will return 5.
1537

1538
  Example: C{[0, 1, 3]} will return I{2}.
1539

1540
  @type seq: sequence
1541
  @param seq: the sequence to be analyzed.
1542
  @type base: int
1543
  @param base: use this value as the base index of the sequence
1544
  @rtype: int
1545
  @return: the first non-used index in the sequence
1546

1547
  """
1548
  for idx, elem in enumerate(seq):
1549
    assert elem >= base, "Passed element is higher than base offset"
1550
    if elem > idx + base:
1551
      # idx is not used
1552
      return idx + base
1553
  return None
1554

    
1555

    
1556
def SingleWaitForFdCondition(fdobj, event, timeout):
1557
  """Waits for a condition to occur on the socket.
1558

1559
  Immediately returns at the first interruption.
1560

1561
  @type fdobj: integer or object supporting a fileno() method
1562
  @param fdobj: entity to wait for events on
1563
  @type event: integer
1564
  @param event: ORed condition (see select module)
1565
  @type timeout: float or None
1566
  @param timeout: Timeout in seconds
1567
  @rtype: int or None
1568
  @return: None for timeout, otherwise occured conditions
1569

1570
  """
1571
  check = (event | select.POLLPRI |
1572
           select.POLLNVAL | select.POLLHUP | select.POLLERR)
1573

    
1574
  if timeout is not None:
1575
    # Poller object expects milliseconds
1576
    timeout *= 1000
1577

    
1578
  poller = select.poll()
1579
  poller.register(fdobj, event)
1580
  try:
1581
    # TODO: If the main thread receives a signal and we have no timeout, we
1582
    # could wait forever. This should check a global "quit" flag or something
1583
    # every so often.
1584
    io_events = poller.poll(timeout)
1585
  except select.error, err:
1586
    if err[0] != errno.EINTR:
1587
      raise
1588
    io_events = []
1589
  if io_events and io_events[0][1] & check:
1590
    return io_events[0][1]
1591
  else:
1592
    return None
1593

    
1594

    
1595
class FdConditionWaiterHelper(object):
1596
  """Retry helper for WaitForFdCondition.
1597

1598
  This class contains the retried and wait functions that make sure
1599
  WaitForFdCondition can continue waiting until the timeout is actually
1600
  expired.
1601

1602
  """
1603

    
1604
  def __init__(self, timeout):
1605
    self.timeout = timeout
1606

    
1607
  def Poll(self, fdobj, event):
1608
    result = SingleWaitForFdCondition(fdobj, event, self.timeout)
1609
    if result is None:
1610
      raise RetryAgain()
1611
    else:
1612
      return result
1613

    
1614
  def UpdateTimeout(self, timeout):
1615
    self.timeout = timeout
1616

    
1617

    
1618
def WaitForFdCondition(fdobj, event, timeout):
1619
  """Waits for a condition to occur on the socket.
1620

1621
  Retries until the timeout is expired, even if interrupted.
1622

1623
  @type fdobj: integer or object supporting a fileno() method
1624
  @param fdobj: entity to wait for events on
1625
  @type event: integer
1626
  @param event: ORed condition (see select module)
1627
  @type timeout: float or None
1628
  @param timeout: Timeout in seconds
1629
  @rtype: int or None
1630
  @return: None for timeout, otherwise occured conditions
1631

1632
  """
1633
  if timeout is not None:
1634
    retrywaiter = FdConditionWaiterHelper(timeout)
1635
    try:
1636
      result = Retry(retrywaiter.Poll, RETRY_REMAINING_TIME, timeout,
1637
                     args=(fdobj, event), wait_fn=retrywaiter.UpdateTimeout)
1638
    except RetryTimeout:
1639
      result = None
1640
  else:
1641
    result = None
1642
    while result is None:
1643
      result = SingleWaitForFdCondition(fdobj, event, timeout)
1644
  return result
1645

    
1646

    
1647
def UniqueSequence(seq):
1648
  """Returns a list with unique elements.
1649

1650
  Element order is preserved.
1651

1652
  @type seq: sequence
1653
  @param seq: the sequence with the source elements
1654
  @rtype: list
1655
  @return: list of unique elements from seq
1656

1657
  """
1658
  seen = set()
1659
  return [i for i in seq if i not in seen and not seen.add(i)]
1660

    
1661

    
1662
def NormalizeAndValidateMac(mac):
1663
  """Normalizes and check if a MAC address is valid.
1664

1665
  Checks whether the supplied MAC address is formally correct, only
1666
  accepts colon separated format. Normalize it to all lower.
1667

1668
  @type mac: str
1669
  @param mac: the MAC to be validated
1670
  @rtype: str
1671
  @return: returns the normalized and validated MAC.
1672

1673
  @raise errors.OpPrereqError: If the MAC isn't valid
1674

1675
  """
1676
  mac_check = re.compile("^([0-9a-f]{2}(:|$)){6}$", re.I)
1677
  if not mac_check.match(mac):
1678
    raise errors.OpPrereqError("Invalid MAC address specified: %s" %
1679
                               mac, errors.ECODE_INVAL)
1680

    
1681
  return mac.lower()
1682

    
1683

    
1684
def TestDelay(duration):
1685
  """Sleep for a fixed amount of time.
1686

1687
  @type duration: float
1688
  @param duration: the sleep duration
1689
  @rtype: boolean
1690
  @return: False for negative value, True otherwise
1691

1692
  """
1693
  if duration < 0:
1694
    return False, "Invalid sleep duration"
1695
  time.sleep(duration)
1696
  return True, None
1697

    
1698

    
1699
def _CloseFDNoErr(fd, retries=5):
1700
  """Close a file descriptor ignoring errors.
1701

1702
  @type fd: int
1703
  @param fd: the file descriptor
1704
  @type retries: int
1705
  @param retries: how many retries to make, in case we get any
1706
      other error than EBADF
1707

1708
  """
1709
  try:
1710
    os.close(fd)
1711
  except OSError, err:
1712
    if err.errno != errno.EBADF:
1713
      if retries > 0:
1714
        _CloseFDNoErr(fd, retries - 1)
1715
    # else either it's closed already or we're out of retries, so we
1716
    # ignore this and go on
1717

    
1718

    
1719
def CloseFDs(noclose_fds=None):
1720
  """Close file descriptors.
1721

1722
  This closes all file descriptors above 2 (i.e. except
1723
  stdin/out/err).
1724

1725
  @type noclose_fds: list or None
1726
  @param noclose_fds: if given, it denotes a list of file descriptor
1727
      that should not be closed
1728

1729
  """
1730
  # Default maximum for the number of available file descriptors.
1731
  if 'SC_OPEN_MAX' in os.sysconf_names:
1732
    try:
1733
      MAXFD = os.sysconf('SC_OPEN_MAX')
1734
      if MAXFD < 0:
1735
        MAXFD = 1024
1736
    except OSError:
1737
      MAXFD = 1024
1738
  else:
1739
    MAXFD = 1024
1740
  maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
1741
  if (maxfd == resource.RLIM_INFINITY):
1742
    maxfd = MAXFD
1743

    
1744
  # Iterate through and close all file descriptors (except the standard ones)
1745
  for fd in range(3, maxfd):
1746
    if noclose_fds and fd in noclose_fds:
1747
      continue
1748
    _CloseFDNoErr(fd)
1749

    
1750

    
1751
def Mlockall():
1752
  """Lock current process' virtual address space into RAM.
1753

1754
  This is equivalent to the C call mlockall(MCL_CURRENT|MCL_FUTURE),
1755
  see mlock(2) for more details. This function requires ctypes module.
1756

1757
  """
1758
  if ctypes is None:
1759
    logging.warning("Cannot set memory lock, ctypes module not found")
1760
    return
1761

    
1762
  libc = ctypes.cdll.LoadLibrary("libc.so.6")
1763
  if libc is None:
1764
    logging.error("Cannot set memory lock, ctypes cannot load libc")
1765
    return
1766

    
1767
  # Some older version of the ctypes module don't have built-in functionality
1768
  # to access the errno global variable, where function error codes are stored.
1769
  # By declaring this variable as a pointer to an integer we can then access
1770
  # its value correctly, should the mlockall call fail, in order to see what
1771
  # the actual error code was.
1772
  libc.__errno_location.restype = ctypes.POINTER(ctypes.c_int)
1773

    
1774
  if libc.mlockall(_MCL_CURRENT | _MCL_FUTURE):
1775
    logging.error("Cannot set memory lock: %s",
1776
                  os.strerror(libc.__errno_location().contents.value))
1777
    return
1778

    
1779
  logging.debug("Memory lock set")
1780

    
1781

    
1782
def Daemonize(logfile):
1783
  """Daemonize the current process.
1784

1785
  This detaches the current process from the controlling terminal and
1786
  runs it in the background as a daemon.
1787

1788
  @type logfile: str
1789
  @param logfile: the logfile to which we should redirect stdout/stderr
1790
  @rtype: int
1791
  @return: the value zero
1792

1793
  """
1794
  # pylint: disable-msg=W0212
1795
  # yes, we really want os._exit
1796
  UMASK = 077
1797
  WORKDIR = "/"
1798

    
1799
  # this might fail
1800
  pid = os.fork()
1801
  if (pid == 0):  # The first child.
1802
    os.setsid()
1803
    # this might fail
1804
    pid = os.fork() # Fork a second child.
1805
    if (pid == 0):  # The second child.
1806
      os.chdir(WORKDIR)
1807
      os.umask(UMASK)
1808
    else:
1809
      # exit() or _exit()?  See below.
1810
      os._exit(0) # Exit parent (the first child) of the second child.
1811
  else:
1812
    os._exit(0) # Exit parent of the first child.
1813

    
1814
  for fd in range(3):
1815
    _CloseFDNoErr(fd)
1816
  i = os.open("/dev/null", os.O_RDONLY) # stdin
1817
  assert i == 0, "Can't close/reopen stdin"
1818
  i = os.open(logfile, os.O_WRONLY|os.O_CREAT|os.O_APPEND, 0600) # stdout
1819
  assert i == 1, "Can't close/reopen stdout"
1820
  # Duplicate standard output to standard error.
1821
  os.dup2(1, 2)
1822
  return 0
1823

    
1824

    
1825
def DaemonPidFileName(name):
1826
  """Compute a ganeti pid file absolute path
1827

1828
  @type name: str
1829
  @param name: the daemon name
1830
  @rtype: str
1831
  @return: the full path to the pidfile corresponding to the given
1832
      daemon name
1833

1834
  """
1835
  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
1836

    
1837

    
1838
def EnsureDaemon(name):
1839
  """Check for and start daemon if not alive.
1840

1841
  """
1842
  result = RunCmd([constants.DAEMON_UTIL, "check-and-start", name])
1843
  if result.failed:
1844
    logging.error("Can't start daemon '%s', failure %s, output: %s",
1845
                  name, result.fail_reason, result.output)
1846
    return False
1847

    
1848
  return True
1849

    
1850

    
1851
def WritePidFile(name):
1852
  """Write the current process pidfile.
1853

1854
  The file will be written to L{constants.RUN_GANETI_DIR}I{/name.pid}
1855

1856
  @type name: str
1857
  @param name: the daemon name to use
1858
  @raise errors.GenericError: if the pid file already exists and
1859
      points to a live process
1860

1861
  """
1862
  pid = os.getpid()
1863
  pidfilename = DaemonPidFileName(name)
1864
  if IsProcessAlive(ReadPidFile(pidfilename)):
1865
    raise errors.GenericError("%s contains a live process" % pidfilename)
1866

    
1867
  WriteFile(pidfilename, data="%d\n" % pid)
1868

    
1869

    
1870
def RemovePidFile(name):
1871
  """Remove the current process pidfile.
1872

1873
  Any errors are ignored.
1874

1875
  @type name: str
1876
  @param name: the daemon name used to derive the pidfile name
1877

1878
  """
1879
  pidfilename = DaemonPidFileName(name)
1880
  # TODO: we could check here that the file contains our pid
1881
  try:
1882
    RemoveFile(pidfilename)
1883
  except: # pylint: disable-msg=W0702
1884
    pass
1885

    
1886

    
1887
def KillProcess(pid, signal_=signal.SIGTERM, timeout=30,
1888
                waitpid=False):
1889
  """Kill a process given by its pid.
1890

1891
  @type pid: int
1892
  @param pid: The PID to terminate.
1893
  @type signal_: int
1894
  @param signal_: The signal to send, by default SIGTERM
1895
  @type timeout: int
1896
  @param timeout: The timeout after which, if the process is still alive,
1897
                  a SIGKILL will be sent. If not positive, no such checking
1898
                  will be done
1899
  @type waitpid: boolean
1900
  @param waitpid: If true, we should waitpid on this process after
1901
      sending signals, since it's our own child and otherwise it
1902
      would remain as zombie
1903

1904
  """
1905
  def _helper(pid, signal_, wait):
1906
    """Simple helper to encapsulate the kill/waitpid sequence"""
1907
    os.kill(pid, signal_)
1908
    if wait:
1909
      try:
1910
        os.waitpid(pid, os.WNOHANG)
1911
      except OSError:
1912
        pass
1913

    
1914
  if pid <= 0:
1915
    # kill with pid=0 == suicide
1916
    raise errors.ProgrammerError("Invalid pid given '%s'" % pid)
1917

    
1918
  if not IsProcessAlive(pid):
1919
    return
1920

    
1921
  _helper(pid, signal_, waitpid)
1922

    
1923
  if timeout <= 0:
1924
    return
1925

    
1926
  def _CheckProcess():
1927
    if not IsProcessAlive(pid):
1928
      return
1929

    
1930
    try:
1931
      (result_pid, _) = os.waitpid(pid, os.WNOHANG)
1932
    except OSError:
1933
      raise RetryAgain()
1934

    
1935
    if result_pid > 0:
1936
      return
1937

    
1938
    raise RetryAgain()
1939

    
1940
  try:
1941
    # Wait up to $timeout seconds
1942
    Retry(_CheckProcess, (0.01, 1.5, 0.1), timeout)
1943
  except RetryTimeout:
1944
    pass
1945

    
1946
  if IsProcessAlive(pid):
1947
    # Kill process if it's still alive
1948
    _helper(pid, signal.SIGKILL, waitpid)
1949

    
1950

    
1951
def FindFile(name, search_path, test=os.path.exists):
1952
  """Look for a filesystem object in a given path.
1953

1954
  This is an abstract method to search for filesystem object (files,
1955
  dirs) under a given search path.
1956

1957
  @type name: str
1958
  @param name: the name to look for
1959
  @type search_path: str
1960
  @param search_path: location to start at
1961
  @type test: callable
1962
  @param test: a function taking one argument that should return True
1963
      if the a given object is valid; the default value is
1964
      os.path.exists, causing only existing files to be returned
1965
  @rtype: str or None
1966
  @return: full path to the object if found, None otherwise
1967

1968
  """
1969
  # validate the filename mask
1970
  if constants.EXT_PLUGIN_MASK.match(name) is None:
1971
    logging.critical("Invalid value passed for external script name: '%s'",
1972
                     name)
1973
    return None
1974

    
1975
  for dir_name in search_path:
1976
    # FIXME: investigate switch to PathJoin
1977
    item_name = os.path.sep.join([dir_name, name])
1978
    # check the user test and that we're indeed resolving to the given
1979
    # basename
1980
    if test(item_name) and os.path.basename(item_name) == name:
1981
      return item_name
1982
  return None
1983

    
1984

    
1985
def CheckVolumeGroupSize(vglist, vgname, minsize):
1986
  """Checks if the volume group list is valid.
1987

1988
  The function will check if a given volume group is in the list of
1989
  volume groups and has a minimum size.
1990

1991
  @type vglist: dict
1992
  @param vglist: dictionary of volume group names and their size
1993
  @type vgname: str
1994
  @param vgname: the volume group we should check
1995
  @type minsize: int
1996
  @param minsize: the minimum size we accept
1997
  @rtype: None or str
1998
  @return: None for success, otherwise the error message
1999

2000
  """
2001
  vgsize = vglist.get(vgname, None)
2002
  if vgsize is None:
2003
    return "volume group '%s' missing" % vgname
2004
  elif vgsize < minsize:
2005
    return ("volume group '%s' too small (%s MiB required, %d MiB found)" %
2006
            (vgname, minsize, vgsize))
2007
  return None
2008

    
2009

    
2010
def SplitTime(value):
2011
  """Splits time as floating point number into a tuple.
2012

2013
  @param value: Time in seconds
2014
  @type value: int or float
2015
  @return: Tuple containing (seconds, microseconds)
2016

2017
  """
2018
  (seconds, microseconds) = divmod(int(value * 1000000), 1000000)
2019

    
2020
  assert 0 <= seconds, \
2021
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2022
  assert 0 <= microseconds <= 999999, \
2023
    "Microseconds must be 0-999999, but are %s" % microseconds
2024

    
2025
  return (int(seconds), int(microseconds))
2026

    
2027

    
2028
def MergeTime(timetuple):
2029
  """Merges a tuple into time as a floating point number.
2030

2031
  @param timetuple: Time as tuple, (seconds, microseconds)
2032
  @type timetuple: tuple
2033
  @return: Time as a floating point number expressed in seconds
2034

2035
  """
2036
  (seconds, microseconds) = timetuple
2037

    
2038
  assert 0 <= seconds, \
2039
    "Seconds must be larger than or equal to 0, but are %s" % seconds
2040
  assert 0 <= microseconds <= 999999, \
2041
    "Microseconds must be 0-999999, but are %s" % microseconds
2042

    
2043
  return float(seconds) + (float(microseconds) * 0.000001)
2044

    
2045

    
2046
def GetDaemonPort(daemon_name):
2047
  """Get the daemon port for this cluster.
2048

2049
  Note that this routine does not read a ganeti-specific file, but
2050
  instead uses C{socket.getservbyname} to allow pre-customization of
2051
  this parameter outside of Ganeti.
2052

2053
  @type daemon_name: string
2054
  @param daemon_name: daemon name (in constants.DAEMONS_PORTS)
2055
  @rtype: int
2056

2057
  """
2058
  if daemon_name not in constants.DAEMONS_PORTS:
2059
    raise errors.ProgrammerError("Unknown daemon: %s" % daemon_name)
2060

    
2061
  (proto, default_port) = constants.DAEMONS_PORTS[daemon_name]
2062
  try:
2063
    port = socket.getservbyname(daemon_name, proto)
2064
  except socket.error:
2065
    port = default_port
2066

    
2067
  return port
2068

    
2069

    
2070
class LogFileHandler(logging.FileHandler):
2071
  """Log handler that doesn't fallback to stderr.
2072

2073
  When an error occurs while writing on the logfile, logging.FileHandler tries
2074
  to log on stderr. This doesn't work in ganeti since stderr is redirected to
2075
  the logfile. This class avoids failures reporting errors to /dev/console.
2076

2077
  """
2078
  def __init__(self, filename, mode="a", encoding=None):
2079
    """Open the specified file and use it as the stream for logging.
2080

2081
    Also open /dev/console to report errors while logging.
2082

2083
    """
2084
    logging.FileHandler.__init__(self, filename, mode, encoding)
2085
    self.console = open(constants.DEV_CONSOLE, "a")
2086

    
2087
  def handleError(self, record):
2088
    """Handle errors which occur during an emit() call.
2089

2090
    Try to handle errors with FileHandler method, if it fails write to
2091
    /dev/console.
2092

2093
    """
2094
    try:
2095
      logging.FileHandler.handleError(self, record)
2096
    except Exception:
2097
      try:
2098
        self.console.write("Cannot log message:\n%s\n" % self.format(record))
2099
      except Exception:
2100
        # Log handler tried everything it could, now just give up
2101
        pass
2102

    
2103

    
2104
def SetupLogging(logfile, debug=0, stderr_logging=False, program="",
2105
                 multithreaded=False, syslog=constants.SYSLOG_USAGE,
2106
                 console_logging=False):
2107
  """Configures the logging module.
2108

2109
  @type logfile: str
2110
  @param logfile: the filename to which we should log
2111
  @type debug: integer
2112
  @param debug: if greater than zero, enable debug messages, otherwise
2113
      only those at C{INFO} and above level
2114
  @type stderr_logging: boolean
2115
  @param stderr_logging: whether we should also log to the standard error
2116
  @type program: str
2117
  @param program: the name under which we should log messages
2118
  @type multithreaded: boolean
2119
  @param multithreaded: if True, will add the thread name to the log file
2120
  @type syslog: string
2121
  @param syslog: one of 'no', 'yes', 'only':
2122
      - if no, syslog is not used
2123
      - if yes, syslog is used (in addition to file-logging)
2124
      - if only, only syslog is used
2125
  @type console_logging: boolean
2126
  @param console_logging: if True, will use a FileHandler which falls back to
2127
      the system console if logging fails
2128
  @raise EnvironmentError: if we can't open the log file and
2129
      syslog/stderr logging is disabled
2130

2131
  """
2132
  fmt = "%(asctime)s: " + program + " pid=%(process)d"
2133
  sft = program + "[%(process)d]:"
2134
  if multithreaded:
2135
    fmt += "/%(threadName)s"
2136
    sft += " (%(threadName)s)"
2137
  if debug:
2138
    fmt += " %(module)s:%(lineno)s"
2139
    # no debug info for syslog loggers
2140
  fmt += " %(levelname)s %(message)s"
2141
  # yes, we do want the textual level, as remote syslog will probably
2142
  # lose the error level, and it's easier to grep for it
2143
  sft += " %(levelname)s %(message)s"
2144
  formatter = logging.Formatter(fmt)
2145
  sys_fmt = logging.Formatter(sft)
2146

    
2147
  root_logger = logging.getLogger("")
2148
  root_logger.setLevel(logging.NOTSET)
2149

    
2150
  # Remove all previously setup handlers
2151
  for handler in root_logger.handlers:
2152
    handler.close()
2153
    root_logger.removeHandler(handler)
2154

    
2155
  if stderr_logging:
2156
    stderr_handler = logging.StreamHandler()
2157
    stderr_handler.setFormatter(formatter)
2158
    if debug:
2159
      stderr_handler.setLevel(logging.NOTSET)
2160
    else:
2161
      stderr_handler.setLevel(logging.CRITICAL)
2162
    root_logger.addHandler(stderr_handler)
2163

    
2164
  if syslog in (constants.SYSLOG_YES, constants.SYSLOG_ONLY):
2165
    facility = logging.handlers.SysLogHandler.LOG_DAEMON
2166
    syslog_handler = logging.handlers.SysLogHandler(constants.SYSLOG_SOCKET,
2167
                                                    facility)
2168
    syslog_handler.setFormatter(sys_fmt)
2169
    # Never enable debug over syslog
2170
    syslog_handler.setLevel(logging.INFO)
2171
    root_logger.addHandler(syslog_handler)
2172

    
2173
  if syslog != constants.SYSLOG_ONLY:
2174
    # this can fail, if the logging directories are not setup or we have
2175
    # a permisssion problem; in this case, it's best to log but ignore
2176
    # the error if stderr_logging is True, and if false we re-raise the
2177
    # exception since otherwise we could run but without any logs at all
2178
    try:
2179
      if console_logging:
2180
        logfile_handler = LogFileHandler(logfile)
2181
      else:
2182
        logfile_handler = logging.FileHandler(logfile)
2183
      logfile_handler.setFormatter(formatter)
2184
      if debug:
2185
        logfile_handler.setLevel(logging.DEBUG)
2186
      else:
2187
        logfile_handler.setLevel(logging.INFO)
2188
      root_logger.addHandler(logfile_handler)
2189
    except EnvironmentError:
2190
      if stderr_logging or syslog == constants.SYSLOG_YES:
2191
        logging.exception("Failed to enable logging to file '%s'", logfile)
2192
      else:
2193
        # we need to re-raise the exception
2194
        raise
2195

    
2196

    
2197
def IsNormAbsPath(path):
2198
  """Check whether a path is absolute and also normalized
2199

2200
  This avoids things like /dir/../../other/path to be valid.
2201

2202
  """
2203
  return os.path.normpath(path) == path and os.path.isabs(path)
2204

    
2205

    
2206
def PathJoin(*args):
2207
  """Safe-join a list of path components.
2208

2209
  Requirements:
2210
      - the first argument must be an absolute path
2211
      - no component in the path must have backtracking (e.g. /../),
2212
        since we check for normalization at the end
2213

2214
  @param args: the path components to be joined
2215
  @raise ValueError: for invalid paths
2216

2217
  """
2218
  # ensure we're having at least one path passed in
2219
  assert args
2220
  # ensure the first component is an absolute and normalized path name
2221
  root = args[0]
2222
  if not IsNormAbsPath(root):
2223
    raise ValueError("Invalid parameter to PathJoin: '%s'" % str(args[0]))
2224
  result = os.path.join(*args)
2225
  # ensure that the whole path is normalized
2226
  if not IsNormAbsPath(result):
2227
    raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
2228
  # check that we're still under the original prefix
2229
  prefix = os.path.commonprefix([root, result])
2230
  if prefix != root:
2231
    raise ValueError("Error: path joining resulted in different prefix"
2232
                     " (%s != %s)" % (prefix, root))
2233
  return result
2234

    
2235

    
2236
def TailFile(fname, lines=20):
2237
  """Return the last lines from a file.
2238

2239
  @note: this function will only read and parse the last 4KB of
2240
      the file; if the lines are very long, it could be that less
2241
      than the requested number of lines are returned
2242

2243
  @param fname: the file name
2244
  @type lines: int
2245
  @param lines: the (maximum) number of lines to return
2246

2247
  """
2248
  fd = open(fname, "r")
2249
  try:
2250
    fd.seek(0, 2)
2251
    pos = fd.tell()
2252
    pos = max(0, pos-4096)
2253
    fd.seek(pos, 0)
2254
    raw_data = fd.read()
2255
  finally:
2256
    fd.close()
2257

    
2258
  rows = raw_data.splitlines()
2259
  return rows[-lines:]
2260

    
2261

    
2262
def _ParseAsn1Generalizedtime(value):
2263
  """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
2264

2265
  @type value: string
2266
  @param value: ASN1 GENERALIZEDTIME timestamp
2267

2268
  """
2269
  m = re.match(r"^(\d+)([-+]\d\d)(\d\d)$", value)
2270
  if m:
2271
    # We have an offset
2272
    asn1time = m.group(1)
2273
    hours = int(m.group(2))
2274
    minutes = int(m.group(3))
2275
    utcoffset = (60 * hours) + minutes
2276
  else:
2277
    if not value.endswith("Z"):
2278
      raise ValueError("Missing timezone")
2279
    asn1time = value[:-1]
2280
    utcoffset = 0
2281

    
2282
  parsed = time.strptime(asn1time, "%Y%m%d%H%M%S")
2283

    
2284
  tt = datetime.datetime(*(parsed[:7])) - datetime.timedelta(minutes=utcoffset)
2285

    
2286
  return calendar.timegm(tt.utctimetuple())
2287

    
2288

    
2289
def GetX509CertValidity(cert):
2290
  """Returns the validity period of the certificate.
2291

2292
  @type cert: OpenSSL.crypto.X509
2293
  @param cert: X509 certificate object
2294

2295
  """
2296
  # The get_notBefore and get_notAfter functions are only supported in
2297
  # pyOpenSSL 0.7 and above.
2298
  try:
2299
    get_notbefore_fn = cert.get_notBefore
2300
  except AttributeError:
2301
    not_before = None
2302
  else:
2303
    not_before_asn1 = get_notbefore_fn()
2304

    
2305
    if not_before_asn1 is None:
2306
      not_before = None
2307
    else:
2308
      not_before = _ParseAsn1Generalizedtime(not_before_asn1)
2309

    
2310
  try:
2311
    get_notafter_fn = cert.get_notAfter
2312
  except AttributeError:
2313
    not_after = None
2314
  else:
2315
    not_after_asn1 = get_notafter_fn()
2316

    
2317
    if not_after_asn1 is None:
2318
      not_after = None
2319
    else:
2320
      not_after = _ParseAsn1Generalizedtime(not_after_asn1)
2321

    
2322
  return (not_before, not_after)
2323

    
2324

    
2325
def SafeEncode(text):
2326
  """Return a 'safe' version of a source string.
2327

2328
  This function mangles the input string and returns a version that
2329
  should be safe to display/encode as ASCII. To this end, we first
2330
  convert it to ASCII using the 'backslashreplace' encoding which
2331
  should get rid of any non-ASCII chars, and then we process it
2332
  through a loop copied from the string repr sources in the python; we
2333
  don't use string_escape anymore since that escape single quotes and
2334
  backslashes too, and that is too much; and that escaping is not
2335
  stable, i.e. string_escape(string_escape(x)) != string_escape(x).
2336

2337
  @type text: str or unicode
2338
  @param text: input data
2339
  @rtype: str
2340
  @return: a safe version of text
2341

2342
  """
2343
  if isinstance(text, unicode):
2344
    # only if unicode; if str already, we handle it below
2345
    text = text.encode('ascii', 'backslashreplace')
2346
  resu = ""
2347
  for char in text:
2348
    c = ord(char)
2349
    if char  == '\t':
2350
      resu += r'\t'
2351
    elif char == '\n':
2352
      resu += r'\n'
2353
    elif char == '\r':
2354
      resu += r'\'r'
2355
    elif c < 32 or c >= 127: # non-printable
2356
      resu += "\\x%02x" % (c & 0xff)
2357
    else:
2358
      resu += char
2359
  return resu
2360

    
2361

    
2362
def UnescapeAndSplit(text, sep=","):
2363
  """Split and unescape a string based on a given separator.
2364

2365
  This function splits a string based on a separator where the
2366
  separator itself can be escape in order to be an element of the
2367
  elements. The escaping rules are (assuming coma being the
2368
  separator):
2369
    - a plain , separates the elements
2370
    - a sequence \\\\, (double backslash plus comma) is handled as a
2371
      backslash plus a separator comma
2372
    - a sequence \, (backslash plus comma) is handled as a
2373
      non-separator comma
2374

2375
  @type text: string
2376
  @param text: the string to split
2377
  @type sep: string
2378
  @param text: the separator
2379
  @rtype: string
2380
  @return: a list of strings
2381

2382
  """
2383
  # we split the list by sep (with no escaping at this stage)
2384
  slist = text.split(sep)
2385
  # next, we revisit the elements and if any of them ended with an odd
2386
  # number of backslashes, then we join it with the next
2387
  rlist = []
2388
  while slist:
2389
    e1 = slist.pop(0)
2390
    if e1.endswith("\\"):
2391
      num_b = len(e1) - len(e1.rstrip("\\"))
2392
      if num_b % 2 == 1:
2393
        e2 = slist.pop(0)
2394
        # here the backslashes remain (all), and will be reduced in
2395
        # the next step
2396
        rlist.append(e1 + sep + e2)
2397
        continue
2398
    rlist.append(e1)
2399
  # finally, replace backslash-something with something
2400
  rlist = [re.sub(r"\\(.)", r"\1", v) for v in rlist]
2401
  return rlist
2402

    
2403

    
2404
def CommaJoin(names):
2405
  """Nicely join a set of identifiers.
2406

2407
  @param names: set, list or tuple
2408
  @return: a string with the formatted results
2409

2410
  """
2411
  return ", ".join([str(val) for val in names])
2412

    
2413

    
2414
def BytesToMebibyte(value):
2415
  """Converts bytes to mebibytes.
2416

2417
  @type value: int
2418
  @param value: Value in bytes
2419
  @rtype: int
2420
  @return: Value in mebibytes
2421

2422
  """
2423
  return int(round(value / (1024.0 * 1024.0), 0))
2424

    
2425

    
2426
def CalculateDirectorySize(path):
2427
  """Calculates the size of a directory recursively.
2428

2429
  @type path: string
2430
  @param path: Path to directory
2431
  @rtype: int
2432
  @return: Size in mebibytes
2433

2434
  """
2435
  size = 0
2436

    
2437
  for (curpath, _, files) in os.walk(path):
2438
    for filename in files:
2439
      st = os.lstat(PathJoin(curpath, filename))
2440
      size += st.st_size
2441

    
2442
  return BytesToMebibyte(size)
2443

    
2444

    
2445
def GetFilesystemStats(path):
2446
  """Returns the total and free space on a filesystem.
2447

2448
  @type path: string
2449
  @param path: Path on filesystem to be examined
2450
  @rtype: int
2451
  @return: tuple of (Total space, Free space) in mebibytes
2452

2453
  """
2454
  st = os.statvfs(path)
2455

    
2456
  fsize = BytesToMebibyte(st.f_bavail * st.f_frsize)
2457
  tsize = BytesToMebibyte(st.f_blocks * st.f_frsize)
2458
  return (tsize, fsize)
2459

    
2460

    
2461
def RunInSeparateProcess(fn, *args):
2462
  """Runs a function in a separate process.
2463

2464
  Note: Only boolean return values are supported.
2465

2466
  @type fn: callable
2467
  @param fn: Function to be called
2468
  @rtype: bool
2469
  @return: Function's result
2470

2471
  """
2472
  pid = os.fork()
2473
  if pid == 0:
2474
    # Child process
2475
    try:
2476
      # In case the function uses temporary files
2477
      ResetTempfileModule()
2478

    
2479
      # Call function
2480
      result = int(bool(fn(*args)))
2481
      assert result in (0, 1)
2482
    except: # pylint: disable-msg=W0702
2483
      logging.exception("Error while calling function in separate process")
2484
      # 0 and 1 are reserved for the return value
2485
      result = 33
2486

    
2487
    os._exit(result) # pylint: disable-msg=W0212
2488

    
2489
  # Parent process
2490

    
2491
  # Avoid zombies and check exit code
2492
  (_, status) = os.waitpid(pid, 0)
2493

    
2494
  if os.WIFSIGNALED(status):
2495
    exitcode = None
2496
    signum = os.WTERMSIG(status)
2497
  else:
2498
    exitcode = os.WEXITSTATUS(status)
2499
    signum = None
2500

    
2501
  if not (exitcode in (0, 1) and signum is None):
2502
    raise errors.GenericError("Child program failed (code=%s, signal=%s)" %
2503
                              (exitcode, signum))
2504

    
2505
  return bool(exitcode)
2506

    
2507

    
2508
def LockedMethod(fn):
2509
  """Synchronized object access decorator.
2510

2511
  This decorator is intended to protect access to an object using the
2512
  object's own lock which is hardcoded to '_lock'.
2513

2514
  """
2515
  def _LockDebug(*args, **kwargs):
2516
    if debug_locks:
2517
      logging.debug(*args, **kwargs)
2518

    
2519
  def wrapper(self, *args, **kwargs):
2520
    # pylint: disable-msg=W0212
2521
    assert hasattr(self, '_lock')
2522
    lock = self._lock
2523
    _LockDebug("Waiting for %s", lock)
2524
    lock.acquire()
2525
    try:
2526
      _LockDebug("Acquired %s", lock)
2527
      result = fn(self, *args, **kwargs)
2528
    finally:
2529
      _LockDebug("Releasing %s", lock)
2530
      lock.release()
2531
      _LockDebug("Released %s", lock)
2532
    return result
2533
  return wrapper
2534

    
2535

    
2536
def LockFile(fd):
2537
  """Locks a file using POSIX locks.
2538

2539
  @type fd: int
2540
  @param fd: the file descriptor we need to lock
2541

2542
  """
2543
  try:
2544
    fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
2545
  except IOError, err:
2546
    if err.errno == errno.EAGAIN:
2547
      raise errors.LockError("File already locked")
2548
    raise
2549

    
2550

    
2551
def FormatTime(val):
2552
  """Formats a time value.
2553

2554
  @type val: float or None
2555
  @param val: the timestamp as returned by time.time()
2556
  @return: a string value or N/A if we don't have a valid timestamp
2557

2558
  """
2559
  if val is None or not isinstance(val, (int, float)):
2560
    return "N/A"
2561
  # these two codes works on Linux, but they are not guaranteed on all
2562
  # platforms
2563
  return time.strftime("%F %T", time.localtime(val))
2564

    
2565

    
2566
def ReadWatcherPauseFile(filename, now=None, remove_after=3600):
2567
  """Reads the watcher pause file.
2568

2569
  @type filename: string
2570
  @param filename: Path to watcher pause file
2571
  @type now: None, float or int
2572
  @param now: Current time as Unix timestamp
2573
  @type remove_after: int
2574
  @param remove_after: Remove watcher pause file after specified amount of
2575
    seconds past the pause end time
2576

2577
  """
2578
  if now is None:
2579
    now = time.time()
2580

    
2581
  try:
2582
    value = ReadFile(filename)
2583
  except IOError, err:
2584
    if err.errno != errno.ENOENT:
2585
      raise
2586
    value = None
2587

    
2588
  if value is not None:
2589
    try:
2590
      value = int(value)
2591
    except ValueError:
2592
      logging.warning(("Watcher pause file (%s) contains invalid value,"
2593
                       " removing it"), filename)
2594
      RemoveFile(filename)
2595
      value = None
2596

    
2597
    if value is not None:
2598
      # Remove file if it's outdated
2599
      if now > (value + remove_after):
2600
        RemoveFile(filename)
2601
        value = None
2602

    
2603
      elif now > value:
2604
        value = None
2605

    
2606
  return value
2607

    
2608

    
2609
class RetryTimeout(Exception):
2610
  """Retry loop timed out.
2611

2612
  Any arguments which was passed by the retried function to RetryAgain will be
2613
  preserved in RetryTimeout, if it is raised. If such argument was an exception
2614
  the RaiseInner helper method will reraise it.
2615

2616
  """
2617
  def RaiseInner(self):
2618
    if self.args and isinstance(self.args[0], Exception):
2619
      raise self.args[0]
2620
    else:
2621
      raise RetryTimeout(*self.args)
2622

    
2623

    
2624
class RetryAgain(Exception):
2625
  """Retry again.
2626

2627
  Any arguments passed to RetryAgain will be preserved, if a timeout occurs, as
2628
  arguments to RetryTimeout. If an exception is passed, the RaiseInner() method
2629
  of the RetryTimeout() method can be used to reraise it.
2630

2631
  """
2632

    
2633

    
2634
class _RetryDelayCalculator(object):
2635
  """Calculator for increasing delays.
2636

2637
  """
2638
  __slots__ = [
2639
    "_factor",
2640
    "_limit",
2641
    "_next",
2642
    "_start",
2643
    ]
2644

    
2645
  def __init__(self, start, factor, limit):
2646
    """Initializes this class.
2647

2648
    @type start: float
2649
    @param start: Initial delay
2650
    @type factor: float
2651
    @param factor: Factor for delay increase
2652
    @type limit: float or None
2653
    @param limit: Upper limit for delay or None for no limit
2654

2655
    """
2656
    assert start > 0.0
2657
    assert factor >= 1.0
2658
    assert limit is None or limit >= 0.0
2659

    
2660
    self._start = start
2661
    self._factor = factor
2662
    self._limit = limit
2663

    
2664
    self._next = start
2665

    
2666
  def __call__(self):
2667
    """Returns current delay and calculates the next one.
2668

2669
    """
2670
    current = self._next
2671

    
2672
    # Update for next run
2673
    if self._limit is None or self._next < self._limit:
2674
      self._next = min(self._limit, self._next * self._factor)
2675

    
2676
    return current
2677

    
2678

    
2679
#: Special delay to specify whole remaining timeout
2680
RETRY_REMAINING_TIME = object()
2681

    
2682

    
2683
def Retry(fn, delay, timeout, args=None, wait_fn=time.sleep,
2684
          _time_fn=time.time):
2685
  """Call a function repeatedly until it succeeds.
2686

2687
  The function C{fn} is called repeatedly until it doesn't throw L{RetryAgain}
2688
  anymore. Between calls a delay, specified by C{delay}, is inserted. After a
2689
  total of C{timeout} seconds, this function throws L{RetryTimeout}.
2690

2691
  C{delay} can be one of the following:
2692
    - callable returning the delay length as a float
2693
    - Tuple of (start, factor, limit)
2694
    - L{RETRY_REMAINING_TIME} to sleep until the timeout expires (this is
2695
      useful when overriding L{wait_fn} to wait for an external event)
2696
    - A static delay as a number (int or float)
2697

2698
  @type fn: callable
2699
  @param fn: Function to be called
2700
  @param delay: Either a callable (returning the delay), a tuple of (start,
2701
                factor, limit) (see L{_RetryDelayCalculator}),
2702
                L{RETRY_REMAINING_TIME} or a number (int or float)
2703
  @type timeout: float
2704
  @param timeout: Total timeout
2705
  @type wait_fn: callable
2706
  @param wait_fn: Waiting function
2707
  @return: Return value of function
2708

2709
  """
2710
  assert callable(fn)
2711
  assert callable(wait_fn)
2712
  assert callable(_time_fn)
2713

    
2714
  if args is None:
2715
    args = []
2716

    
2717
  end_time = _time_fn() + timeout
2718

    
2719
  if callable(delay):
2720
    # External function to calculate delay
2721
    calc_delay = delay
2722

    
2723
  elif isinstance(delay, (tuple, list)):
2724
    # Increasing delay with optional upper boundary
2725
    (start, factor, limit) = delay
2726
    calc_delay = _RetryDelayCalculator(start, factor, limit)
2727

    
2728
  elif delay is RETRY_REMAINING_TIME:
2729
    # Always use the remaining time
2730
    calc_delay = None
2731

    
2732
  else:
2733
    # Static delay
2734
    calc_delay = lambda: delay
2735

    
2736
  assert calc_delay is None or callable(calc_delay)
2737

    
2738
  while True:
2739
    retry_args = []
2740
    try:
2741
      # pylint: disable-msg=W0142
2742
      return fn(*args)
2743
    except RetryAgain, err:
2744
      retry_args = err.args
2745
    except RetryTimeout:
2746
      raise errors.ProgrammerError("Nested retry loop detected that didn't"
2747
                                   " handle RetryTimeout")
2748

    
2749
    remaining_time = end_time - _time_fn()
2750

    
2751
    if remaining_time < 0.0:
2752
      # pylint: disable-msg=W0142
2753
      raise RetryTimeout(*retry_args)
2754

    
2755
    assert remaining_time >= 0.0
2756

    
2757
    if calc_delay is None:
2758
      wait_fn(remaining_time)
2759
    else:
2760
      current_delay = calc_delay()
2761
      if current_delay > 0.0:
2762
        wait_fn(current_delay)
2763

    
2764

    
2765
class FileLock(object):
2766
  """Utility class for file locks.
2767

2768
  """
2769
  def __init__(self, fd, filename):
2770
    """Constructor for FileLock.
2771

2772
    @type fd: file
2773
    @param fd: File object
2774
    @type filename: str
2775
    @param filename: Path of the file opened at I{fd}
2776

2777
    """
2778
    self.fd = fd
2779
    self.filename = filename
2780

    
2781
  @classmethod
2782
  def Open(cls, filename):
2783
    """Creates and opens a file to be used as a file-based lock.
2784

2785
    @type filename: string
2786
    @param filename: path to the file to be locked
2787

2788
    """
2789
    # Using "os.open" is necessary to allow both opening existing file
2790
    # read/write and creating if not existing. Vanilla "open" will truncate an
2791
    # existing file -or- allow creating if not existing.
2792
    return cls(os.fdopen(os.open(filename, os.O_RDWR | os.O_CREAT), "w+"),
2793
               filename)
2794

    
2795
  def __del__(self):
2796
    self.Close()
2797

    
2798
  def Close(self):
2799
    """Close the file and release the lock.
2800

2801
    """
2802
    if hasattr(self, "fd") and self.fd:
2803
      self.fd.close()
2804
      self.fd = None
2805

    
2806
  def _flock(self, flag, blocking, timeout, errmsg):
2807
    """Wrapper for fcntl.flock.
2808

2809
    @type flag: int
2810
    @param flag: operation flag
2811
    @type blocking: bool
2812
    @param blocking: whether the operation should be done in blocking mode.
2813
    @type timeout: None or float
2814
    @param timeout: for how long the operation should be retried (implies
2815
                    non-blocking mode).
2816
    @type errmsg: string
2817
    @param errmsg: error message in case operation fails.
2818

2819
    """
2820
    assert self.fd, "Lock was closed"
2821
    assert timeout is None or timeout >= 0, \
2822
      "If specified, timeout must be positive"
2823
    assert not (flag & fcntl.LOCK_NB), "LOCK_NB must not be set"
2824

    
2825
    # When a timeout is used, LOCK_NB must always be set
2826
    if not (timeout is None and blocking):
2827
      flag |= fcntl.LOCK_NB
2828

    
2829
    if timeout is None:
2830
      self._Lock(self.fd, flag, timeout)
2831
    else:
2832
      try:
2833
        Retry(self._Lock, (0.1, 1.2, 1.0), timeout,
2834
              args=(self.fd, flag, timeout))
2835
      except RetryTimeout:
2836
        raise errors.LockError(errmsg)
2837

    
2838
  @staticmethod
2839
  def _Lock(fd, flag, timeout):
2840
    try:
2841
      fcntl.flock(fd, flag)
2842
    except IOError, err:
2843
      if timeout is not None and err.errno == errno.EAGAIN:
2844
        raise RetryAgain()
2845

    
2846
      logging.exception("fcntl.flock failed")
2847
      raise
2848

    
2849
  def Exclusive(self, blocking=False, timeout=None):
2850
    """Locks the file in exclusive mode.
2851

2852
    @type blocking: boolean
2853
    @param blocking: whether to block and wait until we
2854
        can lock the file or return immediately
2855
    @type timeout: int or None
2856
    @param timeout: if not None, the duration to wait for the lock
2857
        (in blocking mode)
2858

2859
    """
2860
    self._flock(fcntl.LOCK_EX, blocking, timeout,
2861
                "Failed to lock %s in exclusive mode" % self.filename)
2862

    
2863
  def Shared(self, blocking=False, timeout=None):
2864
    """Locks the file in shared mode.
2865

2866
    @type blocking: boolean
2867
    @param blocking: whether to block and wait until we
2868
        can lock the file or return immediately
2869
    @type timeout: int or None
2870
    @param timeout: if not None, the duration to wait for the lock
2871
        (in blocking mode)
2872

2873
    """
2874
    self._flock(fcntl.LOCK_SH, blocking, timeout,
2875
                "Failed to lock %s in shared mode" % self.filename)
2876

    
2877
  def Unlock(self, blocking=True, timeout=None):
2878
    """Unlocks the file.
2879

2880
    According to C{flock(2)}, unlocking can also be a nonblocking
2881
    operation::
2882

2883
      To make a non-blocking request, include LOCK_NB with any of the above
2884
      operations.
2885

2886
    @type blocking: boolean
2887
    @param blocking: whether to block and wait until we
2888
        can lock the file or return immediately
2889
    @type timeout: int or None
2890
    @param timeout: if not None, the duration to wait for the lock
2891
        (in blocking mode)
2892

2893
    """
2894
    self._flock(fcntl.LOCK_UN, blocking, timeout,
2895
                "Failed to unlock %s" % self.filename)
2896

    
2897

    
2898
class LineSplitter:
2899
  """Splits data chunks into lines separated by newline.
2900

2901
  Instances provide a file-like interface.
2902

2903
  """
2904
  def __init__(self, line_fn, *args):
2905
    """Initializes this class.
2906

2907
    @type line_fn: callable
2908
    @param line_fn: Function called for each line, first parameter is line
2909
    @param args: Extra arguments for L{line_fn}
2910

2911
    """
2912
    assert callable(line_fn)
2913

    
2914
    if args:
2915
      # Python 2.4 doesn't have functools.partial yet
2916
      self._line_fn = \
2917
        lambda line: line_fn(line, *args) # pylint: disable-msg=W0142
2918
    else:
2919
      self._line_fn = line_fn
2920

    
2921
    self._lines = collections.deque()
2922
    self._buffer = ""
2923

    
2924
  def write(self, data):
2925
    parts = (self._buffer + data).split("\n")
2926
    self._buffer = parts.pop()
2927
    self._lines.extend(parts)
2928

    
2929
  def flush(self):
2930
    while self._lines:
2931
      self._line_fn(self._lines.popleft().rstrip("\r\n"))
2932

    
2933
  def close(self):
2934
    self.flush()
2935
    if self._buffer:
2936
      self._line_fn(self._buffer)
2937

    
2938

    
2939
def SignalHandled(signums):
2940
  """Signal Handled decoration.
2941

2942
  This special decorator installs a signal handler and then calls the target
2943
  function. The function must accept a 'signal_handlers' keyword argument,
2944
  which will contain a dict indexed by signal number, with SignalHandler
2945
  objects as values.
2946

2947
  The decorator can be safely stacked with iself, to handle multiple signals
2948
  with different handlers.
2949

2950
  @type signums: list
2951
  @param signums: signals to intercept
2952

2953
  """
2954
  def wrap(fn):
2955
    def sig_function(*args, **kwargs):
2956
      assert 'signal_handlers' not in kwargs or \
2957
             kwargs['signal_handlers'] is None or \
2958
             isinstance(kwargs['signal_handlers'], dict), \
2959
             "Wrong signal_handlers parameter in original function call"
2960
      if 'signal_handlers' in kwargs and kwargs['signal_handlers'] is not None:
2961
        signal_handlers = kwargs['signal_handlers']
2962
      else:
2963
        signal_handlers = {}
2964
        kwargs['signal_handlers'] = signal_handlers
2965
      sighandler = SignalHandler(signums)
2966
      try:
2967
        for sig in signums:
2968
          signal_handlers[sig] = sighandler
2969
        return fn(*args, **kwargs)
2970
      finally:
2971
        sighandler.Reset()
2972
    return sig_function
2973
  return wrap
2974

    
2975

    
2976
class SignalHandler(object):
2977
  """Generic signal handler class.
2978

2979
  It automatically restores the original handler when deconstructed or
2980
  when L{Reset} is called. You can either pass your own handler
2981
  function in or query the L{called} attribute to detect whether the
2982
  signal was sent.
2983

2984
  @type signum: list
2985
  @ivar signum: the signals we handle
2986
  @type called: boolean
2987
  @ivar called: tracks whether any of the signals have been raised
2988

2989
  """
2990
  def __init__(self, signum):
2991
    """Constructs a new SignalHandler instance.
2992

2993
    @type signum: int or list of ints
2994
    @param signum: Single signal number or set of signal numbers
2995

2996
    """
2997
    self.signum = set(signum)
2998
    self.called = False
2999

    
3000
    self._previous = {}
3001
    try:
3002
      for signum in self.signum:
3003
        # Setup handler
3004
        prev_handler = signal.signal(signum, self._HandleSignal)
3005
        try:
3006
          self._previous[signum] = prev_handler
3007
        except:
3008
          # Restore previous handler
3009
          signal.signal(signum, prev_handler)
3010
          raise
3011
    except:
3012
      # Reset all handlers
3013
      self.Reset()
3014
      # Here we have a race condition: a handler may have already been called,
3015
      # but there's not much we can do about it at this point.
3016
      raise
3017

    
3018
  def __del__(self):
3019
    self.Reset()
3020

    
3021
  def Reset(self):
3022
    """Restore previous handler.
3023

3024
    This will reset all the signals to their previous handlers.
3025

3026
    """
3027
    for signum, prev_handler in self._previous.items():
3028
      signal.signal(signum, prev_handler)
3029
      # If successful, remove from dict
3030
      del self._previous[signum]
3031

    
3032
  def Clear(self):
3033
    """Unsets the L{called} flag.
3034

3035
    This function can be used in case a signal may arrive several times.
3036

3037
    """
3038
    self.called = False
3039

    
3040
  # we don't care about arguments, but we leave them named for the future
3041
  def _HandleSignal(self, signum, frame): # pylint: disable-msg=W0613
3042
    """Actual signal handling function.
3043

3044
    """
3045
    # This is not nice and not absolutely atomic, but it appears to be the only
3046
    # solution in Python -- there are no atomic types.
3047
    self.called = True
3048

    
3049

    
3050
class FieldSet(object):
3051
  """A simple field set.
3052

3053
  Among the features are:
3054
    - checking if a string is among a list of static string or regex objects
3055
    - checking if a whole list of string matches
3056
    - returning the matching groups from a regex match
3057

3058
  Internally, all fields are held as regular expression objects.
3059

3060
  """
3061
  def __init__(self, *items):
3062
    self.items = [re.compile("^%s$" % value) for value in items]
3063

    
3064
  def Extend(self, other_set):
3065
    """Extend the field set with the items from another one"""
3066
    self.items.extend(other_set.items)
3067

    
3068
  def Matches(self, field):
3069
    """Checks if a field matches the current set
3070

3071
    @type field: str
3072
    @param field: the string to match
3073
    @return: either None or a regular expression match object
3074

3075
    """
3076
    for m in itertools.ifilter(None, (val.match(field) for val in self.items)):
3077
      return m
3078
    return None
3079

    
3080
  def NonMatching(self, items):
3081
    """Returns the list of fields not matching the current set
3082

3083
    @type items: list
3084
    @param items: the list of fields to check
3085
    @rtype: list
3086
    @return: list of non-matching fields
3087

3088
    """
3089
    return [val for val in items if not self.Matches(val)]