Handle empty patches better
[ganeti-local] / qa / qa_config.py
index e8f2789..4bb3dce 100644 (file)
@@ -32,6 +32,7 @@ from ganeti import compat
 from ganeti import ht
 
 import qa_error
+import qa_logging
 
 
 _INSTANCE_CHECK_KEY = "instance-check"
@@ -40,6 +41,13 @@ _VCLUSTER_MASTER_KEY = "vcluster-master"
 _VCLUSTER_BASEDIR_KEY = "vcluster-basedir"
 _ENABLED_DISK_TEMPLATES_KEY = "enabled-disk-templates"
 
+# The constants related to JSON patching (as per RFC6902) that modifies QA's
+# configuration.
+_QA_BASE_PATH = os.path.dirname(__file__)
+_QA_DEFAULT_PATCH = "qa-patch.json"
+_QA_PATCH_DIR = "patch"
+_QA_PATCH_ORDER_FILE = "order"
+
 #: QA configuration (L{_QaConfig})
 _config = None
 
@@ -250,6 +258,114 @@ class _QaConfig(object):
     #: Cluster-wide run-time value of the exclusive storage flag
     self._exclusive_storage = None
 
+  @staticmethod
+  def LoadPatch(patch_dict, rel_path):
+    """ Loads a single patch.
+
+    @type patch_dict: dict of string to dict
+    @param patch_dict: A dictionary storing patches by relative path.
+    @type rel_path: string
+    @param rel_path: The relative path to the patch, might or might not exist.
+
+    """
+    try:
+      full_path = os.path.join(_QA_BASE_PATH, rel_path)
+      patch = serializer.LoadJson(utils.ReadFile(full_path))
+      patch_dict[rel_path] = patch
+    except IOError:
+      pass
+
+  @staticmethod
+  def LoadPatches():
+    """ Finds and loads all patches supported by the QA.
+
+    @rtype: dict of string to dict
+    @return: A dictionary of relative path to patch content.
+
+    """
+    patches = {}
+    _QaConfig.LoadPatch(patches, _QA_DEFAULT_PATCH)
+    patch_dir_path = os.path.join(_QA_BASE_PATH, _QA_PATCH_DIR)
+    if os.path.exists(patch_dir_path):
+      for filename in os.listdir(patch_dir_path):
+        if filename.endswith(".json"):
+          _QaConfig.LoadPatch(patches, os.path.join(_QA_PATCH_DIR, filename))
+    return patches
+
+  @staticmethod
+  def ApplyPatch(data, patch_module, patches, patch_path):
+    """Applies a single patch.
+
+    @type data: dict (deserialized json)
+    @param data: The QA configuration
+    @type patch_module: module
+    @param patch_module: The json patch module, loaded dynamically
+    @type patches: dict of string to dict
+    @param patches: The dictionary of patch path to content
+    @type patch_path: string
+    @param patch_path: The path to the patch, relative to the QA directory
+
+    @return: The modified configuration data.
+
+    """
+    patch_content = patches[patch_path]
+    print qa_logging.FormatInfo("Applying patch %s" % patch_path)
+    if not patch_content and patch_path != _QA_DEFAULT_PATCH:
+      print qa_logging.FormatWarning("The patch %s added by the user is empty" %
+                                     patch_path)
+    data = patch_module.apply_patch(data, patch_content)
+
+  @staticmethod
+  def ApplyPatches(data, patch_module, patches):
+    """Applies any patches present, and returns the modified QA configuration.
+
+    First, patches from the patch directory are applied. They are ordered
+    alphabetically, unless there is an ``order`` file present - any patches
+    listed within are applied in that order, and any remaining ones in
+    alphabetical order again. Finally, the default patch residing in the
+    top-level QA directory is applied.
+
+    @type data: dict (deserialized json)
+    @param data: The QA configuration
+    @type patch_module: module
+    @param patch_module: The json patch module, loaded dynamically
+    @type patches: dict of string to dict
+    @param patches: The dictionary of patch path to content
+
+    @return: The modified configuration data.
+
+    """
+    ordered_patches = []
+    order_path = os.path.join(_QA_BASE_PATH, _QA_PATCH_DIR,
+                              _QA_PATCH_ORDER_FILE)
+    if os.path.exists(order_path):
+      order_file = open(order_path, 'r')
+      ordered_patches = order_file.read().splitlines()
+      # Removes empty lines
+      ordered_patches = filter(None, ordered_patches)
+
+    # Add the patch dir
+    ordered_patches = map(lambda x: os.path.join(_QA_PATCH_DIR, x),
+                          ordered_patches)
+
+    # First the ordered patches
+    for patch in ordered_patches:
+      if patch not in patches:
+        raise qa_error.Error("Patch %s specified in the ordering file does not "
+                             "exist" % patch)
+      _QaConfig.ApplyPatch(data, patch_module, patches, patch)
+
+    # Then the other non-default ones
+    for patch in sorted(patches):
+      if patch != _QA_DEFAULT_PATCH and patch not in ordered_patches:
+        _QaConfig.ApplyPatch(data, patch_module, patches, patch)
+
+    # Finally the default one
+    if _QA_DEFAULT_PATCH in patches:
+      _QaConfig.ApplyPatch(data, patch_module, patches, _QA_DEFAULT_PATCH)
+
+    return data
+
   @classmethod
   def Load(cls, filename):
     """Loads a configuration file and produces a configuration object.
@@ -261,6 +377,21 @@ class _QaConfig(object):
     """
     data = serializer.LoadJson(utils.ReadFile(filename))
 
+    # Patch the document using JSON Patch (RFC6902) in file _PATCH_JSON, if
+    # available
+    try:
+      patches = _QaConfig.LoadPatches()
+      # Try to use the module only if there is a non-empty patch present
+      if any(patches.values()):
+        mod = __import__("jsonpatch", fromlist=[])
+        data = _QaConfig.ApplyPatches(data, mod, patches)
+    except IOError:
+      pass
+    except ImportError:
+      raise qa_error.Error("For the QA JSON patching feature to work, you "
+                           "need to install Python modules 'jsonpatch' and "
+                           "'jsonpointer'.")
+
     result = cls(dict(map(_ConvertResources,
                           data.items()))) # pylint: disable=E1103
     result.Validate()
@@ -280,12 +411,14 @@ class _QaConfig(object):
     if not self.get("instances"):
       raise qa_error.Error("Need at least one instance")
 
-    if (self.get("disk") is None or
-        self.get("disk-growth") is None or
-        len(self.get("disk")) != len(self.get("disk-growth"))):
-      raise qa_error.Error("Config options 'disk' and 'disk-growth' must exist"
-                           " and have the same number of items")
-
+    disks = self.GetDiskOptions()
+    if disks is None:
+      raise qa_error.Error("Config option 'disks' must exist")
+    else:
+      for d in disks:
+        if d.get("size") is None or d.get("growth") is None:
+          raise qa_error.Error("Config options `size` and `growth` must exist"
+                               " for all `disks` items")
     check = self.GetInstanceCheckScript()
     if check:
       try:
@@ -412,8 +545,9 @@ class _QaConfig(object):
     """Is the given disk template supported by the current configuration?
 
     """
-    return (not self.GetExclusiveStorage() or
-            templ in constants.DTS_EXCL_STORAGE)
+    enabled = templ in self.GetEnabledDiskTemplates()
+    return enabled and (not self.GetExclusiveStorage() or
+                        templ in constants.DTS_EXCL_STORAGE)
 
   def GetVclusterSettings(self):
     """Returns settings for virtual cluster.
@@ -424,6 +558,32 @@ class _QaConfig(object):
 
     return (master, basedir)
 
+  def GetDiskOptions(self):
+    """Return options for the disks of the instances.
+
+    Get 'disks' parameter from the configuration data. If 'disks' is missing,
+    try to create it from the legacy 'disk' and 'disk-growth' parameters.
+
+    """
+    try:
+      return self._data["disks"]
+    except KeyError:
+      pass
+
+    # Legacy interface
+    sizes = self._data.get("disk")
+    growths = self._data.get("disk-growth")
+    if sizes or growths:
+      if (sizes is None or growths is None or len(sizes) != len(growths)):
+        raise qa_error.Error("Config options 'disk' and 'disk-growth' must"
+                             " exist and have the same number of items")
+      disks = []
+      for (size, growth) in zip(sizes, growths):
+        disks.append({"size": size, "growth": growth})
+      return disks
+    else:
+      return None
+
 
 def Load(path):
   """Loads the passed configuration file.
@@ -608,7 +768,7 @@ def GetExclusiveStorage():
 
 
 def IsTemplateSupported(templ):
-  """Wrapper for L{_QaConfig.GetExclusiveStorage}.
+  """Wrapper for L{_QaConfig.IsTemplateSupported}.
 
   """
   return GetConfig().IsTemplateSupported(templ)
@@ -717,3 +877,10 @@ def NoVirtualCluster():
 
   """
   return not UseVirtualCluster()
+
+
+def GetDiskOptions():
+  """Wrapper for L{_QaConfig.GetDiskOptions}.
+
+  """
+  return GetConfig().GetDiskOptions()