Extract group related logial units from cmdlib
[ganeti-local] / test / py / ganeti.cmdlib_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2008, 2011, 2012, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the cmdlib module"""
23
24
25 import os
26 import unittest
27 import time
28 import tempfile
29 import shutil
30 import operator
31 import itertools
32 import copy
33
34 from ganeti import constants
35 from ganeti import mcpu
36 from ganeti import cmdlib
37 from ganeti.cmdlib import cluster
38 from ganeti.cmdlib import group
39 from ganeti.cmdlib import common
40 from ganeti import opcodes
41 from ganeti import errors
42 from ganeti import utils
43 from ganeti import luxi
44 from ganeti import ht
45 from ganeti import objects
46 from ganeti import compat
47 from ganeti import rpc
48 from ganeti import locking
49 from ganeti import pathutils
50 from ganeti.masterd import iallocator
51 from ganeti.hypervisor import hv_xen
52
53 import testutils
54 import mocks
55
56
57 class TestCertVerification(testutils.GanetiTestCase):
58   def setUp(self):
59     testutils.GanetiTestCase.setUp(self)
60
61     self.tmpdir = tempfile.mkdtemp()
62
63   def tearDown(self):
64     shutil.rmtree(self.tmpdir)
65
66   def testVerifyCertificate(self):
67     cluster._VerifyCertificate(testutils.TestDataFilename("cert1.pem"))
68
69     nonexist_filename = os.path.join(self.tmpdir, "does-not-exist")
70
71     (errcode, msg) = cluster._VerifyCertificate(nonexist_filename)
72     self.assertEqual(errcode, cluster.LUClusterVerifyConfig.ETYPE_ERROR)
73
74     # Try to load non-certificate file
75     invalid_cert = testutils.TestDataFilename("bdev-net.txt")
76     (errcode, msg) = cluster._VerifyCertificate(invalid_cert)
77     self.assertEqual(errcode, cluster.LUClusterVerifyConfig.ETYPE_ERROR)
78
79
80 class TestOpcodeParams(testutils.GanetiTestCase):
81   def testParamsStructures(self):
82     for op in sorted(mcpu.Processor.DISPATCH_TABLE):
83       lu = mcpu.Processor.DISPATCH_TABLE[op]
84       lu_name = lu.__name__
85       self.failIf(hasattr(lu, "_OP_REQP"),
86                   msg=("LU '%s' has old-style _OP_REQP" % lu_name))
87       self.failIf(hasattr(lu, "_OP_DEFS"),
88                   msg=("LU '%s' has old-style _OP_DEFS" % lu_name))
89       self.failIf(hasattr(lu, "_OP_PARAMS"),
90                   msg=("LU '%s' has old-style _OP_PARAMS" % lu_name))
91
92
93 class TestIAllocatorChecks(testutils.GanetiTestCase):
94   def testFunction(self):
95     class TestLU(object):
96       def __init__(self, opcode):
97         self.cfg = mocks.FakeConfig()
98         self.op = opcode
99
100     class OpTest(opcodes.OpCode):
101        OP_PARAMS = [
102         ("iallocator", None, ht.NoType, None),
103         ("node", None, ht.NoType, None),
104         ]
105
106     default_iallocator = mocks.FakeConfig().GetDefaultIAllocator()
107     other_iallocator = default_iallocator + "_not"
108
109     op = OpTest()
110     lu = TestLU(op)
111
112     c_i = lambda: cmdlib._CheckIAllocatorOrNode(lu, "iallocator", "node")
113
114     # Neither node nor iallocator given
115     for n in (None, []):
116       op.iallocator = None
117       op.node = n
118       c_i()
119       self.assertEqual(lu.op.iallocator, default_iallocator)
120       self.assertEqual(lu.op.node, n)
121
122     # Both, iallocator and node given
123     for a in ("test", constants.DEFAULT_IALLOCATOR_SHORTCUT):
124       op.iallocator = a
125       op.node = "test"
126       self.assertRaises(errors.OpPrereqError, c_i)
127
128     # Only iallocator given
129     for n in (None, []):
130       op.iallocator = other_iallocator
131       op.node = n
132       c_i()
133       self.assertEqual(lu.op.iallocator, other_iallocator)
134       self.assertEqual(lu.op.node, n)
135
136     # Only node given
137     op.iallocator = None
138     op.node = "node"
139     c_i()
140     self.assertEqual(lu.op.iallocator, None)
141     self.assertEqual(lu.op.node, "node")
142
143     # Asked for default iallocator, no node given
144     op.iallocator = constants.DEFAULT_IALLOCATOR_SHORTCUT
145     op.node = None
146     c_i()
147     self.assertEqual(lu.op.iallocator, default_iallocator)
148     self.assertEqual(lu.op.node, None)
149
150     # No node, iallocator or default iallocator
151     op.iallocator = None
152     op.node = None
153     lu.cfg.GetDefaultIAllocator = lambda: None
154     self.assertRaises(errors.OpPrereqError, c_i)
155
156
157 class TestLUTestJqueue(unittest.TestCase):
158   def test(self):
159     self.assert_(cmdlib.LUTestJqueue._CLIENT_CONNECT_TIMEOUT <
160                  (luxi.WFJC_TIMEOUT * 0.75),
161                  msg=("Client timeout too high, might not notice bugs"
162                       " in WaitForJobChange"))
163
164
165 class TestLUQuery(unittest.TestCase):
166   def test(self):
167     self.assertEqual(sorted(cmdlib._QUERY_IMPL.keys()),
168                      sorted(constants.QR_VIA_OP))
169
170     assert constants.QR_NODE in constants.QR_VIA_OP
171     assert constants.QR_INSTANCE in constants.QR_VIA_OP
172
173     for i in constants.QR_VIA_OP:
174       self.assert_(cmdlib._GetQueryImplementation(i))
175
176     self.assertRaises(errors.OpPrereqError, cmdlib._GetQueryImplementation, "")
177     self.assertRaises(errors.OpPrereqError, cmdlib._GetQueryImplementation,
178                       "xyz")
179
180
181 class TestLUGroupAssignNodes(unittest.TestCase):
182
183   def testCheckAssignmentForSplitInstances(self):
184     node_data = dict((name, objects.Node(name=name, group=group))
185                      for (name, group) in [("n1a", "g1"), ("n1b", "g1"),
186                                            ("n2a", "g2"), ("n2b", "g2"),
187                                            ("n3a", "g3"), ("n3b", "g3"),
188                                            ("n3c", "g3"),
189                                            ])
190
191     def Instance(name, pnode, snode):
192       if snode is None:
193         disks = []
194         disk_template = constants.DT_DISKLESS
195       else:
196         disks = [objects.Disk(dev_type=constants.LD_DRBD8,
197                               logical_id=[pnode, snode, 1, 17, 17])]
198         disk_template = constants.DT_DRBD8
199
200       return objects.Instance(name=name, primary_node=pnode, disks=disks,
201                               disk_template=disk_template)
202
203     instance_data = dict((name, Instance(name, pnode, snode))
204                          for name, pnode, snode in [("inst1a", "n1a", "n1b"),
205                                                     ("inst1b", "n1b", "n1a"),
206                                                     ("inst2a", "n2a", "n2b"),
207                                                     ("inst3a", "n3a", None),
208                                                     ("inst3b", "n3b", "n1b"),
209                                                     ("inst3c", "n3b", "n2b"),
210                                                     ])
211
212     # Test first with the existing state.
213     (new, prev) = \
214       group.LUGroupAssignNodes.CheckAssignmentForSplitInstances([],
215                                                                 node_data,
216                                                                 instance_data)
217
218     self.assertEqual([], new)
219     self.assertEqual(set(["inst3b", "inst3c"]), set(prev))
220
221     # And now some changes.
222     (new, prev) = \
223       group.LUGroupAssignNodes.CheckAssignmentForSplitInstances([("n1b",
224                                                                   "g3")],
225                                                                 node_data,
226                                                                 instance_data)
227
228     self.assertEqual(set(["inst1a", "inst1b"]), set(new))
229     self.assertEqual(set(["inst3c"]), set(prev))
230
231
232 class TestClusterVerifySsh(unittest.TestCase):
233   def testMultipleGroups(self):
234     fn = cluster.LUClusterVerifyGroup._SelectSshCheckNodes
235     mygroupnodes = [
236       objects.Node(name="node20", group="my", offline=False),
237       objects.Node(name="node21", group="my", offline=False),
238       objects.Node(name="node22", group="my", offline=False),
239       objects.Node(name="node23", group="my", offline=False),
240       objects.Node(name="node24", group="my", offline=False),
241       objects.Node(name="node25", group="my", offline=False),
242       objects.Node(name="node26", group="my", offline=True),
243       ]
244     nodes = [
245       objects.Node(name="node1", group="g1", offline=True),
246       objects.Node(name="node2", group="g1", offline=False),
247       objects.Node(name="node3", group="g1", offline=False),
248       objects.Node(name="node4", group="g1", offline=True),
249       objects.Node(name="node5", group="g1", offline=False),
250       objects.Node(name="node10", group="xyz", offline=False),
251       objects.Node(name="node11", group="xyz", offline=False),
252       objects.Node(name="node40", group="alloff", offline=True),
253       objects.Node(name="node41", group="alloff", offline=True),
254       objects.Node(name="node50", group="aaa", offline=False),
255       ] + mygroupnodes
256     assert not utils.FindDuplicates(map(operator.attrgetter("name"), nodes))
257
258     (online, perhost) = fn(mygroupnodes, "my", nodes)
259     self.assertEqual(online, ["node%s" % i for i in range(20, 26)])
260     self.assertEqual(set(perhost.keys()), set(online))
261
262     self.assertEqual(perhost, {
263       "node20": ["node10", "node2", "node50"],
264       "node21": ["node11", "node3", "node50"],
265       "node22": ["node10", "node5", "node50"],
266       "node23": ["node11", "node2", "node50"],
267       "node24": ["node10", "node3", "node50"],
268       "node25": ["node11", "node5", "node50"],
269       })
270
271   def testSingleGroup(self):
272     fn = cluster.LUClusterVerifyGroup._SelectSshCheckNodes
273     nodes = [
274       objects.Node(name="node1", group="default", offline=True),
275       objects.Node(name="node2", group="default", offline=False),
276       objects.Node(name="node3", group="default", offline=False),
277       objects.Node(name="node4", group="default", offline=True),
278       ]
279     assert not utils.FindDuplicates(map(operator.attrgetter("name"), nodes))
280
281     (online, perhost) = fn(nodes, "default", nodes)
282     self.assertEqual(online, ["node2", "node3"])
283     self.assertEqual(set(perhost.keys()), set(online))
284
285     self.assertEqual(perhost, {
286       "node2": [],
287       "node3": [],
288       })
289
290
291 class TestClusterVerifyFiles(unittest.TestCase):
292   @staticmethod
293   def _FakeErrorIf(errors, cond, ecode, item, msg, *args, **kwargs):
294     assert ((ecode == constants.CV_ENODEFILECHECK and
295              ht.TNonEmptyString(item)) or
296             (ecode == constants.CV_ECLUSTERFILECHECK and
297              item is None))
298
299     if args:
300       msg = msg % args
301
302     if cond:
303       errors.append((item, msg))
304
305   _VerifyFiles = cluster.LUClusterVerifyGroup._VerifyFiles
306
307   def test(self):
308     errors = []
309     master_name = "master.example.com"
310     nodeinfo = [
311       objects.Node(name=master_name, offline=False, vm_capable=True),
312       objects.Node(name="node2.example.com", offline=False, vm_capable=True),
313       objects.Node(name="node3.example.com", master_candidate=True,
314                    vm_capable=False),
315       objects.Node(name="node4.example.com", offline=False, vm_capable=True),
316       objects.Node(name="nodata.example.com", offline=False, vm_capable=True),
317       objects.Node(name="offline.example.com", offline=True),
318       ]
319     cluster = objects.Cluster(modify_etc_hosts=True,
320                               enabled_hypervisors=[constants.HT_XEN_HVM])
321     files_all = set([
322       pathutils.CLUSTER_DOMAIN_SECRET_FILE,
323       pathutils.RAPI_CERT_FILE,
324       pathutils.RAPI_USERS_FILE,
325       ])
326     files_opt = set([
327       pathutils.RAPI_USERS_FILE,
328       hv_xen.XL_CONFIG_FILE,
329       pathutils.VNC_PASSWORD_FILE,
330       ])
331     files_mc = set([
332       pathutils.CLUSTER_CONF_FILE,
333       ])
334     files_vm = set([
335       hv_xen.XEND_CONFIG_FILE,
336       hv_xen.XL_CONFIG_FILE,
337       pathutils.VNC_PASSWORD_FILE,
338       ])
339     nvinfo = {
340       master_name: rpc.RpcResult(data=(True, {
341         constants.NV_FILELIST: {
342           pathutils.CLUSTER_CONF_FILE: "82314f897f38b35f9dab2f7c6b1593e0",
343           pathutils.RAPI_CERT_FILE: "babbce8f387bc082228e544a2146fee4",
344           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
345           hv_xen.XEND_CONFIG_FILE: "b4a8a824ab3cac3d88839a9adeadf310",
346           hv_xen.XL_CONFIG_FILE: "77935cee92afd26d162f9e525e3d49b9"
347         }})),
348       "node2.example.com": rpc.RpcResult(data=(True, {
349         constants.NV_FILELIST: {
350           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
351           hv_xen.XEND_CONFIG_FILE: "b4a8a824ab3cac3d88839a9adeadf310",
352           }
353         })),
354       "node3.example.com": rpc.RpcResult(data=(True, {
355         constants.NV_FILELIST: {
356           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
357           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
358           }
359         })),
360       "node4.example.com": rpc.RpcResult(data=(True, {
361         constants.NV_FILELIST: {
362           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
363           pathutils.CLUSTER_CONF_FILE: "conf-a6d4b13e407867f7a7b4f0f232a8f527",
364           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
365           pathutils.RAPI_USERS_FILE: "rapiusers-ea3271e8d810ef3",
366           hv_xen.XL_CONFIG_FILE: "77935cee92afd26d162f9e525e3d49b9"
367           }
368         })),
369       "nodata.example.com": rpc.RpcResult(data=(True, {})),
370       "offline.example.com": rpc.RpcResult(offline=True),
371       }
372     assert set(nvinfo.keys()) == set(map(operator.attrgetter("name"), nodeinfo))
373
374     self._VerifyFiles(compat.partial(self._FakeErrorIf, errors), nodeinfo,
375                       master_name, nvinfo,
376                       (files_all, files_opt, files_mc, files_vm))
377     self.assertEqual(sorted(errors), sorted([
378       (None, ("File %s found with 2 different checksums (variant 1 on"
379               " node2.example.com, node3.example.com, node4.example.com;"
380               " variant 2 on master.example.com)" % pathutils.RAPI_CERT_FILE)),
381       (None, ("File %s is missing from node(s) node2.example.com" %
382               pathutils.CLUSTER_DOMAIN_SECRET_FILE)),
383       (None, ("File %s should not exist on node(s) node4.example.com" %
384               pathutils.CLUSTER_CONF_FILE)),
385       (None, ("File %s is missing from node(s) node4.example.com" %
386               hv_xen.XEND_CONFIG_FILE)),
387       (None, ("File %s is missing from node(s) node3.example.com" %
388               pathutils.CLUSTER_CONF_FILE)),
389       (None, ("File %s found with 2 different checksums (variant 1 on"
390               " master.example.com; variant 2 on node4.example.com)" %
391               pathutils.CLUSTER_CONF_FILE)),
392       (None, ("File %s is optional, but it must exist on all or no nodes (not"
393               " found on master.example.com, node2.example.com,"
394               " node3.example.com)" % pathutils.RAPI_USERS_FILE)),
395       (None, ("File %s is optional, but it must exist on all or no nodes (not"
396               " found on node2.example.com)" % hv_xen.XL_CONFIG_FILE)),
397       ("nodata.example.com", "Node did not return file checksum data"),
398       ]))
399
400
401 class _FakeLU:
402   def __init__(self, cfg=NotImplemented, proc=NotImplemented,
403                rpc=NotImplemented):
404     self.warning_log = []
405     self.info_log = []
406     self.cfg = cfg
407     self.proc = proc
408     self.rpc = rpc
409
410   def LogWarning(self, text, *args):
411     self.warning_log.append((text, args))
412
413   def LogInfo(self, text, *args):
414     self.info_log.append((text, args))
415
416
417 class TestLoadNodeEvacResult(unittest.TestCase):
418   def testSuccess(self):
419     for moved in [[], [
420       ("inst20153.example.com", "grp2", ["nodeA4509", "nodeB2912"]),
421       ]]:
422       for early_release in [False, True]:
423         for use_nodes in [False, True]:
424           jobs = [
425             [opcodes.OpInstanceReplaceDisks().__getstate__()],
426             [opcodes.OpInstanceMigrate().__getstate__()],
427             ]
428
429           alloc_result = (moved, [], jobs)
430           assert iallocator._NEVAC_RESULT(alloc_result)
431
432           lu = _FakeLU()
433           result = common._LoadNodeEvacResult(lu, alloc_result,
434                                               early_release, use_nodes)
435
436           if moved:
437             (_, (info_args, )) = lu.info_log.pop(0)
438             for (instname, instgroup, instnodes) in moved:
439               self.assertTrue(instname in info_args)
440               if use_nodes:
441                 for i in instnodes:
442                   self.assertTrue(i in info_args)
443               else:
444                 self.assertTrue(instgroup in info_args)
445
446           self.assertFalse(lu.info_log)
447           self.assertFalse(lu.warning_log)
448
449           for op in itertools.chain(*result):
450             if hasattr(op.__class__, "early_release"):
451               self.assertEqual(op.early_release, early_release)
452             else:
453               self.assertFalse(hasattr(op, "early_release"))
454
455   def testFailed(self):
456     alloc_result = ([], [
457       ("inst5191.example.com", "errormsg21178"),
458       ], [])
459     assert iallocator._NEVAC_RESULT(alloc_result)
460
461     lu = _FakeLU()
462     self.assertRaises(errors.OpExecError, common._LoadNodeEvacResult,
463                       lu, alloc_result, False, False)
464     self.assertFalse(lu.info_log)
465     (_, (args, )) = lu.warning_log.pop(0)
466     self.assertTrue("inst5191.example.com" in args)
467     self.assertTrue("errormsg21178" in args)
468     self.assertFalse(lu.warning_log)
469
470
471 class TestUpdateAndVerifySubDict(unittest.TestCase):
472   def setUp(self):
473     self.type_check = {
474         "a": constants.VTYPE_INT,
475         "b": constants.VTYPE_STRING,
476         "c": constants.VTYPE_BOOL,
477         "d": constants.VTYPE_STRING,
478         }
479
480   def test(self):
481     old_test = {
482       "foo": {
483         "d": "blubb",
484         "a": 321,
485         },
486       "baz": {
487         "a": 678,
488         "b": "678",
489         "c": True,
490         },
491       }
492     test = {
493       "foo": {
494         "a": 123,
495         "b": "123",
496         "c": True,
497         },
498       "bar": {
499         "a": 321,
500         "b": "321",
501         "c": False,
502         },
503       }
504
505     mv = {
506       "foo": {
507         "a": 123,
508         "b": "123",
509         "c": True,
510         "d": "blubb"
511         },
512       "bar": {
513         "a": 321,
514         "b": "321",
515         "c": False,
516         },
517       "baz": {
518         "a": 678,
519         "b": "678",
520         "c": True,
521         },
522       }
523
524     verified = common._UpdateAndVerifySubDict(old_test, test, self.type_check)
525     self.assertEqual(verified, mv)
526
527   def testWrong(self):
528     test = {
529       "foo": {
530         "a": "blubb",
531         "b": "123",
532         "c": True,
533         },
534       "bar": {
535         "a": 321,
536         "b": "321",
537         "c": False,
538         },
539       }
540
541     self.assertRaises(errors.TypeEnforcementError,
542                       common._UpdateAndVerifySubDict, {}, test,
543                       self.type_check)
544
545
546 class TestHvStateHelper(unittest.TestCase):
547   def testWithoutOpData(self):
548     self.assertEqual(common._MergeAndVerifyHvState(None, NotImplemented),
549                      None)
550
551   def testWithoutOldData(self):
552     new = {
553       constants.HT_XEN_PVM: {
554         constants.HVST_MEMORY_TOTAL: 4096,
555         },
556       }
557     self.assertEqual(common._MergeAndVerifyHvState(new, None), new)
558
559   def testWithWrongHv(self):
560     new = {
561       "i-dont-exist": {
562         constants.HVST_MEMORY_TOTAL: 4096,
563         },
564       }
565     self.assertRaises(errors.OpPrereqError, common._MergeAndVerifyHvState,
566                       new, None)
567
568 class TestDiskStateHelper(unittest.TestCase):
569   def testWithoutOpData(self):
570     self.assertEqual(common._MergeAndVerifyDiskState(None, NotImplemented),
571                      None)
572
573   def testWithoutOldData(self):
574     new = {
575       constants.LD_LV: {
576         "xenvg": {
577           constants.DS_DISK_RESERVED: 1024,
578           },
579         },
580       }
581     self.assertEqual(common._MergeAndVerifyDiskState(new, None), new)
582
583   def testWithWrongStorageType(self):
584     new = {
585       "i-dont-exist": {
586         "xenvg": {
587           constants.DS_DISK_RESERVED: 1024,
588           },
589         },
590       }
591     self.assertRaises(errors.OpPrereqError, common._MergeAndVerifyDiskState,
592                       new, None)
593
594
595 class TestComputeMinMaxSpec(unittest.TestCase):
596   def setUp(self):
597     self.ispecs = {
598       constants.ISPECS_MAX: {
599         constants.ISPEC_MEM_SIZE: 512,
600         constants.ISPEC_DISK_SIZE: 1024,
601         },
602       constants.ISPECS_MIN: {
603         constants.ISPEC_MEM_SIZE: 128,
604         constants.ISPEC_DISK_COUNT: 1,
605         },
606       }
607
608   def testNoneValue(self):
609     self.assertTrue(common._ComputeMinMaxSpec(constants.ISPEC_MEM_SIZE, None,
610                                               self.ispecs, None) is None)
611
612   def testAutoValue(self):
613     self.assertTrue(common._ComputeMinMaxSpec(constants.ISPEC_MEM_SIZE, None,
614                                               self.ispecs,
615                                               constants.VALUE_AUTO) is None)
616
617   def testNotDefined(self):
618     self.assertTrue(common._ComputeMinMaxSpec(constants.ISPEC_NIC_COUNT, None,
619                                               self.ispecs, 3) is None)
620
621   def testNoMinDefined(self):
622     self.assertTrue(common._ComputeMinMaxSpec(constants.ISPEC_DISK_SIZE, None,
623                                               self.ispecs, 128) is None)
624
625   def testNoMaxDefined(self):
626     self.assertTrue(common._ComputeMinMaxSpec(constants.ISPEC_DISK_COUNT,
627                                               None, self.ispecs, 16) is None)
628
629   def testOutOfRange(self):
630     for (name, val) in ((constants.ISPEC_MEM_SIZE, 64),
631                         (constants.ISPEC_MEM_SIZE, 768),
632                         (constants.ISPEC_DISK_SIZE, 4096),
633                         (constants.ISPEC_DISK_COUNT, 0)):
634       min_v = self.ispecs[constants.ISPECS_MIN].get(name, val)
635       max_v = self.ispecs[constants.ISPECS_MAX].get(name, val)
636       self.assertEqual(common._ComputeMinMaxSpec(name, None,
637                                                  self.ispecs, val),
638                        "%s value %s is not in range [%s, %s]" %
639                        (name, val,min_v, max_v))
640       self.assertEqual(common._ComputeMinMaxSpec(name, "1",
641                                                  self.ispecs, val),
642                        "%s/1 value %s is not in range [%s, %s]" %
643                        (name, val,min_v, max_v))
644
645   def test(self):
646     for (name, val) in ((constants.ISPEC_MEM_SIZE, 256),
647                         (constants.ISPEC_MEM_SIZE, 128),
648                         (constants.ISPEC_MEM_SIZE, 512),
649                         (constants.ISPEC_DISK_SIZE, 1024),
650                         (constants.ISPEC_DISK_SIZE, 0),
651                         (constants.ISPEC_DISK_COUNT, 1),
652                         (constants.ISPEC_DISK_COUNT, 5)):
653       self.assertTrue(common._ComputeMinMaxSpec(name, None, self.ispecs, val)
654                       is None)
655
656
657 def _ValidateComputeMinMaxSpec(name, *_):
658   assert name in constants.ISPECS_PARAMETERS
659   return None
660
661
662 def _NoDiskComputeMinMaxSpec(name, *_):
663   if name == constants.ISPEC_DISK_COUNT:
664     return name
665   else:
666     return None
667
668
669 class _SpecWrapper:
670   def __init__(self, spec):
671     self.spec = spec
672
673   def ComputeMinMaxSpec(self, *args):
674     return self.spec.pop(0)
675
676
677 class TestComputeIPolicySpecViolation(unittest.TestCase):
678   # Minimal policy accepted by _ComputeIPolicySpecViolation()
679   _MICRO_IPOL = {
680     constants.IPOLICY_DTS: [constants.DT_PLAIN, constants.DT_DISKLESS],
681     constants.ISPECS_MINMAX: [NotImplemented],
682     }
683
684   def test(self):
685     compute_fn = _ValidateComputeMinMaxSpec
686     ret = common._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
687                                               [1024], 1, constants.DT_PLAIN,
688                                               _compute_fn=compute_fn)
689     self.assertEqual(ret, [])
690
691   def testDiskFull(self):
692     compute_fn = _NoDiskComputeMinMaxSpec
693     ret = common._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
694                                               [1024], 1, constants.DT_PLAIN,
695                                               _compute_fn=compute_fn)
696     self.assertEqual(ret, [constants.ISPEC_DISK_COUNT])
697
698   def testDiskLess(self):
699     compute_fn = _NoDiskComputeMinMaxSpec
700     ret = common._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
701                                               [1024], 1, constants.DT_DISKLESS,
702                                               _compute_fn=compute_fn)
703     self.assertEqual(ret, [])
704
705   def testWrongTemplates(self):
706     compute_fn = _ValidateComputeMinMaxSpec
707     ret = common._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
708                                               [1024], 1, constants.DT_DRBD8,
709                                               _compute_fn=compute_fn)
710     self.assertEqual(len(ret), 1)
711     self.assertTrue("Disk template" in ret[0])
712
713   def testInvalidArguments(self):
714     self.assertRaises(AssertionError, common._ComputeIPolicySpecViolation,
715                       self._MICRO_IPOL, 1024, 1, 1, 1, [], 1,
716                       constants.DT_PLAIN,)
717
718   def testInvalidSpec(self):
719     spec = _SpecWrapper([None, False, "foo", None, "bar", None])
720     compute_fn = spec.ComputeMinMaxSpec
721     ret = common._ComputeIPolicySpecViolation(self._MICRO_IPOL, 1024, 1, 1, 1,
722                                               [1024], 1, constants.DT_PLAIN,
723                                               _compute_fn=compute_fn)
724     self.assertEqual(ret, ["foo", "bar"])
725     self.assertFalse(spec.spec)
726
727   def testWithIPolicy(self):
728     mem_size = 2048
729     cpu_count = 2
730     disk_count = 1
731     disk_sizes = [512]
732     nic_count = 1
733     spindle_use = 4
734     disk_template = "mytemplate"
735     ispec = {
736       constants.ISPEC_MEM_SIZE: mem_size,
737       constants.ISPEC_CPU_COUNT: cpu_count,
738       constants.ISPEC_DISK_COUNT: disk_count,
739       constants.ISPEC_DISK_SIZE: disk_sizes[0],
740       constants.ISPEC_NIC_COUNT: nic_count,
741       constants.ISPEC_SPINDLE_USE: spindle_use,
742       }
743     ipolicy1 = {
744       constants.ISPECS_MINMAX: [{
745         constants.ISPECS_MIN: ispec,
746         constants.ISPECS_MAX: ispec,
747         }],
748       constants.IPOLICY_DTS: [disk_template],
749       }
750     ispec_copy = copy.deepcopy(ispec)
751     ipolicy2 = {
752       constants.ISPECS_MINMAX: [
753         {
754           constants.ISPECS_MIN: ispec_copy,
755           constants.ISPECS_MAX: ispec_copy,
756           },
757         {
758           constants.ISPECS_MIN: ispec,
759           constants.ISPECS_MAX: ispec,
760           },
761         ],
762       constants.IPOLICY_DTS: [disk_template],
763       }
764     ipolicy3 = {
765       constants.ISPECS_MINMAX: [
766         {
767           constants.ISPECS_MIN: ispec,
768           constants.ISPECS_MAX: ispec,
769           },
770         {
771           constants.ISPECS_MIN: ispec_copy,
772           constants.ISPECS_MAX: ispec_copy,
773           },
774         ],
775       constants.IPOLICY_DTS: [disk_template],
776       }
777     def AssertComputeViolation(ipolicy, violations):
778       ret = common._ComputeIPolicySpecViolation(ipolicy, mem_size, cpu_count,
779                                                 disk_count, nic_count,
780                                                 disk_sizes, spindle_use,
781                                                 disk_template)
782       self.assertEqual(len(ret), violations)
783
784     AssertComputeViolation(ipolicy1, 0)
785     AssertComputeViolation(ipolicy2, 0)
786     AssertComputeViolation(ipolicy3, 0)
787     for par in constants.ISPECS_PARAMETERS:
788       ispec[par] += 1
789       AssertComputeViolation(ipolicy1, 1)
790       AssertComputeViolation(ipolicy2, 0)
791       AssertComputeViolation(ipolicy3, 0)
792       ispec[par] -= 2
793       AssertComputeViolation(ipolicy1, 1)
794       AssertComputeViolation(ipolicy2, 0)
795       AssertComputeViolation(ipolicy3, 0)
796       ispec[par] += 1 # Restore
797     ipolicy1[constants.IPOLICY_DTS] = ["another_template"]
798     AssertComputeViolation(ipolicy1, 1)
799
800
801 class _StubComputeIPolicySpecViolation:
802   def __init__(self, mem_size, cpu_count, disk_count, nic_count, disk_sizes,
803                spindle_use, disk_template):
804     self.mem_size = mem_size
805     self.cpu_count = cpu_count
806     self.disk_count = disk_count
807     self.nic_count = nic_count
808     self.disk_sizes = disk_sizes
809     self.spindle_use = spindle_use
810     self.disk_template = disk_template
811
812   def __call__(self, _, mem_size, cpu_count, disk_count, nic_count, disk_sizes,
813                spindle_use, disk_template):
814     assert self.mem_size == mem_size
815     assert self.cpu_count == cpu_count
816     assert self.disk_count == disk_count
817     assert self.nic_count == nic_count
818     assert self.disk_sizes == disk_sizes
819     assert self.spindle_use == spindle_use
820     assert self.disk_template == disk_template
821
822     return []
823
824
825 class _FakeConfigForComputeIPolicyInstanceViolation:
826   def __init__(self, be):
827     self.cluster = objects.Cluster(beparams={"default": be})
828
829   def GetClusterInfo(self):
830     return self.cluster
831
832
833 class TestComputeIPolicyInstanceViolation(unittest.TestCase):
834   def test(self):
835     beparams = {
836       constants.BE_MAXMEM: 2048,
837       constants.BE_VCPUS: 2,
838       constants.BE_SPINDLE_USE: 4,
839       }
840     disks = [objects.Disk(size=512)]
841     cfg = _FakeConfigForComputeIPolicyInstanceViolation(beparams)
842     instance = objects.Instance(beparams=beparams, disks=disks, nics=[],
843                                 disk_template=constants.DT_PLAIN)
844     stub = _StubComputeIPolicySpecViolation(2048, 2, 1, 0, [512], 4,
845                                             constants.DT_PLAIN)
846     ret = common._ComputeIPolicyInstanceViolation(NotImplemented, instance,
847                                                   cfg, _compute_fn=stub)
848     self.assertEqual(ret, [])
849     instance2 = objects.Instance(beparams={}, disks=disks, nics=[],
850                                  disk_template=constants.DT_PLAIN)
851     ret = common._ComputeIPolicyInstanceViolation(NotImplemented, instance2,
852                                                   cfg, _compute_fn=stub)
853     self.assertEqual(ret, [])
854
855
856 class TestComputeIPolicyInstanceSpecViolation(unittest.TestCase):
857   def test(self):
858     ispec = {
859       constants.ISPEC_MEM_SIZE: 2048,
860       constants.ISPEC_CPU_COUNT: 2,
861       constants.ISPEC_DISK_COUNT: 1,
862       constants.ISPEC_DISK_SIZE: [512],
863       constants.ISPEC_NIC_COUNT: 0,
864       constants.ISPEC_SPINDLE_USE: 1,
865       }
866     stub = _StubComputeIPolicySpecViolation(2048, 2, 1, 0, [512], 1,
867                                             constants.DT_PLAIN)
868     ret = cmdlib._ComputeIPolicyInstanceSpecViolation(NotImplemented, ispec,
869                                                       constants.DT_PLAIN,
870                                                       _compute_fn=stub)
871     self.assertEqual(ret, [])
872
873
874 class _CallRecorder:
875   def __init__(self, return_value=None):
876     self.called = False
877     self.return_value = return_value
878
879   def __call__(self, *args):
880     self.called = True
881     return self.return_value
882
883
884 class TestComputeIPolicyNodeViolation(unittest.TestCase):
885   def setUp(self):
886     self.recorder = _CallRecorder(return_value=[])
887
888   def testSameGroup(self):
889     ret = cmdlib._ComputeIPolicyNodeViolation(NotImplemented, NotImplemented,
890                                               "foo", "foo", NotImplemented,
891                                               _compute_fn=self.recorder)
892     self.assertFalse(self.recorder.called)
893     self.assertEqual(ret, [])
894
895   def testDifferentGroup(self):
896     ret = cmdlib._ComputeIPolicyNodeViolation(NotImplemented, NotImplemented,
897                                               "foo", "bar", NotImplemented,
898                                               _compute_fn=self.recorder)
899     self.assertTrue(self.recorder.called)
900     self.assertEqual(ret, [])
901
902
903 class _FakeConfigForTargetNodeIPolicy:
904   def __init__(self, node_info=NotImplemented):
905     self._node_info = node_info
906
907   def GetNodeInfo(self, _):
908     return self._node_info
909
910
911 class TestCheckTargetNodeIPolicy(unittest.TestCase):
912   def setUp(self):
913     self.instance = objects.Instance(primary_node="blubb")
914     self.target_node = objects.Node(group="bar")
915     node_info = objects.Node(group="foo")
916     fake_cfg = _FakeConfigForTargetNodeIPolicy(node_info=node_info)
917     self.lu = _FakeLU(cfg=fake_cfg)
918
919   def testNoViolation(self):
920     compute_recoder = _CallRecorder(return_value=[])
921     cmdlib._CheckTargetNodeIPolicy(self.lu, NotImplemented, self.instance,
922                                    self.target_node, NotImplemented,
923                                    _compute_fn=compute_recoder)
924     self.assertTrue(compute_recoder.called)
925     self.assertEqual(self.lu.warning_log, [])
926
927   def testNoIgnore(self):
928     compute_recoder = _CallRecorder(return_value=["mem_size not in range"])
929     self.assertRaises(errors.OpPrereqError, cmdlib._CheckTargetNodeIPolicy,
930                       self.lu, NotImplemented, self.instance, self.target_node,
931                       NotImplemented, _compute_fn=compute_recoder)
932     self.assertTrue(compute_recoder.called)
933     self.assertEqual(self.lu.warning_log, [])
934
935   def testIgnoreViolation(self):
936     compute_recoder = _CallRecorder(return_value=["mem_size not in range"])
937     cmdlib._CheckTargetNodeIPolicy(self.lu, NotImplemented, self.instance,
938                                    self.target_node, NotImplemented,
939                                    ignore=True, _compute_fn=compute_recoder)
940     self.assertTrue(compute_recoder.called)
941     msg = ("Instance does not meet target node group's (bar) instance policy:"
942            " mem_size not in range")
943     self.assertEqual(self.lu.warning_log, [(msg, ())])
944
945
946 class TestApplyContainerMods(unittest.TestCase):
947   def testEmptyContainer(self):
948     container = []
949     chgdesc = []
950     cmdlib.ApplyContainerMods("test", container, chgdesc, [], None, None, None)
951     self.assertEqual(container, [])
952     self.assertEqual(chgdesc, [])
953
954   def testAdd(self):
955     container = []
956     chgdesc = []
957     mods = cmdlib.PrepareContainerMods([
958       (constants.DDM_ADD, -1, "Hello"),
959       (constants.DDM_ADD, -1, "World"),
960       (constants.DDM_ADD, 0, "Start"),
961       (constants.DDM_ADD, -1, "End"),
962       ], None)
963     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
964                               None, None, None)
965     self.assertEqual(container, ["Start", "Hello", "World", "End"])
966     self.assertEqual(chgdesc, [])
967
968     mods = cmdlib.PrepareContainerMods([
969       (constants.DDM_ADD, 0, "zero"),
970       (constants.DDM_ADD, 3, "Added"),
971       (constants.DDM_ADD, 5, "four"),
972       (constants.DDM_ADD, 7, "xyz"),
973       ], None)
974     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
975                               None, None, None)
976     self.assertEqual(container,
977                      ["zero", "Start", "Hello", "Added", "World", "four",
978                       "End", "xyz"])
979     self.assertEqual(chgdesc, [])
980
981     for idx in [-2, len(container) + 1]:
982       mods = cmdlib.PrepareContainerMods([
983         (constants.DDM_ADD, idx, "error"),
984         ], None)
985       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
986                         "test", container, None, mods, None, None, None)
987
988   def testRemoveError(self):
989     for idx in [0, 1, 2, 100, -1, -4]:
990       mods = cmdlib.PrepareContainerMods([
991         (constants.DDM_REMOVE, idx, None),
992         ], None)
993       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
994                         "test", [], None, mods, None, None, None)
995
996     mods = cmdlib.PrepareContainerMods([
997       (constants.DDM_REMOVE, 0, object()),
998       ], None)
999     self.assertRaises(AssertionError, cmdlib.ApplyContainerMods,
1000                       "test", [""], None, mods, None, None, None)
1001
1002   def testAddError(self):
1003     for idx in range(-100, -1) + [100]:
1004       mods = cmdlib.PrepareContainerMods([
1005         (constants.DDM_ADD, idx, None),
1006         ], None)
1007       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
1008                         "test", [], None, mods, None, None, None)
1009
1010   def testRemove(self):
1011     container = ["item 1", "item 2"]
1012     mods = cmdlib.PrepareContainerMods([
1013       (constants.DDM_ADD, -1, "aaa"),
1014       (constants.DDM_REMOVE, -1, None),
1015       (constants.DDM_ADD, -1, "bbb"),
1016       ], None)
1017     chgdesc = []
1018     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
1019                               None, None, None)
1020     self.assertEqual(container, ["item 1", "item 2", "bbb"])
1021     self.assertEqual(chgdesc, [
1022       ("test/2", "remove"),
1023       ])
1024
1025   def testModify(self):
1026     container = ["item 1", "item 2"]
1027     mods = cmdlib.PrepareContainerMods([
1028       (constants.DDM_MODIFY, -1, "a"),
1029       (constants.DDM_MODIFY, 0, "b"),
1030       (constants.DDM_MODIFY, 1, "c"),
1031       ], None)
1032     chgdesc = []
1033     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
1034                               None, None, None)
1035     self.assertEqual(container, ["item 1", "item 2"])
1036     self.assertEqual(chgdesc, [])
1037
1038     for idx in [-2, len(container) + 1]:
1039       mods = cmdlib.PrepareContainerMods([
1040         (constants.DDM_MODIFY, idx, "error"),
1041         ], None)
1042       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
1043                         "test", container, None, mods, None, None, None)
1044
1045   class _PrivateData:
1046     def __init__(self):
1047       self.data = None
1048
1049   @staticmethod
1050   def _CreateTestFn(idx, params, private):
1051     private.data = ("add", idx, params)
1052     return ((100 * idx, params), [
1053       ("test/%s" % idx, hex(idx)),
1054       ])
1055
1056   @staticmethod
1057   def _ModifyTestFn(idx, item, params, private):
1058     private.data = ("modify", idx, params)
1059     return [
1060       ("test/%s" % idx, "modify %s" % params),
1061       ]
1062
1063   @staticmethod
1064   def _RemoveTestFn(idx, item, private):
1065     private.data = ("remove", idx, item)
1066
1067   def testAddWithCreateFunction(self):
1068     container = []
1069     chgdesc = []
1070     mods = cmdlib.PrepareContainerMods([
1071       (constants.DDM_ADD, -1, "Hello"),
1072       (constants.DDM_ADD, -1, "World"),
1073       (constants.DDM_ADD, 0, "Start"),
1074       (constants.DDM_ADD, -1, "End"),
1075       (constants.DDM_REMOVE, 2, None),
1076       (constants.DDM_MODIFY, -1, "foobar"),
1077       (constants.DDM_REMOVE, 2, None),
1078       (constants.DDM_ADD, 1, "More"),
1079       ], self._PrivateData)
1080     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
1081       self._CreateTestFn, self._ModifyTestFn, self._RemoveTestFn)
1082     self.assertEqual(container, [
1083       (000, "Start"),
1084       (100, "More"),
1085       (000, "Hello"),
1086       ])
1087     self.assertEqual(chgdesc, [
1088       ("test/0", "0x0"),
1089       ("test/1", "0x1"),
1090       ("test/0", "0x0"),
1091       ("test/3", "0x3"),
1092       ("test/2", "remove"),
1093       ("test/2", "modify foobar"),
1094       ("test/2", "remove"),
1095       ("test/1", "0x1")
1096       ])
1097     self.assertTrue(compat.all(op == private.data[0]
1098                                for (op, _, _, private) in mods))
1099     self.assertEqual([private.data for (op, _, _, private) in mods], [
1100       ("add", 0, "Hello"),
1101       ("add", 1, "World"),
1102       ("add", 0, "Start"),
1103       ("add", 3, "End"),
1104       ("remove", 2, (100, "World")),
1105       ("modify", 2, "foobar"),
1106       ("remove", 2, (300, "End")),
1107       ("add", 1, "More"),
1108       ])
1109
1110
1111 class _FakeConfigForGenDiskTemplate:
1112   def __init__(self):
1113     self._unique_id = itertools.count()
1114     self._drbd_minor = itertools.count(20)
1115     self._port = itertools.count(constants.FIRST_DRBD_PORT)
1116     self._secret = itertools.count()
1117
1118   def GetVGName(self):
1119     return "testvg"
1120
1121   def GenerateUniqueID(self, ec_id):
1122     return "ec%s-uq%s" % (ec_id, self._unique_id.next())
1123
1124   def AllocateDRBDMinor(self, nodes, instance):
1125     return [self._drbd_minor.next()
1126             for _ in nodes]
1127
1128   def AllocatePort(self):
1129     return self._port.next()
1130
1131   def GenerateDRBDSecret(self, ec_id):
1132     return "ec%s-secret%s" % (ec_id, self._secret.next())
1133
1134   def GetInstanceInfo(self, _):
1135     return "foobar"
1136
1137
1138 class _FakeProcForGenDiskTemplate:
1139   def GetECId(self):
1140     return 0
1141
1142
1143 class TestGenerateDiskTemplate(unittest.TestCase):
1144   def setUp(self):
1145     nodegroup = objects.NodeGroup(name="ng")
1146     nodegroup.UpgradeConfig()
1147
1148     cfg = _FakeConfigForGenDiskTemplate()
1149     proc = _FakeProcForGenDiskTemplate()
1150
1151     self.lu = _FakeLU(cfg=cfg, proc=proc)
1152     self.nodegroup = nodegroup
1153
1154   @staticmethod
1155   def GetDiskParams():
1156     return copy.deepcopy(constants.DISK_DT_DEFAULTS)
1157
1158   def testWrongDiskTemplate(self):
1159     gdt = cmdlib._GenerateDiskTemplate
1160     disk_template = "##unknown##"
1161
1162     assert disk_template not in constants.DISK_TEMPLATES
1163
1164     self.assertRaises(errors.ProgrammerError, gdt, self.lu, disk_template,
1165                       "inst26831.example.com", "node30113.example.com", [], [],
1166                       NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1167                       self.GetDiskParams())
1168
1169   def testDiskless(self):
1170     gdt = cmdlib._GenerateDiskTemplate
1171
1172     result = gdt(self.lu, constants.DT_DISKLESS, "inst27734.example.com",
1173                  "node30113.example.com", [], [],
1174                  NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1175                  self.GetDiskParams())
1176     self.assertEqual(result, [])
1177
1178   def _TestTrivialDisk(self, template, disk_info, base_index, exp_dev_type,
1179                        file_storage_dir=NotImplemented,
1180                        file_driver=NotImplemented,
1181                        req_file_storage=NotImplemented,
1182                        req_shr_file_storage=NotImplemented):
1183     gdt = cmdlib._GenerateDiskTemplate
1184
1185     map(lambda params: utils.ForceDictType(params,
1186                                            constants.IDISK_PARAMS_TYPES),
1187         disk_info)
1188
1189     # Check if non-empty list of secondaries is rejected
1190     self.assertRaises(errors.ProgrammerError, gdt, self.lu,
1191                       template, "inst25088.example.com",
1192                       "node185.example.com", ["node323.example.com"], [],
1193                       NotImplemented, NotImplemented, base_index,
1194                       self.lu.LogInfo, self.GetDiskParams(),
1195                       _req_file_storage=req_file_storage,
1196                       _req_shr_file_storage=req_shr_file_storage)
1197
1198     result = gdt(self.lu, template, "inst21662.example.com",
1199                  "node21741.example.com", [],
1200                  disk_info, file_storage_dir, file_driver, base_index,
1201                  self.lu.LogInfo, self.GetDiskParams(),
1202                  _req_file_storage=req_file_storage,
1203                  _req_shr_file_storage=req_shr_file_storage)
1204
1205     for (idx, disk) in enumerate(result):
1206       self.assertTrue(isinstance(disk, objects.Disk))
1207       self.assertEqual(disk.dev_type, exp_dev_type)
1208       self.assertEqual(disk.size, disk_info[idx][constants.IDISK_SIZE])
1209       self.assertEqual(disk.mode, disk_info[idx][constants.IDISK_MODE])
1210       self.assertTrue(disk.children is None)
1211
1212     self._CheckIvNames(result, base_index, base_index + len(disk_info))
1213     cmdlib._UpdateIvNames(base_index, result)
1214     self._CheckIvNames(result, base_index, base_index + len(disk_info))
1215
1216     return result
1217
1218   def _CheckIvNames(self, disks, base_index, end_index):
1219     self.assertEqual(map(operator.attrgetter("iv_name"), disks),
1220                      ["disk/%s" % i for i in range(base_index, end_index)])
1221
1222   def testPlain(self):
1223     disk_info = [{
1224       constants.IDISK_SIZE: 1024,
1225       constants.IDISK_MODE: constants.DISK_RDWR,
1226       }, {
1227       constants.IDISK_SIZE: 4096,
1228       constants.IDISK_VG: "othervg",
1229       constants.IDISK_MODE: constants.DISK_RDWR,
1230       }]
1231
1232     result = self._TestTrivialDisk(constants.DT_PLAIN, disk_info, 3,
1233                                    constants.LD_LV)
1234
1235     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1236       ("testvg", "ec0-uq0.disk3"),
1237       ("othervg", "ec0-uq1.disk4"),
1238       ])
1239
1240   @staticmethod
1241   def _AllowFileStorage():
1242     pass
1243
1244   @staticmethod
1245   def _ForbidFileStorage():
1246     raise errors.OpPrereqError("Disallowed in test")
1247
1248   def testFile(self):
1249     self.assertRaises(errors.OpPrereqError, self._TestTrivialDisk,
1250                       constants.DT_FILE, [], 0, NotImplemented,
1251                       req_file_storage=self._ForbidFileStorage)
1252     self.assertRaises(errors.OpPrereqError, self._TestTrivialDisk,
1253                       constants.DT_SHARED_FILE, [], 0, NotImplemented,
1254                       req_shr_file_storage=self._ForbidFileStorage)
1255
1256     for disk_template in [constants.DT_FILE, constants.DT_SHARED_FILE]:
1257       disk_info = [{
1258         constants.IDISK_SIZE: 80 * 1024,
1259         constants.IDISK_MODE: constants.DISK_RDONLY,
1260         }, {
1261         constants.IDISK_SIZE: 4096,
1262         constants.IDISK_MODE: constants.DISK_RDWR,
1263         }, {
1264         constants.IDISK_SIZE: 6 * 1024,
1265         constants.IDISK_MODE: constants.DISK_RDWR,
1266         }]
1267
1268       result = self._TestTrivialDisk(disk_template, disk_info, 2,
1269         constants.LD_FILE, file_storage_dir="/tmp",
1270         file_driver=constants.FD_BLKTAP,
1271         req_file_storage=self._AllowFileStorage,
1272         req_shr_file_storage=self._AllowFileStorage)
1273
1274       self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1275         (constants.FD_BLKTAP, "/tmp/disk2"),
1276         (constants.FD_BLKTAP, "/tmp/disk3"),
1277         (constants.FD_BLKTAP, "/tmp/disk4"),
1278         ])
1279
1280   def testBlock(self):
1281     disk_info = [{
1282       constants.IDISK_SIZE: 8 * 1024,
1283       constants.IDISK_MODE: constants.DISK_RDWR,
1284       constants.IDISK_ADOPT: "/tmp/some/block/dev",
1285       }]
1286
1287     result = self._TestTrivialDisk(constants.DT_BLOCK, disk_info, 10,
1288                                    constants.LD_BLOCKDEV)
1289
1290     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1291       (constants.BLOCKDEV_DRIVER_MANUAL, "/tmp/some/block/dev"),
1292       ])
1293
1294   def testRbd(self):
1295     disk_info = [{
1296       constants.IDISK_SIZE: 8 * 1024,
1297       constants.IDISK_MODE: constants.DISK_RDONLY,
1298       }, {
1299       constants.IDISK_SIZE: 100 * 1024,
1300       constants.IDISK_MODE: constants.DISK_RDWR,
1301       }]
1302
1303     result = self._TestTrivialDisk(constants.DT_RBD, disk_info, 0,
1304                                    constants.LD_RBD)
1305
1306     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1307       ("rbd", "ec0-uq0.rbd.disk0"),
1308       ("rbd", "ec0-uq1.rbd.disk1"),
1309       ])
1310
1311   def testDrbd8(self):
1312     gdt = cmdlib._GenerateDiskTemplate
1313     drbd8_defaults = constants.DISK_LD_DEFAULTS[constants.LD_DRBD8]
1314     drbd8_default_metavg = drbd8_defaults[constants.LDP_DEFAULT_METAVG]
1315
1316     disk_info = [{
1317       constants.IDISK_SIZE: 1024,
1318       constants.IDISK_MODE: constants.DISK_RDWR,
1319       }, {
1320       constants.IDISK_SIZE: 100 * 1024,
1321       constants.IDISK_MODE: constants.DISK_RDONLY,
1322       constants.IDISK_METAVG: "metavg",
1323       }, {
1324       constants.IDISK_SIZE: 4096,
1325       constants.IDISK_MODE: constants.DISK_RDWR,
1326       constants.IDISK_VG: "vgxyz",
1327       },
1328       ]
1329
1330     exp_logical_ids = [[
1331       (self.lu.cfg.GetVGName(), "ec0-uq0.disk0_data"),
1332       (drbd8_default_metavg, "ec0-uq0.disk0_meta"),
1333       ], [
1334       (self.lu.cfg.GetVGName(), "ec0-uq1.disk1_data"),
1335       ("metavg", "ec0-uq1.disk1_meta"),
1336       ], [
1337       ("vgxyz", "ec0-uq2.disk2_data"),
1338       (drbd8_default_metavg, "ec0-uq2.disk2_meta"),
1339       ]]
1340
1341     assert len(exp_logical_ids) == len(disk_info)
1342
1343     map(lambda params: utils.ForceDictType(params,
1344                                            constants.IDISK_PARAMS_TYPES),
1345         disk_info)
1346
1347     # Check if empty list of secondaries is rejected
1348     self.assertRaises(errors.ProgrammerError, gdt, self.lu, constants.DT_DRBD8,
1349                       "inst827.example.com", "node1334.example.com", [],
1350                       disk_info, NotImplemented, NotImplemented, 0,
1351                       self.lu.LogInfo, self.GetDiskParams())
1352
1353     result = gdt(self.lu, constants.DT_DRBD8, "inst827.example.com",
1354                  "node1334.example.com", ["node12272.example.com"],
1355                  disk_info, NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1356                  self.GetDiskParams())
1357
1358     for (idx, disk) in enumerate(result):
1359       self.assertTrue(isinstance(disk, objects.Disk))
1360       self.assertEqual(disk.dev_type, constants.LD_DRBD8)
1361       self.assertEqual(disk.size, disk_info[idx][constants.IDISK_SIZE])
1362       self.assertEqual(disk.mode, disk_info[idx][constants.IDISK_MODE])
1363
1364       for child in disk.children:
1365         self.assertTrue(isinstance(disk, objects.Disk))
1366         self.assertEqual(child.dev_type, constants.LD_LV)
1367         self.assertTrue(child.children is None)
1368
1369       self.assertEqual(map(operator.attrgetter("logical_id"), disk.children),
1370                        exp_logical_ids[idx])
1371
1372       self.assertEqual(len(disk.children), 2)
1373       self.assertEqual(disk.children[0].size, disk.size)
1374       self.assertEqual(disk.children[1].size, constants.DRBD_META_SIZE)
1375
1376     self._CheckIvNames(result, 0, len(disk_info))
1377     cmdlib._UpdateIvNames(0, result)
1378     self._CheckIvNames(result, 0, len(disk_info))
1379
1380     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1381       ("node1334.example.com", "node12272.example.com",
1382        constants.FIRST_DRBD_PORT, 20, 21, "ec0-secret0"),
1383       ("node1334.example.com", "node12272.example.com",
1384        constants.FIRST_DRBD_PORT + 1, 22, 23, "ec0-secret1"),
1385       ("node1334.example.com", "node12272.example.com",
1386        constants.FIRST_DRBD_PORT + 2, 24, 25, "ec0-secret2"),
1387       ])
1388
1389
1390 class _ConfigForDiskWipe:
1391   def __init__(self, exp_node):
1392     self._exp_node = exp_node
1393
1394   def SetDiskID(self, device, node):
1395     assert isinstance(device, objects.Disk)
1396     assert node == self._exp_node
1397
1398
1399 class _RpcForDiskWipe:
1400   def __init__(self, exp_node, pause_cb, wipe_cb):
1401     self._exp_node = exp_node
1402     self._pause_cb = pause_cb
1403     self._wipe_cb = wipe_cb
1404
1405   def call_blockdev_pause_resume_sync(self, node, disks, pause):
1406     assert node == self._exp_node
1407     return rpc.RpcResult(data=self._pause_cb(disks, pause))
1408
1409   def call_blockdev_wipe(self, node, bdev, offset, size):
1410     assert node == self._exp_node
1411     return rpc.RpcResult(data=self._wipe_cb(bdev, offset, size))
1412
1413
1414 class _DiskPauseTracker:
1415   def __init__(self):
1416     self.history = []
1417
1418   def __call__(self, (disks, instance), pause):
1419     assert not (set(disks) - set(instance.disks))
1420
1421     self.history.extend((i.logical_id, i.size, pause)
1422                         for i in disks)
1423
1424     return (True, [True] * len(disks))
1425
1426
1427 class _DiskWipeProgressTracker:
1428   def __init__(self, start_offset):
1429     self._start_offset = start_offset
1430     self.progress = {}
1431
1432   def __call__(self, (disk, _), offset, size):
1433     assert isinstance(offset, (long, int))
1434     assert isinstance(size, (long, int))
1435
1436     max_chunk_size = (disk.size / 100.0 * constants.MIN_WIPE_CHUNK_PERCENT)
1437
1438     assert offset >= self._start_offset
1439     assert (offset + size) <= disk.size
1440
1441     assert size > 0
1442     assert size <= constants.MAX_WIPE_CHUNK
1443     assert size <= max_chunk_size
1444
1445     assert offset == self._start_offset or disk.logical_id in self.progress
1446
1447     # Keep track of progress
1448     cur_progress = self.progress.setdefault(disk.logical_id, self._start_offset)
1449
1450     assert cur_progress == offset
1451
1452     # Record progress
1453     self.progress[disk.logical_id] += size
1454
1455     return (True, None)
1456
1457
1458 class TestWipeDisks(unittest.TestCase):
1459   def _FailingPauseCb(self, (disks, _), pause):
1460     self.assertEqual(len(disks), 3)
1461     self.assertTrue(pause)
1462     # Simulate an RPC error
1463     return (False, "error")
1464
1465   def testPauseFailure(self):
1466     node_name = "node1372.example.com"
1467
1468     lu = _FakeLU(rpc=_RpcForDiskWipe(node_name, self._FailingPauseCb,
1469                                      NotImplemented),
1470                  cfg=_ConfigForDiskWipe(node_name))
1471
1472     disks = [
1473       objects.Disk(dev_type=constants.LD_LV),
1474       objects.Disk(dev_type=constants.LD_LV),
1475       objects.Disk(dev_type=constants.LD_LV),
1476       ]
1477
1478     instance = objects.Instance(name="inst21201",
1479                                 primary_node=node_name,
1480                                 disk_template=constants.DT_PLAIN,
1481                                 disks=disks)
1482
1483     self.assertRaises(errors.OpExecError, cmdlib._WipeDisks, lu, instance)
1484
1485   def _FailingWipeCb(self, (disk, _), offset, size):
1486     # This should only ever be called for the first disk
1487     self.assertEqual(disk.logical_id, "disk0")
1488     return (False, None)
1489
1490   def testFailingWipe(self):
1491     node_name = "node13445.example.com"
1492     pt = _DiskPauseTracker()
1493
1494     lu = _FakeLU(rpc=_RpcForDiskWipe(node_name, pt, self._FailingWipeCb),
1495                  cfg=_ConfigForDiskWipe(node_name))
1496
1497     disks = [
1498       objects.Disk(dev_type=constants.LD_LV, logical_id="disk0",
1499                    size=100 * 1024),
1500       objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1501                    size=500 * 1024),
1502       objects.Disk(dev_type=constants.LD_LV, logical_id="disk2", size=256),
1503       ]
1504
1505     instance = objects.Instance(name="inst562",
1506                                 primary_node=node_name,
1507                                 disk_template=constants.DT_PLAIN,
1508                                 disks=disks)
1509
1510     try:
1511       cmdlib._WipeDisks(lu, instance)
1512     except errors.OpExecError, err:
1513       self.assertTrue(str(err), "Could not wipe disk 0 at offset 0 ")
1514     else:
1515       self.fail("Did not raise exception")
1516
1517     # Check if all disks were paused and resumed
1518     self.assertEqual(pt.history, [
1519       ("disk0", 100 * 1024, True),
1520       ("disk1", 500 * 1024, True),
1521       ("disk2", 256, True),
1522       ("disk0", 100 * 1024, False),
1523       ("disk1", 500 * 1024, False),
1524       ("disk2", 256, False),
1525       ])
1526
1527   def _PrepareWipeTest(self, start_offset, disks):
1528     node_name = "node-with-offset%s.example.com" % start_offset
1529     pauset = _DiskPauseTracker()
1530     progresst = _DiskWipeProgressTracker(start_offset)
1531
1532     lu = _FakeLU(rpc=_RpcForDiskWipe(node_name, pauset, progresst),
1533                  cfg=_ConfigForDiskWipe(node_name))
1534
1535     instance = objects.Instance(name="inst3560",
1536                                 primary_node=node_name,
1537                                 disk_template=constants.DT_PLAIN,
1538                                 disks=disks)
1539
1540     return (lu, instance, pauset, progresst)
1541
1542   def testNormalWipe(self):
1543     disks = [
1544       objects.Disk(dev_type=constants.LD_LV, logical_id="disk0", size=1024),
1545       objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1546                    size=500 * 1024),
1547       objects.Disk(dev_type=constants.LD_LV, logical_id="disk2", size=128),
1548       objects.Disk(dev_type=constants.LD_LV, logical_id="disk3",
1549                    size=constants.MAX_WIPE_CHUNK),
1550       ]
1551
1552     (lu, instance, pauset, progresst) = self._PrepareWipeTest(0, disks)
1553
1554     cmdlib._WipeDisks(lu, instance)
1555
1556     self.assertEqual(pauset.history, [
1557       ("disk0", 1024, True),
1558       ("disk1", 500 * 1024, True),
1559       ("disk2", 128, True),
1560       ("disk3", constants.MAX_WIPE_CHUNK, True),
1561       ("disk0", 1024, False),
1562       ("disk1", 500 * 1024, False),
1563       ("disk2", 128, False),
1564       ("disk3", constants.MAX_WIPE_CHUNK, False),
1565       ])
1566
1567     # Ensure the complete disk has been wiped
1568     self.assertEqual(progresst.progress,
1569                      dict((i.logical_id, i.size) for i in disks))
1570
1571   def testWipeWithStartOffset(self):
1572     for start_offset in [0, 280, 8895, 1563204]:
1573       disks = [
1574         objects.Disk(dev_type=constants.LD_LV, logical_id="disk0",
1575                      size=128),
1576         objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1577                      size=start_offset + (100 * 1024)),
1578         ]
1579
1580       (lu, instance, pauset, progresst) = \
1581         self._PrepareWipeTest(start_offset, disks)
1582
1583       # Test start offset with only one disk
1584       cmdlib._WipeDisks(lu, instance,
1585                         disks=[(1, disks[1], start_offset)])
1586
1587       # Only the second disk may have been paused and wiped
1588       self.assertEqual(pauset.history, [
1589         ("disk1", start_offset + (100 * 1024), True),
1590         ("disk1", start_offset + (100 * 1024), False),
1591         ])
1592       self.assertEqual(progresst.progress, {
1593         "disk1": disks[1].size,
1594         })
1595
1596
1597 class TestDiskSizeInBytesToMebibytes(unittest.TestCase):
1598   def testLessThanOneMebibyte(self):
1599     for i in [1, 2, 7, 512, 1000, 1023]:
1600       lu = _FakeLU()
1601       result = cmdlib._DiskSizeInBytesToMebibytes(lu, i)
1602       self.assertEqual(result, 1)
1603       self.assertEqual(len(lu.warning_log), 1)
1604       self.assertEqual(len(lu.warning_log[0]), 2)
1605       (_, (warnsize, )) = lu.warning_log[0]
1606       self.assertEqual(warnsize, (1024 * 1024) - i)
1607
1608   def testEven(self):
1609     for i in [1, 2, 7, 512, 1000, 1023]:
1610       lu = _FakeLU()
1611       result = cmdlib._DiskSizeInBytesToMebibytes(lu, i * 1024 * 1024)
1612       self.assertEqual(result, i)
1613       self.assertFalse(lu.warning_log)
1614
1615   def testLargeNumber(self):
1616     for i in [1, 2, 7, 512, 1000, 1023, 2724, 12420]:
1617       for j in [1, 2, 486, 326, 986, 1023]:
1618         lu = _FakeLU()
1619         size = (1024 * 1024 * i) + j
1620         result = cmdlib._DiskSizeInBytesToMebibytes(lu, size)
1621         self.assertEqual(result, i + 1, msg="Amount was not rounded up")
1622         self.assertEqual(len(lu.warning_log), 1)
1623         self.assertEqual(len(lu.warning_log[0]), 2)
1624         (_, (warnsize, )) = lu.warning_log[0]
1625         self.assertEqual(warnsize, (1024 * 1024) - j)
1626
1627
1628 class TestCopyLockList(unittest.TestCase):
1629   def test(self):
1630     self.assertEqual(cmdlib._CopyLockList([]), [])
1631     self.assertEqual(cmdlib._CopyLockList(None), None)
1632     self.assertEqual(cmdlib._CopyLockList(locking.ALL_SET), locking.ALL_SET)
1633
1634     names = ["foo", "bar"]
1635     output = cmdlib._CopyLockList(names)
1636     self.assertEqual(names, output)
1637     self.assertNotEqual(id(names), id(output), msg="List was not copied")
1638
1639
1640 class TestCheckOpportunisticLocking(unittest.TestCase):
1641   class OpTest(opcodes.OpCode):
1642     OP_PARAMS = [
1643       opcodes._POpportunisticLocking,
1644       opcodes._PIAllocFromDesc(""),
1645       ]
1646
1647   @classmethod
1648   def _MakeOp(cls, **kwargs):
1649     op = cls.OpTest(**kwargs)
1650     op.Validate(True)
1651     return op
1652
1653   def testMissingAttributes(self):
1654     self.assertRaises(AttributeError, cmdlib._CheckOpportunisticLocking,
1655                       object())
1656
1657   def testDefaults(self):
1658     op = self._MakeOp()
1659     cmdlib._CheckOpportunisticLocking(op)
1660
1661   def test(self):
1662     for iallocator in [None, "something", "other"]:
1663       for opplock in [False, True]:
1664         op = self._MakeOp(iallocator=iallocator, opportunistic_locking=opplock)
1665         if opplock and not iallocator:
1666           self.assertRaises(errors.OpPrereqError,
1667                             cmdlib._CheckOpportunisticLocking, op)
1668         else:
1669           cmdlib._CheckOpportunisticLocking(op)
1670
1671
1672 class _OpTestVerifyErrors(opcodes.OpCode):
1673   OP_PARAMS = [
1674     opcodes._PDebugSimulateErrors,
1675     opcodes._PErrorCodes,
1676     opcodes._PIgnoreErrors,
1677     ]
1678
1679
1680 class _LuTestVerifyErrors(cluster._VerifyErrors):
1681   def __init__(self, **kwargs):
1682     cluster._VerifyErrors.__init__(self)
1683     self.op = _OpTestVerifyErrors(**kwargs)
1684     self.op.Validate(True)
1685     self.msglist = []
1686     self._feedback_fn = self.msglist.append
1687     self.bad = False
1688
1689   def DispatchCallError(self, which, *args, **kwargs):
1690     if which:
1691       self._Error(*args, **kwargs)
1692     else:
1693       self._ErrorIf(True, *args, **kwargs)
1694
1695   def CallErrorIf(self, c, *args, **kwargs):
1696     self._ErrorIf(c, *args, **kwargs)
1697
1698
1699 class TestVerifyErrors(unittest.TestCase):
1700   # Fake cluster-verify error code structures; we use two arbitary real error
1701   # codes to pass validation of ignore_errors
1702   (_, _ERR1ID, _) = constants.CV_ECLUSTERCFG
1703   _NODESTR = "node"
1704   _NODENAME = "mynode"
1705   _ERR1CODE = (_NODESTR, _ERR1ID, "Error one")
1706   (_, _ERR2ID, _) = constants.CV_ECLUSTERCERT
1707   _INSTSTR = "instance"
1708   _INSTNAME = "myinstance"
1709   _ERR2CODE = (_INSTSTR, _ERR2ID, "Error two")
1710   # Arguments used to call _Error() or _ErrorIf()
1711   _ERR1ARGS = (_ERR1CODE, _NODENAME, "Error1 is %s", "an error")
1712   _ERR2ARGS = (_ERR2CODE, _INSTNAME, "Error2 has no argument")
1713   # Expected error messages
1714   _ERR1MSG = _ERR1ARGS[2] % _ERR1ARGS[3]
1715   _ERR2MSG = _ERR2ARGS[2]
1716
1717   def testNoError(self):
1718     lu = _LuTestVerifyErrors()
1719     lu.CallErrorIf(False, self._ERR1CODE, *self._ERR1ARGS)
1720     self.assertFalse(lu.bad)
1721     self.assertFalse(lu.msglist)
1722
1723   def _InitTest(self, **kwargs):
1724     self.lu1 = _LuTestVerifyErrors(**kwargs)
1725     self.lu2 = _LuTestVerifyErrors(**kwargs)
1726
1727   def _CallError(self, *args, **kwargs):
1728     # Check that _Error() and _ErrorIf() produce the same results
1729     self.lu1.DispatchCallError(True, *args, **kwargs)
1730     self.lu2.DispatchCallError(False, *args, **kwargs)
1731     self.assertEqual(self.lu1.bad, self.lu2.bad)
1732     self.assertEqual(self.lu1.msglist, self.lu2.msglist)
1733     # Test-specific checks are made on one LU
1734     return self.lu1
1735
1736   def _checkMsgCommon(self, logstr, errmsg, itype, item, warning):
1737     self.assertTrue(errmsg in logstr)
1738     if warning:
1739       self.assertTrue("WARNING" in logstr)
1740     else:
1741       self.assertTrue("ERROR" in logstr)
1742     self.assertTrue(itype in logstr)
1743     self.assertTrue(item in logstr)
1744
1745   def _checkMsg1(self, logstr, warning=False):
1746     self._checkMsgCommon(logstr, self._ERR1MSG, self._NODESTR,
1747                          self._NODENAME, warning)
1748
1749   def _checkMsg2(self, logstr, warning=False):
1750     self._checkMsgCommon(logstr, self._ERR2MSG, self._INSTSTR,
1751                          self._INSTNAME, warning)
1752
1753   def testPlain(self):
1754     self._InitTest()
1755     lu = self._CallError(*self._ERR1ARGS)
1756     self.assertTrue(lu.bad)
1757     self.assertEqual(len(lu.msglist), 1)
1758     self._checkMsg1(lu.msglist[0])
1759
1760   def testMultiple(self):
1761     self._InitTest()
1762     self._CallError(*self._ERR1ARGS)
1763     lu = self._CallError(*self._ERR2ARGS)
1764     self.assertTrue(lu.bad)
1765     self.assertEqual(len(lu.msglist), 2)
1766     self._checkMsg1(lu.msglist[0])
1767     self._checkMsg2(lu.msglist[1])
1768
1769   def testIgnore(self):
1770     self._InitTest(ignore_errors=[self._ERR1ID])
1771     lu = self._CallError(*self._ERR1ARGS)
1772     self.assertFalse(lu.bad)
1773     self.assertEqual(len(lu.msglist), 1)
1774     self._checkMsg1(lu.msglist[0], warning=True)
1775
1776   def testWarning(self):
1777     self._InitTest()
1778     lu = self._CallError(*self._ERR1ARGS,
1779                          code=_LuTestVerifyErrors.ETYPE_WARNING)
1780     self.assertFalse(lu.bad)
1781     self.assertEqual(len(lu.msglist), 1)
1782     self._checkMsg1(lu.msglist[0], warning=True)
1783
1784   def testWarning2(self):
1785     self._InitTest()
1786     self._CallError(*self._ERR1ARGS)
1787     lu = self._CallError(*self._ERR2ARGS,
1788                          code=_LuTestVerifyErrors.ETYPE_WARNING)
1789     self.assertTrue(lu.bad)
1790     self.assertEqual(len(lu.msglist), 2)
1791     self._checkMsg1(lu.msglist[0])
1792     self._checkMsg2(lu.msglist[1], warning=True)
1793
1794   def testDebugSimulate(self):
1795     lu = _LuTestVerifyErrors(debug_simulate_errors=True)
1796     lu.CallErrorIf(False, *self._ERR1ARGS)
1797     self.assertTrue(lu.bad)
1798     self.assertEqual(len(lu.msglist), 1)
1799     self._checkMsg1(lu.msglist[0])
1800
1801   def testErrCodes(self):
1802     self._InitTest(error_codes=True)
1803     lu = self._CallError(*self._ERR1ARGS)
1804     self.assertTrue(lu.bad)
1805     self.assertEqual(len(lu.msglist), 1)
1806     self._checkMsg1(lu.msglist[0])
1807     self.assertTrue(self._ERR1ID in lu.msglist[0])
1808
1809
1810 class TestGetUpdatedIPolicy(unittest.TestCase):
1811   """Tests for cmdlib._GetUpdatedIPolicy()"""
1812   _OLD_CLUSTER_POLICY = {
1813     constants.IPOLICY_VCPU_RATIO: 1.5,
1814     constants.ISPECS_MINMAX: [
1815       {
1816         constants.ISPECS_MIN: {
1817           constants.ISPEC_MEM_SIZE: 32768,
1818           constants.ISPEC_CPU_COUNT: 8,
1819           constants.ISPEC_DISK_COUNT: 1,
1820           constants.ISPEC_DISK_SIZE: 1024,
1821           constants.ISPEC_NIC_COUNT: 1,
1822           constants.ISPEC_SPINDLE_USE: 1,
1823           },
1824         constants.ISPECS_MAX: {
1825           constants.ISPEC_MEM_SIZE: 65536,
1826           constants.ISPEC_CPU_COUNT: 10,
1827           constants.ISPEC_DISK_COUNT: 5,
1828           constants.ISPEC_DISK_SIZE: 1024 * 1024,
1829           constants.ISPEC_NIC_COUNT: 3,
1830           constants.ISPEC_SPINDLE_USE: 12,
1831           },
1832         },
1833       constants.ISPECS_MINMAX_DEFAULTS,
1834       ],
1835     constants.ISPECS_STD: constants.IPOLICY_DEFAULTS[constants.ISPECS_STD],
1836     }
1837   _OLD_GROUP_POLICY = {
1838     constants.IPOLICY_SPINDLE_RATIO: 2.5,
1839     constants.ISPECS_MINMAX: [{
1840       constants.ISPECS_MIN: {
1841         constants.ISPEC_MEM_SIZE: 128,
1842         constants.ISPEC_CPU_COUNT: 1,
1843         constants.ISPEC_DISK_COUNT: 1,
1844         constants.ISPEC_DISK_SIZE: 1024,
1845         constants.ISPEC_NIC_COUNT: 1,
1846         constants.ISPEC_SPINDLE_USE: 1,
1847         },
1848       constants.ISPECS_MAX: {
1849         constants.ISPEC_MEM_SIZE: 32768,
1850         constants.ISPEC_CPU_COUNT: 8,
1851         constants.ISPEC_DISK_COUNT: 5,
1852         constants.ISPEC_DISK_SIZE: 1024 * 1024,
1853         constants.ISPEC_NIC_COUNT: 3,
1854         constants.ISPEC_SPINDLE_USE: 12,
1855         },
1856       }],
1857     }
1858
1859   def _TestSetSpecs(self, old_policy, isgroup):
1860     diff_minmax = [{
1861       constants.ISPECS_MIN: {
1862         constants.ISPEC_MEM_SIZE: 64,
1863         constants.ISPEC_CPU_COUNT: 1,
1864         constants.ISPEC_DISK_COUNT: 2,
1865         constants.ISPEC_DISK_SIZE: 64,
1866         constants.ISPEC_NIC_COUNT: 1,
1867         constants.ISPEC_SPINDLE_USE: 1,
1868         },
1869       constants.ISPECS_MAX: {
1870         constants.ISPEC_MEM_SIZE: 16384,
1871         constants.ISPEC_CPU_COUNT: 10,
1872         constants.ISPEC_DISK_COUNT: 12,
1873         constants.ISPEC_DISK_SIZE: 1024,
1874         constants.ISPEC_NIC_COUNT: 9,
1875         constants.ISPEC_SPINDLE_USE: 18,
1876         },
1877       }]
1878     diff_std = {
1879         constants.ISPEC_DISK_COUNT: 10,
1880         constants.ISPEC_DISK_SIZE: 512,
1881         }
1882     diff_policy = {
1883       constants.ISPECS_MINMAX: diff_minmax
1884       }
1885     if not isgroup:
1886       diff_policy[constants.ISPECS_STD] = diff_std
1887     new_policy = common._GetUpdatedIPolicy(old_policy, diff_policy,
1888                                            group_policy=isgroup)
1889
1890     self.assertTrue(constants.ISPECS_MINMAX in new_policy)
1891     self.assertEqual(new_policy[constants.ISPECS_MINMAX], diff_minmax)
1892     for key in old_policy:
1893       if not key in diff_policy:
1894         self.assertTrue(key in new_policy)
1895         self.assertEqual(new_policy[key], old_policy[key])
1896
1897     if not isgroup:
1898       new_std = new_policy[constants.ISPECS_STD]
1899       for key in diff_std:
1900         self.assertTrue(key in new_std)
1901         self.assertEqual(new_std[key], diff_std[key])
1902       old_std = old_policy.get(constants.ISPECS_STD, {})
1903       for key in old_std:
1904         self.assertTrue(key in new_std)
1905         if key not in diff_std:
1906           self.assertEqual(new_std[key], old_std[key])
1907
1908   def _TestSet(self, old_policy, diff_policy, isgroup):
1909     new_policy = common._GetUpdatedIPolicy(old_policy, diff_policy,
1910                                            group_policy=isgroup)
1911     for key in diff_policy:
1912       self.assertTrue(key in new_policy)
1913       self.assertEqual(new_policy[key], diff_policy[key])
1914     for key in old_policy:
1915       if not key in diff_policy:
1916         self.assertTrue(key in new_policy)
1917         self.assertEqual(new_policy[key], old_policy[key])
1918
1919   def testSet(self):
1920     diff_policy = {
1921       constants.IPOLICY_VCPU_RATIO: 3,
1922       constants.IPOLICY_DTS: [constants.DT_FILE],
1923       }
1924     self._TestSet(self._OLD_GROUP_POLICY, diff_policy, True)
1925     self._TestSetSpecs(self._OLD_GROUP_POLICY, True)
1926     self._TestSet({}, diff_policy, True)
1927     self._TestSetSpecs({}, True)
1928     self._TestSet(self._OLD_CLUSTER_POLICY, diff_policy, False)
1929     self._TestSetSpecs(self._OLD_CLUSTER_POLICY, False)
1930
1931   def testUnset(self):
1932     old_policy = self._OLD_GROUP_POLICY
1933     diff_policy = {
1934       constants.IPOLICY_SPINDLE_RATIO: constants.VALUE_DEFAULT,
1935       }
1936     new_policy = common._GetUpdatedIPolicy(old_policy, diff_policy,
1937                                            group_policy=True)
1938     for key in diff_policy:
1939       self.assertFalse(key in new_policy)
1940     for key in old_policy:
1941       if not key in diff_policy:
1942         self.assertTrue(key in new_policy)
1943         self.assertEqual(new_policy[key], old_policy[key])
1944
1945     self.assertRaises(errors.OpPrereqError, common._GetUpdatedIPolicy,
1946                       old_policy, diff_policy, group_policy=False)
1947
1948   def testUnsetEmpty(self):
1949     old_policy = {}
1950     for key in constants.IPOLICY_ALL_KEYS:
1951       diff_policy = {
1952         key: constants.VALUE_DEFAULT,
1953         }
1954     new_policy = common._GetUpdatedIPolicy(old_policy, diff_policy,
1955                                            group_policy=True)
1956     self.assertEqual(new_policy, old_policy)
1957
1958   def _TestInvalidKeys(self, old_policy, isgroup):
1959     INVALID_KEY = "this_key_shouldnt_be_allowed"
1960     INVALID_DICT = {
1961       INVALID_KEY: 3,
1962       }
1963     invalid_policy = INVALID_DICT
1964     self.assertRaises(errors.OpPrereqError, common._GetUpdatedIPolicy,
1965                       old_policy, invalid_policy, group_policy=isgroup)
1966     invalid_ispecs = {
1967       constants.ISPECS_MINMAX: [INVALID_DICT],
1968       }
1969     self.assertRaises(errors.TypeEnforcementError, common._GetUpdatedIPolicy,
1970                       old_policy, invalid_ispecs, group_policy=isgroup)
1971     if isgroup:
1972       invalid_for_group = {
1973         constants.ISPECS_STD: constants.IPOLICY_DEFAULTS[constants.ISPECS_STD],
1974         }
1975       self.assertRaises(errors.OpPrereqError, common._GetUpdatedIPolicy,
1976                         old_policy, invalid_for_group, group_policy=isgroup)
1977     good_ispecs = self._OLD_CLUSTER_POLICY[constants.ISPECS_MINMAX]
1978     invalid_ispecs = copy.deepcopy(good_ispecs)
1979     invalid_policy = {
1980       constants.ISPECS_MINMAX: invalid_ispecs,
1981       }
1982     for minmax in invalid_ispecs:
1983       for key in constants.ISPECS_MINMAX_KEYS:
1984         ispec = minmax[key]
1985         ispec[INVALID_KEY] = None
1986         self.assertRaises(errors.TypeEnforcementError,
1987                           common._GetUpdatedIPolicy, old_policy,
1988                           invalid_policy, group_policy=isgroup)
1989         del ispec[INVALID_KEY]
1990         for par in constants.ISPECS_PARAMETERS:
1991           oldv = ispec[par]
1992           ispec[par] = "this_is_not_good"
1993           self.assertRaises(errors.TypeEnforcementError,
1994                             common._GetUpdatedIPolicy,
1995                             old_policy, invalid_policy, group_policy=isgroup)
1996           ispec[par] = oldv
1997     # This is to make sure that no two errors were present during the tests
1998     common._GetUpdatedIPolicy(old_policy, invalid_policy,
1999                               group_policy=isgroup)
2000
2001   def testInvalidKeys(self):
2002     self._TestInvalidKeys(self._OLD_GROUP_POLICY, True)
2003     self._TestInvalidKeys(self._OLD_CLUSTER_POLICY, False)
2004
2005   def testInvalidValues(self):
2006     for par in (constants.IPOLICY_PARAMETERS |
2007                 frozenset([constants.IPOLICY_DTS])):
2008       bad_policy = {
2009         par: "invalid_value",
2010         }
2011       self.assertRaises(errors.OpPrereqError, common._GetUpdatedIPolicy, {},
2012                         bad_policy, group_policy=True)
2013
2014 if __name__ == "__main__":
2015   testutils.GanetiTestProgram()