Merge branch 'stable-2.8' into stable-2.9
[ganeti-local] / test / py / ganeti.backend_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.backend"""
23
24 import mock
25 import os
26 import shutil
27 import tempfile
28 import testutils
29 import unittest
30
31 from ganeti import backend
32 from ganeti import constants
33 from ganeti import errors
34 from ganeti import hypervisor
35 from ganeti import netutils
36 from ganeti import utils
37
38
39 class TestX509Certificates(unittest.TestCase):
40   def setUp(self):
41     self.tmpdir = tempfile.mkdtemp()
42
43   def tearDown(self):
44     shutil.rmtree(self.tmpdir)
45
46   def test(self):
47     (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
48
49     self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
50                                                  backend._X509_CERT_FILE)),
51                      cert_pem)
52     self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
53                                                   backend._X509_KEY_FILE)))
54
55     (name2, cert_pem2) = \
56       backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
57
58     backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
59     backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
60
61     self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
62
63   def testNonEmpty(self):
64     (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
65
66     utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
67                     data="Hello World")
68
69     self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
70                       name, cryptodir=self.tmpdir)
71
72     self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
73
74
75 class TestNodeVerify(testutils.GanetiTestCase):
76
77   def setUp(self):
78     testutils.GanetiTestCase.setUp(self)
79     self._mock_hv = None
80
81   def _GetHypervisor(self, hv_name):
82     self._mock_hv = hypervisor.GetHypervisor(hv_name)
83     self._mock_hv.ValidateParameters = mock.Mock()
84     self._mock_hv.Verify = mock.Mock()
85     return self._mock_hv
86
87   def testMasterIPLocalhost(self):
88     # this a real functional test, but requires localhost to be reachable
89     local_data = (netutils.Hostname.GetSysName(),
90                   constants.IP4_ADDRESS_LOCALHOST)
91     result = backend.VerifyNode({constants.NV_MASTERIP: local_data}, None, {})
92     self.failUnless(constants.NV_MASTERIP in result,
93                     "Master IP data not returned")
94     self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
95
96   def testMasterIPUnreachable(self):
97     # Network 192.0.2.0/24 is reserved for test/documentation as per
98     # RFC 5737
99     bad_data =  ("master.example.com", "192.0.2.1")
100     # we just test that whatever TcpPing returns, VerifyNode returns too
101     netutils.TcpPing = lambda a, b, source=None: False
102     result = backend.VerifyNode({constants.NV_MASTERIP: bad_data}, None, {})
103     self.failUnless(constants.NV_MASTERIP in result,
104                     "Master IP data not returned")
105     self.failIf(result[constants.NV_MASTERIP],
106                 "Result from netutils.TcpPing corrupted")
107
108   def testVerifyHvparams(self):
109     test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
110     test_what = {constants.NV_HVPARAMS: \
111         [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
112     result = {}
113     backend._VerifyHvparams(test_what, True, result,
114                             get_hv_fn=self._GetHypervisor)
115     self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
116
117   def testVerifyHypervisors(self):
118     hvname = constants.HT_XEN_PVM
119     hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
120     all_hvparams = {hvname: hvparams}
121     test_what = {constants.NV_HYPERVISOR: [hvname]}
122     result = {}
123     backend._VerifyHypervisors(
124         test_what, True, result, all_hvparams=all_hvparams,
125         get_hv_fn=self._GetHypervisor)
126     self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
127
128
129 def _DefRestrictedCmdOwner():
130   return (os.getuid(), os.getgid())
131
132
133 class TestVerifyRestrictedCmdName(unittest.TestCase):
134   def testAcceptableName(self):
135     for i in ["foo", "bar", "z1", "000first", "hello-world"]:
136       for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
137         (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
138         self.assertTrue(status)
139         self.assertTrue(msg is None)
140
141   def testEmptyAndSpace(self):
142     for i in ["", " ", "\t", "\n"]:
143       (status, msg) = backend._VerifyRestrictedCmdName(i)
144       self.assertFalse(status)
145       self.assertEqual(msg, "Missing command name")
146
147   def testNameWithSlashes(self):
148     for i in ["/", "./foo", "../moo", "some/name"]:
149       (status, msg) = backend._VerifyRestrictedCmdName(i)
150       self.assertFalse(status)
151       self.assertEqual(msg, "Invalid command name")
152
153   def testForbiddenCharacters(self):
154     for i in ["#", ".", "..", "bash -c ls", "'"]:
155       (status, msg) = backend._VerifyRestrictedCmdName(i)
156       self.assertFalse(status)
157       self.assertEqual(msg, "Command name contains forbidden characters")
158
159
160 class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
161   def setUp(self):
162     self.tmpdir = tempfile.mkdtemp()
163
164   def tearDown(self):
165     shutil.rmtree(self.tmpdir)
166
167   def testCanNotStat(self):
168     tmpname = utils.PathJoin(self.tmpdir, "foobar")
169     self.assertFalse(os.path.exists(tmpname))
170     (status, msg) = \
171       backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
172     self.assertFalse(status)
173     self.assertTrue(msg.startswith("Can't stat(2) '"))
174
175   def testTooPermissive(self):
176     tmpname = utils.PathJoin(self.tmpdir, "foobar")
177     os.mkdir(tmpname)
178
179     for mode in [0777, 0706, 0760, 0722]:
180       os.chmod(tmpname, mode)
181       self.assertTrue(os.path.isdir(tmpname))
182       (status, msg) = \
183         backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
184       self.assertFalse(status)
185       self.assertTrue(msg.startswith("Permissions on '"))
186
187   def testNoDirectory(self):
188     tmpname = utils.PathJoin(self.tmpdir, "foobar")
189     utils.WriteFile(tmpname, data="empty\n")
190     self.assertTrue(os.path.isfile(tmpname))
191     (status, msg) = \
192       backend._VerifyRestrictedCmdDirectory(tmpname,
193                                             _owner=_DefRestrictedCmdOwner())
194     self.assertFalse(status)
195     self.assertTrue(msg.endswith("is not a directory"))
196
197   def testNormal(self):
198     tmpname = utils.PathJoin(self.tmpdir, "foobar")
199     os.mkdir(tmpname)
200     os.chmod(tmpname, 0755)
201     self.assertTrue(os.path.isdir(tmpname))
202     (status, msg) = \
203       backend._VerifyRestrictedCmdDirectory(tmpname,
204                                             _owner=_DefRestrictedCmdOwner())
205     self.assertTrue(status)
206     self.assertTrue(msg is None)
207
208
209 class TestVerifyRestrictedCmd(unittest.TestCase):
210   def setUp(self):
211     self.tmpdir = tempfile.mkdtemp()
212
213   def tearDown(self):
214     shutil.rmtree(self.tmpdir)
215
216   def testCanNotStat(self):
217     tmpname = utils.PathJoin(self.tmpdir, "helloworld")
218     self.assertFalse(os.path.exists(tmpname))
219     (status, msg) = \
220       backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
221                                    _owner=NotImplemented)
222     self.assertFalse(status)
223     self.assertTrue(msg.startswith("Can't stat(2) '"))
224
225   def testNotExecutable(self):
226     tmpname = utils.PathJoin(self.tmpdir, "cmdname")
227     utils.WriteFile(tmpname, data="empty\n")
228     (status, msg) = \
229       backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
230                                    _owner=_DefRestrictedCmdOwner())
231     self.assertFalse(status)
232     self.assertTrue(msg.startswith("access(2) thinks '"))
233
234   def testExecutable(self):
235     tmpname = utils.PathJoin(self.tmpdir, "cmdname")
236     utils.WriteFile(tmpname, data="empty\n", mode=0700)
237     (status, executable) = \
238       backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
239                                    _owner=_DefRestrictedCmdOwner())
240     self.assertTrue(status)
241     self.assertEqual(executable, tmpname)
242
243
244 class TestPrepareRestrictedCmd(unittest.TestCase):
245   _TEST_PATH = "/tmp/some/test/path"
246
247   def testDirFails(self):
248     def fn(path):
249       self.assertEqual(path, self._TEST_PATH)
250       return (False, "test error 31420")
251
252     (status, msg) = \
253       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
254                                     _verify_dir=fn,
255                                     _verify_name=NotImplemented,
256                                     _verify_cmd=NotImplemented)
257     self.assertFalse(status)
258     self.assertEqual(msg, "test error 31420")
259
260   def testNameFails(self):
261     def fn(cmd):
262       self.assertEqual(cmd, "cmd4617")
263       return (False, "test error 591")
264
265     (status, msg) = \
266       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
267                                     _verify_dir=lambda _: (True, None),
268                                     _verify_name=fn,
269                                     _verify_cmd=NotImplemented)
270     self.assertFalse(status)
271     self.assertEqual(msg, "test error 591")
272
273   def testCommandFails(self):
274     def fn(path, cmd):
275       self.assertEqual(path, self._TEST_PATH)
276       self.assertEqual(cmd, "cmd17577")
277       return (False, "test error 25524")
278
279     (status, msg) = \
280       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
281                                     _verify_dir=lambda _: (True, None),
282                                     _verify_name=lambda _: (True, None),
283                                     _verify_cmd=fn)
284     self.assertFalse(status)
285     self.assertEqual(msg, "test error 25524")
286
287   def testSuccess(self):
288     def fn(path, cmd):
289       return (True, utils.PathJoin(path, cmd))
290
291     (status, executable) = \
292       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
293                                     _verify_dir=lambda _: (True, None),
294                                     _verify_name=lambda _: (True, None),
295                                     _verify_cmd=fn)
296     self.assertTrue(status)
297     self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
298
299
300 def _SleepForRestrictedCmd(duration):
301   assert duration > 5
302
303
304 def _GenericRestrictedCmdError(cmd):
305   return "Executing command '%s' failed" % cmd
306
307
308 class TestRunRestrictedCmd(unittest.TestCase):
309   def setUp(self):
310     self.tmpdir = tempfile.mkdtemp()
311
312   def tearDown(self):
313     shutil.rmtree(self.tmpdir)
314
315   def testNonExistantLockDirectory(self):
316     lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
317     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
318     self.assertFalse(os.path.exists(lockfile))
319     self.assertRaises(backend.RPCFail,
320                       backend.RunRestrictedCmd, "test",
321                       _lock_timeout=NotImplemented,
322                       _lock_file=lockfile,
323                       _path=NotImplemented,
324                       _sleep_fn=sleep_fn,
325                       _prepare_fn=NotImplemented,
326                       _runcmd_fn=NotImplemented,
327                       _enabled=True)
328     self.assertEqual(sleep_fn.Count(), 1)
329
330   @staticmethod
331   def _TryLock(lockfile):
332     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
333
334     result = False
335     try:
336       backend.RunRestrictedCmd("test22717",
337                                _lock_timeout=0.1,
338                                _lock_file=lockfile,
339                                _path=NotImplemented,
340                                _sleep_fn=sleep_fn,
341                                _prepare_fn=NotImplemented,
342                                _runcmd_fn=NotImplemented,
343                                _enabled=True)
344     except backend.RPCFail, err:
345       assert str(err) == _GenericRestrictedCmdError("test22717"), \
346              "Did not fail with generic error message"
347       result = True
348
349     assert sleep_fn.Count() == 1
350
351     return result
352
353   def testLockHeldByOtherProcess(self):
354     lockfile = utils.PathJoin(self.tmpdir, "lock")
355
356     lock = utils.FileLock.Open(lockfile)
357     lock.Exclusive(blocking=True, timeout=1.0)
358     try:
359       self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
360     finally:
361       lock.Close()
362
363   @staticmethod
364   def _PrepareRaisingException(path, cmd):
365     assert cmd == "test23122"
366     raise Exception("test")
367
368   def testPrepareRaisesException(self):
369     lockfile = utils.PathJoin(self.tmpdir, "lock")
370
371     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
372     prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
373
374     try:
375       backend.RunRestrictedCmd("test23122",
376                                _lock_timeout=1.0, _lock_file=lockfile,
377                                _path=NotImplemented, _runcmd_fn=NotImplemented,
378                                _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
379                                _enabled=True)
380     except backend.RPCFail, err:
381       self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
382     else:
383       self.fail("Didn't fail")
384
385     self.assertEqual(sleep_fn.Count(), 1)
386     self.assertEqual(prepare_fn.Count(), 1)
387
388   @staticmethod
389   def _PrepareFails(path, cmd):
390     assert cmd == "test29327"
391     return ("some error message", None)
392
393   def testPrepareFails(self):
394     lockfile = utils.PathJoin(self.tmpdir, "lock")
395
396     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
397     prepare_fn = testutils.CallCounter(self._PrepareFails)
398
399     try:
400       backend.RunRestrictedCmd("test29327",
401                                _lock_timeout=1.0, _lock_file=lockfile,
402                                _path=NotImplemented, _runcmd_fn=NotImplemented,
403                                _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
404                                _enabled=True)
405     except backend.RPCFail, err:
406       self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
407     else:
408       self.fail("Didn't fail")
409
410     self.assertEqual(sleep_fn.Count(), 1)
411     self.assertEqual(prepare_fn.Count(), 1)
412
413   @staticmethod
414   def _SuccessfulPrepare(path, cmd):
415     return (True, utils.PathJoin(path, cmd))
416
417   def testRunCmdFails(self):
418     lockfile = utils.PathJoin(self.tmpdir, "lock")
419
420     def fn(args, env=NotImplemented, reset_env=NotImplemented,
421            postfork_fn=NotImplemented):
422       self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
423       self.assertEqual(env, {})
424       self.assertTrue(reset_env)
425       self.assertTrue(callable(postfork_fn))
426
427       trylock = utils.FileLock.Open(lockfile)
428       try:
429         # See if lockfile is still held
430         self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
431
432         # Call back to release lock
433         postfork_fn(NotImplemented)
434
435         # See if lockfile can be acquired
436         trylock.Exclusive(blocking=False)
437       finally:
438         trylock.Close()
439
440       # Simulate a failed command
441       return utils.RunResult(constants.EXIT_FAILURE, None,
442                              "stdout", "stderr406328567",
443                              utils.ShellQuoteArgs(args),
444                              NotImplemented, NotImplemented)
445
446     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
447     prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
448     runcmd_fn = testutils.CallCounter(fn)
449
450     try:
451       backend.RunRestrictedCmd("test3079",
452                                _lock_timeout=1.0, _lock_file=lockfile,
453                                _path=self.tmpdir, _runcmd_fn=runcmd_fn,
454                                _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
455                                _enabled=True)
456     except backend.RPCFail, err:
457       self.assertTrue(str(err).startswith("Restricted command 'test3079'"
458                                           " failed:"))
459       self.assertTrue("stderr406328567" in str(err),
460                       msg="Error did not include output")
461     else:
462       self.fail("Didn't fail")
463
464     self.assertEqual(sleep_fn.Count(), 0)
465     self.assertEqual(prepare_fn.Count(), 1)
466     self.assertEqual(runcmd_fn.Count(), 1)
467
468   def testRunCmdSucceeds(self):
469     lockfile = utils.PathJoin(self.tmpdir, "lock")
470
471     def fn(args, env=NotImplemented, reset_env=NotImplemented,
472            postfork_fn=NotImplemented):
473       self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
474       self.assertEqual(env, {})
475       self.assertTrue(reset_env)
476
477       # Call back to release lock
478       postfork_fn(NotImplemented)
479
480       # Simulate a successful command
481       return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
482                              utils.ShellQuoteArgs(args),
483                              NotImplemented, NotImplemented)
484
485     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
486     prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
487     runcmd_fn = testutils.CallCounter(fn)
488
489     result = backend.RunRestrictedCmd("test5667",
490                                       _lock_timeout=1.0, _lock_file=lockfile,
491                                       _path=self.tmpdir, _runcmd_fn=runcmd_fn,
492                                       _sleep_fn=sleep_fn,
493                                       _prepare_fn=prepare_fn,
494                                       _enabled=True)
495     self.assertEqual(result, "stdout14463")
496
497     self.assertEqual(sleep_fn.Count(), 0)
498     self.assertEqual(prepare_fn.Count(), 1)
499     self.assertEqual(runcmd_fn.Count(), 1)
500
501   def testCommandsDisabled(self):
502     try:
503       backend.RunRestrictedCmd("test",
504                                _lock_timeout=NotImplemented,
505                                _lock_file=NotImplemented,
506                                _path=NotImplemented,
507                                _sleep_fn=NotImplemented,
508                                _prepare_fn=NotImplemented,
509                                _runcmd_fn=NotImplemented,
510                                _enabled=False)
511     except backend.RPCFail, err:
512       self.assertEqual(str(err),
513                        "Restricted commands disabled at configure time")
514     else:
515       self.fail("Did not raise exception")
516
517
518 class TestSetWatcherPause(unittest.TestCase):
519   def setUp(self):
520     self.tmpdir = tempfile.mkdtemp()
521     self.filename = utils.PathJoin(self.tmpdir, "pause")
522
523   def tearDown(self):
524     shutil.rmtree(self.tmpdir)
525
526   def testUnsetNonExisting(self):
527     self.assertFalse(os.path.exists(self.filename))
528     backend.SetWatcherPause(None, _filename=self.filename)
529     self.assertFalse(os.path.exists(self.filename))
530
531   def testSetNonNumeric(self):
532     for i in ["", [], {}, "Hello World", "0", "1.0"]:
533       self.assertFalse(os.path.exists(self.filename))
534
535       try:
536         backend.SetWatcherPause(i, _filename=self.filename)
537       except backend.RPCFail, err:
538         self.assertEqual(str(err), "Duration must be numeric")
539       else:
540         self.fail("Did not raise exception")
541
542       self.assertFalse(os.path.exists(self.filename))
543
544   def testSet(self):
545     self.assertFalse(os.path.exists(self.filename))
546
547     for i in range(10):
548       backend.SetWatcherPause(i, _filename=self.filename)
549       self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
550       self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
551
552
553 class TestGetBlockDevSymlinkPath(unittest.TestCase):
554   def setUp(self):
555     self.tmpdir = tempfile.mkdtemp()
556
557   def tearDown(self):
558     shutil.rmtree(self.tmpdir)
559
560   def _Test(self, name, idx):
561     self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
562                                                      _dir=self.tmpdir),
563                      ("%s/%s%s%s" % (self.tmpdir, name,
564                                      constants.DISK_SEPARATOR, idx)))
565
566   def test(self):
567     for idx in range(100):
568       self._Test("inst1.example.com", idx)
569
570
571 class TestGetInstanceList(unittest.TestCase):
572
573   def setUp(self):
574     self._test_hv = self._TestHypervisor()
575     self._test_hv.ListInstances = mock.Mock(
576       return_value=["instance1", "instance2", "instance3"] )
577
578   class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
579     def __init__(self):
580       hypervisor.hv_base.BaseHypervisor.__init__(self)
581
582   def _GetHypervisor(self, name):
583     return self._test_hv
584
585   def testHvparams(self):
586     fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
587     hvparams = {constants.HT_FAKE: fake_hvparams}
588     backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
589                             get_hv_fn=self._GetHypervisor)
590     self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
591
592
593 class TestGetHvInfo(unittest.TestCase):
594
595   def setUp(self):
596     self._test_hv = self._TestHypervisor()
597     self._test_hv.GetNodeInfo = mock.Mock()
598
599   class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
600     def __init__(self):
601       hypervisor.hv_base.BaseHypervisor.__init__(self)
602
603   def _GetHypervisor(self, name):
604     return self._test_hv
605
606   def testGetHvInfoAllNone(self):
607     result = backend._GetHvInfoAll(None)
608     self.assertTrue(result is None)
609
610   def testGetHvInfoAll(self):
611     hvname = constants.HT_XEN_PVM
612     hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
613     hv_specs = [(hvname, hvparams)]
614
615     backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
616     self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
617
618
619 class TestApplyStorageInfoFunction(unittest.TestCase):
620
621   _STORAGE_KEY = "some_key"
622   _SOME_ARGS = ["some_args"]
623
624   def setUp(self):
625     self.mock_storage_fn = mock.Mock()
626
627   def testApplyValidStorageType(self):
628     storage_type = constants.ST_LVM_VG
629     info_fn_orig = backend._STORAGE_TYPE_INFO_FN
630     backend._STORAGE_TYPE_INFO_FN = {
631         storage_type: self.mock_storage_fn
632       }
633
634     backend._ApplyStorageInfoFunction(
635         storage_type, self._STORAGE_KEY, self._SOME_ARGS)
636
637     self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
638     backend._STORAGE_TYPE_INFO_FN = info_fn_orig
639
640   def testApplyInValidStorageType(self):
641     storage_type = "invalid_storage_type"
642     info_fn_orig = backend._STORAGE_TYPE_INFO_FN
643     backend._STORAGE_TYPE_INFO_FN = {}
644
645     self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
646                       storage_type, self._STORAGE_KEY, self._SOME_ARGS)
647     backend._STORAGE_TYPE_INFO_FN = info_fn_orig
648
649   def testApplyNotImplementedStorageType(self):
650     storage_type = "not_implemented_storage_type"
651     info_fn_orig = backend._STORAGE_TYPE_INFO_FN
652     backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
653
654     self.assertRaises(NotImplementedError,
655                       backend._ApplyStorageInfoFunction,
656                       storage_type, self._STORAGE_KEY, self._SOME_ARGS)
657     backend._STORAGE_TYPE_INFO_FN = info_fn_orig
658
659
660 class TestGetLvmVgSpaceInfo(unittest.TestCase):
661
662   def testValid(self):
663     path = "somepath"
664     excl_stor = True
665     orig_fn = backend._GetVgInfo
666     backend._GetVgInfo = mock.Mock()
667     backend._GetLvmVgSpaceInfo(path, [excl_stor])
668     backend._GetVgInfo.assert_called_with(path, excl_stor)
669     backend._GetVgInfo = orig_fn
670
671   def testNoExclStorageNotBool(self):
672     path = "somepath"
673     excl_stor = "123"
674     self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
675                       path, [excl_stor])
676
677   def testNoExclStorageNotInList(self):
678     path = "somepath"
679     excl_stor = "123"
680     self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
681                       path, excl_stor)
682
683 class TestGetLvmPvSpaceInfo(unittest.TestCase):
684
685   def testValid(self):
686     path = "somepath"
687     excl_stor = True
688     orig_fn = backend._GetVgSpindlesInfo
689     backend._GetVgSpindlesInfo = mock.Mock()
690     backend._GetLvmPvSpaceInfo(path, [excl_stor])
691     backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
692     backend._GetVgSpindlesInfo = orig_fn
693
694
695 class TestCheckStorageParams(unittest.TestCase):
696
697   def testParamsNone(self):
698     self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
699                       None, NotImplemented)
700
701   def testParamsWrongType(self):
702     self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
703                       "string", NotImplemented)
704
705   def testParamsEmpty(self):
706     backend._CheckStorageParams([], 0)
707
708   def testParamsValidNumber(self):
709     backend._CheckStorageParams(["a", True], 2)
710
711   def testParamsInvalidNumber(self):
712     self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
713                       ["b", False], 3)
714
715
716 class TestGetVgSpindlesInfo(unittest.TestCase):
717
718   def setUp(self):
719     self.vg_free = 13
720     self.vg_size = 31
721     self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
722
723   def testValidInput(self):
724     name = "myvg"
725     excl_stor = True
726     result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
727     self.mock_fn.assert_called_with(name)
728     self.assertEqual(name, result["name"])
729     self.assertEqual(constants.ST_LVM_PV, result["type"])
730     self.assertEqual(self.vg_free, result["storage_free"])
731     self.assertEqual(self.vg_size, result["storage_size"])
732
733   def testNoExclStor(self):
734     name = "myvg"
735     excl_stor = False
736     result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
737     self.mock_fn.assert_not_called()
738     self.assertEqual(name, result["name"])
739     self.assertEqual(constants.ST_LVM_PV, result["type"])
740     self.assertEqual(0, result["storage_free"])
741     self.assertEqual(0, result["storage_size"])
742
743
744 class TestGetVgSpindlesInfo(unittest.TestCase):
745
746   def testValidInput(self):
747     self.vg_free = 13
748     self.vg_size = 31
749     self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
750     name = "myvg"
751     excl_stor = True
752     result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
753     self.mock_fn.assert_called_with([name], excl_stor)
754     self.assertEqual(name, result["name"])
755     self.assertEqual(constants.ST_LVM_VG, result["type"])
756     self.assertEqual(self.vg_free, result["storage_free"])
757     self.assertEqual(self.vg_size, result["storage_size"])
758
759   def testNoExclStor(self):
760     name = "myvg"
761     excl_stor = True
762     self.mock_fn = mock.Mock(return_value=None)
763     result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
764     self.mock_fn.assert_called_with([name], excl_stor)
765     self.assertEqual(name, result["name"])
766     self.assertEqual(constants.ST_LVM_VG, result["type"])
767     self.assertEqual(None, result["storage_free"])
768     self.assertEqual(None, result["storage_size"])
769
770
771 class TestGetNodeInfo(unittest.TestCase):
772
773   _SOME_RESULT = None
774
775   def testApplyStorageInfoFunction(self):
776     orig_fn = backend._ApplyStorageInfoFunction
777     backend._ApplyStorageInfoFunction = mock.Mock(
778         return_value=self._SOME_RESULT)
779     storage_units = [(st, st + "_key", [st + "_params"]) for st in
780                      constants.STORAGE_TYPES]
781
782     backend.GetNodeInfo(storage_units, None)
783
784     call_args_list = backend._ApplyStorageInfoFunction.call_args_list
785     self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
786     for call in call_args_list:
787       storage_type, storage_key, storage_params = call[0]
788       self.assertEqual(storage_type + "_key", storage_key)
789       self.assertEqual([storage_type + "_params"], storage_params)
790       self.assertTrue(storage_type in constants.STORAGE_TYPES)
791     backend._ApplyStorageInfoFunction = orig_fn
792
793
794 class TestSpaceReportingConstants(unittest.TestCase):
795   """Ensures consistency between STS_REPORT and backend.
796
797   These tests ensure, that the constant 'STS_REPORT' is consitent
798   with the implementation of invoking space reporting functions
799   in backend.py. Once space reporting is available for all types,
800   the constant can be removed and these tests as well.
801
802   """
803   def testAllReportingTypesHaveAReportingFunction(self):
804     for storage_type in constants.STS_REPORT:
805       self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
806
807   def testAllNotReportingTypesDoneHaveFunction(self):
808     non_reporting_types = set(constants.STORAGE_TYPES)\
809         - set(constants.STS_REPORT)
810     for storage_type in non_reporting_types:
811       self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
812
813
814 if __name__ == "__main__":
815   testutils.GanetiTestProgram()