Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.backend_unittest.py @ a6c43c02

History | View | Annotate | Download (32.3 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.backend"""
23

    
24
import mock
25
import os
26
import shutil
27
import tempfile
28
import testutils
29
import unittest
30

    
31
from ganeti import backend
32
from ganeti import constants
33
from ganeti import errors
34
from ganeti import hypervisor
35
from ganeti import netutils
36
from ganeti import objects
37
from ganeti import pathutils
38
from ganeti import utils
39

    
40

    
41
class TestX509Certificates(unittest.TestCase):
42
  def setUp(self):
43
    self.tmpdir = tempfile.mkdtemp()
44

    
45
  def tearDown(self):
46
    shutil.rmtree(self.tmpdir)
47

    
48
  def test(self):
49
    (name, cert_pem) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
50

    
51
    self.assertEqual(utils.ReadFile(os.path.join(self.tmpdir, name,
52
                                                 backend._X509_CERT_FILE)),
53
                     cert_pem)
54
    self.assert_(0 < os.path.getsize(os.path.join(self.tmpdir, name,
55
                                                  backend._X509_KEY_FILE)))
56

    
57
    (name2, cert_pem2) = \
58
      backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
59

    
60
    backend.RemoveX509Certificate(name, cryptodir=self.tmpdir)
61
    backend.RemoveX509Certificate(name2, cryptodir=self.tmpdir)
62

    
63
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [])
64

    
65
  def testNonEmpty(self):
66
    (name, _) = backend.CreateX509Certificate(300, cryptodir=self.tmpdir)
67

    
68
    utils.WriteFile(utils.PathJoin(self.tmpdir, name, "hello-world"),
69
                    data="Hello World")
70

    
71
    self.assertRaises(backend.RPCFail, backend.RemoveX509Certificate,
72
                      name, cryptodir=self.tmpdir)
73

    
74
    self.assertEqual(utils.ListVisibleFiles(self.tmpdir), [name])
75

    
76

    
77
class TestGetCryptoTokens(testutils.GanetiTestCase):
78

    
79
  def setUp(self):
80
    self._get_digest_fn_orig = utils.GetCertificateDigest
81
    self._create_digest_fn_orig = utils.GenerateNewSslCert
82
    self._ssl_digest = "12345"
83
    utils.GetCertificateDigest = mock.Mock(
84
      return_value=self._ssl_digest)
85
    utils.GenerateNewSslCert = mock.Mock()
86

    
87
  def tearDown(self):
88
    utils.GetCertificateDigest = self._get_digest_fn_orig
89
    utils.GenerateNewSslCert = self._create_digest_fn_orig
90

    
91
  def testGetSslToken(self):
92
    result = backend.GetCryptoTokens(
93
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_GET, None)])
94
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
95
                    in result)
96

    
97
  def testCreateSslToken(self):
98
    result = backend.GetCryptoTokens(
99
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
100
        None)])
101
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
102
                    in result)
103
    self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
104

    
105
  def testCreateSslTokenDifferentFilename(self):
106
    result = backend.GetCryptoTokens(
107
      [(constants.CRYPTO_TYPE_SSL_DIGEST, constants.CRYPTO_ACTION_CREATE,
108
        {constants.CRYPTO_OPTION_CERT_FILE:
109
          pathutils.NODED_CLIENT_CERT_FILE_TMP})])
110
    self.assertTrue((constants.CRYPTO_TYPE_SSL_DIGEST, self._ssl_digest)
111
                    in result)
112
    self.assertTrue(utils.GenerateNewSslCert.assert_calls().once())
113

    
114
  def testUnknownTokenType(self):
115
    self.assertRaises(errors.ProgrammerError,
116
                      backend.GetCryptoTokens,
117
                      [("pink_bunny", constants.CRYPTO_ACTION_GET, None)])
118

    
119
  def testUnknownAction(self):
120
    self.assertRaises(errors.ProgrammerError,
121
                      backend.GetCryptoTokens,
122
                      [(constants.CRYPTO_TYPE_SSL_DIGEST, "illuminate", None)])
123

    
124

    
125
class TestNodeVerify(testutils.GanetiTestCase):
126

    
127
  def setUp(self):
128
    testutils.GanetiTestCase.setUp(self)
129
    self._mock_hv = None
130

    
131
  def _GetHypervisor(self, hv_name):
132
    self._mock_hv = hypervisor.GetHypervisor(hv_name)
133
    self._mock_hv.ValidateParameters = mock.Mock()
134
    self._mock_hv.Verify = mock.Mock()
135
    return self._mock_hv
136

    
137
  def testMasterIPLocalhost(self):
138
    # this a real functional test, but requires localhost to be reachable
139
    local_data = (netutils.Hostname.GetSysName(),
140
                  constants.IP4_ADDRESS_LOCALHOST)
141
    result = backend.VerifyNode({constants.NV_MASTERIP: local_data},
142
                                None, {}, {}, {})
143
    self.failUnless(constants.NV_MASTERIP in result,
144
                    "Master IP data not returned")
145
    self.failUnless(result[constants.NV_MASTERIP], "Cannot reach localhost")
146

    
147
  def testMasterIPUnreachable(self):
148
    # Network 192.0.2.0/24 is reserved for test/documentation as per
149
    # RFC 5737
150
    bad_data =  ("master.example.com", "192.0.2.1")
151
    # we just test that whatever TcpPing returns, VerifyNode returns too
152
    netutils.TcpPing = lambda a, b, source=None: False
153
    result = backend.VerifyNode({constants.NV_MASTERIP: bad_data},
154
                                None, {}, {}, {})
155
    self.failUnless(constants.NV_MASTERIP in result,
156
                    "Master IP data not returned")
157
    self.failIf(result[constants.NV_MASTERIP],
158
                "Result from netutils.TcpPing corrupted")
159

    
160
  def testVerifyHvparams(self):
161
    test_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
162
    test_what = {constants.NV_HVPARAMS: \
163
        [("mynode", constants.HT_XEN_PVM, test_hvparams)]}
164
    result = {}
165
    backend._VerifyHvparams(test_what, True, result,
166
                            get_hv_fn=self._GetHypervisor)
167
    self._mock_hv.ValidateParameters.assert_called_with(test_hvparams)
168

    
169
  def testVerifyHypervisors(self):
170
    hvname = constants.HT_XEN_PVM
171
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
172
    all_hvparams = {hvname: hvparams}
173
    test_what = {constants.NV_HYPERVISOR: [hvname]}
174
    result = {}
175
    backend._VerifyHypervisors(
176
        test_what, True, result, all_hvparams=all_hvparams,
177
        get_hv_fn=self._GetHypervisor)
178
    self._mock_hv.Verify.assert_called_with(hvparams=hvparams)
179

    
180
  @testutils.patch_object(utils, "VerifyCertificate")
181
  def testVerifyClientCertificateSuccess(self, verif_cert):
182
    # mock the underlying x509 verification because the test cert is expired
183
    verif_cert.return_value = (None, None)
184
    cert_file = testutils.TestDataFilename("cert2.pem")
185
    (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
186
    self.assertEqual(None, errcode)
187
    self.assertTrue(isinstance(digest, str))
188

    
189
  @testutils.patch_object(utils, "VerifyCertificate")
190
  def testVerifyClientCertificateFailed(self, verif_cert):
191
    expected_errcode = 666
192
    verif_cert.return_value = (expected_errcode,
193
                               "The devil created this certificate.")
194
    cert_file = testutils.TestDataFilename("cert2.pem")
195
    (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
196
    self.assertEqual(expected_errcode, errcode)
197

    
198
  def testVerifyClientCertificateNoCert(self):
199
    cert_file = testutils.TestDataFilename("cert-that-does-not-exist.pem")
200
    (errcode, digest) = backend._VerifyClientCertificate(cert_file=cert_file)
201
    self.assertEqual(constants.CV_ERROR, errcode)
202

    
203

    
204
def _DefRestrictedCmdOwner():
205
  return (os.getuid(), os.getgid())
206

    
207

    
208
class TestVerifyRestrictedCmdName(unittest.TestCase):
209
  def testAcceptableName(self):
210
    for i in ["foo", "bar", "z1", "000first", "hello-world"]:
211
      for fn in [lambda s: s, lambda s: s.upper(), lambda s: s.title()]:
212
        (status, msg) = backend._VerifyRestrictedCmdName(fn(i))
213
        self.assertTrue(status)
214
        self.assertTrue(msg is None)
215

    
216
  def testEmptyAndSpace(self):
217
    for i in ["", " ", "\t", "\n"]:
218
      (status, msg) = backend._VerifyRestrictedCmdName(i)
219
      self.assertFalse(status)
220
      self.assertEqual(msg, "Missing command name")
221

    
222
  def testNameWithSlashes(self):
223
    for i in ["/", "./foo", "../moo", "some/name"]:
224
      (status, msg) = backend._VerifyRestrictedCmdName(i)
225
      self.assertFalse(status)
226
      self.assertEqual(msg, "Invalid command name")
227

    
228
  def testForbiddenCharacters(self):
229
    for i in ["#", ".", "..", "bash -c ls", "'"]:
230
      (status, msg) = backend._VerifyRestrictedCmdName(i)
231
      self.assertFalse(status)
232
      self.assertEqual(msg, "Command name contains forbidden characters")
233

    
234

    
235
class TestVerifyRestrictedCmdDirectory(unittest.TestCase):
236
  def setUp(self):
237
    self.tmpdir = tempfile.mkdtemp()
238

    
239
  def tearDown(self):
240
    shutil.rmtree(self.tmpdir)
241

    
242
  def testCanNotStat(self):
243
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
244
    self.assertFalse(os.path.exists(tmpname))
245
    (status, msg) = \
246
      backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
247
    self.assertFalse(status)
248
    self.assertTrue(msg.startswith("Can't stat(2) '"))
249

    
250
  def testTooPermissive(self):
251
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
252
    os.mkdir(tmpname)
253

    
254
    for mode in [0777, 0706, 0760, 0722]:
255
      os.chmod(tmpname, mode)
256
      self.assertTrue(os.path.isdir(tmpname))
257
      (status, msg) = \
258
        backend._VerifyRestrictedCmdDirectory(tmpname, _owner=NotImplemented)
259
      self.assertFalse(status)
260
      self.assertTrue(msg.startswith("Permissions on '"))
261

    
262
  def testNoDirectory(self):
263
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
264
    utils.WriteFile(tmpname, data="empty\n")
265
    self.assertTrue(os.path.isfile(tmpname))
266
    (status, msg) = \
267
      backend._VerifyRestrictedCmdDirectory(tmpname,
268
                                            _owner=_DefRestrictedCmdOwner())
269
    self.assertFalse(status)
270
    self.assertTrue(msg.endswith("is not a directory"))
271

    
272
  def testNormal(self):
273
    tmpname = utils.PathJoin(self.tmpdir, "foobar")
274
    os.mkdir(tmpname)
275
    os.chmod(tmpname, 0755)
276
    self.assertTrue(os.path.isdir(tmpname))
277
    (status, msg) = \
278
      backend._VerifyRestrictedCmdDirectory(tmpname,
279
                                            _owner=_DefRestrictedCmdOwner())
280
    self.assertTrue(status)
281
    self.assertTrue(msg is None)
282

    
283

    
284
class TestVerifyRestrictedCmd(unittest.TestCase):
285
  def setUp(self):
286
    self.tmpdir = tempfile.mkdtemp()
287

    
288
  def tearDown(self):
289
    shutil.rmtree(self.tmpdir)
290

    
291
  def testCanNotStat(self):
292
    tmpname = utils.PathJoin(self.tmpdir, "helloworld")
293
    self.assertFalse(os.path.exists(tmpname))
294
    (status, msg) = \
295
      backend._VerifyRestrictedCmd(self.tmpdir, "helloworld",
296
                                   _owner=NotImplemented)
297
    self.assertFalse(status)
298
    self.assertTrue(msg.startswith("Can't stat(2) '"))
299

    
300
  def testNotExecutable(self):
301
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
302
    utils.WriteFile(tmpname, data="empty\n")
303
    (status, msg) = \
304
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
305
                                   _owner=_DefRestrictedCmdOwner())
306
    self.assertFalse(status)
307
    self.assertTrue(msg.startswith("access(2) thinks '"))
308

    
309
  def testExecutable(self):
310
    tmpname = utils.PathJoin(self.tmpdir, "cmdname")
311
    utils.WriteFile(tmpname, data="empty\n", mode=0700)
312
    (status, executable) = \
313
      backend._VerifyRestrictedCmd(self.tmpdir, "cmdname",
314
                                   _owner=_DefRestrictedCmdOwner())
315
    self.assertTrue(status)
316
    self.assertEqual(executable, tmpname)
317

    
318

    
319
class TestPrepareRestrictedCmd(unittest.TestCase):
320
  _TEST_PATH = "/tmp/some/test/path"
321

    
322
  def testDirFails(self):
323
    def fn(path):
324
      self.assertEqual(path, self._TEST_PATH)
325
      return (False, "test error 31420")
326

    
327
    (status, msg) = \
328
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd21152",
329
                                    _verify_dir=fn,
330
                                    _verify_name=NotImplemented,
331
                                    _verify_cmd=NotImplemented)
332
    self.assertFalse(status)
333
    self.assertEqual(msg, "test error 31420")
334

    
335
  def testNameFails(self):
336
    def fn(cmd):
337
      self.assertEqual(cmd, "cmd4617")
338
      return (False, "test error 591")
339

    
340
    (status, msg) = \
341
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd4617",
342
                                    _verify_dir=lambda _: (True, None),
343
                                    _verify_name=fn,
344
                                    _verify_cmd=NotImplemented)
345
    self.assertFalse(status)
346
    self.assertEqual(msg, "test error 591")
347

    
348
  def testCommandFails(self):
349
    def fn(path, cmd):
350
      self.assertEqual(path, self._TEST_PATH)
351
      self.assertEqual(cmd, "cmd17577")
352
      return (False, "test error 25524")
353

    
354
    (status, msg) = \
355
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd17577",
356
                                    _verify_dir=lambda _: (True, None),
357
                                    _verify_name=lambda _: (True, None),
358
                                    _verify_cmd=fn)
359
    self.assertFalse(status)
360
    self.assertEqual(msg, "test error 25524")
361

    
362
  def testSuccess(self):
363
    def fn(path, cmd):
364
      return (True, utils.PathJoin(path, cmd))
365

    
366
    (status, executable) = \
367
      backend._PrepareRestrictedCmd(self._TEST_PATH, "cmd22633",
368
                                    _verify_dir=lambda _: (True, None),
369
                                    _verify_name=lambda _: (True, None),
370
                                    _verify_cmd=fn)
371
    self.assertTrue(status)
372
    self.assertEqual(executable, utils.PathJoin(self._TEST_PATH, "cmd22633"))
373

    
374

    
375
def _SleepForRestrictedCmd(duration):
376
  assert duration > 5
377

    
378

    
379
def _GenericRestrictedCmdError(cmd):
380
  return "Executing command '%s' failed" % cmd
381

    
382

    
383
class TestRunRestrictedCmd(unittest.TestCase):
384
  def setUp(self):
385
    self.tmpdir = tempfile.mkdtemp()
386

    
387
  def tearDown(self):
388
    shutil.rmtree(self.tmpdir)
389

    
390
  def testNonExistantLockDirectory(self):
391
    lockfile = utils.PathJoin(self.tmpdir, "does", "not", "exist")
392
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
393
    self.assertFalse(os.path.exists(lockfile))
394
    self.assertRaises(backend.RPCFail,
395
                      backend.RunRestrictedCmd, "test",
396
                      _lock_timeout=NotImplemented,
397
                      _lock_file=lockfile,
398
                      _path=NotImplemented,
399
                      _sleep_fn=sleep_fn,
400
                      _prepare_fn=NotImplemented,
401
                      _runcmd_fn=NotImplemented,
402
                      _enabled=True)
403
    self.assertEqual(sleep_fn.Count(), 1)
404

    
405
  @staticmethod
406
  def _TryLock(lockfile):
407
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
408

    
409
    result = False
410
    try:
411
      backend.RunRestrictedCmd("test22717",
412
                               _lock_timeout=0.1,
413
                               _lock_file=lockfile,
414
                               _path=NotImplemented,
415
                               _sleep_fn=sleep_fn,
416
                               _prepare_fn=NotImplemented,
417
                               _runcmd_fn=NotImplemented,
418
                               _enabled=True)
419
    except backend.RPCFail, err:
420
      assert str(err) == _GenericRestrictedCmdError("test22717"), \
421
             "Did not fail with generic error message"
422
      result = True
423

    
424
    assert sleep_fn.Count() == 1
425

    
426
    return result
427

    
428
  def testLockHeldByOtherProcess(self):
429
    lockfile = utils.PathJoin(self.tmpdir, "lock")
430

    
431
    lock = utils.FileLock.Open(lockfile)
432
    lock.Exclusive(blocking=True, timeout=1.0)
433
    try:
434
      self.assertTrue(utils.RunInSeparateProcess(self._TryLock, lockfile))
435
    finally:
436
      lock.Close()
437

    
438
  @staticmethod
439
  def _PrepareRaisingException(path, cmd):
440
    assert cmd == "test23122"
441
    raise Exception("test")
442

    
443
  def testPrepareRaisesException(self):
444
    lockfile = utils.PathJoin(self.tmpdir, "lock")
445

    
446
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
447
    prepare_fn = testutils.CallCounter(self._PrepareRaisingException)
448

    
449
    try:
450
      backend.RunRestrictedCmd("test23122",
451
                               _lock_timeout=1.0, _lock_file=lockfile,
452
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
453
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
454
                               _enabled=True)
455
    except backend.RPCFail, err:
456
      self.assertEqual(str(err), _GenericRestrictedCmdError("test23122"))
457
    else:
458
      self.fail("Didn't fail")
459

    
460
    self.assertEqual(sleep_fn.Count(), 1)
461
    self.assertEqual(prepare_fn.Count(), 1)
462

    
463
  @staticmethod
464
  def _PrepareFails(path, cmd):
465
    assert cmd == "test29327"
466
    return ("some error message", None)
467

    
468
  def testPrepareFails(self):
469
    lockfile = utils.PathJoin(self.tmpdir, "lock")
470

    
471
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
472
    prepare_fn = testutils.CallCounter(self._PrepareFails)
473

    
474
    try:
475
      backend.RunRestrictedCmd("test29327",
476
                               _lock_timeout=1.0, _lock_file=lockfile,
477
                               _path=NotImplemented, _runcmd_fn=NotImplemented,
478
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
479
                               _enabled=True)
480
    except backend.RPCFail, err:
481
      self.assertEqual(str(err), _GenericRestrictedCmdError("test29327"))
482
    else:
483
      self.fail("Didn't fail")
484

    
485
    self.assertEqual(sleep_fn.Count(), 1)
486
    self.assertEqual(prepare_fn.Count(), 1)
487

    
488
  @staticmethod
489
  def _SuccessfulPrepare(path, cmd):
490
    return (True, utils.PathJoin(path, cmd))
491

    
492
  def testRunCmdFails(self):
493
    lockfile = utils.PathJoin(self.tmpdir, "lock")
494

    
495
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
496
           postfork_fn=NotImplemented):
497
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test3079")])
498
      self.assertEqual(env, {})
499
      self.assertTrue(reset_env)
500
      self.assertTrue(callable(postfork_fn))
501

    
502
      trylock = utils.FileLock.Open(lockfile)
503
      try:
504
        # See if lockfile is still held
505
        self.assertRaises(EnvironmentError, trylock.Exclusive, blocking=False)
506

    
507
        # Call back to release lock
508
        postfork_fn(NotImplemented)
509

    
510
        # See if lockfile can be acquired
511
        trylock.Exclusive(blocking=False)
512
      finally:
513
        trylock.Close()
514

    
515
      # Simulate a failed command
516
      return utils.RunResult(constants.EXIT_FAILURE, None,
517
                             "stdout", "stderr406328567",
518
                             utils.ShellQuoteArgs(args),
519
                             NotImplemented, NotImplemented)
520

    
521
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
522
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
523
    runcmd_fn = testutils.CallCounter(fn)
524

    
525
    try:
526
      backend.RunRestrictedCmd("test3079",
527
                               _lock_timeout=1.0, _lock_file=lockfile,
528
                               _path=self.tmpdir, _runcmd_fn=runcmd_fn,
529
                               _sleep_fn=sleep_fn, _prepare_fn=prepare_fn,
530
                               _enabled=True)
531
    except backend.RPCFail, err:
532
      self.assertTrue(str(err).startswith("Restricted command 'test3079'"
533
                                          " failed:"))
534
      self.assertTrue("stderr406328567" in str(err),
535
                      msg="Error did not include output")
536
    else:
537
      self.fail("Didn't fail")
538

    
539
    self.assertEqual(sleep_fn.Count(), 0)
540
    self.assertEqual(prepare_fn.Count(), 1)
541
    self.assertEqual(runcmd_fn.Count(), 1)
542

    
543
  def testRunCmdSucceeds(self):
544
    lockfile = utils.PathJoin(self.tmpdir, "lock")
545

    
546
    def fn(args, env=NotImplemented, reset_env=NotImplemented,
547
           postfork_fn=NotImplemented):
548
      self.assertEqual(args, [utils.PathJoin(self.tmpdir, "test5667")])
549
      self.assertEqual(env, {})
550
      self.assertTrue(reset_env)
551

    
552
      # Call back to release lock
553
      postfork_fn(NotImplemented)
554

    
555
      # Simulate a successful command
556
      return utils.RunResult(constants.EXIT_SUCCESS, None, "stdout14463", "",
557
                             utils.ShellQuoteArgs(args),
558
                             NotImplemented, NotImplemented)
559

    
560
    sleep_fn = testutils.CallCounter(_SleepForRestrictedCmd)
561
    prepare_fn = testutils.CallCounter(self._SuccessfulPrepare)
562
    runcmd_fn = testutils.CallCounter(fn)
563

    
564
    result = backend.RunRestrictedCmd("test5667",
565
                                      _lock_timeout=1.0, _lock_file=lockfile,
566
                                      _path=self.tmpdir, _runcmd_fn=runcmd_fn,
567
                                      _sleep_fn=sleep_fn,
568
                                      _prepare_fn=prepare_fn,
569
                                      _enabled=True)
570
    self.assertEqual(result, "stdout14463")
571

    
572
    self.assertEqual(sleep_fn.Count(), 0)
573
    self.assertEqual(prepare_fn.Count(), 1)
574
    self.assertEqual(runcmd_fn.Count(), 1)
575

    
576
  def testCommandsDisabled(self):
577
    try:
578
      backend.RunRestrictedCmd("test",
579
                               _lock_timeout=NotImplemented,
580
                               _lock_file=NotImplemented,
581
                               _path=NotImplemented,
582
                               _sleep_fn=NotImplemented,
583
                               _prepare_fn=NotImplemented,
584
                               _runcmd_fn=NotImplemented,
585
                               _enabled=False)
586
    except backend.RPCFail, err:
587
      self.assertEqual(str(err),
588
                       "Restricted commands disabled at configure time")
589
    else:
590
      self.fail("Did not raise exception")
591

    
592

    
593
class TestSetWatcherPause(unittest.TestCase):
594
  def setUp(self):
595
    self.tmpdir = tempfile.mkdtemp()
596
    self.filename = utils.PathJoin(self.tmpdir, "pause")
597

    
598
  def tearDown(self):
599
    shutil.rmtree(self.tmpdir)
600

    
601
  def testUnsetNonExisting(self):
602
    self.assertFalse(os.path.exists(self.filename))
603
    backend.SetWatcherPause(None, _filename=self.filename)
604
    self.assertFalse(os.path.exists(self.filename))
605

    
606
  def testSetNonNumeric(self):
607
    for i in ["", [], {}, "Hello World", "0", "1.0"]:
608
      self.assertFalse(os.path.exists(self.filename))
609

    
610
      try:
611
        backend.SetWatcherPause(i, _filename=self.filename)
612
      except backend.RPCFail, err:
613
        self.assertEqual(str(err), "Duration must be numeric")
614
      else:
615
        self.fail("Did not raise exception")
616

    
617
      self.assertFalse(os.path.exists(self.filename))
618

    
619
  def testSet(self):
620
    self.assertFalse(os.path.exists(self.filename))
621

    
622
    for i in range(10):
623
      backend.SetWatcherPause(i, _filename=self.filename)
624
      self.assertEqual(utils.ReadFile(self.filename), "%s\n" % i)
625
      self.assertEqual(os.stat(self.filename).st_mode & 0777, 0644)
626

    
627

    
628
class TestGetBlockDevSymlinkPath(unittest.TestCase):
629
  def setUp(self):
630
    self.tmpdir = tempfile.mkdtemp()
631

    
632
  def tearDown(self):
633
    shutil.rmtree(self.tmpdir)
634

    
635
  def _Test(self, name, idx):
636
    self.assertEqual(backend._GetBlockDevSymlinkPath(name, idx,
637
                                                     _dir=self.tmpdir),
638
                     ("%s/%s%s%s" % (self.tmpdir, name,
639
                                     constants.DISK_SEPARATOR, idx)))
640

    
641
  def test(self):
642
    for idx in range(100):
643
      self._Test("inst1.example.com", idx)
644

    
645

    
646
class TestGetInstanceList(unittest.TestCase):
647

    
648
  def setUp(self):
649
    self._test_hv = self._TestHypervisor()
650
    self._test_hv.ListInstances = mock.Mock(
651
      return_value=["instance1", "instance2", "instance3"] )
652

    
653
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
654
    def __init__(self):
655
      hypervisor.hv_base.BaseHypervisor.__init__(self)
656

    
657
  def _GetHypervisor(self, name):
658
    return self._test_hv
659

    
660
  def testHvparams(self):
661
    fake_hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
662
    hvparams = {constants.HT_FAKE: fake_hvparams}
663
    backend.GetInstanceList([constants.HT_FAKE], all_hvparams=hvparams,
664
                            get_hv_fn=self._GetHypervisor)
665
    self._test_hv.ListInstances.assert_called_with(hvparams=fake_hvparams)
666

    
667

    
668
class TestInstanceConsoleInfo(unittest.TestCase):
669

    
670
  def setUp(self):
671
    self._test_hv_a = self._TestHypervisor()
672
    self._test_hv_a.GetInstanceConsole = mock.Mock(
673
      return_value = objects.InstanceConsole(instance="inst", kind="aHy")
674
    )
675
    self._test_hv_b = self._TestHypervisor()
676
    self._test_hv_b.GetInstanceConsole = mock.Mock(
677
      return_value = objects.InstanceConsole(instance="inst", kind="bHy")
678
    )
679

    
680
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
681
    def __init__(self):
682
      hypervisor.hv_base.BaseHypervisor.__init__(self)
683

    
684
  def _GetHypervisor(self, name):
685
    if name == "a":
686
      return self._test_hv_a
687
    else:
688
      return self._test_hv_b
689

    
690
  def testRightHypervisor(self):
691
    dictMaker = lambda hyName: {
692
      "instance":{"hypervisor":hyName},
693
      "node":{},
694
      "group":{},
695
      "hvParams":{},
696
      "beParams":{},
697
    }
698

    
699
    call = {
700
      'i1':dictMaker("a"),
701
      'i2':dictMaker("b"),
702
    }
703

    
704
    res = backend.GetInstanceConsoleInfo(call, get_hv_fn=self._GetHypervisor)
705

    
706
    self.assertTrue(res["i1"]["kind"] == "aHy")
707
    self.assertTrue(res["i2"]["kind"] == "bHy")
708

    
709

    
710
class TestGetHvInfo(unittest.TestCase):
711

    
712
  def setUp(self):
713
    self._test_hv = self._TestHypervisor()
714
    self._test_hv.GetNodeInfo = mock.Mock()
715

    
716
  class _TestHypervisor(hypervisor.hv_base.BaseHypervisor):
717
    def __init__(self):
718
      hypervisor.hv_base.BaseHypervisor.__init__(self)
719

    
720
  def _GetHypervisor(self, name):
721
    return self._test_hv
722

    
723
  def testGetHvInfoAllNone(self):
724
    result = backend._GetHvInfoAll(None)
725
    self.assertTrue(result is None)
726

    
727
  def testGetHvInfoAll(self):
728
    hvname = constants.HT_XEN_PVM
729
    hvparams = {constants.HV_XEN_CMD: constants.XEN_CMD_XL}
730
    hv_specs = [(hvname, hvparams)]
731

    
732
    backend._GetHvInfoAll(hv_specs, self._GetHypervisor)
733
    self._test_hv.GetNodeInfo.assert_called_with(hvparams=hvparams)
734

    
735

    
736
class TestApplyStorageInfoFunction(unittest.TestCase):
737

    
738
  _STORAGE_KEY = "some_key"
739
  _SOME_ARGS = ["some_args"]
740

    
741
  def setUp(self):
742
    self.mock_storage_fn = mock.Mock()
743

    
744
  def testApplyValidStorageType(self):
745
    storage_type = constants.ST_LVM_VG
746
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
747
    backend._STORAGE_TYPE_INFO_FN = {
748
        storage_type: self.mock_storage_fn
749
      }
750

    
751
    backend._ApplyStorageInfoFunction(
752
        storage_type, self._STORAGE_KEY, self._SOME_ARGS)
753

    
754
    self.mock_storage_fn.assert_called_with(self._STORAGE_KEY, self._SOME_ARGS)
755
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
756

    
757
  def testApplyInValidStorageType(self):
758
    storage_type = "invalid_storage_type"
759
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
760
    backend._STORAGE_TYPE_INFO_FN = {}
761

    
762
    self.assertRaises(KeyError, backend._ApplyStorageInfoFunction,
763
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
764
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
765

    
766
  def testApplyNotImplementedStorageType(self):
767
    storage_type = "not_implemented_storage_type"
768
    info_fn_orig = backend._STORAGE_TYPE_INFO_FN
769
    backend._STORAGE_TYPE_INFO_FN = {storage_type: None}
770

    
771
    self.assertRaises(NotImplementedError,
772
                      backend._ApplyStorageInfoFunction,
773
                      storage_type, self._STORAGE_KEY, self._SOME_ARGS)
774
    backend._STORAGE_TYPE_INFO_FN = info_fn_orig
775

    
776

    
777
class TestGetLvmVgSpaceInfo(unittest.TestCase):
778

    
779
  def testValid(self):
780
    path = "somepath"
781
    excl_stor = True
782
    orig_fn = backend._GetVgInfo
783
    backend._GetVgInfo = mock.Mock()
784
    backend._GetLvmVgSpaceInfo(path, [excl_stor])
785
    backend._GetVgInfo.assert_called_with(path, excl_stor)
786
    backend._GetVgInfo = orig_fn
787

    
788
  def testNoExclStorageNotBool(self):
789
    path = "somepath"
790
    excl_stor = "123"
791
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
792
                      path, [excl_stor])
793

    
794
  def testNoExclStorageNotInList(self):
795
    path = "somepath"
796
    excl_stor = "123"
797
    self.assertRaises(errors.ProgrammerError, backend._GetLvmVgSpaceInfo,
798
                      path, excl_stor)
799

    
800
class TestGetLvmPvSpaceInfo(unittest.TestCase):
801

    
802
  def testValid(self):
803
    path = "somepath"
804
    excl_stor = True
805
    orig_fn = backend._GetVgSpindlesInfo
806
    backend._GetVgSpindlesInfo = mock.Mock()
807
    backend._GetLvmPvSpaceInfo(path, [excl_stor])
808
    backend._GetVgSpindlesInfo.assert_called_with(path, excl_stor)
809
    backend._GetVgSpindlesInfo = orig_fn
810

    
811

    
812
class TestCheckStorageParams(unittest.TestCase):
813

    
814
  def testParamsNone(self):
815
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
816
                      None, NotImplemented)
817

    
818
  def testParamsWrongType(self):
819
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
820
                      "string", NotImplemented)
821

    
822
  def testParamsEmpty(self):
823
    backend._CheckStorageParams([], 0)
824

    
825
  def testParamsValidNumber(self):
826
    backend._CheckStorageParams(["a", True], 2)
827

    
828
  def testParamsInvalidNumber(self):
829
    self.assertRaises(errors.ProgrammerError, backend._CheckStorageParams,
830
                      ["b", False], 3)
831

    
832

    
833
class TestGetVgSpindlesInfo(unittest.TestCase):
834

    
835
  def setUp(self):
836
    self.vg_free = 13
837
    self.vg_size = 31
838
    self.mock_fn = mock.Mock(return_value=(self.vg_free, self.vg_size))
839

    
840
  def testValidInput(self):
841
    name = "myvg"
842
    excl_stor = True
843
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
844
    self.mock_fn.assert_called_with(name)
845
    self.assertEqual(name, result["name"])
846
    self.assertEqual(constants.ST_LVM_PV, result["type"])
847
    self.assertEqual(self.vg_free, result["storage_free"])
848
    self.assertEqual(self.vg_size, result["storage_size"])
849

    
850
  def testNoExclStor(self):
851
    name = "myvg"
852
    excl_stor = False
853
    result = backend._GetVgSpindlesInfo(name, excl_stor, info_fn=self.mock_fn)
854
    self.mock_fn.assert_not_called()
855
    self.assertEqual(name, result["name"])
856
    self.assertEqual(constants.ST_LVM_PV, result["type"])
857
    self.assertEqual(0, result["storage_free"])
858
    self.assertEqual(0, result["storage_size"])
859

    
860

    
861
class TestGetVgSpindlesInfo(unittest.TestCase):
862

    
863
  def testValidInput(self):
864
    self.vg_free = 13
865
    self.vg_size = 31
866
    self.mock_fn = mock.Mock(return_value=[(self.vg_free, self.vg_size)])
867
    name = "myvg"
868
    excl_stor = True
869
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
870
    self.mock_fn.assert_called_with([name], excl_stor)
871
    self.assertEqual(name, result["name"])
872
    self.assertEqual(constants.ST_LVM_VG, result["type"])
873
    self.assertEqual(self.vg_free, result["storage_free"])
874
    self.assertEqual(self.vg_size, result["storage_size"])
875

    
876
  def testNoExclStor(self):
877
    name = "myvg"
878
    excl_stor = True
879
    self.mock_fn = mock.Mock(return_value=None)
880
    result = backend._GetVgInfo(name, excl_stor, info_fn=self.mock_fn)
881
    self.mock_fn.assert_called_with([name], excl_stor)
882
    self.assertEqual(name, result["name"])
883
    self.assertEqual(constants.ST_LVM_VG, result["type"])
884
    self.assertEqual(None, result["storage_free"])
885
    self.assertEqual(None, result["storage_size"])
886

    
887

    
888
class TestGetNodeInfo(unittest.TestCase):
889

    
890
  _SOME_RESULT = None
891

    
892
  def testApplyStorageInfoFunction(self):
893
    orig_fn = backend._ApplyStorageInfoFunction
894
    backend._ApplyStorageInfoFunction = mock.Mock(
895
        return_value=self._SOME_RESULT)
896
    storage_units = [(st, st + "_key", [st + "_params"]) for st in
897
                     constants.STORAGE_TYPES]
898

    
899
    backend.GetNodeInfo(storage_units, None)
900

    
901
    call_args_list = backend._ApplyStorageInfoFunction.call_args_list
902
    self.assertEqual(len(constants.STORAGE_TYPES), len(call_args_list))
903
    for call in call_args_list:
904
      storage_type, storage_key, storage_params = call[0]
905
      self.assertEqual(storage_type + "_key", storage_key)
906
      self.assertEqual([storage_type + "_params"], storage_params)
907
      self.assertTrue(storage_type in constants.STORAGE_TYPES)
908
    backend._ApplyStorageInfoFunction = orig_fn
909

    
910

    
911
class TestSpaceReportingConstants(unittest.TestCase):
912
  """Ensures consistency between STS_REPORT and backend.
913

914
  These tests ensure, that the constant 'STS_REPORT' is consistent
915
  with the implementation of invoking space reporting functions
916
  in backend.py. Once space reporting is available for all types,
917
  the constant can be removed and these tests as well.
918

919
  """
920

    
921
  REPORTING = set(constants.STS_REPORT)
922
  NOT_REPORTING = set(constants.STORAGE_TYPES) - REPORTING
923

    
924
  def testAllReportingTypesHaveAReportingFunction(self):
925
    for storage_type in TestSpaceReportingConstants.REPORTING:
926
      self.assertTrue(backend._STORAGE_TYPE_INFO_FN[storage_type] is not None)
927

    
928
  def testAllNotReportingTypesDontHaveFunction(self):
929
    for storage_type in TestSpaceReportingConstants.NOT_REPORTING:
930
      self.assertEqual(None, backend._STORAGE_TYPE_INFO_FN[storage_type])
931

    
932

    
933
if __name__ == "__main__":
934
  testutils.GanetiTestProgram()