Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 7b4baeb1

History | View | Annotate | Download (68.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
import random
42
import operator
43

    
44
import testutils
45
from ganeti import constants
46
from ganeti import compat
47
from ganeti import utils
48
from ganeti import errors
49
from ganeti.utils import RunCmd, RemoveFile, \
50
     ListVisibleFiles, FirstFree, \
51
     TailFile, RunParts, PathJoin, \
52
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
53

    
54

    
55
class TestIsProcessAlive(unittest.TestCase):
56
  """Testing case for IsProcessAlive"""
57

    
58
  def testExists(self):
59
    mypid = os.getpid()
60
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
61

    
62
  def testNotExisting(self):
63
    pid_non_existing = os.fork()
64
    if pid_non_existing == 0:
65
      os._exit(0)
66
    elif pid_non_existing < 0:
67
      raise SystemError("can't fork")
68
    os.waitpid(pid_non_existing, 0)
69
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
70
                     "nonexisting process detected")
71

    
72

    
73
class TestGetProcStatusPath(unittest.TestCase):
74
  def test(self):
75
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
76
    self.assertNotEqual(utils._GetProcStatusPath(1),
77
                        utils._GetProcStatusPath(2))
78

    
79

    
80
class TestIsProcessHandlingSignal(unittest.TestCase):
81
  def setUp(self):
82
    self.tmpdir = tempfile.mkdtemp()
83

    
84
  def tearDown(self):
85
    shutil.rmtree(self.tmpdir)
86

    
87
  def testParseSigsetT(self):
88
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
89
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
90
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
91
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
92
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
93
                     set([2, 10, 32, 33]))
94
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
95
                     set([2, 32, 33]))
96
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
97
                     set([2, 28, 32, 33]))
98
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
99
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
100
                          24, 25, 26, 28, 31]))
101
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
102

    
103
  def testGetProcStatusField(self):
104
    for field in ["SigCgt", "Name", "FDSize"]:
105
      for value in ["", "0", "cat", "  1234 KB"]:
106
        pstatus = "\n".join([
107
          "VmPeak: 999 kB",
108
          "%s: %s" % (field, value),
109
          "TracerPid: 0",
110
          ])
111
        result = utils._GetProcStatusField(pstatus, field)
112
        self.assertEqual(result, value.strip())
113

    
114
  def test(self):
115
    sp = PathJoin(self.tmpdir, "status")
116

    
117
    utils.WriteFile(sp, data="\n".join([
118
      "Name:   bash",
119
      "State:  S (sleeping)",
120
      "SleepAVG:       98%",
121
      "Pid:    22250",
122
      "PPid:   10858",
123
      "TracerPid:      0",
124
      "SigBlk: 0000000000010000",
125
      "SigIgn: 0000000000384004",
126
      "SigCgt: 000000004b813efb",
127
      "CapEff: 0000000000000000",
128
      ]))
129

    
130
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
131

    
132
  def testNoSigCgt(self):
133
    sp = PathJoin(self.tmpdir, "status")
134

    
135
    utils.WriteFile(sp, data="\n".join([
136
      "Name:   bash",
137
      ]))
138

    
139
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
140
                      1234, 10, status_path=sp)
141

    
142
  def testNoSuchFile(self):
143
    sp = PathJoin(self.tmpdir, "notexist")
144

    
145
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
146

    
147
  @staticmethod
148
  def _TestRealProcess():
149
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
150
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
151
      raise Exception("SIGUSR1 is handled when it should not be")
152

    
153
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
154
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
155
      raise Exception("SIGUSR1 is not handled when it should be")
156

    
157
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
158
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
159
      raise Exception("SIGUSR1 is not handled when it should be")
160

    
161
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
162
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
163
      raise Exception("SIGUSR1 is handled when it should not be")
164

    
165
    return True
166

    
167
  def testRealProcess(self):
168
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
169

    
170

    
171
class TestPidFileFunctions(unittest.TestCase):
172
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
173

    
174
  def setUp(self):
175
    self.dir = tempfile.mkdtemp()
176
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
177

    
178
  def testPidFileFunctions(self):
179
    pid_file = self.f_dpn('test')
180
    fd = utils.WritePidFile(self.f_dpn('test'))
181
    self.failUnless(os.path.exists(pid_file),
182
                    "PID file should have been created")
183
    read_pid = utils.ReadPidFile(pid_file)
184
    self.failUnlessEqual(read_pid, os.getpid())
185
    self.failUnless(utils.IsProcessAlive(read_pid))
186
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
187
                          self.f_dpn('test'))
188
    os.close(fd)
189
    utils.RemovePidFile(self.f_dpn("test"))
190
    self.failIf(os.path.exists(pid_file),
191
                "PID file should not exist anymore")
192
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
193
                         "ReadPidFile should return 0 for missing pid file")
194
    fh = open(pid_file, "w")
195
    fh.write("blah\n")
196
    fh.close()
197
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
198
                         "ReadPidFile should return 0 for invalid pid file")
199
    # but now, even with the file existing, we should be able to lock it
200
    fd = utils.WritePidFile(self.f_dpn('test'))
201
    os.close(fd)
202
    utils.RemovePidFile(self.f_dpn("test"))
203
    self.failIf(os.path.exists(pid_file),
204
                "PID file should not exist anymore")
205

    
206
  def testKill(self):
207
    pid_file = self.f_dpn('child')
208
    r_fd, w_fd = os.pipe()
209
    new_pid = os.fork()
210
    if new_pid == 0: #child
211
      utils.WritePidFile(self.f_dpn('child'))
212
      os.write(w_fd, 'a')
213
      signal.pause()
214
      os._exit(0)
215
      return
216
    # else we are in the parent
217
    # wait until the child has written the pid file
218
    os.read(r_fd, 1)
219
    read_pid = utils.ReadPidFile(pid_file)
220
    self.failUnlessEqual(read_pid, new_pid)
221
    self.failUnless(utils.IsProcessAlive(new_pid))
222
    utils.KillProcess(new_pid, waitpid=True)
223
    self.failIf(utils.IsProcessAlive(new_pid))
224
    utils.RemovePidFile(self.f_dpn('child'))
225
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
226

    
227
  def tearDown(self):
228
    for name in os.listdir(self.dir):
229
      os.unlink(os.path.join(self.dir, name))
230
    os.rmdir(self.dir)
231

    
232

    
233
class TestRunCmd(testutils.GanetiTestCase):
234
  """Testing case for the RunCmd function"""
235

    
236
  def setUp(self):
237
    testutils.GanetiTestCase.setUp(self)
238
    self.magic = time.ctime() + " ganeti test"
239
    self.fname = self._CreateTempFile()
240
    self.fifo_tmpdir = tempfile.mkdtemp()
241
    self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
242
    os.mkfifo(self.fifo_file)
243

    
244
  def tearDown(self):
245
    shutil.rmtree(self.fifo_tmpdir)
246
    testutils.GanetiTestCase.tearDown(self)
247

    
248
  def testOk(self):
249
    """Test successful exit code"""
250
    result = RunCmd("/bin/sh -c 'exit 0'")
251
    self.assertEqual(result.exit_code, 0)
252
    self.assertEqual(result.output, "")
253

    
254
  def testFail(self):
255
    """Test fail exit code"""
256
    result = RunCmd("/bin/sh -c 'exit 1'")
257
    self.assertEqual(result.exit_code, 1)
258
    self.assertEqual(result.output, "")
259

    
260
  def testStdout(self):
261
    """Test standard output"""
262
    cmd = 'echo -n "%s"' % self.magic
263
    result = RunCmd("/bin/sh -c '%s'" % cmd)
264
    self.assertEqual(result.stdout, self.magic)
265
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
266
    self.assertEqual(result.output, "")
267
    self.assertFileContent(self.fname, self.magic)
268

    
269
  def testStderr(self):
270
    """Test standard error"""
271
    cmd = 'echo -n "%s"' % self.magic
272
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
273
    self.assertEqual(result.stderr, self.magic)
274
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
275
    self.assertEqual(result.output, "")
276
    self.assertFileContent(self.fname, self.magic)
277

    
278
  def testCombined(self):
279
    """Test combined output"""
280
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
281
    expected = "A" + self.magic + "B" + self.magic
282
    result = RunCmd("/bin/sh -c '%s'" % cmd)
283
    self.assertEqual(result.output, expected)
284
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
285
    self.assertEqual(result.output, "")
286
    self.assertFileContent(self.fname, expected)
287

    
288
  def testSignal(self):
289
    """Test signal"""
290
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
291
    self.assertEqual(result.signal, 15)
292
    self.assertEqual(result.output, "")
293

    
294
  def testTimeoutClean(self):
295
    cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
296
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
297
    self.assertEqual(result.exit_code, 0)
298

    
299
  def testTimeoutKill(self):
300
    cmd = ["/bin/sh", "-c", "trap '' TERM; read < %s" % self.fifo_file]
301
    timeout = 0.2
302
    out, err, status, ta = utils._RunCmdPipe(cmd, {}, False, "/", False,
303
                                             timeout, _linger_timeout=0.2)
304
    self.assert_(status < 0)
305
    self.assertEqual(-status, signal.SIGKILL)
306

    
307
  def testTimeoutOutputAfterTerm(self):
308
    cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
309
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
310
    self.assert_(result.failed)
311
    self.assertEqual(result.stdout, "sigtermed\n")
312

    
313
  def testListRun(self):
314
    """Test list runs"""
315
    result = RunCmd(["true"])
316
    self.assertEqual(result.signal, None)
317
    self.assertEqual(result.exit_code, 0)
318
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
319
    self.assertEqual(result.signal, None)
320
    self.assertEqual(result.exit_code, 1)
321
    result = RunCmd(["echo", "-n", self.magic])
322
    self.assertEqual(result.signal, None)
323
    self.assertEqual(result.exit_code, 0)
324
    self.assertEqual(result.stdout, self.magic)
325

    
326
  def testFileEmptyOutput(self):
327
    """Test file output"""
328
    result = RunCmd(["true"], output=self.fname)
329
    self.assertEqual(result.signal, None)
330
    self.assertEqual(result.exit_code, 0)
331
    self.assertFileContent(self.fname, "")
332

    
333
  def testLang(self):
334
    """Test locale environment"""
335
    old_env = os.environ.copy()
336
    try:
337
      os.environ["LANG"] = "en_US.UTF-8"
338
      os.environ["LC_ALL"] = "en_US.UTF-8"
339
      result = RunCmd(["locale"])
340
      for line in result.output.splitlines():
341
        key, value = line.split("=", 1)
342
        # Ignore these variables, they're overridden by LC_ALL
343
        if key == "LANG" or key == "LANGUAGE":
344
          continue
345
        self.failIf(value and value != "C" and value != '"C"',
346
            "Variable %s is set to the invalid value '%s'" % (key, value))
347
    finally:
348
      os.environ = old_env
349

    
350
  def testDefaultCwd(self):
351
    """Test default working directory"""
352
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
353

    
354
  def testCwd(self):
355
    """Test default working directory"""
356
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
357
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
358
    cwd = os.getcwd()
359
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
360

    
361
  def testResetEnv(self):
362
    """Test environment reset functionality"""
363
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
364
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
365
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
366

    
367
  def testNoFork(self):
368
    """Test that nofork raise an error"""
369
    self.assertFalse(utils._no_fork)
370
    utils.DisableFork()
371
    try:
372
      self.assertTrue(utils._no_fork)
373
      self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
374
    finally:
375
      utils._no_fork = False
376

    
377
  def testWrongParams(self):
378
    """Test wrong parameters"""
379
    self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
380
                      output="/dev/null", interactive=True)
381

    
382

    
383
class TestRunParts(testutils.GanetiTestCase):
384
  """Testing case for the RunParts function"""
385

    
386
  def setUp(self):
387
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
388

    
389
  def tearDown(self):
390
    shutil.rmtree(self.rundir)
391

    
392
  def testEmpty(self):
393
    """Test on an empty dir"""
394
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
395

    
396
  def testSkipWrongName(self):
397
    """Test that wrong files are skipped"""
398
    fname = os.path.join(self.rundir, "00test.dot")
399
    utils.WriteFile(fname, data="")
400
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
401
    relname = os.path.basename(fname)
402
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
403
                         [(relname, constants.RUNPARTS_SKIP, None)])
404

    
405
  def testSkipNonExec(self):
406
    """Test that non executable files are skipped"""
407
    fname = os.path.join(self.rundir, "00test")
408
    utils.WriteFile(fname, data="")
409
    relname = os.path.basename(fname)
410
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
411
                         [(relname, constants.RUNPARTS_SKIP, None)])
412

    
413
  def testError(self):
414
    """Test error on a broken executable"""
415
    fname = os.path.join(self.rundir, "00test")
416
    utils.WriteFile(fname, data="")
417
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
418
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
419
    self.failUnlessEqual(relname, os.path.basename(fname))
420
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
421
    self.failUnless(error)
422

    
423
  def testSorted(self):
424
    """Test executions are sorted"""
425
    files = []
426
    files.append(os.path.join(self.rundir, "64test"))
427
    files.append(os.path.join(self.rundir, "00test"))
428
    files.append(os.path.join(self.rundir, "42test"))
429

    
430
    for fname in files:
431
      utils.WriteFile(fname, data="")
432

    
433
    results = RunParts(self.rundir, reset_env=True)
434

    
435
    for fname in sorted(files):
436
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
437

    
438
  def testOk(self):
439
    """Test correct execution"""
440
    fname = os.path.join(self.rundir, "00test")
441
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
442
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
443
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
444
    self.failUnlessEqual(relname, os.path.basename(fname))
445
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
446
    self.failUnlessEqual(runresult.stdout, "ciao")
447

    
448
  def testRunFail(self):
449
    """Test correct execution, with run failure"""
450
    fname = os.path.join(self.rundir, "00test")
451
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
452
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
453
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
454
    self.failUnlessEqual(relname, os.path.basename(fname))
455
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
456
    self.failUnlessEqual(runresult.exit_code, 1)
457
    self.failUnless(runresult.failed)
458

    
459
  def testRunMix(self):
460
    files = []
461
    files.append(os.path.join(self.rundir, "00test"))
462
    files.append(os.path.join(self.rundir, "42test"))
463
    files.append(os.path.join(self.rundir, "64test"))
464
    files.append(os.path.join(self.rundir, "99test"))
465

    
466
    files.sort()
467

    
468
    # 1st has errors in execution
469
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
470
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
471

    
472
    # 2nd is skipped
473
    utils.WriteFile(files[1], data="")
474

    
475
    # 3rd cannot execute properly
476
    utils.WriteFile(files[2], data="")
477
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
478

    
479
    # 4th execs
480
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
481
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
482

    
483
    results = RunParts(self.rundir, reset_env=True)
484

    
485
    (relname, status, runresult) = results[0]
486
    self.failUnlessEqual(relname, os.path.basename(files[0]))
487
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
488
    self.failUnlessEqual(runresult.exit_code, 1)
489
    self.failUnless(runresult.failed)
490

    
491
    (relname, status, runresult) = results[1]
492
    self.failUnlessEqual(relname, os.path.basename(files[1]))
493
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
494
    self.failUnlessEqual(runresult, None)
495

    
496
    (relname, status, runresult) = results[2]
497
    self.failUnlessEqual(relname, os.path.basename(files[2]))
498
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
499
    self.failUnless(runresult)
500

    
501
    (relname, status, runresult) = results[3]
502
    self.failUnlessEqual(relname, os.path.basename(files[3]))
503
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
504
    self.failUnlessEqual(runresult.output, "ciao")
505
    self.failUnlessEqual(runresult.exit_code, 0)
506
    self.failUnless(not runresult.failed)
507

    
508
  def testMissingDirectory(self):
509
    nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
510
    self.assertEqual(RunParts(nosuchdir), [])
511

    
512

    
513
class TestStartDaemon(testutils.GanetiTestCase):
514
  def setUp(self):
515
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
516
    self.tmpfile = os.path.join(self.tmpdir, "test")
517

    
518
  def tearDown(self):
519
    shutil.rmtree(self.tmpdir)
520

    
521
  def testShell(self):
522
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
523
    self._wait(self.tmpfile, 60.0, "Hello World")
524

    
525
  def testShellOutput(self):
526
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
527
    self._wait(self.tmpfile, 60.0, "Hello World")
528

    
529
  def testNoShellNoOutput(self):
530
    utils.StartDaemon(["pwd"])
531

    
532
  def testNoShellNoOutputTouch(self):
533
    testfile = os.path.join(self.tmpdir, "check")
534
    self.failIf(os.path.exists(testfile))
535
    utils.StartDaemon(["touch", testfile])
536
    self._wait(testfile, 60.0, "")
537

    
538
  def testNoShellOutput(self):
539
    utils.StartDaemon(["pwd"], output=self.tmpfile)
540
    self._wait(self.tmpfile, 60.0, "/")
541

    
542
  def testNoShellOutputCwd(self):
543
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
544
    self._wait(self.tmpfile, 60.0, os.getcwd())
545

    
546
  def testShellEnv(self):
547
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
548
                      env={ "GNT_TEST_VAR": "Hello World", })
549
    self._wait(self.tmpfile, 60.0, "Hello World")
550

    
551
  def testNoShellEnv(self):
552
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
553
                      env={ "GNT_TEST_VAR": "Hello World", })
554
    self._wait(self.tmpfile, 60.0, "Hello World")
555

    
556
  def testOutputFd(self):
557
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
558
    try:
559
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
560
    finally:
561
      os.close(fd)
562
    self._wait(self.tmpfile, 60.0, os.getcwd())
563

    
564
  def testPid(self):
565
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
566
    self._wait(self.tmpfile, 60.0, str(pid))
567

    
568
  def testPidFile(self):
569
    pidfile = os.path.join(self.tmpdir, "pid")
570
    checkfile = os.path.join(self.tmpdir, "abort")
571

    
572
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
573
                            output=self.tmpfile)
574
    try:
575
      fd = os.open(pidfile, os.O_RDONLY)
576
      try:
577
        # Check file is locked
578
        self.assertRaises(errors.LockError, utils.LockFile, fd)
579

    
580
        pidtext = os.read(fd, 100)
581
      finally:
582
        os.close(fd)
583

    
584
      self.assertEqual(int(pidtext.strip()), pid)
585

    
586
      self.assert_(utils.IsProcessAlive(pid))
587
    finally:
588
      # No matter what happens, kill daemon
589
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
590
      self.failIf(utils.IsProcessAlive(pid))
591

    
592
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
593

    
594
  def _wait(self, path, timeout, expected):
595
    # Due to the asynchronous nature of daemon processes, polling is necessary.
596
    # A timeout makes sure the test doesn't hang forever.
597
    def _CheckFile():
598
      if not (os.path.isfile(path) and
599
              utils.ReadFile(path).strip() == expected):
600
        raise utils.RetryAgain()
601

    
602
    try:
603
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
604
    except utils.RetryTimeout:
605
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
606
                " didn't write the correct output" % timeout)
607

    
608
  def testError(self):
609
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
610
                      ["./does-NOT-EXIST/here/0123456789"])
611
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
612
                      ["./does-NOT-EXIST/here/0123456789"],
613
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
614
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
615
                      ["./does-NOT-EXIST/here/0123456789"],
616
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
617
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
618
                      ["./does-NOT-EXIST/here/0123456789"],
619
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
620

    
621
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
622
    try:
623
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
624
                        ["./does-NOT-EXIST/here/0123456789"],
625
                        output=self.tmpfile, output_fd=fd)
626
    finally:
627
      os.close(fd)
628

    
629

    
630
class TestSetCloseOnExecFlag(unittest.TestCase):
631
  """Tests for SetCloseOnExecFlag"""
632

    
633
  def setUp(self):
634
    self.tmpfile = tempfile.TemporaryFile()
635

    
636
  def testEnable(self):
637
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
638
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
639
                    fcntl.FD_CLOEXEC)
640

    
641
  def testDisable(self):
642
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
643
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
644
                fcntl.FD_CLOEXEC)
645

    
646

    
647
class TestSetNonblockFlag(unittest.TestCase):
648
  def setUp(self):
649
    self.tmpfile = tempfile.TemporaryFile()
650

    
651
  def testEnable(self):
652
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
653
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
654
                    os.O_NONBLOCK)
655

    
656
  def testDisable(self):
657
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
658
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
659
                os.O_NONBLOCK)
660

    
661

    
662
class TestRemoveFile(unittest.TestCase):
663
  """Test case for the RemoveFile function"""
664

    
665
  def setUp(self):
666
    """Create a temp dir and file for each case"""
667
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
668
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
669
    os.close(fd)
670

    
671
  def tearDown(self):
672
    if os.path.exists(self.tmpfile):
673
      os.unlink(self.tmpfile)
674
    os.rmdir(self.tmpdir)
675

    
676
  def testIgnoreDirs(self):
677
    """Test that RemoveFile() ignores directories"""
678
    self.assertEqual(None, RemoveFile(self.tmpdir))
679

    
680
  def testIgnoreNotExisting(self):
681
    """Test that RemoveFile() ignores non-existing files"""
682
    RemoveFile(self.tmpfile)
683
    RemoveFile(self.tmpfile)
684

    
685
  def testRemoveFile(self):
686
    """Test that RemoveFile does remove a file"""
687
    RemoveFile(self.tmpfile)
688
    if os.path.exists(self.tmpfile):
689
      self.fail("File '%s' not removed" % self.tmpfile)
690

    
691
  def testRemoveSymlink(self):
692
    """Test that RemoveFile does remove symlinks"""
693
    symlink = self.tmpdir + "/symlink"
694
    os.symlink("no-such-file", symlink)
695
    RemoveFile(symlink)
696
    if os.path.exists(symlink):
697
      self.fail("File '%s' not removed" % symlink)
698
    os.symlink(self.tmpfile, symlink)
699
    RemoveFile(symlink)
700
    if os.path.exists(symlink):
701
      self.fail("File '%s' not removed" % symlink)
702

    
703

    
704
class TestRemoveDir(unittest.TestCase):
705
  def setUp(self):
706
    self.tmpdir = tempfile.mkdtemp()
707

    
708
  def tearDown(self):
709
    try:
710
      shutil.rmtree(self.tmpdir)
711
    except EnvironmentError:
712
      pass
713

    
714
  def testEmptyDir(self):
715
    utils.RemoveDir(self.tmpdir)
716
    self.assertFalse(os.path.isdir(self.tmpdir))
717

    
718
  def testNonEmptyDir(self):
719
    self.tmpfile = os.path.join(self.tmpdir, "test1")
720
    open(self.tmpfile, "w").close()
721
    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
722

    
723

    
724
class TestRename(unittest.TestCase):
725
  """Test case for RenameFile"""
726

    
727
  def setUp(self):
728
    """Create a temporary directory"""
729
    self.tmpdir = tempfile.mkdtemp()
730
    self.tmpfile = os.path.join(self.tmpdir, "test1")
731

    
732
    # Touch the file
733
    open(self.tmpfile, "w").close()
734

    
735
  def tearDown(self):
736
    """Remove temporary directory"""
737
    shutil.rmtree(self.tmpdir)
738

    
739
  def testSimpleRename1(self):
740
    """Simple rename 1"""
741
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
742
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
743

    
744
  def testSimpleRename2(self):
745
    """Simple rename 2"""
746
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
747
                     mkdir=True)
748
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
749

    
750
  def testRenameMkdir(self):
751
    """Rename with mkdir"""
752
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
753
                     mkdir=True)
754
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
755
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
756

    
757
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
758
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
759
                     mkdir=True)
760
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
761
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
762
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
763

    
764

    
765
class TestReadFile(testutils.GanetiTestCase):
766

    
767
  def testReadAll(self):
768
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
769
    self.assertEqual(len(data), 814)
770

    
771
    h = compat.md5_hash()
772
    h.update(data)
773
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
774

    
775
  def testReadSize(self):
776
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
777
                          size=100)
778
    self.assertEqual(len(data), 100)
779

    
780
    h = compat.md5_hash()
781
    h.update(data)
782
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
783

    
784
  def testError(self):
785
    self.assertRaises(EnvironmentError, utils.ReadFile,
786
                      "/dev/null/does-not-exist")
787

    
788

    
789
class TestReadOneLineFile(testutils.GanetiTestCase):
790

    
791
  def setUp(self):
792
    testutils.GanetiTestCase.setUp(self)
793

    
794
  def testDefault(self):
795
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
796
    self.assertEqual(len(data), 27)
797
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
798

    
799
  def testNotStrict(self):
800
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
801
    self.assertEqual(len(data), 27)
802
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
803

    
804
  def testStrictFailure(self):
805
    self.assertRaises(errors.GenericError, ReadOneLineFile,
806
                      self._TestDataFilename("cert1.pem"), strict=True)
807

    
808
  def testLongLine(self):
809
    dummydata = (1024 * "Hello World! ")
810
    myfile = self._CreateTempFile()
811
    utils.WriteFile(myfile, data=dummydata)
812
    datastrict = ReadOneLineFile(myfile, strict=True)
813
    datalax = ReadOneLineFile(myfile, strict=False)
814
    self.assertEqual(dummydata, datastrict)
815
    self.assertEqual(dummydata, datalax)
816

    
817
  def testNewline(self):
818
    myfile = self._CreateTempFile()
819
    myline = "myline"
820
    for nl in ["", "\n", "\r\n"]:
821
      dummydata = "%s%s" % (myline, nl)
822
      utils.WriteFile(myfile, data=dummydata)
823
      datalax = ReadOneLineFile(myfile, strict=False)
824
      self.assertEqual(myline, datalax)
825
      datastrict = ReadOneLineFile(myfile, strict=True)
826
      self.assertEqual(myline, datastrict)
827

    
828
  def testWhitespaceAndMultipleLines(self):
829
    myfile = self._CreateTempFile()
830
    for nl in ["", "\n", "\r\n"]:
831
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
832
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
833
        utils.WriteFile(myfile, data=dummydata)
834
        datalax = ReadOneLineFile(myfile, strict=False)
835
        if nl:
836
          self.assert_(set("\r\n") & set(dummydata))
837
          self.assertRaises(errors.GenericError, ReadOneLineFile,
838
                            myfile, strict=True)
839
          explen = len("Foo bar baz ") + len(ws)
840
          self.assertEqual(len(datalax), explen)
841
          self.assertEqual(datalax, dummydata[:explen])
842
          self.assertFalse(set("\r\n") & set(datalax))
843
        else:
844
          datastrict = ReadOneLineFile(myfile, strict=True)
845
          self.assertEqual(dummydata, datastrict)
846
          self.assertEqual(dummydata, datalax)
847

    
848
  def testEmptylines(self):
849
    myfile = self._CreateTempFile()
850
    myline = "myline"
851
    for nl in ["\n", "\r\n"]:
852
      for ol in ["", "otherline"]:
853
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
854
        utils.WriteFile(myfile, data=dummydata)
855
        self.assert_(set("\r\n") & set(dummydata))
856
        datalax = ReadOneLineFile(myfile, strict=False)
857
        self.assertEqual(myline, datalax)
858
        if ol:
859
          self.assertRaises(errors.GenericError, ReadOneLineFile,
860
                            myfile, strict=True)
861
        else:
862
          datastrict = ReadOneLineFile(myfile, strict=True)
863
          self.assertEqual(myline, datastrict)
864

    
865
  def testEmptyfile(self):
866
    myfile = self._CreateTempFile()
867
    self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
868

    
869

    
870
class TestTimestampForFilename(unittest.TestCase):
871
  def test(self):
872
    self.assert_("." not in utils.TimestampForFilename())
873
    self.assert_(":" not in utils.TimestampForFilename())
874

    
875

    
876
class TestCreateBackup(testutils.GanetiTestCase):
877
  def setUp(self):
878
    testutils.GanetiTestCase.setUp(self)
879

    
880
    self.tmpdir = tempfile.mkdtemp()
881

    
882
  def tearDown(self):
883
    testutils.GanetiTestCase.tearDown(self)
884

    
885
    shutil.rmtree(self.tmpdir)
886

    
887
  def testEmpty(self):
888
    filename = PathJoin(self.tmpdir, "config.data")
889
    utils.WriteFile(filename, data="")
890
    bname = utils.CreateBackup(filename)
891
    self.assertFileContent(bname, "")
892
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
893
    utils.CreateBackup(filename)
894
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
895
    utils.CreateBackup(filename)
896
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
897

    
898
    fifoname = PathJoin(self.tmpdir, "fifo")
899
    os.mkfifo(fifoname)
900
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
901

    
902
  def testContent(self):
903
    bkpcount = 0
904
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
905
      for rep in [1, 2, 10, 127]:
906
        testdata = data * rep
907

    
908
        filename = PathJoin(self.tmpdir, "test.data_")
909
        utils.WriteFile(filename, data=testdata)
910
        self.assertFileContent(filename, testdata)
911

    
912
        for _ in range(3):
913
          bname = utils.CreateBackup(filename)
914
          bkpcount += 1
915
          self.assertFileContent(bname, testdata)
916
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
917

    
918

    
919
class TestParseCpuMask(unittest.TestCase):
920
  """Test case for the ParseCpuMask function."""
921

    
922
  def testWellFormed(self):
923
    self.assertEqual(utils.ParseCpuMask(""), [])
924
    self.assertEqual(utils.ParseCpuMask("1"), [1])
925
    self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
926

    
927
  def testInvalidInput(self):
928
    for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
929
      self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
930

    
931

    
932
class TestSshKeys(testutils.GanetiTestCase):
933
  """Test case for the AddAuthorizedKey function"""
934

    
935
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
936
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
937
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
938

    
939
  def setUp(self):
940
    testutils.GanetiTestCase.setUp(self)
941
    self.tmpname = self._CreateTempFile()
942
    handle = open(self.tmpname, 'w')
943
    try:
944
      handle.write("%s\n" % TestSshKeys.KEY_A)
945
      handle.write("%s\n" % TestSshKeys.KEY_B)
946
    finally:
947
      handle.close()
948

    
949
  def testAddingNewKey(self):
950
    utils.AddAuthorizedKey(self.tmpname,
951
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
952

    
953
    self.assertFileContent(self.tmpname,
954
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
955
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
956
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
957
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
958

    
959
  def testAddingAlmostButNotCompletelyTheSameKey(self):
960
    utils.AddAuthorizedKey(self.tmpname,
961
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
962

    
963
    self.assertFileContent(self.tmpname,
964
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
965
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
966
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
967
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
968

    
969
  def testAddingExistingKeyWithSomeMoreSpaces(self):
970
    utils.AddAuthorizedKey(self.tmpname,
971
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
972

    
973
    self.assertFileContent(self.tmpname,
974
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
975
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
976
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
977

    
978
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
979
    utils.RemoveAuthorizedKey(self.tmpname,
980
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
981

    
982
    self.assertFileContent(self.tmpname,
983
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
984
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
985

    
986
  def testRemovingNonExistingKey(self):
987
    utils.RemoveAuthorizedKey(self.tmpname,
988
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
989

    
990
    self.assertFileContent(self.tmpname,
991
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
992
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
993
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
994

    
995

    
996
class TestEtcHosts(testutils.GanetiTestCase):
997
  """Test functions modifying /etc/hosts"""
998

    
999
  def setUp(self):
1000
    testutils.GanetiTestCase.setUp(self)
1001
    self.tmpname = self._CreateTempFile()
1002
    handle = open(self.tmpname, 'w')
1003
    try:
1004
      handle.write('# This is a test file for /etc/hosts\n')
1005
      handle.write('127.0.0.1\tlocalhost\n')
1006
      handle.write('192.0.2.1 router gw\n')
1007
    finally:
1008
      handle.close()
1009

    
1010
  def testSettingNewIp(self):
1011
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1012
                     ['myhost'])
1013

    
1014
    self.assertFileContent(self.tmpname,
1015
      "# This is a test file for /etc/hosts\n"
1016
      "127.0.0.1\tlocalhost\n"
1017
      "192.0.2.1 router gw\n"
1018
      "198.51.100.4\tmyhost.example.com myhost\n")
1019
    self.assertFileMode(self.tmpname, 0644)
1020

    
1021
  def testSettingExistingIp(self):
1022
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1023
                     ['myhost'])
1024

    
1025
    self.assertFileContent(self.tmpname,
1026
      "# This is a test file for /etc/hosts\n"
1027
      "127.0.0.1\tlocalhost\n"
1028
      "192.0.2.1\tmyhost.example.com myhost\n")
1029
    self.assertFileMode(self.tmpname, 0644)
1030

    
1031
  def testSettingDuplicateName(self):
1032
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1033

    
1034
    self.assertFileContent(self.tmpname,
1035
      "# This is a test file for /etc/hosts\n"
1036
      "127.0.0.1\tlocalhost\n"
1037
      "192.0.2.1 router gw\n"
1038
      "198.51.100.4\tmyhost\n")
1039
    self.assertFileMode(self.tmpname, 0644)
1040

    
1041
  def testRemovingExistingHost(self):
1042
    RemoveEtcHostsEntry(self.tmpname, 'router')
1043

    
1044
    self.assertFileContent(self.tmpname,
1045
      "# This is a test file for /etc/hosts\n"
1046
      "127.0.0.1\tlocalhost\n"
1047
      "192.0.2.1 gw\n")
1048
    self.assertFileMode(self.tmpname, 0644)
1049

    
1050
  def testRemovingSingleExistingHost(self):
1051
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1052

    
1053
    self.assertFileContent(self.tmpname,
1054
      "# This is a test file for /etc/hosts\n"
1055
      "192.0.2.1 router gw\n")
1056
    self.assertFileMode(self.tmpname, 0644)
1057

    
1058
  def testRemovingNonExistingHost(self):
1059
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1060

    
1061
    self.assertFileContent(self.tmpname,
1062
      "# This is a test file for /etc/hosts\n"
1063
      "127.0.0.1\tlocalhost\n"
1064
      "192.0.2.1 router gw\n")
1065
    self.assertFileMode(self.tmpname, 0644)
1066

    
1067
  def testRemovingAlias(self):
1068
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1069

    
1070
    self.assertFileContent(self.tmpname,
1071
      "# This is a test file for /etc/hosts\n"
1072
      "127.0.0.1\tlocalhost\n"
1073
      "192.0.2.1 router\n")
1074
    self.assertFileMode(self.tmpname, 0644)
1075

    
1076

    
1077
class TestGetMounts(unittest.TestCase):
1078
  """Test case for GetMounts()."""
1079

    
1080
  TESTDATA = (
1081
    "rootfs /     rootfs rw 0 0\n"
1082
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1083
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1084

    
1085
  def setUp(self):
1086
    self.tmpfile = tempfile.NamedTemporaryFile()
1087
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1088

    
1089
  def testGetMounts(self):
1090
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1091
      [
1092
        ("rootfs", "/", "rootfs", "rw"),
1093
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1094
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1095
      ])
1096

    
1097

    
1098
class TestListVisibleFiles(unittest.TestCase):
1099
  """Test case for ListVisibleFiles"""
1100

    
1101
  def setUp(self):
1102
    self.path = tempfile.mkdtemp()
1103

    
1104
  def tearDown(self):
1105
    shutil.rmtree(self.path)
1106

    
1107
  def _CreateFiles(self, files):
1108
    for name in files:
1109
      utils.WriteFile(os.path.join(self.path, name), data="test")
1110

    
1111
  def _test(self, files, expected):
1112
    self._CreateFiles(files)
1113
    found = ListVisibleFiles(self.path)
1114
    self.assertEqual(set(found), set(expected))
1115

    
1116
  def testAllVisible(self):
1117
    files = ["a", "b", "c"]
1118
    expected = files
1119
    self._test(files, expected)
1120

    
1121
  def testNoneVisible(self):
1122
    files = [".a", ".b", ".c"]
1123
    expected = []
1124
    self._test(files, expected)
1125

    
1126
  def testSomeVisible(self):
1127
    files = ["a", "b", ".c"]
1128
    expected = ["a", "b"]
1129
    self._test(files, expected)
1130

    
1131
  def testNonAbsolutePath(self):
1132
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1133

    
1134
  def testNonNormalizedPath(self):
1135
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1136
                          "/bin/../tmp")
1137

    
1138

    
1139
class TestNewUUID(unittest.TestCase):
1140
  """Test case for NewUUID"""
1141

    
1142
  def runTest(self):
1143
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1144

    
1145

    
1146
class TestFirstFree(unittest.TestCase):
1147
  """Test case for the FirstFree function"""
1148

    
1149
  def test(self):
1150
    """Test FirstFree"""
1151
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1152
    self.failUnlessEqual(FirstFree([]), None)
1153
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1154
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1155
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1156

    
1157

    
1158
class TestTailFile(testutils.GanetiTestCase):
1159
  """Test case for the TailFile function"""
1160

    
1161
  def testEmpty(self):
1162
    fname = self._CreateTempFile()
1163
    self.failUnlessEqual(TailFile(fname), [])
1164
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1165

    
1166
  def testAllLines(self):
1167
    data = ["test %d" % i for i in range(30)]
1168
    for i in range(30):
1169
      fname = self._CreateTempFile()
1170
      fd = open(fname, "w")
1171
      fd.write("\n".join(data[:i]))
1172
      if i > 0:
1173
        fd.write("\n")
1174
      fd.close()
1175
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1176

    
1177
  def testPartialLines(self):
1178
    data = ["test %d" % i for i in range(30)]
1179
    fname = self._CreateTempFile()
1180
    fd = open(fname, "w")
1181
    fd.write("\n".join(data))
1182
    fd.write("\n")
1183
    fd.close()
1184
    for i in range(1, 30):
1185
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1186

    
1187
  def testBigFile(self):
1188
    data = ["test %d" % i for i in range(30)]
1189
    fname = self._CreateTempFile()
1190
    fd = open(fname, "w")
1191
    fd.write("X" * 1048576)
1192
    fd.write("\n")
1193
    fd.write("\n".join(data))
1194
    fd.write("\n")
1195
    fd.close()
1196
    for i in range(1, 30):
1197
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1198

    
1199

    
1200
class _BaseFileLockTest:
1201
  """Test case for the FileLock class"""
1202

    
1203
  def testSharedNonblocking(self):
1204
    self.lock.Shared(blocking=False)
1205
    self.lock.Close()
1206

    
1207
  def testExclusiveNonblocking(self):
1208
    self.lock.Exclusive(blocking=False)
1209
    self.lock.Close()
1210

    
1211
  def testUnlockNonblocking(self):
1212
    self.lock.Unlock(blocking=False)
1213
    self.lock.Close()
1214

    
1215
  def testSharedBlocking(self):
1216
    self.lock.Shared(blocking=True)
1217
    self.lock.Close()
1218

    
1219
  def testExclusiveBlocking(self):
1220
    self.lock.Exclusive(blocking=True)
1221
    self.lock.Close()
1222

    
1223
  def testUnlockBlocking(self):
1224
    self.lock.Unlock(blocking=True)
1225
    self.lock.Close()
1226

    
1227
  def testSharedExclusiveUnlock(self):
1228
    self.lock.Shared(blocking=False)
1229
    self.lock.Exclusive(blocking=False)
1230
    self.lock.Unlock(blocking=False)
1231
    self.lock.Close()
1232

    
1233
  def testExclusiveSharedUnlock(self):
1234
    self.lock.Exclusive(blocking=False)
1235
    self.lock.Shared(blocking=False)
1236
    self.lock.Unlock(blocking=False)
1237
    self.lock.Close()
1238

    
1239
  def testSimpleTimeout(self):
1240
    # These will succeed on the first attempt, hence a short timeout
1241
    self.lock.Shared(blocking=True, timeout=10.0)
1242
    self.lock.Exclusive(blocking=False, timeout=10.0)
1243
    self.lock.Unlock(blocking=True, timeout=10.0)
1244
    self.lock.Close()
1245

    
1246
  @staticmethod
1247
  def _TryLockInner(filename, shared, blocking):
1248
    lock = utils.FileLock.Open(filename)
1249

    
1250
    if shared:
1251
      fn = lock.Shared
1252
    else:
1253
      fn = lock.Exclusive
1254

    
1255
    try:
1256
      # The timeout doesn't really matter as the parent process waits for us to
1257
      # finish anyway.
1258
      fn(blocking=blocking, timeout=0.01)
1259
    except errors.LockError, err:
1260
      return False
1261

    
1262
    return True
1263

    
1264
  def _TryLock(self, *args):
1265
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1266
                                      *args)
1267

    
1268
  def testTimeout(self):
1269
    for blocking in [True, False]:
1270
      self.lock.Exclusive(blocking=True)
1271
      self.failIf(self._TryLock(False, blocking))
1272
      self.failIf(self._TryLock(True, blocking))
1273

    
1274
      self.lock.Shared(blocking=True)
1275
      self.assert_(self._TryLock(True, blocking))
1276
      self.failIf(self._TryLock(False, blocking))
1277

    
1278
  def testCloseShared(self):
1279
    self.lock.Close()
1280
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1281

    
1282
  def testCloseExclusive(self):
1283
    self.lock.Close()
1284
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1285

    
1286
  def testCloseUnlock(self):
1287
    self.lock.Close()
1288
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1289

    
1290

    
1291
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1292
  TESTDATA = "Hello World\n" * 10
1293

    
1294
  def setUp(self):
1295
    testutils.GanetiTestCase.setUp(self)
1296

    
1297
    self.tmpfile = tempfile.NamedTemporaryFile()
1298
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1299
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1300

    
1301
    # Ensure "Open" didn't truncate file
1302
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1303

    
1304
  def tearDown(self):
1305
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1306

    
1307
    testutils.GanetiTestCase.tearDown(self)
1308

    
1309

    
1310
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1311
  def setUp(self):
1312
    self.tmpfile = tempfile.NamedTemporaryFile()
1313
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1314

    
1315

    
1316
class TestTimeFunctions(unittest.TestCase):
1317
  """Test case for time functions"""
1318

    
1319
  def runTest(self):
1320
    self.assertEqual(utils.SplitTime(1), (1, 0))
1321
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1322
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1323
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1324
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1325
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1326
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1327
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1328

    
1329
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1330

    
1331
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1332
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1333
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1334

    
1335
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1336
                     1218448917.481)
1337
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1338

    
1339
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1340
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1341
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1342
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1343
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1344

    
1345

    
1346
class FieldSetTestCase(unittest.TestCase):
1347
  """Test case for FieldSets"""
1348

    
1349
  def testSimpleMatch(self):
1350
    f = utils.FieldSet("a", "b", "c", "def")
1351
    self.failUnless(f.Matches("a"))
1352
    self.failIf(f.Matches("d"), "Substring matched")
1353
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1354
    self.failIf(f.NonMatching(["b", "c"]))
1355
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1356
    self.failUnless(f.NonMatching(["a", "d"]))
1357

    
1358
  def testRegexMatch(self):
1359
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1360
    self.failUnless(f.Matches("b1"))
1361
    self.failUnless(f.Matches("b99"))
1362
    self.failIf(f.Matches("b/1"))
1363
    self.failIf(f.NonMatching(["b12", "c"]))
1364
    self.failUnless(f.NonMatching(["a", "1"]))
1365

    
1366
class TestForceDictType(unittest.TestCase):
1367
  """Test case for ForceDictType"""
1368
  KEY_TYPES = {
1369
    "a": constants.VTYPE_INT,
1370
    "b": constants.VTYPE_BOOL,
1371
    "c": constants.VTYPE_STRING,
1372
    "d": constants.VTYPE_SIZE,
1373
    "e": constants.VTYPE_MAYBE_STRING,
1374
    }
1375

    
1376
  def _fdt(self, dict, allowed_values=None):
1377
    if allowed_values is None:
1378
      utils.ForceDictType(dict, self.KEY_TYPES)
1379
    else:
1380
      utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
1381

    
1382
    return dict
1383

    
1384
  def testSimpleDict(self):
1385
    self.assertEqual(self._fdt({}), {})
1386
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1387
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1388
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1389
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1390
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1391
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1392
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1393
    self.assertEqual(self._fdt({'b': False}), {'b': False})
1394
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1395
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1396
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1397
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1398
    self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1399
    self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1400
    self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1401
    self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
1402

    
1403
  def testErrors(self):
1404
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1405
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
1406
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1407
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1408
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1409
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1410
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1411
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
1412
    self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
1413
    self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
1414
                      {"b": "hello"}, {"b": "no-such-type"})
1415

    
1416

    
1417
class TestIsNormAbsPath(unittest.TestCase):
1418
  """Testing case for IsNormAbsPath"""
1419

    
1420
  def _pathTestHelper(self, path, result):
1421
    if result:
1422
      self.assert_(utils.IsNormAbsPath(path),
1423
          "Path %s should result absolute and normalized" % path)
1424
    else:
1425
      self.assertFalse(utils.IsNormAbsPath(path),
1426
          "Path %s should not result absolute and normalized" % path)
1427

    
1428
  def testBase(self):
1429
    self._pathTestHelper('/etc', True)
1430
    self._pathTestHelper('/srv', True)
1431
    self._pathTestHelper('etc', False)
1432
    self._pathTestHelper('/etc/../root', False)
1433
    self._pathTestHelper('/etc/', False)
1434

    
1435

    
1436
class RunInSeparateProcess(unittest.TestCase):
1437
  def test(self):
1438
    for exp in [True, False]:
1439
      def _child():
1440
        return exp
1441

    
1442
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1443

    
1444
  def testArgs(self):
1445
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1446
      def _child(carg1, carg2):
1447
        return carg1 == "Foo" and carg2 == arg
1448

    
1449
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1450

    
1451
  def testPid(self):
1452
    parent_pid = os.getpid()
1453

    
1454
    def _check():
1455
      return os.getpid() == parent_pid
1456

    
1457
    self.failIf(utils.RunInSeparateProcess(_check))
1458

    
1459
  def testSignal(self):
1460
    def _kill():
1461
      os.kill(os.getpid(), signal.SIGTERM)
1462

    
1463
    self.assertRaises(errors.GenericError,
1464
                      utils.RunInSeparateProcess, _kill)
1465

    
1466
  def testException(self):
1467
    def _exc():
1468
      raise errors.GenericError("This is a test")
1469

    
1470
    self.assertRaises(errors.GenericError,
1471
                      utils.RunInSeparateProcess, _exc)
1472

    
1473

    
1474
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1475
  def setUp(self):
1476
    self.tmpdir = tempfile.mkdtemp()
1477

    
1478
  def tearDown(self):
1479
    shutil.rmtree(self.tmpdir)
1480

    
1481
  def _checkRsaPrivateKey(self, key):
1482
    lines = key.splitlines()
1483
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1484
            "-----END RSA PRIVATE KEY-----" in lines)
1485

    
1486
  def _checkCertificate(self, cert):
1487
    lines = cert.splitlines()
1488
    return ("-----BEGIN CERTIFICATE-----" in lines and
1489
            "-----END CERTIFICATE-----" in lines)
1490

    
1491
  def test(self):
1492
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1493
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1494
      self._checkRsaPrivateKey(key_pem)
1495
      self._checkCertificate(cert_pem)
1496

    
1497
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1498
                                           key_pem)
1499
      self.assert_(key.bits() >= 1024)
1500
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1501
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1502

    
1503
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1504
                                             cert_pem)
1505
      self.failIf(x509.has_expired())
1506
      self.assertEqual(x509.get_issuer().CN, common_name)
1507
      self.assertEqual(x509.get_subject().CN, common_name)
1508
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1509

    
1510
  def testLegacy(self):
1511
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1512

    
1513
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1514

    
1515
    cert1 = utils.ReadFile(cert1_filename)
1516

    
1517
    self.assert_(self._checkRsaPrivateKey(cert1))
1518
    self.assert_(self._checkCertificate(cert1))
1519

    
1520

    
1521
class TestPathJoin(unittest.TestCase):
1522
  """Testing case for PathJoin"""
1523

    
1524
  def testBasicItems(self):
1525
    mlist = ["/a", "b", "c"]
1526
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1527

    
1528
  def testNonAbsPrefix(self):
1529
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1530

    
1531
  def testBackTrack(self):
1532
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1533

    
1534
  def testMultiAbs(self):
1535
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1536

    
1537

    
1538
class TestValidateServiceName(unittest.TestCase):
1539
  def testValid(self):
1540
    testnames = [
1541
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1542
      "ganeti",
1543
      "gnt-masterd",
1544
      "HELLO_WORLD_SVC",
1545
      "hello.world.1",
1546
      "0", "80", "1111", "65535",
1547
      ]
1548

    
1549
    for name in testnames:
1550
      self.assertEqual(utils.ValidateServiceName(name), name)
1551

    
1552
  def testInvalid(self):
1553
    testnames = [
1554
      -15756, -1, 65536, 133428083,
1555
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1556
      "-8546", "-1", "65536",
1557
      (129 * "A"),
1558
      ]
1559

    
1560
    for name in testnames:
1561
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1562

    
1563

    
1564
class TestParseAsn1Generalizedtime(unittest.TestCase):
1565
  def test(self):
1566
    # UTC
1567
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1568
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1569
                     1266860512)
1570
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1571
                     (2**31) - 1)
1572

    
1573
    # With offset
1574
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1575
                     1266860512)
1576
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1577
                     1266931012)
1578
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1579
                     1266931088)
1580
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1581
                     1266931295)
1582
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1583
                     3600)
1584

    
1585
    # Leap seconds are not supported by datetime.datetime
1586
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1587
                      "19841231235960+0000")
1588
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1589
                      "19920630235960+0000")
1590

    
1591
    # Errors
1592
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1593
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1594
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1595
                      "20100222174152")
1596
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1597
                      "Mon Feb 22 17:47:02 UTC 2010")
1598
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1599
                      "2010-02-22 17:42:02")
1600

    
1601

    
1602
class TestGetX509CertValidity(testutils.GanetiTestCase):
1603
  def setUp(self):
1604
    testutils.GanetiTestCase.setUp(self)
1605

    
1606
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1607

    
1608
    # Test whether we have pyOpenSSL 0.7 or above
1609
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1610

    
1611
    if not self.pyopenssl0_7:
1612
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1613
                    " function correctly")
1614

    
1615
  def _LoadCert(self, name):
1616
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1617
                                           self._ReadTestData(name))
1618

    
1619
  def test(self):
1620
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1621
    if self.pyopenssl0_7:
1622
      self.assertEqual(validity, (1266919967, 1267524767))
1623
    else:
1624
      self.assertEqual(validity, (None, None))
1625

    
1626

    
1627
class TestSignX509Certificate(unittest.TestCase):
1628
  KEY = "My private key!"
1629
  KEY_OTHER = "Another key"
1630

    
1631
  def test(self):
1632
    # Generate certificate valid for 5 minutes
1633
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1634

    
1635
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1636
                                           cert_pem)
1637

    
1638
    # No signature at all
1639
    self.assertRaises(errors.GenericError,
1640
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1641

    
1642
    # Invalid input
1643
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1644
                      "", self.KEY)
1645
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1646
                      "X-Ganeti-Signature: \n", self.KEY)
1647
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1648
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1649
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1650
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1651
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1652
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1653

    
1654
    # Invalid salt
1655
    for salt in list("-_@$,:;/\\ \t\n"):
1656
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1657
                        cert_pem, self.KEY, "foo%sbar" % salt)
1658

    
1659
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1660
                 utils.GenerateSecret(numbytes=4),
1661
                 utils.GenerateSecret(numbytes=16),
1662
                 "{123:456}".encode("hex")]:
1663
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1664

    
1665
      self._Check(cert, salt, signed_pem)
1666

    
1667
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1668
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1669
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1670
                               "lines----\n------ at\nthe end!"))
1671

    
1672
  def _Check(self, cert, salt, pem):
1673
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1674
    self.assertEqual(salt, salt2)
1675
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1676

    
1677
    # Other key
1678
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1679
                      pem, self.KEY_OTHER)
1680

    
1681

    
1682
class TestMakedirs(unittest.TestCase):
1683
  def setUp(self):
1684
    self.tmpdir = tempfile.mkdtemp()
1685

    
1686
  def tearDown(self):
1687
    shutil.rmtree(self.tmpdir)
1688

    
1689
  def testNonExisting(self):
1690
    path = PathJoin(self.tmpdir, "foo")
1691
    utils.Makedirs(path)
1692
    self.assert_(os.path.isdir(path))
1693

    
1694
  def testExisting(self):
1695
    path = PathJoin(self.tmpdir, "foo")
1696
    os.mkdir(path)
1697
    utils.Makedirs(path)
1698
    self.assert_(os.path.isdir(path))
1699

    
1700
  def testRecursiveNonExisting(self):
1701
    path = PathJoin(self.tmpdir, "foo/bar/baz")
1702
    utils.Makedirs(path)
1703
    self.assert_(os.path.isdir(path))
1704

    
1705
  def testRecursiveExisting(self):
1706
    path = PathJoin(self.tmpdir, "B/moo/xyz")
1707
    self.assertFalse(os.path.exists(path))
1708
    os.mkdir(PathJoin(self.tmpdir, "B"))
1709
    utils.Makedirs(path)
1710
    self.assert_(os.path.isdir(path))
1711

    
1712

    
1713
class TestReadLockedPidFile(unittest.TestCase):
1714
  def setUp(self):
1715
    self.tmpdir = tempfile.mkdtemp()
1716

    
1717
  def tearDown(self):
1718
    shutil.rmtree(self.tmpdir)
1719

    
1720
  def testNonExistent(self):
1721
    path = PathJoin(self.tmpdir, "nonexist")
1722
    self.assert_(utils.ReadLockedPidFile(path) is None)
1723

    
1724
  def testUnlocked(self):
1725
    path = PathJoin(self.tmpdir, "pid")
1726
    utils.WriteFile(path, data="123")
1727
    self.assert_(utils.ReadLockedPidFile(path) is None)
1728

    
1729
  def testLocked(self):
1730
    path = PathJoin(self.tmpdir, "pid")
1731
    utils.WriteFile(path, data="123")
1732

    
1733
    fl = utils.FileLock.Open(path)
1734
    try:
1735
      fl.Exclusive(blocking=True)
1736

    
1737
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
1738
    finally:
1739
      fl.Close()
1740

    
1741
    self.assert_(utils.ReadLockedPidFile(path) is None)
1742

    
1743
  def testError(self):
1744
    path = PathJoin(self.tmpdir, "foobar", "pid")
1745
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
1746
    # open(2) should return ENOTDIR
1747
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
1748

    
1749

    
1750
class TestCertVerification(testutils.GanetiTestCase):
1751
  def setUp(self):
1752
    testutils.GanetiTestCase.setUp(self)
1753

    
1754
    self.tmpdir = tempfile.mkdtemp()
1755

    
1756
  def tearDown(self):
1757
    shutil.rmtree(self.tmpdir)
1758

    
1759
  def testVerifyCertificate(self):
1760
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
1761
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1762
                                           cert_pem)
1763

    
1764
    # Not checking return value as this certificate is expired
1765
    utils.VerifyX509Certificate(cert, 30, 7)
1766

    
1767

    
1768
class TestVerifyCertificateInner(unittest.TestCase):
1769
  def test(self):
1770
    vci = utils._VerifyCertificateInner
1771

    
1772
    # Valid
1773
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
1774
                     (None, None))
1775

    
1776
    # Not yet valid
1777
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
1778
    self.assertEqual(errcode, utils.CERT_WARNING)
1779

    
1780
    # Expiring soon
1781
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
1782
    self.assertEqual(errcode, utils.CERT_ERROR)
1783

    
1784
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
1785
    self.assertEqual(errcode, utils.CERT_WARNING)
1786

    
1787
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
1788
    self.assertEqual(errcode, None)
1789

    
1790
    # Expired
1791
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
1792
    self.assertEqual(errcode, utils.CERT_ERROR)
1793

    
1794
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
1795
    self.assertEqual(errcode, utils.CERT_ERROR)
1796

    
1797
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
1798
    self.assertEqual(errcode, utils.CERT_ERROR)
1799

    
1800
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
1801
    self.assertEqual(errcode, utils.CERT_ERROR)
1802

    
1803

    
1804
class TestIgnoreSignals(unittest.TestCase):
1805
  """Test the IgnoreSignals decorator"""
1806

    
1807
  @staticmethod
1808
  def _Raise(exception):
1809
    raise exception
1810

    
1811
  @staticmethod
1812
  def _Return(rval):
1813
    return rval
1814

    
1815
  def testIgnoreSignals(self):
1816
    sock_err_intr = socket.error(errno.EINTR, "Message")
1817
    sock_err_inval = socket.error(errno.EINVAL, "Message")
1818

    
1819
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
1820
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
1821

    
1822
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
1823
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
1824
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
1825
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
1826

    
1827
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
1828
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
1829
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
1830
                      sock_err_inval)
1831
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
1832
                      env_err_inval)
1833

    
1834
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
1835
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
1836

    
1837

    
1838
class TestEnsureDirs(unittest.TestCase):
1839
  """Tests for EnsureDirs"""
1840

    
1841
  def setUp(self):
1842
    self.dir = tempfile.mkdtemp()
1843
    self.old_umask = os.umask(0777)
1844

    
1845
  def testEnsureDirs(self):
1846
    utils.EnsureDirs([
1847
        (PathJoin(self.dir, "foo"), 0777),
1848
        (PathJoin(self.dir, "bar"), 0000),
1849
        ])
1850
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
1851
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
1852

    
1853
  def tearDown(self):
1854
    os.rmdir(PathJoin(self.dir, "foo"))
1855
    os.rmdir(PathJoin(self.dir, "bar"))
1856
    os.rmdir(self.dir)
1857
    os.umask(self.old_umask)
1858

    
1859

    
1860
class TestIgnoreProcessNotFound(unittest.TestCase):
1861
  @staticmethod
1862
  def _WritePid(fd):
1863
    os.write(fd, str(os.getpid()))
1864
    os.close(fd)
1865
    return True
1866

    
1867
  def test(self):
1868
    (pid_read_fd, pid_write_fd) = os.pipe()
1869

    
1870
    # Start short-lived process which writes its PID to pipe
1871
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
1872
    os.close(pid_write_fd)
1873

    
1874
    # Read PID from pipe
1875
    pid = int(os.read(pid_read_fd, 1024))
1876
    os.close(pid_read_fd)
1877

    
1878
    # Try to send signal to process which exited recently
1879
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
1880

    
1881

    
1882
class TestFindMatch(unittest.TestCase):
1883
  def test(self):
1884
    data = {
1885
      "aaaa": "Four A",
1886
      "bb": {"Two B": True},
1887
      re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
1888
      }
1889

    
1890
    self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
1891
    self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
1892

    
1893
    for i in ["foo", "bar", "bazX"]:
1894
      for j in range(1, 100, 7):
1895
        self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
1896
                         ((1, 2, 3), [i, str(j)]))
1897

    
1898
  def testNoMatch(self):
1899
    self.assert_(utils.FindMatch({}, "") is None)
1900
    self.assert_(utils.FindMatch({}, "foo") is None)
1901
    self.assert_(utils.FindMatch({}, 1234) is None)
1902

    
1903
    data = {
1904
      "X": "Hello World",
1905
      re.compile("^(something)$"): "Hello World",
1906
      }
1907

    
1908
    self.assert_(utils.FindMatch(data, "") is None)
1909
    self.assert_(utils.FindMatch(data, "Hello World") is None)
1910

    
1911

    
1912
class TestFileID(testutils.GanetiTestCase):
1913
  def testEquality(self):
1914
    name = self._CreateTempFile()
1915
    oldi = utils.GetFileID(path=name)
1916
    self.failUnless(utils.VerifyFileID(oldi, oldi))
1917

    
1918
  def testUpdate(self):
1919
    name = self._CreateTempFile()
1920
    oldi = utils.GetFileID(path=name)
1921
    os.utime(name, None)
1922
    fd = os.open(name, os.O_RDWR)
1923
    try:
1924
      newi = utils.GetFileID(fd=fd)
1925
      self.failUnless(utils.VerifyFileID(oldi, newi))
1926
      self.failUnless(utils.VerifyFileID(newi, oldi))
1927
    finally:
1928
      os.close(fd)
1929

    
1930
  def testWriteFile(self):
1931
    name = self._CreateTempFile()
1932
    oldi = utils.GetFileID(path=name)
1933
    mtime = oldi[2]
1934
    os.utime(name, (mtime + 10, mtime + 10))
1935
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
1936
                      oldi, data="")
1937
    os.utime(name, (mtime - 10, mtime - 10))
1938
    utils.SafeWriteFile(name, oldi, data="")
1939
    oldi = utils.GetFileID(path=name)
1940
    mtime = oldi[2]
1941
    os.utime(name, (mtime + 10, mtime + 10))
1942
    # this doesn't raise, since we passed None
1943
    utils.SafeWriteFile(name, None, data="")
1944

    
1945
  def testError(self):
1946
    t = tempfile.NamedTemporaryFile()
1947
    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
1948
                      path=t.name, fd=t.fileno())
1949

    
1950

    
1951
class TimeMock:
1952
  def __init__(self, values):
1953
    self.values = values
1954

    
1955
  def __call__(self):
1956
    return self.values.pop(0)
1957

    
1958

    
1959
class TestRunningTimeout(unittest.TestCase):
1960
  def setUp(self):
1961
    self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
1962

    
1963
  def testRemainingFloat(self):
1964
    timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
1965
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
1966
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
1967
    self.assertAlmostEqual(timeout.Remaining(), -1.5)
1968

    
1969
  def testRemaining(self):
1970
    self.time_fn = TimeMock([0, 2, 4, 5, 6])
1971
    timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
1972
    self.assertEqual(timeout.Remaining(), 3)
1973
    self.assertEqual(timeout.Remaining(), 1)
1974
    self.assertEqual(timeout.Remaining(), 0)
1975
    self.assertEqual(timeout.Remaining(), -1)
1976

    
1977
  def testRemainingNonNegative(self):
1978
    timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
1979
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
1980
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
1981
    self.assertEqual(timeout.Remaining(), 0.0)
1982

    
1983
  def testNegativeTimeout(self):
1984
    self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
1985

    
1986

    
1987
class TestTryConvert(unittest.TestCase):
1988
  def test(self):
1989
    for src, fn, result in [
1990
      ("1", int, 1),
1991
      ("a", int, "a"),
1992
      ("", bool, False),
1993
      ("a", bool, True),
1994
      ]:
1995
      self.assertEqual(utils.TryConvert(fn, src), result)
1996

    
1997

    
1998
class TestIsValidShellParam(unittest.TestCase):
1999
  def test(self):
2000
    for val, result in [
2001
      ("abc", True),
2002
      ("ab;cd", False),
2003
      ]:
2004
      self.assertEqual(utils.IsValidShellParam(val), result)
2005

    
2006

    
2007
class TestBuildShellCmd(unittest.TestCase):
2008
  def test(self):
2009
    self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
2010
                      "ls %s", "ab;cd")
2011
    self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
2012

    
2013

    
2014
class TestWriteFile(unittest.TestCase):
2015
  def setUp(self):
2016
    self.tfile = tempfile.NamedTemporaryFile()
2017
    self.did_pre = False
2018
    self.did_post = False
2019
    self.did_write = False
2020

    
2021
  def markPre(self, fd):
2022
    self.did_pre = True
2023

    
2024
  def markPost(self, fd):
2025
    self.did_post = True
2026

    
2027
  def markWrite(self, fd):
2028
    self.did_write = True
2029

    
2030
  def testWrite(self):
2031
    data = "abc"
2032
    utils.WriteFile(self.tfile.name, data=data)
2033
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
2034

    
2035
  def testErrors(self):
2036
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2037
                      self.tfile.name, data="test", fn=lambda fd: None)
2038
    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
2039
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2040
                      self.tfile.name, data="test", atime=0)
2041

    
2042
  def testCalls(self):
2043
    utils.WriteFile(self.tfile.name, fn=self.markWrite,
2044
                    prewrite=self.markPre, postwrite=self.markPost)
2045
    self.assertTrue(self.did_pre)
2046
    self.assertTrue(self.did_post)
2047
    self.assertTrue(self.did_write)
2048

    
2049
  def testDryRun(self):
2050
    orig = "abc"
2051
    self.tfile.write(orig)
2052
    self.tfile.flush()
2053
    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
2054
    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
2055

    
2056
  def testTimes(self):
2057
    f = self.tfile.name
2058
    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
2059
                   (int(time.time()), 5000)]:
2060
      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
2061
      st = os.stat(f)
2062
      self.assertEqual(st.st_atime, at)
2063
      self.assertEqual(st.st_mtime, mt)
2064

    
2065

    
2066
  def testNoClose(self):
2067
    data = "hello"
2068
    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
2069
    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
2070
    try:
2071
      os.lseek(fd, 0, 0)
2072
      self.assertEqual(os.read(fd, 4096), data)
2073
    finally:
2074
      os.close(fd)
2075

    
2076

    
2077
if __name__ == '__main__':
2078
  testutils.GanetiTestProgram()