Add checks for master IP in cluster verify
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import stat
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37 import OpenSSL
38 import warnings
39 import distutils.version
40 import glob
41 import md5
42 import errno
43
44 import ganeti
45 import testutils
46 from ganeti import constants
47 from ganeti import utils
48 from ganeti import errors
49 from ganeti import serializer
50 from ganeti.utils import IsProcessAlive, RunCmd, \
51      RemoveFile, MatchNameComponent, FormatUnit, \
52      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
53      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
54      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
55      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
56      UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
57
58 from ganeti.errors import LockError, UnitParseError, GenericError, \
59      ProgrammerError, OpPrereqError
60
61
62 class TestIsProcessAlive(unittest.TestCase):
63   """Testing case for IsProcessAlive"""
64
65   def testExists(self):
66     mypid = os.getpid()
67     self.assert_(IsProcessAlive(mypid),
68                  "can't find myself running")
69
70   def testNotExisting(self):
71     pid_non_existing = os.fork()
72     if pid_non_existing == 0:
73       os._exit(0)
74     elif pid_non_existing < 0:
75       raise SystemError("can't fork")
76     os.waitpid(pid_non_existing, 0)
77     self.assert_(not IsProcessAlive(pid_non_existing),
78                  "nonexisting process detected")
79
80
81 class TestPidFileFunctions(unittest.TestCase):
82   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
83
84   def setUp(self):
85     self.dir = tempfile.mkdtemp()
86     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
87     utils.DaemonPidFileName = self.f_dpn
88
89   def testPidFileFunctions(self):
90     pid_file = self.f_dpn('test')
91     utils.WritePidFile('test')
92     self.failUnless(os.path.exists(pid_file),
93                     "PID file should have been created")
94     read_pid = utils.ReadPidFile(pid_file)
95     self.failUnlessEqual(read_pid, os.getpid())
96     self.failUnless(utils.IsProcessAlive(read_pid))
97     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
98     utils.RemovePidFile('test')
99     self.failIf(os.path.exists(pid_file),
100                 "PID file should not exist anymore")
101     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
102                          "ReadPidFile should return 0 for missing pid file")
103     fh = open(pid_file, "w")
104     fh.write("blah\n")
105     fh.close()
106     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
107                          "ReadPidFile should return 0 for invalid pid file")
108     utils.RemovePidFile('test')
109     self.failIf(os.path.exists(pid_file),
110                 "PID file should not exist anymore")
111
112   def testKill(self):
113     pid_file = self.f_dpn('child')
114     r_fd, w_fd = os.pipe()
115     new_pid = os.fork()
116     if new_pid == 0: #child
117       utils.WritePidFile('child')
118       os.write(w_fd, 'a')
119       signal.pause()
120       os._exit(0)
121       return
122     # else we are in the parent
123     # wait until the child has written the pid file
124     os.read(r_fd, 1)
125     read_pid = utils.ReadPidFile(pid_file)
126     self.failUnlessEqual(read_pid, new_pid)
127     self.failUnless(utils.IsProcessAlive(new_pid))
128     utils.KillProcess(new_pid, waitpid=True)
129     self.failIf(utils.IsProcessAlive(new_pid))
130     utils.RemovePidFile('child')
131     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
132
133   def tearDown(self):
134     for name in os.listdir(self.dir):
135       os.unlink(os.path.join(self.dir, name))
136     os.rmdir(self.dir)
137
138
139 class TestRunCmd(testutils.GanetiTestCase):
140   """Testing case for the RunCmd function"""
141
142   def setUp(self):
143     testutils.GanetiTestCase.setUp(self)
144     self.magic = time.ctime() + " ganeti test"
145     self.fname = self._CreateTempFile()
146
147   def testOk(self):
148     """Test successful exit code"""
149     result = RunCmd("/bin/sh -c 'exit 0'")
150     self.assertEqual(result.exit_code, 0)
151     self.assertEqual(result.output, "")
152
153   def testFail(self):
154     """Test fail exit code"""
155     result = RunCmd("/bin/sh -c 'exit 1'")
156     self.assertEqual(result.exit_code, 1)
157     self.assertEqual(result.output, "")
158
159   def testStdout(self):
160     """Test standard output"""
161     cmd = 'echo -n "%s"' % self.magic
162     result = RunCmd("/bin/sh -c '%s'" % cmd)
163     self.assertEqual(result.stdout, self.magic)
164     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
165     self.assertEqual(result.output, "")
166     self.assertFileContent(self.fname, self.magic)
167
168   def testStderr(self):
169     """Test standard error"""
170     cmd = 'echo -n "%s"' % self.magic
171     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
172     self.assertEqual(result.stderr, self.magic)
173     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
174     self.assertEqual(result.output, "")
175     self.assertFileContent(self.fname, self.magic)
176
177   def testCombined(self):
178     """Test combined output"""
179     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
180     expected = "A" + self.magic + "B" + self.magic
181     result = RunCmd("/bin/sh -c '%s'" % cmd)
182     self.assertEqual(result.output, expected)
183     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
184     self.assertEqual(result.output, "")
185     self.assertFileContent(self.fname, expected)
186
187   def testSignal(self):
188     """Test signal"""
189     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
190     self.assertEqual(result.signal, 15)
191     self.assertEqual(result.output, "")
192
193   def testListRun(self):
194     """Test list runs"""
195     result = RunCmd(["true"])
196     self.assertEqual(result.signal, None)
197     self.assertEqual(result.exit_code, 0)
198     result = RunCmd(["/bin/sh", "-c", "exit 1"])
199     self.assertEqual(result.signal, None)
200     self.assertEqual(result.exit_code, 1)
201     result = RunCmd(["echo", "-n", self.magic])
202     self.assertEqual(result.signal, None)
203     self.assertEqual(result.exit_code, 0)
204     self.assertEqual(result.stdout, self.magic)
205
206   def testFileEmptyOutput(self):
207     """Test file output"""
208     result = RunCmd(["true"], output=self.fname)
209     self.assertEqual(result.signal, None)
210     self.assertEqual(result.exit_code, 0)
211     self.assertFileContent(self.fname, "")
212
213   def testLang(self):
214     """Test locale environment"""
215     old_env = os.environ.copy()
216     try:
217       os.environ["LANG"] = "en_US.UTF-8"
218       os.environ["LC_ALL"] = "en_US.UTF-8"
219       result = RunCmd(["locale"])
220       for line in result.output.splitlines():
221         key, value = line.split("=", 1)
222         # Ignore these variables, they're overridden by LC_ALL
223         if key == "LANG" or key == "LANGUAGE":
224           continue
225         self.failIf(value and value != "C" and value != '"C"',
226             "Variable %s is set to the invalid value '%s'" % (key, value))
227     finally:
228       os.environ = old_env
229
230   def testDefaultCwd(self):
231     """Test default working directory"""
232     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
233
234   def testCwd(self):
235     """Test default working directory"""
236     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
237     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
238     cwd = os.getcwd()
239     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
240
241   def testResetEnv(self):
242     """Test environment reset functionality"""
243     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
244     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
245                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
246
247
248 class TestRunParts(unittest.TestCase):
249   """Testing case for the RunParts function"""
250
251   def setUp(self):
252     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
253
254   def tearDown(self):
255     shutil.rmtree(self.rundir)
256
257   def testEmpty(self):
258     """Test on an empty dir"""
259     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
260
261   def testSkipWrongName(self):
262     """Test that wrong files are skipped"""
263     fname = os.path.join(self.rundir, "00test.dot")
264     utils.WriteFile(fname, data="")
265     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
266     relname = os.path.basename(fname)
267     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
268                          [(relname, constants.RUNPARTS_SKIP, None)])
269
270   def testSkipNonExec(self):
271     """Test that non executable files are skipped"""
272     fname = os.path.join(self.rundir, "00test")
273     utils.WriteFile(fname, data="")
274     relname = os.path.basename(fname)
275     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
276                          [(relname, constants.RUNPARTS_SKIP, None)])
277
278   def testError(self):
279     """Test error on a broken executable"""
280     fname = os.path.join(self.rundir, "00test")
281     utils.WriteFile(fname, data="")
282     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
283     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
284     self.failUnlessEqual(relname, os.path.basename(fname))
285     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
286     self.failUnless(error)
287
288   def testSorted(self):
289     """Test executions are sorted"""
290     files = []
291     files.append(os.path.join(self.rundir, "64test"))
292     files.append(os.path.join(self.rundir, "00test"))
293     files.append(os.path.join(self.rundir, "42test"))
294
295     for fname in files:
296       utils.WriteFile(fname, data="")
297
298     results = RunParts(self.rundir, reset_env=True)
299
300     for fname in sorted(files):
301       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
302
303   def testOk(self):
304     """Test correct execution"""
305     fname = os.path.join(self.rundir, "00test")
306     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
307     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
308     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
309     self.failUnlessEqual(relname, os.path.basename(fname))
310     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
311     self.failUnlessEqual(runresult.stdout, "ciao")
312
313   def testRunFail(self):
314     """Test correct execution, with run failure"""
315     fname = os.path.join(self.rundir, "00test")
316     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
317     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
318     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
319     self.failUnlessEqual(relname, os.path.basename(fname))
320     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
321     self.failUnlessEqual(runresult.exit_code, 1)
322     self.failUnless(runresult.failed)
323
324   def testRunMix(self):
325     files = []
326     files.append(os.path.join(self.rundir, "00test"))
327     files.append(os.path.join(self.rundir, "42test"))
328     files.append(os.path.join(self.rundir, "64test"))
329     files.append(os.path.join(self.rundir, "99test"))
330
331     files.sort()
332
333     # 1st has errors in execution
334     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
335     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
336
337     # 2nd is skipped
338     utils.WriteFile(files[1], data="")
339
340     # 3rd cannot execute properly
341     utils.WriteFile(files[2], data="")
342     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
343
344     # 4th execs
345     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
346     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
347
348     results = RunParts(self.rundir, reset_env=True)
349
350     (relname, status, runresult) = results[0]
351     self.failUnlessEqual(relname, os.path.basename(files[0]))
352     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
353     self.failUnlessEqual(runresult.exit_code, 1)
354     self.failUnless(runresult.failed)
355
356     (relname, status, runresult) = results[1]
357     self.failUnlessEqual(relname, os.path.basename(files[1]))
358     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
359     self.failUnlessEqual(runresult, None)
360
361     (relname, status, runresult) = results[2]
362     self.failUnlessEqual(relname, os.path.basename(files[2]))
363     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
364     self.failUnless(runresult)
365
366     (relname, status, runresult) = results[3]
367     self.failUnlessEqual(relname, os.path.basename(files[3]))
368     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
369     self.failUnlessEqual(runresult.output, "ciao")
370     self.failUnlessEqual(runresult.exit_code, 0)
371     self.failUnless(not runresult.failed)
372
373
374 class TestRemoveFile(unittest.TestCase):
375   """Test case for the RemoveFile function"""
376
377   def setUp(self):
378     """Create a temp dir and file for each case"""
379     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
380     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
381     os.close(fd)
382
383   def tearDown(self):
384     if os.path.exists(self.tmpfile):
385       os.unlink(self.tmpfile)
386     os.rmdir(self.tmpdir)
387
388
389   def testIgnoreDirs(self):
390     """Test that RemoveFile() ignores directories"""
391     self.assertEqual(None, RemoveFile(self.tmpdir))
392
393
394   def testIgnoreNotExisting(self):
395     """Test that RemoveFile() ignores non-existing files"""
396     RemoveFile(self.tmpfile)
397     RemoveFile(self.tmpfile)
398
399
400   def testRemoveFile(self):
401     """Test that RemoveFile does remove a file"""
402     RemoveFile(self.tmpfile)
403     if os.path.exists(self.tmpfile):
404       self.fail("File '%s' not removed" % self.tmpfile)
405
406
407   def testRemoveSymlink(self):
408     """Test that RemoveFile does remove symlinks"""
409     symlink = self.tmpdir + "/symlink"
410     os.symlink("no-such-file", symlink)
411     RemoveFile(symlink)
412     if os.path.exists(symlink):
413       self.fail("File '%s' not removed" % symlink)
414     os.symlink(self.tmpfile, symlink)
415     RemoveFile(symlink)
416     if os.path.exists(symlink):
417       self.fail("File '%s' not removed" % symlink)
418
419
420 class TestRename(unittest.TestCase):
421   """Test case for RenameFile"""
422
423   def setUp(self):
424     """Create a temporary directory"""
425     self.tmpdir = tempfile.mkdtemp()
426     self.tmpfile = os.path.join(self.tmpdir, "test1")
427
428     # Touch the file
429     open(self.tmpfile, "w").close()
430
431   def tearDown(self):
432     """Remove temporary directory"""
433     shutil.rmtree(self.tmpdir)
434
435   def testSimpleRename1(self):
436     """Simple rename 1"""
437     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
438     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
439
440   def testSimpleRename2(self):
441     """Simple rename 2"""
442     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
443                      mkdir=True)
444     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
445
446   def testRenameMkdir(self):
447     """Rename with mkdir"""
448     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
449                      mkdir=True)
450     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
451     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
452
453     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
454                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
455                      mkdir=True)
456     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
457     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
458     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
459
460
461 class TestMatchNameComponent(unittest.TestCase):
462   """Test case for the MatchNameComponent function"""
463
464   def testEmptyList(self):
465     """Test that there is no match against an empty list"""
466
467     self.failUnlessEqual(MatchNameComponent("", []), None)
468     self.failUnlessEqual(MatchNameComponent("test", []), None)
469
470   def testSingleMatch(self):
471     """Test that a single match is performed correctly"""
472     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
473     for key in "test2", "test2.example", "test2.example.com":
474       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
475
476   def testMultipleMatches(self):
477     """Test that a multiple match is returned as None"""
478     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
479     for key in "test1", "test1.example":
480       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
481
482   def testFullMatch(self):
483     """Test that a full match is returned correctly"""
484     key1 = "test1"
485     key2 = "test1.example"
486     mlist = [key2, key2 + ".com"]
487     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
488     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
489
490   def testCaseInsensitivePartialMatch(self):
491     """Test for the case_insensitive keyword"""
492     mlist = ["test1.example.com", "test2.example.net"]
493     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
494                      "test2.example.net")
495     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
496                      "test2.example.net")
497     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
498                      "test2.example.net")
499     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
500                      "test2.example.net")
501
502
503   def testCaseInsensitiveFullMatch(self):
504     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
505     # Between the two ts1 a full string match non-case insensitive should work
506     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
507                      None)
508     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
509                      "ts1.ex")
510     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
511                      "ts1.ex")
512     # Between the two ts2 only case differs, so only case-match works
513     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
514                      "ts2.ex")
515     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
516                      "Ts2.ex")
517     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
518                      None)
519
520
521 class TestReadFile(testutils.GanetiTestCase):
522
523   def testReadAll(self):
524     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
525     self.assertEqual(len(data), 814)
526
527     h = md5.new()
528     h.update(data)
529     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
530
531   def testReadSize(self):
532     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
533                           size=100)
534     self.assertEqual(len(data), 100)
535
536     h = md5.new()
537     h.update(data)
538     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
539
540   def testError(self):
541     self.assertRaises(EnvironmentError, utils.ReadFile,
542                       "/dev/null/does-not-exist")
543
544
545 class TestReadOneLineFile(testutils.GanetiTestCase):
546
547   def setUp(self):
548     testutils.GanetiTestCase.setUp(self)
549
550   def testDefault(self):
551     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
552     self.assertEqual(len(data), 27)
553     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
554
555   def testNotStrict(self):
556     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
557     self.assertEqual(len(data), 27)
558     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
559
560   def testStrictFailure(self):
561     self.assertRaises(errors.GenericError, ReadOneLineFile,
562                       self._TestDataFilename("cert1.pem"), strict=True)
563
564   def testLongLine(self):
565     dummydata = (1024 * "Hello World! ")
566     myfile = self._CreateTempFile()
567     utils.WriteFile(myfile, data=dummydata)
568     datastrict = ReadOneLineFile(myfile, strict=True)
569     datalax = ReadOneLineFile(myfile, strict=False)
570     self.assertEqual(dummydata, datastrict)
571     self.assertEqual(dummydata, datalax)
572
573   def testNewline(self):
574     myfile = self._CreateTempFile()
575     myline = "myline"
576     for nl in ["", "\n", "\r\n"]:
577       dummydata = "%s%s" % (myline, nl)
578       utils.WriteFile(myfile, data=dummydata)
579       datalax = ReadOneLineFile(myfile, strict=False)
580       self.assertEqual(myline, datalax)
581       datastrict = ReadOneLineFile(myfile, strict=True)
582       self.assertEqual(myline, datastrict)
583
584   def testWhitespaceAndMultipleLines(self):
585     myfile = self._CreateTempFile()
586     for nl in ["", "\n", "\r\n"]:
587       for ws in [" ", "\t", "\t\t  \t", "\t "]:
588         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
589         utils.WriteFile(myfile, data=dummydata)
590         datalax = ReadOneLineFile(myfile, strict=False)
591         if nl:
592           self.assert_(set("\r\n") & set(dummydata))
593           self.assertRaises(errors.GenericError, ReadOneLineFile,
594                             myfile, strict=True)
595           explen = len("Foo bar baz ") + len(ws)
596           self.assertEqual(len(datalax), explen)
597           self.assertEqual(datalax, dummydata[:explen])
598           self.assertFalse(set("\r\n") & set(datalax))
599         else:
600           datastrict = ReadOneLineFile(myfile, strict=True)
601           self.assertEqual(dummydata, datastrict)
602           self.assertEqual(dummydata, datalax)
603
604   def testEmptylines(self):
605     myfile = self._CreateTempFile()
606     myline = "myline"
607     for nl in ["\n", "\r\n"]:
608       for ol in ["", "otherline"]:
609         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
610         utils.WriteFile(myfile, data=dummydata)
611         self.assert_(set("\r\n") & set(dummydata))
612         datalax = ReadOneLineFile(myfile, strict=False)
613         self.assertEqual(myline, datalax)
614         if ol:
615           self.assertRaises(errors.GenericError, ReadOneLineFile,
616                             myfile, strict=True)
617         else:
618           datastrict = ReadOneLineFile(myfile, strict=True)
619           self.assertEqual(myline, datastrict)
620
621
622 class TestTimestampForFilename(unittest.TestCase):
623   def test(self):
624     self.assert_("." not in utils.TimestampForFilename())
625     self.assert_(":" not in utils.TimestampForFilename())
626
627
628 class TestCreateBackup(testutils.GanetiTestCase):
629   def setUp(self):
630     testutils.GanetiTestCase.setUp(self)
631
632     self.tmpdir = tempfile.mkdtemp()
633
634   def tearDown(self):
635     testutils.GanetiTestCase.tearDown(self)
636
637     shutil.rmtree(self.tmpdir)
638
639   def testEmpty(self):
640     filename = utils.PathJoin(self.tmpdir, "config.data")
641     utils.WriteFile(filename, data="")
642     bname = utils.CreateBackup(filename)
643     self.assertFileContent(bname, "")
644     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
645     utils.CreateBackup(filename)
646     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
647     utils.CreateBackup(filename)
648     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
649
650     fifoname = utils.PathJoin(self.tmpdir, "fifo")
651     os.mkfifo(fifoname)
652     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
653
654   def testContent(self):
655     bkpcount = 0
656     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
657       for rep in [1, 2, 10, 127]:
658         testdata = data * rep
659
660         filename = utils.PathJoin(self.tmpdir, "test.data_")
661         utils.WriteFile(filename, data=testdata)
662         self.assertFileContent(filename, testdata)
663
664         for _ in range(3):
665           bname = utils.CreateBackup(filename)
666           bkpcount += 1
667           self.assertFileContent(bname, testdata)
668           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
669
670
671 class TestFormatUnit(unittest.TestCase):
672   """Test case for the FormatUnit function"""
673
674   def testMiB(self):
675     self.assertEqual(FormatUnit(1, 'h'), '1M')
676     self.assertEqual(FormatUnit(100, 'h'), '100M')
677     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
678
679     self.assertEqual(FormatUnit(1, 'm'), '1')
680     self.assertEqual(FormatUnit(100, 'm'), '100')
681     self.assertEqual(FormatUnit(1023, 'm'), '1023')
682
683     self.assertEqual(FormatUnit(1024, 'm'), '1024')
684     self.assertEqual(FormatUnit(1536, 'm'), '1536')
685     self.assertEqual(FormatUnit(17133, 'm'), '17133')
686     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
687
688   def testGiB(self):
689     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
690     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
691     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
692     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
693
694     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
695     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
696     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
697     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
698
699     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
700     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
701     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
702
703   def testTiB(self):
704     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
705     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
706     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
707
708     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
709     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
710     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
711
712 class TestParseUnit(unittest.TestCase):
713   """Test case for the ParseUnit function"""
714
715   SCALES = (('', 1),
716             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
717             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
718             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
719
720   def testRounding(self):
721     self.assertEqual(ParseUnit('0'), 0)
722     self.assertEqual(ParseUnit('1'), 4)
723     self.assertEqual(ParseUnit('2'), 4)
724     self.assertEqual(ParseUnit('3'), 4)
725
726     self.assertEqual(ParseUnit('124'), 124)
727     self.assertEqual(ParseUnit('125'), 128)
728     self.assertEqual(ParseUnit('126'), 128)
729     self.assertEqual(ParseUnit('127'), 128)
730     self.assertEqual(ParseUnit('128'), 128)
731     self.assertEqual(ParseUnit('129'), 132)
732     self.assertEqual(ParseUnit('130'), 132)
733
734   def testFloating(self):
735     self.assertEqual(ParseUnit('0'), 0)
736     self.assertEqual(ParseUnit('0.5'), 4)
737     self.assertEqual(ParseUnit('1.75'), 4)
738     self.assertEqual(ParseUnit('1.99'), 4)
739     self.assertEqual(ParseUnit('2.00'), 4)
740     self.assertEqual(ParseUnit('2.01'), 4)
741     self.assertEqual(ParseUnit('3.99'), 4)
742     self.assertEqual(ParseUnit('4.00'), 4)
743     self.assertEqual(ParseUnit('4.01'), 8)
744     self.assertEqual(ParseUnit('1.5G'), 1536)
745     self.assertEqual(ParseUnit('1.8G'), 1844)
746     self.assertEqual(ParseUnit('8.28T'), 8682212)
747
748   def testSuffixes(self):
749     for sep in ('', ' ', '   ', "\t", "\t "):
750       for suffix, scale in TestParseUnit.SCALES:
751         for func in (lambda x: x, str.lower, str.upper):
752           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
753                            1024 * scale)
754
755   def testInvalidInput(self):
756     for sep in ('-', '_', ',', 'a'):
757       for suffix, _ in TestParseUnit.SCALES:
758         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
759
760     for suffix, _ in TestParseUnit.SCALES:
761       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
762
763
764 class TestSshKeys(testutils.GanetiTestCase):
765   """Test case for the AddAuthorizedKey function"""
766
767   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
768   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
769            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
770
771   def setUp(self):
772     testutils.GanetiTestCase.setUp(self)
773     self.tmpname = self._CreateTempFile()
774     handle = open(self.tmpname, 'w')
775     try:
776       handle.write("%s\n" % TestSshKeys.KEY_A)
777       handle.write("%s\n" % TestSshKeys.KEY_B)
778     finally:
779       handle.close()
780
781   def testAddingNewKey(self):
782     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
783
784     self.assertFileContent(self.tmpname,
785       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
786       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
787       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
788       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
789
790   def testAddingAlmostButNotCompletelyTheSameKey(self):
791     AddAuthorizedKey(self.tmpname,
792         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
793
794     self.assertFileContent(self.tmpname,
795       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
796       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
797       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
798       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
799
800   def testAddingExistingKeyWithSomeMoreSpaces(self):
801     AddAuthorizedKey(self.tmpname,
802         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
803
804     self.assertFileContent(self.tmpname,
805       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
806       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
807       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
808
809   def testRemovingExistingKeyWithSomeMoreSpaces(self):
810     RemoveAuthorizedKey(self.tmpname,
811         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
812
813     self.assertFileContent(self.tmpname,
814       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
815       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
816
817   def testRemovingNonExistingKey(self):
818     RemoveAuthorizedKey(self.tmpname,
819         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
820
821     self.assertFileContent(self.tmpname,
822       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
823       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
824       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
825
826
827 class TestEtcHosts(testutils.GanetiTestCase):
828   """Test functions modifying /etc/hosts"""
829
830   def setUp(self):
831     testutils.GanetiTestCase.setUp(self)
832     self.tmpname = self._CreateTempFile()
833     handle = open(self.tmpname, 'w')
834     try:
835       handle.write('# This is a test file for /etc/hosts\n')
836       handle.write('127.0.0.1\tlocalhost\n')
837       handle.write('192.168.1.1 router gw\n')
838     finally:
839       handle.close()
840
841   def testSettingNewIp(self):
842     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
843
844     self.assertFileContent(self.tmpname,
845       "# This is a test file for /etc/hosts\n"
846       "127.0.0.1\tlocalhost\n"
847       "192.168.1.1 router gw\n"
848       "1.2.3.4\tmyhost.domain.tld myhost\n")
849     self.assertFileMode(self.tmpname, 0644)
850
851   def testSettingExistingIp(self):
852     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
853                      ['myhost'])
854
855     self.assertFileContent(self.tmpname,
856       "# This is a test file for /etc/hosts\n"
857       "127.0.0.1\tlocalhost\n"
858       "192.168.1.1\tmyhost.domain.tld myhost\n")
859     self.assertFileMode(self.tmpname, 0644)
860
861   def testSettingDuplicateName(self):
862     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
863
864     self.assertFileContent(self.tmpname,
865       "# This is a test file for /etc/hosts\n"
866       "127.0.0.1\tlocalhost\n"
867       "192.168.1.1 router gw\n"
868       "1.2.3.4\tmyhost\n")
869     self.assertFileMode(self.tmpname, 0644)
870
871   def testRemovingExistingHost(self):
872     RemoveEtcHostsEntry(self.tmpname, 'router')
873
874     self.assertFileContent(self.tmpname,
875       "# This is a test file for /etc/hosts\n"
876       "127.0.0.1\tlocalhost\n"
877       "192.168.1.1 gw\n")
878     self.assertFileMode(self.tmpname, 0644)
879
880   def testRemovingSingleExistingHost(self):
881     RemoveEtcHostsEntry(self.tmpname, 'localhost')
882
883     self.assertFileContent(self.tmpname,
884       "# This is a test file for /etc/hosts\n"
885       "192.168.1.1 router gw\n")
886     self.assertFileMode(self.tmpname, 0644)
887
888   def testRemovingNonExistingHost(self):
889     RemoveEtcHostsEntry(self.tmpname, 'myhost')
890
891     self.assertFileContent(self.tmpname,
892       "# This is a test file for /etc/hosts\n"
893       "127.0.0.1\tlocalhost\n"
894       "192.168.1.1 router gw\n")
895     self.assertFileMode(self.tmpname, 0644)
896
897   def testRemovingAlias(self):
898     RemoveEtcHostsEntry(self.tmpname, 'gw')
899
900     self.assertFileContent(self.tmpname,
901       "# This is a test file for /etc/hosts\n"
902       "127.0.0.1\tlocalhost\n"
903       "192.168.1.1 router\n")
904     self.assertFileMode(self.tmpname, 0644)
905
906
907 class TestShellQuoting(unittest.TestCase):
908   """Test case for shell quoting functions"""
909
910   def testShellQuote(self):
911     self.assertEqual(ShellQuote('abc'), "abc")
912     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
913     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
914     self.assertEqual(ShellQuote("a b c"), "'a b c'")
915     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
916
917   def testShellQuoteArgs(self):
918     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
919     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
920     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
921
922
923 class TestTcpPing(unittest.TestCase):
924   """Testcase for TCP version of ping - against listen(2)ing port"""
925
926   def setUp(self):
927     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
928     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
929     self.listenerport = self.listener.getsockname()[1]
930     self.listener.listen(1)
931
932   def tearDown(self):
933     self.listener.shutdown(socket.SHUT_RDWR)
934     del self.listener
935     del self.listenerport
936
937   def testTcpPingToLocalHostAccept(self):
938     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
939                          self.listenerport,
940                          timeout=10,
941                          live_port_needed=True,
942                          source=constants.LOCALHOST_IP_ADDRESS,
943                          ),
944                  "failed to connect to test listener")
945
946     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
947                          self.listenerport,
948                          timeout=10,
949                          live_port_needed=True,
950                          ),
951                  "failed to connect to test listener (no source)")
952
953
954 class TestTcpPingDeaf(unittest.TestCase):
955   """Testcase for TCP version of ping - against non listen(2)ing port"""
956
957   def setUp(self):
958     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
959     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
960     self.deaflistenerport = self.deaflistener.getsockname()[1]
961
962   def tearDown(self):
963     del self.deaflistener
964     del self.deaflistenerport
965
966   def testTcpPingToLocalHostAcceptDeaf(self):
967     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
968                         self.deaflistenerport,
969                         timeout=constants.TCP_PING_TIMEOUT,
970                         live_port_needed=True,
971                         source=constants.LOCALHOST_IP_ADDRESS,
972                         ), # need successful connect(2)
973                 "successfully connected to deaf listener")
974
975     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
976                         self.deaflistenerport,
977                         timeout=constants.TCP_PING_TIMEOUT,
978                         live_port_needed=True,
979                         ), # need successful connect(2)
980                 "successfully connected to deaf listener (no source addr)")
981
982   def testTcpPingToLocalHostNoAccept(self):
983     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
984                          self.deaflistenerport,
985                          timeout=constants.TCP_PING_TIMEOUT,
986                          live_port_needed=False,
987                          source=constants.LOCALHOST_IP_ADDRESS,
988                          ), # ECONNREFUSED is OK
989                  "failed to ping alive host on deaf port")
990
991     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
992                          self.deaflistenerport,
993                          timeout=constants.TCP_PING_TIMEOUT,
994                          live_port_needed=False,
995                          ), # ECONNREFUSED is OK
996                  "failed to ping alive host on deaf port (no source addr)")
997
998
999 class TestOwnIpAddress(unittest.TestCase):
1000   """Testcase for OwnIpAddress"""
1001
1002   def testOwnLoopback(self):
1003     """check having the loopback ip"""
1004     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1005                     "Should own the loopback address")
1006
1007   def testNowOwnAddress(self):
1008     """check that I don't own an address"""
1009
1010     # Network 192.0.2.0/24 is reserved for test/documentation as per
1011     # RFC 5735, so we *should* not have an address of this range... if
1012     # this fails, we should extend the test to multiple addresses
1013     DST_IP = "192.0.2.1"
1014     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1015
1016
1017 def _GetSocketCredentials(path):
1018   """Connect to a Unix socket and return remote credentials.
1019
1020   """
1021   sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1022   try:
1023     sock.settimeout(10)
1024     sock.connect(path)
1025     return utils.GetSocketCredentials(sock)
1026   finally:
1027     sock.close()
1028
1029
1030 class TestGetSocketCredentials(unittest.TestCase):
1031   def setUp(self):
1032     self.tmpdir = tempfile.mkdtemp()
1033     self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1034
1035     self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1036     self.listener.settimeout(10)
1037     self.listener.bind(self.sockpath)
1038     self.listener.listen(1)
1039
1040   def tearDown(self):
1041     self.listener.shutdown(socket.SHUT_RDWR)
1042     self.listener.close()
1043     shutil.rmtree(self.tmpdir)
1044
1045   def test(self):
1046     (c2pr, c2pw) = os.pipe()
1047
1048     # Start child process
1049     child = os.fork()
1050     if child == 0:
1051       try:
1052         data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1053
1054         os.write(c2pw, data)
1055         os.close(c2pw)
1056
1057         os._exit(0)
1058       finally:
1059         os._exit(1)
1060
1061     os.close(c2pw)
1062
1063     # Wait for one connection
1064     (conn, _) = self.listener.accept()
1065     conn.recv(1)
1066     conn.close()
1067
1068     # Wait for result
1069     result = os.read(c2pr, 4096)
1070     os.close(c2pr)
1071
1072     # Check child's exit code
1073     (_, status) = os.waitpid(child, 0)
1074     self.assertFalse(os.WIFSIGNALED(status))
1075     self.assertEqual(os.WEXITSTATUS(status), 0)
1076
1077     # Check result
1078     (pid, uid, gid) = serializer.LoadJson(result)
1079     self.assertEqual(pid, os.getpid())
1080     self.assertEqual(uid, os.getuid())
1081     self.assertEqual(gid, os.getgid())
1082
1083
1084 class TestListVisibleFiles(unittest.TestCase):
1085   """Test case for ListVisibleFiles"""
1086
1087   def setUp(self):
1088     self.path = tempfile.mkdtemp()
1089
1090   def tearDown(self):
1091     shutil.rmtree(self.path)
1092
1093   def _test(self, files, expected):
1094     # Sort a copy
1095     expected = expected[:]
1096     expected.sort()
1097
1098     for name in files:
1099       f = open(os.path.join(self.path, name), 'w')
1100       try:
1101         f.write("Test\n")
1102       finally:
1103         f.close()
1104
1105     found = ListVisibleFiles(self.path)
1106     found.sort()
1107
1108     self.assertEqual(found, expected)
1109
1110   def testAllVisible(self):
1111     files = ["a", "b", "c"]
1112     expected = files
1113     self._test(files, expected)
1114
1115   def testNoneVisible(self):
1116     files = [".a", ".b", ".c"]
1117     expected = []
1118     self._test(files, expected)
1119
1120   def testSomeVisible(self):
1121     files = ["a", "b", ".c"]
1122     expected = ["a", "b"]
1123     self._test(files, expected)
1124
1125   def testNonAbsolutePath(self):
1126     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1127
1128   def testNonNormalizedPath(self):
1129     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1130                           "/bin/../tmp")
1131
1132
1133 class TestNewUUID(unittest.TestCase):
1134   """Test case for NewUUID"""
1135
1136   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1137                         '[a-f0-9]{4}-[a-f0-9]{12}$')
1138
1139   def runTest(self):
1140     self.failUnless(self._re_uuid.match(utils.NewUUID()))
1141
1142
1143 class TestUniqueSequence(unittest.TestCase):
1144   """Test case for UniqueSequence"""
1145
1146   def _test(self, input, expected):
1147     self.assertEqual(utils.UniqueSequence(input), expected)
1148
1149   def runTest(self):
1150     # Ordered input
1151     self._test([1, 2, 3], [1, 2, 3])
1152     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1153     self._test([1, 2, 2, 3], [1, 2, 3])
1154     self._test([1, 2, 3, 3], [1, 2, 3])
1155
1156     # Unordered input
1157     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1158     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1159
1160     # Strings
1161     self._test(["a", "a"], ["a"])
1162     self._test(["a", "b"], ["a", "b"])
1163     self._test(["a", "b", "a"], ["a", "b"])
1164
1165
1166 class TestFirstFree(unittest.TestCase):
1167   """Test case for the FirstFree function"""
1168
1169   def test(self):
1170     """Test FirstFree"""
1171     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1172     self.failUnlessEqual(FirstFree([]), None)
1173     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1174     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1175     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1176
1177
1178 class TestTailFile(testutils.GanetiTestCase):
1179   """Test case for the TailFile function"""
1180
1181   def testEmpty(self):
1182     fname = self._CreateTempFile()
1183     self.failUnlessEqual(TailFile(fname), [])
1184     self.failUnlessEqual(TailFile(fname, lines=25), [])
1185
1186   def testAllLines(self):
1187     data = ["test %d" % i for i in range(30)]
1188     for i in range(30):
1189       fname = self._CreateTempFile()
1190       fd = open(fname, "w")
1191       fd.write("\n".join(data[:i]))
1192       if i > 0:
1193         fd.write("\n")
1194       fd.close()
1195       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1196
1197   def testPartialLines(self):
1198     data = ["test %d" % i for i in range(30)]
1199     fname = self._CreateTempFile()
1200     fd = open(fname, "w")
1201     fd.write("\n".join(data))
1202     fd.write("\n")
1203     fd.close()
1204     for i in range(1, 30):
1205       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1206
1207   def testBigFile(self):
1208     data = ["test %d" % i for i in range(30)]
1209     fname = self._CreateTempFile()
1210     fd = open(fname, "w")
1211     fd.write("X" * 1048576)
1212     fd.write("\n")
1213     fd.write("\n".join(data))
1214     fd.write("\n")
1215     fd.close()
1216     for i in range(1, 30):
1217       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1218
1219
1220 class _BaseFileLockTest:
1221   """Test case for the FileLock class"""
1222
1223   def testSharedNonblocking(self):
1224     self.lock.Shared(blocking=False)
1225     self.lock.Close()
1226
1227   def testExclusiveNonblocking(self):
1228     self.lock.Exclusive(blocking=False)
1229     self.lock.Close()
1230
1231   def testUnlockNonblocking(self):
1232     self.lock.Unlock(blocking=False)
1233     self.lock.Close()
1234
1235   def testSharedBlocking(self):
1236     self.lock.Shared(blocking=True)
1237     self.lock.Close()
1238
1239   def testExclusiveBlocking(self):
1240     self.lock.Exclusive(blocking=True)
1241     self.lock.Close()
1242
1243   def testUnlockBlocking(self):
1244     self.lock.Unlock(blocking=True)
1245     self.lock.Close()
1246
1247   def testSharedExclusiveUnlock(self):
1248     self.lock.Shared(blocking=False)
1249     self.lock.Exclusive(blocking=False)
1250     self.lock.Unlock(blocking=False)
1251     self.lock.Close()
1252
1253   def testExclusiveSharedUnlock(self):
1254     self.lock.Exclusive(blocking=False)
1255     self.lock.Shared(blocking=False)
1256     self.lock.Unlock(blocking=False)
1257     self.lock.Close()
1258
1259   def testSimpleTimeout(self):
1260     # These will succeed on the first attempt, hence a short timeout
1261     self.lock.Shared(blocking=True, timeout=10.0)
1262     self.lock.Exclusive(blocking=False, timeout=10.0)
1263     self.lock.Unlock(blocking=True, timeout=10.0)
1264     self.lock.Close()
1265
1266   @staticmethod
1267   def _TryLockInner(filename, shared, blocking):
1268     lock = utils.FileLock.Open(filename)
1269
1270     if shared:
1271       fn = lock.Shared
1272     else:
1273       fn = lock.Exclusive
1274
1275     try:
1276       # The timeout doesn't really matter as the parent process waits for us to
1277       # finish anyway.
1278       fn(blocking=blocking, timeout=0.01)
1279     except errors.LockError, err:
1280       return False
1281
1282     return True
1283
1284   def _TryLock(self, *args):
1285     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1286                                       *args)
1287
1288   def testTimeout(self):
1289     for blocking in [True, False]:
1290       self.lock.Exclusive(blocking=True)
1291       self.failIf(self._TryLock(False, blocking))
1292       self.failIf(self._TryLock(True, blocking))
1293
1294       self.lock.Shared(blocking=True)
1295       self.assert_(self._TryLock(True, blocking))
1296       self.failIf(self._TryLock(False, blocking))
1297
1298   def testCloseShared(self):
1299     self.lock.Close()
1300     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1301
1302   def testCloseExclusive(self):
1303     self.lock.Close()
1304     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1305
1306   def testCloseUnlock(self):
1307     self.lock.Close()
1308     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1309
1310
1311 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1312   TESTDATA = "Hello World\n" * 10
1313
1314   def setUp(self):
1315     testutils.GanetiTestCase.setUp(self)
1316
1317     self.tmpfile = tempfile.NamedTemporaryFile()
1318     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1319     self.lock = utils.FileLock.Open(self.tmpfile.name)
1320
1321     # Ensure "Open" didn't truncate file
1322     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1323
1324   def tearDown(self):
1325     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1326
1327     testutils.GanetiTestCase.tearDown(self)
1328
1329
1330 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1331   def setUp(self):
1332     self.tmpfile = tempfile.NamedTemporaryFile()
1333     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1334
1335
1336 class TestTimeFunctions(unittest.TestCase):
1337   """Test case for time functions"""
1338
1339   def runTest(self):
1340     self.assertEqual(utils.SplitTime(1), (1, 0))
1341     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1342     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1343     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1344     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1345     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1346     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1347     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1348
1349     self.assertRaises(AssertionError, utils.SplitTime, -1)
1350
1351     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1352     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1353     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1354
1355     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1356                      1218448917.481)
1357     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1358
1359     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1360     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1361     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1362     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1363     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1364
1365
1366 class FieldSetTestCase(unittest.TestCase):
1367   """Test case for FieldSets"""
1368
1369   def testSimpleMatch(self):
1370     f = utils.FieldSet("a", "b", "c", "def")
1371     self.failUnless(f.Matches("a"))
1372     self.failIf(f.Matches("d"), "Substring matched")
1373     self.failIf(f.Matches("defghi"), "Prefix string matched")
1374     self.failIf(f.NonMatching(["b", "c"]))
1375     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1376     self.failUnless(f.NonMatching(["a", "d"]))
1377
1378   def testRegexMatch(self):
1379     f = utils.FieldSet("a", "b([0-9]+)", "c")
1380     self.failUnless(f.Matches("b1"))
1381     self.failUnless(f.Matches("b99"))
1382     self.failIf(f.Matches("b/1"))
1383     self.failIf(f.NonMatching(["b12", "c"]))
1384     self.failUnless(f.NonMatching(["a", "1"]))
1385
1386 class TestForceDictType(unittest.TestCase):
1387   """Test case for ForceDictType"""
1388
1389   def setUp(self):
1390     self.key_types = {
1391       'a': constants.VTYPE_INT,
1392       'b': constants.VTYPE_BOOL,
1393       'c': constants.VTYPE_STRING,
1394       'd': constants.VTYPE_SIZE,
1395       }
1396
1397   def _fdt(self, dict, allowed_values=None):
1398     if allowed_values is None:
1399       ForceDictType(dict, self.key_types)
1400     else:
1401       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1402
1403     return dict
1404
1405   def testSimpleDict(self):
1406     self.assertEqual(self._fdt({}), {})
1407     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1408     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1409     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1410     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1411     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1412     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1413     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1414     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1415     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1416     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1417     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1418
1419   def testErrors(self):
1420     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1421     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1422     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1423     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1424
1425
1426 class TestIsNormAbsPath(unittest.TestCase):
1427   """Testing case for IsNormAbsPath"""
1428
1429   def _pathTestHelper(self, path, result):
1430     if result:
1431       self.assert_(IsNormAbsPath(path),
1432           "Path %s should result absolute and normalized" % path)
1433     else:
1434       self.assert_(not IsNormAbsPath(path),
1435           "Path %s should not result absolute and normalized" % path)
1436
1437   def testBase(self):
1438     self._pathTestHelper('/etc', True)
1439     self._pathTestHelper('/srv', True)
1440     self._pathTestHelper('etc', False)
1441     self._pathTestHelper('/etc/../root', False)
1442     self._pathTestHelper('/etc/', False)
1443
1444
1445 class TestSafeEncode(unittest.TestCase):
1446   """Test case for SafeEncode"""
1447
1448   def testAscii(self):
1449     for txt in [string.digits, string.letters, string.punctuation]:
1450       self.failUnlessEqual(txt, SafeEncode(txt))
1451
1452   def testDoubleEncode(self):
1453     for i in range(255):
1454       txt = SafeEncode(chr(i))
1455       self.failUnlessEqual(txt, SafeEncode(txt))
1456
1457   def testUnicode(self):
1458     # 1024 is high enough to catch non-direct ASCII mappings
1459     for i in range(1024):
1460       txt = SafeEncode(unichr(i))
1461       self.failUnlessEqual(txt, SafeEncode(txt))
1462
1463
1464 class TestFormatTime(unittest.TestCase):
1465   """Testing case for FormatTime"""
1466
1467   def testNone(self):
1468     self.failUnlessEqual(FormatTime(None), "N/A")
1469
1470   def testInvalid(self):
1471     self.failUnlessEqual(FormatTime(()), "N/A")
1472
1473   def testNow(self):
1474     # tests that we accept time.time input
1475     FormatTime(time.time())
1476     # tests that we accept int input
1477     FormatTime(int(time.time()))
1478
1479
1480 class RunInSeparateProcess(unittest.TestCase):
1481   def test(self):
1482     for exp in [True, False]:
1483       def _child():
1484         return exp
1485
1486       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1487
1488   def testArgs(self):
1489     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1490       def _child(carg1, carg2):
1491         return carg1 == "Foo" and carg2 == arg
1492
1493       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1494
1495   def testPid(self):
1496     parent_pid = os.getpid()
1497
1498     def _check():
1499       return os.getpid() == parent_pid
1500
1501     self.failIf(utils.RunInSeparateProcess(_check))
1502
1503   def testSignal(self):
1504     def _kill():
1505       os.kill(os.getpid(), signal.SIGTERM)
1506
1507     self.assertRaises(errors.GenericError,
1508                       utils.RunInSeparateProcess, _kill)
1509
1510   def testException(self):
1511     def _exc():
1512       raise errors.GenericError("This is a test")
1513
1514     self.assertRaises(errors.GenericError,
1515                       utils.RunInSeparateProcess, _exc)
1516
1517
1518 class TestFingerprintFile(unittest.TestCase):
1519   def setUp(self):
1520     self.tmpfile = tempfile.NamedTemporaryFile()
1521
1522   def test(self):
1523     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1524                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1525
1526     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1527     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1528                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1529
1530
1531 class TestUnescapeAndSplit(unittest.TestCase):
1532   """Testing case for UnescapeAndSplit"""
1533
1534   def setUp(self):
1535     # testing more that one separator for regexp safety
1536     self._seps = [",", "+", "."]
1537
1538   def testSimple(self):
1539     a = ["a", "b", "c", "d"]
1540     for sep in self._seps:
1541       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1542
1543   def testEscape(self):
1544     for sep in self._seps:
1545       a = ["a", "b\\" + sep + "c", "d"]
1546       b = ["a", "b" + sep + "c", "d"]
1547       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1548
1549   def testDoubleEscape(self):
1550     for sep in self._seps:
1551       a = ["a", "b\\\\", "c", "d"]
1552       b = ["a", "b\\", "c", "d"]
1553       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1554
1555   def testThreeEscape(self):
1556     for sep in self._seps:
1557       a = ["a", "b\\\\\\" + sep + "c", "d"]
1558       b = ["a", "b\\" + sep + "c", "d"]
1559       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1560
1561
1562 class TestPathJoin(unittest.TestCase):
1563   """Testing case for PathJoin"""
1564
1565   def testBasicItems(self):
1566     mlist = ["/a", "b", "c"]
1567     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1568
1569   def testNonAbsPrefix(self):
1570     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1571
1572   def testBackTrack(self):
1573     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1574
1575   def testMultiAbs(self):
1576     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1577
1578
1579 class TestHostInfo(unittest.TestCase):
1580   """Testing case for HostInfo"""
1581
1582   def testUppercase(self):
1583     data = "AbC.example.com"
1584     self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1585
1586   def testTooLongName(self):
1587     data = "a.b." + "c" * 255
1588     self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1589
1590   def testTrailingDot(self):
1591     data = "a.b.c"
1592     self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1593
1594   def testInvalidName(self):
1595     data = [
1596       "a b",
1597       "a/b",
1598       ".a.b",
1599       "a..b",
1600       ]
1601     for value in data:
1602       self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1603
1604   def testValidName(self):
1605     data = [
1606       "a.b",
1607       "a-b",
1608       "a_b",
1609       "a.b.c",
1610       ]
1611     for value in data:
1612       HostInfo.NormalizeName(value)
1613
1614
1615 class TestParseAsn1Generalizedtime(unittest.TestCase):
1616   def test(self):
1617     # UTC
1618     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1619     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1620                      1266860512)
1621     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1622                      (2**31) - 1)
1623
1624     # With offset
1625     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1626                      1266860512)
1627     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1628                      1266931012)
1629     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1630                      1266931088)
1631     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1632                      1266931295)
1633     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1634                      3600)
1635
1636     # Leap seconds are not supported by datetime.datetime
1637     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1638                       "19841231235960+0000")
1639     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1640                       "19920630235960+0000")
1641
1642     # Errors
1643     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1644     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1645     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1646                       "20100222174152")
1647     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1648                       "Mon Feb 22 17:47:02 UTC 2010")
1649     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1650                       "2010-02-22 17:42:02")
1651
1652
1653 class TestGetX509CertValidity(testutils.GanetiTestCase):
1654   def setUp(self):
1655     testutils.GanetiTestCase.setUp(self)
1656
1657     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1658
1659     # Test whether we have pyOpenSSL 0.7 or above
1660     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1661
1662     if not self.pyopenssl0_7:
1663       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1664                     " function correctly")
1665
1666   def _LoadCert(self, name):
1667     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1668                                            self._ReadTestData(name))
1669
1670   def test(self):
1671     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1672     if self.pyopenssl0_7:
1673       self.assertEqual(validity, (1266919967, 1267524767))
1674     else:
1675       self.assertEqual(validity, (None, None))
1676
1677
1678 class TestMakedirs(unittest.TestCase):
1679   def setUp(self):
1680     self.tmpdir = tempfile.mkdtemp()
1681
1682   def tearDown(self):
1683     shutil.rmtree(self.tmpdir)
1684
1685   def testNonExisting(self):
1686     path = utils.PathJoin(self.tmpdir, "foo")
1687     utils.Makedirs(path)
1688     self.assert_(os.path.isdir(path))
1689
1690   def testExisting(self):
1691     path = utils.PathJoin(self.tmpdir, "foo")
1692     os.mkdir(path)
1693     utils.Makedirs(path)
1694     self.assert_(os.path.isdir(path))
1695
1696   def testRecursiveNonExisting(self):
1697     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
1698     utils.Makedirs(path)
1699     self.assert_(os.path.isdir(path))
1700
1701   def testRecursiveExisting(self):
1702     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
1703     self.assert_(not os.path.exists(path))
1704     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
1705     utils.Makedirs(path)
1706     self.assert_(os.path.isdir(path))
1707
1708
1709 class TestRetry(testutils.GanetiTestCase):
1710   def setUp(self):
1711     testutils.GanetiTestCase.setUp(self)
1712     self.retries = 0
1713
1714   @staticmethod
1715   def _RaiseRetryAgain():
1716     raise utils.RetryAgain()
1717
1718   @staticmethod
1719   def _RaiseRetryAgainWithArg(args):
1720     raise utils.RetryAgain(*args)
1721
1722   def _WrongNestedLoop(self):
1723     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1724
1725   def _RetryAndSucceed(self, retries):
1726     if self.retries < retries:
1727       self.retries += 1
1728       raise utils.RetryAgain()
1729     else:
1730       return True
1731
1732   def testRaiseTimeout(self):
1733     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1734                           self._RaiseRetryAgain, 0.01, 0.02)
1735     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1736                           self._RetryAndSucceed, 0.01, 0, args=[1])
1737     self.failUnlessEqual(self.retries, 1)
1738
1739   def testComplete(self):
1740     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1741     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1742                          True)
1743     self.failUnlessEqual(self.retries, 2)
1744
1745   def testNestedLoop(self):
1746     try:
1747       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1748                             self._WrongNestedLoop, 0, 1)
1749     except utils.RetryTimeout:
1750       self.fail("Didn't detect inner loop's exception")
1751
1752   def testTimeoutArgument(self):
1753     retry_arg="my_important_debugging_message"
1754     try:
1755       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
1756     except utils.RetryTimeout, err:
1757       self.failUnlessEqual(err.args, (retry_arg, ))
1758     else:
1759       self.fail("Expected timeout didn't happen")
1760
1761   def testRaiseInnerWithExc(self):
1762     retry_arg="my_important_debugging_message"
1763     try:
1764       try:
1765         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1766                     args=[[errors.GenericError(retry_arg, retry_arg)]])
1767       except utils.RetryTimeout, err:
1768         err.RaiseInner()
1769       else:
1770         self.fail("Expected timeout didn't happen")
1771     except errors.GenericError, err:
1772       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1773     else:
1774       self.fail("Expected GenericError didn't happen")
1775
1776   def testRaiseInnerWithMsg(self):
1777     retry_arg="my_important_debugging_message"
1778     try:
1779       try:
1780         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1781                     args=[[retry_arg, retry_arg]])
1782       except utils.RetryTimeout, err:
1783         err.RaiseInner()
1784       else:
1785         self.fail("Expected timeout didn't happen")
1786     except utils.RetryTimeout, err:
1787       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1788     else:
1789       self.fail("Expected RetryTimeout didn't happen")
1790
1791
1792 class TestLineSplitter(unittest.TestCase):
1793   def test(self):
1794     lines = []
1795     ls = utils.LineSplitter(lines.append)
1796     ls.write("Hello World\n")
1797     self.assertEqual(lines, [])
1798     ls.write("Foo\n Bar\r\n ")
1799     ls.write("Baz")
1800     ls.write("Moo")
1801     self.assertEqual(lines, [])
1802     ls.flush()
1803     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
1804     ls.close()
1805     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
1806
1807   def _testExtra(self, line, all_lines, p1, p2):
1808     self.assertEqual(p1, 999)
1809     self.assertEqual(p2, "extra")
1810     all_lines.append(line)
1811
1812   def testExtraArgsNoFlush(self):
1813     lines = []
1814     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
1815     ls.write("\n\nHello World\n")
1816     ls.write("Foo\n Bar\r\n ")
1817     ls.write("")
1818     ls.write("Baz")
1819     ls.write("Moo\n\nx\n")
1820     self.assertEqual(lines, [])
1821     ls.close()
1822     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
1823                              "", "x"])
1824
1825
1826 class TestIgnoreSignals(unittest.TestCase):
1827   """Test the IgnoreSignals decorator"""
1828
1829   @staticmethod
1830   def _Raise(exception):
1831     raise exception
1832
1833   @staticmethod
1834   def _Return(rval):
1835     return rval
1836
1837   def testIgnoreSignals(self):
1838     sock_err_intr = socket.error(errno.EINTR, "Message")
1839     sock_err_intr.errno = errno.EINTR
1840     sock_err_inval = socket.error(errno.EINVAL, "Message")
1841     sock_err_inval.errno = errno.EINVAL
1842
1843     env_err_intr = EnvironmentError(errno.EINTR, "Message")
1844     env_err_inval = EnvironmentError(errno.EINVAL, "Message")
1845
1846     self.assertRaises(socket.error, self._Raise, sock_err_intr)
1847     self.assertRaises(socket.error, self._Raise, sock_err_inval)
1848     self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
1849     self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
1850
1851     self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
1852     self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
1853     self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
1854                       sock_err_inval)
1855     self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
1856                       env_err_inval)
1857
1858     self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
1859     self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
1860
1861
1862 if __name__ == '__main__':
1863   testutils.GanetiTestProgram()