utils.RunCmd: Test case with reset_env set and setting variables
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import md5
32 import signal
33 import socket
34 import shutil
35 import re
36 import select
37 import string
38 import OpenSSL
39 import warnings
40 import distutils.version
41 import glob
42
43 import ganeti
44 import testutils
45 from ganeti import constants
46 from ganeti import utils
47 from ganeti import errors
48 from ganeti.utils import IsProcessAlive, RunCmd, \
49      RemoveFile, MatchNameComponent, FormatUnit, \
50      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
51      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
52      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
53      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
54      UnescapeAndSplit, RunParts, PathJoin, HostInfo
55
56 from ganeti.errors import LockError, UnitParseError, GenericError, \
57      ProgrammerError, OpPrereqError
58
59
60 class TestIsProcessAlive(unittest.TestCase):
61   """Testing case for IsProcessAlive"""
62
63   def testExists(self):
64     mypid = os.getpid()
65     self.assert_(IsProcessAlive(mypid),
66                  "can't find myself running")
67
68   def testNotExisting(self):
69     pid_non_existing = os.fork()
70     if pid_non_existing == 0:
71       os._exit(0)
72     elif pid_non_existing < 0:
73       raise SystemError("can't fork")
74     os.waitpid(pid_non_existing, 0)
75     self.assert_(not IsProcessAlive(pid_non_existing),
76                  "nonexisting process detected")
77
78
79 class TestPidFileFunctions(unittest.TestCase):
80   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
81
82   def setUp(self):
83     self.dir = tempfile.mkdtemp()
84     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
85     utils.DaemonPidFileName = self.f_dpn
86
87   def testPidFileFunctions(self):
88     pid_file = self.f_dpn('test')
89     utils.WritePidFile('test')
90     self.failUnless(os.path.exists(pid_file),
91                     "PID file should have been created")
92     read_pid = utils.ReadPidFile(pid_file)
93     self.failUnlessEqual(read_pid, os.getpid())
94     self.failUnless(utils.IsProcessAlive(read_pid))
95     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
96     utils.RemovePidFile('test')
97     self.failIf(os.path.exists(pid_file),
98                 "PID file should not exist anymore")
99     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
100                          "ReadPidFile should return 0 for missing pid file")
101     fh = open(pid_file, "w")
102     fh.write("blah\n")
103     fh.close()
104     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
105                          "ReadPidFile should return 0 for invalid pid file")
106     utils.RemovePidFile('test')
107     self.failIf(os.path.exists(pid_file),
108                 "PID file should not exist anymore")
109
110   def testKill(self):
111     pid_file = self.f_dpn('child')
112     r_fd, w_fd = os.pipe()
113     new_pid = os.fork()
114     if new_pid == 0: #child
115       utils.WritePidFile('child')
116       os.write(w_fd, 'a')
117       signal.pause()
118       os._exit(0)
119       return
120     # else we are in the parent
121     # wait until the child has written the pid file
122     os.read(r_fd, 1)
123     read_pid = utils.ReadPidFile(pid_file)
124     self.failUnlessEqual(read_pid, new_pid)
125     self.failUnless(utils.IsProcessAlive(new_pid))
126     utils.KillProcess(new_pid, waitpid=True)
127     self.failIf(utils.IsProcessAlive(new_pid))
128     utils.RemovePidFile('child')
129     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
130
131   def tearDown(self):
132     for name in os.listdir(self.dir):
133       os.unlink(os.path.join(self.dir, name))
134     os.rmdir(self.dir)
135
136
137 class TestRunCmd(testutils.GanetiTestCase):
138   """Testing case for the RunCmd function"""
139
140   def setUp(self):
141     testutils.GanetiTestCase.setUp(self)
142     self.magic = time.ctime() + " ganeti test"
143     self.fname = self._CreateTempFile()
144
145   def testOk(self):
146     """Test successful exit code"""
147     result = RunCmd("/bin/sh -c 'exit 0'")
148     self.assertEqual(result.exit_code, 0)
149     self.assertEqual(result.output, "")
150
151   def testFail(self):
152     """Test fail exit code"""
153     result = RunCmd("/bin/sh -c 'exit 1'")
154     self.assertEqual(result.exit_code, 1)
155     self.assertEqual(result.output, "")
156
157   def testStdout(self):
158     """Test standard output"""
159     cmd = 'echo -n "%s"' % self.magic
160     result = RunCmd("/bin/sh -c '%s'" % cmd)
161     self.assertEqual(result.stdout, self.magic)
162     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
163     self.assertEqual(result.output, "")
164     self.assertFileContent(self.fname, self.magic)
165
166   def testStderr(self):
167     """Test standard error"""
168     cmd = 'echo -n "%s"' % self.magic
169     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
170     self.assertEqual(result.stderr, self.magic)
171     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
172     self.assertEqual(result.output, "")
173     self.assertFileContent(self.fname, self.magic)
174
175   def testCombined(self):
176     """Test combined output"""
177     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
178     expected = "A" + self.magic + "B" + self.magic
179     result = RunCmd("/bin/sh -c '%s'" % cmd)
180     self.assertEqual(result.output, expected)
181     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
182     self.assertEqual(result.output, "")
183     self.assertFileContent(self.fname, expected)
184
185   def testSignal(self):
186     """Test signal"""
187     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
188     self.assertEqual(result.signal, 15)
189     self.assertEqual(result.output, "")
190
191   def testListRun(self):
192     """Test list runs"""
193     result = RunCmd(["true"])
194     self.assertEqual(result.signal, None)
195     self.assertEqual(result.exit_code, 0)
196     result = RunCmd(["/bin/sh", "-c", "exit 1"])
197     self.assertEqual(result.signal, None)
198     self.assertEqual(result.exit_code, 1)
199     result = RunCmd(["echo", "-n", self.magic])
200     self.assertEqual(result.signal, None)
201     self.assertEqual(result.exit_code, 0)
202     self.assertEqual(result.stdout, self.magic)
203
204   def testFileEmptyOutput(self):
205     """Test file output"""
206     result = RunCmd(["true"], output=self.fname)
207     self.assertEqual(result.signal, None)
208     self.assertEqual(result.exit_code, 0)
209     self.assertFileContent(self.fname, "")
210
211   def testLang(self):
212     """Test locale environment"""
213     old_env = os.environ.copy()
214     try:
215       os.environ["LANG"] = "en_US.UTF-8"
216       os.environ["LC_ALL"] = "en_US.UTF-8"
217       result = RunCmd(["locale"])
218       for line in result.output.splitlines():
219         key, value = line.split("=", 1)
220         # Ignore these variables, they're overridden by LC_ALL
221         if key == "LANG" or key == "LANGUAGE":
222           continue
223         self.failIf(value and value != "C" and value != '"C"',
224             "Variable %s is set to the invalid value '%s'" % (key, value))
225     finally:
226       os.environ = old_env
227
228   def testDefaultCwd(self):
229     """Test default working directory"""
230     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
231
232   def testCwd(self):
233     """Test default working directory"""
234     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
235     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
236     cwd = os.getcwd()
237     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
238
239   def testResetEnv(self):
240     """Test environment reset functionality"""
241     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
242     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
243                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
244
245
246 class TestRunParts(unittest.TestCase):
247   """Testing case for the RunParts function"""
248
249   def setUp(self):
250     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
251
252   def tearDown(self):
253     shutil.rmtree(self.rundir)
254
255   def testEmpty(self):
256     """Test on an empty dir"""
257     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
258
259   def testSkipWrongName(self):
260     """Test that wrong files are skipped"""
261     fname = os.path.join(self.rundir, "00test.dot")
262     utils.WriteFile(fname, data="")
263     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
264     relname = os.path.basename(fname)
265     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
266                          [(relname, constants.RUNPARTS_SKIP, None)])
267
268   def testSkipNonExec(self):
269     """Test that non executable files are skipped"""
270     fname = os.path.join(self.rundir, "00test")
271     utils.WriteFile(fname, data="")
272     relname = os.path.basename(fname)
273     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
274                          [(relname, constants.RUNPARTS_SKIP, None)])
275
276   def testError(self):
277     """Test error on a broken executable"""
278     fname = os.path.join(self.rundir, "00test")
279     utils.WriteFile(fname, data="")
280     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
281     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
282     self.failUnlessEqual(relname, os.path.basename(fname))
283     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
284     self.failUnless(error)
285
286   def testSorted(self):
287     """Test executions are sorted"""
288     files = []
289     files.append(os.path.join(self.rundir, "64test"))
290     files.append(os.path.join(self.rundir, "00test"))
291     files.append(os.path.join(self.rundir, "42test"))
292
293     for fname in files:
294       utils.WriteFile(fname, data="")
295
296     results = RunParts(self.rundir, reset_env=True)
297
298     for fname in sorted(files):
299       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
300
301   def testOk(self):
302     """Test correct execution"""
303     fname = os.path.join(self.rundir, "00test")
304     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
305     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
306     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
307     self.failUnlessEqual(relname, os.path.basename(fname))
308     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
309     self.failUnlessEqual(runresult.stdout, "ciao")
310
311   def testRunFail(self):
312     """Test correct execution, with run failure"""
313     fname = os.path.join(self.rundir, "00test")
314     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
315     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
316     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
317     self.failUnlessEqual(relname, os.path.basename(fname))
318     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
319     self.failUnlessEqual(runresult.exit_code, 1)
320     self.failUnless(runresult.failed)
321
322   def testRunMix(self):
323     files = []
324     files.append(os.path.join(self.rundir, "00test"))
325     files.append(os.path.join(self.rundir, "42test"))
326     files.append(os.path.join(self.rundir, "64test"))
327     files.append(os.path.join(self.rundir, "99test"))
328
329     files.sort()
330
331     # 1st has errors in execution
332     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
333     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
334
335     # 2nd is skipped
336     utils.WriteFile(files[1], data="")
337
338     # 3rd cannot execute properly
339     utils.WriteFile(files[2], data="")
340     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
341
342     # 4th execs
343     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
344     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
345
346     results = RunParts(self.rundir, reset_env=True)
347
348     (relname, status, runresult) = results[0]
349     self.failUnlessEqual(relname, os.path.basename(files[0]))
350     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
351     self.failUnlessEqual(runresult.exit_code, 1)
352     self.failUnless(runresult.failed)
353
354     (relname, status, runresult) = results[1]
355     self.failUnlessEqual(relname, os.path.basename(files[1]))
356     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
357     self.failUnlessEqual(runresult, None)
358
359     (relname, status, runresult) = results[2]
360     self.failUnlessEqual(relname, os.path.basename(files[2]))
361     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
362     self.failUnless(runresult)
363
364     (relname, status, runresult) = results[3]
365     self.failUnlessEqual(relname, os.path.basename(files[3]))
366     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
367     self.failUnlessEqual(runresult.output, "ciao")
368     self.failUnlessEqual(runresult.exit_code, 0)
369     self.failUnless(not runresult.failed)
370
371
372 class TestRemoveFile(unittest.TestCase):
373   """Test case for the RemoveFile function"""
374
375   def setUp(self):
376     """Create a temp dir and file for each case"""
377     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
378     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
379     os.close(fd)
380
381   def tearDown(self):
382     if os.path.exists(self.tmpfile):
383       os.unlink(self.tmpfile)
384     os.rmdir(self.tmpdir)
385
386
387   def testIgnoreDirs(self):
388     """Test that RemoveFile() ignores directories"""
389     self.assertEqual(None, RemoveFile(self.tmpdir))
390
391
392   def testIgnoreNotExisting(self):
393     """Test that RemoveFile() ignores non-existing files"""
394     RemoveFile(self.tmpfile)
395     RemoveFile(self.tmpfile)
396
397
398   def testRemoveFile(self):
399     """Test that RemoveFile does remove a file"""
400     RemoveFile(self.tmpfile)
401     if os.path.exists(self.tmpfile):
402       self.fail("File '%s' not removed" % self.tmpfile)
403
404
405   def testRemoveSymlink(self):
406     """Test that RemoveFile does remove symlinks"""
407     symlink = self.tmpdir + "/symlink"
408     os.symlink("no-such-file", symlink)
409     RemoveFile(symlink)
410     if os.path.exists(symlink):
411       self.fail("File '%s' not removed" % symlink)
412     os.symlink(self.tmpfile, symlink)
413     RemoveFile(symlink)
414     if os.path.exists(symlink):
415       self.fail("File '%s' not removed" % symlink)
416
417
418 class TestRename(unittest.TestCase):
419   """Test case for RenameFile"""
420
421   def setUp(self):
422     """Create a temporary directory"""
423     self.tmpdir = tempfile.mkdtemp()
424     self.tmpfile = os.path.join(self.tmpdir, "test1")
425
426     # Touch the file
427     open(self.tmpfile, "w").close()
428
429   def tearDown(self):
430     """Remove temporary directory"""
431     shutil.rmtree(self.tmpdir)
432
433   def testSimpleRename1(self):
434     """Simple rename 1"""
435     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
436     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
437
438   def testSimpleRename2(self):
439     """Simple rename 2"""
440     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
441                      mkdir=True)
442     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
443
444   def testRenameMkdir(self):
445     """Rename with mkdir"""
446     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
447                      mkdir=True)
448     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
449     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
450
451     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
452                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
453                      mkdir=True)
454     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
455     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
456     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
457
458
459 class TestMatchNameComponent(unittest.TestCase):
460   """Test case for the MatchNameComponent function"""
461
462   def testEmptyList(self):
463     """Test that there is no match against an empty list"""
464
465     self.failUnlessEqual(MatchNameComponent("", []), None)
466     self.failUnlessEqual(MatchNameComponent("test", []), None)
467
468   def testSingleMatch(self):
469     """Test that a single match is performed correctly"""
470     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
471     for key in "test2", "test2.example", "test2.example.com":
472       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
473
474   def testMultipleMatches(self):
475     """Test that a multiple match is returned as None"""
476     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
477     for key in "test1", "test1.example":
478       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
479
480   def testFullMatch(self):
481     """Test that a full match is returned correctly"""
482     key1 = "test1"
483     key2 = "test1.example"
484     mlist = [key2, key2 + ".com"]
485     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
486     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
487
488   def testCaseInsensitivePartialMatch(self):
489     """Test for the case_insensitive keyword"""
490     mlist = ["test1.example.com", "test2.example.net"]
491     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
492                      "test2.example.net")
493     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
494                      "test2.example.net")
495     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
496                      "test2.example.net")
497     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
498                      "test2.example.net")
499
500
501   def testCaseInsensitiveFullMatch(self):
502     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
503     # Between the two ts1 a full string match non-case insensitive should work
504     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
505                      None)
506     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
507                      "ts1.ex")
508     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
509                      "ts1.ex")
510     # Between the two ts2 only case differs, so only case-match works
511     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
512                      "ts2.ex")
513     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
514                      "Ts2.ex")
515     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
516                      None)
517
518
519 class TestTimestampForFilename(unittest.TestCase):
520   def test(self):
521     self.assert_("." not in utils.TimestampForFilename())
522     self.assert_(":" not in utils.TimestampForFilename())
523
524
525 class TestCreateBackup(testutils.GanetiTestCase):
526   def setUp(self):
527     testutils.GanetiTestCase.setUp(self)
528
529     self.tmpdir = tempfile.mkdtemp()
530
531   def tearDown(self):
532     testutils.GanetiTestCase.tearDown(self)
533
534     shutil.rmtree(self.tmpdir)
535
536   def testEmpty(self):
537     filename = utils.PathJoin(self.tmpdir, "config.data")
538     utils.WriteFile(filename, data="")
539     bname = utils.CreateBackup(filename)
540     self.assertFileContent(bname, "")
541     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
542     utils.CreateBackup(filename)
543     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
544     utils.CreateBackup(filename)
545     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
546
547     fifoname = utils.PathJoin(self.tmpdir, "fifo")
548     os.mkfifo(fifoname)
549     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
550
551   def testContent(self):
552     bkpcount = 0
553     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
554       for rep in [1, 2, 10, 127]:
555         testdata = data * rep
556
557         filename = utils.PathJoin(self.tmpdir, "test.data_")
558         utils.WriteFile(filename, data=testdata)
559         self.assertFileContent(filename, testdata)
560
561         for _ in range(3):
562           bname = utils.CreateBackup(filename)
563           bkpcount += 1
564           self.assertFileContent(bname, testdata)
565           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
566
567
568 class TestFormatUnit(unittest.TestCase):
569   """Test case for the FormatUnit function"""
570
571   def testMiB(self):
572     self.assertEqual(FormatUnit(1, 'h'), '1M')
573     self.assertEqual(FormatUnit(100, 'h'), '100M')
574     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
575
576     self.assertEqual(FormatUnit(1, 'm'), '1')
577     self.assertEqual(FormatUnit(100, 'm'), '100')
578     self.assertEqual(FormatUnit(1023, 'm'), '1023')
579
580     self.assertEqual(FormatUnit(1024, 'm'), '1024')
581     self.assertEqual(FormatUnit(1536, 'm'), '1536')
582     self.assertEqual(FormatUnit(17133, 'm'), '17133')
583     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
584
585   def testGiB(self):
586     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
587     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
588     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
589     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
590
591     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
592     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
593     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
594     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
595
596     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
597     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
598     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
599
600   def testTiB(self):
601     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
602     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
603     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
604
605     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
606     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
607     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
608
609 class TestParseUnit(unittest.TestCase):
610   """Test case for the ParseUnit function"""
611
612   SCALES = (('', 1),
613             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
614             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
615             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
616
617   def testRounding(self):
618     self.assertEqual(ParseUnit('0'), 0)
619     self.assertEqual(ParseUnit('1'), 4)
620     self.assertEqual(ParseUnit('2'), 4)
621     self.assertEqual(ParseUnit('3'), 4)
622
623     self.assertEqual(ParseUnit('124'), 124)
624     self.assertEqual(ParseUnit('125'), 128)
625     self.assertEqual(ParseUnit('126'), 128)
626     self.assertEqual(ParseUnit('127'), 128)
627     self.assertEqual(ParseUnit('128'), 128)
628     self.assertEqual(ParseUnit('129'), 132)
629     self.assertEqual(ParseUnit('130'), 132)
630
631   def testFloating(self):
632     self.assertEqual(ParseUnit('0'), 0)
633     self.assertEqual(ParseUnit('0.5'), 4)
634     self.assertEqual(ParseUnit('1.75'), 4)
635     self.assertEqual(ParseUnit('1.99'), 4)
636     self.assertEqual(ParseUnit('2.00'), 4)
637     self.assertEqual(ParseUnit('2.01'), 4)
638     self.assertEqual(ParseUnit('3.99'), 4)
639     self.assertEqual(ParseUnit('4.00'), 4)
640     self.assertEqual(ParseUnit('4.01'), 8)
641     self.assertEqual(ParseUnit('1.5G'), 1536)
642     self.assertEqual(ParseUnit('1.8G'), 1844)
643     self.assertEqual(ParseUnit('8.28T'), 8682212)
644
645   def testSuffixes(self):
646     for sep in ('', ' ', '   ', "\t", "\t "):
647       for suffix, scale in TestParseUnit.SCALES:
648         for func in (lambda x: x, str.lower, str.upper):
649           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
650                            1024 * scale)
651
652   def testInvalidInput(self):
653     for sep in ('-', '_', ',', 'a'):
654       for suffix, _ in TestParseUnit.SCALES:
655         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
656
657     for suffix, _ in TestParseUnit.SCALES:
658       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
659
660
661 class TestSshKeys(testutils.GanetiTestCase):
662   """Test case for the AddAuthorizedKey function"""
663
664   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
665   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
666            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
667
668   def setUp(self):
669     testutils.GanetiTestCase.setUp(self)
670     self.tmpname = self._CreateTempFile()
671     handle = open(self.tmpname, 'w')
672     try:
673       handle.write("%s\n" % TestSshKeys.KEY_A)
674       handle.write("%s\n" % TestSshKeys.KEY_B)
675     finally:
676       handle.close()
677
678   def testAddingNewKey(self):
679     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
680
681     self.assertFileContent(self.tmpname,
682       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
683       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
684       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
685       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
686
687   def testAddingAlmostButNotCompletelyTheSameKey(self):
688     AddAuthorizedKey(self.tmpname,
689         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
690
691     self.assertFileContent(self.tmpname,
692       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
693       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
694       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
695       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
696
697   def testAddingExistingKeyWithSomeMoreSpaces(self):
698     AddAuthorizedKey(self.tmpname,
699         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
700
701     self.assertFileContent(self.tmpname,
702       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
703       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
704       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
705
706   def testRemovingExistingKeyWithSomeMoreSpaces(self):
707     RemoveAuthorizedKey(self.tmpname,
708         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
709
710     self.assertFileContent(self.tmpname,
711       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
712       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
713
714   def testRemovingNonExistingKey(self):
715     RemoveAuthorizedKey(self.tmpname,
716         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
717
718     self.assertFileContent(self.tmpname,
719       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
720       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
721       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
722
723
724 class TestEtcHosts(testutils.GanetiTestCase):
725   """Test functions modifying /etc/hosts"""
726
727   def setUp(self):
728     testutils.GanetiTestCase.setUp(self)
729     self.tmpname = self._CreateTempFile()
730     handle = open(self.tmpname, 'w')
731     try:
732       handle.write('# This is a test file for /etc/hosts\n')
733       handle.write('127.0.0.1\tlocalhost\n')
734       handle.write('192.168.1.1 router gw\n')
735     finally:
736       handle.close()
737
738   def testSettingNewIp(self):
739     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
740
741     self.assertFileContent(self.tmpname,
742       "# This is a test file for /etc/hosts\n"
743       "127.0.0.1\tlocalhost\n"
744       "192.168.1.1 router gw\n"
745       "1.2.3.4\tmyhost.domain.tld myhost\n")
746     self.assertFileMode(self.tmpname, 0644)
747
748   def testSettingExistingIp(self):
749     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
750                      ['myhost'])
751
752     self.assertFileContent(self.tmpname,
753       "# This is a test file for /etc/hosts\n"
754       "127.0.0.1\tlocalhost\n"
755       "192.168.1.1\tmyhost.domain.tld myhost\n")
756     self.assertFileMode(self.tmpname, 0644)
757
758   def testSettingDuplicateName(self):
759     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
760
761     self.assertFileContent(self.tmpname,
762       "# This is a test file for /etc/hosts\n"
763       "127.0.0.1\tlocalhost\n"
764       "192.168.1.1 router gw\n"
765       "1.2.3.4\tmyhost\n")
766     self.assertFileMode(self.tmpname, 0644)
767
768   def testRemovingExistingHost(self):
769     RemoveEtcHostsEntry(self.tmpname, 'router')
770
771     self.assertFileContent(self.tmpname,
772       "# This is a test file for /etc/hosts\n"
773       "127.0.0.1\tlocalhost\n"
774       "192.168.1.1 gw\n")
775     self.assertFileMode(self.tmpname, 0644)
776
777   def testRemovingSingleExistingHost(self):
778     RemoveEtcHostsEntry(self.tmpname, 'localhost')
779
780     self.assertFileContent(self.tmpname,
781       "# This is a test file for /etc/hosts\n"
782       "192.168.1.1 router gw\n")
783     self.assertFileMode(self.tmpname, 0644)
784
785   def testRemovingNonExistingHost(self):
786     RemoveEtcHostsEntry(self.tmpname, 'myhost')
787
788     self.assertFileContent(self.tmpname,
789       "# This is a test file for /etc/hosts\n"
790       "127.0.0.1\tlocalhost\n"
791       "192.168.1.1 router gw\n")
792     self.assertFileMode(self.tmpname, 0644)
793
794   def testRemovingAlias(self):
795     RemoveEtcHostsEntry(self.tmpname, 'gw')
796
797     self.assertFileContent(self.tmpname,
798       "# This is a test file for /etc/hosts\n"
799       "127.0.0.1\tlocalhost\n"
800       "192.168.1.1 router\n")
801     self.assertFileMode(self.tmpname, 0644)
802
803
804 class TestShellQuoting(unittest.TestCase):
805   """Test case for shell quoting functions"""
806
807   def testShellQuote(self):
808     self.assertEqual(ShellQuote('abc'), "abc")
809     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
810     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
811     self.assertEqual(ShellQuote("a b c"), "'a b c'")
812     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
813
814   def testShellQuoteArgs(self):
815     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
816     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
817     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
818
819
820 class TestTcpPing(unittest.TestCase):
821   """Testcase for TCP version of ping - against listen(2)ing port"""
822
823   def setUp(self):
824     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
825     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
826     self.listenerport = self.listener.getsockname()[1]
827     self.listener.listen(1)
828
829   def tearDown(self):
830     self.listener.shutdown(socket.SHUT_RDWR)
831     del self.listener
832     del self.listenerport
833
834   def testTcpPingToLocalHostAccept(self):
835     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
836                          self.listenerport,
837                          timeout=10,
838                          live_port_needed=True,
839                          source=constants.LOCALHOST_IP_ADDRESS,
840                          ),
841                  "failed to connect to test listener")
842
843     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
844                          self.listenerport,
845                          timeout=10,
846                          live_port_needed=True,
847                          ),
848                  "failed to connect to test listener (no source)")
849
850
851 class TestTcpPingDeaf(unittest.TestCase):
852   """Testcase for TCP version of ping - against non listen(2)ing port"""
853
854   def setUp(self):
855     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
856     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
857     self.deaflistenerport = self.deaflistener.getsockname()[1]
858
859   def tearDown(self):
860     del self.deaflistener
861     del self.deaflistenerport
862
863   def testTcpPingToLocalHostAcceptDeaf(self):
864     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
865                         self.deaflistenerport,
866                         timeout=constants.TCP_PING_TIMEOUT,
867                         live_port_needed=True,
868                         source=constants.LOCALHOST_IP_ADDRESS,
869                         ), # need successful connect(2)
870                 "successfully connected to deaf listener")
871
872     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
873                         self.deaflistenerport,
874                         timeout=constants.TCP_PING_TIMEOUT,
875                         live_port_needed=True,
876                         ), # need successful connect(2)
877                 "successfully connected to deaf listener (no source addr)")
878
879   def testTcpPingToLocalHostNoAccept(self):
880     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
881                          self.deaflistenerport,
882                          timeout=constants.TCP_PING_TIMEOUT,
883                          live_port_needed=False,
884                          source=constants.LOCALHOST_IP_ADDRESS,
885                          ), # ECONNREFUSED is OK
886                  "failed to ping alive host on deaf port")
887
888     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
889                          self.deaflistenerport,
890                          timeout=constants.TCP_PING_TIMEOUT,
891                          live_port_needed=False,
892                          ), # ECONNREFUSED is OK
893                  "failed to ping alive host on deaf port (no source addr)")
894
895
896 class TestOwnIpAddress(unittest.TestCase):
897   """Testcase for OwnIpAddress"""
898
899   def testOwnLoopback(self):
900     """check having the loopback ip"""
901     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
902                     "Should own the loopback address")
903
904   def testNowOwnAddress(self):
905     """check that I don't own an address"""
906
907     # network 192.0.2.0/24 is reserved for test/documentation as per
908     # rfc 3330, so we *should* not have an address of this range... if
909     # this fails, we should extend the test to multiple addresses
910     DST_IP = "192.0.2.1"
911     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
912
913
914 class TestListVisibleFiles(unittest.TestCase):
915   """Test case for ListVisibleFiles"""
916
917   def setUp(self):
918     self.path = tempfile.mkdtemp()
919
920   def tearDown(self):
921     shutil.rmtree(self.path)
922
923   def _test(self, files, expected):
924     # Sort a copy
925     expected = expected[:]
926     expected.sort()
927
928     for name in files:
929       f = open(os.path.join(self.path, name), 'w')
930       try:
931         f.write("Test\n")
932       finally:
933         f.close()
934
935     found = ListVisibleFiles(self.path)
936     found.sort()
937
938     self.assertEqual(found, expected)
939
940   def testAllVisible(self):
941     files = ["a", "b", "c"]
942     expected = files
943     self._test(files, expected)
944
945   def testNoneVisible(self):
946     files = [".a", ".b", ".c"]
947     expected = []
948     self._test(files, expected)
949
950   def testSomeVisible(self):
951     files = ["a", "b", ".c"]
952     expected = ["a", "b"]
953     self._test(files, expected)
954
955   def testNonAbsolutePath(self):
956     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
957
958   def testNonNormalizedPath(self):
959     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
960                           "/bin/../tmp")
961
962
963 class TestNewUUID(unittest.TestCase):
964   """Test case for NewUUID"""
965
966   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
967                         '[a-f0-9]{4}-[a-f0-9]{12}$')
968
969   def runTest(self):
970     self.failUnless(self._re_uuid.match(utils.NewUUID()))
971
972
973 class TestUniqueSequence(unittest.TestCase):
974   """Test case for UniqueSequence"""
975
976   def _test(self, input, expected):
977     self.assertEqual(utils.UniqueSequence(input), expected)
978
979   def runTest(self):
980     # Ordered input
981     self._test([1, 2, 3], [1, 2, 3])
982     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
983     self._test([1, 2, 2, 3], [1, 2, 3])
984     self._test([1, 2, 3, 3], [1, 2, 3])
985
986     # Unordered input
987     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
988     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
989
990     # Strings
991     self._test(["a", "a"], ["a"])
992     self._test(["a", "b"], ["a", "b"])
993     self._test(["a", "b", "a"], ["a", "b"])
994
995
996 class TestFirstFree(unittest.TestCase):
997   """Test case for the FirstFree function"""
998
999   def test(self):
1000     """Test FirstFree"""
1001     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1002     self.failUnlessEqual(FirstFree([]), None)
1003     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1004     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1005     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1006
1007
1008 class TestTailFile(testutils.GanetiTestCase):
1009   """Test case for the TailFile function"""
1010
1011   def testEmpty(self):
1012     fname = self._CreateTempFile()
1013     self.failUnlessEqual(TailFile(fname), [])
1014     self.failUnlessEqual(TailFile(fname, lines=25), [])
1015
1016   def testAllLines(self):
1017     data = ["test %d" % i for i in range(30)]
1018     for i in range(30):
1019       fname = self._CreateTempFile()
1020       fd = open(fname, "w")
1021       fd.write("\n".join(data[:i]))
1022       if i > 0:
1023         fd.write("\n")
1024       fd.close()
1025       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1026
1027   def testPartialLines(self):
1028     data = ["test %d" % i for i in range(30)]
1029     fname = self._CreateTempFile()
1030     fd = open(fname, "w")
1031     fd.write("\n".join(data))
1032     fd.write("\n")
1033     fd.close()
1034     for i in range(1, 30):
1035       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1036
1037   def testBigFile(self):
1038     data = ["test %d" % i for i in range(30)]
1039     fname = self._CreateTempFile()
1040     fd = open(fname, "w")
1041     fd.write("X" * 1048576)
1042     fd.write("\n")
1043     fd.write("\n".join(data))
1044     fd.write("\n")
1045     fd.close()
1046     for i in range(1, 30):
1047       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1048
1049
1050 class _BaseFileLockTest:
1051   """Test case for the FileLock class"""
1052
1053   def testSharedNonblocking(self):
1054     self.lock.Shared(blocking=False)
1055     self.lock.Close()
1056
1057   def testExclusiveNonblocking(self):
1058     self.lock.Exclusive(blocking=False)
1059     self.lock.Close()
1060
1061   def testUnlockNonblocking(self):
1062     self.lock.Unlock(blocking=False)
1063     self.lock.Close()
1064
1065   def testSharedBlocking(self):
1066     self.lock.Shared(blocking=True)
1067     self.lock.Close()
1068
1069   def testExclusiveBlocking(self):
1070     self.lock.Exclusive(blocking=True)
1071     self.lock.Close()
1072
1073   def testUnlockBlocking(self):
1074     self.lock.Unlock(blocking=True)
1075     self.lock.Close()
1076
1077   def testSharedExclusiveUnlock(self):
1078     self.lock.Shared(blocking=False)
1079     self.lock.Exclusive(blocking=False)
1080     self.lock.Unlock(blocking=False)
1081     self.lock.Close()
1082
1083   def testExclusiveSharedUnlock(self):
1084     self.lock.Exclusive(blocking=False)
1085     self.lock.Shared(blocking=False)
1086     self.lock.Unlock(blocking=False)
1087     self.lock.Close()
1088
1089   def testSimpleTimeout(self):
1090     # These will succeed on the first attempt, hence a short timeout
1091     self.lock.Shared(blocking=True, timeout=10.0)
1092     self.lock.Exclusive(blocking=False, timeout=10.0)
1093     self.lock.Unlock(blocking=True, timeout=10.0)
1094     self.lock.Close()
1095
1096   @staticmethod
1097   def _TryLockInner(filename, shared, blocking):
1098     lock = utils.FileLock.Open(filename)
1099
1100     if shared:
1101       fn = lock.Shared
1102     else:
1103       fn = lock.Exclusive
1104
1105     try:
1106       # The timeout doesn't really matter as the parent process waits for us to
1107       # finish anyway.
1108       fn(blocking=blocking, timeout=0.01)
1109     except errors.LockError, err:
1110       return False
1111
1112     return True
1113
1114   def _TryLock(self, *args):
1115     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1116                                       *args)
1117
1118   def testTimeout(self):
1119     for blocking in [True, False]:
1120       self.lock.Exclusive(blocking=True)
1121       self.failIf(self._TryLock(False, blocking))
1122       self.failIf(self._TryLock(True, blocking))
1123
1124       self.lock.Shared(blocking=True)
1125       self.assert_(self._TryLock(True, blocking))
1126       self.failIf(self._TryLock(False, blocking))
1127
1128   def testCloseShared(self):
1129     self.lock.Close()
1130     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1131
1132   def testCloseExclusive(self):
1133     self.lock.Close()
1134     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1135
1136   def testCloseUnlock(self):
1137     self.lock.Close()
1138     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1139
1140
1141 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1142   TESTDATA = "Hello World\n" * 10
1143
1144   def setUp(self):
1145     testutils.GanetiTestCase.setUp(self)
1146
1147     self.tmpfile = tempfile.NamedTemporaryFile()
1148     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1149     self.lock = utils.FileLock.Open(self.tmpfile.name)
1150
1151     # Ensure "Open" didn't truncate file
1152     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1153
1154   def tearDown(self):
1155     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1156
1157     testutils.GanetiTestCase.tearDown(self)
1158
1159
1160 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1161   def setUp(self):
1162     self.tmpfile = tempfile.NamedTemporaryFile()
1163     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1164
1165
1166 class TestTimeFunctions(unittest.TestCase):
1167   """Test case for time functions"""
1168
1169   def runTest(self):
1170     self.assertEqual(utils.SplitTime(1), (1, 0))
1171     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1172     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1173     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1174     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1175     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1176     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1177     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1178
1179     self.assertRaises(AssertionError, utils.SplitTime, -1)
1180
1181     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1182     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1183     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1184
1185     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1186                      1218448917.481)
1187     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1188
1189     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1190     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1191     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1192     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1193     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1194
1195
1196 class FieldSetTestCase(unittest.TestCase):
1197   """Test case for FieldSets"""
1198
1199   def testSimpleMatch(self):
1200     f = utils.FieldSet("a", "b", "c", "def")
1201     self.failUnless(f.Matches("a"))
1202     self.failIf(f.Matches("d"), "Substring matched")
1203     self.failIf(f.Matches("defghi"), "Prefix string matched")
1204     self.failIf(f.NonMatching(["b", "c"]))
1205     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1206     self.failUnless(f.NonMatching(["a", "d"]))
1207
1208   def testRegexMatch(self):
1209     f = utils.FieldSet("a", "b([0-9]+)", "c")
1210     self.failUnless(f.Matches("b1"))
1211     self.failUnless(f.Matches("b99"))
1212     self.failIf(f.Matches("b/1"))
1213     self.failIf(f.NonMatching(["b12", "c"]))
1214     self.failUnless(f.NonMatching(["a", "1"]))
1215
1216 class TestForceDictType(unittest.TestCase):
1217   """Test case for ForceDictType"""
1218
1219   def setUp(self):
1220     self.key_types = {
1221       'a': constants.VTYPE_INT,
1222       'b': constants.VTYPE_BOOL,
1223       'c': constants.VTYPE_STRING,
1224       'd': constants.VTYPE_SIZE,
1225       }
1226
1227   def _fdt(self, dict, allowed_values=None):
1228     if allowed_values is None:
1229       ForceDictType(dict, self.key_types)
1230     else:
1231       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1232
1233     return dict
1234
1235   def testSimpleDict(self):
1236     self.assertEqual(self._fdt({}), {})
1237     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1238     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1239     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1240     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1241     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1242     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1243     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1244     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1245     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1246     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1247     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1248
1249   def testErrors(self):
1250     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1251     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1252     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1253     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1254
1255
1256 class TestIsAbsNormPath(unittest.TestCase):
1257   """Testing case for IsProcessAlive"""
1258
1259   def _pathTestHelper(self, path, result):
1260     if result:
1261       self.assert_(IsNormAbsPath(path),
1262           "Path %s should result absolute and normalized" % path)
1263     else:
1264       self.assert_(not IsNormAbsPath(path),
1265           "Path %s should not result absolute and normalized" % path)
1266
1267   def testBase(self):
1268     self._pathTestHelper('/etc', True)
1269     self._pathTestHelper('/srv', True)
1270     self._pathTestHelper('etc', False)
1271     self._pathTestHelper('/etc/../root', False)
1272     self._pathTestHelper('/etc/', False)
1273
1274
1275 class TestSafeEncode(unittest.TestCase):
1276   """Test case for SafeEncode"""
1277
1278   def testAscii(self):
1279     for txt in [string.digits, string.letters, string.punctuation]:
1280       self.failUnlessEqual(txt, SafeEncode(txt))
1281
1282   def testDoubleEncode(self):
1283     for i in range(255):
1284       txt = SafeEncode(chr(i))
1285       self.failUnlessEqual(txt, SafeEncode(txt))
1286
1287   def testUnicode(self):
1288     # 1024 is high enough to catch non-direct ASCII mappings
1289     for i in range(1024):
1290       txt = SafeEncode(unichr(i))
1291       self.failUnlessEqual(txt, SafeEncode(txt))
1292
1293
1294 class TestFormatTime(unittest.TestCase):
1295   """Testing case for FormatTime"""
1296
1297   def testNone(self):
1298     self.failUnlessEqual(FormatTime(None), "N/A")
1299
1300   def testInvalid(self):
1301     self.failUnlessEqual(FormatTime(()), "N/A")
1302
1303   def testNow(self):
1304     # tests that we accept time.time input
1305     FormatTime(time.time())
1306     # tests that we accept int input
1307     FormatTime(int(time.time()))
1308
1309
1310 class RunInSeparateProcess(unittest.TestCase):
1311   def test(self):
1312     for exp in [True, False]:
1313       def _child():
1314         return exp
1315
1316       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1317
1318   def testArgs(self):
1319     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1320       def _child(carg1, carg2):
1321         return carg1 == "Foo" and carg2 == arg
1322
1323       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1324
1325   def testPid(self):
1326     parent_pid = os.getpid()
1327
1328     def _check():
1329       return os.getpid() == parent_pid
1330
1331     self.failIf(utils.RunInSeparateProcess(_check))
1332
1333   def testSignal(self):
1334     def _kill():
1335       os.kill(os.getpid(), signal.SIGTERM)
1336
1337     self.assertRaises(errors.GenericError,
1338                       utils.RunInSeparateProcess, _kill)
1339
1340   def testException(self):
1341     def _exc():
1342       raise errors.GenericError("This is a test")
1343
1344     self.assertRaises(errors.GenericError,
1345                       utils.RunInSeparateProcess, _exc)
1346
1347
1348 class TestFingerprintFile(unittest.TestCase):
1349   def setUp(self):
1350     self.tmpfile = tempfile.NamedTemporaryFile()
1351
1352   def test(self):
1353     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1354                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1355
1356     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1357     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1358                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1359
1360
1361 class TestUnescapeAndSplit(unittest.TestCase):
1362   """Testing case for UnescapeAndSplit"""
1363
1364   def setUp(self):
1365     # testing more that one separator for regexp safety
1366     self._seps = [",", "+", "."]
1367
1368   def testSimple(self):
1369     a = ["a", "b", "c", "d"]
1370     for sep in self._seps:
1371       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1372
1373   def testEscape(self):
1374     for sep in self._seps:
1375       a = ["a", "b\\" + sep + "c", "d"]
1376       b = ["a", "b" + sep + "c", "d"]
1377       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1378
1379   def testDoubleEscape(self):
1380     for sep in self._seps:
1381       a = ["a", "b\\\\", "c", "d"]
1382       b = ["a", "b\\", "c", "d"]
1383       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1384
1385   def testThreeEscape(self):
1386     for sep in self._seps:
1387       a = ["a", "b\\\\\\" + sep + "c", "d"]
1388       b = ["a", "b\\" + sep + "c", "d"]
1389       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1390
1391
1392 class TestPathJoin(unittest.TestCase):
1393   """Testing case for PathJoin"""
1394
1395   def testBasicItems(self):
1396     mlist = ["/a", "b", "c"]
1397     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1398
1399   def testNonAbsPrefix(self):
1400     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1401
1402   def testBackTrack(self):
1403     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1404
1405   def testMultiAbs(self):
1406     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1407
1408
1409 class TestHostInfo(unittest.TestCase):
1410   """Testing case for HostInfo"""
1411
1412   def testUppercase(self):
1413     data = "AbC.example.com"
1414     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1415
1416   def testTooLongName(self):
1417     data = "a.b." + "c" * 255
1418     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1419
1420   def testTrailingDot(self):
1421     data = "a.b.c"
1422     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1423
1424   def testInvalidName(self):
1425     data = [
1426       "a b",
1427       "a/b",
1428       ".a.b",
1429       "a..b",
1430       ]
1431     for value in data:
1432       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1433
1434   def testValidName(self):
1435     data = [
1436       "a.b",
1437       "a-b",
1438       "a_b",
1439       "a.b.c",
1440       ]
1441     for value in data:
1442       HostInfo.NormalizeName(value)
1443
1444
1445 class TestParseAsn1Generalizedtime(unittest.TestCase):
1446   def test(self):
1447     # UTC
1448     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1449     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1450                      1266860512)
1451     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1452                      (2**31) - 1)
1453
1454     # With offset
1455     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1456                      1266860512)
1457     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1458                      1266931012)
1459     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1460                      1266931088)
1461     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1462                      1266931295)
1463     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1464                      3600)
1465
1466     # Leap seconds are not supported by datetime.datetime
1467     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1468                       "19841231235960+0000")
1469     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1470                       "19920630235960+0000")
1471
1472     # Errors
1473     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1474     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1475     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1476                       "20100222174152")
1477     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1478                       "Mon Feb 22 17:47:02 UTC 2010")
1479     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1480                       "2010-02-22 17:42:02")
1481
1482
1483 class TestGetX509CertValidity(testutils.GanetiTestCase):
1484   def setUp(self):
1485     testutils.GanetiTestCase.setUp(self)
1486
1487     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1488
1489     # Test whether we have pyOpenSSL 0.7 or above
1490     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1491
1492     if not self.pyopenssl0_7:
1493       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1494                     " function correctly")
1495
1496   def _LoadCert(self, name):
1497     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1498                                            self._ReadTestData(name))
1499
1500   def test(self):
1501     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1502     if self.pyopenssl0_7:
1503       self.assertEqual(validity, (1266919967, 1267524767))
1504     else:
1505       self.assertEqual(validity, (None, None))
1506
1507
1508 if __name__ == '__main__':
1509   testutils.GanetiTestProgram()