Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ f39a8d14

History | View | Annotate | Download (23.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import os
25
import sys
26
import shutil
27
import tempfile
28
import unittest
29
import mock
30

    
31
from ganeti import utils
32
from ganeti import constants
33
from ganeti import backend
34
from ganeti import netutils
35
from ganeti import errors
36
from ganeti import serializer
37
from ganeti import hypervisor
38

    
39
import testutils
40
import mocks
41

    
42

    
43
class TestX509Certificates(unittest.TestCase):
44
  def setUp(self):
45
    self.tmpdir = tempfile.mkdtemp()
46

    
47
  def tearDown(self):
48
    shutil.rmtree(self.tmpdir)
49

    
50
  def test(self):
51
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
52

    
53
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
54
                                                 backend._X509_CERT_FILE)),
55
                     cert_pem)
56
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
57
                                                  backend._X509_KEY_FILE)))
58

    
59
    (name2, cert_pem2) = \
60
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
61

    
62
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
63
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
64

    
65
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
66

    
67
  def testNonEmpty(self):
68
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
69

    
70
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
71
                    data="Hello World")
72

    
73
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
74
                      name, cryptodir=self.tmpdir)
75

    
76
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
77

    
78

    
79
class TestNodeVerify(testutils.GanetiTestCase):
80

    
81
  def setUp(self):
82
    testutils.GanetiTestCase.setUp(self)
83
    self._mock_hv = None
84

    
85
  def _GetHypervisor(self, hv_name):
86
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
87
    self._mock_hv.ValidateParameters = mock.Mock()
88
    self._mock_hv.Verify = mock.Mock()
89
    return self._mock_hv
90

    
91
  def testMasterIPLocalhost(self):
92
    # this a real functional test, but requires localhost to be reachable
93
    local_data = (netutils.Hostname.GetSysName(),
94
                  constants.IP4_ADDRESS_LOCALHOST)
95
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data}, None, {})
96
    self.failUnless(constants.NV_MASTERIP in result,
97
                    "Master IP data not returned")
98
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
99

    
100
  def testMasterIPUnreachable(self):
101
    # Network 192.0.2.0/24 is reserved for test/documentation as per
102
    # RFC 5737
103
    bad_data =  ("master.example.com", "192.0.2.1")
104
    # we just test that whatever TcpPing returns, VerifyNode returns too
105
    netutils.TcpPing = lambda a, b, source=None: False
106
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data}, None, {})
107
    self.failUnless(constants.NV_MASTERIP in result,
108
                    "Master IP data not returned")
109
    self.failIf(result[constants.NV_MASTERIP],
110
                "Result from netutils.TcpPing corrupted")
111

    
112
  def testVerifyHvparams(self):
113
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
114
    test_what = {constants.NV_HVPARAMS: \
115
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
116
    result = {}
117
    backend._VerifyHvparams(test_what, True, result,
118
                            get_hv_fn=self._GetHypervisor)
119
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
120

    
121
  def testVerifyHypervisors(self):
122
    hvname = constants.HT_XEN_PVM
123
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
124
    all_hvparams = {hvname: hvparams}
125
    test_what = {constants.NV_HYPERVISOR: [hvname]}
126
    result = {}
127
    backend._VerifyHypervisors(
128
        test_what, True, result, all_hvparams=all_hvparams,
129
        get_hv_fn=self._GetHypervisor)
130
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
131

    
132

    
133
def _DefRestrictedCmdOwner():
134
  return (os.getuid(), os.getgid())
135

    
136

    
137
class TestVerifyRestrictedCmdName(unittest.TestCase):
138
  def testAcceptableName(self):
139
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
140
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
141
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
142
        self.assertTrue(status)
143
        self.assertTrue(msg is None)
144

    
145
  def testEmptyAndSpace(self):
146
    for i in ["", " ", "\t", "\n"]:
147
      (status, msg) = backend._VerifyRestrictedCmdName(i)
148
      self.assertFalse(status)
149
      self.assertEqual(msg, "Missing command name")
150

    
151
  def testNameWithSlashes(self):
152
    for i in ["/", "./foo", "../moo", "some/name"]:
153
      (status, msg) = backend._VerifyRestrictedCmdName(i)
154
      self.assertFalse(status)
155
      self.assertEqual(msg, "Invalid command name")
156

    
157
  def testForbiddenCharacters(self):
158
    for i in ["#", ".", "..", "bash -c ls", "'"]:
159
      (status, msg) = backend._VerifyRestrictedCmdName(i)
160
      self.assertFalse(status)
161
      self.assertEqual(msg, "Command name contains forbidden characters")
162

    
163

    
164
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
165
  def setUp(self):
166
    self.tmpdir = tempfile.mkdtemp()
167

    
168
  def tearDown(self):
169
    shutil.rmtree(self.tmpdir)
170

    
171
  def testCanNotStat(self):
172
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
173
    self.assertFalse(os.path.exists(tmpname))
174
    (status, msg) = \
175
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
176
    self.assertFalse(status)
177
    self.assertTrue(msg.startswith("Can't stat(2) '"))
178

    
179
  def testTooPermissive(self):
180
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
181
    os.mkdir(tmpname)
182

    
183
    for mode in [0777, 0706, 0760, 0722]:
184
      os.chmod(tmpname, mode)
185
      self.assertTrue(os.path.isdir(tmpname))
186
      (status, msg) = \
187
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
188
      self.assertFalse(status)
189
      self.assertTrue(msg.startswith("Permissions on '"))
190

    
191
  def testNoDirectory(self):
192
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
193
    utils.WriteFile(tmpname, data="empty\n")
194
    self.assertTrue(os.path.isfile(tmpname))
195
    (status, msg) = \
196
      backend._VerifyRestrictedCmdDirectory(tmpname,
197
                                            _owner=_DefRestrictedCmdOwner())
198
    self.assertFalse(status)
199
    self.assertTrue(msg.endswith("is not a directory"))
200

    
201
  def testNormal(self):
202
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
203
    os.mkdir(tmpname)
204
    self.assertTrue(os.path.isdir(tmpname))
205
    (status, msg) = \
206
      backend._VerifyRestrictedCmdDirectory(tmpname,
207
                                            _owner=_DefRestrictedCmdOwner())
208
    self.assertTrue(status)
209
    self.assertTrue(msg is None)
210

    
211

    
212
class TestVerifyRestrictedCmd(unittest.TestCase):
213
  def setUp(self):
214
    self.tmpdir = tempfile.mkdtemp()
215

    
216
  def tearDown(self):
217
    shutil.rmtree(self.tmpdir)
218

    
219
  def testCanNotStat(self):
220
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
221
    self.assertFalse(os.path.exists(tmpname))
222
    (status, msg) = \
223
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
224
                                   _owner=NotImplemented)
225
    self.assertFalse(status)
226
    self.assertTrue(msg.startswith("Can't stat(2) '"))
227

    
228
  def testNotExecutable(self):
229
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
230
    utils.WriteFile(tmpname, data="empty\n")
231
    (status, msg) = \
232
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
233
                                   _owner=_DefRestrictedCmdOwner())
234
    self.assertFalse(status)
235
    self.assertTrue(msg.startswith("access(2) thinks '"))
236

    
237
  def testExecutable(self):
238
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
239
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
240
    (status, executable) = \
241
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
242
                                   _owner=_DefRestrictedCmdOwner())
243
    self.assertTrue(status)
244
    self.assertEqual(executable, tmpname)
245

    
246

    
247
class TestPrepareRestrictedCmd(unittest.TestCase):
248
  _TEST_PATH = "/tmp/some/test/path"
249

    
250
  def testDirFails(self):
251
    def fn(path):
252
      self.assertEqual(path, self._TEST_PATH)
253
      return (False, "test error 31420")
254

    
255
    (status, msg) = \
256
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
257
                                    _verify_dir=fn,
258
                                    _verify_name=NotImplemented,
259
                                    _verify_cmd=NotImplemented)
260
    self.assertFalse(status)
261
    self.assertEqual(msg, "test error 31420")
262

    
263
  def testNameFails(self):
264
    def fn(cmd):
265
      self.assertEqual(cmd, "cmd4617")
266
      return (False, "test error 591")
267

    
268
    (status, msg) = \
269
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
270
                                    _verify_dir=lambda _: (True, None),
271
                                    _verify_name=fn,
272
                                    _verify_cmd=NotImplemented)
273
    self.assertFalse(status)
274
    self.assertEqual(msg, "test error 591")
275

    
276
  def testCommandFails(self):
277
    def fn(path, cmd):
278
      self.assertEqual(path, self._TEST_PATH)
279
      self.assertEqual(cmd, "cmd17577")
280
      return (False, "test error 25524")
281

    
282
    (status, msg) = \
283
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
284
                                    _verify_dir=lambda _: (True, None),
285
                                    _verify_name=lambda _: (True, None),
286
                                    _verify_cmd=fn)
287
    self.assertFalse(status)
288
    self.assertEqual(msg, "test error 25524")
289

    
290
  def testSuccess(self):
291
    def fn(path, cmd):
292
      return (True, utils.PathJoin(path, cmd))
293

    
294
    (status, executable) = \
295
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
296
                                    _verify_dir=lambda _: (True, None),
297
                                    _verify_name=lambda _: (True, None),
298
                                    _verify_cmd=fn)
299
    self.assertTrue(status)
300
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
301

    
302

    
303
def _SleepForRestrictedCmd(duration):
304
  assert duration > 5
305

    
306

    
307
def _GenericRestrictedCmdError(cmd):
308
  return "Executing command '%s' failed" % cmd
309

    
310

    
311
class TestRunRestrictedCmd(unittest.TestCase):
312
  def setUp(self):
313
    self.tmpdir = tempfile.mkdtemp()
314

    
315
  def tearDown(self):
316
    shutil.rmtree(self.tmpdir)
317

    
318
  def testNonExistantLockDirectory(self):
319
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
320
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
321
    self.assertFalse(os.path.exists(lockfile))
322
    self.assertRaises(backend.RPCFail,
323
                      backend.RunRestrictedCmd, "test",
324
                      _lock_timeout=NotImplemented,
325
                      _lock_file=lockfile,
326
                      _path=NotImplemented,
327
                      _sleep_fn=sleep_fn,
328
                      _prepare_fn=NotImplemented,
329
                      _runcmd_fn=NotImplemented,
330
                      _enabled=True)
331
    self.assertEqual(sleep_fn.Count(), 1)
332

    
333
  @staticmethod
334
  def _TryLock(lockfile):
335
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
336

    
337
    result = False
338
    try:
339
      backend.RunRestrictedCmd("test22717",
340
                               _lock_timeout=0.1,
341
                               _lock_file=lockfile,
342
                               _path=NotImplemented,
343
                               _sleep_fn=sleep_fn,
344
                               _prepare_fn=NotImplemented,
345
                               _runcmd_fn=NotImplemented,
346
                               _enabled=True)
347
    except backend.RPCFail, err:
348
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
349
             "Did not fail with generic error message"
350
      result = True
351

    
352
    assert sleep_fn.Count() == 1
353

    
354
    return result
355

    
356
  def testLockHeldByOtherProcess(self):
357
    lockfile = utils.PathJoin(self.tmpdir, "lock")
358

    
359
    lock = utils.FileLock.Open(lockfile)
360
    lock.Exclusive(blocking=True, timeout=1.0)
361
    try:
362
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
363
    finally:
364
      lock.Close()
365

    
366
  @staticmethod
367
  def _PrepareRaisingException(path, cmd):
368
    assert cmd == "test23122"
369
    raise Exception("test")
370

    
371
  def testPrepareRaisesException(self):
372
    lockfile = utils.PathJoin(self.tmpdir, "lock")
373

    
374
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
375
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
376

    
377
    try:
378
      backend.RunRestrictedCmd("test23122",
379
                               _lock_timeout=1.0, _lock_file=lockfile,
380
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
381
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
382
                               _enabled=True)
383
    except backend.RPCFail, err:
384
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
385
    else:
386
      self.fail("Didn't fail")
387

    
388
    self.assertEqual(sleep_fn.Count(), 1)
389
    self.assertEqual(prepare_fn.Count(), 1)
390

    
391
  @staticmethod
392
  def _PrepareFails(path, cmd):
393
    assert cmd == "test29327"
394
    return ("some error message", None)
395

    
396
  def testPrepareFails(self):
397
    lockfile = utils.PathJoin(self.tmpdir, "lock")
398

    
399
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
400
    prepare_fn = testutils.CallCounter(self._PrepareFails)
401

    
402
    try:
403
      backend.RunRestrictedCmd("test29327",
404
                               _lock_timeout=1.0, _lock_file=lockfile,
405
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
406
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
407
                               _enabled=True)
408
    except backend.RPCFail, err:
409
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
410
    else:
411
      self.fail("Didn't fail")
412

    
413
    self.assertEqual(sleep_fn.Count(), 1)
414
    self.assertEqual(prepare_fn.Count(), 1)
415

    
416
  @staticmethod
417
  def _SuccessfulPrepare(path, cmd):
418
    return (True, utils.PathJoin(path, cmd))
419

    
420
  def testRunCmdFails(self):
421
    lockfile = utils.PathJoin(self.tmpdir, "lock")
422

    
423
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
424
           postfork_fn=NotImplemented):
425
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
426
      self.assertEqual(env, {})
427
      self.assertTrue(reset_env)
428
      self.assertTrue(callable(postfork_fn))
429

    
430
      trylock = utils.FileLock.Open(lockfile)
431
      try:
432
        # See if lockfile is still held
433
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
434

    
435
        # Call back to release lock
436
        postfork_fn(NotImplemented)
437

    
438
        # See if lockfile can be acquired
439
        trylock.Exclusive(blocking=False)
440
      finally:
441
        trylock.Close()
442

    
443
      # Simulate a failed command
444
      return utils.RunResult(constants.EXIT_FAILURE, None,
445
                             "stdout", "stderr406328567",
446
                             utils.ShellQuoteArgs(args),
447
                             NotImplemented, NotImplemented)
448

    
449
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
450
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
451
    runcmd_fn = testutils.CallCounter(fn)
452

    
453
    try:
454
      backend.RunRestrictedCmd("test3079",
455
                               _lock_timeout=1.0, _lock_file=lockfile,
456
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
457
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
458
                               _enabled=True)
459
    except backend.RPCFail, err:
460
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
461
                                          " failed:"))
462
      self.assertTrue("stderr406328567" in str(err),
463
                      msg="Error did not include output")
464
    else:
465
      self.fail("Didn't fail")
466

    
467
    self.assertEqual(sleep_fn.Count(), 0)
468
    self.assertEqual(prepare_fn.Count(), 1)
469
    self.assertEqual(runcmd_fn.Count(), 1)
470

    
471
  def testRunCmdSucceeds(self):
472
    lockfile = utils.PathJoin(self.tmpdir, "lock")
473

    
474
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
475
           postfork_fn=NotImplemented):
476
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
477
      self.assertEqual(env, {})
478
      self.assertTrue(reset_env)
479

    
480
      # Call back to release lock
481
      postfork_fn(NotImplemented)
482

    
483
      # Simulate a successful command
484
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
485
                             utils.ShellQuoteArgs(args),
486
                             NotImplemented, NotImplemented)
487

    
488
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
489
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
490
    runcmd_fn = testutils.CallCounter(fn)
491

    
492
    result = backend.RunRestrictedCmd("test5667",
493
                                      _lock_timeout=1.0, _lock_file=lockfile,
494
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
495
                                      _sleep_fn=sleep_fn,
496
                                      _prepare_fn=prepare_fn,
497
                                      _enabled=True)
498
    self.assertEqual(result, "stdout14463")
499

    
500
    self.assertEqual(sleep_fn.Count(), 0)
501
    self.assertEqual(prepare_fn.Count(), 1)
502
    self.assertEqual(runcmd_fn.Count(), 1)
503

    
504
  def testCommandsDisabled(self):
505
    try:
506
      backend.RunRestrictedCmd("test",
507
                               _lock_timeout=NotImplemented,
508
                               _lock_file=NotImplemented,
509
                               _path=NotImplemented,
510
                               _sleep_fn=NotImplemented,
511
                               _prepare_fn=NotImplemented,
512
                               _runcmd_fn=NotImplemented,
513
                               _enabled=False)
514
    except backend.RPCFail, err:
515
      self.assertEqual(str(err),
516
                       "Restricted commands disabled at configure time")
517
    else:
518
      self.fail("Did not raise exception")
519

    
520

    
521
class TestSetWatcherPause(unittest.TestCase):
522
  def setUp(self):
523
    self.tmpdir = tempfile.mkdtemp()
524
    self.filename = utils.PathJoin(self.tmpdir, "pause")
525

    
526
  def tearDown(self):
527
    shutil.rmtree(self.tmpdir)
528

    
529
  def testUnsetNonExisting(self):
530
    self.assertFalse(os.path.exists(self.filename))
531
    backend.SetWatcherPause(None, _filename=self.filename)
532
    self.assertFalse(os.path.exists(self.filename))
533

    
534
  def testSetNonNumeric(self):
535
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
536
      self.assertFalse(os.path.exists(self.filename))
537

    
538
      try:
539
        backend.SetWatcherPause(i, _filename=self.filename)
540
      except backend.RPCFail, err:
541
        self.assertEqual(str(err), "Duration must be numeric")
542
      else:
543
        self.fail("Did not raise exception")
544

    
545
      self.assertFalse(os.path.exists(self.filename))
546

    
547
  def testSet(self):
548
    self.assertFalse(os.path.exists(self.filename))
549

    
550
    for i in range(10):
551
      backend.SetWatcherPause(i, _filename=self.filename)
552
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
553
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
554

    
555

    
556
class TestGetBlockDevSymlinkPath(unittest.TestCase):
557
  def setUp(self):
558
    self.tmpdir = tempfile.mkdtemp()
559

    
560
  def tearDown(self):
561
    shutil.rmtree(self.tmpdir)
562

    
563
  def _Test(self, name, idx):
564
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
565
                                                     _dir=self.tmpdir),
566
                     ("%s/%s%s%s" % (self.tmpdir, name,
567
                                     constants.DISK_SEPARATOR, idx)))
568

    
569
  def test(self):
570
    for idx in range(100):
571
      self._Test("inst1.example.com", idx)
572

    
573

    
574
class TestGetInstanceList(unittest.TestCase):
575

    
576
  def setUp(self):
577
    self._test_hv = self._TestHypervisor()
578
    self._test_hv.ListInstances = mock.Mock(
579
      return_value=["instance1", "instance2", "instance3"] )
580

    
581
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
582
    def __init__(self):
583
      hypervisor.hv_base.BaseHypervisor.__init__(self)
584

    
585
  def _GetHypervisor(self, name):
586
    return self._test_hv
587

    
588
  def testHvparams(self):
589
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
590
    hvparams = {constants.HT_FAKE: fake_hvparams}
591
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
592
                            get_hv_fn=self._GetHypervisor)
593
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
594

    
595

    
596
class TestGetHvInfo(unittest.TestCase):
597

    
598
  def setUp(self):
599
    self._test_hv = self._TestHypervisor()
600
    self._test_hv.GetNodeInfo = mock.Mock()
601

    
602
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
603
    def __init__(self):
604
      hypervisor.hv_base.BaseHypervisor.__init__(self)
605

    
606
  def _GetHypervisor(self, name):
607
    return self._test_hv
608

    
609
  def testGetHvInfoAllNone(self):
610
    result = backend._GetHvInfoAll(None)
611
    self.assertTrue(result is None)
612

    
613
  def testGetHvInfoAll(self):
614
    hvname = constants.HT_XEN_PVM
615
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
616
    hv_specs = [(hvname, hvparams)]
617

    
618
    result = backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
619
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
620

    
621

    
622
class TestApplyStorageInfoFunction(unittest.TestCase):
623

    
624
  _STORAGE_KEY = "some_key"
625
  _SOME_ARGS = "some_args"
626

    
627
  def setUp(self):
628
    self.mock_storage_fn = mock.Mock()
629

    
630
  def testApplyValidStorageType(self):
631
    storage_type = constants.ST_LVM_VG
632
    backend._STORAGE_TYPE_INFO_FN = {
633
        storage_type: self.mock_storage_fn
634
      }
635

    
636
    backend._ApplyStorageInfoFunction(
637
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
638

    
639
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
640

    
641
  def testApplyInValidStorageType(self):
642
    storage_type = "invalid_storage_type"
643
    backend._STORAGE_TYPE_INFO_FN = {}
644

    
645
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
646
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
647

    
648
  def testApplyNotImplementedStorageType(self):
649
    storage_type = "not_implemented_storage_type"
650
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
651

    
652
    self.assertRaises(NotImplementedError,
653
                      backend._ApplyStorageInfoFunction,
654
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
655

    
656

    
657
class TestGetNodeInfo(unittest.TestCase):
658

    
659
  _SOME_RESULT = None
660

    
661
  def testApplyStorageInfoFunction(self):
662
    excl_storage_flag = False
663
    backend._ApplyStorageInfoFunction = mock.Mock(
664
        return_value=self._SOME_RESULT)
665
    storage_units = [(st, st + "_key") for st in
666
                     constants.VALID_STORAGE_TYPES]
667

    
668
    backend.GetNodeInfo(storage_units, None, excl_storage_flag)
669

    
670
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
671
    self.assertEqual(len(constants.VALID_STORAGE_TYPES), len(call_args_list))
672
    for call in call_args_list:
673
      storage_type, storage_key, excl_storage = call[0]
674
      self.assertEqual(storage_type + "_key", storage_key)
675
      self.assertTrue(storage_type in constants.VALID_STORAGE_TYPES)
676

    
677

    
678
if __name__ == "__main__":
679
  testutils.GanetiTestProgram()