Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ a9f33339

History | View | Annotate | Download (29.1 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import mock
25
import os
26
import shutil
27
import tempfile
28
import testutils
29
import unittest
30

    
31
from ganeti import backend
32
from ganeti import constants
33
from ganeti import errors
34
from ganeti import hypervisor
35
from ganeti import netutils
36
from ganeti import objects
37
from ganeti import utils
38

    
39

    
40
class TestX509Certificates(unittest.TestCase):
41
  def setUp(self):
42
    self.tmpdir = tempfile.mkdtemp()
43

    
44
  def tearDown(self):
45
    shutil.rmtree(self.tmpdir)
46

    
47
  def test(self):
48
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
49

    
50
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
51
                                                 backend._X509_CERT_FILE)),
52
                     cert_pem)
53
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
54
                                                  backend._X509_KEY_FILE)))
55

    
56
    (name2, cert_pem2) = \
57
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
58

    
59
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
60
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
61

    
62
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
63

    
64
  def testNonEmpty(self):
65
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
66

    
67
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
68
                    data="Hello World")
69

    
70
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
71
                      name, cryptodir=self.tmpdir)
72

    
73
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
74

    
75

    
76
class TestNodeVerify(testutils.GanetiTestCase):
77

    
78
  def setUp(self):
79
    testutils.GanetiTestCase.setUp(self)
80
    self._mock_hv = None
81

    
82
  def _GetHypervisor(self, hv_name):
83
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
84
    self._mock_hv.ValidateParameters = mock.Mock()
85
    self._mock_hv.Verify = mock.Mock()
86
    return self._mock_hv
87

    
88
  def testMasterIPLocalhost(self):
89
    # this a real functional test, but requires localhost to be reachable
90
    local_data = (netutils.Hostname.GetSysName(),
91
                  constants.IP4_ADDRESS_LOCALHOST)
92
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data}, None, {}, {}, {})
93
    self.failUnless(constants.NV_MASTERIP in result,
94
                    "Master IP data not returned")
95
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
96

    
97
  def testMasterIPUnreachable(self):
98
    # Network 192.0.2.0/24 is reserved for test/documentation as per
99
    # RFC 5737
100
    bad_data =  ("master.example.com", "192.0.2.1")
101
    # we just test that whatever TcpPing returns, VerifyNode returns too
102
    netutils.TcpPing = lambda a, b, source=None: False
103
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data}, None, {}, {}, {})
104
    self.failUnless(constants.NV_MASTERIP in result,
105
                    "Master IP data not returned")
106
    self.failIf(result[constants.NV_MASTERIP],
107
                "Result from netutils.TcpPing corrupted")
108

    
109
  def testVerifyHvparams(self):
110
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
111
    test_what = {constants.NV_HVPARAMS: \
112
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
113
    result = {}
114
    backend._VerifyHvparams(test_what, True, result,
115
                            get_hv_fn=self._GetHypervisor)
116
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
117

    
118
  def testVerifyHypervisors(self):
119
    hvname = constants.HT_XEN_PVM
120
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
121
    all_hvparams = {hvname: hvparams}
122
    test_what = {constants.NV_HYPERVISOR: [hvname]}
123
    result = {}
124
    backend._VerifyHypervisors(
125
        test_what, True, result, all_hvparams=all_hvparams,
126
        get_hv_fn=self._GetHypervisor)
127
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
128

    
129

    
130
def _DefRestrictedCmdOwner():
131
  return (os.getuid(), os.getgid())
132

    
133

    
134
class TestVerifyRestrictedCmdName(unittest.TestCase):
135
  def testAcceptableName(self):
136
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
137
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
138
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
139
        self.assertTrue(status)
140
        self.assertTrue(msg is None)
141

    
142
  def testEmptyAndSpace(self):
143
    for i in ["", " ", "\t", "\n"]:
144
      (status, msg) = backend._VerifyRestrictedCmdName(i)
145
      self.assertFalse(status)
146
      self.assertEqual(msg, "Missing command name")
147

    
148
  def testNameWithSlashes(self):
149
    for i in ["/", "./foo", "../moo", "some/name"]:
150
      (status, msg) = backend._VerifyRestrictedCmdName(i)
151
      self.assertFalse(status)
152
      self.assertEqual(msg, "Invalid command name")
153

    
154
  def testForbiddenCharacters(self):
155
    for i in ["#", ".", "..", "bash -c ls", "'"]:
156
      (status, msg) = backend._VerifyRestrictedCmdName(i)
157
      self.assertFalse(status)
158
      self.assertEqual(msg, "Command name contains forbidden characters")
159

    
160

    
161
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
162
  def setUp(self):
163
    self.tmpdir = tempfile.mkdtemp()
164

    
165
  def tearDown(self):
166
    shutil.rmtree(self.tmpdir)
167

    
168
  def testCanNotStat(self):
169
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
170
    self.assertFalse(os.path.exists(tmpname))
171
    (status, msg) = \
172
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
173
    self.assertFalse(status)
174
    self.assertTrue(msg.startswith("Can't stat(2) '"))
175

    
176
  def testTooPermissive(self):
177
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
178
    os.mkdir(tmpname)
179

    
180
    for mode in [0777, 0706, 0760, 0722]:
181
      os.chmod(tmpname, mode)
182
      self.assertTrue(os.path.isdir(tmpname))
183
      (status, msg) = \
184
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
185
      self.assertFalse(status)
186
      self.assertTrue(msg.startswith("Permissions on '"))
187

    
188
  def testNoDirectory(self):
189
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
190
    utils.WriteFile(tmpname, data="empty\n")
191
    self.assertTrue(os.path.isfile(tmpname))
192
    (status, msg) = \
193
      backend._VerifyRestrictedCmdDirectory(tmpname,
194
                                            _owner=_DefRestrictedCmdOwner())
195
    self.assertFalse(status)
196
    self.assertTrue(msg.endswith("is not a directory"))
197

    
198
  def testNormal(self):
199
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
200
    os.mkdir(tmpname)
201
    os.chmod(tmpname, 0755)
202
    self.assertTrue(os.path.isdir(tmpname))
203
    (status, msg) = \
204
      backend._VerifyRestrictedCmdDirectory(tmpname,
205
                                            _owner=_DefRestrictedCmdOwner())
206
    self.assertTrue(status)
207
    self.assertTrue(msg is None)
208

    
209

    
210
class TestVerifyRestrictedCmd(unittest.TestCase):
211
  def setUp(self):
212
    self.tmpdir = tempfile.mkdtemp()
213

    
214
  def tearDown(self):
215
    shutil.rmtree(self.tmpdir)
216

    
217
  def testCanNotStat(self):
218
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
219
    self.assertFalse(os.path.exists(tmpname))
220
    (status, msg) = \
221
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
222
                                   _owner=NotImplemented)
223
    self.assertFalse(status)
224
    self.assertTrue(msg.startswith("Can't stat(2) '"))
225

    
226
  def testNotExecutable(self):
227
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
228
    utils.WriteFile(tmpname, data="empty\n")
229
    (status, msg) = \
230
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
231
                                   _owner=_DefRestrictedCmdOwner())
232
    self.assertFalse(status)
233
    self.assertTrue(msg.startswith("access(2) thinks '"))
234

    
235
  def testExecutable(self):
236
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
237
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
238
    (status, executable) = \
239
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
240
                                   _owner=_DefRestrictedCmdOwner())
241
    self.assertTrue(status)
242
    self.assertEqual(executable, tmpname)
243

    
244

    
245
class TestPrepareRestrictedCmd(unittest.TestCase):
246
  _TEST_PATH = "/tmp/some/test/path"
247

    
248
  def testDirFails(self):
249
    def fn(path):
250
      self.assertEqual(path, self._TEST_PATH)
251
      return (False, "test error 31420")
252

    
253
    (status, msg) = \
254
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
255
                                    _verify_dir=fn,
256
                                    _verify_name=NotImplemented,
257
                                    _verify_cmd=NotImplemented)
258
    self.assertFalse(status)
259
    self.assertEqual(msg, "test error 31420")
260

    
261
  def testNameFails(self):
262
    def fn(cmd):
263
      self.assertEqual(cmd, "cmd4617")
264
      return (False, "test error 591")
265

    
266
    (status, msg) = \
267
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
268
                                    _verify_dir=lambda _: (True, None),
269
                                    _verify_name=fn,
270
                                    _verify_cmd=NotImplemented)
271
    self.assertFalse(status)
272
    self.assertEqual(msg, "test error 591")
273

    
274
  def testCommandFails(self):
275
    def fn(path, cmd):
276
      self.assertEqual(path, self._TEST_PATH)
277
      self.assertEqual(cmd, "cmd17577")
278
      return (False, "test error 25524")
279

    
280
    (status, msg) = \
281
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
282
                                    _verify_dir=lambda _: (True, None),
283
                                    _verify_name=lambda _: (True, None),
284
                                    _verify_cmd=fn)
285
    self.assertFalse(status)
286
    self.assertEqual(msg, "test error 25524")
287

    
288
  def testSuccess(self):
289
    def fn(path, cmd):
290
      return (True, utils.PathJoin(path, cmd))
291

    
292
    (status, executable) = \
293
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
294
                                    _verify_dir=lambda _: (True, None),
295
                                    _verify_name=lambda _: (True, None),
296
                                    _verify_cmd=fn)
297
    self.assertTrue(status)
298
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
299

    
300

    
301
def _SleepForRestrictedCmd(duration):
302
  assert duration > 5
303

    
304

    
305
def _GenericRestrictedCmdError(cmd):
306
  return "Executing command '%s' failed" % cmd
307

    
308

    
309
class TestRunRestrictedCmd(unittest.TestCase):
310
  def setUp(self):
311
    self.tmpdir = tempfile.mkdtemp()
312

    
313
  def tearDown(self):
314
    shutil.rmtree(self.tmpdir)
315

    
316
  def testNonExistantLockDirectory(self):
317
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
318
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
319
    self.assertFalse(os.path.exists(lockfile))
320
    self.assertRaises(backend.RPCFail,
321
                      backend.RunRestrictedCmd, "test",
322
                      _lock_timeout=NotImplemented,
323
                      _lock_file=lockfile,
324
                      _path=NotImplemented,
325
                      _sleep_fn=sleep_fn,
326
                      _prepare_fn=NotImplemented,
327
                      _runcmd_fn=NotImplemented,
328
                      _enabled=True)
329
    self.assertEqual(sleep_fn.Count(), 1)
330

    
331
  @staticmethod
332
  def _TryLock(lockfile):
333
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
334

    
335
    result = False
336
    try:
337
      backend.RunRestrictedCmd("test22717",
338
                               _lock_timeout=0.1,
339
                               _lock_file=lockfile,
340
                               _path=NotImplemented,
341
                               _sleep_fn=sleep_fn,
342
                               _prepare_fn=NotImplemented,
343
                               _runcmd_fn=NotImplemented,
344
                               _enabled=True)
345
    except backend.RPCFail, err:
346
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
347
             "Did not fail with generic error message"
348
      result = True
349

    
350
    assert sleep_fn.Count() == 1
351

    
352
    return result
353

    
354
  def testLockHeldByOtherProcess(self):
355
    lockfile = utils.PathJoin(self.tmpdir, "lock")
356

    
357
    lock = utils.FileLock.Open(lockfile)
358
    lock.Exclusive(blocking=True, timeout=1.0)
359
    try:
360
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
361
    finally:
362
      lock.Close()
363

    
364
  @staticmethod
365
  def _PrepareRaisingException(path, cmd):
366
    assert cmd == "test23122"
367
    raise Exception("test")
368

    
369
  def testPrepareRaisesException(self):
370
    lockfile = utils.PathJoin(self.tmpdir, "lock")
371

    
372
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
373
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
374

    
375
    try:
376
      backend.RunRestrictedCmd("test23122",
377
                               _lock_timeout=1.0, _lock_file=lockfile,
378
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
379
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
380
                               _enabled=True)
381
    except backend.RPCFail, err:
382
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
383
    else:
384
      self.fail("Didn't fail")
385

    
386
    self.assertEqual(sleep_fn.Count(), 1)
387
    self.assertEqual(prepare_fn.Count(), 1)
388

    
389
  @staticmethod
390
  def _PrepareFails(path, cmd):
391
    assert cmd == "test29327"
392
    return ("some error message", None)
393

    
394
  def testPrepareFails(self):
395
    lockfile = utils.PathJoin(self.tmpdir, "lock")
396

    
397
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
398
    prepare_fn = testutils.CallCounter(self._PrepareFails)
399

    
400
    try:
401
      backend.RunRestrictedCmd("test29327",
402
                               _lock_timeout=1.0, _lock_file=lockfile,
403
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
404
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
405
                               _enabled=True)
406
    except backend.RPCFail, err:
407
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
408
    else:
409
      self.fail("Didn't fail")
410

    
411
    self.assertEqual(sleep_fn.Count(), 1)
412
    self.assertEqual(prepare_fn.Count(), 1)
413

    
414
  @staticmethod
415
  def _SuccessfulPrepare(path, cmd):
416
    return (True, utils.PathJoin(path, cmd))
417

    
418
  def testRunCmdFails(self):
419
    lockfile = utils.PathJoin(self.tmpdir, "lock")
420

    
421
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
422
           postfork_fn=NotImplemented):
423
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
424
      self.assertEqual(env, {})
425
      self.assertTrue(reset_env)
426
      self.assertTrue(callable(postfork_fn))
427

    
428
      trylock = utils.FileLock.Open(lockfile)
429
      try:
430
        # See if lockfile is still held
431
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
432

    
433
        # Call back to release lock
434
        postfork_fn(NotImplemented)
435

    
436
        # See if lockfile can be acquired
437
        trylock.Exclusive(blocking=False)
438
      finally:
439
        trylock.Close()
440

    
441
      # Simulate a failed command
442
      return utils.RunResult(constants.EXIT_FAILURE, None,
443
                             "stdout", "stderr406328567",
444
                             utils.ShellQuoteArgs(args),
445
                             NotImplemented, NotImplemented)
446

    
447
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
448
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
449
    runcmd_fn = testutils.CallCounter(fn)
450

    
451
    try:
452
      backend.RunRestrictedCmd("test3079",
453
                               _lock_timeout=1.0, _lock_file=lockfile,
454
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
455
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
456
                               _enabled=True)
457
    except backend.RPCFail, err:
458
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
459
                                          " failed:"))
460
      self.assertTrue("stderr406328567" in str(err),
461
                      msg="Error did not include output")
462
    else:
463
      self.fail("Didn't fail")
464

    
465
    self.assertEqual(sleep_fn.Count(), 0)
466
    self.assertEqual(prepare_fn.Count(), 1)
467
    self.assertEqual(runcmd_fn.Count(), 1)
468

    
469
  def testRunCmdSucceeds(self):
470
    lockfile = utils.PathJoin(self.tmpdir, "lock")
471

    
472
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
473
           postfork_fn=NotImplemented):
474
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
475
      self.assertEqual(env, {})
476
      self.assertTrue(reset_env)
477

    
478
      # Call back to release lock
479
      postfork_fn(NotImplemented)
480

    
481
      # Simulate a successful command
482
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
483
                             utils.ShellQuoteArgs(args),
484
                             NotImplemented, NotImplemented)
485

    
486
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
487
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
488
    runcmd_fn = testutils.CallCounter(fn)
489

    
490
    result = backend.RunRestrictedCmd("test5667",
491
                                      _lock_timeout=1.0, _lock_file=lockfile,
492
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
493
                                      _sleep_fn=sleep_fn,
494
                                      _prepare_fn=prepare_fn,
495
                                      _enabled=True)
496
    self.assertEqual(result, "stdout14463")
497

    
498
    self.assertEqual(sleep_fn.Count(), 0)
499
    self.assertEqual(prepare_fn.Count(), 1)
500
    self.assertEqual(runcmd_fn.Count(), 1)
501

    
502
  def testCommandsDisabled(self):
503
    try:
504
      backend.RunRestrictedCmd("test",
505
                               _lock_timeout=NotImplemented,
506
                               _lock_file=NotImplemented,
507
                               _path=NotImplemented,
508
                               _sleep_fn=NotImplemented,
509
                               _prepare_fn=NotImplemented,
510
                               _runcmd_fn=NotImplemented,
511
                               _enabled=False)
512
    except backend.RPCFail, err:
513
      self.assertEqual(str(err),
514
                       "Restricted commands disabled at configure time")
515
    else:
516
      self.fail("Did not raise exception")
517

    
518

    
519
class TestSetWatcherPause(unittest.TestCase):
520
  def setUp(self):
521
    self.tmpdir = tempfile.mkdtemp()
522
    self.filename = utils.PathJoin(self.tmpdir, "pause")
523

    
524
  def tearDown(self):
525
    shutil.rmtree(self.tmpdir)
526

    
527
  def testUnsetNonExisting(self):
528
    self.assertFalse(os.path.exists(self.filename))
529
    backend.SetWatcherPause(None, _filename=self.filename)
530
    self.assertFalse(os.path.exists(self.filename))
531

    
532
  def testSetNonNumeric(self):
533
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
534
      self.assertFalse(os.path.exists(self.filename))
535

    
536
      try:
537
        backend.SetWatcherPause(i, _filename=self.filename)
538
      except backend.RPCFail, err:
539
        self.assertEqual(str(err), "Duration must be numeric")
540
      else:
541
        self.fail("Did not raise exception")
542

    
543
      self.assertFalse(os.path.exists(self.filename))
544

    
545
  def testSet(self):
546
    self.assertFalse(os.path.exists(self.filename))
547

    
548
    for i in range(10):
549
      backend.SetWatcherPause(i, _filename=self.filename)
550
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
551
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
552

    
553

    
554
class TestGetBlockDevSymlinkPath(unittest.TestCase):
555
  def setUp(self):
556
    self.tmpdir = tempfile.mkdtemp()
557

    
558
  def tearDown(self):
559
    shutil.rmtree(self.tmpdir)
560

    
561
  def _Test(self, name, idx):
562
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
563
                                                     _dir=self.tmpdir),
564
                     ("%s/%s%s%s" % (self.tmpdir, name,
565
                                     constants.DISK_SEPARATOR, idx)))
566

    
567
  def test(self):
568
    for idx in range(100):
569
      self._Test("inst1.example.com", idx)
570

    
571

    
572
class TestGetInstanceList(unittest.TestCase):
573

    
574
  def setUp(self):
575
    self._test_hv = self._TestHypervisor()
576
    self._test_hv.ListInstances = mock.Mock(
577
      return_value=["instance1", "instance2", "instance3"] )
578

    
579
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
580
    def __init__(self):
581
      hypervisor.hv_base.BaseHypervisor.__init__(self)
582

    
583
  def _GetHypervisor(self, name):
584
    return self._test_hv
585

    
586
  def testHvparams(self):
587
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
588
    hvparams = {constants.HT_FAKE: fake_hvparams}
589
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
590
                            get_hv_fn=self._GetHypervisor)
591
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
592

    
593

    
594
class TestInstanceConsoleInfo(unittest.TestCase):
595

    
596
  def setUp(self):
597
    self._test_hv_a = self._TestHypervisor()
598
    self._test_hv_a.GetInstanceConsole = mock.Mock(
599
      return_value = objects.InstanceConsole(instance="inst", kind="aHy")
600
    )
601
    self._test_hv_b = self._TestHypervisor()
602
    self._test_hv_b.GetInstanceConsole = mock.Mock(
603
      return_value = objects.InstanceConsole(instance="inst", kind="bHy")
604
    )
605

    
606
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
607
    def __init__(self):
608
      hypervisor.hv_base.BaseHypervisor.__init__(self)
609

    
610
  def _GetHypervisor(self, name):
611
    if name == "a":
612
      return self._test_hv_a
613
    else:
614
      return self._test_hv_b
615

    
616
  def testRightHypervisor(self):
617
    dictMaker = lambda hyName: {
618
      "instance":{"hypervisor":hyName},
619
      "node":{},
620
      "hvParams":{},
621
      "beParams":{},
622
    }
623

    
624
    call = {
625
      'i1':dictMaker("a"),
626
      'i2':dictMaker("b"),
627
    }
628

    
629
    res = backend.GetInstanceConsoleInfo(call, get_hv_fn=self._GetHypervisor)
630

    
631
    self.assertTrue(res["i1"]["kind"] == "aHy")
632
    self.assertTrue(res["i2"]["kind"] == "bHy")
633

    
634

    
635
class TestGetHvInfo(unittest.TestCase):
636

    
637
  def setUp(self):
638
    self._test_hv = self._TestHypervisor()
639
    self._test_hv.GetNodeInfo = mock.Mock()
640

    
641
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
642
    def __init__(self):
643
      hypervisor.hv_base.BaseHypervisor.__init__(self)
644

    
645
  def _GetHypervisor(self, name):
646
    return self._test_hv
647

    
648
  def testGetHvInfoAllNone(self):
649
    result = backend._GetHvInfoAll(None)
650
    self.assertTrue(result is None)
651

    
652
  def testGetHvInfoAll(self):
653
    hvname = constants.HT_XEN_PVM
654
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
655
    hv_specs = [(hvname, hvparams)]
656

    
657
    backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
658
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
659

    
660

    
661
class TestApplyStorageInfoFunction(unittest.TestCase):
662

    
663
  _STORAGE_KEY = "some_key"
664
  _SOME_ARGS = ["some_args"]
665

    
666
  def setUp(self):
667
    self.mock_storage_fn = mock.Mock()
668

    
669
  def testApplyValidStorageType(self):
670
    storage_type = constants.ST_LVM_VG
671
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
672
    backend._STORAGE_TYPE_INFO_FN = {
673
        storage_type: self.mock_storage_fn
674
      }
675

    
676
    backend._ApplyStorageInfoFunction(
677
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
678

    
679
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
680
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
681

    
682
  def testApplyInValidStorageType(self):
683
    storage_type = "invalid_storage_type"
684
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
685
    backend._STORAGE_TYPE_INFO_FN = {}
686

    
687
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
688
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
689
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
690

    
691
  def testApplyNotImplementedStorageType(self):
692
    storage_type = "not_implemented_storage_type"
693
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
694
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
695

    
696
    self.assertRaises(NotImplementedError,
697
                      backend._ApplyStorageInfoFunction,
698
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
699
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
700

    
701

    
702
class TestGetLvmVgSpaceInfo(unittest.TestCase):
703

    
704
  def testValid(self):
705
    path = "somepath"
706
    excl_stor = True
707
    orig_fn = backend._GetVgInfo
708
    backend._GetVgInfo = mock.Mock()
709
    backend._GetLvmVgSpaceInfo(path, [excl_stor])
710
    backend._GetVgInfo.assert_called_with(path, excl_stor)
711
    backend._GetVgInfo = orig_fn
712

    
713
  def testNoExclStorageNotBool(self):
714
    path = "somepath"
715
    excl_stor = "123"
716
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
717
                      path, [excl_stor])
718

    
719
  def testNoExclStorageNotInList(self):
720
    path = "somepath"
721
    excl_stor = "123"
722
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
723
                      path, excl_stor)
724

    
725
class TestGetLvmPvSpaceInfo(unittest.TestCase):
726

    
727
  def testValid(self):
728
    path = "somepath"
729
    excl_stor = True
730
    orig_fn = backend._GetVgSpindlesInfo
731
    backend._GetVgSpindlesInfo = mock.Mock()
732
    backend._GetLvmPvSpaceInfo(path, [excl_stor])
733
    backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
734
    backend._GetVgSpindlesInfo = orig_fn
735

    
736

    
737
class TestCheckStorageParams(unittest.TestCase):
738

    
739
  def testParamsNone(self):
740
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
741
                      None, NotImplemented)
742

    
743
  def testParamsWrongType(self):
744
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
745
                      "string", NotImplemented)
746

    
747
  def testParamsEmpty(self):
748
    backend._CheckStorageParams([], 0)
749

    
750
  def testParamsValidNumber(self):
751
    backend._CheckStorageParams(["a", True], 2)
752

    
753
  def testParamsInvalidNumber(self):
754
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
755
                      ["b", False], 3)
756

    
757

    
758
class TestGetVgSpindlesInfo(unittest.TestCase):
759

    
760
  def setUp(self):
761
    self.vg_free = 13
762
    self.vg_size = 31
763
    self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
764

    
765
  def testValidInput(self):
766
    name = "myvg"
767
    excl_stor = True
768
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
769
    self.mock_fn.assert_called_with(name)
770
    self.assertEqual(name, result["name"])
771
    self.assertEqual(constants.ST_LVM_PV, result["type"])
772
    self.assertEqual(self.vg_free, result["storage_free"])
773
    self.assertEqual(self.vg_size, result["storage_size"])
774

    
775
  def testNoExclStor(self):
776
    name = "myvg"
777
    excl_stor = False
778
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
779
    self.mock_fn.assert_not_called()
780
    self.assertEqual(name, result["name"])
781
    self.assertEqual(constants.ST_LVM_PV, result["type"])
782
    self.assertEqual(0, result["storage_free"])
783
    self.assertEqual(0, result["storage_size"])
784

    
785

    
786
class TestGetVgSpindlesInfo(unittest.TestCase):
787

    
788
  def testValidInput(self):
789
    self.vg_free = 13
790
    self.vg_size = 31
791
    self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
792
    name = "myvg"
793
    excl_stor = True
794
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
795
    self.mock_fn.assert_called_with([name], excl_stor)
796
    self.assertEqual(name, result["name"])
797
    self.assertEqual(constants.ST_LVM_VG, result["type"])
798
    self.assertEqual(self.vg_free, result["storage_free"])
799
    self.assertEqual(self.vg_size, result["storage_size"])
800

    
801
  def testNoExclStor(self):
802
    name = "myvg"
803
    excl_stor = True
804
    self.mock_fn = mock.Mock(return_value=None)
805
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
806
    self.mock_fn.assert_called_with([name], excl_stor)
807
    self.assertEqual(name, result["name"])
808
    self.assertEqual(constants.ST_LVM_VG, result["type"])
809
    self.assertEqual(None, result["storage_free"])
810
    self.assertEqual(None, result["storage_size"])
811

    
812

    
813
class TestGetNodeInfo(unittest.TestCase):
814

    
815
  _SOME_RESULT = None
816

    
817
  def testApplyStorageInfoFunction(self):
818
    orig_fn = backend._ApplyStorageInfoFunction
819
    backend._ApplyStorageInfoFunction = mock.Mock(
820
        return_value=self._SOME_RESULT)
821
    storage_units = [(st, st + "_key", [st + "_params"]) for st in
822
                     constants.STORAGE_TYPES]
823

    
824
    backend.GetNodeInfo(storage_units, None)
825

    
826
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
827
    self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
828
    for call in call_args_list:
829
      storage_type, storage_key, storage_params = call[0]
830
      self.assertEqual(storage_type + "_key", storage_key)
831
      self.assertEqual([storage_type + "_params"], storage_params)
832
      self.assertTrue(storage_type in constants.STORAGE_TYPES)
833
    backend._ApplyStorageInfoFunction = orig_fn
834

    
835

    
836
class TestSpaceReportingConstants(unittest.TestCase):
837
  """Ensures consistency between STS_REPORT and backend.
838

839
  These tests ensure, that the constant 'STS_REPORT' is consitent
840
  with the implementation of invoking space reporting functions
841
  in backend.py. Once space reporting is available for all types,
842
  the constant can be removed and these tests as well.
843

844
  """
845
  def testAllReportingTypesHaveAReportingFunction(self):
846
    for storage_type in constants.STS_REPORT:
847
      self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
848

    
849
  def testAllNotReportingTypesDoneHaveFunction(self):
850
    non_reporting_types = set(constants.STORAGE_TYPES)\
851
        - set(constants.STS_REPORT)
852
    for storage_type in non_reporting_types:
853
      self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
854

    
855

    
856
if __name__ == "__main__":
857
  testutils.GanetiTestProgram()