Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ b3cc1646

History | View | Annotate | Download (31.1 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import mock
25
import os
26
import shutil
27
import tempfile
28
import testutils
29
import unittest
30

    
31
from ganeti import backend
32
from ganeti import constants
33
from ganeti import errors
34
from ganeti import hypervisor
35
from ganeti import netutils
36
from ganeti import objects
37
from ganeti import pathutils
38
from ganeti import utils
39

    
40

    
41
class TestX509Certificates(unittest.TestCase):
42
  def setUp(self):
43
    self.tmpdir = tempfile.mkdtemp()
44

    
45
  def tearDown(self):
46
    shutil.rmtree(self.tmpdir)
47

    
48
  def test(self):
49
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
50

    
51
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
52
                                                 backend._X509_CERT_FILE)),
53
                     cert_pem)
54
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
55
                                                  backend._X509_KEY_FILE)))
56

    
57
    (name2, cert_pem2) = \
58
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
59

    
60
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
61
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
62

    
63
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
64

    
65
  def testNonEmpty(self):
66
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
67

    
68
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
69
                    data="Hello World")
70

    
71
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
72
                      name, cryptodir=self.tmpdir)
73

    
74
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
75

    
76

    
77
class TestGetCryptoTokens(testutils.GanetiTestCase):
78

    
79
  def setUp(self):
80
    self._get_digest_fn_orig = utils.GetCertificateDigest
81
    self._create_digest_fn_orig = utils.GenerateNewSslCert
82
    self._ssl_digest = "12345"
83
    utils.GetCertificateDigest = mock.Mock(
84
      return_value=self._ssl_digest)
85
    utils.GenerateNewSslCert = mock.Mock()
86

    
87
  def tearDown(self):
88
    utils.GetCertificateDigest = self._get_digest_fn_orig
89
    utils.GenerateNewSslCert = self._create_digest_fn_orig
90

    
91
  def testGetSslToken(self):
92
    result = backend.GetCryptoTokens(
93
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_GET, None)])
94
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
95
                    in result)
96

    
97
  def testCreateSslToken(self):
98
    result = backend.GetCryptoTokens(
99
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
100
        None)])
101
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
102
                    in result)
103
    self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
104

    
105
  def testCreateSslTokenDifferentFilename(self):
106
    result = backend.GetCryptoTokens(
107
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
108
        {constants.CRYPTO_OPTION_CERT_FILE:
109
          pathutils.NODED_CLIENT_CERT_FILE_TMP})])
110
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
111
                    in result)
112
    self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
113

    
114
  def testUnknownTokenType(self):
115
    self.assertRaises(errors.ProgrammerError,
116
                      backend.GetCryptoTokens,
117
                      [("pink_bunny", constants.CRYPTO_ACTION_GET, None)])
118

    
119
  def testUnknownAction(self):
120
    self.assertRaises(errors.ProgrammerError,
121
                      backend.GetCryptoTokens,
122
                      [(constants.CRYPTO_TYPE_SSL_DIGEST, "illuminate", None)])
123

    
124

    
125
class TestNodeVerify(testutils.GanetiTestCase):
126

    
127
  def setUp(self):
128
    testutils.GanetiTestCase.setUp(self)
129
    self._mock_hv = None
130

    
131
  def _GetHypervisor(self, hv_name):
132
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
133
    self._mock_hv.ValidateParameters = mock.Mock()
134
    self._mock_hv.Verify = mock.Mock()
135
    return self._mock_hv
136

    
137
  def testMasterIPLocalhost(self):
138
    # this a real functional test, but requires localhost to be reachable
139
    local_data = (netutils.Hostname.GetSysName(),
140
                  constants.IP4_ADDRESS_LOCALHOST)
141
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data},
142
                                None, {}, {}, {})
143
    self.failUnless(constants.NV_MASTERIP in result,
144
                    "Master IP data not returned")
145
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
146

    
147
  def testMasterIPUnreachable(self):
148
    # Network 192.0.2.0/24 is reserved for test/documentation as per
149
    # RFC 5737
150
    bad_data =  ("master.example.com", "192.0.2.1")
151
    # we just test that whatever TcpPing returns, VerifyNode returns too
152
    netutils.TcpPing = lambda a, b, source=None: False
153
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data},
154
                                None, {}, {}, {})
155
    self.failUnless(constants.NV_MASTERIP in result,
156
                    "Master IP data not returned")
157
    self.failIf(result[constants.NV_MASTERIP],
158
                "Result from netutils.TcpPing corrupted")
159

    
160
  def testVerifyHvparams(self):
161
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
162
    test_what = {constants.NV_HVPARAMS: \
163
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
164
    result = {}
165
    backend._VerifyHvparams(test_what, True, result,
166
                            get_hv_fn=self._GetHypervisor)
167
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
168

    
169
  def testVerifyHypervisors(self):
170
    hvname = constants.HT_XEN_PVM
171
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
172
    all_hvparams = {hvname: hvparams}
173
    test_what = {constants.NV_HYPERVISOR: [hvname]}
174
    result = {}
175
    backend._VerifyHypervisors(
176
        test_what, True, result, all_hvparams=all_hvparams,
177
        get_hv_fn=self._GetHypervisor)
178
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
179

    
180

    
181
def _DefRestrictedCmdOwner():
182
  return (os.getuid(), os.getgid())
183

    
184

    
185
class TestVerifyRestrictedCmdName(unittest.TestCase):
186
  def testAcceptableName(self):
187
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
188
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
189
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
190
        self.assertTrue(status)
191
        self.assertTrue(msg is None)
192

    
193
  def testEmptyAndSpace(self):
194
    for i in ["", " ", "\t", "\n"]:
195
      (status, msg) = backend._VerifyRestrictedCmdName(i)
196
      self.assertFalse(status)
197
      self.assertEqual(msg, "Missing command name")
198

    
199
  def testNameWithSlashes(self):
200
    for i in ["/", "./foo", "../moo", "some/name"]:
201
      (status, msg) = backend._VerifyRestrictedCmdName(i)
202
      self.assertFalse(status)
203
      self.assertEqual(msg, "Invalid command name")
204

    
205
  def testForbiddenCharacters(self):
206
    for i in ["#", ".", "..", "bash -c ls", "'"]:
207
      (status, msg) = backend._VerifyRestrictedCmdName(i)
208
      self.assertFalse(status)
209
      self.assertEqual(msg, "Command name contains forbidden characters")
210

    
211

    
212
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
213
  def setUp(self):
214
    self.tmpdir = tempfile.mkdtemp()
215

    
216
  def tearDown(self):
217
    shutil.rmtree(self.tmpdir)
218

    
219
  def testCanNotStat(self):
220
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
221
    self.assertFalse(os.path.exists(tmpname))
222
    (status, msg) = \
223
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
224
    self.assertFalse(status)
225
    self.assertTrue(msg.startswith("Can't stat(2) '"))
226

    
227
  def testTooPermissive(self):
228
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
229
    os.mkdir(tmpname)
230

    
231
    for mode in [0777, 0706, 0760, 0722]:
232
      os.chmod(tmpname, mode)
233
      self.assertTrue(os.path.isdir(tmpname))
234
      (status, msg) = \
235
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
236
      self.assertFalse(status)
237
      self.assertTrue(msg.startswith("Permissions on '"))
238

    
239
  def testNoDirectory(self):
240
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
241
    utils.WriteFile(tmpname, data="empty\n")
242
    self.assertTrue(os.path.isfile(tmpname))
243
    (status, msg) = \
244
      backend._VerifyRestrictedCmdDirectory(tmpname,
245
                                            _owner=_DefRestrictedCmdOwner())
246
    self.assertFalse(status)
247
    self.assertTrue(msg.endswith("is not a directory"))
248

    
249
  def testNormal(self):
250
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
251
    os.mkdir(tmpname)
252
    os.chmod(tmpname, 0755)
253
    self.assertTrue(os.path.isdir(tmpname))
254
    (status, msg) = \
255
      backend._VerifyRestrictedCmdDirectory(tmpname,
256
                                            _owner=_DefRestrictedCmdOwner())
257
    self.assertTrue(status)
258
    self.assertTrue(msg is None)
259

    
260

    
261
class TestVerifyRestrictedCmd(unittest.TestCase):
262
  def setUp(self):
263
    self.tmpdir = tempfile.mkdtemp()
264

    
265
  def tearDown(self):
266
    shutil.rmtree(self.tmpdir)
267

    
268
  def testCanNotStat(self):
269
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
270
    self.assertFalse(os.path.exists(tmpname))
271
    (status, msg) = \
272
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
273
                                   _owner=NotImplemented)
274
    self.assertFalse(status)
275
    self.assertTrue(msg.startswith("Can't stat(2) '"))
276

    
277
  def testNotExecutable(self):
278
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
279
    utils.WriteFile(tmpname, data="empty\n")
280
    (status, msg) = \
281
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
282
                                   _owner=_DefRestrictedCmdOwner())
283
    self.assertFalse(status)
284
    self.assertTrue(msg.startswith("access(2) thinks '"))
285

    
286
  def testExecutable(self):
287
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
288
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
289
    (status, executable) = \
290
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
291
                                   _owner=_DefRestrictedCmdOwner())
292
    self.assertTrue(status)
293
    self.assertEqual(executable, tmpname)
294

    
295

    
296
class TestPrepareRestrictedCmd(unittest.TestCase):
297
  _TEST_PATH = "/tmp/some/test/path"
298

    
299
  def testDirFails(self):
300
    def fn(path):
301
      self.assertEqual(path, self._TEST_PATH)
302
      return (False, "test error 31420")
303

    
304
    (status, msg) = \
305
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
306
                                    _verify_dir=fn,
307
                                    _verify_name=NotImplemented,
308
                                    _verify_cmd=NotImplemented)
309
    self.assertFalse(status)
310
    self.assertEqual(msg, "test error 31420")
311

    
312
  def testNameFails(self):
313
    def fn(cmd):
314
      self.assertEqual(cmd, "cmd4617")
315
      return (False, "test error 591")
316

    
317
    (status, msg) = \
318
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
319
                                    _verify_dir=lambda _: (True, None),
320
                                    _verify_name=fn,
321
                                    _verify_cmd=NotImplemented)
322
    self.assertFalse(status)
323
    self.assertEqual(msg, "test error 591")
324

    
325
  def testCommandFails(self):
326
    def fn(path, cmd):
327
      self.assertEqual(path, self._TEST_PATH)
328
      self.assertEqual(cmd, "cmd17577")
329
      return (False, "test error 25524")
330

    
331
    (status, msg) = \
332
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
333
                                    _verify_dir=lambda _: (True, None),
334
                                    _verify_name=lambda _: (True, None),
335
                                    _verify_cmd=fn)
336
    self.assertFalse(status)
337
    self.assertEqual(msg, "test error 25524")
338

    
339
  def testSuccess(self):
340
    def fn(path, cmd):
341
      return (True, utils.PathJoin(path, cmd))
342

    
343
    (status, executable) = \
344
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
345
                                    _verify_dir=lambda _: (True, None),
346
                                    _verify_name=lambda _: (True, None),
347
                                    _verify_cmd=fn)
348
    self.assertTrue(status)
349
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
350

    
351

    
352
def _SleepForRestrictedCmd(duration):
353
  assert duration > 5
354

    
355

    
356
def _GenericRestrictedCmdError(cmd):
357
  return "Executing command '%s' failed" % cmd
358

    
359

    
360
class TestRunRestrictedCmd(unittest.TestCase):
361
  def setUp(self):
362
    self.tmpdir = tempfile.mkdtemp()
363

    
364
  def tearDown(self):
365
    shutil.rmtree(self.tmpdir)
366

    
367
  def testNonExistantLockDirectory(self):
368
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
369
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
370
    self.assertFalse(os.path.exists(lockfile))
371
    self.assertRaises(backend.RPCFail,
372
                      backend.RunRestrictedCmd, "test",
373
                      _lock_timeout=NotImplemented,
374
                      _lock_file=lockfile,
375
                      _path=NotImplemented,
376
                      _sleep_fn=sleep_fn,
377
                      _prepare_fn=NotImplemented,
378
                      _runcmd_fn=NotImplemented,
379
                      _enabled=True)
380
    self.assertEqual(sleep_fn.Count(), 1)
381

    
382
  @staticmethod
383
  def _TryLock(lockfile):
384
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
385

    
386
    result = False
387
    try:
388
      backend.RunRestrictedCmd("test22717",
389
                               _lock_timeout=0.1,
390
                               _lock_file=lockfile,
391
                               _path=NotImplemented,
392
                               _sleep_fn=sleep_fn,
393
                               _prepare_fn=NotImplemented,
394
                               _runcmd_fn=NotImplemented,
395
                               _enabled=True)
396
    except backend.RPCFail, err:
397
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
398
             "Did not fail with generic error message"
399
      result = True
400

    
401
    assert sleep_fn.Count() == 1
402

    
403
    return result
404

    
405
  def testLockHeldByOtherProcess(self):
406
    lockfile = utils.PathJoin(self.tmpdir, "lock")
407

    
408
    lock = utils.FileLock.Open(lockfile)
409
    lock.Exclusive(blocking=True, timeout=1.0)
410
    try:
411
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
412
    finally:
413
      lock.Close()
414

    
415
  @staticmethod
416
  def _PrepareRaisingException(path, cmd):
417
    assert cmd == "test23122"
418
    raise Exception("test")
419

    
420
  def testPrepareRaisesException(self):
421
    lockfile = utils.PathJoin(self.tmpdir, "lock")
422

    
423
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
424
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
425

    
426
    try:
427
      backend.RunRestrictedCmd("test23122",
428
                               _lock_timeout=1.0, _lock_file=lockfile,
429
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
430
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
431
                               _enabled=True)
432
    except backend.RPCFail, err:
433
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
434
    else:
435
      self.fail("Didn't fail")
436

    
437
    self.assertEqual(sleep_fn.Count(), 1)
438
    self.assertEqual(prepare_fn.Count(), 1)
439

    
440
  @staticmethod
441
  def _PrepareFails(path, cmd):
442
    assert cmd == "test29327"
443
    return ("some error message", None)
444

    
445
  def testPrepareFails(self):
446
    lockfile = utils.PathJoin(self.tmpdir, "lock")
447

    
448
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
449
    prepare_fn = testutils.CallCounter(self._PrepareFails)
450

    
451
    try:
452
      backend.RunRestrictedCmd("test29327",
453
                               _lock_timeout=1.0, _lock_file=lockfile,
454
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
455
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
456
                               _enabled=True)
457
    except backend.RPCFail, err:
458
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
459
    else:
460
      self.fail("Didn't fail")
461

    
462
    self.assertEqual(sleep_fn.Count(), 1)
463
    self.assertEqual(prepare_fn.Count(), 1)
464

    
465
  @staticmethod
466
  def _SuccessfulPrepare(path, cmd):
467
    return (True, utils.PathJoin(path, cmd))
468

    
469
  def testRunCmdFails(self):
470
    lockfile = utils.PathJoin(self.tmpdir, "lock")
471

    
472
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
473
           postfork_fn=NotImplemented):
474
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
475
      self.assertEqual(env, {})
476
      self.assertTrue(reset_env)
477
      self.assertTrue(callable(postfork_fn))
478

    
479
      trylock = utils.FileLock.Open(lockfile)
480
      try:
481
        # See if lockfile is still held
482
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
483

    
484
        # Call back to release lock
485
        postfork_fn(NotImplemented)
486

    
487
        # See if lockfile can be acquired
488
        trylock.Exclusive(blocking=False)
489
      finally:
490
        trylock.Close()
491

    
492
      # Simulate a failed command
493
      return utils.RunResult(constants.EXIT_FAILURE, None,
494
                             "stdout", "stderr406328567",
495
                             utils.ShellQuoteArgs(args),
496
                             NotImplemented, NotImplemented)
497

    
498
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
499
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
500
    runcmd_fn = testutils.CallCounter(fn)
501

    
502
    try:
503
      backend.RunRestrictedCmd("test3079",
504
                               _lock_timeout=1.0, _lock_file=lockfile,
505
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
506
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
507
                               _enabled=True)
508
    except backend.RPCFail, err:
509
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
510
                                          " failed:"))
511
      self.assertTrue("stderr406328567" in str(err),
512
                      msg="Error did not include output")
513
    else:
514
      self.fail("Didn't fail")
515

    
516
    self.assertEqual(sleep_fn.Count(), 0)
517
    self.assertEqual(prepare_fn.Count(), 1)
518
    self.assertEqual(runcmd_fn.Count(), 1)
519

    
520
  def testRunCmdSucceeds(self):
521
    lockfile = utils.PathJoin(self.tmpdir, "lock")
522

    
523
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
524
           postfork_fn=NotImplemented):
525
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
526
      self.assertEqual(env, {})
527
      self.assertTrue(reset_env)
528

    
529
      # Call back to release lock
530
      postfork_fn(NotImplemented)
531

    
532
      # Simulate a successful command
533
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
534
                             utils.ShellQuoteArgs(args),
535
                             NotImplemented, NotImplemented)
536

    
537
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
538
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
539
    runcmd_fn = testutils.CallCounter(fn)
540

    
541
    result = backend.RunRestrictedCmd("test5667",
542
                                      _lock_timeout=1.0, _lock_file=lockfile,
543
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
544
                                      _sleep_fn=sleep_fn,
545
                                      _prepare_fn=prepare_fn,
546
                                      _enabled=True)
547
    self.assertEqual(result, "stdout14463")
548

    
549
    self.assertEqual(sleep_fn.Count(), 0)
550
    self.assertEqual(prepare_fn.Count(), 1)
551
    self.assertEqual(runcmd_fn.Count(), 1)
552

    
553
  def testCommandsDisabled(self):
554
    try:
555
      backend.RunRestrictedCmd("test",
556
                               _lock_timeout=NotImplemented,
557
                               _lock_file=NotImplemented,
558
                               _path=NotImplemented,
559
                               _sleep_fn=NotImplemented,
560
                               _prepare_fn=NotImplemented,
561
                               _runcmd_fn=NotImplemented,
562
                               _enabled=False)
563
    except backend.RPCFail, err:
564
      self.assertEqual(str(err),
565
                       "Restricted commands disabled at configure time")
566
    else:
567
      self.fail("Did not raise exception")
568

    
569

    
570
class TestSetWatcherPause(unittest.TestCase):
571
  def setUp(self):
572
    self.tmpdir = tempfile.mkdtemp()
573
    self.filename = utils.PathJoin(self.tmpdir, "pause")
574

    
575
  def tearDown(self):
576
    shutil.rmtree(self.tmpdir)
577

    
578
  def testUnsetNonExisting(self):
579
    self.assertFalse(os.path.exists(self.filename))
580
    backend.SetWatcherPause(None, _filename=self.filename)
581
    self.assertFalse(os.path.exists(self.filename))
582

    
583
  def testSetNonNumeric(self):
584
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
585
      self.assertFalse(os.path.exists(self.filename))
586

    
587
      try:
588
        backend.SetWatcherPause(i, _filename=self.filename)
589
      except backend.RPCFail, err:
590
        self.assertEqual(str(err), "Duration must be numeric")
591
      else:
592
        self.fail("Did not raise exception")
593

    
594
      self.assertFalse(os.path.exists(self.filename))
595

    
596
  def testSet(self):
597
    self.assertFalse(os.path.exists(self.filename))
598

    
599
    for i in range(10):
600
      backend.SetWatcherPause(i, _filename=self.filename)
601
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
602
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
603

    
604

    
605
class TestGetBlockDevSymlinkPath(unittest.TestCase):
606
  def setUp(self):
607
    self.tmpdir = tempfile.mkdtemp()
608

    
609
  def tearDown(self):
610
    shutil.rmtree(self.tmpdir)
611

    
612
  def _Test(self, name, idx):
613
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
614
                                                     _dir=self.tmpdir),
615
                     ("%s/%s%s%s" % (self.tmpdir, name,
616
                                     constants.DISK_SEPARATOR, idx)))
617

    
618
  def test(self):
619
    for idx in range(100):
620
      self._Test("inst1.example.com", idx)
621

    
622

    
623
class TestGetInstanceList(unittest.TestCase):
624

    
625
  def setUp(self):
626
    self._test_hv = self._TestHypervisor()
627
    self._test_hv.ListInstances = mock.Mock(
628
      return_value=["instance1", "instance2", "instance3"] )
629

    
630
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
631
    def __init__(self):
632
      hypervisor.hv_base.BaseHypervisor.__init__(self)
633

    
634
  def _GetHypervisor(self, name):
635
    return self._test_hv
636

    
637
  def testHvparams(self):
638
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
639
    hvparams = {constants.HT_FAKE: fake_hvparams}
640
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
641
                            get_hv_fn=self._GetHypervisor)
642
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
643

    
644

    
645
class TestInstanceConsoleInfo(unittest.TestCase):
646

    
647
  def setUp(self):
648
    self._test_hv_a = self._TestHypervisor()
649
    self._test_hv_a.GetInstanceConsole = mock.Mock(
650
      return_value = objects.InstanceConsole(instance="inst", kind="aHy")
651
    )
652
    self._test_hv_b = self._TestHypervisor()
653
    self._test_hv_b.GetInstanceConsole = mock.Mock(
654
      return_value = objects.InstanceConsole(instance="inst", kind="bHy")
655
    )
656

    
657
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
658
    def __init__(self):
659
      hypervisor.hv_base.BaseHypervisor.__init__(self)
660

    
661
  def _GetHypervisor(self, name):
662
    if name == "a":
663
      return self._test_hv_a
664
    else:
665
      return self._test_hv_b
666

    
667
  def testRightHypervisor(self):
668
    dictMaker = lambda hyName: {
669
      "instance":{"hypervisor":hyName},
670
      "node":{},
671
      "group":{},
672
      "hvParams":{},
673
      "beParams":{},
674
    }
675

    
676
    call = {
677
      'i1':dictMaker("a"),
678
      'i2':dictMaker("b"),
679
    }
680

    
681
    res = backend.GetInstanceConsoleInfo(call, get_hv_fn=self._GetHypervisor)
682

    
683
    self.assertTrue(res["i1"]["kind"] == "aHy")
684
    self.assertTrue(res["i2"]["kind"] == "bHy")
685

    
686

    
687
class TestGetHvInfo(unittest.TestCase):
688

    
689
  def setUp(self):
690
    self._test_hv = self._TestHypervisor()
691
    self._test_hv.GetNodeInfo = mock.Mock()
692

    
693
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
694
    def __init__(self):
695
      hypervisor.hv_base.BaseHypervisor.__init__(self)
696

    
697
  def _GetHypervisor(self, name):
698
    return self._test_hv
699

    
700
  def testGetHvInfoAllNone(self):
701
    result = backend._GetHvInfoAll(None)
702
    self.assertTrue(result is None)
703

    
704
  def testGetHvInfoAll(self):
705
    hvname = constants.HT_XEN_PVM
706
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
707
    hv_specs = [(hvname, hvparams)]
708

    
709
    backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
710
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
711

    
712

    
713
class TestApplyStorageInfoFunction(unittest.TestCase):
714

    
715
  _STORAGE_KEY = "some_key"
716
  _SOME_ARGS = ["some_args"]
717

    
718
  def setUp(self):
719
    self.mock_storage_fn = mock.Mock()
720

    
721
  def testApplyValidStorageType(self):
722
    storage_type = constants.ST_LVM_VG
723
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
724
    backend._STORAGE_TYPE_INFO_FN = {
725
        storage_type: self.mock_storage_fn
726
      }
727

    
728
    backend._ApplyStorageInfoFunction(
729
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
730

    
731
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
732
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
733

    
734
  def testApplyInValidStorageType(self):
735
    storage_type = "invalid_storage_type"
736
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
737
    backend._STORAGE_TYPE_INFO_FN = {}
738

    
739
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
740
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
741
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
742

    
743
  def testApplyNotImplementedStorageType(self):
744
    storage_type = "not_implemented_storage_type"
745
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
746
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
747

    
748
    self.assertRaises(NotImplementedError,
749
                      backend._ApplyStorageInfoFunction,
750
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
751
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
752

    
753

    
754
class TestGetLvmVgSpaceInfo(unittest.TestCase):
755

    
756
  def testValid(self):
757
    path = "somepath"
758
    excl_stor = True
759
    orig_fn = backend._GetVgInfo
760
    backend._GetVgInfo = mock.Mock()
761
    backend._GetLvmVgSpaceInfo(path, [excl_stor])
762
    backend._GetVgInfo.assert_called_with(path, excl_stor)
763
    backend._GetVgInfo = orig_fn
764

    
765
  def testNoExclStorageNotBool(self):
766
    path = "somepath"
767
    excl_stor = "123"
768
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
769
                      path, [excl_stor])
770

    
771
  def testNoExclStorageNotInList(self):
772
    path = "somepath"
773
    excl_stor = "123"
774
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
775
                      path, excl_stor)
776

    
777
class TestGetLvmPvSpaceInfo(unittest.TestCase):
778

    
779
  def testValid(self):
780
    path = "somepath"
781
    excl_stor = True
782
    orig_fn = backend._GetVgSpindlesInfo
783
    backend._GetVgSpindlesInfo = mock.Mock()
784
    backend._GetLvmPvSpaceInfo(path, [excl_stor])
785
    backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
786
    backend._GetVgSpindlesInfo = orig_fn
787

    
788

    
789
class TestCheckStorageParams(unittest.TestCase):
790

    
791
  def testParamsNone(self):
792
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
793
                      None, NotImplemented)
794

    
795
  def testParamsWrongType(self):
796
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
797
                      "string", NotImplemented)
798

    
799
  def testParamsEmpty(self):
800
    backend._CheckStorageParams([], 0)
801

    
802
  def testParamsValidNumber(self):
803
    backend._CheckStorageParams(["a", True], 2)
804

    
805
  def testParamsInvalidNumber(self):
806
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
807
                      ["b", False], 3)
808

    
809

    
810
class TestGetVgSpindlesInfo(unittest.TestCase):
811

    
812
  def setUp(self):
813
    self.vg_free = 13
814
    self.vg_size = 31
815
    self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
816

    
817
  def testValidInput(self):
818
    name = "myvg"
819
    excl_stor = True
820
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
821
    self.mock_fn.assert_called_with(name)
822
    self.assertEqual(name, result["name"])
823
    self.assertEqual(constants.ST_LVM_PV, result["type"])
824
    self.assertEqual(self.vg_free, result["storage_free"])
825
    self.assertEqual(self.vg_size, result["storage_size"])
826

    
827
  def testNoExclStor(self):
828
    name = "myvg"
829
    excl_stor = False
830
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
831
    self.mock_fn.assert_not_called()
832
    self.assertEqual(name, result["name"])
833
    self.assertEqual(constants.ST_LVM_PV, result["type"])
834
    self.assertEqual(0, result["storage_free"])
835
    self.assertEqual(0, result["storage_size"])
836

    
837

    
838
class TestGetVgSpindlesInfo(unittest.TestCase):
839

    
840
  def testValidInput(self):
841
    self.vg_free = 13
842
    self.vg_size = 31
843
    self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
844
    name = "myvg"
845
    excl_stor = True
846
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
847
    self.mock_fn.assert_called_with([name], excl_stor)
848
    self.assertEqual(name, result["name"])
849
    self.assertEqual(constants.ST_LVM_VG, result["type"])
850
    self.assertEqual(self.vg_free, result["storage_free"])
851
    self.assertEqual(self.vg_size, result["storage_size"])
852

    
853
  def testNoExclStor(self):
854
    name = "myvg"
855
    excl_stor = True
856
    self.mock_fn = mock.Mock(return_value=None)
857
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
858
    self.mock_fn.assert_called_with([name], excl_stor)
859
    self.assertEqual(name, result["name"])
860
    self.assertEqual(constants.ST_LVM_VG, result["type"])
861
    self.assertEqual(None, result["storage_free"])
862
    self.assertEqual(None, result["storage_size"])
863

    
864

    
865
class TestGetNodeInfo(unittest.TestCase):
866

    
867
  _SOME_RESULT = None
868

    
869
  def testApplyStorageInfoFunction(self):
870
    orig_fn = backend._ApplyStorageInfoFunction
871
    backend._ApplyStorageInfoFunction = mock.Mock(
872
        return_value=self._SOME_RESULT)
873
    storage_units = [(st, st + "_key", [st + "_params"]) for st in
874
                     constants.STORAGE_TYPES]
875

    
876
    backend.GetNodeInfo(storage_units, None)
877

    
878
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
879
    self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
880
    for call in call_args_list:
881
      storage_type, storage_key, storage_params = call[0]
882
      self.assertEqual(storage_type + "_key", storage_key)
883
      self.assertEqual([storage_type + "_params"], storage_params)
884
      self.assertTrue(storage_type in constants.STORAGE_TYPES)
885
    backend._ApplyStorageInfoFunction = orig_fn
886

    
887

    
888
class TestSpaceReportingConstants(unittest.TestCase):
889
  """Ensures consistency between STS_REPORT and backend.
890

891
  These tests ensure, that the constant 'STS_REPORT' is consistent
892
  with the implementation of invoking space reporting functions
893
  in backend.py. Once space reporting is available for all types,
894
  the constant can be removed and these tests as well.
895

896
  """
897

    
898
  REPORTING = set(constants.STS_REPORT)
899
  NOT_REPORTING = set(constants.STORAGE_TYPES) - REPORTING
900

    
901
  def testAllReportingTypesHaveAReportingFunction(self):
902
    for storage_type in TestSpaceReportingConstants.REPORTING:
903
      self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
904

    
905
  def testAllNotReportingTypesDontHaveFunction(self):
906
    for storage_type in TestSpaceReportingConstants.NOT_REPORTING:
907
      self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
908

    
909

    
910
if __name__ == "__main__":
911
  testutils.GanetiTestProgram()