ht: Add strict check for dictionaries
authorMichael Hanselmann <hansmi@google.com>
Tue, 17 May 2011 14:09:11 +0000 (16:09 +0200)
committerMichael Hanselmann <hansmi@google.com>
Tue, 17 May 2011 15:05:48 +0000 (17:05 +0200)
This allows checking specific dictionary items, unlike TDict
or TDictOf.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>

lib/ht.py
test/ganeti.ht_unittest.py

index f2fc0b3..a0c787e 100644 (file)
--- a/lib/ht.py
+++ b/lib/ht.py
@@ -292,3 +292,51 @@ def TDictOf(key_type, val_type):
             compat.all(val_type(v) for v in container.values()))
 
   return desc(TAnd(TDict, fn))
+
+
+def _TStrictDictCheck(require_all, exclusive, items, val):
+  """Helper function for L{TStrictDict}.
+
+  """
+  notfound_fn = lambda _: not exclusive
+
+  if require_all and not frozenset(val.keys()).issuperset(items.keys()):
+    # Requires items not found in value
+    return False
+
+  return compat.all(items.get(key, notfound_fn)(value)
+                    for (key, value) in val.items())
+
+
+def TStrictDict(require_all, exclusive, items):
+  """Strict dictionary check with specific keys.
+
+  @type require_all: boolean
+  @param require_all: Whether all keys in L{items} are required
+  @type exclusive: boolean
+  @param exclusive: Whether only keys listed in L{items} should be accepted
+  @type items: dictionary
+  @param items: Mapping from key (string) to verification function
+
+  """
+  descparts = ["Dictionary containing"]
+
+  if exclusive:
+    descparts.append(" none but the")
+
+  if require_all:
+    descparts.append(" required")
+
+  if len(items) == 1:
+    descparts.append(" key ")
+  else:
+    descparts.append(" keys ")
+
+  descparts.append(utils.CommaJoin("\"%s\" (value %s)" % (key, value)
+                                   for (key, value) in items.items()))
+
+  desc = WithDesc("".join(descparts))
+
+  return desc(TAnd(TDict,
+                   compat.partial(_TStrictDictCheck, require_all, exclusive,
+                                  items)))
index 34ae671..1dba316 100755 (executable)
@@ -187,6 +187,49 @@ class TestTypeChecks(unittest.TestCase):
     self.assertFalse(fn({"x": None}))
     self.assertFalse(fn({"": 8234}))
 
+  def testStrictDictRequireAllExclusive(self):
+    fn = ht.TStrictDict(True, True, { "a": ht.TInt, })
+    self.assertFalse(fn(1))
+    self.assertFalse(fn(None))
+    self.assertFalse(fn({}))
+    self.assertFalse(fn({"a": "Hello", }))
+    self.assertFalse(fn({"unknown": 999,}))
+    self.assertFalse(fn({"unknown": None,}))
+
+    self.assertTrue(fn({"a": 123, }))
+    self.assertTrue(fn({"a": -5, }))
+
+    fn = ht.TStrictDict(True, True, { "a": ht.TInt, "x": ht.TString, })
+    self.assertFalse(fn({}))
+    self.assertFalse(fn({"a": -5, }))
+    self.assertTrue(fn({"a": 123, "x": "", }))
+    self.assertFalse(fn({"a": 123, "x": None, }))
+
+  def testStrictDictExclusive(self):
+    fn = ht.TStrictDict(False, True, { "a": ht.TInt, "b": ht.TList, })
+    self.assertTrue(fn({}))
+    self.assertTrue(fn({"a": 123, }))
+    self.assertTrue(fn({"b": range(4), }))
+    self.assertFalse(fn({"b": 123, }))
+
+    self.assertFalse(fn({"foo": {}, }))
+    self.assertFalse(fn({"bar": object(), }))
+
+  def testStrictDictRequireAll(self):
+    fn = ht.TStrictDict(True, False, { "a": ht.TInt, "m": ht.TInt, })
+    self.assertTrue(fn({"a": 1, "m": 2, "bar": object(), }))
+    self.assertFalse(fn({}))
+    self.assertFalse(fn({"a": 1, "bar": object(), }))
+    self.assertFalse(fn({"a": 1, "m": [], "bar": object(), }))
+
+  def testStrictDict(self):
+    fn = ht.TStrictDict(False, False, { "a": ht.TInt, })
+    self.assertTrue(fn({}))
+    self.assertFalse(fn({"a": ""}))
+    self.assertTrue(fn({"a": 11}))
+    self.assertTrue(fn({"other": 11}))
+    self.assertTrue(fn({"other": object()}))
+
 
 if __name__ == "__main__":
   testutils.GanetiTestProgram()