Merge branch 'devel-2.1'
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import fcntl
38 import OpenSSL
39 import warnings
40 import distutils.version
41 import glob
42
43 import ganeti
44 import testutils
45 from ganeti import constants
46 from ganeti import utils
47 from ganeti import errors
48 from ganeti import serializer
49 from ganeti.utils import IsProcessAlive, RunCmd, \
50      RemoveFile, MatchNameComponent, FormatUnit, \
51      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
52      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
53      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
54      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
55      UnescapeAndSplit, RunParts, PathJoin, HostInfo
56
57 from ganeti.errors import LockError, UnitParseError, GenericError, \
58      ProgrammerError, OpPrereqError
59
60
61 class TestIsProcessAlive(unittest.TestCase):
62   """Testing case for IsProcessAlive"""
63
64   def testExists(self):
65     mypid = os.getpid()
66     self.assert_(IsProcessAlive(mypid),
67                  "can't find myself running")
68
69   def testNotExisting(self):
70     pid_non_existing = os.fork()
71     if pid_non_existing == 0:
72       os._exit(0)
73     elif pid_non_existing < 0:
74       raise SystemError("can't fork")
75     os.waitpid(pid_non_existing, 0)
76     self.assert_(not IsProcessAlive(pid_non_existing),
77                  "nonexisting process detected")
78
79
80 class TestPidFileFunctions(unittest.TestCase):
81   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
82
83   def setUp(self):
84     self.dir = tempfile.mkdtemp()
85     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
86     utils.DaemonPidFileName = self.f_dpn
87
88   def testPidFileFunctions(self):
89     pid_file = self.f_dpn('test')
90     utils.WritePidFile('test')
91     self.failUnless(os.path.exists(pid_file),
92                     "PID file should have been created")
93     read_pid = utils.ReadPidFile(pid_file)
94     self.failUnlessEqual(read_pid, os.getpid())
95     self.failUnless(utils.IsProcessAlive(read_pid))
96     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
97     utils.RemovePidFile('test')
98     self.failIf(os.path.exists(pid_file),
99                 "PID file should not exist anymore")
100     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
101                          "ReadPidFile should return 0 for missing pid file")
102     fh = open(pid_file, "w")
103     fh.write("blah\n")
104     fh.close()
105     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
106                          "ReadPidFile should return 0 for invalid pid file")
107     utils.RemovePidFile('test')
108     self.failIf(os.path.exists(pid_file),
109                 "PID file should not exist anymore")
110
111   def testKill(self):
112     pid_file = self.f_dpn('child')
113     r_fd, w_fd = os.pipe()
114     new_pid = os.fork()
115     if new_pid == 0: #child
116       utils.WritePidFile('child')
117       os.write(w_fd, 'a')
118       signal.pause()
119       os._exit(0)
120       return
121     # else we are in the parent
122     # wait until the child has written the pid file
123     os.read(r_fd, 1)
124     read_pid = utils.ReadPidFile(pid_file)
125     self.failUnlessEqual(read_pid, new_pid)
126     self.failUnless(utils.IsProcessAlive(new_pid))
127     utils.KillProcess(new_pid, waitpid=True)
128     self.failIf(utils.IsProcessAlive(new_pid))
129     utils.RemovePidFile('child')
130     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
131
132   def tearDown(self):
133     for name in os.listdir(self.dir):
134       os.unlink(os.path.join(self.dir, name))
135     os.rmdir(self.dir)
136
137
138 class TestRunCmd(testutils.GanetiTestCase):
139   """Testing case for the RunCmd function"""
140
141   def setUp(self):
142     testutils.GanetiTestCase.setUp(self)
143     self.magic = time.ctime() + " ganeti test"
144     self.fname = self._CreateTempFile()
145
146   def testOk(self):
147     """Test successful exit code"""
148     result = RunCmd("/bin/sh -c 'exit 0'")
149     self.assertEqual(result.exit_code, 0)
150     self.assertEqual(result.output, "")
151
152   def testFail(self):
153     """Test fail exit code"""
154     result = RunCmd("/bin/sh -c 'exit 1'")
155     self.assertEqual(result.exit_code, 1)
156     self.assertEqual(result.output, "")
157
158   def testStdout(self):
159     """Test standard output"""
160     cmd = 'echo -n "%s"' % self.magic
161     result = RunCmd("/bin/sh -c '%s'" % cmd)
162     self.assertEqual(result.stdout, self.magic)
163     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
164     self.assertEqual(result.output, "")
165     self.assertFileContent(self.fname, self.magic)
166
167   def testStderr(self):
168     """Test standard error"""
169     cmd = 'echo -n "%s"' % self.magic
170     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
171     self.assertEqual(result.stderr, self.magic)
172     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
173     self.assertEqual(result.output, "")
174     self.assertFileContent(self.fname, self.magic)
175
176   def testCombined(self):
177     """Test combined output"""
178     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
179     expected = "A" + self.magic + "B" + self.magic
180     result = RunCmd("/bin/sh -c '%s'" % cmd)
181     self.assertEqual(result.output, expected)
182     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
183     self.assertEqual(result.output, "")
184     self.assertFileContent(self.fname, expected)
185
186   def testSignal(self):
187     """Test signal"""
188     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
189     self.assertEqual(result.signal, 15)
190     self.assertEqual(result.output, "")
191
192   def testListRun(self):
193     """Test list runs"""
194     result = RunCmd(["true"])
195     self.assertEqual(result.signal, None)
196     self.assertEqual(result.exit_code, 0)
197     result = RunCmd(["/bin/sh", "-c", "exit 1"])
198     self.assertEqual(result.signal, None)
199     self.assertEqual(result.exit_code, 1)
200     result = RunCmd(["echo", "-n", self.magic])
201     self.assertEqual(result.signal, None)
202     self.assertEqual(result.exit_code, 0)
203     self.assertEqual(result.stdout, self.magic)
204
205   def testFileEmptyOutput(self):
206     """Test file output"""
207     result = RunCmd(["true"], output=self.fname)
208     self.assertEqual(result.signal, None)
209     self.assertEqual(result.exit_code, 0)
210     self.assertFileContent(self.fname, "")
211
212   def testLang(self):
213     """Test locale environment"""
214     old_env = os.environ.copy()
215     try:
216       os.environ["LANG"] = "en_US.UTF-8"
217       os.environ["LC_ALL"] = "en_US.UTF-8"
218       result = RunCmd(["locale"])
219       for line in result.output.splitlines():
220         key, value = line.split("=", 1)
221         # Ignore these variables, they're overridden by LC_ALL
222         if key == "LANG" or key == "LANGUAGE":
223           continue
224         self.failIf(value and value != "C" and value != '"C"',
225             "Variable %s is set to the invalid value '%s'" % (key, value))
226     finally:
227       os.environ = old_env
228
229   def testDefaultCwd(self):
230     """Test default working directory"""
231     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
232
233   def testCwd(self):
234     """Test default working directory"""
235     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
236     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
237     cwd = os.getcwd()
238     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
239
240   def testResetEnv(self):
241     """Test environment reset functionality"""
242     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
243     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
244                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
245
246
247 class TestRunParts(unittest.TestCase):
248   """Testing case for the RunParts function"""
249
250   def setUp(self):
251     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
252
253   def tearDown(self):
254     shutil.rmtree(self.rundir)
255
256   def testEmpty(self):
257     """Test on an empty dir"""
258     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
259
260   def testSkipWrongName(self):
261     """Test that wrong files are skipped"""
262     fname = os.path.join(self.rundir, "00test.dot")
263     utils.WriteFile(fname, data="")
264     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
265     relname = os.path.basename(fname)
266     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
267                          [(relname, constants.RUNPARTS_SKIP, None)])
268
269   def testSkipNonExec(self):
270     """Test that non executable files are skipped"""
271     fname = os.path.join(self.rundir, "00test")
272     utils.WriteFile(fname, data="")
273     relname = os.path.basename(fname)
274     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
275                          [(relname, constants.RUNPARTS_SKIP, None)])
276
277   def testError(self):
278     """Test error on a broken executable"""
279     fname = os.path.join(self.rundir, "00test")
280     utils.WriteFile(fname, data="")
281     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
282     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
283     self.failUnlessEqual(relname, os.path.basename(fname))
284     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
285     self.failUnless(error)
286
287   def testSorted(self):
288     """Test executions are sorted"""
289     files = []
290     files.append(os.path.join(self.rundir, "64test"))
291     files.append(os.path.join(self.rundir, "00test"))
292     files.append(os.path.join(self.rundir, "42test"))
293
294     for fname in files:
295       utils.WriteFile(fname, data="")
296
297     results = RunParts(self.rundir, reset_env=True)
298
299     for fname in sorted(files):
300       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
301
302   def testOk(self):
303     """Test correct execution"""
304     fname = os.path.join(self.rundir, "00test")
305     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
306     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
307     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
308     self.failUnlessEqual(relname, os.path.basename(fname))
309     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
310     self.failUnlessEqual(runresult.stdout, "ciao")
311
312   def testRunFail(self):
313     """Test correct execution, with run failure"""
314     fname = os.path.join(self.rundir, "00test")
315     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
316     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
317     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
318     self.failUnlessEqual(relname, os.path.basename(fname))
319     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
320     self.failUnlessEqual(runresult.exit_code, 1)
321     self.failUnless(runresult.failed)
322
323   def testRunMix(self):
324     files = []
325     files.append(os.path.join(self.rundir, "00test"))
326     files.append(os.path.join(self.rundir, "42test"))
327     files.append(os.path.join(self.rundir, "64test"))
328     files.append(os.path.join(self.rundir, "99test"))
329
330     files.sort()
331
332     # 1st has errors in execution
333     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
334     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
335
336     # 2nd is skipped
337     utils.WriteFile(files[1], data="")
338
339     # 3rd cannot execute properly
340     utils.WriteFile(files[2], data="")
341     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
342
343     # 4th execs
344     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
345     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
346
347     results = RunParts(self.rundir, reset_env=True)
348
349     (relname, status, runresult) = results[0]
350     self.failUnlessEqual(relname, os.path.basename(files[0]))
351     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
352     self.failUnlessEqual(runresult.exit_code, 1)
353     self.failUnless(runresult.failed)
354
355     (relname, status, runresult) = results[1]
356     self.failUnlessEqual(relname, os.path.basename(files[1]))
357     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
358     self.failUnlessEqual(runresult, None)
359
360     (relname, status, runresult) = results[2]
361     self.failUnlessEqual(relname, os.path.basename(files[2]))
362     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
363     self.failUnless(runresult)
364
365     (relname, status, runresult) = results[3]
366     self.failUnlessEqual(relname, os.path.basename(files[3]))
367     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
368     self.failUnlessEqual(runresult.output, "ciao")
369     self.failUnlessEqual(runresult.exit_code, 0)
370     self.failUnless(not runresult.failed)
371
372
373 class TestStartDaemon(testutils.GanetiTestCase):
374   def setUp(self):
375     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
376     self.tmpfile = os.path.join(self.tmpdir, "test")
377
378   def tearDown(self):
379     shutil.rmtree(self.tmpdir)
380
381   def testShell(self):
382     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
383     self._wait(self.tmpfile, 60.0, "Hello World")
384
385   def testShellOutput(self):
386     utils.StartDaemon("echo Hello World", output=self.tmpfile)
387     self._wait(self.tmpfile, 60.0, "Hello World")
388
389   def testNoShellNoOutput(self):
390     utils.StartDaemon(["pwd"])
391
392   def testNoShellNoOutputTouch(self):
393     testfile = os.path.join(self.tmpdir, "check")
394     self.failIf(os.path.exists(testfile))
395     utils.StartDaemon(["touch", testfile])
396     self._wait(testfile, 60.0, "")
397
398   def testNoShellOutput(self):
399     utils.StartDaemon(["pwd"], output=self.tmpfile)
400     self._wait(self.tmpfile, 60.0, "/")
401
402   def testNoShellOutputCwd(self):
403     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
404     self._wait(self.tmpfile, 60.0, os.getcwd())
405
406   def testShellEnv(self):
407     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
408                       env={ "GNT_TEST_VAR": "Hello World", })
409     self._wait(self.tmpfile, 60.0, "Hello World")
410
411   def testNoShellEnv(self):
412     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
413                       env={ "GNT_TEST_VAR": "Hello World", })
414     self._wait(self.tmpfile, 60.0, "Hello World")
415
416   def testOutputFd(self):
417     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
418     try:
419       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
420     finally:
421       os.close(fd)
422     self._wait(self.tmpfile, 60.0, os.getcwd())
423
424   def testPid(self):
425     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
426     self._wait(self.tmpfile, 60.0, str(pid))
427
428   def testPidFile(self):
429     pidfile = os.path.join(self.tmpdir, "pid")
430     checkfile = os.path.join(self.tmpdir, "abort")
431
432     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
433                             output=self.tmpfile)
434     try:
435       fd = os.open(pidfile, os.O_RDONLY)
436       try:
437         # Check file is locked
438         self.assertRaises(errors.LockError, utils.LockFile, fd)
439
440         pidtext = os.read(fd, 100)
441       finally:
442         os.close(fd)
443
444       self.assertEqual(int(pidtext.strip()), pid)
445
446       self.assert_(utils.IsProcessAlive(pid))
447     finally:
448       # No matter what happens, kill daemon
449       utils.KillProcess(pid, timeout=5.0, waitpid=False)
450       self.failIf(utils.IsProcessAlive(pid))
451
452     self.assertEqual(utils.ReadFile(self.tmpfile), "")
453
454   def _wait(self, path, timeout, expected):
455     # Due to the asynchronous nature of daemon processes, polling is necessary.
456     # A timeout makes sure the test doesn't hang forever.
457     def _CheckFile():
458       if not (os.path.isfile(path) and
459               utils.ReadFile(path).strip() == expected):
460         raise utils.RetryAgain()
461
462     try:
463       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
464     except utils.RetryTimeout:
465       self.fail("Apparently the daemon didn't run in %s seconds and/or"
466                 " didn't write the correct output" % timeout)
467
468   def testError(self):
469     self.assertRaises(errors.OpExecError, utils.StartDaemon,
470                       ["./does-NOT-EXIST/here/0123456789"])
471     self.assertRaises(errors.OpExecError, utils.StartDaemon,
472                       ["./does-NOT-EXIST/here/0123456789"],
473                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
474     self.assertRaises(errors.OpExecError, utils.StartDaemon,
475                       ["./does-NOT-EXIST/here/0123456789"],
476                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
477     self.assertRaises(errors.OpExecError, utils.StartDaemon,
478                       ["./does-NOT-EXIST/here/0123456789"],
479                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
480
481     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
482     try:
483       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
484                         ["./does-NOT-EXIST/here/0123456789"],
485                         output=self.tmpfile, output_fd=fd)
486     finally:
487       os.close(fd)
488
489
490 class TestSetCloseOnExecFlag(unittest.TestCase):
491   """Tests for SetCloseOnExecFlag"""
492
493   def setUp(self):
494     self.tmpfile = tempfile.TemporaryFile()
495
496   def testEnable(self):
497     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
498     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
499                     fcntl.FD_CLOEXEC)
500
501   def testDisable(self):
502     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
503     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
504                 fcntl.FD_CLOEXEC)
505
506
507 class TestSetNonblockFlag(unittest.TestCase):
508   def setUp(self):
509     self.tmpfile = tempfile.TemporaryFile()
510
511   def testEnable(self):
512     utils.SetNonblockFlag(self.tmpfile.fileno(), True)
513     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
514                     os.O_NONBLOCK)
515
516   def testDisable(self):
517     utils.SetNonblockFlag(self.tmpfile.fileno(), False)
518     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
519                 os.O_NONBLOCK)
520
521
522 class TestRemoveFile(unittest.TestCase):
523   """Test case for the RemoveFile function"""
524
525   def setUp(self):
526     """Create a temp dir and file for each case"""
527     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
528     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
529     os.close(fd)
530
531   def tearDown(self):
532     if os.path.exists(self.tmpfile):
533       os.unlink(self.tmpfile)
534     os.rmdir(self.tmpdir)
535
536   def testIgnoreDirs(self):
537     """Test that RemoveFile() ignores directories"""
538     self.assertEqual(None, RemoveFile(self.tmpdir))
539
540   def testIgnoreNotExisting(self):
541     """Test that RemoveFile() ignores non-existing files"""
542     RemoveFile(self.tmpfile)
543     RemoveFile(self.tmpfile)
544
545   def testRemoveFile(self):
546     """Test that RemoveFile does remove a file"""
547     RemoveFile(self.tmpfile)
548     if os.path.exists(self.tmpfile):
549       self.fail("File '%s' not removed" % self.tmpfile)
550
551   def testRemoveSymlink(self):
552     """Test that RemoveFile does remove symlinks"""
553     symlink = self.tmpdir + "/symlink"
554     os.symlink("no-such-file", symlink)
555     RemoveFile(symlink)
556     if os.path.exists(symlink):
557       self.fail("File '%s' not removed" % symlink)
558     os.symlink(self.tmpfile, symlink)
559     RemoveFile(symlink)
560     if os.path.exists(symlink):
561       self.fail("File '%s' not removed" % symlink)
562
563
564 class TestRename(unittest.TestCase):
565   """Test case for RenameFile"""
566
567   def setUp(self):
568     """Create a temporary directory"""
569     self.tmpdir = tempfile.mkdtemp()
570     self.tmpfile = os.path.join(self.tmpdir, "test1")
571
572     # Touch the file
573     open(self.tmpfile, "w").close()
574
575   def tearDown(self):
576     """Remove temporary directory"""
577     shutil.rmtree(self.tmpdir)
578
579   def testSimpleRename1(self):
580     """Simple rename 1"""
581     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
582     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
583
584   def testSimpleRename2(self):
585     """Simple rename 2"""
586     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
587                      mkdir=True)
588     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
589
590   def testRenameMkdir(self):
591     """Rename with mkdir"""
592     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
593                      mkdir=True)
594     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
595     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
596
597     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
598                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
599                      mkdir=True)
600     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
601     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
602     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
603
604
605 class TestMatchNameComponent(unittest.TestCase):
606   """Test case for the MatchNameComponent function"""
607
608   def testEmptyList(self):
609     """Test that there is no match against an empty list"""
610
611     self.failUnlessEqual(MatchNameComponent("", []), None)
612     self.failUnlessEqual(MatchNameComponent("test", []), None)
613
614   def testSingleMatch(self):
615     """Test that a single match is performed correctly"""
616     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
617     for key in "test2", "test2.example", "test2.example.com":
618       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
619
620   def testMultipleMatches(self):
621     """Test that a multiple match is returned as None"""
622     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
623     for key in "test1", "test1.example":
624       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
625
626   def testFullMatch(self):
627     """Test that a full match is returned correctly"""
628     key1 = "test1"
629     key2 = "test1.example"
630     mlist = [key2, key2 + ".com"]
631     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
632     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
633
634   def testCaseInsensitivePartialMatch(self):
635     """Test for the case_insensitive keyword"""
636     mlist = ["test1.example.com", "test2.example.net"]
637     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
638                      "test2.example.net")
639     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
640                      "test2.example.net")
641     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
642                      "test2.example.net")
643     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
644                      "test2.example.net")
645
646
647   def testCaseInsensitiveFullMatch(self):
648     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
649     # Between the two ts1 a full string match non-case insensitive should work
650     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
651                      None)
652     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
653                      "ts1.ex")
654     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
655                      "ts1.ex")
656     # Between the two ts2 only case differs, so only case-match works
657     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
658                      "ts2.ex")
659     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
660                      "Ts2.ex")
661     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
662                      None)
663
664
665 class TestTimestampForFilename(unittest.TestCase):
666   def test(self):
667     self.assert_("." not in utils.TimestampForFilename())
668     self.assert_(":" not in utils.TimestampForFilename())
669
670
671 class TestCreateBackup(testutils.GanetiTestCase):
672   def setUp(self):
673     testutils.GanetiTestCase.setUp(self)
674
675     self.tmpdir = tempfile.mkdtemp()
676
677   def tearDown(self):
678     testutils.GanetiTestCase.tearDown(self)
679
680     shutil.rmtree(self.tmpdir)
681
682   def testEmpty(self):
683     filename = utils.PathJoin(self.tmpdir, "config.data")
684     utils.WriteFile(filename, data="")
685     bname = utils.CreateBackup(filename)
686     self.assertFileContent(bname, "")
687     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
688     utils.CreateBackup(filename)
689     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
690     utils.CreateBackup(filename)
691     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
692
693     fifoname = utils.PathJoin(self.tmpdir, "fifo")
694     os.mkfifo(fifoname)
695     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
696
697   def testContent(self):
698     bkpcount = 0
699     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
700       for rep in [1, 2, 10, 127]:
701         testdata = data * rep
702
703         filename = utils.PathJoin(self.tmpdir, "test.data_")
704         utils.WriteFile(filename, data=testdata)
705         self.assertFileContent(filename, testdata)
706
707         for _ in range(3):
708           bname = utils.CreateBackup(filename)
709           bkpcount += 1
710           self.assertFileContent(bname, testdata)
711           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
712
713
714 class TestFormatUnit(unittest.TestCase):
715   """Test case for the FormatUnit function"""
716
717   def testMiB(self):
718     self.assertEqual(FormatUnit(1, 'h'), '1M')
719     self.assertEqual(FormatUnit(100, 'h'), '100M')
720     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
721
722     self.assertEqual(FormatUnit(1, 'm'), '1')
723     self.assertEqual(FormatUnit(100, 'm'), '100')
724     self.assertEqual(FormatUnit(1023, 'm'), '1023')
725
726     self.assertEqual(FormatUnit(1024, 'm'), '1024')
727     self.assertEqual(FormatUnit(1536, 'm'), '1536')
728     self.assertEqual(FormatUnit(17133, 'm'), '17133')
729     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
730
731   def testGiB(self):
732     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
733     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
734     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
735     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
736
737     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
738     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
739     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
740     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
741
742     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
743     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
744     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
745
746   def testTiB(self):
747     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
748     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
749     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
750
751     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
752     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
753     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
754
755 class TestParseUnit(unittest.TestCase):
756   """Test case for the ParseUnit function"""
757
758   SCALES = (('', 1),
759             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
760             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
761             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
762
763   def testRounding(self):
764     self.assertEqual(ParseUnit('0'), 0)
765     self.assertEqual(ParseUnit('1'), 4)
766     self.assertEqual(ParseUnit('2'), 4)
767     self.assertEqual(ParseUnit('3'), 4)
768
769     self.assertEqual(ParseUnit('124'), 124)
770     self.assertEqual(ParseUnit('125'), 128)
771     self.assertEqual(ParseUnit('126'), 128)
772     self.assertEqual(ParseUnit('127'), 128)
773     self.assertEqual(ParseUnit('128'), 128)
774     self.assertEqual(ParseUnit('129'), 132)
775     self.assertEqual(ParseUnit('130'), 132)
776
777   def testFloating(self):
778     self.assertEqual(ParseUnit('0'), 0)
779     self.assertEqual(ParseUnit('0.5'), 4)
780     self.assertEqual(ParseUnit('1.75'), 4)
781     self.assertEqual(ParseUnit('1.99'), 4)
782     self.assertEqual(ParseUnit('2.00'), 4)
783     self.assertEqual(ParseUnit('2.01'), 4)
784     self.assertEqual(ParseUnit('3.99'), 4)
785     self.assertEqual(ParseUnit('4.00'), 4)
786     self.assertEqual(ParseUnit('4.01'), 8)
787     self.assertEqual(ParseUnit('1.5G'), 1536)
788     self.assertEqual(ParseUnit('1.8G'), 1844)
789     self.assertEqual(ParseUnit('8.28T'), 8682212)
790
791   def testSuffixes(self):
792     for sep in ('', ' ', '   ', "\t", "\t "):
793       for suffix, scale in TestParseUnit.SCALES:
794         for func in (lambda x: x, str.lower, str.upper):
795           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
796                            1024 * scale)
797
798   def testInvalidInput(self):
799     for sep in ('-', '_', ',', 'a'):
800       for suffix, _ in TestParseUnit.SCALES:
801         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
802
803     for suffix, _ in TestParseUnit.SCALES:
804       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
805
806
807 class TestSshKeys(testutils.GanetiTestCase):
808   """Test case for the AddAuthorizedKey function"""
809
810   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
811   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
812            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
813
814   def setUp(self):
815     testutils.GanetiTestCase.setUp(self)
816     self.tmpname = self._CreateTempFile()
817     handle = open(self.tmpname, 'w')
818     try:
819       handle.write("%s\n" % TestSshKeys.KEY_A)
820       handle.write("%s\n" % TestSshKeys.KEY_B)
821     finally:
822       handle.close()
823
824   def testAddingNewKey(self):
825     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
826
827     self.assertFileContent(self.tmpname,
828       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
829       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
830       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
831       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
832
833   def testAddingAlmostButNotCompletelyTheSameKey(self):
834     AddAuthorizedKey(self.tmpname,
835         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
836
837     self.assertFileContent(self.tmpname,
838       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
839       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
840       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
841       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
842
843   def testAddingExistingKeyWithSomeMoreSpaces(self):
844     AddAuthorizedKey(self.tmpname,
845         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
846
847     self.assertFileContent(self.tmpname,
848       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
849       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
850       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
851
852   def testRemovingExistingKeyWithSomeMoreSpaces(self):
853     RemoveAuthorizedKey(self.tmpname,
854         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
855
856     self.assertFileContent(self.tmpname,
857       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
858       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
859
860   def testRemovingNonExistingKey(self):
861     RemoveAuthorizedKey(self.tmpname,
862         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
863
864     self.assertFileContent(self.tmpname,
865       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
866       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
867       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
868
869
870 class TestEtcHosts(testutils.GanetiTestCase):
871   """Test functions modifying /etc/hosts"""
872
873   def setUp(self):
874     testutils.GanetiTestCase.setUp(self)
875     self.tmpname = self._CreateTempFile()
876     handle = open(self.tmpname, 'w')
877     try:
878       handle.write('# This is a test file for /etc/hosts\n')
879       handle.write('127.0.0.1\tlocalhost\n')
880       handle.write('192.168.1.1 router gw\n')
881     finally:
882       handle.close()
883
884   def testSettingNewIp(self):
885     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
886
887     self.assertFileContent(self.tmpname,
888       "# This is a test file for /etc/hosts\n"
889       "127.0.0.1\tlocalhost\n"
890       "192.168.1.1 router gw\n"
891       "1.2.3.4\tmyhost.domain.tld myhost\n")
892     self.assertFileMode(self.tmpname, 0644)
893
894   def testSettingExistingIp(self):
895     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
896                      ['myhost'])
897
898     self.assertFileContent(self.tmpname,
899       "# This is a test file for /etc/hosts\n"
900       "127.0.0.1\tlocalhost\n"
901       "192.168.1.1\tmyhost.domain.tld myhost\n")
902     self.assertFileMode(self.tmpname, 0644)
903
904   def testSettingDuplicateName(self):
905     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
906
907     self.assertFileContent(self.tmpname,
908       "# This is a test file for /etc/hosts\n"
909       "127.0.0.1\tlocalhost\n"
910       "192.168.1.1 router gw\n"
911       "1.2.3.4\tmyhost\n")
912     self.assertFileMode(self.tmpname, 0644)
913
914   def testRemovingExistingHost(self):
915     RemoveEtcHostsEntry(self.tmpname, 'router')
916
917     self.assertFileContent(self.tmpname,
918       "# This is a test file for /etc/hosts\n"
919       "127.0.0.1\tlocalhost\n"
920       "192.168.1.1 gw\n")
921     self.assertFileMode(self.tmpname, 0644)
922
923   def testRemovingSingleExistingHost(self):
924     RemoveEtcHostsEntry(self.tmpname, 'localhost')
925
926     self.assertFileContent(self.tmpname,
927       "# This is a test file for /etc/hosts\n"
928       "192.168.1.1 router gw\n")
929     self.assertFileMode(self.tmpname, 0644)
930
931   def testRemovingNonExistingHost(self):
932     RemoveEtcHostsEntry(self.tmpname, 'myhost')
933
934     self.assertFileContent(self.tmpname,
935       "# This is a test file for /etc/hosts\n"
936       "127.0.0.1\tlocalhost\n"
937       "192.168.1.1 router gw\n")
938     self.assertFileMode(self.tmpname, 0644)
939
940   def testRemovingAlias(self):
941     RemoveEtcHostsEntry(self.tmpname, 'gw')
942
943     self.assertFileContent(self.tmpname,
944       "# This is a test file for /etc/hosts\n"
945       "127.0.0.1\tlocalhost\n"
946       "192.168.1.1 router\n")
947     self.assertFileMode(self.tmpname, 0644)
948
949
950 class TestShellQuoting(unittest.TestCase):
951   """Test case for shell quoting functions"""
952
953   def testShellQuote(self):
954     self.assertEqual(ShellQuote('abc'), "abc")
955     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
956     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
957     self.assertEqual(ShellQuote("a b c"), "'a b c'")
958     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
959
960   def testShellQuoteArgs(self):
961     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
962     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
963     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
964
965
966 class TestTcpPing(unittest.TestCase):
967   """Testcase for TCP version of ping - against listen(2)ing port"""
968
969   def setUp(self):
970     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
971     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
972     self.listenerport = self.listener.getsockname()[1]
973     self.listener.listen(1)
974
975   def tearDown(self):
976     self.listener.shutdown(socket.SHUT_RDWR)
977     del self.listener
978     del self.listenerport
979
980   def testTcpPingToLocalHostAccept(self):
981     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
982                          self.listenerport,
983                          timeout=10,
984                          live_port_needed=True,
985                          source=constants.LOCALHOST_IP_ADDRESS,
986                          ),
987                  "failed to connect to test listener")
988
989     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
990                          self.listenerport,
991                          timeout=10,
992                          live_port_needed=True,
993                          ),
994                  "failed to connect to test listener (no source)")
995
996
997 class TestTcpPingDeaf(unittest.TestCase):
998   """Testcase for TCP version of ping - against non listen(2)ing port"""
999
1000   def setUp(self):
1001     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1002     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1003     self.deaflistenerport = self.deaflistener.getsockname()[1]
1004
1005   def tearDown(self):
1006     del self.deaflistener
1007     del self.deaflistenerport
1008
1009   def testTcpPingToLocalHostAcceptDeaf(self):
1010     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1011                         self.deaflistenerport,
1012                         timeout=constants.TCP_PING_TIMEOUT,
1013                         live_port_needed=True,
1014                         source=constants.LOCALHOST_IP_ADDRESS,
1015                         ), # need successful connect(2)
1016                 "successfully connected to deaf listener")
1017
1018     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1019                         self.deaflistenerport,
1020                         timeout=constants.TCP_PING_TIMEOUT,
1021                         live_port_needed=True,
1022                         ), # need successful connect(2)
1023                 "successfully connected to deaf listener (no source addr)")
1024
1025   def testTcpPingToLocalHostNoAccept(self):
1026     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1027                          self.deaflistenerport,
1028                          timeout=constants.TCP_PING_TIMEOUT,
1029                          live_port_needed=False,
1030                          source=constants.LOCALHOST_IP_ADDRESS,
1031                          ), # ECONNREFUSED is OK
1032                  "failed to ping alive host on deaf port")
1033
1034     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1035                          self.deaflistenerport,
1036                          timeout=constants.TCP_PING_TIMEOUT,
1037                          live_port_needed=False,
1038                          ), # ECONNREFUSED is OK
1039                  "failed to ping alive host on deaf port (no source addr)")
1040
1041
1042 class TestOwnIpAddress(unittest.TestCase):
1043   """Testcase for OwnIpAddress"""
1044
1045   def testOwnLoopback(self):
1046     """check having the loopback ip"""
1047     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1048                     "Should own the loopback address")
1049
1050   def testNowOwnAddress(self):
1051     """check that I don't own an address"""
1052
1053     # network 192.0.2.0/24 is reserved for test/documentation as per
1054     # rfc 3330, so we *should* not have an address of this range... if
1055     # this fails, we should extend the test to multiple addresses
1056     DST_IP = "192.0.2.1"
1057     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1058
1059
1060 def _GetSocketCredentials(path):
1061   """Connect to a Unix socket and return remote credentials.
1062
1063   """
1064   sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1065   try:
1066     sock.settimeout(10)
1067     sock.connect(path)
1068     return utils.GetSocketCredentials(sock)
1069   finally:
1070     sock.close()
1071
1072
1073 class TestGetSocketCredentials(unittest.TestCase):
1074   def setUp(self):
1075     self.tmpdir = tempfile.mkdtemp()
1076     self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1077
1078     self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1079     self.listener.settimeout(10)
1080     self.listener.bind(self.sockpath)
1081     self.listener.listen(1)
1082
1083   def tearDown(self):
1084     self.listener.shutdown(socket.SHUT_RDWR)
1085     self.listener.close()
1086     shutil.rmtree(self.tmpdir)
1087
1088   def test(self):
1089     (c2pr, c2pw) = os.pipe()
1090
1091     # Start child process
1092     child = os.fork()
1093     if child == 0:
1094       try:
1095         data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1096
1097         os.write(c2pw, data)
1098         os.close(c2pw)
1099
1100         os._exit(0)
1101       finally:
1102         os._exit(1)
1103
1104     os.close(c2pw)
1105
1106     # Wait for one connection
1107     (conn, _) = self.listener.accept()
1108     conn.recv(1)
1109     conn.close()
1110
1111     # Wait for result
1112     result = os.read(c2pr, 4096)
1113     os.close(c2pr)
1114
1115     # Check child's exit code
1116     (_, status) = os.waitpid(child, 0)
1117     self.assertFalse(os.WIFSIGNALED(status))
1118     self.assertEqual(os.WEXITSTATUS(status), 0)
1119
1120     # Check result
1121     (pid, uid, gid) = serializer.LoadJson(result)
1122     self.assertEqual(pid, os.getpid())
1123     self.assertEqual(uid, os.getuid())
1124     self.assertEqual(gid, os.getgid())
1125
1126
1127 class TestListVisibleFiles(unittest.TestCase):
1128   """Test case for ListVisibleFiles"""
1129
1130   def setUp(self):
1131     self.path = tempfile.mkdtemp()
1132
1133   def tearDown(self):
1134     shutil.rmtree(self.path)
1135
1136   def _test(self, files, expected):
1137     # Sort a copy
1138     expected = expected[:]
1139     expected.sort()
1140
1141     for name in files:
1142       f = open(os.path.join(self.path, name), 'w')
1143       try:
1144         f.write("Test\n")
1145       finally:
1146         f.close()
1147
1148     found = ListVisibleFiles(self.path)
1149     found.sort()
1150
1151     self.assertEqual(found, expected)
1152
1153   def testAllVisible(self):
1154     files = ["a", "b", "c"]
1155     expected = files
1156     self._test(files, expected)
1157
1158   def testNoneVisible(self):
1159     files = [".a", ".b", ".c"]
1160     expected = []
1161     self._test(files, expected)
1162
1163   def testSomeVisible(self):
1164     files = ["a", "b", ".c"]
1165     expected = ["a", "b"]
1166     self._test(files, expected)
1167
1168   def testNonAbsolutePath(self):
1169     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1170
1171   def testNonNormalizedPath(self):
1172     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1173                           "/bin/../tmp")
1174
1175
1176 class TestNewUUID(unittest.TestCase):
1177   """Test case for NewUUID"""
1178
1179   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1180                         '[a-f0-9]{4}-[a-f0-9]{12}$')
1181
1182   def runTest(self):
1183     self.failUnless(self._re_uuid.match(utils.NewUUID()))
1184
1185
1186 class TestUniqueSequence(unittest.TestCase):
1187   """Test case for UniqueSequence"""
1188
1189   def _test(self, input, expected):
1190     self.assertEqual(utils.UniqueSequence(input), expected)
1191
1192   def runTest(self):
1193     # Ordered input
1194     self._test([1, 2, 3], [1, 2, 3])
1195     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1196     self._test([1, 2, 2, 3], [1, 2, 3])
1197     self._test([1, 2, 3, 3], [1, 2, 3])
1198
1199     # Unordered input
1200     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1201     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1202
1203     # Strings
1204     self._test(["a", "a"], ["a"])
1205     self._test(["a", "b"], ["a", "b"])
1206     self._test(["a", "b", "a"], ["a", "b"])
1207
1208
1209 class TestFirstFree(unittest.TestCase):
1210   """Test case for the FirstFree function"""
1211
1212   def test(self):
1213     """Test FirstFree"""
1214     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1215     self.failUnlessEqual(FirstFree([]), None)
1216     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1217     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1218     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1219
1220
1221 class TestTailFile(testutils.GanetiTestCase):
1222   """Test case for the TailFile function"""
1223
1224   def testEmpty(self):
1225     fname = self._CreateTempFile()
1226     self.failUnlessEqual(TailFile(fname), [])
1227     self.failUnlessEqual(TailFile(fname, lines=25), [])
1228
1229   def testAllLines(self):
1230     data = ["test %d" % i for i in range(30)]
1231     for i in range(30):
1232       fname = self._CreateTempFile()
1233       fd = open(fname, "w")
1234       fd.write("\n".join(data[:i]))
1235       if i > 0:
1236         fd.write("\n")
1237       fd.close()
1238       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1239
1240   def testPartialLines(self):
1241     data = ["test %d" % i for i in range(30)]
1242     fname = self._CreateTempFile()
1243     fd = open(fname, "w")
1244     fd.write("\n".join(data))
1245     fd.write("\n")
1246     fd.close()
1247     for i in range(1, 30):
1248       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1249
1250   def testBigFile(self):
1251     data = ["test %d" % i for i in range(30)]
1252     fname = self._CreateTempFile()
1253     fd = open(fname, "w")
1254     fd.write("X" * 1048576)
1255     fd.write("\n")
1256     fd.write("\n".join(data))
1257     fd.write("\n")
1258     fd.close()
1259     for i in range(1, 30):
1260       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1261
1262
1263 class _BaseFileLockTest:
1264   """Test case for the FileLock class"""
1265
1266   def testSharedNonblocking(self):
1267     self.lock.Shared(blocking=False)
1268     self.lock.Close()
1269
1270   def testExclusiveNonblocking(self):
1271     self.lock.Exclusive(blocking=False)
1272     self.lock.Close()
1273
1274   def testUnlockNonblocking(self):
1275     self.lock.Unlock(blocking=False)
1276     self.lock.Close()
1277
1278   def testSharedBlocking(self):
1279     self.lock.Shared(blocking=True)
1280     self.lock.Close()
1281
1282   def testExclusiveBlocking(self):
1283     self.lock.Exclusive(blocking=True)
1284     self.lock.Close()
1285
1286   def testUnlockBlocking(self):
1287     self.lock.Unlock(blocking=True)
1288     self.lock.Close()
1289
1290   def testSharedExclusiveUnlock(self):
1291     self.lock.Shared(blocking=False)
1292     self.lock.Exclusive(blocking=False)
1293     self.lock.Unlock(blocking=False)
1294     self.lock.Close()
1295
1296   def testExclusiveSharedUnlock(self):
1297     self.lock.Exclusive(blocking=False)
1298     self.lock.Shared(blocking=False)
1299     self.lock.Unlock(blocking=False)
1300     self.lock.Close()
1301
1302   def testSimpleTimeout(self):
1303     # These will succeed on the first attempt, hence a short timeout
1304     self.lock.Shared(blocking=True, timeout=10.0)
1305     self.lock.Exclusive(blocking=False, timeout=10.0)
1306     self.lock.Unlock(blocking=True, timeout=10.0)
1307     self.lock.Close()
1308
1309   @staticmethod
1310   def _TryLockInner(filename, shared, blocking):
1311     lock = utils.FileLock.Open(filename)
1312
1313     if shared:
1314       fn = lock.Shared
1315     else:
1316       fn = lock.Exclusive
1317
1318     try:
1319       # The timeout doesn't really matter as the parent process waits for us to
1320       # finish anyway.
1321       fn(blocking=blocking, timeout=0.01)
1322     except errors.LockError, err:
1323       return False
1324
1325     return True
1326
1327   def _TryLock(self, *args):
1328     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1329                                       *args)
1330
1331   def testTimeout(self):
1332     for blocking in [True, False]:
1333       self.lock.Exclusive(blocking=True)
1334       self.failIf(self._TryLock(False, blocking))
1335       self.failIf(self._TryLock(True, blocking))
1336
1337       self.lock.Shared(blocking=True)
1338       self.assert_(self._TryLock(True, blocking))
1339       self.failIf(self._TryLock(False, blocking))
1340
1341   def testCloseShared(self):
1342     self.lock.Close()
1343     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1344
1345   def testCloseExclusive(self):
1346     self.lock.Close()
1347     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1348
1349   def testCloseUnlock(self):
1350     self.lock.Close()
1351     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1352
1353
1354 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1355   TESTDATA = "Hello World\n" * 10
1356
1357   def setUp(self):
1358     testutils.GanetiTestCase.setUp(self)
1359
1360     self.tmpfile = tempfile.NamedTemporaryFile()
1361     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1362     self.lock = utils.FileLock.Open(self.tmpfile.name)
1363
1364     # Ensure "Open" didn't truncate file
1365     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1366
1367   def tearDown(self):
1368     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1369
1370     testutils.GanetiTestCase.tearDown(self)
1371
1372
1373 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1374   def setUp(self):
1375     self.tmpfile = tempfile.NamedTemporaryFile()
1376     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1377
1378
1379 class TestTimeFunctions(unittest.TestCase):
1380   """Test case for time functions"""
1381
1382   def runTest(self):
1383     self.assertEqual(utils.SplitTime(1), (1, 0))
1384     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1385     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1386     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1387     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1388     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1389     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1390     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1391
1392     self.assertRaises(AssertionError, utils.SplitTime, -1)
1393
1394     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1395     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1396     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1397
1398     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1399                      1218448917.481)
1400     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1401
1402     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1403     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1404     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1405     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1406     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1407
1408
1409 class FieldSetTestCase(unittest.TestCase):
1410   """Test case for FieldSets"""
1411
1412   def testSimpleMatch(self):
1413     f = utils.FieldSet("a", "b", "c", "def")
1414     self.failUnless(f.Matches("a"))
1415     self.failIf(f.Matches("d"), "Substring matched")
1416     self.failIf(f.Matches("defghi"), "Prefix string matched")
1417     self.failIf(f.NonMatching(["b", "c"]))
1418     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1419     self.failUnless(f.NonMatching(["a", "d"]))
1420
1421   def testRegexMatch(self):
1422     f = utils.FieldSet("a", "b([0-9]+)", "c")
1423     self.failUnless(f.Matches("b1"))
1424     self.failUnless(f.Matches("b99"))
1425     self.failIf(f.Matches("b/1"))
1426     self.failIf(f.NonMatching(["b12", "c"]))
1427     self.failUnless(f.NonMatching(["a", "1"]))
1428
1429 class TestForceDictType(unittest.TestCase):
1430   """Test case for ForceDictType"""
1431
1432   def setUp(self):
1433     self.key_types = {
1434       'a': constants.VTYPE_INT,
1435       'b': constants.VTYPE_BOOL,
1436       'c': constants.VTYPE_STRING,
1437       'd': constants.VTYPE_SIZE,
1438       }
1439
1440   def _fdt(self, dict, allowed_values=None):
1441     if allowed_values is None:
1442       ForceDictType(dict, self.key_types)
1443     else:
1444       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1445
1446     return dict
1447
1448   def testSimpleDict(self):
1449     self.assertEqual(self._fdt({}), {})
1450     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1451     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1452     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1453     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1454     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1455     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1456     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1457     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1458     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1459     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1460     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1461
1462   def testErrors(self):
1463     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1464     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1465     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1466     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1467
1468
1469 class TestIsAbsNormPath(unittest.TestCase):
1470   """Testing case for IsNormAbsPath"""
1471
1472   def _pathTestHelper(self, path, result):
1473     if result:
1474       self.assert_(IsNormAbsPath(path),
1475           "Path %s should result absolute and normalized" % path)
1476     else:
1477       self.assert_(not IsNormAbsPath(path),
1478           "Path %s should not result absolute and normalized" % path)
1479
1480   def testBase(self):
1481     self._pathTestHelper('/etc', True)
1482     self._pathTestHelper('/srv', True)
1483     self._pathTestHelper('etc', False)
1484     self._pathTestHelper('/etc/../root', False)
1485     self._pathTestHelper('/etc/', False)
1486
1487
1488 class TestSafeEncode(unittest.TestCase):
1489   """Test case for SafeEncode"""
1490
1491   def testAscii(self):
1492     for txt in [string.digits, string.letters, string.punctuation]:
1493       self.failUnlessEqual(txt, SafeEncode(txt))
1494
1495   def testDoubleEncode(self):
1496     for i in range(255):
1497       txt = SafeEncode(chr(i))
1498       self.failUnlessEqual(txt, SafeEncode(txt))
1499
1500   def testUnicode(self):
1501     # 1024 is high enough to catch non-direct ASCII mappings
1502     for i in range(1024):
1503       txt = SafeEncode(unichr(i))
1504       self.failUnlessEqual(txt, SafeEncode(txt))
1505
1506
1507 class TestFormatTime(unittest.TestCase):
1508   """Testing case for FormatTime"""
1509
1510   def testNone(self):
1511     self.failUnlessEqual(FormatTime(None), "N/A")
1512
1513   def testInvalid(self):
1514     self.failUnlessEqual(FormatTime(()), "N/A")
1515
1516   def testNow(self):
1517     # tests that we accept time.time input
1518     FormatTime(time.time())
1519     # tests that we accept int input
1520     FormatTime(int(time.time()))
1521
1522
1523 class RunInSeparateProcess(unittest.TestCase):
1524   def test(self):
1525     for exp in [True, False]:
1526       def _child():
1527         return exp
1528
1529       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1530
1531   def testArgs(self):
1532     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1533       def _child(carg1, carg2):
1534         return carg1 == "Foo" and carg2 == arg
1535
1536       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1537
1538   def testPid(self):
1539     parent_pid = os.getpid()
1540
1541     def _check():
1542       return os.getpid() == parent_pid
1543
1544     self.failIf(utils.RunInSeparateProcess(_check))
1545
1546   def testSignal(self):
1547     def _kill():
1548       os.kill(os.getpid(), signal.SIGTERM)
1549
1550     self.assertRaises(errors.GenericError,
1551                       utils.RunInSeparateProcess, _kill)
1552
1553   def testException(self):
1554     def _exc():
1555       raise errors.GenericError("This is a test")
1556
1557     self.assertRaises(errors.GenericError,
1558                       utils.RunInSeparateProcess, _exc)
1559
1560
1561 class TestFingerprintFile(unittest.TestCase):
1562   def setUp(self):
1563     self.tmpfile = tempfile.NamedTemporaryFile()
1564
1565   def test(self):
1566     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1567                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1568
1569     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1570     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1571                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1572
1573
1574 class TestUnescapeAndSplit(unittest.TestCase):
1575   """Testing case for UnescapeAndSplit"""
1576
1577   def setUp(self):
1578     # testing more that one separator for regexp safety
1579     self._seps = [",", "+", "."]
1580
1581   def testSimple(self):
1582     a = ["a", "b", "c", "d"]
1583     for sep in self._seps:
1584       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1585
1586   def testEscape(self):
1587     for sep in self._seps:
1588       a = ["a", "b\\" + sep + "c", "d"]
1589       b = ["a", "b" + sep + "c", "d"]
1590       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1591
1592   def testDoubleEscape(self):
1593     for sep in self._seps:
1594       a = ["a", "b\\\\", "c", "d"]
1595       b = ["a", "b\\", "c", "d"]
1596       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1597
1598   def testThreeEscape(self):
1599     for sep in self._seps:
1600       a = ["a", "b\\\\\\" + sep + "c", "d"]
1601       b = ["a", "b\\" + sep + "c", "d"]
1602       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1603
1604
1605 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1606   def setUp(self):
1607     self.tmpdir = tempfile.mkdtemp()
1608
1609   def tearDown(self):
1610     shutil.rmtree(self.tmpdir)
1611
1612   def _checkRsaPrivateKey(self, key):
1613     lines = key.splitlines()
1614     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1615             "-----END RSA PRIVATE KEY-----" in lines)
1616
1617   def _checkCertificate(self, cert):
1618     lines = cert.splitlines()
1619     return ("-----BEGIN CERTIFICATE-----" in lines and
1620             "-----END CERTIFICATE-----" in lines)
1621
1622   def test(self):
1623     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1624       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1625       self._checkRsaPrivateKey(key_pem)
1626       self._checkCertificate(cert_pem)
1627
1628       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1629                                            key_pem)
1630       self.assert_(key.bits() >= 1024)
1631       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1632       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1633
1634       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1635                                              cert_pem)
1636       self.failIf(x509.has_expired())
1637       self.assertEqual(x509.get_issuer().CN, common_name)
1638       self.assertEqual(x509.get_subject().CN, common_name)
1639       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1640
1641   def testLegacy(self):
1642     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1643
1644     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1645
1646     cert1 = utils.ReadFile(cert1_filename)
1647
1648     self.assert_(self._checkRsaPrivateKey(cert1))
1649     self.assert_(self._checkCertificate(cert1))
1650
1651
1652 class TestPathJoin(unittest.TestCase):
1653   """Testing case for PathJoin"""
1654
1655   def testBasicItems(self):
1656     mlist = ["/a", "b", "c"]
1657     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1658
1659   def testNonAbsPrefix(self):
1660     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1661
1662   def testBackTrack(self):
1663     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1664
1665   def testMultiAbs(self):
1666     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1667
1668
1669 class TestHostInfo(unittest.TestCase):
1670   """Testing case for HostInfo"""
1671
1672   def testUppercase(self):
1673     data = "AbC.example.com"
1674     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1675
1676   def testTooLongName(self):
1677     data = "a.b." + "c" * 255
1678     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1679
1680   def testTrailingDot(self):
1681     data = "a.b.c"
1682     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1683
1684   def testInvalidName(self):
1685     data = [
1686       "a b",
1687       "a/b",
1688       ".a.b",
1689       "a..b",
1690       ]
1691     for value in data:
1692       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1693
1694   def testValidName(self):
1695     data = [
1696       "a.b",
1697       "a-b",
1698       "a_b",
1699       "a.b.c",
1700       ]
1701     for value in data:
1702       HostInfo.NormalizeName(value)
1703
1704
1705 class TestParseAsn1Generalizedtime(unittest.TestCase):
1706   def test(self):
1707     # UTC
1708     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1709     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1710                      1266860512)
1711     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1712                      (2**31) - 1)
1713
1714     # With offset
1715     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1716                      1266860512)
1717     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1718                      1266931012)
1719     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1720                      1266931088)
1721     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1722                      1266931295)
1723     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1724                      3600)
1725
1726     # Leap seconds are not supported by datetime.datetime
1727     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1728                       "19841231235960+0000")
1729     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1730                       "19920630235960+0000")
1731
1732     # Errors
1733     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1734     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1735     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1736                       "20100222174152")
1737     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1738                       "Mon Feb 22 17:47:02 UTC 2010")
1739     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1740                       "2010-02-22 17:42:02")
1741
1742
1743 class TestGetX509CertValidity(testutils.GanetiTestCase):
1744   def setUp(self):
1745     testutils.GanetiTestCase.setUp(self)
1746
1747     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1748
1749     # Test whether we have pyOpenSSL 0.7 or above
1750     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1751
1752     if not self.pyopenssl0_7:
1753       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1754                     " function correctly")
1755
1756   def _LoadCert(self, name):
1757     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1758                                            self._ReadTestData(name))
1759
1760   def test(self):
1761     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1762     if self.pyopenssl0_7:
1763       self.assertEqual(validity, (1266919967, 1267524767))
1764     else:
1765       self.assertEqual(validity, (None, None))
1766
1767
1768 class TestSignX509Certificate(unittest.TestCase):
1769   KEY = "My private key!"
1770   KEY_OTHER = "Another key"
1771
1772   def test(self):
1773     # Generate certificate valid for 5 minutes
1774     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1775
1776     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1777                                            cert_pem)
1778
1779     # No signature at all
1780     self.assertRaises(errors.GenericError,
1781                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1782
1783     # Invalid input
1784     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1785                       "", self.KEY)
1786     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1787                       "X-Ganeti-Signature: \n", self.KEY)
1788     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1789                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1790     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1791                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1792     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1793                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1794
1795     # Invalid salt
1796     for salt in list("-_@$,:;/\\ \t\n"):
1797       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1798                         cert_pem, self.KEY, "foo%sbar" % salt)
1799
1800     for salt in ["HelloWorld", "salt", string.letters, string.digits,
1801                  utils.GenerateSecret(numbytes=4),
1802                  utils.GenerateSecret(numbytes=16),
1803                  "{123:456}".encode("hex")]:
1804       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1805
1806       self._Check(cert, salt, signed_pem)
1807
1808       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1809       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1810       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1811                                "lines----\n------ at\nthe end!"))
1812
1813   def _Check(self, cert, salt, pem):
1814     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1815     self.assertEqual(salt, salt2)
1816     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1817
1818     # Other key
1819     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1820                       pem, self.KEY_OTHER)
1821
1822
1823 class TestMakedirs(unittest.TestCase):
1824   def setUp(self):
1825     self.tmpdir = tempfile.mkdtemp()
1826
1827   def tearDown(self):
1828     shutil.rmtree(self.tmpdir)
1829
1830   def testNonExisting(self):
1831     path = utils.PathJoin(self.tmpdir, "foo")
1832     utils.Makedirs(path)
1833     self.assert_(os.path.isdir(path))
1834
1835   def testExisting(self):
1836     path = utils.PathJoin(self.tmpdir, "foo")
1837     os.mkdir(path)
1838     utils.Makedirs(path)
1839     self.assert_(os.path.isdir(path))
1840
1841   def testRecursiveNonExisting(self):
1842     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
1843     utils.Makedirs(path)
1844     self.assert_(os.path.isdir(path))
1845
1846   def testRecursiveExisting(self):
1847     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
1848     self.assert_(not os.path.exists(path))
1849     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
1850     utils.Makedirs(path)
1851     self.assert_(os.path.isdir(path))
1852
1853
1854 class TestRetry(testutils.GanetiTestCase):
1855   @staticmethod
1856   def _RaiseRetryAgain():
1857     raise utils.RetryAgain()
1858
1859   def _WrongNestedLoop(self):
1860     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1861
1862   def testRaiseTimeout(self):
1863     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1864                           self._RaiseRetryAgain, 0.01, 0.02)
1865
1866   def testComplete(self):
1867     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1868
1869   def testNestedLoop(self):
1870     try:
1871       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1872                             self._WrongNestedLoop, 0, 1)
1873     except utils.RetryTimeout:
1874       self.fail("Didn't detect inner loop's exception")
1875
1876
1877 class TestLineSplitter(unittest.TestCase):
1878   def test(self):
1879     lines = []
1880     ls = utils.LineSplitter(lines.append)
1881     ls.write("Hello World\n")
1882     self.assertEqual(lines, [])
1883     ls.write("Foo\n Bar\r\n ")
1884     ls.write("Baz")
1885     ls.write("Moo")
1886     self.assertEqual(lines, [])
1887     ls.flush()
1888     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
1889     ls.close()
1890     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
1891
1892   def _testExtra(self, line, all_lines, p1, p2):
1893     self.assertEqual(p1, 999)
1894     self.assertEqual(p2, "extra")
1895     all_lines.append(line)
1896
1897   def testExtraArgsNoFlush(self):
1898     lines = []
1899     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
1900     ls.write("\n\nHello World\n")
1901     ls.write("Foo\n Bar\r\n ")
1902     ls.write("")
1903     ls.write("Baz")
1904     ls.write("Moo\n\nx\n")
1905     self.assertEqual(lines, [])
1906     ls.close()
1907     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
1908                              "", "x"])
1909
1910
1911 class TestReadLockedPidFile(unittest.TestCase):
1912   def setUp(self):
1913     self.tmpdir = tempfile.mkdtemp()
1914
1915   def tearDown(self):
1916     shutil.rmtree(self.tmpdir)
1917
1918   def testNonExistent(self):
1919     path = utils.PathJoin(self.tmpdir, "nonexist")
1920     self.assert_(utils.ReadLockedPidFile(path) is None)
1921
1922   def testUnlocked(self):
1923     path = utils.PathJoin(self.tmpdir, "pid")
1924     utils.WriteFile(path, data="123")
1925     self.assert_(utils.ReadLockedPidFile(path) is None)
1926
1927   def testLocked(self):
1928     path = utils.PathJoin(self.tmpdir, "pid")
1929     utils.WriteFile(path, data="123")
1930
1931     fl = utils.FileLock.Open(path)
1932     try:
1933       fl.Exclusive(blocking=True)
1934
1935       self.assertEqual(utils.ReadLockedPidFile(path), 123)
1936     finally:
1937       fl.Close()
1938
1939     self.assert_(utils.ReadLockedPidFile(path) is None)
1940
1941   def testError(self):
1942     path = utils.PathJoin(self.tmpdir, "foobar", "pid")
1943     utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
1944     # open(2) should return ENOTDIR
1945     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
1946
1947
1948 if __name__ == '__main__':
1949   testutils.GanetiTestProgram()