Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ c42be2c0

History | View | Annotate | Download (29.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import mock
25
import os
26
import shutil
27
import tempfile
28
import testutils
29
import unittest
30

    
31
from ganeti import backend
32
from ganeti import constants
33
from ganeti import errors
34
from ganeti import hypervisor
35
from ganeti import netutils
36
from ganeti import objects
37
from ganeti import utils
38

    
39

    
40
class TestX509Certificates(unittest.TestCase):
41
  def setUp(self):
42
    self.tmpdir = tempfile.mkdtemp()
43

    
44
  def tearDown(self):
45
    shutil.rmtree(self.tmpdir)
46

    
47
  def test(self):
48
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
49

    
50
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
51
                                                 backend._X509_CERT_FILE)),
52
                     cert_pem)
53
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
54
                                                  backend._X509_KEY_FILE)))
55

    
56
    (name2, cert_pem2) = \
57
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
58

    
59
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
60
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
61

    
62
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
63

    
64
  def testNonEmpty(self):
65
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
66

    
67
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
68
                    data="Hello World")
69

    
70
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
71
                      name, cryptodir=self.tmpdir)
72

    
73
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
74

    
75

    
76
class TestNodeVerify(testutils.GanetiTestCase):
77

    
78
  def setUp(self):
79
    testutils.GanetiTestCase.setUp(self)
80
    self._mock_hv = None
81

    
82
  def _GetHypervisor(self, hv_name):
83
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
84
    self._mock_hv.ValidateParameters = mock.Mock()
85
    self._mock_hv.Verify = mock.Mock()
86
    return self._mock_hv
87

    
88
  def testMasterIPLocalhost(self):
89
    # this a real functional test, but requires localhost to be reachable
90
    local_data = (netutils.Hostname.GetSysName(),
91
                  constants.IP4_ADDRESS_LOCALHOST)
92
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data},
93
                                None, {}, {}, {})
94
    self.failUnless(constants.NV_MASTERIP in result,
95
                    "Master IP data not returned")
96
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
97

    
98
  def testMasterIPUnreachable(self):
99
    # Network 192.0.2.0/24 is reserved for test/documentation as per
100
    # RFC 5737
101
    bad_data =  ("master.example.com", "192.0.2.1")
102
    # we just test that whatever TcpPing returns, VerifyNode returns too
103
    netutils.TcpPing = lambda a, b, source=None: False
104
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data},
105
                                None, {}, {}, {})
106
    self.failUnless(constants.NV_MASTERIP in result,
107
                    "Master IP data not returned")
108
    self.failIf(result[constants.NV_MASTERIP],
109
                "Result from netutils.TcpPing corrupted")
110

    
111
  def testVerifyHvparams(self):
112
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
113
    test_what = {constants.NV_HVPARAMS: \
114
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
115
    result = {}
116
    backend._VerifyHvparams(test_what, True, result,
117
                            get_hv_fn=self._GetHypervisor)
118
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
119

    
120
  def testVerifyHypervisors(self):
121
    hvname = constants.HT_XEN_PVM
122
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
123
    all_hvparams = {hvname: hvparams}
124
    test_what = {constants.NV_HYPERVISOR: [hvname]}
125
    result = {}
126
    backend._VerifyHypervisors(
127
        test_what, True, result, all_hvparams=all_hvparams,
128
        get_hv_fn=self._GetHypervisor)
129
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
130

    
131

    
132
def _DefRestrictedCmdOwner():
133
  return (os.getuid(), os.getgid())
134

    
135

    
136
class TestVerifyRestrictedCmdName(unittest.TestCase):
137
  def testAcceptableName(self):
138
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
139
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
140
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
141
        self.assertTrue(status)
142
        self.assertTrue(msg is None)
143

    
144
  def testEmptyAndSpace(self):
145
    for i in ["", " ", "\t", "\n"]:
146
      (status, msg) = backend._VerifyRestrictedCmdName(i)
147
      self.assertFalse(status)
148
      self.assertEqual(msg, "Missing command name")
149

    
150
  def testNameWithSlashes(self):
151
    for i in ["/", "./foo", "../moo", "some/name"]:
152
      (status, msg) = backend._VerifyRestrictedCmdName(i)
153
      self.assertFalse(status)
154
      self.assertEqual(msg, "Invalid command name")
155

    
156
  def testForbiddenCharacters(self):
157
    for i in ["#", ".", "..", "bash -c ls", "'"]:
158
      (status, msg) = backend._VerifyRestrictedCmdName(i)
159
      self.assertFalse(status)
160
      self.assertEqual(msg, "Command name contains forbidden characters")
161

    
162

    
163
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
164
  def setUp(self):
165
    self.tmpdir = tempfile.mkdtemp()
166

    
167
  def tearDown(self):
168
    shutil.rmtree(self.tmpdir)
169

    
170
  def testCanNotStat(self):
171
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
172
    self.assertFalse(os.path.exists(tmpname))
173
    (status, msg) = \
174
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
175
    self.assertFalse(status)
176
    self.assertTrue(msg.startswith("Can't stat(2) '"))
177

    
178
  def testTooPermissive(self):
179
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
180
    os.mkdir(tmpname)
181

    
182
    for mode in [0777, 0706, 0760, 0722]:
183
      os.chmod(tmpname, mode)
184
      self.assertTrue(os.path.isdir(tmpname))
185
      (status, msg) = \
186
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
187
      self.assertFalse(status)
188
      self.assertTrue(msg.startswith("Permissions on '"))
189

    
190
  def testNoDirectory(self):
191
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
192
    utils.WriteFile(tmpname, data="empty\n")
193
    self.assertTrue(os.path.isfile(tmpname))
194
    (status, msg) = \
195
      backend._VerifyRestrictedCmdDirectory(tmpname,
196
                                            _owner=_DefRestrictedCmdOwner())
197
    self.assertFalse(status)
198
    self.assertTrue(msg.endswith("is not a directory"))
199

    
200
  def testNormal(self):
201
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
202
    os.mkdir(tmpname)
203
    os.chmod(tmpname, 0755)
204
    self.assertTrue(os.path.isdir(tmpname))
205
    (status, msg) = \
206
      backend._VerifyRestrictedCmdDirectory(tmpname,
207
                                            _owner=_DefRestrictedCmdOwner())
208
    self.assertTrue(status)
209
    self.assertTrue(msg is None)
210

    
211

    
212
class TestVerifyRestrictedCmd(unittest.TestCase):
213
  def setUp(self):
214
    self.tmpdir = tempfile.mkdtemp()
215

    
216
  def tearDown(self):
217
    shutil.rmtree(self.tmpdir)
218

    
219
  def testCanNotStat(self):
220
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
221
    self.assertFalse(os.path.exists(tmpname))
222
    (status, msg) = \
223
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
224
                                   _owner=NotImplemented)
225
    self.assertFalse(status)
226
    self.assertTrue(msg.startswith("Can't stat(2) '"))
227

    
228
  def testNotExecutable(self):
229
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
230
    utils.WriteFile(tmpname, data="empty\n")
231
    (status, msg) = \
232
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
233
                                   _owner=_DefRestrictedCmdOwner())
234
    self.assertFalse(status)
235
    self.assertTrue(msg.startswith("access(2) thinks '"))
236

    
237
  def testExecutable(self):
238
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
239
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
240
    (status, executable) = \
241
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
242
                                   _owner=_DefRestrictedCmdOwner())
243
    self.assertTrue(status)
244
    self.assertEqual(executable, tmpname)
245

    
246

    
247
class TestPrepareRestrictedCmd(unittest.TestCase):
248
  _TEST_PATH = "/tmp/some/test/path"
249

    
250
  def testDirFails(self):
251
    def fn(path):
252
      self.assertEqual(path, self._TEST_PATH)
253
      return (False, "test error 31420")
254

    
255
    (status, msg) = \
256
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
257
                                    _verify_dir=fn,
258
                                    _verify_name=NotImplemented,
259
                                    _verify_cmd=NotImplemented)
260
    self.assertFalse(status)
261
    self.assertEqual(msg, "test error 31420")
262

    
263
  def testNameFails(self):
264
    def fn(cmd):
265
      self.assertEqual(cmd, "cmd4617")
266
      return (False, "test error 591")
267

    
268
    (status, msg) = \
269
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
270
                                    _verify_dir=lambda _: (True, None),
271
                                    _verify_name=fn,
272
                                    _verify_cmd=NotImplemented)
273
    self.assertFalse(status)
274
    self.assertEqual(msg, "test error 591")
275

    
276
  def testCommandFails(self):
277
    def fn(path, cmd):
278
      self.assertEqual(path, self._TEST_PATH)
279
      self.assertEqual(cmd, "cmd17577")
280
      return (False, "test error 25524")
281

    
282
    (status, msg) = \
283
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
284
                                    _verify_dir=lambda _: (True, None),
285
                                    _verify_name=lambda _: (True, None),
286
                                    _verify_cmd=fn)
287
    self.assertFalse(status)
288
    self.assertEqual(msg, "test error 25524")
289

    
290
  def testSuccess(self):
291
    def fn(path, cmd):
292
      return (True, utils.PathJoin(path, cmd))
293

    
294
    (status, executable) = \
295
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
296
                                    _verify_dir=lambda _: (True, None),
297
                                    _verify_name=lambda _: (True, None),
298
                                    _verify_cmd=fn)
299
    self.assertTrue(status)
300
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
301

    
302

    
303
def _SleepForRestrictedCmd(duration):
304
  assert duration > 5
305

    
306

    
307
def _GenericRestrictedCmdError(cmd):
308
  return "Executing command '%s' failed" % cmd
309

    
310

    
311
class TestRunRestrictedCmd(unittest.TestCase):
312
  def setUp(self):
313
    self.tmpdir = tempfile.mkdtemp()
314

    
315
  def tearDown(self):
316
    shutil.rmtree(self.tmpdir)
317

    
318
  def testNonExistantLockDirectory(self):
319
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
320
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
321
    self.assertFalse(os.path.exists(lockfile))
322
    self.assertRaises(backend.RPCFail,
323
                      backend.RunRestrictedCmd, "test",
324
                      _lock_timeout=NotImplemented,
325
                      _lock_file=lockfile,
326
                      _path=NotImplemented,
327
                      _sleep_fn=sleep_fn,
328
                      _prepare_fn=NotImplemented,
329
                      _runcmd_fn=NotImplemented,
330
                      _enabled=True)
331
    self.assertEqual(sleep_fn.Count(), 1)
332

    
333
  @staticmethod
334
  def _TryLock(lockfile):
335
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
336

    
337
    result = False
338
    try:
339
      backend.RunRestrictedCmd("test22717",
340
                               _lock_timeout=0.1,
341
                               _lock_file=lockfile,
342
                               _path=NotImplemented,
343
                               _sleep_fn=sleep_fn,
344
                               _prepare_fn=NotImplemented,
345
                               _runcmd_fn=NotImplemented,
346
                               _enabled=True)
347
    except backend.RPCFail, err:
348
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
349
             "Did not fail with generic error message"
350
      result = True
351

    
352
    assert sleep_fn.Count() == 1
353

    
354
    return result
355

    
356
  def testLockHeldByOtherProcess(self):
357
    lockfile = utils.PathJoin(self.tmpdir, "lock")
358

    
359
    lock = utils.FileLock.Open(lockfile)
360
    lock.Exclusive(blocking=True, timeout=1.0)
361
    try:
362
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
363
    finally:
364
      lock.Close()
365

    
366
  @staticmethod
367
  def _PrepareRaisingException(path, cmd):
368
    assert cmd == "test23122"
369
    raise Exception("test")
370

    
371
  def testPrepareRaisesException(self):
372
    lockfile = utils.PathJoin(self.tmpdir, "lock")
373

    
374
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
375
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
376

    
377
    try:
378
      backend.RunRestrictedCmd("test23122",
379
                               _lock_timeout=1.0, _lock_file=lockfile,
380
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
381
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
382
                               _enabled=True)
383
    except backend.RPCFail, err:
384
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
385
    else:
386
      self.fail("Didn't fail")
387

    
388
    self.assertEqual(sleep_fn.Count(), 1)
389
    self.assertEqual(prepare_fn.Count(), 1)
390

    
391
  @staticmethod
392
  def _PrepareFails(path, cmd):
393
    assert cmd == "test29327"
394
    return ("some error message", None)
395

    
396
  def testPrepareFails(self):
397
    lockfile = utils.PathJoin(self.tmpdir, "lock")
398

    
399
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
400
    prepare_fn = testutils.CallCounter(self._PrepareFails)
401

    
402
    try:
403
      backend.RunRestrictedCmd("test29327",
404
                               _lock_timeout=1.0, _lock_file=lockfile,
405
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
406
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
407
                               _enabled=True)
408
    except backend.RPCFail, err:
409
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
410
    else:
411
      self.fail("Didn't fail")
412

    
413
    self.assertEqual(sleep_fn.Count(), 1)
414
    self.assertEqual(prepare_fn.Count(), 1)
415

    
416
  @staticmethod
417
  def _SuccessfulPrepare(path, cmd):
418
    return (True, utils.PathJoin(path, cmd))
419

    
420
  def testRunCmdFails(self):
421
    lockfile = utils.PathJoin(self.tmpdir, "lock")
422

    
423
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
424
           postfork_fn=NotImplemented):
425
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
426
      self.assertEqual(env, {})
427
      self.assertTrue(reset_env)
428
      self.assertTrue(callable(postfork_fn))
429

    
430
      trylock = utils.FileLock.Open(lockfile)
431
      try:
432
        # See if lockfile is still held
433
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
434

    
435
        # Call back to release lock
436
        postfork_fn(NotImplemented)
437

    
438
        # See if lockfile can be acquired
439
        trylock.Exclusive(blocking=False)
440
      finally:
441
        trylock.Close()
442

    
443
      # Simulate a failed command
444
      return utils.RunResult(constants.EXIT_FAILURE, None,
445
                             "stdout", "stderr406328567",
446
                             utils.ShellQuoteArgs(args),
447
                             NotImplemented, NotImplemented)
448

    
449
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
450
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
451
    runcmd_fn = testutils.CallCounter(fn)
452

    
453
    try:
454
      backend.RunRestrictedCmd("test3079",
455
                               _lock_timeout=1.0, _lock_file=lockfile,
456
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
457
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
458
                               _enabled=True)
459
    except backend.RPCFail, err:
460
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
461
                                          " failed:"))
462
      self.assertTrue("stderr406328567" in str(err),
463
                      msg="Error did not include output")
464
    else:
465
      self.fail("Didn't fail")
466

    
467
    self.assertEqual(sleep_fn.Count(), 0)
468
    self.assertEqual(prepare_fn.Count(), 1)
469
    self.assertEqual(runcmd_fn.Count(), 1)
470

    
471
  def testRunCmdSucceeds(self):
472
    lockfile = utils.PathJoin(self.tmpdir, "lock")
473

    
474
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
475
           postfork_fn=NotImplemented):
476
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
477
      self.assertEqual(env, {})
478
      self.assertTrue(reset_env)
479

    
480
      # Call back to release lock
481
      postfork_fn(NotImplemented)
482

    
483
      # Simulate a successful command
484
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
485
                             utils.ShellQuoteArgs(args),
486
                             NotImplemented, NotImplemented)
487

    
488
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
489
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
490
    runcmd_fn = testutils.CallCounter(fn)
491

    
492
    result = backend.RunRestrictedCmd("test5667",
493
                                      _lock_timeout=1.0, _lock_file=lockfile,
494
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
495
                                      _sleep_fn=sleep_fn,
496
                                      _prepare_fn=prepare_fn,
497
                                      _enabled=True)
498
    self.assertEqual(result, "stdout14463")
499

    
500
    self.assertEqual(sleep_fn.Count(), 0)
501
    self.assertEqual(prepare_fn.Count(), 1)
502
    self.assertEqual(runcmd_fn.Count(), 1)
503

    
504
  def testCommandsDisabled(self):
505
    try:
506
      backend.RunRestrictedCmd("test",
507
                               _lock_timeout=NotImplemented,
508
                               _lock_file=NotImplemented,
509
                               _path=NotImplemented,
510
                               _sleep_fn=NotImplemented,
511
                               _prepare_fn=NotImplemented,
512
                               _runcmd_fn=NotImplemented,
513
                               _enabled=False)
514
    except backend.RPCFail, err:
515
      self.assertEqual(str(err),
516
                       "Restricted commands disabled at configure time")
517
    else:
518
      self.fail("Did not raise exception")
519

    
520

    
521
class TestSetWatcherPause(unittest.TestCase):
522
  def setUp(self):
523
    self.tmpdir = tempfile.mkdtemp()
524
    self.filename = utils.PathJoin(self.tmpdir, "pause")
525

    
526
  def tearDown(self):
527
    shutil.rmtree(self.tmpdir)
528

    
529
  def testUnsetNonExisting(self):
530
    self.assertFalse(os.path.exists(self.filename))
531
    backend.SetWatcherPause(None, _filename=self.filename)
532
    self.assertFalse(os.path.exists(self.filename))
533

    
534
  def testSetNonNumeric(self):
535
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
536
      self.assertFalse(os.path.exists(self.filename))
537

    
538
      try:
539
        backend.SetWatcherPause(i, _filename=self.filename)
540
      except backend.RPCFail, err:
541
        self.assertEqual(str(err), "Duration must be numeric")
542
      else:
543
        self.fail("Did not raise exception")
544

    
545
      self.assertFalse(os.path.exists(self.filename))
546

    
547
  def testSet(self):
548
    self.assertFalse(os.path.exists(self.filename))
549

    
550
    for i in range(10):
551
      backend.SetWatcherPause(i, _filename=self.filename)
552
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
553
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
554

    
555

    
556
class TestGetBlockDevSymlinkPath(unittest.TestCase):
557
  def setUp(self):
558
    self.tmpdir = tempfile.mkdtemp()
559

    
560
  def tearDown(self):
561
    shutil.rmtree(self.tmpdir)
562

    
563
  def _Test(self, name, idx):
564
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
565
                                                     _dir=self.tmpdir),
566
                     ("%s/%s%s%s" % (self.tmpdir, name,
567
                                     constants.DISK_SEPARATOR, idx)))
568

    
569
  def test(self):
570
    for idx in range(100):
571
      self._Test("inst1.example.com", idx)
572

    
573

    
574
class TestGetInstanceList(unittest.TestCase):
575

    
576
  def setUp(self):
577
    self._test_hv = self._TestHypervisor()
578
    self._test_hv.ListInstances = mock.Mock(
579
      return_value=["instance1", "instance2", "instance3"] )
580

    
581
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
582
    def __init__(self):
583
      hypervisor.hv_base.BaseHypervisor.__init__(self)
584

    
585
  def _GetHypervisor(self, name):
586
    return self._test_hv
587

    
588
  def testHvparams(self):
589
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
590
    hvparams = {constants.HT_FAKE: fake_hvparams}
591
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
592
                            get_hv_fn=self._GetHypervisor)
593
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
594

    
595

    
596
class TestInstanceConsoleInfo(unittest.TestCase):
597

    
598
  def setUp(self):
599
    self._test_hv_a = self._TestHypervisor()
600
    self._test_hv_a.GetInstanceConsole = mock.Mock(
601
      return_value = objects.InstanceConsole(instance="inst", kind="aHy")
602
    )
603
    self._test_hv_b = self._TestHypervisor()
604
    self._test_hv_b.GetInstanceConsole = mock.Mock(
605
      return_value = objects.InstanceConsole(instance="inst", kind="bHy")
606
    )
607

    
608
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
609
    def __init__(self):
610
      hypervisor.hv_base.BaseHypervisor.__init__(self)
611

    
612
  def _GetHypervisor(self, name):
613
    if name == "a":
614
      return self._test_hv_a
615
    else:
616
      return self._test_hv_b
617

    
618
  def testRightHypervisor(self):
619
    dictMaker = lambda hyName: {
620
      "instance":{"hypervisor":hyName},
621
      "node":{},
622
      "group":{},
623
      "hvParams":{},
624
      "beParams":{},
625
    }
626

    
627
    call = {
628
      'i1':dictMaker("a"),
629
      'i2':dictMaker("b"),
630
    }
631

    
632
    res = backend.GetInstanceConsoleInfo(call, get_hv_fn=self._GetHypervisor)
633

    
634
    self.assertTrue(res["i1"]["kind"] == "aHy")
635
    self.assertTrue(res["i2"]["kind"] == "bHy")
636

    
637

    
638
class TestGetHvInfo(unittest.TestCase):
639

    
640
  def setUp(self):
641
    self._test_hv = self._TestHypervisor()
642
    self._test_hv.GetNodeInfo = mock.Mock()
643

    
644
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
645
    def __init__(self):
646
      hypervisor.hv_base.BaseHypervisor.__init__(self)
647

    
648
  def _GetHypervisor(self, name):
649
    return self._test_hv
650

    
651
  def testGetHvInfoAllNone(self):
652
    result = backend._GetHvInfoAll(None)
653
    self.assertTrue(result is None)
654

    
655
  def testGetHvInfoAll(self):
656
    hvname = constants.HT_XEN_PVM
657
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
658
    hv_specs = [(hvname, hvparams)]
659

    
660
    backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
661
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
662

    
663

    
664
class TestApplyStorageInfoFunction(unittest.TestCase):
665

    
666
  _STORAGE_KEY = "some_key"
667
  _SOME_ARGS = ["some_args"]
668

    
669
  def setUp(self):
670
    self.mock_storage_fn = mock.Mock()
671

    
672
  def testApplyValidStorageType(self):
673
    storage_type = constants.ST_LVM_VG
674
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
675
    backend._STORAGE_TYPE_INFO_FN = {
676
        storage_type: self.mock_storage_fn
677
      }
678

    
679
    backend._ApplyStorageInfoFunction(
680
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
681

    
682
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
683
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
684

    
685
  def testApplyInValidStorageType(self):
686
    storage_type = "invalid_storage_type"
687
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
688
    backend._STORAGE_TYPE_INFO_FN = {}
689

    
690
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
691
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
692
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
693

    
694
  def testApplyNotImplementedStorageType(self):
695
    storage_type = "not_implemented_storage_type"
696
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
697
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
698

    
699
    self.assertRaises(NotImplementedError,
700
                      backend._ApplyStorageInfoFunction,
701
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
702
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
703

    
704

    
705
class TestGetLvmVgSpaceInfo(unittest.TestCase):
706

    
707
  def testValid(self):
708
    path = "somepath"
709
    excl_stor = True
710
    orig_fn = backend._GetVgInfo
711
    backend._GetVgInfo = mock.Mock()
712
    backend._GetLvmVgSpaceInfo(path, [excl_stor])
713
    backend._GetVgInfo.assert_called_with(path, excl_stor)
714
    backend._GetVgInfo = orig_fn
715

    
716
  def testNoExclStorageNotBool(self):
717
    path = "somepath"
718
    excl_stor = "123"
719
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
720
                      path, [excl_stor])
721

    
722
  def testNoExclStorageNotInList(self):
723
    path = "somepath"
724
    excl_stor = "123"
725
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
726
                      path, excl_stor)
727

    
728
class TestGetLvmPvSpaceInfo(unittest.TestCase):
729

    
730
  def testValid(self):
731
    path = "somepath"
732
    excl_stor = True
733
    orig_fn = backend._GetVgSpindlesInfo
734
    backend._GetVgSpindlesInfo = mock.Mock()
735
    backend._GetLvmPvSpaceInfo(path, [excl_stor])
736
    backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
737
    backend._GetVgSpindlesInfo = orig_fn
738

    
739

    
740
class TestCheckStorageParams(unittest.TestCase):
741

    
742
  def testParamsNone(self):
743
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
744
                      None, NotImplemented)
745

    
746
  def testParamsWrongType(self):
747
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
748
                      "string", NotImplemented)
749

    
750
  def testParamsEmpty(self):
751
    backend._CheckStorageParams([], 0)
752

    
753
  def testParamsValidNumber(self):
754
    backend._CheckStorageParams(["a", True], 2)
755

    
756
  def testParamsInvalidNumber(self):
757
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
758
                      ["b", False], 3)
759

    
760

    
761
class TestGetVgSpindlesInfo(unittest.TestCase):
762

    
763
  def setUp(self):
764
    self.vg_free = 13
765
    self.vg_size = 31
766
    self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
767

    
768
  def testValidInput(self):
769
    name = "myvg"
770
    excl_stor = True
771
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
772
    self.mock_fn.assert_called_with(name)
773
    self.assertEqual(name, result["name"])
774
    self.assertEqual(constants.ST_LVM_PV, result["type"])
775
    self.assertEqual(self.vg_free, result["storage_free"])
776
    self.assertEqual(self.vg_size, result["storage_size"])
777

    
778
  def testNoExclStor(self):
779
    name = "myvg"
780
    excl_stor = False
781
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
782
    self.mock_fn.assert_not_called()
783
    self.assertEqual(name, result["name"])
784
    self.assertEqual(constants.ST_LVM_PV, result["type"])
785
    self.assertEqual(0, result["storage_free"])
786
    self.assertEqual(0, result["storage_size"])
787

    
788

    
789
class TestGetVgSpindlesInfo(unittest.TestCase):
790

    
791
  def testValidInput(self):
792
    self.vg_free = 13
793
    self.vg_size = 31
794
    self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
795
    name = "myvg"
796
    excl_stor = True
797
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
798
    self.mock_fn.assert_called_with([name], excl_stor)
799
    self.assertEqual(name, result["name"])
800
    self.assertEqual(constants.ST_LVM_VG, result["type"])
801
    self.assertEqual(self.vg_free, result["storage_free"])
802
    self.assertEqual(self.vg_size, result["storage_size"])
803

    
804
  def testNoExclStor(self):
805
    name = "myvg"
806
    excl_stor = True
807
    self.mock_fn = mock.Mock(return_value=None)
808
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
809
    self.mock_fn.assert_called_with([name], excl_stor)
810
    self.assertEqual(name, result["name"])
811
    self.assertEqual(constants.ST_LVM_VG, result["type"])
812
    self.assertEqual(None, result["storage_free"])
813
    self.assertEqual(None, result["storage_size"])
814

    
815

    
816
class TestGetNodeInfo(unittest.TestCase):
817

    
818
  _SOME_RESULT = None
819

    
820
  def testApplyStorageInfoFunction(self):
821
    orig_fn = backend._ApplyStorageInfoFunction
822
    backend._ApplyStorageInfoFunction = mock.Mock(
823
        return_value=self._SOME_RESULT)
824
    storage_units = [(st, st + "_key", [st + "_params"]) for st in
825
                     constants.STORAGE_TYPES]
826

    
827
    backend.GetNodeInfo(storage_units, None)
828

    
829
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
830
    self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
831
    for call in call_args_list:
832
      storage_type, storage_key, storage_params = call[0]
833
      self.assertEqual(storage_type + "_key", storage_key)
834
      self.assertEqual([storage_type + "_params"], storage_params)
835
      self.assertTrue(storage_type in constants.STORAGE_TYPES)
836
    backend._ApplyStorageInfoFunction = orig_fn
837

    
838

    
839
class TestSpaceReportingConstants(unittest.TestCase):
840
  """Ensures consistency between STS_REPORT and backend.
841

842
  These tests ensure, that the constant 'STS_REPORT' is consitent
843
  with the implementation of invoking space reporting functions
844
  in backend.py. Once space reporting is available for all types,
845
  the constant can be removed and these tests as well.
846

847
  """
848
  def testAllReportingTypesHaveAReportingFunction(self):
849
    for storage_type in constants.STS_REPORT:
850
      self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
851

    
852
  def testAllNotReportingTypesDoneHaveFunction(self):
853
    non_reporting_types = set(constants.STORAGE_TYPES)\
854
        - set(constants.STS_REPORT)
855
    for storage_type in non_reporting_types:
856
      self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
857

    
858

    
859
if __name__ == "__main__":
860
  testutils.GanetiTestProgram()