Improve unittests for the utils module
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007, 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import distutils.version
25 import errno
26 import fcntl
27 import glob
28 import os
29 import os.path
30 import re
31 import shutil
32 import signal
33 import socket
34 import stat
35 import string
36 import tempfile
37 import time
38 import unittest
39 import warnings
40 import OpenSSL
41 from cStringIO import StringIO
42
43 import testutils
44 from ganeti import constants
45 from ganeti import compat
46 from ganeti import utils
47 from ganeti import errors
48 from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
49      ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
50      TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
51      ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
52
53
54 class TestIsProcessAlive(unittest.TestCase):
55   """Testing case for IsProcessAlive"""
56
57   def testExists(self):
58     mypid = os.getpid()
59     self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
60
61   def testNotExisting(self):
62     pid_non_existing = os.fork()
63     if pid_non_existing == 0:
64       os._exit(0)
65     elif pid_non_existing < 0:
66       raise SystemError("can't fork")
67     os.waitpid(pid_non_existing, 0)
68     self.assertFalse(utils.IsProcessAlive(pid_non_existing),
69                      "nonexisting process detected")
70
71
72 class TestGetProcStatusPath(unittest.TestCase):
73   def test(self):
74     self.assert_("/1234/" in utils._GetProcStatusPath(1234))
75     self.assertNotEqual(utils._GetProcStatusPath(1),
76                         utils._GetProcStatusPath(2))
77
78
79 class TestIsProcessHandlingSignal(unittest.TestCase):
80   def setUp(self):
81     self.tmpdir = tempfile.mkdtemp()
82
83   def tearDown(self):
84     shutil.rmtree(self.tmpdir)
85
86   def testParseSigsetT(self):
87     self.assertEqual(len(utils._ParseSigsetT("0")), 0)
88     self.assertEqual(utils._ParseSigsetT("1"), set([1]))
89     self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
90     self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
91     self.assertEqual(utils._ParseSigsetT("0000000180000202"),
92                      set([2, 10, 32, 33]))
93     self.assertEqual(utils._ParseSigsetT("0000000180000002"),
94                      set([2, 32, 33]))
95     self.assertEqual(utils._ParseSigsetT("0000000188000002"),
96                      set([2, 28, 32, 33]))
97     self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
98                      set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
99                           24, 25, 26, 28, 31]))
100     self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
101
102   def testGetProcStatusField(self):
103     for field in ["SigCgt", "Name", "FDSize"]:
104       for value in ["", "0", "cat", "  1234 KB"]:
105         pstatus = "\n".join([
106           "VmPeak: 999 kB",
107           "%s: %s" % (field, value),
108           "TracerPid: 0",
109           ])
110         result = utils._GetProcStatusField(pstatus, field)
111         self.assertEqual(result, value.strip())
112
113   def test(self):
114     sp = PathJoin(self.tmpdir, "status")
115
116     utils.WriteFile(sp, data="\n".join([
117       "Name:   bash",
118       "State:  S (sleeping)",
119       "SleepAVG:       98%",
120       "Pid:    22250",
121       "PPid:   10858",
122       "TracerPid:      0",
123       "SigBlk: 0000000000010000",
124       "SigIgn: 0000000000384004",
125       "SigCgt: 000000004b813efb",
126       "CapEff: 0000000000000000",
127       ]))
128
129     self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
130
131   def testNoSigCgt(self):
132     sp = PathJoin(self.tmpdir, "status")
133
134     utils.WriteFile(sp, data="\n".join([
135       "Name:   bash",
136       ]))
137
138     self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
139                       1234, 10, status_path=sp)
140
141   def testNoSuchFile(self):
142     sp = PathJoin(self.tmpdir, "notexist")
143
144     self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
145
146   @staticmethod
147   def _TestRealProcess():
148     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
149     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
150       raise Exception("SIGUSR1 is handled when it should not be")
151
152     signal.signal(signal.SIGUSR1, lambda signum, frame: None)
153     if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
154       raise Exception("SIGUSR1 is not handled when it should be")
155
156     signal.signal(signal.SIGUSR1, signal.SIG_IGN)
157     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
158       raise Exception("SIGUSR1 is not handled when it should be")
159
160     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
161     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
162       raise Exception("SIGUSR1 is handled when it should not be")
163
164     return True
165
166   def testRealProcess(self):
167     self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
168
169
170 class TestPidFileFunctions(unittest.TestCase):
171   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
172
173   def setUp(self):
174     self.dir = tempfile.mkdtemp()
175     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
176     utils.DaemonPidFileName = self.f_dpn
177
178   def testPidFileFunctions(self):
179     pid_file = self.f_dpn('test')
180     fd = utils.WritePidFile(self.f_dpn('test'))
181     self.failUnless(os.path.exists(pid_file),
182                     "PID file should have been created")
183     read_pid = utils.ReadPidFile(pid_file)
184     self.failUnlessEqual(read_pid, os.getpid())
185     self.failUnless(utils.IsProcessAlive(read_pid))
186     self.failUnlessRaises(errors.LockError, utils.WritePidFile,
187                           self.f_dpn('test'))
188     os.close(fd)
189     utils.RemovePidFile('test')
190     self.failIf(os.path.exists(pid_file),
191                 "PID file should not exist anymore")
192     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
193                          "ReadPidFile should return 0 for missing pid file")
194     fh = open(pid_file, "w")
195     fh.write("blah\n")
196     fh.close()
197     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
198                          "ReadPidFile should return 0 for invalid pid file")
199     # but now, even with the file existing, we should be able to lock it
200     fd = utils.WritePidFile(self.f_dpn('test'))
201     os.close(fd)
202     utils.RemovePidFile('test')
203     self.failIf(os.path.exists(pid_file),
204                 "PID file should not exist anymore")
205
206   def testKill(self):
207     pid_file = self.f_dpn('child')
208     r_fd, w_fd = os.pipe()
209     new_pid = os.fork()
210     if new_pid == 0: #child
211       utils.WritePidFile(self.f_dpn('child'))
212       os.write(w_fd, 'a')
213       signal.pause()
214       os._exit(0)
215       return
216     # else we are in the parent
217     # wait until the child has written the pid file
218     os.read(r_fd, 1)
219     read_pid = utils.ReadPidFile(pid_file)
220     self.failUnlessEqual(read_pid, new_pid)
221     self.failUnless(utils.IsProcessAlive(new_pid))
222     utils.KillProcess(new_pid, waitpid=True)
223     self.failIf(utils.IsProcessAlive(new_pid))
224     utils.RemovePidFile('child')
225     self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
226
227   def tearDown(self):
228     for name in os.listdir(self.dir):
229       os.unlink(os.path.join(self.dir, name))
230     os.rmdir(self.dir)
231
232
233 class TestRunCmd(testutils.GanetiTestCase):
234   """Testing case for the RunCmd function"""
235
236   def setUp(self):
237     testutils.GanetiTestCase.setUp(self)
238     self.magic = time.ctime() + " ganeti test"
239     self.fname = self._CreateTempFile()
240     self.fifo_tmpdir = tempfile.mkdtemp()
241     self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
242     os.mkfifo(self.fifo_file)
243
244   def tearDown(self):
245     shutil.rmtree(self.fifo_tmpdir)
246
247   def testOk(self):
248     """Test successful exit code"""
249     result = RunCmd("/bin/sh -c 'exit 0'")
250     self.assertEqual(result.exit_code, 0)
251     self.assertEqual(result.output, "")
252
253   def testFail(self):
254     """Test fail exit code"""
255     result = RunCmd("/bin/sh -c 'exit 1'")
256     self.assertEqual(result.exit_code, 1)
257     self.assertEqual(result.output, "")
258
259   def testStdout(self):
260     """Test standard output"""
261     cmd = 'echo -n "%s"' % self.magic
262     result = RunCmd("/bin/sh -c '%s'" % cmd)
263     self.assertEqual(result.stdout, self.magic)
264     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
265     self.assertEqual(result.output, "")
266     self.assertFileContent(self.fname, self.magic)
267
268   def testStderr(self):
269     """Test standard error"""
270     cmd = 'echo -n "%s"' % self.magic
271     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
272     self.assertEqual(result.stderr, self.magic)
273     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
274     self.assertEqual(result.output, "")
275     self.assertFileContent(self.fname, self.magic)
276
277   def testCombined(self):
278     """Test combined output"""
279     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
280     expected = "A" + self.magic + "B" + self.magic
281     result = RunCmd("/bin/sh -c '%s'" % cmd)
282     self.assertEqual(result.output, expected)
283     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
284     self.assertEqual(result.output, "")
285     self.assertFileContent(self.fname, expected)
286
287   def testSignal(self):
288     """Test signal"""
289     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
290     self.assertEqual(result.signal, 15)
291     self.assertEqual(result.output, "")
292
293   def testTimeoutClean(self):
294     cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
295     result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
296     self.assertEqual(result.exit_code, 0)
297
298   def testTimeoutKill(self):
299     cmd = ["/bin/sh", "-c", "trap '' TERM; read < %s" % self.fifo_file]
300     timeout = 0.2
301     out, err, status, ta = utils._RunCmdPipe(cmd, {}, False, "/", False,
302                                              timeout, _linger_timeout=0.2)
303     self.assert_(status < 0)
304     self.assertEqual(-status, signal.SIGKILL)
305
306   def testTimeoutOutputAfterTerm(self):
307     cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
308     result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
309     self.assert_(result.failed)
310     self.assertEqual(result.stdout, "sigtermed\n")
311
312   def testListRun(self):
313     """Test list runs"""
314     result = RunCmd(["true"])
315     self.assertEqual(result.signal, None)
316     self.assertEqual(result.exit_code, 0)
317     result = RunCmd(["/bin/sh", "-c", "exit 1"])
318     self.assertEqual(result.signal, None)
319     self.assertEqual(result.exit_code, 1)
320     result = RunCmd(["echo", "-n", self.magic])
321     self.assertEqual(result.signal, None)
322     self.assertEqual(result.exit_code, 0)
323     self.assertEqual(result.stdout, self.magic)
324
325   def testFileEmptyOutput(self):
326     """Test file output"""
327     result = RunCmd(["true"], output=self.fname)
328     self.assertEqual(result.signal, None)
329     self.assertEqual(result.exit_code, 0)
330     self.assertFileContent(self.fname, "")
331
332   def testLang(self):
333     """Test locale environment"""
334     old_env = os.environ.copy()
335     try:
336       os.environ["LANG"] = "en_US.UTF-8"
337       os.environ["LC_ALL"] = "en_US.UTF-8"
338       result = RunCmd(["locale"])
339       for line in result.output.splitlines():
340         key, value = line.split("=", 1)
341         # Ignore these variables, they're overridden by LC_ALL
342         if key == "LANG" or key == "LANGUAGE":
343           continue
344         self.failIf(value and value != "C" and value != '"C"',
345             "Variable %s is set to the invalid value '%s'" % (key, value))
346     finally:
347       os.environ = old_env
348
349   def testDefaultCwd(self):
350     """Test default working directory"""
351     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
352
353   def testCwd(self):
354     """Test default working directory"""
355     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
356     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
357     cwd = os.getcwd()
358     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
359
360   def testResetEnv(self):
361     """Test environment reset functionality"""
362     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
363     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
364                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
365
366   def testNoFork(self):
367     """Test that nofork raise an error"""
368     assert not utils.no_fork
369     utils.no_fork = True
370     try:
371       self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
372     finally:
373       utils.no_fork = False
374
375   def testWrongParams(self):
376     """Test wrong parameters"""
377     self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
378                       output="/dev/null", interactive=True)
379
380
381 class TestRunParts(testutils.GanetiTestCase):
382   """Testing case for the RunParts function"""
383
384   def setUp(self):
385     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
386
387   def tearDown(self):
388     shutil.rmtree(self.rundir)
389
390   def testEmpty(self):
391     """Test on an empty dir"""
392     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
393
394   def testSkipWrongName(self):
395     """Test that wrong files are skipped"""
396     fname = os.path.join(self.rundir, "00test.dot")
397     utils.WriteFile(fname, data="")
398     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
399     relname = os.path.basename(fname)
400     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
401                          [(relname, constants.RUNPARTS_SKIP, None)])
402
403   def testSkipNonExec(self):
404     """Test that non executable files are skipped"""
405     fname = os.path.join(self.rundir, "00test")
406     utils.WriteFile(fname, data="")
407     relname = os.path.basename(fname)
408     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
409                          [(relname, constants.RUNPARTS_SKIP, None)])
410
411   def testError(self):
412     """Test error on a broken executable"""
413     fname = os.path.join(self.rundir, "00test")
414     utils.WriteFile(fname, data="")
415     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
416     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
417     self.failUnlessEqual(relname, os.path.basename(fname))
418     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
419     self.failUnless(error)
420
421   def testSorted(self):
422     """Test executions are sorted"""
423     files = []
424     files.append(os.path.join(self.rundir, "64test"))
425     files.append(os.path.join(self.rundir, "00test"))
426     files.append(os.path.join(self.rundir, "42test"))
427
428     for fname in files:
429       utils.WriteFile(fname, data="")
430
431     results = RunParts(self.rundir, reset_env=True)
432
433     for fname in sorted(files):
434       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
435
436   def testOk(self):
437     """Test correct execution"""
438     fname = os.path.join(self.rundir, "00test")
439     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
440     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
441     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
442     self.failUnlessEqual(relname, os.path.basename(fname))
443     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
444     self.failUnlessEqual(runresult.stdout, "ciao")
445
446   def testRunFail(self):
447     """Test correct execution, with run failure"""
448     fname = os.path.join(self.rundir, "00test")
449     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
450     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
451     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
452     self.failUnlessEqual(relname, os.path.basename(fname))
453     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
454     self.failUnlessEqual(runresult.exit_code, 1)
455     self.failUnless(runresult.failed)
456
457   def testRunMix(self):
458     files = []
459     files.append(os.path.join(self.rundir, "00test"))
460     files.append(os.path.join(self.rundir, "42test"))
461     files.append(os.path.join(self.rundir, "64test"))
462     files.append(os.path.join(self.rundir, "99test"))
463
464     files.sort()
465
466     # 1st has errors in execution
467     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
468     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
469
470     # 2nd is skipped
471     utils.WriteFile(files[1], data="")
472
473     # 3rd cannot execute properly
474     utils.WriteFile(files[2], data="")
475     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
476
477     # 4th execs
478     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
479     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
480
481     results = RunParts(self.rundir, reset_env=True)
482
483     (relname, status, runresult) = results[0]
484     self.failUnlessEqual(relname, os.path.basename(files[0]))
485     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
486     self.failUnlessEqual(runresult.exit_code, 1)
487     self.failUnless(runresult.failed)
488
489     (relname, status, runresult) = results[1]
490     self.failUnlessEqual(relname, os.path.basename(files[1]))
491     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
492     self.failUnlessEqual(runresult, None)
493
494     (relname, status, runresult) = results[2]
495     self.failUnlessEqual(relname, os.path.basename(files[2]))
496     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
497     self.failUnless(runresult)
498
499     (relname, status, runresult) = results[3]
500     self.failUnlessEqual(relname, os.path.basename(files[3]))
501     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
502     self.failUnlessEqual(runresult.output, "ciao")
503     self.failUnlessEqual(runresult.exit_code, 0)
504     self.failUnless(not runresult.failed)
505
506   def testMissingDirectory(self):
507     nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
508     self.assertEqual(RunParts(nosuchdir), [])
509
510
511 class TestStartDaemon(testutils.GanetiTestCase):
512   def setUp(self):
513     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
514     self.tmpfile = os.path.join(self.tmpdir, "test")
515
516   def tearDown(self):
517     shutil.rmtree(self.tmpdir)
518
519   def testShell(self):
520     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
521     self._wait(self.tmpfile, 60.0, "Hello World")
522
523   def testShellOutput(self):
524     utils.StartDaemon("echo Hello World", output=self.tmpfile)
525     self._wait(self.tmpfile, 60.0, "Hello World")
526
527   def testNoShellNoOutput(self):
528     utils.StartDaemon(["pwd"])
529
530   def testNoShellNoOutputTouch(self):
531     testfile = os.path.join(self.tmpdir, "check")
532     self.failIf(os.path.exists(testfile))
533     utils.StartDaemon(["touch", testfile])
534     self._wait(testfile, 60.0, "")
535
536   def testNoShellOutput(self):
537     utils.StartDaemon(["pwd"], output=self.tmpfile)
538     self._wait(self.tmpfile, 60.0, "/")
539
540   def testNoShellOutputCwd(self):
541     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
542     self._wait(self.tmpfile, 60.0, os.getcwd())
543
544   def testShellEnv(self):
545     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
546                       env={ "GNT_TEST_VAR": "Hello World", })
547     self._wait(self.tmpfile, 60.0, "Hello World")
548
549   def testNoShellEnv(self):
550     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
551                       env={ "GNT_TEST_VAR": "Hello World", })
552     self._wait(self.tmpfile, 60.0, "Hello World")
553
554   def testOutputFd(self):
555     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
556     try:
557       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
558     finally:
559       os.close(fd)
560     self._wait(self.tmpfile, 60.0, os.getcwd())
561
562   def testPid(self):
563     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
564     self._wait(self.tmpfile, 60.0, str(pid))
565
566   def testPidFile(self):
567     pidfile = os.path.join(self.tmpdir, "pid")
568     checkfile = os.path.join(self.tmpdir, "abort")
569
570     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
571                             output=self.tmpfile)
572     try:
573       fd = os.open(pidfile, os.O_RDONLY)
574       try:
575         # Check file is locked
576         self.assertRaises(errors.LockError, utils.LockFile, fd)
577
578         pidtext = os.read(fd, 100)
579       finally:
580         os.close(fd)
581
582       self.assertEqual(int(pidtext.strip()), pid)
583
584       self.assert_(utils.IsProcessAlive(pid))
585     finally:
586       # No matter what happens, kill daemon
587       utils.KillProcess(pid, timeout=5.0, waitpid=False)
588       self.failIf(utils.IsProcessAlive(pid))
589
590     self.assertEqual(utils.ReadFile(self.tmpfile), "")
591
592   def _wait(self, path, timeout, expected):
593     # Due to the asynchronous nature of daemon processes, polling is necessary.
594     # A timeout makes sure the test doesn't hang forever.
595     def _CheckFile():
596       if not (os.path.isfile(path) and
597               utils.ReadFile(path).strip() == expected):
598         raise utils.RetryAgain()
599
600     try:
601       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
602     except utils.RetryTimeout:
603       self.fail("Apparently the daemon didn't run in %s seconds and/or"
604                 " didn't write the correct output" % timeout)
605
606   def testError(self):
607     self.assertRaises(errors.OpExecError, utils.StartDaemon,
608                       ["./does-NOT-EXIST/here/0123456789"])
609     self.assertRaises(errors.OpExecError, utils.StartDaemon,
610                       ["./does-NOT-EXIST/here/0123456789"],
611                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
612     self.assertRaises(errors.OpExecError, utils.StartDaemon,
613                       ["./does-NOT-EXIST/here/0123456789"],
614                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
615     self.assertRaises(errors.OpExecError, utils.StartDaemon,
616                       ["./does-NOT-EXIST/here/0123456789"],
617                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
618
619     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
620     try:
621       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
622                         ["./does-NOT-EXIST/here/0123456789"],
623                         output=self.tmpfile, output_fd=fd)
624     finally:
625       os.close(fd)
626
627
628 class TestSetCloseOnExecFlag(unittest.TestCase):
629   """Tests for SetCloseOnExecFlag"""
630
631   def setUp(self):
632     self.tmpfile = tempfile.TemporaryFile()
633
634   def testEnable(self):
635     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
636     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
637                     fcntl.FD_CLOEXEC)
638
639   def testDisable(self):
640     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
641     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
642                 fcntl.FD_CLOEXEC)
643
644
645 class TestSetNonblockFlag(unittest.TestCase):
646   def setUp(self):
647     self.tmpfile = tempfile.TemporaryFile()
648
649   def testEnable(self):
650     utils.SetNonblockFlag(self.tmpfile.fileno(), True)
651     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
652                     os.O_NONBLOCK)
653
654   def testDisable(self):
655     utils.SetNonblockFlag(self.tmpfile.fileno(), False)
656     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
657                 os.O_NONBLOCK)
658
659
660 class TestRemoveFile(unittest.TestCase):
661   """Test case for the RemoveFile function"""
662
663   def setUp(self):
664     """Create a temp dir and file for each case"""
665     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
666     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
667     os.close(fd)
668
669   def tearDown(self):
670     if os.path.exists(self.tmpfile):
671       os.unlink(self.tmpfile)
672     os.rmdir(self.tmpdir)
673
674   def testIgnoreDirs(self):
675     """Test that RemoveFile() ignores directories"""
676     self.assertEqual(None, RemoveFile(self.tmpdir))
677
678   def testIgnoreNotExisting(self):
679     """Test that RemoveFile() ignores non-existing files"""
680     RemoveFile(self.tmpfile)
681     RemoveFile(self.tmpfile)
682
683   def testRemoveFile(self):
684     """Test that RemoveFile does remove a file"""
685     RemoveFile(self.tmpfile)
686     if os.path.exists(self.tmpfile):
687       self.fail("File '%s' not removed" % self.tmpfile)
688
689   def testRemoveSymlink(self):
690     """Test that RemoveFile does remove symlinks"""
691     symlink = self.tmpdir + "/symlink"
692     os.symlink("no-such-file", symlink)
693     RemoveFile(symlink)
694     if os.path.exists(symlink):
695       self.fail("File '%s' not removed" % symlink)
696     os.symlink(self.tmpfile, symlink)
697     RemoveFile(symlink)
698     if os.path.exists(symlink):
699       self.fail("File '%s' not removed" % symlink)
700
701
702 class TestRemoveDir(unittest.TestCase):
703   def setUp(self):
704     self.tmpdir = tempfile.mkdtemp()
705
706   def tearDown(self):
707     try:
708       shutil.rmtree(self.tmpdir)
709     except EnvironmentError:
710       pass
711
712   def testEmptyDir(self):
713     utils.RemoveDir(self.tmpdir)
714     self.assertFalse(os.path.isdir(self.tmpdir))
715
716   def testNonEmptyDir(self):
717     self.tmpfile = os.path.join(self.tmpdir, "test1")
718     open(self.tmpfile, "w").close()
719     self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
720
721
722 class TestRename(unittest.TestCase):
723   """Test case for RenameFile"""
724
725   def setUp(self):
726     """Create a temporary directory"""
727     self.tmpdir = tempfile.mkdtemp()
728     self.tmpfile = os.path.join(self.tmpdir, "test1")
729
730     # Touch the file
731     open(self.tmpfile, "w").close()
732
733   def tearDown(self):
734     """Remove temporary directory"""
735     shutil.rmtree(self.tmpdir)
736
737   def testSimpleRename1(self):
738     """Simple rename 1"""
739     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
740     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
741
742   def testSimpleRename2(self):
743     """Simple rename 2"""
744     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
745                      mkdir=True)
746     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
747
748   def testRenameMkdir(self):
749     """Rename with mkdir"""
750     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
751                      mkdir=True)
752     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
753     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
754
755     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
756                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
757                      mkdir=True)
758     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
759     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
760     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
761
762
763 class TestMatchNameComponent(unittest.TestCase):
764   """Test case for the MatchNameComponent function"""
765
766   def testEmptyList(self):
767     """Test that there is no match against an empty list"""
768
769     self.failUnlessEqual(MatchNameComponent("", []), None)
770     self.failUnlessEqual(MatchNameComponent("test", []), None)
771
772   def testSingleMatch(self):
773     """Test that a single match is performed correctly"""
774     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
775     for key in "test2", "test2.example", "test2.example.com":
776       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
777
778   def testMultipleMatches(self):
779     """Test that a multiple match is returned as None"""
780     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
781     for key in "test1", "test1.example":
782       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
783
784   def testFullMatch(self):
785     """Test that a full match is returned correctly"""
786     key1 = "test1"
787     key2 = "test1.example"
788     mlist = [key2, key2 + ".com"]
789     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
790     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
791
792   def testCaseInsensitivePartialMatch(self):
793     """Test for the case_insensitive keyword"""
794     mlist = ["test1.example.com", "test2.example.net"]
795     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
796                      "test2.example.net")
797     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
798                      "test2.example.net")
799     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
800                      "test2.example.net")
801     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
802                      "test2.example.net")
803
804
805   def testCaseInsensitiveFullMatch(self):
806     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
807     # Between the two ts1 a full string match non-case insensitive should work
808     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
809                      None)
810     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
811                      "ts1.ex")
812     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
813                      "ts1.ex")
814     # Between the two ts2 only case differs, so only case-match works
815     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
816                      "ts2.ex")
817     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
818                      "Ts2.ex")
819     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
820                      None)
821
822
823 class TestReadFile(testutils.GanetiTestCase):
824
825   def testReadAll(self):
826     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
827     self.assertEqual(len(data), 814)
828
829     h = compat.md5_hash()
830     h.update(data)
831     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
832
833   def testReadSize(self):
834     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
835                           size=100)
836     self.assertEqual(len(data), 100)
837
838     h = compat.md5_hash()
839     h.update(data)
840     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
841
842   def testError(self):
843     self.assertRaises(EnvironmentError, utils.ReadFile,
844                       "/dev/null/does-not-exist")
845
846
847 class TestReadOneLineFile(testutils.GanetiTestCase):
848
849   def setUp(self):
850     testutils.GanetiTestCase.setUp(self)
851
852   def testDefault(self):
853     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
854     self.assertEqual(len(data), 27)
855     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
856
857   def testNotStrict(self):
858     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
859     self.assertEqual(len(data), 27)
860     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
861
862   def testStrictFailure(self):
863     self.assertRaises(errors.GenericError, ReadOneLineFile,
864                       self._TestDataFilename("cert1.pem"), strict=True)
865
866   def testLongLine(self):
867     dummydata = (1024 * "Hello World! ")
868     myfile = self._CreateTempFile()
869     utils.WriteFile(myfile, data=dummydata)
870     datastrict = ReadOneLineFile(myfile, strict=True)
871     datalax = ReadOneLineFile(myfile, strict=False)
872     self.assertEqual(dummydata, datastrict)
873     self.assertEqual(dummydata, datalax)
874
875   def testNewline(self):
876     myfile = self._CreateTempFile()
877     myline = "myline"
878     for nl in ["", "\n", "\r\n"]:
879       dummydata = "%s%s" % (myline, nl)
880       utils.WriteFile(myfile, data=dummydata)
881       datalax = ReadOneLineFile(myfile, strict=False)
882       self.assertEqual(myline, datalax)
883       datastrict = ReadOneLineFile(myfile, strict=True)
884       self.assertEqual(myline, datastrict)
885
886   def testWhitespaceAndMultipleLines(self):
887     myfile = self._CreateTempFile()
888     for nl in ["", "\n", "\r\n"]:
889       for ws in [" ", "\t", "\t\t  \t", "\t "]:
890         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
891         utils.WriteFile(myfile, data=dummydata)
892         datalax = ReadOneLineFile(myfile, strict=False)
893         if nl:
894           self.assert_(set("\r\n") & set(dummydata))
895           self.assertRaises(errors.GenericError, ReadOneLineFile,
896                             myfile, strict=True)
897           explen = len("Foo bar baz ") + len(ws)
898           self.assertEqual(len(datalax), explen)
899           self.assertEqual(datalax, dummydata[:explen])
900           self.assertFalse(set("\r\n") & set(datalax))
901         else:
902           datastrict = ReadOneLineFile(myfile, strict=True)
903           self.assertEqual(dummydata, datastrict)
904           self.assertEqual(dummydata, datalax)
905
906   def testEmptylines(self):
907     myfile = self._CreateTempFile()
908     myline = "myline"
909     for nl in ["\n", "\r\n"]:
910       for ol in ["", "otherline"]:
911         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
912         utils.WriteFile(myfile, data=dummydata)
913         self.assert_(set("\r\n") & set(dummydata))
914         datalax = ReadOneLineFile(myfile, strict=False)
915         self.assertEqual(myline, datalax)
916         if ol:
917           self.assertRaises(errors.GenericError, ReadOneLineFile,
918                             myfile, strict=True)
919         else:
920           datastrict = ReadOneLineFile(myfile, strict=True)
921           self.assertEqual(myline, datastrict)
922
923   def testEmptyfile(self):
924     myfile = self._CreateTempFile()
925     self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
926
927
928 class TestTimestampForFilename(unittest.TestCase):
929   def test(self):
930     self.assert_("." not in utils.TimestampForFilename())
931     self.assert_(":" not in utils.TimestampForFilename())
932
933
934 class TestCreateBackup(testutils.GanetiTestCase):
935   def setUp(self):
936     testutils.GanetiTestCase.setUp(self)
937
938     self.tmpdir = tempfile.mkdtemp()
939
940   def tearDown(self):
941     testutils.GanetiTestCase.tearDown(self)
942
943     shutil.rmtree(self.tmpdir)
944
945   def testEmpty(self):
946     filename = PathJoin(self.tmpdir, "config.data")
947     utils.WriteFile(filename, data="")
948     bname = utils.CreateBackup(filename)
949     self.assertFileContent(bname, "")
950     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
951     utils.CreateBackup(filename)
952     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
953     utils.CreateBackup(filename)
954     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
955
956     fifoname = PathJoin(self.tmpdir, "fifo")
957     os.mkfifo(fifoname)
958     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
959
960   def testContent(self):
961     bkpcount = 0
962     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
963       for rep in [1, 2, 10, 127]:
964         testdata = data * rep
965
966         filename = PathJoin(self.tmpdir, "test.data_")
967         utils.WriteFile(filename, data=testdata)
968         self.assertFileContent(filename, testdata)
969
970         for _ in range(3):
971           bname = utils.CreateBackup(filename)
972           bkpcount += 1
973           self.assertFileContent(bname, testdata)
974           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
975
976
977 class TestFormatUnit(unittest.TestCase):
978   """Test case for the FormatUnit function"""
979
980   def testMiB(self):
981     self.assertEqual(FormatUnit(1, 'h'), '1M')
982     self.assertEqual(FormatUnit(100, 'h'), '100M')
983     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
984
985     self.assertEqual(FormatUnit(1, 'm'), '1')
986     self.assertEqual(FormatUnit(100, 'm'), '100')
987     self.assertEqual(FormatUnit(1023, 'm'), '1023')
988
989     self.assertEqual(FormatUnit(1024, 'm'), '1024')
990     self.assertEqual(FormatUnit(1536, 'm'), '1536')
991     self.assertEqual(FormatUnit(17133, 'm'), '17133')
992     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
993
994   def testGiB(self):
995     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
996     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
997     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
998     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
999
1000     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
1001     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
1002     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
1003     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
1004
1005     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
1006     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
1007     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
1008
1009   def testTiB(self):
1010     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
1011     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
1012     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
1013
1014     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
1015     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
1016     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
1017
1018   def testErrors(self):
1019     self.assertRaises(errors.ProgrammerError, FormatUnit, 1, "a")
1020
1021
1022 class TestParseUnit(unittest.TestCase):
1023   """Test case for the ParseUnit function"""
1024
1025   SCALES = (('', 1),
1026             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
1027             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
1028             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
1029
1030   def testRounding(self):
1031     self.assertEqual(ParseUnit('0'), 0)
1032     self.assertEqual(ParseUnit('1'), 4)
1033     self.assertEqual(ParseUnit('2'), 4)
1034     self.assertEqual(ParseUnit('3'), 4)
1035
1036     self.assertEqual(ParseUnit('124'), 124)
1037     self.assertEqual(ParseUnit('125'), 128)
1038     self.assertEqual(ParseUnit('126'), 128)
1039     self.assertEqual(ParseUnit('127'), 128)
1040     self.assertEqual(ParseUnit('128'), 128)
1041     self.assertEqual(ParseUnit('129'), 132)
1042     self.assertEqual(ParseUnit('130'), 132)
1043
1044   def testFloating(self):
1045     self.assertEqual(ParseUnit('0'), 0)
1046     self.assertEqual(ParseUnit('0.5'), 4)
1047     self.assertEqual(ParseUnit('1.75'), 4)
1048     self.assertEqual(ParseUnit('1.99'), 4)
1049     self.assertEqual(ParseUnit('2.00'), 4)
1050     self.assertEqual(ParseUnit('2.01'), 4)
1051     self.assertEqual(ParseUnit('3.99'), 4)
1052     self.assertEqual(ParseUnit('4.00'), 4)
1053     self.assertEqual(ParseUnit('4.01'), 8)
1054     self.assertEqual(ParseUnit('1.5G'), 1536)
1055     self.assertEqual(ParseUnit('1.8G'), 1844)
1056     self.assertEqual(ParseUnit('8.28T'), 8682212)
1057
1058   def testSuffixes(self):
1059     for sep in ('', ' ', '   ', "\t", "\t "):
1060       for suffix, scale in TestParseUnit.SCALES:
1061         for func in (lambda x: x, str.lower, str.upper):
1062           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
1063                            1024 * scale)
1064
1065   def testInvalidInput(self):
1066     for sep in ('-', '_', ',', 'a'):
1067       for suffix, _ in TestParseUnit.SCALES:
1068         self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
1069
1070     for suffix, _ in TestParseUnit.SCALES:
1071       self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
1072
1073
1074 class TestParseCpuMask(unittest.TestCase):
1075   """Test case for the ParseCpuMask function."""
1076
1077   def testWellFormed(self):
1078     self.assertEqual(utils.ParseCpuMask(""), [])
1079     self.assertEqual(utils.ParseCpuMask("1"), [1])
1080     self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
1081
1082   def testInvalidInput(self):
1083     for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
1084       self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
1085
1086
1087 class TestSshKeys(testutils.GanetiTestCase):
1088   """Test case for the AddAuthorizedKey function"""
1089
1090   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1091   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1092            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1093
1094   def setUp(self):
1095     testutils.GanetiTestCase.setUp(self)
1096     self.tmpname = self._CreateTempFile()
1097     handle = open(self.tmpname, 'w')
1098     try:
1099       handle.write("%s\n" % TestSshKeys.KEY_A)
1100       handle.write("%s\n" % TestSshKeys.KEY_B)
1101     finally:
1102       handle.close()
1103
1104   def testAddingNewKey(self):
1105     utils.AddAuthorizedKey(self.tmpname,
1106                            'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1107
1108     self.assertFileContent(self.tmpname,
1109       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1110       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1111       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1112       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1113
1114   def testAddingAlmostButNotCompletelyTheSameKey(self):
1115     utils.AddAuthorizedKey(self.tmpname,
1116         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1117
1118     self.assertFileContent(self.tmpname,
1119       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1120       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1121       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1122       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1123
1124   def testAddingExistingKeyWithSomeMoreSpaces(self):
1125     utils.AddAuthorizedKey(self.tmpname,
1126         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1127
1128     self.assertFileContent(self.tmpname,
1129       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1130       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1131       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1132
1133   def testRemovingExistingKeyWithSomeMoreSpaces(self):
1134     utils.RemoveAuthorizedKey(self.tmpname,
1135         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1136
1137     self.assertFileContent(self.tmpname,
1138       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1139       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1140
1141   def testRemovingNonExistingKey(self):
1142     utils.RemoveAuthorizedKey(self.tmpname,
1143         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1144
1145     self.assertFileContent(self.tmpname,
1146       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1147       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1148       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1149
1150
1151 class TestEtcHosts(testutils.GanetiTestCase):
1152   """Test functions modifying /etc/hosts"""
1153
1154   def setUp(self):
1155     testutils.GanetiTestCase.setUp(self)
1156     self.tmpname = self._CreateTempFile()
1157     handle = open(self.tmpname, 'w')
1158     try:
1159       handle.write('# This is a test file for /etc/hosts\n')
1160       handle.write('127.0.0.1\tlocalhost\n')
1161       handle.write('192.0.2.1 router gw\n')
1162     finally:
1163       handle.close()
1164
1165   def testSettingNewIp(self):
1166     SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1167                      ['myhost'])
1168
1169     self.assertFileContent(self.tmpname,
1170       "# This is a test file for /etc/hosts\n"
1171       "127.0.0.1\tlocalhost\n"
1172       "192.0.2.1 router gw\n"
1173       "198.51.100.4\tmyhost.example.com myhost\n")
1174     self.assertFileMode(self.tmpname, 0644)
1175
1176   def testSettingExistingIp(self):
1177     SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1178                      ['myhost'])
1179
1180     self.assertFileContent(self.tmpname,
1181       "# This is a test file for /etc/hosts\n"
1182       "127.0.0.1\tlocalhost\n"
1183       "192.0.2.1\tmyhost.example.com myhost\n")
1184     self.assertFileMode(self.tmpname, 0644)
1185
1186   def testSettingDuplicateName(self):
1187     SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1188
1189     self.assertFileContent(self.tmpname,
1190       "# This is a test file for /etc/hosts\n"
1191       "127.0.0.1\tlocalhost\n"
1192       "192.0.2.1 router gw\n"
1193       "198.51.100.4\tmyhost\n")
1194     self.assertFileMode(self.tmpname, 0644)
1195
1196   def testRemovingExistingHost(self):
1197     RemoveEtcHostsEntry(self.tmpname, 'router')
1198
1199     self.assertFileContent(self.tmpname,
1200       "# This is a test file for /etc/hosts\n"
1201       "127.0.0.1\tlocalhost\n"
1202       "192.0.2.1 gw\n")
1203     self.assertFileMode(self.tmpname, 0644)
1204
1205   def testRemovingSingleExistingHost(self):
1206     RemoveEtcHostsEntry(self.tmpname, 'localhost')
1207
1208     self.assertFileContent(self.tmpname,
1209       "# This is a test file for /etc/hosts\n"
1210       "192.0.2.1 router gw\n")
1211     self.assertFileMode(self.tmpname, 0644)
1212
1213   def testRemovingNonExistingHost(self):
1214     RemoveEtcHostsEntry(self.tmpname, 'myhost')
1215
1216     self.assertFileContent(self.tmpname,
1217       "# This is a test file for /etc/hosts\n"
1218       "127.0.0.1\tlocalhost\n"
1219       "192.0.2.1 router gw\n")
1220     self.assertFileMode(self.tmpname, 0644)
1221
1222   def testRemovingAlias(self):
1223     RemoveEtcHostsEntry(self.tmpname, 'gw')
1224
1225     self.assertFileContent(self.tmpname,
1226       "# This is a test file for /etc/hosts\n"
1227       "127.0.0.1\tlocalhost\n"
1228       "192.0.2.1 router\n")
1229     self.assertFileMode(self.tmpname, 0644)
1230
1231
1232 class TestGetMounts(unittest.TestCase):
1233   """Test case for GetMounts()."""
1234
1235   TESTDATA = (
1236     "rootfs /     rootfs rw 0 0\n"
1237     "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1238     "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1239
1240   def setUp(self):
1241     self.tmpfile = tempfile.NamedTemporaryFile()
1242     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1243
1244   def testGetMounts(self):
1245     self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1246       [
1247         ("rootfs", "/", "rootfs", "rw"),
1248         ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1249         ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1250       ])
1251
1252
1253 class TestShellQuoting(unittest.TestCase):
1254   """Test case for shell quoting functions"""
1255
1256   def testShellQuote(self):
1257     self.assertEqual(ShellQuote('abc'), "abc")
1258     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1259     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1260     self.assertEqual(ShellQuote("a b c"), "'a b c'")
1261     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1262
1263   def testShellQuoteArgs(self):
1264     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1265     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1266     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1267
1268
1269 class TestListVisibleFiles(unittest.TestCase):
1270   """Test case for ListVisibleFiles"""
1271
1272   def setUp(self):
1273     self.path = tempfile.mkdtemp()
1274
1275   def tearDown(self):
1276     shutil.rmtree(self.path)
1277
1278   def _CreateFiles(self, files):
1279     for name in files:
1280       utils.WriteFile(os.path.join(self.path, name), data="test")
1281
1282   def _test(self, files, expected):
1283     self._CreateFiles(files)
1284     found = ListVisibleFiles(self.path)
1285     self.assertEqual(set(found), set(expected))
1286
1287   def testAllVisible(self):
1288     files = ["a", "b", "c"]
1289     expected = files
1290     self._test(files, expected)
1291
1292   def testNoneVisible(self):
1293     files = [".a", ".b", ".c"]
1294     expected = []
1295     self._test(files, expected)
1296
1297   def testSomeVisible(self):
1298     files = ["a", "b", ".c"]
1299     expected = ["a", "b"]
1300     self._test(files, expected)
1301
1302   def testNonAbsolutePath(self):
1303     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1304
1305   def testNonNormalizedPath(self):
1306     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1307                           "/bin/../tmp")
1308
1309
1310 class TestNewUUID(unittest.TestCase):
1311   """Test case for NewUUID"""
1312
1313   def runTest(self):
1314     self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1315
1316
1317 class TestUniqueSequence(unittest.TestCase):
1318   """Test case for UniqueSequence"""
1319
1320   def _test(self, input, expected):
1321     self.assertEqual(utils.UniqueSequence(input), expected)
1322
1323   def runTest(self):
1324     # Ordered input
1325     self._test([1, 2, 3], [1, 2, 3])
1326     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1327     self._test([1, 2, 2, 3], [1, 2, 3])
1328     self._test([1, 2, 3, 3], [1, 2, 3])
1329
1330     # Unordered input
1331     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1332     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1333
1334     # Strings
1335     self._test(["a", "a"], ["a"])
1336     self._test(["a", "b"], ["a", "b"])
1337     self._test(["a", "b", "a"], ["a", "b"])
1338
1339
1340 class TestFirstFree(unittest.TestCase):
1341   """Test case for the FirstFree function"""
1342
1343   def test(self):
1344     """Test FirstFree"""
1345     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1346     self.failUnlessEqual(FirstFree([]), None)
1347     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1348     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1349     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1350
1351
1352 class TestTailFile(testutils.GanetiTestCase):
1353   """Test case for the TailFile function"""
1354
1355   def testEmpty(self):
1356     fname = self._CreateTempFile()
1357     self.failUnlessEqual(TailFile(fname), [])
1358     self.failUnlessEqual(TailFile(fname, lines=25), [])
1359
1360   def testAllLines(self):
1361     data = ["test %d" % i for i in range(30)]
1362     for i in range(30):
1363       fname = self._CreateTempFile()
1364       fd = open(fname, "w")
1365       fd.write("\n".join(data[:i]))
1366       if i > 0:
1367         fd.write("\n")
1368       fd.close()
1369       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1370
1371   def testPartialLines(self):
1372     data = ["test %d" % i for i in range(30)]
1373     fname = self._CreateTempFile()
1374     fd = open(fname, "w")
1375     fd.write("\n".join(data))
1376     fd.write("\n")
1377     fd.close()
1378     for i in range(1, 30):
1379       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1380
1381   def testBigFile(self):
1382     data = ["test %d" % i for i in range(30)]
1383     fname = self._CreateTempFile()
1384     fd = open(fname, "w")
1385     fd.write("X" * 1048576)
1386     fd.write("\n")
1387     fd.write("\n".join(data))
1388     fd.write("\n")
1389     fd.close()
1390     for i in range(1, 30):
1391       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1392
1393
1394 class _BaseFileLockTest:
1395   """Test case for the FileLock class"""
1396
1397   def testSharedNonblocking(self):
1398     self.lock.Shared(blocking=False)
1399     self.lock.Close()
1400
1401   def testExclusiveNonblocking(self):
1402     self.lock.Exclusive(blocking=False)
1403     self.lock.Close()
1404
1405   def testUnlockNonblocking(self):
1406     self.lock.Unlock(blocking=False)
1407     self.lock.Close()
1408
1409   def testSharedBlocking(self):
1410     self.lock.Shared(blocking=True)
1411     self.lock.Close()
1412
1413   def testExclusiveBlocking(self):
1414     self.lock.Exclusive(blocking=True)
1415     self.lock.Close()
1416
1417   def testUnlockBlocking(self):
1418     self.lock.Unlock(blocking=True)
1419     self.lock.Close()
1420
1421   def testSharedExclusiveUnlock(self):
1422     self.lock.Shared(blocking=False)
1423     self.lock.Exclusive(blocking=False)
1424     self.lock.Unlock(blocking=False)
1425     self.lock.Close()
1426
1427   def testExclusiveSharedUnlock(self):
1428     self.lock.Exclusive(blocking=False)
1429     self.lock.Shared(blocking=False)
1430     self.lock.Unlock(blocking=False)
1431     self.lock.Close()
1432
1433   def testSimpleTimeout(self):
1434     # These will succeed on the first attempt, hence a short timeout
1435     self.lock.Shared(blocking=True, timeout=10.0)
1436     self.lock.Exclusive(blocking=False, timeout=10.0)
1437     self.lock.Unlock(blocking=True, timeout=10.0)
1438     self.lock.Close()
1439
1440   @staticmethod
1441   def _TryLockInner(filename, shared, blocking):
1442     lock = utils.FileLock.Open(filename)
1443
1444     if shared:
1445       fn = lock.Shared
1446     else:
1447       fn = lock.Exclusive
1448
1449     try:
1450       # The timeout doesn't really matter as the parent process waits for us to
1451       # finish anyway.
1452       fn(blocking=blocking, timeout=0.01)
1453     except errors.LockError, err:
1454       return False
1455
1456     return True
1457
1458   def _TryLock(self, *args):
1459     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1460                                       *args)
1461
1462   def testTimeout(self):
1463     for blocking in [True, False]:
1464       self.lock.Exclusive(blocking=True)
1465       self.failIf(self._TryLock(False, blocking))
1466       self.failIf(self._TryLock(True, blocking))
1467
1468       self.lock.Shared(blocking=True)
1469       self.assert_(self._TryLock(True, blocking))
1470       self.failIf(self._TryLock(False, blocking))
1471
1472   def testCloseShared(self):
1473     self.lock.Close()
1474     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1475
1476   def testCloseExclusive(self):
1477     self.lock.Close()
1478     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1479
1480   def testCloseUnlock(self):
1481     self.lock.Close()
1482     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1483
1484
1485 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1486   TESTDATA = "Hello World\n" * 10
1487
1488   def setUp(self):
1489     testutils.GanetiTestCase.setUp(self)
1490
1491     self.tmpfile = tempfile.NamedTemporaryFile()
1492     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1493     self.lock = utils.FileLock.Open(self.tmpfile.name)
1494
1495     # Ensure "Open" didn't truncate file
1496     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1497
1498   def tearDown(self):
1499     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1500
1501     testutils.GanetiTestCase.tearDown(self)
1502
1503
1504 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1505   def setUp(self):
1506     self.tmpfile = tempfile.NamedTemporaryFile()
1507     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1508
1509
1510 class TestTimeFunctions(unittest.TestCase):
1511   """Test case for time functions"""
1512
1513   def runTest(self):
1514     self.assertEqual(utils.SplitTime(1), (1, 0))
1515     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1516     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1517     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1518     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1519     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1520     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1521     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1522
1523     self.assertRaises(AssertionError, utils.SplitTime, -1)
1524
1525     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1526     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1527     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1528
1529     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1530                      1218448917.481)
1531     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1532
1533     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1534     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1535     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1536     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1537     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1538
1539
1540 class FieldSetTestCase(unittest.TestCase):
1541   """Test case for FieldSets"""
1542
1543   def testSimpleMatch(self):
1544     f = utils.FieldSet("a", "b", "c", "def")
1545     self.failUnless(f.Matches("a"))
1546     self.failIf(f.Matches("d"), "Substring matched")
1547     self.failIf(f.Matches("defghi"), "Prefix string matched")
1548     self.failIf(f.NonMatching(["b", "c"]))
1549     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1550     self.failUnless(f.NonMatching(["a", "d"]))
1551
1552   def testRegexMatch(self):
1553     f = utils.FieldSet("a", "b([0-9]+)", "c")
1554     self.failUnless(f.Matches("b1"))
1555     self.failUnless(f.Matches("b99"))
1556     self.failIf(f.Matches("b/1"))
1557     self.failIf(f.NonMatching(["b12", "c"]))
1558     self.failUnless(f.NonMatching(["a", "1"]))
1559
1560 class TestForceDictType(unittest.TestCase):
1561   """Test case for ForceDictType"""
1562   KEY_TYPES = {
1563     "a": constants.VTYPE_INT,
1564     "b": constants.VTYPE_BOOL,
1565     "c": constants.VTYPE_STRING,
1566     "d": constants.VTYPE_SIZE,
1567     "e": constants.VTYPE_MAYBE_STRING,
1568     }
1569
1570   def _fdt(self, dict, allowed_values=None):
1571     if allowed_values is None:
1572       utils.ForceDictType(dict, self.KEY_TYPES)
1573     else:
1574       utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
1575
1576     return dict
1577
1578   def testSimpleDict(self):
1579     self.assertEqual(self._fdt({}), {})
1580     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1581     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1582     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1583     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1584     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1585     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1586     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1587     self.assertEqual(self._fdt({'b': False}), {'b': False})
1588     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1589     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1590     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1591     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1592     self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1593     self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1594     self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1595     self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
1596
1597   def testErrors(self):
1598     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1599     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
1600     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1601     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1602     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1603     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1604     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1605     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
1606     self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
1607     self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
1608                       {"b": "hello"}, {"b": "no-such-type"})
1609
1610
1611 class TestIsNormAbsPath(unittest.TestCase):
1612   """Testing case for IsNormAbsPath"""
1613
1614   def _pathTestHelper(self, path, result):
1615     if result:
1616       self.assert_(utils.IsNormAbsPath(path),
1617           "Path %s should result absolute and normalized" % path)
1618     else:
1619       self.assertFalse(utils.IsNormAbsPath(path),
1620           "Path %s should not result absolute and normalized" % path)
1621
1622   def testBase(self):
1623     self._pathTestHelper('/etc', True)
1624     self._pathTestHelper('/srv', True)
1625     self._pathTestHelper('etc', False)
1626     self._pathTestHelper('/etc/../root', False)
1627     self._pathTestHelper('/etc/', False)
1628
1629
1630 class TestSafeEncode(unittest.TestCase):
1631   """Test case for SafeEncode"""
1632
1633   def testAscii(self):
1634     for txt in [string.digits, string.letters, string.punctuation]:
1635       self.failUnlessEqual(txt, SafeEncode(txt))
1636
1637   def testDoubleEncode(self):
1638     for i in range(255):
1639       txt = SafeEncode(chr(i))
1640       self.failUnlessEqual(txt, SafeEncode(txt))
1641
1642   def testUnicode(self):
1643     # 1024 is high enough to catch non-direct ASCII mappings
1644     for i in range(1024):
1645       txt = SafeEncode(unichr(i))
1646       self.failUnlessEqual(txt, SafeEncode(txt))
1647
1648
1649 class TestFormatTime(unittest.TestCase):
1650   """Testing case for FormatTime"""
1651
1652   def testNone(self):
1653     self.failUnlessEqual(FormatTime(None), "N/A")
1654
1655   def testInvalid(self):
1656     self.failUnlessEqual(FormatTime(()), "N/A")
1657
1658   def testNow(self):
1659     # tests that we accept time.time input
1660     FormatTime(time.time())
1661     # tests that we accept int input
1662     FormatTime(int(time.time()))
1663
1664
1665 class RunInSeparateProcess(unittest.TestCase):
1666   def test(self):
1667     for exp in [True, False]:
1668       def _child():
1669         return exp
1670
1671       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1672
1673   def testArgs(self):
1674     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1675       def _child(carg1, carg2):
1676         return carg1 == "Foo" and carg2 == arg
1677
1678       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1679
1680   def testPid(self):
1681     parent_pid = os.getpid()
1682
1683     def _check():
1684       return os.getpid() == parent_pid
1685
1686     self.failIf(utils.RunInSeparateProcess(_check))
1687
1688   def testSignal(self):
1689     def _kill():
1690       os.kill(os.getpid(), signal.SIGTERM)
1691
1692     self.assertRaises(errors.GenericError,
1693                       utils.RunInSeparateProcess, _kill)
1694
1695   def testException(self):
1696     def _exc():
1697       raise errors.GenericError("This is a test")
1698
1699     self.assertRaises(errors.GenericError,
1700                       utils.RunInSeparateProcess, _exc)
1701
1702
1703 class TestFingerprintFiles(unittest.TestCase):
1704   def setUp(self):
1705     self.tmpfile = tempfile.NamedTemporaryFile()
1706     self.tmpfile2 = tempfile.NamedTemporaryFile()
1707     utils.WriteFile(self.tmpfile2.name, data="Hello World\n")
1708     self.results = {
1709       self.tmpfile.name: "da39a3ee5e6b4b0d3255bfef95601890afd80709",
1710       self.tmpfile2.name: "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a",
1711       }
1712
1713   def testSingleFile(self):
1714     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1715                      self.results[self.tmpfile.name])
1716
1717     self.assertEqual(utils._FingerprintFile("/no/such/file"), None)
1718
1719   def testBigFile(self):
1720     self.tmpfile.write("A" * 8192)
1721     self.tmpfile.flush()
1722     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1723                      "35b6795ca20d6dc0aff8c7c110c96cd1070b8c38")
1724
1725   def testMultiple(self):
1726     all_files = self.results.keys()
1727     all_files.append("/no/such/file")
1728     self.assertEqual(utils.FingerprintFiles(self.results.keys()), self.results)
1729
1730
1731 class TestUnescapeAndSplit(unittest.TestCase):
1732   """Testing case for UnescapeAndSplit"""
1733
1734   def setUp(self):
1735     # testing more that one separator for regexp safety
1736     self._seps = [",", "+", "."]
1737
1738   def testSimple(self):
1739     a = ["a", "b", "c", "d"]
1740     for sep in self._seps:
1741       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1742
1743   def testEscape(self):
1744     for sep in self._seps:
1745       a = ["a", "b\\" + sep + "c", "d"]
1746       b = ["a", "b" + sep + "c", "d"]
1747       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1748
1749   def testDoubleEscape(self):
1750     for sep in self._seps:
1751       a = ["a", "b\\\\", "c", "d"]
1752       b = ["a", "b\\", "c", "d"]
1753       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1754
1755   def testThreeEscape(self):
1756     for sep in self._seps:
1757       a = ["a", "b\\\\\\" + sep + "c", "d"]
1758       b = ["a", "b\\" + sep + "c", "d"]
1759       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1760
1761
1762 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1763   def setUp(self):
1764     self.tmpdir = tempfile.mkdtemp()
1765
1766   def tearDown(self):
1767     shutil.rmtree(self.tmpdir)
1768
1769   def _checkRsaPrivateKey(self, key):
1770     lines = key.splitlines()
1771     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1772             "-----END RSA PRIVATE KEY-----" in lines)
1773
1774   def _checkCertificate(self, cert):
1775     lines = cert.splitlines()
1776     return ("-----BEGIN CERTIFICATE-----" in lines and
1777             "-----END CERTIFICATE-----" in lines)
1778
1779   def test(self):
1780     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1781       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1782       self._checkRsaPrivateKey(key_pem)
1783       self._checkCertificate(cert_pem)
1784
1785       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1786                                            key_pem)
1787       self.assert_(key.bits() >= 1024)
1788       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1789       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1790
1791       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1792                                              cert_pem)
1793       self.failIf(x509.has_expired())
1794       self.assertEqual(x509.get_issuer().CN, common_name)
1795       self.assertEqual(x509.get_subject().CN, common_name)
1796       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1797
1798   def testLegacy(self):
1799     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1800
1801     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1802
1803     cert1 = utils.ReadFile(cert1_filename)
1804
1805     self.assert_(self._checkRsaPrivateKey(cert1))
1806     self.assert_(self._checkCertificate(cert1))
1807
1808
1809 class TestPathJoin(unittest.TestCase):
1810   """Testing case for PathJoin"""
1811
1812   def testBasicItems(self):
1813     mlist = ["/a", "b", "c"]
1814     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1815
1816   def testNonAbsPrefix(self):
1817     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1818
1819   def testBackTrack(self):
1820     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1821
1822   def testMultiAbs(self):
1823     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1824
1825
1826 class TestValidateServiceName(unittest.TestCase):
1827   def testValid(self):
1828     testnames = [
1829       0, 1, 2, 3, 1024, 65000, 65534, 65535,
1830       "ganeti",
1831       "gnt-masterd",
1832       "HELLO_WORLD_SVC",
1833       "hello.world.1",
1834       "0", "80", "1111", "65535",
1835       ]
1836
1837     for name in testnames:
1838       self.assertEqual(utils.ValidateServiceName(name), name)
1839
1840   def testInvalid(self):
1841     testnames = [
1842       -15756, -1, 65536, 133428083,
1843       "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1844       "-8546", "-1", "65536",
1845       (129 * "A"),
1846       ]
1847
1848     for name in testnames:
1849       self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1850
1851
1852 class TestParseAsn1Generalizedtime(unittest.TestCase):
1853   def test(self):
1854     # UTC
1855     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1856     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1857                      1266860512)
1858     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1859                      (2**31) - 1)
1860
1861     # With offset
1862     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1863                      1266860512)
1864     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1865                      1266931012)
1866     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1867                      1266931088)
1868     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1869                      1266931295)
1870     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1871                      3600)
1872
1873     # Leap seconds are not supported by datetime.datetime
1874     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1875                       "19841231235960+0000")
1876     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1877                       "19920630235960+0000")
1878
1879     # Errors
1880     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1881     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1882     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1883                       "20100222174152")
1884     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1885                       "Mon Feb 22 17:47:02 UTC 2010")
1886     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1887                       "2010-02-22 17:42:02")
1888
1889
1890 class TestGetX509CertValidity(testutils.GanetiTestCase):
1891   def setUp(self):
1892     testutils.GanetiTestCase.setUp(self)
1893
1894     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1895
1896     # Test whether we have pyOpenSSL 0.7 or above
1897     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1898
1899     if not self.pyopenssl0_7:
1900       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1901                     " function correctly")
1902
1903   def _LoadCert(self, name):
1904     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1905                                            self._ReadTestData(name))
1906
1907   def test(self):
1908     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1909     if self.pyopenssl0_7:
1910       self.assertEqual(validity, (1266919967, 1267524767))
1911     else:
1912       self.assertEqual(validity, (None, None))
1913
1914
1915 class TestSignX509Certificate(unittest.TestCase):
1916   KEY = "My private key!"
1917   KEY_OTHER = "Another key"
1918
1919   def test(self):
1920     # Generate certificate valid for 5 minutes
1921     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1922
1923     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1924                                            cert_pem)
1925
1926     # No signature at all
1927     self.assertRaises(errors.GenericError,
1928                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1929
1930     # Invalid input
1931     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1932                       "", self.KEY)
1933     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1934                       "X-Ganeti-Signature: \n", self.KEY)
1935     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1936                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1937     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1938                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1939     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1940                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1941
1942     # Invalid salt
1943     for salt in list("-_@$,:;/\\ \t\n"):
1944       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1945                         cert_pem, self.KEY, "foo%sbar" % salt)
1946
1947     for salt in ["HelloWorld", "salt", string.letters, string.digits,
1948                  utils.GenerateSecret(numbytes=4),
1949                  utils.GenerateSecret(numbytes=16),
1950                  "{123:456}".encode("hex")]:
1951       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1952
1953       self._Check(cert, salt, signed_pem)
1954
1955       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1956       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1957       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1958                                "lines----\n------ at\nthe end!"))
1959
1960   def _Check(self, cert, salt, pem):
1961     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1962     self.assertEqual(salt, salt2)
1963     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1964
1965     # Other key
1966     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1967                       pem, self.KEY_OTHER)
1968
1969
1970 class TestMakedirs(unittest.TestCase):
1971   def setUp(self):
1972     self.tmpdir = tempfile.mkdtemp()
1973
1974   def tearDown(self):
1975     shutil.rmtree(self.tmpdir)
1976
1977   def testNonExisting(self):
1978     path = PathJoin(self.tmpdir, "foo")
1979     utils.Makedirs(path)
1980     self.assert_(os.path.isdir(path))
1981
1982   def testExisting(self):
1983     path = PathJoin(self.tmpdir, "foo")
1984     os.mkdir(path)
1985     utils.Makedirs(path)
1986     self.assert_(os.path.isdir(path))
1987
1988   def testRecursiveNonExisting(self):
1989     path = PathJoin(self.tmpdir, "foo/bar/baz")
1990     utils.Makedirs(path)
1991     self.assert_(os.path.isdir(path))
1992
1993   def testRecursiveExisting(self):
1994     path = PathJoin(self.tmpdir, "B/moo/xyz")
1995     self.assertFalse(os.path.exists(path))
1996     os.mkdir(PathJoin(self.tmpdir, "B"))
1997     utils.Makedirs(path)
1998     self.assert_(os.path.isdir(path))
1999
2000
2001 class TestRetry(testutils.GanetiTestCase):
2002   def setUp(self):
2003     testutils.GanetiTestCase.setUp(self)
2004     self.retries = 0
2005
2006   @staticmethod
2007   def _RaiseRetryAgain():
2008     raise utils.RetryAgain()
2009
2010   @staticmethod
2011   def _RaiseRetryAgainWithArg(args):
2012     raise utils.RetryAgain(*args)
2013
2014   def _WrongNestedLoop(self):
2015     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2016
2017   def _RetryAndSucceed(self, retries):
2018     if self.retries < retries:
2019       self.retries += 1
2020       raise utils.RetryAgain()
2021     else:
2022       return True
2023
2024   def testRaiseTimeout(self):
2025     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2026                           self._RaiseRetryAgain, 0.01, 0.02)
2027     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2028                           self._RetryAndSucceed, 0.01, 0, args=[1])
2029     self.failUnlessEqual(self.retries, 1)
2030
2031   def testComplete(self):
2032     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2033     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2034                          True)
2035     self.failUnlessEqual(self.retries, 2)
2036
2037   def testNestedLoop(self):
2038     try:
2039       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2040                             self._WrongNestedLoop, 0, 1)
2041     except utils.RetryTimeout:
2042       self.fail("Didn't detect inner loop's exception")
2043
2044   def testTimeoutArgument(self):
2045     retry_arg="my_important_debugging_message"
2046     try:
2047       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2048     except utils.RetryTimeout, err:
2049       self.failUnlessEqual(err.args, (retry_arg, ))
2050     else:
2051       self.fail("Expected timeout didn't happen")
2052
2053   def testRaiseInnerWithExc(self):
2054     retry_arg="my_important_debugging_message"
2055     try:
2056       try:
2057         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2058                     args=[[errors.GenericError(retry_arg, retry_arg)]])
2059       except utils.RetryTimeout, err:
2060         err.RaiseInner()
2061       else:
2062         self.fail("Expected timeout didn't happen")
2063     except errors.GenericError, err:
2064       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2065     else:
2066       self.fail("Expected GenericError didn't happen")
2067
2068   def testRaiseInnerWithMsg(self):
2069     retry_arg="my_important_debugging_message"
2070     try:
2071       try:
2072         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2073                     args=[[retry_arg, retry_arg]])
2074       except utils.RetryTimeout, err:
2075         err.RaiseInner()
2076       else:
2077         self.fail("Expected timeout didn't happen")
2078     except utils.RetryTimeout, err:
2079       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2080     else:
2081       self.fail("Expected RetryTimeout didn't happen")
2082
2083
2084 class TestLineSplitter(unittest.TestCase):
2085   def test(self):
2086     lines = []
2087     ls = utils.LineSplitter(lines.append)
2088     ls.write("Hello World\n")
2089     self.assertEqual(lines, [])
2090     ls.write("Foo\n Bar\r\n ")
2091     ls.write("Baz")
2092     ls.write("Moo")
2093     self.assertEqual(lines, [])
2094     ls.flush()
2095     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2096     ls.close()
2097     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2098
2099   def _testExtra(self, line, all_lines, p1, p2):
2100     self.assertEqual(p1, 999)
2101     self.assertEqual(p2, "extra")
2102     all_lines.append(line)
2103
2104   def testExtraArgsNoFlush(self):
2105     lines = []
2106     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2107     ls.write("\n\nHello World\n")
2108     ls.write("Foo\n Bar\r\n ")
2109     ls.write("")
2110     ls.write("Baz")
2111     ls.write("Moo\n\nx\n")
2112     self.assertEqual(lines, [])
2113     ls.close()
2114     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2115                              "", "x"])
2116
2117
2118 class TestReadLockedPidFile(unittest.TestCase):
2119   def setUp(self):
2120     self.tmpdir = tempfile.mkdtemp()
2121
2122   def tearDown(self):
2123     shutil.rmtree(self.tmpdir)
2124
2125   def testNonExistent(self):
2126     path = PathJoin(self.tmpdir, "nonexist")
2127     self.assert_(utils.ReadLockedPidFile(path) is None)
2128
2129   def testUnlocked(self):
2130     path = PathJoin(self.tmpdir, "pid")
2131     utils.WriteFile(path, data="123")
2132     self.assert_(utils.ReadLockedPidFile(path) is None)
2133
2134   def testLocked(self):
2135     path = PathJoin(self.tmpdir, "pid")
2136     utils.WriteFile(path, data="123")
2137
2138     fl = utils.FileLock.Open(path)
2139     try:
2140       fl.Exclusive(blocking=True)
2141
2142       self.assertEqual(utils.ReadLockedPidFile(path), 123)
2143     finally:
2144       fl.Close()
2145
2146     self.assert_(utils.ReadLockedPidFile(path) is None)
2147
2148   def testError(self):
2149     path = PathJoin(self.tmpdir, "foobar", "pid")
2150     utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2151     # open(2) should return ENOTDIR
2152     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2153
2154
2155 class TestCertVerification(testutils.GanetiTestCase):
2156   def setUp(self):
2157     testutils.GanetiTestCase.setUp(self)
2158
2159     self.tmpdir = tempfile.mkdtemp()
2160
2161   def tearDown(self):
2162     shutil.rmtree(self.tmpdir)
2163
2164   def testVerifyCertificate(self):
2165     cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2166     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2167                                            cert_pem)
2168
2169     # Not checking return value as this certificate is expired
2170     utils.VerifyX509Certificate(cert, 30, 7)
2171
2172
2173 class TestVerifyCertificateInner(unittest.TestCase):
2174   def test(self):
2175     vci = utils._VerifyCertificateInner
2176
2177     # Valid
2178     self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2179                      (None, None))
2180
2181     # Not yet valid
2182     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2183     self.assertEqual(errcode, utils.CERT_WARNING)
2184
2185     # Expiring soon
2186     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2187     self.assertEqual(errcode, utils.CERT_ERROR)
2188
2189     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2190     self.assertEqual(errcode, utils.CERT_WARNING)
2191
2192     (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2193     self.assertEqual(errcode, None)
2194
2195     # Expired
2196     (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2197     self.assertEqual(errcode, utils.CERT_ERROR)
2198
2199     (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2200     self.assertEqual(errcode, utils.CERT_ERROR)
2201
2202     (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2203     self.assertEqual(errcode, utils.CERT_ERROR)
2204
2205     (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2206     self.assertEqual(errcode, utils.CERT_ERROR)
2207
2208
2209 class TestHmacFunctions(unittest.TestCase):
2210   # Digests can be checked with "openssl sha1 -hmac $key"
2211   def testSha1Hmac(self):
2212     self.assertEqual(utils.Sha1Hmac("", ""),
2213                      "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2214     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2215                      "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2216     self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2217                      "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2218
2219     longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2220     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2221                      "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2222
2223   def testSha1HmacSalt(self):
2224     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2225                      "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2226     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2227                      "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2228     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2229                      "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2230
2231   def testVerifySha1Hmac(self):
2232     self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2233                                                "7d64b71fb76370690e1d")))
2234     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2235                                       ("f904c2476527c6d3e660"
2236                                        "9ab683c66fa0652cb1dc")))
2237
2238     digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2239     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2240     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2241                                       digest.lower()))
2242     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2243                                       digest.upper()))
2244     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2245                                       digest.title()))
2246
2247   def testVerifySha1HmacSalt(self):
2248     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2249                                       ("17a4adc34d69c0d367d4"
2250                                        "ffbef96fd41d4df7a6e8"),
2251                                       salt="abc9"))
2252     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2253                                       ("7f264f8114c9066afc9b"
2254                                        "b7636e1786d996d3cc0d"),
2255                                       salt="xyz0"))
2256
2257
2258 class TestIgnoreSignals(unittest.TestCase):
2259   """Test the IgnoreSignals decorator"""
2260
2261   @staticmethod
2262   def _Raise(exception):
2263     raise exception
2264
2265   @staticmethod
2266   def _Return(rval):
2267     return rval
2268
2269   def testIgnoreSignals(self):
2270     sock_err_intr = socket.error(errno.EINTR, "Message")
2271     sock_err_inval = socket.error(errno.EINVAL, "Message")
2272
2273     env_err_intr = EnvironmentError(errno.EINTR, "Message")
2274     env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2275
2276     self.assertRaises(socket.error, self._Raise, sock_err_intr)
2277     self.assertRaises(socket.error, self._Raise, sock_err_inval)
2278     self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2279     self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2280
2281     self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2282     self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2283     self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2284                       sock_err_inval)
2285     self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2286                       env_err_inval)
2287
2288     self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2289     self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2290
2291
2292 class TestEnsureDirs(unittest.TestCase):
2293   """Tests for EnsureDirs"""
2294
2295   def setUp(self):
2296     self.dir = tempfile.mkdtemp()
2297     self.old_umask = os.umask(0777)
2298
2299   def testEnsureDirs(self):
2300     utils.EnsureDirs([
2301         (PathJoin(self.dir, "foo"), 0777),
2302         (PathJoin(self.dir, "bar"), 0000),
2303         ])
2304     self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2305     self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2306
2307   def tearDown(self):
2308     os.rmdir(PathJoin(self.dir, "foo"))
2309     os.rmdir(PathJoin(self.dir, "bar"))
2310     os.rmdir(self.dir)
2311     os.umask(self.old_umask)
2312
2313
2314 class TestFormatSeconds(unittest.TestCase):
2315   def test(self):
2316     self.assertEqual(utils.FormatSeconds(1), "1s")
2317     self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2318     self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2319     self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2320     self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2321     self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2322     self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2323     self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2324     self.assertEqual(utils.FormatSeconds(-1), "-1s")
2325     self.assertEqual(utils.FormatSeconds(-282), "-282s")
2326     self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2327
2328   def testFloat(self):
2329     self.assertEqual(utils.FormatSeconds(1.3), "1s")
2330     self.assertEqual(utils.FormatSeconds(1.9), "2s")
2331     self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2332     self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2333
2334
2335 class TestIgnoreProcessNotFound(unittest.TestCase):
2336   @staticmethod
2337   def _WritePid(fd):
2338     os.write(fd, str(os.getpid()))
2339     os.close(fd)
2340     return True
2341
2342   def test(self):
2343     (pid_read_fd, pid_write_fd) = os.pipe()
2344
2345     # Start short-lived process which writes its PID to pipe
2346     self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2347     os.close(pid_write_fd)
2348
2349     # Read PID from pipe
2350     pid = int(os.read(pid_read_fd, 1024))
2351     os.close(pid_read_fd)
2352
2353     # Try to send signal to process which exited recently
2354     self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2355
2356
2357 class TestShellWriter(unittest.TestCase):
2358   def test(self):
2359     buf = StringIO()
2360     sw = utils.ShellWriter(buf)
2361     sw.Write("#!/bin/bash")
2362     sw.Write("if true; then")
2363     sw.IncIndent()
2364     try:
2365       sw.Write("echo true")
2366
2367       sw.Write("for i in 1 2 3")
2368       sw.Write("do")
2369       sw.IncIndent()
2370       try:
2371         self.assertEqual(sw._indent, 2)
2372         sw.Write("date")
2373       finally:
2374         sw.DecIndent()
2375       sw.Write("done")
2376     finally:
2377       sw.DecIndent()
2378     sw.Write("echo %s", utils.ShellQuote("Hello World"))
2379     sw.Write("exit 0")
2380
2381     self.assertEqual(sw._indent, 0)
2382
2383     output = buf.getvalue()
2384
2385     self.assert_(output.endswith("\n"))
2386
2387     lines = output.splitlines()
2388     self.assertEqual(len(lines), 9)
2389     self.assertEqual(lines[0], "#!/bin/bash")
2390     self.assert_(re.match(r"^\s+date$", lines[5]))
2391     self.assertEqual(lines[7], "echo 'Hello World'")
2392
2393   def testEmpty(self):
2394     buf = StringIO()
2395     sw = utils.ShellWriter(buf)
2396     sw = None
2397     self.assertEqual(buf.getvalue(), "")
2398
2399
2400 class TestCommaJoin(unittest.TestCase):
2401   def test(self):
2402     self.assertEqual(utils.CommaJoin([]), "")
2403     self.assertEqual(utils.CommaJoin([1, 2, 3]), "1, 2, 3")
2404     self.assertEqual(utils.CommaJoin(["Hello"]), "Hello")
2405     self.assertEqual(utils.CommaJoin(["Hello", "World"]), "Hello, World")
2406     self.assertEqual(utils.CommaJoin(["Hello", "World", 99]),
2407                      "Hello, World, 99")
2408
2409
2410 class TestFindMatch(unittest.TestCase):
2411   def test(self):
2412     data = {
2413       "aaaa": "Four A",
2414       "bb": {"Two B": True},
2415       re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
2416       }
2417
2418     self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
2419     self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
2420
2421     for i in ["foo", "bar", "bazX"]:
2422       for j in range(1, 100, 7):
2423         self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
2424                          ((1, 2, 3), [i, str(j)]))
2425
2426   def testNoMatch(self):
2427     self.assert_(utils.FindMatch({}, "") is None)
2428     self.assert_(utils.FindMatch({}, "foo") is None)
2429     self.assert_(utils.FindMatch({}, 1234) is None)
2430
2431     data = {
2432       "X": "Hello World",
2433       re.compile("^(something)$"): "Hello World",
2434       }
2435
2436     self.assert_(utils.FindMatch(data, "") is None)
2437     self.assert_(utils.FindMatch(data, "Hello World") is None)
2438
2439
2440 class TestFileID(testutils.GanetiTestCase):
2441   def testEquality(self):
2442     name = self._CreateTempFile()
2443     oldi = utils.GetFileID(path=name)
2444     self.failUnless(utils.VerifyFileID(oldi, oldi))
2445
2446   def testUpdate(self):
2447     name = self._CreateTempFile()
2448     oldi = utils.GetFileID(path=name)
2449     os.utime(name, None)
2450     fd = os.open(name, os.O_RDWR)
2451     try:
2452       newi = utils.GetFileID(fd=fd)
2453       self.failUnless(utils.VerifyFileID(oldi, newi))
2454       self.failUnless(utils.VerifyFileID(newi, oldi))
2455     finally:
2456       os.close(fd)
2457
2458   def testWriteFile(self):
2459     name = self._CreateTempFile()
2460     oldi = utils.GetFileID(path=name)
2461     mtime = oldi[2]
2462     os.utime(name, (mtime + 10, mtime + 10))
2463     self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
2464                       oldi, data="")
2465     os.utime(name, (mtime - 10, mtime - 10))
2466     utils.SafeWriteFile(name, oldi, data="")
2467     oldi = utils.GetFileID(path=name)
2468     mtime = oldi[2]
2469     os.utime(name, (mtime + 10, mtime + 10))
2470     # this doesn't raise, since we passed None
2471     utils.SafeWriteFile(name, None, data="")
2472
2473   def testError(self):
2474     t = tempfile.NamedTemporaryFile()
2475     self.assertRaises(errors.ProgrammerError, utils.GetFileID,
2476                       path=t.name, fd=t.fileno())
2477
2478
2479 class TimeMock:
2480   def __init__(self, values):
2481     self.values = values
2482
2483   def __call__(self):
2484     return self.values.pop(0)
2485
2486
2487 class TestRunningTimeout(unittest.TestCase):
2488   def setUp(self):
2489     self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
2490
2491   def testRemainingFloat(self):
2492     timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
2493     self.assertAlmostEqual(timeout.Remaining(), 4.7)
2494     self.assertAlmostEqual(timeout.Remaining(), 0.4)
2495     self.assertAlmostEqual(timeout.Remaining(), -1.5)
2496
2497   def testRemaining(self):
2498     self.time_fn = TimeMock([0, 2, 4, 5, 6])
2499     timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
2500     self.assertEqual(timeout.Remaining(), 3)
2501     self.assertEqual(timeout.Remaining(), 1)
2502     self.assertEqual(timeout.Remaining(), 0)
2503     self.assertEqual(timeout.Remaining(), -1)
2504
2505   def testRemainingNonNegative(self):
2506     timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
2507     self.assertAlmostEqual(timeout.Remaining(), 4.7)
2508     self.assertAlmostEqual(timeout.Remaining(), 0.4)
2509     self.assertEqual(timeout.Remaining(), 0.0)
2510
2511   def testNegativeTimeout(self):
2512     self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
2513
2514
2515 class TestTryConvert(unittest.TestCase):
2516   def test(self):
2517     for src, fn, result in [
2518       ("1", int, 1),
2519       ("a", int, "a"),
2520       ("", bool, False),
2521       ("a", bool, True),
2522       ]:
2523       self.assertEqual(utils.TryConvert(fn, src), result)
2524
2525
2526 class TestIsValidShellParam(unittest.TestCase):
2527   def test(self):
2528     for val, result in [
2529       ("abc", True),
2530       ("ab;cd", False),
2531       ]:
2532       self.assertEqual(utils.IsValidShellParam(val), result)
2533
2534
2535 class TestBuildShellCmd(unittest.TestCase):
2536   def test(self):
2537     self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
2538                       "ls %s", "ab;cd")
2539     self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
2540
2541
2542 class TestWriteFile(unittest.TestCase):
2543   def setUp(self):
2544     self.tfile = tempfile.NamedTemporaryFile()
2545     self.did_pre = False
2546     self.did_post = False
2547     self.did_write = False
2548
2549   def markPre(self, fd):
2550     self.did_pre = True
2551
2552   def markPost(self, fd):
2553     self.did_post = True
2554
2555   def markWrite(self, fd):
2556     self.did_write = True
2557
2558   def testWrite(self):
2559     data = "abc"
2560     utils.WriteFile(self.tfile.name, data=data)
2561     self.assertEqual(utils.ReadFile(self.tfile.name), data)
2562
2563   def testErrors(self):
2564     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2565                       self.tfile.name, data="test", fn=lambda fd: None)
2566     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
2567     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2568                       self.tfile.name, data="test", atime=0)
2569
2570   def testCalls(self):
2571     utils.WriteFile(self.tfile.name, fn=self.markWrite,
2572                     prewrite=self.markPre, postwrite=self.markPost)
2573     self.assertTrue(self.did_pre)
2574     self.assertTrue(self.did_post)
2575     self.assertTrue(self.did_write)
2576
2577   def testDryRun(self):
2578     orig = "abc"
2579     self.tfile.write(orig)
2580     self.tfile.flush()
2581     utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
2582     self.assertEqual(utils.ReadFile(self.tfile.name), orig)
2583
2584   def testTimes(self):
2585     f = self.tfile.name
2586     for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
2587                    (int(time.time()), 5000)]:
2588       utils.WriteFile(f, data="hello", atime=at, mtime=mt)
2589       st = os.stat(f)
2590       self.assertEqual(st.st_atime, at)
2591       self.assertEqual(st.st_mtime, mt)
2592
2593
2594   def testNoClose(self):
2595     data = "hello"
2596     self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
2597     fd = utils.WriteFile(self.tfile.name, data=data, close=False)
2598     try:
2599       os.lseek(fd, 0, 0)
2600       self.assertEqual(os.read(fd, 4096), data)
2601     finally:
2602       os.close(fd)
2603
2604
2605 class TestNormalizeAndValidateMac(unittest.TestCase):
2606   def testInvalid(self):
2607     self.assertRaises(errors.OpPrereqError,
2608                       utils.NormalizeAndValidateMac, "xxx")
2609
2610   def testNormalization(self):
2611     for mac in ["aa:bb:cc:dd:ee:ff", "00:AA:11:bB:22:cc"]:
2612       self.assertEqual(utils.NormalizeAndValidateMac(mac), mac.lower())
2613
2614
2615 if __name__ == '__main__':
2616   testutils.GanetiTestProgram()