Implement replacing cluster certs and keys via “gnt-cluster renew-crypto”
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import md5
32 import signal
33 import socket
34 import shutil
35 import re
36 import select
37 import string
38 import OpenSSL
39 import warnings
40 import distutils.version
41 import glob
42
43 import ganeti
44 import testutils
45 from ganeti import constants
46 from ganeti import utils
47 from ganeti import errors
48 from ganeti.utils import IsProcessAlive, RunCmd, \
49      RemoveFile, MatchNameComponent, FormatUnit, \
50      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
51      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
52      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
53      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
54      UnescapeAndSplit, RunParts, PathJoin, HostInfo
55
56 from ganeti.errors import LockError, UnitParseError, GenericError, \
57      ProgrammerError, OpPrereqError
58
59
60 class TestIsProcessAlive(unittest.TestCase):
61   """Testing case for IsProcessAlive"""
62
63   def testExists(self):
64     mypid = os.getpid()
65     self.assert_(IsProcessAlive(mypid),
66                  "can't find myself running")
67
68   def testNotExisting(self):
69     pid_non_existing = os.fork()
70     if pid_non_existing == 0:
71       os._exit(0)
72     elif pid_non_existing < 0:
73       raise SystemError("can't fork")
74     os.waitpid(pid_non_existing, 0)
75     self.assert_(not IsProcessAlive(pid_non_existing),
76                  "nonexisting process detected")
77
78
79 class TestPidFileFunctions(unittest.TestCase):
80   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
81
82   def setUp(self):
83     self.dir = tempfile.mkdtemp()
84     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
85     utils.DaemonPidFileName = self.f_dpn
86
87   def testPidFileFunctions(self):
88     pid_file = self.f_dpn('test')
89     utils.WritePidFile('test')
90     self.failUnless(os.path.exists(pid_file),
91                     "PID file should have been created")
92     read_pid = utils.ReadPidFile(pid_file)
93     self.failUnlessEqual(read_pid, os.getpid())
94     self.failUnless(utils.IsProcessAlive(read_pid))
95     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
96     utils.RemovePidFile('test')
97     self.failIf(os.path.exists(pid_file),
98                 "PID file should not exist anymore")
99     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
100                          "ReadPidFile should return 0 for missing pid file")
101     fh = open(pid_file, "w")
102     fh.write("blah\n")
103     fh.close()
104     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
105                          "ReadPidFile should return 0 for invalid pid file")
106     utils.RemovePidFile('test')
107     self.failIf(os.path.exists(pid_file),
108                 "PID file should not exist anymore")
109
110   def testKill(self):
111     pid_file = self.f_dpn('child')
112     r_fd, w_fd = os.pipe()
113     new_pid = os.fork()
114     if new_pid == 0: #child
115       utils.WritePidFile('child')
116       os.write(w_fd, 'a')
117       signal.pause()
118       os._exit(0)
119       return
120     # else we are in the parent
121     # wait until the child has written the pid file
122     os.read(r_fd, 1)
123     read_pid = utils.ReadPidFile(pid_file)
124     self.failUnlessEqual(read_pid, new_pid)
125     self.failUnless(utils.IsProcessAlive(new_pid))
126     utils.KillProcess(new_pid, waitpid=True)
127     self.failIf(utils.IsProcessAlive(new_pid))
128     utils.RemovePidFile('child')
129     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
130
131   def tearDown(self):
132     for name in os.listdir(self.dir):
133       os.unlink(os.path.join(self.dir, name))
134     os.rmdir(self.dir)
135
136
137 class TestRunCmd(testutils.GanetiTestCase):
138   """Testing case for the RunCmd function"""
139
140   def setUp(self):
141     testutils.GanetiTestCase.setUp(self)
142     self.magic = time.ctime() + " ganeti test"
143     self.fname = self._CreateTempFile()
144
145   def testOk(self):
146     """Test successful exit code"""
147     result = RunCmd("/bin/sh -c 'exit 0'")
148     self.assertEqual(result.exit_code, 0)
149     self.assertEqual(result.output, "")
150
151   def testFail(self):
152     """Test fail exit code"""
153     result = RunCmd("/bin/sh -c 'exit 1'")
154     self.assertEqual(result.exit_code, 1)
155     self.assertEqual(result.output, "")
156
157   def testStdout(self):
158     """Test standard output"""
159     cmd = 'echo -n "%s"' % self.magic
160     result = RunCmd("/bin/sh -c '%s'" % cmd)
161     self.assertEqual(result.stdout, self.magic)
162     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
163     self.assertEqual(result.output, "")
164     self.assertFileContent(self.fname, self.magic)
165
166   def testStderr(self):
167     """Test standard error"""
168     cmd = 'echo -n "%s"' % self.magic
169     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
170     self.assertEqual(result.stderr, self.magic)
171     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
172     self.assertEqual(result.output, "")
173     self.assertFileContent(self.fname, self.magic)
174
175   def testCombined(self):
176     """Test combined output"""
177     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
178     expected = "A" + self.magic + "B" + self.magic
179     result = RunCmd("/bin/sh -c '%s'" % cmd)
180     self.assertEqual(result.output, expected)
181     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
182     self.assertEqual(result.output, "")
183     self.assertFileContent(self.fname, expected)
184
185   def testSignal(self):
186     """Test signal"""
187     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
188     self.assertEqual(result.signal, 15)
189     self.assertEqual(result.output, "")
190
191   def testListRun(self):
192     """Test list runs"""
193     result = RunCmd(["true"])
194     self.assertEqual(result.signal, None)
195     self.assertEqual(result.exit_code, 0)
196     result = RunCmd(["/bin/sh", "-c", "exit 1"])
197     self.assertEqual(result.signal, None)
198     self.assertEqual(result.exit_code, 1)
199     result = RunCmd(["echo", "-n", self.magic])
200     self.assertEqual(result.signal, None)
201     self.assertEqual(result.exit_code, 0)
202     self.assertEqual(result.stdout, self.magic)
203
204   def testFileEmptyOutput(self):
205     """Test file output"""
206     result = RunCmd(["true"], output=self.fname)
207     self.assertEqual(result.signal, None)
208     self.assertEqual(result.exit_code, 0)
209     self.assertFileContent(self.fname, "")
210
211   def testLang(self):
212     """Test locale environment"""
213     old_env = os.environ.copy()
214     try:
215       os.environ["LANG"] = "en_US.UTF-8"
216       os.environ["LC_ALL"] = "en_US.UTF-8"
217       result = RunCmd(["locale"])
218       for line in result.output.splitlines():
219         key, value = line.split("=", 1)
220         # Ignore these variables, they're overridden by LC_ALL
221         if key == "LANG" or key == "LANGUAGE":
222           continue
223         self.failIf(value and value != "C" and value != '"C"',
224             "Variable %s is set to the invalid value '%s'" % (key, value))
225     finally:
226       os.environ = old_env
227
228   def testDefaultCwd(self):
229     """Test default working directory"""
230     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
231
232   def testCwd(self):
233     """Test default working directory"""
234     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
235     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
236     cwd = os.getcwd()
237     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
238
239   def testResetEnv(self):
240     """Test environment reset functionality"""
241     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
242
243
244 class TestRunParts(unittest.TestCase):
245   """Testing case for the RunParts function"""
246
247   def setUp(self):
248     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
249
250   def tearDown(self):
251     shutil.rmtree(self.rundir)
252
253   def testEmpty(self):
254     """Test on an empty dir"""
255     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
256
257   def testSkipWrongName(self):
258     """Test that wrong files are skipped"""
259     fname = os.path.join(self.rundir, "00test.dot")
260     utils.WriteFile(fname, data="")
261     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
262     relname = os.path.basename(fname)
263     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
264                          [(relname, constants.RUNPARTS_SKIP, None)])
265
266   def testSkipNonExec(self):
267     """Test that non executable files are skipped"""
268     fname = os.path.join(self.rundir, "00test")
269     utils.WriteFile(fname, data="")
270     relname = os.path.basename(fname)
271     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
272                          [(relname, constants.RUNPARTS_SKIP, None)])
273
274   def testError(self):
275     """Test error on a broken executable"""
276     fname = os.path.join(self.rundir, "00test")
277     utils.WriteFile(fname, data="")
278     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
279     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
280     self.failUnlessEqual(relname, os.path.basename(fname))
281     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
282     self.failUnless(error)
283
284   def testSorted(self):
285     """Test executions are sorted"""
286     files = []
287     files.append(os.path.join(self.rundir, "64test"))
288     files.append(os.path.join(self.rundir, "00test"))
289     files.append(os.path.join(self.rundir, "42test"))
290
291     for fname in files:
292       utils.WriteFile(fname, data="")
293
294     results = RunParts(self.rundir, reset_env=True)
295
296     for fname in sorted(files):
297       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
298
299   def testOk(self):
300     """Test correct execution"""
301     fname = os.path.join(self.rundir, "00test")
302     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
303     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
304     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
305     self.failUnlessEqual(relname, os.path.basename(fname))
306     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
307     self.failUnlessEqual(runresult.stdout, "ciao")
308
309   def testRunFail(self):
310     """Test correct execution, with run failure"""
311     fname = os.path.join(self.rundir, "00test")
312     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
313     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
314     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
315     self.failUnlessEqual(relname, os.path.basename(fname))
316     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
317     self.failUnlessEqual(runresult.exit_code, 1)
318     self.failUnless(runresult.failed)
319
320   def testRunMix(self):
321     files = []
322     files.append(os.path.join(self.rundir, "00test"))
323     files.append(os.path.join(self.rundir, "42test"))
324     files.append(os.path.join(self.rundir, "64test"))
325     files.append(os.path.join(self.rundir, "99test"))
326
327     files.sort()
328
329     # 1st has errors in execution
330     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
331     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
332
333     # 2nd is skipped
334     utils.WriteFile(files[1], data="")
335
336     # 3rd cannot execute properly
337     utils.WriteFile(files[2], data="")
338     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
339
340     # 4th execs
341     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
342     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
343
344     results = RunParts(self.rundir, reset_env=True)
345
346     (relname, status, runresult) = results[0]
347     self.failUnlessEqual(relname, os.path.basename(files[0]))
348     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
349     self.failUnlessEqual(runresult.exit_code, 1)
350     self.failUnless(runresult.failed)
351
352     (relname, status, runresult) = results[1]
353     self.failUnlessEqual(relname, os.path.basename(files[1]))
354     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
355     self.failUnlessEqual(runresult, None)
356
357     (relname, status, runresult) = results[2]
358     self.failUnlessEqual(relname, os.path.basename(files[2]))
359     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
360     self.failUnless(runresult)
361
362     (relname, status, runresult) = results[3]
363     self.failUnlessEqual(relname, os.path.basename(files[3]))
364     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
365     self.failUnlessEqual(runresult.output, "ciao")
366     self.failUnlessEqual(runresult.exit_code, 0)
367     self.failUnless(not runresult.failed)
368
369
370 class TestRemoveFile(unittest.TestCase):
371   """Test case for the RemoveFile function"""
372
373   def setUp(self):
374     """Create a temp dir and file for each case"""
375     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
376     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
377     os.close(fd)
378
379   def tearDown(self):
380     if os.path.exists(self.tmpfile):
381       os.unlink(self.tmpfile)
382     os.rmdir(self.tmpdir)
383
384
385   def testIgnoreDirs(self):
386     """Test that RemoveFile() ignores directories"""
387     self.assertEqual(None, RemoveFile(self.tmpdir))
388
389
390   def testIgnoreNotExisting(self):
391     """Test that RemoveFile() ignores non-existing files"""
392     RemoveFile(self.tmpfile)
393     RemoveFile(self.tmpfile)
394
395
396   def testRemoveFile(self):
397     """Test that RemoveFile does remove a file"""
398     RemoveFile(self.tmpfile)
399     if os.path.exists(self.tmpfile):
400       self.fail("File '%s' not removed" % self.tmpfile)
401
402
403   def testRemoveSymlink(self):
404     """Test that RemoveFile does remove symlinks"""
405     symlink = self.tmpdir + "/symlink"
406     os.symlink("no-such-file", symlink)
407     RemoveFile(symlink)
408     if os.path.exists(symlink):
409       self.fail("File '%s' not removed" % symlink)
410     os.symlink(self.tmpfile, symlink)
411     RemoveFile(symlink)
412     if os.path.exists(symlink):
413       self.fail("File '%s' not removed" % symlink)
414
415
416 class TestRename(unittest.TestCase):
417   """Test case for RenameFile"""
418
419   def setUp(self):
420     """Create a temporary directory"""
421     self.tmpdir = tempfile.mkdtemp()
422     self.tmpfile = os.path.join(self.tmpdir, "test1")
423
424     # Touch the file
425     open(self.tmpfile, "w").close()
426
427   def tearDown(self):
428     """Remove temporary directory"""
429     shutil.rmtree(self.tmpdir)
430
431   def testSimpleRename1(self):
432     """Simple rename 1"""
433     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
434     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
435
436   def testSimpleRename2(self):
437     """Simple rename 2"""
438     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
439                      mkdir=True)
440     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
441
442   def testRenameMkdir(self):
443     """Rename with mkdir"""
444     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
445                      mkdir=True)
446     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
447     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
448
449     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
450                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
451                      mkdir=True)
452     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
453     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
454     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
455
456
457 class TestMatchNameComponent(unittest.TestCase):
458   """Test case for the MatchNameComponent function"""
459
460   def testEmptyList(self):
461     """Test that there is no match against an empty list"""
462
463     self.failUnlessEqual(MatchNameComponent("", []), None)
464     self.failUnlessEqual(MatchNameComponent("test", []), None)
465
466   def testSingleMatch(self):
467     """Test that a single match is performed correctly"""
468     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
469     for key in "test2", "test2.example", "test2.example.com":
470       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
471
472   def testMultipleMatches(self):
473     """Test that a multiple match is returned as None"""
474     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
475     for key in "test1", "test1.example":
476       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
477
478   def testFullMatch(self):
479     """Test that a full match is returned correctly"""
480     key1 = "test1"
481     key2 = "test1.example"
482     mlist = [key2, key2 + ".com"]
483     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
484     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
485
486   def testCaseInsensitivePartialMatch(self):
487     """Test for the case_insensitive keyword"""
488     mlist = ["test1.example.com", "test2.example.net"]
489     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
490                      "test2.example.net")
491     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
492                      "test2.example.net")
493     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
494                      "test2.example.net")
495     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
496                      "test2.example.net")
497
498
499   def testCaseInsensitiveFullMatch(self):
500     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
501     # Between the two ts1 a full string match non-case insensitive should work
502     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
503                      None)
504     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
505                      "ts1.ex")
506     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
507                      "ts1.ex")
508     # Between the two ts2 only case differs, so only case-match works
509     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
510                      "ts2.ex")
511     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
512                      "Ts2.ex")
513     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
514                      None)
515
516
517 class TestTimestampForFilename(unittest.TestCase):
518   def test(self):
519     self.assert_("." not in utils.TimestampForFilename())
520     self.assert_(":" not in utils.TimestampForFilename())
521
522
523 class TestCreateBackup(testutils.GanetiTestCase):
524   def setUp(self):
525     testutils.GanetiTestCase.setUp(self)
526
527     self.tmpdir = tempfile.mkdtemp()
528
529   def tearDown(self):
530     testutils.GanetiTestCase.tearDown(self)
531
532     shutil.rmtree(self.tmpdir)
533
534   def testEmpty(self):
535     filename = utils.PathJoin(self.tmpdir, "config.data")
536     utils.WriteFile(filename, data="")
537     bname = utils.CreateBackup(filename)
538     self.assertFileContent(bname, "")
539     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
540     utils.CreateBackup(filename)
541     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
542     utils.CreateBackup(filename)
543     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
544
545     fifoname = utils.PathJoin(self.tmpdir, "fifo")
546     os.mkfifo(fifoname)
547     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
548
549   def testContent(self):
550     bkpcount = 0
551     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
552       for rep in [1, 2, 10, 127]:
553         testdata = data * rep
554
555         filename = utils.PathJoin(self.tmpdir, "test.data_")
556         utils.WriteFile(filename, data=testdata)
557         self.assertFileContent(filename, testdata)
558
559         for _ in range(3):
560           bname = utils.CreateBackup(filename)
561           bkpcount += 1
562           self.assertFileContent(bname, testdata)
563           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
564
565
566 class TestFormatUnit(unittest.TestCase):
567   """Test case for the FormatUnit function"""
568
569   def testMiB(self):
570     self.assertEqual(FormatUnit(1, 'h'), '1M')
571     self.assertEqual(FormatUnit(100, 'h'), '100M')
572     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
573
574     self.assertEqual(FormatUnit(1, 'm'), '1')
575     self.assertEqual(FormatUnit(100, 'm'), '100')
576     self.assertEqual(FormatUnit(1023, 'm'), '1023')
577
578     self.assertEqual(FormatUnit(1024, 'm'), '1024')
579     self.assertEqual(FormatUnit(1536, 'm'), '1536')
580     self.assertEqual(FormatUnit(17133, 'm'), '17133')
581     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
582
583   def testGiB(self):
584     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
585     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
586     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
587     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
588
589     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
590     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
591     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
592     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
593
594     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
595     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
596     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
597
598   def testTiB(self):
599     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
600     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
601     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
602
603     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
604     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
605     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
606
607 class TestParseUnit(unittest.TestCase):
608   """Test case for the ParseUnit function"""
609
610   SCALES = (('', 1),
611             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
612             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
613             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
614
615   def testRounding(self):
616     self.assertEqual(ParseUnit('0'), 0)
617     self.assertEqual(ParseUnit('1'), 4)
618     self.assertEqual(ParseUnit('2'), 4)
619     self.assertEqual(ParseUnit('3'), 4)
620
621     self.assertEqual(ParseUnit('124'), 124)
622     self.assertEqual(ParseUnit('125'), 128)
623     self.assertEqual(ParseUnit('126'), 128)
624     self.assertEqual(ParseUnit('127'), 128)
625     self.assertEqual(ParseUnit('128'), 128)
626     self.assertEqual(ParseUnit('129'), 132)
627     self.assertEqual(ParseUnit('130'), 132)
628
629   def testFloating(self):
630     self.assertEqual(ParseUnit('0'), 0)
631     self.assertEqual(ParseUnit('0.5'), 4)
632     self.assertEqual(ParseUnit('1.75'), 4)
633     self.assertEqual(ParseUnit('1.99'), 4)
634     self.assertEqual(ParseUnit('2.00'), 4)
635     self.assertEqual(ParseUnit('2.01'), 4)
636     self.assertEqual(ParseUnit('3.99'), 4)
637     self.assertEqual(ParseUnit('4.00'), 4)
638     self.assertEqual(ParseUnit('4.01'), 8)
639     self.assertEqual(ParseUnit('1.5G'), 1536)
640     self.assertEqual(ParseUnit('1.8G'), 1844)
641     self.assertEqual(ParseUnit('8.28T'), 8682212)
642
643   def testSuffixes(self):
644     for sep in ('', ' ', '   ', "\t", "\t "):
645       for suffix, scale in TestParseUnit.SCALES:
646         for func in (lambda x: x, str.lower, str.upper):
647           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
648                            1024 * scale)
649
650   def testInvalidInput(self):
651     for sep in ('-', '_', ',', 'a'):
652       for suffix, _ in TestParseUnit.SCALES:
653         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
654
655     for suffix, _ in TestParseUnit.SCALES:
656       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
657
658
659 class TestSshKeys(testutils.GanetiTestCase):
660   """Test case for the AddAuthorizedKey function"""
661
662   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
663   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
664            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
665
666   def setUp(self):
667     testutils.GanetiTestCase.setUp(self)
668     self.tmpname = self._CreateTempFile()
669     handle = open(self.tmpname, 'w')
670     try:
671       handle.write("%s\n" % TestSshKeys.KEY_A)
672       handle.write("%s\n" % TestSshKeys.KEY_B)
673     finally:
674       handle.close()
675
676   def testAddingNewKey(self):
677     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
678
679     self.assertFileContent(self.tmpname,
680       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
681       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
682       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
683       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
684
685   def testAddingAlmostButNotCompletelyTheSameKey(self):
686     AddAuthorizedKey(self.tmpname,
687         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
688
689     self.assertFileContent(self.tmpname,
690       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
691       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
692       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
693       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
694
695   def testAddingExistingKeyWithSomeMoreSpaces(self):
696     AddAuthorizedKey(self.tmpname,
697         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
698
699     self.assertFileContent(self.tmpname,
700       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
701       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
702       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
703
704   def testRemovingExistingKeyWithSomeMoreSpaces(self):
705     RemoveAuthorizedKey(self.tmpname,
706         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
707
708     self.assertFileContent(self.tmpname,
709       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
710       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
711
712   def testRemovingNonExistingKey(self):
713     RemoveAuthorizedKey(self.tmpname,
714         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
715
716     self.assertFileContent(self.tmpname,
717       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
718       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
719       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
720
721
722 class TestEtcHosts(testutils.GanetiTestCase):
723   """Test functions modifying /etc/hosts"""
724
725   def setUp(self):
726     testutils.GanetiTestCase.setUp(self)
727     self.tmpname = self._CreateTempFile()
728     handle = open(self.tmpname, 'w')
729     try:
730       handle.write('# This is a test file for /etc/hosts\n')
731       handle.write('127.0.0.1\tlocalhost\n')
732       handle.write('192.168.1.1 router gw\n')
733     finally:
734       handle.close()
735
736   def testSettingNewIp(self):
737     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
738
739     self.assertFileContent(self.tmpname,
740       "# This is a test file for /etc/hosts\n"
741       "127.0.0.1\tlocalhost\n"
742       "192.168.1.1 router gw\n"
743       "1.2.3.4\tmyhost.domain.tld myhost\n")
744     self.assertFileMode(self.tmpname, 0644)
745
746   def testSettingExistingIp(self):
747     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
748                      ['myhost'])
749
750     self.assertFileContent(self.tmpname,
751       "# This is a test file for /etc/hosts\n"
752       "127.0.0.1\tlocalhost\n"
753       "192.168.1.1\tmyhost.domain.tld myhost\n")
754     self.assertFileMode(self.tmpname, 0644)
755
756   def testSettingDuplicateName(self):
757     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
758
759     self.assertFileContent(self.tmpname,
760       "# This is a test file for /etc/hosts\n"
761       "127.0.0.1\tlocalhost\n"
762       "192.168.1.1 router gw\n"
763       "1.2.3.4\tmyhost\n")
764     self.assertFileMode(self.tmpname, 0644)
765
766   def testRemovingExistingHost(self):
767     RemoveEtcHostsEntry(self.tmpname, 'router')
768
769     self.assertFileContent(self.tmpname,
770       "# This is a test file for /etc/hosts\n"
771       "127.0.0.1\tlocalhost\n"
772       "192.168.1.1 gw\n")
773     self.assertFileMode(self.tmpname, 0644)
774
775   def testRemovingSingleExistingHost(self):
776     RemoveEtcHostsEntry(self.tmpname, 'localhost')
777
778     self.assertFileContent(self.tmpname,
779       "# This is a test file for /etc/hosts\n"
780       "192.168.1.1 router gw\n")
781     self.assertFileMode(self.tmpname, 0644)
782
783   def testRemovingNonExistingHost(self):
784     RemoveEtcHostsEntry(self.tmpname, 'myhost')
785
786     self.assertFileContent(self.tmpname,
787       "# This is a test file for /etc/hosts\n"
788       "127.0.0.1\tlocalhost\n"
789       "192.168.1.1 router gw\n")
790     self.assertFileMode(self.tmpname, 0644)
791
792   def testRemovingAlias(self):
793     RemoveEtcHostsEntry(self.tmpname, 'gw')
794
795     self.assertFileContent(self.tmpname,
796       "# This is a test file for /etc/hosts\n"
797       "127.0.0.1\tlocalhost\n"
798       "192.168.1.1 router\n")
799     self.assertFileMode(self.tmpname, 0644)
800
801
802 class TestShellQuoting(unittest.TestCase):
803   """Test case for shell quoting functions"""
804
805   def testShellQuote(self):
806     self.assertEqual(ShellQuote('abc'), "abc")
807     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
808     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
809     self.assertEqual(ShellQuote("a b c"), "'a b c'")
810     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
811
812   def testShellQuoteArgs(self):
813     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
814     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
815     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
816
817
818 class TestTcpPing(unittest.TestCase):
819   """Testcase for TCP version of ping - against listen(2)ing port"""
820
821   def setUp(self):
822     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
823     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
824     self.listenerport = self.listener.getsockname()[1]
825     self.listener.listen(1)
826
827   def tearDown(self):
828     self.listener.shutdown(socket.SHUT_RDWR)
829     del self.listener
830     del self.listenerport
831
832   def testTcpPingToLocalHostAccept(self):
833     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
834                          self.listenerport,
835                          timeout=10,
836                          live_port_needed=True,
837                          source=constants.LOCALHOST_IP_ADDRESS,
838                          ),
839                  "failed to connect to test listener")
840
841     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
842                          self.listenerport,
843                          timeout=10,
844                          live_port_needed=True,
845                          ),
846                  "failed to connect to test listener (no source)")
847
848
849 class TestTcpPingDeaf(unittest.TestCase):
850   """Testcase for TCP version of ping - against non listen(2)ing port"""
851
852   def setUp(self):
853     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
854     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
855     self.deaflistenerport = self.deaflistener.getsockname()[1]
856
857   def tearDown(self):
858     del self.deaflistener
859     del self.deaflistenerport
860
861   def testTcpPingToLocalHostAcceptDeaf(self):
862     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
863                         self.deaflistenerport,
864                         timeout=constants.TCP_PING_TIMEOUT,
865                         live_port_needed=True,
866                         source=constants.LOCALHOST_IP_ADDRESS,
867                         ), # need successful connect(2)
868                 "successfully connected to deaf listener")
869
870     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
871                         self.deaflistenerport,
872                         timeout=constants.TCP_PING_TIMEOUT,
873                         live_port_needed=True,
874                         ), # need successful connect(2)
875                 "successfully connected to deaf listener (no source addr)")
876
877   def testTcpPingToLocalHostNoAccept(self):
878     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
879                          self.deaflistenerport,
880                          timeout=constants.TCP_PING_TIMEOUT,
881                          live_port_needed=False,
882                          source=constants.LOCALHOST_IP_ADDRESS,
883                          ), # ECONNREFUSED is OK
884                  "failed to ping alive host on deaf port")
885
886     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
887                          self.deaflistenerport,
888                          timeout=constants.TCP_PING_TIMEOUT,
889                          live_port_needed=False,
890                          ), # ECONNREFUSED is OK
891                  "failed to ping alive host on deaf port (no source addr)")
892
893
894 class TestOwnIpAddress(unittest.TestCase):
895   """Testcase for OwnIpAddress"""
896
897   def testOwnLoopback(self):
898     """check having the loopback ip"""
899     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
900                     "Should own the loopback address")
901
902   def testNowOwnAddress(self):
903     """check that I don't own an address"""
904
905     # network 192.0.2.0/24 is reserved for test/documentation as per
906     # rfc 3330, so we *should* not have an address of this range... if
907     # this fails, we should extend the test to multiple addresses
908     DST_IP = "192.0.2.1"
909     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
910
911
912 class TestListVisibleFiles(unittest.TestCase):
913   """Test case for ListVisibleFiles"""
914
915   def setUp(self):
916     self.path = tempfile.mkdtemp()
917
918   def tearDown(self):
919     shutil.rmtree(self.path)
920
921   def _test(self, files, expected):
922     # Sort a copy
923     expected = expected[:]
924     expected.sort()
925
926     for name in files:
927       f = open(os.path.join(self.path, name), 'w')
928       try:
929         f.write("Test\n")
930       finally:
931         f.close()
932
933     found = ListVisibleFiles(self.path)
934     found.sort()
935
936     self.assertEqual(found, expected)
937
938   def testAllVisible(self):
939     files = ["a", "b", "c"]
940     expected = files
941     self._test(files, expected)
942
943   def testNoneVisible(self):
944     files = [".a", ".b", ".c"]
945     expected = []
946     self._test(files, expected)
947
948   def testSomeVisible(self):
949     files = ["a", "b", ".c"]
950     expected = ["a", "b"]
951     self._test(files, expected)
952
953   def testNonAbsolutePath(self):
954     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
955
956   def testNonNormalizedPath(self):
957     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
958                           "/bin/../tmp")
959
960
961 class TestNewUUID(unittest.TestCase):
962   """Test case for NewUUID"""
963
964   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
965                         '[a-f0-9]{4}-[a-f0-9]{12}$')
966
967   def runTest(self):
968     self.failUnless(self._re_uuid.match(utils.NewUUID()))
969
970
971 class TestUniqueSequence(unittest.TestCase):
972   """Test case for UniqueSequence"""
973
974   def _test(self, input, expected):
975     self.assertEqual(utils.UniqueSequence(input), expected)
976
977   def runTest(self):
978     # Ordered input
979     self._test([1, 2, 3], [1, 2, 3])
980     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
981     self._test([1, 2, 2, 3], [1, 2, 3])
982     self._test([1, 2, 3, 3], [1, 2, 3])
983
984     # Unordered input
985     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
986     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
987
988     # Strings
989     self._test(["a", "a"], ["a"])
990     self._test(["a", "b"], ["a", "b"])
991     self._test(["a", "b", "a"], ["a", "b"])
992
993
994 class TestFirstFree(unittest.TestCase):
995   """Test case for the FirstFree function"""
996
997   def test(self):
998     """Test FirstFree"""
999     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1000     self.failUnlessEqual(FirstFree([]), None)
1001     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1002     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1003     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1004
1005
1006 class TestTailFile(testutils.GanetiTestCase):
1007   """Test case for the TailFile function"""
1008
1009   def testEmpty(self):
1010     fname = self._CreateTempFile()
1011     self.failUnlessEqual(TailFile(fname), [])
1012     self.failUnlessEqual(TailFile(fname, lines=25), [])
1013
1014   def testAllLines(self):
1015     data = ["test %d" % i for i in range(30)]
1016     for i in range(30):
1017       fname = self._CreateTempFile()
1018       fd = open(fname, "w")
1019       fd.write("\n".join(data[:i]))
1020       if i > 0:
1021         fd.write("\n")
1022       fd.close()
1023       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1024
1025   def testPartialLines(self):
1026     data = ["test %d" % i for i in range(30)]
1027     fname = self._CreateTempFile()
1028     fd = open(fname, "w")
1029     fd.write("\n".join(data))
1030     fd.write("\n")
1031     fd.close()
1032     for i in range(1, 30):
1033       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1034
1035   def testBigFile(self):
1036     data = ["test %d" % i for i in range(30)]
1037     fname = self._CreateTempFile()
1038     fd = open(fname, "w")
1039     fd.write("X" * 1048576)
1040     fd.write("\n")
1041     fd.write("\n".join(data))
1042     fd.write("\n")
1043     fd.close()
1044     for i in range(1, 30):
1045       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1046
1047
1048 class _BaseFileLockTest:
1049   """Test case for the FileLock class"""
1050
1051   def testSharedNonblocking(self):
1052     self.lock.Shared(blocking=False)
1053     self.lock.Close()
1054
1055   def testExclusiveNonblocking(self):
1056     self.lock.Exclusive(blocking=False)
1057     self.lock.Close()
1058
1059   def testUnlockNonblocking(self):
1060     self.lock.Unlock(blocking=False)
1061     self.lock.Close()
1062
1063   def testSharedBlocking(self):
1064     self.lock.Shared(blocking=True)
1065     self.lock.Close()
1066
1067   def testExclusiveBlocking(self):
1068     self.lock.Exclusive(blocking=True)
1069     self.lock.Close()
1070
1071   def testUnlockBlocking(self):
1072     self.lock.Unlock(blocking=True)
1073     self.lock.Close()
1074
1075   def testSharedExclusiveUnlock(self):
1076     self.lock.Shared(blocking=False)
1077     self.lock.Exclusive(blocking=False)
1078     self.lock.Unlock(blocking=False)
1079     self.lock.Close()
1080
1081   def testExclusiveSharedUnlock(self):
1082     self.lock.Exclusive(blocking=False)
1083     self.lock.Shared(blocking=False)
1084     self.lock.Unlock(blocking=False)
1085     self.lock.Close()
1086
1087   def testSimpleTimeout(self):
1088     # These will succeed on the first attempt, hence a short timeout
1089     self.lock.Shared(blocking=True, timeout=10.0)
1090     self.lock.Exclusive(blocking=False, timeout=10.0)
1091     self.lock.Unlock(blocking=True, timeout=10.0)
1092     self.lock.Close()
1093
1094   @staticmethod
1095   def _TryLockInner(filename, shared, blocking):
1096     lock = utils.FileLock.Open(filename)
1097
1098     if shared:
1099       fn = lock.Shared
1100     else:
1101       fn = lock.Exclusive
1102
1103     try:
1104       # The timeout doesn't really matter as the parent process waits for us to
1105       # finish anyway.
1106       fn(blocking=blocking, timeout=0.01)
1107     except errors.LockError, err:
1108       return False
1109
1110     return True
1111
1112   def _TryLock(self, *args):
1113     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1114                                       *args)
1115
1116   def testTimeout(self):
1117     for blocking in [True, False]:
1118       self.lock.Exclusive(blocking=True)
1119       self.failIf(self._TryLock(False, blocking))
1120       self.failIf(self._TryLock(True, blocking))
1121
1122       self.lock.Shared(blocking=True)
1123       self.assert_(self._TryLock(True, blocking))
1124       self.failIf(self._TryLock(False, blocking))
1125
1126   def testCloseShared(self):
1127     self.lock.Close()
1128     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1129
1130   def testCloseExclusive(self):
1131     self.lock.Close()
1132     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1133
1134   def testCloseUnlock(self):
1135     self.lock.Close()
1136     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1137
1138
1139 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1140   TESTDATA = "Hello World\n" * 10
1141
1142   def setUp(self):
1143     testutils.GanetiTestCase.setUp(self)
1144
1145     self.tmpfile = tempfile.NamedTemporaryFile()
1146     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1147     self.lock = utils.FileLock.Open(self.tmpfile.name)
1148
1149     # Ensure "Open" didn't truncate file
1150     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1151
1152   def tearDown(self):
1153     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1154
1155     testutils.GanetiTestCase.tearDown(self)
1156
1157
1158 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1159   def setUp(self):
1160     self.tmpfile = tempfile.NamedTemporaryFile()
1161     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1162
1163
1164 class TestTimeFunctions(unittest.TestCase):
1165   """Test case for time functions"""
1166
1167   def runTest(self):
1168     self.assertEqual(utils.SplitTime(1), (1, 0))
1169     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1170     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1171     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1172     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1173     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1174     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1175     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1176
1177     self.assertRaises(AssertionError, utils.SplitTime, -1)
1178
1179     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1180     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1181     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1182
1183     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1184                      1218448917.481)
1185     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1186
1187     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1188     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1189     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1190     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1191     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1192
1193
1194 class FieldSetTestCase(unittest.TestCase):
1195   """Test case for FieldSets"""
1196
1197   def testSimpleMatch(self):
1198     f = utils.FieldSet("a", "b", "c", "def")
1199     self.failUnless(f.Matches("a"))
1200     self.failIf(f.Matches("d"), "Substring matched")
1201     self.failIf(f.Matches("defghi"), "Prefix string matched")
1202     self.failIf(f.NonMatching(["b", "c"]))
1203     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1204     self.failUnless(f.NonMatching(["a", "d"]))
1205
1206   def testRegexMatch(self):
1207     f = utils.FieldSet("a", "b([0-9]+)", "c")
1208     self.failUnless(f.Matches("b1"))
1209     self.failUnless(f.Matches("b99"))
1210     self.failIf(f.Matches("b/1"))
1211     self.failIf(f.NonMatching(["b12", "c"]))
1212     self.failUnless(f.NonMatching(["a", "1"]))
1213
1214 class TestForceDictType(unittest.TestCase):
1215   """Test case for ForceDictType"""
1216
1217   def setUp(self):
1218     self.key_types = {
1219       'a': constants.VTYPE_INT,
1220       'b': constants.VTYPE_BOOL,
1221       'c': constants.VTYPE_STRING,
1222       'd': constants.VTYPE_SIZE,
1223       }
1224
1225   def _fdt(self, dict, allowed_values=None):
1226     if allowed_values is None:
1227       ForceDictType(dict, self.key_types)
1228     else:
1229       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1230
1231     return dict
1232
1233   def testSimpleDict(self):
1234     self.assertEqual(self._fdt({}), {})
1235     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1236     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1237     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1238     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1239     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1240     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1241     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1242     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1243     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1244     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1245     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1246
1247   def testErrors(self):
1248     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1249     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1250     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1251     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1252
1253
1254 class TestIsAbsNormPath(unittest.TestCase):
1255   """Testing case for IsProcessAlive"""
1256
1257   def _pathTestHelper(self, path, result):
1258     if result:
1259       self.assert_(IsNormAbsPath(path),
1260           "Path %s should result absolute and normalized" % path)
1261     else:
1262       self.assert_(not IsNormAbsPath(path),
1263           "Path %s should not result absolute and normalized" % path)
1264
1265   def testBase(self):
1266     self._pathTestHelper('/etc', True)
1267     self._pathTestHelper('/srv', True)
1268     self._pathTestHelper('etc', False)
1269     self._pathTestHelper('/etc/../root', False)
1270     self._pathTestHelper('/etc/', False)
1271
1272
1273 class TestSafeEncode(unittest.TestCase):
1274   """Test case for SafeEncode"""
1275
1276   def testAscii(self):
1277     for txt in [string.digits, string.letters, string.punctuation]:
1278       self.failUnlessEqual(txt, SafeEncode(txt))
1279
1280   def testDoubleEncode(self):
1281     for i in range(255):
1282       txt = SafeEncode(chr(i))
1283       self.failUnlessEqual(txt, SafeEncode(txt))
1284
1285   def testUnicode(self):
1286     # 1024 is high enough to catch non-direct ASCII mappings
1287     for i in range(1024):
1288       txt = SafeEncode(unichr(i))
1289       self.failUnlessEqual(txt, SafeEncode(txt))
1290
1291
1292 class TestFormatTime(unittest.TestCase):
1293   """Testing case for FormatTime"""
1294
1295   def testNone(self):
1296     self.failUnlessEqual(FormatTime(None), "N/A")
1297
1298   def testInvalid(self):
1299     self.failUnlessEqual(FormatTime(()), "N/A")
1300
1301   def testNow(self):
1302     # tests that we accept time.time input
1303     FormatTime(time.time())
1304     # tests that we accept int input
1305     FormatTime(int(time.time()))
1306
1307
1308 class RunInSeparateProcess(unittest.TestCase):
1309   def test(self):
1310     for exp in [True, False]:
1311       def _child():
1312         return exp
1313
1314       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1315
1316   def testArgs(self):
1317     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1318       def _child(carg1, carg2):
1319         return carg1 == "Foo" and carg2 == arg
1320
1321       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1322
1323   def testPid(self):
1324     parent_pid = os.getpid()
1325
1326     def _check():
1327       return os.getpid() == parent_pid
1328
1329     self.failIf(utils.RunInSeparateProcess(_check))
1330
1331   def testSignal(self):
1332     def _kill():
1333       os.kill(os.getpid(), signal.SIGTERM)
1334
1335     self.assertRaises(errors.GenericError,
1336                       utils.RunInSeparateProcess, _kill)
1337
1338   def testException(self):
1339     def _exc():
1340       raise errors.GenericError("This is a test")
1341
1342     self.assertRaises(errors.GenericError,
1343                       utils.RunInSeparateProcess, _exc)
1344
1345
1346 class TestFingerprintFile(unittest.TestCase):
1347   def setUp(self):
1348     self.tmpfile = tempfile.NamedTemporaryFile()
1349
1350   def test(self):
1351     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1352                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1353
1354     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1355     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1356                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1357
1358
1359 class TestUnescapeAndSplit(unittest.TestCase):
1360   """Testing case for UnescapeAndSplit"""
1361
1362   def setUp(self):
1363     # testing more that one separator for regexp safety
1364     self._seps = [",", "+", "."]
1365
1366   def testSimple(self):
1367     a = ["a", "b", "c", "d"]
1368     for sep in self._seps:
1369       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1370
1371   def testEscape(self):
1372     for sep in self._seps:
1373       a = ["a", "b\\" + sep + "c", "d"]
1374       b = ["a", "b" + sep + "c", "d"]
1375       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1376
1377   def testDoubleEscape(self):
1378     for sep in self._seps:
1379       a = ["a", "b\\\\", "c", "d"]
1380       b = ["a", "b\\", "c", "d"]
1381       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1382
1383   def testThreeEscape(self):
1384     for sep in self._seps:
1385       a = ["a", "b\\\\\\" + sep + "c", "d"]
1386       b = ["a", "b\\" + sep + "c", "d"]
1387       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1388
1389
1390 class TestPathJoin(unittest.TestCase):
1391   """Testing case for PathJoin"""
1392
1393   def testBasicItems(self):
1394     mlist = ["/a", "b", "c"]
1395     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1396
1397   def testNonAbsPrefix(self):
1398     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1399
1400   def testBackTrack(self):
1401     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1402
1403   def testMultiAbs(self):
1404     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1405
1406
1407 class TestHostInfo(unittest.TestCase):
1408   """Testing case for HostInfo"""
1409
1410   def testUppercase(self):
1411     data = "AbC.example.com"
1412     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1413
1414   def testTooLongName(self):
1415     data = "a.b." + "c" * 255
1416     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1417
1418   def testTrailingDot(self):
1419     data = "a.b.c"
1420     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1421
1422   def testInvalidName(self):
1423     data = [
1424       "a b",
1425       "a/b",
1426       ".a.b",
1427       "a..b",
1428       ]
1429     for value in data:
1430       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1431
1432   def testValidName(self):
1433     data = [
1434       "a.b",
1435       "a-b",
1436       "a_b",
1437       "a.b.c",
1438       ]
1439     for value in data:
1440       HostInfo.NormalizeName(value)
1441
1442
1443 class TestParseAsn1Generalizedtime(unittest.TestCase):
1444   def test(self):
1445     # UTC
1446     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1447     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1448                      1266860512)
1449     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1450                      (2**31) - 1)
1451
1452     # With offset
1453     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1454                      1266860512)
1455     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1456                      1266931012)
1457     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1458                      1266931088)
1459     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1460                      1266931295)
1461     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1462                      3600)
1463
1464     # Leap seconds are not supported by datetime.datetime
1465     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1466                       "19841231235960+0000")
1467     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1468                       "19920630235960+0000")
1469
1470     # Errors
1471     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1472     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1473     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1474                       "20100222174152")
1475     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1476                       "Mon Feb 22 17:47:02 UTC 2010")
1477     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1478                       "2010-02-22 17:42:02")
1479
1480
1481 class TestGetX509CertValidity(testutils.GanetiTestCase):
1482   def setUp(self):
1483     testutils.GanetiTestCase.setUp(self)
1484
1485     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1486
1487     # Test whether we have pyOpenSSL 0.7 or above
1488     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1489
1490     if not self.pyopenssl0_7:
1491       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1492                     " function correctly")
1493
1494   def _LoadCert(self, name):
1495     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1496                                            self._ReadTestData(name))
1497
1498   def test(self):
1499     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1500     if self.pyopenssl0_7:
1501       self.assertEqual(validity, (1266919967, 1267524767))
1502     else:
1503       self.assertEqual(validity, (None, None))
1504
1505
1506 if __name__ == '__main__':
1507   testutils.GanetiTestProgram()