Refactor QA configuration code
authorMichael Hanselmann <hansmi@google.com>
Mon, 4 Feb 2013 15:06:23 +0000 (16:06 +0100)
committerMichael Hanselmann <hansmi@google.com>
Fri, 8 Feb 2013 13:35:25 +0000 (14:35 +0100)
Ever since its introduction (sometime before commit cec9845 in September
2007), the QA configuration was stored in a dictionary at module-level
in “qa/qa_config.py”. The configuration was loaded, verified and
evaluated using module-level functions. Since then the configuration has
become more complicated and more functionality has been added. This
patch refactors handling the configuration to use a class and provides
unittests.

- The configuration is loaded through a class method which also verifies
  it for consistency
- Wrapper methods are provided in “qa_config” to not change the
  interface
- Unit tests are provided for the new configuration class
- The configuration object is still stored in a module-level variable
  and can be retrieved using “GetConfig” (direct access should be
  avoided so an uninitialized configuration can be detected)

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Helga Velroyen <helgav@google.com>

qa/qa_config.py
test/py/qa.qa_config_unittest.py

index 4d13a84..9b50efb 100644 (file)
@@ -40,49 +40,150 @@ _ENABLED_HV_KEY = "enabled-hypervisors"
 _exclusive_storage = None
 
 
-cfg = {}
+#: QA configuration (L{_QaConfig})
+_config = None
 
 
-def Load(path):
-  """Loads the passed configuration file.
+class _QaConfig(object):
+  def __init__(self, data):
+    """Initializes instances of this class.
 
-  """
-  global cfg # pylint: disable=W0603
+    """
+    self._data = data
+
+  @classmethod
+  def Load(cls, filename):
+    """Loads a configuration file and produces a configuration object.
+
+    @type filename: string
+    @param filename: Path to configuration file
+    @rtype: L{_QaConfig}
+
+    """
+    data = serializer.LoadJson(utils.ReadFile(filename))
+
+    result = cls(data)
+    result.Validate()
+
+    return result
+
+  def Validate(self):
+    """Validates loaded configuration data.
+
+    """
+    if not self.get("nodes"):
+      raise qa_error.Error("Need at least one node")
+
+    if not self.get("instances"):
+      raise qa_error.Error("Need at least one instance")
+
+    if (self.get("disk") is None or
+        self.get("disk-growth") is None or
+        len(self.get("disk")) != len(self.get("disk-growth"))):
+      raise qa_error.Error("Config options 'disk' and 'disk-growth' must exist"
+                           " and have the same number of items")
+
+    check = self.GetInstanceCheckScript()
+    if check:
+      try:
+        os.stat(check)
+      except EnvironmentError, err:
+        raise qa_error.Error("Can't find instance check script '%s': %s" %
+                             (check, err))
+
+    enabled_hv = frozenset(self.GetEnabledHypervisors())
+    if not enabled_hv:
+      raise qa_error.Error("No hypervisor is enabled")
+
+    difference = enabled_hv - constants.HYPER_TYPES
+    if difference:
+      raise qa_error.Error("Unknown hypervisor(s) enabled: %s" %
+                           utils.CommaJoin(difference))
+
+  def __getitem__(self, name):
+    """Returns configuration value.
+
+    @type name: string
+    @param name: Name of configuration entry
+
+    """
+    return self._data[name]
+
+  def get(self, name, default=None):
+    """Returns configuration value.
+
+    @type name: string
+    @param name: Name of configuration entry
+    @param default: Default value
+
+    """
+    return self._data.get(name, default)
 
-  cfg = serializer.LoadJson(utils.ReadFile(path))
+  def GetMasterNode(self):
+    """Returns the default master node for the cluster.
 
-  Validate()
+    """
+    return self["nodes"][0]
+
+  def GetInstanceCheckScript(self):
+    """Returns path to instance check script or C{None}.
+
+    """
+    return self._data.get(_INSTANCE_CHECK_KEY, None)
 
+  def GetEnabledHypervisors(self):
+    """Returns list of enabled hypervisors.
 
-def Validate():
-  if len(cfg["nodes"]) < 1:
-    raise qa_error.Error("Need at least one node")
-  if len(cfg["instances"]) < 1:
-    raise qa_error.Error("Need at least one instance")
-  if len(cfg["disk"]) != len(cfg["disk-growth"]):
-    raise qa_error.Error("Config options 'disk' and 'disk-growth' must have"
-                         " the same number of items")
+    @rtype: list
 
-  check = GetInstanceCheckScript()
-  if check:
+    """
     try:
-      os.stat(check)
-    except EnvironmentError, err:
-      raise qa_error.Error("Can't find instance check script '%s': %s" %
-                           (check, err))
+      value = self._data[_ENABLED_HV_KEY]
+    except KeyError:
+      return [constants.DEFAULT_ENABLED_HYPERVISOR]
+    else:
+      if value is None:
+        return []
+      elif isinstance(value, basestring):
+        # The configuration key ("enabled-hypervisors") implies there can be
+        # multiple values. Multiple hypervisors are comma-separated on the
+        # command line option to "gnt-cluster init", so we need to handle them
+        # equally here.
+        return value.split(",")
+      else:
+        return value
+
+  def GetDefaultHypervisor(self):
+    """Returns the default hypervisor to be used.
+
+    """
+    return self.GetEnabledHypervisors()[0]
+
+
+def Load(path):
+  """Loads the passed configuration file.
+
+  """
+  global _config # pylint: disable=W0603
 
-  enabled_hv = frozenset(GetEnabledHypervisors())
-  if not enabled_hv:
-    raise qa_error.Error("No hypervisor is enabled")
+  _config = _QaConfig.Load(path)
 
-  difference = enabled_hv - constants.HYPER_TYPES
-  if difference:
-    raise qa_error.Error("Unknown hypervisor(s) enabled: %s" %
-                         utils.CommaJoin(difference))
+
+def GetConfig():
+  """Returns the configuration object.
+
+  """
+  if _config is None:
+    raise RuntimeError("Configuration not yet loaded")
+
+  return _config
 
 
 def get(name, default=None):
-  return cfg.get(name, default)
+  """Wrapper for L{_QaConfig.get}.
+
+  """
+  return GetConfig().get(name, default=default)
 
 
 class Either:
@@ -148,10 +249,12 @@ def TestEnabled(tests, _cfg=None):
 
   """
   if _cfg is None:
-    _cfg = cfg
+    cfg = GetConfig()
+  else:
+    cfg = _cfg
 
   # Get settings for all tests
-  cfg_tests = _cfg.get("tests", {})
+  cfg_tests = cfg.get("tests", {})
 
   # Get default setting
   default = cfg_tests.get("default", True)
@@ -160,39 +263,25 @@ def TestEnabled(tests, _cfg=None):
                            tests, compat.all)
 
 
-def GetInstanceCheckScript():
-  """Returns path to instance check script or C{None}.
+def GetInstanceCheckScript(*args):
+  """Wrapper for L{_QaConfig.GetInstanceCheckScript}.
 
   """
-  return cfg.get(_INSTANCE_CHECK_KEY, None)
+  return GetConfig().GetInstanceCheckScript(*args)
 
 
-def GetEnabledHypervisors():
-  """Returns list of enabled hypervisors.
-
-  @rtype: list
+def GetEnabledHypervisors(*args):
+  """Wrapper for L{_QaConfig.GetEnabledHypervisors}.
 
   """
-  try:
-    value = cfg[_ENABLED_HV_KEY]
-  except KeyError:
-    return [constants.DEFAULT_ENABLED_HYPERVISOR]
-  else:
-    if isinstance(value, basestring):
-      # The configuration key ("enabled-hypervisors") implies there can be
-      # multiple values. Multiple hypervisors are comma-separated on the
-      # command line option to "gnt-cluster init", so we need to handle them
-      # equally here.
-      return value.split(",")
-    else:
-      return value
+  return GetConfig().GetEnabledHypervisors(*args)
 
 
-def GetDefaultHypervisor():
-  """Returns the default hypervisor to be used.
+def GetDefaultHypervisor(*args):
+  """Wrapper for L{_QaConfig.GetDefaultHypervisor}.
 
   """
-  return GetEnabledHypervisors()[0]
+  return GetConfig().GetDefaultHypervisor(*args)
 
 
 def GetInstanceNicMac(inst, default=None):
@@ -203,7 +292,10 @@ def GetInstanceNicMac(inst, default=None):
 
 
 def GetMasterNode():
-  return cfg["nodes"][0]
+  """Wrapper for L{_QaConfig.GetMasterNode}.
+
+  """
+  return GetConfig().GetMasterNode()
 
 
 def AcquireInstance():
@@ -212,7 +304,7 @@ def AcquireInstance():
   """
   # Filter out unwanted instances
   tmp_flt = lambda inst: not inst.get("_used", False)
-  instances = filter(tmp_flt, cfg["instances"])
+  instances = filter(tmp_flt, GetConfig()["instances"])
   del tmp_flt
 
   if len(instances) == 0:
@@ -263,7 +355,7 @@ def GetExclusiveStorage():
 
 
 def IsTemplateSupported(templ):
-  """Is the given templated supported by the current configuration?
+  """Is the given disk template supported by the current configuration?
 
   """
   if GetExclusiveStorage():
@@ -277,6 +369,7 @@ def AcquireNode(exclude=None):
 
   """
   master = GetMasterNode()
+  cfg = GetConfig()
 
   # Filter out unwanted nodes
   # TODO: Maybe combine filters
index fd73322..2755719 100755 (executable)
@@ -1,7 +1,7 @@
 #!/usr/bin/python
 #
 
-# Copyright (C) 2012 Google Inc.
+# Copyright (C) 2012, 2013 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 """Script for testing qa.qa_config"""
 
 import unittest
+import tempfile
+import shutil
+import os
+
+from ganeti import utils
+from ganeti import serializer
+from ganeti import constants
+from ganeti import compat
 
 from qa import qa_config
+from qa import qa_error
 
 import testutils
 
@@ -133,5 +142,115 @@ class TestTestEnabled(unittest.TestCase):
         }))
 
 
+class TestQaConfigLoad(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  def testLoadNonExistent(self):
+    filename = utils.PathJoin(self.tmpdir, "does.not.exist")
+    self.assertRaises(EnvironmentError, qa_config._QaConfig.Load, filename)
+
+  @staticmethod
+  def _WriteConfig(filename, data):
+    utils.WriteFile(filename, data=serializer.DumpJson(data))
+
+  def _CheckLoadError(self, filename, data, expected):
+    self._WriteConfig(filename, data)
+
+    try:
+      qa_config._QaConfig.Load(filename)
+    except qa_error.Error, err:
+      self.assertTrue(str(err).startswith(expected))
+    else:
+      self.fail("Exception was not raised")
+
+  def testFailsValidation(self):
+    filename = utils.PathJoin(self.tmpdir, "qa.json")
+    testconfig = {}
+
+    check_fn = compat.partial(self._CheckLoadError, filename, testconfig)
+
+    # No nodes
+    check_fn("Need at least one node")
+
+    testconfig["nodes"] = [
+      {
+        "primary": "xen-test-0",
+        "secondary": "192.0.2.1",
+        },
+      ]
+
+    # No instances
+    check_fn("Need at least one instance")
+
+    testconfig["instances"] = [
+      {
+        "name": "xen-test-inst1",
+        },
+      ]
+
+    # Missing "disk" and "disk-growth"
+    check_fn("Config options 'disk' and 'disk-growth' ")
+
+    testconfig["disk"] = []
+    testconfig["disk-growth"] = testconfig["disk"]
+
+    # Minimal accepted configuration
+    self._WriteConfig(filename, testconfig)
+    result = qa_config._QaConfig.Load(filename)
+    self.assertTrue(result.get("nodes"))
+
+    # Non-existent instance check script
+    testconfig[qa_config._INSTANCE_CHECK_KEY] = \
+      utils.PathJoin(self.tmpdir, "instcheck")
+    check_fn("Can't find instance check script")
+    del testconfig[qa_config._INSTANCE_CHECK_KEY]
+
+    # No enabled hypervisor
+    testconfig[qa_config._ENABLED_HV_KEY] = None
+    check_fn("No hypervisor is enabled")
+
+    # Unknown hypervisor
+    testconfig[qa_config._ENABLED_HV_KEY] = ["#unknownhv#"]
+    check_fn("Unknown hypervisor(s) enabled:")
+
+
+class TestQaConfigWithSampleConfig(unittest.TestCase):
+  """Tests using C{qa-sample.json}.
+
+  This test case serves two purposes:
+
+    - Ensure shipped C{qa-sample.json} file is considered a valid QA
+      configuration
+    - Test some functions of L{qa_config._QaConfig} without having to
+      mock a whole configuration file
+
+  """
+  def setUp(self):
+    filename = "%s/qa/qa-sample.json" % testutils.GetSourceDir()
+
+    self.config = qa_config._QaConfig.Load(filename)
+
+  def testGetEnabledHypervisors(self):
+    self.assertEqual(self.config.GetEnabledHypervisors(),
+                     [constants.DEFAULT_ENABLED_HYPERVISOR])
+
+  def testGetDefaultHypervisor(self):
+    self.assertEqual(self.config.GetDefaultHypervisor(),
+                     constants.DEFAULT_ENABLED_HYPERVISOR)
+
+  def testGetInstanceCheckScript(self):
+    self.assertTrue(self.config.GetInstanceCheckScript() is None)
+
+  def testGetAndGetItem(self):
+    self.assertEqual(self.config["nodes"], self.config.get("nodes"))
+
+  def testGetMasterNode(self):
+    self.assertEqual(self.config.GetMasterNode(), self.config["nodes"][0])
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()