IPv6 support for utils.TcpPing()
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import fcntl
38 import OpenSSL
39 import warnings
40 import distutils.version
41 import glob
42 import errno
43
44 import ganeti
45 import testutils
46 from ganeti import constants
47 from ganeti import compat
48 from ganeti import utils
49 from ganeti import errors
50 from ganeti import serializer
51 from ganeti.utils import IsProcessAlive, RunCmd, \
52      RemoveFile, MatchNameComponent, FormatUnit, \
53      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
54      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
55      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
56      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
57      UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
58
59 from ganeti.errors import LockError, UnitParseError, GenericError, \
60      ProgrammerError, OpPrereqError
61
62
63 class TestIsProcessAlive(unittest.TestCase):
64   """Testing case for IsProcessAlive"""
65
66   def testExists(self):
67     mypid = os.getpid()
68     self.assert_(IsProcessAlive(mypid),
69                  "can't find myself running")
70
71   def testNotExisting(self):
72     pid_non_existing = os.fork()
73     if pid_non_existing == 0:
74       os._exit(0)
75     elif pid_non_existing < 0:
76       raise SystemError("can't fork")
77     os.waitpid(pid_non_existing, 0)
78     self.assertFalse(IsProcessAlive(pid_non_existing),
79                      "nonexisting process detected")
80
81
82 class TestGetProcStatusPath(unittest.TestCase):
83   def test(self):
84     self.assert_("/1234/" in utils._GetProcStatusPath(1234))
85     self.assertNotEqual(utils._GetProcStatusPath(1),
86                         utils._GetProcStatusPath(2))
87
88
89 class TestIsProcessHandlingSignal(unittest.TestCase):
90   def setUp(self):
91     self.tmpdir = tempfile.mkdtemp()
92
93   def tearDown(self):
94     shutil.rmtree(self.tmpdir)
95
96   def testParseSigsetT(self):
97     self.assertEqual(len(utils._ParseSigsetT("0")), 0)
98     self.assertEqual(utils._ParseSigsetT("1"), set([1]))
99     self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
100     self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
101     self.assertEqual(utils._ParseSigsetT("0000000180000202"),
102                      set([2, 10, 32, 33]))
103     self.assertEqual(utils._ParseSigsetT("0000000180000002"),
104                      set([2, 32, 33]))
105     self.assertEqual(utils._ParseSigsetT("0000000188000002"),
106                      set([2, 28, 32, 33]))
107     self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
108                      set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
109                           24, 25, 26, 28, 31]))
110     self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
111
112   def testGetProcStatusField(self):
113     for field in ["SigCgt", "Name", "FDSize"]:
114       for value in ["", "0", "cat", "  1234 KB"]:
115         pstatus = "\n".join([
116           "VmPeak: 999 kB",
117           "%s: %s" % (field, value),
118           "TracerPid: 0",
119           ])
120         result = utils._GetProcStatusField(pstatus, field)
121         self.assertEqual(result, value.strip())
122
123   def test(self):
124     sp = utils.PathJoin(self.tmpdir, "status")
125
126     utils.WriteFile(sp, data="\n".join([
127       "Name:   bash",
128       "State:  S (sleeping)",
129       "SleepAVG:       98%",
130       "Pid:    22250",
131       "PPid:   10858",
132       "TracerPid:      0",
133       "SigBlk: 0000000000010000",
134       "SigIgn: 0000000000384004",
135       "SigCgt: 000000004b813efb",
136       "CapEff: 0000000000000000",
137       ]))
138
139     self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
140
141   def testNoSigCgt(self):
142     sp = utils.PathJoin(self.tmpdir, "status")
143
144     utils.WriteFile(sp, data="\n".join([
145       "Name:   bash",
146       ]))
147
148     self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
149                       1234, 10, status_path=sp)
150
151   def testNoSuchFile(self):
152     sp = utils.PathJoin(self.tmpdir, "notexist")
153
154     self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
155
156   @staticmethod
157   def _TestRealProcess():
158     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
159     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160       raise Exception("SIGUSR1 is handled when it should not be")
161
162     signal.signal(signal.SIGUSR1, lambda signum, frame: None)
163     if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164       raise Exception("SIGUSR1 is not handled when it should be")
165
166     signal.signal(signal.SIGUSR1, signal.SIG_IGN)
167     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
168       raise Exception("SIGUSR1 is not handled when it should be")
169
170     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
171     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
172       raise Exception("SIGUSR1 is handled when it should not be")
173
174     return True
175
176   def testRealProcess(self):
177     self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
178
179
180 class TestPidFileFunctions(unittest.TestCase):
181   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
182
183   def setUp(self):
184     self.dir = tempfile.mkdtemp()
185     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
186     utils.DaemonPidFileName = self.f_dpn
187
188   def testPidFileFunctions(self):
189     pid_file = self.f_dpn('test')
190     utils.WritePidFile('test')
191     self.failUnless(os.path.exists(pid_file),
192                     "PID file should have been created")
193     read_pid = utils.ReadPidFile(pid_file)
194     self.failUnlessEqual(read_pid, os.getpid())
195     self.failUnless(utils.IsProcessAlive(read_pid))
196     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
197     utils.RemovePidFile('test')
198     self.failIf(os.path.exists(pid_file),
199                 "PID file should not exist anymore")
200     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
201                          "ReadPidFile should return 0 for missing pid file")
202     fh = open(pid_file, "w")
203     fh.write("blah\n")
204     fh.close()
205     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
206                          "ReadPidFile should return 0 for invalid pid file")
207     utils.RemovePidFile('test')
208     self.failIf(os.path.exists(pid_file),
209                 "PID file should not exist anymore")
210
211   def testKill(self):
212     pid_file = self.f_dpn('child')
213     r_fd, w_fd = os.pipe()
214     new_pid = os.fork()
215     if new_pid == 0: #child
216       utils.WritePidFile('child')
217       os.write(w_fd, 'a')
218       signal.pause()
219       os._exit(0)
220       return
221     # else we are in the parent
222     # wait until the child has written the pid file
223     os.read(r_fd, 1)
224     read_pid = utils.ReadPidFile(pid_file)
225     self.failUnlessEqual(read_pid, new_pid)
226     self.failUnless(utils.IsProcessAlive(new_pid))
227     utils.KillProcess(new_pid, waitpid=True)
228     self.failIf(utils.IsProcessAlive(new_pid))
229     utils.RemovePidFile('child')
230     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
231
232   def tearDown(self):
233     for name in os.listdir(self.dir):
234       os.unlink(os.path.join(self.dir, name))
235     os.rmdir(self.dir)
236
237
238 class TestRunCmd(testutils.GanetiTestCase):
239   """Testing case for the RunCmd function"""
240
241   def setUp(self):
242     testutils.GanetiTestCase.setUp(self)
243     self.magic = time.ctime() + " ganeti test"
244     self.fname = self._CreateTempFile()
245
246   def testOk(self):
247     """Test successful exit code"""
248     result = RunCmd("/bin/sh -c 'exit 0'")
249     self.assertEqual(result.exit_code, 0)
250     self.assertEqual(result.output, "")
251
252   def testFail(self):
253     """Test fail exit code"""
254     result = RunCmd("/bin/sh -c 'exit 1'")
255     self.assertEqual(result.exit_code, 1)
256     self.assertEqual(result.output, "")
257
258   def testStdout(self):
259     """Test standard output"""
260     cmd = 'echo -n "%s"' % self.magic
261     result = RunCmd("/bin/sh -c '%s'" % cmd)
262     self.assertEqual(result.stdout, self.magic)
263     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
264     self.assertEqual(result.output, "")
265     self.assertFileContent(self.fname, self.magic)
266
267   def testStderr(self):
268     """Test standard error"""
269     cmd = 'echo -n "%s"' % self.magic
270     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
271     self.assertEqual(result.stderr, self.magic)
272     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
273     self.assertEqual(result.output, "")
274     self.assertFileContent(self.fname, self.magic)
275
276   def testCombined(self):
277     """Test combined output"""
278     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
279     expected = "A" + self.magic + "B" + self.magic
280     result = RunCmd("/bin/sh -c '%s'" % cmd)
281     self.assertEqual(result.output, expected)
282     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
283     self.assertEqual(result.output, "")
284     self.assertFileContent(self.fname, expected)
285
286   def testSignal(self):
287     """Test signal"""
288     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
289     self.assertEqual(result.signal, 15)
290     self.assertEqual(result.output, "")
291
292   def testListRun(self):
293     """Test list runs"""
294     result = RunCmd(["true"])
295     self.assertEqual(result.signal, None)
296     self.assertEqual(result.exit_code, 0)
297     result = RunCmd(["/bin/sh", "-c", "exit 1"])
298     self.assertEqual(result.signal, None)
299     self.assertEqual(result.exit_code, 1)
300     result = RunCmd(["echo", "-n", self.magic])
301     self.assertEqual(result.signal, None)
302     self.assertEqual(result.exit_code, 0)
303     self.assertEqual(result.stdout, self.magic)
304
305   def testFileEmptyOutput(self):
306     """Test file output"""
307     result = RunCmd(["true"], output=self.fname)
308     self.assertEqual(result.signal, None)
309     self.assertEqual(result.exit_code, 0)
310     self.assertFileContent(self.fname, "")
311
312   def testLang(self):
313     """Test locale environment"""
314     old_env = os.environ.copy()
315     try:
316       os.environ["LANG"] = "en_US.UTF-8"
317       os.environ["LC_ALL"] = "en_US.UTF-8"
318       result = RunCmd(["locale"])
319       for line in result.output.splitlines():
320         key, value = line.split("=", 1)
321         # Ignore these variables, they're overridden by LC_ALL
322         if key == "LANG" or key == "LANGUAGE":
323           continue
324         self.failIf(value and value != "C" and value != '"C"',
325             "Variable %s is set to the invalid value '%s'" % (key, value))
326     finally:
327       os.environ = old_env
328
329   def testDefaultCwd(self):
330     """Test default working directory"""
331     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
332
333   def testCwd(self):
334     """Test default working directory"""
335     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
336     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
337     cwd = os.getcwd()
338     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
339
340   def testResetEnv(self):
341     """Test environment reset functionality"""
342     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
343     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
344                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
345
346
347 class TestRunParts(unittest.TestCase):
348   """Testing case for the RunParts function"""
349
350   def setUp(self):
351     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
352
353   def tearDown(self):
354     shutil.rmtree(self.rundir)
355
356   def testEmpty(self):
357     """Test on an empty dir"""
358     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
359
360   def testSkipWrongName(self):
361     """Test that wrong files are skipped"""
362     fname = os.path.join(self.rundir, "00test.dot")
363     utils.WriteFile(fname, data="")
364     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
365     relname = os.path.basename(fname)
366     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
367                          [(relname, constants.RUNPARTS_SKIP, None)])
368
369   def testSkipNonExec(self):
370     """Test that non executable files are skipped"""
371     fname = os.path.join(self.rundir, "00test")
372     utils.WriteFile(fname, data="")
373     relname = os.path.basename(fname)
374     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
375                          [(relname, constants.RUNPARTS_SKIP, None)])
376
377   def testError(self):
378     """Test error on a broken executable"""
379     fname = os.path.join(self.rundir, "00test")
380     utils.WriteFile(fname, data="")
381     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
382     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
383     self.failUnlessEqual(relname, os.path.basename(fname))
384     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
385     self.failUnless(error)
386
387   def testSorted(self):
388     """Test executions are sorted"""
389     files = []
390     files.append(os.path.join(self.rundir, "64test"))
391     files.append(os.path.join(self.rundir, "00test"))
392     files.append(os.path.join(self.rundir, "42test"))
393
394     for fname in files:
395       utils.WriteFile(fname, data="")
396
397     results = RunParts(self.rundir, reset_env=True)
398
399     for fname in sorted(files):
400       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
401
402   def testOk(self):
403     """Test correct execution"""
404     fname = os.path.join(self.rundir, "00test")
405     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
406     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408     self.failUnlessEqual(relname, os.path.basename(fname))
409     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410     self.failUnlessEqual(runresult.stdout, "ciao")
411
412   def testRunFail(self):
413     """Test correct execution, with run failure"""
414     fname = os.path.join(self.rundir, "00test")
415     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
416     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
417     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
418     self.failUnlessEqual(relname, os.path.basename(fname))
419     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
420     self.failUnlessEqual(runresult.exit_code, 1)
421     self.failUnless(runresult.failed)
422
423   def testRunMix(self):
424     files = []
425     files.append(os.path.join(self.rundir, "00test"))
426     files.append(os.path.join(self.rundir, "42test"))
427     files.append(os.path.join(self.rundir, "64test"))
428     files.append(os.path.join(self.rundir, "99test"))
429
430     files.sort()
431
432     # 1st has errors in execution
433     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
434     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
435
436     # 2nd is skipped
437     utils.WriteFile(files[1], data="")
438
439     # 3rd cannot execute properly
440     utils.WriteFile(files[2], data="")
441     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
442
443     # 4th execs
444     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
445     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
446
447     results = RunParts(self.rundir, reset_env=True)
448
449     (relname, status, runresult) = results[0]
450     self.failUnlessEqual(relname, os.path.basename(files[0]))
451     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
452     self.failUnlessEqual(runresult.exit_code, 1)
453     self.failUnless(runresult.failed)
454
455     (relname, status, runresult) = results[1]
456     self.failUnlessEqual(relname, os.path.basename(files[1]))
457     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
458     self.failUnlessEqual(runresult, None)
459
460     (relname, status, runresult) = results[2]
461     self.failUnlessEqual(relname, os.path.basename(files[2]))
462     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
463     self.failUnless(runresult)
464
465     (relname, status, runresult) = results[3]
466     self.failUnlessEqual(relname, os.path.basename(files[3]))
467     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
468     self.failUnlessEqual(runresult.output, "ciao")
469     self.failUnlessEqual(runresult.exit_code, 0)
470     self.failUnless(not runresult.failed)
471
472
473 class TestStartDaemon(testutils.GanetiTestCase):
474   def setUp(self):
475     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
476     self.tmpfile = os.path.join(self.tmpdir, "test")
477
478   def tearDown(self):
479     shutil.rmtree(self.tmpdir)
480
481   def testShell(self):
482     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
483     self._wait(self.tmpfile, 60.0, "Hello World")
484
485   def testShellOutput(self):
486     utils.StartDaemon("echo Hello World", output=self.tmpfile)
487     self._wait(self.tmpfile, 60.0, "Hello World")
488
489   def testNoShellNoOutput(self):
490     utils.StartDaemon(["pwd"])
491
492   def testNoShellNoOutputTouch(self):
493     testfile = os.path.join(self.tmpdir, "check")
494     self.failIf(os.path.exists(testfile))
495     utils.StartDaemon(["touch", testfile])
496     self._wait(testfile, 60.0, "")
497
498   def testNoShellOutput(self):
499     utils.StartDaemon(["pwd"], output=self.tmpfile)
500     self._wait(self.tmpfile, 60.0, "/")
501
502   def testNoShellOutputCwd(self):
503     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
504     self._wait(self.tmpfile, 60.0, os.getcwd())
505
506   def testShellEnv(self):
507     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
508                       env={ "GNT_TEST_VAR": "Hello World", })
509     self._wait(self.tmpfile, 60.0, "Hello World")
510
511   def testNoShellEnv(self):
512     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
513                       env={ "GNT_TEST_VAR": "Hello World", })
514     self._wait(self.tmpfile, 60.0, "Hello World")
515
516   def testOutputFd(self):
517     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
518     try:
519       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
520     finally:
521       os.close(fd)
522     self._wait(self.tmpfile, 60.0, os.getcwd())
523
524   def testPid(self):
525     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
526     self._wait(self.tmpfile, 60.0, str(pid))
527
528   def testPidFile(self):
529     pidfile = os.path.join(self.tmpdir, "pid")
530     checkfile = os.path.join(self.tmpdir, "abort")
531
532     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
533                             output=self.tmpfile)
534     try:
535       fd = os.open(pidfile, os.O_RDONLY)
536       try:
537         # Check file is locked
538         self.assertRaises(errors.LockError, utils.LockFile, fd)
539
540         pidtext = os.read(fd, 100)
541       finally:
542         os.close(fd)
543
544       self.assertEqual(int(pidtext.strip()), pid)
545
546       self.assert_(utils.IsProcessAlive(pid))
547     finally:
548       # No matter what happens, kill daemon
549       utils.KillProcess(pid, timeout=5.0, waitpid=False)
550       self.failIf(utils.IsProcessAlive(pid))
551
552     self.assertEqual(utils.ReadFile(self.tmpfile), "")
553
554   def _wait(self, path, timeout, expected):
555     # Due to the asynchronous nature of daemon processes, polling is necessary.
556     # A timeout makes sure the test doesn't hang forever.
557     def _CheckFile():
558       if not (os.path.isfile(path) and
559               utils.ReadFile(path).strip() == expected):
560         raise utils.RetryAgain()
561
562     try:
563       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
564     except utils.RetryTimeout:
565       self.fail("Apparently the daemon didn't run in %s seconds and/or"
566                 " didn't write the correct output" % timeout)
567
568   def testError(self):
569     self.assertRaises(errors.OpExecError, utils.StartDaemon,
570                       ["./does-NOT-EXIST/here/0123456789"])
571     self.assertRaises(errors.OpExecError, utils.StartDaemon,
572                       ["./does-NOT-EXIST/here/0123456789"],
573                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
574     self.assertRaises(errors.OpExecError, utils.StartDaemon,
575                       ["./does-NOT-EXIST/here/0123456789"],
576                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
577     self.assertRaises(errors.OpExecError, utils.StartDaemon,
578                       ["./does-NOT-EXIST/here/0123456789"],
579                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
580
581     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
582     try:
583       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
584                         ["./does-NOT-EXIST/here/0123456789"],
585                         output=self.tmpfile, output_fd=fd)
586     finally:
587       os.close(fd)
588
589
590 class TestSetCloseOnExecFlag(unittest.TestCase):
591   """Tests for SetCloseOnExecFlag"""
592
593   def setUp(self):
594     self.tmpfile = tempfile.TemporaryFile()
595
596   def testEnable(self):
597     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
598     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
599                     fcntl.FD_CLOEXEC)
600
601   def testDisable(self):
602     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
603     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
604                 fcntl.FD_CLOEXEC)
605
606
607 class TestSetNonblockFlag(unittest.TestCase):
608   def setUp(self):
609     self.tmpfile = tempfile.TemporaryFile()
610
611   def testEnable(self):
612     utils.SetNonblockFlag(self.tmpfile.fileno(), True)
613     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
614                     os.O_NONBLOCK)
615
616   def testDisable(self):
617     utils.SetNonblockFlag(self.tmpfile.fileno(), False)
618     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
619                 os.O_NONBLOCK)
620
621
622 class TestRemoveFile(unittest.TestCase):
623   """Test case for the RemoveFile function"""
624
625   def setUp(self):
626     """Create a temp dir and file for each case"""
627     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
628     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
629     os.close(fd)
630
631   def tearDown(self):
632     if os.path.exists(self.tmpfile):
633       os.unlink(self.tmpfile)
634     os.rmdir(self.tmpdir)
635
636   def testIgnoreDirs(self):
637     """Test that RemoveFile() ignores directories"""
638     self.assertEqual(None, RemoveFile(self.tmpdir))
639
640   def testIgnoreNotExisting(self):
641     """Test that RemoveFile() ignores non-existing files"""
642     RemoveFile(self.tmpfile)
643     RemoveFile(self.tmpfile)
644
645   def testRemoveFile(self):
646     """Test that RemoveFile does remove a file"""
647     RemoveFile(self.tmpfile)
648     if os.path.exists(self.tmpfile):
649       self.fail("File '%s' not removed" % self.tmpfile)
650
651   def testRemoveSymlink(self):
652     """Test that RemoveFile does remove symlinks"""
653     symlink = self.tmpdir + "/symlink"
654     os.symlink("no-such-file", symlink)
655     RemoveFile(symlink)
656     if os.path.exists(symlink):
657       self.fail("File '%s' not removed" % symlink)
658     os.symlink(self.tmpfile, symlink)
659     RemoveFile(symlink)
660     if os.path.exists(symlink):
661       self.fail("File '%s' not removed" % symlink)
662
663
664 class TestRename(unittest.TestCase):
665   """Test case for RenameFile"""
666
667   def setUp(self):
668     """Create a temporary directory"""
669     self.tmpdir = tempfile.mkdtemp()
670     self.tmpfile = os.path.join(self.tmpdir, "test1")
671
672     # Touch the file
673     open(self.tmpfile, "w").close()
674
675   def tearDown(self):
676     """Remove temporary directory"""
677     shutil.rmtree(self.tmpdir)
678
679   def testSimpleRename1(self):
680     """Simple rename 1"""
681     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
682     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
683
684   def testSimpleRename2(self):
685     """Simple rename 2"""
686     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
687                      mkdir=True)
688     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
689
690   def testRenameMkdir(self):
691     """Rename with mkdir"""
692     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
693                      mkdir=True)
694     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
695     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
696
697     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
698                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
699                      mkdir=True)
700     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
701     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
702     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
703
704
705 class TestMatchNameComponent(unittest.TestCase):
706   """Test case for the MatchNameComponent function"""
707
708   def testEmptyList(self):
709     """Test that there is no match against an empty list"""
710
711     self.failUnlessEqual(MatchNameComponent("", []), None)
712     self.failUnlessEqual(MatchNameComponent("test", []), None)
713
714   def testSingleMatch(self):
715     """Test that a single match is performed correctly"""
716     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
717     for key in "test2", "test2.example", "test2.example.com":
718       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
719
720   def testMultipleMatches(self):
721     """Test that a multiple match is returned as None"""
722     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
723     for key in "test1", "test1.example":
724       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
725
726   def testFullMatch(self):
727     """Test that a full match is returned correctly"""
728     key1 = "test1"
729     key2 = "test1.example"
730     mlist = [key2, key2 + ".com"]
731     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
732     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
733
734   def testCaseInsensitivePartialMatch(self):
735     """Test for the case_insensitive keyword"""
736     mlist = ["test1.example.com", "test2.example.net"]
737     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
738                      "test2.example.net")
739     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
740                      "test2.example.net")
741     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
742                      "test2.example.net")
743     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
744                      "test2.example.net")
745
746
747   def testCaseInsensitiveFullMatch(self):
748     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
749     # Between the two ts1 a full string match non-case insensitive should work
750     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
751                      None)
752     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
753                      "ts1.ex")
754     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
755                      "ts1.ex")
756     # Between the two ts2 only case differs, so only case-match works
757     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
758                      "ts2.ex")
759     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
760                      "Ts2.ex")
761     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
762                      None)
763
764
765 class TestReadFile(testutils.GanetiTestCase):
766
767   def testReadAll(self):
768     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
769     self.assertEqual(len(data), 814)
770
771     h = compat.md5_hash()
772     h.update(data)
773     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
774
775   def testReadSize(self):
776     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
777                           size=100)
778     self.assertEqual(len(data), 100)
779
780     h = compat.md5_hash()
781     h.update(data)
782     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
783
784   def testError(self):
785     self.assertRaises(EnvironmentError, utils.ReadFile,
786                       "/dev/null/does-not-exist")
787
788
789 class TestReadOneLineFile(testutils.GanetiTestCase):
790
791   def setUp(self):
792     testutils.GanetiTestCase.setUp(self)
793
794   def testDefault(self):
795     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
796     self.assertEqual(len(data), 27)
797     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
798
799   def testNotStrict(self):
800     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
801     self.assertEqual(len(data), 27)
802     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
803
804   def testStrictFailure(self):
805     self.assertRaises(errors.GenericError, ReadOneLineFile,
806                       self._TestDataFilename("cert1.pem"), strict=True)
807
808   def testLongLine(self):
809     dummydata = (1024 * "Hello World! ")
810     myfile = self._CreateTempFile()
811     utils.WriteFile(myfile, data=dummydata)
812     datastrict = ReadOneLineFile(myfile, strict=True)
813     datalax = ReadOneLineFile(myfile, strict=False)
814     self.assertEqual(dummydata, datastrict)
815     self.assertEqual(dummydata, datalax)
816
817   def testNewline(self):
818     myfile = self._CreateTempFile()
819     myline = "myline"
820     for nl in ["", "\n", "\r\n"]:
821       dummydata = "%s%s" % (myline, nl)
822       utils.WriteFile(myfile, data=dummydata)
823       datalax = ReadOneLineFile(myfile, strict=False)
824       self.assertEqual(myline, datalax)
825       datastrict = ReadOneLineFile(myfile, strict=True)
826       self.assertEqual(myline, datastrict)
827
828   def testWhitespaceAndMultipleLines(self):
829     myfile = self._CreateTempFile()
830     for nl in ["", "\n", "\r\n"]:
831       for ws in [" ", "\t", "\t\t  \t", "\t "]:
832         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
833         utils.WriteFile(myfile, data=dummydata)
834         datalax = ReadOneLineFile(myfile, strict=False)
835         if nl:
836           self.assert_(set("\r\n") & set(dummydata))
837           self.assertRaises(errors.GenericError, ReadOneLineFile,
838                             myfile, strict=True)
839           explen = len("Foo bar baz ") + len(ws)
840           self.assertEqual(len(datalax), explen)
841           self.assertEqual(datalax, dummydata[:explen])
842           self.assertFalse(set("\r\n") & set(datalax))
843         else:
844           datastrict = ReadOneLineFile(myfile, strict=True)
845           self.assertEqual(dummydata, datastrict)
846           self.assertEqual(dummydata, datalax)
847
848   def testEmptylines(self):
849     myfile = self._CreateTempFile()
850     myline = "myline"
851     for nl in ["\n", "\r\n"]:
852       for ol in ["", "otherline"]:
853         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
854         utils.WriteFile(myfile, data=dummydata)
855         self.assert_(set("\r\n") & set(dummydata))
856         datalax = ReadOneLineFile(myfile, strict=False)
857         self.assertEqual(myline, datalax)
858         if ol:
859           self.assertRaises(errors.GenericError, ReadOneLineFile,
860                             myfile, strict=True)
861         else:
862           datastrict = ReadOneLineFile(myfile, strict=True)
863           self.assertEqual(myline, datastrict)
864
865
866 class TestTimestampForFilename(unittest.TestCase):
867   def test(self):
868     self.assert_("." not in utils.TimestampForFilename())
869     self.assert_(":" not in utils.TimestampForFilename())
870
871
872 class TestCreateBackup(testutils.GanetiTestCase):
873   def setUp(self):
874     testutils.GanetiTestCase.setUp(self)
875
876     self.tmpdir = tempfile.mkdtemp()
877
878   def tearDown(self):
879     testutils.GanetiTestCase.tearDown(self)
880
881     shutil.rmtree(self.tmpdir)
882
883   def testEmpty(self):
884     filename = utils.PathJoin(self.tmpdir, "config.data")
885     utils.WriteFile(filename, data="")
886     bname = utils.CreateBackup(filename)
887     self.assertFileContent(bname, "")
888     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
889     utils.CreateBackup(filename)
890     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
891     utils.CreateBackup(filename)
892     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
893
894     fifoname = utils.PathJoin(self.tmpdir, "fifo")
895     os.mkfifo(fifoname)
896     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
897
898   def testContent(self):
899     bkpcount = 0
900     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
901       for rep in [1, 2, 10, 127]:
902         testdata = data * rep
903
904         filename = utils.PathJoin(self.tmpdir, "test.data_")
905         utils.WriteFile(filename, data=testdata)
906         self.assertFileContent(filename, testdata)
907
908         for _ in range(3):
909           bname = utils.CreateBackup(filename)
910           bkpcount += 1
911           self.assertFileContent(bname, testdata)
912           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
913
914
915 class TestFormatUnit(unittest.TestCase):
916   """Test case for the FormatUnit function"""
917
918   def testMiB(self):
919     self.assertEqual(FormatUnit(1, 'h'), '1M')
920     self.assertEqual(FormatUnit(100, 'h'), '100M')
921     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
922
923     self.assertEqual(FormatUnit(1, 'm'), '1')
924     self.assertEqual(FormatUnit(100, 'm'), '100')
925     self.assertEqual(FormatUnit(1023, 'm'), '1023')
926
927     self.assertEqual(FormatUnit(1024, 'm'), '1024')
928     self.assertEqual(FormatUnit(1536, 'm'), '1536')
929     self.assertEqual(FormatUnit(17133, 'm'), '17133')
930     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
931
932   def testGiB(self):
933     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
934     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
935     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
936     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
937
938     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
939     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
940     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
941     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
942
943     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
944     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
945     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
946
947   def testTiB(self):
948     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
949     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
950     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
951
952     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
953     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
954     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
955
956 class TestParseUnit(unittest.TestCase):
957   """Test case for the ParseUnit function"""
958
959   SCALES = (('', 1),
960             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
961             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
962             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
963
964   def testRounding(self):
965     self.assertEqual(ParseUnit('0'), 0)
966     self.assertEqual(ParseUnit('1'), 4)
967     self.assertEqual(ParseUnit('2'), 4)
968     self.assertEqual(ParseUnit('3'), 4)
969
970     self.assertEqual(ParseUnit('124'), 124)
971     self.assertEqual(ParseUnit('125'), 128)
972     self.assertEqual(ParseUnit('126'), 128)
973     self.assertEqual(ParseUnit('127'), 128)
974     self.assertEqual(ParseUnit('128'), 128)
975     self.assertEqual(ParseUnit('129'), 132)
976     self.assertEqual(ParseUnit('130'), 132)
977
978   def testFloating(self):
979     self.assertEqual(ParseUnit('0'), 0)
980     self.assertEqual(ParseUnit('0.5'), 4)
981     self.assertEqual(ParseUnit('1.75'), 4)
982     self.assertEqual(ParseUnit('1.99'), 4)
983     self.assertEqual(ParseUnit('2.00'), 4)
984     self.assertEqual(ParseUnit('2.01'), 4)
985     self.assertEqual(ParseUnit('3.99'), 4)
986     self.assertEqual(ParseUnit('4.00'), 4)
987     self.assertEqual(ParseUnit('4.01'), 8)
988     self.assertEqual(ParseUnit('1.5G'), 1536)
989     self.assertEqual(ParseUnit('1.8G'), 1844)
990     self.assertEqual(ParseUnit('8.28T'), 8682212)
991
992   def testSuffixes(self):
993     for sep in ('', ' ', '   ', "\t", "\t "):
994       for suffix, scale in TestParseUnit.SCALES:
995         for func in (lambda x: x, str.lower, str.upper):
996           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
997                            1024 * scale)
998
999   def testInvalidInput(self):
1000     for sep in ('-', '_', ',', 'a'):
1001       for suffix, _ in TestParseUnit.SCALES:
1002         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
1003
1004     for suffix, _ in TestParseUnit.SCALES:
1005       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
1006
1007
1008 class TestSshKeys(testutils.GanetiTestCase):
1009   """Test case for the AddAuthorizedKey function"""
1010
1011   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1012   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
1013            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1014
1015   def setUp(self):
1016     testutils.GanetiTestCase.setUp(self)
1017     self.tmpname = self._CreateTempFile()
1018     handle = open(self.tmpname, 'w')
1019     try:
1020       handle.write("%s\n" % TestSshKeys.KEY_A)
1021       handle.write("%s\n" % TestSshKeys.KEY_B)
1022     finally:
1023       handle.close()
1024
1025   def testAddingNewKey(self):
1026     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1027
1028     self.assertFileContent(self.tmpname,
1029       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1030       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1031       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1032       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1033
1034   def testAddingAlmostButNotCompletelyTheSameKey(self):
1035     AddAuthorizedKey(self.tmpname,
1036         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1037
1038     self.assertFileContent(self.tmpname,
1039       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1040       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1041       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1042       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1043
1044   def testAddingExistingKeyWithSomeMoreSpaces(self):
1045     AddAuthorizedKey(self.tmpname,
1046         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1047
1048     self.assertFileContent(self.tmpname,
1049       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1050       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1051       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1052
1053   def testRemovingExistingKeyWithSomeMoreSpaces(self):
1054     RemoveAuthorizedKey(self.tmpname,
1055         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1056
1057     self.assertFileContent(self.tmpname,
1058       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1059       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1060
1061   def testRemovingNonExistingKey(self):
1062     RemoveAuthorizedKey(self.tmpname,
1063         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1064
1065     self.assertFileContent(self.tmpname,
1066       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1067       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1068       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1069
1070
1071 class TestEtcHosts(testutils.GanetiTestCase):
1072   """Test functions modifying /etc/hosts"""
1073
1074   def setUp(self):
1075     testutils.GanetiTestCase.setUp(self)
1076     self.tmpname = self._CreateTempFile()
1077     handle = open(self.tmpname, 'w')
1078     try:
1079       handle.write('# This is a test file for /etc/hosts\n')
1080       handle.write('127.0.0.1\tlocalhost\n')
1081       handle.write('192.168.1.1 router gw\n')
1082     finally:
1083       handle.close()
1084
1085   def testSettingNewIp(self):
1086     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
1087
1088     self.assertFileContent(self.tmpname,
1089       "# This is a test file for /etc/hosts\n"
1090       "127.0.0.1\tlocalhost\n"
1091       "192.168.1.1 router gw\n"
1092       "1.2.3.4\tmyhost.domain.tld myhost\n")
1093     self.assertFileMode(self.tmpname, 0644)
1094
1095   def testSettingExistingIp(self):
1096     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
1097                      ['myhost'])
1098
1099     self.assertFileContent(self.tmpname,
1100       "# This is a test file for /etc/hosts\n"
1101       "127.0.0.1\tlocalhost\n"
1102       "192.168.1.1\tmyhost.domain.tld myhost\n")
1103     self.assertFileMode(self.tmpname, 0644)
1104
1105   def testSettingDuplicateName(self):
1106     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
1107
1108     self.assertFileContent(self.tmpname,
1109       "# This is a test file for /etc/hosts\n"
1110       "127.0.0.1\tlocalhost\n"
1111       "192.168.1.1 router gw\n"
1112       "1.2.3.4\tmyhost\n")
1113     self.assertFileMode(self.tmpname, 0644)
1114
1115   def testRemovingExistingHost(self):
1116     RemoveEtcHostsEntry(self.tmpname, 'router')
1117
1118     self.assertFileContent(self.tmpname,
1119       "# This is a test file for /etc/hosts\n"
1120       "127.0.0.1\tlocalhost\n"
1121       "192.168.1.1 gw\n")
1122     self.assertFileMode(self.tmpname, 0644)
1123
1124   def testRemovingSingleExistingHost(self):
1125     RemoveEtcHostsEntry(self.tmpname, 'localhost')
1126
1127     self.assertFileContent(self.tmpname,
1128       "# This is a test file for /etc/hosts\n"
1129       "192.168.1.1 router gw\n")
1130     self.assertFileMode(self.tmpname, 0644)
1131
1132   def testRemovingNonExistingHost(self):
1133     RemoveEtcHostsEntry(self.tmpname, 'myhost')
1134
1135     self.assertFileContent(self.tmpname,
1136       "# This is a test file for /etc/hosts\n"
1137       "127.0.0.1\tlocalhost\n"
1138       "192.168.1.1 router gw\n")
1139     self.assertFileMode(self.tmpname, 0644)
1140
1141   def testRemovingAlias(self):
1142     RemoveEtcHostsEntry(self.tmpname, 'gw')
1143
1144     self.assertFileContent(self.tmpname,
1145       "# This is a test file for /etc/hosts\n"
1146       "127.0.0.1\tlocalhost\n"
1147       "192.168.1.1 router\n")
1148     self.assertFileMode(self.tmpname, 0644)
1149
1150
1151 class TestShellQuoting(unittest.TestCase):
1152   """Test case for shell quoting functions"""
1153
1154   def testShellQuote(self):
1155     self.assertEqual(ShellQuote('abc'), "abc")
1156     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1157     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1158     self.assertEqual(ShellQuote("a b c"), "'a b c'")
1159     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1160
1161   def testShellQuoteArgs(self):
1162     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1163     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1164     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1165
1166
1167 class _BaseTcpPingTest:
1168   """Base class for TcpPing tests against listen(2)ing port"""
1169   family = None
1170   address = None
1171
1172   def setUp(self):
1173     self.listener = socket.socket(self.family, socket.SOCK_STREAM)
1174     self.listener.bind((self.address, 0))
1175     self.listenerport = self.listener.getsockname()[1]
1176     self.listener.listen(1)
1177
1178   def tearDown(self):
1179     self.listener.shutdown(socket.SHUT_RDWR)
1180     del self.listener
1181     del self.listenerport
1182
1183   def testTcpPingToLocalHostAccept(self):
1184     self.assert_(TcpPing(self.address,
1185                          self.listenerport,
1186                          timeout=constants.TCP_PING_TIMEOUT,
1187                          live_port_needed=True,
1188                          source=self.address,
1189                          ),
1190                  "failed to connect to test listener")
1191
1192     self.assert_(TcpPing(self.address, self.listenerport,
1193                          timeout=constants.TCP_PING_TIMEOUT,
1194                          live_port_needed=True),
1195                  "failed to connect to test listener (no source)")
1196
1197
1198 class TestIP4TcpPing(unittest.TestCase, _BaseTcpPingTest):
1199   """Testcase for IPv4 TCP version of ping - against listen(2)ing port"""
1200   family = socket.AF_INET
1201   address = constants.IP4_ADDRESS_LOCALHOST
1202
1203   def setUp(self):
1204     unittest.TestCase.setUp(self)
1205     _BaseTcpPingTest.setUp(self)
1206
1207   def tearDown(self):
1208     unittest.TestCase.tearDown(self)
1209     _BaseTcpPingTest.tearDown(self)
1210
1211
1212 class TestIP6TcpPing(unittest.TestCase, _BaseTcpPingTest):
1213   """Testcase for IPv6 TCP version of ping - against listen(2)ing port"""
1214   family = socket.AF_INET6
1215   address = constants.IP6_ADDRESS_LOCALHOST
1216
1217   def setUp(self):
1218     unittest.TestCase.setUp(self)
1219     _BaseTcpPingTest.setUp(self)
1220
1221   def tearDown(self):
1222     unittest.TestCase.tearDown(self)
1223     _BaseTcpPingTest.tearDown(self)
1224
1225
1226 class _BaseTcpPingDeafTest:
1227   """Base class for TcpPing tests against non listen(2)ing port"""
1228   family = None
1229   address = None
1230
1231   def setUp(self):
1232     self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
1233     self.deaflistener.bind((self.address, 0))
1234     self.deaflistenerport = self.deaflistener.getsockname()[1]
1235
1236   def tearDown(self):
1237     del self.deaflistener
1238     del self.deaflistenerport
1239
1240   def testTcpPingToLocalHostAcceptDeaf(self):
1241     self.assertFalse(TcpPing(self.address,
1242                              self.deaflistenerport,
1243                              timeout=constants.TCP_PING_TIMEOUT,
1244                              live_port_needed=True,
1245                              source=self.address,
1246                              ), # need successful connect(2)
1247                      "successfully connected to deaf listener")
1248
1249     self.assertFalse(TcpPing(self.address,
1250                              self.deaflistenerport,
1251                              timeout=constants.TCP_PING_TIMEOUT,
1252                              live_port_needed=True,
1253                              ), # need successful connect(2)
1254                      "successfully connected to deaf listener (no source)")
1255
1256   def testTcpPingToLocalHostNoAccept(self):
1257     self.assert_(TcpPing(self.address,
1258                          self.deaflistenerport,
1259                          timeout=constants.TCP_PING_TIMEOUT,
1260                          live_port_needed=False,
1261                          source=self.address,
1262                          ), # ECONNREFUSED is OK
1263                  "failed to ping alive host on deaf port")
1264
1265     self.assert_(TcpPing(self.address,
1266                          self.deaflistenerport,
1267                          timeout=constants.TCP_PING_TIMEOUT,
1268                          live_port_needed=False,
1269                          ), # ECONNREFUSED is OK
1270                  "failed to ping alive host on deaf port (no source)")
1271
1272
1273 class TestIP4TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
1274   """Testcase for IPv4 TCP version of ping - against non listen(2)ing port"""
1275   family = socket.AF_INET
1276   address = constants.IP4_ADDRESS_LOCALHOST
1277
1278   def setUp(self):
1279     self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
1280     self.deaflistener.bind((self.address, 0))
1281     self.deaflistenerport = self.deaflistener.getsockname()[1]
1282
1283   def tearDown(self):
1284     del self.deaflistener
1285     del self.deaflistenerport
1286
1287
1288 class TestIP6TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
1289   """Testcase for IPv6 TCP version of ping - against non listen(2)ing port"""
1290   family = socket.AF_INET6
1291   address = constants.IP6_ADDRESS_LOCALHOST
1292
1293   def setUp(self):
1294     unittest.TestCase.setUp(self)
1295     _BaseTcpPingDeafTest.setUp(self)
1296
1297   def tearDown(self):
1298     unittest.TestCase.tearDown(self)
1299     _BaseTcpPingDeafTest.tearDown(self)
1300
1301
1302 class TestOwnIpAddress(unittest.TestCase):
1303   """Testcase for OwnIpAddress"""
1304
1305   def testOwnLoopback(self):
1306     """check having the loopback ip"""
1307     self.failUnless(OwnIpAddress(constants.IP4_ADDRESS_LOCALHOST),
1308                     "Should own the loopback address")
1309
1310   def testNowOwnAddress(self):
1311     """check that I don't own an address"""
1312
1313     # Network 192.0.2.0/24 is reserved for test/documentation as per
1314     # RFC 5735, so we *should* not have an address of this range... if
1315     # this fails, we should extend the test to multiple addresses
1316     DST_IP = "192.0.2.1"
1317     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1318
1319
1320 def _GetSocketCredentials(path):
1321   """Connect to a Unix socket and return remote credentials.
1322
1323   """
1324   sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1325   try:
1326     sock.settimeout(10)
1327     sock.connect(path)
1328     return utils.GetSocketCredentials(sock)
1329   finally:
1330     sock.close()
1331
1332
1333 class TestGetSocketCredentials(unittest.TestCase):
1334   def setUp(self):
1335     self.tmpdir = tempfile.mkdtemp()
1336     self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1337
1338     self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1339     self.listener.settimeout(10)
1340     self.listener.bind(self.sockpath)
1341     self.listener.listen(1)
1342
1343   def tearDown(self):
1344     self.listener.shutdown(socket.SHUT_RDWR)
1345     self.listener.close()
1346     shutil.rmtree(self.tmpdir)
1347
1348   def test(self):
1349     (c2pr, c2pw) = os.pipe()
1350
1351     # Start child process
1352     child = os.fork()
1353     if child == 0:
1354       try:
1355         data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1356
1357         os.write(c2pw, data)
1358         os.close(c2pw)
1359
1360         os._exit(0)
1361       finally:
1362         os._exit(1)
1363
1364     os.close(c2pw)
1365
1366     # Wait for one connection
1367     (conn, _) = self.listener.accept()
1368     conn.recv(1)
1369     conn.close()
1370
1371     # Wait for result
1372     result = os.read(c2pr, 4096)
1373     os.close(c2pr)
1374
1375     # Check child's exit code
1376     (_, status) = os.waitpid(child, 0)
1377     self.assertFalse(os.WIFSIGNALED(status))
1378     self.assertEqual(os.WEXITSTATUS(status), 0)
1379
1380     # Check result
1381     (pid, uid, gid) = serializer.LoadJson(result)
1382     self.assertEqual(pid, os.getpid())
1383     self.assertEqual(uid, os.getuid())
1384     self.assertEqual(gid, os.getgid())
1385
1386
1387 class TestListVisibleFiles(unittest.TestCase):
1388   """Test case for ListVisibleFiles"""
1389
1390   def setUp(self):
1391     self.path = tempfile.mkdtemp()
1392
1393   def tearDown(self):
1394     shutil.rmtree(self.path)
1395
1396   def _CreateFiles(self, files):
1397     for name in files:
1398       utils.WriteFile(os.path.join(self.path, name), data="test")
1399
1400   def _test(self, files, expected):
1401     self._CreateFiles(files)
1402     found = ListVisibleFiles(self.path)
1403     self.assertEqual(set(found), set(expected))
1404
1405   def testAllVisible(self):
1406     files = ["a", "b", "c"]
1407     expected = files
1408     self._test(files, expected)
1409
1410   def testNoneVisible(self):
1411     files = [".a", ".b", ".c"]
1412     expected = []
1413     self._test(files, expected)
1414
1415   def testSomeVisible(self):
1416     files = ["a", "b", ".c"]
1417     expected = ["a", "b"]
1418     self._test(files, expected)
1419
1420   def testNonAbsolutePath(self):
1421     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1422
1423   def testNonNormalizedPath(self):
1424     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1425                           "/bin/../tmp")
1426
1427
1428 class TestNewUUID(unittest.TestCase):
1429   """Test case for NewUUID"""
1430
1431   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1432                         '[a-f0-9]{4}-[a-f0-9]{12}$')
1433
1434   def runTest(self):
1435     self.failUnless(self._re_uuid.match(utils.NewUUID()))
1436
1437
1438 class TestUniqueSequence(unittest.TestCase):
1439   """Test case for UniqueSequence"""
1440
1441   def _test(self, input, expected):
1442     self.assertEqual(utils.UniqueSequence(input), expected)
1443
1444   def runTest(self):
1445     # Ordered input
1446     self._test([1, 2, 3], [1, 2, 3])
1447     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1448     self._test([1, 2, 2, 3], [1, 2, 3])
1449     self._test([1, 2, 3, 3], [1, 2, 3])
1450
1451     # Unordered input
1452     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1453     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1454
1455     # Strings
1456     self._test(["a", "a"], ["a"])
1457     self._test(["a", "b"], ["a", "b"])
1458     self._test(["a", "b", "a"], ["a", "b"])
1459
1460
1461 class TestFirstFree(unittest.TestCase):
1462   """Test case for the FirstFree function"""
1463
1464   def test(self):
1465     """Test FirstFree"""
1466     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1467     self.failUnlessEqual(FirstFree([]), None)
1468     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1469     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1470     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1471
1472
1473 class TestTailFile(testutils.GanetiTestCase):
1474   """Test case for the TailFile function"""
1475
1476   def testEmpty(self):
1477     fname = self._CreateTempFile()
1478     self.failUnlessEqual(TailFile(fname), [])
1479     self.failUnlessEqual(TailFile(fname, lines=25), [])
1480
1481   def testAllLines(self):
1482     data = ["test %d" % i for i in range(30)]
1483     for i in range(30):
1484       fname = self._CreateTempFile()
1485       fd = open(fname, "w")
1486       fd.write("\n".join(data[:i]))
1487       if i > 0:
1488         fd.write("\n")
1489       fd.close()
1490       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1491
1492   def testPartialLines(self):
1493     data = ["test %d" % i for i in range(30)]
1494     fname = self._CreateTempFile()
1495     fd = open(fname, "w")
1496     fd.write("\n".join(data))
1497     fd.write("\n")
1498     fd.close()
1499     for i in range(1, 30):
1500       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1501
1502   def testBigFile(self):
1503     data = ["test %d" % i for i in range(30)]
1504     fname = self._CreateTempFile()
1505     fd = open(fname, "w")
1506     fd.write("X" * 1048576)
1507     fd.write("\n")
1508     fd.write("\n".join(data))
1509     fd.write("\n")
1510     fd.close()
1511     for i in range(1, 30):
1512       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1513
1514
1515 class _BaseFileLockTest:
1516   """Test case for the FileLock class"""
1517
1518   def testSharedNonblocking(self):
1519     self.lock.Shared(blocking=False)
1520     self.lock.Close()
1521
1522   def testExclusiveNonblocking(self):
1523     self.lock.Exclusive(blocking=False)
1524     self.lock.Close()
1525
1526   def testUnlockNonblocking(self):
1527     self.lock.Unlock(blocking=False)
1528     self.lock.Close()
1529
1530   def testSharedBlocking(self):
1531     self.lock.Shared(blocking=True)
1532     self.lock.Close()
1533
1534   def testExclusiveBlocking(self):
1535     self.lock.Exclusive(blocking=True)
1536     self.lock.Close()
1537
1538   def testUnlockBlocking(self):
1539     self.lock.Unlock(blocking=True)
1540     self.lock.Close()
1541
1542   def testSharedExclusiveUnlock(self):
1543     self.lock.Shared(blocking=False)
1544     self.lock.Exclusive(blocking=False)
1545     self.lock.Unlock(blocking=False)
1546     self.lock.Close()
1547
1548   def testExclusiveSharedUnlock(self):
1549     self.lock.Exclusive(blocking=False)
1550     self.lock.Shared(blocking=False)
1551     self.lock.Unlock(blocking=False)
1552     self.lock.Close()
1553
1554   def testSimpleTimeout(self):
1555     # These will succeed on the first attempt, hence a short timeout
1556     self.lock.Shared(blocking=True, timeout=10.0)
1557     self.lock.Exclusive(blocking=False, timeout=10.0)
1558     self.lock.Unlock(blocking=True, timeout=10.0)
1559     self.lock.Close()
1560
1561   @staticmethod
1562   def _TryLockInner(filename, shared, blocking):
1563     lock = utils.FileLock.Open(filename)
1564
1565     if shared:
1566       fn = lock.Shared
1567     else:
1568       fn = lock.Exclusive
1569
1570     try:
1571       # The timeout doesn't really matter as the parent process waits for us to
1572       # finish anyway.
1573       fn(blocking=blocking, timeout=0.01)
1574     except errors.LockError, err:
1575       return False
1576
1577     return True
1578
1579   def _TryLock(self, *args):
1580     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1581                                       *args)
1582
1583   def testTimeout(self):
1584     for blocking in [True, False]:
1585       self.lock.Exclusive(blocking=True)
1586       self.failIf(self._TryLock(False, blocking))
1587       self.failIf(self._TryLock(True, blocking))
1588
1589       self.lock.Shared(blocking=True)
1590       self.assert_(self._TryLock(True, blocking))
1591       self.failIf(self._TryLock(False, blocking))
1592
1593   def testCloseShared(self):
1594     self.lock.Close()
1595     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1596
1597   def testCloseExclusive(self):
1598     self.lock.Close()
1599     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1600
1601   def testCloseUnlock(self):
1602     self.lock.Close()
1603     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1604
1605
1606 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1607   TESTDATA = "Hello World\n" * 10
1608
1609   def setUp(self):
1610     testutils.GanetiTestCase.setUp(self)
1611
1612     self.tmpfile = tempfile.NamedTemporaryFile()
1613     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1614     self.lock = utils.FileLock.Open(self.tmpfile.name)
1615
1616     # Ensure "Open" didn't truncate file
1617     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1618
1619   def tearDown(self):
1620     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1621
1622     testutils.GanetiTestCase.tearDown(self)
1623
1624
1625 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1626   def setUp(self):
1627     self.tmpfile = tempfile.NamedTemporaryFile()
1628     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1629
1630
1631 class TestTimeFunctions(unittest.TestCase):
1632   """Test case for time functions"""
1633
1634   def runTest(self):
1635     self.assertEqual(utils.SplitTime(1), (1, 0))
1636     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1637     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1638     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1639     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1640     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1641     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1642     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1643
1644     self.assertRaises(AssertionError, utils.SplitTime, -1)
1645
1646     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1647     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1648     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1649
1650     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1651                      1218448917.481)
1652     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1653
1654     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1655     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1656     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1657     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1658     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1659
1660
1661 class FieldSetTestCase(unittest.TestCase):
1662   """Test case for FieldSets"""
1663
1664   def testSimpleMatch(self):
1665     f = utils.FieldSet("a", "b", "c", "def")
1666     self.failUnless(f.Matches("a"))
1667     self.failIf(f.Matches("d"), "Substring matched")
1668     self.failIf(f.Matches("defghi"), "Prefix string matched")
1669     self.failIf(f.NonMatching(["b", "c"]))
1670     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1671     self.failUnless(f.NonMatching(["a", "d"]))
1672
1673   def testRegexMatch(self):
1674     f = utils.FieldSet("a", "b([0-9]+)", "c")
1675     self.failUnless(f.Matches("b1"))
1676     self.failUnless(f.Matches("b99"))
1677     self.failIf(f.Matches("b/1"))
1678     self.failIf(f.NonMatching(["b12", "c"]))
1679     self.failUnless(f.NonMatching(["a", "1"]))
1680
1681 class TestForceDictType(unittest.TestCase):
1682   """Test case for ForceDictType"""
1683
1684   def setUp(self):
1685     self.key_types = {
1686       'a': constants.VTYPE_INT,
1687       'b': constants.VTYPE_BOOL,
1688       'c': constants.VTYPE_STRING,
1689       'd': constants.VTYPE_SIZE,
1690       }
1691
1692   def _fdt(self, dict, allowed_values=None):
1693     if allowed_values is None:
1694       ForceDictType(dict, self.key_types)
1695     else:
1696       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1697
1698     return dict
1699
1700   def testSimpleDict(self):
1701     self.assertEqual(self._fdt({}), {})
1702     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1703     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1704     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1705     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1706     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1707     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1708     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1709     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1710     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1711     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1712     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1713
1714   def testErrors(self):
1715     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1716     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1717     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1718     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1719
1720
1721 class TestIsNormAbsPath(unittest.TestCase):
1722   """Testing case for IsNormAbsPath"""
1723
1724   def _pathTestHelper(self, path, result):
1725     if result:
1726       self.assert_(IsNormAbsPath(path),
1727           "Path %s should result absolute and normalized" % path)
1728     else:
1729       self.assertFalse(IsNormAbsPath(path),
1730           "Path %s should not result absolute and normalized" % path)
1731
1732   def testBase(self):
1733     self._pathTestHelper('/etc', True)
1734     self._pathTestHelper('/srv', True)
1735     self._pathTestHelper('etc', False)
1736     self._pathTestHelper('/etc/../root', False)
1737     self._pathTestHelper('/etc/', False)
1738
1739
1740 class TestSafeEncode(unittest.TestCase):
1741   """Test case for SafeEncode"""
1742
1743   def testAscii(self):
1744     for txt in [string.digits, string.letters, string.punctuation]:
1745       self.failUnlessEqual(txt, SafeEncode(txt))
1746
1747   def testDoubleEncode(self):
1748     for i in range(255):
1749       txt = SafeEncode(chr(i))
1750       self.failUnlessEqual(txt, SafeEncode(txt))
1751
1752   def testUnicode(self):
1753     # 1024 is high enough to catch non-direct ASCII mappings
1754     for i in range(1024):
1755       txt = SafeEncode(unichr(i))
1756       self.failUnlessEqual(txt, SafeEncode(txt))
1757
1758
1759 class TestFormatTime(unittest.TestCase):
1760   """Testing case for FormatTime"""
1761
1762   def testNone(self):
1763     self.failUnlessEqual(FormatTime(None), "N/A")
1764
1765   def testInvalid(self):
1766     self.failUnlessEqual(FormatTime(()), "N/A")
1767
1768   def testNow(self):
1769     # tests that we accept time.time input
1770     FormatTime(time.time())
1771     # tests that we accept int input
1772     FormatTime(int(time.time()))
1773
1774
1775 class RunInSeparateProcess(unittest.TestCase):
1776   def test(self):
1777     for exp in [True, False]:
1778       def _child():
1779         return exp
1780
1781       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1782
1783   def testArgs(self):
1784     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1785       def _child(carg1, carg2):
1786         return carg1 == "Foo" and carg2 == arg
1787
1788       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1789
1790   def testPid(self):
1791     parent_pid = os.getpid()
1792
1793     def _check():
1794       return os.getpid() == parent_pid
1795
1796     self.failIf(utils.RunInSeparateProcess(_check))
1797
1798   def testSignal(self):
1799     def _kill():
1800       os.kill(os.getpid(), signal.SIGTERM)
1801
1802     self.assertRaises(errors.GenericError,
1803                       utils.RunInSeparateProcess, _kill)
1804
1805   def testException(self):
1806     def _exc():
1807       raise errors.GenericError("This is a test")
1808
1809     self.assertRaises(errors.GenericError,
1810                       utils.RunInSeparateProcess, _exc)
1811
1812
1813 class TestFingerprintFile(unittest.TestCase):
1814   def setUp(self):
1815     self.tmpfile = tempfile.NamedTemporaryFile()
1816
1817   def test(self):
1818     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1819                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1820
1821     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1822     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1823                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1824
1825
1826 class TestUnescapeAndSplit(unittest.TestCase):
1827   """Testing case for UnescapeAndSplit"""
1828
1829   def setUp(self):
1830     # testing more that one separator for regexp safety
1831     self._seps = [",", "+", "."]
1832
1833   def testSimple(self):
1834     a = ["a", "b", "c", "d"]
1835     for sep in self._seps:
1836       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1837
1838   def testEscape(self):
1839     for sep in self._seps:
1840       a = ["a", "b\\" + sep + "c", "d"]
1841       b = ["a", "b" + sep + "c", "d"]
1842       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1843
1844   def testDoubleEscape(self):
1845     for sep in self._seps:
1846       a = ["a", "b\\\\", "c", "d"]
1847       b = ["a", "b\\", "c", "d"]
1848       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1849
1850   def testThreeEscape(self):
1851     for sep in self._seps:
1852       a = ["a", "b\\\\\\" + sep + "c", "d"]
1853       b = ["a", "b\\" + sep + "c", "d"]
1854       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1855
1856
1857 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1858   def setUp(self):
1859     self.tmpdir = tempfile.mkdtemp()
1860
1861   def tearDown(self):
1862     shutil.rmtree(self.tmpdir)
1863
1864   def _checkRsaPrivateKey(self, key):
1865     lines = key.splitlines()
1866     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1867             "-----END RSA PRIVATE KEY-----" in lines)
1868
1869   def _checkCertificate(self, cert):
1870     lines = cert.splitlines()
1871     return ("-----BEGIN CERTIFICATE-----" in lines and
1872             "-----END CERTIFICATE-----" in lines)
1873
1874   def test(self):
1875     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1876       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1877       self._checkRsaPrivateKey(key_pem)
1878       self._checkCertificate(cert_pem)
1879
1880       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1881                                            key_pem)
1882       self.assert_(key.bits() >= 1024)
1883       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1884       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1885
1886       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1887                                              cert_pem)
1888       self.failIf(x509.has_expired())
1889       self.assertEqual(x509.get_issuer().CN, common_name)
1890       self.assertEqual(x509.get_subject().CN, common_name)
1891       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1892
1893   def testLegacy(self):
1894     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1895
1896     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1897
1898     cert1 = utils.ReadFile(cert1_filename)
1899
1900     self.assert_(self._checkRsaPrivateKey(cert1))
1901     self.assert_(self._checkCertificate(cert1))
1902
1903
1904 class TestPathJoin(unittest.TestCase):
1905   """Testing case for PathJoin"""
1906
1907   def testBasicItems(self):
1908     mlist = ["/a", "b", "c"]
1909     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1910
1911   def testNonAbsPrefix(self):
1912     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1913
1914   def testBackTrack(self):
1915     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1916
1917   def testMultiAbs(self):
1918     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1919
1920
1921 class TestHostInfo(unittest.TestCase):
1922   """Testing case for HostInfo"""
1923
1924   def testUppercase(self):
1925     data = "AbC.example.com"
1926     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1927
1928   def testTooLongName(self):
1929     data = "a.b." + "c" * 255
1930     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1931
1932   def testTrailingDot(self):
1933     data = "a.b.c"
1934     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1935
1936   def testInvalidName(self):
1937     data = [
1938       "a b",
1939       "a/b",
1940       ".a.b",
1941       "a..b",
1942       ]
1943     for value in data:
1944       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1945
1946   def testValidName(self):
1947     data = [
1948       "a.b",
1949       "a-b",
1950       "a_b",
1951       "a.b.c",
1952       ]
1953     for value in data:
1954       HostInfo.NormalizeName(value)
1955
1956
1957 class TestValidateServiceName(unittest.TestCase):
1958   def testValid(self):
1959     testnames = [
1960       0, 1, 2, 3, 1024, 65000, 65534, 65535,
1961       "ganeti",
1962       "gnt-masterd",
1963       "HELLO_WORLD_SVC",
1964       "hello.world.1",
1965       "0", "80", "1111", "65535",
1966       ]
1967
1968     for name in testnames:
1969       self.assertEqual(utils.ValidateServiceName(name), name)
1970
1971   def testInvalid(self):
1972     testnames = [
1973       -15756, -1, 65536, 133428083,
1974       "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1975       "-8546", "-1", "65536",
1976       (129 * "A"),
1977       ]
1978
1979     for name in testnames:
1980       self.assertRaises(OpPrereqError, utils.ValidateServiceName, name)
1981
1982
1983 class TestParseAsn1Generalizedtime(unittest.TestCase):
1984   def test(self):
1985     # UTC
1986     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1987     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1988                      1266860512)
1989     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1990                      (2**31) - 1)
1991
1992     # With offset
1993     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1994                      1266860512)
1995     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1996                      1266931012)
1997     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1998                      1266931088)
1999     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
2000                      1266931295)
2001     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
2002                      3600)
2003
2004     # Leap seconds are not supported by datetime.datetime
2005     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2006                       "19841231235960+0000")
2007     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2008                       "19920630235960+0000")
2009
2010     # Errors
2011     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
2012     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
2013     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2014                       "20100222174152")
2015     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2016                       "Mon Feb 22 17:47:02 UTC 2010")
2017     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2018                       "2010-02-22 17:42:02")
2019
2020
2021 class TestGetX509CertValidity(testutils.GanetiTestCase):
2022   def setUp(self):
2023     testutils.GanetiTestCase.setUp(self)
2024
2025     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
2026
2027     # Test whether we have pyOpenSSL 0.7 or above
2028     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
2029
2030     if not self.pyopenssl0_7:
2031       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
2032                     " function correctly")
2033
2034   def _LoadCert(self, name):
2035     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2036                                            self._ReadTestData(name))
2037
2038   def test(self):
2039     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
2040     if self.pyopenssl0_7:
2041       self.assertEqual(validity, (1266919967, 1267524767))
2042     else:
2043       self.assertEqual(validity, (None, None))
2044
2045
2046 class TestSignX509Certificate(unittest.TestCase):
2047   KEY = "My private key!"
2048   KEY_OTHER = "Another key"
2049
2050   def test(self):
2051     # Generate certificate valid for 5 minutes
2052     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
2053
2054     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2055                                            cert_pem)
2056
2057     # No signature at all
2058     self.assertRaises(errors.GenericError,
2059                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
2060
2061     # Invalid input
2062     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2063                       "", self.KEY)
2064     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2065                       "X-Ganeti-Signature: \n", self.KEY)
2066     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2067                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
2068     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2069                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
2070     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2071                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
2072
2073     # Invalid salt
2074     for salt in list("-_@$,:;/\\ \t\n"):
2075       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
2076                         cert_pem, self.KEY, "foo%sbar" % salt)
2077
2078     for salt in ["HelloWorld", "salt", string.letters, string.digits,
2079                  utils.GenerateSecret(numbytes=4),
2080                  utils.GenerateSecret(numbytes=16),
2081                  "{123:456}".encode("hex")]:
2082       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2083
2084       self._Check(cert, salt, signed_pem)
2085
2086       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2087       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2088       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2089                                "lines----\n------ at\nthe end!"))
2090
2091   def _Check(self, cert, salt, pem):
2092     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2093     self.assertEqual(salt, salt2)
2094     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2095
2096     # Other key
2097     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2098                       pem, self.KEY_OTHER)
2099
2100
2101 class TestMakedirs(unittest.TestCase):
2102   def setUp(self):
2103     self.tmpdir = tempfile.mkdtemp()
2104
2105   def tearDown(self):
2106     shutil.rmtree(self.tmpdir)
2107
2108   def testNonExisting(self):
2109     path = utils.PathJoin(self.tmpdir, "foo")
2110     utils.Makedirs(path)
2111     self.assert_(os.path.isdir(path))
2112
2113   def testExisting(self):
2114     path = utils.PathJoin(self.tmpdir, "foo")
2115     os.mkdir(path)
2116     utils.Makedirs(path)
2117     self.assert_(os.path.isdir(path))
2118
2119   def testRecursiveNonExisting(self):
2120     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
2121     utils.Makedirs(path)
2122     self.assert_(os.path.isdir(path))
2123
2124   def testRecursiveExisting(self):
2125     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
2126     self.assertFalse(os.path.exists(path))
2127     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
2128     utils.Makedirs(path)
2129     self.assert_(os.path.isdir(path))
2130
2131
2132 class TestRetry(testutils.GanetiTestCase):
2133   def setUp(self):
2134     testutils.GanetiTestCase.setUp(self)
2135     self.retries = 0
2136
2137   @staticmethod
2138   def _RaiseRetryAgain():
2139     raise utils.RetryAgain()
2140
2141   @staticmethod
2142   def _RaiseRetryAgainWithArg(args):
2143     raise utils.RetryAgain(*args)
2144
2145   def _WrongNestedLoop(self):
2146     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2147
2148   def _RetryAndSucceed(self, retries):
2149     if self.retries < retries:
2150       self.retries += 1
2151       raise utils.RetryAgain()
2152     else:
2153       return True
2154
2155   def testRaiseTimeout(self):
2156     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2157                           self._RaiseRetryAgain, 0.01, 0.02)
2158     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2159                           self._RetryAndSucceed, 0.01, 0, args=[1])
2160     self.failUnlessEqual(self.retries, 1)
2161
2162   def testComplete(self):
2163     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2164     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2165                          True)
2166     self.failUnlessEqual(self.retries, 2)
2167
2168   def testNestedLoop(self):
2169     try:
2170       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2171                             self._WrongNestedLoop, 0, 1)
2172     except utils.RetryTimeout:
2173       self.fail("Didn't detect inner loop's exception")
2174
2175   def testTimeoutArgument(self):
2176     retry_arg="my_important_debugging_message"
2177     try:
2178       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2179     except utils.RetryTimeout, err:
2180       self.failUnlessEqual(err.args, (retry_arg, ))
2181     else:
2182       self.fail("Expected timeout didn't happen")
2183
2184   def testRaiseInnerWithExc(self):
2185     retry_arg="my_important_debugging_message"
2186     try:
2187       try:
2188         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2189                     args=[[errors.GenericError(retry_arg, retry_arg)]])
2190       except utils.RetryTimeout, err:
2191         err.RaiseInner()
2192       else:
2193         self.fail("Expected timeout didn't happen")
2194     except errors.GenericError, err:
2195       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2196     else:
2197       self.fail("Expected GenericError didn't happen")
2198
2199   def testRaiseInnerWithMsg(self):
2200     retry_arg="my_important_debugging_message"
2201     try:
2202       try:
2203         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2204                     args=[[retry_arg, retry_arg]])
2205       except utils.RetryTimeout, err:
2206         err.RaiseInner()
2207       else:
2208         self.fail("Expected timeout didn't happen")
2209     except utils.RetryTimeout, err:
2210       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2211     else:
2212       self.fail("Expected RetryTimeout didn't happen")
2213
2214
2215 class TestLineSplitter(unittest.TestCase):
2216   def test(self):
2217     lines = []
2218     ls = utils.LineSplitter(lines.append)
2219     ls.write("Hello World\n")
2220     self.assertEqual(lines, [])
2221     ls.write("Foo\n Bar\r\n ")
2222     ls.write("Baz")
2223     ls.write("Moo")
2224     self.assertEqual(lines, [])
2225     ls.flush()
2226     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2227     ls.close()
2228     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2229
2230   def _testExtra(self, line, all_lines, p1, p2):
2231     self.assertEqual(p1, 999)
2232     self.assertEqual(p2, "extra")
2233     all_lines.append(line)
2234
2235   def testExtraArgsNoFlush(self):
2236     lines = []
2237     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2238     ls.write("\n\nHello World\n")
2239     ls.write("Foo\n Bar\r\n ")
2240     ls.write("")
2241     ls.write("Baz")
2242     ls.write("Moo\n\nx\n")
2243     self.assertEqual(lines, [])
2244     ls.close()
2245     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2246                              "", "x"])
2247
2248
2249 class TestReadLockedPidFile(unittest.TestCase):
2250   def setUp(self):
2251     self.tmpdir = tempfile.mkdtemp()
2252
2253   def tearDown(self):
2254     shutil.rmtree(self.tmpdir)
2255
2256   def testNonExistent(self):
2257     path = utils.PathJoin(self.tmpdir, "nonexist")
2258     self.assert_(utils.ReadLockedPidFile(path) is None)
2259
2260   def testUnlocked(self):
2261     path = utils.PathJoin(self.tmpdir, "pid")
2262     utils.WriteFile(path, data="123")
2263     self.assert_(utils.ReadLockedPidFile(path) is None)
2264
2265   def testLocked(self):
2266     path = utils.PathJoin(self.tmpdir, "pid")
2267     utils.WriteFile(path, data="123")
2268
2269     fl = utils.FileLock.Open(path)
2270     try:
2271       fl.Exclusive(blocking=True)
2272
2273       self.assertEqual(utils.ReadLockedPidFile(path), 123)
2274     finally:
2275       fl.Close()
2276
2277     self.assert_(utils.ReadLockedPidFile(path) is None)
2278
2279   def testError(self):
2280     path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2281     utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2282     # open(2) should return ENOTDIR
2283     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2284
2285
2286 class TestCertVerification(testutils.GanetiTestCase):
2287   def setUp(self):
2288     testutils.GanetiTestCase.setUp(self)
2289
2290     self.tmpdir = tempfile.mkdtemp()
2291
2292   def tearDown(self):
2293     shutil.rmtree(self.tmpdir)
2294
2295   def testVerifyCertificate(self):
2296     cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2297     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2298                                            cert_pem)
2299
2300     # Not checking return value as this certificate is expired
2301     utils.VerifyX509Certificate(cert, 30, 7)
2302
2303
2304 class TestVerifyCertificateInner(unittest.TestCase):
2305   def test(self):
2306     vci = utils._VerifyCertificateInner
2307
2308     # Valid
2309     self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2310                      (None, None))
2311
2312     # Not yet valid
2313     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2314     self.assertEqual(errcode, utils.CERT_WARNING)
2315
2316     # Expiring soon
2317     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2318     self.assertEqual(errcode, utils.CERT_ERROR)
2319
2320     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2321     self.assertEqual(errcode, utils.CERT_WARNING)
2322
2323     (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2324     self.assertEqual(errcode, None)
2325
2326     # Expired
2327     (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2328     self.assertEqual(errcode, utils.CERT_ERROR)
2329
2330     (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2331     self.assertEqual(errcode, utils.CERT_ERROR)
2332
2333     (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2334     self.assertEqual(errcode, utils.CERT_ERROR)
2335
2336     (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2337     self.assertEqual(errcode, utils.CERT_ERROR)
2338
2339
2340 class TestHmacFunctions(unittest.TestCase):
2341   # Digests can be checked with "openssl sha1 -hmac $key"
2342   def testSha1Hmac(self):
2343     self.assertEqual(utils.Sha1Hmac("", ""),
2344                      "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2345     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2346                      "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2347     self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2348                      "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2349
2350     longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2351     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2352                      "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2353
2354   def testSha1HmacSalt(self):
2355     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2356                      "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2357     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2358                      "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2359     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2360                      "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2361
2362   def testVerifySha1Hmac(self):
2363     self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2364                                                "7d64b71fb76370690e1d")))
2365     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2366                                       ("f904c2476527c6d3e660"
2367                                        "9ab683c66fa0652cb1dc")))
2368
2369     digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2370     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2371     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2372                                       digest.lower()))
2373     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2374                                       digest.upper()))
2375     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2376                                       digest.title()))
2377
2378   def testVerifySha1HmacSalt(self):
2379     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2380                                       ("17a4adc34d69c0d367d4"
2381                                        "ffbef96fd41d4df7a6e8"),
2382                                       salt="abc9"))
2383     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2384                                       ("7f264f8114c9066afc9b"
2385                                        "b7636e1786d996d3cc0d"),
2386                                       salt="xyz0"))
2387
2388
2389 class TestIgnoreSignals(unittest.TestCase):
2390   """Test the IgnoreSignals decorator"""
2391
2392   @staticmethod
2393   def _Raise(exception):
2394     raise exception
2395
2396   @staticmethod
2397   def _Return(rval):
2398     return rval
2399
2400   def testIgnoreSignals(self):
2401     sock_err_intr = socket.error(errno.EINTR, "Message")
2402     sock_err_inval = socket.error(errno.EINVAL, "Message")
2403
2404     env_err_intr = EnvironmentError(errno.EINTR, "Message")
2405     env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2406
2407     self.assertRaises(socket.error, self._Raise, sock_err_intr)
2408     self.assertRaises(socket.error, self._Raise, sock_err_inval)
2409     self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2410     self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2411
2412     self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2413     self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2414     self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2415                       sock_err_inval)
2416     self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2417                       env_err_inval)
2418
2419     self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2420     self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2421
2422
2423 class TestEnsureDirs(unittest.TestCase):
2424   """Tests for EnsureDirs"""
2425
2426   def setUp(self):
2427     self.dir = tempfile.mkdtemp()
2428     self.old_umask = os.umask(0777)
2429
2430   def testEnsureDirs(self):
2431     utils.EnsureDirs([
2432         (utils.PathJoin(self.dir, "foo"), 0777),
2433         (utils.PathJoin(self.dir, "bar"), 0000),
2434         ])
2435     self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2436     self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2437
2438   def tearDown(self):
2439     os.rmdir(utils.PathJoin(self.dir, "foo"))
2440     os.rmdir(utils.PathJoin(self.dir, "bar"))
2441     os.rmdir(self.dir)
2442     os.umask(self.old_umask)
2443
2444
2445 class TestFormatSeconds(unittest.TestCase):
2446   def test(self):
2447     self.assertEqual(utils.FormatSeconds(1), "1s")
2448     self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2449     self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2450     self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2451     self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2452     self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2453     self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2454     self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2455     self.assertEqual(utils.FormatSeconds(-1), "-1s")
2456     self.assertEqual(utils.FormatSeconds(-282), "-282s")
2457     self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2458
2459   def testFloat(self):
2460     self.assertEqual(utils.FormatSeconds(1.3), "1s")
2461     self.assertEqual(utils.FormatSeconds(1.9), "2s")
2462     self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2463     self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2464
2465
2466 class RunIgnoreProcessNotFound(unittest.TestCase):
2467   @staticmethod
2468   def _WritePid(fd):
2469     os.write(fd, str(os.getpid()))
2470     os.close(fd)
2471     return True
2472
2473   def test(self):
2474     (pid_read_fd, pid_write_fd) = os.pipe()
2475
2476     # Start short-lived process which writes its PID to pipe
2477     self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2478     os.close(pid_write_fd)
2479
2480     # Read PID from pipe
2481     pid = int(os.read(pid_read_fd, 1024))
2482     os.close(pid_read_fd)
2483
2484     # Try to send signal to process which exited recently
2485     self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2486
2487
2488 class TestIsValidIP4(unittest.TestCase):
2489   def test(self):
2490     self.assert_(utils.IsValidIP4("127.0.0.1"))
2491     self.assert_(utils.IsValidIP4("0.0.0.0"))
2492     self.assert_(utils.IsValidIP4("255.255.255.255"))
2493     self.assertFalse(utils.IsValidIP4("0"))
2494     self.assertFalse(utils.IsValidIP4("1"))
2495     self.assertFalse(utils.IsValidIP4("1.1.1"))
2496     self.assertFalse(utils.IsValidIP4("255.255.255.256"))
2497     self.assertFalse(utils.IsValidIP4("::1"))
2498
2499
2500 class TestIsValidIP6(unittest.TestCase):
2501   def test(self):
2502     self.assert_(utils.IsValidIP6("::"))
2503     self.assert_(utils.IsValidIP6("::1"))
2504     self.assert_(utils.IsValidIP6("1" + (":1" * 7)))
2505     self.assert_(utils.IsValidIP6("ffff" + (":ffff" * 7)))
2506     self.assertFalse(utils.IsValidIP6("0"))
2507     self.assertFalse(utils.IsValidIP6(":1"))
2508     self.assertFalse(utils.IsValidIP6("f" + (":f" * 6)))
2509     self.assertFalse(utils.IsValidIP6("fffg" + (":ffff" * 7)))
2510     self.assertFalse(utils.IsValidIP6("fffff" + (":ffff" * 7)))
2511     self.assertFalse(utils.IsValidIP6("1" + (":1" * 8)))
2512     self.assertFalse(utils.IsValidIP6("127.0.0.1"))
2513
2514
2515 class TestIsValidIP(unittest.TestCase):
2516   def test(self):
2517     self.assert_(utils.IsValidIP("0.0.0.0"))
2518     self.assert_(utils.IsValidIP("127.0.0.1"))
2519     self.assert_(utils.IsValidIP("::"))
2520     self.assert_(utils.IsValidIP("::1"))
2521     self.assertFalse(utils.IsValidIP("0"))
2522     self.assertFalse(utils.IsValidIP("1.1.1.256"))
2523     self.assertFalse(utils.IsValidIP("a:g::1"))
2524
2525
2526 class TestGetAddressFamily(unittest.TestCase):
2527   def test(self):
2528     self.assertEqual(utils.GetAddressFamily("127.0.0.1"), socket.AF_INET)
2529     self.assertEqual(utils.GetAddressFamily("10.2.0.127"), socket.AF_INET)
2530     self.assertEqual(utils.GetAddressFamily("::1"), socket.AF_INET6)
2531     self.assertEqual(utils.GetAddressFamily("fe80::a00:27ff:fe08:5048"),
2532                      socket.AF_INET6)
2533     self.assertRaises(errors.GenericError, utils.GetAddressFamily, "0")
2534
2535
2536 if __name__ == '__main__':
2537   testutils.GanetiTestProgram()