Convert OsDiagnose to query
authorMichael Hanselmann <hansmi@google.com>
Mon, 7 Mar 2011 13:48:46 +0000 (14:48 +0100)
committerMichael Hanselmann <hansmi@google.com>
Mon, 14 Mar 2011 10:50:47 +0000 (11:50 +0100)
Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>

doc/design-query2.rst
lib/cmdlib.py
lib/constants.py
lib/query.py
test/ganeti.query_unittest.py

index 22d02c2..a307e9f 100644 (file)
@@ -100,6 +100,8 @@ items:
   Jobs
 ``lock``
   Locks
+``os``
+  Operating systems
 
 .. _data-query:
 
index f5e4134..b967979 100644 (file)
@@ -3410,37 +3410,28 @@ class LUOobCommand(NoHooksLU):
       raise errors.OpExecError("Check of out-of-band payload failed due to %s" %
                                utils.CommaJoin(errs))
 
+class _OsQuery(_QueryBase):
+  FIELDS = query.OS_FIELDS
 
-
-class LUOsDiagnose(NoHooksLU):
-  """Logical unit for OS diagnose/query.
-
-  """
-  REQ_BGL = False
-  _HID = "hidden"
-  _BLK = "blacklisted"
-  _VLD = "valid"
-  _FIELDS_STATIC = utils.FieldSet()
-  _FIELDS_DYNAMIC = utils.FieldSet("name", _VLD, "node_status", "variants",
-                                   "parameters", "api_versions", _HID, _BLK)
-
-  def CheckArguments(self):
-    if self.op.names:
-      raise errors.OpPrereqError("Selective OS query not supported",
-                                 errors.ECODE_INVAL)
-
-    _CheckOutputFields(static=self._FIELDS_STATIC,
-                       dynamic=self._FIELDS_DYNAMIC,
-                       selected=self.op.output_fields)
-
-  def ExpandNames(self):
-    # Lock all nodes, in shared mode
+  def ExpandNames(self, lu):
+    # Lock all nodes in shared mode
     # Temporary removal of locks, should be reverted later
     # TODO: reintroduce locks when they are lighter-weight
-    self.needed_locks = {}
+    lu.needed_locks = {}
     #self.share_locks[locking.LEVEL_NODE] = 1
     #self.needed_locks[locking.LEVEL_NODE] = locking.ALL_SET
 
+    # The following variables interact with _QueryBase._GetNames
+    if self.names:
+      self.wanted = self.names
+    else:
+      self.wanted = locking.ALL_SET
+
+    self.do_locking = self.use_locking
+
+  def DeclareLocks(self, lu, level):
+    pass
+
   @staticmethod
   def _DiagnoseByOS(rlist):
     """Remaps a per-node return list into an a per-os per-node dictionary
@@ -3481,71 +3472,87 @@ class LUOsDiagnose(NoHooksLU):
                                         variants, params, api_versions))
     return all_os
 
-  def Exec(self, feedback_fn):
-    """Compute the list of OSes.
+  def _GetQueryData(self, lu):
+    """Computes the list of nodes and their attributes.
 
     """
+    # Locking is not used
+    assert not (lu.acquired_locks or self.do_locking or self.use_locking)
+
+    # Used further down
+    assert "valid" in self.FIELDS
+    assert "hidden" in self.FIELDS
+    assert "blacklisted" in self.FIELDS
+
     valid_nodes = [node.name
-                   for node in self.cfg.GetAllNodesInfo().values()
+                   for node in lu.cfg.GetAllNodesInfo().values()
                    if not node.offline and node.vm_capable]
-    node_data = self.rpc.call_os_diagnose(valid_nodes)
-    pol = self._DiagnoseByOS(node_data)
-    output = []
-    cluster = self.cfg.GetClusterInfo()
+    pol = self._DiagnoseByOS(lu.rpc.call_os_diagnose(valid_nodes))
+    cluster = lu.cfg.GetClusterInfo()
+
+    # Build list of used field names
+    fields = [fdef.name for fdef in self.query.GetFields()]
+
+    data = {}
+
+    for (os_name, os_data) in pol.items():
+      info = query.OsInfo(name=os_name, valid=True, node_status=os_data,
+                          hidden=(os_name in cluster.hidden_os),
+                          blacklisted=(os_name in cluster.blacklisted_os))
+
+      variants = set()
+      parameters = set()
+      api_versions = set()
 
-    for os_name in utils.NiceSort(pol.keys()):
-      os_data = pol[os_name]
-      row = []
-      valid = True
-      (variants, params, api_versions) = null_state = (set(), set(), set())
       for idx, osl in enumerate(os_data.values()):
-        valid = bool(valid and osl and osl[0][1])
-        if not valid:
-          (variants, params, api_versions) = null_state
+        info.valid = bool(info.valid and osl and osl[0][1])
+        if not info.valid:
           break
-        node_variants, node_params, node_api = osl[0][3:6]
-        if idx == 0: # first entry
-          variants = set(node_variants)
-          params = set(node_params)
-          api_versions = set(node_api)
-        else: # keep consistency
+
+        (node_variants, node_params, node_api) = osl[0][3:6]
+        if idx == 0:
+          # First entry
+          variants.update(node_variants)
+          parameters.update(node_params)
+          api_versions.update(node_api)
+        else:
+          # Filter out inconsistent values
           variants.intersection_update(node_variants)
-          params.intersection_update(node_params)
+          parameters.intersection_update(node_params)
           api_versions.intersection_update(node_api)
 
-      is_hid = os_name in cluster.hidden_os
-      is_blk = os_name in cluster.blacklisted_os
-      if ((self._HID not in self.op.output_fields and is_hid) or
-          (self._BLK not in self.op.output_fields and is_blk) or
-          (self._VLD not in self.op.output_fields and not valid)):
+      info.variants = list(variants)
+      info.parameters = list(parameters)
+      info.api_versions = list(api_versions)
+
+      # TODO: Move this to filters provided by the client
+      if (("hidden" not in fields and info.hidden) or
+          ("blacklisted" not in fields and info.blacklisted) or
+          ("valid" not in fields and not info.valid)):
         continue
 
-      for field in self.op.output_fields:
-        if field == "name":
-          val = os_name
-        elif field == self._VLD:
-          val = valid
-        elif field == "node_status":
-          # this is just a copy of the dict
-          val = {}
-          for node_name, nos_list in os_data.items():
-            val[node_name] = nos_list
-        elif field == "variants":
-          val = utils.NiceSort(list(variants))
-        elif field == "parameters":
-          val = list(params)
-        elif field == "api_versions":
-          val = list(api_versions)
-        elif field == self._HID:
-          val = is_hid
-        elif field == self._BLK:
-          val = is_blk
-        else:
-          raise errors.ParameterError(field)
-        row.append(val)
-      output.append(row)
+      data[os_name] = info
 
-    return output
+    # Prepare data in requested order
+    return [data[name] for name in self._GetNames(lu, pol.keys(), None)
+            if name in data]
+
+
+class LUOsDiagnose(NoHooksLU):
+  """Logical unit for OS diagnose/query.
+
+  """
+  REQ_BGL = False
+
+  def CheckArguments(self):
+    self.oq = _OsQuery(qlang.MakeSimpleFilter("name", self.op.names),
+                       self.op.output_fields, False)
+
+  def ExpandNames(self):
+    self.oq.ExpandNames(self)
+
+  def Exec(self, feedback_fn):
+    return self.oq.OldStyleQuery(self)
 
 
 class LUNodeRemove(LogicalUnit):
@@ -11643,8 +11650,11 @@ _QUERY_IMPL = {
   constants.QR_INSTANCE: _InstanceQuery,
   constants.QR_NODE: _NodeQuery,
   constants.QR_GROUP: _GroupQuery,
+  constants.QR_OS: _OsQuery,
   }
 
+assert set(_QUERY_IMPL.keys()) == constants.QR_OP_QUERY
+
 
 def _GetQueryImplementation(name):
   """Returns the implemtnation for a query type.
index 44550e5..0560dec 100644 (file)
@@ -1027,9 +1027,10 @@ QR_INSTANCE = "instance"
 QR_NODE = "node"
 QR_LOCK = "lock"
 QR_GROUP = "group"
+QR_OS = "os"
 
 #: List of resources which can be queried using L{opcodes.OpQuery}
-QR_OP_QUERY = frozenset([QR_INSTANCE, QR_NODE, QR_GROUP])
+QR_OP_QUERY = frozenset([QR_INSTANCE, QR_NODE, QR_GROUP, QR_OS])
 
 #: List of resources which can be queried using Local UniX Interface
 QR_OP_LUXI = QR_OP_QUERY.union([
index 28afec1..a462b7b 100644 (file)
@@ -1888,6 +1888,52 @@ def _BuildGroupFields():
   return _PrepareFieldList(fields, [])
 
 
+class OsInfo(objects.ConfigObject):
+  __slots__ = [
+    "name",
+    "valid",
+    "hidden",
+    "blacklisted",
+    "variants",
+    "api_versions",
+    "parameters",
+    "node_status",
+    ]
+
+
+def _BuildOsFields():
+  """Builds list of fields for operating system queries.
+
+  """
+  fields = [
+    (_MakeField("name", "Name", QFT_TEXT, "Operating system name"),
+     None, 0, _GetItemAttr("name")),
+    (_MakeField("valid", "Valid", QFT_BOOL,
+                "Whether operating system definition is valid"),
+     None, 0, _GetItemAttr("valid")),
+    (_MakeField("hidden", "Hidden", QFT_BOOL,
+                "Whether operating system is hidden"),
+     None, 0, _GetItemAttr("hidden")),
+    (_MakeField("blacklisted", "Blacklisted", QFT_BOOL,
+                "Whether operating system is blacklisted"),
+     None, 0, _GetItemAttr("blacklisted")),
+    (_MakeField("variants", "Variants", QFT_OTHER,
+                "Operating system variants"),
+     None, 0, _ConvWrap(utils.NiceSort, _GetItemAttr("variants"))),
+    (_MakeField("api_versions", "ApiVersions", QFT_OTHER,
+                "Operating system API versions"),
+     None, 0, _ConvWrap(sorted, _GetItemAttr("api_versions"))),
+    (_MakeField("parameters", "Parameters", QFT_OTHER,
+                "Operating system parameters"),
+     None, 0, _ConvWrap(utils.NiceSort, _GetItemAttr("parameters"))),
+    (_MakeField("node_status", "NodeStatus", QFT_OTHER,
+                "Status from node"),
+     None, 0, _GetItemAttr("node_status")),
+    ]
+
+  return _PrepareFieldList(fields, [])
+
+
 #: Fields available for node queries
 NODE_FIELDS = _BuildNodeFields()
 
@@ -1900,12 +1946,16 @@ LOCK_FIELDS = _BuildLockFields()
 #: Fields available for node group queries
 GROUP_FIELDS = _BuildGroupFields()
 
+#: Fields available for operating system queries
+OS_FIELDS = _BuildOsFields()
+
 #: All available resources
 ALL_FIELDS = {
   constants.QR_INSTANCE: INSTANCE_FIELDS,
   constants.QR_NODE: NODE_FIELDS,
   constants.QR_LOCK: LOCK_FIELDS,
   constants.QR_GROUP: GROUP_FIELDS,
+  constants.QR_OS: OS_FIELDS,
   }
 
 #: All available field lists
index 0a2691f..d6233a9 100755 (executable)
@@ -941,6 +941,58 @@ class TestGroupQuery(unittest.TestCase):
                       ])
 
 
+class TestOsQuery(unittest.TestCase):
+  def _Create(self, selected):
+    return query.Query(query.OS_FIELDS, selected)
+
+  def test(self):
+    variants = ["v00", "plain", "v3", "var0", "v33", "v20"]
+    api_versions = [10, 0, 15, 5]
+    parameters = ["zpar3", "apar9"]
+
+    assert variants != sorted(variants) and variants != utils.NiceSort(variants)
+    assert (api_versions != sorted(api_versions) and
+            api_versions != utils.NiceSort(variants))
+    assert (parameters != sorted(parameters) and
+            parameters != utils.NiceSort(parameters))
+
+    data = [
+      query.OsInfo(name="debian", valid=False, hidden=False, blacklisted=False,
+                   variants=set(), api_versions=set(), parameters=set(),
+                   node_status={ "some": "status", }),
+      query.OsInfo(name="dos", valid=True, hidden=False, blacklisted=True,
+                   variants=set(variants),
+                   api_versions=set(api_versions),
+                   parameters=set(parameters),
+                   node_status={ "some": "other", "status": None, }),
+      ]
+
+
+    q = self._Create(["name", "valid", "hidden", "blacklisted", "variants",
+                      "api_versions", "parameters", "node_status"])
+    self.assertEqual(q.RequestedData(), set([]))
+    self.assertEqual(q.Query(data),
+                     [[(constants.RS_NORMAL, "debian"),
+                       (constants.RS_NORMAL, False),
+                       (constants.RS_NORMAL, False),
+                       (constants.RS_NORMAL, False),
+                       (constants.RS_NORMAL, []),
+                       (constants.RS_NORMAL, []),
+                       (constants.RS_NORMAL, []),
+                       (constants.RS_NORMAL, {"some": "status"})],
+                      [(constants.RS_NORMAL, "dos"),
+                       (constants.RS_NORMAL, True),
+                       (constants.RS_NORMAL, False),
+                       (constants.RS_NORMAL, True),
+                       (constants.RS_NORMAL,
+                        ["plain", "v00", "v3", "v20", "v33", "var0"]),
+                       (constants.RS_NORMAL, [0, 5, 10, 15]),
+                       (constants.RS_NORMAL, ["apar9", "zpar3"]),
+                       (constants.RS_NORMAL,
+                        { "some": "other", "status": None, })
+                       ]])
+
+
 class TestQueryFields(unittest.TestCase):
   def testAllFields(self):
     for fielddefs in query.ALL_FIELD_LISTS: