Add unit tests for cmdlib._WipeDisks
[ganeti-local] / test / ganeti.cmdlib_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2008, 2011, 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the cmdlib module"""
23
24
25 import os
26 import unittest
27 import time
28 import tempfile
29 import shutil
30 import operator
31 import itertools
32 import copy
33
34 from ganeti import constants
35 from ganeti import mcpu
36 from ganeti import cmdlib
37 from ganeti import opcodes
38 from ganeti import errors
39 from ganeti import utils
40 from ganeti import luxi
41 from ganeti import ht
42 from ganeti import objects
43 from ganeti import compat
44 from ganeti import rpc
45 from ganeti import pathutils
46 from ganeti.masterd import iallocator
47 from ganeti.hypervisor import hv_xen
48
49 import testutils
50 import mocks
51
52
53 class TestCertVerification(testutils.GanetiTestCase):
54   def setUp(self):
55     testutils.GanetiTestCase.setUp(self)
56
57     self.tmpdir = tempfile.mkdtemp()
58
59   def tearDown(self):
60     shutil.rmtree(self.tmpdir)
61
62   def testVerifyCertificate(self):
63     cmdlib._VerifyCertificate(self._TestDataFilename("cert1.pem"))
64
65     nonexist_filename = os.path.join(self.tmpdir, "does-not-exist")
66
67     (errcode, msg) = cmdlib._VerifyCertificate(nonexist_filename)
68     self.assertEqual(errcode, cmdlib.LUClusterVerifyConfig.ETYPE_ERROR)
69
70     # Try to load non-certificate file
71     invalid_cert = self._TestDataFilename("bdev-net.txt")
72     (errcode, msg) = cmdlib._VerifyCertificate(invalid_cert)
73     self.assertEqual(errcode, cmdlib.LUClusterVerifyConfig.ETYPE_ERROR)
74
75
76 class TestOpcodeParams(testutils.GanetiTestCase):
77   def testParamsStructures(self):
78     for op in sorted(mcpu.Processor.DISPATCH_TABLE):
79       lu = mcpu.Processor.DISPATCH_TABLE[op]
80       lu_name = lu.__name__
81       self.failIf(hasattr(lu, "_OP_REQP"),
82                   msg=("LU '%s' has old-style _OP_REQP" % lu_name))
83       self.failIf(hasattr(lu, "_OP_DEFS"),
84                   msg=("LU '%s' has old-style _OP_DEFS" % lu_name))
85       self.failIf(hasattr(lu, "_OP_PARAMS"),
86                   msg=("LU '%s' has old-style _OP_PARAMS" % lu_name))
87
88
89 class TestIAllocatorChecks(testutils.GanetiTestCase):
90   def testFunction(self):
91     class TestLU(object):
92       def __init__(self, opcode):
93         self.cfg = mocks.FakeConfig()
94         self.op = opcode
95
96     class OpTest(opcodes.OpCode):
97        OP_PARAMS = [
98         ("iallocator", None, ht.NoType, None),
99         ("node", None, ht.NoType, None),
100         ]
101
102     default_iallocator = mocks.FakeConfig().GetDefaultIAllocator()
103     other_iallocator = default_iallocator + "_not"
104
105     op = OpTest()
106     lu = TestLU(op)
107
108     c_i = lambda: cmdlib._CheckIAllocatorOrNode(lu, "iallocator", "node")
109
110     # Neither node nor iallocator given
111     op.iallocator = None
112     op.node = None
113     c_i()
114     self.assertEqual(lu.op.iallocator, default_iallocator)
115     self.assertEqual(lu.op.node, None)
116
117     # Both, iallocator and node given
118     op.iallocator = "test"
119     op.node = "test"
120     self.assertRaises(errors.OpPrereqError, c_i)
121
122     # Only iallocator given
123     op.iallocator = other_iallocator
124     op.node = None
125     c_i()
126     self.assertEqual(lu.op.iallocator, other_iallocator)
127     self.assertEqual(lu.op.node, None)
128
129     # Only node given
130     op.iallocator = None
131     op.node = "node"
132     c_i()
133     self.assertEqual(lu.op.iallocator, None)
134     self.assertEqual(lu.op.node, "node")
135
136     # No node, iallocator or default iallocator
137     op.iallocator = None
138     op.node = None
139     lu.cfg.GetDefaultIAllocator = lambda: None
140     self.assertRaises(errors.OpPrereqError, c_i)
141
142
143 class TestLUTestJqueue(unittest.TestCase):
144   def test(self):
145     self.assert_(cmdlib.LUTestJqueue._CLIENT_CONNECT_TIMEOUT <
146                  (luxi.WFJC_TIMEOUT * 0.75),
147                  msg=("Client timeout too high, might not notice bugs"
148                       " in WaitForJobChange"))
149
150
151 class TestLUQuery(unittest.TestCase):
152   def test(self):
153     self.assertEqual(sorted(cmdlib._QUERY_IMPL.keys()),
154                      sorted(constants.QR_VIA_OP))
155
156     assert constants.QR_NODE in constants.QR_VIA_OP
157     assert constants.QR_INSTANCE in constants.QR_VIA_OP
158
159     for i in constants.QR_VIA_OP:
160       self.assert_(cmdlib._GetQueryImplementation(i))
161
162     self.assertRaises(errors.OpPrereqError, cmdlib._GetQueryImplementation, "")
163     self.assertRaises(errors.OpPrereqError, cmdlib._GetQueryImplementation,
164                       "xyz")
165
166
167 class TestLUGroupAssignNodes(unittest.TestCase):
168
169   def testCheckAssignmentForSplitInstances(self):
170     node_data = dict((name, objects.Node(name=name, group=group))
171                      for (name, group) in [("n1a", "g1"), ("n1b", "g1"),
172                                            ("n2a", "g2"), ("n2b", "g2"),
173                                            ("n3a", "g3"), ("n3b", "g3"),
174                                            ("n3c", "g3"),
175                                            ])
176
177     def Instance(name, pnode, snode):
178       if snode is None:
179         disks = []
180         disk_template = constants.DT_DISKLESS
181       else:
182         disks = [objects.Disk(dev_type=constants.LD_DRBD8,
183                               logical_id=[pnode, snode, 1, 17, 17])]
184         disk_template = constants.DT_DRBD8
185
186       return objects.Instance(name=name, primary_node=pnode, disks=disks,
187                               disk_template=disk_template)
188
189     instance_data = dict((name, Instance(name, pnode, snode))
190                          for name, pnode, snode in [("inst1a", "n1a", "n1b"),
191                                                     ("inst1b", "n1b", "n1a"),
192                                                     ("inst2a", "n2a", "n2b"),
193                                                     ("inst3a", "n3a", None),
194                                                     ("inst3b", "n3b", "n1b"),
195                                                     ("inst3c", "n3b", "n2b"),
196                                                     ])
197
198     # Test first with the existing state.
199     (new, prev) = \
200       cmdlib.LUGroupAssignNodes.CheckAssignmentForSplitInstances([],
201                                                                  node_data,
202                                                                  instance_data)
203
204     self.assertEqual([], new)
205     self.assertEqual(set(["inst3b", "inst3c"]), set(prev))
206
207     # And now some changes.
208     (new, prev) = \
209       cmdlib.LUGroupAssignNodes.CheckAssignmentForSplitInstances([("n1b",
210                                                                    "g3")],
211                                                                  node_data,
212                                                                  instance_data)
213
214     self.assertEqual(set(["inst1a", "inst1b"]), set(new))
215     self.assertEqual(set(["inst3c"]), set(prev))
216
217
218 class TestClusterVerifySsh(unittest.TestCase):
219   def testMultipleGroups(self):
220     fn = cmdlib.LUClusterVerifyGroup._SelectSshCheckNodes
221     mygroupnodes = [
222       objects.Node(name="node20", group="my", offline=False),
223       objects.Node(name="node21", group="my", offline=False),
224       objects.Node(name="node22", group="my", offline=False),
225       objects.Node(name="node23", group="my", offline=False),
226       objects.Node(name="node24", group="my", offline=False),
227       objects.Node(name="node25", group="my", offline=False),
228       objects.Node(name="node26", group="my", offline=True),
229       ]
230     nodes = [
231       objects.Node(name="node1", group="g1", offline=True),
232       objects.Node(name="node2", group="g1", offline=False),
233       objects.Node(name="node3", group="g1", offline=False),
234       objects.Node(name="node4", group="g1", offline=True),
235       objects.Node(name="node5", group="g1", offline=False),
236       objects.Node(name="node10", group="xyz", offline=False),
237       objects.Node(name="node11", group="xyz", offline=False),
238       objects.Node(name="node40", group="alloff", offline=True),
239       objects.Node(name="node41", group="alloff", offline=True),
240       objects.Node(name="node50", group="aaa", offline=False),
241       ] + mygroupnodes
242     assert not utils.FindDuplicates(map(operator.attrgetter("name"), nodes))
243
244     (online, perhost) = fn(mygroupnodes, "my", nodes)
245     self.assertEqual(online, ["node%s" % i for i in range(20, 26)])
246     self.assertEqual(set(perhost.keys()), set(online))
247
248     self.assertEqual(perhost, {
249       "node20": ["node10", "node2", "node50"],
250       "node21": ["node11", "node3", "node50"],
251       "node22": ["node10", "node5", "node50"],
252       "node23": ["node11", "node2", "node50"],
253       "node24": ["node10", "node3", "node50"],
254       "node25": ["node11", "node5", "node50"],
255       })
256
257   def testSingleGroup(self):
258     fn = cmdlib.LUClusterVerifyGroup._SelectSshCheckNodes
259     nodes = [
260       objects.Node(name="node1", group="default", offline=True),
261       objects.Node(name="node2", group="default", offline=False),
262       objects.Node(name="node3", group="default", offline=False),
263       objects.Node(name="node4", group="default", offline=True),
264       ]
265     assert not utils.FindDuplicates(map(operator.attrgetter("name"), nodes))
266
267     (online, perhost) = fn(nodes, "default", nodes)
268     self.assertEqual(online, ["node2", "node3"])
269     self.assertEqual(set(perhost.keys()), set(online))
270
271     self.assertEqual(perhost, {
272       "node2": [],
273       "node3": [],
274       })
275
276
277 class TestClusterVerifyFiles(unittest.TestCase):
278   @staticmethod
279   def _FakeErrorIf(errors, cond, ecode, item, msg, *args, **kwargs):
280     assert ((ecode == constants.CV_ENODEFILECHECK and
281              ht.TNonEmptyString(item)) or
282             (ecode == constants.CV_ECLUSTERFILECHECK and
283              item is None))
284
285     if args:
286       msg = msg % args
287
288     if cond:
289       errors.append((item, msg))
290
291   _VerifyFiles = cmdlib.LUClusterVerifyGroup._VerifyFiles
292
293   def test(self):
294     errors = []
295     master_name = "master.example.com"
296     nodeinfo = [
297       objects.Node(name=master_name, offline=False, vm_capable=True),
298       objects.Node(name="node2.example.com", offline=False, vm_capable=True),
299       objects.Node(name="node3.example.com", master_candidate=True,
300                    vm_capable=False),
301       objects.Node(name="node4.example.com", offline=False, vm_capable=True),
302       objects.Node(name="nodata.example.com", offline=False, vm_capable=True),
303       objects.Node(name="offline.example.com", offline=True),
304       ]
305     cluster = objects.Cluster(modify_etc_hosts=True,
306                               enabled_hypervisors=[constants.HT_XEN_HVM])
307     files_all = set([
308       pathutils.CLUSTER_DOMAIN_SECRET_FILE,
309       pathutils.RAPI_CERT_FILE,
310       pathutils.RAPI_USERS_FILE,
311       ])
312     files_opt = set([
313       pathutils.RAPI_USERS_FILE,
314       hv_xen.XL_CONFIG_FILE,
315       pathutils.VNC_PASSWORD_FILE,
316       ])
317     files_mc = set([
318       pathutils.CLUSTER_CONF_FILE,
319       ])
320     files_vm = set([
321       hv_xen.XEND_CONFIG_FILE,
322       hv_xen.XL_CONFIG_FILE,
323       pathutils.VNC_PASSWORD_FILE,
324       ])
325     nvinfo = {
326       master_name: rpc.RpcResult(data=(True, {
327         constants.NV_FILELIST: {
328           pathutils.CLUSTER_CONF_FILE: "82314f897f38b35f9dab2f7c6b1593e0",
329           pathutils.RAPI_CERT_FILE: "babbce8f387bc082228e544a2146fee4",
330           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
331           hv_xen.XEND_CONFIG_FILE: "b4a8a824ab3cac3d88839a9adeadf310",
332           hv_xen.XL_CONFIG_FILE: "77935cee92afd26d162f9e525e3d49b9"
333         }})),
334       "node2.example.com": rpc.RpcResult(data=(True, {
335         constants.NV_FILELIST: {
336           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
337           hv_xen.XEND_CONFIG_FILE: "b4a8a824ab3cac3d88839a9adeadf310",
338           }
339         })),
340       "node3.example.com": rpc.RpcResult(data=(True, {
341         constants.NV_FILELIST: {
342           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
343           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
344           }
345         })),
346       "node4.example.com": rpc.RpcResult(data=(True, {
347         constants.NV_FILELIST: {
348           pathutils.RAPI_CERT_FILE: "97f0356500e866387f4b84233848cc4a",
349           pathutils.CLUSTER_CONF_FILE: "conf-a6d4b13e407867f7a7b4f0f232a8f527",
350           pathutils.CLUSTER_DOMAIN_SECRET_FILE: "cds-47b5b3f19202936bb4",
351           pathutils.RAPI_USERS_FILE: "rapiusers-ea3271e8d810ef3",
352           hv_xen.XL_CONFIG_FILE: "77935cee92afd26d162f9e525e3d49b9"
353           }
354         })),
355       "nodata.example.com": rpc.RpcResult(data=(True, {})),
356       "offline.example.com": rpc.RpcResult(offline=True),
357       }
358     assert set(nvinfo.keys()) == set(map(operator.attrgetter("name"), nodeinfo))
359
360     self._VerifyFiles(compat.partial(self._FakeErrorIf, errors), nodeinfo,
361                       master_name, nvinfo,
362                       (files_all, files_opt, files_mc, files_vm))
363     self.assertEqual(sorted(errors), sorted([
364       (None, ("File %s found with 2 different checksums (variant 1 on"
365               " node2.example.com, node3.example.com, node4.example.com;"
366               " variant 2 on master.example.com)" % pathutils.RAPI_CERT_FILE)),
367       (None, ("File %s is missing from node(s) node2.example.com" %
368               pathutils.CLUSTER_DOMAIN_SECRET_FILE)),
369       (None, ("File %s should not exist on node(s) node4.example.com" %
370               pathutils.CLUSTER_CONF_FILE)),
371       (None, ("File %s is missing from node(s) node4.example.com" %
372               hv_xen.XEND_CONFIG_FILE)),
373       (None, ("File %s is missing from node(s) node3.example.com" %
374               pathutils.CLUSTER_CONF_FILE)),
375       (None, ("File %s found with 2 different checksums (variant 1 on"
376               " master.example.com; variant 2 on node4.example.com)" %
377               pathutils.CLUSTER_CONF_FILE)),
378       (None, ("File %s is optional, but it must exist on all or no nodes (not"
379               " found on master.example.com, node2.example.com,"
380               " node3.example.com)" % pathutils.RAPI_USERS_FILE)),
381       (None, ("File %s is optional, but it must exist on all or no nodes (not"
382               " found on node2.example.com)" % hv_xen.XL_CONFIG_FILE)),
383       ("nodata.example.com", "Node did not return file checksum data"),
384       ]))
385
386
387 class _FakeLU:
388   def __init__(self, cfg=NotImplemented, proc=NotImplemented,
389                rpc=NotImplemented):
390     self.warning_log = []
391     self.info_log = []
392     self.cfg = cfg
393     self.proc = proc
394     self.rpc = rpc
395
396   def LogWarning(self, text, *args):
397     self.warning_log.append((text, args))
398
399   def LogInfo(self, text, *args):
400     self.info_log.append((text, args))
401
402
403 class TestLoadNodeEvacResult(unittest.TestCase):
404   def testSuccess(self):
405     for moved in [[], [
406       ("inst20153.example.com", "grp2", ["nodeA4509", "nodeB2912"]),
407       ]]:
408       for early_release in [False, True]:
409         for use_nodes in [False, True]:
410           jobs = [
411             [opcodes.OpInstanceReplaceDisks().__getstate__()],
412             [opcodes.OpInstanceMigrate().__getstate__()],
413             ]
414
415           alloc_result = (moved, [], jobs)
416           assert iallocator._NEVAC_RESULT(alloc_result)
417
418           lu = _FakeLU()
419           result = cmdlib._LoadNodeEvacResult(lu, alloc_result,
420                                               early_release, use_nodes)
421
422           if moved:
423             (_, (info_args, )) = lu.info_log.pop(0)
424             for (instname, instgroup, instnodes) in moved:
425               self.assertTrue(instname in info_args)
426               if use_nodes:
427                 for i in instnodes:
428                   self.assertTrue(i in info_args)
429               else:
430                 self.assertTrue(instgroup in info_args)
431
432           self.assertFalse(lu.info_log)
433           self.assertFalse(lu.warning_log)
434
435           for op in itertools.chain(*result):
436             if hasattr(op.__class__, "early_release"):
437               self.assertEqual(op.early_release, early_release)
438             else:
439               self.assertFalse(hasattr(op, "early_release"))
440
441   def testFailed(self):
442     alloc_result = ([], [
443       ("inst5191.example.com", "errormsg21178"),
444       ], [])
445     assert iallocator._NEVAC_RESULT(alloc_result)
446
447     lu = _FakeLU()
448     self.assertRaises(errors.OpExecError, cmdlib._LoadNodeEvacResult,
449                       lu, alloc_result, False, False)
450     self.assertFalse(lu.info_log)
451     (_, (args, )) = lu.warning_log.pop(0)
452     self.assertTrue("inst5191.example.com" in args)
453     self.assertTrue("errormsg21178" in args)
454     self.assertFalse(lu.warning_log)
455
456
457 class TestUpdateAndVerifySubDict(unittest.TestCase):
458   def setUp(self):
459     self.type_check = {
460         "a": constants.VTYPE_INT,
461         "b": constants.VTYPE_STRING,
462         "c": constants.VTYPE_BOOL,
463         "d": constants.VTYPE_STRING,
464         }
465
466   def test(self):
467     old_test = {
468       "foo": {
469         "d": "blubb",
470         "a": 321,
471         },
472       "baz": {
473         "a": 678,
474         "b": "678",
475         "c": True,
476         },
477       }
478     test = {
479       "foo": {
480         "a": 123,
481         "b": "123",
482         "c": True,
483         },
484       "bar": {
485         "a": 321,
486         "b": "321",
487         "c": False,
488         },
489       }
490
491     mv = {
492       "foo": {
493         "a": 123,
494         "b": "123",
495         "c": True,
496         "d": "blubb"
497         },
498       "bar": {
499         "a": 321,
500         "b": "321",
501         "c": False,
502         },
503       "baz": {
504         "a": 678,
505         "b": "678",
506         "c": True,
507         },
508       }
509
510     verified = cmdlib._UpdateAndVerifySubDict(old_test, test, self.type_check)
511     self.assertEqual(verified, mv)
512
513   def testWrong(self):
514     test = {
515       "foo": {
516         "a": "blubb",
517         "b": "123",
518         "c": True,
519         },
520       "bar": {
521         "a": 321,
522         "b": "321",
523         "c": False,
524         },
525       }
526
527     self.assertRaises(errors.TypeEnforcementError,
528                       cmdlib._UpdateAndVerifySubDict, {}, test, self.type_check)
529
530
531 class TestHvStateHelper(unittest.TestCase):
532   def testWithoutOpData(self):
533     self.assertEqual(cmdlib._MergeAndVerifyHvState(None, NotImplemented), None)
534
535   def testWithoutOldData(self):
536     new = {
537       constants.HT_XEN_PVM: {
538         constants.HVST_MEMORY_TOTAL: 4096,
539         },
540       }
541     self.assertEqual(cmdlib._MergeAndVerifyHvState(new, None), new)
542
543   def testWithWrongHv(self):
544     new = {
545       "i-dont-exist": {
546         constants.HVST_MEMORY_TOTAL: 4096,
547         },
548       }
549     self.assertRaises(errors.OpPrereqError, cmdlib._MergeAndVerifyHvState, new,
550                       None)
551
552 class TestDiskStateHelper(unittest.TestCase):
553   def testWithoutOpData(self):
554     self.assertEqual(cmdlib._MergeAndVerifyDiskState(None, NotImplemented),
555                      None)
556
557   def testWithoutOldData(self):
558     new = {
559       constants.LD_LV: {
560         "xenvg": {
561           constants.DS_DISK_RESERVED: 1024,
562           },
563         },
564       }
565     self.assertEqual(cmdlib._MergeAndVerifyDiskState(new, None), new)
566
567   def testWithWrongStorageType(self):
568     new = {
569       "i-dont-exist": {
570         "xenvg": {
571           constants.DS_DISK_RESERVED: 1024,
572           },
573         },
574       }
575     self.assertRaises(errors.OpPrereqError, cmdlib._MergeAndVerifyDiskState,
576                       new, None)
577
578
579 class TestComputeMinMaxSpec(unittest.TestCase):
580   def setUp(self):
581     self.ipolicy = {
582       constants.ISPECS_MAX: {
583         constants.ISPEC_MEM_SIZE: 512,
584         constants.ISPEC_DISK_SIZE: 1024,
585         },
586       constants.ISPECS_MIN: {
587         constants.ISPEC_MEM_SIZE: 128,
588         constants.ISPEC_DISK_COUNT: 1,
589         },
590       }
591
592   def testNoneValue(self):
593     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_MEM_SIZE, None,
594                                               self.ipolicy, None) is None)
595
596   def testAutoValue(self):
597     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_MEM_SIZE, None,
598                                               self.ipolicy,
599                                               constants.VALUE_AUTO) is None)
600
601   def testNotDefined(self):
602     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_NIC_COUNT, None,
603                                               self.ipolicy, 3) is None)
604
605   def testNoMinDefined(self):
606     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_DISK_SIZE, None,
607                                               self.ipolicy, 128) is None)
608
609   def testNoMaxDefined(self):
610     self.assertTrue(cmdlib._ComputeMinMaxSpec(constants.ISPEC_DISK_COUNT, None,
611                                                 self.ipolicy, 16) is None)
612
613   def testOutOfRange(self):
614     for (name, val) in ((constants.ISPEC_MEM_SIZE, 64),
615                         (constants.ISPEC_MEM_SIZE, 768),
616                         (constants.ISPEC_DISK_SIZE, 4096),
617                         (constants.ISPEC_DISK_COUNT, 0)):
618       min_v = self.ipolicy[constants.ISPECS_MIN].get(name, val)
619       max_v = self.ipolicy[constants.ISPECS_MAX].get(name, val)
620       self.assertEqual(cmdlib._ComputeMinMaxSpec(name, None,
621                                                  self.ipolicy, val),
622                        "%s value %s is not in range [%s, %s]" %
623                        (name, val,min_v, max_v))
624       self.assertEqual(cmdlib._ComputeMinMaxSpec(name, "1",
625                                                  self.ipolicy, val),
626                        "%s/1 value %s is not in range [%s, %s]" %
627                        (name, val,min_v, max_v))
628
629   def test(self):
630     for (name, val) in ((constants.ISPEC_MEM_SIZE, 256),
631                         (constants.ISPEC_MEM_SIZE, 128),
632                         (constants.ISPEC_MEM_SIZE, 512),
633                         (constants.ISPEC_DISK_SIZE, 1024),
634                         (constants.ISPEC_DISK_SIZE, 0),
635                         (constants.ISPEC_DISK_COUNT, 1),
636                         (constants.ISPEC_DISK_COUNT, 5)):
637       self.assertTrue(cmdlib._ComputeMinMaxSpec(name, None, self.ipolicy, val)
638                       is None)
639
640
641 def _ValidateComputeMinMaxSpec(name, *_):
642   assert name in constants.ISPECS_PARAMETERS
643   return None
644
645
646 class _SpecWrapper:
647   def __init__(self, spec):
648     self.spec = spec
649
650   def ComputeMinMaxSpec(self, *args):
651     return self.spec.pop(0)
652
653
654 class TestComputeIPolicySpecViolation(unittest.TestCase):
655   def test(self):
656     compute_fn = _ValidateComputeMinMaxSpec
657     ret = cmdlib._ComputeIPolicySpecViolation(NotImplemented, 1024, 1, 1, 1,
658                                               [1024], 1, _compute_fn=compute_fn)
659     self.assertEqual(ret, [])
660
661   def testInvalidArguments(self):
662     self.assertRaises(AssertionError, cmdlib._ComputeIPolicySpecViolation,
663                       NotImplemented, 1024, 1, 1, 1, [], 1)
664
665   def testInvalidSpec(self):
666     spec = _SpecWrapper([None, False, "foo", None, "bar", None])
667     compute_fn = spec.ComputeMinMaxSpec
668     ret = cmdlib._ComputeIPolicySpecViolation(NotImplemented, 1024, 1, 1, 1,
669                                               [1024], 1, _compute_fn=compute_fn)
670     self.assertEqual(ret, ["foo", "bar"])
671     self.assertFalse(spec.spec)
672
673
674 class _StubComputeIPolicySpecViolation:
675   def __init__(self, mem_size, cpu_count, disk_count, nic_count, disk_sizes,
676                spindle_use):
677     self.mem_size = mem_size
678     self.cpu_count = cpu_count
679     self.disk_count = disk_count
680     self.nic_count = nic_count
681     self.disk_sizes = disk_sizes
682     self.spindle_use = spindle_use
683
684   def __call__(self, _, mem_size, cpu_count, disk_count, nic_count, disk_sizes,
685                spindle_use):
686     assert self.mem_size == mem_size
687     assert self.cpu_count == cpu_count
688     assert self.disk_count == disk_count
689     assert self.nic_count == nic_count
690     assert self.disk_sizes == disk_sizes
691     assert self.spindle_use == spindle_use
692
693     return []
694
695
696 class TestComputeIPolicyInstanceViolation(unittest.TestCase):
697   def test(self):
698     beparams = {
699       constants.BE_MAXMEM: 2048,
700       constants.BE_VCPUS: 2,
701       constants.BE_SPINDLE_USE: 4,
702       }
703     disks = [objects.Disk(size=512)]
704     instance = objects.Instance(beparams=beparams, disks=disks, nics=[])
705     stub = _StubComputeIPolicySpecViolation(2048, 2, 1, 0, [512], 4)
706     ret = cmdlib._ComputeIPolicyInstanceViolation(NotImplemented, instance,
707                                                   _compute_fn=stub)
708     self.assertEqual(ret, [])
709
710
711 class TestComputeIPolicyInstanceSpecViolation(unittest.TestCase):
712   def test(self):
713     ispec = {
714       constants.ISPEC_MEM_SIZE: 2048,
715       constants.ISPEC_CPU_COUNT: 2,
716       constants.ISPEC_DISK_COUNT: 1,
717       constants.ISPEC_DISK_SIZE: [512],
718       constants.ISPEC_NIC_COUNT: 0,
719       constants.ISPEC_SPINDLE_USE: 1,
720       }
721     stub = _StubComputeIPolicySpecViolation(2048, 2, 1, 0, [512], 1)
722     ret = cmdlib._ComputeIPolicyInstanceSpecViolation(NotImplemented, ispec,
723                                                       _compute_fn=stub)
724     self.assertEqual(ret, [])
725
726
727 class _CallRecorder:
728   def __init__(self, return_value=None):
729     self.called = False
730     self.return_value = return_value
731
732   def __call__(self, *args):
733     self.called = True
734     return self.return_value
735
736
737 class TestComputeIPolicyNodeViolation(unittest.TestCase):
738   def setUp(self):
739     self.recorder = _CallRecorder(return_value=[])
740
741   def testSameGroup(self):
742     ret = cmdlib._ComputeIPolicyNodeViolation(NotImplemented, NotImplemented,
743                                               "foo", "foo",
744                                               _compute_fn=self.recorder)
745     self.assertFalse(self.recorder.called)
746     self.assertEqual(ret, [])
747
748   def testDifferentGroup(self):
749     ret = cmdlib._ComputeIPolicyNodeViolation(NotImplemented, NotImplemented,
750                                               "foo", "bar",
751                                               _compute_fn=self.recorder)
752     self.assertTrue(self.recorder.called)
753     self.assertEqual(ret, [])
754
755
756 class _FakeConfigForTargetNodeIPolicy:
757   def __init__(self, node_info=NotImplemented):
758     self._node_info = node_info
759
760   def GetNodeInfo(self, _):
761     return self._node_info
762
763
764 class TestCheckTargetNodeIPolicy(unittest.TestCase):
765   def setUp(self):
766     self.instance = objects.Instance(primary_node="blubb")
767     self.target_node = objects.Node(group="bar")
768     node_info = objects.Node(group="foo")
769     fake_cfg = _FakeConfigForTargetNodeIPolicy(node_info=node_info)
770     self.lu = _FakeLU(cfg=fake_cfg)
771
772   def testNoViolation(self):
773     compute_recoder = _CallRecorder(return_value=[])
774     cmdlib._CheckTargetNodeIPolicy(self.lu, NotImplemented, self.instance,
775                                    self.target_node,
776                                    _compute_fn=compute_recoder)
777     self.assertTrue(compute_recoder.called)
778     self.assertEqual(self.lu.warning_log, [])
779
780   def testNoIgnore(self):
781     compute_recoder = _CallRecorder(return_value=["mem_size not in range"])
782     self.assertRaises(errors.OpPrereqError, cmdlib._CheckTargetNodeIPolicy,
783                       self.lu, NotImplemented, self.instance, self.target_node,
784                       _compute_fn=compute_recoder)
785     self.assertTrue(compute_recoder.called)
786     self.assertEqual(self.lu.warning_log, [])
787
788   def testIgnoreViolation(self):
789     compute_recoder = _CallRecorder(return_value=["mem_size not in range"])
790     cmdlib._CheckTargetNodeIPolicy(self.lu, NotImplemented, self.instance,
791                                    self.target_node, ignore=True,
792                                    _compute_fn=compute_recoder)
793     self.assertTrue(compute_recoder.called)
794     msg = ("Instance does not meet target node group's (bar) instance policy:"
795            " mem_size not in range")
796     self.assertEqual(self.lu.warning_log, [(msg, ())])
797
798
799 class TestApplyContainerMods(unittest.TestCase):
800   def testEmptyContainer(self):
801     container = []
802     chgdesc = []
803     cmdlib.ApplyContainerMods("test", container, chgdesc, [], None, None, None)
804     self.assertEqual(container, [])
805     self.assertEqual(chgdesc, [])
806
807   def testAdd(self):
808     container = []
809     chgdesc = []
810     mods = cmdlib.PrepareContainerMods([
811       (constants.DDM_ADD, -1, "Hello"),
812       (constants.DDM_ADD, -1, "World"),
813       (constants.DDM_ADD, 0, "Start"),
814       (constants.DDM_ADD, -1, "End"),
815       ], None)
816     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
817                               None, None, None)
818     self.assertEqual(container, ["Start", "Hello", "World", "End"])
819     self.assertEqual(chgdesc, [])
820
821     mods = cmdlib.PrepareContainerMods([
822       (constants.DDM_ADD, 0, "zero"),
823       (constants.DDM_ADD, 3, "Added"),
824       (constants.DDM_ADD, 5, "four"),
825       (constants.DDM_ADD, 7, "xyz"),
826       ], None)
827     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
828                               None, None, None)
829     self.assertEqual(container,
830                      ["zero", "Start", "Hello", "Added", "World", "four",
831                       "End", "xyz"])
832     self.assertEqual(chgdesc, [])
833
834     for idx in [-2, len(container) + 1]:
835       mods = cmdlib.PrepareContainerMods([
836         (constants.DDM_ADD, idx, "error"),
837         ], None)
838       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
839                         "test", container, None, mods, None, None, None)
840
841   def testRemoveError(self):
842     for idx in [0, 1, 2, 100, -1, -4]:
843       mods = cmdlib.PrepareContainerMods([
844         (constants.DDM_REMOVE, idx, None),
845         ], None)
846       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
847                         "test", [], None, mods, None, None, None)
848
849     mods = cmdlib.PrepareContainerMods([
850       (constants.DDM_REMOVE, 0, object()),
851       ], None)
852     self.assertRaises(AssertionError, cmdlib.ApplyContainerMods,
853                       "test", [""], None, mods, None, None, None)
854
855   def testAddError(self):
856     for idx in range(-100, -1) + [100]:
857       mods = cmdlib.PrepareContainerMods([
858         (constants.DDM_ADD, idx, None),
859         ], None)
860       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
861                         "test", [], None, mods, None, None, None)
862
863   def testRemove(self):
864     container = ["item 1", "item 2"]
865     mods = cmdlib.PrepareContainerMods([
866       (constants.DDM_ADD, -1, "aaa"),
867       (constants.DDM_REMOVE, -1, None),
868       (constants.DDM_ADD, -1, "bbb"),
869       ], None)
870     chgdesc = []
871     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
872                               None, None, None)
873     self.assertEqual(container, ["item 1", "item 2", "bbb"])
874     self.assertEqual(chgdesc, [
875       ("test/2", "remove"),
876       ])
877
878   def testModify(self):
879     container = ["item 1", "item 2"]
880     mods = cmdlib.PrepareContainerMods([
881       (constants.DDM_MODIFY, -1, "a"),
882       (constants.DDM_MODIFY, 0, "b"),
883       (constants.DDM_MODIFY, 1, "c"),
884       ], None)
885     chgdesc = []
886     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
887                               None, None, None)
888     self.assertEqual(container, ["item 1", "item 2"])
889     self.assertEqual(chgdesc, [])
890
891     for idx in [-2, len(container) + 1]:
892       mods = cmdlib.PrepareContainerMods([
893         (constants.DDM_MODIFY, idx, "error"),
894         ], None)
895       self.assertRaises(IndexError, cmdlib.ApplyContainerMods,
896                         "test", container, None, mods, None, None, None)
897
898   class _PrivateData:
899     def __init__(self):
900       self.data = None
901
902   @staticmethod
903   def _CreateTestFn(idx, params, private):
904     private.data = ("add", idx, params)
905     return ((100 * idx, params), [
906       ("test/%s" % idx, hex(idx)),
907       ])
908
909   @staticmethod
910   def _ModifyTestFn(idx, item, params, private):
911     private.data = ("modify", idx, params)
912     return [
913       ("test/%s" % idx, "modify %s" % params),
914       ]
915
916   @staticmethod
917   def _RemoveTestFn(idx, item, private):
918     private.data = ("remove", idx, item)
919
920   def testAddWithCreateFunction(self):
921     container = []
922     chgdesc = []
923     mods = cmdlib.PrepareContainerMods([
924       (constants.DDM_ADD, -1, "Hello"),
925       (constants.DDM_ADD, -1, "World"),
926       (constants.DDM_ADD, 0, "Start"),
927       (constants.DDM_ADD, -1, "End"),
928       (constants.DDM_REMOVE, 2, None),
929       (constants.DDM_MODIFY, -1, "foobar"),
930       (constants.DDM_REMOVE, 2, None),
931       (constants.DDM_ADD, 1, "More"),
932       ], self._PrivateData)
933     cmdlib.ApplyContainerMods("test", container, chgdesc, mods,
934       self._CreateTestFn, self._ModifyTestFn, self._RemoveTestFn)
935     self.assertEqual(container, [
936       (000, "Start"),
937       (100, "More"),
938       (000, "Hello"),
939       ])
940     self.assertEqual(chgdesc, [
941       ("test/0", "0x0"),
942       ("test/1", "0x1"),
943       ("test/0", "0x0"),
944       ("test/3", "0x3"),
945       ("test/2", "remove"),
946       ("test/2", "modify foobar"),
947       ("test/2", "remove"),
948       ("test/1", "0x1")
949       ])
950     self.assertTrue(compat.all(op == private.data[0]
951                                for (op, _, _, private) in mods))
952     self.assertEqual([private.data for (op, _, _, private) in mods], [
953       ("add", 0, "Hello"),
954       ("add", 1, "World"),
955       ("add", 0, "Start"),
956       ("add", 3, "End"),
957       ("remove", 2, (100, "World")),
958       ("modify", 2, "foobar"),
959       ("remove", 2, (300, "End")),
960       ("add", 1, "More"),
961       ])
962
963
964 class _FakeConfigForGenDiskTemplate:
965   def __init__(self):
966     self._unique_id = itertools.count()
967     self._drbd_minor = itertools.count(20)
968     self._port = itertools.count(constants.FIRST_DRBD_PORT)
969     self._secret = itertools.count()
970
971   def GetVGName(self):
972     return "testvg"
973
974   def GenerateUniqueID(self, ec_id):
975     return "ec%s-uq%s" % (ec_id, self._unique_id.next())
976
977   def AllocateDRBDMinor(self, nodes, instance):
978     return [self._drbd_minor.next()
979             for _ in nodes]
980
981   def AllocatePort(self):
982     return self._port.next()
983
984   def GenerateDRBDSecret(self, ec_id):
985     return "ec%s-secret%s" % (ec_id, self._secret.next())
986
987   def GetInstanceInfo(self, _):
988     return "foobar"
989
990
991 class _FakeProcForGenDiskTemplate:
992   def GetECId(self):
993     return 0
994
995
996 class TestGenerateDiskTemplate(unittest.TestCase):
997   def setUp(self):
998     nodegroup = objects.NodeGroup(name="ng")
999     nodegroup.UpgradeConfig()
1000
1001     cfg = _FakeConfigForGenDiskTemplate()
1002     proc = _FakeProcForGenDiskTemplate()
1003
1004     self.lu = _FakeLU(cfg=cfg, proc=proc)
1005     self.nodegroup = nodegroup
1006
1007   @staticmethod
1008   def GetDiskParams():
1009     return copy.deepcopy(constants.DISK_DT_DEFAULTS)
1010
1011   def testWrongDiskTemplate(self):
1012     gdt = cmdlib._GenerateDiskTemplate
1013     disk_template = "##unknown##"
1014
1015     assert disk_template not in constants.DISK_TEMPLATES
1016
1017     self.assertRaises(errors.ProgrammerError, gdt, self.lu, disk_template,
1018                       "inst26831.example.com", "node30113.example.com", [], [],
1019                       NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1020                       self.GetDiskParams())
1021
1022   def testDiskless(self):
1023     gdt = cmdlib._GenerateDiskTemplate
1024
1025     result = gdt(self.lu, constants.DT_DISKLESS, "inst27734.example.com",
1026                  "node30113.example.com", [], [],
1027                  NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1028                  self.GetDiskParams())
1029     self.assertEqual(result, [])
1030
1031   def _TestTrivialDisk(self, template, disk_info, base_index, exp_dev_type,
1032                        file_storage_dir=NotImplemented,
1033                        file_driver=NotImplemented,
1034                        req_file_storage=NotImplemented,
1035                        req_shr_file_storage=NotImplemented):
1036     gdt = cmdlib._GenerateDiskTemplate
1037
1038     map(lambda params: utils.ForceDictType(params,
1039                                            constants.IDISK_PARAMS_TYPES),
1040         disk_info)
1041
1042     # Check if non-empty list of secondaries is rejected
1043     self.assertRaises(errors.ProgrammerError, gdt, self.lu,
1044                       template, "inst25088.example.com",
1045                       "node185.example.com", ["node323.example.com"], [],
1046                       NotImplemented, NotImplemented, base_index,
1047                       self.lu.LogInfo, self.GetDiskParams(),
1048                       _req_file_storage=req_file_storage,
1049                       _req_shr_file_storage=req_shr_file_storage)
1050
1051     result = gdt(self.lu, template, "inst21662.example.com",
1052                  "node21741.example.com", [],
1053                  disk_info, file_storage_dir, file_driver, base_index,
1054                  self.lu.LogInfo, self.GetDiskParams(),
1055                  _req_file_storage=req_file_storage,
1056                  _req_shr_file_storage=req_shr_file_storage)
1057
1058     for (idx, disk) in enumerate(result):
1059       self.assertTrue(isinstance(disk, objects.Disk))
1060       self.assertEqual(disk.dev_type, exp_dev_type)
1061       self.assertEqual(disk.size, disk_info[idx][constants.IDISK_SIZE])
1062       self.assertEqual(disk.mode, disk_info[idx][constants.IDISK_MODE])
1063       self.assertTrue(disk.children is None)
1064
1065     self._CheckIvNames(result, base_index, base_index + len(disk_info))
1066     cmdlib._UpdateIvNames(base_index, result)
1067     self._CheckIvNames(result, base_index, base_index + len(disk_info))
1068
1069     return result
1070
1071   def _CheckIvNames(self, disks, base_index, end_index):
1072     self.assertEqual(map(operator.attrgetter("iv_name"), disks),
1073                      ["disk/%s" % i for i in range(base_index, end_index)])
1074
1075   def testPlain(self):
1076     disk_info = [{
1077       constants.IDISK_SIZE: 1024,
1078       constants.IDISK_MODE: constants.DISK_RDWR,
1079       }, {
1080       constants.IDISK_SIZE: 4096,
1081       constants.IDISK_VG: "othervg",
1082       constants.IDISK_MODE: constants.DISK_RDWR,
1083       }]
1084
1085     result = self._TestTrivialDisk(constants.DT_PLAIN, disk_info, 3,
1086                                    constants.LD_LV)
1087
1088     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1089       ("testvg", "ec0-uq0.disk3"),
1090       ("othervg", "ec0-uq1.disk4"),
1091       ])
1092
1093   @staticmethod
1094   def _AllowFileStorage():
1095     pass
1096
1097   @staticmethod
1098   def _ForbidFileStorage():
1099     raise errors.OpPrereqError("Disallowed in test")
1100
1101   def testFile(self):
1102     self.assertRaises(errors.OpPrereqError, self._TestTrivialDisk,
1103                       constants.DT_FILE, [], 0, NotImplemented,
1104                       req_file_storage=self._ForbidFileStorage)
1105     self.assertRaises(errors.OpPrereqError, self._TestTrivialDisk,
1106                       constants.DT_SHARED_FILE, [], 0, NotImplemented,
1107                       req_shr_file_storage=self._ForbidFileStorage)
1108
1109     for disk_template in [constants.DT_FILE, constants.DT_SHARED_FILE]:
1110       disk_info = [{
1111         constants.IDISK_SIZE: 80 * 1024,
1112         constants.IDISK_MODE: constants.DISK_RDONLY,
1113         }, {
1114         constants.IDISK_SIZE: 4096,
1115         constants.IDISK_MODE: constants.DISK_RDWR,
1116         }, {
1117         constants.IDISK_SIZE: 6 * 1024,
1118         constants.IDISK_MODE: constants.DISK_RDWR,
1119         }]
1120
1121       result = self._TestTrivialDisk(disk_template, disk_info, 2,
1122         constants.LD_FILE, file_storage_dir="/tmp",
1123         file_driver=constants.FD_BLKTAP,
1124         req_file_storage=self._AllowFileStorage,
1125         req_shr_file_storage=self._AllowFileStorage)
1126
1127       self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1128         (constants.FD_BLKTAP, "/tmp/disk2"),
1129         (constants.FD_BLKTAP, "/tmp/disk3"),
1130         (constants.FD_BLKTAP, "/tmp/disk4"),
1131         ])
1132
1133   def testBlock(self):
1134     disk_info = [{
1135       constants.IDISK_SIZE: 8 * 1024,
1136       constants.IDISK_MODE: constants.DISK_RDWR,
1137       constants.IDISK_ADOPT: "/tmp/some/block/dev",
1138       }]
1139
1140     result = self._TestTrivialDisk(constants.DT_BLOCK, disk_info, 10,
1141                                    constants.LD_BLOCKDEV)
1142
1143     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1144       (constants.BLOCKDEV_DRIVER_MANUAL, "/tmp/some/block/dev"),
1145       ])
1146
1147   def testRbd(self):
1148     disk_info = [{
1149       constants.IDISK_SIZE: 8 * 1024,
1150       constants.IDISK_MODE: constants.DISK_RDONLY,
1151       }, {
1152       constants.IDISK_SIZE: 100 * 1024,
1153       constants.IDISK_MODE: constants.DISK_RDWR,
1154       }]
1155
1156     result = self._TestTrivialDisk(constants.DT_RBD, disk_info, 0,
1157                                    constants.LD_RBD)
1158
1159     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1160       ("rbd", "ec0-uq0.rbd.disk0"),
1161       ("rbd", "ec0-uq1.rbd.disk1"),
1162       ])
1163
1164   def testDrbd8(self):
1165     gdt = cmdlib._GenerateDiskTemplate
1166     drbd8_defaults = constants.DISK_LD_DEFAULTS[constants.LD_DRBD8]
1167     drbd8_default_metavg = drbd8_defaults[constants.LDP_DEFAULT_METAVG]
1168
1169     disk_info = [{
1170       constants.IDISK_SIZE: 1024,
1171       constants.IDISK_MODE: constants.DISK_RDWR,
1172       }, {
1173       constants.IDISK_SIZE: 100 * 1024,
1174       constants.IDISK_MODE: constants.DISK_RDONLY,
1175       constants.IDISK_METAVG: "metavg",
1176       }, {
1177       constants.IDISK_SIZE: 4096,
1178       constants.IDISK_MODE: constants.DISK_RDWR,
1179       constants.IDISK_VG: "vgxyz",
1180       },
1181       ]
1182
1183     exp_logical_ids = [[
1184       (self.lu.cfg.GetVGName(), "ec0-uq0.disk0_data"),
1185       (drbd8_default_metavg, "ec0-uq0.disk0_meta"),
1186       ], [
1187       (self.lu.cfg.GetVGName(), "ec0-uq1.disk1_data"),
1188       ("metavg", "ec0-uq1.disk1_meta"),
1189       ], [
1190       ("vgxyz", "ec0-uq2.disk2_data"),
1191       (drbd8_default_metavg, "ec0-uq2.disk2_meta"),
1192       ]]
1193
1194     assert len(exp_logical_ids) == len(disk_info)
1195
1196     map(lambda params: utils.ForceDictType(params,
1197                                            constants.IDISK_PARAMS_TYPES),
1198         disk_info)
1199
1200     # Check if empty list of secondaries is rejected
1201     self.assertRaises(errors.ProgrammerError, gdt, self.lu, constants.DT_DRBD8,
1202                       "inst827.example.com", "node1334.example.com", [],
1203                       disk_info, NotImplemented, NotImplemented, 0,
1204                       self.lu.LogInfo, self.GetDiskParams())
1205
1206     result = gdt(self.lu, constants.DT_DRBD8, "inst827.example.com",
1207                  "node1334.example.com", ["node12272.example.com"],
1208                  disk_info, NotImplemented, NotImplemented, 0, self.lu.LogInfo,
1209                  self.GetDiskParams())
1210
1211     for (idx, disk) in enumerate(result):
1212       self.assertTrue(isinstance(disk, objects.Disk))
1213       self.assertEqual(disk.dev_type, constants.LD_DRBD8)
1214       self.assertEqual(disk.size, disk_info[idx][constants.IDISK_SIZE])
1215       self.assertEqual(disk.mode, disk_info[idx][constants.IDISK_MODE])
1216
1217       for child in disk.children:
1218         self.assertTrue(isinstance(disk, objects.Disk))
1219         self.assertEqual(child.dev_type, constants.LD_LV)
1220         self.assertTrue(child.children is None)
1221
1222       self.assertEqual(map(operator.attrgetter("logical_id"), disk.children),
1223                        exp_logical_ids[idx])
1224
1225       self.assertEqual(len(disk.children), 2)
1226       self.assertEqual(disk.children[0].size, disk.size)
1227       self.assertEqual(disk.children[1].size, constants.DRBD_META_SIZE)
1228
1229     self._CheckIvNames(result, 0, len(disk_info))
1230     cmdlib._UpdateIvNames(0, result)
1231     self._CheckIvNames(result, 0, len(disk_info))
1232
1233     self.assertEqual(map(operator.attrgetter("logical_id"), result), [
1234       ("node1334.example.com", "node12272.example.com",
1235        constants.FIRST_DRBD_PORT, 20, 21, "ec0-secret0"),
1236       ("node1334.example.com", "node12272.example.com",
1237        constants.FIRST_DRBD_PORT + 1, 22, 23, "ec0-secret1"),
1238       ("node1334.example.com", "node12272.example.com",
1239        constants.FIRST_DRBD_PORT + 2, 24, 25, "ec0-secret2"),
1240       ])
1241
1242
1243 class _ConfigForDiskWipe:
1244   def SetDiskID(self, device, node):
1245     assert isinstance(device, objects.Disk)
1246     assert node == "node1.example.com"
1247
1248
1249 class _RpcForDiskWipe:
1250   def __init__(self, pause_cb, wipe_cb):
1251     self._pause_cb = pause_cb
1252     self._wipe_cb = wipe_cb
1253
1254   def call_blockdev_pause_resume_sync(self, node, disks, pause):
1255     assert node == "node1.example.com"
1256     return rpc.RpcResult(data=self._pause_cb(disks, pause))
1257
1258   def call_blockdev_wipe(self, node, bdev, offset, size):
1259     assert node == "node1.example.com"
1260     return rpc.RpcResult(data=self._wipe_cb(bdev, offset, size))
1261
1262
1263 class _DiskPauseTracker:
1264   def __init__(self):
1265     self.history = []
1266
1267   def __call__(self, (disks, instance), pause):
1268     assert instance.disks == disks
1269
1270     self.history.extend((i.logical_id, i.size, pause)
1271                         for i in disks)
1272
1273     return (True, [True] * len(disks))
1274
1275
1276 class TestWipeDisks(unittest.TestCase):
1277   def testPauseFailure(self):
1278     def _FailPause((disks, _), pause):
1279       self.assertEqual(len(disks), 3)
1280       self.assertTrue(pause)
1281       return (False, "error")
1282
1283     lu = _FakeLU(rpc=_RpcForDiskWipe(_FailPause, NotImplemented),
1284                  cfg=_ConfigForDiskWipe())
1285
1286     disks = [
1287       objects.Disk(dev_type=constants.LD_LV),
1288       objects.Disk(dev_type=constants.LD_LV),
1289       objects.Disk(dev_type=constants.LD_LV),
1290       ]
1291
1292     instance = objects.Instance(name="inst21201",
1293                                 primary_node="node1.example.com",
1294                                 disk_template=constants.DT_PLAIN,
1295                                 disks=disks)
1296
1297     self.assertRaises(errors.OpExecError, cmdlib._WipeDisks, lu, instance)
1298
1299   def testFailingWipe(self):
1300     pt = _DiskPauseTracker()
1301
1302     def _WipeCb((disk, _), offset, size):
1303       assert disk.logical_id == "disk0"
1304       return (False, None)
1305
1306     lu = _FakeLU(rpc=_RpcForDiskWipe(pt, _WipeCb),
1307                  cfg=_ConfigForDiskWipe())
1308
1309     disks = [
1310       objects.Disk(dev_type=constants.LD_LV, logical_id="disk0",
1311                    size=100 * 1024),
1312       objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1313                    size=500 * 1024),
1314       objects.Disk(dev_type=constants.LD_LV, logical_id="disk2", size=256),
1315       ]
1316
1317     instance = objects.Instance(name="inst562",
1318                                 primary_node="node1.example.com",
1319                                 disk_template=constants.DT_PLAIN,
1320                                 disks=disks)
1321
1322     try:
1323       cmdlib._WipeDisks(lu, instance)
1324     except errors.OpExecError, err:
1325       self.assertTrue(str(err), "Could not wipe disk 0 at offset 0 ")
1326     else:
1327       self.fail("Did not raise exception")
1328
1329     self.assertEqual(pt.history, [
1330       ("disk0", 100 * 1024, True),
1331       ("disk1", 500 * 1024, True),
1332       ("disk2", 256, True),
1333       ("disk0", 100 * 1024, False),
1334       ("disk1", 500 * 1024, False),
1335       ("disk2", 256, False),
1336       ])
1337
1338   def testNormalWipe(self):
1339     pt = _DiskPauseTracker()
1340
1341     progress = {}
1342
1343     def _WipeCb((disk, _), offset, size):
1344       assert isinstance(offset, (long, int))
1345       assert isinstance(size, (long, int))
1346
1347       max_chunk_size = (disk.size / 100.0 * constants.MIN_WIPE_CHUNK_PERCENT)
1348
1349       self.assertTrue(offset >= 0)
1350       self.assertTrue((offset + size) <= disk.size)
1351
1352       self.assertTrue(size > 0)
1353       self.assertTrue(size <= constants.MAX_WIPE_CHUNK)
1354       self.assertTrue(size <= max_chunk_size)
1355
1356       self.assertTrue(offset == 0 or disk.logical_id in progress)
1357
1358       # Keep track of progress
1359       cur_progress = progress.setdefault(disk.logical_id, 0)
1360       self.assertEqual(cur_progress, offset)
1361
1362       progress[disk.logical_id] += size
1363
1364       return (True, None)
1365
1366     lu = _FakeLU(rpc=_RpcForDiskWipe(pt, _WipeCb),
1367                  cfg=_ConfigForDiskWipe())
1368
1369     disks = [
1370       objects.Disk(dev_type=constants.LD_LV, logical_id="disk0", size=1024),
1371       objects.Disk(dev_type=constants.LD_LV, logical_id="disk1",
1372                    size=500 * 1024),
1373       objects.Disk(dev_type=constants.LD_LV, logical_id="disk2", size=128),
1374       objects.Disk(dev_type=constants.LD_LV, logical_id="disk3",
1375                    size=constants.MAX_WIPE_CHUNK),
1376       ]
1377
1378     instance = objects.Instance(name="inst3560",
1379                                 primary_node="node1.example.com",
1380                                 disk_template=constants.DT_PLAIN,
1381                                 disks=disks)
1382
1383     cmdlib._WipeDisks(lu, instance)
1384
1385     self.assertEqual(pt.history, [
1386       ("disk0", 1024, True),
1387       ("disk1", 500 * 1024, True),
1388       ("disk2", 128, True),
1389       ("disk3", constants.MAX_WIPE_CHUNK, True),
1390       ("disk0", 1024, False),
1391       ("disk1", 500 * 1024, False),
1392       ("disk2", 128, False),
1393       ("disk3", constants.MAX_WIPE_CHUNK, False),
1394       ])
1395
1396     # Ensure the complete disk has been wiped
1397     self.assertEqual(progress, dict((i.logical_id, i.size) for i in disks))
1398
1399
1400 if __name__ == "__main__":
1401   testutils.GanetiTestProgram()