Handle empty patches better
[ganeti-local] / qa / qa_config.py
index 92f9584..4bb3dce 100644 (file)
@@ -32,6 +32,7 @@ from ganeti import compat
 from ganeti import ht
 
 import qa_error
+import qa_logging
 
 
 _INSTANCE_CHECK_KEY = "instance-check"
@@ -40,9 +41,12 @@ _VCLUSTER_MASTER_KEY = "vcluster-master"
 _VCLUSTER_BASEDIR_KEY = "vcluster-basedir"
 _ENABLED_DISK_TEMPLATES_KEY = "enabled-disk-templates"
 
-# The path of an optional JSON Patch file (as per RFC6902) that modifies QA's
+# The constants related to JSON patching (as per RFC6902) that modifies QA's
 # configuration.
-_PATCH_JSON = os.path.join(os.path.dirname(__file__), "qa-patch.json")
+_QA_BASE_PATH = os.path.dirname(__file__)
+_QA_DEFAULT_PATCH = "qa-patch.json"
+_QA_PATCH_DIR = "patch"
+_QA_PATCH_ORDER_FILE = "order"
 
 #: QA configuration (L{_QaConfig})
 _config = None
@@ -254,6 +258,114 @@ class _QaConfig(object):
     #: Cluster-wide run-time value of the exclusive storage flag
     self._exclusive_storage = None
 
+  @staticmethod
+  def LoadPatch(patch_dict, rel_path):
+    """ Loads a single patch.
+
+    @type patch_dict: dict of string to dict
+    @param patch_dict: A dictionary storing patches by relative path.
+    @type rel_path: string
+    @param rel_path: The relative path to the patch, might or might not exist.
+
+    """
+    try:
+      full_path = os.path.join(_QA_BASE_PATH, rel_path)
+      patch = serializer.LoadJson(utils.ReadFile(full_path))
+      patch_dict[rel_path] = patch
+    except IOError:
+      pass
+
+  @staticmethod
+  def LoadPatches():
+    """ Finds and loads all patches supported by the QA.
+
+    @rtype: dict of string to dict
+    @return: A dictionary of relative path to patch content.
+
+    """
+    patches = {}
+    _QaConfig.LoadPatch(patches, _QA_DEFAULT_PATCH)
+    patch_dir_path = os.path.join(_QA_BASE_PATH, _QA_PATCH_DIR)
+    if os.path.exists(patch_dir_path):
+      for filename in os.listdir(patch_dir_path):
+        if filename.endswith(".json"):
+          _QaConfig.LoadPatch(patches, os.path.join(_QA_PATCH_DIR, filename))
+    return patches
+
+  @staticmethod
+  def ApplyPatch(data, patch_module, patches, patch_path):
+    """Applies a single patch.
+
+    @type data: dict (deserialized json)
+    @param data: The QA configuration
+    @type patch_module: module
+    @param patch_module: The json patch module, loaded dynamically
+    @type patches: dict of string to dict
+    @param patches: The dictionary of patch path to content
+    @type patch_path: string
+    @param patch_path: The path to the patch, relative to the QA directory
+
+    @return: The modified configuration data.
+
+    """
+    patch_content = patches[patch_path]
+    print qa_logging.FormatInfo("Applying patch %s" % patch_path)
+    if not patch_content and patch_path != _QA_DEFAULT_PATCH:
+      print qa_logging.FormatWarning("The patch %s added by the user is empty" %
+                                     patch_path)
+    data = patch_module.apply_patch(data, patch_content)
+
+  @staticmethod
+  def ApplyPatches(data, patch_module, patches):
+    """Applies any patches present, and returns the modified QA configuration.
+
+    First, patches from the patch directory are applied. They are ordered
+    alphabetically, unless there is an ``order`` file present - any patches
+    listed within are applied in that order, and any remaining ones in
+    alphabetical order again. Finally, the default patch residing in the
+    top-level QA directory is applied.
+
+    @type data: dict (deserialized json)
+    @param data: The QA configuration
+    @type patch_module: module
+    @param patch_module: The json patch module, loaded dynamically
+    @type patches: dict of string to dict
+    @param patches: The dictionary of patch path to content
+
+    @return: The modified configuration data.
+
+    """
+    ordered_patches = []
+    order_path = os.path.join(_QA_BASE_PATH, _QA_PATCH_DIR,
+                              _QA_PATCH_ORDER_FILE)
+    if os.path.exists(order_path):
+      order_file = open(order_path, 'r')
+      ordered_patches = order_file.read().splitlines()
+      # Removes empty lines
+      ordered_patches = filter(None, ordered_patches)
+
+    # Add the patch dir
+    ordered_patches = map(lambda x: os.path.join(_QA_PATCH_DIR, x),
+                          ordered_patches)
+
+    # First the ordered patches
+    for patch in ordered_patches:
+      if patch not in patches:
+        raise qa_error.Error("Patch %s specified in the ordering file does not "
+                             "exist" % patch)
+      _QaConfig.ApplyPatch(data, patch_module, patches, patch)
+
+    # Then the other non-default ones
+    for patch in sorted(patches):
+      if patch != _QA_DEFAULT_PATCH and patch not in ordered_patches:
+        _QaConfig.ApplyPatch(data, patch_module, patches, patch)
+
+    # Finally the default one
+    if _QA_DEFAULT_PATCH in patches:
+      _QaConfig.ApplyPatch(data, patch_module, patches, _QA_DEFAULT_PATCH)
+
+    return data
+
   @classmethod
   def Load(cls, filename):
     """Loads a configuration file and produces a configuration object.
@@ -268,16 +380,17 @@ class _QaConfig(object):
     # Patch the document using JSON Patch (RFC6902) in file _PATCH_JSON, if
     # available
     try:
-      patch = serializer.LoadJson(utils.ReadFile(_PATCH_JSON))
-      if patch:
+      patches = _QaConfig.LoadPatches()
+      # Try to use the module only if there is a non-empty patch present
+      if any(patches.values()):
         mod = __import__("jsonpatch", fromlist=[])
-        data = mod.apply_patch(data, patch)
+        data = _QaConfig.ApplyPatches(data, mod, patches)
     except IOError:
       pass
     except ImportError:
-      raise qa_error.Error("If you want to use the QA JSON patching feature,"
-                           " you need to install Python modules"
-                           " 'jsonpatch' and 'jsonpointer'.")
+      raise qa_error.Error("For the QA JSON patching feature to work, you "
+                           "need to install Python modules 'jsonpatch' and "
+                           "'jsonpointer'.")
 
     result = cls(dict(map(_ConvertResources,
                           data.items()))) # pylint: disable=E1103