backend: Check for shared storage also
[ganeti-local] / test / ganeti.utils.algo_unittest.py
index b4e3a64..5d08e2e 100755 (executable)
@@ -26,6 +26,7 @@ import random
 import operator
 
 from ganeti import constants
+from ganeti import compat
 from ganeti.utils import algo
 
 import testutils
@@ -229,6 +230,23 @@ class TestNiceSort(unittest.TestCase):
                       None, ""])
 
 
+class TestInvertDict(unittest.TestCase):
+  def testInvertDict(self):
+    test_dict = { "foo": 1, "bar": 2, "baz": 5 }
+    self.assertEqual(algo.InvertDict(test_dict),
+                     { 1: "foo", 2: "bar", 5: "baz"})
+
+
+class TestInsertAtPos(unittest.TestCase):
+  def test(self):
+    a = [1, 5, 6]
+    b = [2, 3, 4]
+    self.assertEqual(algo.InsertAtPos(a, 1, b), [1, 2, 3, 4, 5, 6])
+    self.assertEqual(algo.InsertAtPos(a, 0, b), b + a)
+    self.assertEqual(algo.InsertAtPos(a, len(a), b), a + b)
+    self.assertEqual(algo.InsertAtPos(a, 2, b), [1, 5, 2, 3, 4, 6])
+
+
 class TimeMock:
   def __init__(self, values):
     self.values = values
@@ -265,5 +283,90 @@ class TestRunningTimeout(unittest.TestCase):
     self.assertRaises(ValueError, algo.RunningTimeout, -1.0, True)
 
 
+class TestJoinDisjointDicts(unittest.TestCase):
+  def setUp(self):
+    self.non_empty_dict = {"a": 1, "b": 2}
+    self.empty_dict = dict()
+
+  def testWithEmptyDicts(self):
+    self.assertEqual(self.empty_dict, algo.JoinDisjointDicts(self.empty_dict,
+      self.empty_dict))
+    self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts(
+      self.empty_dict, self.non_empty_dict))
+    self.assertEqual(self.non_empty_dict, algo.JoinDisjointDicts(
+      self.non_empty_dict, self.empty_dict))
+
+  def testNonDisjoint(self):
+    self.assertRaises(AssertionError, algo.JoinDisjointDicts,
+      self.non_empty_dict, self.non_empty_dict)
+
+  def testCommonCase(self):
+    dict_a = {"TEST1": 1, "TEST2": 2}
+    dict_b = {"TEST3": 3, "TEST4": 4}
+
+    result = dict_a.copy()
+    result.update(dict_b)
+
+    self.assertEqual(result, algo.JoinDisjointDicts(dict_a, dict_b))
+    self.assertEqual(result, algo.JoinDisjointDicts(dict_b, dict_a))
+
+
+class TestSequenceToDict(unittest.TestCase):
+  def testEmpty(self):
+    self.assertEqual(algo.SequenceToDict([]), {})
+    self.assertEqual(algo.SequenceToDict({}), {})
+
+  def testSimple(self):
+    data = [(i, str(i), "test%s" % i) for i in range(391)]
+    self.assertEqual(algo.SequenceToDict(data),
+      dict((i, (i, str(i), "test%s" % i))
+           for i in range(391)))
+
+  def testCustomKey(self):
+    data = [(i, hex(i), "test%s" % i) for i in range(100)]
+    self.assertEqual(algo.SequenceToDict(data, key=compat.snd),
+      dict((hex(i), (i, hex(i), "test%s" % i))
+           for i in range(100)))
+    self.assertEqual(algo.SequenceToDict(data,
+                                         key=lambda (a, b, val): hash(val)),
+      dict((hash("test%s" % i), (i, hex(i), "test%s" % i))
+           for i in range(100)))
+
+  def testDuplicate(self):
+    self.assertRaises(ValueError, algo.SequenceToDict,
+                      [(0, 0), (0, 0)])
+    self.assertRaises(ValueError, algo.SequenceToDict,
+                      [(i, ) for i in range(200)] + [(10, )])
+
+
+class TestFlatToDict(unittest.TestCase):
+  def testNormal(self):
+    data = [
+      ("lv/xenvg", {"foo": "bar", "bar": "baz"}),
+      ("lv/xenfoo", {"foo": "bar", "baz": "blubb"}),
+      ("san/foo", {"ip": "127.0.0.1", "port": 1337}),
+      ("san/blubb/blibb", 54),
+      ]
+    reference = {
+      "lv": {
+        "xenvg": {"foo": "bar", "bar": "baz"},
+        "xenfoo": {"foo": "bar", "baz": "blubb"},
+        },
+      "san": {
+        "foo": {"ip": "127.0.0.1", "port": 1337},
+        "blubb": {"blibb": 54},
+        },
+      }
+    self.assertEqual(algo.FlatToDict(data), reference)
+
+  def testUnlikeDepth(self):
+    data = [
+      ("san/foo", {"ip": "127.0.0.1", "port": 1337}),
+      ("san/foo/blubb", 23), # Another foo entry under san
+      ("san/blubb/blibb", 54),
+      ]
+    self.assertRaises(AssertionError, algo.FlatToDict, data)
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()