Add utils.GetMounts()
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import fcntl
38 import OpenSSL
39 import warnings
40 import distutils.version
41 import glob
42 import errno
43
44 import ganeti
45 import testutils
46 from ganeti import constants
47 from ganeti import compat
48 from ganeti import utils
49 from ganeti import errors
50 from ganeti import serializer
51 from ganeti.utils import IsProcessAlive, RunCmd, \
52      RemoveFile, MatchNameComponent, FormatUnit, \
53      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
54      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
55      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
56      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
57      UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
58
59 from ganeti.errors import LockError, UnitParseError, GenericError, \
60      ProgrammerError, OpPrereqError
61
62
63 class TestIsProcessAlive(unittest.TestCase):
64   """Testing case for IsProcessAlive"""
65
66   def testExists(self):
67     mypid = os.getpid()
68     self.assert_(IsProcessAlive(mypid),
69                  "can't find myself running")
70
71   def testNotExisting(self):
72     pid_non_existing = os.fork()
73     if pid_non_existing == 0:
74       os._exit(0)
75     elif pid_non_existing < 0:
76       raise SystemError("can't fork")
77     os.waitpid(pid_non_existing, 0)
78     self.assertFalse(IsProcessAlive(pid_non_existing),
79                      "nonexisting process detected")
80
81
82 class TestGetProcStatusPath(unittest.TestCase):
83   def test(self):
84     self.assert_("/1234/" in utils._GetProcStatusPath(1234))
85     self.assertNotEqual(utils._GetProcStatusPath(1),
86                         utils._GetProcStatusPath(2))
87
88
89 class TestIsProcessHandlingSignal(unittest.TestCase):
90   def setUp(self):
91     self.tmpdir = tempfile.mkdtemp()
92
93   def tearDown(self):
94     shutil.rmtree(self.tmpdir)
95
96   def testParseSigsetT(self):
97     self.assertEqual(len(utils._ParseSigsetT("0")), 0)
98     self.assertEqual(utils._ParseSigsetT("1"), set([1]))
99     self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
100     self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
101     self.assertEqual(utils._ParseSigsetT("0000000180000202"),
102                      set([2, 10, 32, 33]))
103     self.assertEqual(utils._ParseSigsetT("0000000180000002"),
104                      set([2, 32, 33]))
105     self.assertEqual(utils._ParseSigsetT("0000000188000002"),
106                      set([2, 28, 32, 33]))
107     self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
108                      set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
109                           24, 25, 26, 28, 31]))
110     self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
111
112   def testGetProcStatusField(self):
113     for field in ["SigCgt", "Name", "FDSize"]:
114       for value in ["", "0", "cat", "  1234 KB"]:
115         pstatus = "\n".join([
116           "VmPeak: 999 kB",
117           "%s: %s" % (field, value),
118           "TracerPid: 0",
119           ])
120         result = utils._GetProcStatusField(pstatus, field)
121         self.assertEqual(result, value.strip())
122
123   def test(self):
124     sp = utils.PathJoin(self.tmpdir, "status")
125
126     utils.WriteFile(sp, data="\n".join([
127       "Name:   bash",
128       "State:  S (sleeping)",
129       "SleepAVG:       98%",
130       "Pid:    22250",
131       "PPid:   10858",
132       "TracerPid:      0",
133       "SigBlk: 0000000000010000",
134       "SigIgn: 0000000000384004",
135       "SigCgt: 000000004b813efb",
136       "CapEff: 0000000000000000",
137       ]))
138
139     self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
140
141   def testNoSigCgt(self):
142     sp = utils.PathJoin(self.tmpdir, "status")
143
144     utils.WriteFile(sp, data="\n".join([
145       "Name:   bash",
146       ]))
147
148     self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
149                       1234, 10, status_path=sp)
150
151   def testNoSuchFile(self):
152     sp = utils.PathJoin(self.tmpdir, "notexist")
153
154     self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
155
156   @staticmethod
157   def _TestRealProcess():
158     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
159     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160       raise Exception("SIGUSR1 is handled when it should not be")
161
162     signal.signal(signal.SIGUSR1, lambda signum, frame: None)
163     if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164       raise Exception("SIGUSR1 is not handled when it should be")
165
166     signal.signal(signal.SIGUSR1, signal.SIG_IGN)
167     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
168       raise Exception("SIGUSR1 is not handled when it should be")
169
170     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
171     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
172       raise Exception("SIGUSR1 is handled when it should not be")
173
174     return True
175
176   def testRealProcess(self):
177     self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
178
179
180 class TestPidFileFunctions(unittest.TestCase):
181   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
182
183   def setUp(self):
184     self.dir = tempfile.mkdtemp()
185     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
186     utils.DaemonPidFileName = self.f_dpn
187
188   def testPidFileFunctions(self):
189     pid_file = self.f_dpn('test')
190     utils.WritePidFile('test')
191     self.failUnless(os.path.exists(pid_file),
192                     "PID file should have been created")
193     read_pid = utils.ReadPidFile(pid_file)
194     self.failUnlessEqual(read_pid, os.getpid())
195     self.failUnless(utils.IsProcessAlive(read_pid))
196     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
197     utils.RemovePidFile('test')
198     self.failIf(os.path.exists(pid_file),
199                 "PID file should not exist anymore")
200     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
201                          "ReadPidFile should return 0 for missing pid file")
202     fh = open(pid_file, "w")
203     fh.write("blah\n")
204     fh.close()
205     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
206                          "ReadPidFile should return 0 for invalid pid file")
207     utils.RemovePidFile('test')
208     self.failIf(os.path.exists(pid_file),
209                 "PID file should not exist anymore")
210
211   def testKill(self):
212     pid_file = self.f_dpn('child')
213     r_fd, w_fd = os.pipe()
214     new_pid = os.fork()
215     if new_pid == 0: #child
216       utils.WritePidFile('child')
217       os.write(w_fd, 'a')
218       signal.pause()
219       os._exit(0)
220       return
221     # else we are in the parent
222     # wait until the child has written the pid file
223     os.read(r_fd, 1)
224     read_pid = utils.ReadPidFile(pid_file)
225     self.failUnlessEqual(read_pid, new_pid)
226     self.failUnless(utils.IsProcessAlive(new_pid))
227     utils.KillProcess(new_pid, waitpid=True)
228     self.failIf(utils.IsProcessAlive(new_pid))
229     utils.RemovePidFile('child')
230     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
231
232   def tearDown(self):
233     for name in os.listdir(self.dir):
234       os.unlink(os.path.join(self.dir, name))
235     os.rmdir(self.dir)
236
237
238 class TestRunCmd(testutils.GanetiTestCase):
239   """Testing case for the RunCmd function"""
240
241   def setUp(self):
242     testutils.GanetiTestCase.setUp(self)
243     self.magic = time.ctime() + " ganeti test"
244     self.fname = self._CreateTempFile()
245
246   def testOk(self):
247     """Test successful exit code"""
248     result = RunCmd("/bin/sh -c 'exit 0'")
249     self.assertEqual(result.exit_code, 0)
250     self.assertEqual(result.output, "")
251
252   def testFail(self):
253     """Test fail exit code"""
254     result = RunCmd("/bin/sh -c 'exit 1'")
255     self.assertEqual(result.exit_code, 1)
256     self.assertEqual(result.output, "")
257
258   def testStdout(self):
259     """Test standard output"""
260     cmd = 'echo -n "%s"' % self.magic
261     result = RunCmd("/bin/sh -c '%s'" % cmd)
262     self.assertEqual(result.stdout, self.magic)
263     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
264     self.assertEqual(result.output, "")
265     self.assertFileContent(self.fname, self.magic)
266
267   def testStderr(self):
268     """Test standard error"""
269     cmd = 'echo -n "%s"' % self.magic
270     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
271     self.assertEqual(result.stderr, self.magic)
272     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
273     self.assertEqual(result.output, "")
274     self.assertFileContent(self.fname, self.magic)
275
276   def testCombined(self):
277     """Test combined output"""
278     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
279     expected = "A" + self.magic + "B" + self.magic
280     result = RunCmd("/bin/sh -c '%s'" % cmd)
281     self.assertEqual(result.output, expected)
282     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
283     self.assertEqual(result.output, "")
284     self.assertFileContent(self.fname, expected)
285
286   def testSignal(self):
287     """Test signal"""
288     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
289     self.assertEqual(result.signal, 15)
290     self.assertEqual(result.output, "")
291
292   def testListRun(self):
293     """Test list runs"""
294     result = RunCmd(["true"])
295     self.assertEqual(result.signal, None)
296     self.assertEqual(result.exit_code, 0)
297     result = RunCmd(["/bin/sh", "-c", "exit 1"])
298     self.assertEqual(result.signal, None)
299     self.assertEqual(result.exit_code, 1)
300     result = RunCmd(["echo", "-n", self.magic])
301     self.assertEqual(result.signal, None)
302     self.assertEqual(result.exit_code, 0)
303     self.assertEqual(result.stdout, self.magic)
304
305   def testFileEmptyOutput(self):
306     """Test file output"""
307     result = RunCmd(["true"], output=self.fname)
308     self.assertEqual(result.signal, None)
309     self.assertEqual(result.exit_code, 0)
310     self.assertFileContent(self.fname, "")
311
312   def testLang(self):
313     """Test locale environment"""
314     old_env = os.environ.copy()
315     try:
316       os.environ["LANG"] = "en_US.UTF-8"
317       os.environ["LC_ALL"] = "en_US.UTF-8"
318       result = RunCmd(["locale"])
319       for line in result.output.splitlines():
320         key, value = line.split("=", 1)
321         # Ignore these variables, they're overridden by LC_ALL
322         if key == "LANG" or key == "LANGUAGE":
323           continue
324         self.failIf(value and value != "C" and value != '"C"',
325             "Variable %s is set to the invalid value '%s'" % (key, value))
326     finally:
327       os.environ = old_env
328
329   def testDefaultCwd(self):
330     """Test default working directory"""
331     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
332
333   def testCwd(self):
334     """Test default working directory"""
335     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
336     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
337     cwd = os.getcwd()
338     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
339
340   def testResetEnv(self):
341     """Test environment reset functionality"""
342     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
343     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
344                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
345
346
347 class TestRunParts(unittest.TestCase):
348   """Testing case for the RunParts function"""
349
350   def setUp(self):
351     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
352
353   def tearDown(self):
354     shutil.rmtree(self.rundir)
355
356   def testEmpty(self):
357     """Test on an empty dir"""
358     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
359
360   def testSkipWrongName(self):
361     """Test that wrong files are skipped"""
362     fname = os.path.join(self.rundir, "00test.dot")
363     utils.WriteFile(fname, data="")
364     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
365     relname = os.path.basename(fname)
366     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
367                          [(relname, constants.RUNPARTS_SKIP, None)])
368
369   def testSkipNonExec(self):
370     """Test that non executable files are skipped"""
371     fname = os.path.join(self.rundir, "00test")
372     utils.WriteFile(fname, data="")
373     relname = os.path.basename(fname)
374     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
375                          [(relname, constants.RUNPARTS_SKIP, None)])
376
377   def testError(self):
378     """Test error on a broken executable"""
379     fname = os.path.join(self.rundir, "00test")
380     utils.WriteFile(fname, data="")
381     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
382     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
383     self.failUnlessEqual(relname, os.path.basename(fname))
384     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
385     self.failUnless(error)
386
387   def testSorted(self):
388     """Test executions are sorted"""
389     files = []
390     files.append(os.path.join(self.rundir, "64test"))
391     files.append(os.path.join(self.rundir, "00test"))
392     files.append(os.path.join(self.rundir, "42test"))
393
394     for fname in files:
395       utils.WriteFile(fname, data="")
396
397     results = RunParts(self.rundir, reset_env=True)
398
399     for fname in sorted(files):
400       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
401
402   def testOk(self):
403     """Test correct execution"""
404     fname = os.path.join(self.rundir, "00test")
405     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
406     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408     self.failUnlessEqual(relname, os.path.basename(fname))
409     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410     self.failUnlessEqual(runresult.stdout, "ciao")
411
412   def testRunFail(self):
413     """Test correct execution, with run failure"""
414     fname = os.path.join(self.rundir, "00test")
415     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
416     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
417     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
418     self.failUnlessEqual(relname, os.path.basename(fname))
419     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
420     self.failUnlessEqual(runresult.exit_code, 1)
421     self.failUnless(runresult.failed)
422
423   def testRunMix(self):
424     files = []
425     files.append(os.path.join(self.rundir, "00test"))
426     files.append(os.path.join(self.rundir, "42test"))
427     files.append(os.path.join(self.rundir, "64test"))
428     files.append(os.path.join(self.rundir, "99test"))
429
430     files.sort()
431
432     # 1st has errors in execution
433     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
434     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
435
436     # 2nd is skipped
437     utils.WriteFile(files[1], data="")
438
439     # 3rd cannot execute properly
440     utils.WriteFile(files[2], data="")
441     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
442
443     # 4th execs
444     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
445     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
446
447     results = RunParts(self.rundir, reset_env=True)
448
449     (relname, status, runresult) = results[0]
450     self.failUnlessEqual(relname, os.path.basename(files[0]))
451     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
452     self.failUnlessEqual(runresult.exit_code, 1)
453     self.failUnless(runresult.failed)
454
455     (relname, status, runresult) = results[1]
456     self.failUnlessEqual(relname, os.path.basename(files[1]))
457     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
458     self.failUnlessEqual(runresult, None)
459
460     (relname, status, runresult) = results[2]
461     self.failUnlessEqual(relname, os.path.basename(files[2]))
462     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
463     self.failUnless(runresult)
464
465     (relname, status, runresult) = results[3]
466     self.failUnlessEqual(relname, os.path.basename(files[3]))
467     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
468     self.failUnlessEqual(runresult.output, "ciao")
469     self.failUnlessEqual(runresult.exit_code, 0)
470     self.failUnless(not runresult.failed)
471
472
473 class TestStartDaemon(testutils.GanetiTestCase):
474   def setUp(self):
475     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
476     self.tmpfile = os.path.join(self.tmpdir, "test")
477
478   def tearDown(self):
479     shutil.rmtree(self.tmpdir)
480
481   def testShell(self):
482     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
483     self._wait(self.tmpfile, 60.0, "Hello World")
484
485   def testShellOutput(self):
486     utils.StartDaemon("echo Hello World", output=self.tmpfile)
487     self._wait(self.tmpfile, 60.0, "Hello World")
488
489   def testNoShellNoOutput(self):
490     utils.StartDaemon(["pwd"])
491
492   def testNoShellNoOutputTouch(self):
493     testfile = os.path.join(self.tmpdir, "check")
494     self.failIf(os.path.exists(testfile))
495     utils.StartDaemon(["touch", testfile])
496     self._wait(testfile, 60.0, "")
497
498   def testNoShellOutput(self):
499     utils.StartDaemon(["pwd"], output=self.tmpfile)
500     self._wait(self.tmpfile, 60.0, "/")
501
502   def testNoShellOutputCwd(self):
503     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
504     self._wait(self.tmpfile, 60.0, os.getcwd())
505
506   def testShellEnv(self):
507     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
508                       env={ "GNT_TEST_VAR": "Hello World", })
509     self._wait(self.tmpfile, 60.0, "Hello World")
510
511   def testNoShellEnv(self):
512     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
513                       env={ "GNT_TEST_VAR": "Hello World", })
514     self._wait(self.tmpfile, 60.0, "Hello World")
515
516   def testOutputFd(self):
517     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
518     try:
519       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
520     finally:
521       os.close(fd)
522     self._wait(self.tmpfile, 60.0, os.getcwd())
523
524   def testPid(self):
525     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
526     self._wait(self.tmpfile, 60.0, str(pid))
527
528   def testPidFile(self):
529     pidfile = os.path.join(self.tmpdir, "pid")
530     checkfile = os.path.join(self.tmpdir, "abort")
531
532     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
533                             output=self.tmpfile)
534     try:
535       fd = os.open(pidfile, os.O_RDONLY)
536       try:
537         # Check file is locked
538         self.assertRaises(errors.LockError, utils.LockFile, fd)
539
540         pidtext = os.read(fd, 100)
541       finally:
542         os.close(fd)
543
544       self.assertEqual(int(pidtext.strip()), pid)
545
546       self.assert_(utils.IsProcessAlive(pid))
547     finally:
548       # No matter what happens, kill daemon
549       utils.KillProcess(pid, timeout=5.0, waitpid=False)
550       self.failIf(utils.IsProcessAlive(pid))
551
552     self.assertEqual(utils.ReadFile(self.tmpfile), "")
553
554   def _wait(self, path, timeout, expected):
555     # Due to the asynchronous nature of daemon processes, polling is necessary.
556     # A timeout makes sure the test doesn't hang forever.
557     def _CheckFile():
558       if not (os.path.isfile(path) and
559               utils.ReadFile(path).strip() == expected):
560         raise utils.RetryAgain()
561
562     try:
563       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
564     except utils.RetryTimeout:
565       self.fail("Apparently the daemon didn't run in %s seconds and/or"
566                 " didn't write the correct output" % timeout)
567
568   def testError(self):
569     self.assertRaises(errors.OpExecError, utils.StartDaemon,
570                       ["./does-NOT-EXIST/here/0123456789"])
571     self.assertRaises(errors.OpExecError, utils.StartDaemon,
572                       ["./does-NOT-EXIST/here/0123456789"],
573                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
574     self.assertRaises(errors.OpExecError, utils.StartDaemon,
575                       ["./does-NOT-EXIST/here/0123456789"],
576                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
577     self.assertRaises(errors.OpExecError, utils.StartDaemon,
578                       ["./does-NOT-EXIST/here/0123456789"],
579                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
580
581     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
582     try:
583       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
584                         ["./does-NOT-EXIST/here/0123456789"],
585                         output=self.tmpfile, output_fd=fd)
586     finally:
587       os.close(fd)
588
589
590 class TestSetCloseOnExecFlag(unittest.TestCase):
591   """Tests for SetCloseOnExecFlag"""
592
593   def setUp(self):
594     self.tmpfile = tempfile.TemporaryFile()
595
596   def testEnable(self):
597     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
598     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
599                     fcntl.FD_CLOEXEC)
600
601   def testDisable(self):
602     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
603     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
604                 fcntl.FD_CLOEXEC)
605
606
607 class TestSetNonblockFlag(unittest.TestCase):
608   def setUp(self):
609     self.tmpfile = tempfile.TemporaryFile()
610
611   def testEnable(self):
612     utils.SetNonblockFlag(self.tmpfile.fileno(), True)
613     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
614                     os.O_NONBLOCK)
615
616   def testDisable(self):
617     utils.SetNonblockFlag(self.tmpfile.fileno(), False)
618     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
619                 os.O_NONBLOCK)
620
621
622 class TestRemoveFile(unittest.TestCase):
623   """Test case for the RemoveFile function"""
624
625   def setUp(self):
626     """Create a temp dir and file for each case"""
627     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
628     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
629     os.close(fd)
630
631   def tearDown(self):
632     if os.path.exists(self.tmpfile):
633       os.unlink(self.tmpfile)
634     os.rmdir(self.tmpdir)
635
636   def testIgnoreDirs(self):
637     """Test that RemoveFile() ignores directories"""
638     self.assertEqual(None, RemoveFile(self.tmpdir))
639
640   def testIgnoreNotExisting(self):
641     """Test that RemoveFile() ignores non-existing files"""
642     RemoveFile(self.tmpfile)
643     RemoveFile(self.tmpfile)
644
645   def testRemoveFile(self):
646     """Test that RemoveFile does remove a file"""
647     RemoveFile(self.tmpfile)
648     if os.path.exists(self.tmpfile):
649       self.fail("File '%s' not removed" % self.tmpfile)
650
651   def testRemoveSymlink(self):
652     """Test that RemoveFile does remove symlinks"""
653     symlink = self.tmpdir + "/symlink"
654     os.symlink("no-such-file", symlink)
655     RemoveFile(symlink)
656     if os.path.exists(symlink):
657       self.fail("File '%s' not removed" % symlink)
658     os.symlink(self.tmpfile, symlink)
659     RemoveFile(symlink)
660     if os.path.exists(symlink):
661       self.fail("File '%s' not removed" % symlink)
662
663
664 class TestRename(unittest.TestCase):
665   """Test case for RenameFile"""
666
667   def setUp(self):
668     """Create a temporary directory"""
669     self.tmpdir = tempfile.mkdtemp()
670     self.tmpfile = os.path.join(self.tmpdir, "test1")
671
672     # Touch the file
673     open(self.tmpfile, "w").close()
674
675   def tearDown(self):
676     """Remove temporary directory"""
677     shutil.rmtree(self.tmpdir)
678
679   def testSimpleRename1(self):
680     """Simple rename 1"""
681     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
682     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
683
684   def testSimpleRename2(self):
685     """Simple rename 2"""
686     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
687                      mkdir=True)
688     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
689
690   def testRenameMkdir(self):
691     """Rename with mkdir"""
692     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
693                      mkdir=True)
694     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
695     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
696
697     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
698                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
699                      mkdir=True)
700     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
701     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
702     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
703
704
705 class TestMatchNameComponent(unittest.TestCase):
706   """Test case for the MatchNameComponent function"""
707
708   def testEmptyList(self):
709     """Test that there is no match against an empty list"""
710
711     self.failUnlessEqual(MatchNameComponent("", []), None)
712     self.failUnlessEqual(MatchNameComponent("test", []), None)
713
714   def testSingleMatch(self):
715     """Test that a single match is performed correctly"""
716     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
717     for key in "test2", "test2.example", "test2.example.com":
718       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
719
720   def testMultipleMatches(self):
721     """Test that a multiple match is returned as None"""
722     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
723     for key in "test1", "test1.example":
724       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
725
726   def testFullMatch(self):
727     """Test that a full match is returned correctly"""
728     key1 = "test1"
729     key2 = "test1.example"
730     mlist = [key2, key2 + ".com"]
731     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
732     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
733
734   def testCaseInsensitivePartialMatch(self):
735     """Test for the case_insensitive keyword"""
736     mlist = ["test1.example.com", "test2.example.net"]
737     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
738                      "test2.example.net")
739     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
740                      "test2.example.net")
741     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
742                      "test2.example.net")
743     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
744                      "test2.example.net")
745
746
747   def testCaseInsensitiveFullMatch(self):
748     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
749     # Between the two ts1 a full string match non-case insensitive should work
750     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
751                      None)
752     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
753                      "ts1.ex")
754     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
755                      "ts1.ex")
756     # Between the two ts2 only case differs, so only case-match works
757     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
758                      "ts2.ex")
759     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
760                      "Ts2.ex")
761     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
762                      None)
763
764
765 class TestReadFile(testutils.GanetiTestCase):
766
767   def testReadAll(self):
768     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
769     self.assertEqual(len(data), 814)
770
771     h = compat.md5_hash()
772     h.update(data)
773     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
774
775   def testReadSize(self):
776     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
777                           size=100)
778     self.assertEqual(len(data), 100)
779
780     h = compat.md5_hash()
781     h.update(data)
782     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
783
784   def testError(self):
785     self.assertRaises(EnvironmentError, utils.ReadFile,
786                       "/dev/null/does-not-exist")
787
788
789 class TestReadOneLineFile(testutils.GanetiTestCase):
790
791   def setUp(self):
792     testutils.GanetiTestCase.setUp(self)
793
794   def testDefault(self):
795     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
796     self.assertEqual(len(data), 27)
797     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
798
799   def testNotStrict(self):
800     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
801     self.assertEqual(len(data), 27)
802     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
803
804   def testStrictFailure(self):
805     self.assertRaises(errors.GenericError, ReadOneLineFile,
806                       self._TestDataFilename("cert1.pem"), strict=True)
807
808   def testLongLine(self):
809     dummydata = (1024 * "Hello World! ")
810     myfile = self._CreateTempFile()
811     utils.WriteFile(myfile, data=dummydata)
812     datastrict = ReadOneLineFile(myfile, strict=True)
813     datalax = ReadOneLineFile(myfile, strict=False)
814     self.assertEqual(dummydata, datastrict)
815     self.assertEqual(dummydata, datalax)
816
817   def testNewline(self):
818     myfile = self._CreateTempFile()
819     myline = "myline"
820     for nl in ["", "\n", "\r\n"]:
821       dummydata = "%s%s" % (myline, nl)
822       utils.WriteFile(myfile, data=dummydata)
823       datalax = ReadOneLineFile(myfile, strict=False)
824       self.assertEqual(myline, datalax)
825       datastrict = ReadOneLineFile(myfile, strict=True)
826       self.assertEqual(myline, datastrict)
827
828   def testWhitespaceAndMultipleLines(self):
829     myfile = self._CreateTempFile()
830     for nl in ["", "\n", "\r\n"]:
831       for ws in [" ", "\t", "\t\t  \t", "\t "]:
832         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
833         utils.WriteFile(myfile, data=dummydata)
834         datalax = ReadOneLineFile(myfile, strict=False)
835         if nl:
836           self.assert_(set("\r\n") & set(dummydata))
837           self.assertRaises(errors.GenericError, ReadOneLineFile,
838                             myfile, strict=True)
839           explen = len("Foo bar baz ") + len(ws)
840           self.assertEqual(len(datalax), explen)
841           self.assertEqual(datalax, dummydata[:explen])
842           self.assertFalse(set("\r\n") & set(datalax))
843         else:
844           datastrict = ReadOneLineFile(myfile, strict=True)
845           self.assertEqual(dummydata, datastrict)
846           self.assertEqual(dummydata, datalax)
847
848   def testEmptylines(self):
849     myfile = self._CreateTempFile()
850     myline = "myline"
851     for nl in ["\n", "\r\n"]:
852       for ol in ["", "otherline"]:
853         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
854         utils.WriteFile(myfile, data=dummydata)
855         self.assert_(set("\r\n") & set(dummydata))
856         datalax = ReadOneLineFile(myfile, strict=False)
857         self.assertEqual(myline, datalax)
858         if ol:
859           self.assertRaises(errors.GenericError, ReadOneLineFile,
860                             myfile, strict=True)
861         else:
862           datastrict = ReadOneLineFile(myfile, strict=True)
863           self.assertEqual(myline, datastrict)
864
865
866 class TestTimestampForFilename(unittest.TestCase):
867   def test(self):
868     self.assert_("." not in utils.TimestampForFilename())
869     self.assert_(":" not in utils.TimestampForFilename())
870
871
872 class TestCreateBackup(testutils.GanetiTestCase):
873   def setUp(self):
874     testutils.GanetiTestCase.setUp(self)
875
876     self.tmpdir = tempfile.mkdtemp()
877
878   def tearDown(self):
879     testutils.GanetiTestCase.tearDown(self)
880
881     shutil.rmtree(self.tmpdir)
882
883   def testEmpty(self):
884     filename = utils.PathJoin(self.tmpdir, "config.data")
885     utils.WriteFile(filename, data="")
886     bname = utils.CreateBackup(filename)
887     self.assertFileContent(bname, "")
888     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
889     utils.CreateBackup(filename)
890     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
891     utils.CreateBackup(filename)
892     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
893
894     fifoname = utils.PathJoin(self.tmpdir, "fifo")
895     os.mkfifo(fifoname)
896     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
897
898   def testContent(self):
899     bkpcount = 0
900     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
901       for rep in [1, 2, 10, 127]:
902         testdata = data * rep
903
904         filename = utils.PathJoin(self.tmpdir, "test.data_")
905         utils.WriteFile(filename, data=testdata)
906         self.assertFileContent(filename, testdata)
907
908         for _ in range(3):
909           bname = utils.CreateBackup(filename)
910           bkpcount += 1
911           self.assertFileContent(bname, testdata)
912           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
913
914
915 class TestFormatUnit(unittest.TestCase):
916   """Test case for the FormatUnit function"""
917
918   def testMiB(self):
919     self.assertEqual(FormatUnit(1, 'h'), '1M')
920     self.assertEqual(FormatUnit(100, 'h'), '100M')
921     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
922
923     self.assertEqual(FormatUnit(1, 'm'), '1')
924     self.assertEqual(FormatUnit(100, 'm'), '100')
925     self.assertEqual(FormatUnit(1023, 'm'), '1023')
926
927     self.assertEqual(FormatUnit(1024, 'm'), '1024')
928     self.assertEqual(FormatUnit(1536, 'm'), '1536')
929     self.assertEqual(FormatUnit(17133, 'm'), '17133')
930     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
931
932   def testGiB(self):
933     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
934     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
935     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
936     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
937
938     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
939     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
940     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
941     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
942
943     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
944     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
945     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
946
947   def testTiB(self):
948     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
949     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
950     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
951
952     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
953     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
954     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
955
956 class TestParseUnit(unittest.TestCase):
957   """Test case for the ParseUnit function"""
958
959   SCALES = (('', 1),
960             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
961             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
962             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
963
964   def testRounding(self):
965     self.assertEqual(ParseUnit('0'), 0)
966     self.assertEqual(ParseUnit('1'), 4)
967     self.assertEqual(ParseUnit('2'), 4)
968     self.assertEqual(ParseUnit('3'), 4)
969
970     self.assertEqual(ParseUnit('124'), 124)
971     self.assertEqual(ParseUnit('125'), 128)
972     self.assertEqual(ParseUnit('126'), 128)
973     self.assertEqual(ParseUnit('127'), 128)
974     self.assertEqual(ParseUnit('128'), 128)
975     self.assertEqual(ParseUnit('129'), 132)
976     self.assertEqual(ParseUnit('130'), 132)
977
978   def testFloating(self):
979     self.assertEqual(ParseUnit('0'), 0)
980     self.assertEqual(ParseUnit('0.5'), 4)
981     self.assertEqual(ParseUnit('1.75'), 4)
982     self.assertEqual(ParseUnit('1.99'), 4)
983     self.assertEqual(ParseUnit('2.00'), 4)
984     self.assertEqual(ParseUnit('2.01'), 4)
985     self.assertEqual(ParseUnit('3.99'), 4)
986     self.assertEqual(ParseUnit('4.00'), 4)
987     self.assertEqual(ParseUnit('4.01'), 8)
988     self.assertEqual(ParseUnit('1.5G'), 1536)
989     self.assertEqual(ParseUnit('1.8G'), 1844)
990     self.assertEqual(ParseUnit('8.28T'), 8682212)
991
992   def testSuffixes(self):
993     for sep in ('', ' ', '   ', "\t", "\t "):
994       for suffix, scale in TestParseUnit.SCALES:
995         for func in (lambda x: x, str.lower, str.upper):
996           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
997                            1024 * scale)
998
999   def testInvalidInput(self):
1000     for sep in ('-', '_', ',', 'a'):
1001       for suffix, _ in TestParseUnit.SCALES:
1002         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
1003
1004     for suffix, _ in TestParseUnit.SCALES:
1005       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
1006
1007
1008 class TestSshKeys(testutils.GanetiTestCase):
1009   """Test case for the AddAuthorizedKey function"""
1010
1011   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1012   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
1013            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1014
1015   def setUp(self):
1016     testutils.GanetiTestCase.setUp(self)
1017     self.tmpname = self._CreateTempFile()
1018     handle = open(self.tmpname, 'w')
1019     try:
1020       handle.write("%s\n" % TestSshKeys.KEY_A)
1021       handle.write("%s\n" % TestSshKeys.KEY_B)
1022     finally:
1023       handle.close()
1024
1025   def testAddingNewKey(self):
1026     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1027
1028     self.assertFileContent(self.tmpname,
1029       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1030       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1031       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1032       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1033
1034   def testAddingAlmostButNotCompletelyTheSameKey(self):
1035     AddAuthorizedKey(self.tmpname,
1036         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1037
1038     self.assertFileContent(self.tmpname,
1039       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1040       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1041       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1042       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1043
1044   def testAddingExistingKeyWithSomeMoreSpaces(self):
1045     AddAuthorizedKey(self.tmpname,
1046         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1047
1048     self.assertFileContent(self.tmpname,
1049       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1050       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1051       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1052
1053   def testRemovingExistingKeyWithSomeMoreSpaces(self):
1054     RemoveAuthorizedKey(self.tmpname,
1055         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1056
1057     self.assertFileContent(self.tmpname,
1058       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1059       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1060
1061   def testRemovingNonExistingKey(self):
1062     RemoveAuthorizedKey(self.tmpname,
1063         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1064
1065     self.assertFileContent(self.tmpname,
1066       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1067       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1068       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1069
1070
1071 class TestEtcHosts(testutils.GanetiTestCase):
1072   """Test functions modifying /etc/hosts"""
1073
1074   def setUp(self):
1075     testutils.GanetiTestCase.setUp(self)
1076     self.tmpname = self._CreateTempFile()
1077     handle = open(self.tmpname, 'w')
1078     try:
1079       handle.write('# This is a test file for /etc/hosts\n')
1080       handle.write('127.0.0.1\tlocalhost\n')
1081       handle.write('192.168.1.1 router gw\n')
1082     finally:
1083       handle.close()
1084
1085   def testSettingNewIp(self):
1086     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
1087
1088     self.assertFileContent(self.tmpname,
1089       "# This is a test file for /etc/hosts\n"
1090       "127.0.0.1\tlocalhost\n"
1091       "192.168.1.1 router gw\n"
1092       "1.2.3.4\tmyhost.domain.tld myhost\n")
1093     self.assertFileMode(self.tmpname, 0644)
1094
1095   def testSettingExistingIp(self):
1096     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
1097                      ['myhost'])
1098
1099     self.assertFileContent(self.tmpname,
1100       "# This is a test file for /etc/hosts\n"
1101       "127.0.0.1\tlocalhost\n"
1102       "192.168.1.1\tmyhost.domain.tld myhost\n")
1103     self.assertFileMode(self.tmpname, 0644)
1104
1105   def testSettingDuplicateName(self):
1106     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
1107
1108     self.assertFileContent(self.tmpname,
1109       "# This is a test file for /etc/hosts\n"
1110       "127.0.0.1\tlocalhost\n"
1111       "192.168.1.1 router gw\n"
1112       "1.2.3.4\tmyhost\n")
1113     self.assertFileMode(self.tmpname, 0644)
1114
1115   def testRemovingExistingHost(self):
1116     RemoveEtcHostsEntry(self.tmpname, 'router')
1117
1118     self.assertFileContent(self.tmpname,
1119       "# This is a test file for /etc/hosts\n"
1120       "127.0.0.1\tlocalhost\n"
1121       "192.168.1.1 gw\n")
1122     self.assertFileMode(self.tmpname, 0644)
1123
1124   def testRemovingSingleExistingHost(self):
1125     RemoveEtcHostsEntry(self.tmpname, 'localhost')
1126
1127     self.assertFileContent(self.tmpname,
1128       "# This is a test file for /etc/hosts\n"
1129       "192.168.1.1 router gw\n")
1130     self.assertFileMode(self.tmpname, 0644)
1131
1132   def testRemovingNonExistingHost(self):
1133     RemoveEtcHostsEntry(self.tmpname, 'myhost')
1134
1135     self.assertFileContent(self.tmpname,
1136       "# This is a test file for /etc/hosts\n"
1137       "127.0.0.1\tlocalhost\n"
1138       "192.168.1.1 router gw\n")
1139     self.assertFileMode(self.tmpname, 0644)
1140
1141   def testRemovingAlias(self):
1142     RemoveEtcHostsEntry(self.tmpname, 'gw')
1143
1144     self.assertFileContent(self.tmpname,
1145       "# This is a test file for /etc/hosts\n"
1146       "127.0.0.1\tlocalhost\n"
1147       "192.168.1.1 router\n")
1148     self.assertFileMode(self.tmpname, 0644)
1149
1150
1151 class TestGetMounts(unittest.TestCase):
1152   """Test case for GetMounts()."""
1153
1154   TESTDATA = (
1155     "rootfs /     rootfs rw 0 0\n"
1156     "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1157     "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1158
1159   def setUp(self):
1160     self.tmpfile = tempfile.NamedTemporaryFile()
1161     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1162
1163   def testGetMounts(self):
1164     self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1165       [
1166         ("rootfs", "/", "rootfs", "rw"),
1167         ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1168         ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1169       ])
1170
1171
1172 class TestShellQuoting(unittest.TestCase):
1173   """Test case for shell quoting functions"""
1174
1175   def testShellQuote(self):
1176     self.assertEqual(ShellQuote('abc'), "abc")
1177     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1178     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1179     self.assertEqual(ShellQuote("a b c"), "'a b c'")
1180     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1181
1182   def testShellQuoteArgs(self):
1183     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1184     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1185     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1186
1187
1188 class _BaseTcpPingTest:
1189   """Base class for TcpPing tests against listen(2)ing port"""
1190   family = None
1191   address = None
1192
1193   def setUp(self):
1194     self.listener = socket.socket(self.family, socket.SOCK_STREAM)
1195     self.listener.bind((self.address, 0))
1196     self.listenerport = self.listener.getsockname()[1]
1197     self.listener.listen(1)
1198
1199   def tearDown(self):
1200     self.listener.shutdown(socket.SHUT_RDWR)
1201     del self.listener
1202     del self.listenerport
1203
1204   def testTcpPingToLocalHostAccept(self):
1205     self.assert_(TcpPing(self.address,
1206                          self.listenerport,
1207                          timeout=constants.TCP_PING_TIMEOUT,
1208                          live_port_needed=True,
1209                          source=self.address,
1210                          ),
1211                  "failed to connect to test listener")
1212
1213     self.assert_(TcpPing(self.address, self.listenerport,
1214                          timeout=constants.TCP_PING_TIMEOUT,
1215                          live_port_needed=True),
1216                  "failed to connect to test listener (no source)")
1217
1218
1219 class TestIP4TcpPing(unittest.TestCase, _BaseTcpPingTest):
1220   """Testcase for IPv4 TCP version of ping - against listen(2)ing port"""
1221   family = socket.AF_INET
1222   address = constants.IP4_ADDRESS_LOCALHOST
1223
1224   def setUp(self):
1225     unittest.TestCase.setUp(self)
1226     _BaseTcpPingTest.setUp(self)
1227
1228   def tearDown(self):
1229     unittest.TestCase.tearDown(self)
1230     _BaseTcpPingTest.tearDown(self)
1231
1232
1233 class TestIP6TcpPing(unittest.TestCase, _BaseTcpPingTest):
1234   """Testcase for IPv6 TCP version of ping - against listen(2)ing port"""
1235   family = socket.AF_INET6
1236   address = constants.IP6_ADDRESS_LOCALHOST
1237
1238   def setUp(self):
1239     unittest.TestCase.setUp(self)
1240     _BaseTcpPingTest.setUp(self)
1241
1242   def tearDown(self):
1243     unittest.TestCase.tearDown(self)
1244     _BaseTcpPingTest.tearDown(self)
1245
1246
1247 class _BaseTcpPingDeafTest:
1248   """Base class for TcpPing tests against non listen(2)ing port"""
1249   family = None
1250   address = None
1251
1252   def setUp(self):
1253     self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
1254     self.deaflistener.bind((self.address, 0))
1255     self.deaflistenerport = self.deaflistener.getsockname()[1]
1256
1257   def tearDown(self):
1258     del self.deaflistener
1259     del self.deaflistenerport
1260
1261   def testTcpPingToLocalHostAcceptDeaf(self):
1262     self.assertFalse(TcpPing(self.address,
1263                              self.deaflistenerport,
1264                              timeout=constants.TCP_PING_TIMEOUT,
1265                              live_port_needed=True,
1266                              source=self.address,
1267                              ), # need successful connect(2)
1268                      "successfully connected to deaf listener")
1269
1270     self.assertFalse(TcpPing(self.address,
1271                              self.deaflistenerport,
1272                              timeout=constants.TCP_PING_TIMEOUT,
1273                              live_port_needed=True,
1274                              ), # need successful connect(2)
1275                      "successfully connected to deaf listener (no source)")
1276
1277   def testTcpPingToLocalHostNoAccept(self):
1278     self.assert_(TcpPing(self.address,
1279                          self.deaflistenerport,
1280                          timeout=constants.TCP_PING_TIMEOUT,
1281                          live_port_needed=False,
1282                          source=self.address,
1283                          ), # ECONNREFUSED is OK
1284                  "failed to ping alive host on deaf port")
1285
1286     self.assert_(TcpPing(self.address,
1287                          self.deaflistenerport,
1288                          timeout=constants.TCP_PING_TIMEOUT,
1289                          live_port_needed=False,
1290                          ), # ECONNREFUSED is OK
1291                  "failed to ping alive host on deaf port (no source)")
1292
1293
1294 class TestIP4TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
1295   """Testcase for IPv4 TCP version of ping - against non listen(2)ing port"""
1296   family = socket.AF_INET
1297   address = constants.IP4_ADDRESS_LOCALHOST
1298
1299   def setUp(self):
1300     self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
1301     self.deaflistener.bind((self.address, 0))
1302     self.deaflistenerport = self.deaflistener.getsockname()[1]
1303
1304   def tearDown(self):
1305     del self.deaflistener
1306     del self.deaflistenerport
1307
1308
1309 class TestIP6TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
1310   """Testcase for IPv6 TCP version of ping - against non listen(2)ing port"""
1311   family = socket.AF_INET6
1312   address = constants.IP6_ADDRESS_LOCALHOST
1313
1314   def setUp(self):
1315     unittest.TestCase.setUp(self)
1316     _BaseTcpPingDeafTest.setUp(self)
1317
1318   def tearDown(self):
1319     unittest.TestCase.tearDown(self)
1320     _BaseTcpPingDeafTest.tearDown(self)
1321
1322
1323 class TestOwnIpAddress(unittest.TestCase):
1324   """Testcase for OwnIpAddress"""
1325
1326   def testOwnLoopback(self):
1327     """check having the loopback ip"""
1328     self.failUnless(OwnIpAddress(constants.IP4_ADDRESS_LOCALHOST),
1329                     "Should own the loopback address")
1330
1331   def testNowOwnAddress(self):
1332     """check that I don't own an address"""
1333
1334     # Network 192.0.2.0/24 is reserved for test/documentation as per
1335     # RFC 5735, so we *should* not have an address of this range... if
1336     # this fails, we should extend the test to multiple addresses
1337     DST_IP = "192.0.2.1"
1338     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1339
1340
1341 def _GetSocketCredentials(path):
1342   """Connect to a Unix socket and return remote credentials.
1343
1344   """
1345   sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1346   try:
1347     sock.settimeout(10)
1348     sock.connect(path)
1349     return utils.GetSocketCredentials(sock)
1350   finally:
1351     sock.close()
1352
1353
1354 class TestGetSocketCredentials(unittest.TestCase):
1355   def setUp(self):
1356     self.tmpdir = tempfile.mkdtemp()
1357     self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1358
1359     self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1360     self.listener.settimeout(10)
1361     self.listener.bind(self.sockpath)
1362     self.listener.listen(1)
1363
1364   def tearDown(self):
1365     self.listener.shutdown(socket.SHUT_RDWR)
1366     self.listener.close()
1367     shutil.rmtree(self.tmpdir)
1368
1369   def test(self):
1370     (c2pr, c2pw) = os.pipe()
1371
1372     # Start child process
1373     child = os.fork()
1374     if child == 0:
1375       try:
1376         data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1377
1378         os.write(c2pw, data)
1379         os.close(c2pw)
1380
1381         os._exit(0)
1382       finally:
1383         os._exit(1)
1384
1385     os.close(c2pw)
1386
1387     # Wait for one connection
1388     (conn, _) = self.listener.accept()
1389     conn.recv(1)
1390     conn.close()
1391
1392     # Wait for result
1393     result = os.read(c2pr, 4096)
1394     os.close(c2pr)
1395
1396     # Check child's exit code
1397     (_, status) = os.waitpid(child, 0)
1398     self.assertFalse(os.WIFSIGNALED(status))
1399     self.assertEqual(os.WEXITSTATUS(status), 0)
1400
1401     # Check result
1402     (pid, uid, gid) = serializer.LoadJson(result)
1403     self.assertEqual(pid, os.getpid())
1404     self.assertEqual(uid, os.getuid())
1405     self.assertEqual(gid, os.getgid())
1406
1407
1408 class TestListVisibleFiles(unittest.TestCase):
1409   """Test case for ListVisibleFiles"""
1410
1411   def setUp(self):
1412     self.path = tempfile.mkdtemp()
1413
1414   def tearDown(self):
1415     shutil.rmtree(self.path)
1416
1417   def _CreateFiles(self, files):
1418     for name in files:
1419       utils.WriteFile(os.path.join(self.path, name), data="test")
1420
1421   def _test(self, files, expected):
1422     self._CreateFiles(files)
1423     found = ListVisibleFiles(self.path)
1424     self.assertEqual(set(found), set(expected))
1425
1426   def testAllVisible(self):
1427     files = ["a", "b", "c"]
1428     expected = files
1429     self._test(files, expected)
1430
1431   def testNoneVisible(self):
1432     files = [".a", ".b", ".c"]
1433     expected = []
1434     self._test(files, expected)
1435
1436   def testSomeVisible(self):
1437     files = ["a", "b", ".c"]
1438     expected = ["a", "b"]
1439     self._test(files, expected)
1440
1441   def testNonAbsolutePath(self):
1442     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1443
1444   def testNonNormalizedPath(self):
1445     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1446                           "/bin/../tmp")
1447
1448
1449 class TestNewUUID(unittest.TestCase):
1450   """Test case for NewUUID"""
1451
1452   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1453                         '[a-f0-9]{4}-[a-f0-9]{12}$')
1454
1455   def runTest(self):
1456     self.failUnless(self._re_uuid.match(utils.NewUUID()))
1457
1458
1459 class TestUniqueSequence(unittest.TestCase):
1460   """Test case for UniqueSequence"""
1461
1462   def _test(self, input, expected):
1463     self.assertEqual(utils.UniqueSequence(input), expected)
1464
1465   def runTest(self):
1466     # Ordered input
1467     self._test([1, 2, 3], [1, 2, 3])
1468     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1469     self._test([1, 2, 2, 3], [1, 2, 3])
1470     self._test([1, 2, 3, 3], [1, 2, 3])
1471
1472     # Unordered input
1473     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1474     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1475
1476     # Strings
1477     self._test(["a", "a"], ["a"])
1478     self._test(["a", "b"], ["a", "b"])
1479     self._test(["a", "b", "a"], ["a", "b"])
1480
1481
1482 class TestFirstFree(unittest.TestCase):
1483   """Test case for the FirstFree function"""
1484
1485   def test(self):
1486     """Test FirstFree"""
1487     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1488     self.failUnlessEqual(FirstFree([]), None)
1489     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1490     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1491     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1492
1493
1494 class TestTailFile(testutils.GanetiTestCase):
1495   """Test case for the TailFile function"""
1496
1497   def testEmpty(self):
1498     fname = self._CreateTempFile()
1499     self.failUnlessEqual(TailFile(fname), [])
1500     self.failUnlessEqual(TailFile(fname, lines=25), [])
1501
1502   def testAllLines(self):
1503     data = ["test %d" % i for i in range(30)]
1504     for i in range(30):
1505       fname = self._CreateTempFile()
1506       fd = open(fname, "w")
1507       fd.write("\n".join(data[:i]))
1508       if i > 0:
1509         fd.write("\n")
1510       fd.close()
1511       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1512
1513   def testPartialLines(self):
1514     data = ["test %d" % i for i in range(30)]
1515     fname = self._CreateTempFile()
1516     fd = open(fname, "w")
1517     fd.write("\n".join(data))
1518     fd.write("\n")
1519     fd.close()
1520     for i in range(1, 30):
1521       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1522
1523   def testBigFile(self):
1524     data = ["test %d" % i for i in range(30)]
1525     fname = self._CreateTempFile()
1526     fd = open(fname, "w")
1527     fd.write("X" * 1048576)
1528     fd.write("\n")
1529     fd.write("\n".join(data))
1530     fd.write("\n")
1531     fd.close()
1532     for i in range(1, 30):
1533       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1534
1535
1536 class _BaseFileLockTest:
1537   """Test case for the FileLock class"""
1538
1539   def testSharedNonblocking(self):
1540     self.lock.Shared(blocking=False)
1541     self.lock.Close()
1542
1543   def testExclusiveNonblocking(self):
1544     self.lock.Exclusive(blocking=False)
1545     self.lock.Close()
1546
1547   def testUnlockNonblocking(self):
1548     self.lock.Unlock(blocking=False)
1549     self.lock.Close()
1550
1551   def testSharedBlocking(self):
1552     self.lock.Shared(blocking=True)
1553     self.lock.Close()
1554
1555   def testExclusiveBlocking(self):
1556     self.lock.Exclusive(blocking=True)
1557     self.lock.Close()
1558
1559   def testUnlockBlocking(self):
1560     self.lock.Unlock(blocking=True)
1561     self.lock.Close()
1562
1563   def testSharedExclusiveUnlock(self):
1564     self.lock.Shared(blocking=False)
1565     self.lock.Exclusive(blocking=False)
1566     self.lock.Unlock(blocking=False)
1567     self.lock.Close()
1568
1569   def testExclusiveSharedUnlock(self):
1570     self.lock.Exclusive(blocking=False)
1571     self.lock.Shared(blocking=False)
1572     self.lock.Unlock(blocking=False)
1573     self.lock.Close()
1574
1575   def testSimpleTimeout(self):
1576     # These will succeed on the first attempt, hence a short timeout
1577     self.lock.Shared(blocking=True, timeout=10.0)
1578     self.lock.Exclusive(blocking=False, timeout=10.0)
1579     self.lock.Unlock(blocking=True, timeout=10.0)
1580     self.lock.Close()
1581
1582   @staticmethod
1583   def _TryLockInner(filename, shared, blocking):
1584     lock = utils.FileLock.Open(filename)
1585
1586     if shared:
1587       fn = lock.Shared
1588     else:
1589       fn = lock.Exclusive
1590
1591     try:
1592       # The timeout doesn't really matter as the parent process waits for us to
1593       # finish anyway.
1594       fn(blocking=blocking, timeout=0.01)
1595     except errors.LockError, err:
1596       return False
1597
1598     return True
1599
1600   def _TryLock(self, *args):
1601     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1602                                       *args)
1603
1604   def testTimeout(self):
1605     for blocking in [True, False]:
1606       self.lock.Exclusive(blocking=True)
1607       self.failIf(self._TryLock(False, blocking))
1608       self.failIf(self._TryLock(True, blocking))
1609
1610       self.lock.Shared(blocking=True)
1611       self.assert_(self._TryLock(True, blocking))
1612       self.failIf(self._TryLock(False, blocking))
1613
1614   def testCloseShared(self):
1615     self.lock.Close()
1616     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1617
1618   def testCloseExclusive(self):
1619     self.lock.Close()
1620     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1621
1622   def testCloseUnlock(self):
1623     self.lock.Close()
1624     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1625
1626
1627 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1628   TESTDATA = "Hello World\n" * 10
1629
1630   def setUp(self):
1631     testutils.GanetiTestCase.setUp(self)
1632
1633     self.tmpfile = tempfile.NamedTemporaryFile()
1634     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1635     self.lock = utils.FileLock.Open(self.tmpfile.name)
1636
1637     # Ensure "Open" didn't truncate file
1638     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1639
1640   def tearDown(self):
1641     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1642
1643     testutils.GanetiTestCase.tearDown(self)
1644
1645
1646 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1647   def setUp(self):
1648     self.tmpfile = tempfile.NamedTemporaryFile()
1649     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1650
1651
1652 class TestTimeFunctions(unittest.TestCase):
1653   """Test case for time functions"""
1654
1655   def runTest(self):
1656     self.assertEqual(utils.SplitTime(1), (1, 0))
1657     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1658     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1659     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1660     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1661     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1662     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1663     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1664
1665     self.assertRaises(AssertionError, utils.SplitTime, -1)
1666
1667     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1668     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1669     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1670
1671     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1672                      1218448917.481)
1673     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1674
1675     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1676     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1677     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1678     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1679     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1680
1681
1682 class FieldSetTestCase(unittest.TestCase):
1683   """Test case for FieldSets"""
1684
1685   def testSimpleMatch(self):
1686     f = utils.FieldSet("a", "b", "c", "def")
1687     self.failUnless(f.Matches("a"))
1688     self.failIf(f.Matches("d"), "Substring matched")
1689     self.failIf(f.Matches("defghi"), "Prefix string matched")
1690     self.failIf(f.NonMatching(["b", "c"]))
1691     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1692     self.failUnless(f.NonMatching(["a", "d"]))
1693
1694   def testRegexMatch(self):
1695     f = utils.FieldSet("a", "b([0-9]+)", "c")
1696     self.failUnless(f.Matches("b1"))
1697     self.failUnless(f.Matches("b99"))
1698     self.failIf(f.Matches("b/1"))
1699     self.failIf(f.NonMatching(["b12", "c"]))
1700     self.failUnless(f.NonMatching(["a", "1"]))
1701
1702 class TestForceDictType(unittest.TestCase):
1703   """Test case for ForceDictType"""
1704
1705   def setUp(self):
1706     self.key_types = {
1707       'a': constants.VTYPE_INT,
1708       'b': constants.VTYPE_BOOL,
1709       'c': constants.VTYPE_STRING,
1710       'd': constants.VTYPE_SIZE,
1711       }
1712
1713   def _fdt(self, dict, allowed_values=None):
1714     if allowed_values is None:
1715       ForceDictType(dict, self.key_types)
1716     else:
1717       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1718
1719     return dict
1720
1721   def testSimpleDict(self):
1722     self.assertEqual(self._fdt({}), {})
1723     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1724     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1725     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1726     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1727     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1728     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1729     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1730     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1731     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1732     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1733     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1734
1735   def testErrors(self):
1736     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1737     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1738     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1739     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1740
1741
1742 class TestIsNormAbsPath(unittest.TestCase):
1743   """Testing case for IsNormAbsPath"""
1744
1745   def _pathTestHelper(self, path, result):
1746     if result:
1747       self.assert_(IsNormAbsPath(path),
1748           "Path %s should result absolute and normalized" % path)
1749     else:
1750       self.assertFalse(IsNormAbsPath(path),
1751           "Path %s should not result absolute and normalized" % path)
1752
1753   def testBase(self):
1754     self._pathTestHelper('/etc', True)
1755     self._pathTestHelper('/srv', True)
1756     self._pathTestHelper('etc', False)
1757     self._pathTestHelper('/etc/../root', False)
1758     self._pathTestHelper('/etc/', False)
1759
1760
1761 class TestSafeEncode(unittest.TestCase):
1762   """Test case for SafeEncode"""
1763
1764   def testAscii(self):
1765     for txt in [string.digits, string.letters, string.punctuation]:
1766       self.failUnlessEqual(txt, SafeEncode(txt))
1767
1768   def testDoubleEncode(self):
1769     for i in range(255):
1770       txt = SafeEncode(chr(i))
1771       self.failUnlessEqual(txt, SafeEncode(txt))
1772
1773   def testUnicode(self):
1774     # 1024 is high enough to catch non-direct ASCII mappings
1775     for i in range(1024):
1776       txt = SafeEncode(unichr(i))
1777       self.failUnlessEqual(txt, SafeEncode(txt))
1778
1779
1780 class TestFormatTime(unittest.TestCase):
1781   """Testing case for FormatTime"""
1782
1783   def testNone(self):
1784     self.failUnlessEqual(FormatTime(None), "N/A")
1785
1786   def testInvalid(self):
1787     self.failUnlessEqual(FormatTime(()), "N/A")
1788
1789   def testNow(self):
1790     # tests that we accept time.time input
1791     FormatTime(time.time())
1792     # tests that we accept int input
1793     FormatTime(int(time.time()))
1794
1795
1796 class RunInSeparateProcess(unittest.TestCase):
1797   def test(self):
1798     for exp in [True, False]:
1799       def _child():
1800         return exp
1801
1802       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1803
1804   def testArgs(self):
1805     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1806       def _child(carg1, carg2):
1807         return carg1 == "Foo" and carg2 == arg
1808
1809       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1810
1811   def testPid(self):
1812     parent_pid = os.getpid()
1813
1814     def _check():
1815       return os.getpid() == parent_pid
1816
1817     self.failIf(utils.RunInSeparateProcess(_check))
1818
1819   def testSignal(self):
1820     def _kill():
1821       os.kill(os.getpid(), signal.SIGTERM)
1822
1823     self.assertRaises(errors.GenericError,
1824                       utils.RunInSeparateProcess, _kill)
1825
1826   def testException(self):
1827     def _exc():
1828       raise errors.GenericError("This is a test")
1829
1830     self.assertRaises(errors.GenericError,
1831                       utils.RunInSeparateProcess, _exc)
1832
1833
1834 class TestFingerprintFile(unittest.TestCase):
1835   def setUp(self):
1836     self.tmpfile = tempfile.NamedTemporaryFile()
1837
1838   def test(self):
1839     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1840                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1841
1842     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1843     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1844                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1845
1846
1847 class TestUnescapeAndSplit(unittest.TestCase):
1848   """Testing case for UnescapeAndSplit"""
1849
1850   def setUp(self):
1851     # testing more that one separator for regexp safety
1852     self._seps = [",", "+", "."]
1853
1854   def testSimple(self):
1855     a = ["a", "b", "c", "d"]
1856     for sep in self._seps:
1857       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1858
1859   def testEscape(self):
1860     for sep in self._seps:
1861       a = ["a", "b\\" + sep + "c", "d"]
1862       b = ["a", "b" + sep + "c", "d"]
1863       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1864
1865   def testDoubleEscape(self):
1866     for sep in self._seps:
1867       a = ["a", "b\\\\", "c", "d"]
1868       b = ["a", "b\\", "c", "d"]
1869       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1870
1871   def testThreeEscape(self):
1872     for sep in self._seps:
1873       a = ["a", "b\\\\\\" + sep + "c", "d"]
1874       b = ["a", "b\\" + sep + "c", "d"]
1875       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1876
1877
1878 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1879   def setUp(self):
1880     self.tmpdir = tempfile.mkdtemp()
1881
1882   def tearDown(self):
1883     shutil.rmtree(self.tmpdir)
1884
1885   def _checkRsaPrivateKey(self, key):
1886     lines = key.splitlines()
1887     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1888             "-----END RSA PRIVATE KEY-----" in lines)
1889
1890   def _checkCertificate(self, cert):
1891     lines = cert.splitlines()
1892     return ("-----BEGIN CERTIFICATE-----" in lines and
1893             "-----END CERTIFICATE-----" in lines)
1894
1895   def test(self):
1896     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1897       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1898       self._checkRsaPrivateKey(key_pem)
1899       self._checkCertificate(cert_pem)
1900
1901       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1902                                            key_pem)
1903       self.assert_(key.bits() >= 1024)
1904       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1905       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1906
1907       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1908                                              cert_pem)
1909       self.failIf(x509.has_expired())
1910       self.assertEqual(x509.get_issuer().CN, common_name)
1911       self.assertEqual(x509.get_subject().CN, common_name)
1912       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1913
1914   def testLegacy(self):
1915     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1916
1917     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1918
1919     cert1 = utils.ReadFile(cert1_filename)
1920
1921     self.assert_(self._checkRsaPrivateKey(cert1))
1922     self.assert_(self._checkCertificate(cert1))
1923
1924
1925 class TestPathJoin(unittest.TestCase):
1926   """Testing case for PathJoin"""
1927
1928   def testBasicItems(self):
1929     mlist = ["/a", "b", "c"]
1930     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1931
1932   def testNonAbsPrefix(self):
1933     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1934
1935   def testBackTrack(self):
1936     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1937
1938   def testMultiAbs(self):
1939     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1940
1941
1942 class TestHostInfo(unittest.TestCase):
1943   """Testing case for HostInfo"""
1944
1945   def testUppercase(self):
1946     data = "AbC.example.com"
1947     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1948
1949   def testTooLongName(self):
1950     data = "a.b." + "c" * 255
1951     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1952
1953   def testTrailingDot(self):
1954     data = "a.b.c"
1955     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1956
1957   def testInvalidName(self):
1958     data = [
1959       "a b",
1960       "a/b",
1961       ".a.b",
1962       "a..b",
1963       ]
1964     for value in data:
1965       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1966
1967   def testValidName(self):
1968     data = [
1969       "a.b",
1970       "a-b",
1971       "a_b",
1972       "a.b.c",
1973       ]
1974     for value in data:
1975       HostInfo.NormalizeName(value)
1976
1977
1978 class TestValidateServiceName(unittest.TestCase):
1979   def testValid(self):
1980     testnames = [
1981       0, 1, 2, 3, 1024, 65000, 65534, 65535,
1982       "ganeti",
1983       "gnt-masterd",
1984       "HELLO_WORLD_SVC",
1985       "hello.world.1",
1986       "0", "80", "1111", "65535",
1987       ]
1988
1989     for name in testnames:
1990       self.assertEqual(utils.ValidateServiceName(name), name)
1991
1992   def testInvalid(self):
1993     testnames = [
1994       -15756, -1, 65536, 133428083,
1995       "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1996       "-8546", "-1", "65536",
1997       (129 * "A"),
1998       ]
1999
2000     for name in testnames:
2001       self.assertRaises(OpPrereqError, utils.ValidateServiceName, name)
2002
2003
2004 class TestParseAsn1Generalizedtime(unittest.TestCase):
2005   def test(self):
2006     # UTC
2007     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
2008     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
2009                      1266860512)
2010     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
2011                      (2**31) - 1)
2012
2013     # With offset
2014     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
2015                      1266860512)
2016     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
2017                      1266931012)
2018     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
2019                      1266931088)
2020     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
2021                      1266931295)
2022     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
2023                      3600)
2024
2025     # Leap seconds are not supported by datetime.datetime
2026     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2027                       "19841231235960+0000")
2028     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2029                       "19920630235960+0000")
2030
2031     # Errors
2032     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
2033     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
2034     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2035                       "20100222174152")
2036     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2037                       "Mon Feb 22 17:47:02 UTC 2010")
2038     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2039                       "2010-02-22 17:42:02")
2040
2041
2042 class TestGetX509CertValidity(testutils.GanetiTestCase):
2043   def setUp(self):
2044     testutils.GanetiTestCase.setUp(self)
2045
2046     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
2047
2048     # Test whether we have pyOpenSSL 0.7 or above
2049     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
2050
2051     if not self.pyopenssl0_7:
2052       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
2053                     " function correctly")
2054
2055   def _LoadCert(self, name):
2056     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2057                                            self._ReadTestData(name))
2058
2059   def test(self):
2060     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
2061     if self.pyopenssl0_7:
2062       self.assertEqual(validity, (1266919967, 1267524767))
2063     else:
2064       self.assertEqual(validity, (None, None))
2065
2066
2067 class TestSignX509Certificate(unittest.TestCase):
2068   KEY = "My private key!"
2069   KEY_OTHER = "Another key"
2070
2071   def test(self):
2072     # Generate certificate valid for 5 minutes
2073     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
2074
2075     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2076                                            cert_pem)
2077
2078     # No signature at all
2079     self.assertRaises(errors.GenericError,
2080                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
2081
2082     # Invalid input
2083     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2084                       "", self.KEY)
2085     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2086                       "X-Ganeti-Signature: \n", self.KEY)
2087     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2088                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
2089     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2090                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
2091     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2092                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
2093
2094     # Invalid salt
2095     for salt in list("-_@$,:;/\\ \t\n"):
2096       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
2097                         cert_pem, self.KEY, "foo%sbar" % salt)
2098
2099     for salt in ["HelloWorld", "salt", string.letters, string.digits,
2100                  utils.GenerateSecret(numbytes=4),
2101                  utils.GenerateSecret(numbytes=16),
2102                  "{123:456}".encode("hex")]:
2103       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2104
2105       self._Check(cert, salt, signed_pem)
2106
2107       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2108       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2109       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2110                                "lines----\n------ at\nthe end!"))
2111
2112   def _Check(self, cert, salt, pem):
2113     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2114     self.assertEqual(salt, salt2)
2115     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2116
2117     # Other key
2118     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2119                       pem, self.KEY_OTHER)
2120
2121
2122 class TestMakedirs(unittest.TestCase):
2123   def setUp(self):
2124     self.tmpdir = tempfile.mkdtemp()
2125
2126   def tearDown(self):
2127     shutil.rmtree(self.tmpdir)
2128
2129   def testNonExisting(self):
2130     path = utils.PathJoin(self.tmpdir, "foo")
2131     utils.Makedirs(path)
2132     self.assert_(os.path.isdir(path))
2133
2134   def testExisting(self):
2135     path = utils.PathJoin(self.tmpdir, "foo")
2136     os.mkdir(path)
2137     utils.Makedirs(path)
2138     self.assert_(os.path.isdir(path))
2139
2140   def testRecursiveNonExisting(self):
2141     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
2142     utils.Makedirs(path)
2143     self.assert_(os.path.isdir(path))
2144
2145   def testRecursiveExisting(self):
2146     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
2147     self.assertFalse(os.path.exists(path))
2148     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
2149     utils.Makedirs(path)
2150     self.assert_(os.path.isdir(path))
2151
2152
2153 class TestRetry(testutils.GanetiTestCase):
2154   def setUp(self):
2155     testutils.GanetiTestCase.setUp(self)
2156     self.retries = 0
2157
2158   @staticmethod
2159   def _RaiseRetryAgain():
2160     raise utils.RetryAgain()
2161
2162   @staticmethod
2163   def _RaiseRetryAgainWithArg(args):
2164     raise utils.RetryAgain(*args)
2165
2166   def _WrongNestedLoop(self):
2167     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2168
2169   def _RetryAndSucceed(self, retries):
2170     if self.retries < retries:
2171       self.retries += 1
2172       raise utils.RetryAgain()
2173     else:
2174       return True
2175
2176   def testRaiseTimeout(self):
2177     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2178                           self._RaiseRetryAgain, 0.01, 0.02)
2179     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2180                           self._RetryAndSucceed, 0.01, 0, args=[1])
2181     self.failUnlessEqual(self.retries, 1)
2182
2183   def testComplete(self):
2184     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2185     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2186                          True)
2187     self.failUnlessEqual(self.retries, 2)
2188
2189   def testNestedLoop(self):
2190     try:
2191       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2192                             self._WrongNestedLoop, 0, 1)
2193     except utils.RetryTimeout:
2194       self.fail("Didn't detect inner loop's exception")
2195
2196   def testTimeoutArgument(self):
2197     retry_arg="my_important_debugging_message"
2198     try:
2199       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2200     except utils.RetryTimeout, err:
2201       self.failUnlessEqual(err.args, (retry_arg, ))
2202     else:
2203       self.fail("Expected timeout didn't happen")
2204
2205   def testRaiseInnerWithExc(self):
2206     retry_arg="my_important_debugging_message"
2207     try:
2208       try:
2209         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2210                     args=[[errors.GenericError(retry_arg, retry_arg)]])
2211       except utils.RetryTimeout, err:
2212         err.RaiseInner()
2213       else:
2214         self.fail("Expected timeout didn't happen")
2215     except errors.GenericError, err:
2216       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2217     else:
2218       self.fail("Expected GenericError didn't happen")
2219
2220   def testRaiseInnerWithMsg(self):
2221     retry_arg="my_important_debugging_message"
2222     try:
2223       try:
2224         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2225                     args=[[retry_arg, retry_arg]])
2226       except utils.RetryTimeout, err:
2227         err.RaiseInner()
2228       else:
2229         self.fail("Expected timeout didn't happen")
2230     except utils.RetryTimeout, err:
2231       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2232     else:
2233       self.fail("Expected RetryTimeout didn't happen")
2234
2235
2236 class TestLineSplitter(unittest.TestCase):
2237   def test(self):
2238     lines = []
2239     ls = utils.LineSplitter(lines.append)
2240     ls.write("Hello World\n")
2241     self.assertEqual(lines, [])
2242     ls.write("Foo\n Bar\r\n ")
2243     ls.write("Baz")
2244     ls.write("Moo")
2245     self.assertEqual(lines, [])
2246     ls.flush()
2247     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2248     ls.close()
2249     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2250
2251   def _testExtra(self, line, all_lines, p1, p2):
2252     self.assertEqual(p1, 999)
2253     self.assertEqual(p2, "extra")
2254     all_lines.append(line)
2255
2256   def testExtraArgsNoFlush(self):
2257     lines = []
2258     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2259     ls.write("\n\nHello World\n")
2260     ls.write("Foo\n Bar\r\n ")
2261     ls.write("")
2262     ls.write("Baz")
2263     ls.write("Moo\n\nx\n")
2264     self.assertEqual(lines, [])
2265     ls.close()
2266     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2267                              "", "x"])
2268
2269
2270 class TestReadLockedPidFile(unittest.TestCase):
2271   def setUp(self):
2272     self.tmpdir = tempfile.mkdtemp()
2273
2274   def tearDown(self):
2275     shutil.rmtree(self.tmpdir)
2276
2277   def testNonExistent(self):
2278     path = utils.PathJoin(self.tmpdir, "nonexist")
2279     self.assert_(utils.ReadLockedPidFile(path) is None)
2280
2281   def testUnlocked(self):
2282     path = utils.PathJoin(self.tmpdir, "pid")
2283     utils.WriteFile(path, data="123")
2284     self.assert_(utils.ReadLockedPidFile(path) is None)
2285
2286   def testLocked(self):
2287     path = utils.PathJoin(self.tmpdir, "pid")
2288     utils.WriteFile(path, data="123")
2289
2290     fl = utils.FileLock.Open(path)
2291     try:
2292       fl.Exclusive(blocking=True)
2293
2294       self.assertEqual(utils.ReadLockedPidFile(path), 123)
2295     finally:
2296       fl.Close()
2297
2298     self.assert_(utils.ReadLockedPidFile(path) is None)
2299
2300   def testError(self):
2301     path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2302     utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2303     # open(2) should return ENOTDIR
2304     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2305
2306
2307 class TestCertVerification(testutils.GanetiTestCase):
2308   def setUp(self):
2309     testutils.GanetiTestCase.setUp(self)
2310
2311     self.tmpdir = tempfile.mkdtemp()
2312
2313   def tearDown(self):
2314     shutil.rmtree(self.tmpdir)
2315
2316   def testVerifyCertificate(self):
2317     cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2318     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2319                                            cert_pem)
2320
2321     # Not checking return value as this certificate is expired
2322     utils.VerifyX509Certificate(cert, 30, 7)
2323
2324
2325 class TestVerifyCertificateInner(unittest.TestCase):
2326   def test(self):
2327     vci = utils._VerifyCertificateInner
2328
2329     # Valid
2330     self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2331                      (None, None))
2332
2333     # Not yet valid
2334     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2335     self.assertEqual(errcode, utils.CERT_WARNING)
2336
2337     # Expiring soon
2338     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2339     self.assertEqual(errcode, utils.CERT_ERROR)
2340
2341     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2342     self.assertEqual(errcode, utils.CERT_WARNING)
2343
2344     (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2345     self.assertEqual(errcode, None)
2346
2347     # Expired
2348     (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2349     self.assertEqual(errcode, utils.CERT_ERROR)
2350
2351     (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2352     self.assertEqual(errcode, utils.CERT_ERROR)
2353
2354     (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2355     self.assertEqual(errcode, utils.CERT_ERROR)
2356
2357     (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2358     self.assertEqual(errcode, utils.CERT_ERROR)
2359
2360
2361 class TestHmacFunctions(unittest.TestCase):
2362   # Digests can be checked with "openssl sha1 -hmac $key"
2363   def testSha1Hmac(self):
2364     self.assertEqual(utils.Sha1Hmac("", ""),
2365                      "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2366     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2367                      "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2368     self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2369                      "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2370
2371     longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2372     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2373                      "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2374
2375   def testSha1HmacSalt(self):
2376     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2377                      "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2378     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2379                      "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2380     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2381                      "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2382
2383   def testVerifySha1Hmac(self):
2384     self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2385                                                "7d64b71fb76370690e1d")))
2386     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2387                                       ("f904c2476527c6d3e660"
2388                                        "9ab683c66fa0652cb1dc")))
2389
2390     digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2391     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2392     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2393                                       digest.lower()))
2394     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2395                                       digest.upper()))
2396     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2397                                       digest.title()))
2398
2399   def testVerifySha1HmacSalt(self):
2400     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2401                                       ("17a4adc34d69c0d367d4"
2402                                        "ffbef96fd41d4df7a6e8"),
2403                                       salt="abc9"))
2404     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2405                                       ("7f264f8114c9066afc9b"
2406                                        "b7636e1786d996d3cc0d"),
2407                                       salt="xyz0"))
2408
2409
2410 class TestIgnoreSignals(unittest.TestCase):
2411   """Test the IgnoreSignals decorator"""
2412
2413   @staticmethod
2414   def _Raise(exception):
2415     raise exception
2416
2417   @staticmethod
2418   def _Return(rval):
2419     return rval
2420
2421   def testIgnoreSignals(self):
2422     sock_err_intr = socket.error(errno.EINTR, "Message")
2423     sock_err_inval = socket.error(errno.EINVAL, "Message")
2424
2425     env_err_intr = EnvironmentError(errno.EINTR, "Message")
2426     env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2427
2428     self.assertRaises(socket.error, self._Raise, sock_err_intr)
2429     self.assertRaises(socket.error, self._Raise, sock_err_inval)
2430     self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2431     self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2432
2433     self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2434     self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2435     self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2436                       sock_err_inval)
2437     self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2438                       env_err_inval)
2439
2440     self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2441     self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2442
2443
2444 class TestEnsureDirs(unittest.TestCase):
2445   """Tests for EnsureDirs"""
2446
2447   def setUp(self):
2448     self.dir = tempfile.mkdtemp()
2449     self.old_umask = os.umask(0777)
2450
2451   def testEnsureDirs(self):
2452     utils.EnsureDirs([
2453         (utils.PathJoin(self.dir, "foo"), 0777),
2454         (utils.PathJoin(self.dir, "bar"), 0000),
2455         ])
2456     self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2457     self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2458
2459   def tearDown(self):
2460     os.rmdir(utils.PathJoin(self.dir, "foo"))
2461     os.rmdir(utils.PathJoin(self.dir, "bar"))
2462     os.rmdir(self.dir)
2463     os.umask(self.old_umask)
2464
2465
2466 class TestFormatSeconds(unittest.TestCase):
2467   def test(self):
2468     self.assertEqual(utils.FormatSeconds(1), "1s")
2469     self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2470     self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2471     self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2472     self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2473     self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2474     self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2475     self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2476     self.assertEqual(utils.FormatSeconds(-1), "-1s")
2477     self.assertEqual(utils.FormatSeconds(-282), "-282s")
2478     self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2479
2480   def testFloat(self):
2481     self.assertEqual(utils.FormatSeconds(1.3), "1s")
2482     self.assertEqual(utils.FormatSeconds(1.9), "2s")
2483     self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2484     self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2485
2486
2487 class RunIgnoreProcessNotFound(unittest.TestCase):
2488   @staticmethod
2489   def _WritePid(fd):
2490     os.write(fd, str(os.getpid()))
2491     os.close(fd)
2492     return True
2493
2494   def test(self):
2495     (pid_read_fd, pid_write_fd) = os.pipe()
2496
2497     # Start short-lived process which writes its PID to pipe
2498     self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2499     os.close(pid_write_fd)
2500
2501     # Read PID from pipe
2502     pid = int(os.read(pid_read_fd, 1024))
2503     os.close(pid_read_fd)
2504
2505     # Try to send signal to process which exited recently
2506     self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2507
2508
2509 class TestIsValidIP4(unittest.TestCase):
2510   def test(self):
2511     self.assert_(utils.IsValidIP4("127.0.0.1"))
2512     self.assert_(utils.IsValidIP4("0.0.0.0"))
2513     self.assert_(utils.IsValidIP4("255.255.255.255"))
2514     self.assertFalse(utils.IsValidIP4("0"))
2515     self.assertFalse(utils.IsValidIP4("1"))
2516     self.assertFalse(utils.IsValidIP4("1.1.1"))
2517     self.assertFalse(utils.IsValidIP4("255.255.255.256"))
2518     self.assertFalse(utils.IsValidIP4("::1"))
2519
2520
2521 class TestIsValidIP6(unittest.TestCase):
2522   def test(self):
2523     self.assert_(utils.IsValidIP6("::"))
2524     self.assert_(utils.IsValidIP6("::1"))
2525     self.assert_(utils.IsValidIP6("1" + (":1" * 7)))
2526     self.assert_(utils.IsValidIP6("ffff" + (":ffff" * 7)))
2527     self.assertFalse(utils.IsValidIP6("0"))
2528     self.assertFalse(utils.IsValidIP6(":1"))
2529     self.assertFalse(utils.IsValidIP6("f" + (":f" * 6)))
2530     self.assertFalse(utils.IsValidIP6("fffg" + (":ffff" * 7)))
2531     self.assertFalse(utils.IsValidIP6("fffff" + (":ffff" * 7)))
2532     self.assertFalse(utils.IsValidIP6("1" + (":1" * 8)))
2533     self.assertFalse(utils.IsValidIP6("127.0.0.1"))
2534
2535
2536 class TestIsValidIP(unittest.TestCase):
2537   def test(self):
2538     self.assert_(utils.IsValidIP("0.0.0.0"))
2539     self.assert_(utils.IsValidIP("127.0.0.1"))
2540     self.assert_(utils.IsValidIP("::"))
2541     self.assert_(utils.IsValidIP("::1"))
2542     self.assertFalse(utils.IsValidIP("0"))
2543     self.assertFalse(utils.IsValidIP("1.1.1.256"))
2544     self.assertFalse(utils.IsValidIP("a:g::1"))
2545
2546
2547 class TestGetAddressFamily(unittest.TestCase):
2548   def test(self):
2549     self.assertEqual(utils.GetAddressFamily("127.0.0.1"), socket.AF_INET)
2550     self.assertEqual(utils.GetAddressFamily("10.2.0.127"), socket.AF_INET)
2551     self.assertEqual(utils.GetAddressFamily("::1"), socket.AF_INET6)
2552     self.assertEqual(utils.GetAddressFamily("fe80::a00:27ff:fe08:5048"),
2553                      socket.AF_INET6)
2554     self.assertRaises(errors.GenericError, utils.GetAddressFamily, "0")
2555
2556
2557 if __name__ == '__main__':
2558   testutils.GanetiTestProgram()