Reason trail implementation for "shutdown"
[ganeti-local] / test / py / ganeti.cmdlib_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2008, 2011, 2012, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the cmdlib module"""
23
24
25 import os
26 import unittest
27 import time
28 import tempfile
29 import shutil
30 import operator
31 import itertools
32 import copy
33
34 from ganeti import constants
35 from ganeti import mcpu
36 from ganeti import cmdlib
37 from ganeti import opcodes
38 from ganeti import errors
39 from ganeti import utils
40 from ganeti import luxi
41 from ganeti import ht
42 from ganeti import objects
43 from ganeti import compat
44 from ganeti import rpc
45 from ganeti import locking
46 from ganeti import pathutils
47 from ganeti.masterd import iallocator
48 from ganeti.hypervisor import hv_xen
49
50 import testutils
51 import mocks
52
53
54 class TestCertVerification(testutils.GanetiTestCase):
55   def setUp(self):
56     testutils.GanetiTestCase.setUp(self)
57
58     self.tmpdir = tempfile.mkdtemp()
59
60   def tearDown(self):
61     shutil.rmtree(self.tmpdir)
62
63   def testVerifyCertificate(self):
64     cmdlib._VerifyCertificate(testutils.TestDataFilename("cert1.pem"))
65
66     nonexist_filename = os.path.join(self.tmpdir, "does-not-exist")
67
68     (errcode, msg) = cmdlib._VerifyCertificate(nonexist_filename)
69     self.assertEqual(errcode, cmdlib.LUClusterVerifyConfig.ETYPE_ERROR)
70
71     # Try to load non-certificate file
72     invalid_cert = testutils.TestDataFilename("bdev-net.txt")
73     (errcode, msg) = cmdlib._VerifyCertificate(invalid_cert)
74     self.assertEqual(errcode, cmdlib.LUClusterVerifyConfig.ETYPE_ERROR)
75
76
77 class TestOpcodeParams(testutils.GanetiTestCase):
78   def testParamsStructures(self):
79     for op in sorted(mcpu.Processor.DISPATCH_TABLE):
80       lu = mcpu.Processor.DISPATCH_TABLE[op]
81       lu_name = lu.__name__
82       self.failIf(hasattr(lu, "_OP_REQP"),
83                   msg=("LU '%s' has old-style _OP_REQP" % lu_name))
84       self.failIf(hasattr(lu, "_OP_DEFS"),
85                   msg=("LU '%s' has old-style _OP_DEFS" % lu_name))
86       self.failIf(hasattr(lu, "_OP_PARAMS"),
87                   msg=("LU '%s' has old-style _OP_PARAMS" % lu_name))
88
89
90 class TestIAllocatorChecks(testutils.GanetiTestCase):
91   def testFunction(self):
92     class TestLU(object):
93       def __init__(self, opcode):
94         self.cfg = mocks.FakeConfig()
95         self.op = opcode
96
97     class OpTest(opcodes.OpCode):
98        OP_PARAMS = [
99         ("iallocator", None, ht.NoType, None),
100         ("node", None, ht.NoType, None),
101         ]
102
103     default_iallocator = mocks.FakeConfig().GetDefaultIAllocator()
104     other_iallocator = default_iallocator + "_not"
105
106     op = OpTest()
107     lu = TestLU(op)
108
109     c_i = lambda: cmdlib._CheckIAllocatorOrNode(lu, "iallocator", "node")
110
111     # Neither node nor iallocator given
112     for n in (None, []):
113       op.iallocator = None
114       op.node = n
115       c_i()
116       self.assertEqual(lu.op.iallocator, default_iallocator)
117       self.assertEqual(lu.op.node, n)
118
119     # Both, iallocator and node given
120     for a in ("test", constants.DEFAULT_IALLOCATOR_SHORTCUT):
121       op.iallocator = a
122       op.node = "test"
123       self.assertRaises(errors.OpPrereqError, c_i)
124
125     # Only iallocator given
126     for n in (None, []):
127       op.iallocator = other_iallocator
128       op.node = n
129       c_i()
130       self.assertEqual(lu.op.iallocator, other_iallocator)
131       self.assertEqual(lu.op.node, n)
132
133     # Only node given
134     op.iallocator = None
135     op.node = "node"
136     c_i()
137     self.assertEqual(lu.op.iallocator, None)
138     self.assertEqual(lu.op.node, "node")
139
140     # Asked for default iallocator, no node given
141     op.iallocator = constants.DEFAULT_IALLOCATOR_SHORTCUT
142     op.node = None
143     c_i()
144     self.assertEqual(lu.op.iallocator, default_iallocator)
145     self.assertEqual(lu.op.node, None)
146
147     # No node, iallocator or default iallocator
148     op.iallocator = None
149     op.node = None
150     lu.cfg.GetDefaultIAllocator = lambda: None
151     self.assertRaises(errors.OpPrereqError, c_i)
152
153
154 class TestLUTestJqueue(unittest.TestCase):
155   def test(self):
156     self.assert_(cmdlib.LUTestJqueue._CLIENT_CONNECT_TIMEOUT <
157                  (luxi.WFJC_TIMEOUT * 0.75),
158                  msg=("Client timeout too high, might not notice bugs"
159                       " in WaitForJobChange"))
160
161
162 class TestLUQuery(unittest.TestCase):
163   def test(self):
164     self.assertEqual(sorted(cmdlib._QUERY_IMPL.keys()),
165                      sorted(constants.QR_VIA_OP))
166
167     assert constants.QR_NODE in constants.QR_VIA_OP
168     assert constants.QR_INSTANCE in constants.QR_VIA_OP
169
170     for i in constants.QR_VIA_OP:
171       self.assert_(cmdlib._GetQueryImplementation(i))
172
173     self.assertRaises(errors.OpPrereqError, cmdlib._GetQueryImplementation, "")
174     self.assertRaises(errors.OpPrereqError, cmdlib._GetQueryImplementation,
175                       "xyz")
176
177
178 class TestLUGroupAssignNodes(unittest.TestCase):
179
180   def testCheckAssignmentForSplitInstances(self):
181     node_data = dict((name, objects.Node(name=name, group=group))
182                      for (name, group) in [("n1a", "g1"), ("n1b", "g1"),
183                                            ("n2a", "g2"), ("n2b", "g2"),
184                                            ("n3a", "g3"), ("n3b", "g3"),
185                                            ("n3c", "g3"),
186                                            ])
187
188     def Instance(name, pnode, snode):
189       if snode is None:
190         disks = []
191         disk_template = constants.DT_DISKLESS
192       else:
193         disks = [objects.Disk(dev_type=constants.LD_DRBD8,
194                               logical_id=[pnode, snode, 1, 17, 17])]
195         disk_template = constants.DT_DRBD8
196
197       return objects.Instance(name=name, primary_node=pnode, disks=disks,
198                               disk_template=disk_template)
199
200     instance_data = dict((name, Instance(name, pnode, snode))
201                          for name, pnode, snode in [("inst1a", "n1a", "n1b"),
202                                                     ("inst1b", "n1b", "n1a"),
203                                                     ("inst2a", "n2a", "n2b"),
204                                                     ("inst3a", "n3a", None),
205                                                     ("inst3b", "n3b", "n1b"),
206                                                     ("inst3c", "n3b", "n2b"),
207                                                     ])
208
209     # Test first with the existing state.
210     (new, prev) = \
211       cmdlib.LUGroupAssignNodes.CheckAssignmentForSplitInstances([],
212                                                                  node_data,
213                                                                  instance_data)
214
215     self.assertEqual([], new)
216     self.assertEqual(set(["inst3b", "inst3c"]), set(prev))
217
218     # And now some changes.
219     (new, prev) = \
220       cmdlib.LUGroupAssignNodes.CheckAssignmentForSplitInstances([("n1b",
221                                                                    "g3")],
222                                                                  node_data,
223                                                                  instance_data)
224
225     self.assertEqual(set(["inst1a", "inst1b"]), set(new))
226     self.assertEqual(set(["inst3c"]), set(prev))
227
228
229 class TestClusterVerifySsh(unittest.TestCase):
230   def testMultipleGroups(self):
231     fn = cmdlib.LUClusterVerifyGroup._SelectSshCheckNodes
232     mygroupnodes = [
233       objects.Node(name="node20", group="my", offline=False),
234       objects.Node(name="node21", group="my", offline=False),
235       objects.Node(name="node22", group="my", offline=False),
236       objects.Node(name="node23", group="my", offline=False),
237       objects.Node(name="node24", group="my", offline=False),
238       objects.Node(name="node25", group="my", offline=False),
239       objects.Node(name="node26", group="my", offline=True),
240       ]
241     nodes = [
242       objects.Node(name="node1", group="g1", offline=True),
243       objects.Node(name="node2", group="g1", offline=False),
244       objects.Node(name="node3", group="g1", offline=False),
245       objects.Node(name="node4", group="g1", offline=True),
246       objects.Node(name="node5", group="g1", offline=False),
247       objects.Node(name="node10", group="xyz", offline=False),
248       objects.Node(name="node11", group="xyz", offline=False),
249       objects.Node(name="node40", group="alloff", offline=True),
250       objects.Node(name="node41", group="alloff", offline=True),
251       objects.Node(name="node50", group="aaa", offline=False),
252       ] + mygroupnodes
253     assert not utils.FindDuplicates(map(operator.attrgetter("name"), nodes))
254
255     (online, perhost) = fn(mygroupnodes, "my", nodes)
256     self.assertEqual(online, ["node%s" % i for i in range(20, 26)])
257     self.assertEqual(set(perhost.keys()), set(online))
258
259     self.assertEqual(perhost, {
260       "node20": ["node10", "node2", "node50"],
261       "node21": ["node11", "node3", "node50"],
262       "node22": ["node10", "node5", "node50"],
263       "node23": ["node11", "node2", "node50"],
264       "node24": ["node10", "node3", "node50"],
265       "node25": ["node11", "node5", "node50"],
266       })
267
268   def testSingleGroup(self):
269     fn = cmdlib.LUClusterVerifyGroup._SelectSshCheckNodes
270     nodes = [
271       objects.Node(name="node1", group="default", offline=True),
272       objects.Node(name="node2", group="default", offline=False),
273       objects.Node(name="node3", group="default", offline=False),
274       objects.Node(name="node4", group="default", offline=True),
275       ]
276     assert not utils.FindDuplicates(map(operator.attrgetter("name"), nodes))
277
278     (online, perhost) = fn(nodes, "default", nodes)
279     self.assertEqual(online, ["node2", "node3"])
280     self.assertEqual(set(perhost.keys()), set(online))
281
282     self.assertEqual(perhost, {
283       "node2": [],
284       "node3": [],
285       })
286
287
288 class TestClusterVerifyFiles(unittest.TestCase):
289   @staticmethod
290   def _FakeErrorIf(errors, cond, ecode, item, msg, *args, **kwargs):
291     assert ((ecode == constants.CV_ENODEFILECHECK and
292              ht.TNonEmptyString(item)) or
293             (ecode == constants.CV_ECLUSTERFILECHECK and
294              item is None))
295
296     if args:
297       msg = msg % args
298
299     if cond:
300       errors.append((item, msg))
301
302   _VerifyFiles = cmdlib.LUClusterVerifyGroup._VerifyFiles
303
304   def test(self):
305     errors = []
306     master_name = "master.example.com"
307     nodeinfo = [
308       objects.Node(name=master_name, offline=False, vm_capable=True),
309       objects.Node(name="node2.example.com", offline=False, vm_capable=True),
310       objects.Node(name="node3.example.com", master_candidate=True,
311                    vm_capable=False),
312       objects.Node(name="node4.example.com", offline=False, vm_capable=True),
313       objects.Node(name="nodata.example.com", offline=False, vm_capable=True),
314       objects.Node(name="offline.example.com", offline=True),
315       ]
316     cluster = objects.Cluster(modify_etc_hosts=True,
317                               enabled_hypervisors=[constants.HT_XEN_HVM])
318     files_all = set([
319       pathutils.CLUSTER_DOMAIN_SECRET_FILE,
320       pathutils.RAPI_CERT_FILE,
321       pathutils.RAPI_USERS_FILE,
322       ])
323     files_opt = set([
324       pathutils.RAPI_USERS_FILE,
325       hv_xen.XL_CONFIG_FILE,
326       pathutils.VNC_PASSWORD_FILE,
327       ])
328     files_mc = set([
329       pathutils.CLUSTER_CONF_FILE,
330       ])
331     files_vm = set([
332       hv_xen.XEND_CONFIG_FILE,
333       hv_xen.XL_CONFIG_FILE,
334       pathutils.VNC_PASSWORD_FILE,
335       ])
336     nvinfo = {
337       master_name: rpc.RpcResult(data=(True, {
338         constants.NV_FILELIST: {
339           pathutils.CLUSTER_CONF_FILE: "82314f897f38b35f9dab2f7c6b1593e0",
340           pathutils.RAPI_CERT_FILE: "babbce8f387bc082228e544a2146fee4",
341           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
342           hv_xen.XEND_CONFIG_FILE: "b4a8a824ab3cac3d88839a9adeadf310",
343           hv_xen.XL_CONFIG_FILE: "77935cee92afd26d162f9e525e3d49b9"
344         }})),
345       "node2.example.com": rpc.RpcResult(data=(True, {
346         constants.NV_FILELIST: {
347           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
348           hv_xen.XEND_CONFIG_FILE: "b4a8a824ab3cac3d88839a9adeadf310",
349           }
350         })),
351       "node3.example.com": rpc.RpcResult(data=(True, {
352         constants.NV_FILELIST: {
353           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
354           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
355           }
356         })),
357       "node4.example.com": rpc.RpcResult(data=(True, {
358         constants.NV_FILELIST: {
359           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
360           pathutils.CLUSTER_CONF_FILE: "conf-a6d4b13e407867f7a7b4f0f232a8f527",
361           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
362           pathutils.RAPI_USERS_FILE: "rapiusers-ea3271e8d810ef3",
363           hv_xen.XL_CONFIG_FILE: "77935cee92afd26d162f9e525e3d49b9"
364           }
365         })),
366       "nodata.example.com": rpc.RpcResult(data=(True, {})),
367       "offline.example.com": rpc.RpcResult(offline=True),
368       }
369     assert set(nvinfo.keys()) == set(map(operator.attrgetter("name"), nodeinfo))
370
371     self._VerifyFiles(compat.partial(self._FakeErrorIf, errors), nodeinfo,
372                       master_name, nvinfo,
373                       (files_all, files_opt, files_mc, files_vm))
374     self.assertEqual(sorted(errors), sorted([
375       (None, ("File %s found with 2 different checksums (variant 1 on"
376               " node2.example.com, node3.example.com, node4.example.com;"
377               " variant 2 on master.example.com)" % pathutils.RAPI_CERT_FILE)),
378       (None, ("File %s is missing from node(s) node2.example.com" %
379               pathutils.CLUSTER_DOMAIN_SECRET_FILE)),
380       (None, ("File %s should not exist on node(s) node4.example.com" %
381               pathutils.CLUSTER_CONF_FILE)),
382       (None, ("File %s is missing from node(s) node4.example.com" %
383               hv_xen.XEND_CONFIG_FILE)),
384       (None, ("File %s is missing from node(s) node3.example.com" %
385               pathutils.CLUSTER_CONF_FILE)),
386       (None, ("File %s found with 2 different checksums (variant 1 on"
387               " master.example.com; variant 2 on node4.example.com)" %
388               pathutils.CLUSTER_CONF_FILE)),
389       (None, ("File %s is optional, but it must exist on all or no nodes (not"
390               " found on master.example.com, node2.example.com,"
391               " node3.example.com)" % pathutils.RAPI_USERS_FILE)),
392       (None, ("File %s is optional, but it must exist on all or no nodes (not"
393               " found on node2.example.com)" % hv_xen.XL_CONFIG_FILE)),
394       ("nodata.example.com", "Node did not return file checksum data"),
395       ]))
396
397
398 class _FakeLU:
399   def __init__(self, cfg=NotImplemented, proc=NotImplemented,
400                rpc=NotImplemented):
401     self.warning_log = []
402     self.info_log = []
403     self.cfg = cfg
404     self.proc = proc
405     self.rpc = rpc
406
407   def LogWarning(self, text, *args):
408     self.warning_log.append((text, args))
409
410   def LogInfo(self, text, *args):
411     self.info_log.append((text, args))
412
413
414 class TestLoadNodeEvacResult(unittest.TestCase):
415   def testSuccess(self):
416     for moved in [[], [
417       ("inst20153.example.com", "grp2", ["nodeA4509", "nodeB2912"]),
418       ]]:
419       for early_release in [False, True]:
420         for use_nodes in [False, True]:
421           jobs = [
422             [opcodes.OpInstanceReplaceDisks().__getstate__()],
423             [opcodes.OpInstanceMigrate().__getstate__()],
424             ]
425
426           alloc_result = (moved, [], jobs)
427           assert iallocator._NEVAC_RESULT(alloc_result)
428
429           lu = _FakeLU()
430           result = cmdlib._LoadNodeEvacResult(lu, alloc_result,
431                                               early_release, use_nodes)
432
433           if moved:
434             (_, (info_args, )) = lu.info_log.pop(0)
435             for (instname, instgroup, instnodes) in moved:
436               self.assertTrue(instname in info_args)
437               if use_nodes:
438                 for i in instnodes:
439                   self.assertTrue(i in info_args)
440               else:
441                 self.assertTrue(instgroup in info_args)
442
443           self.assertFalse(lu.info_log)
444           self.assertFalse(lu.warning_log)
445
446           for op in itertools.chain(*result):
447             if hasattr(op.__class__, "early_release"):
448               self.assertEqual(op.early_release, early_release)
449             else:
450               self.assertFalse(hasattr(op, "early_release"))
451
452   def testFailed(self):
453     alloc_result = ([], [
454       ("inst5191.example.com", "errormsg21178"),
455       ], [])
456     assert iallocator._NEVAC_RESULT(alloc_result)
457
458     lu = _FakeLU()
459     self.assertRaises(errors.OpExecError, cmdlib._LoadNodeEvacResult,
460                       lu, alloc_result, False, False)
461     self.assertFalse(lu.info_log)
462     (_, (args, )) = lu.warning_log.pop(0)
463     self.assertTrue("inst5191.example.com" in args)
464     self.assertTrue("errormsg21178" in args)
465     self.assertFalse(lu.warning_log)
466
467
468 class TestUpdateAndVerifySubDict(unittest.TestCase):
469   def setUp(self):
470     self.type_check = {
471         "a": constants.VTYPE_INT,
472         "b": constants.VTYPE_STRING,
473         "c": constants.VTYPE_BOOL,
474         "d": constants.VTYPE_STRING,
475         }
476
477   def test(self):
478     old_test = {
479       "foo": {
480         "d": "blubb",
481         "a": 321,
482         },
483       "baz": {
484         "a": 678,
485         "b": "678",
486         "c": True,
487         },
488       }
489     test = {
490       "foo": {
491         "a": 123,
492         "b": "123",
493         "c": True,
494         },
495       "bar": {
496         "a": 321,
497         "b": "321",
498         "c": False,
499         },
500       }
501
502     mv = {
503       "foo": {
504         "a": 123,
505         "b": "123",
506         "c": True,
507         "d": "blubb"
508         },
509       "bar": {
510         "a": 321,
511         "b": "321",
512         "c": False,
513         },
514       "baz": {
515         "a": 678,
516         "b": "678",
517         "c": True,
518         },
519       }
520
521     verified = cmdlib._UpdateAndVerifySubDict(old_test, test, self.type_check)
522     self.assertEqual(verified, mv)
523
524   def testWrong(self):
525     test = {
526       "foo": {
527         "a": "blubb",
528         "b": "123",
529         "c": True,
530         },
531       "bar": {
532         "a": 321,
533         "b": "321",
534         "c": False,
535         },
536       }
537
538     self.assertRaises(errors.TypeEnforcementError,
539                       cmdlib._UpdateAndVerifySubDict, {}, test, self.type_check)
540
541
542 class TestHvStateHelper(unittest.TestCase):
543   def testWithoutOpData(self):
544     self.assertEqual(cmdlib._MergeAndVerifyHvState(None, NotImplemented), None)
545
546   def testWithoutOldData(self):
547     new = {
548       constants.HT_XEN_PVM: {
549         constants.HVST_MEMORY_TOTAL: 4096,
550         },
551       }
552     self.assertEqual(cmdlib._MergeAndVerifyHvState(new, None), new)
553
554   def testWithWrongHv(self):
555     new = {
556       "i-dont-exist": {
557         constants.HVST_MEMORY_TOTAL: 4096,
558         },
559       }
560     self.assertRaises(errors.OpPrereqError, cmdlib._MergeAndVerifyHvState, new,
561                       None)
562
563 class TestDiskStateHelper(unittest.TestCase):
564   def testWithoutOpData(self):
565     self.assertEqual(cmdlib._MergeAndVerifyDiskState(None, NotImplemented),
566                      None)
567
568   def testWithoutOldData(self):
569     new = {
570       constants.LD_LV: {
571         "xenvg": {
572           constants.DS_DISK_RESERVED: 1024,
573           },
574         },
575       }
576     self.assertEqual(cmdlib._MergeAndVerifyDiskState(new, None), new)
577
578   def testWithWrongStorageType(self):
579     new = {
580       "i-dont-exist": {
581         "xenvg": {
582           constants.DS_DISK_RESERVED: 1024,
583           },
584         },
585       }
586     self.assertRaises(errors.OpPrereqError, cmdlib._MergeAndVerifyDiskState,
587                       new, None)
588
589
590 class TestComputeMinMaxSpec(unittest.TestCase):
591   def setUp(self):
592     self.ispecs = {
593       constants.ISPECS_MAX: {
594         constants.ISPEC_MEM_SIZE: 512,
595         constants.ISPEC_DISK_SIZE: 1024,
596         },
597       constants.ISPECS_MIN: {
598         constants.ISPEC_MEM_SIZE: 128,
599         constants.ISPEC_DISK_COUNT: 1,
600         },
601       }
602
603   def testNoneValue(self):
604     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_MEM_SIZE, None,
605                                               self.ispecs, None) is None)
606
607   def testAutoValue(self):
608     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_MEM_SIZE, None,
609                                               self.ispecs,
610                                               constants.VALUE_AUTO) is None)
611
612   def testNotDefined(self):
613     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_NIC_COUNT, None,
614                                               self.ispecs, 3) is None)
615
616   def testNoMinDefined(self):
617     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_DISK_SIZE, None,
618                                               self.ispecs, 128) is None)
619
620   def testNoMaxDefined(self):
621     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_DISK_COUNT, None,
622                                                 self.ispecs, 16) is None)
623
624   def testOutOfRange(self):
625     for (name, val) in ((constants.ISPEC_MEM_SIZE, 64),
626                         (constants.ISPEC_MEM_SIZE, 768),
627                         (constants.ISPEC_DISK_SIZE, 4096),
628                         (constants.ISPEC_DISK_COUNT, 0)):
629       min_v = self.ispecs[constants.ISPECS_MIN].get(name, val)
630       max_v = self.ispecs[constants.ISPECS_MAX].get(name, val)
631       self.assertEqual(cmdlib._ComputeMinMaxSpec(name, None,
632                                                  self.ispecs, val),
633                        "%s value %s is not in range [%s, %s]" %
634                        (name, val,min_v, max_v))
635       self.assertEqual(cmdlib._ComputeMinMaxSpec(name, "1",
636                                                  self.ispecs, val),
637                        "%s/1 value %s is not in range [%s, %s]" %
638                        (name, val,min_v, max_v))
639
640   def test(self):
641     for (name, val) in ((constants.ISPEC_MEM_SIZE, 256),
642                         (constants.ISPEC_MEM_SIZE, 128),
643                         (constants.ISPEC_MEM_SIZE, 512),
644                         (constants.ISPEC_DISK_SIZE, 1024),
645                         (constants.ISPEC_DISK_SIZE, 0),
646                         (constants.ISPEC_DISK_COUNT, 1),
647                         (constants.ISPEC_DISK_COUNT, 5)):
648       self.assertTrue(cmdlib._ComputeMinMaxSpec(name, None, self.ispecs, val)
649                       is None)
650
651
652 def _ValidateComputeMinMaxSpec(name, *_):
653   assert name in constants.ISPECS_PARAMETERS
654   return None
655
656
657 def _NoDiskComputeMinMaxSpec(name, *_):
658   if name == constants.ISPEC_DISK_COUNT:
659     return name
660   else:
661     return None
662
663
664 class _SpecWrapper:
665   def __init__(self, spec):
666     self.spec = spec
667
668   def ComputeMinMaxSpec(self, *args):
669     return self.spec.pop(0)
670
671
672 class TestComputeIPolicySpecViolation(unittest.TestCase):
673   # Minimal policy accepted by _ComputeIPolicySpecViolation()
674   _MICRO_IPOL = {
675     constants.IPOLICY_DTS: [constants.DT_PLAIN, constants.DT_DISKLESS],
676     constants.ISPECS_MINMAX: NotImplemented,
677     }
678
679   def test(self):
680     compute_fn = _ValidateComputeMinMaxSpec
681     ret = cmdlib._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
682                                               [1024], 1, constants.DT_PLAIN,
683                                               _compute_fn=compute_fn)
684     self.assertEqual(ret, [])
685
686   def testDiskFull(self):
687     compute_fn = _NoDiskComputeMinMaxSpec
688     ret = cmdlib._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
689                                               [1024], 1, constants.DT_PLAIN,
690                                               _compute_fn=compute_fn)
691     self.assertEqual(ret, [constants.ISPEC_DISK_COUNT])
692
693   def testDiskLess(self):
694     compute_fn = _NoDiskComputeMinMaxSpec
695     ret = cmdlib._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
696                                               [1024], 1, constants.DT_DISKLESS,
697                                               _compute_fn=compute_fn)
698     self.assertEqual(ret, [])
699
700   def testWrongTemplates(self):
701     compute_fn = _ValidateComputeMinMaxSpec
702     ret = cmdlib._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
703                                               [1024], 1, constants.DT_DRBD8,
704                                               _compute_fn=compute_fn)
705     self.assertEqual(len(ret), 1)
706     self.assertTrue("Disk template" in ret[0])
707
708   def testInvalidArguments(self):
709     self.assertRaises(AssertionError, cmdlib._ComputeIPolicySpecViolation,
710                       self._MICRO_IPOL, 1024, 1, 1, 1, [], 1,
711                       constants.DT_PLAIN,)
712
713   def testInvalidSpec(self):
714     spec = _SpecWrapper([None, False, "foo", None, "bar", None])
715     compute_fn = spec.ComputeMinMaxSpec
716     ret = cmdlib._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
717                                               [1024], 1, constants.DT_PLAIN,
718                                               _compute_fn=compute_fn)
719     self.assertEqual(ret, ["foo", "bar"])
720     self.assertFalse(spec.spec)
721
722
723 class _StubComputeIPolicySpecViolation:
724   def __init__(self, mem_size, cpu_count, disk_count, nic_count, disk_sizes,
725                spindle_use, disk_template):
726     self.mem_size = mem_size
727     self.cpu_count = cpu_count
728     self.disk_count = disk_count
729     self.nic_count = nic_count
730     self.disk_sizes = disk_sizes
731     self.spindle_use = spindle_use
732     self.disk_template = disk_template
733
734   def __call__(self, _, mem_size, cpu_count, disk_count, nic_count, disk_sizes,
735                spindle_use, disk_template):
736     assert self.mem_size == mem_size
737     assert self.cpu_count == cpu_count
738     assert self.disk_count == disk_count
739     assert self.nic_count == nic_count
740     assert self.disk_sizes == disk_sizes
741     assert self.spindle_use == spindle_use
742     assert self.disk_template == disk_template
743
744     return []
745
746
747 class _FakeConfigForComputeIPolicyInstanceViolation:
748   def __init__(self, be):
749     self.cluster = objects.Cluster(beparams={"default": be})
750
751   def GetClusterInfo(self):
752     return self.cluster
753
754
755 class TestComputeIPolicyInstanceViolation(unittest.TestCase):
756   def test(self):
757     beparams = {
758       constants.BE_MAXMEM: 2048,
759       constants.BE_VCPUS: 2,
760       constants.BE_SPINDLE_USE: 4,
761       }
762     disks = [objects.Disk(size=512)]
763     cfg = _FakeConfigForComputeIPolicyInstanceViolation(beparams)
764     instance = objects.Instance(beparams=beparams, disks=disks, nics=[],
765                                 disk_template=constants.DT_PLAIN)
766     stub = _StubComputeIPolicySpecViolation(2048, 2, 1, 0, [512], 4,
767                                             constants.DT_PLAIN)
768     ret = cmdlib._ComputeIPolicyInstanceViolation(NotImplemented, instance,
769                                                   cfg, _compute_fn=stub)
770     self.assertEqual(ret, [])
771     instance2 = objects.Instance(beparams={}, disks=disks, nics=[],
772                                  disk_template=constants.DT_PLAIN)
773     ret = cmdlib._ComputeIPolicyInstanceViolation(NotImplemented, instance2,
774                                                   cfg, _compute_fn=stub)
775     self.assertEqual(ret, [])
776
777
778 class TestComputeIPolicyInstanceSpecViolation(unittest.TestCase):
779   def test(self):
780     ispec = {
781       constants.ISPEC_MEM_SIZE: 2048,
782       constants.ISPEC_CPU_COUNT: 2,
783       constants.ISPEC_DISK_COUNT: 1,
784       constants.ISPEC_DISK_SIZE: [512],
785       constants.ISPEC_NIC_COUNT: 0,
786       constants.ISPEC_SPINDLE_USE: 1,
787       }
788     stub = _StubComputeIPolicySpecViolation(2048, 2, 1, 0, [512], 1,
789                                             constants.DT_PLAIN)
790     ret = cmdlib._ComputeIPolicyInstanceSpecViolation(NotImplemented, ispec,
791                                                       constants.DT_PLAIN,
792                                                       _compute_fn=stub)
793     self.assertEqual(ret, [])
794
795
796 class _CallRecorder:
797   def __init__(self, return_value=None):
798     self.called = False
799     self.return_value = return_value
800
801   def __call__(self, *args):
802     self.called = True
803     return self.return_value
804
805
806 class TestComputeIPolicyNodeViolation(unittest.TestCase):
807   def setUp(self):
808     self.recorder = _CallRecorder(return_value=[])
809
810   def testSameGroup(self):
811     ret = cmdlib._ComputeIPolicyNodeViolation(NotImplemented, NotImplemented,
812                                               "foo", "foo", NotImplemented,
813                                               _compute_fn=self.recorder)
814     self.assertFalse(self.recorder.called)
815     self.assertEqual(ret, [])
816
817   def testDifferentGroup(self):
818     ret = cmdlib._ComputeIPolicyNodeViolation(NotImplemented, NotImplemented,
819                                               "foo", "bar", NotImplemented,
820                                               _compute_fn=self.recorder)
821     self.assertTrue(self.recorder.called)
822     self.assertEqual(ret, [])
823
824
825 class _FakeConfigForTargetNodeIPolicy:
826   def __init__(self, node_info=NotImplemented):
827     self._node_info = node_info
828
829   def GetNodeInfo(self, _):
830     return self._node_info
831
832
833 class TestCheckTargetNodeIPolicy(unittest.TestCase):
834   def setUp(self):
835     self.instance = objects.Instance(primary_node="blubb")
836     self.target_node = objects.Node(group="bar")
837     node_info = objects.Node(group="foo")
838     fake_cfg = _FakeConfigForTargetNodeIPolicy(node_info=node_info)
839     self.lu = _FakeLU(cfg=fake_cfg)
840
841   def testNoViolation(self):
842     compute_recoder = _CallRecorder(return_value=[])
843     cmdlib._CheckTargetNodeIPolicy(self.lu, NotImplemented, self.instance,
844                                    self.target_node, NotImplemented,
845                                    _compute_fn=compute_recoder)
846     self.assertTrue(compute_recoder.called)
847     self.assertEqual(self.lu.warning_log, [])
848
849   def testNoIgnore(self):
850     compute_recoder = _CallRecorder(return_value=["mem_size not in range"])
851     self.assertRaises(errors.OpPrereqError, cmdlib._CheckTargetNodeIPolicy,
852                       self.lu, NotImplemented, self.instance, self.target_node,
853                       NotImplemented, _compute_fn=compute_recoder)
854     self.assertTrue(compute_recoder.called)
855     self.assertEqual(self.lu.warning_log, [])
856
857   def testIgnoreViolation(self):
858     compute_recoder = _CallRecorder(return_value=["mem_size not in range"])
859     cmdlib._CheckTargetNodeIPolicy(self.lu, NotImplemented, self.instance,
860                                    self.target_node, NotImplemented,
861                                    ignore=True, _compute_fn=compute_recoder)
862     self.assertTrue(compute_recoder.called)
863     msg = ("Instance does not meet target node group's (bar) instance policy:"
864            " mem_size not in range")
865     self.assertEqual(self.lu.warning_log, [(msg, ())])
866
867
868 class TestApplyContainerMods(unittest.TestCase):
869   def testEmptyContainer(self):
870     container = []
871     chgdesc = []
872     cmdlib.ApplyContainerMods("test", container, chgdesc, [], None, None, None)
873     self.assertEqual(container, [])
874     self.assertEqual(chgdesc, [])
875
876   def testAdd(self):
877     container = []
878     chgdesc = []
879     mods = cmdlib.PrepareContainerMods([
880       (constants.DDM_ADD, -1, "Hello"),
881       (constants.DDM_ADD, -1, "World"),
882       (constants.DDM_ADD, 0, "Start"),
883       (constants.DDM_ADD, -1, "End"),
884       ], None)
885     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
886                               None, None, None)
887     self.assertEqual(container, ["Start", "Hello", "World", "End"])
888     self.assertEqual(chgdesc, [])
889
890     mods = cmdlib.PrepareContainerMods([
891       (constants.DDM_ADD, 0, "zero"),
892       (constants.DDM_ADD, 3, "Added"),
893       (constants.DDM_ADD, 5, "four"),
894       (constants.DDM_ADD, 7, "xyz"),
895       ], None)
896     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
897                               None, None, None)
898     self.assertEqual(container,
899                      ["zero", "Start", "Hello", "Added", "World", "four",
900                       "End", "xyz"])
901     self.assertEqual(chgdesc, [])
902
903     for idx in [-2, len(container) + 1]:
904       mods = cmdlib.PrepareContainerMods([
905         (constants.DDM_ADD, idx, "error"),
906         ], None)
907       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
908                         "test", container, None, mods, None, None, None)
909
910   def testRemoveError(self):
911     for idx in [0, 1, 2, 100, -1, -4]:
912       mods = cmdlib.PrepareContainerMods([
913         (constants.DDM_REMOVE, idx, None),
914         ], None)
915       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
916                         "test", [], None, mods, None, None, None)
917
918     mods = cmdlib.PrepareContainerMods([
919       (constants.DDM_REMOVE, 0, object()),
920       ], None)
921     self.assertRaises(AssertionError, cmdlib.ApplyContainerMods,
922                       "test", [""], None, mods, None, None, None)
923
924   def testAddError(self):
925     for idx in range(-100, -1) + [100]:
926       mods = cmdlib.PrepareContainerMods([
927         (constants.DDM_ADD, idx, None),
928         ], None)
929       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
930                         "test", [], None, mods, None, None, None)
931
932   def testRemove(self):
933     container = ["item 1", "item 2"]
934     mods = cmdlib.PrepareContainerMods([
935       (constants.DDM_ADD, -1, "aaa"),
936       (constants.DDM_REMOVE, -1, None),
937       (constants.DDM_ADD, -1, "bbb"),
938       ], None)
939     chgdesc = []
940     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
941                               None, None, None)
942     self.assertEqual(container, ["item 1", "item 2", "bbb"])
943     self.assertEqual(chgdesc, [
944       ("test/2", "remove"),
945       ])
946
947   def testModify(self):
948     container = ["item 1", "item 2"]
949     mods = cmdlib.PrepareContainerMods([
950       (constants.DDM_MODIFY, -1, "a"),
951       (constants.DDM_MODIFY, 0, "b"),
952       (constants.DDM_MODIFY, 1, "c"),
953       ], None)
954     chgdesc = []
955     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
956                               None, None, None)
957     self.assertEqual(container, ["item 1", "item 2"])
958     self.assertEqual(chgdesc, [])
959
960     for idx in [-2, len(container) + 1]:
961       mods = cmdlib.PrepareContainerMods([
962         (constants.DDM_MODIFY, idx, "error"),
963         ], None)
964       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
965                         "test", container, None, mods, None, None, None)
966
967   class _PrivateData:
968     def __init__(self):
969       self.data = None
970
971   @staticmethod
972   def _CreateTestFn(idx, params, private):
973     private.data = ("add", idx, params)
974     return ((100 * idx, params), [
975       ("test/%s" % idx, hex(idx)),
976       ])
977
978   @staticmethod
979   def _ModifyTestFn(idx, item, params, private):
980     private.data = ("modify", idx, params)
981     return [
982       ("test/%s" % idx, "modify %s" % params),
983       ]
984
985   @staticmethod
986   def _RemoveTestFn(idx, item, private):
987     private.data = ("remove", idx, item)
988
989   def testAddWithCreateFunction(self):
990     container = []
991     chgdesc = []
992     mods = cmdlib.PrepareContainerMods([
993       (constants.DDM_ADD, -1, "Hello"),
994       (constants.DDM_ADD, -1, "World"),
995       (constants.DDM_ADD, 0, "Start"),
996       (constants.DDM_ADD, -1, "End"),
997       (constants.DDM_REMOVE, 2, None),
998       (constants.DDM_MODIFY, -1, "foobar"),
999       (constants.DDM_REMOVE, 2, None),
1000       (constants.DDM_ADD, 1, "More"),
1001       ], self._PrivateData)
1002     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
1003       self._CreateTestFn, self._ModifyTestFn, self._RemoveTestFn)
1004     self.assertEqual(container, [
1005       (000, "Start"),
1006       (100, "More"),
1007       (000, "Hello"),
1008       ])
1009     self.assertEqual(chgdesc, [
1010       ("test/0", "0x0"),
1011       ("test/1", "0x1"),
1012       ("test/0", "0x0"),
1013       ("test/3", "0x3"),
1014       ("test/2", "remove"),
1015       ("test/2", "modify foobar"),
1016       ("test/2", "remove"),
1017       ("test/1", "0x1")
1018       ])
1019     self.assertTrue(compat.all(op == private.data[0]
1020                                for (op, _, _, private) in mods))
1021     self.assertEqual([private.data for (op, _, _, private) in mods], [
1022       ("add", 0, "Hello"),
1023       ("add", 1, "World"),
1024       ("add", 0, "Start"),
1025       ("add", 3, "End"),
1026       ("remove", 2, (100, "World")),
1027       ("modify", 2, "foobar"),
1028       ("remove", 2, (300, "End")),
1029       ("add", 1, "More"),
1030       ])
1031
1032
1033 class _FakeConfigForGenDiskTemplate:
1034   def __init__(self):
1035     self._unique_id = itertools.count()
1036     self._drbd_minor = itertools.count(20)
1037     self._port = itertools.count(constants.FIRST_DRBD_PORT)
1038     self._secret = itertools.count()
1039
1040   def GetVGName(self):
1041     return "testvg"
1042
1043   def GenerateUniqueID(self, ec_id):
1044     return "ec%s-uq%s" % (ec_id, self._unique_id.next())
1045
1046   def AllocateDRBDMinor(self, nodes, instance):
1047     return [self._drbd_minor.next()
1048             for _ in nodes]
1049
1050   def AllocatePort(self):
1051     return self._port.next()
1052
1053   def GenerateDRBDSecret(self, ec_id):
1054     return "ec%s-secret%s" % (ec_id, self._secret.next())
1055
1056   def GetInstanceInfo(self, _):
1057     return "foobar"
1058
1059
1060 class _FakeProcForGenDiskTemplate:
1061   def GetECId(self):
1062     return 0
1063
1064
1065 class TestGenerateDiskTemplate(unittest.TestCase):
1066   def setUp(self):
1067     nodegroup = objects.NodeGroup(name="ng")
1068     nodegroup.UpgradeConfig()
1069
1070     cfg = _FakeConfigForGenDiskTemplate()
1071     proc = _FakeProcForGenDiskTemplate()
1072
1073     self.lu = _FakeLU(cfg=cfg, proc=proc)
1074     self.nodegroup = nodegroup
1075
1076   @staticmethod
1077   def GetDiskParams():
1078     return copy.deepcopy(constants.DISK_DT_DEFAULTS)
1079
1080   def testWrongDiskTemplate(self):
1081     gdt = cmdlib._GenerateDiskTemplate
1082     disk_template = "##unknown##"
1083
1084     assert disk_template not in constants.DISK_TEMPLATES
1085
1086     self.assertRaises(errors.ProgrammerError, gdt, self.lu, disk_template,
1087                       "inst26831.example.com", "node30113.example.com", [], [],
1088                       NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1089                       self.GetDiskParams())
1090
1091   def testDiskless(self):
1092     gdt = cmdlib._GenerateDiskTemplate
1093
1094     result = gdt(self.lu, constants.DT_DISKLESS, "inst27734.example.com",
1095                  "node30113.example.com", [], [],
1096                  NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1097                  self.GetDiskParams())
1098     self.assertEqual(result, [])
1099
1100   def _TestTrivialDisk(self, template, disk_info, base_index, exp_dev_type,
1101                        file_storage_dir=NotImplemented,
1102                        file_driver=NotImplemented,
1103                        req_file_storage=NotImplemented,
1104                        req_shr_file_storage=NotImplemented):
1105     gdt = cmdlib._GenerateDiskTemplate
1106
1107     map(lambda params: utils.ForceDictType(params,
1108                                            constants.IDISK_PARAMS_TYPES),
1109         disk_info)
1110
1111     # Check if non-empty list of secondaries is rejected
1112     self.assertRaises(errors.ProgrammerError, gdt, self.lu,
1113                       template, "inst25088.example.com",
1114                       "node185.example.com", ["node323.example.com"], [],
1115                       NotImplemented, NotImplemented, base_index,
1116                       self.lu.LogInfo, self.GetDiskParams(),
1117                       _req_file_storage=req_file_storage,
1118                       _req_shr_file_storage=req_shr_file_storage)
1119
1120     result = gdt(self.lu, template, "inst21662.example.com",
1121                  "node21741.example.com", [],
1122                  disk_info, file_storage_dir, file_driver, base_index,
1123                  self.lu.LogInfo, self.GetDiskParams(),
1124                  _req_file_storage=req_file_storage,
1125                  _req_shr_file_storage=req_shr_file_storage)
1126
1127     for (idx, disk) in enumerate(result):
1128       self.assertTrue(isinstance(disk, objects.Disk))
1129       self.assertEqual(disk.dev_type, exp_dev_type)
1130       self.assertEqual(disk.size, disk_info[idx][constants.IDISK_SIZE])
1131       self.assertEqual(disk.mode, disk_info[idx][constants.IDISK_MODE])
1132       self.assertTrue(disk.children is None)
1133
1134     self._CheckIvNames(result, base_index, base_index + len(disk_info))
1135     cmdlib._UpdateIvNames(base_index, result)
1136     self._CheckIvNames(result, base_index, base_index + len(disk_info))
1137
1138     return result
1139
1140   def _CheckIvNames(self, disks, base_index, end_index):
1141     self.assertEqual(map(operator.attrgetter("iv_name"), disks),
1142                      ["disk/%s" % i for i in range(base_index, end_index)])
1143
1144   def testPlain(self):
1145     disk_info = [{
1146       constants.IDISK_SIZE: 1024,
1147       constants.IDISK_MODE: constants.DISK_RDWR,
1148       }, {
1149       constants.IDISK_SIZE: 4096,
1150       constants.IDISK_VG: "othervg",
1151       constants.IDISK_MODE: constants.DISK_RDWR,
1152       }]
1153
1154     result = self._TestTrivialDisk(constants.DT_PLAIN, disk_info, 3,
1155                                    constants.LD_LV)
1156
1157     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1158       ("testvg", "ec0-uq0.disk3"),
1159       ("othervg", "ec0-uq1.disk4"),
1160       ])
1161
1162   @staticmethod
1163   def _AllowFileStorage():
1164     pass
1165
1166   @staticmethod
1167   def _ForbidFileStorage():
1168     raise errors.OpPrereqError("Disallowed in test")
1169
1170   def testFile(self):
1171     self.assertRaises(errors.OpPrereqError, self._TestTrivialDisk,
1172                       constants.DT_FILE, [], 0, NotImplemented,
1173                       req_file_storage=self._ForbidFileStorage)
1174     self.assertRaises(errors.OpPrereqError, self._TestTrivialDisk,
1175                       constants.DT_SHARED_FILE, [], 0, NotImplemented,
1176                       req_shr_file_storage=self._ForbidFileStorage)
1177
1178     for disk_template in [constants.DT_FILE, constants.DT_SHARED_FILE]:
1179       disk_info = [{
1180         constants.IDISK_SIZE: 80 * 1024,
1181         constants.IDISK_MODE: constants.DISK_RDONLY,
1182         }, {
1183         constants.IDISK_SIZE: 4096,
1184         constants.IDISK_MODE: constants.DISK_RDWR,
1185         }, {
1186         constants.IDISK_SIZE: 6 * 1024,
1187         constants.IDISK_MODE: constants.DISK_RDWR,
1188         }]
1189
1190       result = self._TestTrivialDisk(disk_template, disk_info, 2,
1191         constants.LD_FILE, file_storage_dir="/tmp",
1192         file_driver=constants.FD_BLKTAP,
1193         req_file_storage=self._AllowFileStorage,
1194         req_shr_file_storage=self._AllowFileStorage)
1195
1196       self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1197         (constants.FD_BLKTAP, "/tmp/disk2"),
1198         (constants.FD_BLKTAP, "/tmp/disk3"),
1199         (constants.FD_BLKTAP, "/tmp/disk4"),
1200         ])
1201
1202   def testBlock(self):
1203     disk_info = [{
1204       constants.IDISK_SIZE: 8 * 1024,
1205       constants.IDISK_MODE: constants.DISK_RDWR,
1206       constants.IDISK_ADOPT: "/tmp/some/block/dev",
1207       }]
1208
1209     result = self._TestTrivialDisk(constants.DT_BLOCK, disk_info, 10,
1210                                    constants.LD_BLOCKDEV)
1211
1212     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1213       (constants.BLOCKDEV_DRIVER_MANUAL, "/tmp/some/block/dev"),
1214       ])
1215
1216   def testRbd(self):
1217     disk_info = [{
1218       constants.IDISK_SIZE: 8 * 1024,
1219       constants.IDISK_MODE: constants.DISK_RDONLY,
1220       }, {
1221       constants.IDISK_SIZE: 100 * 1024,
1222       constants.IDISK_MODE: constants.DISK_RDWR,
1223       }]
1224
1225     result = self._TestTrivialDisk(constants.DT_RBD, disk_info, 0,
1226                                    constants.LD_RBD)
1227
1228     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1229       ("rbd", "ec0-uq0.rbd.disk0"),
1230       ("rbd", "ec0-uq1.rbd.disk1"),
1231       ])
1232
1233   def testDrbd8(self):
1234     gdt = cmdlib._GenerateDiskTemplate
1235     drbd8_defaults = constants.DISK_LD_DEFAULTS[constants.LD_DRBD8]
1236     drbd8_default_metavg = drbd8_defaults[constants.LDP_DEFAULT_METAVG]
1237
1238     disk_info = [{
1239       constants.IDISK_SIZE: 1024,
1240       constants.IDISK_MODE: constants.DISK_RDWR,
1241       }, {
1242       constants.IDISK_SIZE: 100 * 1024,
1243       constants.IDISK_MODE: constants.DISK_RDONLY,
1244       constants.IDISK_METAVG: "metavg",
1245       }, {
1246       constants.IDISK_SIZE: 4096,
1247       constants.IDISK_MODE: constants.DISK_RDWR,
1248       constants.IDISK_VG: "vgxyz",
1249       },
1250       ]
1251
1252     exp_logical_ids = [[
1253       (self.lu.cfg.GetVGName(), "ec0-uq0.disk0_data"),
1254       (drbd8_default_metavg, "ec0-uq0.disk0_meta"),
1255       ], [
1256       (self.lu.cfg.GetVGName(), "ec0-uq1.disk1_data"),
1257       ("metavg", "ec0-uq1.disk1_meta"),
1258       ], [
1259       ("vgxyz", "ec0-uq2.disk2_data"),
1260       (drbd8_default_metavg, "ec0-uq2.disk2_meta"),
1261       ]]
1262
1263     assert len(exp_logical_ids) == len(disk_info)
1264
1265     map(lambda params: utils.ForceDictType(params,
1266                                            constants.IDISK_PARAMS_TYPES),
1267         disk_info)
1268
1269     # Check if empty list of secondaries is rejected
1270     self.assertRaises(errors.ProgrammerError, gdt, self.lu, constants.DT_DRBD8,
1271                       "inst827.example.com", "node1334.example.com", [],
1272                       disk_info, NotImplemented, NotImplemented, 0,
1273                       self.lu.LogInfo, self.GetDiskParams())
1274
1275     result = gdt(self.lu, constants.DT_DRBD8, "inst827.example.com",
1276                  "node1334.example.com", ["node12272.example.com"],
1277                  disk_info, NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1278                  self.GetDiskParams())
1279
1280     for (idx, disk) in enumerate(result):
1281       self.assertTrue(isinstance(disk, objects.Disk))
1282       self.assertEqual(disk.dev_type, constants.LD_DRBD8)
1283       self.assertEqual(disk.size, disk_info[idx][constants.IDISK_SIZE])
1284       self.assertEqual(disk.mode, disk_info[idx][constants.IDISK_MODE])
1285
1286       for child in disk.children:
1287         self.assertTrue(isinstance(disk, objects.Disk))
1288         self.assertEqual(child.dev_type, constants.LD_LV)
1289         self.assertTrue(child.children is None)
1290
1291       self.assertEqual(map(operator.attrgetter("logical_id"), disk.children),
1292                        exp_logical_ids[idx])
1293
1294       self.assertEqual(len(disk.children), 2)
1295       self.assertEqual(disk.children[0].size, disk.size)
1296       self.assertEqual(disk.children[1].size, constants.DRBD_META_SIZE)
1297
1298     self._CheckIvNames(result, 0, len(disk_info))
1299     cmdlib._UpdateIvNames(0, result)
1300     self._CheckIvNames(result, 0, len(disk_info))
1301
1302     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1303       ("node1334.example.com", "node12272.example.com",
1304        constants.FIRST_DRBD_PORT, 20, 21, "ec0-secret0"),
1305       ("node1334.example.com", "node12272.example.com",
1306        constants.FIRST_DRBD_PORT + 1, 22, 23, "ec0-secret1"),
1307       ("node1334.example.com", "node12272.example.com",
1308        constants.FIRST_DRBD_PORT + 2, 24, 25, "ec0-secret2"),
1309       ])
1310
1311
1312 class _ConfigForDiskWipe:
1313   def __init__(self, exp_node):
1314     self._exp_node = exp_node
1315
1316   def SetDiskID(self, device, node):
1317     assert isinstance(device, objects.Disk)
1318     assert node == self._exp_node
1319
1320
1321 class _RpcForDiskWipe:
1322   def __init__(self, exp_node, pause_cb, wipe_cb):
1323     self._exp_node = exp_node
1324     self._pause_cb = pause_cb
1325     self._wipe_cb = wipe_cb
1326
1327   def call_blockdev_pause_resume_sync(self, node, disks, pause):
1328     assert node == self._exp_node
1329     return rpc.RpcResult(data=self._pause_cb(disks, pause))
1330
1331   def call_blockdev_wipe(self, node, bdev, offset, size):
1332     assert node == self._exp_node
1333     return rpc.RpcResult(data=self._wipe_cb(bdev, offset, size))
1334
1335
1336 class _DiskPauseTracker:
1337   def __init__(self):
1338     self.history = []
1339
1340   def __call__(self, (disks, instance), pause):
1341     assert not (set(disks) - set(instance.disks))
1342
1343     self.history.extend((i.logical_id, i.size, pause)
1344                         for i in disks)
1345
1346     return (True, [True] * len(disks))
1347
1348
1349 class _DiskWipeProgressTracker:
1350   def __init__(self, start_offset):
1351     self._start_offset = start_offset
1352     self.progress = {}
1353
1354   def __call__(self, (disk, _), offset, size):
1355     assert isinstance(offset, (long, int))
1356     assert isinstance(size, (long, int))
1357
1358     max_chunk_size = (disk.size / 100.0 * constants.MIN_WIPE_CHUNK_PERCENT)
1359
1360     assert offset >= self._start_offset
1361     assert (offset + size) <= disk.size
1362
1363     assert size > 0
1364     assert size <= constants.MAX_WIPE_CHUNK
1365     assert size <= max_chunk_size
1366
1367     assert offset == self._start_offset or disk.logical_id in self.progress
1368
1369     # Keep track of progress
1370     cur_progress = self.progress.setdefault(disk.logical_id, self._start_offset)
1371
1372     assert cur_progress == offset
1373
1374     # Record progress
1375     self.progress[disk.logical_id] += size
1376
1377     return (True, None)
1378
1379
1380 class TestWipeDisks(unittest.TestCase):
1381   def _FailingPauseCb(self, (disks, _), pause):
1382     self.assertEqual(len(disks), 3)
1383     self.assertTrue(pause)
1384     # Simulate an RPC error
1385     return (False, "error")
1386
1387   def testPauseFailure(self):
1388     node_name = "node1372.example.com"
1389
1390     lu = _FakeLU(rpc=_RpcForDiskWipe(node_name, self._FailingPauseCb,
1391                                      NotImplemented),
1392                  cfg=_ConfigForDiskWipe(node_name))
1393
1394     disks = [
1395       objects.Disk(dev_type=constants.LD_LV),
1396       objects.Disk(dev_type=constants.LD_LV),
1397       objects.Disk(dev_type=constants.LD_LV),
1398       ]
1399
1400     instance = objects.Instance(name="inst21201",
1401                                 primary_node=node_name,
1402                                 disk_template=constants.DT_PLAIN,
1403                                 disks=disks)
1404
1405     self.assertRaises(errors.OpExecError, cmdlib._WipeDisks, lu, instance)
1406
1407   def _FailingWipeCb(self, (disk, _), offset, size):
1408     # This should only ever be called for the first disk
1409     self.assertEqual(disk.logical_id, "disk0")
1410     return (False, None)
1411
1412   def testFailingWipe(self):
1413     node_name = "node13445.example.com"
1414     pt = _DiskPauseTracker()
1415
1416     lu = _FakeLU(rpc=_RpcForDiskWipe(node_name, pt, self._FailingWipeCb),
1417                  cfg=_ConfigForDiskWipe(node_name))
1418
1419     disks = [
1420       objects.Disk(dev_type=constants.LD_LV, logical_id="disk0",
1421                    size=100 * 1024),
1422       objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1423                    size=500 * 1024),
1424       objects.Disk(dev_type=constants.LD_LV, logical_id="disk2", size=256),
1425       ]
1426
1427     instance = objects.Instance(name="inst562",
1428                                 primary_node=node_name,
1429                                 disk_template=constants.DT_PLAIN,
1430                                 disks=disks)
1431
1432     try:
1433       cmdlib._WipeDisks(lu, instance)
1434     except errors.OpExecError, err:
1435       self.assertTrue(str(err), "Could not wipe disk 0 at offset 0 ")
1436     else:
1437       self.fail("Did not raise exception")
1438
1439     # Check if all disks were paused and resumed
1440     self.assertEqual(pt.history, [
1441       ("disk0", 100 * 1024, True),
1442       ("disk1", 500 * 1024, True),
1443       ("disk2", 256, True),
1444       ("disk0", 100 * 1024, False),
1445       ("disk1", 500 * 1024, False),
1446       ("disk2", 256, False),
1447       ])
1448
1449   def _PrepareWipeTest(self, start_offset, disks):
1450     node_name = "node-with-offset%s.example.com" % start_offset
1451     pauset = _DiskPauseTracker()
1452     progresst = _DiskWipeProgressTracker(start_offset)
1453
1454     lu = _FakeLU(rpc=_RpcForDiskWipe(node_name, pauset, progresst),
1455                  cfg=_ConfigForDiskWipe(node_name))
1456
1457     instance = objects.Instance(name="inst3560",
1458                                 primary_node=node_name,
1459                                 disk_template=constants.DT_PLAIN,
1460                                 disks=disks)
1461
1462     return (lu, instance, pauset, progresst)
1463
1464   def testNormalWipe(self):
1465     disks = [
1466       objects.Disk(dev_type=constants.LD_LV, logical_id="disk0", size=1024),
1467       objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1468                    size=500 * 1024),
1469       objects.Disk(dev_type=constants.LD_LV, logical_id="disk2", size=128),
1470       objects.Disk(dev_type=constants.LD_LV, logical_id="disk3",
1471                    size=constants.MAX_WIPE_CHUNK),
1472       ]
1473
1474     (lu, instance, pauset, progresst) = self._PrepareWipeTest(0, disks)
1475
1476     cmdlib._WipeDisks(lu, instance)
1477
1478     self.assertEqual(pauset.history, [
1479       ("disk0", 1024, True),
1480       ("disk1", 500 * 1024, True),
1481       ("disk2", 128, True),
1482       ("disk3", constants.MAX_WIPE_CHUNK, True),
1483       ("disk0", 1024, False),
1484       ("disk1", 500 * 1024, False),
1485       ("disk2", 128, False),
1486       ("disk3", constants.MAX_WIPE_CHUNK, False),
1487       ])
1488
1489     # Ensure the complete disk has been wiped
1490     self.assertEqual(progresst.progress,
1491                      dict((i.logical_id, i.size) for i in disks))
1492
1493   def testWipeWithStartOffset(self):
1494     for start_offset in [0, 280, 8895, 1563204]:
1495       disks = [
1496         objects.Disk(dev_type=constants.LD_LV, logical_id="disk0",
1497                      size=128),
1498         objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1499                      size=start_offset + (100 * 1024)),
1500         ]
1501
1502       (lu, instance, pauset, progresst) = \
1503         self._PrepareWipeTest(start_offset, disks)
1504
1505       # Test start offset with only one disk
1506       cmdlib._WipeDisks(lu, instance,
1507                         disks=[(1, disks[1], start_offset)])
1508
1509       # Only the second disk may have been paused and wiped
1510       self.assertEqual(pauset.history, [
1511         ("disk1", start_offset + (100 * 1024), True),
1512         ("disk1", start_offset + (100 * 1024), False),
1513         ])
1514       self.assertEqual(progresst.progress, {
1515         "disk1": disks[1].size,
1516         })
1517
1518
1519 class TestDiskSizeInBytesToMebibytes(unittest.TestCase):
1520   def testLessThanOneMebibyte(self):
1521     for i in [1, 2, 7, 512, 1000, 1023]:
1522       lu = _FakeLU()
1523       result = cmdlib._DiskSizeInBytesToMebibytes(lu, i)
1524       self.assertEqual(result, 1)
1525       self.assertEqual(len(lu.warning_log), 1)
1526       self.assertEqual(len(lu.warning_log[0]), 2)
1527       (_, (warnsize, )) = lu.warning_log[0]
1528       self.assertEqual(warnsize, (1024 * 1024) - i)
1529
1530   def testEven(self):
1531     for i in [1, 2, 7, 512, 1000, 1023]:
1532       lu = _FakeLU()
1533       result = cmdlib._DiskSizeInBytesToMebibytes(lu, i * 1024 * 1024)
1534       self.assertEqual(result, i)
1535       self.assertFalse(lu.warning_log)
1536
1537   def testLargeNumber(self):
1538     for i in [1, 2, 7, 512, 1000, 1023, 2724, 12420]:
1539       for j in [1, 2, 486, 326, 986, 1023]:
1540         lu = _FakeLU()
1541         size = (1024 * 1024 * i) + j
1542         result = cmdlib._DiskSizeInBytesToMebibytes(lu, size)
1543         self.assertEqual(result, i + 1, msg="Amount was not rounded up")
1544         self.assertEqual(len(lu.warning_log), 1)
1545         self.assertEqual(len(lu.warning_log[0]), 2)
1546         (_, (warnsize, )) = lu.warning_log[0]
1547         self.assertEqual(warnsize, (1024 * 1024) - j)
1548
1549
1550 class TestCopyLockList(unittest.TestCase):
1551   def test(self):
1552     self.assertEqual(cmdlib._CopyLockList([]), [])
1553     self.assertEqual(cmdlib._CopyLockList(None), None)
1554     self.assertEqual(cmdlib._CopyLockList(locking.ALL_SET), locking.ALL_SET)
1555
1556     names = ["foo", "bar"]
1557     output = cmdlib._CopyLockList(names)
1558     self.assertEqual(names, output)
1559     self.assertNotEqual(id(names), id(output), msg="List was not copied")
1560
1561
1562 class TestCheckOpportunisticLocking(unittest.TestCase):
1563   class OpTest(opcodes.OpCode):
1564     OP_PARAMS = [
1565       opcodes._POpportunisticLocking,
1566       opcodes._PIAllocFromDesc(""),
1567       ]
1568
1569   @classmethod
1570   def _MakeOp(cls, **kwargs):
1571     op = cls.OpTest(**kwargs)
1572     op.Validate(True)
1573     return op
1574
1575   def testMissingAttributes(self):
1576     self.assertRaises(AttributeError, cmdlib._CheckOpportunisticLocking,
1577                       object())
1578
1579   def testDefaults(self):
1580     op = self._MakeOp()
1581     cmdlib._CheckOpportunisticLocking(op)
1582
1583   def test(self):
1584     for iallocator in [None, "something", "other"]:
1585       for opplock in [False, True]:
1586         op = self._MakeOp(iallocator=iallocator, opportunistic_locking=opplock)
1587         if opplock and not iallocator:
1588           self.assertRaises(errors.OpPrereqError,
1589                             cmdlib._CheckOpportunisticLocking, op)
1590         else:
1591           cmdlib._CheckOpportunisticLocking(op)
1592
1593
1594 class _OpTestVerifyErrors(opcodes.OpCode):
1595   OP_PARAMS = [
1596     opcodes._PDebugSimulateErrors,
1597     opcodes._PErrorCodes,
1598     opcodes._PIgnoreErrors,
1599     ]
1600
1601
1602 class _LuTestVerifyErrors(cmdlib._VerifyErrors):
1603   def __init__(self, **kwargs):
1604     cmdlib._VerifyErrors.__init__(self)
1605     self.op = _OpTestVerifyErrors(**kwargs)
1606     self.op.Validate(True)
1607     self.msglist = []
1608     self._feedback_fn = self.msglist.append
1609     self.bad = False
1610
1611   def DispatchCallError(self, which, *args, **kwargs):
1612     if which:
1613       self._Error(*args, **kwargs)
1614     else:
1615       self._ErrorIf(True, *args, **kwargs)
1616
1617   def CallErrorIf(self, c, *args, **kwargs):
1618     self._ErrorIf(c, *args, **kwargs)
1619
1620
1621 class TestVerifyErrors(unittest.TestCase):
1622   # Fake cluster-verify error code structures; we use two arbitary real error
1623   # codes to pass validation of ignore_errors
1624   (_, _ERR1ID, _) = constants.CV_ECLUSTERCFG
1625   _NODESTR = "node"
1626   _NODENAME = "mynode"
1627   _ERR1CODE = (_NODESTR, _ERR1ID, "Error one")
1628   (_, _ERR2ID, _) = constants.CV_ECLUSTERCERT
1629   _INSTSTR = "instance"
1630   _INSTNAME = "myinstance"
1631   _ERR2CODE = (_INSTSTR, _ERR2ID, "Error two")
1632   # Arguments used to call _Error() or _ErrorIf()
1633   _ERR1ARGS = (_ERR1CODE, _NODENAME, "Error1 is %s", "an error")
1634   _ERR2ARGS = (_ERR2CODE, _INSTNAME, "Error2 has no argument")
1635   # Expected error messages
1636   _ERR1MSG = _ERR1ARGS[2] % _ERR1ARGS[3]
1637   _ERR2MSG = _ERR2ARGS[2]
1638
1639   def testNoError(self):
1640     lu = _LuTestVerifyErrors()
1641     lu.CallErrorIf(False, self._ERR1CODE, *self._ERR1ARGS)
1642     self.assertFalse(lu.bad)
1643     self.assertFalse(lu.msglist)
1644
1645   def _InitTest(self, **kwargs):
1646     self.lu1 = _LuTestVerifyErrors(**kwargs)
1647     self.lu2 = _LuTestVerifyErrors(**kwargs)
1648
1649   def _CallError(self, *args, **kwargs):
1650     # Check that _Error() and _ErrorIf() produce the same results
1651     self.lu1.DispatchCallError(True, *args, **kwargs)
1652     self.lu2.DispatchCallError(False, *args, **kwargs)
1653     self.assertEqual(self.lu1.bad, self.lu2.bad)
1654     self.assertEqual(self.lu1.msglist, self.lu2.msglist)
1655     # Test-specific checks are made on one LU
1656     return self.lu1
1657
1658   def _checkMsgCommon(self, logstr, errmsg, itype, item, warning):
1659     self.assertTrue(errmsg in logstr)
1660     if warning:
1661       self.assertTrue("WARNING" in logstr)
1662     else:
1663       self.assertTrue("ERROR" in logstr)
1664     self.assertTrue(itype in logstr)
1665     self.assertTrue(item in logstr)
1666
1667   def _checkMsg1(self, logstr, warning=False):
1668     self._checkMsgCommon(logstr, self._ERR1MSG, self._NODESTR,
1669                          self._NODENAME, warning)
1670
1671   def _checkMsg2(self, logstr, warning=False):
1672     self._checkMsgCommon(logstr, self._ERR2MSG, self._INSTSTR,
1673                          self._INSTNAME, warning)
1674
1675   def testPlain(self):
1676     self._InitTest()
1677     lu = self._CallError(*self._ERR1ARGS)
1678     self.assertTrue(lu.bad)
1679     self.assertEqual(len(lu.msglist), 1)
1680     self._checkMsg1(lu.msglist[0])
1681
1682   def testMultiple(self):
1683     self._InitTest()
1684     self._CallError(*self._ERR1ARGS)
1685     lu = self._CallError(*self._ERR2ARGS)
1686     self.assertTrue(lu.bad)
1687     self.assertEqual(len(lu.msglist), 2)
1688     self._checkMsg1(lu.msglist[0])
1689     self._checkMsg2(lu.msglist[1])
1690
1691   def testIgnore(self):
1692     self._InitTest(ignore_errors=[self._ERR1ID])
1693     lu = self._CallError(*self._ERR1ARGS)
1694     self.assertFalse(lu.bad)
1695     self.assertEqual(len(lu.msglist), 1)
1696     self._checkMsg1(lu.msglist[0], warning=True)
1697
1698   def testWarning(self):
1699     self._InitTest()
1700     lu = self._CallError(*self._ERR1ARGS,
1701                          code=_LuTestVerifyErrors.ETYPE_WARNING)
1702     self.assertFalse(lu.bad)
1703     self.assertEqual(len(lu.msglist), 1)
1704     self._checkMsg1(lu.msglist[0], warning=True)
1705
1706   def testWarning2(self):
1707     self._InitTest()
1708     self._CallError(*self._ERR1ARGS)
1709     lu = self._CallError(*self._ERR2ARGS,
1710                          code=_LuTestVerifyErrors.ETYPE_WARNING)
1711     self.assertTrue(lu.bad)
1712     self.assertEqual(len(lu.msglist), 2)
1713     self._checkMsg1(lu.msglist[0])
1714     self._checkMsg2(lu.msglist[1], warning=True)
1715
1716   def testDebugSimulate(self):
1717     lu = _LuTestVerifyErrors(debug_simulate_errors=True)
1718     lu.CallErrorIf(False, *self._ERR1ARGS)
1719     self.assertTrue(lu.bad)
1720     self.assertEqual(len(lu.msglist), 1)
1721     self._checkMsg1(lu.msglist[0])
1722
1723   def testErrCodes(self):
1724     self._InitTest(error_codes=True)
1725     lu = self._CallError(*self._ERR1ARGS)
1726     self.assertTrue(lu.bad)
1727     self.assertEqual(len(lu.msglist), 1)
1728     self._checkMsg1(lu.msglist[0])
1729     self.assertTrue(self._ERR1ID in lu.msglist[0])
1730
1731
1732 class TestGetUpdatedIPolicy(unittest.TestCase):
1733   """Tests for cmdlib._GetUpdatedIPolicy()"""
1734   _OLD_CLUSTER_POLICY = {
1735     constants.IPOLICY_VCPU_RATIO: 1.5,
1736     constants.ISPECS_MINMAX: constants.ISPECS_MINMAX_DEFAULTS,
1737     constants.ISPECS_STD: constants.IPOLICY_DEFAULTS[constants.ISPECS_STD],
1738     }
1739   _OLD_GROUP_POLICY = {
1740     constants.IPOLICY_SPINDLE_RATIO: 2.5,
1741     constants.ISPECS_MINMAX: {
1742       constants.ISPECS_MIN: {
1743         constants.ISPEC_MEM_SIZE: 128,
1744         constants.ISPEC_CPU_COUNT: 1,
1745         constants.ISPEC_DISK_COUNT: 1,
1746         constants.ISPEC_DISK_SIZE: 1024,
1747         constants.ISPEC_NIC_COUNT: 1,
1748         constants.ISPEC_SPINDLE_USE: 1,
1749         },
1750       constants.ISPECS_MAX: {
1751         constants.ISPEC_MEM_SIZE: 32768,
1752         constants.ISPEC_CPU_COUNT: 8,
1753         constants.ISPEC_DISK_COUNT: 5,
1754         constants.ISPEC_DISK_SIZE: 1024 * 1024,
1755         constants.ISPEC_NIC_COUNT: 3,
1756         constants.ISPEC_SPINDLE_USE: 12,
1757         },
1758       },
1759     }
1760
1761   def _TestSetSpecs(self, old_policy, isgroup):
1762     diff_minmax = {
1763       constants.ISPECS_MIN: {
1764         constants.ISPEC_MEM_SIZE: 64,
1765         constants.ISPEC_CPU_COUNT: 1,
1766         constants.ISPEC_DISK_COUNT: 2,
1767         constants.ISPEC_DISK_SIZE: 64,
1768         constants.ISPEC_NIC_COUNT: 1,
1769         constants.ISPEC_SPINDLE_USE: 1,
1770         },
1771       constants.ISPECS_MAX: {
1772         constants.ISPEC_MEM_SIZE: 16384,
1773         constants.ISPEC_CPU_COUNT: 10,
1774         constants.ISPEC_DISK_COUNT: 12,
1775         constants.ISPEC_DISK_SIZE: 1024,
1776         constants.ISPEC_NIC_COUNT: 9,
1777         constants.ISPEC_SPINDLE_USE: 18,
1778         },
1779       }
1780     diff_std = {
1781         constants.ISPEC_DISK_COUNT: 10,
1782         constants.ISPEC_DISK_SIZE: 512,
1783         }
1784     diff_policy = {
1785       constants.ISPECS_MINMAX: diff_minmax
1786       }
1787     if not isgroup:
1788       diff_policy[constants.ISPECS_STD] = diff_std
1789     new_policy = cmdlib._GetUpdatedIPolicy(old_policy, diff_policy,
1790                                            group_policy=isgroup)
1791
1792     self.assertTrue(constants.ISPECS_MINMAX in new_policy)
1793     self.assertEqual(new_policy[constants.ISPECS_MINMAX], diff_minmax)
1794     for key in old_policy:
1795       if not key in diff_policy:
1796         self.assertTrue(key in new_policy)
1797         self.assertEqual(new_policy[key], old_policy[key])
1798
1799     if not isgroup:
1800       new_std = new_policy[constants.ISPECS_STD]
1801       for key in diff_std:
1802         self.assertTrue(key in new_std)
1803         self.assertEqual(new_std[key], diff_std[key])
1804       old_std = old_policy.get(constants.ISPECS_STD, {})
1805       for key in old_std:
1806         self.assertTrue(key in new_std)
1807         if key not in diff_std:
1808           self.assertEqual(new_std[key], old_std[key])
1809
1810   def _TestSet(self, old_policy, diff_policy, isgroup):
1811     new_policy = cmdlib._GetUpdatedIPolicy(old_policy, diff_policy,
1812                                            group_policy=isgroup)
1813     for key in diff_policy:
1814       self.assertTrue(key in new_policy)
1815       self.assertEqual(new_policy[key], diff_policy[key])
1816     for key in old_policy:
1817       if not key in diff_policy:
1818         self.assertTrue(key in new_policy)
1819         self.assertEqual(new_policy[key], old_policy[key])
1820
1821   def testSet(self):
1822     diff_policy = {
1823       constants.IPOLICY_VCPU_RATIO: 3,
1824       constants.IPOLICY_DTS: [constants.DT_FILE],
1825       }
1826     self._TestSet(self._OLD_GROUP_POLICY, diff_policy, True)
1827     self._TestSetSpecs(self._OLD_GROUP_POLICY, True)
1828     self._TestSet({}, diff_policy, True)
1829     self._TestSetSpecs({}, True)
1830     self._TestSet(self._OLD_CLUSTER_POLICY, diff_policy, False)
1831     self._TestSetSpecs(self._OLD_CLUSTER_POLICY, False)
1832
1833   def testUnset(self):
1834     old_policy = self._OLD_GROUP_POLICY
1835     diff_policy = {
1836       constants.IPOLICY_SPINDLE_RATIO: constants.VALUE_DEFAULT,
1837       }
1838     new_policy = cmdlib._GetUpdatedIPolicy(old_policy, diff_policy,
1839                                            group_policy=True)
1840     for key in diff_policy:
1841       self.assertFalse(key in new_policy)
1842     for key in old_policy:
1843       if not key in diff_policy:
1844         self.assertTrue(key in new_policy)
1845         self.assertEqual(new_policy[key], old_policy[key])
1846
1847     self.assertRaises(errors.OpPrereqError, cmdlib._GetUpdatedIPolicy,
1848                       old_policy, diff_policy, group_policy=False)
1849
1850   def _TestInvalidKeys(self, old_policy, isgroup):
1851     INVALID_KEY = "this_key_shouldnt_be_allowed"
1852     INVALID_DICT = {
1853       INVALID_KEY: 3,
1854       }
1855     invalid_policy = INVALID_DICT
1856     self.assertRaises(errors.OpPrereqError, cmdlib._GetUpdatedIPolicy,
1857                       old_policy, invalid_policy, group_policy=isgroup)
1858     invalid_ispecs = {
1859       constants.ISPECS_MINMAX: INVALID_DICT,
1860       }
1861     self.assertRaises(errors.TypeEnforcementError, cmdlib._GetUpdatedIPolicy,
1862                       old_policy, invalid_ispecs, group_policy=isgroup)
1863     if isgroup:
1864       invalid_for_group = {
1865         constants.ISPECS_STD: constants.IPOLICY_DEFAULTS[constants.ISPECS_STD],
1866         }
1867       self.assertRaises(errors.OpPrereqError, cmdlib._GetUpdatedIPolicy,
1868                         old_policy, invalid_for_group, group_policy=isgroup)
1869     good_ispecs = self._OLD_CLUSTER_POLICY[constants.ISPECS_MINMAX]
1870     invalid_ispecs = copy.deepcopy(good_ispecs)
1871     invalid_policy = {
1872       constants.ISPECS_MINMAX: invalid_ispecs,
1873       }
1874     for key in constants.ISPECS_MINMAX_KEYS:
1875       ispec = invalid_ispecs[key]
1876       ispec[INVALID_KEY] = None
1877       self.assertRaises(errors.TypeEnforcementError, cmdlib._GetUpdatedIPolicy,
1878                         old_policy, invalid_policy, group_policy=isgroup)
1879       del ispec[INVALID_KEY]
1880       for par in constants.ISPECS_PARAMETERS:
1881         oldv = ispec[par]
1882         ispec[par] = "this_is_not_good"
1883         self.assertRaises(errors.TypeEnforcementError,
1884                           cmdlib._GetUpdatedIPolicy,
1885                           old_policy, invalid_policy, group_policy=isgroup)
1886         ispec[par] = oldv
1887     # This is to make sure that no two errors were present during the tests
1888     cmdlib._GetUpdatedIPolicy(old_policy, invalid_policy, group_policy=isgroup)
1889
1890   def testInvalidKeys(self):
1891     self._TestInvalidKeys(self._OLD_GROUP_POLICY, True)
1892     self._TestInvalidKeys(self._OLD_CLUSTER_POLICY, False)
1893
1894   def testInvalidValues(self):
1895     for par in (constants.IPOLICY_PARAMETERS |
1896                 frozenset([constants.IPOLICY_DTS])):
1897       bad_policy = {
1898         par: "invalid_value",
1899         }
1900       self.assertRaises(errors.OpPrereqError, cmdlib._GetUpdatedIPolicy, {},
1901                         bad_policy, group_policy=True)
1902
1903 if __name__ == "__main__":
1904   testutils.GanetiTestProgram()