Merge branch 'devel-2.1'
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import fcntl
38 import OpenSSL
39 import warnings
40 import distutils.version
41 import glob
42 import errno
43
44 import ganeti
45 import testutils
46 from ganeti import constants
47 from ganeti import compat
48 from ganeti import utils
49 from ganeti import errors
50 from ganeti import serializer
51 from ganeti.utils import IsProcessAlive, RunCmd, \
52      RemoveFile, MatchNameComponent, FormatUnit, \
53      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
54      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
55      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
56      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
57      UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
58
59 from ganeti.errors import LockError, UnitParseError, GenericError, \
60      ProgrammerError, OpPrereqError
61
62
63 class TestIsProcessAlive(unittest.TestCase):
64   """Testing case for IsProcessAlive"""
65
66   def testExists(self):
67     mypid = os.getpid()
68     self.assert_(IsProcessAlive(mypid),
69                  "can't find myself running")
70
71   def testNotExisting(self):
72     pid_non_existing = os.fork()
73     if pid_non_existing == 0:
74       os._exit(0)
75     elif pid_non_existing < 0:
76       raise SystemError("can't fork")
77     os.waitpid(pid_non_existing, 0)
78     self.assert_(not IsProcessAlive(pid_non_existing),
79                  "nonexisting process detected")
80
81
82 class TestPidFileFunctions(unittest.TestCase):
83   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
84
85   def setUp(self):
86     self.dir = tempfile.mkdtemp()
87     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
88     utils.DaemonPidFileName = self.f_dpn
89
90   def testPidFileFunctions(self):
91     pid_file = self.f_dpn('test')
92     utils.WritePidFile('test')
93     self.failUnless(os.path.exists(pid_file),
94                     "PID file should have been created")
95     read_pid = utils.ReadPidFile(pid_file)
96     self.failUnlessEqual(read_pid, os.getpid())
97     self.failUnless(utils.IsProcessAlive(read_pid))
98     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
99     utils.RemovePidFile('test')
100     self.failIf(os.path.exists(pid_file),
101                 "PID file should not exist anymore")
102     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
103                          "ReadPidFile should return 0 for missing pid file")
104     fh = open(pid_file, "w")
105     fh.write("blah\n")
106     fh.close()
107     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
108                          "ReadPidFile should return 0 for invalid pid file")
109     utils.RemovePidFile('test')
110     self.failIf(os.path.exists(pid_file),
111                 "PID file should not exist anymore")
112
113   def testKill(self):
114     pid_file = self.f_dpn('child')
115     r_fd, w_fd = os.pipe()
116     new_pid = os.fork()
117     if new_pid == 0: #child
118       utils.WritePidFile('child')
119       os.write(w_fd, 'a')
120       signal.pause()
121       os._exit(0)
122       return
123     # else we are in the parent
124     # wait until the child has written the pid file
125     os.read(r_fd, 1)
126     read_pid = utils.ReadPidFile(pid_file)
127     self.failUnlessEqual(read_pid, new_pid)
128     self.failUnless(utils.IsProcessAlive(new_pid))
129     utils.KillProcess(new_pid, waitpid=True)
130     self.failIf(utils.IsProcessAlive(new_pid))
131     utils.RemovePidFile('child')
132     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
133
134   def tearDown(self):
135     for name in os.listdir(self.dir):
136       os.unlink(os.path.join(self.dir, name))
137     os.rmdir(self.dir)
138
139
140 class TestRunCmd(testutils.GanetiTestCase):
141   """Testing case for the RunCmd function"""
142
143   def setUp(self):
144     testutils.GanetiTestCase.setUp(self)
145     self.magic = time.ctime() + " ganeti test"
146     self.fname = self._CreateTempFile()
147
148   def testOk(self):
149     """Test successful exit code"""
150     result = RunCmd("/bin/sh -c 'exit 0'")
151     self.assertEqual(result.exit_code, 0)
152     self.assertEqual(result.output, "")
153
154   def testFail(self):
155     """Test fail exit code"""
156     result = RunCmd("/bin/sh -c 'exit 1'")
157     self.assertEqual(result.exit_code, 1)
158     self.assertEqual(result.output, "")
159
160   def testStdout(self):
161     """Test standard output"""
162     cmd = 'echo -n "%s"' % self.magic
163     result = RunCmd("/bin/sh -c '%s'" % cmd)
164     self.assertEqual(result.stdout, self.magic)
165     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
166     self.assertEqual(result.output, "")
167     self.assertFileContent(self.fname, self.magic)
168
169   def testStderr(self):
170     """Test standard error"""
171     cmd = 'echo -n "%s"' % self.magic
172     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
173     self.assertEqual(result.stderr, self.magic)
174     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
175     self.assertEqual(result.output, "")
176     self.assertFileContent(self.fname, self.magic)
177
178   def testCombined(self):
179     """Test combined output"""
180     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
181     expected = "A" + self.magic + "B" + self.magic
182     result = RunCmd("/bin/sh -c '%s'" % cmd)
183     self.assertEqual(result.output, expected)
184     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
185     self.assertEqual(result.output, "")
186     self.assertFileContent(self.fname, expected)
187
188   def testSignal(self):
189     """Test signal"""
190     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
191     self.assertEqual(result.signal, 15)
192     self.assertEqual(result.output, "")
193
194   def testListRun(self):
195     """Test list runs"""
196     result = RunCmd(["true"])
197     self.assertEqual(result.signal, None)
198     self.assertEqual(result.exit_code, 0)
199     result = RunCmd(["/bin/sh", "-c", "exit 1"])
200     self.assertEqual(result.signal, None)
201     self.assertEqual(result.exit_code, 1)
202     result = RunCmd(["echo", "-n", self.magic])
203     self.assertEqual(result.signal, None)
204     self.assertEqual(result.exit_code, 0)
205     self.assertEqual(result.stdout, self.magic)
206
207   def testFileEmptyOutput(self):
208     """Test file output"""
209     result = RunCmd(["true"], output=self.fname)
210     self.assertEqual(result.signal, None)
211     self.assertEqual(result.exit_code, 0)
212     self.assertFileContent(self.fname, "")
213
214   def testLang(self):
215     """Test locale environment"""
216     old_env = os.environ.copy()
217     try:
218       os.environ["LANG"] = "en_US.UTF-8"
219       os.environ["LC_ALL"] = "en_US.UTF-8"
220       result = RunCmd(["locale"])
221       for line in result.output.splitlines():
222         key, value = line.split("=", 1)
223         # Ignore these variables, they're overridden by LC_ALL
224         if key == "LANG" or key == "LANGUAGE":
225           continue
226         self.failIf(value and value != "C" and value != '"C"',
227             "Variable %s is set to the invalid value '%s'" % (key, value))
228     finally:
229       os.environ = old_env
230
231   def testDefaultCwd(self):
232     """Test default working directory"""
233     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
234
235   def testCwd(self):
236     """Test default working directory"""
237     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
238     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
239     cwd = os.getcwd()
240     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
241
242   def testResetEnv(self):
243     """Test environment reset functionality"""
244     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
245     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
246                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
247
248
249 class TestRunParts(unittest.TestCase):
250   """Testing case for the RunParts function"""
251
252   def setUp(self):
253     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
254
255   def tearDown(self):
256     shutil.rmtree(self.rundir)
257
258   def testEmpty(self):
259     """Test on an empty dir"""
260     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
261
262   def testSkipWrongName(self):
263     """Test that wrong files are skipped"""
264     fname = os.path.join(self.rundir, "00test.dot")
265     utils.WriteFile(fname, data="")
266     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
267     relname = os.path.basename(fname)
268     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
269                          [(relname, constants.RUNPARTS_SKIP, None)])
270
271   def testSkipNonExec(self):
272     """Test that non executable files are skipped"""
273     fname = os.path.join(self.rundir, "00test")
274     utils.WriteFile(fname, data="")
275     relname = os.path.basename(fname)
276     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
277                          [(relname, constants.RUNPARTS_SKIP, None)])
278
279   def testError(self):
280     """Test error on a broken executable"""
281     fname = os.path.join(self.rundir, "00test")
282     utils.WriteFile(fname, data="")
283     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
284     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
285     self.failUnlessEqual(relname, os.path.basename(fname))
286     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
287     self.failUnless(error)
288
289   def testSorted(self):
290     """Test executions are sorted"""
291     files = []
292     files.append(os.path.join(self.rundir, "64test"))
293     files.append(os.path.join(self.rundir, "00test"))
294     files.append(os.path.join(self.rundir, "42test"))
295
296     for fname in files:
297       utils.WriteFile(fname, data="")
298
299     results = RunParts(self.rundir, reset_env=True)
300
301     for fname in sorted(files):
302       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
303
304   def testOk(self):
305     """Test correct execution"""
306     fname = os.path.join(self.rundir, "00test")
307     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
308     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
309     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
310     self.failUnlessEqual(relname, os.path.basename(fname))
311     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
312     self.failUnlessEqual(runresult.stdout, "ciao")
313
314   def testRunFail(self):
315     """Test correct execution, with run failure"""
316     fname = os.path.join(self.rundir, "00test")
317     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
318     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
319     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
320     self.failUnlessEqual(relname, os.path.basename(fname))
321     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
322     self.failUnlessEqual(runresult.exit_code, 1)
323     self.failUnless(runresult.failed)
324
325   def testRunMix(self):
326     files = []
327     files.append(os.path.join(self.rundir, "00test"))
328     files.append(os.path.join(self.rundir, "42test"))
329     files.append(os.path.join(self.rundir, "64test"))
330     files.append(os.path.join(self.rundir, "99test"))
331
332     files.sort()
333
334     # 1st has errors in execution
335     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
336     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
337
338     # 2nd is skipped
339     utils.WriteFile(files[1], data="")
340
341     # 3rd cannot execute properly
342     utils.WriteFile(files[2], data="")
343     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
344
345     # 4th execs
346     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
347     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
348
349     results = RunParts(self.rundir, reset_env=True)
350
351     (relname, status, runresult) = results[0]
352     self.failUnlessEqual(relname, os.path.basename(files[0]))
353     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
354     self.failUnlessEqual(runresult.exit_code, 1)
355     self.failUnless(runresult.failed)
356
357     (relname, status, runresult) = results[1]
358     self.failUnlessEqual(relname, os.path.basename(files[1]))
359     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
360     self.failUnlessEqual(runresult, None)
361
362     (relname, status, runresult) = results[2]
363     self.failUnlessEqual(relname, os.path.basename(files[2]))
364     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
365     self.failUnless(runresult)
366
367     (relname, status, runresult) = results[3]
368     self.failUnlessEqual(relname, os.path.basename(files[3]))
369     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
370     self.failUnlessEqual(runresult.output, "ciao")
371     self.failUnlessEqual(runresult.exit_code, 0)
372     self.failUnless(not runresult.failed)
373
374
375 class TestStartDaemon(testutils.GanetiTestCase):
376   def setUp(self):
377     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
378     self.tmpfile = os.path.join(self.tmpdir, "test")
379
380   def tearDown(self):
381     shutil.rmtree(self.tmpdir)
382
383   def testShell(self):
384     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
385     self._wait(self.tmpfile, 60.0, "Hello World")
386
387   def testShellOutput(self):
388     utils.StartDaemon("echo Hello World", output=self.tmpfile)
389     self._wait(self.tmpfile, 60.0, "Hello World")
390
391   def testNoShellNoOutput(self):
392     utils.StartDaemon(["pwd"])
393
394   def testNoShellNoOutputTouch(self):
395     testfile = os.path.join(self.tmpdir, "check")
396     self.failIf(os.path.exists(testfile))
397     utils.StartDaemon(["touch", testfile])
398     self._wait(testfile, 60.0, "")
399
400   def testNoShellOutput(self):
401     utils.StartDaemon(["pwd"], output=self.tmpfile)
402     self._wait(self.tmpfile, 60.0, "/")
403
404   def testNoShellOutputCwd(self):
405     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
406     self._wait(self.tmpfile, 60.0, os.getcwd())
407
408   def testShellEnv(self):
409     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
410                       env={ "GNT_TEST_VAR": "Hello World", })
411     self._wait(self.tmpfile, 60.0, "Hello World")
412
413   def testNoShellEnv(self):
414     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
415                       env={ "GNT_TEST_VAR": "Hello World", })
416     self._wait(self.tmpfile, 60.0, "Hello World")
417
418   def testOutputFd(self):
419     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
420     try:
421       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
422     finally:
423       os.close(fd)
424     self._wait(self.tmpfile, 60.0, os.getcwd())
425
426   def testPid(self):
427     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
428     self._wait(self.tmpfile, 60.0, str(pid))
429
430   def testPidFile(self):
431     pidfile = os.path.join(self.tmpdir, "pid")
432     checkfile = os.path.join(self.tmpdir, "abort")
433
434     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
435                             output=self.tmpfile)
436     try:
437       fd = os.open(pidfile, os.O_RDONLY)
438       try:
439         # Check file is locked
440         self.assertRaises(errors.LockError, utils.LockFile, fd)
441
442         pidtext = os.read(fd, 100)
443       finally:
444         os.close(fd)
445
446       self.assertEqual(int(pidtext.strip()), pid)
447
448       self.assert_(utils.IsProcessAlive(pid))
449     finally:
450       # No matter what happens, kill daemon
451       utils.KillProcess(pid, timeout=5.0, waitpid=False)
452       self.failIf(utils.IsProcessAlive(pid))
453
454     self.assertEqual(utils.ReadFile(self.tmpfile), "")
455
456   def _wait(self, path, timeout, expected):
457     # Due to the asynchronous nature of daemon processes, polling is necessary.
458     # A timeout makes sure the test doesn't hang forever.
459     def _CheckFile():
460       if not (os.path.isfile(path) and
461               utils.ReadFile(path).strip() == expected):
462         raise utils.RetryAgain()
463
464     try:
465       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
466     except utils.RetryTimeout:
467       self.fail("Apparently the daemon didn't run in %s seconds and/or"
468                 " didn't write the correct output" % timeout)
469
470   def testError(self):
471     self.assertRaises(errors.OpExecError, utils.StartDaemon,
472                       ["./does-NOT-EXIST/here/0123456789"])
473     self.assertRaises(errors.OpExecError, utils.StartDaemon,
474                       ["./does-NOT-EXIST/here/0123456789"],
475                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
476     self.assertRaises(errors.OpExecError, utils.StartDaemon,
477                       ["./does-NOT-EXIST/here/0123456789"],
478                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
479     self.assertRaises(errors.OpExecError, utils.StartDaemon,
480                       ["./does-NOT-EXIST/here/0123456789"],
481                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
482
483     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
484     try:
485       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
486                         ["./does-NOT-EXIST/here/0123456789"],
487                         output=self.tmpfile, output_fd=fd)
488     finally:
489       os.close(fd)
490
491
492 class TestSetCloseOnExecFlag(unittest.TestCase):
493   """Tests for SetCloseOnExecFlag"""
494
495   def setUp(self):
496     self.tmpfile = tempfile.TemporaryFile()
497
498   def testEnable(self):
499     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
500     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
501                     fcntl.FD_CLOEXEC)
502
503   def testDisable(self):
504     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
505     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
506                 fcntl.FD_CLOEXEC)
507
508
509 class TestSetNonblockFlag(unittest.TestCase):
510   def setUp(self):
511     self.tmpfile = tempfile.TemporaryFile()
512
513   def testEnable(self):
514     utils.SetNonblockFlag(self.tmpfile.fileno(), True)
515     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
516                     os.O_NONBLOCK)
517
518   def testDisable(self):
519     utils.SetNonblockFlag(self.tmpfile.fileno(), False)
520     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
521                 os.O_NONBLOCK)
522
523
524 class TestRemoveFile(unittest.TestCase):
525   """Test case for the RemoveFile function"""
526
527   def setUp(self):
528     """Create a temp dir and file for each case"""
529     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
530     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
531     os.close(fd)
532
533   def tearDown(self):
534     if os.path.exists(self.tmpfile):
535       os.unlink(self.tmpfile)
536     os.rmdir(self.tmpdir)
537
538   def testIgnoreDirs(self):
539     """Test that RemoveFile() ignores directories"""
540     self.assertEqual(None, RemoveFile(self.tmpdir))
541
542   def testIgnoreNotExisting(self):
543     """Test that RemoveFile() ignores non-existing files"""
544     RemoveFile(self.tmpfile)
545     RemoveFile(self.tmpfile)
546
547   def testRemoveFile(self):
548     """Test that RemoveFile does remove a file"""
549     RemoveFile(self.tmpfile)
550     if os.path.exists(self.tmpfile):
551       self.fail("File '%s' not removed" % self.tmpfile)
552
553   def testRemoveSymlink(self):
554     """Test that RemoveFile does remove symlinks"""
555     symlink = self.tmpdir + "/symlink"
556     os.symlink("no-such-file", symlink)
557     RemoveFile(symlink)
558     if os.path.exists(symlink):
559       self.fail("File '%s' not removed" % symlink)
560     os.symlink(self.tmpfile, symlink)
561     RemoveFile(symlink)
562     if os.path.exists(symlink):
563       self.fail("File '%s' not removed" % symlink)
564
565
566 class TestRename(unittest.TestCase):
567   """Test case for RenameFile"""
568
569   def setUp(self):
570     """Create a temporary directory"""
571     self.tmpdir = tempfile.mkdtemp()
572     self.tmpfile = os.path.join(self.tmpdir, "test1")
573
574     # Touch the file
575     open(self.tmpfile, "w").close()
576
577   def tearDown(self):
578     """Remove temporary directory"""
579     shutil.rmtree(self.tmpdir)
580
581   def testSimpleRename1(self):
582     """Simple rename 1"""
583     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
584     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
585
586   def testSimpleRename2(self):
587     """Simple rename 2"""
588     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
589                      mkdir=True)
590     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
591
592   def testRenameMkdir(self):
593     """Rename with mkdir"""
594     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
595                      mkdir=True)
596     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
597     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
598
599     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
600                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
601                      mkdir=True)
602     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
603     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
604     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
605
606
607 class TestMatchNameComponent(unittest.TestCase):
608   """Test case for the MatchNameComponent function"""
609
610   def testEmptyList(self):
611     """Test that there is no match against an empty list"""
612
613     self.failUnlessEqual(MatchNameComponent("", []), None)
614     self.failUnlessEqual(MatchNameComponent("test", []), None)
615
616   def testSingleMatch(self):
617     """Test that a single match is performed correctly"""
618     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
619     for key in "test2", "test2.example", "test2.example.com":
620       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
621
622   def testMultipleMatches(self):
623     """Test that a multiple match is returned as None"""
624     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
625     for key in "test1", "test1.example":
626       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
627
628   def testFullMatch(self):
629     """Test that a full match is returned correctly"""
630     key1 = "test1"
631     key2 = "test1.example"
632     mlist = [key2, key2 + ".com"]
633     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
634     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
635
636   def testCaseInsensitivePartialMatch(self):
637     """Test for the case_insensitive keyword"""
638     mlist = ["test1.example.com", "test2.example.net"]
639     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
640                      "test2.example.net")
641     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
642                      "test2.example.net")
643     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
644                      "test2.example.net")
645     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
646                      "test2.example.net")
647
648
649   def testCaseInsensitiveFullMatch(self):
650     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
651     # Between the two ts1 a full string match non-case insensitive should work
652     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
653                      None)
654     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
655                      "ts1.ex")
656     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
657                      "ts1.ex")
658     # Between the two ts2 only case differs, so only case-match works
659     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
660                      "ts2.ex")
661     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
662                      "Ts2.ex")
663     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
664                      None)
665
666
667 class TestReadFile(testutils.GanetiTestCase):
668
669   def testReadAll(self):
670     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
671     self.assertEqual(len(data), 814)
672
673     h = compat.md5_hash()
674     h.update(data)
675     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
676
677   def testReadSize(self):
678     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
679                           size=100)
680     self.assertEqual(len(data), 100)
681
682     h = compat.md5_hash()
683     h.update(data)
684     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
685
686   def testError(self):
687     self.assertRaises(EnvironmentError, utils.ReadFile,
688                       "/dev/null/does-not-exist")
689
690
691 class TestReadOneLineFile(testutils.GanetiTestCase):
692
693   def setUp(self):
694     testutils.GanetiTestCase.setUp(self)
695
696   def testDefault(self):
697     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
698     self.assertEqual(len(data), 27)
699     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
700
701   def testNotStrict(self):
702     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
703     self.assertEqual(len(data), 27)
704     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
705
706   def testStrictFailure(self):
707     self.assertRaises(errors.GenericError, ReadOneLineFile,
708                       self._TestDataFilename("cert1.pem"), strict=True)
709
710   def testLongLine(self):
711     dummydata = (1024 * "Hello World! ")
712     myfile = self._CreateTempFile()
713     utils.WriteFile(myfile, data=dummydata)
714     datastrict = ReadOneLineFile(myfile, strict=True)
715     datalax = ReadOneLineFile(myfile, strict=False)
716     self.assertEqual(dummydata, datastrict)
717     self.assertEqual(dummydata, datalax)
718
719   def testNewline(self):
720     myfile = self._CreateTempFile()
721     myline = "myline"
722     for nl in ["", "\n", "\r\n"]:
723       dummydata = "%s%s" % (myline, nl)
724       utils.WriteFile(myfile, data=dummydata)
725       datalax = ReadOneLineFile(myfile, strict=False)
726       self.assertEqual(myline, datalax)
727       datastrict = ReadOneLineFile(myfile, strict=True)
728       self.assertEqual(myline, datastrict)
729
730   def testWhitespaceAndMultipleLines(self):
731     myfile = self._CreateTempFile()
732     for nl in ["", "\n", "\r\n"]:
733       for ws in [" ", "\t", "\t\t  \t", "\t "]:
734         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
735         utils.WriteFile(myfile, data=dummydata)
736         datalax = ReadOneLineFile(myfile, strict=False)
737         if nl:
738           self.assert_(set("\r\n") & set(dummydata))
739           self.assertRaises(errors.GenericError, ReadOneLineFile,
740                             myfile, strict=True)
741           explen = len("Foo bar baz ") + len(ws)
742           self.assertEqual(len(datalax), explen)
743           self.assertEqual(datalax, dummydata[:explen])
744           self.assertFalse(set("\r\n") & set(datalax))
745         else:
746           datastrict = ReadOneLineFile(myfile, strict=True)
747           self.assertEqual(dummydata, datastrict)
748           self.assertEqual(dummydata, datalax)
749
750   def testEmptylines(self):
751     myfile = self._CreateTempFile()
752     myline = "myline"
753     for nl in ["\n", "\r\n"]:
754       for ol in ["", "otherline"]:
755         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
756         utils.WriteFile(myfile, data=dummydata)
757         self.assert_(set("\r\n") & set(dummydata))
758         datalax = ReadOneLineFile(myfile, strict=False)
759         self.assertEqual(myline, datalax)
760         if ol:
761           self.assertRaises(errors.GenericError, ReadOneLineFile,
762                             myfile, strict=True)
763         else:
764           datastrict = ReadOneLineFile(myfile, strict=True)
765           self.assertEqual(myline, datastrict)
766
767
768 class TestTimestampForFilename(unittest.TestCase):
769   def test(self):
770     self.assert_("." not in utils.TimestampForFilename())
771     self.assert_(":" not in utils.TimestampForFilename())
772
773
774 class TestCreateBackup(testutils.GanetiTestCase):
775   def setUp(self):
776     testutils.GanetiTestCase.setUp(self)
777
778     self.tmpdir = tempfile.mkdtemp()
779
780   def tearDown(self):
781     testutils.GanetiTestCase.tearDown(self)
782
783     shutil.rmtree(self.tmpdir)
784
785   def testEmpty(self):
786     filename = utils.PathJoin(self.tmpdir, "config.data")
787     utils.WriteFile(filename, data="")
788     bname = utils.CreateBackup(filename)
789     self.assertFileContent(bname, "")
790     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
791     utils.CreateBackup(filename)
792     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
793     utils.CreateBackup(filename)
794     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
795
796     fifoname = utils.PathJoin(self.tmpdir, "fifo")
797     os.mkfifo(fifoname)
798     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
799
800   def testContent(self):
801     bkpcount = 0
802     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
803       for rep in [1, 2, 10, 127]:
804         testdata = data * rep
805
806         filename = utils.PathJoin(self.tmpdir, "test.data_")
807         utils.WriteFile(filename, data=testdata)
808         self.assertFileContent(filename, testdata)
809
810         for _ in range(3):
811           bname = utils.CreateBackup(filename)
812           bkpcount += 1
813           self.assertFileContent(bname, testdata)
814           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
815
816
817 class TestFormatUnit(unittest.TestCase):
818   """Test case for the FormatUnit function"""
819
820   def testMiB(self):
821     self.assertEqual(FormatUnit(1, 'h'), '1M')
822     self.assertEqual(FormatUnit(100, 'h'), '100M')
823     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
824
825     self.assertEqual(FormatUnit(1, 'm'), '1')
826     self.assertEqual(FormatUnit(100, 'm'), '100')
827     self.assertEqual(FormatUnit(1023, 'm'), '1023')
828
829     self.assertEqual(FormatUnit(1024, 'm'), '1024')
830     self.assertEqual(FormatUnit(1536, 'm'), '1536')
831     self.assertEqual(FormatUnit(17133, 'm'), '17133')
832     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
833
834   def testGiB(self):
835     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
836     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
837     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
838     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
839
840     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
841     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
842     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
843     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
844
845     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
846     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
847     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
848
849   def testTiB(self):
850     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
851     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
852     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
853
854     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
855     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
856     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
857
858 class TestParseUnit(unittest.TestCase):
859   """Test case for the ParseUnit function"""
860
861   SCALES = (('', 1),
862             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
863             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
864             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
865
866   def testRounding(self):
867     self.assertEqual(ParseUnit('0'), 0)
868     self.assertEqual(ParseUnit('1'), 4)
869     self.assertEqual(ParseUnit('2'), 4)
870     self.assertEqual(ParseUnit('3'), 4)
871
872     self.assertEqual(ParseUnit('124'), 124)
873     self.assertEqual(ParseUnit('125'), 128)
874     self.assertEqual(ParseUnit('126'), 128)
875     self.assertEqual(ParseUnit('127'), 128)
876     self.assertEqual(ParseUnit('128'), 128)
877     self.assertEqual(ParseUnit('129'), 132)
878     self.assertEqual(ParseUnit('130'), 132)
879
880   def testFloating(self):
881     self.assertEqual(ParseUnit('0'), 0)
882     self.assertEqual(ParseUnit('0.5'), 4)
883     self.assertEqual(ParseUnit('1.75'), 4)
884     self.assertEqual(ParseUnit('1.99'), 4)
885     self.assertEqual(ParseUnit('2.00'), 4)
886     self.assertEqual(ParseUnit('2.01'), 4)
887     self.assertEqual(ParseUnit('3.99'), 4)
888     self.assertEqual(ParseUnit('4.00'), 4)
889     self.assertEqual(ParseUnit('4.01'), 8)
890     self.assertEqual(ParseUnit('1.5G'), 1536)
891     self.assertEqual(ParseUnit('1.8G'), 1844)
892     self.assertEqual(ParseUnit('8.28T'), 8682212)
893
894   def testSuffixes(self):
895     for sep in ('', ' ', '   ', "\t", "\t "):
896       for suffix, scale in TestParseUnit.SCALES:
897         for func in (lambda x: x, str.lower, str.upper):
898           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
899                            1024 * scale)
900
901   def testInvalidInput(self):
902     for sep in ('-', '_', ',', 'a'):
903       for suffix, _ in TestParseUnit.SCALES:
904         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
905
906     for suffix, _ in TestParseUnit.SCALES:
907       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
908
909
910 class TestSshKeys(testutils.GanetiTestCase):
911   """Test case for the AddAuthorizedKey function"""
912
913   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
914   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
915            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
916
917   def setUp(self):
918     testutils.GanetiTestCase.setUp(self)
919     self.tmpname = self._CreateTempFile()
920     handle = open(self.tmpname, 'w')
921     try:
922       handle.write("%s\n" % TestSshKeys.KEY_A)
923       handle.write("%s\n" % TestSshKeys.KEY_B)
924     finally:
925       handle.close()
926
927   def testAddingNewKey(self):
928     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
929
930     self.assertFileContent(self.tmpname,
931       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
932       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
933       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
934       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
935
936   def testAddingAlmostButNotCompletelyTheSameKey(self):
937     AddAuthorizedKey(self.tmpname,
938         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
939
940     self.assertFileContent(self.tmpname,
941       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
942       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
943       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
944       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
945
946   def testAddingExistingKeyWithSomeMoreSpaces(self):
947     AddAuthorizedKey(self.tmpname,
948         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
949
950     self.assertFileContent(self.tmpname,
951       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
952       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
953       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
954
955   def testRemovingExistingKeyWithSomeMoreSpaces(self):
956     RemoveAuthorizedKey(self.tmpname,
957         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
958
959     self.assertFileContent(self.tmpname,
960       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
961       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
962
963   def testRemovingNonExistingKey(self):
964     RemoveAuthorizedKey(self.tmpname,
965         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
966
967     self.assertFileContent(self.tmpname,
968       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
969       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
970       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
971
972
973 class TestEtcHosts(testutils.GanetiTestCase):
974   """Test functions modifying /etc/hosts"""
975
976   def setUp(self):
977     testutils.GanetiTestCase.setUp(self)
978     self.tmpname = self._CreateTempFile()
979     handle = open(self.tmpname, 'w')
980     try:
981       handle.write('# This is a test file for /etc/hosts\n')
982       handle.write('127.0.0.1\tlocalhost\n')
983       handle.write('192.168.1.1 router gw\n')
984     finally:
985       handle.close()
986
987   def testSettingNewIp(self):
988     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
989
990     self.assertFileContent(self.tmpname,
991       "# This is a test file for /etc/hosts\n"
992       "127.0.0.1\tlocalhost\n"
993       "192.168.1.1 router gw\n"
994       "1.2.3.4\tmyhost.domain.tld myhost\n")
995     self.assertFileMode(self.tmpname, 0644)
996
997   def testSettingExistingIp(self):
998     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
999                      ['myhost'])
1000
1001     self.assertFileContent(self.tmpname,
1002       "# This is a test file for /etc/hosts\n"
1003       "127.0.0.1\tlocalhost\n"
1004       "192.168.1.1\tmyhost.domain.tld myhost\n")
1005     self.assertFileMode(self.tmpname, 0644)
1006
1007   def testSettingDuplicateName(self):
1008     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
1009
1010     self.assertFileContent(self.tmpname,
1011       "# This is a test file for /etc/hosts\n"
1012       "127.0.0.1\tlocalhost\n"
1013       "192.168.1.1 router gw\n"
1014       "1.2.3.4\tmyhost\n")
1015     self.assertFileMode(self.tmpname, 0644)
1016
1017   def testRemovingExistingHost(self):
1018     RemoveEtcHostsEntry(self.tmpname, 'router')
1019
1020     self.assertFileContent(self.tmpname,
1021       "# This is a test file for /etc/hosts\n"
1022       "127.0.0.1\tlocalhost\n"
1023       "192.168.1.1 gw\n")
1024     self.assertFileMode(self.tmpname, 0644)
1025
1026   def testRemovingSingleExistingHost(self):
1027     RemoveEtcHostsEntry(self.tmpname, 'localhost')
1028
1029     self.assertFileContent(self.tmpname,
1030       "# This is a test file for /etc/hosts\n"
1031       "192.168.1.1 router gw\n")
1032     self.assertFileMode(self.tmpname, 0644)
1033
1034   def testRemovingNonExistingHost(self):
1035     RemoveEtcHostsEntry(self.tmpname, 'myhost')
1036
1037     self.assertFileContent(self.tmpname,
1038       "# This is a test file for /etc/hosts\n"
1039       "127.0.0.1\tlocalhost\n"
1040       "192.168.1.1 router gw\n")
1041     self.assertFileMode(self.tmpname, 0644)
1042
1043   def testRemovingAlias(self):
1044     RemoveEtcHostsEntry(self.tmpname, 'gw')
1045
1046     self.assertFileContent(self.tmpname,
1047       "# This is a test file for /etc/hosts\n"
1048       "127.0.0.1\tlocalhost\n"
1049       "192.168.1.1 router\n")
1050     self.assertFileMode(self.tmpname, 0644)
1051
1052
1053 class TestShellQuoting(unittest.TestCase):
1054   """Test case for shell quoting functions"""
1055
1056   def testShellQuote(self):
1057     self.assertEqual(ShellQuote('abc'), "abc")
1058     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1059     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1060     self.assertEqual(ShellQuote("a b c"), "'a b c'")
1061     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1062
1063   def testShellQuoteArgs(self):
1064     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1065     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1066     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1067
1068
1069 class TestTcpPing(unittest.TestCase):
1070   """Testcase for TCP version of ping - against listen(2)ing port"""
1071
1072   def setUp(self):
1073     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1074     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1075     self.listenerport = self.listener.getsockname()[1]
1076     self.listener.listen(1)
1077
1078   def tearDown(self):
1079     self.listener.shutdown(socket.SHUT_RDWR)
1080     del self.listener
1081     del self.listenerport
1082
1083   def testTcpPingToLocalHostAccept(self):
1084     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1085                          self.listenerport,
1086                          timeout=10,
1087                          live_port_needed=True,
1088                          source=constants.LOCALHOST_IP_ADDRESS,
1089                          ),
1090                  "failed to connect to test listener")
1091
1092     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1093                          self.listenerport,
1094                          timeout=10,
1095                          live_port_needed=True,
1096                          ),
1097                  "failed to connect to test listener (no source)")
1098
1099
1100 class TestTcpPingDeaf(unittest.TestCase):
1101   """Testcase for TCP version of ping - against non listen(2)ing port"""
1102
1103   def setUp(self):
1104     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1105     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1106     self.deaflistenerport = self.deaflistener.getsockname()[1]
1107
1108   def tearDown(self):
1109     del self.deaflistener
1110     del self.deaflistenerport
1111
1112   def testTcpPingToLocalHostAcceptDeaf(self):
1113     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1114                         self.deaflistenerport,
1115                         timeout=constants.TCP_PING_TIMEOUT,
1116                         live_port_needed=True,
1117                         source=constants.LOCALHOST_IP_ADDRESS,
1118                         ), # need successful connect(2)
1119                 "successfully connected to deaf listener")
1120
1121     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1122                         self.deaflistenerport,
1123                         timeout=constants.TCP_PING_TIMEOUT,
1124                         live_port_needed=True,
1125                         ), # need successful connect(2)
1126                 "successfully connected to deaf listener (no source addr)")
1127
1128   def testTcpPingToLocalHostNoAccept(self):
1129     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1130                          self.deaflistenerport,
1131                          timeout=constants.TCP_PING_TIMEOUT,
1132                          live_port_needed=False,
1133                          source=constants.LOCALHOST_IP_ADDRESS,
1134                          ), # ECONNREFUSED is OK
1135                  "failed to ping alive host on deaf port")
1136
1137     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1138                          self.deaflistenerport,
1139                          timeout=constants.TCP_PING_TIMEOUT,
1140                          live_port_needed=False,
1141                          ), # ECONNREFUSED is OK
1142                  "failed to ping alive host on deaf port (no source addr)")
1143
1144
1145 class TestOwnIpAddress(unittest.TestCase):
1146   """Testcase for OwnIpAddress"""
1147
1148   def testOwnLoopback(self):
1149     """check having the loopback ip"""
1150     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1151                     "Should own the loopback address")
1152
1153   def testNowOwnAddress(self):
1154     """check that I don't own an address"""
1155
1156     # Network 192.0.2.0/24 is reserved for test/documentation as per
1157     # RFC 5735, so we *should* not have an address of this range... if
1158     # this fails, we should extend the test to multiple addresses
1159     DST_IP = "192.0.2.1"
1160     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1161
1162
1163 def _GetSocketCredentials(path):
1164   """Connect to a Unix socket and return remote credentials.
1165
1166   """
1167   sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1168   try:
1169     sock.settimeout(10)
1170     sock.connect(path)
1171     return utils.GetSocketCredentials(sock)
1172   finally:
1173     sock.close()
1174
1175
1176 class TestGetSocketCredentials(unittest.TestCase):
1177   def setUp(self):
1178     self.tmpdir = tempfile.mkdtemp()
1179     self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1180
1181     self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1182     self.listener.settimeout(10)
1183     self.listener.bind(self.sockpath)
1184     self.listener.listen(1)
1185
1186   def tearDown(self):
1187     self.listener.shutdown(socket.SHUT_RDWR)
1188     self.listener.close()
1189     shutil.rmtree(self.tmpdir)
1190
1191   def test(self):
1192     (c2pr, c2pw) = os.pipe()
1193
1194     # Start child process
1195     child = os.fork()
1196     if child == 0:
1197       try:
1198         data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1199
1200         os.write(c2pw, data)
1201         os.close(c2pw)
1202
1203         os._exit(0)
1204       finally:
1205         os._exit(1)
1206
1207     os.close(c2pw)
1208
1209     # Wait for one connection
1210     (conn, _) = self.listener.accept()
1211     conn.recv(1)
1212     conn.close()
1213
1214     # Wait for result
1215     result = os.read(c2pr, 4096)
1216     os.close(c2pr)
1217
1218     # Check child's exit code
1219     (_, status) = os.waitpid(child, 0)
1220     self.assertFalse(os.WIFSIGNALED(status))
1221     self.assertEqual(os.WEXITSTATUS(status), 0)
1222
1223     # Check result
1224     (pid, uid, gid) = serializer.LoadJson(result)
1225     self.assertEqual(pid, os.getpid())
1226     self.assertEqual(uid, os.getuid())
1227     self.assertEqual(gid, os.getgid())
1228
1229
1230 class TestListVisibleFiles(unittest.TestCase):
1231   """Test case for ListVisibleFiles"""
1232
1233   def setUp(self):
1234     self.path = tempfile.mkdtemp()
1235
1236   def tearDown(self):
1237     shutil.rmtree(self.path)
1238
1239   def _test(self, files, expected):
1240     # Sort a copy
1241     expected = expected[:]
1242     expected.sort()
1243
1244     for name in files:
1245       f = open(os.path.join(self.path, name), 'w')
1246       try:
1247         f.write("Test\n")
1248       finally:
1249         f.close()
1250
1251     found = ListVisibleFiles(self.path)
1252     found.sort()
1253
1254     self.assertEqual(found, expected)
1255
1256   def testAllVisible(self):
1257     files = ["a", "b", "c"]
1258     expected = files
1259     self._test(files, expected)
1260
1261   def testNoneVisible(self):
1262     files = [".a", ".b", ".c"]
1263     expected = []
1264     self._test(files, expected)
1265
1266   def testSomeVisible(self):
1267     files = ["a", "b", ".c"]
1268     expected = ["a", "b"]
1269     self._test(files, expected)
1270
1271   def testNonAbsolutePath(self):
1272     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1273
1274   def testNonNormalizedPath(self):
1275     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1276                           "/bin/../tmp")
1277
1278
1279 class TestNewUUID(unittest.TestCase):
1280   """Test case for NewUUID"""
1281
1282   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1283                         '[a-f0-9]{4}-[a-f0-9]{12}$')
1284
1285   def runTest(self):
1286     self.failUnless(self._re_uuid.match(utils.NewUUID()))
1287
1288
1289 class TestUniqueSequence(unittest.TestCase):
1290   """Test case for UniqueSequence"""
1291
1292   def _test(self, input, expected):
1293     self.assertEqual(utils.UniqueSequence(input), expected)
1294
1295   def runTest(self):
1296     # Ordered input
1297     self._test([1, 2, 3], [1, 2, 3])
1298     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1299     self._test([1, 2, 2, 3], [1, 2, 3])
1300     self._test([1, 2, 3, 3], [1, 2, 3])
1301
1302     # Unordered input
1303     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1304     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1305
1306     # Strings
1307     self._test(["a", "a"], ["a"])
1308     self._test(["a", "b"], ["a", "b"])
1309     self._test(["a", "b", "a"], ["a", "b"])
1310
1311
1312 class TestFirstFree(unittest.TestCase):
1313   """Test case for the FirstFree function"""
1314
1315   def test(self):
1316     """Test FirstFree"""
1317     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1318     self.failUnlessEqual(FirstFree([]), None)
1319     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1320     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1321     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1322
1323
1324 class TestTailFile(testutils.GanetiTestCase):
1325   """Test case for the TailFile function"""
1326
1327   def testEmpty(self):
1328     fname = self._CreateTempFile()
1329     self.failUnlessEqual(TailFile(fname), [])
1330     self.failUnlessEqual(TailFile(fname, lines=25), [])
1331
1332   def testAllLines(self):
1333     data = ["test %d" % i for i in range(30)]
1334     for i in range(30):
1335       fname = self._CreateTempFile()
1336       fd = open(fname, "w")
1337       fd.write("\n".join(data[:i]))
1338       if i > 0:
1339         fd.write("\n")
1340       fd.close()
1341       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1342
1343   def testPartialLines(self):
1344     data = ["test %d" % i for i in range(30)]
1345     fname = self._CreateTempFile()
1346     fd = open(fname, "w")
1347     fd.write("\n".join(data))
1348     fd.write("\n")
1349     fd.close()
1350     for i in range(1, 30):
1351       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1352
1353   def testBigFile(self):
1354     data = ["test %d" % i for i in range(30)]
1355     fname = self._CreateTempFile()
1356     fd = open(fname, "w")
1357     fd.write("X" * 1048576)
1358     fd.write("\n")
1359     fd.write("\n".join(data))
1360     fd.write("\n")
1361     fd.close()
1362     for i in range(1, 30):
1363       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1364
1365
1366 class _BaseFileLockTest:
1367   """Test case for the FileLock class"""
1368
1369   def testSharedNonblocking(self):
1370     self.lock.Shared(blocking=False)
1371     self.lock.Close()
1372
1373   def testExclusiveNonblocking(self):
1374     self.lock.Exclusive(blocking=False)
1375     self.lock.Close()
1376
1377   def testUnlockNonblocking(self):
1378     self.lock.Unlock(blocking=False)
1379     self.lock.Close()
1380
1381   def testSharedBlocking(self):
1382     self.lock.Shared(blocking=True)
1383     self.lock.Close()
1384
1385   def testExclusiveBlocking(self):
1386     self.lock.Exclusive(blocking=True)
1387     self.lock.Close()
1388
1389   def testUnlockBlocking(self):
1390     self.lock.Unlock(blocking=True)
1391     self.lock.Close()
1392
1393   def testSharedExclusiveUnlock(self):
1394     self.lock.Shared(blocking=False)
1395     self.lock.Exclusive(blocking=False)
1396     self.lock.Unlock(blocking=False)
1397     self.lock.Close()
1398
1399   def testExclusiveSharedUnlock(self):
1400     self.lock.Exclusive(blocking=False)
1401     self.lock.Shared(blocking=False)
1402     self.lock.Unlock(blocking=False)
1403     self.lock.Close()
1404
1405   def testSimpleTimeout(self):
1406     # These will succeed on the first attempt, hence a short timeout
1407     self.lock.Shared(blocking=True, timeout=10.0)
1408     self.lock.Exclusive(blocking=False, timeout=10.0)
1409     self.lock.Unlock(blocking=True, timeout=10.0)
1410     self.lock.Close()
1411
1412   @staticmethod
1413   def _TryLockInner(filename, shared, blocking):
1414     lock = utils.FileLock.Open(filename)
1415
1416     if shared:
1417       fn = lock.Shared
1418     else:
1419       fn = lock.Exclusive
1420
1421     try:
1422       # The timeout doesn't really matter as the parent process waits for us to
1423       # finish anyway.
1424       fn(blocking=blocking, timeout=0.01)
1425     except errors.LockError, err:
1426       return False
1427
1428     return True
1429
1430   def _TryLock(self, *args):
1431     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1432                                       *args)
1433
1434   def testTimeout(self):
1435     for blocking in [True, False]:
1436       self.lock.Exclusive(blocking=True)
1437       self.failIf(self._TryLock(False, blocking))
1438       self.failIf(self._TryLock(True, blocking))
1439
1440       self.lock.Shared(blocking=True)
1441       self.assert_(self._TryLock(True, blocking))
1442       self.failIf(self._TryLock(False, blocking))
1443
1444   def testCloseShared(self):
1445     self.lock.Close()
1446     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1447
1448   def testCloseExclusive(self):
1449     self.lock.Close()
1450     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1451
1452   def testCloseUnlock(self):
1453     self.lock.Close()
1454     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1455
1456
1457 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1458   TESTDATA = "Hello World\n" * 10
1459
1460   def setUp(self):
1461     testutils.GanetiTestCase.setUp(self)
1462
1463     self.tmpfile = tempfile.NamedTemporaryFile()
1464     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1465     self.lock = utils.FileLock.Open(self.tmpfile.name)
1466
1467     # Ensure "Open" didn't truncate file
1468     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1469
1470   def tearDown(self):
1471     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1472
1473     testutils.GanetiTestCase.tearDown(self)
1474
1475
1476 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1477   def setUp(self):
1478     self.tmpfile = tempfile.NamedTemporaryFile()
1479     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1480
1481
1482 class TestTimeFunctions(unittest.TestCase):
1483   """Test case for time functions"""
1484
1485   def runTest(self):
1486     self.assertEqual(utils.SplitTime(1), (1, 0))
1487     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1488     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1489     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1490     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1491     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1492     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1493     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1494
1495     self.assertRaises(AssertionError, utils.SplitTime, -1)
1496
1497     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1498     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1499     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1500
1501     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1502                      1218448917.481)
1503     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1504
1505     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1506     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1507     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1508     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1509     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1510
1511
1512 class FieldSetTestCase(unittest.TestCase):
1513   """Test case for FieldSets"""
1514
1515   def testSimpleMatch(self):
1516     f = utils.FieldSet("a", "b", "c", "def")
1517     self.failUnless(f.Matches("a"))
1518     self.failIf(f.Matches("d"), "Substring matched")
1519     self.failIf(f.Matches("defghi"), "Prefix string matched")
1520     self.failIf(f.NonMatching(["b", "c"]))
1521     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1522     self.failUnless(f.NonMatching(["a", "d"]))
1523
1524   def testRegexMatch(self):
1525     f = utils.FieldSet("a", "b([0-9]+)", "c")
1526     self.failUnless(f.Matches("b1"))
1527     self.failUnless(f.Matches("b99"))
1528     self.failIf(f.Matches("b/1"))
1529     self.failIf(f.NonMatching(["b12", "c"]))
1530     self.failUnless(f.NonMatching(["a", "1"]))
1531
1532 class TestForceDictType(unittest.TestCase):
1533   """Test case for ForceDictType"""
1534
1535   def setUp(self):
1536     self.key_types = {
1537       'a': constants.VTYPE_INT,
1538       'b': constants.VTYPE_BOOL,
1539       'c': constants.VTYPE_STRING,
1540       'd': constants.VTYPE_SIZE,
1541       }
1542
1543   def _fdt(self, dict, allowed_values=None):
1544     if allowed_values is None:
1545       ForceDictType(dict, self.key_types)
1546     else:
1547       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1548
1549     return dict
1550
1551   def testSimpleDict(self):
1552     self.assertEqual(self._fdt({}), {})
1553     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1554     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1555     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1556     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1557     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1558     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1559     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1560     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1561     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1562     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1563     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1564
1565   def testErrors(self):
1566     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1567     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1568     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1569     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1570
1571
1572 class TestIsNormAbsPath(unittest.TestCase):
1573   """Testing case for IsNormAbsPath"""
1574
1575   def _pathTestHelper(self, path, result):
1576     if result:
1577       self.assert_(IsNormAbsPath(path),
1578           "Path %s should result absolute and normalized" % path)
1579     else:
1580       self.assert_(not IsNormAbsPath(path),
1581           "Path %s should not result absolute and normalized" % path)
1582
1583   def testBase(self):
1584     self._pathTestHelper('/etc', True)
1585     self._pathTestHelper('/srv', True)
1586     self._pathTestHelper('etc', False)
1587     self._pathTestHelper('/etc/../root', False)
1588     self._pathTestHelper('/etc/', False)
1589
1590
1591 class TestSafeEncode(unittest.TestCase):
1592   """Test case for SafeEncode"""
1593
1594   def testAscii(self):
1595     for txt in [string.digits, string.letters, string.punctuation]:
1596       self.failUnlessEqual(txt, SafeEncode(txt))
1597
1598   def testDoubleEncode(self):
1599     for i in range(255):
1600       txt = SafeEncode(chr(i))
1601       self.failUnlessEqual(txt, SafeEncode(txt))
1602
1603   def testUnicode(self):
1604     # 1024 is high enough to catch non-direct ASCII mappings
1605     for i in range(1024):
1606       txt = SafeEncode(unichr(i))
1607       self.failUnlessEqual(txt, SafeEncode(txt))
1608
1609
1610 class TestFormatTime(unittest.TestCase):
1611   """Testing case for FormatTime"""
1612
1613   def testNone(self):
1614     self.failUnlessEqual(FormatTime(None), "N/A")
1615
1616   def testInvalid(self):
1617     self.failUnlessEqual(FormatTime(()), "N/A")
1618
1619   def testNow(self):
1620     # tests that we accept time.time input
1621     FormatTime(time.time())
1622     # tests that we accept int input
1623     FormatTime(int(time.time()))
1624
1625
1626 class RunInSeparateProcess(unittest.TestCase):
1627   def test(self):
1628     for exp in [True, False]:
1629       def _child():
1630         return exp
1631
1632       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1633
1634   def testArgs(self):
1635     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1636       def _child(carg1, carg2):
1637         return carg1 == "Foo" and carg2 == arg
1638
1639       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1640
1641   def testPid(self):
1642     parent_pid = os.getpid()
1643
1644     def _check():
1645       return os.getpid() == parent_pid
1646
1647     self.failIf(utils.RunInSeparateProcess(_check))
1648
1649   def testSignal(self):
1650     def _kill():
1651       os.kill(os.getpid(), signal.SIGTERM)
1652
1653     self.assertRaises(errors.GenericError,
1654                       utils.RunInSeparateProcess, _kill)
1655
1656   def testException(self):
1657     def _exc():
1658       raise errors.GenericError("This is a test")
1659
1660     self.assertRaises(errors.GenericError,
1661                       utils.RunInSeparateProcess, _exc)
1662
1663
1664 class TestFingerprintFile(unittest.TestCase):
1665   def setUp(self):
1666     self.tmpfile = tempfile.NamedTemporaryFile()
1667
1668   def test(self):
1669     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1670                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1671
1672     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1673     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1674                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1675
1676
1677 class TestUnescapeAndSplit(unittest.TestCase):
1678   """Testing case for UnescapeAndSplit"""
1679
1680   def setUp(self):
1681     # testing more that one separator for regexp safety
1682     self._seps = [",", "+", "."]
1683
1684   def testSimple(self):
1685     a = ["a", "b", "c", "d"]
1686     for sep in self._seps:
1687       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1688
1689   def testEscape(self):
1690     for sep in self._seps:
1691       a = ["a", "b\\" + sep + "c", "d"]
1692       b = ["a", "b" + sep + "c", "d"]
1693       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1694
1695   def testDoubleEscape(self):
1696     for sep in self._seps:
1697       a = ["a", "b\\\\", "c", "d"]
1698       b = ["a", "b\\", "c", "d"]
1699       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1700
1701   def testThreeEscape(self):
1702     for sep in self._seps:
1703       a = ["a", "b\\\\\\" + sep + "c", "d"]
1704       b = ["a", "b\\" + sep + "c", "d"]
1705       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1706
1707
1708 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1709   def setUp(self):
1710     self.tmpdir = tempfile.mkdtemp()
1711
1712   def tearDown(self):
1713     shutil.rmtree(self.tmpdir)
1714
1715   def _checkRsaPrivateKey(self, key):
1716     lines = key.splitlines()
1717     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1718             "-----END RSA PRIVATE KEY-----" in lines)
1719
1720   def _checkCertificate(self, cert):
1721     lines = cert.splitlines()
1722     return ("-----BEGIN CERTIFICATE-----" in lines and
1723             "-----END CERTIFICATE-----" in lines)
1724
1725   def test(self):
1726     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1727       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1728       self._checkRsaPrivateKey(key_pem)
1729       self._checkCertificate(cert_pem)
1730
1731       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1732                                            key_pem)
1733       self.assert_(key.bits() >= 1024)
1734       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1735       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1736
1737       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1738                                              cert_pem)
1739       self.failIf(x509.has_expired())
1740       self.assertEqual(x509.get_issuer().CN, common_name)
1741       self.assertEqual(x509.get_subject().CN, common_name)
1742       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1743
1744   def testLegacy(self):
1745     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1746
1747     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1748
1749     cert1 = utils.ReadFile(cert1_filename)
1750
1751     self.assert_(self._checkRsaPrivateKey(cert1))
1752     self.assert_(self._checkCertificate(cert1))
1753
1754
1755 class TestPathJoin(unittest.TestCase):
1756   """Testing case for PathJoin"""
1757
1758   def testBasicItems(self):
1759     mlist = ["/a", "b", "c"]
1760     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1761
1762   def testNonAbsPrefix(self):
1763     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1764
1765   def testBackTrack(self):
1766     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1767
1768   def testMultiAbs(self):
1769     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1770
1771
1772 class TestHostInfo(unittest.TestCase):
1773   """Testing case for HostInfo"""
1774
1775   def testUppercase(self):
1776     data = "AbC.example.com"
1777     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1778
1779   def testTooLongName(self):
1780     data = "a.b." + "c" * 255
1781     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1782
1783   def testTrailingDot(self):
1784     data = "a.b.c"
1785     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1786
1787   def testInvalidName(self):
1788     data = [
1789       "a b",
1790       "a/b",
1791       ".a.b",
1792       "a..b",
1793       ]
1794     for value in data:
1795       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1796
1797   def testValidName(self):
1798     data = [
1799       "a.b",
1800       "a-b",
1801       "a_b",
1802       "a.b.c",
1803       ]
1804     for value in data:
1805       HostInfo.NormalizeName(value)
1806
1807
1808 class TestParseAsn1Generalizedtime(unittest.TestCase):
1809   def test(self):
1810     # UTC
1811     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1812     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1813                      1266860512)
1814     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1815                      (2**31) - 1)
1816
1817     # With offset
1818     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1819                      1266860512)
1820     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1821                      1266931012)
1822     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1823                      1266931088)
1824     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1825                      1266931295)
1826     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1827                      3600)
1828
1829     # Leap seconds are not supported by datetime.datetime
1830     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1831                       "19841231235960+0000")
1832     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1833                       "19920630235960+0000")
1834
1835     # Errors
1836     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1837     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1838     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1839                       "20100222174152")
1840     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1841                       "Mon Feb 22 17:47:02 UTC 2010")
1842     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1843                       "2010-02-22 17:42:02")
1844
1845
1846 class TestGetX509CertValidity(testutils.GanetiTestCase):
1847   def setUp(self):
1848     testutils.GanetiTestCase.setUp(self)
1849
1850     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1851
1852     # Test whether we have pyOpenSSL 0.7 or above
1853     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1854
1855     if not self.pyopenssl0_7:
1856       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1857                     " function correctly")
1858
1859   def _LoadCert(self, name):
1860     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1861                                            self._ReadTestData(name))
1862
1863   def test(self):
1864     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1865     if self.pyopenssl0_7:
1866       self.assertEqual(validity, (1266919967, 1267524767))
1867     else:
1868       self.assertEqual(validity, (None, None))
1869
1870
1871 class TestSignX509Certificate(unittest.TestCase):
1872   KEY = "My private key!"
1873   KEY_OTHER = "Another key"
1874
1875   def test(self):
1876     # Generate certificate valid for 5 minutes
1877     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1878
1879     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1880                                            cert_pem)
1881
1882     # No signature at all
1883     self.assertRaises(errors.GenericError,
1884                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1885
1886     # Invalid input
1887     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1888                       "", self.KEY)
1889     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1890                       "X-Ganeti-Signature: \n", self.KEY)
1891     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1892                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1893     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1894                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1895     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1896                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1897
1898     # Invalid salt
1899     for salt in list("-_@$,:;/\\ \t\n"):
1900       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1901                         cert_pem, self.KEY, "foo%sbar" % salt)
1902
1903     for salt in ["HelloWorld", "salt", string.letters, string.digits,
1904                  utils.GenerateSecret(numbytes=4),
1905                  utils.GenerateSecret(numbytes=16),
1906                  "{123:456}".encode("hex")]:
1907       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1908
1909       self._Check(cert, salt, signed_pem)
1910
1911       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1912       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1913       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1914                                "lines----\n------ at\nthe end!"))
1915
1916   def _Check(self, cert, salt, pem):
1917     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1918     self.assertEqual(salt, salt2)
1919     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1920
1921     # Other key
1922     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1923                       pem, self.KEY_OTHER)
1924
1925
1926 class TestMakedirs(unittest.TestCase):
1927   def setUp(self):
1928     self.tmpdir = tempfile.mkdtemp()
1929
1930   def tearDown(self):
1931     shutil.rmtree(self.tmpdir)
1932
1933   def testNonExisting(self):
1934     path = utils.PathJoin(self.tmpdir, "foo")
1935     utils.Makedirs(path)
1936     self.assert_(os.path.isdir(path))
1937
1938   def testExisting(self):
1939     path = utils.PathJoin(self.tmpdir, "foo")
1940     os.mkdir(path)
1941     utils.Makedirs(path)
1942     self.assert_(os.path.isdir(path))
1943
1944   def testRecursiveNonExisting(self):
1945     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
1946     utils.Makedirs(path)
1947     self.assert_(os.path.isdir(path))
1948
1949   def testRecursiveExisting(self):
1950     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
1951     self.assert_(not os.path.exists(path))
1952     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
1953     utils.Makedirs(path)
1954     self.assert_(os.path.isdir(path))
1955
1956
1957 class TestRetry(testutils.GanetiTestCase):
1958   def setUp(self):
1959     testutils.GanetiTestCase.setUp(self)
1960     self.retries = 0
1961
1962   @staticmethod
1963   def _RaiseRetryAgain():
1964     raise utils.RetryAgain()
1965
1966   @staticmethod
1967   def _RaiseRetryAgainWithArg(args):
1968     raise utils.RetryAgain(*args)
1969
1970   def _WrongNestedLoop(self):
1971     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1972
1973   def _RetryAndSucceed(self, retries):
1974     if self.retries < retries:
1975       self.retries += 1
1976       raise utils.RetryAgain()
1977     else:
1978       return True
1979
1980   def testRaiseTimeout(self):
1981     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1982                           self._RaiseRetryAgain, 0.01, 0.02)
1983     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1984                           self._RetryAndSucceed, 0.01, 0, args=[1])
1985     self.failUnlessEqual(self.retries, 1)
1986
1987   def testComplete(self):
1988     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1989     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1990                          True)
1991     self.failUnlessEqual(self.retries, 2)
1992
1993   def testNestedLoop(self):
1994     try:
1995       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1996                             self._WrongNestedLoop, 0, 1)
1997     except utils.RetryTimeout:
1998       self.fail("Didn't detect inner loop's exception")
1999
2000   def testTimeoutArgument(self):
2001     retry_arg="my_important_debugging_message"
2002     try:
2003       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2004     except utils.RetryTimeout, err:
2005       self.failUnlessEqual(err.args, (retry_arg, ))
2006     else:
2007       self.fail("Expected timeout didn't happen")
2008
2009   def testRaiseInnerWithExc(self):
2010     retry_arg="my_important_debugging_message"
2011     try:
2012       try:
2013         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2014                     args=[[errors.GenericError(retry_arg, retry_arg)]])
2015       except utils.RetryTimeout, err:
2016         err.RaiseInner()
2017       else:
2018         self.fail("Expected timeout didn't happen")
2019     except errors.GenericError, err:
2020       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2021     else:
2022       self.fail("Expected GenericError didn't happen")
2023
2024   def testRaiseInnerWithMsg(self):
2025     retry_arg="my_important_debugging_message"
2026     try:
2027       try:
2028         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2029                     args=[[retry_arg, retry_arg]])
2030       except utils.RetryTimeout, err:
2031         err.RaiseInner()
2032       else:
2033         self.fail("Expected timeout didn't happen")
2034     except utils.RetryTimeout, err:
2035       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2036     else:
2037       self.fail("Expected RetryTimeout didn't happen")
2038
2039
2040 class TestLineSplitter(unittest.TestCase):
2041   def test(self):
2042     lines = []
2043     ls = utils.LineSplitter(lines.append)
2044     ls.write("Hello World\n")
2045     self.assertEqual(lines, [])
2046     ls.write("Foo\n Bar\r\n ")
2047     ls.write("Baz")
2048     ls.write("Moo")
2049     self.assertEqual(lines, [])
2050     ls.flush()
2051     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2052     ls.close()
2053     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2054
2055   def _testExtra(self, line, all_lines, p1, p2):
2056     self.assertEqual(p1, 999)
2057     self.assertEqual(p2, "extra")
2058     all_lines.append(line)
2059
2060   def testExtraArgsNoFlush(self):
2061     lines = []
2062     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2063     ls.write("\n\nHello World\n")
2064     ls.write("Foo\n Bar\r\n ")
2065     ls.write("")
2066     ls.write("Baz")
2067     ls.write("Moo\n\nx\n")
2068     self.assertEqual(lines, [])
2069     ls.close()
2070     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2071                              "", "x"])
2072
2073
2074 class TestReadLockedPidFile(unittest.TestCase):
2075   def setUp(self):
2076     self.tmpdir = tempfile.mkdtemp()
2077
2078   def tearDown(self):
2079     shutil.rmtree(self.tmpdir)
2080
2081   def testNonExistent(self):
2082     path = utils.PathJoin(self.tmpdir, "nonexist")
2083     self.assert_(utils.ReadLockedPidFile(path) is None)
2084
2085   def testUnlocked(self):
2086     path = utils.PathJoin(self.tmpdir, "pid")
2087     utils.WriteFile(path, data="123")
2088     self.assert_(utils.ReadLockedPidFile(path) is None)
2089
2090   def testLocked(self):
2091     path = utils.PathJoin(self.tmpdir, "pid")
2092     utils.WriteFile(path, data="123")
2093
2094     fl = utils.FileLock.Open(path)
2095     try:
2096       fl.Exclusive(blocking=True)
2097
2098       self.assertEqual(utils.ReadLockedPidFile(path), 123)
2099     finally:
2100       fl.Close()
2101
2102     self.assert_(utils.ReadLockedPidFile(path) is None)
2103
2104   def testError(self):
2105     path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2106     utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2107     # open(2) should return ENOTDIR
2108     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2109
2110
2111 class TestCertVerification(testutils.GanetiTestCase):
2112   def setUp(self):
2113     testutils.GanetiTestCase.setUp(self)
2114
2115     self.tmpdir = tempfile.mkdtemp()
2116
2117   def tearDown(self):
2118     shutil.rmtree(self.tmpdir)
2119
2120   def testVerifyCertificate(self):
2121     cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2122     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2123                                            cert_pem)
2124
2125     # Not checking return value as this certificate is expired
2126     utils.VerifyX509Certificate(cert, 30, 7)
2127
2128
2129 class TestVerifyCertificateInner(unittest.TestCase):
2130   def test(self):
2131     vci = utils._VerifyCertificateInner
2132
2133     # Valid
2134     self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2135                      (None, None))
2136
2137     # Not yet valid
2138     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2139     self.assertEqual(errcode, utils.CERT_WARNING)
2140
2141     # Expiring soon
2142     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2143     self.assertEqual(errcode, utils.CERT_ERROR)
2144
2145     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2146     self.assertEqual(errcode, utils.CERT_WARNING)
2147
2148     (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2149     self.assertEqual(errcode, None)
2150
2151     # Expired
2152     (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2153     self.assertEqual(errcode, utils.CERT_ERROR)
2154
2155     (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2156     self.assertEqual(errcode, utils.CERT_ERROR)
2157
2158     (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2159     self.assertEqual(errcode, utils.CERT_ERROR)
2160
2161     (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2162     self.assertEqual(errcode, utils.CERT_ERROR)
2163
2164
2165 class TestHmacFunctions(unittest.TestCase):
2166   # Digests can be checked with "openssl sha1 -hmac $key"
2167   def testSha1Hmac(self):
2168     self.assertEqual(utils.Sha1Hmac("", ""),
2169                      "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2170     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2171                      "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2172     self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2173                      "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2174
2175     longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2176     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2177                      "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2178
2179   def testSha1HmacSalt(self):
2180     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2181                      "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2182     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2183                      "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2184     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2185                      "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2186
2187   def testVerifySha1Hmac(self):
2188     self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2189                                                "7d64b71fb76370690e1d")))
2190     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2191                                       ("f904c2476527c6d3e660"
2192                                        "9ab683c66fa0652cb1dc")))
2193
2194     digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2195     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2196     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2197                                       digest.lower()))
2198     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2199                                       digest.upper()))
2200     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2201                                       digest.title()))
2202
2203   def testVerifySha1HmacSalt(self):
2204     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2205                                       ("17a4adc34d69c0d367d4"
2206                                        "ffbef96fd41d4df7a6e8"),
2207                                       salt="abc9"))
2208     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2209                                       ("7f264f8114c9066afc9b"
2210                                        "b7636e1786d996d3cc0d"),
2211                                       salt="xyz0"))
2212
2213
2214 class TestIgnoreSignals(unittest.TestCase):
2215   """Test the IgnoreSignals decorator"""
2216
2217   @staticmethod
2218   def _Raise(exception):
2219     raise exception
2220
2221   @staticmethod
2222   def _Return(rval):
2223     return rval
2224
2225   def testIgnoreSignals(self):
2226     sock_err_intr = socket.error(errno.EINTR, "Message")
2227     sock_err_inval = socket.error(errno.EINVAL, "Message")
2228
2229     env_err_intr = EnvironmentError(errno.EINTR, "Message")
2230     env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2231
2232     self.assertRaises(socket.error, self._Raise, sock_err_intr)
2233     self.assertRaises(socket.error, self._Raise, sock_err_inval)
2234     self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2235     self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2236
2237     self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2238     self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2239     self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2240                       sock_err_inval)
2241     self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2242                       env_err_inval)
2243
2244     self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2245     self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2246
2247
2248 class TestEnsureDirs(unittest.TestCase):
2249   """Tests for EnsureDirs"""
2250
2251   def setUp(self):
2252     self.dir = tempfile.mkdtemp()
2253     self.old_umask = os.umask(0777)
2254
2255   def testEnsureDirs(self):
2256     utils.EnsureDirs([
2257         (utils.PathJoin(self.dir, "foo"), 0777),
2258         (utils.PathJoin(self.dir, "bar"), 0000),
2259         ])
2260     self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2261     self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2262
2263   def tearDown(self):
2264     os.rmdir(utils.PathJoin(self.dir, "foo"))
2265     os.rmdir(utils.PathJoin(self.dir, "bar"))
2266     os.rmdir(self.dir)
2267     os.umask(self.old_umask)
2268
2269 if __name__ == '__main__':
2270   testutils.GanetiTestProgram()