Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ d8e55568

History | View | Annotate | Download (28 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import mock
25
import os
26
import shutil
27
import tempfile
28
import testutils
29
import unittest
30

    
31
from ganeti import backend
32
from ganeti import constants
33
from ganeti import errors
34
from ganeti import hypervisor
35
from ganeti import netutils
36
from ganeti import utils
37

    
38

    
39
class TestX509Certificates(unittest.TestCase):
40
  def setUp(self):
41
    self.tmpdir = tempfile.mkdtemp()
42

    
43
  def tearDown(self):
44
    shutil.rmtree(self.tmpdir)
45

    
46
  def test(self):
47
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
48

    
49
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
50
                                                 backend._X509_CERT_FILE)),
51
                     cert_pem)
52
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
53
                                                  backend._X509_KEY_FILE)))
54

    
55
    (name2, cert_pem2) = \
56
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
57

    
58
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
59
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
60

    
61
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
62

    
63
  def testNonEmpty(self):
64
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
65

    
66
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
67
                    data="Hello World")
68

    
69
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
70
                      name, cryptodir=self.tmpdir)
71

    
72
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
73

    
74

    
75
class TestNodeVerify(testutils.GanetiTestCase):
76

    
77
  def setUp(self):
78
    testutils.GanetiTestCase.setUp(self)
79
    self._mock_hv = None
80

    
81
  def _GetHypervisor(self, hv_name):
82
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
83
    self._mock_hv.ValidateParameters = mock.Mock()
84
    self._mock_hv.Verify = mock.Mock()
85
    return self._mock_hv
86

    
87
  def testMasterIPLocalhost(self):
88
    # this a real functional test, but requires localhost to be reachable
89
    local_data = (netutils.Hostname.GetSysName(),
90
                  constants.IP4_ADDRESS_LOCALHOST)
91
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data}, None, {})
92
    self.failUnless(constants.NV_MASTERIP in result,
93
                    "Master IP data not returned")
94
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
95

    
96
  def testMasterIPUnreachable(self):
97
    # Network 192.0.2.0/24 is reserved for test/documentation as per
98
    # RFC 5737
99
    bad_data =  ("master.example.com", "192.0.2.1")
100
    # we just test that whatever TcpPing returns, VerifyNode returns too
101
    netutils.TcpPing = lambda a, b, source=None: False
102
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data}, None, {})
103
    self.failUnless(constants.NV_MASTERIP in result,
104
                    "Master IP data not returned")
105
    self.failIf(result[constants.NV_MASTERIP],
106
                "Result from netutils.TcpPing corrupted")
107

    
108
  def testVerifyHvparams(self):
109
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
110
    test_what = {constants.NV_HVPARAMS: \
111
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
112
    result = {}
113
    backend._VerifyHvparams(test_what, True, result,
114
                            get_hv_fn=self._GetHypervisor)
115
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
116

    
117
  def testVerifyHypervisors(self):
118
    hvname = constants.HT_XEN_PVM
119
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
120
    all_hvparams = {hvname: hvparams}
121
    test_what = {constants.NV_HYPERVISOR: [hvname]}
122
    result = {}
123
    backend._VerifyHypervisors(
124
        test_what, True, result, all_hvparams=all_hvparams,
125
        get_hv_fn=self._GetHypervisor)
126
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
127

    
128

    
129
def _DefRestrictedCmdOwner():
130
  return (os.getuid(), os.getgid())
131

    
132

    
133
class TestVerifyRestrictedCmdName(unittest.TestCase):
134
  def testAcceptableName(self):
135
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
136
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
137
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
138
        self.assertTrue(status)
139
        self.assertTrue(msg is None)
140

    
141
  def testEmptyAndSpace(self):
142
    for i in ["", " ", "\t", "\n"]:
143
      (status, msg) = backend._VerifyRestrictedCmdName(i)
144
      self.assertFalse(status)
145
      self.assertEqual(msg, "Missing command name")
146

    
147
  def testNameWithSlashes(self):
148
    for i in ["/", "./foo", "../moo", "some/name"]:
149
      (status, msg) = backend._VerifyRestrictedCmdName(i)
150
      self.assertFalse(status)
151
      self.assertEqual(msg, "Invalid command name")
152

    
153
  def testForbiddenCharacters(self):
154
    for i in ["#", ".", "..", "bash -c ls", "'"]:
155
      (status, msg) = backend._VerifyRestrictedCmdName(i)
156
      self.assertFalse(status)
157
      self.assertEqual(msg, "Command name contains forbidden characters")
158

    
159

    
160
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
161
  def setUp(self):
162
    self.tmpdir = tempfile.mkdtemp()
163

    
164
  def tearDown(self):
165
    shutil.rmtree(self.tmpdir)
166

    
167
  def testCanNotStat(self):
168
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
169
    self.assertFalse(os.path.exists(tmpname))
170
    (status, msg) = \
171
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
172
    self.assertFalse(status)
173
    self.assertTrue(msg.startswith("Can't stat(2) '"))
174

    
175
  def testTooPermissive(self):
176
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
177
    os.mkdir(tmpname)
178

    
179
    for mode in [0777, 0706, 0760, 0722]:
180
      os.chmod(tmpname, mode)
181
      self.assertTrue(os.path.isdir(tmpname))
182
      (status, msg) = \
183
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
184
      self.assertFalse(status)
185
      self.assertTrue(msg.startswith("Permissions on '"))
186

    
187
  def testNoDirectory(self):
188
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
189
    utils.WriteFile(tmpname, data="empty\n")
190
    self.assertTrue(os.path.isfile(tmpname))
191
    (status, msg) = \
192
      backend._VerifyRestrictedCmdDirectory(tmpname,
193
                                            _owner=_DefRestrictedCmdOwner())
194
    self.assertFalse(status)
195
    self.assertTrue(msg.endswith("is not a directory"))
196

    
197
  def testNormal(self):
198
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
199
    os.mkdir(tmpname)
200
    self.assertTrue(os.path.isdir(tmpname))
201
    (status, msg) = \
202
      backend._VerifyRestrictedCmdDirectory(tmpname,
203
                                            _owner=_DefRestrictedCmdOwner())
204
    self.assertTrue(status)
205
    self.assertTrue(msg is None)
206

    
207

    
208
class TestVerifyRestrictedCmd(unittest.TestCase):
209
  def setUp(self):
210
    self.tmpdir = tempfile.mkdtemp()
211

    
212
  def tearDown(self):
213
    shutil.rmtree(self.tmpdir)
214

    
215
  def testCanNotStat(self):
216
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
217
    self.assertFalse(os.path.exists(tmpname))
218
    (status, msg) = \
219
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
220
                                   _owner=NotImplemented)
221
    self.assertFalse(status)
222
    self.assertTrue(msg.startswith("Can't stat(2) '"))
223

    
224
  def testNotExecutable(self):
225
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
226
    utils.WriteFile(tmpname, data="empty\n")
227
    (status, msg) = \
228
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
229
                                   _owner=_DefRestrictedCmdOwner())
230
    self.assertFalse(status)
231
    self.assertTrue(msg.startswith("access(2) thinks '"))
232

    
233
  def testExecutable(self):
234
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
235
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
236
    (status, executable) = \
237
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
238
                                   _owner=_DefRestrictedCmdOwner())
239
    self.assertTrue(status)
240
    self.assertEqual(executable, tmpname)
241

    
242

    
243
class TestPrepareRestrictedCmd(unittest.TestCase):
244
  _TEST_PATH = "/tmp/some/test/path"
245

    
246
  def testDirFails(self):
247
    def fn(path):
248
      self.assertEqual(path, self._TEST_PATH)
249
      return (False, "test error 31420")
250

    
251
    (status, msg) = \
252
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
253
                                    _verify_dir=fn,
254
                                    _verify_name=NotImplemented,
255
                                    _verify_cmd=NotImplemented)
256
    self.assertFalse(status)
257
    self.assertEqual(msg, "test error 31420")
258

    
259
  def testNameFails(self):
260
    def fn(cmd):
261
      self.assertEqual(cmd, "cmd4617")
262
      return (False, "test error 591")
263

    
264
    (status, msg) = \
265
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
266
                                    _verify_dir=lambda _: (True, None),
267
                                    _verify_name=fn,
268
                                    _verify_cmd=NotImplemented)
269
    self.assertFalse(status)
270
    self.assertEqual(msg, "test error 591")
271

    
272
  def testCommandFails(self):
273
    def fn(path, cmd):
274
      self.assertEqual(path, self._TEST_PATH)
275
      self.assertEqual(cmd, "cmd17577")
276
      return (False, "test error 25524")
277

    
278
    (status, msg) = \
279
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
280
                                    _verify_dir=lambda _: (True, None),
281
                                    _verify_name=lambda _: (True, None),
282
                                    _verify_cmd=fn)
283
    self.assertFalse(status)
284
    self.assertEqual(msg, "test error 25524")
285

    
286
  def testSuccess(self):
287
    def fn(path, cmd):
288
      return (True, utils.PathJoin(path, cmd))
289

    
290
    (status, executable) = \
291
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
292
                                    _verify_dir=lambda _: (True, None),
293
                                    _verify_name=lambda _: (True, None),
294
                                    _verify_cmd=fn)
295
    self.assertTrue(status)
296
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
297

    
298

    
299
def _SleepForRestrictedCmd(duration):
300
  assert duration > 5
301

    
302

    
303
def _GenericRestrictedCmdError(cmd):
304
  return "Executing command '%s' failed" % cmd
305

    
306

    
307
class TestRunRestrictedCmd(unittest.TestCase):
308
  def setUp(self):
309
    self.tmpdir = tempfile.mkdtemp()
310

    
311
  def tearDown(self):
312
    shutil.rmtree(self.tmpdir)
313

    
314
  def testNonExistantLockDirectory(self):
315
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
316
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
317
    self.assertFalse(os.path.exists(lockfile))
318
    self.assertRaises(backend.RPCFail,
319
                      backend.RunRestrictedCmd, "test",
320
                      _lock_timeout=NotImplemented,
321
                      _lock_file=lockfile,
322
                      _path=NotImplemented,
323
                      _sleep_fn=sleep_fn,
324
                      _prepare_fn=NotImplemented,
325
                      _runcmd_fn=NotImplemented,
326
                      _enabled=True)
327
    self.assertEqual(sleep_fn.Count(), 1)
328

    
329
  @staticmethod
330
  def _TryLock(lockfile):
331
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
332

    
333
    result = False
334
    try:
335
      backend.RunRestrictedCmd("test22717",
336
                               _lock_timeout=0.1,
337
                               _lock_file=lockfile,
338
                               _path=NotImplemented,
339
                               _sleep_fn=sleep_fn,
340
                               _prepare_fn=NotImplemented,
341
                               _runcmd_fn=NotImplemented,
342
                               _enabled=True)
343
    except backend.RPCFail, err:
344
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
345
             "Did not fail with generic error message"
346
      result = True
347

    
348
    assert sleep_fn.Count() == 1
349

    
350
    return result
351

    
352
  def testLockHeldByOtherProcess(self):
353
    lockfile = utils.PathJoin(self.tmpdir, "lock")
354

    
355
    lock = utils.FileLock.Open(lockfile)
356
    lock.Exclusive(blocking=True, timeout=1.0)
357
    try:
358
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
359
    finally:
360
      lock.Close()
361

    
362
  @staticmethod
363
  def _PrepareRaisingException(path, cmd):
364
    assert cmd == "test23122"
365
    raise Exception("test")
366

    
367
  def testPrepareRaisesException(self):
368
    lockfile = utils.PathJoin(self.tmpdir, "lock")
369

    
370
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
371
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
372

    
373
    try:
374
      backend.RunRestrictedCmd("test23122",
375
                               _lock_timeout=1.0, _lock_file=lockfile,
376
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
377
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
378
                               _enabled=True)
379
    except backend.RPCFail, err:
380
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
381
    else:
382
      self.fail("Didn't fail")
383

    
384
    self.assertEqual(sleep_fn.Count(), 1)
385
    self.assertEqual(prepare_fn.Count(), 1)
386

    
387
  @staticmethod
388
  def _PrepareFails(path, cmd):
389
    assert cmd == "test29327"
390
    return ("some error message", None)
391

    
392
  def testPrepareFails(self):
393
    lockfile = utils.PathJoin(self.tmpdir, "lock")
394

    
395
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
396
    prepare_fn = testutils.CallCounter(self._PrepareFails)
397

    
398
    try:
399
      backend.RunRestrictedCmd("test29327",
400
                               _lock_timeout=1.0, _lock_file=lockfile,
401
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
402
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
403
                               _enabled=True)
404
    except backend.RPCFail, err:
405
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
406
    else:
407
      self.fail("Didn't fail")
408

    
409
    self.assertEqual(sleep_fn.Count(), 1)
410
    self.assertEqual(prepare_fn.Count(), 1)
411

    
412
  @staticmethod
413
  def _SuccessfulPrepare(path, cmd):
414
    return (True, utils.PathJoin(path, cmd))
415

    
416
  def testRunCmdFails(self):
417
    lockfile = utils.PathJoin(self.tmpdir, "lock")
418

    
419
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
420
           postfork_fn=NotImplemented):
421
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
422
      self.assertEqual(env, {})
423
      self.assertTrue(reset_env)
424
      self.assertTrue(callable(postfork_fn))
425

    
426
      trylock = utils.FileLock.Open(lockfile)
427
      try:
428
        # See if lockfile is still held
429
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
430

    
431
        # Call back to release lock
432
        postfork_fn(NotImplemented)
433

    
434
        # See if lockfile can be acquired
435
        trylock.Exclusive(blocking=False)
436
      finally:
437
        trylock.Close()
438

    
439
      # Simulate a failed command
440
      return utils.RunResult(constants.EXIT_FAILURE, None,
441
                             "stdout", "stderr406328567",
442
                             utils.ShellQuoteArgs(args),
443
                             NotImplemented, NotImplemented)
444

    
445
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
446
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
447
    runcmd_fn = testutils.CallCounter(fn)
448

    
449
    try:
450
      backend.RunRestrictedCmd("test3079",
451
                               _lock_timeout=1.0, _lock_file=lockfile,
452
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
453
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
454
                               _enabled=True)
455
    except backend.RPCFail, err:
456
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
457
                                          " failed:"))
458
      self.assertTrue("stderr406328567" in str(err),
459
                      msg="Error did not include output")
460
    else:
461
      self.fail("Didn't fail")
462

    
463
    self.assertEqual(sleep_fn.Count(), 0)
464
    self.assertEqual(prepare_fn.Count(), 1)
465
    self.assertEqual(runcmd_fn.Count(), 1)
466

    
467
  def testRunCmdSucceeds(self):
468
    lockfile = utils.PathJoin(self.tmpdir, "lock")
469

    
470
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
471
           postfork_fn=NotImplemented):
472
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
473
      self.assertEqual(env, {})
474
      self.assertTrue(reset_env)
475

    
476
      # Call back to release lock
477
      postfork_fn(NotImplemented)
478

    
479
      # Simulate a successful command
480
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
481
                             utils.ShellQuoteArgs(args),
482
                             NotImplemented, NotImplemented)
483

    
484
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
485
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
486
    runcmd_fn = testutils.CallCounter(fn)
487

    
488
    result = backend.RunRestrictedCmd("test5667",
489
                                      _lock_timeout=1.0, _lock_file=lockfile,
490
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
491
                                      _sleep_fn=sleep_fn,
492
                                      _prepare_fn=prepare_fn,
493
                                      _enabled=True)
494
    self.assertEqual(result, "stdout14463")
495

    
496
    self.assertEqual(sleep_fn.Count(), 0)
497
    self.assertEqual(prepare_fn.Count(), 1)
498
    self.assertEqual(runcmd_fn.Count(), 1)
499

    
500
  def testCommandsDisabled(self):
501
    try:
502
      backend.RunRestrictedCmd("test",
503
                               _lock_timeout=NotImplemented,
504
                               _lock_file=NotImplemented,
505
                               _path=NotImplemented,
506
                               _sleep_fn=NotImplemented,
507
                               _prepare_fn=NotImplemented,
508
                               _runcmd_fn=NotImplemented,
509
                               _enabled=False)
510
    except backend.RPCFail, err:
511
      self.assertEqual(str(err),
512
                       "Restricted commands disabled at configure time")
513
    else:
514
      self.fail("Did not raise exception")
515

    
516

    
517
class TestSetWatcherPause(unittest.TestCase):
518
  def setUp(self):
519
    self.tmpdir = tempfile.mkdtemp()
520
    self.filename = utils.PathJoin(self.tmpdir, "pause")
521

    
522
  def tearDown(self):
523
    shutil.rmtree(self.tmpdir)
524

    
525
  def testUnsetNonExisting(self):
526
    self.assertFalse(os.path.exists(self.filename))
527
    backend.SetWatcherPause(None, _filename=self.filename)
528
    self.assertFalse(os.path.exists(self.filename))
529

    
530
  def testSetNonNumeric(self):
531
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
532
      self.assertFalse(os.path.exists(self.filename))
533

    
534
      try:
535
        backend.SetWatcherPause(i, _filename=self.filename)
536
      except backend.RPCFail, err:
537
        self.assertEqual(str(err), "Duration must be numeric")
538
      else:
539
        self.fail("Did not raise exception")
540

    
541
      self.assertFalse(os.path.exists(self.filename))
542

    
543
  def testSet(self):
544
    self.assertFalse(os.path.exists(self.filename))
545

    
546
    for i in range(10):
547
      backend.SetWatcherPause(i, _filename=self.filename)
548
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
549
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
550

    
551

    
552
class TestGetBlockDevSymlinkPath(unittest.TestCase):
553
  def setUp(self):
554
    self.tmpdir = tempfile.mkdtemp()
555

    
556
  def tearDown(self):
557
    shutil.rmtree(self.tmpdir)
558

    
559
  def _Test(self, name, idx):
560
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
561
                                                     _dir=self.tmpdir),
562
                     ("%s/%s%s%s" % (self.tmpdir, name,
563
                                     constants.DISK_SEPARATOR, idx)))
564

    
565
  def test(self):
566
    for idx in range(100):
567
      self._Test("inst1.example.com", idx)
568

    
569

    
570
class TestGetInstanceList(unittest.TestCase):
571

    
572
  def setUp(self):
573
    self._test_hv = self._TestHypervisor()
574
    self._test_hv.ListInstances = mock.Mock(
575
      return_value=["instance1", "instance2", "instance3"] )
576

    
577
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
578
    def __init__(self):
579
      hypervisor.hv_base.BaseHypervisor.__init__(self)
580

    
581
  def _GetHypervisor(self, name):
582
    return self._test_hv
583

    
584
  def testHvparams(self):
585
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
586
    hvparams = {constants.HT_FAKE: fake_hvparams}
587
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
588
                            get_hv_fn=self._GetHypervisor)
589
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
590

    
591

    
592
class TestGetHvInfo(unittest.TestCase):
593

    
594
  def setUp(self):
595
    self._test_hv = self._TestHypervisor()
596
    self._test_hv.GetNodeInfo = mock.Mock()
597

    
598
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
599
    def __init__(self):
600
      hypervisor.hv_base.BaseHypervisor.__init__(self)
601

    
602
  def _GetHypervisor(self, name):
603
    return self._test_hv
604

    
605
  def testGetHvInfoAllNone(self):
606
    result = backend._GetHvInfoAll(None)
607
    self.assertTrue(result is None)
608

    
609
  def testGetHvInfoAll(self):
610
    hvname = constants.HT_XEN_PVM
611
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
612
    hv_specs = [(hvname, hvparams)]
613

    
614
    backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
615
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
616

    
617

    
618
class TestApplyStorageInfoFunction(unittest.TestCase):
619

    
620
  _STORAGE_KEY = "some_key"
621
  _SOME_ARGS = ["some_args"]
622

    
623
  def setUp(self):
624
    self.mock_storage_fn = mock.Mock()
625

    
626
  def testApplyValidStorageType(self):
627
    storage_type = constants.ST_LVM_VG
628
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
629
    backend._STORAGE_TYPE_INFO_FN = {
630
        storage_type: self.mock_storage_fn
631
      }
632

    
633
    backend._ApplyStorageInfoFunction(
634
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
635

    
636
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
637
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
638

    
639
  def testApplyInValidStorageType(self):
640
    storage_type = "invalid_storage_type"
641
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
642
    backend._STORAGE_TYPE_INFO_FN = {}
643

    
644
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
645
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
646
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
647

    
648
  def testApplyNotImplementedStorageType(self):
649
    storage_type = "not_implemented_storage_type"
650
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
651
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
652

    
653
    self.assertRaises(NotImplementedError,
654
                      backend._ApplyStorageInfoFunction,
655
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
656
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
657

    
658

    
659
class TestGetLvmVgSpaceInfo(unittest.TestCase):
660

    
661
  def testValid(self):
662
    path = "somepath"
663
    excl_stor = True
664
    orig_fn = backend._GetVgInfo
665
    backend._GetVgInfo = mock.Mock()
666
    backend._GetLvmVgSpaceInfo(path, [excl_stor])
667
    backend._GetVgInfo.assert_called_with(path, excl_stor)
668
    backend._GetVgInfo = orig_fn
669

    
670
  def testNoExclStorageNotBool(self):
671
    path = "somepath"
672
    excl_stor = "123"
673
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
674
                      path, [excl_stor])
675

    
676
  def testNoExclStorageNotInList(self):
677
    path = "somepath"
678
    excl_stor = "123"
679
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
680
                      path, excl_stor)
681

    
682
class TestGetLvmPvSpaceInfo(unittest.TestCase):
683

    
684
  def testValid(self):
685
    path = "somepath"
686
    excl_stor = True
687
    orig_fn = backend._GetVgSpindlesInfo
688
    backend._GetVgSpindlesInfo = mock.Mock()
689
    backend._GetLvmPvSpaceInfo(path, [excl_stor])
690
    backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
691
    backend._GetVgSpindlesInfo = orig_fn
692

    
693

    
694
class TestCheckStorageParams(unittest.TestCase):
695

    
696
  def testParamsNone(self):
697
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
698
                      None, NotImplemented)
699

    
700
  def testParamsWrongType(self):
701
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
702
                      "string", NotImplemented)
703

    
704
  def testParamsEmpty(self):
705
    backend._CheckStorageParams([], 0)
706

    
707
  def testParamsValidNumber(self):
708
    backend._CheckStorageParams(["a", True], 2)
709

    
710
  def testParamsInvalidNumber(self):
711
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
712
                      ["b", False], 3)
713

    
714

    
715
class TestGetVgSpindlesInfo(unittest.TestCase):
716

    
717
  def setUp(self):
718
    self.vg_free = 13
719
    self.vg_size = 31
720
    self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
721

    
722
  def testValidInput(self):
723
    name = "myvg"
724
    excl_stor = True
725
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
726
    self.mock_fn.assert_called_with(name)
727
    self.assertEqual(name, result["name"])
728
    self.assertEqual(constants.ST_LVM_PV, result["type"])
729
    self.assertEqual(self.vg_free, result["storage_free"])
730
    self.assertEqual(self.vg_size, result["storage_size"])
731

    
732
  def testNoExclStor(self):
733
    name = "myvg"
734
    excl_stor = False
735
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
736
    self.mock_fn.assert_not_called()
737
    self.assertEqual(name, result["name"])
738
    self.assertEqual(constants.ST_LVM_PV, result["type"])
739
    self.assertEqual(0, result["storage_free"])
740
    self.assertEqual(0, result["storage_size"])
741

    
742

    
743
class TestGetVgSpindlesInfo(unittest.TestCase):
744

    
745
  def testValidInput(self):
746
    self.vg_free = 13
747
    self.vg_size = 31
748
    self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
749
    name = "myvg"
750
    excl_stor = True
751
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
752
    self.mock_fn.assert_called_with([name], excl_stor)
753
    self.assertEqual(name, result["name"])
754
    self.assertEqual(constants.ST_LVM_VG, result["type"])
755
    self.assertEqual(self.vg_free, result["storage_free"])
756
    self.assertEqual(self.vg_size, result["storage_size"])
757

    
758
  def testNoExclStor(self):
759
    name = "myvg"
760
    excl_stor = True
761
    self.mock_fn = mock.Mock(return_value=None)
762
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
763
    self.mock_fn.assert_called_with([name], excl_stor)
764
    self.assertEqual(name, result["name"])
765
    self.assertEqual(constants.ST_LVM_VG, result["type"])
766
    self.assertEqual(None, result["storage_free"])
767
    self.assertEqual(None, result["storage_size"])
768

    
769

    
770
class TestGetNodeInfo(unittest.TestCase):
771

    
772
  _SOME_RESULT = None
773

    
774
  def testApplyStorageInfoFunction(self):
775
    orig_fn = backend._ApplyStorageInfoFunction
776
    backend._ApplyStorageInfoFunction = mock.Mock(
777
        return_value=self._SOME_RESULT)
778
    storage_units = [(st, st + "_key", [st + "_params"]) for st in
779
                     constants.STORAGE_TYPES]
780

    
781
    backend.GetNodeInfo(storage_units, None)
782

    
783
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
784
    self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
785
    for call in call_args_list:
786
      storage_type, storage_key, storage_params = call[0]
787
      self.assertEqual(storage_type + "_key", storage_key)
788
      self.assertEqual([storage_type + "_params"], storage_params)
789
      self.assertTrue(storage_type in constants.STORAGE_TYPES)
790
    backend._ApplyStorageInfoFunction = orig_fn
791

    
792

    
793
class TestSpaceReportingConstants(unittest.TestCase):
794
  """Ensures consistency between STS_REPORT and backend.
795

796
  These tests ensure, that the constant 'STS_REPORT' is consitent
797
  with the implementation of invoking space reporting functions
798
  in backend.py. Once space reporting is available for all types,
799
  the constant can be removed and these tests as well.
800

801
  """
802
  def testAllReportingTypesHaveAReportingFunction(self):
803
    for storage_type in constants.STS_REPORT:
804
      self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
805

    
806
  def testAllNotReportingTypesDoneHaveFunction(self):
807
    non_reporting_types = set(constants.VALID_STORAGE_TYPES)\
808
        - set(constants.STS_REPORT)
809
    for storage_type in non_reporting_types:
810
      self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
811

    
812

    
813
if __name__ == "__main__":
814
  testutils.GanetiTestProgram()