Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ 33ffda6c

History | View | Annotate | Download (29.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import mock
25
import os
26
import shutil
27
import tempfile
28
import testutils
29
import unittest
30

    
31
from ganeti import backend
32
from ganeti import constants
33
from ganeti import errors
34
from ganeti import hypervisor
35
from ganeti import netutils
36
from ganeti import objects
37
from ganeti import utils
38

    
39

    
40
class TestX509Certificates(unittest.TestCase):
41
  def setUp(self):
42
    self.tmpdir = tempfile.mkdtemp()
43

    
44
  def tearDown(self):
45
    shutil.rmtree(self.tmpdir)
46

    
47
  def test(self):
48
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
49

    
50
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
51
                                                 backend._X509_CERT_FILE)),
52
                     cert_pem)
53
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
54
                                                  backend._X509_KEY_FILE)))
55

    
56
    (name2, cert_pem2) = \
57
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
58

    
59
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
60
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
61

    
62
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
63

    
64
  def testNonEmpty(self):
65
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
66

    
67
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
68
                    data="Hello World")
69

    
70
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
71
                      name, cryptodir=self.tmpdir)
72

    
73
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
74

    
75

    
76
class TestNodeVerify(testutils.GanetiTestCase):
77

    
78
  def setUp(self):
79
    testutils.GanetiTestCase.setUp(self)
80
    self._mock_hv = None
81

    
82
  def _GetHypervisor(self, hv_name):
83
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
84
    self._mock_hv.ValidateParameters = mock.Mock()
85
    self._mock_hv.Verify = mock.Mock()
86
    return self._mock_hv
87

    
88
  def testMasterIPLocalhost(self):
89
    # this a real functional test, but requires localhost to be reachable
90
    local_data = (netutils.Hostname.GetSysName(),
91
                  constants.IP4_ADDRESS_LOCALHOST)
92
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data},
93
                                None, {}, {}, {})
94
    self.failUnless(constants.NV_MASTERIP in result,
95
                    "Master IP data not returned")
96
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
97

    
98
  def testMasterIPUnreachable(self):
99
    # Network 192.0.2.0/24 is reserved for test/documentation as per
100
    # RFC 5737
101
    bad_data =  ("master.example.com", "192.0.2.1")
102
    # we just test that whatever TcpPing returns, VerifyNode returns too
103
    netutils.TcpPing = lambda a, b, source=None: False
104
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data},
105
                                None, {}, {}, {})
106
    self.failUnless(constants.NV_MASTERIP in result,
107
                    "Master IP data not returned")
108
    self.failIf(result[constants.NV_MASTERIP],
109
                "Result from netutils.TcpPing corrupted")
110

    
111
  def testVerifyHvparams(self):
112
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
113
    test_what = {constants.NV_HVPARAMS: \
114
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
115
    result = {}
116
    backend._VerifyHvparams(test_what, True, result,
117
                            get_hv_fn=self._GetHypervisor)
118
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
119

    
120
  def testVerifyHypervisors(self):
121
    hvname = constants.HT_XEN_PVM
122
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
123
    all_hvparams = {hvname: hvparams}
124
    test_what = {constants.NV_HYPERVISOR: [hvname]}
125
    result = {}
126
    backend._VerifyHypervisors(
127
        test_what, True, result, all_hvparams=all_hvparams,
128
        get_hv_fn=self._GetHypervisor)
129
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
130

    
131

    
132
def _DefRestrictedCmdOwner():
133
  return (os.getuid(), os.getgid())
134

    
135

    
136
class TestVerifyRestrictedCmdName(unittest.TestCase):
137
  def testAcceptableName(self):
138
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
139
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
140
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
141
        self.assertTrue(status)
142
        self.assertTrue(msg is None)
143

    
144
  def testEmptyAndSpace(self):
145
    for i in ["", " ", "\t", "\n"]:
146
      (status, msg) = backend._VerifyRestrictedCmdName(i)
147
      self.assertFalse(status)
148
      self.assertEqual(msg, "Missing command name")
149

    
150
  def testNameWithSlashes(self):
151
    for i in ["/", "./foo", "../moo", "some/name"]:
152
      (status, msg) = backend._VerifyRestrictedCmdName(i)
153
      self.assertFalse(status)
154
      self.assertEqual(msg, "Invalid command name")
155

    
156
  def testForbiddenCharacters(self):
157
    for i in ["#", ".", "..", "bash -c ls", "'"]:
158
      (status, msg) = backend._VerifyRestrictedCmdName(i)
159
      self.assertFalse(status)
160
      self.assertEqual(msg, "Command name contains forbidden characters")
161

    
162

    
163
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
164
  def setUp(self):
165
    self.tmpdir = tempfile.mkdtemp()
166

    
167
  def tearDown(self):
168
    shutil.rmtree(self.tmpdir)
169

    
170
  def testCanNotStat(self):
171
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
172
    self.assertFalse(os.path.exists(tmpname))
173
    (status, msg) = \
174
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
175
    self.assertFalse(status)
176
    self.assertTrue(msg.startswith("Can't stat(2) '"))
177

    
178
  def testTooPermissive(self):
179
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
180
    os.mkdir(tmpname)
181

    
182
    for mode in [0777, 0706, 0760, 0722]:
183
      os.chmod(tmpname, mode)
184
      self.assertTrue(os.path.isdir(tmpname))
185
      (status, msg) = \
186
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
187
      self.assertFalse(status)
188
      self.assertTrue(msg.startswith("Permissions on '"))
189

    
190
  def testNoDirectory(self):
191
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
192
    utils.WriteFile(tmpname, data="empty\n")
193
    self.assertTrue(os.path.isfile(tmpname))
194
    (status, msg) = \
195
      backend._VerifyRestrictedCmdDirectory(tmpname,
196
                                            _owner=_DefRestrictedCmdOwner())
197
    self.assertFalse(status)
198
    self.assertTrue(msg.endswith("is not a directory"))
199

    
200
  def testNormal(self):
201
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
202
    os.mkdir(tmpname)
203
    os.chmod(tmpname, 0755)
204
    self.assertTrue(os.path.isdir(tmpname))
205
    (status, msg) = \
206
      backend._VerifyRestrictedCmdDirectory(tmpname,
207
                                            _owner=_DefRestrictedCmdOwner())
208
    self.assertTrue(status)
209
    self.assertTrue(msg is None)
210

    
211

    
212
class TestVerifyRestrictedCmd(unittest.TestCase):
213
  def setUp(self):
214
    self.tmpdir = tempfile.mkdtemp()
215

    
216
  def tearDown(self):
217
    shutil.rmtree(self.tmpdir)
218

    
219
  def testCanNotStat(self):
220
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
221
    self.assertFalse(os.path.exists(tmpname))
222
    (status, msg) = \
223
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
224
                                   _owner=NotImplemented)
225
    self.assertFalse(status)
226
    self.assertTrue(msg.startswith("Can't stat(2) '"))
227

    
228
  def testNotExecutable(self):
229
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
230
    utils.WriteFile(tmpname, data="empty\n")
231
    (status, msg) = \
232
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
233
                                   _owner=_DefRestrictedCmdOwner())
234
    self.assertFalse(status)
235
    self.assertTrue(msg.startswith("access(2) thinks '"))
236

    
237
  def testExecutable(self):
238
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
239
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
240
    (status, executable) = \
241
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
242
                                   _owner=_DefRestrictedCmdOwner())
243
    self.assertTrue(status)
244
    self.assertEqual(executable, tmpname)
245

    
246

    
247
class TestPrepareRestrictedCmd(unittest.TestCase):
248
  _TEST_PATH = "/tmp/some/test/path"
249

    
250
  def testDirFails(self):
251
    def fn(path):
252
      self.assertEqual(path, self._TEST_PATH)
253
      return (False, "test error 31420")
254

    
255
    (status, msg) = \
256
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
257
                                    _verify_dir=fn,
258
                                    _verify_name=NotImplemented,
259
                                    _verify_cmd=NotImplemented)
260
    self.assertFalse(status)
261
    self.assertEqual(msg, "test error 31420")
262

    
263
  def testNameFails(self):
264
    def fn(cmd):
265
      self.assertEqual(cmd, "cmd4617")
266
      return (False, "test error 591")
267

    
268
    (status, msg) = \
269
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
270
                                    _verify_dir=lambda _: (True, None),
271
                                    _verify_name=fn,
272
                                    _verify_cmd=NotImplemented)
273
    self.assertFalse(status)
274
    self.assertEqual(msg, "test error 591")
275

    
276
  def testCommandFails(self):
277
    def fn(path, cmd):
278
      self.assertEqual(path, self._TEST_PATH)
279
      self.assertEqual(cmd, "cmd17577")
280
      return (False, "test error 25524")
281

    
282
    (status, msg) = \
283
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
284
                                    _verify_dir=lambda _: (True, None),
285
                                    _verify_name=lambda _: (True, None),
286
                                    _verify_cmd=fn)
287
    self.assertFalse(status)
288
    self.assertEqual(msg, "test error 25524")
289

    
290
  def testSuccess(self):
291
    def fn(path, cmd):
292
      return (True, utils.PathJoin(path, cmd))
293

    
294
    (status, executable) = \
295
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
296
                                    _verify_dir=lambda _: (True, None),
297
                                    _verify_name=lambda _: (True, None),
298
                                    _verify_cmd=fn)
299
    self.assertTrue(status)
300
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
301

    
302

    
303
def _SleepForRestrictedCmd(duration):
304
  assert duration > 5
305

    
306

    
307
def _GenericRestrictedCmdError(cmd):
308
  return "Executing command '%s' failed" % cmd
309

    
310

    
311
class TestRunRestrictedCmd(unittest.TestCase):
312
  def setUp(self):
313
    self.tmpdir = tempfile.mkdtemp()
314

    
315
  def tearDown(self):
316
    shutil.rmtree(self.tmpdir)
317

    
318
  def testNonExistantLockDirectory(self):
319
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
320
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
321
    self.assertFalse(os.path.exists(lockfile))
322
    self.assertRaises(backend.RPCFail,
323
                      backend.RunRestrictedCmd, "test",
324
                      _lock_timeout=NotImplemented,
325
                      _lock_file=lockfile,
326
                      _path=NotImplemented,
327
                      _sleep_fn=sleep_fn,
328
                      _prepare_fn=NotImplemented,
329
                      _runcmd_fn=NotImplemented,
330
                      _enabled=True)
331
    self.assertEqual(sleep_fn.Count(), 1)
332

    
333
  @staticmethod
334
  def _TryLock(lockfile):
335
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
336

    
337
    result = False
338
    try:
339
      backend.RunRestrictedCmd("test22717",
340
                               _lock_timeout=0.1,
341
                               _lock_file=lockfile,
342
                               _path=NotImplemented,
343
                               _sleep_fn=sleep_fn,
344
                               _prepare_fn=NotImplemented,
345
                               _runcmd_fn=NotImplemented,
346
                               _enabled=True)
347
    except backend.RPCFail, err:
348
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
349
             "Did not fail with generic error message"
350
      result = True
351

    
352
    assert sleep_fn.Count() == 1
353

    
354
    return result
355

    
356
  def testLockHeldByOtherProcess(self):
357
    lockfile = utils.PathJoin(self.tmpdir, "lock")
358

    
359
    lock = utils.FileLock.Open(lockfile)
360
    lock.Exclusive(blocking=True, timeout=1.0)
361
    try:
362
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
363
    finally:
364
      lock.Close()
365

    
366
  @staticmethod
367
  def _PrepareRaisingException(path, cmd):
368
    assert cmd == "test23122"
369
    raise Exception("test")
370

    
371
  def testPrepareRaisesException(self):
372
    lockfile = utils.PathJoin(self.tmpdir, "lock")
373

    
374
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
375
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
376

    
377
    try:
378
      backend.RunRestrictedCmd("test23122",
379
                               _lock_timeout=1.0, _lock_file=lockfile,
380
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
381
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
382
                               _enabled=True)
383
    except backend.RPCFail, err:
384
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
385
    else:
386
      self.fail("Didn't fail")
387

    
388
    self.assertEqual(sleep_fn.Count(), 1)
389
    self.assertEqual(prepare_fn.Count(), 1)
390

    
391
  @staticmethod
392
  def _PrepareFails(path, cmd):
393
    assert cmd == "test29327"
394
    return ("some error message", None)
395

    
396
  def testPrepareFails(self):
397
    lockfile = utils.PathJoin(self.tmpdir, "lock")
398

    
399
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
400
    prepare_fn = testutils.CallCounter(self._PrepareFails)
401

    
402
    try:
403
      backend.RunRestrictedCmd("test29327",
404
                               _lock_timeout=1.0, _lock_file=lockfile,
405
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
406
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
407
                               _enabled=True)
408
    except backend.RPCFail, err:
409
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
410
    else:
411
      self.fail("Didn't fail")
412

    
413
    self.assertEqual(sleep_fn.Count(), 1)
414
    self.assertEqual(prepare_fn.Count(), 1)
415

    
416
  @staticmethod
417
  def _SuccessfulPrepare(path, cmd):
418
    return (True, utils.PathJoin(path, cmd))
419

    
420
  def testRunCmdFails(self):
421
    lockfile = utils.PathJoin(self.tmpdir, "lock")
422

    
423
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
424
           postfork_fn=NotImplemented):
425
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
426
      self.assertEqual(env, {})
427
      self.assertTrue(reset_env)
428
      self.assertTrue(callable(postfork_fn))
429

    
430
      trylock = utils.FileLock.Open(lockfile)
431
      try:
432
        # See if lockfile is still held
433
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
434

    
435
        # Call back to release lock
436
        postfork_fn(NotImplemented)
437

    
438
        # See if lockfile can be acquired
439
        trylock.Exclusive(blocking=False)
440
      finally:
441
        trylock.Close()
442

    
443
      # Simulate a failed command
444
      return utils.RunResult(constants.EXIT_FAILURE, None,
445
                             "stdout", "stderr406328567",
446
                             utils.ShellQuoteArgs(args),
447
                             NotImplemented, NotImplemented)
448

    
449
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
450
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
451
    runcmd_fn = testutils.CallCounter(fn)
452

    
453
    try:
454
      backend.RunRestrictedCmd("test3079",
455
                               _lock_timeout=1.0, _lock_file=lockfile,
456
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
457
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
458
                               _enabled=True)
459
    except backend.RPCFail, err:
460
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
461
                                          " failed:"))
462
      self.assertTrue("stderr406328567" in str(err),
463
                      msg="Error did not include output")
464
    else:
465
      self.fail("Didn't fail")
466

    
467
    self.assertEqual(sleep_fn.Count(), 0)
468
    self.assertEqual(prepare_fn.Count(), 1)
469
    self.assertEqual(runcmd_fn.Count(), 1)
470

    
471
  def testRunCmdSucceeds(self):
472
    lockfile = utils.PathJoin(self.tmpdir, "lock")
473

    
474
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
475
           postfork_fn=NotImplemented):
476
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
477
      self.assertEqual(env, {})
478
      self.assertTrue(reset_env)
479

    
480
      # Call back to release lock
481
      postfork_fn(NotImplemented)
482

    
483
      # Simulate a successful command
484
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
485
                             utils.ShellQuoteArgs(args),
486
                             NotImplemented, NotImplemented)
487

    
488
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
489
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
490
    runcmd_fn = testutils.CallCounter(fn)
491

    
492
    result = backend.RunRestrictedCmd("test5667",
493
                                      _lock_timeout=1.0, _lock_file=lockfile,
494
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
495
                                      _sleep_fn=sleep_fn,
496
                                      _prepare_fn=prepare_fn,
497
                                      _enabled=True)
498
    self.assertEqual(result, "stdout14463")
499

    
500
    self.assertEqual(sleep_fn.Count(), 0)
501
    self.assertEqual(prepare_fn.Count(), 1)
502
    self.assertEqual(runcmd_fn.Count(), 1)
503

    
504
  def testCommandsDisabled(self):
505
    try:
506
      backend.RunRestrictedCmd("test",
507
                               _lock_timeout=NotImplemented,
508
                               _lock_file=NotImplemented,
509
                               _path=NotImplemented,
510
                               _sleep_fn=NotImplemented,
511
                               _prepare_fn=NotImplemented,
512
                               _runcmd_fn=NotImplemented,
513
                               _enabled=False)
514
    except backend.RPCFail, err:
515
      self.assertEqual(str(err),
516
                       "Restricted commands disabled at configure time")
517
    else:
518
      self.fail("Did not raise exception")
519

    
520

    
521
class TestSetWatcherPause(unittest.TestCase):
522
  def setUp(self):
523
    self.tmpdir = tempfile.mkdtemp()
524
    self.filename = utils.PathJoin(self.tmpdir, "pause")
525

    
526
  def tearDown(self):
527
    shutil.rmtree(self.tmpdir)
528

    
529
  def testUnsetNonExisting(self):
530
    self.assertFalse(os.path.exists(self.filename))
531
    backend.SetWatcherPause(None, _filename=self.filename)
532
    self.assertFalse(os.path.exists(self.filename))
533

    
534
  def testSetNonNumeric(self):
535
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
536
      self.assertFalse(os.path.exists(self.filename))
537

    
538
      try:
539
        backend.SetWatcherPause(i, _filename=self.filename)
540
      except backend.RPCFail, err:
541
        self.assertEqual(str(err), "Duration must be numeric")
542
      else:
543
        self.fail("Did not raise exception")
544

    
545
      self.assertFalse(os.path.exists(self.filename))
546

    
547
  def testSet(self):
548
    self.assertFalse(os.path.exists(self.filename))
549

    
550
    for i in range(10):
551
      backend.SetWatcherPause(i, _filename=self.filename)
552
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
553
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
554

    
555

    
556
class TestGetBlockDevSymlinkPath(unittest.TestCase):
557
  def setUp(self):
558
    self.tmpdir = tempfile.mkdtemp()
559

    
560
  def tearDown(self):
561
    shutil.rmtree(self.tmpdir)
562

    
563
  def _Test(self, name, idx):
564
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
565
                                                     _dir=self.tmpdir),
566
                     ("%s/%s%s%s" % (self.tmpdir, name,
567
                                     constants.DISK_SEPARATOR, idx)))
568

    
569
  def test(self):
570
    for idx in range(100):
571
      self._Test("inst1.example.com", idx)
572

    
573

    
574
class TestGetInstanceList(unittest.TestCase):
575

    
576
  def setUp(self):
577
    self._test_hv = self._TestHypervisor()
578
    self._test_hv.ListInstances = mock.Mock(
579
      return_value=["instance1", "instance2", "instance3"] )
580

    
581
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
582
    def __init__(self):
583
      hypervisor.hv_base.BaseHypervisor.__init__(self)
584

    
585
  def _GetHypervisor(self, name):
586
    return self._test_hv
587

    
588
  def testHvparams(self):
589
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
590
    hvparams = {constants.HT_FAKE: fake_hvparams}
591
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
592
                            get_hv_fn=self._GetHypervisor)
593
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
594

    
595

    
596
class TestInstanceConsoleInfo(unittest.TestCase):
597

    
598
  def setUp(self):
599
    self._test_hv_a = self._TestHypervisor()
600
    self._test_hv_a.GetInstanceConsole = mock.Mock(
601
      return_value = objects.InstanceConsole(instance="inst", kind="aHy")
602
    )
603
    self._test_hv_b = self._TestHypervisor()
604
    self._test_hv_b.GetInstanceConsole = mock.Mock(
605
      return_value = objects.InstanceConsole(instance="inst", kind="bHy")
606
    )
607

    
608
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
609
    def __init__(self):
610
      hypervisor.hv_base.BaseHypervisor.__init__(self)
611

    
612
  def _GetHypervisor(self, name):
613
    if name == "a":
614
      return self._test_hv_a
615
    else:
616
      return self._test_hv_b
617

    
618
  def testRightHypervisor(self):
619
    dictMaker = lambda hyName: {
620
      "instance":{"hypervisor":hyName},
621
      "node":{},
622
      "hvParams":{},
623
      "beParams":{},
624
    }
625

    
626
    call = {
627
      'i1':dictMaker("a"),
628
      'i2':dictMaker("b"),
629
    }
630

    
631
    res = backend.GetInstanceConsoleInfo(call, get_hv_fn=self._GetHypervisor)
632

    
633
    self.assertTrue(res["i1"]["kind"] == "aHy")
634
    self.assertTrue(res["i2"]["kind"] == "bHy")
635

    
636

    
637
class TestGetHvInfo(unittest.TestCase):
638

    
639
  def setUp(self):
640
    self._test_hv = self._TestHypervisor()
641
    self._test_hv.GetNodeInfo = mock.Mock()
642

    
643
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
644
    def __init__(self):
645
      hypervisor.hv_base.BaseHypervisor.__init__(self)
646

    
647
  def _GetHypervisor(self, name):
648
    return self._test_hv
649

    
650
  def testGetHvInfoAllNone(self):
651
    result = backend._GetHvInfoAll(None)
652
    self.assertTrue(result is None)
653

    
654
  def testGetHvInfoAll(self):
655
    hvname = constants.HT_XEN_PVM
656
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
657
    hv_specs = [(hvname, hvparams)]
658

    
659
    backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
660
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
661

    
662

    
663
class TestApplyStorageInfoFunction(unittest.TestCase):
664

    
665
  _STORAGE_KEY = "some_key"
666
  _SOME_ARGS = ["some_args"]
667

    
668
  def setUp(self):
669
    self.mock_storage_fn = mock.Mock()
670

    
671
  def testApplyValidStorageType(self):
672
    storage_type = constants.ST_LVM_VG
673
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
674
    backend._STORAGE_TYPE_INFO_FN = {
675
        storage_type: self.mock_storage_fn
676
      }
677

    
678
    backend._ApplyStorageInfoFunction(
679
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
680

    
681
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
682
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
683

    
684
  def testApplyInValidStorageType(self):
685
    storage_type = "invalid_storage_type"
686
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
687
    backend._STORAGE_TYPE_INFO_FN = {}
688

    
689
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
690
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
691
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
692

    
693
  def testApplyNotImplementedStorageType(self):
694
    storage_type = "not_implemented_storage_type"
695
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
696
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
697

    
698
    self.assertRaises(NotImplementedError,
699
                      backend._ApplyStorageInfoFunction,
700
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
701
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
702

    
703

    
704
class TestGetLvmVgSpaceInfo(unittest.TestCase):
705

    
706
  def testValid(self):
707
    path = "somepath"
708
    excl_stor = True
709
    orig_fn = backend._GetVgInfo
710
    backend._GetVgInfo = mock.Mock()
711
    backend._GetLvmVgSpaceInfo(path, [excl_stor])
712
    backend._GetVgInfo.assert_called_with(path, excl_stor)
713
    backend._GetVgInfo = orig_fn
714

    
715
  def testNoExclStorageNotBool(self):
716
    path = "somepath"
717
    excl_stor = "123"
718
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
719
                      path, [excl_stor])
720

    
721
  def testNoExclStorageNotInList(self):
722
    path = "somepath"
723
    excl_stor = "123"
724
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
725
                      path, excl_stor)
726

    
727
class TestGetLvmPvSpaceInfo(unittest.TestCase):
728

    
729
  def testValid(self):
730
    path = "somepath"
731
    excl_stor = True
732
    orig_fn = backend._GetVgSpindlesInfo
733
    backend._GetVgSpindlesInfo = mock.Mock()
734
    backend._GetLvmPvSpaceInfo(path, [excl_stor])
735
    backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
736
    backend._GetVgSpindlesInfo = orig_fn
737

    
738

    
739
class TestCheckStorageParams(unittest.TestCase):
740

    
741
  def testParamsNone(self):
742
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
743
                      None, NotImplemented)
744

    
745
  def testParamsWrongType(self):
746
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
747
                      "string", NotImplemented)
748

    
749
  def testParamsEmpty(self):
750
    backend._CheckStorageParams([], 0)
751

    
752
  def testParamsValidNumber(self):
753
    backend._CheckStorageParams(["a", True], 2)
754

    
755
  def testParamsInvalidNumber(self):
756
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
757
                      ["b", False], 3)
758

    
759

    
760
class TestGetVgSpindlesInfo(unittest.TestCase):
761

    
762
  def setUp(self):
763
    self.vg_free = 13
764
    self.vg_size = 31
765
    self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
766

    
767
  def testValidInput(self):
768
    name = "myvg"
769
    excl_stor = True
770
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
771
    self.mock_fn.assert_called_with(name)
772
    self.assertEqual(name, result["name"])
773
    self.assertEqual(constants.ST_LVM_PV, result["type"])
774
    self.assertEqual(self.vg_free, result["storage_free"])
775
    self.assertEqual(self.vg_size, result["storage_size"])
776

    
777
  def testNoExclStor(self):
778
    name = "myvg"
779
    excl_stor = False
780
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
781
    self.mock_fn.assert_not_called()
782
    self.assertEqual(name, result["name"])
783
    self.assertEqual(constants.ST_LVM_PV, result["type"])
784
    self.assertEqual(0, result["storage_free"])
785
    self.assertEqual(0, result["storage_size"])
786

    
787

    
788
class TestGetVgSpindlesInfo(unittest.TestCase):
789

    
790
  def testValidInput(self):
791
    self.vg_free = 13
792
    self.vg_size = 31
793
    self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
794
    name = "myvg"
795
    excl_stor = True
796
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
797
    self.mock_fn.assert_called_with([name], excl_stor)
798
    self.assertEqual(name, result["name"])
799
    self.assertEqual(constants.ST_LVM_VG, result["type"])
800
    self.assertEqual(self.vg_free, result["storage_free"])
801
    self.assertEqual(self.vg_size, result["storage_size"])
802

    
803
  def testNoExclStor(self):
804
    name = "myvg"
805
    excl_stor = True
806
    self.mock_fn = mock.Mock(return_value=None)
807
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
808
    self.mock_fn.assert_called_with([name], excl_stor)
809
    self.assertEqual(name, result["name"])
810
    self.assertEqual(constants.ST_LVM_VG, result["type"])
811
    self.assertEqual(None, result["storage_free"])
812
    self.assertEqual(None, result["storage_size"])
813

    
814

    
815
class TestGetNodeInfo(unittest.TestCase):
816

    
817
  _SOME_RESULT = None
818

    
819
  def testApplyStorageInfoFunction(self):
820
    orig_fn = backend._ApplyStorageInfoFunction
821
    backend._ApplyStorageInfoFunction = mock.Mock(
822
        return_value=self._SOME_RESULT)
823
    storage_units = [(st, st + "_key", [st + "_params"]) for st in
824
                     constants.STORAGE_TYPES]
825

    
826
    backend.GetNodeInfo(storage_units, None)
827

    
828
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
829
    self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
830
    for call in call_args_list:
831
      storage_type, storage_key, storage_params = call[0]
832
      self.assertEqual(storage_type + "_key", storage_key)
833
      self.assertEqual([storage_type + "_params"], storage_params)
834
      self.assertTrue(storage_type in constants.STORAGE_TYPES)
835
    backend._ApplyStorageInfoFunction = orig_fn
836

    
837

    
838
class TestSpaceReportingConstants(unittest.TestCase):
839
  """Ensures consistency between STS_REPORT and backend.
840

841
  These tests ensure, that the constant 'STS_REPORT' is consitent
842
  with the implementation of invoking space reporting functions
843
  in backend.py. Once space reporting is available for all types,
844
  the constant can be removed and these tests as well.
845

846
  """
847
  def testAllReportingTypesHaveAReportingFunction(self):
848
    for storage_type in constants.STS_REPORT:
849
      self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
850

    
851
  def testAllNotReportingTypesDoneHaveFunction(self):
852
    non_reporting_types = set(constants.STORAGE_TYPES)\
853
        - set(constants.STS_REPORT)
854
    for storage_type in non_reporting_types:
855
      self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
856

    
857

    
858
if __name__ == "__main__":
859
  testutils.GanetiTestProgram()