utils.Retry: pass up timeout arguments
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import OpenSSL
38 import warnings
39 import distutils.version
40 import glob
41
42 import ganeti
43 import testutils
44 from ganeti import constants
45 from ganeti import utils
46 from ganeti import errors
47 from ganeti import serializer
48 from ganeti.utils import IsProcessAlive, RunCmd, \
49      RemoveFile, MatchNameComponent, FormatUnit, \
50      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
51      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
52      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
53      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
54      UnescapeAndSplit, RunParts, PathJoin, HostInfo
55
56 from ganeti.errors import LockError, UnitParseError, GenericError, \
57      ProgrammerError, OpPrereqError
58
59
60 class TestIsProcessAlive(unittest.TestCase):
61   """Testing case for IsProcessAlive"""
62
63   def testExists(self):
64     mypid = os.getpid()
65     self.assert_(IsProcessAlive(mypid),
66                  "can't find myself running")
67
68   def testNotExisting(self):
69     pid_non_existing = os.fork()
70     if pid_non_existing == 0:
71       os._exit(0)
72     elif pid_non_existing < 0:
73       raise SystemError("can't fork")
74     os.waitpid(pid_non_existing, 0)
75     self.assert_(not IsProcessAlive(pid_non_existing),
76                  "nonexisting process detected")
77
78
79 class TestPidFileFunctions(unittest.TestCase):
80   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
81
82   def setUp(self):
83     self.dir = tempfile.mkdtemp()
84     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
85     utils.DaemonPidFileName = self.f_dpn
86
87   def testPidFileFunctions(self):
88     pid_file = self.f_dpn('test')
89     utils.WritePidFile('test')
90     self.failUnless(os.path.exists(pid_file),
91                     "PID file should have been created")
92     read_pid = utils.ReadPidFile(pid_file)
93     self.failUnlessEqual(read_pid, os.getpid())
94     self.failUnless(utils.IsProcessAlive(read_pid))
95     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
96     utils.RemovePidFile('test')
97     self.failIf(os.path.exists(pid_file),
98                 "PID file should not exist anymore")
99     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
100                          "ReadPidFile should return 0 for missing pid file")
101     fh = open(pid_file, "w")
102     fh.write("blah\n")
103     fh.close()
104     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
105                          "ReadPidFile should return 0 for invalid pid file")
106     utils.RemovePidFile('test')
107     self.failIf(os.path.exists(pid_file),
108                 "PID file should not exist anymore")
109
110   def testKill(self):
111     pid_file = self.f_dpn('child')
112     r_fd, w_fd = os.pipe()
113     new_pid = os.fork()
114     if new_pid == 0: #child
115       utils.WritePidFile('child')
116       os.write(w_fd, 'a')
117       signal.pause()
118       os._exit(0)
119       return
120     # else we are in the parent
121     # wait until the child has written the pid file
122     os.read(r_fd, 1)
123     read_pid = utils.ReadPidFile(pid_file)
124     self.failUnlessEqual(read_pid, new_pid)
125     self.failUnless(utils.IsProcessAlive(new_pid))
126     utils.KillProcess(new_pid, waitpid=True)
127     self.failIf(utils.IsProcessAlive(new_pid))
128     utils.RemovePidFile('child')
129     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
130
131   def tearDown(self):
132     for name in os.listdir(self.dir):
133       os.unlink(os.path.join(self.dir, name))
134     os.rmdir(self.dir)
135
136
137 class TestRunCmd(testutils.GanetiTestCase):
138   """Testing case for the RunCmd function"""
139
140   def setUp(self):
141     testutils.GanetiTestCase.setUp(self)
142     self.magic = time.ctime() + " ganeti test"
143     self.fname = self._CreateTempFile()
144
145   def testOk(self):
146     """Test successful exit code"""
147     result = RunCmd("/bin/sh -c 'exit 0'")
148     self.assertEqual(result.exit_code, 0)
149     self.assertEqual(result.output, "")
150
151   def testFail(self):
152     """Test fail exit code"""
153     result = RunCmd("/bin/sh -c 'exit 1'")
154     self.assertEqual(result.exit_code, 1)
155     self.assertEqual(result.output, "")
156
157   def testStdout(self):
158     """Test standard output"""
159     cmd = 'echo -n "%s"' % self.magic
160     result = RunCmd("/bin/sh -c '%s'" % cmd)
161     self.assertEqual(result.stdout, self.magic)
162     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
163     self.assertEqual(result.output, "")
164     self.assertFileContent(self.fname, self.magic)
165
166   def testStderr(self):
167     """Test standard error"""
168     cmd = 'echo -n "%s"' % self.magic
169     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
170     self.assertEqual(result.stderr, self.magic)
171     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
172     self.assertEqual(result.output, "")
173     self.assertFileContent(self.fname, self.magic)
174
175   def testCombined(self):
176     """Test combined output"""
177     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
178     expected = "A" + self.magic + "B" + self.magic
179     result = RunCmd("/bin/sh -c '%s'" % cmd)
180     self.assertEqual(result.output, expected)
181     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
182     self.assertEqual(result.output, "")
183     self.assertFileContent(self.fname, expected)
184
185   def testSignal(self):
186     """Test signal"""
187     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
188     self.assertEqual(result.signal, 15)
189     self.assertEqual(result.output, "")
190
191   def testListRun(self):
192     """Test list runs"""
193     result = RunCmd(["true"])
194     self.assertEqual(result.signal, None)
195     self.assertEqual(result.exit_code, 0)
196     result = RunCmd(["/bin/sh", "-c", "exit 1"])
197     self.assertEqual(result.signal, None)
198     self.assertEqual(result.exit_code, 1)
199     result = RunCmd(["echo", "-n", self.magic])
200     self.assertEqual(result.signal, None)
201     self.assertEqual(result.exit_code, 0)
202     self.assertEqual(result.stdout, self.magic)
203
204   def testFileEmptyOutput(self):
205     """Test file output"""
206     result = RunCmd(["true"], output=self.fname)
207     self.assertEqual(result.signal, None)
208     self.assertEqual(result.exit_code, 0)
209     self.assertFileContent(self.fname, "")
210
211   def testLang(self):
212     """Test locale environment"""
213     old_env = os.environ.copy()
214     try:
215       os.environ["LANG"] = "en_US.UTF-8"
216       os.environ["LC_ALL"] = "en_US.UTF-8"
217       result = RunCmd(["locale"])
218       for line in result.output.splitlines():
219         key, value = line.split("=", 1)
220         # Ignore these variables, they're overridden by LC_ALL
221         if key == "LANG" or key == "LANGUAGE":
222           continue
223         self.failIf(value and value != "C" and value != '"C"',
224             "Variable %s is set to the invalid value '%s'" % (key, value))
225     finally:
226       os.environ = old_env
227
228   def testDefaultCwd(self):
229     """Test default working directory"""
230     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
231
232   def testCwd(self):
233     """Test default working directory"""
234     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
235     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
236     cwd = os.getcwd()
237     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
238
239   def testResetEnv(self):
240     """Test environment reset functionality"""
241     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
242     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
243                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
244
245
246 class TestRunParts(unittest.TestCase):
247   """Testing case for the RunParts function"""
248
249   def setUp(self):
250     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
251
252   def tearDown(self):
253     shutil.rmtree(self.rundir)
254
255   def testEmpty(self):
256     """Test on an empty dir"""
257     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
258
259   def testSkipWrongName(self):
260     """Test that wrong files are skipped"""
261     fname = os.path.join(self.rundir, "00test.dot")
262     utils.WriteFile(fname, data="")
263     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
264     relname = os.path.basename(fname)
265     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
266                          [(relname, constants.RUNPARTS_SKIP, None)])
267
268   def testSkipNonExec(self):
269     """Test that non executable files are skipped"""
270     fname = os.path.join(self.rundir, "00test")
271     utils.WriteFile(fname, data="")
272     relname = os.path.basename(fname)
273     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
274                          [(relname, constants.RUNPARTS_SKIP, None)])
275
276   def testError(self):
277     """Test error on a broken executable"""
278     fname = os.path.join(self.rundir, "00test")
279     utils.WriteFile(fname, data="")
280     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
281     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
282     self.failUnlessEqual(relname, os.path.basename(fname))
283     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
284     self.failUnless(error)
285
286   def testSorted(self):
287     """Test executions are sorted"""
288     files = []
289     files.append(os.path.join(self.rundir, "64test"))
290     files.append(os.path.join(self.rundir, "00test"))
291     files.append(os.path.join(self.rundir, "42test"))
292
293     for fname in files:
294       utils.WriteFile(fname, data="")
295
296     results = RunParts(self.rundir, reset_env=True)
297
298     for fname in sorted(files):
299       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
300
301   def testOk(self):
302     """Test correct execution"""
303     fname = os.path.join(self.rundir, "00test")
304     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
305     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
306     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
307     self.failUnlessEqual(relname, os.path.basename(fname))
308     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
309     self.failUnlessEqual(runresult.stdout, "ciao")
310
311   def testRunFail(self):
312     """Test correct execution, with run failure"""
313     fname = os.path.join(self.rundir, "00test")
314     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
315     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
316     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
317     self.failUnlessEqual(relname, os.path.basename(fname))
318     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
319     self.failUnlessEqual(runresult.exit_code, 1)
320     self.failUnless(runresult.failed)
321
322   def testRunMix(self):
323     files = []
324     files.append(os.path.join(self.rundir, "00test"))
325     files.append(os.path.join(self.rundir, "42test"))
326     files.append(os.path.join(self.rundir, "64test"))
327     files.append(os.path.join(self.rundir, "99test"))
328
329     files.sort()
330
331     # 1st has errors in execution
332     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
333     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
334
335     # 2nd is skipped
336     utils.WriteFile(files[1], data="")
337
338     # 3rd cannot execute properly
339     utils.WriteFile(files[2], data="")
340     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
341
342     # 4th execs
343     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
344     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
345
346     results = RunParts(self.rundir, reset_env=True)
347
348     (relname, status, runresult) = results[0]
349     self.failUnlessEqual(relname, os.path.basename(files[0]))
350     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
351     self.failUnlessEqual(runresult.exit_code, 1)
352     self.failUnless(runresult.failed)
353
354     (relname, status, runresult) = results[1]
355     self.failUnlessEqual(relname, os.path.basename(files[1]))
356     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
357     self.failUnlessEqual(runresult, None)
358
359     (relname, status, runresult) = results[2]
360     self.failUnlessEqual(relname, os.path.basename(files[2]))
361     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
362     self.failUnless(runresult)
363
364     (relname, status, runresult) = results[3]
365     self.failUnlessEqual(relname, os.path.basename(files[3]))
366     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
367     self.failUnlessEqual(runresult.output, "ciao")
368     self.failUnlessEqual(runresult.exit_code, 0)
369     self.failUnless(not runresult.failed)
370
371
372 class TestRemoveFile(unittest.TestCase):
373   """Test case for the RemoveFile function"""
374
375   def setUp(self):
376     """Create a temp dir and file for each case"""
377     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
378     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
379     os.close(fd)
380
381   def tearDown(self):
382     if os.path.exists(self.tmpfile):
383       os.unlink(self.tmpfile)
384     os.rmdir(self.tmpdir)
385
386
387   def testIgnoreDirs(self):
388     """Test that RemoveFile() ignores directories"""
389     self.assertEqual(None, RemoveFile(self.tmpdir))
390
391
392   def testIgnoreNotExisting(self):
393     """Test that RemoveFile() ignores non-existing files"""
394     RemoveFile(self.tmpfile)
395     RemoveFile(self.tmpfile)
396
397
398   def testRemoveFile(self):
399     """Test that RemoveFile does remove a file"""
400     RemoveFile(self.tmpfile)
401     if os.path.exists(self.tmpfile):
402       self.fail("File '%s' not removed" % self.tmpfile)
403
404
405   def testRemoveSymlink(self):
406     """Test that RemoveFile does remove symlinks"""
407     symlink = self.tmpdir + "/symlink"
408     os.symlink("no-such-file", symlink)
409     RemoveFile(symlink)
410     if os.path.exists(symlink):
411       self.fail("File '%s' not removed" % symlink)
412     os.symlink(self.tmpfile, symlink)
413     RemoveFile(symlink)
414     if os.path.exists(symlink):
415       self.fail("File '%s' not removed" % symlink)
416
417
418 class TestRename(unittest.TestCase):
419   """Test case for RenameFile"""
420
421   def setUp(self):
422     """Create a temporary directory"""
423     self.tmpdir = tempfile.mkdtemp()
424     self.tmpfile = os.path.join(self.tmpdir, "test1")
425
426     # Touch the file
427     open(self.tmpfile, "w").close()
428
429   def tearDown(self):
430     """Remove temporary directory"""
431     shutil.rmtree(self.tmpdir)
432
433   def testSimpleRename1(self):
434     """Simple rename 1"""
435     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
436     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
437
438   def testSimpleRename2(self):
439     """Simple rename 2"""
440     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
441                      mkdir=True)
442     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
443
444   def testRenameMkdir(self):
445     """Rename with mkdir"""
446     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
447                      mkdir=True)
448     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
449     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
450
451     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
452                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
453                      mkdir=True)
454     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
455     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
456     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
457
458
459 class TestMatchNameComponent(unittest.TestCase):
460   """Test case for the MatchNameComponent function"""
461
462   def testEmptyList(self):
463     """Test that there is no match against an empty list"""
464
465     self.failUnlessEqual(MatchNameComponent("", []), None)
466     self.failUnlessEqual(MatchNameComponent("test", []), None)
467
468   def testSingleMatch(self):
469     """Test that a single match is performed correctly"""
470     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
471     for key in "test2", "test2.example", "test2.example.com":
472       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
473
474   def testMultipleMatches(self):
475     """Test that a multiple match is returned as None"""
476     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
477     for key in "test1", "test1.example":
478       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
479
480   def testFullMatch(self):
481     """Test that a full match is returned correctly"""
482     key1 = "test1"
483     key2 = "test1.example"
484     mlist = [key2, key2 + ".com"]
485     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
486     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
487
488   def testCaseInsensitivePartialMatch(self):
489     """Test for the case_insensitive keyword"""
490     mlist = ["test1.example.com", "test2.example.net"]
491     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
492                      "test2.example.net")
493     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
494                      "test2.example.net")
495     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
496                      "test2.example.net")
497     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
498                      "test2.example.net")
499
500
501   def testCaseInsensitiveFullMatch(self):
502     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
503     # Between the two ts1 a full string match non-case insensitive should work
504     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
505                      None)
506     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
507                      "ts1.ex")
508     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
509                      "ts1.ex")
510     # Between the two ts2 only case differs, so only case-match works
511     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
512                      "ts2.ex")
513     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
514                      "Ts2.ex")
515     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
516                      None)
517
518
519 class TestTimestampForFilename(unittest.TestCase):
520   def test(self):
521     self.assert_("." not in utils.TimestampForFilename())
522     self.assert_(":" not in utils.TimestampForFilename())
523
524
525 class TestCreateBackup(testutils.GanetiTestCase):
526   def setUp(self):
527     testutils.GanetiTestCase.setUp(self)
528
529     self.tmpdir = tempfile.mkdtemp()
530
531   def tearDown(self):
532     testutils.GanetiTestCase.tearDown(self)
533
534     shutil.rmtree(self.tmpdir)
535
536   def testEmpty(self):
537     filename = utils.PathJoin(self.tmpdir, "config.data")
538     utils.WriteFile(filename, data="")
539     bname = utils.CreateBackup(filename)
540     self.assertFileContent(bname, "")
541     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
542     utils.CreateBackup(filename)
543     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
544     utils.CreateBackup(filename)
545     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
546
547     fifoname = utils.PathJoin(self.tmpdir, "fifo")
548     os.mkfifo(fifoname)
549     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
550
551   def testContent(self):
552     bkpcount = 0
553     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
554       for rep in [1, 2, 10, 127]:
555         testdata = data * rep
556
557         filename = utils.PathJoin(self.tmpdir, "test.data_")
558         utils.WriteFile(filename, data=testdata)
559         self.assertFileContent(filename, testdata)
560
561         for _ in range(3):
562           bname = utils.CreateBackup(filename)
563           bkpcount += 1
564           self.assertFileContent(bname, testdata)
565           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
566
567
568 class TestFormatUnit(unittest.TestCase):
569   """Test case for the FormatUnit function"""
570
571   def testMiB(self):
572     self.assertEqual(FormatUnit(1, 'h'), '1M')
573     self.assertEqual(FormatUnit(100, 'h'), '100M')
574     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
575
576     self.assertEqual(FormatUnit(1, 'm'), '1')
577     self.assertEqual(FormatUnit(100, 'm'), '100')
578     self.assertEqual(FormatUnit(1023, 'm'), '1023')
579
580     self.assertEqual(FormatUnit(1024, 'm'), '1024')
581     self.assertEqual(FormatUnit(1536, 'm'), '1536')
582     self.assertEqual(FormatUnit(17133, 'm'), '17133')
583     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
584
585   def testGiB(self):
586     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
587     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
588     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
589     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
590
591     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
592     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
593     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
594     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
595
596     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
597     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
598     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
599
600   def testTiB(self):
601     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
602     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
603     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
604
605     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
606     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
607     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
608
609 class TestParseUnit(unittest.TestCase):
610   """Test case for the ParseUnit function"""
611
612   SCALES = (('', 1),
613             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
614             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
615             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
616
617   def testRounding(self):
618     self.assertEqual(ParseUnit('0'), 0)
619     self.assertEqual(ParseUnit('1'), 4)
620     self.assertEqual(ParseUnit('2'), 4)
621     self.assertEqual(ParseUnit('3'), 4)
622
623     self.assertEqual(ParseUnit('124'), 124)
624     self.assertEqual(ParseUnit('125'), 128)
625     self.assertEqual(ParseUnit('126'), 128)
626     self.assertEqual(ParseUnit('127'), 128)
627     self.assertEqual(ParseUnit('128'), 128)
628     self.assertEqual(ParseUnit('129'), 132)
629     self.assertEqual(ParseUnit('130'), 132)
630
631   def testFloating(self):
632     self.assertEqual(ParseUnit('0'), 0)
633     self.assertEqual(ParseUnit('0.5'), 4)
634     self.assertEqual(ParseUnit('1.75'), 4)
635     self.assertEqual(ParseUnit('1.99'), 4)
636     self.assertEqual(ParseUnit('2.00'), 4)
637     self.assertEqual(ParseUnit('2.01'), 4)
638     self.assertEqual(ParseUnit('3.99'), 4)
639     self.assertEqual(ParseUnit('4.00'), 4)
640     self.assertEqual(ParseUnit('4.01'), 8)
641     self.assertEqual(ParseUnit('1.5G'), 1536)
642     self.assertEqual(ParseUnit('1.8G'), 1844)
643     self.assertEqual(ParseUnit('8.28T'), 8682212)
644
645   def testSuffixes(self):
646     for sep in ('', ' ', '   ', "\t", "\t "):
647       for suffix, scale in TestParseUnit.SCALES:
648         for func in (lambda x: x, str.lower, str.upper):
649           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
650                            1024 * scale)
651
652   def testInvalidInput(self):
653     for sep in ('-', '_', ',', 'a'):
654       for suffix, _ in TestParseUnit.SCALES:
655         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
656
657     for suffix, _ in TestParseUnit.SCALES:
658       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
659
660
661 class TestSshKeys(testutils.GanetiTestCase):
662   """Test case for the AddAuthorizedKey function"""
663
664   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
665   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
666            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
667
668   def setUp(self):
669     testutils.GanetiTestCase.setUp(self)
670     self.tmpname = self._CreateTempFile()
671     handle = open(self.tmpname, 'w')
672     try:
673       handle.write("%s\n" % TestSshKeys.KEY_A)
674       handle.write("%s\n" % TestSshKeys.KEY_B)
675     finally:
676       handle.close()
677
678   def testAddingNewKey(self):
679     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
680
681     self.assertFileContent(self.tmpname,
682       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
683       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
684       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
685       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
686
687   def testAddingAlmostButNotCompletelyTheSameKey(self):
688     AddAuthorizedKey(self.tmpname,
689         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
690
691     self.assertFileContent(self.tmpname,
692       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
693       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
694       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
695       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
696
697   def testAddingExistingKeyWithSomeMoreSpaces(self):
698     AddAuthorizedKey(self.tmpname,
699         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
700
701     self.assertFileContent(self.tmpname,
702       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
703       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
704       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
705
706   def testRemovingExistingKeyWithSomeMoreSpaces(self):
707     RemoveAuthorizedKey(self.tmpname,
708         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
709
710     self.assertFileContent(self.tmpname,
711       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
712       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
713
714   def testRemovingNonExistingKey(self):
715     RemoveAuthorizedKey(self.tmpname,
716         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
717
718     self.assertFileContent(self.tmpname,
719       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
720       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
721       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
722
723
724 class TestEtcHosts(testutils.GanetiTestCase):
725   """Test functions modifying /etc/hosts"""
726
727   def setUp(self):
728     testutils.GanetiTestCase.setUp(self)
729     self.tmpname = self._CreateTempFile()
730     handle = open(self.tmpname, 'w')
731     try:
732       handle.write('# This is a test file for /etc/hosts\n')
733       handle.write('127.0.0.1\tlocalhost\n')
734       handle.write('192.168.1.1 router gw\n')
735     finally:
736       handle.close()
737
738   def testSettingNewIp(self):
739     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
740
741     self.assertFileContent(self.tmpname,
742       "# This is a test file for /etc/hosts\n"
743       "127.0.0.1\tlocalhost\n"
744       "192.168.1.1 router gw\n"
745       "1.2.3.4\tmyhost.domain.tld myhost\n")
746     self.assertFileMode(self.tmpname, 0644)
747
748   def testSettingExistingIp(self):
749     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
750                      ['myhost'])
751
752     self.assertFileContent(self.tmpname,
753       "# This is a test file for /etc/hosts\n"
754       "127.0.0.1\tlocalhost\n"
755       "192.168.1.1\tmyhost.domain.tld myhost\n")
756     self.assertFileMode(self.tmpname, 0644)
757
758   def testSettingDuplicateName(self):
759     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
760
761     self.assertFileContent(self.tmpname,
762       "# This is a test file for /etc/hosts\n"
763       "127.0.0.1\tlocalhost\n"
764       "192.168.1.1 router gw\n"
765       "1.2.3.4\tmyhost\n")
766     self.assertFileMode(self.tmpname, 0644)
767
768   def testRemovingExistingHost(self):
769     RemoveEtcHostsEntry(self.tmpname, 'router')
770
771     self.assertFileContent(self.tmpname,
772       "# This is a test file for /etc/hosts\n"
773       "127.0.0.1\tlocalhost\n"
774       "192.168.1.1 gw\n")
775     self.assertFileMode(self.tmpname, 0644)
776
777   def testRemovingSingleExistingHost(self):
778     RemoveEtcHostsEntry(self.tmpname, 'localhost')
779
780     self.assertFileContent(self.tmpname,
781       "# This is a test file for /etc/hosts\n"
782       "192.168.1.1 router gw\n")
783     self.assertFileMode(self.tmpname, 0644)
784
785   def testRemovingNonExistingHost(self):
786     RemoveEtcHostsEntry(self.tmpname, 'myhost')
787
788     self.assertFileContent(self.tmpname,
789       "# This is a test file for /etc/hosts\n"
790       "127.0.0.1\tlocalhost\n"
791       "192.168.1.1 router gw\n")
792     self.assertFileMode(self.tmpname, 0644)
793
794   def testRemovingAlias(self):
795     RemoveEtcHostsEntry(self.tmpname, 'gw')
796
797     self.assertFileContent(self.tmpname,
798       "# This is a test file for /etc/hosts\n"
799       "127.0.0.1\tlocalhost\n"
800       "192.168.1.1 router\n")
801     self.assertFileMode(self.tmpname, 0644)
802
803
804 class TestShellQuoting(unittest.TestCase):
805   """Test case for shell quoting functions"""
806
807   def testShellQuote(self):
808     self.assertEqual(ShellQuote('abc'), "abc")
809     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
810     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
811     self.assertEqual(ShellQuote("a b c"), "'a b c'")
812     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
813
814   def testShellQuoteArgs(self):
815     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
816     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
817     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
818
819
820 class TestTcpPing(unittest.TestCase):
821   """Testcase for TCP version of ping - against listen(2)ing port"""
822
823   def setUp(self):
824     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
825     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
826     self.listenerport = self.listener.getsockname()[1]
827     self.listener.listen(1)
828
829   def tearDown(self):
830     self.listener.shutdown(socket.SHUT_RDWR)
831     del self.listener
832     del self.listenerport
833
834   def testTcpPingToLocalHostAccept(self):
835     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
836                          self.listenerport,
837                          timeout=10,
838                          live_port_needed=True,
839                          source=constants.LOCALHOST_IP_ADDRESS,
840                          ),
841                  "failed to connect to test listener")
842
843     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
844                          self.listenerport,
845                          timeout=10,
846                          live_port_needed=True,
847                          ),
848                  "failed to connect to test listener (no source)")
849
850
851 class TestTcpPingDeaf(unittest.TestCase):
852   """Testcase for TCP version of ping - against non listen(2)ing port"""
853
854   def setUp(self):
855     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
856     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
857     self.deaflistenerport = self.deaflistener.getsockname()[1]
858
859   def tearDown(self):
860     del self.deaflistener
861     del self.deaflistenerport
862
863   def testTcpPingToLocalHostAcceptDeaf(self):
864     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
865                         self.deaflistenerport,
866                         timeout=constants.TCP_PING_TIMEOUT,
867                         live_port_needed=True,
868                         source=constants.LOCALHOST_IP_ADDRESS,
869                         ), # need successful connect(2)
870                 "successfully connected to deaf listener")
871
872     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
873                         self.deaflistenerport,
874                         timeout=constants.TCP_PING_TIMEOUT,
875                         live_port_needed=True,
876                         ), # need successful connect(2)
877                 "successfully connected to deaf listener (no source addr)")
878
879   def testTcpPingToLocalHostNoAccept(self):
880     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
881                          self.deaflistenerport,
882                          timeout=constants.TCP_PING_TIMEOUT,
883                          live_port_needed=False,
884                          source=constants.LOCALHOST_IP_ADDRESS,
885                          ), # ECONNREFUSED is OK
886                  "failed to ping alive host on deaf port")
887
888     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
889                          self.deaflistenerport,
890                          timeout=constants.TCP_PING_TIMEOUT,
891                          live_port_needed=False,
892                          ), # ECONNREFUSED is OK
893                  "failed to ping alive host on deaf port (no source addr)")
894
895
896 class TestOwnIpAddress(unittest.TestCase):
897   """Testcase for OwnIpAddress"""
898
899   def testOwnLoopback(self):
900     """check having the loopback ip"""
901     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
902                     "Should own the loopback address")
903
904   def testNowOwnAddress(self):
905     """check that I don't own an address"""
906
907     # network 192.0.2.0/24 is reserved for test/documentation as per
908     # rfc 3330, so we *should* not have an address of this range... if
909     # this fails, we should extend the test to multiple addresses
910     DST_IP = "192.0.2.1"
911     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
912
913
914 def _GetSocketCredentials(path):
915   """Connect to a Unix socket and return remote credentials.
916
917   """
918   sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
919   try:
920     sock.settimeout(10)
921     sock.connect(path)
922     return utils.GetSocketCredentials(sock)
923   finally:
924     sock.close()
925
926
927 class TestGetSocketCredentials(unittest.TestCase):
928   def setUp(self):
929     self.tmpdir = tempfile.mkdtemp()
930     self.sockpath = utils.PathJoin(self.tmpdir, "sock")
931
932     self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
933     self.listener.settimeout(10)
934     self.listener.bind(self.sockpath)
935     self.listener.listen(1)
936
937   def tearDown(self):
938     self.listener.shutdown(socket.SHUT_RDWR)
939     self.listener.close()
940     shutil.rmtree(self.tmpdir)
941
942   def test(self):
943     (c2pr, c2pw) = os.pipe()
944
945     # Start child process
946     child = os.fork()
947     if child == 0:
948       try:
949         data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
950
951         os.write(c2pw, data)
952         os.close(c2pw)
953
954         os._exit(0)
955       finally:
956         os._exit(1)
957
958     os.close(c2pw)
959
960     # Wait for one connection
961     (conn, _) = self.listener.accept()
962     conn.recv(1)
963     conn.close()
964
965     # Wait for result
966     result = os.read(c2pr, 4096)
967     os.close(c2pr)
968
969     # Check child's exit code
970     (_, status) = os.waitpid(child, 0)
971     self.assertFalse(os.WIFSIGNALED(status))
972     self.assertEqual(os.WEXITSTATUS(status), 0)
973
974     # Check result
975     (pid, uid, gid) = serializer.LoadJson(result)
976     self.assertEqual(pid, os.getpid())
977     self.assertEqual(uid, os.getuid())
978     self.assertEqual(gid, os.getgid())
979
980
981 class TestListVisibleFiles(unittest.TestCase):
982   """Test case for ListVisibleFiles"""
983
984   def setUp(self):
985     self.path = tempfile.mkdtemp()
986
987   def tearDown(self):
988     shutil.rmtree(self.path)
989
990   def _test(self, files, expected):
991     # Sort a copy
992     expected = expected[:]
993     expected.sort()
994
995     for name in files:
996       f = open(os.path.join(self.path, name), 'w')
997       try:
998         f.write("Test\n")
999       finally:
1000         f.close()
1001
1002     found = ListVisibleFiles(self.path)
1003     found.sort()
1004
1005     self.assertEqual(found, expected)
1006
1007   def testAllVisible(self):
1008     files = ["a", "b", "c"]
1009     expected = files
1010     self._test(files, expected)
1011
1012   def testNoneVisible(self):
1013     files = [".a", ".b", ".c"]
1014     expected = []
1015     self._test(files, expected)
1016
1017   def testSomeVisible(self):
1018     files = ["a", "b", ".c"]
1019     expected = ["a", "b"]
1020     self._test(files, expected)
1021
1022   def testNonAbsolutePath(self):
1023     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1024
1025   def testNonNormalizedPath(self):
1026     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1027                           "/bin/../tmp")
1028
1029
1030 class TestNewUUID(unittest.TestCase):
1031   """Test case for NewUUID"""
1032
1033   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1034                         '[a-f0-9]{4}-[a-f0-9]{12}$')
1035
1036   def runTest(self):
1037     self.failUnless(self._re_uuid.match(utils.NewUUID()))
1038
1039
1040 class TestUniqueSequence(unittest.TestCase):
1041   """Test case for UniqueSequence"""
1042
1043   def _test(self, input, expected):
1044     self.assertEqual(utils.UniqueSequence(input), expected)
1045
1046   def runTest(self):
1047     # Ordered input
1048     self._test([1, 2, 3], [1, 2, 3])
1049     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1050     self._test([1, 2, 2, 3], [1, 2, 3])
1051     self._test([1, 2, 3, 3], [1, 2, 3])
1052
1053     # Unordered input
1054     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1055     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1056
1057     # Strings
1058     self._test(["a", "a"], ["a"])
1059     self._test(["a", "b"], ["a", "b"])
1060     self._test(["a", "b", "a"], ["a", "b"])
1061
1062
1063 class TestFirstFree(unittest.TestCase):
1064   """Test case for the FirstFree function"""
1065
1066   def test(self):
1067     """Test FirstFree"""
1068     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1069     self.failUnlessEqual(FirstFree([]), None)
1070     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1071     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1072     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1073
1074
1075 class TestTailFile(testutils.GanetiTestCase):
1076   """Test case for the TailFile function"""
1077
1078   def testEmpty(self):
1079     fname = self._CreateTempFile()
1080     self.failUnlessEqual(TailFile(fname), [])
1081     self.failUnlessEqual(TailFile(fname, lines=25), [])
1082
1083   def testAllLines(self):
1084     data = ["test %d" % i for i in range(30)]
1085     for i in range(30):
1086       fname = self._CreateTempFile()
1087       fd = open(fname, "w")
1088       fd.write("\n".join(data[:i]))
1089       if i > 0:
1090         fd.write("\n")
1091       fd.close()
1092       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1093
1094   def testPartialLines(self):
1095     data = ["test %d" % i for i in range(30)]
1096     fname = self._CreateTempFile()
1097     fd = open(fname, "w")
1098     fd.write("\n".join(data))
1099     fd.write("\n")
1100     fd.close()
1101     for i in range(1, 30):
1102       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1103
1104   def testBigFile(self):
1105     data = ["test %d" % i for i in range(30)]
1106     fname = self._CreateTempFile()
1107     fd = open(fname, "w")
1108     fd.write("X" * 1048576)
1109     fd.write("\n")
1110     fd.write("\n".join(data))
1111     fd.write("\n")
1112     fd.close()
1113     for i in range(1, 30):
1114       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1115
1116
1117 class _BaseFileLockTest:
1118   """Test case for the FileLock class"""
1119
1120   def testSharedNonblocking(self):
1121     self.lock.Shared(blocking=False)
1122     self.lock.Close()
1123
1124   def testExclusiveNonblocking(self):
1125     self.lock.Exclusive(blocking=False)
1126     self.lock.Close()
1127
1128   def testUnlockNonblocking(self):
1129     self.lock.Unlock(blocking=False)
1130     self.lock.Close()
1131
1132   def testSharedBlocking(self):
1133     self.lock.Shared(blocking=True)
1134     self.lock.Close()
1135
1136   def testExclusiveBlocking(self):
1137     self.lock.Exclusive(blocking=True)
1138     self.lock.Close()
1139
1140   def testUnlockBlocking(self):
1141     self.lock.Unlock(blocking=True)
1142     self.lock.Close()
1143
1144   def testSharedExclusiveUnlock(self):
1145     self.lock.Shared(blocking=False)
1146     self.lock.Exclusive(blocking=False)
1147     self.lock.Unlock(blocking=False)
1148     self.lock.Close()
1149
1150   def testExclusiveSharedUnlock(self):
1151     self.lock.Exclusive(blocking=False)
1152     self.lock.Shared(blocking=False)
1153     self.lock.Unlock(blocking=False)
1154     self.lock.Close()
1155
1156   def testSimpleTimeout(self):
1157     # These will succeed on the first attempt, hence a short timeout
1158     self.lock.Shared(blocking=True, timeout=10.0)
1159     self.lock.Exclusive(blocking=False, timeout=10.0)
1160     self.lock.Unlock(blocking=True, timeout=10.0)
1161     self.lock.Close()
1162
1163   @staticmethod
1164   def _TryLockInner(filename, shared, blocking):
1165     lock = utils.FileLock.Open(filename)
1166
1167     if shared:
1168       fn = lock.Shared
1169     else:
1170       fn = lock.Exclusive
1171
1172     try:
1173       # The timeout doesn't really matter as the parent process waits for us to
1174       # finish anyway.
1175       fn(blocking=blocking, timeout=0.01)
1176     except errors.LockError, err:
1177       return False
1178
1179     return True
1180
1181   def _TryLock(self, *args):
1182     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1183                                       *args)
1184
1185   def testTimeout(self):
1186     for blocking in [True, False]:
1187       self.lock.Exclusive(blocking=True)
1188       self.failIf(self._TryLock(False, blocking))
1189       self.failIf(self._TryLock(True, blocking))
1190
1191       self.lock.Shared(blocking=True)
1192       self.assert_(self._TryLock(True, blocking))
1193       self.failIf(self._TryLock(False, blocking))
1194
1195   def testCloseShared(self):
1196     self.lock.Close()
1197     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1198
1199   def testCloseExclusive(self):
1200     self.lock.Close()
1201     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1202
1203   def testCloseUnlock(self):
1204     self.lock.Close()
1205     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1206
1207
1208 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1209   TESTDATA = "Hello World\n" * 10
1210
1211   def setUp(self):
1212     testutils.GanetiTestCase.setUp(self)
1213
1214     self.tmpfile = tempfile.NamedTemporaryFile()
1215     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1216     self.lock = utils.FileLock.Open(self.tmpfile.name)
1217
1218     # Ensure "Open" didn't truncate file
1219     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1220
1221   def tearDown(self):
1222     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1223
1224     testutils.GanetiTestCase.tearDown(self)
1225
1226
1227 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1228   def setUp(self):
1229     self.tmpfile = tempfile.NamedTemporaryFile()
1230     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1231
1232
1233 class TestTimeFunctions(unittest.TestCase):
1234   """Test case for time functions"""
1235
1236   def runTest(self):
1237     self.assertEqual(utils.SplitTime(1), (1, 0))
1238     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1239     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1240     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1241     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1242     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1243     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1244     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1245
1246     self.assertRaises(AssertionError, utils.SplitTime, -1)
1247
1248     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1249     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1250     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1251
1252     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1253                      1218448917.481)
1254     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1255
1256     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1257     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1258     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1259     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1260     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1261
1262
1263 class FieldSetTestCase(unittest.TestCase):
1264   """Test case for FieldSets"""
1265
1266   def testSimpleMatch(self):
1267     f = utils.FieldSet("a", "b", "c", "def")
1268     self.failUnless(f.Matches("a"))
1269     self.failIf(f.Matches("d"), "Substring matched")
1270     self.failIf(f.Matches("defghi"), "Prefix string matched")
1271     self.failIf(f.NonMatching(["b", "c"]))
1272     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1273     self.failUnless(f.NonMatching(["a", "d"]))
1274
1275   def testRegexMatch(self):
1276     f = utils.FieldSet("a", "b([0-9]+)", "c")
1277     self.failUnless(f.Matches("b1"))
1278     self.failUnless(f.Matches("b99"))
1279     self.failIf(f.Matches("b/1"))
1280     self.failIf(f.NonMatching(["b12", "c"]))
1281     self.failUnless(f.NonMatching(["a", "1"]))
1282
1283 class TestForceDictType(unittest.TestCase):
1284   """Test case for ForceDictType"""
1285
1286   def setUp(self):
1287     self.key_types = {
1288       'a': constants.VTYPE_INT,
1289       'b': constants.VTYPE_BOOL,
1290       'c': constants.VTYPE_STRING,
1291       'd': constants.VTYPE_SIZE,
1292       }
1293
1294   def _fdt(self, dict, allowed_values=None):
1295     if allowed_values is None:
1296       ForceDictType(dict, self.key_types)
1297     else:
1298       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1299
1300     return dict
1301
1302   def testSimpleDict(self):
1303     self.assertEqual(self._fdt({}), {})
1304     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1305     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1306     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1307     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1308     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1309     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1310     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1311     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1312     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1313     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1314     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1315
1316   def testErrors(self):
1317     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1318     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1319     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1320     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1321
1322
1323 class TestIsAbsNormPath(unittest.TestCase):
1324   """Testing case for IsProcessAlive"""
1325
1326   def _pathTestHelper(self, path, result):
1327     if result:
1328       self.assert_(IsNormAbsPath(path),
1329           "Path %s should result absolute and normalized" % path)
1330     else:
1331       self.assert_(not IsNormAbsPath(path),
1332           "Path %s should not result absolute and normalized" % path)
1333
1334   def testBase(self):
1335     self._pathTestHelper('/etc', True)
1336     self._pathTestHelper('/srv', True)
1337     self._pathTestHelper('etc', False)
1338     self._pathTestHelper('/etc/../root', False)
1339     self._pathTestHelper('/etc/', False)
1340
1341
1342 class TestSafeEncode(unittest.TestCase):
1343   """Test case for SafeEncode"""
1344
1345   def testAscii(self):
1346     for txt in [string.digits, string.letters, string.punctuation]:
1347       self.failUnlessEqual(txt, SafeEncode(txt))
1348
1349   def testDoubleEncode(self):
1350     for i in range(255):
1351       txt = SafeEncode(chr(i))
1352       self.failUnlessEqual(txt, SafeEncode(txt))
1353
1354   def testUnicode(self):
1355     # 1024 is high enough to catch non-direct ASCII mappings
1356     for i in range(1024):
1357       txt = SafeEncode(unichr(i))
1358       self.failUnlessEqual(txt, SafeEncode(txt))
1359
1360
1361 class TestFormatTime(unittest.TestCase):
1362   """Testing case for FormatTime"""
1363
1364   def testNone(self):
1365     self.failUnlessEqual(FormatTime(None), "N/A")
1366
1367   def testInvalid(self):
1368     self.failUnlessEqual(FormatTime(()), "N/A")
1369
1370   def testNow(self):
1371     # tests that we accept time.time input
1372     FormatTime(time.time())
1373     # tests that we accept int input
1374     FormatTime(int(time.time()))
1375
1376
1377 class RunInSeparateProcess(unittest.TestCase):
1378   def test(self):
1379     for exp in [True, False]:
1380       def _child():
1381         return exp
1382
1383       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1384
1385   def testArgs(self):
1386     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1387       def _child(carg1, carg2):
1388         return carg1 == "Foo" and carg2 == arg
1389
1390       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1391
1392   def testPid(self):
1393     parent_pid = os.getpid()
1394
1395     def _check():
1396       return os.getpid() == parent_pid
1397
1398     self.failIf(utils.RunInSeparateProcess(_check))
1399
1400   def testSignal(self):
1401     def _kill():
1402       os.kill(os.getpid(), signal.SIGTERM)
1403
1404     self.assertRaises(errors.GenericError,
1405                       utils.RunInSeparateProcess, _kill)
1406
1407   def testException(self):
1408     def _exc():
1409       raise errors.GenericError("This is a test")
1410
1411     self.assertRaises(errors.GenericError,
1412                       utils.RunInSeparateProcess, _exc)
1413
1414
1415 class TestFingerprintFile(unittest.TestCase):
1416   def setUp(self):
1417     self.tmpfile = tempfile.NamedTemporaryFile()
1418
1419   def test(self):
1420     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1421                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1422
1423     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1424     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1425                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1426
1427
1428 class TestUnescapeAndSplit(unittest.TestCase):
1429   """Testing case for UnescapeAndSplit"""
1430
1431   def setUp(self):
1432     # testing more that one separator for regexp safety
1433     self._seps = [",", "+", "."]
1434
1435   def testSimple(self):
1436     a = ["a", "b", "c", "d"]
1437     for sep in self._seps:
1438       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1439
1440   def testEscape(self):
1441     for sep in self._seps:
1442       a = ["a", "b\\" + sep + "c", "d"]
1443       b = ["a", "b" + sep + "c", "d"]
1444       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1445
1446   def testDoubleEscape(self):
1447     for sep in self._seps:
1448       a = ["a", "b\\\\", "c", "d"]
1449       b = ["a", "b\\", "c", "d"]
1450       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1451
1452   def testThreeEscape(self):
1453     for sep in self._seps:
1454       a = ["a", "b\\\\\\" + sep + "c", "d"]
1455       b = ["a", "b\\" + sep + "c", "d"]
1456       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1457
1458
1459 class TestPathJoin(unittest.TestCase):
1460   """Testing case for PathJoin"""
1461
1462   def testBasicItems(self):
1463     mlist = ["/a", "b", "c"]
1464     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1465
1466   def testNonAbsPrefix(self):
1467     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1468
1469   def testBackTrack(self):
1470     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1471
1472   def testMultiAbs(self):
1473     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1474
1475
1476 class TestHostInfo(unittest.TestCase):
1477   """Testing case for HostInfo"""
1478
1479   def testUppercase(self):
1480     data = "AbC.example.com"
1481     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1482
1483   def testTooLongName(self):
1484     data = "a.b." + "c" * 255
1485     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1486
1487   def testTrailingDot(self):
1488     data = "a.b.c"
1489     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1490
1491   def testInvalidName(self):
1492     data = [
1493       "a b",
1494       "a/b",
1495       ".a.b",
1496       "a..b",
1497       ]
1498     for value in data:
1499       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1500
1501   def testValidName(self):
1502     data = [
1503       "a.b",
1504       "a-b",
1505       "a_b",
1506       "a.b.c",
1507       ]
1508     for value in data:
1509       HostInfo.NormalizeName(value)
1510
1511
1512 class TestParseAsn1Generalizedtime(unittest.TestCase):
1513   def test(self):
1514     # UTC
1515     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1516     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1517                      1266860512)
1518     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1519                      (2**31) - 1)
1520
1521     # With offset
1522     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1523                      1266860512)
1524     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1525                      1266931012)
1526     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1527                      1266931088)
1528     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1529                      1266931295)
1530     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1531                      3600)
1532
1533     # Leap seconds are not supported by datetime.datetime
1534     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1535                       "19841231235960+0000")
1536     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1537                       "19920630235960+0000")
1538
1539     # Errors
1540     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1541     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1542     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1543                       "20100222174152")
1544     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1545                       "Mon Feb 22 17:47:02 UTC 2010")
1546     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1547                       "2010-02-22 17:42:02")
1548
1549
1550 class TestGetX509CertValidity(testutils.GanetiTestCase):
1551   def setUp(self):
1552     testutils.GanetiTestCase.setUp(self)
1553
1554     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1555
1556     # Test whether we have pyOpenSSL 0.7 or above
1557     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1558
1559     if not self.pyopenssl0_7:
1560       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1561                     " function correctly")
1562
1563   def _LoadCert(self, name):
1564     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1565                                            self._ReadTestData(name))
1566
1567   def test(self):
1568     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1569     if self.pyopenssl0_7:
1570       self.assertEqual(validity, (1266919967, 1267524767))
1571     else:
1572       self.assertEqual(validity, (None, None))
1573
1574
1575 class TestMakedirs(unittest.TestCase):
1576   def setUp(self):
1577     self.tmpdir = tempfile.mkdtemp()
1578
1579   def tearDown(self):
1580     shutil.rmtree(self.tmpdir)
1581
1582   def testNonExisting(self):
1583     path = utils.PathJoin(self.tmpdir, "foo")
1584     utils.Makedirs(path)
1585     self.assert_(os.path.isdir(path))
1586
1587   def testExisting(self):
1588     path = utils.PathJoin(self.tmpdir, "foo")
1589     os.mkdir(path)
1590     utils.Makedirs(path)
1591     self.assert_(os.path.isdir(path))
1592
1593   def testRecursiveNonExisting(self):
1594     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
1595     utils.Makedirs(path)
1596     self.assert_(os.path.isdir(path))
1597
1598   def testRecursiveExisting(self):
1599     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
1600     self.assert_(not os.path.exists(path))
1601     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
1602     utils.Makedirs(path)
1603     self.assert_(os.path.isdir(path))
1604
1605
1606 class TestRetry(testutils.GanetiTestCase):
1607   def setUp(self):
1608     testutils.GanetiTestCase.setUp(self)
1609     self.retries = 0
1610
1611   @staticmethod
1612   def _RaiseRetryAgain():
1613     raise utils.RetryAgain()
1614
1615   @staticmethod
1616   def _RaiseRetryAgainWithArg(args):
1617     raise utils.RetryAgain(*args)
1618
1619   def _WrongNestedLoop(self):
1620     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1621
1622   def _RetryAndSucceed(self, retries):
1623     if self.retries < retries:
1624       self.retries += 1
1625       raise utils.RetryAgain()
1626     else:
1627       return True
1628
1629   def testRaiseTimeout(self):
1630     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1631                           self._RaiseRetryAgain, 0.01, 0.02)
1632     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1633                           self._RetryAndSucceed, 0.01, 0, args=[1])
1634     self.failUnlessEqual(self.retries, 1)
1635
1636   def testComplete(self):
1637     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1638     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1639                          True)
1640     self.failUnlessEqual(self.retries, 2)
1641
1642   def testNestedLoop(self):
1643     try:
1644       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1645                             self._WrongNestedLoop, 0, 1)
1646     except utils.RetryTimeout:
1647       self.fail("Didn't detect inner loop's exception")
1648
1649   def testTimeoutArgument(self):
1650     retry_arg="my_important_debugging_message"
1651     try:
1652       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
1653     except utils.RetryTimeout, err:
1654       self.failUnlessEqual(err.args, (retry_arg, ))
1655     else:
1656       self.fail("Expected timeout didn't happen")
1657
1658   def testRaiseInnerWithExc(self):
1659     retry_arg="my_important_debugging_message"
1660     try:
1661       try:
1662         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1663                     args=[[errors.GenericError(retry_arg, retry_arg)]])
1664       except utils.RetryTimeout, err:
1665         err.RaiseInner()
1666       else:
1667         self.fail("Expected timeout didn't happen")
1668     except errors.GenericError, err:
1669       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1670     else:
1671       self.fail("Expected GenericError didn't happen")
1672
1673   def testRaiseInnerWithMsg(self):
1674     retry_arg="my_important_debugging_message"
1675     try:
1676       try:
1677         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1678                     args=[[retry_arg, retry_arg]])
1679       except utils.RetryTimeout, err:
1680         err.RaiseInner()
1681       else:
1682         self.fail("Expected timeout didn't happen")
1683     except utils.RetryTimeout, err:
1684       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1685     else:
1686       self.fail("Expected RetryTimeout didn't happen")
1687
1688
1689 class TestLineSplitter(unittest.TestCase):
1690   def test(self):
1691     lines = []
1692     ls = utils.LineSplitter(lines.append)
1693     ls.write("Hello World\n")
1694     self.assertEqual(lines, [])
1695     ls.write("Foo\n Bar\r\n ")
1696     ls.write("Baz")
1697     ls.write("Moo")
1698     self.assertEqual(lines, [])
1699     ls.flush()
1700     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
1701     ls.close()
1702     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
1703
1704   def _testExtra(self, line, all_lines, p1, p2):
1705     self.assertEqual(p1, 999)
1706     self.assertEqual(p2, "extra")
1707     all_lines.append(line)
1708
1709   def testExtraArgsNoFlush(self):
1710     lines = []
1711     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
1712     ls.write("\n\nHello World\n")
1713     ls.write("Foo\n Bar\r\n ")
1714     ls.write("")
1715     ls.write("Baz")
1716     ls.write("Moo\n\nx\n")
1717     self.assertEqual(lines, [])
1718     ls.close()
1719     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
1720                              "", "x"])
1721
1722
1723 if __name__ == '__main__':
1724   testutils.GanetiTestProgram()