Locking related fixes for networks
[ganeti-local] / test / testutils.py
index b4d286e..3fcfbc4 100644 (file)
 """Utilities for unit testing"""
 
 import os
+import sys
 import stat
 import tempfile
 import unittest
+import logging
+import types
 
 from ganeti import utils
 
 
+def GetSourceDir():
+  return os.environ.get("TOP_SRCDIR", ".")
+
+
+def _SetupLogging(verbose):
+  """Setupup logging infrastructure.
+
+  """
+  fmt = logging.Formatter("%(asctime)s: %(threadName)s"
+                          " %(levelname)s %(message)s")
+
+  if verbose:
+    handler = logging.StreamHandler()
+  else:
+    handler = logging.FileHandler(os.devnull, "a")
+
+  handler.setLevel(logging.NOTSET)
+  handler.setFormatter(fmt)
+
+  root_logger = logging.getLogger("")
+  root_logger.setLevel(logging.NOTSET)
+  root_logger.addHandler(handler)
+
+
+class GanetiTestProgram(unittest.TestProgram):
+  def runTests(self):
+    """Runs all tests.
+
+    """
+    _SetupLogging("LOGTOSTDERR" in os.environ)
+
+    sys.stderr.write("Running %s\n" % self.progName)
+    sys.stderr.flush()
+
+    # Ensure assertions will be evaluated
+    if not __debug__:
+      raise Exception("Not running in debug mode, assertions would not be"
+                      " evaluated")
+
+    # Check again, this time with a real assertion
+    try:
+      assert False
+    except AssertionError:
+      pass
+    else:
+      raise Exception("Assertion not evaluated")
+
+    # The following piece of code is a backport from Python 2.6. Python 2.4/2.5
+    # only accept class instances as test runners. Being able to pass classes
+    # reduces the amount of code necessary for using a custom test runner.
+    # 2.6 and above should use their own code, however.
+    if (self.testRunner and sys.hexversion < 0x2060000 and
+        isinstance(self.testRunner, (type, types.ClassType))):
+      try:
+        self.testRunner = self.testRunner(verbosity=self.verbosity)
+      except TypeError:
+        # didn't accept the verbosity argument
+        self.testRunner = self.testRunner()
+
+    return unittest.TestProgram.runTests(self)
+
+
 class GanetiTestCase(unittest.TestCase):
   """Helper class for unittesting.
 
@@ -71,6 +136,42 @@ class GanetiTestCase(unittest.TestCase):
     actual_mode = stat.S_IMODE(st.st_mode)
     self.assertEqual(actual_mode, expected_mode)
 
+  def assertFileUid(self, file_name, expected_uid):
+    """Checks that the user id of a file is what we expect.
+
+    @type file_name: str
+    @param file_name: the file whose contents we should check
+    @type expected_uid: int
+    @param expected_uid: the user id we expect
+
+    """
+    st = os.stat(file_name)
+    actual_uid = st.st_uid
+    self.assertEqual(actual_uid, expected_uid)
+
+  def assertFileGid(self, file_name, expected_gid):
+    """Checks that the group id of a file is what we expect.
+
+    @type file_name: str
+    @param file_name: the file whose contents we should check
+    @type expected_gid: int
+    @param expected_gid: the group id we expect
+
+    """
+    st = os.stat(file_name)
+    actual_gid = st.st_gid
+    self.assertEqual(actual_gid, expected_gid)
+
+  def assertEqualValues(self, first, second, msg=None):
+    """Compares two values whether they're equal.
+
+    Tuples are automatically converted to lists before comparing.
+
+    """
+    return self.assertEqual(UnifyValueType(first),
+                            UnifyValueType(second),
+                            msg=msg)
+
   @staticmethod
   def _TestDataFilename(name):
     """Returns the filename of a given test data file.
@@ -83,8 +184,7 @@ class GanetiTestCase(unittest.TestCase):
         be used in 'make distcheck' rules
 
     """
-    prefix = os.environ.get("TOP_SRCDIR", ".")
-    return "%s/test/data/%s" % (prefix, name)
+    return "%s/test/data/%s" % (GetSourceDir(), name)
 
   @classmethod
   def _ReadTestData(cls, name):
@@ -94,7 +194,6 @@ class GanetiTestCase(unittest.TestCase):
     proper test file name.
 
     """
-
     return utils.ReadFile(cls._TestDataFilename(name))
 
   def _CreateTempFile(self):
@@ -108,3 +207,48 @@ class GanetiTestCase(unittest.TestCase):
     os.close(fh)
     self._temp_files.append(fname)
     return fname
+
+
+def UnifyValueType(data):
+  """Converts all tuples into lists.
+
+  This is useful for unittests where an external library doesn't keep types.
+
+  """
+  if isinstance(data, (tuple, list)):
+    return [UnifyValueType(i) for i in data]
+
+  elif isinstance(data, dict):
+    return dict([(UnifyValueType(key), UnifyValueType(value))
+                 for (key, value) in data.iteritems()])
+
+  return data
+
+
+class CallCounter(object):
+  """Utility class to count number of calls to a function/method.
+
+  """
+  def __init__(self, fn):
+    """Initializes this class.
+
+    @type fn: Callable
+
+    """
+    self._fn = fn
+    self._count = 0
+
+  def __call__(self, *args, **kwargs):
+    """Calls wrapped function with given parameters.
+
+    """
+    self._count += 1
+    return self._fn(*args, **kwargs)
+
+  def Count(self):
+    """Returns number of calls.
+
+    @rtype: number
+
+    """
+    return self._count