1aaeed7ed394631d6bfab9cc8e2d02a2023c8623
[ganeti-local] / test / py / ganeti.backend_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.backend"""
23
24 import mock
25 import os
26 import shutil
27 import tempfile
28 import testutils
29 import unittest
30
31 from ganeti import backend
32 from ganeti import constants
33 from ganeti import errors
34 from ganeti import hypervisor
35 from ganeti import netutils
36 from ganeti import utils
37
38
39 class TestX509Certificates(unittest.TestCase):
40   def setUp(self):
41     self.tmpdir = tempfile.mkdtemp()
42
43   def tearDown(self):
44     shutil.rmtree(self.tmpdir)
45
46   def test(self):
47     (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
48
49     self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
50                                                  backend._X509_CERT_FILE)),
51                      cert_pem)
52     self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
53                                                   backend._X509_KEY_FILE)))
54
55     (name2, cert_pem2) = \
56       backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
57
58     backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
59     backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
60
61     self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
62
63   def testNonEmpty(self):
64     (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
65
66     utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
67                     data="Hello World")
68
69     self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
70                       name, cryptodir=self.tmpdir)
71
72     self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
73
74
75 class TestNodeVerify(testutils.GanetiTestCase):
76
77   def setUp(self):
78     testutils.GanetiTestCase.setUp(self)
79     self._mock_hv = None
80
81   def _GetHypervisor(self, hv_name):
82     self._mock_hv = hypervisor.GetHypervisor(hv_name)
83     self._mock_hv.ValidateParameters = mock.Mock()
84     self._mock_hv.Verify = mock.Mock()
85     return self._mock_hv
86
87   def testMasterIPLocalhost(self):
88     # this a real functional test, but requires localhost to be reachable
89     local_data = (netutils.Hostname.GetSysName(),
90                   constants.IP4_ADDRESS_LOCALHOST)
91     result = backend.VerifyNode({constants.NV_MASTERIP: local_data}, None, {})
92     self.failUnless(constants.NV_MASTERIP in result,
93                     "Master IP data not returned")
94     self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
95
96   def testMasterIPUnreachable(self):
97     # Network 192.0.2.0/24 is reserved for test/documentation as per
98     # RFC 5737
99     bad_data =  ("master.example.com", "192.0.2.1")
100     # we just test that whatever TcpPing returns, VerifyNode returns too
101     netutils.TcpPing = lambda a, b, source=None: False
102     result = backend.VerifyNode({constants.NV_MASTERIP: bad_data}, None, {})
103     self.failUnless(constants.NV_MASTERIP in result,
104                     "Master IP data not returned")
105     self.failIf(result[constants.NV_MASTERIP],
106                 "Result from netutils.TcpPing corrupted")
107
108   def testVerifyHvparams(self):
109     test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
110     test_what = {constants.NV_HVPARAMS: \
111         [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
112     result = {}
113     backend._VerifyHvparams(test_what, True, result,
114                             get_hv_fn=self._GetHypervisor)
115     self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
116
117   def testVerifyHypervisors(self):
118     hvname = constants.HT_XEN_PVM
119     hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
120     all_hvparams = {hvname: hvparams}
121     test_what = {constants.NV_HYPERVISOR: [hvname]}
122     result = {}
123     backend._VerifyHypervisors(
124         test_what, True, result, all_hvparams=all_hvparams,
125         get_hv_fn=self._GetHypervisor)
126     self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
127
128
129 def _DefRestrictedCmdOwner():
130   return (os.getuid(), os.getgid())
131
132
133 class TestVerifyRestrictedCmdName(unittest.TestCase):
134   def testAcceptableName(self):
135     for i in ["foo", "bar", "z1", "000first", "hello-world"]:
136       for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
137         (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
138         self.assertTrue(status)
139         self.assertTrue(msg is None)
140
141   def testEmptyAndSpace(self):
142     for i in ["", " ", "\t", "\n"]:
143       (status, msg) = backend._VerifyRestrictedCmdName(i)
144       self.assertFalse(status)
145       self.assertEqual(msg, "Missing command name")
146
147   def testNameWithSlashes(self):
148     for i in ["/", "./foo", "../moo", "some/name"]:
149       (status, msg) = backend._VerifyRestrictedCmdName(i)
150       self.assertFalse(status)
151       self.assertEqual(msg, "Invalid command name")
152
153   def testForbiddenCharacters(self):
154     for i in ["#", ".", "..", "bash -c ls", "'"]:
155       (status, msg) = backend._VerifyRestrictedCmdName(i)
156       self.assertFalse(status)
157       self.assertEqual(msg, "Command name contains forbidden characters")
158
159
160 class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
161   def setUp(self):
162     self.tmpdir = tempfile.mkdtemp()
163
164   def tearDown(self):
165     shutil.rmtree(self.tmpdir)
166
167   def testCanNotStat(self):
168     tmpname = utils.PathJoin(self.tmpdir, "foobar")
169     self.assertFalse(os.path.exists(tmpname))
170     (status, msg) = \
171       backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
172     self.assertFalse(status)
173     self.assertTrue(msg.startswith("Can't stat(2) '"))
174
175   def testTooPermissive(self):
176     tmpname = utils.PathJoin(self.tmpdir, "foobar")
177     os.mkdir(tmpname)
178
179     for mode in [0777, 0706, 0760, 0722]:
180       os.chmod(tmpname, mode)
181       self.assertTrue(os.path.isdir(tmpname))
182       (status, msg) = \
183         backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
184       self.assertFalse(status)
185       self.assertTrue(msg.startswith("Permissions on '"))
186
187   def testNoDirectory(self):
188     tmpname = utils.PathJoin(self.tmpdir, "foobar")
189     utils.WriteFile(tmpname, data="empty\n")
190     self.assertTrue(os.path.isfile(tmpname))
191     (status, msg) = \
192       backend._VerifyRestrictedCmdDirectory(tmpname,
193                                             _owner=_DefRestrictedCmdOwner())
194     self.assertFalse(status)
195     self.assertTrue(msg.endswith("is not a directory"))
196
197   def testNormal(self):
198     tmpname = utils.PathJoin(self.tmpdir, "foobar")
199     os.mkdir(tmpname)
200     self.assertTrue(os.path.isdir(tmpname))
201     (status, msg) = \
202       backend._VerifyRestrictedCmdDirectory(tmpname,
203                                             _owner=_DefRestrictedCmdOwner())
204     self.assertTrue(status)
205     self.assertTrue(msg is None)
206
207
208 class TestVerifyRestrictedCmd(unittest.TestCase):
209   def setUp(self):
210     self.tmpdir = tempfile.mkdtemp()
211
212   def tearDown(self):
213     shutil.rmtree(self.tmpdir)
214
215   def testCanNotStat(self):
216     tmpname = utils.PathJoin(self.tmpdir, "helloworld")
217     self.assertFalse(os.path.exists(tmpname))
218     (status, msg) = \
219       backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
220                                    _owner=NotImplemented)
221     self.assertFalse(status)
222     self.assertTrue(msg.startswith("Can't stat(2) '"))
223
224   def testNotExecutable(self):
225     tmpname = utils.PathJoin(self.tmpdir, "cmdname")
226     utils.WriteFile(tmpname, data="empty\n")
227     (status, msg) = \
228       backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
229                                    _owner=_DefRestrictedCmdOwner())
230     self.assertFalse(status)
231     self.assertTrue(msg.startswith("access(2) thinks '"))
232
233   def testExecutable(self):
234     tmpname = utils.PathJoin(self.tmpdir, "cmdname")
235     utils.WriteFile(tmpname, data="empty\n", mode=0700)
236     (status, executable) = \
237       backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
238                                    _owner=_DefRestrictedCmdOwner())
239     self.assertTrue(status)
240     self.assertEqual(executable, tmpname)
241
242
243 class TestPrepareRestrictedCmd(unittest.TestCase):
244   _TEST_PATH = "/tmp/some/test/path"
245
246   def testDirFails(self):
247     def fn(path):
248       self.assertEqual(path, self._TEST_PATH)
249       return (False, "test error 31420")
250
251     (status, msg) = \
252       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
253                                     _verify_dir=fn,
254                                     _verify_name=NotImplemented,
255                                     _verify_cmd=NotImplemented)
256     self.assertFalse(status)
257     self.assertEqual(msg, "test error 31420")
258
259   def testNameFails(self):
260     def fn(cmd):
261       self.assertEqual(cmd, "cmd4617")
262       return (False, "test error 591")
263
264     (status, msg) = \
265       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
266                                     _verify_dir=lambda _: (True, None),
267                                     _verify_name=fn,
268                                     _verify_cmd=NotImplemented)
269     self.assertFalse(status)
270     self.assertEqual(msg, "test error 591")
271
272   def testCommandFails(self):
273     def fn(path, cmd):
274       self.assertEqual(path, self._TEST_PATH)
275       self.assertEqual(cmd, "cmd17577")
276       return (False, "test error 25524")
277
278     (status, msg) = \
279       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
280                                     _verify_dir=lambda _: (True, None),
281                                     _verify_name=lambda _: (True, None),
282                                     _verify_cmd=fn)
283     self.assertFalse(status)
284     self.assertEqual(msg, "test error 25524")
285
286   def testSuccess(self):
287     def fn(path, cmd):
288       return (True, utils.PathJoin(path, cmd))
289
290     (status, executable) = \
291       backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
292                                     _verify_dir=lambda _: (True, None),
293                                     _verify_name=lambda _: (True, None),
294                                     _verify_cmd=fn)
295     self.assertTrue(status)
296     self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
297
298
299 def _SleepForRestrictedCmd(duration):
300   assert duration > 5
301
302
303 def _GenericRestrictedCmdError(cmd):
304   return "Executing command '%s' failed" % cmd
305
306
307 class TestRunRestrictedCmd(unittest.TestCase):
308   def setUp(self):
309     self.tmpdir = tempfile.mkdtemp()
310
311   def tearDown(self):
312     shutil.rmtree(self.tmpdir)
313
314   def testNonExistantLockDirectory(self):
315     lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
316     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
317     self.assertFalse(os.path.exists(lockfile))
318     self.assertRaises(backend.RPCFail,
319                       backend.RunRestrictedCmd, "test",
320                       _lock_timeout=NotImplemented,
321                       _lock_file=lockfile,
322                       _path=NotImplemented,
323                       _sleep_fn=sleep_fn,
324                       _prepare_fn=NotImplemented,
325                       _runcmd_fn=NotImplemented,
326                       _enabled=True)
327     self.assertEqual(sleep_fn.Count(), 1)
328
329   @staticmethod
330   def _TryLock(lockfile):
331     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
332
333     result = False
334     try:
335       backend.RunRestrictedCmd("test22717",
336                                _lock_timeout=0.1,
337                                _lock_file=lockfile,
338                                _path=NotImplemented,
339                                _sleep_fn=sleep_fn,
340                                _prepare_fn=NotImplemented,
341                                _runcmd_fn=NotImplemented,
342                                _enabled=True)
343     except backend.RPCFail, err:
344       assert str(err) == _GenericRestrictedCmdError("test22717"), \
345              "Did not fail with generic error message"
346       result = True
347
348     assert sleep_fn.Count() == 1
349
350     return result
351
352   def testLockHeldByOtherProcess(self):
353     lockfile = utils.PathJoin(self.tmpdir, "lock")
354
355     lock = utils.FileLock.Open(lockfile)
356     lock.Exclusive(blocking=True, timeout=1.0)
357     try:
358       self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
359     finally:
360       lock.Close()
361
362   @staticmethod
363   def _PrepareRaisingException(path, cmd):
364     assert cmd == "test23122"
365     raise Exception("test")
366
367   def testPrepareRaisesException(self):
368     lockfile = utils.PathJoin(self.tmpdir, "lock")
369
370     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
371     prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
372
373     try:
374       backend.RunRestrictedCmd("test23122",
375                                _lock_timeout=1.0, _lock_file=lockfile,
376                                _path=NotImplemented, _runcmd_fn=NotImplemented,
377                                _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
378                                _enabled=True)
379     except backend.RPCFail, err:
380       self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
381     else:
382       self.fail("Didn't fail")
383
384     self.assertEqual(sleep_fn.Count(), 1)
385     self.assertEqual(prepare_fn.Count(), 1)
386
387   @staticmethod
388   def _PrepareFails(path, cmd):
389     assert cmd == "test29327"
390     return ("some error message", None)
391
392   def testPrepareFails(self):
393     lockfile = utils.PathJoin(self.tmpdir, "lock")
394
395     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
396     prepare_fn = testutils.CallCounter(self._PrepareFails)
397
398     try:
399       backend.RunRestrictedCmd("test29327",
400                                _lock_timeout=1.0, _lock_file=lockfile,
401                                _path=NotImplemented, _runcmd_fn=NotImplemented,
402                                _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
403                                _enabled=True)
404     except backend.RPCFail, err:
405       self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
406     else:
407       self.fail("Didn't fail")
408
409     self.assertEqual(sleep_fn.Count(), 1)
410     self.assertEqual(prepare_fn.Count(), 1)
411
412   @staticmethod
413   def _SuccessfulPrepare(path, cmd):
414     return (True, utils.PathJoin(path, cmd))
415
416   def testRunCmdFails(self):
417     lockfile = utils.PathJoin(self.tmpdir, "lock")
418
419     def fn(args, env=NotImplemented, reset_env=NotImplemented,
420            postfork_fn=NotImplemented):
421       self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
422       self.assertEqual(env, {})
423       self.assertTrue(reset_env)
424       self.assertTrue(callable(postfork_fn))
425
426       trylock = utils.FileLock.Open(lockfile)
427       try:
428         # See if lockfile is still held
429         self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
430
431         # Call back to release lock
432         postfork_fn(NotImplemented)
433
434         # See if lockfile can be acquired
435         trylock.Exclusive(blocking=False)
436       finally:
437         trylock.Close()
438
439       # Simulate a failed command
440       return utils.RunResult(constants.EXIT_FAILURE, None,
441                              "stdout", "stderr406328567",
442                              utils.ShellQuoteArgs(args),
443                              NotImplemented, NotImplemented)
444
445     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
446     prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
447     runcmd_fn = testutils.CallCounter(fn)
448
449     try:
450       backend.RunRestrictedCmd("test3079",
451                                _lock_timeout=1.0, _lock_file=lockfile,
452                                _path=self.tmpdir, _runcmd_fn=runcmd_fn,
453                                _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
454                                _enabled=True)
455     except backend.RPCFail, err:
456       self.assertTrue(str(err).startswith("Restricted command 'test3079'"
457                                           " failed:"))
458       self.assertTrue("stderr406328567" in str(err),
459                       msg="Error did not include output")
460     else:
461       self.fail("Didn't fail")
462
463     self.assertEqual(sleep_fn.Count(), 0)
464     self.assertEqual(prepare_fn.Count(), 1)
465     self.assertEqual(runcmd_fn.Count(), 1)
466
467   def testRunCmdSucceeds(self):
468     lockfile = utils.PathJoin(self.tmpdir, "lock")
469
470     def fn(args, env=NotImplemented, reset_env=NotImplemented,
471            postfork_fn=NotImplemented):
472       self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
473       self.assertEqual(env, {})
474       self.assertTrue(reset_env)
475
476       # Call back to release lock
477       postfork_fn(NotImplemented)
478
479       # Simulate a successful command
480       return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
481                              utils.ShellQuoteArgs(args),
482                              NotImplemented, NotImplemented)
483
484     sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
485     prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
486     runcmd_fn = testutils.CallCounter(fn)
487
488     result = backend.RunRestrictedCmd("test5667",
489                                       _lock_timeout=1.0, _lock_file=lockfile,
490                                       _path=self.tmpdir, _runcmd_fn=runcmd_fn,
491                                       _sleep_fn=sleep_fn,
492                                       _prepare_fn=prepare_fn,
493                                       _enabled=True)
494     self.assertEqual(result, "stdout14463")
495
496     self.assertEqual(sleep_fn.Count(), 0)
497     self.assertEqual(prepare_fn.Count(), 1)
498     self.assertEqual(runcmd_fn.Count(), 1)
499
500   def testCommandsDisabled(self):
501     try:
502       backend.RunRestrictedCmd("test",
503                                _lock_timeout=NotImplemented,
504                                _lock_file=NotImplemented,
505                                _path=NotImplemented,
506                                _sleep_fn=NotImplemented,
507                                _prepare_fn=NotImplemented,
508                                _runcmd_fn=NotImplemented,
509                                _enabled=False)
510     except backend.RPCFail, err:
511       self.assertEqual(str(err),
512                        "Restricted commands disabled at configure time")
513     else:
514       self.fail("Did not raise exception")
515
516
517 class TestSetWatcherPause(unittest.TestCase):
518   def setUp(self):
519     self.tmpdir = tempfile.mkdtemp()
520     self.filename = utils.PathJoin(self.tmpdir, "pause")
521
522   def tearDown(self):
523     shutil.rmtree(self.tmpdir)
524
525   def testUnsetNonExisting(self):
526     self.assertFalse(os.path.exists(self.filename))
527     backend.SetWatcherPause(None, _filename=self.filename)
528     self.assertFalse(os.path.exists(self.filename))
529
530   def testSetNonNumeric(self):
531     for i in ["", [], {}, "Hello World", "0", "1.0"]:
532       self.assertFalse(os.path.exists(self.filename))
533
534       try:
535         backend.SetWatcherPause(i, _filename=self.filename)
536       except backend.RPCFail, err:
537         self.assertEqual(str(err), "Duration must be numeric")
538       else:
539         self.fail("Did not raise exception")
540
541       self.assertFalse(os.path.exists(self.filename))
542
543   def testSet(self):
544     self.assertFalse(os.path.exists(self.filename))
545
546     for i in range(10):
547       backend.SetWatcherPause(i, _filename=self.filename)
548       self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
549       self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
550
551
552 class TestGetBlockDevSymlinkPath(unittest.TestCase):
553   def setUp(self):
554     self.tmpdir = tempfile.mkdtemp()
555
556   def tearDown(self):
557     shutil.rmtree(self.tmpdir)
558
559   def _Test(self, name, idx):
560     self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
561                                                      _dir=self.tmpdir),
562                      ("%s/%s%s%s" % (self.tmpdir, name,
563                                      constants.DISK_SEPARATOR, idx)))
564
565   def test(self):
566     for idx in range(100):
567       self._Test("inst1.example.com", idx)
568
569
570 class TestGetInstanceList(unittest.TestCase):
571
572   def setUp(self):
573     self._test_hv = self._TestHypervisor()
574     self._test_hv.ListInstances = mock.Mock(
575       return_value=["instance1", "instance2", "instance3"] )
576
577   class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
578     def __init__(self):
579       hypervisor.hv_base.BaseHypervisor.__init__(self)
580
581   def _GetHypervisor(self, name):
582     return self._test_hv
583
584   def testHvparams(self):
585     fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
586     hvparams = {constants.HT_FAKE: fake_hvparams}
587     backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
588                             get_hv_fn=self._GetHypervisor)
589     self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
590
591
592 class TestGetHvInfo(unittest.TestCase):
593
594   def setUp(self):
595     self._test_hv = self._TestHypervisor()
596     self._test_hv.GetNodeInfo = mock.Mock()
597
598   class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
599     def __init__(self):
600       hypervisor.hv_base.BaseHypervisor.__init__(self)
601
602   def _GetHypervisor(self, name):
603     return self._test_hv
604
605   def testGetHvInfoAllNone(self):
606     result = backend._GetHvInfoAll(None)
607     self.assertTrue(result is None)
608
609   def testGetHvInfoAll(self):
610     hvname = constants.HT_XEN_PVM
611     hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
612     hv_specs = [(hvname, hvparams)]
613
614     backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
615     self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
616
617
618 class TestApplyStorageInfoFunction(unittest.TestCase):
619
620   _STORAGE_KEY = "some_key"
621   _SOME_ARGS = ["some_args"]
622
623   def setUp(self):
624     self.mock_storage_fn = mock.Mock()
625
626   def testApplyValidStorageType(self):
627     storage_type = constants.ST_LVM_VG
628     info_fn_orig = backend._STORAGE_TYPE_INFO_FN
629     backend._STORAGE_TYPE_INFO_FN = {
630         storage_type: self.mock_storage_fn
631       }
632
633     backend._ApplyStorageInfoFunction(
634         storage_type, self._STORAGE_KEY, self._SOME_ARGS)
635
636     self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
637     backend._STORAGE_TYPE_INFO_FN = info_fn_orig
638
639   def testApplyInValidStorageType(self):
640     storage_type = "invalid_storage_type"
641     info_fn_orig = backend._STORAGE_TYPE_INFO_FN
642     backend._STORAGE_TYPE_INFO_FN = {}
643
644     self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
645                       storage_type, self._STORAGE_KEY, self._SOME_ARGS)
646     backend._STORAGE_TYPE_INFO_FN = info_fn_orig
647
648   def testApplyNotImplementedStorageType(self):
649     storage_type = "not_implemented_storage_type"
650     info_fn_orig = backend._STORAGE_TYPE_INFO_FN
651     backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
652
653     self.assertRaises(NotImplementedError,
654                       backend._ApplyStorageInfoFunction,
655                       storage_type, self._STORAGE_KEY, self._SOME_ARGS)
656     backend._STORAGE_TYPE_INFO_FN = info_fn_orig
657
658
659 class TestGetLvmVgSpaceInfo(unittest.TestCase):
660
661   def testValid(self):
662     path = "somepath"
663     excl_stor = True
664     orig_fn = backend._GetVgInfo
665     backend._GetVgInfo = mock.Mock()
666     backend._GetLvmVgSpaceInfo(path, [excl_stor])
667     backend._GetVgInfo.assert_called_with(path, excl_stor)
668     backend._GetVgInfo = orig_fn
669
670   def testNoExclStorageNotBool(self):
671     path = "somepath"
672     excl_stor = "123"
673     self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
674                       path, [excl_stor])
675
676   def testNoExclStorageNotInList(self):
677     path = "somepath"
678     excl_stor = "123"
679     self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
680                       path, excl_stor)
681
682 class TestGetLvmPvSpaceInfo(unittest.TestCase):
683
684   def testValid(self):
685     path = "somepath"
686     excl_stor = True
687     orig_fn = backend._GetVgSpindlesInfo
688     backend._GetVgSpindlesInfo = mock.Mock()
689     backend._GetLvmPvSpaceInfo(path, [excl_stor])
690     backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
691     backend._GetVgSpindlesInfo = orig_fn
692
693
694 class TestCheckStorageParams(unittest.TestCase):
695
696   def testParamsNone(self):
697     self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
698                       None, NotImplemented)
699
700   def testParamsWrongType(self):
701     self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
702                       "string", NotImplemented)
703
704   def testParamsEmpty(self):
705     backend._CheckStorageParams([], 0)
706
707   def testParamsValidNumber(self):
708     backend._CheckStorageParams(["a", True], 2)
709
710   def testParamsInvalidNumber(self):
711     self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
712                       ["b", False], 3)
713
714
715 class TestGetVgSpindlesInfo(unittest.TestCase):
716
717   def setUp(self):
718     self.vg_free = 13
719     self.vg_size = 31
720     self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
721
722   def testValidInput(self):
723     name = "myvg"
724     excl_stor = True
725     result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
726     self.mock_fn.assert_called_with(name)
727     self.assertEqual(name, result["name"])
728     self.assertEqual(constants.ST_LVM_PV, result["type"])
729     self.assertEqual(self.vg_free, result["storage_free"])
730     self.assertEqual(self.vg_size, result["storage_size"])
731
732   def testNoExclStor(self):
733     name = "myvg"
734     excl_stor = False
735     result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
736     self.mock_fn.assert_not_called()
737     self.assertEqual(name, result["name"])
738     self.assertEqual(constants.ST_LVM_PV, result["type"])
739     self.assertEqual(0, result["storage_free"])
740     self.assertEqual(0, result["storage_size"])
741
742
743 class TestGetVgSpindlesInfo(unittest.TestCase):
744
745   def testValidInput(self):
746     self.vg_free = 13
747     self.vg_size = 31
748     self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
749     name = "myvg"
750     excl_stor = True
751     result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
752     self.mock_fn.assert_called_with([name], excl_stor)
753     self.assertEqual(name, result["name"])
754     self.assertEqual(constants.ST_LVM_VG, result["type"])
755     self.assertEqual(self.vg_free, result["storage_free"])
756     self.assertEqual(self.vg_size, result["storage_size"])
757
758   def testNoExclStor(self):
759     name = "myvg"
760     excl_stor = True
761     self.mock_fn = mock.Mock(return_value=None)
762     result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
763     self.mock_fn.assert_called_with([name], excl_stor)
764     self.assertEqual(name, result["name"])
765     self.assertEqual(constants.ST_LVM_VG, result["type"])
766     self.assertEqual(None, result["storage_free"])
767     self.assertEqual(None, result["storage_size"])
768
769
770 class TestGetNodeInfo(unittest.TestCase):
771
772   _SOME_RESULT = None
773
774   def testApplyStorageInfoFunction(self):
775     orig_fn = backend._ApplyStorageInfoFunction
776     backend._ApplyStorageInfoFunction = mock.Mock(
777         return_value=self._SOME_RESULT)
778     storage_units = [(st, st + "_key", [st + "_params"]) for st in
779                      constants.STORAGE_TYPES]
780
781     backend.GetNodeInfo(storage_units, None)
782
783     call_args_list = backend._ApplyStorageInfoFunction.call_args_list
784     self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
785     for call in call_args_list:
786       storage_type, storage_key, storage_params = call[0]
787       self.assertEqual(storage_type + "_key", storage_key)
788       self.assertEqual([storage_type + "_params"], storage_params)
789       self.assertTrue(storage_type in constants.STORAGE_TYPES)
790     backend._ApplyStorageInfoFunction = orig_fn
791
792
793 class TestSpaceReportingConstants(unittest.TestCase):
794   """Ensures consistency between STS_REPORT and backend.
795
796   These tests ensure, that the constant 'STS_REPORT' is consitent
797   with the implementation of invoking space reporting functions
798   in backend.py. Once space reporting is available for all types,
799   the constant can be removed and these tests as well.
800
801   """
802   def testAllReportingTypesHaveAReportingFunction(self):
803     for storage_type in constants.STS_REPORT:
804       self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
805
806   def testAllNotReportingTypesDoneHaveFunction(self):
807     non_reporting_types = set(constants.VALID_STORAGE_TYPES)\
808         - set(constants.STS_REPORT)
809     for storage_type in non_reporting_types:
810       self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
811
812
813 if __name__ == "__main__":
814   testutils.GanetiTestProgram()