utils: Move wrapper code around os.makedirs into separate function
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import OpenSSL
38 import warnings
39 import distutils.version
40 import glob
41
42 import ganeti
43 import testutils
44 from ganeti import constants
45 from ganeti import utils
46 from ganeti import errors
47 from ganeti.utils import IsProcessAlive, RunCmd, \
48      RemoveFile, MatchNameComponent, FormatUnit, \
49      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
50      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
51      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
52      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
53      UnescapeAndSplit, RunParts, PathJoin, HostInfo
54
55 from ganeti.errors import LockError, UnitParseError, GenericError, \
56      ProgrammerError, OpPrereqError
57
58
59 class TestIsProcessAlive(unittest.TestCase):
60   """Testing case for IsProcessAlive"""
61
62   def testExists(self):
63     mypid = os.getpid()
64     self.assert_(IsProcessAlive(mypid),
65                  "can't find myself running")
66
67   def testNotExisting(self):
68     pid_non_existing = os.fork()
69     if pid_non_existing == 0:
70       os._exit(0)
71     elif pid_non_existing < 0:
72       raise SystemError("can't fork")
73     os.waitpid(pid_non_existing, 0)
74     self.assert_(not IsProcessAlive(pid_non_existing),
75                  "nonexisting process detected")
76
77
78 class TestPidFileFunctions(unittest.TestCase):
79   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
80
81   def setUp(self):
82     self.dir = tempfile.mkdtemp()
83     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
84     utils.DaemonPidFileName = self.f_dpn
85
86   def testPidFileFunctions(self):
87     pid_file = self.f_dpn('test')
88     utils.WritePidFile('test')
89     self.failUnless(os.path.exists(pid_file),
90                     "PID file should have been created")
91     read_pid = utils.ReadPidFile(pid_file)
92     self.failUnlessEqual(read_pid, os.getpid())
93     self.failUnless(utils.IsProcessAlive(read_pid))
94     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
95     utils.RemovePidFile('test')
96     self.failIf(os.path.exists(pid_file),
97                 "PID file should not exist anymore")
98     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
99                          "ReadPidFile should return 0 for missing pid file")
100     fh = open(pid_file, "w")
101     fh.write("blah\n")
102     fh.close()
103     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
104                          "ReadPidFile should return 0 for invalid pid file")
105     utils.RemovePidFile('test')
106     self.failIf(os.path.exists(pid_file),
107                 "PID file should not exist anymore")
108
109   def testKill(self):
110     pid_file = self.f_dpn('child')
111     r_fd, w_fd = os.pipe()
112     new_pid = os.fork()
113     if new_pid == 0: #child
114       utils.WritePidFile('child')
115       os.write(w_fd, 'a')
116       signal.pause()
117       os._exit(0)
118       return
119     # else we are in the parent
120     # wait until the child has written the pid file
121     os.read(r_fd, 1)
122     read_pid = utils.ReadPidFile(pid_file)
123     self.failUnlessEqual(read_pid, new_pid)
124     self.failUnless(utils.IsProcessAlive(new_pid))
125     utils.KillProcess(new_pid, waitpid=True)
126     self.failIf(utils.IsProcessAlive(new_pid))
127     utils.RemovePidFile('child')
128     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
129
130   def tearDown(self):
131     for name in os.listdir(self.dir):
132       os.unlink(os.path.join(self.dir, name))
133     os.rmdir(self.dir)
134
135
136 class TestRunCmd(testutils.GanetiTestCase):
137   """Testing case for the RunCmd function"""
138
139   def setUp(self):
140     testutils.GanetiTestCase.setUp(self)
141     self.magic = time.ctime() + " ganeti test"
142     self.fname = self._CreateTempFile()
143
144   def testOk(self):
145     """Test successful exit code"""
146     result = RunCmd("/bin/sh -c 'exit 0'")
147     self.assertEqual(result.exit_code, 0)
148     self.assertEqual(result.output, "")
149
150   def testFail(self):
151     """Test fail exit code"""
152     result = RunCmd("/bin/sh -c 'exit 1'")
153     self.assertEqual(result.exit_code, 1)
154     self.assertEqual(result.output, "")
155
156   def testStdout(self):
157     """Test standard output"""
158     cmd = 'echo -n "%s"' % self.magic
159     result = RunCmd("/bin/sh -c '%s'" % cmd)
160     self.assertEqual(result.stdout, self.magic)
161     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
162     self.assertEqual(result.output, "")
163     self.assertFileContent(self.fname, self.magic)
164
165   def testStderr(self):
166     """Test standard error"""
167     cmd = 'echo -n "%s"' % self.magic
168     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
169     self.assertEqual(result.stderr, self.magic)
170     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
171     self.assertEqual(result.output, "")
172     self.assertFileContent(self.fname, self.magic)
173
174   def testCombined(self):
175     """Test combined output"""
176     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
177     expected = "A" + self.magic + "B" + self.magic
178     result = RunCmd("/bin/sh -c '%s'" % cmd)
179     self.assertEqual(result.output, expected)
180     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
181     self.assertEqual(result.output, "")
182     self.assertFileContent(self.fname, expected)
183
184   def testSignal(self):
185     """Test signal"""
186     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
187     self.assertEqual(result.signal, 15)
188     self.assertEqual(result.output, "")
189
190   def testListRun(self):
191     """Test list runs"""
192     result = RunCmd(["true"])
193     self.assertEqual(result.signal, None)
194     self.assertEqual(result.exit_code, 0)
195     result = RunCmd(["/bin/sh", "-c", "exit 1"])
196     self.assertEqual(result.signal, None)
197     self.assertEqual(result.exit_code, 1)
198     result = RunCmd(["echo", "-n", self.magic])
199     self.assertEqual(result.signal, None)
200     self.assertEqual(result.exit_code, 0)
201     self.assertEqual(result.stdout, self.magic)
202
203   def testFileEmptyOutput(self):
204     """Test file output"""
205     result = RunCmd(["true"], output=self.fname)
206     self.assertEqual(result.signal, None)
207     self.assertEqual(result.exit_code, 0)
208     self.assertFileContent(self.fname, "")
209
210   def testLang(self):
211     """Test locale environment"""
212     old_env = os.environ.copy()
213     try:
214       os.environ["LANG"] = "en_US.UTF-8"
215       os.environ["LC_ALL"] = "en_US.UTF-8"
216       result = RunCmd(["locale"])
217       for line in result.output.splitlines():
218         key, value = line.split("=", 1)
219         # Ignore these variables, they're overridden by LC_ALL
220         if key == "LANG" or key == "LANGUAGE":
221           continue
222         self.failIf(value and value != "C" and value != '"C"',
223             "Variable %s is set to the invalid value '%s'" % (key, value))
224     finally:
225       os.environ = old_env
226
227   def testDefaultCwd(self):
228     """Test default working directory"""
229     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
230
231   def testCwd(self):
232     """Test default working directory"""
233     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
234     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
235     cwd = os.getcwd()
236     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
237
238   def testResetEnv(self):
239     """Test environment reset functionality"""
240     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
241     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
242                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
243
244
245 class TestRunParts(unittest.TestCase):
246   """Testing case for the RunParts function"""
247
248   def setUp(self):
249     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
250
251   def tearDown(self):
252     shutil.rmtree(self.rundir)
253
254   def testEmpty(self):
255     """Test on an empty dir"""
256     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
257
258   def testSkipWrongName(self):
259     """Test that wrong files are skipped"""
260     fname = os.path.join(self.rundir, "00test.dot")
261     utils.WriteFile(fname, data="")
262     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
263     relname = os.path.basename(fname)
264     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
265                          [(relname, constants.RUNPARTS_SKIP, None)])
266
267   def testSkipNonExec(self):
268     """Test that non executable files are skipped"""
269     fname = os.path.join(self.rundir, "00test")
270     utils.WriteFile(fname, data="")
271     relname = os.path.basename(fname)
272     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
273                          [(relname, constants.RUNPARTS_SKIP, None)])
274
275   def testError(self):
276     """Test error on a broken executable"""
277     fname = os.path.join(self.rundir, "00test")
278     utils.WriteFile(fname, data="")
279     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
280     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
281     self.failUnlessEqual(relname, os.path.basename(fname))
282     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
283     self.failUnless(error)
284
285   def testSorted(self):
286     """Test executions are sorted"""
287     files = []
288     files.append(os.path.join(self.rundir, "64test"))
289     files.append(os.path.join(self.rundir, "00test"))
290     files.append(os.path.join(self.rundir, "42test"))
291
292     for fname in files:
293       utils.WriteFile(fname, data="")
294
295     results = RunParts(self.rundir, reset_env=True)
296
297     for fname in sorted(files):
298       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
299
300   def testOk(self):
301     """Test correct execution"""
302     fname = os.path.join(self.rundir, "00test")
303     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
304     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
305     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
306     self.failUnlessEqual(relname, os.path.basename(fname))
307     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
308     self.failUnlessEqual(runresult.stdout, "ciao")
309
310   def testRunFail(self):
311     """Test correct execution, with run failure"""
312     fname = os.path.join(self.rundir, "00test")
313     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
314     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
315     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
316     self.failUnlessEqual(relname, os.path.basename(fname))
317     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
318     self.failUnlessEqual(runresult.exit_code, 1)
319     self.failUnless(runresult.failed)
320
321   def testRunMix(self):
322     files = []
323     files.append(os.path.join(self.rundir, "00test"))
324     files.append(os.path.join(self.rundir, "42test"))
325     files.append(os.path.join(self.rundir, "64test"))
326     files.append(os.path.join(self.rundir, "99test"))
327
328     files.sort()
329
330     # 1st has errors in execution
331     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
332     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
333
334     # 2nd is skipped
335     utils.WriteFile(files[1], data="")
336
337     # 3rd cannot execute properly
338     utils.WriteFile(files[2], data="")
339     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
340
341     # 4th execs
342     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
343     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
344
345     results = RunParts(self.rundir, reset_env=True)
346
347     (relname, status, runresult) = results[0]
348     self.failUnlessEqual(relname, os.path.basename(files[0]))
349     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
350     self.failUnlessEqual(runresult.exit_code, 1)
351     self.failUnless(runresult.failed)
352
353     (relname, status, runresult) = results[1]
354     self.failUnlessEqual(relname, os.path.basename(files[1]))
355     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
356     self.failUnlessEqual(runresult, None)
357
358     (relname, status, runresult) = results[2]
359     self.failUnlessEqual(relname, os.path.basename(files[2]))
360     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
361     self.failUnless(runresult)
362
363     (relname, status, runresult) = results[3]
364     self.failUnlessEqual(relname, os.path.basename(files[3]))
365     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
366     self.failUnlessEqual(runresult.output, "ciao")
367     self.failUnlessEqual(runresult.exit_code, 0)
368     self.failUnless(not runresult.failed)
369
370
371 class TestRemoveFile(unittest.TestCase):
372   """Test case for the RemoveFile function"""
373
374   def setUp(self):
375     """Create a temp dir and file for each case"""
376     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
377     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
378     os.close(fd)
379
380   def tearDown(self):
381     if os.path.exists(self.tmpfile):
382       os.unlink(self.tmpfile)
383     os.rmdir(self.tmpdir)
384
385
386   def testIgnoreDirs(self):
387     """Test that RemoveFile() ignores directories"""
388     self.assertEqual(None, RemoveFile(self.tmpdir))
389
390
391   def testIgnoreNotExisting(self):
392     """Test that RemoveFile() ignores non-existing files"""
393     RemoveFile(self.tmpfile)
394     RemoveFile(self.tmpfile)
395
396
397   def testRemoveFile(self):
398     """Test that RemoveFile does remove a file"""
399     RemoveFile(self.tmpfile)
400     if os.path.exists(self.tmpfile):
401       self.fail("File '%s' not removed" % self.tmpfile)
402
403
404   def testRemoveSymlink(self):
405     """Test that RemoveFile does remove symlinks"""
406     symlink = self.tmpdir + "/symlink"
407     os.symlink("no-such-file", symlink)
408     RemoveFile(symlink)
409     if os.path.exists(symlink):
410       self.fail("File '%s' not removed" % symlink)
411     os.symlink(self.tmpfile, symlink)
412     RemoveFile(symlink)
413     if os.path.exists(symlink):
414       self.fail("File '%s' not removed" % symlink)
415
416
417 class TestRename(unittest.TestCase):
418   """Test case for RenameFile"""
419
420   def setUp(self):
421     """Create a temporary directory"""
422     self.tmpdir = tempfile.mkdtemp()
423     self.tmpfile = os.path.join(self.tmpdir, "test1")
424
425     # Touch the file
426     open(self.tmpfile, "w").close()
427
428   def tearDown(self):
429     """Remove temporary directory"""
430     shutil.rmtree(self.tmpdir)
431
432   def testSimpleRename1(self):
433     """Simple rename 1"""
434     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
435     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
436
437   def testSimpleRename2(self):
438     """Simple rename 2"""
439     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
440                      mkdir=True)
441     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
442
443   def testRenameMkdir(self):
444     """Rename with mkdir"""
445     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
446                      mkdir=True)
447     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
448     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
449
450     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
451                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
452                      mkdir=True)
453     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
454     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
455     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
456
457
458 class TestMatchNameComponent(unittest.TestCase):
459   """Test case for the MatchNameComponent function"""
460
461   def testEmptyList(self):
462     """Test that there is no match against an empty list"""
463
464     self.failUnlessEqual(MatchNameComponent("", []), None)
465     self.failUnlessEqual(MatchNameComponent("test", []), None)
466
467   def testSingleMatch(self):
468     """Test that a single match is performed correctly"""
469     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
470     for key in "test2", "test2.example", "test2.example.com":
471       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
472
473   def testMultipleMatches(self):
474     """Test that a multiple match is returned as None"""
475     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
476     for key in "test1", "test1.example":
477       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
478
479   def testFullMatch(self):
480     """Test that a full match is returned correctly"""
481     key1 = "test1"
482     key2 = "test1.example"
483     mlist = [key2, key2 + ".com"]
484     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
485     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
486
487   def testCaseInsensitivePartialMatch(self):
488     """Test for the case_insensitive keyword"""
489     mlist = ["test1.example.com", "test2.example.net"]
490     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
491                      "test2.example.net")
492     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
493                      "test2.example.net")
494     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
495                      "test2.example.net")
496     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
497                      "test2.example.net")
498
499
500   def testCaseInsensitiveFullMatch(self):
501     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
502     # Between the two ts1 a full string match non-case insensitive should work
503     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
504                      None)
505     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
506                      "ts1.ex")
507     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
508                      "ts1.ex")
509     # Between the two ts2 only case differs, so only case-match works
510     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
511                      "ts2.ex")
512     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
513                      "Ts2.ex")
514     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
515                      None)
516
517
518 class TestTimestampForFilename(unittest.TestCase):
519   def test(self):
520     self.assert_("." not in utils.TimestampForFilename())
521     self.assert_(":" not in utils.TimestampForFilename())
522
523
524 class TestCreateBackup(testutils.GanetiTestCase):
525   def setUp(self):
526     testutils.GanetiTestCase.setUp(self)
527
528     self.tmpdir = tempfile.mkdtemp()
529
530   def tearDown(self):
531     testutils.GanetiTestCase.tearDown(self)
532
533     shutil.rmtree(self.tmpdir)
534
535   def testEmpty(self):
536     filename = utils.PathJoin(self.tmpdir, "config.data")
537     utils.WriteFile(filename, data="")
538     bname = utils.CreateBackup(filename)
539     self.assertFileContent(bname, "")
540     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
541     utils.CreateBackup(filename)
542     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
543     utils.CreateBackup(filename)
544     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
545
546     fifoname = utils.PathJoin(self.tmpdir, "fifo")
547     os.mkfifo(fifoname)
548     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
549
550   def testContent(self):
551     bkpcount = 0
552     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
553       for rep in [1, 2, 10, 127]:
554         testdata = data * rep
555
556         filename = utils.PathJoin(self.tmpdir, "test.data_")
557         utils.WriteFile(filename, data=testdata)
558         self.assertFileContent(filename, testdata)
559
560         for _ in range(3):
561           bname = utils.CreateBackup(filename)
562           bkpcount += 1
563           self.assertFileContent(bname, testdata)
564           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
565
566
567 class TestFormatUnit(unittest.TestCase):
568   """Test case for the FormatUnit function"""
569
570   def testMiB(self):
571     self.assertEqual(FormatUnit(1, 'h'), '1M')
572     self.assertEqual(FormatUnit(100, 'h'), '100M')
573     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
574
575     self.assertEqual(FormatUnit(1, 'm'), '1')
576     self.assertEqual(FormatUnit(100, 'm'), '100')
577     self.assertEqual(FormatUnit(1023, 'm'), '1023')
578
579     self.assertEqual(FormatUnit(1024, 'm'), '1024')
580     self.assertEqual(FormatUnit(1536, 'm'), '1536')
581     self.assertEqual(FormatUnit(17133, 'm'), '17133')
582     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
583
584   def testGiB(self):
585     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
586     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
587     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
588     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
589
590     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
591     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
592     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
593     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
594
595     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
596     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
597     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
598
599   def testTiB(self):
600     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
601     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
602     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
603
604     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
605     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
606     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
607
608 class TestParseUnit(unittest.TestCase):
609   """Test case for the ParseUnit function"""
610
611   SCALES = (('', 1),
612             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
613             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
614             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
615
616   def testRounding(self):
617     self.assertEqual(ParseUnit('0'), 0)
618     self.assertEqual(ParseUnit('1'), 4)
619     self.assertEqual(ParseUnit('2'), 4)
620     self.assertEqual(ParseUnit('3'), 4)
621
622     self.assertEqual(ParseUnit('124'), 124)
623     self.assertEqual(ParseUnit('125'), 128)
624     self.assertEqual(ParseUnit('126'), 128)
625     self.assertEqual(ParseUnit('127'), 128)
626     self.assertEqual(ParseUnit('128'), 128)
627     self.assertEqual(ParseUnit('129'), 132)
628     self.assertEqual(ParseUnit('130'), 132)
629
630   def testFloating(self):
631     self.assertEqual(ParseUnit('0'), 0)
632     self.assertEqual(ParseUnit('0.5'), 4)
633     self.assertEqual(ParseUnit('1.75'), 4)
634     self.assertEqual(ParseUnit('1.99'), 4)
635     self.assertEqual(ParseUnit('2.00'), 4)
636     self.assertEqual(ParseUnit('2.01'), 4)
637     self.assertEqual(ParseUnit('3.99'), 4)
638     self.assertEqual(ParseUnit('4.00'), 4)
639     self.assertEqual(ParseUnit('4.01'), 8)
640     self.assertEqual(ParseUnit('1.5G'), 1536)
641     self.assertEqual(ParseUnit('1.8G'), 1844)
642     self.assertEqual(ParseUnit('8.28T'), 8682212)
643
644   def testSuffixes(self):
645     for sep in ('', ' ', '   ', "\t", "\t "):
646       for suffix, scale in TestParseUnit.SCALES:
647         for func in (lambda x: x, str.lower, str.upper):
648           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
649                            1024 * scale)
650
651   def testInvalidInput(self):
652     for sep in ('-', '_', ',', 'a'):
653       for suffix, _ in TestParseUnit.SCALES:
654         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
655
656     for suffix, _ in TestParseUnit.SCALES:
657       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
658
659
660 class TestSshKeys(testutils.GanetiTestCase):
661   """Test case for the AddAuthorizedKey function"""
662
663   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
664   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
665            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
666
667   def setUp(self):
668     testutils.GanetiTestCase.setUp(self)
669     self.tmpname = self._CreateTempFile()
670     handle = open(self.tmpname, 'w')
671     try:
672       handle.write("%s\n" % TestSshKeys.KEY_A)
673       handle.write("%s\n" % TestSshKeys.KEY_B)
674     finally:
675       handle.close()
676
677   def testAddingNewKey(self):
678     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
679
680     self.assertFileContent(self.tmpname,
681       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
682       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
683       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
684       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
685
686   def testAddingAlmostButNotCompletelyTheSameKey(self):
687     AddAuthorizedKey(self.tmpname,
688         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
689
690     self.assertFileContent(self.tmpname,
691       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
692       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
693       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
694       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
695
696   def testAddingExistingKeyWithSomeMoreSpaces(self):
697     AddAuthorizedKey(self.tmpname,
698         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
699
700     self.assertFileContent(self.tmpname,
701       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
702       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
703       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
704
705   def testRemovingExistingKeyWithSomeMoreSpaces(self):
706     RemoveAuthorizedKey(self.tmpname,
707         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
708
709     self.assertFileContent(self.tmpname,
710       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
711       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
712
713   def testRemovingNonExistingKey(self):
714     RemoveAuthorizedKey(self.tmpname,
715         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
716
717     self.assertFileContent(self.tmpname,
718       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
719       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
720       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
721
722
723 class TestEtcHosts(testutils.GanetiTestCase):
724   """Test functions modifying /etc/hosts"""
725
726   def setUp(self):
727     testutils.GanetiTestCase.setUp(self)
728     self.tmpname = self._CreateTempFile()
729     handle = open(self.tmpname, 'w')
730     try:
731       handle.write('# This is a test file for /etc/hosts\n')
732       handle.write('127.0.0.1\tlocalhost\n')
733       handle.write('192.168.1.1 router gw\n')
734     finally:
735       handle.close()
736
737   def testSettingNewIp(self):
738     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
739
740     self.assertFileContent(self.tmpname,
741       "# This is a test file for /etc/hosts\n"
742       "127.0.0.1\tlocalhost\n"
743       "192.168.1.1 router gw\n"
744       "1.2.3.4\tmyhost.domain.tld myhost\n")
745     self.assertFileMode(self.tmpname, 0644)
746
747   def testSettingExistingIp(self):
748     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
749                      ['myhost'])
750
751     self.assertFileContent(self.tmpname,
752       "# This is a test file for /etc/hosts\n"
753       "127.0.0.1\tlocalhost\n"
754       "192.168.1.1\tmyhost.domain.tld myhost\n")
755     self.assertFileMode(self.tmpname, 0644)
756
757   def testSettingDuplicateName(self):
758     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
759
760     self.assertFileContent(self.tmpname,
761       "# This is a test file for /etc/hosts\n"
762       "127.0.0.1\tlocalhost\n"
763       "192.168.1.1 router gw\n"
764       "1.2.3.4\tmyhost\n")
765     self.assertFileMode(self.tmpname, 0644)
766
767   def testRemovingExistingHost(self):
768     RemoveEtcHostsEntry(self.tmpname, 'router')
769
770     self.assertFileContent(self.tmpname,
771       "# This is a test file for /etc/hosts\n"
772       "127.0.0.1\tlocalhost\n"
773       "192.168.1.1 gw\n")
774     self.assertFileMode(self.tmpname, 0644)
775
776   def testRemovingSingleExistingHost(self):
777     RemoveEtcHostsEntry(self.tmpname, 'localhost')
778
779     self.assertFileContent(self.tmpname,
780       "# This is a test file for /etc/hosts\n"
781       "192.168.1.1 router gw\n")
782     self.assertFileMode(self.tmpname, 0644)
783
784   def testRemovingNonExistingHost(self):
785     RemoveEtcHostsEntry(self.tmpname, 'myhost')
786
787     self.assertFileContent(self.tmpname,
788       "# This is a test file for /etc/hosts\n"
789       "127.0.0.1\tlocalhost\n"
790       "192.168.1.1 router gw\n")
791     self.assertFileMode(self.tmpname, 0644)
792
793   def testRemovingAlias(self):
794     RemoveEtcHostsEntry(self.tmpname, 'gw')
795
796     self.assertFileContent(self.tmpname,
797       "# This is a test file for /etc/hosts\n"
798       "127.0.0.1\tlocalhost\n"
799       "192.168.1.1 router\n")
800     self.assertFileMode(self.tmpname, 0644)
801
802
803 class TestShellQuoting(unittest.TestCase):
804   """Test case for shell quoting functions"""
805
806   def testShellQuote(self):
807     self.assertEqual(ShellQuote('abc'), "abc")
808     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
809     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
810     self.assertEqual(ShellQuote("a b c"), "'a b c'")
811     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
812
813   def testShellQuoteArgs(self):
814     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
815     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
816     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
817
818
819 class TestTcpPing(unittest.TestCase):
820   """Testcase for TCP version of ping - against listen(2)ing port"""
821
822   def setUp(self):
823     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
824     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
825     self.listenerport = self.listener.getsockname()[1]
826     self.listener.listen(1)
827
828   def tearDown(self):
829     self.listener.shutdown(socket.SHUT_RDWR)
830     del self.listener
831     del self.listenerport
832
833   def testTcpPingToLocalHostAccept(self):
834     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
835                          self.listenerport,
836                          timeout=10,
837                          live_port_needed=True,
838                          source=constants.LOCALHOST_IP_ADDRESS,
839                          ),
840                  "failed to connect to test listener")
841
842     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
843                          self.listenerport,
844                          timeout=10,
845                          live_port_needed=True,
846                          ),
847                  "failed to connect to test listener (no source)")
848
849
850 class TestTcpPingDeaf(unittest.TestCase):
851   """Testcase for TCP version of ping - against non listen(2)ing port"""
852
853   def setUp(self):
854     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
855     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
856     self.deaflistenerport = self.deaflistener.getsockname()[1]
857
858   def tearDown(self):
859     del self.deaflistener
860     del self.deaflistenerport
861
862   def testTcpPingToLocalHostAcceptDeaf(self):
863     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
864                         self.deaflistenerport,
865                         timeout=constants.TCP_PING_TIMEOUT,
866                         live_port_needed=True,
867                         source=constants.LOCALHOST_IP_ADDRESS,
868                         ), # need successful connect(2)
869                 "successfully connected to deaf listener")
870
871     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
872                         self.deaflistenerport,
873                         timeout=constants.TCP_PING_TIMEOUT,
874                         live_port_needed=True,
875                         ), # need successful connect(2)
876                 "successfully connected to deaf listener (no source addr)")
877
878   def testTcpPingToLocalHostNoAccept(self):
879     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
880                          self.deaflistenerport,
881                          timeout=constants.TCP_PING_TIMEOUT,
882                          live_port_needed=False,
883                          source=constants.LOCALHOST_IP_ADDRESS,
884                          ), # ECONNREFUSED is OK
885                  "failed to ping alive host on deaf port")
886
887     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
888                          self.deaflistenerport,
889                          timeout=constants.TCP_PING_TIMEOUT,
890                          live_port_needed=False,
891                          ), # ECONNREFUSED is OK
892                  "failed to ping alive host on deaf port (no source addr)")
893
894
895 class TestOwnIpAddress(unittest.TestCase):
896   """Testcase for OwnIpAddress"""
897
898   def testOwnLoopback(self):
899     """check having the loopback ip"""
900     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
901                     "Should own the loopback address")
902
903   def testNowOwnAddress(self):
904     """check that I don't own an address"""
905
906     # network 192.0.2.0/24 is reserved for test/documentation as per
907     # rfc 3330, so we *should* not have an address of this range... if
908     # this fails, we should extend the test to multiple addresses
909     DST_IP = "192.0.2.1"
910     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
911
912
913 class TestListVisibleFiles(unittest.TestCase):
914   """Test case for ListVisibleFiles"""
915
916   def setUp(self):
917     self.path = tempfile.mkdtemp()
918
919   def tearDown(self):
920     shutil.rmtree(self.path)
921
922   def _test(self, files, expected):
923     # Sort a copy
924     expected = expected[:]
925     expected.sort()
926
927     for name in files:
928       f = open(os.path.join(self.path, name), 'w')
929       try:
930         f.write("Test\n")
931       finally:
932         f.close()
933
934     found = ListVisibleFiles(self.path)
935     found.sort()
936
937     self.assertEqual(found, expected)
938
939   def testAllVisible(self):
940     files = ["a", "b", "c"]
941     expected = files
942     self._test(files, expected)
943
944   def testNoneVisible(self):
945     files = [".a", ".b", ".c"]
946     expected = []
947     self._test(files, expected)
948
949   def testSomeVisible(self):
950     files = ["a", "b", ".c"]
951     expected = ["a", "b"]
952     self._test(files, expected)
953
954   def testNonAbsolutePath(self):
955     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
956
957   def testNonNormalizedPath(self):
958     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
959                           "/bin/../tmp")
960
961
962 class TestNewUUID(unittest.TestCase):
963   """Test case for NewUUID"""
964
965   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
966                         '[a-f0-9]{4}-[a-f0-9]{12}$')
967
968   def runTest(self):
969     self.failUnless(self._re_uuid.match(utils.NewUUID()))
970
971
972 class TestUniqueSequence(unittest.TestCase):
973   """Test case for UniqueSequence"""
974
975   def _test(self, input, expected):
976     self.assertEqual(utils.UniqueSequence(input), expected)
977
978   def runTest(self):
979     # Ordered input
980     self._test([1, 2, 3], [1, 2, 3])
981     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
982     self._test([1, 2, 2, 3], [1, 2, 3])
983     self._test([1, 2, 3, 3], [1, 2, 3])
984
985     # Unordered input
986     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
987     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
988
989     # Strings
990     self._test(["a", "a"], ["a"])
991     self._test(["a", "b"], ["a", "b"])
992     self._test(["a", "b", "a"], ["a", "b"])
993
994
995 class TestFirstFree(unittest.TestCase):
996   """Test case for the FirstFree function"""
997
998   def test(self):
999     """Test FirstFree"""
1000     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1001     self.failUnlessEqual(FirstFree([]), None)
1002     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1003     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1004     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1005
1006
1007 class TestTailFile(testutils.GanetiTestCase):
1008   """Test case for the TailFile function"""
1009
1010   def testEmpty(self):
1011     fname = self._CreateTempFile()
1012     self.failUnlessEqual(TailFile(fname), [])
1013     self.failUnlessEqual(TailFile(fname, lines=25), [])
1014
1015   def testAllLines(self):
1016     data = ["test %d" % i for i in range(30)]
1017     for i in range(30):
1018       fname = self._CreateTempFile()
1019       fd = open(fname, "w")
1020       fd.write("\n".join(data[:i]))
1021       if i > 0:
1022         fd.write("\n")
1023       fd.close()
1024       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1025
1026   def testPartialLines(self):
1027     data = ["test %d" % i for i in range(30)]
1028     fname = self._CreateTempFile()
1029     fd = open(fname, "w")
1030     fd.write("\n".join(data))
1031     fd.write("\n")
1032     fd.close()
1033     for i in range(1, 30):
1034       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1035
1036   def testBigFile(self):
1037     data = ["test %d" % i for i in range(30)]
1038     fname = self._CreateTempFile()
1039     fd = open(fname, "w")
1040     fd.write("X" * 1048576)
1041     fd.write("\n")
1042     fd.write("\n".join(data))
1043     fd.write("\n")
1044     fd.close()
1045     for i in range(1, 30):
1046       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1047
1048
1049 class _BaseFileLockTest:
1050   """Test case for the FileLock class"""
1051
1052   def testSharedNonblocking(self):
1053     self.lock.Shared(blocking=False)
1054     self.lock.Close()
1055
1056   def testExclusiveNonblocking(self):
1057     self.lock.Exclusive(blocking=False)
1058     self.lock.Close()
1059
1060   def testUnlockNonblocking(self):
1061     self.lock.Unlock(blocking=False)
1062     self.lock.Close()
1063
1064   def testSharedBlocking(self):
1065     self.lock.Shared(blocking=True)
1066     self.lock.Close()
1067
1068   def testExclusiveBlocking(self):
1069     self.lock.Exclusive(blocking=True)
1070     self.lock.Close()
1071
1072   def testUnlockBlocking(self):
1073     self.lock.Unlock(blocking=True)
1074     self.lock.Close()
1075
1076   def testSharedExclusiveUnlock(self):
1077     self.lock.Shared(blocking=False)
1078     self.lock.Exclusive(blocking=False)
1079     self.lock.Unlock(blocking=False)
1080     self.lock.Close()
1081
1082   def testExclusiveSharedUnlock(self):
1083     self.lock.Exclusive(blocking=False)
1084     self.lock.Shared(blocking=False)
1085     self.lock.Unlock(blocking=False)
1086     self.lock.Close()
1087
1088   def testSimpleTimeout(self):
1089     # These will succeed on the first attempt, hence a short timeout
1090     self.lock.Shared(blocking=True, timeout=10.0)
1091     self.lock.Exclusive(blocking=False, timeout=10.0)
1092     self.lock.Unlock(blocking=True, timeout=10.0)
1093     self.lock.Close()
1094
1095   @staticmethod
1096   def _TryLockInner(filename, shared, blocking):
1097     lock = utils.FileLock.Open(filename)
1098
1099     if shared:
1100       fn = lock.Shared
1101     else:
1102       fn = lock.Exclusive
1103
1104     try:
1105       # The timeout doesn't really matter as the parent process waits for us to
1106       # finish anyway.
1107       fn(blocking=blocking, timeout=0.01)
1108     except errors.LockError, err:
1109       return False
1110
1111     return True
1112
1113   def _TryLock(self, *args):
1114     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1115                                       *args)
1116
1117   def testTimeout(self):
1118     for blocking in [True, False]:
1119       self.lock.Exclusive(blocking=True)
1120       self.failIf(self._TryLock(False, blocking))
1121       self.failIf(self._TryLock(True, blocking))
1122
1123       self.lock.Shared(blocking=True)
1124       self.assert_(self._TryLock(True, blocking))
1125       self.failIf(self._TryLock(False, blocking))
1126
1127   def testCloseShared(self):
1128     self.lock.Close()
1129     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1130
1131   def testCloseExclusive(self):
1132     self.lock.Close()
1133     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1134
1135   def testCloseUnlock(self):
1136     self.lock.Close()
1137     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1138
1139
1140 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1141   TESTDATA = "Hello World\n" * 10
1142
1143   def setUp(self):
1144     testutils.GanetiTestCase.setUp(self)
1145
1146     self.tmpfile = tempfile.NamedTemporaryFile()
1147     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1148     self.lock = utils.FileLock.Open(self.tmpfile.name)
1149
1150     # Ensure "Open" didn't truncate file
1151     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1152
1153   def tearDown(self):
1154     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1155
1156     testutils.GanetiTestCase.tearDown(self)
1157
1158
1159 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1160   def setUp(self):
1161     self.tmpfile = tempfile.NamedTemporaryFile()
1162     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1163
1164
1165 class TestTimeFunctions(unittest.TestCase):
1166   """Test case for time functions"""
1167
1168   def runTest(self):
1169     self.assertEqual(utils.SplitTime(1), (1, 0))
1170     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1171     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1172     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1173     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1174     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1175     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1176     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1177
1178     self.assertRaises(AssertionError, utils.SplitTime, -1)
1179
1180     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1181     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1182     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1183
1184     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1185                      1218448917.481)
1186     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1187
1188     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1189     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1190     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1191     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1192     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1193
1194
1195 class FieldSetTestCase(unittest.TestCase):
1196   """Test case for FieldSets"""
1197
1198   def testSimpleMatch(self):
1199     f = utils.FieldSet("a", "b", "c", "def")
1200     self.failUnless(f.Matches("a"))
1201     self.failIf(f.Matches("d"), "Substring matched")
1202     self.failIf(f.Matches("defghi"), "Prefix string matched")
1203     self.failIf(f.NonMatching(["b", "c"]))
1204     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1205     self.failUnless(f.NonMatching(["a", "d"]))
1206
1207   def testRegexMatch(self):
1208     f = utils.FieldSet("a", "b([0-9]+)", "c")
1209     self.failUnless(f.Matches("b1"))
1210     self.failUnless(f.Matches("b99"))
1211     self.failIf(f.Matches("b/1"))
1212     self.failIf(f.NonMatching(["b12", "c"]))
1213     self.failUnless(f.NonMatching(["a", "1"]))
1214
1215 class TestForceDictType(unittest.TestCase):
1216   """Test case for ForceDictType"""
1217
1218   def setUp(self):
1219     self.key_types = {
1220       'a': constants.VTYPE_INT,
1221       'b': constants.VTYPE_BOOL,
1222       'c': constants.VTYPE_STRING,
1223       'd': constants.VTYPE_SIZE,
1224       }
1225
1226   def _fdt(self, dict, allowed_values=None):
1227     if allowed_values is None:
1228       ForceDictType(dict, self.key_types)
1229     else:
1230       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1231
1232     return dict
1233
1234   def testSimpleDict(self):
1235     self.assertEqual(self._fdt({}), {})
1236     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1237     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1238     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1239     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1240     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1241     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1242     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1243     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1244     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1245     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1246     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1247
1248   def testErrors(self):
1249     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1250     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1251     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1252     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1253
1254
1255 class TestIsAbsNormPath(unittest.TestCase):
1256   """Testing case for IsProcessAlive"""
1257
1258   def _pathTestHelper(self, path, result):
1259     if result:
1260       self.assert_(IsNormAbsPath(path),
1261           "Path %s should result absolute and normalized" % path)
1262     else:
1263       self.assert_(not IsNormAbsPath(path),
1264           "Path %s should not result absolute and normalized" % path)
1265
1266   def testBase(self):
1267     self._pathTestHelper('/etc', True)
1268     self._pathTestHelper('/srv', True)
1269     self._pathTestHelper('etc', False)
1270     self._pathTestHelper('/etc/../root', False)
1271     self._pathTestHelper('/etc/', False)
1272
1273
1274 class TestSafeEncode(unittest.TestCase):
1275   """Test case for SafeEncode"""
1276
1277   def testAscii(self):
1278     for txt in [string.digits, string.letters, string.punctuation]:
1279       self.failUnlessEqual(txt, SafeEncode(txt))
1280
1281   def testDoubleEncode(self):
1282     for i in range(255):
1283       txt = SafeEncode(chr(i))
1284       self.failUnlessEqual(txt, SafeEncode(txt))
1285
1286   def testUnicode(self):
1287     # 1024 is high enough to catch non-direct ASCII mappings
1288     for i in range(1024):
1289       txt = SafeEncode(unichr(i))
1290       self.failUnlessEqual(txt, SafeEncode(txt))
1291
1292
1293 class TestFormatTime(unittest.TestCase):
1294   """Testing case for FormatTime"""
1295
1296   def testNone(self):
1297     self.failUnlessEqual(FormatTime(None), "N/A")
1298
1299   def testInvalid(self):
1300     self.failUnlessEqual(FormatTime(()), "N/A")
1301
1302   def testNow(self):
1303     # tests that we accept time.time input
1304     FormatTime(time.time())
1305     # tests that we accept int input
1306     FormatTime(int(time.time()))
1307
1308
1309 class RunInSeparateProcess(unittest.TestCase):
1310   def test(self):
1311     for exp in [True, False]:
1312       def _child():
1313         return exp
1314
1315       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1316
1317   def testArgs(self):
1318     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1319       def _child(carg1, carg2):
1320         return carg1 == "Foo" and carg2 == arg
1321
1322       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1323
1324   def testPid(self):
1325     parent_pid = os.getpid()
1326
1327     def _check():
1328       return os.getpid() == parent_pid
1329
1330     self.failIf(utils.RunInSeparateProcess(_check))
1331
1332   def testSignal(self):
1333     def _kill():
1334       os.kill(os.getpid(), signal.SIGTERM)
1335
1336     self.assertRaises(errors.GenericError,
1337                       utils.RunInSeparateProcess, _kill)
1338
1339   def testException(self):
1340     def _exc():
1341       raise errors.GenericError("This is a test")
1342
1343     self.assertRaises(errors.GenericError,
1344                       utils.RunInSeparateProcess, _exc)
1345
1346
1347 class TestFingerprintFile(unittest.TestCase):
1348   def setUp(self):
1349     self.tmpfile = tempfile.NamedTemporaryFile()
1350
1351   def test(self):
1352     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1353                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1354
1355     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1356     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1357                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1358
1359
1360 class TestUnescapeAndSplit(unittest.TestCase):
1361   """Testing case for UnescapeAndSplit"""
1362
1363   def setUp(self):
1364     # testing more that one separator for regexp safety
1365     self._seps = [",", "+", "."]
1366
1367   def testSimple(self):
1368     a = ["a", "b", "c", "d"]
1369     for sep in self._seps:
1370       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1371
1372   def testEscape(self):
1373     for sep in self._seps:
1374       a = ["a", "b\\" + sep + "c", "d"]
1375       b = ["a", "b" + sep + "c", "d"]
1376       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1377
1378   def testDoubleEscape(self):
1379     for sep in self._seps:
1380       a = ["a", "b\\\\", "c", "d"]
1381       b = ["a", "b\\", "c", "d"]
1382       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1383
1384   def testThreeEscape(self):
1385     for sep in self._seps:
1386       a = ["a", "b\\\\\\" + sep + "c", "d"]
1387       b = ["a", "b\\" + sep + "c", "d"]
1388       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1389
1390
1391 class TestPathJoin(unittest.TestCase):
1392   """Testing case for PathJoin"""
1393
1394   def testBasicItems(self):
1395     mlist = ["/a", "b", "c"]
1396     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1397
1398   def testNonAbsPrefix(self):
1399     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1400
1401   def testBackTrack(self):
1402     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1403
1404   def testMultiAbs(self):
1405     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1406
1407
1408 class TestHostInfo(unittest.TestCase):
1409   """Testing case for HostInfo"""
1410
1411   def testUppercase(self):
1412     data = "AbC.example.com"
1413     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1414
1415   def testTooLongName(self):
1416     data = "a.b." + "c" * 255
1417     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1418
1419   def testTrailingDot(self):
1420     data = "a.b.c"
1421     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1422
1423   def testInvalidName(self):
1424     data = [
1425       "a b",
1426       "a/b",
1427       ".a.b",
1428       "a..b",
1429       ]
1430     for value in data:
1431       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1432
1433   def testValidName(self):
1434     data = [
1435       "a.b",
1436       "a-b",
1437       "a_b",
1438       "a.b.c",
1439       ]
1440     for value in data:
1441       HostInfo.NormalizeName(value)
1442
1443
1444 class TestParseAsn1Generalizedtime(unittest.TestCase):
1445   def test(self):
1446     # UTC
1447     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1448     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1449                      1266860512)
1450     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1451                      (2**31) - 1)
1452
1453     # With offset
1454     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1455                      1266860512)
1456     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1457                      1266931012)
1458     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1459                      1266931088)
1460     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1461                      1266931295)
1462     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1463                      3600)
1464
1465     # Leap seconds are not supported by datetime.datetime
1466     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1467                       "19841231235960+0000")
1468     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1469                       "19920630235960+0000")
1470
1471     # Errors
1472     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1473     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1474     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1475                       "20100222174152")
1476     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1477                       "Mon Feb 22 17:47:02 UTC 2010")
1478     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1479                       "2010-02-22 17:42:02")
1480
1481
1482 class TestGetX509CertValidity(testutils.GanetiTestCase):
1483   def setUp(self):
1484     testutils.GanetiTestCase.setUp(self)
1485
1486     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1487
1488     # Test whether we have pyOpenSSL 0.7 or above
1489     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1490
1491     if not self.pyopenssl0_7:
1492       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1493                     " function correctly")
1494
1495   def _LoadCert(self, name):
1496     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1497                                            self._ReadTestData(name))
1498
1499   def test(self):
1500     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1501     if self.pyopenssl0_7:
1502       self.assertEqual(validity, (1266919967, 1267524767))
1503     else:
1504       self.assertEqual(validity, (None, None))
1505
1506
1507 class TestMakedirs(unittest.TestCase):
1508   def setUp(self):
1509     self.tmpdir = tempfile.mkdtemp()
1510
1511   def tearDown(self):
1512     shutil.rmtree(self.tmpdir)
1513
1514   def testNonExisting(self):
1515     path = utils.PathJoin(self.tmpdir, "foo")
1516     utils.Makedirs(path)
1517     self.assert_(os.path.isdir(path))
1518
1519   def testExisting(self):
1520     path = utils.PathJoin(self.tmpdir, "foo")
1521     os.mkdir(path)
1522     utils.Makedirs(path)
1523     self.assert_(os.path.isdir(path))
1524
1525   def testRecursiveNonExisting(self):
1526     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
1527     utils.Makedirs(path)
1528     self.assert_(os.path.isdir(path))
1529
1530   def testRecursiveExisting(self):
1531     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
1532     self.assert_(not os.path.exists(path))
1533     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
1534     utils.Makedirs(path)
1535     self.assert_(os.path.isdir(path))
1536
1537
1538 if __name__ == '__main__':
1539   testutils.GanetiTestProgram()