Add a "safe" file wrapper over WriteFile
authorIustin Pop <iustin@google.com>
Fri, 22 Oct 2010 12:29:47 +0000 (14:29 +0200)
committerIustin Pop <iustin@google.com>
Fri, 22 Oct 2010 15:22:45 +0000 (17:22 +0200)
Signed-off-by: Iustin Pop <iustin@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>

lib/utils.py
test/ganeti.utils_unittest.py

index aa84cb1..bdd8610 100644 (file)
@@ -1912,6 +1912,31 @@ def VerifyFileID(fi_disk, fi_ours):
   return (d1, i1) == (d2, i2) and m1 <= m2
 
 
+def SafeWriteFile(file_name, file_id, **kwargs):
+  """Wraper over L{WriteFile} that locks the target file.
+
+  By keeping the target file locked during WriteFile, we ensure that
+  cooperating writers will safely serialise access to the file.
+
+  @type file_name: str
+  @param file_name: the target filename
+  @type file_id: tuple
+  @param file_id: a result from L{GetFileID}
+
+  """
+  fd = os.open(file_name, os.O_RDONLY | os.O_CREAT)
+  try:
+    LockFile(fd)
+    if file_id is not None:
+      disk_id = GetFileID(fd=fd)
+      if not VerifyFileID(disk_id, file_id):
+        raise errors.LockError("Cannot overwrite file %s, it has been modified"
+                               " since last written" % file_name)
+    return WriteFile(file_name, **kwargs)
+  finally:
+    os.close(fd)
+
+
 def ReadOneLineFile(file_name, strict=False):
   """Return the first non-empty line from a file.
 
index 7fc93db..2c46afc 100755 (executable)
@@ -2374,6 +2374,21 @@ class TestFileID(testutils.GanetiTestCase):
     finally:
       os.close(fd)
 
+  def testWriteFile(self):
+    name = self._CreateTempFile()
+    oldi = utils.GetFileID(path=name)
+    mtime = oldi[2]
+    os.utime(name, (mtime + 10, mtime + 10))
+    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
+                      oldi, data="")
+    os.utime(name, (mtime - 10, mtime - 10))
+    utils.SafeWriteFile(name, oldi, data="")
+    oldi = utils.GetFileID(path=name)
+    mtime = oldi[2]
+    os.utime(name, (mtime + 10, mtime + 10))
+    # this doesn't raise, since we passed None
+    utils.SafeWriteFile(name, None, data="")
+
 
 if __name__ == '__main__':
   testutils.GanetiTestProgram()