Merge branch 'devel-2.1'
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import fcntl
38 import OpenSSL
39 import warnings
40 import distutils.version
41 import glob
42 import md5
43
44 import ganeti
45 import testutils
46 from ganeti import constants
47 from ganeti import utils
48 from ganeti import errors
49 from ganeti import serializer
50 from ganeti.utils import IsProcessAlive, RunCmd, \
51      RemoveFile, MatchNameComponent, FormatUnit, \
52      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
53      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
54      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
55      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
56      UnescapeAndSplit, RunParts, PathJoin, HostInfo
57
58 from ganeti.errors import LockError, UnitParseError, GenericError, \
59      ProgrammerError, OpPrereqError
60
61
62 class TestIsProcessAlive(unittest.TestCase):
63   """Testing case for IsProcessAlive"""
64
65   def testExists(self):
66     mypid = os.getpid()
67     self.assert_(IsProcessAlive(mypid),
68                  "can't find myself running")
69
70   def testNotExisting(self):
71     pid_non_existing = os.fork()
72     if pid_non_existing == 0:
73       os._exit(0)
74     elif pid_non_existing < 0:
75       raise SystemError("can't fork")
76     os.waitpid(pid_non_existing, 0)
77     self.assert_(not IsProcessAlive(pid_non_existing),
78                  "nonexisting process detected")
79
80
81 class TestPidFileFunctions(unittest.TestCase):
82   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
83
84   def setUp(self):
85     self.dir = tempfile.mkdtemp()
86     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
87     utils.DaemonPidFileName = self.f_dpn
88
89   def testPidFileFunctions(self):
90     pid_file = self.f_dpn('test')
91     utils.WritePidFile('test')
92     self.failUnless(os.path.exists(pid_file),
93                     "PID file should have been created")
94     read_pid = utils.ReadPidFile(pid_file)
95     self.failUnlessEqual(read_pid, os.getpid())
96     self.failUnless(utils.IsProcessAlive(read_pid))
97     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
98     utils.RemovePidFile('test')
99     self.failIf(os.path.exists(pid_file),
100                 "PID file should not exist anymore")
101     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
102                          "ReadPidFile should return 0 for missing pid file")
103     fh = open(pid_file, "w")
104     fh.write("blah\n")
105     fh.close()
106     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
107                          "ReadPidFile should return 0 for invalid pid file")
108     utils.RemovePidFile('test')
109     self.failIf(os.path.exists(pid_file),
110                 "PID file should not exist anymore")
111
112   def testKill(self):
113     pid_file = self.f_dpn('child')
114     r_fd, w_fd = os.pipe()
115     new_pid = os.fork()
116     if new_pid == 0: #child
117       utils.WritePidFile('child')
118       os.write(w_fd, 'a')
119       signal.pause()
120       os._exit(0)
121       return
122     # else we are in the parent
123     # wait until the child has written the pid file
124     os.read(r_fd, 1)
125     read_pid = utils.ReadPidFile(pid_file)
126     self.failUnlessEqual(read_pid, new_pid)
127     self.failUnless(utils.IsProcessAlive(new_pid))
128     utils.KillProcess(new_pid, waitpid=True)
129     self.failIf(utils.IsProcessAlive(new_pid))
130     utils.RemovePidFile('child')
131     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
132
133   def tearDown(self):
134     for name in os.listdir(self.dir):
135       os.unlink(os.path.join(self.dir, name))
136     os.rmdir(self.dir)
137
138
139 class TestRunCmd(testutils.GanetiTestCase):
140   """Testing case for the RunCmd function"""
141
142   def setUp(self):
143     testutils.GanetiTestCase.setUp(self)
144     self.magic = time.ctime() + " ganeti test"
145     self.fname = self._CreateTempFile()
146
147   def testOk(self):
148     """Test successful exit code"""
149     result = RunCmd("/bin/sh -c 'exit 0'")
150     self.assertEqual(result.exit_code, 0)
151     self.assertEqual(result.output, "")
152
153   def testFail(self):
154     """Test fail exit code"""
155     result = RunCmd("/bin/sh -c 'exit 1'")
156     self.assertEqual(result.exit_code, 1)
157     self.assertEqual(result.output, "")
158
159   def testStdout(self):
160     """Test standard output"""
161     cmd = 'echo -n "%s"' % self.magic
162     result = RunCmd("/bin/sh -c '%s'" % cmd)
163     self.assertEqual(result.stdout, self.magic)
164     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
165     self.assertEqual(result.output, "")
166     self.assertFileContent(self.fname, self.magic)
167
168   def testStderr(self):
169     """Test standard error"""
170     cmd = 'echo -n "%s"' % self.magic
171     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
172     self.assertEqual(result.stderr, self.magic)
173     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
174     self.assertEqual(result.output, "")
175     self.assertFileContent(self.fname, self.magic)
176
177   def testCombined(self):
178     """Test combined output"""
179     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
180     expected = "A" + self.magic + "B" + self.magic
181     result = RunCmd("/bin/sh -c '%s'" % cmd)
182     self.assertEqual(result.output, expected)
183     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
184     self.assertEqual(result.output, "")
185     self.assertFileContent(self.fname, expected)
186
187   def testSignal(self):
188     """Test signal"""
189     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
190     self.assertEqual(result.signal, 15)
191     self.assertEqual(result.output, "")
192
193   def testListRun(self):
194     """Test list runs"""
195     result = RunCmd(["true"])
196     self.assertEqual(result.signal, None)
197     self.assertEqual(result.exit_code, 0)
198     result = RunCmd(["/bin/sh", "-c", "exit 1"])
199     self.assertEqual(result.signal, None)
200     self.assertEqual(result.exit_code, 1)
201     result = RunCmd(["echo", "-n", self.magic])
202     self.assertEqual(result.signal, None)
203     self.assertEqual(result.exit_code, 0)
204     self.assertEqual(result.stdout, self.magic)
205
206   def testFileEmptyOutput(self):
207     """Test file output"""
208     result = RunCmd(["true"], output=self.fname)
209     self.assertEqual(result.signal, None)
210     self.assertEqual(result.exit_code, 0)
211     self.assertFileContent(self.fname, "")
212
213   def testLang(self):
214     """Test locale environment"""
215     old_env = os.environ.copy()
216     try:
217       os.environ["LANG"] = "en_US.UTF-8"
218       os.environ["LC_ALL"] = "en_US.UTF-8"
219       result = RunCmd(["locale"])
220       for line in result.output.splitlines():
221         key, value = line.split("=", 1)
222         # Ignore these variables, they're overridden by LC_ALL
223         if key == "LANG" or key == "LANGUAGE":
224           continue
225         self.failIf(value and value != "C" and value != '"C"',
226             "Variable %s is set to the invalid value '%s'" % (key, value))
227     finally:
228       os.environ = old_env
229
230   def testDefaultCwd(self):
231     """Test default working directory"""
232     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
233
234   def testCwd(self):
235     """Test default working directory"""
236     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
237     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
238     cwd = os.getcwd()
239     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
240
241   def testResetEnv(self):
242     """Test environment reset functionality"""
243     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
244     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
245                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
246
247
248 class TestRunParts(unittest.TestCase):
249   """Testing case for the RunParts function"""
250
251   def setUp(self):
252     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
253
254   def tearDown(self):
255     shutil.rmtree(self.rundir)
256
257   def testEmpty(self):
258     """Test on an empty dir"""
259     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
260
261   def testSkipWrongName(self):
262     """Test that wrong files are skipped"""
263     fname = os.path.join(self.rundir, "00test.dot")
264     utils.WriteFile(fname, data="")
265     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
266     relname = os.path.basename(fname)
267     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
268                          [(relname, constants.RUNPARTS_SKIP, None)])
269
270   def testSkipNonExec(self):
271     """Test that non executable files are skipped"""
272     fname = os.path.join(self.rundir, "00test")
273     utils.WriteFile(fname, data="")
274     relname = os.path.basename(fname)
275     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
276                          [(relname, constants.RUNPARTS_SKIP, None)])
277
278   def testError(self):
279     """Test error on a broken executable"""
280     fname = os.path.join(self.rundir, "00test")
281     utils.WriteFile(fname, data="")
282     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
283     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
284     self.failUnlessEqual(relname, os.path.basename(fname))
285     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
286     self.failUnless(error)
287
288   def testSorted(self):
289     """Test executions are sorted"""
290     files = []
291     files.append(os.path.join(self.rundir, "64test"))
292     files.append(os.path.join(self.rundir, "00test"))
293     files.append(os.path.join(self.rundir, "42test"))
294
295     for fname in files:
296       utils.WriteFile(fname, data="")
297
298     results = RunParts(self.rundir, reset_env=True)
299
300     for fname in sorted(files):
301       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
302
303   def testOk(self):
304     """Test correct execution"""
305     fname = os.path.join(self.rundir, "00test")
306     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
307     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
308     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
309     self.failUnlessEqual(relname, os.path.basename(fname))
310     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
311     self.failUnlessEqual(runresult.stdout, "ciao")
312
313   def testRunFail(self):
314     """Test correct execution, with run failure"""
315     fname = os.path.join(self.rundir, "00test")
316     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
317     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
318     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
319     self.failUnlessEqual(relname, os.path.basename(fname))
320     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
321     self.failUnlessEqual(runresult.exit_code, 1)
322     self.failUnless(runresult.failed)
323
324   def testRunMix(self):
325     files = []
326     files.append(os.path.join(self.rundir, "00test"))
327     files.append(os.path.join(self.rundir, "42test"))
328     files.append(os.path.join(self.rundir, "64test"))
329     files.append(os.path.join(self.rundir, "99test"))
330
331     files.sort()
332
333     # 1st has errors in execution
334     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
335     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
336
337     # 2nd is skipped
338     utils.WriteFile(files[1], data="")
339
340     # 3rd cannot execute properly
341     utils.WriteFile(files[2], data="")
342     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
343
344     # 4th execs
345     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
346     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
347
348     results = RunParts(self.rundir, reset_env=True)
349
350     (relname, status, runresult) = results[0]
351     self.failUnlessEqual(relname, os.path.basename(files[0]))
352     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
353     self.failUnlessEqual(runresult.exit_code, 1)
354     self.failUnless(runresult.failed)
355
356     (relname, status, runresult) = results[1]
357     self.failUnlessEqual(relname, os.path.basename(files[1]))
358     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
359     self.failUnlessEqual(runresult, None)
360
361     (relname, status, runresult) = results[2]
362     self.failUnlessEqual(relname, os.path.basename(files[2]))
363     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
364     self.failUnless(runresult)
365
366     (relname, status, runresult) = results[3]
367     self.failUnlessEqual(relname, os.path.basename(files[3]))
368     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
369     self.failUnlessEqual(runresult.output, "ciao")
370     self.failUnlessEqual(runresult.exit_code, 0)
371     self.failUnless(not runresult.failed)
372
373
374 class TestStartDaemon(testutils.GanetiTestCase):
375   def setUp(self):
376     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
377     self.tmpfile = os.path.join(self.tmpdir, "test")
378
379   def tearDown(self):
380     shutil.rmtree(self.tmpdir)
381
382   def testShell(self):
383     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
384     self._wait(self.tmpfile, 60.0, "Hello World")
385
386   def testShellOutput(self):
387     utils.StartDaemon("echo Hello World", output=self.tmpfile)
388     self._wait(self.tmpfile, 60.0, "Hello World")
389
390   def testNoShellNoOutput(self):
391     utils.StartDaemon(["pwd"])
392
393   def testNoShellNoOutputTouch(self):
394     testfile = os.path.join(self.tmpdir, "check")
395     self.failIf(os.path.exists(testfile))
396     utils.StartDaemon(["touch", testfile])
397     self._wait(testfile, 60.0, "")
398
399   def testNoShellOutput(self):
400     utils.StartDaemon(["pwd"], output=self.tmpfile)
401     self._wait(self.tmpfile, 60.0, "/")
402
403   def testNoShellOutputCwd(self):
404     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
405     self._wait(self.tmpfile, 60.0, os.getcwd())
406
407   def testShellEnv(self):
408     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
409                       env={ "GNT_TEST_VAR": "Hello World", })
410     self._wait(self.tmpfile, 60.0, "Hello World")
411
412   def testNoShellEnv(self):
413     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
414                       env={ "GNT_TEST_VAR": "Hello World", })
415     self._wait(self.tmpfile, 60.0, "Hello World")
416
417   def testOutputFd(self):
418     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
419     try:
420       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
421     finally:
422       os.close(fd)
423     self._wait(self.tmpfile, 60.0, os.getcwd())
424
425   def testPid(self):
426     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
427     self._wait(self.tmpfile, 60.0, str(pid))
428
429   def testPidFile(self):
430     pidfile = os.path.join(self.tmpdir, "pid")
431     checkfile = os.path.join(self.tmpdir, "abort")
432
433     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
434                             output=self.tmpfile)
435     try:
436       fd = os.open(pidfile, os.O_RDONLY)
437       try:
438         # Check file is locked
439         self.assertRaises(errors.LockError, utils.LockFile, fd)
440
441         pidtext = os.read(fd, 100)
442       finally:
443         os.close(fd)
444
445       self.assertEqual(int(pidtext.strip()), pid)
446
447       self.assert_(utils.IsProcessAlive(pid))
448     finally:
449       # No matter what happens, kill daemon
450       utils.KillProcess(pid, timeout=5.0, waitpid=False)
451       self.failIf(utils.IsProcessAlive(pid))
452
453     self.assertEqual(utils.ReadFile(self.tmpfile), "")
454
455   def _wait(self, path, timeout, expected):
456     # Due to the asynchronous nature of daemon processes, polling is necessary.
457     # A timeout makes sure the test doesn't hang forever.
458     def _CheckFile():
459       if not (os.path.isfile(path) and
460               utils.ReadFile(path).strip() == expected):
461         raise utils.RetryAgain()
462
463     try:
464       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
465     except utils.RetryTimeout:
466       self.fail("Apparently the daemon didn't run in %s seconds and/or"
467                 " didn't write the correct output" % timeout)
468
469   def testError(self):
470     self.assertRaises(errors.OpExecError, utils.StartDaemon,
471                       ["./does-NOT-EXIST/here/0123456789"])
472     self.assertRaises(errors.OpExecError, utils.StartDaemon,
473                       ["./does-NOT-EXIST/here/0123456789"],
474                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
475     self.assertRaises(errors.OpExecError, utils.StartDaemon,
476                       ["./does-NOT-EXIST/here/0123456789"],
477                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
478     self.assertRaises(errors.OpExecError, utils.StartDaemon,
479                       ["./does-NOT-EXIST/here/0123456789"],
480                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
481
482     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
483     try:
484       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
485                         ["./does-NOT-EXIST/here/0123456789"],
486                         output=self.tmpfile, output_fd=fd)
487     finally:
488       os.close(fd)
489
490
491 class TestSetCloseOnExecFlag(unittest.TestCase):
492   """Tests for SetCloseOnExecFlag"""
493
494   def setUp(self):
495     self.tmpfile = tempfile.TemporaryFile()
496
497   def testEnable(self):
498     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
499     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
500                     fcntl.FD_CLOEXEC)
501
502   def testDisable(self):
503     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
504     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
505                 fcntl.FD_CLOEXEC)
506
507
508 class TestSetNonblockFlag(unittest.TestCase):
509   def setUp(self):
510     self.tmpfile = tempfile.TemporaryFile()
511
512   def testEnable(self):
513     utils.SetNonblockFlag(self.tmpfile.fileno(), True)
514     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
515                     os.O_NONBLOCK)
516
517   def testDisable(self):
518     utils.SetNonblockFlag(self.tmpfile.fileno(), False)
519     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
520                 os.O_NONBLOCK)
521
522
523 class TestRemoveFile(unittest.TestCase):
524   """Test case for the RemoveFile function"""
525
526   def setUp(self):
527     """Create a temp dir and file for each case"""
528     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
529     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
530     os.close(fd)
531
532   def tearDown(self):
533     if os.path.exists(self.tmpfile):
534       os.unlink(self.tmpfile)
535     os.rmdir(self.tmpdir)
536
537   def testIgnoreDirs(self):
538     """Test that RemoveFile() ignores directories"""
539     self.assertEqual(None, RemoveFile(self.tmpdir))
540
541   def testIgnoreNotExisting(self):
542     """Test that RemoveFile() ignores non-existing files"""
543     RemoveFile(self.tmpfile)
544     RemoveFile(self.tmpfile)
545
546   def testRemoveFile(self):
547     """Test that RemoveFile does remove a file"""
548     RemoveFile(self.tmpfile)
549     if os.path.exists(self.tmpfile):
550       self.fail("File '%s' not removed" % self.tmpfile)
551
552   def testRemoveSymlink(self):
553     """Test that RemoveFile does remove symlinks"""
554     symlink = self.tmpdir + "/symlink"
555     os.symlink("no-such-file", symlink)
556     RemoveFile(symlink)
557     if os.path.exists(symlink):
558       self.fail("File '%s' not removed" % symlink)
559     os.symlink(self.tmpfile, symlink)
560     RemoveFile(symlink)
561     if os.path.exists(symlink):
562       self.fail("File '%s' not removed" % symlink)
563
564
565 class TestRename(unittest.TestCase):
566   """Test case for RenameFile"""
567
568   def setUp(self):
569     """Create a temporary directory"""
570     self.tmpdir = tempfile.mkdtemp()
571     self.tmpfile = os.path.join(self.tmpdir, "test1")
572
573     # Touch the file
574     open(self.tmpfile, "w").close()
575
576   def tearDown(self):
577     """Remove temporary directory"""
578     shutil.rmtree(self.tmpdir)
579
580   def testSimpleRename1(self):
581     """Simple rename 1"""
582     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
583     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
584
585   def testSimpleRename2(self):
586     """Simple rename 2"""
587     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
588                      mkdir=True)
589     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
590
591   def testRenameMkdir(self):
592     """Rename with mkdir"""
593     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
594                      mkdir=True)
595     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
596     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
597
598     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
599                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
600                      mkdir=True)
601     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
602     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
603     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
604
605
606 class TestMatchNameComponent(unittest.TestCase):
607   """Test case for the MatchNameComponent function"""
608
609   def testEmptyList(self):
610     """Test that there is no match against an empty list"""
611
612     self.failUnlessEqual(MatchNameComponent("", []), None)
613     self.failUnlessEqual(MatchNameComponent("test", []), None)
614
615   def testSingleMatch(self):
616     """Test that a single match is performed correctly"""
617     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
618     for key in "test2", "test2.example", "test2.example.com":
619       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
620
621   def testMultipleMatches(self):
622     """Test that a multiple match is returned as None"""
623     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
624     for key in "test1", "test1.example":
625       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
626
627   def testFullMatch(self):
628     """Test that a full match is returned correctly"""
629     key1 = "test1"
630     key2 = "test1.example"
631     mlist = [key2, key2 + ".com"]
632     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
633     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
634
635   def testCaseInsensitivePartialMatch(self):
636     """Test for the case_insensitive keyword"""
637     mlist = ["test1.example.com", "test2.example.net"]
638     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
639                      "test2.example.net")
640     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
641                      "test2.example.net")
642     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
643                      "test2.example.net")
644     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
645                      "test2.example.net")
646
647
648   def testCaseInsensitiveFullMatch(self):
649     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
650     # Between the two ts1 a full string match non-case insensitive should work
651     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
652                      None)
653     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
654                      "ts1.ex")
655     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
656                      "ts1.ex")
657     # Between the two ts2 only case differs, so only case-match works
658     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
659                      "ts2.ex")
660     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
661                      "Ts2.ex")
662     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
663                      None)
664
665
666 class TestReadFile(testutils.GanetiTestCase):
667   def setUp(self):
668     testutils.GanetiTestCase.setUp(self)
669
670     self.tmpdir = tempfile.mkdtemp()
671     self.fname = utils.PathJoin(self.tmpdir, "data1")
672
673   def tearDown(self):
674     testutils.GanetiTestCase.tearDown(self)
675
676     shutil.rmtree(self.tmpdir)
677
678   def testReadAll(self):
679     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
680     self.assertEqual(len(data), 814)
681
682     h = md5.new()
683     h.update(data)
684     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
685
686   def testReadSize(self):
687     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
688                           size=100)
689     self.assertEqual(len(data), 100)
690
691     h = md5.new()
692     h.update(data)
693     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
694
695   def testReadOneline(self):
696     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
697                           oneline=True)
698     self.assertEqual(len(data), 27)
699     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
700
701   def testReadOnelineSize(self):
702     dummydata = (1024 * "Hello World! ")
703     self.assertFalse(set("\r\n") & set(dummydata))
704
705     utils.WriteFile(self.fname, data=dummydata)
706
707     data = utils.ReadFile(self.fname, oneline=True, size=555)
708     self.assertEqual(len(data), 555)
709     self.assertEqual(data, dummydata[:555])
710     self.assertFalse(set("\r\n") & set(data))
711
712   def testReadOnelineSize2(self):
713     for end in ["\n", "\r\n"]:
714       dummydata = (1024 * ("Hello World%s" % end))
715       self.assert_(set("\r\n") & set(dummydata))
716
717       utils.WriteFile(self.fname, data=dummydata)
718
719       data = utils.ReadFile(self.fname, oneline=True, size=555)
720       self.assertEqual(len(data), len("Hello World"))
721       self.assertEqual(data, dummydata[:11])
722       self.assertFalse(set("\r\n") & set(data))
723
724   def testReadOnelineWhitespace(self):
725     for ws in [" ", "\t", "\t\t  \t", "\t "]:
726       dummydata = (1024 * ("Foo bar baz %s\n" % ws))
727       self.assert_(set("\r\n") & set(dummydata))
728
729       utils.WriteFile(self.fname, data=dummydata)
730
731       data = utils.ReadFile(self.fname, oneline=True, size=555)
732       explen = len("Foo bar baz ") + len(ws)
733       self.assertEqual(len(data), explen)
734       self.assertEqual(data, dummydata[:explen])
735       self.assertFalse(set("\r\n") & set(data))
736
737   def testError(self):
738     self.assertRaises(EnvironmentError, utils.ReadFile,
739                       utils.PathJoin(self.tmpdir, "does-not-exist"))
740
741
742 class TestTimestampForFilename(unittest.TestCase):
743   def test(self):
744     self.assert_("." not in utils.TimestampForFilename())
745     self.assert_(":" not in utils.TimestampForFilename())
746
747
748 class TestCreateBackup(testutils.GanetiTestCase):
749   def setUp(self):
750     testutils.GanetiTestCase.setUp(self)
751
752     self.tmpdir = tempfile.mkdtemp()
753
754   def tearDown(self):
755     testutils.GanetiTestCase.tearDown(self)
756
757     shutil.rmtree(self.tmpdir)
758
759   def testEmpty(self):
760     filename = utils.PathJoin(self.tmpdir, "config.data")
761     utils.WriteFile(filename, data="")
762     bname = utils.CreateBackup(filename)
763     self.assertFileContent(bname, "")
764     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
765     utils.CreateBackup(filename)
766     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
767     utils.CreateBackup(filename)
768     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
769
770     fifoname = utils.PathJoin(self.tmpdir, "fifo")
771     os.mkfifo(fifoname)
772     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
773
774   def testContent(self):
775     bkpcount = 0
776     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
777       for rep in [1, 2, 10, 127]:
778         testdata = data * rep
779
780         filename = utils.PathJoin(self.tmpdir, "test.data_")
781         utils.WriteFile(filename, data=testdata)
782         self.assertFileContent(filename, testdata)
783
784         for _ in range(3):
785           bname = utils.CreateBackup(filename)
786           bkpcount += 1
787           self.assertFileContent(bname, testdata)
788           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
789
790
791 class TestFormatUnit(unittest.TestCase):
792   """Test case for the FormatUnit function"""
793
794   def testMiB(self):
795     self.assertEqual(FormatUnit(1, 'h'), '1M')
796     self.assertEqual(FormatUnit(100, 'h'), '100M')
797     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
798
799     self.assertEqual(FormatUnit(1, 'm'), '1')
800     self.assertEqual(FormatUnit(100, 'm'), '100')
801     self.assertEqual(FormatUnit(1023, 'm'), '1023')
802
803     self.assertEqual(FormatUnit(1024, 'm'), '1024')
804     self.assertEqual(FormatUnit(1536, 'm'), '1536')
805     self.assertEqual(FormatUnit(17133, 'm'), '17133')
806     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
807
808   def testGiB(self):
809     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
810     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
811     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
812     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
813
814     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
815     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
816     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
817     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
818
819     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
820     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
821     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
822
823   def testTiB(self):
824     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
825     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
826     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
827
828     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
829     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
830     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
831
832 class TestParseUnit(unittest.TestCase):
833   """Test case for the ParseUnit function"""
834
835   SCALES = (('', 1),
836             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
837             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
838             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
839
840   def testRounding(self):
841     self.assertEqual(ParseUnit('0'), 0)
842     self.assertEqual(ParseUnit('1'), 4)
843     self.assertEqual(ParseUnit('2'), 4)
844     self.assertEqual(ParseUnit('3'), 4)
845
846     self.assertEqual(ParseUnit('124'), 124)
847     self.assertEqual(ParseUnit('125'), 128)
848     self.assertEqual(ParseUnit('126'), 128)
849     self.assertEqual(ParseUnit('127'), 128)
850     self.assertEqual(ParseUnit('128'), 128)
851     self.assertEqual(ParseUnit('129'), 132)
852     self.assertEqual(ParseUnit('130'), 132)
853
854   def testFloating(self):
855     self.assertEqual(ParseUnit('0'), 0)
856     self.assertEqual(ParseUnit('0.5'), 4)
857     self.assertEqual(ParseUnit('1.75'), 4)
858     self.assertEqual(ParseUnit('1.99'), 4)
859     self.assertEqual(ParseUnit('2.00'), 4)
860     self.assertEqual(ParseUnit('2.01'), 4)
861     self.assertEqual(ParseUnit('3.99'), 4)
862     self.assertEqual(ParseUnit('4.00'), 4)
863     self.assertEqual(ParseUnit('4.01'), 8)
864     self.assertEqual(ParseUnit('1.5G'), 1536)
865     self.assertEqual(ParseUnit('1.8G'), 1844)
866     self.assertEqual(ParseUnit('8.28T'), 8682212)
867
868   def testSuffixes(self):
869     for sep in ('', ' ', '   ', "\t", "\t "):
870       for suffix, scale in TestParseUnit.SCALES:
871         for func in (lambda x: x, str.lower, str.upper):
872           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
873                            1024 * scale)
874
875   def testInvalidInput(self):
876     for sep in ('-', '_', ',', 'a'):
877       for suffix, _ in TestParseUnit.SCALES:
878         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
879
880     for suffix, _ in TestParseUnit.SCALES:
881       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
882
883
884 class TestSshKeys(testutils.GanetiTestCase):
885   """Test case for the AddAuthorizedKey function"""
886
887   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
888   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
889            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
890
891   def setUp(self):
892     testutils.GanetiTestCase.setUp(self)
893     self.tmpname = self._CreateTempFile()
894     handle = open(self.tmpname, 'w')
895     try:
896       handle.write("%s\n" % TestSshKeys.KEY_A)
897       handle.write("%s\n" % TestSshKeys.KEY_B)
898     finally:
899       handle.close()
900
901   def testAddingNewKey(self):
902     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
903
904     self.assertFileContent(self.tmpname,
905       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
906       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
907       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
908       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
909
910   def testAddingAlmostButNotCompletelyTheSameKey(self):
911     AddAuthorizedKey(self.tmpname,
912         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
913
914     self.assertFileContent(self.tmpname,
915       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
916       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
917       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
918       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
919
920   def testAddingExistingKeyWithSomeMoreSpaces(self):
921     AddAuthorizedKey(self.tmpname,
922         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
923
924     self.assertFileContent(self.tmpname,
925       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
926       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
927       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
928
929   def testRemovingExistingKeyWithSomeMoreSpaces(self):
930     RemoveAuthorizedKey(self.tmpname,
931         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
932
933     self.assertFileContent(self.tmpname,
934       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
935       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
936
937   def testRemovingNonExistingKey(self):
938     RemoveAuthorizedKey(self.tmpname,
939         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
940
941     self.assertFileContent(self.tmpname,
942       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
943       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
944       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
945
946
947 class TestEtcHosts(testutils.GanetiTestCase):
948   """Test functions modifying /etc/hosts"""
949
950   def setUp(self):
951     testutils.GanetiTestCase.setUp(self)
952     self.tmpname = self._CreateTempFile()
953     handle = open(self.tmpname, 'w')
954     try:
955       handle.write('# This is a test file for /etc/hosts\n')
956       handle.write('127.0.0.1\tlocalhost\n')
957       handle.write('192.168.1.1 router gw\n')
958     finally:
959       handle.close()
960
961   def testSettingNewIp(self):
962     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
963
964     self.assertFileContent(self.tmpname,
965       "# This is a test file for /etc/hosts\n"
966       "127.0.0.1\tlocalhost\n"
967       "192.168.1.1 router gw\n"
968       "1.2.3.4\tmyhost.domain.tld myhost\n")
969     self.assertFileMode(self.tmpname, 0644)
970
971   def testSettingExistingIp(self):
972     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
973                      ['myhost'])
974
975     self.assertFileContent(self.tmpname,
976       "# This is a test file for /etc/hosts\n"
977       "127.0.0.1\tlocalhost\n"
978       "192.168.1.1\tmyhost.domain.tld myhost\n")
979     self.assertFileMode(self.tmpname, 0644)
980
981   def testSettingDuplicateName(self):
982     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
983
984     self.assertFileContent(self.tmpname,
985       "# This is a test file for /etc/hosts\n"
986       "127.0.0.1\tlocalhost\n"
987       "192.168.1.1 router gw\n"
988       "1.2.3.4\tmyhost\n")
989     self.assertFileMode(self.tmpname, 0644)
990
991   def testRemovingExistingHost(self):
992     RemoveEtcHostsEntry(self.tmpname, 'router')
993
994     self.assertFileContent(self.tmpname,
995       "# This is a test file for /etc/hosts\n"
996       "127.0.0.1\tlocalhost\n"
997       "192.168.1.1 gw\n")
998     self.assertFileMode(self.tmpname, 0644)
999
1000   def testRemovingSingleExistingHost(self):
1001     RemoveEtcHostsEntry(self.tmpname, 'localhost')
1002
1003     self.assertFileContent(self.tmpname,
1004       "# This is a test file for /etc/hosts\n"
1005       "192.168.1.1 router gw\n")
1006     self.assertFileMode(self.tmpname, 0644)
1007
1008   def testRemovingNonExistingHost(self):
1009     RemoveEtcHostsEntry(self.tmpname, 'myhost')
1010
1011     self.assertFileContent(self.tmpname,
1012       "# This is a test file for /etc/hosts\n"
1013       "127.0.0.1\tlocalhost\n"
1014       "192.168.1.1 router gw\n")
1015     self.assertFileMode(self.tmpname, 0644)
1016
1017   def testRemovingAlias(self):
1018     RemoveEtcHostsEntry(self.tmpname, 'gw')
1019
1020     self.assertFileContent(self.tmpname,
1021       "# This is a test file for /etc/hosts\n"
1022       "127.0.0.1\tlocalhost\n"
1023       "192.168.1.1 router\n")
1024     self.assertFileMode(self.tmpname, 0644)
1025
1026
1027 class TestShellQuoting(unittest.TestCase):
1028   """Test case for shell quoting functions"""
1029
1030   def testShellQuote(self):
1031     self.assertEqual(ShellQuote('abc'), "abc")
1032     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1033     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1034     self.assertEqual(ShellQuote("a b c"), "'a b c'")
1035     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1036
1037   def testShellQuoteArgs(self):
1038     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1039     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1040     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1041
1042
1043 class TestTcpPing(unittest.TestCase):
1044   """Testcase for TCP version of ping - against listen(2)ing port"""
1045
1046   def setUp(self):
1047     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1048     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1049     self.listenerport = self.listener.getsockname()[1]
1050     self.listener.listen(1)
1051
1052   def tearDown(self):
1053     self.listener.shutdown(socket.SHUT_RDWR)
1054     del self.listener
1055     del self.listenerport
1056
1057   def testTcpPingToLocalHostAccept(self):
1058     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1059                          self.listenerport,
1060                          timeout=10,
1061                          live_port_needed=True,
1062                          source=constants.LOCALHOST_IP_ADDRESS,
1063                          ),
1064                  "failed to connect to test listener")
1065
1066     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1067                          self.listenerport,
1068                          timeout=10,
1069                          live_port_needed=True,
1070                          ),
1071                  "failed to connect to test listener (no source)")
1072
1073
1074 class TestTcpPingDeaf(unittest.TestCase):
1075   """Testcase for TCP version of ping - against non listen(2)ing port"""
1076
1077   def setUp(self):
1078     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1079     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1080     self.deaflistenerport = self.deaflistener.getsockname()[1]
1081
1082   def tearDown(self):
1083     del self.deaflistener
1084     del self.deaflistenerport
1085
1086   def testTcpPingToLocalHostAcceptDeaf(self):
1087     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1088                         self.deaflistenerport,
1089                         timeout=constants.TCP_PING_TIMEOUT,
1090                         live_port_needed=True,
1091                         source=constants.LOCALHOST_IP_ADDRESS,
1092                         ), # need successful connect(2)
1093                 "successfully connected to deaf listener")
1094
1095     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1096                         self.deaflistenerport,
1097                         timeout=constants.TCP_PING_TIMEOUT,
1098                         live_port_needed=True,
1099                         ), # need successful connect(2)
1100                 "successfully connected to deaf listener (no source addr)")
1101
1102   def testTcpPingToLocalHostNoAccept(self):
1103     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1104                          self.deaflistenerport,
1105                          timeout=constants.TCP_PING_TIMEOUT,
1106                          live_port_needed=False,
1107                          source=constants.LOCALHOST_IP_ADDRESS,
1108                          ), # ECONNREFUSED is OK
1109                  "failed to ping alive host on deaf port")
1110
1111     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1112                          self.deaflistenerport,
1113                          timeout=constants.TCP_PING_TIMEOUT,
1114                          live_port_needed=False,
1115                          ), # ECONNREFUSED is OK
1116                  "failed to ping alive host on deaf port (no source addr)")
1117
1118
1119 class TestOwnIpAddress(unittest.TestCase):
1120   """Testcase for OwnIpAddress"""
1121
1122   def testOwnLoopback(self):
1123     """check having the loopback ip"""
1124     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1125                     "Should own the loopback address")
1126
1127   def testNowOwnAddress(self):
1128     """check that I don't own an address"""
1129
1130     # network 192.0.2.0/24 is reserved for test/documentation as per
1131     # rfc 3330, so we *should* not have an address of this range... if
1132     # this fails, we should extend the test to multiple addresses
1133     DST_IP = "192.0.2.1"
1134     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1135
1136
1137 def _GetSocketCredentials(path):
1138   """Connect to a Unix socket and return remote credentials.
1139
1140   """
1141   sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1142   try:
1143     sock.settimeout(10)
1144     sock.connect(path)
1145     return utils.GetSocketCredentials(sock)
1146   finally:
1147     sock.close()
1148
1149
1150 class TestGetSocketCredentials(unittest.TestCase):
1151   def setUp(self):
1152     self.tmpdir = tempfile.mkdtemp()
1153     self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1154
1155     self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1156     self.listener.settimeout(10)
1157     self.listener.bind(self.sockpath)
1158     self.listener.listen(1)
1159
1160   def tearDown(self):
1161     self.listener.shutdown(socket.SHUT_RDWR)
1162     self.listener.close()
1163     shutil.rmtree(self.tmpdir)
1164
1165   def test(self):
1166     (c2pr, c2pw) = os.pipe()
1167
1168     # Start child process
1169     child = os.fork()
1170     if child == 0:
1171       try:
1172         data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1173
1174         os.write(c2pw, data)
1175         os.close(c2pw)
1176
1177         os._exit(0)
1178       finally:
1179         os._exit(1)
1180
1181     os.close(c2pw)
1182
1183     # Wait for one connection
1184     (conn, _) = self.listener.accept()
1185     conn.recv(1)
1186     conn.close()
1187
1188     # Wait for result
1189     result = os.read(c2pr, 4096)
1190     os.close(c2pr)
1191
1192     # Check child's exit code
1193     (_, status) = os.waitpid(child, 0)
1194     self.assertFalse(os.WIFSIGNALED(status))
1195     self.assertEqual(os.WEXITSTATUS(status), 0)
1196
1197     # Check result
1198     (pid, uid, gid) = serializer.LoadJson(result)
1199     self.assertEqual(pid, os.getpid())
1200     self.assertEqual(uid, os.getuid())
1201     self.assertEqual(gid, os.getgid())
1202
1203
1204 class TestListVisibleFiles(unittest.TestCase):
1205   """Test case for ListVisibleFiles"""
1206
1207   def setUp(self):
1208     self.path = tempfile.mkdtemp()
1209
1210   def tearDown(self):
1211     shutil.rmtree(self.path)
1212
1213   def _test(self, files, expected):
1214     # Sort a copy
1215     expected = expected[:]
1216     expected.sort()
1217
1218     for name in files:
1219       f = open(os.path.join(self.path, name), 'w')
1220       try:
1221         f.write("Test\n")
1222       finally:
1223         f.close()
1224
1225     found = ListVisibleFiles(self.path)
1226     found.sort()
1227
1228     self.assertEqual(found, expected)
1229
1230   def testAllVisible(self):
1231     files = ["a", "b", "c"]
1232     expected = files
1233     self._test(files, expected)
1234
1235   def testNoneVisible(self):
1236     files = [".a", ".b", ".c"]
1237     expected = []
1238     self._test(files, expected)
1239
1240   def testSomeVisible(self):
1241     files = ["a", "b", ".c"]
1242     expected = ["a", "b"]
1243     self._test(files, expected)
1244
1245   def testNonAbsolutePath(self):
1246     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1247
1248   def testNonNormalizedPath(self):
1249     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1250                           "/bin/../tmp")
1251
1252
1253 class TestNewUUID(unittest.TestCase):
1254   """Test case for NewUUID"""
1255
1256   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1257                         '[a-f0-9]{4}-[a-f0-9]{12}$')
1258
1259   def runTest(self):
1260     self.failUnless(self._re_uuid.match(utils.NewUUID()))
1261
1262
1263 class TestUniqueSequence(unittest.TestCase):
1264   """Test case for UniqueSequence"""
1265
1266   def _test(self, input, expected):
1267     self.assertEqual(utils.UniqueSequence(input), expected)
1268
1269   def runTest(self):
1270     # Ordered input
1271     self._test([1, 2, 3], [1, 2, 3])
1272     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1273     self._test([1, 2, 2, 3], [1, 2, 3])
1274     self._test([1, 2, 3, 3], [1, 2, 3])
1275
1276     # Unordered input
1277     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1278     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1279
1280     # Strings
1281     self._test(["a", "a"], ["a"])
1282     self._test(["a", "b"], ["a", "b"])
1283     self._test(["a", "b", "a"], ["a", "b"])
1284
1285
1286 class TestFirstFree(unittest.TestCase):
1287   """Test case for the FirstFree function"""
1288
1289   def test(self):
1290     """Test FirstFree"""
1291     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1292     self.failUnlessEqual(FirstFree([]), None)
1293     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1294     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1295     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1296
1297
1298 class TestTailFile(testutils.GanetiTestCase):
1299   """Test case for the TailFile function"""
1300
1301   def testEmpty(self):
1302     fname = self._CreateTempFile()
1303     self.failUnlessEqual(TailFile(fname), [])
1304     self.failUnlessEqual(TailFile(fname, lines=25), [])
1305
1306   def testAllLines(self):
1307     data = ["test %d" % i for i in range(30)]
1308     for i in range(30):
1309       fname = self._CreateTempFile()
1310       fd = open(fname, "w")
1311       fd.write("\n".join(data[:i]))
1312       if i > 0:
1313         fd.write("\n")
1314       fd.close()
1315       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1316
1317   def testPartialLines(self):
1318     data = ["test %d" % i for i in range(30)]
1319     fname = self._CreateTempFile()
1320     fd = open(fname, "w")
1321     fd.write("\n".join(data))
1322     fd.write("\n")
1323     fd.close()
1324     for i in range(1, 30):
1325       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1326
1327   def testBigFile(self):
1328     data = ["test %d" % i for i in range(30)]
1329     fname = self._CreateTempFile()
1330     fd = open(fname, "w")
1331     fd.write("X" * 1048576)
1332     fd.write("\n")
1333     fd.write("\n".join(data))
1334     fd.write("\n")
1335     fd.close()
1336     for i in range(1, 30):
1337       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1338
1339
1340 class _BaseFileLockTest:
1341   """Test case for the FileLock class"""
1342
1343   def testSharedNonblocking(self):
1344     self.lock.Shared(blocking=False)
1345     self.lock.Close()
1346
1347   def testExclusiveNonblocking(self):
1348     self.lock.Exclusive(blocking=False)
1349     self.lock.Close()
1350
1351   def testUnlockNonblocking(self):
1352     self.lock.Unlock(blocking=False)
1353     self.lock.Close()
1354
1355   def testSharedBlocking(self):
1356     self.lock.Shared(blocking=True)
1357     self.lock.Close()
1358
1359   def testExclusiveBlocking(self):
1360     self.lock.Exclusive(blocking=True)
1361     self.lock.Close()
1362
1363   def testUnlockBlocking(self):
1364     self.lock.Unlock(blocking=True)
1365     self.lock.Close()
1366
1367   def testSharedExclusiveUnlock(self):
1368     self.lock.Shared(blocking=False)
1369     self.lock.Exclusive(blocking=False)
1370     self.lock.Unlock(blocking=False)
1371     self.lock.Close()
1372
1373   def testExclusiveSharedUnlock(self):
1374     self.lock.Exclusive(blocking=False)
1375     self.lock.Shared(blocking=False)
1376     self.lock.Unlock(blocking=False)
1377     self.lock.Close()
1378
1379   def testSimpleTimeout(self):
1380     # These will succeed on the first attempt, hence a short timeout
1381     self.lock.Shared(blocking=True, timeout=10.0)
1382     self.lock.Exclusive(blocking=False, timeout=10.0)
1383     self.lock.Unlock(blocking=True, timeout=10.0)
1384     self.lock.Close()
1385
1386   @staticmethod
1387   def _TryLockInner(filename, shared, blocking):
1388     lock = utils.FileLock.Open(filename)
1389
1390     if shared:
1391       fn = lock.Shared
1392     else:
1393       fn = lock.Exclusive
1394
1395     try:
1396       # The timeout doesn't really matter as the parent process waits for us to
1397       # finish anyway.
1398       fn(blocking=blocking, timeout=0.01)
1399     except errors.LockError, err:
1400       return False
1401
1402     return True
1403
1404   def _TryLock(self, *args):
1405     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1406                                       *args)
1407
1408   def testTimeout(self):
1409     for blocking in [True, False]:
1410       self.lock.Exclusive(blocking=True)
1411       self.failIf(self._TryLock(False, blocking))
1412       self.failIf(self._TryLock(True, blocking))
1413
1414       self.lock.Shared(blocking=True)
1415       self.assert_(self._TryLock(True, blocking))
1416       self.failIf(self._TryLock(False, blocking))
1417
1418   def testCloseShared(self):
1419     self.lock.Close()
1420     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1421
1422   def testCloseExclusive(self):
1423     self.lock.Close()
1424     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1425
1426   def testCloseUnlock(self):
1427     self.lock.Close()
1428     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1429
1430
1431 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1432   TESTDATA = "Hello World\n" * 10
1433
1434   def setUp(self):
1435     testutils.GanetiTestCase.setUp(self)
1436
1437     self.tmpfile = tempfile.NamedTemporaryFile()
1438     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1439     self.lock = utils.FileLock.Open(self.tmpfile.name)
1440
1441     # Ensure "Open" didn't truncate file
1442     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1443
1444   def tearDown(self):
1445     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1446
1447     testutils.GanetiTestCase.tearDown(self)
1448
1449
1450 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1451   def setUp(self):
1452     self.tmpfile = tempfile.NamedTemporaryFile()
1453     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1454
1455
1456 class TestTimeFunctions(unittest.TestCase):
1457   """Test case for time functions"""
1458
1459   def runTest(self):
1460     self.assertEqual(utils.SplitTime(1), (1, 0))
1461     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1462     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1463     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1464     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1465     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1466     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1467     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1468
1469     self.assertRaises(AssertionError, utils.SplitTime, -1)
1470
1471     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1472     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1473     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1474
1475     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1476                      1218448917.481)
1477     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1478
1479     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1480     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1481     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1482     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1483     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1484
1485
1486 class FieldSetTestCase(unittest.TestCase):
1487   """Test case for FieldSets"""
1488
1489   def testSimpleMatch(self):
1490     f = utils.FieldSet("a", "b", "c", "def")
1491     self.failUnless(f.Matches("a"))
1492     self.failIf(f.Matches("d"), "Substring matched")
1493     self.failIf(f.Matches("defghi"), "Prefix string matched")
1494     self.failIf(f.NonMatching(["b", "c"]))
1495     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1496     self.failUnless(f.NonMatching(["a", "d"]))
1497
1498   def testRegexMatch(self):
1499     f = utils.FieldSet("a", "b([0-9]+)", "c")
1500     self.failUnless(f.Matches("b1"))
1501     self.failUnless(f.Matches("b99"))
1502     self.failIf(f.Matches("b/1"))
1503     self.failIf(f.NonMatching(["b12", "c"]))
1504     self.failUnless(f.NonMatching(["a", "1"]))
1505
1506 class TestForceDictType(unittest.TestCase):
1507   """Test case for ForceDictType"""
1508
1509   def setUp(self):
1510     self.key_types = {
1511       'a': constants.VTYPE_INT,
1512       'b': constants.VTYPE_BOOL,
1513       'c': constants.VTYPE_STRING,
1514       'd': constants.VTYPE_SIZE,
1515       }
1516
1517   def _fdt(self, dict, allowed_values=None):
1518     if allowed_values is None:
1519       ForceDictType(dict, self.key_types)
1520     else:
1521       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1522
1523     return dict
1524
1525   def testSimpleDict(self):
1526     self.assertEqual(self._fdt({}), {})
1527     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1528     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1529     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1530     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1531     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1532     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1533     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1534     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1535     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1536     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1537     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1538
1539   def testErrors(self):
1540     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1541     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1542     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1543     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1544
1545
1546 class TestIsNormAbsPath(unittest.TestCase):
1547   """Testing case for IsNormAbsPath"""
1548
1549   def _pathTestHelper(self, path, result):
1550     if result:
1551       self.assert_(IsNormAbsPath(path),
1552           "Path %s should result absolute and normalized" % path)
1553     else:
1554       self.assert_(not IsNormAbsPath(path),
1555           "Path %s should not result absolute and normalized" % path)
1556
1557   def testBase(self):
1558     self._pathTestHelper('/etc', True)
1559     self._pathTestHelper('/srv', True)
1560     self._pathTestHelper('etc', False)
1561     self._pathTestHelper('/etc/../root', False)
1562     self._pathTestHelper('/etc/', False)
1563
1564
1565 class TestSafeEncode(unittest.TestCase):
1566   """Test case for SafeEncode"""
1567
1568   def testAscii(self):
1569     for txt in [string.digits, string.letters, string.punctuation]:
1570       self.failUnlessEqual(txt, SafeEncode(txt))
1571
1572   def testDoubleEncode(self):
1573     for i in range(255):
1574       txt = SafeEncode(chr(i))
1575       self.failUnlessEqual(txt, SafeEncode(txt))
1576
1577   def testUnicode(self):
1578     # 1024 is high enough to catch non-direct ASCII mappings
1579     for i in range(1024):
1580       txt = SafeEncode(unichr(i))
1581       self.failUnlessEqual(txt, SafeEncode(txt))
1582
1583
1584 class TestFormatTime(unittest.TestCase):
1585   """Testing case for FormatTime"""
1586
1587   def testNone(self):
1588     self.failUnlessEqual(FormatTime(None), "N/A")
1589
1590   def testInvalid(self):
1591     self.failUnlessEqual(FormatTime(()), "N/A")
1592
1593   def testNow(self):
1594     # tests that we accept time.time input
1595     FormatTime(time.time())
1596     # tests that we accept int input
1597     FormatTime(int(time.time()))
1598
1599
1600 class RunInSeparateProcess(unittest.TestCase):
1601   def test(self):
1602     for exp in [True, False]:
1603       def _child():
1604         return exp
1605
1606       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1607
1608   def testArgs(self):
1609     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1610       def _child(carg1, carg2):
1611         return carg1 == "Foo" and carg2 == arg
1612
1613       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1614
1615   def testPid(self):
1616     parent_pid = os.getpid()
1617
1618     def _check():
1619       return os.getpid() == parent_pid
1620
1621     self.failIf(utils.RunInSeparateProcess(_check))
1622
1623   def testSignal(self):
1624     def _kill():
1625       os.kill(os.getpid(), signal.SIGTERM)
1626
1627     self.assertRaises(errors.GenericError,
1628                       utils.RunInSeparateProcess, _kill)
1629
1630   def testException(self):
1631     def _exc():
1632       raise errors.GenericError("This is a test")
1633
1634     self.assertRaises(errors.GenericError,
1635                       utils.RunInSeparateProcess, _exc)
1636
1637
1638 class TestFingerprintFile(unittest.TestCase):
1639   def setUp(self):
1640     self.tmpfile = tempfile.NamedTemporaryFile()
1641
1642   def test(self):
1643     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1644                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1645
1646     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1647     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1648                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1649
1650
1651 class TestUnescapeAndSplit(unittest.TestCase):
1652   """Testing case for UnescapeAndSplit"""
1653
1654   def setUp(self):
1655     # testing more that one separator for regexp safety
1656     self._seps = [",", "+", "."]
1657
1658   def testSimple(self):
1659     a = ["a", "b", "c", "d"]
1660     for sep in self._seps:
1661       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1662
1663   def testEscape(self):
1664     for sep in self._seps:
1665       a = ["a", "b\\" + sep + "c", "d"]
1666       b = ["a", "b" + sep + "c", "d"]
1667       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1668
1669   def testDoubleEscape(self):
1670     for sep in self._seps:
1671       a = ["a", "b\\\\", "c", "d"]
1672       b = ["a", "b\\", "c", "d"]
1673       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1674
1675   def testThreeEscape(self):
1676     for sep in self._seps:
1677       a = ["a", "b\\\\\\" + sep + "c", "d"]
1678       b = ["a", "b\\" + sep + "c", "d"]
1679       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1680
1681
1682 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1683   def setUp(self):
1684     self.tmpdir = tempfile.mkdtemp()
1685
1686   def tearDown(self):
1687     shutil.rmtree(self.tmpdir)
1688
1689   def _checkRsaPrivateKey(self, key):
1690     lines = key.splitlines()
1691     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1692             "-----END RSA PRIVATE KEY-----" in lines)
1693
1694   def _checkCertificate(self, cert):
1695     lines = cert.splitlines()
1696     return ("-----BEGIN CERTIFICATE-----" in lines and
1697             "-----END CERTIFICATE-----" in lines)
1698
1699   def test(self):
1700     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1701       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1702       self._checkRsaPrivateKey(key_pem)
1703       self._checkCertificate(cert_pem)
1704
1705       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1706                                            key_pem)
1707       self.assert_(key.bits() >= 1024)
1708       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1709       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1710
1711       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1712                                              cert_pem)
1713       self.failIf(x509.has_expired())
1714       self.assertEqual(x509.get_issuer().CN, common_name)
1715       self.assertEqual(x509.get_subject().CN, common_name)
1716       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1717
1718   def testLegacy(self):
1719     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1720
1721     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1722
1723     cert1 = utils.ReadFile(cert1_filename)
1724
1725     self.assert_(self._checkRsaPrivateKey(cert1))
1726     self.assert_(self._checkCertificate(cert1))
1727
1728
1729 class TestPathJoin(unittest.TestCase):
1730   """Testing case for PathJoin"""
1731
1732   def testBasicItems(self):
1733     mlist = ["/a", "b", "c"]
1734     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1735
1736   def testNonAbsPrefix(self):
1737     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1738
1739   def testBackTrack(self):
1740     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1741
1742   def testMultiAbs(self):
1743     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1744
1745
1746 class TestHostInfo(unittest.TestCase):
1747   """Testing case for HostInfo"""
1748
1749   def testUppercase(self):
1750     data = "AbC.example.com"
1751     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1752
1753   def testTooLongName(self):
1754     data = "a.b." + "c" * 255
1755     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1756
1757   def testTrailingDot(self):
1758     data = "a.b.c"
1759     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1760
1761   def testInvalidName(self):
1762     data = [
1763       "a b",
1764       "a/b",
1765       ".a.b",
1766       "a..b",
1767       ]
1768     for value in data:
1769       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1770
1771   def testValidName(self):
1772     data = [
1773       "a.b",
1774       "a-b",
1775       "a_b",
1776       "a.b.c",
1777       ]
1778     for value in data:
1779       HostInfo.NormalizeName(value)
1780
1781
1782 class TestParseAsn1Generalizedtime(unittest.TestCase):
1783   def test(self):
1784     # UTC
1785     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1786     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1787                      1266860512)
1788     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1789                      (2**31) - 1)
1790
1791     # With offset
1792     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1793                      1266860512)
1794     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1795                      1266931012)
1796     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1797                      1266931088)
1798     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1799                      1266931295)
1800     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1801                      3600)
1802
1803     # Leap seconds are not supported by datetime.datetime
1804     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1805                       "19841231235960+0000")
1806     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1807                       "19920630235960+0000")
1808
1809     # Errors
1810     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1811     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1812     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1813                       "20100222174152")
1814     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1815                       "Mon Feb 22 17:47:02 UTC 2010")
1816     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1817                       "2010-02-22 17:42:02")
1818
1819
1820 class TestGetX509CertValidity(testutils.GanetiTestCase):
1821   def setUp(self):
1822     testutils.GanetiTestCase.setUp(self)
1823
1824     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1825
1826     # Test whether we have pyOpenSSL 0.7 or above
1827     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1828
1829     if not self.pyopenssl0_7:
1830       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1831                     " function correctly")
1832
1833   def _LoadCert(self, name):
1834     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1835                                            self._ReadTestData(name))
1836
1837   def test(self):
1838     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1839     if self.pyopenssl0_7:
1840       self.assertEqual(validity, (1266919967, 1267524767))
1841     else:
1842       self.assertEqual(validity, (None, None))
1843
1844
1845 class TestSignX509Certificate(unittest.TestCase):
1846   KEY = "My private key!"
1847   KEY_OTHER = "Another key"
1848
1849   def test(self):
1850     # Generate certificate valid for 5 minutes
1851     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1852
1853     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1854                                            cert_pem)
1855
1856     # No signature at all
1857     self.assertRaises(errors.GenericError,
1858                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1859
1860     # Invalid input
1861     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1862                       "", self.KEY)
1863     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1864                       "X-Ganeti-Signature: \n", self.KEY)
1865     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1866                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1867     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1868                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1869     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1870                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1871
1872     # Invalid salt
1873     for salt in list("-_@$,:;/\\ \t\n"):
1874       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1875                         cert_pem, self.KEY, "foo%sbar" % salt)
1876
1877     for salt in ["HelloWorld", "salt", string.letters, string.digits,
1878                  utils.GenerateSecret(numbytes=4),
1879                  utils.GenerateSecret(numbytes=16),
1880                  "{123:456}".encode("hex")]:
1881       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1882
1883       self._Check(cert, salt, signed_pem)
1884
1885       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1886       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1887       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1888                                "lines----\n------ at\nthe end!"))
1889
1890   def _Check(self, cert, salt, pem):
1891     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1892     self.assertEqual(salt, salt2)
1893     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1894
1895     # Other key
1896     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1897                       pem, self.KEY_OTHER)
1898
1899
1900 class TestMakedirs(unittest.TestCase):
1901   def setUp(self):
1902     self.tmpdir = tempfile.mkdtemp()
1903
1904   def tearDown(self):
1905     shutil.rmtree(self.tmpdir)
1906
1907   def testNonExisting(self):
1908     path = utils.PathJoin(self.tmpdir, "foo")
1909     utils.Makedirs(path)
1910     self.assert_(os.path.isdir(path))
1911
1912   def testExisting(self):
1913     path = utils.PathJoin(self.tmpdir, "foo")
1914     os.mkdir(path)
1915     utils.Makedirs(path)
1916     self.assert_(os.path.isdir(path))
1917
1918   def testRecursiveNonExisting(self):
1919     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
1920     utils.Makedirs(path)
1921     self.assert_(os.path.isdir(path))
1922
1923   def testRecursiveExisting(self):
1924     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
1925     self.assert_(not os.path.exists(path))
1926     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
1927     utils.Makedirs(path)
1928     self.assert_(os.path.isdir(path))
1929
1930
1931 class TestRetry(testutils.GanetiTestCase):
1932   def setUp(self):
1933     testutils.GanetiTestCase.setUp(self)
1934     self.retries = 0
1935
1936   @staticmethod
1937   def _RaiseRetryAgain():
1938     raise utils.RetryAgain()
1939
1940   @staticmethod
1941   def _RaiseRetryAgainWithArg(args):
1942     raise utils.RetryAgain(*args)
1943
1944   def _WrongNestedLoop(self):
1945     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1946
1947   def _RetryAndSucceed(self, retries):
1948     if self.retries < retries:
1949       self.retries += 1
1950       raise utils.RetryAgain()
1951     else:
1952       return True
1953
1954   def testRaiseTimeout(self):
1955     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1956                           self._RaiseRetryAgain, 0.01, 0.02)
1957     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1958                           self._RetryAndSucceed, 0.01, 0, args=[1])
1959     self.failUnlessEqual(self.retries, 1)
1960
1961   def testComplete(self):
1962     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1963     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1964                          True)
1965     self.failUnlessEqual(self.retries, 2)
1966
1967   def testNestedLoop(self):
1968     try:
1969       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1970                             self._WrongNestedLoop, 0, 1)
1971     except utils.RetryTimeout:
1972       self.fail("Didn't detect inner loop's exception")
1973
1974   def testTimeoutArgument(self):
1975     retry_arg="my_important_debugging_message"
1976     try:
1977       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
1978     except utils.RetryTimeout, err:
1979       self.failUnlessEqual(err.args, (retry_arg, ))
1980     else:
1981       self.fail("Expected timeout didn't happen")
1982
1983   def testRaiseInnerWithExc(self):
1984     retry_arg="my_important_debugging_message"
1985     try:
1986       try:
1987         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1988                     args=[[errors.GenericError(retry_arg, retry_arg)]])
1989       except utils.RetryTimeout, err:
1990         err.RaiseInner()
1991       else:
1992         self.fail("Expected timeout didn't happen")
1993     except errors.GenericError, err:
1994       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1995     else:
1996       self.fail("Expected GenericError didn't happen")
1997
1998   def testRaiseInnerWithMsg(self):
1999     retry_arg="my_important_debugging_message"
2000     try:
2001       try:
2002         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2003                     args=[[retry_arg, retry_arg]])
2004       except utils.RetryTimeout, err:
2005         err.RaiseInner()
2006       else:
2007         self.fail("Expected timeout didn't happen")
2008     except utils.RetryTimeout, err:
2009       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2010     else:
2011       self.fail("Expected RetryTimeout didn't happen")
2012
2013
2014 class TestLineSplitter(unittest.TestCase):
2015   def test(self):
2016     lines = []
2017     ls = utils.LineSplitter(lines.append)
2018     ls.write("Hello World\n")
2019     self.assertEqual(lines, [])
2020     ls.write("Foo\n Bar\r\n ")
2021     ls.write("Baz")
2022     ls.write("Moo")
2023     self.assertEqual(lines, [])
2024     ls.flush()
2025     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2026     ls.close()
2027     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2028
2029   def _testExtra(self, line, all_lines, p1, p2):
2030     self.assertEqual(p1, 999)
2031     self.assertEqual(p2, "extra")
2032     all_lines.append(line)
2033
2034   def testExtraArgsNoFlush(self):
2035     lines = []
2036     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2037     ls.write("\n\nHello World\n")
2038     ls.write("Foo\n Bar\r\n ")
2039     ls.write("")
2040     ls.write("Baz")
2041     ls.write("Moo\n\nx\n")
2042     self.assertEqual(lines, [])
2043     ls.close()
2044     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2045                              "", "x"])
2046
2047
2048 class TestReadLockedPidFile(unittest.TestCase):
2049   def setUp(self):
2050     self.tmpdir = tempfile.mkdtemp()
2051
2052   def tearDown(self):
2053     shutil.rmtree(self.tmpdir)
2054
2055   def testNonExistent(self):
2056     path = utils.PathJoin(self.tmpdir, "nonexist")
2057     self.assert_(utils.ReadLockedPidFile(path) is None)
2058
2059   def testUnlocked(self):
2060     path = utils.PathJoin(self.tmpdir, "pid")
2061     utils.WriteFile(path, data="123")
2062     self.assert_(utils.ReadLockedPidFile(path) is None)
2063
2064   def testLocked(self):
2065     path = utils.PathJoin(self.tmpdir, "pid")
2066     utils.WriteFile(path, data="123")
2067
2068     fl = utils.FileLock.Open(path)
2069     try:
2070       fl.Exclusive(blocking=True)
2071
2072       self.assertEqual(utils.ReadLockedPidFile(path), 123)
2073     finally:
2074       fl.Close()
2075
2076     self.assert_(utils.ReadLockedPidFile(path) is None)
2077
2078   def testError(self):
2079     path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2080     utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2081     # open(2) should return ENOTDIR
2082     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2083
2084
2085 class TestCertVerification(testutils.GanetiTestCase):
2086   def setUp(self):
2087     testutils.GanetiTestCase.setUp(self)
2088
2089     self.tmpdir = tempfile.mkdtemp()
2090
2091   def tearDown(self):
2092     shutil.rmtree(self.tmpdir)
2093
2094   def testVerifyCertificate(self):
2095     cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2096     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2097                                            cert_pem)
2098
2099     # Not checking return value as this certificate is expired
2100     utils.VerifyX509Certificate(cert, 30, 7)
2101
2102
2103 class TestVerifyCertificateInner(unittest.TestCase):
2104   def test(self):
2105     vci = utils._VerifyCertificateInner
2106
2107     # Valid
2108     self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2109                      (None, None))
2110
2111     # Not yet valid
2112     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2113     self.assertEqual(errcode, utils.CERT_WARNING)
2114
2115     # Expiring soon
2116     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2117     self.assertEqual(errcode, utils.CERT_ERROR)
2118
2119     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2120     self.assertEqual(errcode, utils.CERT_WARNING)
2121
2122     (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2123     self.assertEqual(errcode, None)
2124
2125     # Expired
2126     (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2127     self.assertEqual(errcode, utils.CERT_ERROR)
2128
2129     (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2130     self.assertEqual(errcode, utils.CERT_ERROR)
2131
2132     (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2133     self.assertEqual(errcode, utils.CERT_ERROR)
2134
2135     (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2136     self.assertEqual(errcode, utils.CERT_ERROR)
2137
2138
2139 class TestHmacFunctions(unittest.TestCase):
2140   # Digests can be checked with "openssl sha1 -hmac $key"
2141   def testSha1Hmac(self):
2142     self.assertEqual(utils.Sha1Hmac("", ""),
2143                      "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2144     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2145                      "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2146     self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2147                      "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2148
2149     longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2150     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2151                      "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2152
2153   def testSha1HmacSalt(self):
2154     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2155                      "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2156     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2157                      "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2158     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2159                      "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2160
2161   def testVerifySha1Hmac(self):
2162     self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2163                                                "7d64b71fb76370690e1d")))
2164     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2165                                       ("f904c2476527c6d3e660"
2166                                        "9ab683c66fa0652cb1dc")))
2167
2168     digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2169     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2170     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2171                                       digest.lower()))
2172     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2173                                       digest.upper()))
2174     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2175                                       digest.title()))
2176
2177   def testVerifySha1HmacSalt(self):
2178     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2179                                       ("17a4adc34d69c0d367d4"
2180                                        "ffbef96fd41d4df7a6e8"),
2181                                       salt="abc9"))
2182     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2183                                       ("7f264f8114c9066afc9b"
2184                                        "b7636e1786d996d3cc0d"),
2185                                       salt="xyz0"))
2186
2187
2188 if __name__ == '__main__':
2189   testutils.GanetiTestProgram()