Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ ab4b1cf2

History | View | Annotate | Download (32.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import mock
25
import os
26
import shutil
27
import tempfile
28
import testutils
29
import unittest
30

    
31
from ganeti import backend
32
from ganeti import constants
33
from ganeti import errors
34
from ganeti import hypervisor
35
from ganeti import netutils
36
from ganeti import objects
37
from ganeti import pathutils
38
from ganeti import utils
39

    
40

    
41
class TestX509Certificates(unittest.TestCase):
42
  def setUp(self):
43
    self.tmpdir = tempfile.mkdtemp()
44

    
45
  def tearDown(self):
46
    shutil.rmtree(self.tmpdir)
47

    
48
  def test(self):
49
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
50

    
51
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
52
                                                 backend._X509_CERT_FILE)),
53
                     cert_pem)
54
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
55
                                                  backend._X509_KEY_FILE)))
56

    
57
    (name2, cert_pem2) = \
58
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
59

    
60
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
61
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
62

    
63
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
64

    
65
  def testNonEmpty(self):
66
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
67

    
68
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
69
                    data="Hello World")
70

    
71
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
72
                      name, cryptodir=self.tmpdir)
73

    
74
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
75

    
76

    
77
class TestGetCryptoTokens(testutils.GanetiTestCase):
78

    
79
  def setUp(self):
80
    self._get_digest_fn_orig = utils.GetCertificateDigest
81
    self._create_digest_fn_orig = utils.GenerateNewSslCert
82
    self._ssl_digest = "12345"
83
    utils.GetCertificateDigest = mock.Mock(
84
      return_value=self._ssl_digest)
85
    utils.GenerateNewSslCert = mock.Mock()
86

    
87
  def tearDown(self):
88
    utils.GetCertificateDigest = self._get_digest_fn_orig
89
    utils.GenerateNewSslCert = self._create_digest_fn_orig
90

    
91
  def testGetSslToken(self):
92
    result = backend.GetCryptoTokens(
93
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_GET, None)])
94
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
95
                    in result)
96

    
97
  def testCreateSslToken(self):
98
    result = backend.GetCryptoTokens(
99
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
100
        {constants.CRYPTO_OPTION_SERIAL_NO: 42})])
101
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
102
                    in result)
103
    self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
104

    
105
  def testCreateSslTokenDifferentFilename(self):
106
    result = backend.GetCryptoTokens(
107
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
108
        {constants.CRYPTO_OPTION_CERT_FILE:
109
          pathutils.NODED_CLIENT_CERT_FILE_TMP,
110
         constants.CRYPTO_OPTION_SERIAL_NO: 42})])
111
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
112
                    in result)
113
    self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
114

    
115
  def testCreateSslTokenSerialNo(self):
116
    result = backend.GetCryptoTokens(
117
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
118
        {constants.CRYPTO_OPTION_SERIAL_NO: 42})])
119
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
120
                    in result)
121
    self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
122

    
123
  def testUnknownTokenType(self):
124
    self.assertRaises(errors.ProgrammerError,
125
                      backend.GetCryptoTokens,
126
                      [("pink_bunny", constants.CRYPTO_ACTION_GET, None)])
127

    
128
  def testUnknownAction(self):
129
    self.assertRaises(errors.ProgrammerError,
130
                      backend.GetCryptoTokens,
131
                      [(constants.CRYPTO_TYPE_SSL_DIGEST, "illuminate", None)])
132

    
133

    
134
class TestNodeVerify(testutils.GanetiTestCase):
135

    
136
  def setUp(self):
137
    testutils.GanetiTestCase.setUp(self)
138
    self._mock_hv = None
139

    
140
  def _GetHypervisor(self, hv_name):
141
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
142
    self._mock_hv.ValidateParameters = mock.Mock()
143
    self._mock_hv.Verify = mock.Mock()
144
    return self._mock_hv
145

    
146
  def testMasterIPLocalhost(self):
147
    # this a real functional test, but requires localhost to be reachable
148
    local_data = (netutils.Hostname.GetSysName(),
149
                  constants.IP4_ADDRESS_LOCALHOST)
150
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data},
151
                                None, {}, {}, {})
152
    self.failUnless(constants.NV_MASTERIP in result,
153
                    "Master IP data not returned")
154
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
155

    
156
  def testMasterIPUnreachable(self):
157
    # Network 192.0.2.0/24 is reserved for test/documentation as per
158
    # RFC 5737
159
    bad_data =  ("master.example.com", "192.0.2.1")
160
    # we just test that whatever TcpPing returns, VerifyNode returns too
161
    netutils.TcpPing = lambda a, b, source=None: False
162
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data},
163
                                None, {}, {}, {})
164
    self.failUnless(constants.NV_MASTERIP in result,
165
                    "Master IP data not returned")
166
    self.failIf(result[constants.NV_MASTERIP],
167
                "Result from netutils.TcpPing corrupted")
168

    
169
  def testVerifyHvparams(self):
170
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
171
    test_what = {constants.NV_HVPARAMS: \
172
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
173
    result = {}
174
    backend._VerifyHvparams(test_what, True, result,
175
                            get_hv_fn=self._GetHypervisor)
176
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
177

    
178
  def testVerifyHypervisors(self):
179
    hvname = constants.HT_XEN_PVM
180
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
181
    all_hvparams = {hvname: hvparams}
182
    test_what = {constants.NV_HYPERVISOR: [hvname]}
183
    result = {}
184
    backend._VerifyHypervisors(
185
        test_what, True, result, all_hvparams=all_hvparams,
186
        get_hv_fn=self._GetHypervisor)
187
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
188

    
189
  @testutils.patch_object(utils, "VerifyCertificate")
190
  def testVerifyClientCertificateSuccess(self, verif_cert):
191
    # mock the underlying x509 verification because the test cert is expired
192
    verif_cert.return_value = (None, None)
193
    cert_file = testutils.TestDataFilename("cert2.pem")
194
    (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
195
    self.assertEqual(None, errcode)
196
    self.assertTrue(isinstance(digest, str))
197

    
198
  @testutils.patch_object(utils, "VerifyCertificate")
199
  def testVerifyClientCertificateFailed(self, verif_cert):
200
    expected_errcode = 666
201
    verif_cert.return_value = (expected_errcode,
202
                               "The devil created this certificate.")
203
    cert_file = testutils.TestDataFilename("cert2.pem")
204
    (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
205
    self.assertEqual(expected_errcode, errcode)
206

    
207
  def testVerifyClientCertificateNoCert(self):
208
    cert_file = testutils.TestDataFilename("cert-that-does-not-exist.pem")
209
    (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
210
    self.assertEqual(constants.CV_ERROR, errcode)
211

    
212

    
213
def _DefRestrictedCmdOwner():
214
  return (os.getuid(), os.getgid())
215

    
216

    
217
class TestVerifyRestrictedCmdName(unittest.TestCase):
218
  def testAcceptableName(self):
219
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
220
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
221
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
222
        self.assertTrue(status)
223
        self.assertTrue(msg is None)
224

    
225
  def testEmptyAndSpace(self):
226
    for i in ["", " ", "\t", "\n"]:
227
      (status, msg) = backend._VerifyRestrictedCmdName(i)
228
      self.assertFalse(status)
229
      self.assertEqual(msg, "Missing command name")
230

    
231
  def testNameWithSlashes(self):
232
    for i in ["/", "./foo", "../moo", "some/name"]:
233
      (status, msg) = backend._VerifyRestrictedCmdName(i)
234
      self.assertFalse(status)
235
      self.assertEqual(msg, "Invalid command name")
236

    
237
  def testForbiddenCharacters(self):
238
    for i in ["#", ".", "..", "bash -c ls", "'"]:
239
      (status, msg) = backend._VerifyRestrictedCmdName(i)
240
      self.assertFalse(status)
241
      self.assertEqual(msg, "Command name contains forbidden characters")
242

    
243

    
244
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
245
  def setUp(self):
246
    self.tmpdir = tempfile.mkdtemp()
247

    
248
  def tearDown(self):
249
    shutil.rmtree(self.tmpdir)
250

    
251
  def testCanNotStat(self):
252
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
253
    self.assertFalse(os.path.exists(tmpname))
254
    (status, msg) = \
255
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
256
    self.assertFalse(status)
257
    self.assertTrue(msg.startswith("Can't stat(2) '"))
258

    
259
  def testTooPermissive(self):
260
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
261
    os.mkdir(tmpname)
262

    
263
    for mode in [0777, 0706, 0760, 0722]:
264
      os.chmod(tmpname, mode)
265
      self.assertTrue(os.path.isdir(tmpname))
266
      (status, msg) = \
267
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
268
      self.assertFalse(status)
269
      self.assertTrue(msg.startswith("Permissions on '"))
270

    
271
  def testNoDirectory(self):
272
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
273
    utils.WriteFile(tmpname, data="empty\n")
274
    self.assertTrue(os.path.isfile(tmpname))
275
    (status, msg) = \
276
      backend._VerifyRestrictedCmdDirectory(tmpname,
277
                                            _owner=_DefRestrictedCmdOwner())
278
    self.assertFalse(status)
279
    self.assertTrue(msg.endswith("is not a directory"))
280

    
281
  def testNormal(self):
282
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
283
    os.mkdir(tmpname)
284
    os.chmod(tmpname, 0755)
285
    self.assertTrue(os.path.isdir(tmpname))
286
    (status, msg) = \
287
      backend._VerifyRestrictedCmdDirectory(tmpname,
288
                                            _owner=_DefRestrictedCmdOwner())
289
    self.assertTrue(status)
290
    self.assertTrue(msg is None)
291

    
292

    
293
class TestVerifyRestrictedCmd(unittest.TestCase):
294
  def setUp(self):
295
    self.tmpdir = tempfile.mkdtemp()
296

    
297
  def tearDown(self):
298
    shutil.rmtree(self.tmpdir)
299

    
300
  def testCanNotStat(self):
301
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
302
    self.assertFalse(os.path.exists(tmpname))
303
    (status, msg) = \
304
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
305
                                   _owner=NotImplemented)
306
    self.assertFalse(status)
307
    self.assertTrue(msg.startswith("Can't stat(2) '"))
308

    
309
  def testNotExecutable(self):
310
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
311
    utils.WriteFile(tmpname, data="empty\n")
312
    (status, msg) = \
313
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
314
                                   _owner=_DefRestrictedCmdOwner())
315
    self.assertFalse(status)
316
    self.assertTrue(msg.startswith("access(2) thinks '"))
317

    
318
  def testExecutable(self):
319
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
320
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
321
    (status, executable) = \
322
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
323
                                   _owner=_DefRestrictedCmdOwner())
324
    self.assertTrue(status)
325
    self.assertEqual(executable, tmpname)
326

    
327

    
328
class TestPrepareRestrictedCmd(unittest.TestCase):
329
  _TEST_PATH = "/tmp/some/test/path"
330

    
331
  def testDirFails(self):
332
    def fn(path):
333
      self.assertEqual(path, self._TEST_PATH)
334
      return (False, "test error 31420")
335

    
336
    (status, msg) = \
337
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
338
                                    _verify_dir=fn,
339
                                    _verify_name=NotImplemented,
340
                                    _verify_cmd=NotImplemented)
341
    self.assertFalse(status)
342
    self.assertEqual(msg, "test error 31420")
343

    
344
  def testNameFails(self):
345
    def fn(cmd):
346
      self.assertEqual(cmd, "cmd4617")
347
      return (False, "test error 591")
348

    
349
    (status, msg) = \
350
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
351
                                    _verify_dir=lambda _: (True, None),
352
                                    _verify_name=fn,
353
                                    _verify_cmd=NotImplemented)
354
    self.assertFalse(status)
355
    self.assertEqual(msg, "test error 591")
356

    
357
  def testCommandFails(self):
358
    def fn(path, cmd):
359
      self.assertEqual(path, self._TEST_PATH)
360
      self.assertEqual(cmd, "cmd17577")
361
      return (False, "test error 25524")
362

    
363
    (status, msg) = \
364
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
365
                                    _verify_dir=lambda _: (True, None),
366
                                    _verify_name=lambda _: (True, None),
367
                                    _verify_cmd=fn)
368
    self.assertFalse(status)
369
    self.assertEqual(msg, "test error 25524")
370

    
371
  def testSuccess(self):
372
    def fn(path, cmd):
373
      return (True, utils.PathJoin(path, cmd))
374

    
375
    (status, executable) = \
376
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
377
                                    _verify_dir=lambda _: (True, None),
378
                                    _verify_name=lambda _: (True, None),
379
                                    _verify_cmd=fn)
380
    self.assertTrue(status)
381
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
382

    
383

    
384
def _SleepForRestrictedCmd(duration):
385
  assert duration > 5
386

    
387

    
388
def _GenericRestrictedCmdError(cmd):
389
  return "Executing command '%s' failed" % cmd
390

    
391

    
392
class TestRunRestrictedCmd(unittest.TestCase):
393
  def setUp(self):
394
    self.tmpdir = tempfile.mkdtemp()
395

    
396
  def tearDown(self):
397
    shutil.rmtree(self.tmpdir)
398

    
399
  def testNonExistantLockDirectory(self):
400
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
401
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
402
    self.assertFalse(os.path.exists(lockfile))
403
    self.assertRaises(backend.RPCFail,
404
                      backend.RunRestrictedCmd, "test",
405
                      _lock_timeout=NotImplemented,
406
                      _lock_file=lockfile,
407
                      _path=NotImplemented,
408
                      _sleep_fn=sleep_fn,
409
                      _prepare_fn=NotImplemented,
410
                      _runcmd_fn=NotImplemented,
411
                      _enabled=True)
412
    self.assertEqual(sleep_fn.Count(), 1)
413

    
414
  @staticmethod
415
  def _TryLock(lockfile):
416
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
417

    
418
    result = False
419
    try:
420
      backend.RunRestrictedCmd("test22717",
421
                               _lock_timeout=0.1,
422
                               _lock_file=lockfile,
423
                               _path=NotImplemented,
424
                               _sleep_fn=sleep_fn,
425
                               _prepare_fn=NotImplemented,
426
                               _runcmd_fn=NotImplemented,
427
                               _enabled=True)
428
    except backend.RPCFail, err:
429
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
430
             "Did not fail with generic error message"
431
      result = True
432

    
433
    assert sleep_fn.Count() == 1
434

    
435
    return result
436

    
437
  def testLockHeldByOtherProcess(self):
438
    lockfile = utils.PathJoin(self.tmpdir, "lock")
439

    
440
    lock = utils.FileLock.Open(lockfile)
441
    lock.Exclusive(blocking=True, timeout=1.0)
442
    try:
443
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
444
    finally:
445
      lock.Close()
446

    
447
  @staticmethod
448
  def _PrepareRaisingException(path, cmd):
449
    assert cmd == "test23122"
450
    raise Exception("test")
451

    
452
  def testPrepareRaisesException(self):
453
    lockfile = utils.PathJoin(self.tmpdir, "lock")
454

    
455
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
456
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
457

    
458
    try:
459
      backend.RunRestrictedCmd("test23122",
460
                               _lock_timeout=1.0, _lock_file=lockfile,
461
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
462
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
463
                               _enabled=True)
464
    except backend.RPCFail, err:
465
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
466
    else:
467
      self.fail("Didn't fail")
468

    
469
    self.assertEqual(sleep_fn.Count(), 1)
470
    self.assertEqual(prepare_fn.Count(), 1)
471

    
472
  @staticmethod
473
  def _PrepareFails(path, cmd):
474
    assert cmd == "test29327"
475
    return ("some error message", None)
476

    
477
  def testPrepareFails(self):
478
    lockfile = utils.PathJoin(self.tmpdir, "lock")
479

    
480
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
481
    prepare_fn = testutils.CallCounter(self._PrepareFails)
482

    
483
    try:
484
      backend.RunRestrictedCmd("test29327",
485
                               _lock_timeout=1.0, _lock_file=lockfile,
486
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
487
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
488
                               _enabled=True)
489
    except backend.RPCFail, err:
490
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
491
    else:
492
      self.fail("Didn't fail")
493

    
494
    self.assertEqual(sleep_fn.Count(), 1)
495
    self.assertEqual(prepare_fn.Count(), 1)
496

    
497
  @staticmethod
498
  def _SuccessfulPrepare(path, cmd):
499
    return (True, utils.PathJoin(path, cmd))
500

    
501
  def testRunCmdFails(self):
502
    lockfile = utils.PathJoin(self.tmpdir, "lock")
503

    
504
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
505
           postfork_fn=NotImplemented):
506
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
507
      self.assertEqual(env, {})
508
      self.assertTrue(reset_env)
509
      self.assertTrue(callable(postfork_fn))
510

    
511
      trylock = utils.FileLock.Open(lockfile)
512
      try:
513
        # See if lockfile is still held
514
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
515

    
516
        # Call back to release lock
517
        postfork_fn(NotImplemented)
518

    
519
        # See if lockfile can be acquired
520
        trylock.Exclusive(blocking=False)
521
      finally:
522
        trylock.Close()
523

    
524
      # Simulate a failed command
525
      return utils.RunResult(constants.EXIT_FAILURE, None,
526
                             "stdout", "stderr406328567",
527
                             utils.ShellQuoteArgs(args),
528
                             NotImplemented, NotImplemented)
529

    
530
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
531
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
532
    runcmd_fn = testutils.CallCounter(fn)
533

    
534
    try:
535
      backend.RunRestrictedCmd("test3079",
536
                               _lock_timeout=1.0, _lock_file=lockfile,
537
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
538
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
539
                               _enabled=True)
540
    except backend.RPCFail, err:
541
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
542
                                          " failed:"))
543
      self.assertTrue("stderr406328567" in str(err),
544
                      msg="Error did not include output")
545
    else:
546
      self.fail("Didn't fail")
547

    
548
    self.assertEqual(sleep_fn.Count(), 0)
549
    self.assertEqual(prepare_fn.Count(), 1)
550
    self.assertEqual(runcmd_fn.Count(), 1)
551

    
552
  def testRunCmdSucceeds(self):
553
    lockfile = utils.PathJoin(self.tmpdir, "lock")
554

    
555
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
556
           postfork_fn=NotImplemented):
557
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
558
      self.assertEqual(env, {})
559
      self.assertTrue(reset_env)
560

    
561
      # Call back to release lock
562
      postfork_fn(NotImplemented)
563

    
564
      # Simulate a successful command
565
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
566
                             utils.ShellQuoteArgs(args),
567
                             NotImplemented, NotImplemented)
568

    
569
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
570
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
571
    runcmd_fn = testutils.CallCounter(fn)
572

    
573
    result = backend.RunRestrictedCmd("test5667",
574
                                      _lock_timeout=1.0, _lock_file=lockfile,
575
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
576
                                      _sleep_fn=sleep_fn,
577
                                      _prepare_fn=prepare_fn,
578
                                      _enabled=True)
579
    self.assertEqual(result, "stdout14463")
580

    
581
    self.assertEqual(sleep_fn.Count(), 0)
582
    self.assertEqual(prepare_fn.Count(), 1)
583
    self.assertEqual(runcmd_fn.Count(), 1)
584

    
585
  def testCommandsDisabled(self):
586
    try:
587
      backend.RunRestrictedCmd("test",
588
                               _lock_timeout=NotImplemented,
589
                               _lock_file=NotImplemented,
590
                               _path=NotImplemented,
591
                               _sleep_fn=NotImplemented,
592
                               _prepare_fn=NotImplemented,
593
                               _runcmd_fn=NotImplemented,
594
                               _enabled=False)
595
    except backend.RPCFail, err:
596
      self.assertEqual(str(err),
597
                       "Restricted commands disabled at configure time")
598
    else:
599
      self.fail("Did not raise exception")
600

    
601

    
602
class TestSetWatcherPause(unittest.TestCase):
603
  def setUp(self):
604
    self.tmpdir = tempfile.mkdtemp()
605
    self.filename = utils.PathJoin(self.tmpdir, "pause")
606

    
607
  def tearDown(self):
608
    shutil.rmtree(self.tmpdir)
609

    
610
  def testUnsetNonExisting(self):
611
    self.assertFalse(os.path.exists(self.filename))
612
    backend.SetWatcherPause(None, _filename=self.filename)
613
    self.assertFalse(os.path.exists(self.filename))
614

    
615
  def testSetNonNumeric(self):
616
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
617
      self.assertFalse(os.path.exists(self.filename))
618

    
619
      try:
620
        backend.SetWatcherPause(i, _filename=self.filename)
621
      except backend.RPCFail, err:
622
        self.assertEqual(str(err), "Duration must be numeric")
623
      else:
624
        self.fail("Did not raise exception")
625

    
626
      self.assertFalse(os.path.exists(self.filename))
627

    
628
  def testSet(self):
629
    self.assertFalse(os.path.exists(self.filename))
630

    
631
    for i in range(10):
632
      backend.SetWatcherPause(i, _filename=self.filename)
633
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
634
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
635

    
636

    
637
class TestGetBlockDevSymlinkPath(unittest.TestCase):
638
  def setUp(self):
639
    self.tmpdir = tempfile.mkdtemp()
640

    
641
  def tearDown(self):
642
    shutil.rmtree(self.tmpdir)
643

    
644
  def _Test(self, name, idx):
645
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
646
                                                     _dir=self.tmpdir),
647
                     ("%s/%s%s%s" % (self.tmpdir, name,
648
                                     constants.DISK_SEPARATOR, idx)))
649

    
650
  def test(self):
651
    for idx in range(100):
652
      self._Test("inst1.example.com", idx)
653

    
654

    
655
class TestGetInstanceList(unittest.TestCase):
656

    
657
  def setUp(self):
658
    self._test_hv = self._TestHypervisor()
659
    self._test_hv.ListInstances = mock.Mock(
660
      return_value=["instance1", "instance2", "instance3"] )
661

    
662
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
663
    def __init__(self):
664
      hypervisor.hv_base.BaseHypervisor.__init__(self)
665

    
666
  def _GetHypervisor(self, name):
667
    return self._test_hv
668

    
669
  def testHvparams(self):
670
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
671
    hvparams = {constants.HT_FAKE: fake_hvparams}
672
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
673
                            get_hv_fn=self._GetHypervisor)
674
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
675

    
676

    
677
class TestInstanceConsoleInfo(unittest.TestCase):
678

    
679
  def setUp(self):
680
    self._test_hv_a = self._TestHypervisor()
681
    self._test_hv_a.GetInstanceConsole = mock.Mock(
682
      return_value = objects.InstanceConsole(instance="inst", kind="aHy")
683
    )
684
    self._test_hv_b = self._TestHypervisor()
685
    self._test_hv_b.GetInstanceConsole = mock.Mock(
686
      return_value = objects.InstanceConsole(instance="inst", kind="bHy")
687
    )
688

    
689
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
690
    def __init__(self):
691
      hypervisor.hv_base.BaseHypervisor.__init__(self)
692

    
693
  def _GetHypervisor(self, name):
694
    if name == "a":
695
      return self._test_hv_a
696
    else:
697
      return self._test_hv_b
698

    
699
  def testRightHypervisor(self):
700
    dictMaker = lambda hyName: {
701
      "instance":{"hypervisor":hyName},
702
      "node":{},
703
      "group":{},
704
      "hvParams":{},
705
      "beParams":{},
706
    }
707

    
708
    call = {
709
      'i1':dictMaker("a"),
710
      'i2':dictMaker("b"),
711
    }
712

    
713
    res = backend.GetInstanceConsoleInfo(call, get_hv_fn=self._GetHypervisor)
714

    
715
    self.assertTrue(res["i1"]["kind"] == "aHy")
716
    self.assertTrue(res["i2"]["kind"] == "bHy")
717

    
718

    
719
class TestGetHvInfo(unittest.TestCase):
720

    
721
  def setUp(self):
722
    self._test_hv = self._TestHypervisor()
723
    self._test_hv.GetNodeInfo = mock.Mock()
724

    
725
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
726
    def __init__(self):
727
      hypervisor.hv_base.BaseHypervisor.__init__(self)
728

    
729
  def _GetHypervisor(self, name):
730
    return self._test_hv
731

    
732
  def testGetHvInfoAllNone(self):
733
    result = backend._GetHvInfoAll(None)
734
    self.assertTrue(result is None)
735

    
736
  def testGetHvInfoAll(self):
737
    hvname = constants.HT_XEN_PVM
738
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
739
    hv_specs = [(hvname, hvparams)]
740

    
741
    backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
742
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
743

    
744

    
745
class TestApplyStorageInfoFunction(unittest.TestCase):
746

    
747
  _STORAGE_KEY = "some_key"
748
  _SOME_ARGS = ["some_args"]
749

    
750
  def setUp(self):
751
    self.mock_storage_fn = mock.Mock()
752

    
753
  def testApplyValidStorageType(self):
754
    storage_type = constants.ST_LVM_VG
755
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
756
    backend._STORAGE_TYPE_INFO_FN = {
757
        storage_type: self.mock_storage_fn
758
      }
759

    
760
    backend._ApplyStorageInfoFunction(
761
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
762

    
763
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
764
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
765

    
766
  def testApplyInValidStorageType(self):
767
    storage_type = "invalid_storage_type"
768
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
769
    backend._STORAGE_TYPE_INFO_FN = {}
770

    
771
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
772
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
773
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
774

    
775
  def testApplyNotImplementedStorageType(self):
776
    storage_type = "not_implemented_storage_type"
777
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
778
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
779

    
780
    self.assertRaises(NotImplementedError,
781
                      backend._ApplyStorageInfoFunction,
782
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
783
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
784

    
785

    
786
class TestGetLvmVgSpaceInfo(unittest.TestCase):
787

    
788
  def testValid(self):
789
    path = "somepath"
790
    excl_stor = True
791
    orig_fn = backend._GetVgInfo
792
    backend._GetVgInfo = mock.Mock()
793
    backend._GetLvmVgSpaceInfo(path, [excl_stor])
794
    backend._GetVgInfo.assert_called_with(path, excl_stor)
795
    backend._GetVgInfo = orig_fn
796

    
797
  def testNoExclStorageNotBool(self):
798
    path = "somepath"
799
    excl_stor = "123"
800
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
801
                      path, [excl_stor])
802

    
803
  def testNoExclStorageNotInList(self):
804
    path = "somepath"
805
    excl_stor = "123"
806
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
807
                      path, excl_stor)
808

    
809
class TestGetLvmPvSpaceInfo(unittest.TestCase):
810

    
811
  def testValid(self):
812
    path = "somepath"
813
    excl_stor = True
814
    orig_fn = backend._GetVgSpindlesInfo
815
    backend._GetVgSpindlesInfo = mock.Mock()
816
    backend._GetLvmPvSpaceInfo(path, [excl_stor])
817
    backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
818
    backend._GetVgSpindlesInfo = orig_fn
819

    
820

    
821
class TestCheckStorageParams(unittest.TestCase):
822

    
823
  def testParamsNone(self):
824
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
825
                      None, NotImplemented)
826

    
827
  def testParamsWrongType(self):
828
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
829
                      "string", NotImplemented)
830

    
831
  def testParamsEmpty(self):
832
    backend._CheckStorageParams([], 0)
833

    
834
  def testParamsValidNumber(self):
835
    backend._CheckStorageParams(["a", True], 2)
836

    
837
  def testParamsInvalidNumber(self):
838
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
839
                      ["b", False], 3)
840

    
841

    
842
class TestGetVgSpindlesInfo(unittest.TestCase):
843

    
844
  def setUp(self):
845
    self.vg_free = 13
846
    self.vg_size = 31
847
    self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
848

    
849
  def testValidInput(self):
850
    name = "myvg"
851
    excl_stor = True
852
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
853
    self.mock_fn.assert_called_with(name)
854
    self.assertEqual(name, result["name"])
855
    self.assertEqual(constants.ST_LVM_PV, result["type"])
856
    self.assertEqual(self.vg_free, result["storage_free"])
857
    self.assertEqual(self.vg_size, result["storage_size"])
858

    
859
  def testNoExclStor(self):
860
    name = "myvg"
861
    excl_stor = False
862
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
863
    self.mock_fn.assert_not_called()
864
    self.assertEqual(name, result["name"])
865
    self.assertEqual(constants.ST_LVM_PV, result["type"])
866
    self.assertEqual(0, result["storage_free"])
867
    self.assertEqual(0, result["storage_size"])
868

    
869

    
870
class TestGetVgSpindlesInfo(unittest.TestCase):
871

    
872
  def testValidInput(self):
873
    self.vg_free = 13
874
    self.vg_size = 31
875
    self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
876
    name = "myvg"
877
    excl_stor = True
878
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
879
    self.mock_fn.assert_called_with([name], excl_stor)
880
    self.assertEqual(name, result["name"])
881
    self.assertEqual(constants.ST_LVM_VG, result["type"])
882
    self.assertEqual(self.vg_free, result["storage_free"])
883
    self.assertEqual(self.vg_size, result["storage_size"])
884

    
885
  def testNoExclStor(self):
886
    name = "myvg"
887
    excl_stor = True
888
    self.mock_fn = mock.Mock(return_value=None)
889
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
890
    self.mock_fn.assert_called_with([name], excl_stor)
891
    self.assertEqual(name, result["name"])
892
    self.assertEqual(constants.ST_LVM_VG, result["type"])
893
    self.assertEqual(None, result["storage_free"])
894
    self.assertEqual(None, result["storage_size"])
895

    
896

    
897
class TestGetNodeInfo(unittest.TestCase):
898

    
899
  _SOME_RESULT = None
900

    
901
  def testApplyStorageInfoFunction(self):
902
    orig_fn = backend._ApplyStorageInfoFunction
903
    backend._ApplyStorageInfoFunction = mock.Mock(
904
        return_value=self._SOME_RESULT)
905
    storage_units = [(st, st + "_key", [st + "_params"]) for st in
906
                     constants.STORAGE_TYPES]
907

    
908
    backend.GetNodeInfo(storage_units, None)
909

    
910
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
911
    self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
912
    for call in call_args_list:
913
      storage_type, storage_key, storage_params = call[0]
914
      self.assertEqual(storage_type + "_key", storage_key)
915
      self.assertEqual([storage_type + "_params"], storage_params)
916
      self.assertTrue(storage_type in constants.STORAGE_TYPES)
917
    backend._ApplyStorageInfoFunction = orig_fn
918

    
919

    
920
class TestSpaceReportingConstants(unittest.TestCase):
921
  """Ensures consistency between STS_REPORT and backend.
922

923
  These tests ensure, that the constant 'STS_REPORT' is consistent
924
  with the implementation of invoking space reporting functions
925
  in backend.py. Once space reporting is available for all types,
926
  the constant can be removed and these tests as well.
927

928
  """
929

    
930
  REPORTING = set(constants.STS_REPORT)
931
  NOT_REPORTING = set(constants.STORAGE_TYPES) - REPORTING
932

    
933
  def testAllReportingTypesHaveAReportingFunction(self):
934
    for storage_type in TestSpaceReportingConstants.REPORTING:
935
      self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
936

    
937
  def testAllNotReportingTypesDontHaveFunction(self):
938
    for storage_type in TestSpaceReportingConstants.NOT_REPORTING:
939
      self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
940

    
941

    
942
if __name__ == "__main__":
943
  testutils.GanetiTestProgram()