Rename disk_template/storage_type map + cleanup
[ganeti-local] / lib / utils / io.py
index 4c1df6c..808751c 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
+# Copyright (C) 2006, 2007, 2010, 2011, 2012 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -28,27 +28,84 @@ import shutil
 import tempfile
 import errno
 import time
+import stat
+import grp
+import pwd
 
 from ganeti import errors
 from ganeti import constants
+from ganeti import pathutils
 from ganeti.utils import filelock
 
+#: Directory used by fsck(8) to store recovered data, usually at a file
+#: system's root directory
+_LOST_AND_FOUND = "lost+found"
 
-#: Path generating random UUID
-_RANDOM_UUID_FILE = "/proc/sys/kernel/random/uuid"
+# Possible values for keep_perms in WriteFile()
+KP_NEVER = 0
+KP_ALWAYS = 1
+KP_IF_EXISTS = 2
 
+KEEP_PERMS_VALUES = [
+  KP_NEVER,
+  KP_ALWAYS,
+  KP_IF_EXISTS,
+  ]
 
-def ReadFile(file_name, size=-1):
+
+def ErrnoOrStr(err):
+  """Format an EnvironmentError exception.
+
+  If the L{err} argument has an errno attribute, it will be looked up
+  and converted into a textual C{E...} description. Otherwise the
+  string representation of the error will be returned.
+
+  @type err: L{EnvironmentError}
+  @param err: the exception to format
+
+  """
+  if hasattr(err, "errno"):
+    detail = errno.errorcode[err.errno]
+  else:
+    detail = str(err)
+  return detail
+
+
+class FileStatHelper:
+  """Helper to store file handle's C{fstat}.
+
+  Useful in combination with L{ReadFile}'s C{preread} parameter.
+
+  """
+  def __init__(self):
+    """Initializes this class.
+
+    """
+    self.st = None
+
+  def __call__(self, fh):
+    """Calls C{fstat} on file handle.
+
+    """
+    self.st = os.fstat(fh.fileno())
+
+
+def ReadFile(file_name, size=-1, preread=None):
   """Reads a file.
 
   @type size: int
   @param size: Read at most size bytes (if negative, entire file)
+  @type preread: callable receiving file handle as single parameter
+  @param preread: Function called before file is read
   @rtype: str
   @return: the (possibly partial) content of the file
 
   """
   f = open(file_name, "r")
   try:
+    if preread:
+      preread(f)
+
     return f.read(size)
   finally:
     f.close()
@@ -58,7 +115,7 @@ def WriteFile(file_name, fn=None, data=None,
               mode=None, uid=-1, gid=-1,
               atime=None, mtime=None, close=True,
               dry_run=False, backup=False,
-              prewrite=None, postwrite=None):
+              prewrite=None, postwrite=None, keep_perms=KP_NEVER):
   """(Over)write a file atomically.
 
   The file_name and either fn (a function taking one argument, the
@@ -95,6 +152,14 @@ def WriteFile(file_name, fn=None, data=None,
   @param prewrite: function to be called before writing content
   @type postwrite: callable
   @param postwrite: function to be called after writing content
+  @type keep_perms: members of L{KEEP_PERMS_VALUES}
+  @param keep_perms: if L{KP_NEVER} (default), owner, group, and mode are
+      taken from the other parameters; if L{KP_ALWAYS}, owner, group, and
+      mode are copied from the existing file; if L{KP_IF_EXISTS}, owner,
+      group, and mode are taken from the file, and if the file doesn't
+      exist, they are taken from the other parameters. It is an error to
+      pass L{KP_ALWAYS} when the file doesn't exist or when C{uid}, C{gid},
+      or C{mode} are set to non-default values.
 
   @rtype: None or int
   @return: None if the 'close' parameter evaluates to True,
@@ -114,9 +179,28 @@ def WriteFile(file_name, fn=None, data=None,
     raise errors.ProgrammerError("Both atime and mtime must be either"
                                  " set or None")
 
+  if not keep_perms in KEEP_PERMS_VALUES:
+    raise errors.ProgrammerError("Invalid value for keep_perms: %s" %
+                                 keep_perms)
+  if keep_perms == KP_ALWAYS and (uid != -1 or gid != -1 or mode is not None):
+    raise errors.ProgrammerError("When keep_perms==KP_ALWAYS, 'uid', 'gid',"
+                                 " and 'mode' cannot be set")
+
   if backup and not dry_run and os.path.isfile(file_name):
     CreateBackup(file_name)
 
+  if keep_perms == KP_ALWAYS or keep_perms == KP_IF_EXISTS:
+    # os.stat() raises an exception if the file doesn't exist
+    try:
+      file_stat = os.stat(file_name)
+      mode = stat.S_IMODE(file_stat.st_mode)
+      uid = file_stat.st_uid
+      gid = file_stat.st_gid
+    except OSError:
+      if keep_perms == KP_ALWAYS:
+        raise
+      # else: if keeep_perms == KP_IF_EXISTS it's ok if the file doesn't exist
+
   # Whether temporary file needs to be removed (e.g. if any error occurs)
   do_remove = True
 
@@ -290,9 +374,13 @@ def RemoveDir(dirname):
       raise
 
 
-def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
+def RenameFile(old, new, mkdir=False, mkdir_mode=0750, dir_uid=None,
+               dir_gid=None):
   """Renames a file.
 
+  This just creates the very least directory if it does not exist and C{mkdir}
+  is set to true.
+
   @type old: string
   @param old: Original path
   @type new: string
@@ -301,6 +389,10 @@ def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
   @param mkdir: Whether to create target directory if it doesn't exist
   @type mkdir_mode: int
   @param mkdir_mode: Mode for newly created directories
+  @type dir_uid: int
+  @param dir_uid: The uid for the (if fresh created) dir
+  @type dir_gid: int
+  @param dir_gid: The gid for the (if fresh created) dir
 
   """
   try:
@@ -311,13 +403,89 @@ def RenameFile(old, new, mkdir=False, mkdir_mode=0750):
     # as efficient.
     if mkdir and err.errno == errno.ENOENT:
       # Create directory and try again
-      Makedirs(os.path.dirname(new), mode=mkdir_mode)
+      dir_path = os.path.dirname(new)
+      MakeDirWithPerm(dir_path, mkdir_mode, dir_uid, dir_gid)
 
       return os.rename(old, new)
 
     raise
 
 
+def EnforcePermission(path, mode, uid=None, gid=None, must_exist=True,
+                      _chmod_fn=os.chmod, _chown_fn=os.chown, _stat_fn=os.stat):
+  """Enforces that given path has given permissions.
+
+  @param path: The path to the file
+  @param mode: The mode of the file
+  @param uid: The uid of the owner of this file
+  @param gid: The gid of the owner of this file
+  @param must_exist: Specifies if non-existance of path will be an error
+  @param _chmod_fn: chmod function to use (unittest only)
+  @param _chown_fn: chown function to use (unittest only)
+
+  """
+  logging.debug("Checking %s", path)
+
+  # chown takes -1 if you want to keep one part of the ownership, however
+  # None is Python standard for that. So we remap them here.
+  if uid is None:
+    uid = -1
+  if gid is None:
+    gid = -1
+
+  try:
+    st = _stat_fn(path)
+
+    fmode = stat.S_IMODE(st[stat.ST_MODE])
+    if fmode != mode:
+      logging.debug("Changing mode of %s from %#o to %#o", path, fmode, mode)
+      _chmod_fn(path, mode)
+
+    if max(uid, gid) > -1:
+      fuid = st[stat.ST_UID]
+      fgid = st[stat.ST_GID]
+      if fuid != uid or fgid != gid:
+        logging.debug("Changing owner of %s from UID %s/GID %s to"
+                      " UID %s/GID %s", path, fuid, fgid, uid, gid)
+        _chown_fn(path, uid, gid)
+  except EnvironmentError, err:
+    if err.errno == errno.ENOENT:
+      if must_exist:
+        raise errors.GenericError("Path %s should exist, but does not" % path)
+    else:
+      raise errors.GenericError("Error while changing permissions on %s: %s" %
+                                (path, err))
+
+
+def MakeDirWithPerm(path, mode, uid, gid, _lstat_fn=os.lstat,
+                    _mkdir_fn=os.mkdir, _perm_fn=EnforcePermission):
+  """Enforces that given path is a dir and has given mode, uid and gid set.
+
+  @param path: The path to the file
+  @param mode: The mode of the file
+  @param uid: The uid of the owner of this file
+  @param gid: The gid of the owner of this file
+  @param _lstat_fn: Stat function to use (unittest only)
+  @param _mkdir_fn: mkdir function to use (unittest only)
+  @param _perm_fn: permission setter function to use (unittest only)
+
+  """
+  logging.debug("Checking directory %s", path)
+  try:
+    # We don't want to follow symlinks
+    st = _lstat_fn(path)
+  except EnvironmentError, err:
+    if err.errno != errno.ENOENT:
+      raise errors.GenericError("stat(2) on %s failed: %s" % (path, err))
+    _mkdir_fn(path)
+  else:
+    if not stat.S_ISDIR(st[stat.ST_MODE]):
+      raise errors.GenericError(("Path %s is expected to be a directory, but "
+                                 "isn't") % path)
+
+  _perm_fn(path, mode, uid=uid, gid=gid)
+
+
 def Makedirs(path, mode=0750):
   """Super-mkdir; create a leaf directory and all intermediate ones.
 
@@ -356,7 +524,7 @@ def CreateBackup(file_name):
   """
   if not os.path.isfile(file_name):
     raise errors.ProgrammerError("Can't make a backup of a non-file '%s'" %
-                                file_name)
+                                 file_name)
 
   prefix = ("%s.backup-%s." %
             (os.path.basename(file_name), TimestampForFilename()))
@@ -377,7 +545,7 @@ def CreateBackup(file_name):
   return backup_name
 
 
-def ListVisibleFiles(path):
+def ListVisibleFiles(path, _is_mountpoint=os.path.ismount):
   """Returns a list of visible files in a directory.
 
   @type path: str
@@ -390,8 +558,22 @@ def ListVisibleFiles(path):
   if not IsNormAbsPath(path):
     raise errors.ProgrammerError("Path passed to ListVisibleFiles is not"
                                  " absolute/normalized: '%s'" % path)
-  files = [i for i in os.listdir(path) if not i.startswith(".")]
-  return files
+
+  mountpoint = _is_mountpoint(path)
+
+  def fn(name):
+    """File name filter.
+
+    Ignores files starting with a dot (".") as by Unix convention they're
+    considered hidden. The "lost+found" directory found at the root of some
+    filesystems is also hidden.
+
+    """
+    return not (name.startswith(".") or
+                (mountpoint and name == _LOST_AND_FOUND and
+                 os.path.isdir(os.path.join(path, name))))
+
+  return filter(fn, os.listdir(path))
 
 
 def EnsureDirs(dirs):
@@ -460,6 +642,33 @@ def IsNormAbsPath(path):
   return os.path.normpath(path) == path and os.path.isabs(path)
 
 
+def IsBelowDir(root, other_path):
+  """Check whether a path is below a root dir.
+
+  This works around the nasty byte-byte comparison of commonprefix.
+
+  """
+  if not (os.path.isabs(root) and os.path.isabs(other_path)):
+    raise ValueError("Provided paths '%s' and '%s' are not absolute" %
+                     (root, other_path))
+
+  norm_other = os.path.normpath(other_path)
+
+  if norm_other == os.sep:
+    # The root directory can never be below another path
+    return False
+
+  norm_root = os.path.normpath(root)
+
+  if norm_root == os.sep:
+    # This is the root directory, no need to add another slash
+    prepared_root = norm_root
+  else:
+    prepared_root = "%s%s" % (norm_root, os.sep)
+
+  return os.path.commonprefix([prepared_root, norm_other]) == prepared_root
+
+
 def PathJoin(*args):
   """Safe-join a list of path components.
 
@@ -483,10 +692,9 @@ def PathJoin(*args):
   if not IsNormAbsPath(result):
     raise ValueError("Invalid parameters to PathJoin: '%s'" % str(args))
   # check that we're still under the original prefix
-  prefix = os.path.commonprefix([root, result])
-  if prefix != root:
+  if not IsBelowDir(root, result):
     raise ValueError("Error: path joining resulted in different prefix"
-                     " (%s != %s)" % (prefix, root))
+                     " (%s != %s)" % (result, root))
   return result
 
 
@@ -506,7 +714,7 @@ def TailFile(fname, lines=20):
   try:
     fd.seek(0, 2)
     pos = fd.tell()
-    pos = max(0, pos-4096)
+    pos = max(0, pos - 4096)
     fd.seek(pos, 0)
     raw_data = fd.read()
   finally:
@@ -580,13 +788,24 @@ def ReadPidFile(pidfile):
       logging.exception("Can't read pid file")
     return 0
 
+  return _ParsePidFileContents(raw_data)
+
+
+def _ParsePidFileContents(data):
+  """Tries to extract a process ID from a PID file's content.
+
+  @type data: string
+  @rtype: int
+  @return: Zero if nothing could be read, PID otherwise
+
+  """
   try:
-    pid = int(raw_data)
-  except (TypeError, ValueError), err:
+    pid = int(data)
+  except (TypeError, ValueError):
     logging.info("Can't parse pid file contents", exc_info=True)
     return 0
-
-  return pid
+  else:
+    return pid
 
 
 def ReadLockedPidFile(path):
@@ -620,6 +839,29 @@ def ReadLockedPidFile(path):
   return None
 
 
+def _SplitSshKey(key):
+  """Splits a line for SSH's C{authorized_keys} file.
+
+  If the line has no options (e.g. no C{command="..."}), only the significant
+  parts, the key type and its hash, are used. Otherwise the whole line is used
+  (split at whitespace).
+
+  @type key: string
+  @param key: Key line
+  @rtype: tuple
+
+  """
+  parts = key.split()
+
+  if parts and parts[0] in constants.SSHAK_ALL:
+    # If the key has no options in front of it, we only want the significant
+    # fields
+    return (False, parts[:2])
+  else:
+    # Can't properly split the line, so use everything
+    return (True, parts)
+
+
 def AddAuthorizedKey(file_obj, key):
   """Adds an SSH public key to an authorized_keys file.
 
@@ -629,7 +871,7 @@ def AddAuthorizedKey(file_obj, key):
   @param key: string containing key
 
   """
-  key_fields = key.split()
+  key_fields = _SplitSshKey(key)
 
   if isinstance(file_obj, basestring):
     f = open(file_obj, "a+")
@@ -640,7 +882,7 @@ def AddAuthorizedKey(file_obj, key):
     nl = True
     for line in f:
       # Ignore whitespace changes
-      if line.split() == key_fields:
+      if _SplitSshKey(line) == key_fields:
         break
       nl = line.endswith("\n")
     else:
@@ -662,7 +904,7 @@ def RemoveAuthorizedKey(file_name, key):
   @param key: string containing key
 
   """
-  key_fields = key.split()
+  key_fields = _SplitSshKey(key)
 
   fd, tmpname = tempfile.mkstemp(dir=os.path.dirname(file_name))
   try:
@@ -672,7 +914,7 @@ def RemoveAuthorizedKey(file_name, key):
       try:
         for line in f:
           # Ignore whitespace changes while comparing lines
-          if line.split() != key_fields:
+          if _SplitSshKey(line) != key_fields:
             out.write(line)
 
         out.flush()
@@ -696,7 +938,7 @@ def DaemonPidFileName(name):
       daemon name
 
   """
-  return PathJoin(constants.RUN_GANETI_DIR, "%s.pid" % name)
+  return PathJoin(pathutils.RUN_DIR, "%s.pid" % name)
 
 
 def WritePidFile(pidfile):
@@ -713,13 +955,21 @@ def WritePidFile(pidfile):
   """
   # We don't rename nor truncate the file to not drop locks under
   # existing processes
-  fd_pidfile = os.open(pidfile, os.O_WRONLY | os.O_CREAT, 0600)
+  fd_pidfile = os.open(pidfile, os.O_RDWR | os.O_CREAT, 0600)
 
   # Lock the PID file (and fail if not possible to do so). Any code
   # wanting to send a signal to the daemon should try to lock the PID
   # file before reading it. If acquiring the lock succeeds, the daemon is
   # no longer running and the signal should not be sent.
-  filelock.LockFile(fd_pidfile)
+  try:
+    filelock.LockFile(fd_pidfile)
+  except errors.LockError:
+    msg = ["PID file '%s' is already locked by another process" % pidfile]
+    # Try to read PID file
+    pid = _ParsePidFileContents(os.read(fd_pidfile, 100))
+    if pid > 0:
+      msg.append(", PID read from file is %s" % pid)
+    raise errors.PidFileLockError("".join(msg))
 
   os.write(fd_pidfile, "%d\n" % os.getpid())
 
@@ -777,4 +1027,76 @@ def NewUUID():
   @rtype: str
 
   """
-  return ReadFile(_RANDOM_UUID_FILE, size=128).rstrip("\n")
+  return ReadFile(constants.RANDOM_UUID_FILE, size=128).rstrip("\n")
+
+
+class TemporaryFileManager(object):
+  """Stores the list of files to be deleted and removes them on demand.
+
+  """
+
+  def __init__(self):
+    self._files = []
+
+  def __del__(self):
+    self.Cleanup()
+
+  def Add(self, filename):
+    """Add file to list of files to be deleted.
+
+    @type filename: string
+    @param filename: path to filename to be added
+
+    """
+    self._files.append(filename)
+
+  def Remove(self, filename):
+    """Remove file from list of files to be deleted.
+
+    @type filename: string
+    @param filename: path to filename to be deleted
+
+    """
+    self._files.remove(filename)
+
+  def Cleanup(self):
+    """Delete all files marked for deletion
+
+    """
+    while self._files:
+      RemoveFile(self._files.pop())
+
+
+def IsUserInGroup(uid, gid):
+  """Returns True if the user belongs to the group.
+
+  @type uid: int
+  @param uid: the user id
+  @type gid: int
+  @param gid: the group id
+  @rtype: bool
+
+  """
+  user = pwd.getpwuid(uid)
+  group = grp.getgrgid(gid)
+  return user.pw_gid == gid or user.pw_name in group.gr_mem
+
+
+def CanRead(username, filename):
+  """Returns True if the user can access (read) the file.
+
+  @type username: string
+  @param username: the name of the user
+  @type filename: string
+  @param filename: the name of the file
+  @rtype: bool
+
+  """
+  filestats = os.stat(filename)
+  user = pwd.getpwnam(username)
+  uid = user.pw_uid
+  user_readable = filestats.st_mode & stat.S_IRUSR != 0
+  group_readable = filestats.st_mode & stat.S_IRGRP != 0
+  return ((filestats.st_uid == uid and user_readable)
+          or (filestats.st_uid != uid and
+              IsUserInGroup(uid, filestats.st_gid) and group_readable))