ssconf: Verify file size when reading, add some tests
authorMichael Hanselmann <hansmi@google.com>
Wed, 28 Nov 2012 08:23:28 +0000 (09:23 +0100)
committerMichael Hanselmann <hansmi@google.com>
Wed, 28 Nov 2012 11:48:36 +0000 (12:48 +0100)
Until now ssconf would limit the amount read from files to 128 KiB and
silently ignored files larger than that. With this patch a check is
added by using fstat(2) on the file descriptor while it's being read.

There were no tests file ssconf at all, so some are added.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Guido Trotter <ultrotter@google.com>

Makefile.am
lib/ssconf.py
test/ganeti.ssconf_unittest.py [new file with mode: 0755]

index afc106c..67475f3 100644 (file)
@@ -958,6 +958,7 @@ python_tests = \
        test/ganeti.runtime_unittest.py \
        test/ganeti.serializer_unittest.py \
        test/ganeti.server.rapi_unittest.py \
+       test/ganeti.ssconf_unittest.py \
        test/ganeti.ssh_unittest.py \
        test/ganeti.storage_unittest.py \
        test/ganeti.tools.ensure_dirs_unittest.py \
index 30c9d78..ad5969c 100644 (file)
@@ -69,6 +69,28 @@ _VALID_KEYS = frozenset([
 _MAX_SIZE = 128 * 1024
 
 
+def ReadSsconfFile(filename):
+  """Reads an ssconf file and verifies its size.
+
+  @type filename: string
+  @param filename: Path to file
+  @rtype: string
+  @return: File contents without newlines at the end
+  @raise RuntimeError: When the file size exceeds L{_MAX_SIZE}
+
+  """
+  statcb = utils.FileStatHelper()
+
+  data = utils.ReadFile(filename, size=_MAX_SIZE, preread=statcb)
+
+  if statcb.st.st_size > _MAX_SIZE:
+    msg = ("File '%s' has a size of %s bytes (up to %s allowed)" %
+           (filename, statcb.st.st_size, _MAX_SIZE))
+    raise RuntimeError(msg)
+
+  return data.rstrip("\n")
+
+
 class SimpleStore(object):
   """Interface to static cluster data.
 
@@ -106,15 +128,13 @@ class SimpleStore(object):
     """
     filename = self.KeyToFilename(key)
     try:
-      data = utils.ReadFile(filename, size=_MAX_SIZE)
+      return ReadSsconfFile(filename)
     except EnvironmentError, err:
       if err.errno == errno.ENOENT and default is not None:
         return default
       raise errors.ConfigurationError("Can't read ssconf file %s: %s" %
                                       (filename, str(err)))
 
-    return data.rstrip("\n")
-
   def WriteFiles(self, values):
     """Writes ssconf files used by external scripts.
 
diff --git a/test/ganeti.ssconf_unittest.py b/test/ganeti.ssconf_unittest.py
new file mode 100755 (executable)
index 0000000..86d93be
--- /dev/null
@@ -0,0 +1,145 @@
+#!/usr/bin/python
+#
+
+# Copyright (C) 2012 Google Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+
+"""Script for testing ganeti.ssconf"""
+
+import os
+import unittest
+import tempfile
+import shutil
+import errno
+
+from ganeti import utils
+from ganeti import constants
+from ganeti import errors
+from ganeti import ssconf
+
+import testutils
+
+
+class TestReadSsconfFile(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  def testReadDirectory(self):
+    self.assertRaises(EnvironmentError, ssconf.ReadSsconfFile, self.tmpdir)
+
+  def testNonExistantFile(self):
+    testfile = utils.PathJoin(self.tmpdir, "does.not.exist")
+
+    self.assertFalse(os.path.exists(testfile))
+
+    try:
+      ssconf.ReadSsconfFile(testfile)
+    except EnvironmentError, err:
+      self.assertEqual(err.errno, errno.ENOENT)
+    else:
+      self.fail("Exception was not raised")
+
+  def testEmptyFile(self):
+    testfile = utils.PathJoin(self.tmpdir, "empty")
+
+    utils.WriteFile(testfile, data="")
+
+    self.assertEqual(ssconf.ReadSsconfFile(testfile), "")
+
+  def testSingleLine(self):
+    testfile = utils.PathJoin(self.tmpdir, "data")
+
+    for nl in range(0, 10):
+      utils.WriteFile(testfile, data="Hello World" + ("\n" * nl))
+
+      self.assertEqual(ssconf.ReadSsconfFile(testfile),
+                       "Hello World")
+
+  def testExactlyMaxSize(self):
+    testfile = utils.PathJoin(self.tmpdir, "data")
+
+    data = "A" * ssconf._MAX_SIZE
+    utils.WriteFile(testfile, data=data)
+
+    self.assertEqual(os.path.getsize(testfile), ssconf._MAX_SIZE)
+
+    self.assertEqual(ssconf.ReadSsconfFile(testfile),
+                     data)
+
+  def testLargeFile(self):
+    testfile = utils.PathJoin(self.tmpdir, "data")
+
+    for size in [ssconf._MAX_SIZE + 1, ssconf._MAX_SIZE * 2]:
+      utils.WriteFile(testfile, data="A" * size)
+      self.assertTrue(os.path.getsize(testfile) > ssconf._MAX_SIZE)
+      self.assertRaises(RuntimeError, ssconf.ReadSsconfFile, testfile)
+
+
+class TestSimpleStore(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+    self.sstore = ssconf.SimpleStore(cfg_location=self.tmpdir)
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  def testInvalidKey(self):
+    self.assertRaises(errors.ProgrammerError, self.sstore.KeyToFilename,
+                      "not a valid key")
+    self.assertRaises(errors.ProgrammerError, self.sstore._ReadFile,
+                      "not a valid key")
+
+  def testKeyToFilename(self):
+    for key in ssconf._VALID_KEYS:
+      result = self.sstore.KeyToFilename(key)
+      self.assertTrue(utils.IsBelowDir(self.tmpdir, result))
+      self.assertTrue(os.path.basename(result).startswith("ssconf_"))
+
+  def testReadFileNonExistingFile(self):
+    filename = self.sstore.KeyToFilename(constants.SS_CLUSTER_NAME)
+
+    self.assertFalse(os.path.exists(filename))
+    try:
+      self.sstore._ReadFile(constants.SS_CLUSTER_NAME)
+    except errors.ConfigurationError, err:
+      self.assertTrue(str(err).startswith("Can't read ssconf file"))
+    else:
+      self.fail("Exception was not raised")
+
+    for default in ["", "Hello World", 0, 100]:
+      self.assertFalse(os.path.exists(filename))
+      result = self.sstore._ReadFile(constants.SS_CLUSTER_NAME, default=default)
+      self.assertEqual(result, default)
+
+  def testReadFile(self):
+    utils.WriteFile(self.sstore.KeyToFilename(constants.SS_CLUSTER_NAME),
+                    data="cluster.example.com")
+
+    self.assertEqual(self.sstore._ReadFile(constants.SS_CLUSTER_NAME),
+                     "cluster.example.com")
+
+    self.assertEqual(self.sstore._ReadFile(constants.SS_CLUSTER_NAME,
+                                           default="something.example.com"),
+                     "cluster.example.com")
+
+
+if __name__ == "__main__":
+  testutils.GanetiTestProgram()