utils: Move wrappers into separate file
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import distutils.version
25 import errno
26 import fcntl
27 import glob
28 import os
29 import os.path
30 import re
31 import shutil
32 import signal
33 import socket
34 import stat
35 import string
36 import tempfile
37 import time
38 import unittest
39 import warnings
40 import OpenSSL
41 import random
42 import operator
43
44 import testutils
45 from ganeti import constants
46 from ganeti import compat
47 from ganeti import utils
48 from ganeti import errors
49 from ganeti.utils import RunCmd, RemoveFile, \
50      ListVisibleFiles, FirstFree, \
51      TailFile, RunParts, PathJoin, \
52      ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
53
54
55 class TestIsProcessAlive(unittest.TestCase):
56   """Testing case for IsProcessAlive"""
57
58   def testExists(self):
59     mypid = os.getpid()
60     self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
61
62   def testNotExisting(self):
63     pid_non_existing = os.fork()
64     if pid_non_existing == 0:
65       os._exit(0)
66     elif pid_non_existing < 0:
67       raise SystemError("can't fork")
68     os.waitpid(pid_non_existing, 0)
69     self.assertFalse(utils.IsProcessAlive(pid_non_existing),
70                      "nonexisting process detected")
71
72
73 class TestGetProcStatusPath(unittest.TestCase):
74   def test(self):
75     self.assert_("/1234/" in utils._GetProcStatusPath(1234))
76     self.assertNotEqual(utils._GetProcStatusPath(1),
77                         utils._GetProcStatusPath(2))
78
79
80 class TestIsProcessHandlingSignal(unittest.TestCase):
81   def setUp(self):
82     self.tmpdir = tempfile.mkdtemp()
83
84   def tearDown(self):
85     shutil.rmtree(self.tmpdir)
86
87   def testParseSigsetT(self):
88     self.assertEqual(len(utils._ParseSigsetT("0")), 0)
89     self.assertEqual(utils._ParseSigsetT("1"), set([1]))
90     self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
91     self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
92     self.assertEqual(utils._ParseSigsetT("0000000180000202"),
93                      set([2, 10, 32, 33]))
94     self.assertEqual(utils._ParseSigsetT("0000000180000002"),
95                      set([2, 32, 33]))
96     self.assertEqual(utils._ParseSigsetT("0000000188000002"),
97                      set([2, 28, 32, 33]))
98     self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
99                      set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
100                           24, 25, 26, 28, 31]))
101     self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
102
103   def testGetProcStatusField(self):
104     for field in ["SigCgt", "Name", "FDSize"]:
105       for value in ["", "0", "cat", "  1234 KB"]:
106         pstatus = "\n".join([
107           "VmPeak: 999 kB",
108           "%s: %s" % (field, value),
109           "TracerPid: 0",
110           ])
111         result = utils._GetProcStatusField(pstatus, field)
112         self.assertEqual(result, value.strip())
113
114   def test(self):
115     sp = PathJoin(self.tmpdir, "status")
116
117     utils.WriteFile(sp, data="\n".join([
118       "Name:   bash",
119       "State:  S (sleeping)",
120       "SleepAVG:       98%",
121       "Pid:    22250",
122       "PPid:   10858",
123       "TracerPid:      0",
124       "SigBlk: 0000000000010000",
125       "SigIgn: 0000000000384004",
126       "SigCgt: 000000004b813efb",
127       "CapEff: 0000000000000000",
128       ]))
129
130     self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
131
132   def testNoSigCgt(self):
133     sp = PathJoin(self.tmpdir, "status")
134
135     utils.WriteFile(sp, data="\n".join([
136       "Name:   bash",
137       ]))
138
139     self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
140                       1234, 10, status_path=sp)
141
142   def testNoSuchFile(self):
143     sp = PathJoin(self.tmpdir, "notexist")
144
145     self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
146
147   @staticmethod
148   def _TestRealProcess():
149     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
150     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
151       raise Exception("SIGUSR1 is handled when it should not be")
152
153     signal.signal(signal.SIGUSR1, lambda signum, frame: None)
154     if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
155       raise Exception("SIGUSR1 is not handled when it should be")
156
157     signal.signal(signal.SIGUSR1, signal.SIG_IGN)
158     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
159       raise Exception("SIGUSR1 is not handled when it should be")
160
161     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
162     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
163       raise Exception("SIGUSR1 is handled when it should not be")
164
165     return True
166
167   def testRealProcess(self):
168     self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
169
170
171 class TestPidFileFunctions(unittest.TestCase):
172   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
173
174   def setUp(self):
175     self.dir = tempfile.mkdtemp()
176     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
177
178   def testPidFileFunctions(self):
179     pid_file = self.f_dpn('test')
180     fd = utils.WritePidFile(self.f_dpn('test'))
181     self.failUnless(os.path.exists(pid_file),
182                     "PID file should have been created")
183     read_pid = utils.ReadPidFile(pid_file)
184     self.failUnlessEqual(read_pid, os.getpid())
185     self.failUnless(utils.IsProcessAlive(read_pid))
186     self.failUnlessRaises(errors.LockError, utils.WritePidFile,
187                           self.f_dpn('test'))
188     os.close(fd)
189     utils.RemovePidFile(self.f_dpn("test"))
190     self.failIf(os.path.exists(pid_file),
191                 "PID file should not exist anymore")
192     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
193                          "ReadPidFile should return 0 for missing pid file")
194     fh = open(pid_file, "w")
195     fh.write("blah\n")
196     fh.close()
197     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
198                          "ReadPidFile should return 0 for invalid pid file")
199     # but now, even with the file existing, we should be able to lock it
200     fd = utils.WritePidFile(self.f_dpn('test'))
201     os.close(fd)
202     utils.RemovePidFile(self.f_dpn("test"))
203     self.failIf(os.path.exists(pid_file),
204                 "PID file should not exist anymore")
205
206   def testKill(self):
207     pid_file = self.f_dpn('child')
208     r_fd, w_fd = os.pipe()
209     new_pid = os.fork()
210     if new_pid == 0: #child
211       utils.WritePidFile(self.f_dpn('child'))
212       os.write(w_fd, 'a')
213       signal.pause()
214       os._exit(0)
215       return
216     # else we are in the parent
217     # wait until the child has written the pid file
218     os.read(r_fd, 1)
219     read_pid = utils.ReadPidFile(pid_file)
220     self.failUnlessEqual(read_pid, new_pid)
221     self.failUnless(utils.IsProcessAlive(new_pid))
222     utils.KillProcess(new_pid, waitpid=True)
223     self.failIf(utils.IsProcessAlive(new_pid))
224     utils.RemovePidFile(self.f_dpn('child'))
225     self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
226
227   def tearDown(self):
228     for name in os.listdir(self.dir):
229       os.unlink(os.path.join(self.dir, name))
230     os.rmdir(self.dir)
231
232
233 class TestRunCmd(testutils.GanetiTestCase):
234   """Testing case for the RunCmd function"""
235
236   def setUp(self):
237     testutils.GanetiTestCase.setUp(self)
238     self.magic = time.ctime() + " ganeti test"
239     self.fname = self._CreateTempFile()
240     self.fifo_tmpdir = tempfile.mkdtemp()
241     self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
242     os.mkfifo(self.fifo_file)
243
244   def tearDown(self):
245     shutil.rmtree(self.fifo_tmpdir)
246     testutils.GanetiTestCase.tearDown(self)
247
248   def testOk(self):
249     """Test successful exit code"""
250     result = RunCmd("/bin/sh -c 'exit 0'")
251     self.assertEqual(result.exit_code, 0)
252     self.assertEqual(result.output, "")
253
254   def testFail(self):
255     """Test fail exit code"""
256     result = RunCmd("/bin/sh -c 'exit 1'")
257     self.assertEqual(result.exit_code, 1)
258     self.assertEqual(result.output, "")
259
260   def testStdout(self):
261     """Test standard output"""
262     cmd = 'echo -n "%s"' % self.magic
263     result = RunCmd("/bin/sh -c '%s'" % cmd)
264     self.assertEqual(result.stdout, self.magic)
265     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
266     self.assertEqual(result.output, "")
267     self.assertFileContent(self.fname, self.magic)
268
269   def testStderr(self):
270     """Test standard error"""
271     cmd = 'echo -n "%s"' % self.magic
272     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
273     self.assertEqual(result.stderr, self.magic)
274     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
275     self.assertEqual(result.output, "")
276     self.assertFileContent(self.fname, self.magic)
277
278   def testCombined(self):
279     """Test combined output"""
280     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
281     expected = "A" + self.magic + "B" + self.magic
282     result = RunCmd("/bin/sh -c '%s'" % cmd)
283     self.assertEqual(result.output, expected)
284     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
285     self.assertEqual(result.output, "")
286     self.assertFileContent(self.fname, expected)
287
288   def testSignal(self):
289     """Test signal"""
290     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
291     self.assertEqual(result.signal, 15)
292     self.assertEqual(result.output, "")
293
294   def testTimeoutClean(self):
295     cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
296     result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
297     self.assertEqual(result.exit_code, 0)
298
299   def testTimeoutKill(self):
300     cmd = ["/bin/sh", "-c", "trap '' TERM; read < %s" % self.fifo_file]
301     timeout = 0.2
302     out, err, status, ta = utils._RunCmdPipe(cmd, {}, False, "/", False,
303                                              timeout, _linger_timeout=0.2)
304     self.assert_(status < 0)
305     self.assertEqual(-status, signal.SIGKILL)
306
307   def testTimeoutOutputAfterTerm(self):
308     cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
309     result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
310     self.assert_(result.failed)
311     self.assertEqual(result.stdout, "sigtermed\n")
312
313   def testListRun(self):
314     """Test list runs"""
315     result = RunCmd(["true"])
316     self.assertEqual(result.signal, None)
317     self.assertEqual(result.exit_code, 0)
318     result = RunCmd(["/bin/sh", "-c", "exit 1"])
319     self.assertEqual(result.signal, None)
320     self.assertEqual(result.exit_code, 1)
321     result = RunCmd(["echo", "-n", self.magic])
322     self.assertEqual(result.signal, None)
323     self.assertEqual(result.exit_code, 0)
324     self.assertEqual(result.stdout, self.magic)
325
326   def testFileEmptyOutput(self):
327     """Test file output"""
328     result = RunCmd(["true"], output=self.fname)
329     self.assertEqual(result.signal, None)
330     self.assertEqual(result.exit_code, 0)
331     self.assertFileContent(self.fname, "")
332
333   def testLang(self):
334     """Test locale environment"""
335     old_env = os.environ.copy()
336     try:
337       os.environ["LANG"] = "en_US.UTF-8"
338       os.environ["LC_ALL"] = "en_US.UTF-8"
339       result = RunCmd(["locale"])
340       for line in result.output.splitlines():
341         key, value = line.split("=", 1)
342         # Ignore these variables, they're overridden by LC_ALL
343         if key == "LANG" or key == "LANGUAGE":
344           continue
345         self.failIf(value and value != "C" and value != '"C"',
346             "Variable %s is set to the invalid value '%s'" % (key, value))
347     finally:
348       os.environ = old_env
349
350   def testDefaultCwd(self):
351     """Test default working directory"""
352     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
353
354   def testCwd(self):
355     """Test default working directory"""
356     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
357     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
358     cwd = os.getcwd()
359     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
360
361   def testResetEnv(self):
362     """Test environment reset functionality"""
363     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
364     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
365                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
366
367   def testNoFork(self):
368     """Test that nofork raise an error"""
369     self.assertFalse(utils._no_fork)
370     utils.DisableFork()
371     try:
372       self.assertTrue(utils._no_fork)
373       self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
374     finally:
375       utils._no_fork = False
376
377   def testWrongParams(self):
378     """Test wrong parameters"""
379     self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
380                       output="/dev/null", interactive=True)
381
382
383 class TestRunParts(testutils.GanetiTestCase):
384   """Testing case for the RunParts function"""
385
386   def setUp(self):
387     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
388
389   def tearDown(self):
390     shutil.rmtree(self.rundir)
391
392   def testEmpty(self):
393     """Test on an empty dir"""
394     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
395
396   def testSkipWrongName(self):
397     """Test that wrong files are skipped"""
398     fname = os.path.join(self.rundir, "00test.dot")
399     utils.WriteFile(fname, data="")
400     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
401     relname = os.path.basename(fname)
402     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
403                          [(relname, constants.RUNPARTS_SKIP, None)])
404
405   def testSkipNonExec(self):
406     """Test that non executable files are skipped"""
407     fname = os.path.join(self.rundir, "00test")
408     utils.WriteFile(fname, data="")
409     relname = os.path.basename(fname)
410     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
411                          [(relname, constants.RUNPARTS_SKIP, None)])
412
413   def testError(self):
414     """Test error on a broken executable"""
415     fname = os.path.join(self.rundir, "00test")
416     utils.WriteFile(fname, data="")
417     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
418     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
419     self.failUnlessEqual(relname, os.path.basename(fname))
420     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
421     self.failUnless(error)
422
423   def testSorted(self):
424     """Test executions are sorted"""
425     files = []
426     files.append(os.path.join(self.rundir, "64test"))
427     files.append(os.path.join(self.rundir, "00test"))
428     files.append(os.path.join(self.rundir, "42test"))
429
430     for fname in files:
431       utils.WriteFile(fname, data="")
432
433     results = RunParts(self.rundir, reset_env=True)
434
435     for fname in sorted(files):
436       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
437
438   def testOk(self):
439     """Test correct execution"""
440     fname = os.path.join(self.rundir, "00test")
441     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
442     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
443     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
444     self.failUnlessEqual(relname, os.path.basename(fname))
445     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
446     self.failUnlessEqual(runresult.stdout, "ciao")
447
448   def testRunFail(self):
449     """Test correct execution, with run failure"""
450     fname = os.path.join(self.rundir, "00test")
451     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
452     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
453     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
454     self.failUnlessEqual(relname, os.path.basename(fname))
455     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
456     self.failUnlessEqual(runresult.exit_code, 1)
457     self.failUnless(runresult.failed)
458
459   def testRunMix(self):
460     files = []
461     files.append(os.path.join(self.rundir, "00test"))
462     files.append(os.path.join(self.rundir, "42test"))
463     files.append(os.path.join(self.rundir, "64test"))
464     files.append(os.path.join(self.rundir, "99test"))
465
466     files.sort()
467
468     # 1st has errors in execution
469     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
470     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
471
472     # 2nd is skipped
473     utils.WriteFile(files[1], data="")
474
475     # 3rd cannot execute properly
476     utils.WriteFile(files[2], data="")
477     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
478
479     # 4th execs
480     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
481     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
482
483     results = RunParts(self.rundir, reset_env=True)
484
485     (relname, status, runresult) = results[0]
486     self.failUnlessEqual(relname, os.path.basename(files[0]))
487     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
488     self.failUnlessEqual(runresult.exit_code, 1)
489     self.failUnless(runresult.failed)
490
491     (relname, status, runresult) = results[1]
492     self.failUnlessEqual(relname, os.path.basename(files[1]))
493     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
494     self.failUnlessEqual(runresult, None)
495
496     (relname, status, runresult) = results[2]
497     self.failUnlessEqual(relname, os.path.basename(files[2]))
498     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
499     self.failUnless(runresult)
500
501     (relname, status, runresult) = results[3]
502     self.failUnlessEqual(relname, os.path.basename(files[3]))
503     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
504     self.failUnlessEqual(runresult.output, "ciao")
505     self.failUnlessEqual(runresult.exit_code, 0)
506     self.failUnless(not runresult.failed)
507
508   def testMissingDirectory(self):
509     nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
510     self.assertEqual(RunParts(nosuchdir), [])
511
512
513 class TestStartDaemon(testutils.GanetiTestCase):
514   def setUp(self):
515     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
516     self.tmpfile = os.path.join(self.tmpdir, "test")
517
518   def tearDown(self):
519     shutil.rmtree(self.tmpdir)
520
521   def testShell(self):
522     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
523     self._wait(self.tmpfile, 60.0, "Hello World")
524
525   def testShellOutput(self):
526     utils.StartDaemon("echo Hello World", output=self.tmpfile)
527     self._wait(self.tmpfile, 60.0, "Hello World")
528
529   def testNoShellNoOutput(self):
530     utils.StartDaemon(["pwd"])
531
532   def testNoShellNoOutputTouch(self):
533     testfile = os.path.join(self.tmpdir, "check")
534     self.failIf(os.path.exists(testfile))
535     utils.StartDaemon(["touch", testfile])
536     self._wait(testfile, 60.0, "")
537
538   def testNoShellOutput(self):
539     utils.StartDaemon(["pwd"], output=self.tmpfile)
540     self._wait(self.tmpfile, 60.0, "/")
541
542   def testNoShellOutputCwd(self):
543     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
544     self._wait(self.tmpfile, 60.0, os.getcwd())
545
546   def testShellEnv(self):
547     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
548                       env={ "GNT_TEST_VAR": "Hello World", })
549     self._wait(self.tmpfile, 60.0, "Hello World")
550
551   def testNoShellEnv(self):
552     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
553                       env={ "GNT_TEST_VAR": "Hello World", })
554     self._wait(self.tmpfile, 60.0, "Hello World")
555
556   def testOutputFd(self):
557     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
558     try:
559       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
560     finally:
561       os.close(fd)
562     self._wait(self.tmpfile, 60.0, os.getcwd())
563
564   def testPid(self):
565     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
566     self._wait(self.tmpfile, 60.0, str(pid))
567
568   def testPidFile(self):
569     pidfile = os.path.join(self.tmpdir, "pid")
570     checkfile = os.path.join(self.tmpdir, "abort")
571
572     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
573                             output=self.tmpfile)
574     try:
575       fd = os.open(pidfile, os.O_RDONLY)
576       try:
577         # Check file is locked
578         self.assertRaises(errors.LockError, utils.LockFile, fd)
579
580         pidtext = os.read(fd, 100)
581       finally:
582         os.close(fd)
583
584       self.assertEqual(int(pidtext.strip()), pid)
585
586       self.assert_(utils.IsProcessAlive(pid))
587     finally:
588       # No matter what happens, kill daemon
589       utils.KillProcess(pid, timeout=5.0, waitpid=False)
590       self.failIf(utils.IsProcessAlive(pid))
591
592     self.assertEqual(utils.ReadFile(self.tmpfile), "")
593
594   def _wait(self, path, timeout, expected):
595     # Due to the asynchronous nature of daemon processes, polling is necessary.
596     # A timeout makes sure the test doesn't hang forever.
597     def _CheckFile():
598       if not (os.path.isfile(path) and
599               utils.ReadFile(path).strip() == expected):
600         raise utils.RetryAgain()
601
602     try:
603       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
604     except utils.RetryTimeout:
605       self.fail("Apparently the daemon didn't run in %s seconds and/or"
606                 " didn't write the correct output" % timeout)
607
608   def testError(self):
609     self.assertRaises(errors.OpExecError, utils.StartDaemon,
610                       ["./does-NOT-EXIST/here/0123456789"])
611     self.assertRaises(errors.OpExecError, utils.StartDaemon,
612                       ["./does-NOT-EXIST/here/0123456789"],
613                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
614     self.assertRaises(errors.OpExecError, utils.StartDaemon,
615                       ["./does-NOT-EXIST/here/0123456789"],
616                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
617     self.assertRaises(errors.OpExecError, utils.StartDaemon,
618                       ["./does-NOT-EXIST/here/0123456789"],
619                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
620
621     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
622     try:
623       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
624                         ["./does-NOT-EXIST/here/0123456789"],
625                         output=self.tmpfile, output_fd=fd)
626     finally:
627       os.close(fd)
628
629
630 class TestRemoveFile(unittest.TestCase):
631   """Test case for the RemoveFile function"""
632
633   def setUp(self):
634     """Create a temp dir and file for each case"""
635     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
636     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
637     os.close(fd)
638
639   def tearDown(self):
640     if os.path.exists(self.tmpfile):
641       os.unlink(self.tmpfile)
642     os.rmdir(self.tmpdir)
643
644   def testIgnoreDirs(self):
645     """Test that RemoveFile() ignores directories"""
646     self.assertEqual(None, RemoveFile(self.tmpdir))
647
648   def testIgnoreNotExisting(self):
649     """Test that RemoveFile() ignores non-existing files"""
650     RemoveFile(self.tmpfile)
651     RemoveFile(self.tmpfile)
652
653   def testRemoveFile(self):
654     """Test that RemoveFile does remove a file"""
655     RemoveFile(self.tmpfile)
656     if os.path.exists(self.tmpfile):
657       self.fail("File '%s' not removed" % self.tmpfile)
658
659   def testRemoveSymlink(self):
660     """Test that RemoveFile does remove symlinks"""
661     symlink = self.tmpdir + "/symlink"
662     os.symlink("no-such-file", symlink)
663     RemoveFile(symlink)
664     if os.path.exists(symlink):
665       self.fail("File '%s' not removed" % symlink)
666     os.symlink(self.tmpfile, symlink)
667     RemoveFile(symlink)
668     if os.path.exists(symlink):
669       self.fail("File '%s' not removed" % symlink)
670
671
672 class TestRemoveDir(unittest.TestCase):
673   def setUp(self):
674     self.tmpdir = tempfile.mkdtemp()
675
676   def tearDown(self):
677     try:
678       shutil.rmtree(self.tmpdir)
679     except EnvironmentError:
680       pass
681
682   def testEmptyDir(self):
683     utils.RemoveDir(self.tmpdir)
684     self.assertFalse(os.path.isdir(self.tmpdir))
685
686   def testNonEmptyDir(self):
687     self.tmpfile = os.path.join(self.tmpdir, "test1")
688     open(self.tmpfile, "w").close()
689     self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
690
691
692 class TestRename(unittest.TestCase):
693   """Test case for RenameFile"""
694
695   def setUp(self):
696     """Create a temporary directory"""
697     self.tmpdir = tempfile.mkdtemp()
698     self.tmpfile = os.path.join(self.tmpdir, "test1")
699
700     # Touch the file
701     open(self.tmpfile, "w").close()
702
703   def tearDown(self):
704     """Remove temporary directory"""
705     shutil.rmtree(self.tmpdir)
706
707   def testSimpleRename1(self):
708     """Simple rename 1"""
709     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
710     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
711
712   def testSimpleRename2(self):
713     """Simple rename 2"""
714     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
715                      mkdir=True)
716     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
717
718   def testRenameMkdir(self):
719     """Rename with mkdir"""
720     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
721                      mkdir=True)
722     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
723     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
724
725     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
726                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
727                      mkdir=True)
728     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
729     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
730     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
731
732
733 class TestReadFile(testutils.GanetiTestCase):
734
735   def testReadAll(self):
736     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
737     self.assertEqual(len(data), 814)
738
739     h = compat.md5_hash()
740     h.update(data)
741     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
742
743   def testReadSize(self):
744     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
745                           size=100)
746     self.assertEqual(len(data), 100)
747
748     h = compat.md5_hash()
749     h.update(data)
750     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
751
752   def testError(self):
753     self.assertRaises(EnvironmentError, utils.ReadFile,
754                       "/dev/null/does-not-exist")
755
756
757 class TestReadOneLineFile(testutils.GanetiTestCase):
758
759   def setUp(self):
760     testutils.GanetiTestCase.setUp(self)
761
762   def testDefault(self):
763     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
764     self.assertEqual(len(data), 27)
765     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
766
767   def testNotStrict(self):
768     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
769     self.assertEqual(len(data), 27)
770     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
771
772   def testStrictFailure(self):
773     self.assertRaises(errors.GenericError, ReadOneLineFile,
774                       self._TestDataFilename("cert1.pem"), strict=True)
775
776   def testLongLine(self):
777     dummydata = (1024 * "Hello World! ")
778     myfile = self._CreateTempFile()
779     utils.WriteFile(myfile, data=dummydata)
780     datastrict = ReadOneLineFile(myfile, strict=True)
781     datalax = ReadOneLineFile(myfile, strict=False)
782     self.assertEqual(dummydata, datastrict)
783     self.assertEqual(dummydata, datalax)
784
785   def testNewline(self):
786     myfile = self._CreateTempFile()
787     myline = "myline"
788     for nl in ["", "\n", "\r\n"]:
789       dummydata = "%s%s" % (myline, nl)
790       utils.WriteFile(myfile, data=dummydata)
791       datalax = ReadOneLineFile(myfile, strict=False)
792       self.assertEqual(myline, datalax)
793       datastrict = ReadOneLineFile(myfile, strict=True)
794       self.assertEqual(myline, datastrict)
795
796   def testWhitespaceAndMultipleLines(self):
797     myfile = self._CreateTempFile()
798     for nl in ["", "\n", "\r\n"]:
799       for ws in [" ", "\t", "\t\t  \t", "\t "]:
800         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
801         utils.WriteFile(myfile, data=dummydata)
802         datalax = ReadOneLineFile(myfile, strict=False)
803         if nl:
804           self.assert_(set("\r\n") & set(dummydata))
805           self.assertRaises(errors.GenericError, ReadOneLineFile,
806                             myfile, strict=True)
807           explen = len("Foo bar baz ") + len(ws)
808           self.assertEqual(len(datalax), explen)
809           self.assertEqual(datalax, dummydata[:explen])
810           self.assertFalse(set("\r\n") & set(datalax))
811         else:
812           datastrict = ReadOneLineFile(myfile, strict=True)
813           self.assertEqual(dummydata, datastrict)
814           self.assertEqual(dummydata, datalax)
815
816   def testEmptylines(self):
817     myfile = self._CreateTempFile()
818     myline = "myline"
819     for nl in ["\n", "\r\n"]:
820       for ol in ["", "otherline"]:
821         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
822         utils.WriteFile(myfile, data=dummydata)
823         self.assert_(set("\r\n") & set(dummydata))
824         datalax = ReadOneLineFile(myfile, strict=False)
825         self.assertEqual(myline, datalax)
826         if ol:
827           self.assertRaises(errors.GenericError, ReadOneLineFile,
828                             myfile, strict=True)
829         else:
830           datastrict = ReadOneLineFile(myfile, strict=True)
831           self.assertEqual(myline, datastrict)
832
833   def testEmptyfile(self):
834     myfile = self._CreateTempFile()
835     self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
836
837
838 class TestTimestampForFilename(unittest.TestCase):
839   def test(self):
840     self.assert_("." not in utils.TimestampForFilename())
841     self.assert_(":" not in utils.TimestampForFilename())
842
843
844 class TestCreateBackup(testutils.GanetiTestCase):
845   def setUp(self):
846     testutils.GanetiTestCase.setUp(self)
847
848     self.tmpdir = tempfile.mkdtemp()
849
850   def tearDown(self):
851     testutils.GanetiTestCase.tearDown(self)
852
853     shutil.rmtree(self.tmpdir)
854
855   def testEmpty(self):
856     filename = PathJoin(self.tmpdir, "config.data")
857     utils.WriteFile(filename, data="")
858     bname = utils.CreateBackup(filename)
859     self.assertFileContent(bname, "")
860     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
861     utils.CreateBackup(filename)
862     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
863     utils.CreateBackup(filename)
864     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
865
866     fifoname = PathJoin(self.tmpdir, "fifo")
867     os.mkfifo(fifoname)
868     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
869
870   def testContent(self):
871     bkpcount = 0
872     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
873       for rep in [1, 2, 10, 127]:
874         testdata = data * rep
875
876         filename = PathJoin(self.tmpdir, "test.data_")
877         utils.WriteFile(filename, data=testdata)
878         self.assertFileContent(filename, testdata)
879
880         for _ in range(3):
881           bname = utils.CreateBackup(filename)
882           bkpcount += 1
883           self.assertFileContent(bname, testdata)
884           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
885
886
887 class TestParseCpuMask(unittest.TestCase):
888   """Test case for the ParseCpuMask function."""
889
890   def testWellFormed(self):
891     self.assertEqual(utils.ParseCpuMask(""), [])
892     self.assertEqual(utils.ParseCpuMask("1"), [1])
893     self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
894
895   def testInvalidInput(self):
896     for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
897       self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
898
899
900 class TestSshKeys(testutils.GanetiTestCase):
901   """Test case for the AddAuthorizedKey function"""
902
903   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
904   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
905            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
906
907   def setUp(self):
908     testutils.GanetiTestCase.setUp(self)
909     self.tmpname = self._CreateTempFile()
910     handle = open(self.tmpname, 'w')
911     try:
912       handle.write("%s\n" % TestSshKeys.KEY_A)
913       handle.write("%s\n" % TestSshKeys.KEY_B)
914     finally:
915       handle.close()
916
917   def testAddingNewKey(self):
918     utils.AddAuthorizedKey(self.tmpname,
919                            'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
920
921     self.assertFileContent(self.tmpname,
922       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
923       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
924       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
925       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
926
927   def testAddingAlmostButNotCompletelyTheSameKey(self):
928     utils.AddAuthorizedKey(self.tmpname,
929         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
930
931     self.assertFileContent(self.tmpname,
932       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
933       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
934       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
935       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
936
937   def testAddingExistingKeyWithSomeMoreSpaces(self):
938     utils.AddAuthorizedKey(self.tmpname,
939         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
940
941     self.assertFileContent(self.tmpname,
942       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
943       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
944       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
945
946   def testRemovingExistingKeyWithSomeMoreSpaces(self):
947     utils.RemoveAuthorizedKey(self.tmpname,
948         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
949
950     self.assertFileContent(self.tmpname,
951       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
952       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
953
954   def testRemovingNonExistingKey(self):
955     utils.RemoveAuthorizedKey(self.tmpname,
956         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
957
958     self.assertFileContent(self.tmpname,
959       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
960       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
961       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
962
963
964 class TestEtcHosts(testutils.GanetiTestCase):
965   """Test functions modifying /etc/hosts"""
966
967   def setUp(self):
968     testutils.GanetiTestCase.setUp(self)
969     self.tmpname = self._CreateTempFile()
970     handle = open(self.tmpname, 'w')
971     try:
972       handle.write('# This is a test file for /etc/hosts\n')
973       handle.write('127.0.0.1\tlocalhost\n')
974       handle.write('192.0.2.1 router gw\n')
975     finally:
976       handle.close()
977
978   def testSettingNewIp(self):
979     SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
980                      ['myhost'])
981
982     self.assertFileContent(self.tmpname,
983       "# This is a test file for /etc/hosts\n"
984       "127.0.0.1\tlocalhost\n"
985       "192.0.2.1 router gw\n"
986       "198.51.100.4\tmyhost.example.com myhost\n")
987     self.assertFileMode(self.tmpname, 0644)
988
989   def testSettingExistingIp(self):
990     SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
991                      ['myhost'])
992
993     self.assertFileContent(self.tmpname,
994       "# This is a test file for /etc/hosts\n"
995       "127.0.0.1\tlocalhost\n"
996       "192.0.2.1\tmyhost.example.com myhost\n")
997     self.assertFileMode(self.tmpname, 0644)
998
999   def testSettingDuplicateName(self):
1000     SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1001
1002     self.assertFileContent(self.tmpname,
1003       "# This is a test file for /etc/hosts\n"
1004       "127.0.0.1\tlocalhost\n"
1005       "192.0.2.1 router gw\n"
1006       "198.51.100.4\tmyhost\n")
1007     self.assertFileMode(self.tmpname, 0644)
1008
1009   def testRemovingExistingHost(self):
1010     RemoveEtcHostsEntry(self.tmpname, 'router')
1011
1012     self.assertFileContent(self.tmpname,
1013       "# This is a test file for /etc/hosts\n"
1014       "127.0.0.1\tlocalhost\n"
1015       "192.0.2.1 gw\n")
1016     self.assertFileMode(self.tmpname, 0644)
1017
1018   def testRemovingSingleExistingHost(self):
1019     RemoveEtcHostsEntry(self.tmpname, 'localhost')
1020
1021     self.assertFileContent(self.tmpname,
1022       "# This is a test file for /etc/hosts\n"
1023       "192.0.2.1 router gw\n")
1024     self.assertFileMode(self.tmpname, 0644)
1025
1026   def testRemovingNonExistingHost(self):
1027     RemoveEtcHostsEntry(self.tmpname, 'myhost')
1028
1029     self.assertFileContent(self.tmpname,
1030       "# This is a test file for /etc/hosts\n"
1031       "127.0.0.1\tlocalhost\n"
1032       "192.0.2.1 router gw\n")
1033     self.assertFileMode(self.tmpname, 0644)
1034
1035   def testRemovingAlias(self):
1036     RemoveEtcHostsEntry(self.tmpname, 'gw')
1037
1038     self.assertFileContent(self.tmpname,
1039       "# This is a test file for /etc/hosts\n"
1040       "127.0.0.1\tlocalhost\n"
1041       "192.0.2.1 router\n")
1042     self.assertFileMode(self.tmpname, 0644)
1043
1044
1045 class TestGetMounts(unittest.TestCase):
1046   """Test case for GetMounts()."""
1047
1048   TESTDATA = (
1049     "rootfs /     rootfs rw 0 0\n"
1050     "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1051     "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1052
1053   def setUp(self):
1054     self.tmpfile = tempfile.NamedTemporaryFile()
1055     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1056
1057   def testGetMounts(self):
1058     self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1059       [
1060         ("rootfs", "/", "rootfs", "rw"),
1061         ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1062         ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1063       ])
1064
1065
1066 class TestListVisibleFiles(unittest.TestCase):
1067   """Test case for ListVisibleFiles"""
1068
1069   def setUp(self):
1070     self.path = tempfile.mkdtemp()
1071
1072   def tearDown(self):
1073     shutil.rmtree(self.path)
1074
1075   def _CreateFiles(self, files):
1076     for name in files:
1077       utils.WriteFile(os.path.join(self.path, name), data="test")
1078
1079   def _test(self, files, expected):
1080     self._CreateFiles(files)
1081     found = ListVisibleFiles(self.path)
1082     self.assertEqual(set(found), set(expected))
1083
1084   def testAllVisible(self):
1085     files = ["a", "b", "c"]
1086     expected = files
1087     self._test(files, expected)
1088
1089   def testNoneVisible(self):
1090     files = [".a", ".b", ".c"]
1091     expected = []
1092     self._test(files, expected)
1093
1094   def testSomeVisible(self):
1095     files = ["a", "b", ".c"]
1096     expected = ["a", "b"]
1097     self._test(files, expected)
1098
1099   def testNonAbsolutePath(self):
1100     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1101
1102   def testNonNormalizedPath(self):
1103     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1104                           "/bin/../tmp")
1105
1106
1107 class TestNewUUID(unittest.TestCase):
1108   """Test case for NewUUID"""
1109
1110   def runTest(self):
1111     self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1112
1113
1114 class TestFirstFree(unittest.TestCase):
1115   """Test case for the FirstFree function"""
1116
1117   def test(self):
1118     """Test FirstFree"""
1119     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1120     self.failUnlessEqual(FirstFree([]), None)
1121     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1122     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1123     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1124
1125
1126 class TestTailFile(testutils.GanetiTestCase):
1127   """Test case for the TailFile function"""
1128
1129   def testEmpty(self):
1130     fname = self._CreateTempFile()
1131     self.failUnlessEqual(TailFile(fname), [])
1132     self.failUnlessEqual(TailFile(fname, lines=25), [])
1133
1134   def testAllLines(self):
1135     data = ["test %d" % i for i in range(30)]
1136     for i in range(30):
1137       fname = self._CreateTempFile()
1138       fd = open(fname, "w")
1139       fd.write("\n".join(data[:i]))
1140       if i > 0:
1141         fd.write("\n")
1142       fd.close()
1143       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1144
1145   def testPartialLines(self):
1146     data = ["test %d" % i for i in range(30)]
1147     fname = self._CreateTempFile()
1148     fd = open(fname, "w")
1149     fd.write("\n".join(data))
1150     fd.write("\n")
1151     fd.close()
1152     for i in range(1, 30):
1153       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1154
1155   def testBigFile(self):
1156     data = ["test %d" % i for i in range(30)]
1157     fname = self._CreateTempFile()
1158     fd = open(fname, "w")
1159     fd.write("X" * 1048576)
1160     fd.write("\n")
1161     fd.write("\n".join(data))
1162     fd.write("\n")
1163     fd.close()
1164     for i in range(1, 30):
1165       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1166
1167
1168 class _BaseFileLockTest:
1169   """Test case for the FileLock class"""
1170
1171   def testSharedNonblocking(self):
1172     self.lock.Shared(blocking=False)
1173     self.lock.Close()
1174
1175   def testExclusiveNonblocking(self):
1176     self.lock.Exclusive(blocking=False)
1177     self.lock.Close()
1178
1179   def testUnlockNonblocking(self):
1180     self.lock.Unlock(blocking=False)
1181     self.lock.Close()
1182
1183   def testSharedBlocking(self):
1184     self.lock.Shared(blocking=True)
1185     self.lock.Close()
1186
1187   def testExclusiveBlocking(self):
1188     self.lock.Exclusive(blocking=True)
1189     self.lock.Close()
1190
1191   def testUnlockBlocking(self):
1192     self.lock.Unlock(blocking=True)
1193     self.lock.Close()
1194
1195   def testSharedExclusiveUnlock(self):
1196     self.lock.Shared(blocking=False)
1197     self.lock.Exclusive(blocking=False)
1198     self.lock.Unlock(blocking=False)
1199     self.lock.Close()
1200
1201   def testExclusiveSharedUnlock(self):
1202     self.lock.Exclusive(blocking=False)
1203     self.lock.Shared(blocking=False)
1204     self.lock.Unlock(blocking=False)
1205     self.lock.Close()
1206
1207   def testSimpleTimeout(self):
1208     # These will succeed on the first attempt, hence a short timeout
1209     self.lock.Shared(blocking=True, timeout=10.0)
1210     self.lock.Exclusive(blocking=False, timeout=10.0)
1211     self.lock.Unlock(blocking=True, timeout=10.0)
1212     self.lock.Close()
1213
1214   @staticmethod
1215   def _TryLockInner(filename, shared, blocking):
1216     lock = utils.FileLock.Open(filename)
1217
1218     if shared:
1219       fn = lock.Shared
1220     else:
1221       fn = lock.Exclusive
1222
1223     try:
1224       # The timeout doesn't really matter as the parent process waits for us to
1225       # finish anyway.
1226       fn(blocking=blocking, timeout=0.01)
1227     except errors.LockError, err:
1228       return False
1229
1230     return True
1231
1232   def _TryLock(self, *args):
1233     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1234                                       *args)
1235
1236   def testTimeout(self):
1237     for blocking in [True, False]:
1238       self.lock.Exclusive(blocking=True)
1239       self.failIf(self._TryLock(False, blocking))
1240       self.failIf(self._TryLock(True, blocking))
1241
1242       self.lock.Shared(blocking=True)
1243       self.assert_(self._TryLock(True, blocking))
1244       self.failIf(self._TryLock(False, blocking))
1245
1246   def testCloseShared(self):
1247     self.lock.Close()
1248     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1249
1250   def testCloseExclusive(self):
1251     self.lock.Close()
1252     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1253
1254   def testCloseUnlock(self):
1255     self.lock.Close()
1256     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1257
1258
1259 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1260   TESTDATA = "Hello World\n" * 10
1261
1262   def setUp(self):
1263     testutils.GanetiTestCase.setUp(self)
1264
1265     self.tmpfile = tempfile.NamedTemporaryFile()
1266     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1267     self.lock = utils.FileLock.Open(self.tmpfile.name)
1268
1269     # Ensure "Open" didn't truncate file
1270     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1271
1272   def tearDown(self):
1273     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1274
1275     testutils.GanetiTestCase.tearDown(self)
1276
1277
1278 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1279   def setUp(self):
1280     self.tmpfile = tempfile.NamedTemporaryFile()
1281     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1282
1283
1284 class TestTimeFunctions(unittest.TestCase):
1285   """Test case for time functions"""
1286
1287   def runTest(self):
1288     self.assertEqual(utils.SplitTime(1), (1, 0))
1289     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1290     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1291     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1292     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1293     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1294     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1295     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1296
1297     self.assertRaises(AssertionError, utils.SplitTime, -1)
1298
1299     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1300     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1301     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1302
1303     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1304                      1218448917.481)
1305     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1306
1307     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1308     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1309     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1310     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1311     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1312
1313
1314 class FieldSetTestCase(unittest.TestCase):
1315   """Test case for FieldSets"""
1316
1317   def testSimpleMatch(self):
1318     f = utils.FieldSet("a", "b", "c", "def")
1319     self.failUnless(f.Matches("a"))
1320     self.failIf(f.Matches("d"), "Substring matched")
1321     self.failIf(f.Matches("defghi"), "Prefix string matched")
1322     self.failIf(f.NonMatching(["b", "c"]))
1323     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1324     self.failUnless(f.NonMatching(["a", "d"]))
1325
1326   def testRegexMatch(self):
1327     f = utils.FieldSet("a", "b([0-9]+)", "c")
1328     self.failUnless(f.Matches("b1"))
1329     self.failUnless(f.Matches("b99"))
1330     self.failIf(f.Matches("b/1"))
1331     self.failIf(f.NonMatching(["b12", "c"]))
1332     self.failUnless(f.NonMatching(["a", "1"]))
1333
1334 class TestForceDictType(unittest.TestCase):
1335   """Test case for ForceDictType"""
1336   KEY_TYPES = {
1337     "a": constants.VTYPE_INT,
1338     "b": constants.VTYPE_BOOL,
1339     "c": constants.VTYPE_STRING,
1340     "d": constants.VTYPE_SIZE,
1341     "e": constants.VTYPE_MAYBE_STRING,
1342     }
1343
1344   def _fdt(self, dict, allowed_values=None):
1345     if allowed_values is None:
1346       utils.ForceDictType(dict, self.KEY_TYPES)
1347     else:
1348       utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
1349
1350     return dict
1351
1352   def testSimpleDict(self):
1353     self.assertEqual(self._fdt({}), {})
1354     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1355     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1356     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1357     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1358     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1359     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1360     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1361     self.assertEqual(self._fdt({'b': False}), {'b': False})
1362     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1363     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1364     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1365     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1366     self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1367     self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1368     self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1369     self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
1370
1371   def testErrors(self):
1372     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1373     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
1374     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1375     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1376     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1377     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1378     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1379     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
1380     self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
1381     self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
1382                       {"b": "hello"}, {"b": "no-such-type"})
1383
1384
1385 class TestIsNormAbsPath(unittest.TestCase):
1386   """Testing case for IsNormAbsPath"""
1387
1388   def _pathTestHelper(self, path, result):
1389     if result:
1390       self.assert_(utils.IsNormAbsPath(path),
1391           "Path %s should result absolute and normalized" % path)
1392     else:
1393       self.assertFalse(utils.IsNormAbsPath(path),
1394           "Path %s should not result absolute and normalized" % path)
1395
1396   def testBase(self):
1397     self._pathTestHelper('/etc', True)
1398     self._pathTestHelper('/srv', True)
1399     self._pathTestHelper('etc', False)
1400     self._pathTestHelper('/etc/../root', False)
1401     self._pathTestHelper('/etc/', False)
1402
1403
1404 class RunInSeparateProcess(unittest.TestCase):
1405   def test(self):
1406     for exp in [True, False]:
1407       def _child():
1408         return exp
1409
1410       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1411
1412   def testArgs(self):
1413     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1414       def _child(carg1, carg2):
1415         return carg1 == "Foo" and carg2 == arg
1416
1417       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1418
1419   def testPid(self):
1420     parent_pid = os.getpid()
1421
1422     def _check():
1423       return os.getpid() == parent_pid
1424
1425     self.failIf(utils.RunInSeparateProcess(_check))
1426
1427   def testSignal(self):
1428     def _kill():
1429       os.kill(os.getpid(), signal.SIGTERM)
1430
1431     self.assertRaises(errors.GenericError,
1432                       utils.RunInSeparateProcess, _kill)
1433
1434   def testException(self):
1435     def _exc():
1436       raise errors.GenericError("This is a test")
1437
1438     self.assertRaises(errors.GenericError,
1439                       utils.RunInSeparateProcess, _exc)
1440
1441
1442 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1443   def setUp(self):
1444     self.tmpdir = tempfile.mkdtemp()
1445
1446   def tearDown(self):
1447     shutil.rmtree(self.tmpdir)
1448
1449   def _checkRsaPrivateKey(self, key):
1450     lines = key.splitlines()
1451     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1452             "-----END RSA PRIVATE KEY-----" in lines)
1453
1454   def _checkCertificate(self, cert):
1455     lines = cert.splitlines()
1456     return ("-----BEGIN CERTIFICATE-----" in lines and
1457             "-----END CERTIFICATE-----" in lines)
1458
1459   def test(self):
1460     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1461       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1462       self._checkRsaPrivateKey(key_pem)
1463       self._checkCertificate(cert_pem)
1464
1465       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1466                                            key_pem)
1467       self.assert_(key.bits() >= 1024)
1468       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1469       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1470
1471       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1472                                              cert_pem)
1473       self.failIf(x509.has_expired())
1474       self.assertEqual(x509.get_issuer().CN, common_name)
1475       self.assertEqual(x509.get_subject().CN, common_name)
1476       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1477
1478   def testLegacy(self):
1479     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1480
1481     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1482
1483     cert1 = utils.ReadFile(cert1_filename)
1484
1485     self.assert_(self._checkRsaPrivateKey(cert1))
1486     self.assert_(self._checkCertificate(cert1))
1487
1488
1489 class TestPathJoin(unittest.TestCase):
1490   """Testing case for PathJoin"""
1491
1492   def testBasicItems(self):
1493     mlist = ["/a", "b", "c"]
1494     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1495
1496   def testNonAbsPrefix(self):
1497     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1498
1499   def testBackTrack(self):
1500     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1501
1502   def testMultiAbs(self):
1503     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1504
1505
1506 class TestValidateServiceName(unittest.TestCase):
1507   def testValid(self):
1508     testnames = [
1509       0, 1, 2, 3, 1024, 65000, 65534, 65535,
1510       "ganeti",
1511       "gnt-masterd",
1512       "HELLO_WORLD_SVC",
1513       "hello.world.1",
1514       "0", "80", "1111", "65535",
1515       ]
1516
1517     for name in testnames:
1518       self.assertEqual(utils.ValidateServiceName(name), name)
1519
1520   def testInvalid(self):
1521     testnames = [
1522       -15756, -1, 65536, 133428083,
1523       "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1524       "-8546", "-1", "65536",
1525       (129 * "A"),
1526       ]
1527
1528     for name in testnames:
1529       self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1530
1531
1532 class TestParseAsn1Generalizedtime(unittest.TestCase):
1533   def test(self):
1534     # UTC
1535     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1536     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1537                      1266860512)
1538     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1539                      (2**31) - 1)
1540
1541     # With offset
1542     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1543                      1266860512)
1544     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1545                      1266931012)
1546     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1547                      1266931088)
1548     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1549                      1266931295)
1550     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1551                      3600)
1552
1553     # Leap seconds are not supported by datetime.datetime
1554     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1555                       "19841231235960+0000")
1556     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1557                       "19920630235960+0000")
1558
1559     # Errors
1560     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1561     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1562     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1563                       "20100222174152")
1564     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1565                       "Mon Feb 22 17:47:02 UTC 2010")
1566     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1567                       "2010-02-22 17:42:02")
1568
1569
1570 class TestGetX509CertValidity(testutils.GanetiTestCase):
1571   def setUp(self):
1572     testutils.GanetiTestCase.setUp(self)
1573
1574     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1575
1576     # Test whether we have pyOpenSSL 0.7 or above
1577     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1578
1579     if not self.pyopenssl0_7:
1580       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1581                     " function correctly")
1582
1583   def _LoadCert(self, name):
1584     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1585                                            self._ReadTestData(name))
1586
1587   def test(self):
1588     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1589     if self.pyopenssl0_7:
1590       self.assertEqual(validity, (1266919967, 1267524767))
1591     else:
1592       self.assertEqual(validity, (None, None))
1593
1594
1595 class TestSignX509Certificate(unittest.TestCase):
1596   KEY = "My private key!"
1597   KEY_OTHER = "Another key"
1598
1599   def test(self):
1600     # Generate certificate valid for 5 minutes
1601     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1602
1603     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1604                                            cert_pem)
1605
1606     # No signature at all
1607     self.assertRaises(errors.GenericError,
1608                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1609
1610     # Invalid input
1611     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1612                       "", self.KEY)
1613     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1614                       "X-Ganeti-Signature: \n", self.KEY)
1615     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1616                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1617     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1618                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1619     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1620                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1621
1622     # Invalid salt
1623     for salt in list("-_@$,:;/\\ \t\n"):
1624       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1625                         cert_pem, self.KEY, "foo%sbar" % salt)
1626
1627     for salt in ["HelloWorld", "salt", string.letters, string.digits,
1628                  utils.GenerateSecret(numbytes=4),
1629                  utils.GenerateSecret(numbytes=16),
1630                  "{123:456}".encode("hex")]:
1631       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1632
1633       self._Check(cert, salt, signed_pem)
1634
1635       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1636       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1637       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1638                                "lines----\n------ at\nthe end!"))
1639
1640   def _Check(self, cert, salt, pem):
1641     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1642     self.assertEqual(salt, salt2)
1643     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1644
1645     # Other key
1646     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1647                       pem, self.KEY_OTHER)
1648
1649
1650 class TestMakedirs(unittest.TestCase):
1651   def setUp(self):
1652     self.tmpdir = tempfile.mkdtemp()
1653
1654   def tearDown(self):
1655     shutil.rmtree(self.tmpdir)
1656
1657   def testNonExisting(self):
1658     path = PathJoin(self.tmpdir, "foo")
1659     utils.Makedirs(path)
1660     self.assert_(os.path.isdir(path))
1661
1662   def testExisting(self):
1663     path = PathJoin(self.tmpdir, "foo")
1664     os.mkdir(path)
1665     utils.Makedirs(path)
1666     self.assert_(os.path.isdir(path))
1667
1668   def testRecursiveNonExisting(self):
1669     path = PathJoin(self.tmpdir, "foo/bar/baz")
1670     utils.Makedirs(path)
1671     self.assert_(os.path.isdir(path))
1672
1673   def testRecursiveExisting(self):
1674     path = PathJoin(self.tmpdir, "B/moo/xyz")
1675     self.assertFalse(os.path.exists(path))
1676     os.mkdir(PathJoin(self.tmpdir, "B"))
1677     utils.Makedirs(path)
1678     self.assert_(os.path.isdir(path))
1679
1680
1681 class TestReadLockedPidFile(unittest.TestCase):
1682   def setUp(self):
1683     self.tmpdir = tempfile.mkdtemp()
1684
1685   def tearDown(self):
1686     shutil.rmtree(self.tmpdir)
1687
1688   def testNonExistent(self):
1689     path = PathJoin(self.tmpdir, "nonexist")
1690     self.assert_(utils.ReadLockedPidFile(path) is None)
1691
1692   def testUnlocked(self):
1693     path = PathJoin(self.tmpdir, "pid")
1694     utils.WriteFile(path, data="123")
1695     self.assert_(utils.ReadLockedPidFile(path) is None)
1696
1697   def testLocked(self):
1698     path = PathJoin(self.tmpdir, "pid")
1699     utils.WriteFile(path, data="123")
1700
1701     fl = utils.FileLock.Open(path)
1702     try:
1703       fl.Exclusive(blocking=True)
1704
1705       self.assertEqual(utils.ReadLockedPidFile(path), 123)
1706     finally:
1707       fl.Close()
1708
1709     self.assert_(utils.ReadLockedPidFile(path) is None)
1710
1711   def testError(self):
1712     path = PathJoin(self.tmpdir, "foobar", "pid")
1713     utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
1714     # open(2) should return ENOTDIR
1715     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
1716
1717
1718 class TestCertVerification(testutils.GanetiTestCase):
1719   def setUp(self):
1720     testutils.GanetiTestCase.setUp(self)
1721
1722     self.tmpdir = tempfile.mkdtemp()
1723
1724   def tearDown(self):
1725     shutil.rmtree(self.tmpdir)
1726
1727   def testVerifyCertificate(self):
1728     cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
1729     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1730                                            cert_pem)
1731
1732     # Not checking return value as this certificate is expired
1733     utils.VerifyX509Certificate(cert, 30, 7)
1734
1735
1736 class TestVerifyCertificateInner(unittest.TestCase):
1737   def test(self):
1738     vci = utils._VerifyCertificateInner
1739
1740     # Valid
1741     self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
1742                      (None, None))
1743
1744     # Not yet valid
1745     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
1746     self.assertEqual(errcode, utils.CERT_WARNING)
1747
1748     # Expiring soon
1749     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
1750     self.assertEqual(errcode, utils.CERT_ERROR)
1751
1752     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
1753     self.assertEqual(errcode, utils.CERT_WARNING)
1754
1755     (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
1756     self.assertEqual(errcode, None)
1757
1758     # Expired
1759     (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
1760     self.assertEqual(errcode, utils.CERT_ERROR)
1761
1762     (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
1763     self.assertEqual(errcode, utils.CERT_ERROR)
1764
1765     (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
1766     self.assertEqual(errcode, utils.CERT_ERROR)
1767
1768     (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
1769     self.assertEqual(errcode, utils.CERT_ERROR)
1770
1771
1772 class TestEnsureDirs(unittest.TestCase):
1773   """Tests for EnsureDirs"""
1774
1775   def setUp(self):
1776     self.dir = tempfile.mkdtemp()
1777     self.old_umask = os.umask(0777)
1778
1779   def testEnsureDirs(self):
1780     utils.EnsureDirs([
1781         (PathJoin(self.dir, "foo"), 0777),
1782         (PathJoin(self.dir, "bar"), 0000),
1783         ])
1784     self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
1785     self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
1786
1787   def tearDown(self):
1788     os.rmdir(PathJoin(self.dir, "foo"))
1789     os.rmdir(PathJoin(self.dir, "bar"))
1790     os.rmdir(self.dir)
1791     os.umask(self.old_umask)
1792
1793
1794 class TestFindMatch(unittest.TestCase):
1795   def test(self):
1796     data = {
1797       "aaaa": "Four A",
1798       "bb": {"Two B": True},
1799       re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
1800       }
1801
1802     self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
1803     self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
1804
1805     for i in ["foo", "bar", "bazX"]:
1806       for j in range(1, 100, 7):
1807         self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
1808                          ((1, 2, 3), [i, str(j)]))
1809
1810   def testNoMatch(self):
1811     self.assert_(utils.FindMatch({}, "") is None)
1812     self.assert_(utils.FindMatch({}, "foo") is None)
1813     self.assert_(utils.FindMatch({}, 1234) is None)
1814
1815     data = {
1816       "X": "Hello World",
1817       re.compile("^(something)$"): "Hello World",
1818       }
1819
1820     self.assert_(utils.FindMatch(data, "") is None)
1821     self.assert_(utils.FindMatch(data, "Hello World") is None)
1822
1823
1824 class TestFileID(testutils.GanetiTestCase):
1825   def testEquality(self):
1826     name = self._CreateTempFile()
1827     oldi = utils.GetFileID(path=name)
1828     self.failUnless(utils.VerifyFileID(oldi, oldi))
1829
1830   def testUpdate(self):
1831     name = self._CreateTempFile()
1832     oldi = utils.GetFileID(path=name)
1833     os.utime(name, None)
1834     fd = os.open(name, os.O_RDWR)
1835     try:
1836       newi = utils.GetFileID(fd=fd)
1837       self.failUnless(utils.VerifyFileID(oldi, newi))
1838       self.failUnless(utils.VerifyFileID(newi, oldi))
1839     finally:
1840       os.close(fd)
1841
1842   def testWriteFile(self):
1843     name = self._CreateTempFile()
1844     oldi = utils.GetFileID(path=name)
1845     mtime = oldi[2]
1846     os.utime(name, (mtime + 10, mtime + 10))
1847     self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
1848                       oldi, data="")
1849     os.utime(name, (mtime - 10, mtime - 10))
1850     utils.SafeWriteFile(name, oldi, data="")
1851     oldi = utils.GetFileID(path=name)
1852     mtime = oldi[2]
1853     os.utime(name, (mtime + 10, mtime + 10))
1854     # this doesn't raise, since we passed None
1855     utils.SafeWriteFile(name, None, data="")
1856
1857   def testError(self):
1858     t = tempfile.NamedTemporaryFile()
1859     self.assertRaises(errors.ProgrammerError, utils.GetFileID,
1860                       path=t.name, fd=t.fileno())
1861
1862
1863 class TimeMock:
1864   def __init__(self, values):
1865     self.values = values
1866
1867   def __call__(self):
1868     return self.values.pop(0)
1869
1870
1871 class TestRunningTimeout(unittest.TestCase):
1872   def setUp(self):
1873     self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
1874
1875   def testRemainingFloat(self):
1876     timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
1877     self.assertAlmostEqual(timeout.Remaining(), 4.7)
1878     self.assertAlmostEqual(timeout.Remaining(), 0.4)
1879     self.assertAlmostEqual(timeout.Remaining(), -1.5)
1880
1881   def testRemaining(self):
1882     self.time_fn = TimeMock([0, 2, 4, 5, 6])
1883     timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
1884     self.assertEqual(timeout.Remaining(), 3)
1885     self.assertEqual(timeout.Remaining(), 1)
1886     self.assertEqual(timeout.Remaining(), 0)
1887     self.assertEqual(timeout.Remaining(), -1)
1888
1889   def testRemainingNonNegative(self):
1890     timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
1891     self.assertAlmostEqual(timeout.Remaining(), 4.7)
1892     self.assertAlmostEqual(timeout.Remaining(), 0.4)
1893     self.assertEqual(timeout.Remaining(), 0.0)
1894
1895   def testNegativeTimeout(self):
1896     self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
1897
1898
1899 class TestTryConvert(unittest.TestCase):
1900   def test(self):
1901     for src, fn, result in [
1902       ("1", int, 1),
1903       ("a", int, "a"),
1904       ("", bool, False),
1905       ("a", bool, True),
1906       ]:
1907       self.assertEqual(utils.TryConvert(fn, src), result)
1908
1909
1910 class TestIsValidShellParam(unittest.TestCase):
1911   def test(self):
1912     for val, result in [
1913       ("abc", True),
1914       ("ab;cd", False),
1915       ]:
1916       self.assertEqual(utils.IsValidShellParam(val), result)
1917
1918
1919 class TestBuildShellCmd(unittest.TestCase):
1920   def test(self):
1921     self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
1922                       "ls %s", "ab;cd")
1923     self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
1924
1925
1926 class TestWriteFile(unittest.TestCase):
1927   def setUp(self):
1928     self.tfile = tempfile.NamedTemporaryFile()
1929     self.did_pre = False
1930     self.did_post = False
1931     self.did_write = False
1932
1933   def markPre(self, fd):
1934     self.did_pre = True
1935
1936   def markPost(self, fd):
1937     self.did_post = True
1938
1939   def markWrite(self, fd):
1940     self.did_write = True
1941
1942   def testWrite(self):
1943     data = "abc"
1944     utils.WriteFile(self.tfile.name, data=data)
1945     self.assertEqual(utils.ReadFile(self.tfile.name), data)
1946
1947   def testErrors(self):
1948     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
1949                       self.tfile.name, data="test", fn=lambda fd: None)
1950     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
1951     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
1952                       self.tfile.name, data="test", atime=0)
1953
1954   def testCalls(self):
1955     utils.WriteFile(self.tfile.name, fn=self.markWrite,
1956                     prewrite=self.markPre, postwrite=self.markPost)
1957     self.assertTrue(self.did_pre)
1958     self.assertTrue(self.did_post)
1959     self.assertTrue(self.did_write)
1960
1961   def testDryRun(self):
1962     orig = "abc"
1963     self.tfile.write(orig)
1964     self.tfile.flush()
1965     utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
1966     self.assertEqual(utils.ReadFile(self.tfile.name), orig)
1967
1968   def testTimes(self):
1969     f = self.tfile.name
1970     for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
1971                    (int(time.time()), 5000)]:
1972       utils.WriteFile(f, data="hello", atime=at, mtime=mt)
1973       st = os.stat(f)
1974       self.assertEqual(st.st_atime, at)
1975       self.assertEqual(st.st_mtime, mt)
1976
1977
1978   def testNoClose(self):
1979     data = "hello"
1980     self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
1981     fd = utils.WriteFile(self.tfile.name, data=data, close=False)
1982     try:
1983       os.lseek(fd, 0, 0)
1984       self.assertEqual(os.read(fd, 4096), data)
1985     finally:
1986       os.close(fd)
1987
1988
1989 if __name__ == '__main__':
1990   testutils.GanetiTestProgram()