Ensure temp files from RunCmd tests are removed
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007, 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import distutils.version
25 import errno
26 import fcntl
27 import glob
28 import os
29 import os.path
30 import re
31 import shutil
32 import signal
33 import socket
34 import stat
35 import string
36 import tempfile
37 import time
38 import unittest
39 import warnings
40 import OpenSSL
41 import random
42 import operator
43 from cStringIO import StringIO
44
45 import testutils
46 from ganeti import constants
47 from ganeti import compat
48 from ganeti import utils
49 from ganeti import errors
50 from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
51      ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
52      TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
53      ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
54
55
56 class TestIsProcessAlive(unittest.TestCase):
57   """Testing case for IsProcessAlive"""
58
59   def testExists(self):
60     mypid = os.getpid()
61     self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
62
63   def testNotExisting(self):
64     pid_non_existing = os.fork()
65     if pid_non_existing == 0:
66       os._exit(0)
67     elif pid_non_existing < 0:
68       raise SystemError("can't fork")
69     os.waitpid(pid_non_existing, 0)
70     self.assertFalse(utils.IsProcessAlive(pid_non_existing),
71                      "nonexisting process detected")
72
73
74 class TestGetProcStatusPath(unittest.TestCase):
75   def test(self):
76     self.assert_("/1234/" in utils._GetProcStatusPath(1234))
77     self.assertNotEqual(utils._GetProcStatusPath(1),
78                         utils._GetProcStatusPath(2))
79
80
81 class TestIsProcessHandlingSignal(unittest.TestCase):
82   def setUp(self):
83     self.tmpdir = tempfile.mkdtemp()
84
85   def tearDown(self):
86     shutil.rmtree(self.tmpdir)
87
88   def testParseSigsetT(self):
89     self.assertEqual(len(utils._ParseSigsetT("0")), 0)
90     self.assertEqual(utils._ParseSigsetT("1"), set([1]))
91     self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
92     self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
93     self.assertEqual(utils._ParseSigsetT("0000000180000202"),
94                      set([2, 10, 32, 33]))
95     self.assertEqual(utils._ParseSigsetT("0000000180000002"),
96                      set([2, 32, 33]))
97     self.assertEqual(utils._ParseSigsetT("0000000188000002"),
98                      set([2, 28, 32, 33]))
99     self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
100                      set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
101                           24, 25, 26, 28, 31]))
102     self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
103
104   def testGetProcStatusField(self):
105     for field in ["SigCgt", "Name", "FDSize"]:
106       for value in ["", "0", "cat", "  1234 KB"]:
107         pstatus = "\n".join([
108           "VmPeak: 999 kB",
109           "%s: %s" % (field, value),
110           "TracerPid: 0",
111           ])
112         result = utils._GetProcStatusField(pstatus, field)
113         self.assertEqual(result, value.strip())
114
115   def test(self):
116     sp = PathJoin(self.tmpdir, "status")
117
118     utils.WriteFile(sp, data="\n".join([
119       "Name:   bash",
120       "State:  S (sleeping)",
121       "SleepAVG:       98%",
122       "Pid:    22250",
123       "PPid:   10858",
124       "TracerPid:      0",
125       "SigBlk: 0000000000010000",
126       "SigIgn: 0000000000384004",
127       "SigCgt: 000000004b813efb",
128       "CapEff: 0000000000000000",
129       ]))
130
131     self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
132
133   def testNoSigCgt(self):
134     sp = PathJoin(self.tmpdir, "status")
135
136     utils.WriteFile(sp, data="\n".join([
137       "Name:   bash",
138       ]))
139
140     self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
141                       1234, 10, status_path=sp)
142
143   def testNoSuchFile(self):
144     sp = PathJoin(self.tmpdir, "notexist")
145
146     self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
147
148   @staticmethod
149   def _TestRealProcess():
150     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
151     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
152       raise Exception("SIGUSR1 is handled when it should not be")
153
154     signal.signal(signal.SIGUSR1, lambda signum, frame: None)
155     if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
156       raise Exception("SIGUSR1 is not handled when it should be")
157
158     signal.signal(signal.SIGUSR1, signal.SIG_IGN)
159     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160       raise Exception("SIGUSR1 is not handled when it should be")
161
162     signal.signal(signal.SIGUSR1, signal.SIG_DFL)
163     if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164       raise Exception("SIGUSR1 is handled when it should not be")
165
166     return True
167
168   def testRealProcess(self):
169     self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
170
171
172 class TestPidFileFunctions(unittest.TestCase):
173   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
174
175   def setUp(self):
176     self.dir = tempfile.mkdtemp()
177     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
178     utils.DaemonPidFileName = self.f_dpn
179
180   def testPidFileFunctions(self):
181     pid_file = self.f_dpn('test')
182     fd = utils.WritePidFile(self.f_dpn('test'))
183     self.failUnless(os.path.exists(pid_file),
184                     "PID file should have been created")
185     read_pid = utils.ReadPidFile(pid_file)
186     self.failUnlessEqual(read_pid, os.getpid())
187     self.failUnless(utils.IsProcessAlive(read_pid))
188     self.failUnlessRaises(errors.LockError, utils.WritePidFile,
189                           self.f_dpn('test'))
190     os.close(fd)
191     utils.RemovePidFile('test')
192     self.failIf(os.path.exists(pid_file),
193                 "PID file should not exist anymore")
194     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
195                          "ReadPidFile should return 0 for missing pid file")
196     fh = open(pid_file, "w")
197     fh.write("blah\n")
198     fh.close()
199     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
200                          "ReadPidFile should return 0 for invalid pid file")
201     # but now, even with the file existing, we should be able to lock it
202     fd = utils.WritePidFile(self.f_dpn('test'))
203     os.close(fd)
204     utils.RemovePidFile('test')
205     self.failIf(os.path.exists(pid_file),
206                 "PID file should not exist anymore")
207
208   def testKill(self):
209     pid_file = self.f_dpn('child')
210     r_fd, w_fd = os.pipe()
211     new_pid = os.fork()
212     if new_pid == 0: #child
213       utils.WritePidFile(self.f_dpn('child'))
214       os.write(w_fd, 'a')
215       signal.pause()
216       os._exit(0)
217       return
218     # else we are in the parent
219     # wait until the child has written the pid file
220     os.read(r_fd, 1)
221     read_pid = utils.ReadPidFile(pid_file)
222     self.failUnlessEqual(read_pid, new_pid)
223     self.failUnless(utils.IsProcessAlive(new_pid))
224     utils.KillProcess(new_pid, waitpid=True)
225     self.failIf(utils.IsProcessAlive(new_pid))
226     utils.RemovePidFile('child')
227     self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
228
229   def tearDown(self):
230     for name in os.listdir(self.dir):
231       os.unlink(os.path.join(self.dir, name))
232     os.rmdir(self.dir)
233
234
235 class TestRunCmd(testutils.GanetiTestCase):
236   """Testing case for the RunCmd function"""
237
238   def setUp(self):
239     testutils.GanetiTestCase.setUp(self)
240     self.magic = time.ctime() + " ganeti test"
241     self.fname = self._CreateTempFile()
242     self.fifo_tmpdir = tempfile.mkdtemp()
243     self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
244     os.mkfifo(self.fifo_file)
245
246   def tearDown(self):
247     shutil.rmtree(self.fifo_tmpdir)
248     testutils.GanetiTestCase.tearDown(self)
249
250   def testOk(self):
251     """Test successful exit code"""
252     result = RunCmd("/bin/sh -c 'exit 0'")
253     self.assertEqual(result.exit_code, 0)
254     self.assertEqual(result.output, "")
255
256   def testFail(self):
257     """Test fail exit code"""
258     result = RunCmd("/bin/sh -c 'exit 1'")
259     self.assertEqual(result.exit_code, 1)
260     self.assertEqual(result.output, "")
261
262   def testStdout(self):
263     """Test standard output"""
264     cmd = 'echo -n "%s"' % self.magic
265     result = RunCmd("/bin/sh -c '%s'" % cmd)
266     self.assertEqual(result.stdout, self.magic)
267     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
268     self.assertEqual(result.output, "")
269     self.assertFileContent(self.fname, self.magic)
270
271   def testStderr(self):
272     """Test standard error"""
273     cmd = 'echo -n "%s"' % self.magic
274     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
275     self.assertEqual(result.stderr, self.magic)
276     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
277     self.assertEqual(result.output, "")
278     self.assertFileContent(self.fname, self.magic)
279
280   def testCombined(self):
281     """Test combined output"""
282     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
283     expected = "A" + self.magic + "B" + self.magic
284     result = RunCmd("/bin/sh -c '%s'" % cmd)
285     self.assertEqual(result.output, expected)
286     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
287     self.assertEqual(result.output, "")
288     self.assertFileContent(self.fname, expected)
289
290   def testSignal(self):
291     """Test signal"""
292     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
293     self.assertEqual(result.signal, 15)
294     self.assertEqual(result.output, "")
295
296   def testTimeoutClean(self):
297     cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
298     result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
299     self.assertEqual(result.exit_code, 0)
300
301   def testTimeoutKill(self):
302     cmd = ["/bin/sh", "-c", "trap '' TERM; read < %s" % self.fifo_file]
303     timeout = 0.2
304     out, err, status, ta = utils._RunCmdPipe(cmd, {}, False, "/", False,
305                                              timeout, _linger_timeout=0.2)
306     self.assert_(status < 0)
307     self.assertEqual(-status, signal.SIGKILL)
308
309   def testTimeoutOutputAfterTerm(self):
310     cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
311     result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
312     self.assert_(result.failed)
313     self.assertEqual(result.stdout, "sigtermed\n")
314
315   def testListRun(self):
316     """Test list runs"""
317     result = RunCmd(["true"])
318     self.assertEqual(result.signal, None)
319     self.assertEqual(result.exit_code, 0)
320     result = RunCmd(["/bin/sh", "-c", "exit 1"])
321     self.assertEqual(result.signal, None)
322     self.assertEqual(result.exit_code, 1)
323     result = RunCmd(["echo", "-n", self.magic])
324     self.assertEqual(result.signal, None)
325     self.assertEqual(result.exit_code, 0)
326     self.assertEqual(result.stdout, self.magic)
327
328   def testFileEmptyOutput(self):
329     """Test file output"""
330     result = RunCmd(["true"], output=self.fname)
331     self.assertEqual(result.signal, None)
332     self.assertEqual(result.exit_code, 0)
333     self.assertFileContent(self.fname, "")
334
335   def testLang(self):
336     """Test locale environment"""
337     old_env = os.environ.copy()
338     try:
339       os.environ["LANG"] = "en_US.UTF-8"
340       os.environ["LC_ALL"] = "en_US.UTF-8"
341       result = RunCmd(["locale"])
342       for line in result.output.splitlines():
343         key, value = line.split("=", 1)
344         # Ignore these variables, they're overridden by LC_ALL
345         if key == "LANG" or key == "LANGUAGE":
346           continue
347         self.failIf(value and value != "C" and value != '"C"',
348             "Variable %s is set to the invalid value '%s'" % (key, value))
349     finally:
350       os.environ = old_env
351
352   def testDefaultCwd(self):
353     """Test default working directory"""
354     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
355
356   def testCwd(self):
357     """Test default working directory"""
358     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
359     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
360     cwd = os.getcwd()
361     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
362
363   def testResetEnv(self):
364     """Test environment reset functionality"""
365     self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
366     self.failUnlessEqual(RunCmd(["env"], reset_env=True,
367                                 env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
368
369   def testNoFork(self):
370     """Test that nofork raise an error"""
371     assert not utils.no_fork
372     utils.no_fork = True
373     try:
374       self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
375     finally:
376       utils.no_fork = False
377
378   def testWrongParams(self):
379     """Test wrong parameters"""
380     self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
381                       output="/dev/null", interactive=True)
382
383
384 class TestRunParts(testutils.GanetiTestCase):
385   """Testing case for the RunParts function"""
386
387   def setUp(self):
388     self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
389
390   def tearDown(self):
391     shutil.rmtree(self.rundir)
392
393   def testEmpty(self):
394     """Test on an empty dir"""
395     self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
396
397   def testSkipWrongName(self):
398     """Test that wrong files are skipped"""
399     fname = os.path.join(self.rundir, "00test.dot")
400     utils.WriteFile(fname, data="")
401     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
402     relname = os.path.basename(fname)
403     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
404                          [(relname, constants.RUNPARTS_SKIP, None)])
405
406   def testSkipNonExec(self):
407     """Test that non executable files are skipped"""
408     fname = os.path.join(self.rundir, "00test")
409     utils.WriteFile(fname, data="")
410     relname = os.path.basename(fname)
411     self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
412                          [(relname, constants.RUNPARTS_SKIP, None)])
413
414   def testError(self):
415     """Test error on a broken executable"""
416     fname = os.path.join(self.rundir, "00test")
417     utils.WriteFile(fname, data="")
418     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
419     (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
420     self.failUnlessEqual(relname, os.path.basename(fname))
421     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
422     self.failUnless(error)
423
424   def testSorted(self):
425     """Test executions are sorted"""
426     files = []
427     files.append(os.path.join(self.rundir, "64test"))
428     files.append(os.path.join(self.rundir, "00test"))
429     files.append(os.path.join(self.rundir, "42test"))
430
431     for fname in files:
432       utils.WriteFile(fname, data="")
433
434     results = RunParts(self.rundir, reset_env=True)
435
436     for fname in sorted(files):
437       self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
438
439   def testOk(self):
440     """Test correct execution"""
441     fname = os.path.join(self.rundir, "00test")
442     utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
443     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
444     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
445     self.failUnlessEqual(relname, os.path.basename(fname))
446     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
447     self.failUnlessEqual(runresult.stdout, "ciao")
448
449   def testRunFail(self):
450     """Test correct execution, with run failure"""
451     fname = os.path.join(self.rundir, "00test")
452     utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
453     os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
454     (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
455     self.failUnlessEqual(relname, os.path.basename(fname))
456     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
457     self.failUnlessEqual(runresult.exit_code, 1)
458     self.failUnless(runresult.failed)
459
460   def testRunMix(self):
461     files = []
462     files.append(os.path.join(self.rundir, "00test"))
463     files.append(os.path.join(self.rundir, "42test"))
464     files.append(os.path.join(self.rundir, "64test"))
465     files.append(os.path.join(self.rundir, "99test"))
466
467     files.sort()
468
469     # 1st has errors in execution
470     utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
471     os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
472
473     # 2nd is skipped
474     utils.WriteFile(files[1], data="")
475
476     # 3rd cannot execute properly
477     utils.WriteFile(files[2], data="")
478     os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
479
480     # 4th execs
481     utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
482     os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
483
484     results = RunParts(self.rundir, reset_env=True)
485
486     (relname, status, runresult) = results[0]
487     self.failUnlessEqual(relname, os.path.basename(files[0]))
488     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
489     self.failUnlessEqual(runresult.exit_code, 1)
490     self.failUnless(runresult.failed)
491
492     (relname, status, runresult) = results[1]
493     self.failUnlessEqual(relname, os.path.basename(files[1]))
494     self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
495     self.failUnlessEqual(runresult, None)
496
497     (relname, status, runresult) = results[2]
498     self.failUnlessEqual(relname, os.path.basename(files[2]))
499     self.failUnlessEqual(status, constants.RUNPARTS_ERR)
500     self.failUnless(runresult)
501
502     (relname, status, runresult) = results[3]
503     self.failUnlessEqual(relname, os.path.basename(files[3]))
504     self.failUnlessEqual(status, constants.RUNPARTS_RUN)
505     self.failUnlessEqual(runresult.output, "ciao")
506     self.failUnlessEqual(runresult.exit_code, 0)
507     self.failUnless(not runresult.failed)
508
509   def testMissingDirectory(self):
510     nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
511     self.assertEqual(RunParts(nosuchdir), [])
512
513
514 class TestStartDaemon(testutils.GanetiTestCase):
515   def setUp(self):
516     self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
517     self.tmpfile = os.path.join(self.tmpdir, "test")
518
519   def tearDown(self):
520     shutil.rmtree(self.tmpdir)
521
522   def testShell(self):
523     utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
524     self._wait(self.tmpfile, 60.0, "Hello World")
525
526   def testShellOutput(self):
527     utils.StartDaemon("echo Hello World", output=self.tmpfile)
528     self._wait(self.tmpfile, 60.0, "Hello World")
529
530   def testNoShellNoOutput(self):
531     utils.StartDaemon(["pwd"])
532
533   def testNoShellNoOutputTouch(self):
534     testfile = os.path.join(self.tmpdir, "check")
535     self.failIf(os.path.exists(testfile))
536     utils.StartDaemon(["touch", testfile])
537     self._wait(testfile, 60.0, "")
538
539   def testNoShellOutput(self):
540     utils.StartDaemon(["pwd"], output=self.tmpfile)
541     self._wait(self.tmpfile, 60.0, "/")
542
543   def testNoShellOutputCwd(self):
544     utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
545     self._wait(self.tmpfile, 60.0, os.getcwd())
546
547   def testShellEnv(self):
548     utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
549                       env={ "GNT_TEST_VAR": "Hello World", })
550     self._wait(self.tmpfile, 60.0, "Hello World")
551
552   def testNoShellEnv(self):
553     utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
554                       env={ "GNT_TEST_VAR": "Hello World", })
555     self._wait(self.tmpfile, 60.0, "Hello World")
556
557   def testOutputFd(self):
558     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
559     try:
560       utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
561     finally:
562       os.close(fd)
563     self._wait(self.tmpfile, 60.0, os.getcwd())
564
565   def testPid(self):
566     pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
567     self._wait(self.tmpfile, 60.0, str(pid))
568
569   def testPidFile(self):
570     pidfile = os.path.join(self.tmpdir, "pid")
571     checkfile = os.path.join(self.tmpdir, "abort")
572
573     pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
574                             output=self.tmpfile)
575     try:
576       fd = os.open(pidfile, os.O_RDONLY)
577       try:
578         # Check file is locked
579         self.assertRaises(errors.LockError, utils.LockFile, fd)
580
581         pidtext = os.read(fd, 100)
582       finally:
583         os.close(fd)
584
585       self.assertEqual(int(pidtext.strip()), pid)
586
587       self.assert_(utils.IsProcessAlive(pid))
588     finally:
589       # No matter what happens, kill daemon
590       utils.KillProcess(pid, timeout=5.0, waitpid=False)
591       self.failIf(utils.IsProcessAlive(pid))
592
593     self.assertEqual(utils.ReadFile(self.tmpfile), "")
594
595   def _wait(self, path, timeout, expected):
596     # Due to the asynchronous nature of daemon processes, polling is necessary.
597     # A timeout makes sure the test doesn't hang forever.
598     def _CheckFile():
599       if not (os.path.isfile(path) and
600               utils.ReadFile(path).strip() == expected):
601         raise utils.RetryAgain()
602
603     try:
604       utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
605     except utils.RetryTimeout:
606       self.fail("Apparently the daemon didn't run in %s seconds and/or"
607                 " didn't write the correct output" % timeout)
608
609   def testError(self):
610     self.assertRaises(errors.OpExecError, utils.StartDaemon,
611                       ["./does-NOT-EXIST/here/0123456789"])
612     self.assertRaises(errors.OpExecError, utils.StartDaemon,
613                       ["./does-NOT-EXIST/here/0123456789"],
614                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
615     self.assertRaises(errors.OpExecError, utils.StartDaemon,
616                       ["./does-NOT-EXIST/here/0123456789"],
617                       cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
618     self.assertRaises(errors.OpExecError, utils.StartDaemon,
619                       ["./does-NOT-EXIST/here/0123456789"],
620                       output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
621
622     fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
623     try:
624       self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
625                         ["./does-NOT-EXIST/here/0123456789"],
626                         output=self.tmpfile, output_fd=fd)
627     finally:
628       os.close(fd)
629
630
631 class TestSetCloseOnExecFlag(unittest.TestCase):
632   """Tests for SetCloseOnExecFlag"""
633
634   def setUp(self):
635     self.tmpfile = tempfile.TemporaryFile()
636
637   def testEnable(self):
638     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
639     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
640                     fcntl.FD_CLOEXEC)
641
642   def testDisable(self):
643     utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
644     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
645                 fcntl.FD_CLOEXEC)
646
647
648 class TestSetNonblockFlag(unittest.TestCase):
649   def setUp(self):
650     self.tmpfile = tempfile.TemporaryFile()
651
652   def testEnable(self):
653     utils.SetNonblockFlag(self.tmpfile.fileno(), True)
654     self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
655                     os.O_NONBLOCK)
656
657   def testDisable(self):
658     utils.SetNonblockFlag(self.tmpfile.fileno(), False)
659     self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
660                 os.O_NONBLOCK)
661
662
663 class TestRemoveFile(unittest.TestCase):
664   """Test case for the RemoveFile function"""
665
666   def setUp(self):
667     """Create a temp dir and file for each case"""
668     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
669     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
670     os.close(fd)
671
672   def tearDown(self):
673     if os.path.exists(self.tmpfile):
674       os.unlink(self.tmpfile)
675     os.rmdir(self.tmpdir)
676
677   def testIgnoreDirs(self):
678     """Test that RemoveFile() ignores directories"""
679     self.assertEqual(None, RemoveFile(self.tmpdir))
680
681   def testIgnoreNotExisting(self):
682     """Test that RemoveFile() ignores non-existing files"""
683     RemoveFile(self.tmpfile)
684     RemoveFile(self.tmpfile)
685
686   def testRemoveFile(self):
687     """Test that RemoveFile does remove a file"""
688     RemoveFile(self.tmpfile)
689     if os.path.exists(self.tmpfile):
690       self.fail("File '%s' not removed" % self.tmpfile)
691
692   def testRemoveSymlink(self):
693     """Test that RemoveFile does remove symlinks"""
694     symlink = self.tmpdir + "/symlink"
695     os.symlink("no-such-file", symlink)
696     RemoveFile(symlink)
697     if os.path.exists(symlink):
698       self.fail("File '%s' not removed" % symlink)
699     os.symlink(self.tmpfile, symlink)
700     RemoveFile(symlink)
701     if os.path.exists(symlink):
702       self.fail("File '%s' not removed" % symlink)
703
704
705 class TestRemoveDir(unittest.TestCase):
706   def setUp(self):
707     self.tmpdir = tempfile.mkdtemp()
708
709   def tearDown(self):
710     try:
711       shutil.rmtree(self.tmpdir)
712     except EnvironmentError:
713       pass
714
715   def testEmptyDir(self):
716     utils.RemoveDir(self.tmpdir)
717     self.assertFalse(os.path.isdir(self.tmpdir))
718
719   def testNonEmptyDir(self):
720     self.tmpfile = os.path.join(self.tmpdir, "test1")
721     open(self.tmpfile, "w").close()
722     self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
723
724
725 class TestRename(unittest.TestCase):
726   """Test case for RenameFile"""
727
728   def setUp(self):
729     """Create a temporary directory"""
730     self.tmpdir = tempfile.mkdtemp()
731     self.tmpfile = os.path.join(self.tmpdir, "test1")
732
733     # Touch the file
734     open(self.tmpfile, "w").close()
735
736   def tearDown(self):
737     """Remove temporary directory"""
738     shutil.rmtree(self.tmpdir)
739
740   def testSimpleRename1(self):
741     """Simple rename 1"""
742     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
743     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
744
745   def testSimpleRename2(self):
746     """Simple rename 2"""
747     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
748                      mkdir=True)
749     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
750
751   def testRenameMkdir(self):
752     """Rename with mkdir"""
753     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
754                      mkdir=True)
755     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
756     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
757
758     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
759                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
760                      mkdir=True)
761     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
762     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
763     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
764
765
766 class TestMatchNameComponent(unittest.TestCase):
767   """Test case for the MatchNameComponent function"""
768
769   def testEmptyList(self):
770     """Test that there is no match against an empty list"""
771
772     self.failUnlessEqual(MatchNameComponent("", []), None)
773     self.failUnlessEqual(MatchNameComponent("test", []), None)
774
775   def testSingleMatch(self):
776     """Test that a single match is performed correctly"""
777     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
778     for key in "test2", "test2.example", "test2.example.com":
779       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
780
781   def testMultipleMatches(self):
782     """Test that a multiple match is returned as None"""
783     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
784     for key in "test1", "test1.example":
785       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
786
787   def testFullMatch(self):
788     """Test that a full match is returned correctly"""
789     key1 = "test1"
790     key2 = "test1.example"
791     mlist = [key2, key2 + ".com"]
792     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
793     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
794
795   def testCaseInsensitivePartialMatch(self):
796     """Test for the case_insensitive keyword"""
797     mlist = ["test1.example.com", "test2.example.net"]
798     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
799                      "test2.example.net")
800     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
801                      "test2.example.net")
802     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
803                      "test2.example.net")
804     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
805                      "test2.example.net")
806
807
808   def testCaseInsensitiveFullMatch(self):
809     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
810     # Between the two ts1 a full string match non-case insensitive should work
811     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
812                      None)
813     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
814                      "ts1.ex")
815     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
816                      "ts1.ex")
817     # Between the two ts2 only case differs, so only case-match works
818     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
819                      "ts2.ex")
820     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
821                      "Ts2.ex")
822     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
823                      None)
824
825
826 class TestReadFile(testutils.GanetiTestCase):
827
828   def testReadAll(self):
829     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
830     self.assertEqual(len(data), 814)
831
832     h = compat.md5_hash()
833     h.update(data)
834     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
835
836   def testReadSize(self):
837     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
838                           size=100)
839     self.assertEqual(len(data), 100)
840
841     h = compat.md5_hash()
842     h.update(data)
843     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
844
845   def testError(self):
846     self.assertRaises(EnvironmentError, utils.ReadFile,
847                       "/dev/null/does-not-exist")
848
849
850 class TestReadOneLineFile(testutils.GanetiTestCase):
851
852   def setUp(self):
853     testutils.GanetiTestCase.setUp(self)
854
855   def testDefault(self):
856     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
857     self.assertEqual(len(data), 27)
858     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
859
860   def testNotStrict(self):
861     data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
862     self.assertEqual(len(data), 27)
863     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
864
865   def testStrictFailure(self):
866     self.assertRaises(errors.GenericError, ReadOneLineFile,
867                       self._TestDataFilename("cert1.pem"), strict=True)
868
869   def testLongLine(self):
870     dummydata = (1024 * "Hello World! ")
871     myfile = self._CreateTempFile()
872     utils.WriteFile(myfile, data=dummydata)
873     datastrict = ReadOneLineFile(myfile, strict=True)
874     datalax = ReadOneLineFile(myfile, strict=False)
875     self.assertEqual(dummydata, datastrict)
876     self.assertEqual(dummydata, datalax)
877
878   def testNewline(self):
879     myfile = self._CreateTempFile()
880     myline = "myline"
881     for nl in ["", "\n", "\r\n"]:
882       dummydata = "%s%s" % (myline, nl)
883       utils.WriteFile(myfile, data=dummydata)
884       datalax = ReadOneLineFile(myfile, strict=False)
885       self.assertEqual(myline, datalax)
886       datastrict = ReadOneLineFile(myfile, strict=True)
887       self.assertEqual(myline, datastrict)
888
889   def testWhitespaceAndMultipleLines(self):
890     myfile = self._CreateTempFile()
891     for nl in ["", "\n", "\r\n"]:
892       for ws in [" ", "\t", "\t\t  \t", "\t "]:
893         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
894         utils.WriteFile(myfile, data=dummydata)
895         datalax = ReadOneLineFile(myfile, strict=False)
896         if nl:
897           self.assert_(set("\r\n") & set(dummydata))
898           self.assertRaises(errors.GenericError, ReadOneLineFile,
899                             myfile, strict=True)
900           explen = len("Foo bar baz ") + len(ws)
901           self.assertEqual(len(datalax), explen)
902           self.assertEqual(datalax, dummydata[:explen])
903           self.assertFalse(set("\r\n") & set(datalax))
904         else:
905           datastrict = ReadOneLineFile(myfile, strict=True)
906           self.assertEqual(dummydata, datastrict)
907           self.assertEqual(dummydata, datalax)
908
909   def testEmptylines(self):
910     myfile = self._CreateTempFile()
911     myline = "myline"
912     for nl in ["\n", "\r\n"]:
913       for ol in ["", "otherline"]:
914         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
915         utils.WriteFile(myfile, data=dummydata)
916         self.assert_(set("\r\n") & set(dummydata))
917         datalax = ReadOneLineFile(myfile, strict=False)
918         self.assertEqual(myline, datalax)
919         if ol:
920           self.assertRaises(errors.GenericError, ReadOneLineFile,
921                             myfile, strict=True)
922         else:
923           datastrict = ReadOneLineFile(myfile, strict=True)
924           self.assertEqual(myline, datastrict)
925
926   def testEmptyfile(self):
927     myfile = self._CreateTempFile()
928     self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
929
930
931 class TestTimestampForFilename(unittest.TestCase):
932   def test(self):
933     self.assert_("." not in utils.TimestampForFilename())
934     self.assert_(":" not in utils.TimestampForFilename())
935
936
937 class TestCreateBackup(testutils.GanetiTestCase):
938   def setUp(self):
939     testutils.GanetiTestCase.setUp(self)
940
941     self.tmpdir = tempfile.mkdtemp()
942
943   def tearDown(self):
944     testutils.GanetiTestCase.tearDown(self)
945
946     shutil.rmtree(self.tmpdir)
947
948   def testEmpty(self):
949     filename = PathJoin(self.tmpdir, "config.data")
950     utils.WriteFile(filename, data="")
951     bname = utils.CreateBackup(filename)
952     self.assertFileContent(bname, "")
953     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
954     utils.CreateBackup(filename)
955     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
956     utils.CreateBackup(filename)
957     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
958
959     fifoname = PathJoin(self.tmpdir, "fifo")
960     os.mkfifo(fifoname)
961     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
962
963   def testContent(self):
964     bkpcount = 0
965     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
966       for rep in [1, 2, 10, 127]:
967         testdata = data * rep
968
969         filename = PathJoin(self.tmpdir, "test.data_")
970         utils.WriteFile(filename, data=testdata)
971         self.assertFileContent(filename, testdata)
972
973         for _ in range(3):
974           bname = utils.CreateBackup(filename)
975           bkpcount += 1
976           self.assertFileContent(bname, testdata)
977           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
978
979
980 class TestFormatUnit(unittest.TestCase):
981   """Test case for the FormatUnit function"""
982
983   def testMiB(self):
984     self.assertEqual(FormatUnit(1, 'h'), '1M')
985     self.assertEqual(FormatUnit(100, 'h'), '100M')
986     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
987
988     self.assertEqual(FormatUnit(1, 'm'), '1')
989     self.assertEqual(FormatUnit(100, 'm'), '100')
990     self.assertEqual(FormatUnit(1023, 'm'), '1023')
991
992     self.assertEqual(FormatUnit(1024, 'm'), '1024')
993     self.assertEqual(FormatUnit(1536, 'm'), '1536')
994     self.assertEqual(FormatUnit(17133, 'm'), '17133')
995     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
996
997   def testGiB(self):
998     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
999     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
1000     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
1001     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
1002
1003     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
1004     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
1005     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
1006     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
1007
1008     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
1009     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
1010     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
1011
1012   def testTiB(self):
1013     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
1014     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
1015     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
1016
1017     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
1018     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
1019     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
1020
1021   def testErrors(self):
1022     self.assertRaises(errors.ProgrammerError, FormatUnit, 1, "a")
1023
1024
1025 class TestParseUnit(unittest.TestCase):
1026   """Test case for the ParseUnit function"""
1027
1028   SCALES = (('', 1),
1029             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
1030             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
1031             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
1032
1033   def testRounding(self):
1034     self.assertEqual(ParseUnit('0'), 0)
1035     self.assertEqual(ParseUnit('1'), 4)
1036     self.assertEqual(ParseUnit('2'), 4)
1037     self.assertEqual(ParseUnit('3'), 4)
1038
1039     self.assertEqual(ParseUnit('124'), 124)
1040     self.assertEqual(ParseUnit('125'), 128)
1041     self.assertEqual(ParseUnit('126'), 128)
1042     self.assertEqual(ParseUnit('127'), 128)
1043     self.assertEqual(ParseUnit('128'), 128)
1044     self.assertEqual(ParseUnit('129'), 132)
1045     self.assertEqual(ParseUnit('130'), 132)
1046
1047   def testFloating(self):
1048     self.assertEqual(ParseUnit('0'), 0)
1049     self.assertEqual(ParseUnit('0.5'), 4)
1050     self.assertEqual(ParseUnit('1.75'), 4)
1051     self.assertEqual(ParseUnit('1.99'), 4)
1052     self.assertEqual(ParseUnit('2.00'), 4)
1053     self.assertEqual(ParseUnit('2.01'), 4)
1054     self.assertEqual(ParseUnit('3.99'), 4)
1055     self.assertEqual(ParseUnit('4.00'), 4)
1056     self.assertEqual(ParseUnit('4.01'), 8)
1057     self.assertEqual(ParseUnit('1.5G'), 1536)
1058     self.assertEqual(ParseUnit('1.8G'), 1844)
1059     self.assertEqual(ParseUnit('8.28T'), 8682212)
1060
1061   def testSuffixes(self):
1062     for sep in ('', ' ', '   ', "\t", "\t "):
1063       for suffix, scale in TestParseUnit.SCALES:
1064         for func in (lambda x: x, str.lower, str.upper):
1065           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
1066                            1024 * scale)
1067
1068   def testInvalidInput(self):
1069     for sep in ('-', '_', ',', 'a'):
1070       for suffix, _ in TestParseUnit.SCALES:
1071         self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
1072
1073     for suffix, _ in TestParseUnit.SCALES:
1074       self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
1075
1076
1077 class TestParseCpuMask(unittest.TestCase):
1078   """Test case for the ParseCpuMask function."""
1079
1080   def testWellFormed(self):
1081     self.assertEqual(utils.ParseCpuMask(""), [])
1082     self.assertEqual(utils.ParseCpuMask("1"), [1])
1083     self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
1084
1085   def testInvalidInput(self):
1086     for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
1087       self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
1088
1089
1090 class TestSshKeys(testutils.GanetiTestCase):
1091   """Test case for the AddAuthorizedKey function"""
1092
1093   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1094   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1095            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1096
1097   def setUp(self):
1098     testutils.GanetiTestCase.setUp(self)
1099     self.tmpname = self._CreateTempFile()
1100     handle = open(self.tmpname, 'w')
1101     try:
1102       handle.write("%s\n" % TestSshKeys.KEY_A)
1103       handle.write("%s\n" % TestSshKeys.KEY_B)
1104     finally:
1105       handle.close()
1106
1107   def testAddingNewKey(self):
1108     utils.AddAuthorizedKey(self.tmpname,
1109                            'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1110
1111     self.assertFileContent(self.tmpname,
1112       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1113       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1114       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1115       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1116
1117   def testAddingAlmostButNotCompletelyTheSameKey(self):
1118     utils.AddAuthorizedKey(self.tmpname,
1119         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1120
1121     self.assertFileContent(self.tmpname,
1122       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1123       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1124       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1125       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1126
1127   def testAddingExistingKeyWithSomeMoreSpaces(self):
1128     utils.AddAuthorizedKey(self.tmpname,
1129         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1130
1131     self.assertFileContent(self.tmpname,
1132       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1133       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1134       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1135
1136   def testRemovingExistingKeyWithSomeMoreSpaces(self):
1137     utils.RemoveAuthorizedKey(self.tmpname,
1138         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1139
1140     self.assertFileContent(self.tmpname,
1141       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1142       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1143
1144   def testRemovingNonExistingKey(self):
1145     utils.RemoveAuthorizedKey(self.tmpname,
1146         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1147
1148     self.assertFileContent(self.tmpname,
1149       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1150       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1151       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1152
1153
1154 class TestEtcHosts(testutils.GanetiTestCase):
1155   """Test functions modifying /etc/hosts"""
1156
1157   def setUp(self):
1158     testutils.GanetiTestCase.setUp(self)
1159     self.tmpname = self._CreateTempFile()
1160     handle = open(self.tmpname, 'w')
1161     try:
1162       handle.write('# This is a test file for /etc/hosts\n')
1163       handle.write('127.0.0.1\tlocalhost\n')
1164       handle.write('192.0.2.1 router gw\n')
1165     finally:
1166       handle.close()
1167
1168   def testSettingNewIp(self):
1169     SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1170                      ['myhost'])
1171
1172     self.assertFileContent(self.tmpname,
1173       "# This is a test file for /etc/hosts\n"
1174       "127.0.0.1\tlocalhost\n"
1175       "192.0.2.1 router gw\n"
1176       "198.51.100.4\tmyhost.example.com myhost\n")
1177     self.assertFileMode(self.tmpname, 0644)
1178
1179   def testSettingExistingIp(self):
1180     SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1181                      ['myhost'])
1182
1183     self.assertFileContent(self.tmpname,
1184       "# This is a test file for /etc/hosts\n"
1185       "127.0.0.1\tlocalhost\n"
1186       "192.0.2.1\tmyhost.example.com myhost\n")
1187     self.assertFileMode(self.tmpname, 0644)
1188
1189   def testSettingDuplicateName(self):
1190     SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1191
1192     self.assertFileContent(self.tmpname,
1193       "# This is a test file for /etc/hosts\n"
1194       "127.0.0.1\tlocalhost\n"
1195       "192.0.2.1 router gw\n"
1196       "198.51.100.4\tmyhost\n")
1197     self.assertFileMode(self.tmpname, 0644)
1198
1199   def testRemovingExistingHost(self):
1200     RemoveEtcHostsEntry(self.tmpname, 'router')
1201
1202     self.assertFileContent(self.tmpname,
1203       "# This is a test file for /etc/hosts\n"
1204       "127.0.0.1\tlocalhost\n"
1205       "192.0.2.1 gw\n")
1206     self.assertFileMode(self.tmpname, 0644)
1207
1208   def testRemovingSingleExistingHost(self):
1209     RemoveEtcHostsEntry(self.tmpname, 'localhost')
1210
1211     self.assertFileContent(self.tmpname,
1212       "# This is a test file for /etc/hosts\n"
1213       "192.0.2.1 router gw\n")
1214     self.assertFileMode(self.tmpname, 0644)
1215
1216   def testRemovingNonExistingHost(self):
1217     RemoveEtcHostsEntry(self.tmpname, 'myhost')
1218
1219     self.assertFileContent(self.tmpname,
1220       "# This is a test file for /etc/hosts\n"
1221       "127.0.0.1\tlocalhost\n"
1222       "192.0.2.1 router gw\n")
1223     self.assertFileMode(self.tmpname, 0644)
1224
1225   def testRemovingAlias(self):
1226     RemoveEtcHostsEntry(self.tmpname, 'gw')
1227
1228     self.assertFileContent(self.tmpname,
1229       "# This is a test file for /etc/hosts\n"
1230       "127.0.0.1\tlocalhost\n"
1231       "192.0.2.1 router\n")
1232     self.assertFileMode(self.tmpname, 0644)
1233
1234
1235 class TestGetMounts(unittest.TestCase):
1236   """Test case for GetMounts()."""
1237
1238   TESTDATA = (
1239     "rootfs /     rootfs rw 0 0\n"
1240     "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1241     "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1242
1243   def setUp(self):
1244     self.tmpfile = tempfile.NamedTemporaryFile()
1245     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1246
1247   def testGetMounts(self):
1248     self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1249       [
1250         ("rootfs", "/", "rootfs", "rw"),
1251         ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1252         ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1253       ])
1254
1255
1256 class TestShellQuoting(unittest.TestCase):
1257   """Test case for shell quoting functions"""
1258
1259   def testShellQuote(self):
1260     self.assertEqual(ShellQuote('abc'), "abc")
1261     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1262     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1263     self.assertEqual(ShellQuote("a b c"), "'a b c'")
1264     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1265
1266   def testShellQuoteArgs(self):
1267     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1268     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1269     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1270
1271
1272 class TestListVisibleFiles(unittest.TestCase):
1273   """Test case for ListVisibleFiles"""
1274
1275   def setUp(self):
1276     self.path = tempfile.mkdtemp()
1277
1278   def tearDown(self):
1279     shutil.rmtree(self.path)
1280
1281   def _CreateFiles(self, files):
1282     for name in files:
1283       utils.WriteFile(os.path.join(self.path, name), data="test")
1284
1285   def _test(self, files, expected):
1286     self._CreateFiles(files)
1287     found = ListVisibleFiles(self.path)
1288     self.assertEqual(set(found), set(expected))
1289
1290   def testAllVisible(self):
1291     files = ["a", "b", "c"]
1292     expected = files
1293     self._test(files, expected)
1294
1295   def testNoneVisible(self):
1296     files = [".a", ".b", ".c"]
1297     expected = []
1298     self._test(files, expected)
1299
1300   def testSomeVisible(self):
1301     files = ["a", "b", ".c"]
1302     expected = ["a", "b"]
1303     self._test(files, expected)
1304
1305   def testNonAbsolutePath(self):
1306     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1307
1308   def testNonNormalizedPath(self):
1309     self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1310                           "/bin/../tmp")
1311
1312
1313 class TestNewUUID(unittest.TestCase):
1314   """Test case for NewUUID"""
1315
1316   def runTest(self):
1317     self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1318
1319
1320 class TestUniqueSequence(unittest.TestCase):
1321   """Test case for UniqueSequence"""
1322
1323   def _test(self, input, expected):
1324     self.assertEqual(utils.UniqueSequence(input), expected)
1325
1326   def runTest(self):
1327     # Ordered input
1328     self._test([1, 2, 3], [1, 2, 3])
1329     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1330     self._test([1, 2, 2, 3], [1, 2, 3])
1331     self._test([1, 2, 3, 3], [1, 2, 3])
1332
1333     # Unordered input
1334     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1335     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1336
1337     # Strings
1338     self._test(["a", "a"], ["a"])
1339     self._test(["a", "b"], ["a", "b"])
1340     self._test(["a", "b", "a"], ["a", "b"])
1341
1342
1343 class TestFindDuplicates(unittest.TestCase):
1344   """Test case for FindDuplicates"""
1345
1346   def _Test(self, seq, expected):
1347     result = utils.FindDuplicates(seq)
1348     self.assertEqual(result, utils.UniqueSequence(result))
1349     self.assertEqual(set(result), set(expected))
1350
1351   def test(self):
1352     self._Test([], [])
1353     self._Test([1, 2, 3], [])
1354     self._Test([9, 8, 8, 0, 5, 1, 7, 0, 6, 7], [8, 0, 7])
1355     for exp in [[1, 2, 3], [3, 2, 1]]:
1356       self._Test([1, 1, 2, 2, 3, 3], exp)
1357
1358     self._Test(["A", "a", "B"], [])
1359     self._Test(["a", "A", "a", "B"], ["a"])
1360     self._Test("Hello World out there!", ["e", " ", "o", "r", "t", "l"])
1361
1362     self._Test(self._Gen(False), [])
1363     self._Test(self._Gen(True), range(1, 10))
1364
1365   @staticmethod
1366   def _Gen(dup):
1367     for i in range(10):
1368       yield i
1369       if dup:
1370         for _ in range(i):
1371           yield i
1372
1373
1374 class TestFirstFree(unittest.TestCase):
1375   """Test case for the FirstFree function"""
1376
1377   def test(self):
1378     """Test FirstFree"""
1379     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1380     self.failUnlessEqual(FirstFree([]), None)
1381     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1382     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1383     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1384
1385
1386 class TestTailFile(testutils.GanetiTestCase):
1387   """Test case for the TailFile function"""
1388
1389   def testEmpty(self):
1390     fname = self._CreateTempFile()
1391     self.failUnlessEqual(TailFile(fname), [])
1392     self.failUnlessEqual(TailFile(fname, lines=25), [])
1393
1394   def testAllLines(self):
1395     data = ["test %d" % i for i in range(30)]
1396     for i in range(30):
1397       fname = self._CreateTempFile()
1398       fd = open(fname, "w")
1399       fd.write("\n".join(data[:i]))
1400       if i > 0:
1401         fd.write("\n")
1402       fd.close()
1403       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1404
1405   def testPartialLines(self):
1406     data = ["test %d" % i for i in range(30)]
1407     fname = self._CreateTempFile()
1408     fd = open(fname, "w")
1409     fd.write("\n".join(data))
1410     fd.write("\n")
1411     fd.close()
1412     for i in range(1, 30):
1413       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1414
1415   def testBigFile(self):
1416     data = ["test %d" % i for i in range(30)]
1417     fname = self._CreateTempFile()
1418     fd = open(fname, "w")
1419     fd.write("X" * 1048576)
1420     fd.write("\n")
1421     fd.write("\n".join(data))
1422     fd.write("\n")
1423     fd.close()
1424     for i in range(1, 30):
1425       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1426
1427
1428 class _BaseFileLockTest:
1429   """Test case for the FileLock class"""
1430
1431   def testSharedNonblocking(self):
1432     self.lock.Shared(blocking=False)
1433     self.lock.Close()
1434
1435   def testExclusiveNonblocking(self):
1436     self.lock.Exclusive(blocking=False)
1437     self.lock.Close()
1438
1439   def testUnlockNonblocking(self):
1440     self.lock.Unlock(blocking=False)
1441     self.lock.Close()
1442
1443   def testSharedBlocking(self):
1444     self.lock.Shared(blocking=True)
1445     self.lock.Close()
1446
1447   def testExclusiveBlocking(self):
1448     self.lock.Exclusive(blocking=True)
1449     self.lock.Close()
1450
1451   def testUnlockBlocking(self):
1452     self.lock.Unlock(blocking=True)
1453     self.lock.Close()
1454
1455   def testSharedExclusiveUnlock(self):
1456     self.lock.Shared(blocking=False)
1457     self.lock.Exclusive(blocking=False)
1458     self.lock.Unlock(blocking=False)
1459     self.lock.Close()
1460
1461   def testExclusiveSharedUnlock(self):
1462     self.lock.Exclusive(blocking=False)
1463     self.lock.Shared(blocking=False)
1464     self.lock.Unlock(blocking=False)
1465     self.lock.Close()
1466
1467   def testSimpleTimeout(self):
1468     # These will succeed on the first attempt, hence a short timeout
1469     self.lock.Shared(blocking=True, timeout=10.0)
1470     self.lock.Exclusive(blocking=False, timeout=10.0)
1471     self.lock.Unlock(blocking=True, timeout=10.0)
1472     self.lock.Close()
1473
1474   @staticmethod
1475   def _TryLockInner(filename, shared, blocking):
1476     lock = utils.FileLock.Open(filename)
1477
1478     if shared:
1479       fn = lock.Shared
1480     else:
1481       fn = lock.Exclusive
1482
1483     try:
1484       # The timeout doesn't really matter as the parent process waits for us to
1485       # finish anyway.
1486       fn(blocking=blocking, timeout=0.01)
1487     except errors.LockError, err:
1488       return False
1489
1490     return True
1491
1492   def _TryLock(self, *args):
1493     return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1494                                       *args)
1495
1496   def testTimeout(self):
1497     for blocking in [True, False]:
1498       self.lock.Exclusive(blocking=True)
1499       self.failIf(self._TryLock(False, blocking))
1500       self.failIf(self._TryLock(True, blocking))
1501
1502       self.lock.Shared(blocking=True)
1503       self.assert_(self._TryLock(True, blocking))
1504       self.failIf(self._TryLock(False, blocking))
1505
1506   def testCloseShared(self):
1507     self.lock.Close()
1508     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1509
1510   def testCloseExclusive(self):
1511     self.lock.Close()
1512     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1513
1514   def testCloseUnlock(self):
1515     self.lock.Close()
1516     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1517
1518
1519 class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1520   TESTDATA = "Hello World\n" * 10
1521
1522   def setUp(self):
1523     testutils.GanetiTestCase.setUp(self)
1524
1525     self.tmpfile = tempfile.NamedTemporaryFile()
1526     utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1527     self.lock = utils.FileLock.Open(self.tmpfile.name)
1528
1529     # Ensure "Open" didn't truncate file
1530     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1531
1532   def tearDown(self):
1533     self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1534
1535     testutils.GanetiTestCase.tearDown(self)
1536
1537
1538 class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1539   def setUp(self):
1540     self.tmpfile = tempfile.NamedTemporaryFile()
1541     self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1542
1543
1544 class TestTimeFunctions(unittest.TestCase):
1545   """Test case for time functions"""
1546
1547   def runTest(self):
1548     self.assertEqual(utils.SplitTime(1), (1, 0))
1549     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1550     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1551     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1552     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1553     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1554     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1555     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1556
1557     self.assertRaises(AssertionError, utils.SplitTime, -1)
1558
1559     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1560     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1561     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1562
1563     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1564                      1218448917.481)
1565     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1566
1567     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1568     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1569     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1570     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1571     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1572
1573
1574 class FieldSetTestCase(unittest.TestCase):
1575   """Test case for FieldSets"""
1576
1577   def testSimpleMatch(self):
1578     f = utils.FieldSet("a", "b", "c", "def")
1579     self.failUnless(f.Matches("a"))
1580     self.failIf(f.Matches("d"), "Substring matched")
1581     self.failIf(f.Matches("defghi"), "Prefix string matched")
1582     self.failIf(f.NonMatching(["b", "c"]))
1583     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1584     self.failUnless(f.NonMatching(["a", "d"]))
1585
1586   def testRegexMatch(self):
1587     f = utils.FieldSet("a", "b([0-9]+)", "c")
1588     self.failUnless(f.Matches("b1"))
1589     self.failUnless(f.Matches("b99"))
1590     self.failIf(f.Matches("b/1"))
1591     self.failIf(f.NonMatching(["b12", "c"]))
1592     self.failUnless(f.NonMatching(["a", "1"]))
1593
1594 class TestForceDictType(unittest.TestCase):
1595   """Test case for ForceDictType"""
1596   KEY_TYPES = {
1597     "a": constants.VTYPE_INT,
1598     "b": constants.VTYPE_BOOL,
1599     "c": constants.VTYPE_STRING,
1600     "d": constants.VTYPE_SIZE,
1601     "e": constants.VTYPE_MAYBE_STRING,
1602     }
1603
1604   def _fdt(self, dict, allowed_values=None):
1605     if allowed_values is None:
1606       utils.ForceDictType(dict, self.KEY_TYPES)
1607     else:
1608       utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
1609
1610     return dict
1611
1612   def testSimpleDict(self):
1613     self.assertEqual(self._fdt({}), {})
1614     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1615     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1616     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1617     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1618     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1619     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1620     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1621     self.assertEqual(self._fdt({'b': False}), {'b': False})
1622     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1623     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1624     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1625     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1626     self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1627     self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1628     self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1629     self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
1630
1631   def testErrors(self):
1632     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1633     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
1634     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1635     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1636     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1637     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1638     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1639     self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
1640     self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
1641     self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
1642                       {"b": "hello"}, {"b": "no-such-type"})
1643
1644
1645 class TestIsNormAbsPath(unittest.TestCase):
1646   """Testing case for IsNormAbsPath"""
1647
1648   def _pathTestHelper(self, path, result):
1649     if result:
1650       self.assert_(utils.IsNormAbsPath(path),
1651           "Path %s should result absolute and normalized" % path)
1652     else:
1653       self.assertFalse(utils.IsNormAbsPath(path),
1654           "Path %s should not result absolute and normalized" % path)
1655
1656   def testBase(self):
1657     self._pathTestHelper('/etc', True)
1658     self._pathTestHelper('/srv', True)
1659     self._pathTestHelper('etc', False)
1660     self._pathTestHelper('/etc/../root', False)
1661     self._pathTestHelper('/etc/', False)
1662
1663
1664 class TestSafeEncode(unittest.TestCase):
1665   """Test case for SafeEncode"""
1666
1667   def testAscii(self):
1668     for txt in [string.digits, string.letters, string.punctuation]:
1669       self.failUnlessEqual(txt, SafeEncode(txt))
1670
1671   def testDoubleEncode(self):
1672     for i in range(255):
1673       txt = SafeEncode(chr(i))
1674       self.failUnlessEqual(txt, SafeEncode(txt))
1675
1676   def testUnicode(self):
1677     # 1024 is high enough to catch non-direct ASCII mappings
1678     for i in range(1024):
1679       txt = SafeEncode(unichr(i))
1680       self.failUnlessEqual(txt, SafeEncode(txt))
1681
1682
1683 class TestFormatTime(unittest.TestCase):
1684   """Testing case for FormatTime"""
1685
1686   @staticmethod
1687   def _TestInProcess(tz, timestamp, expected):
1688     os.environ["TZ"] = tz
1689     time.tzset()
1690     return utils.FormatTime(timestamp) == expected
1691
1692   def _Test(self, *args):
1693     # Need to use separate process as we want to change TZ
1694     self.assert_(utils.RunInSeparateProcess(self._TestInProcess, *args))
1695
1696   def test(self):
1697     self._Test("UTC", 0, "1970-01-01 00:00:00")
1698     self._Test("America/Sao_Paulo", 1292606926, "2010-12-17 15:28:46")
1699     self._Test("Europe/London", 1292606926, "2010-12-17 17:28:46")
1700     self._Test("Europe/Zurich", 1292606926, "2010-12-17 18:28:46")
1701     self._Test("Australia/Sydney", 1292606926, "2010-12-18 04:28:46")
1702
1703   def testNone(self):
1704     self.failUnlessEqual(FormatTime(None), "N/A")
1705
1706   def testInvalid(self):
1707     self.failUnlessEqual(FormatTime(()), "N/A")
1708
1709   def testNow(self):
1710     # tests that we accept time.time input
1711     FormatTime(time.time())
1712     # tests that we accept int input
1713     FormatTime(int(time.time()))
1714
1715
1716 class TestFormatTimestampWithTZ(unittest.TestCase):
1717   @staticmethod
1718   def _TestInProcess(tz, timestamp, expected):
1719     os.environ["TZ"] = tz
1720     time.tzset()
1721     return utils.FormatTimestampWithTZ(timestamp) == expected
1722
1723   def _Test(self, *args):
1724     # Need to use separate process as we want to change TZ
1725     self.assert_(utils.RunInSeparateProcess(self._TestInProcess, *args))
1726
1727   def test(self):
1728     self._Test("UTC", 0, "1970-01-01 00:00:00 UTC")
1729     self._Test("America/Sao_Paulo", 1292606926, "2010-12-17 15:28:46 BRST")
1730     self._Test("Europe/London", 1292606926, "2010-12-17 17:28:46 GMT")
1731     self._Test("Europe/Zurich", 1292606926, "2010-12-17 18:28:46 CET")
1732     self._Test("Australia/Sydney", 1292606926, "2010-12-18 04:28:46 EST")
1733
1734
1735 class RunInSeparateProcess(unittest.TestCase):
1736   def test(self):
1737     for exp in [True, False]:
1738       def _child():
1739         return exp
1740
1741       self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1742
1743   def testArgs(self):
1744     for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1745       def _child(carg1, carg2):
1746         return carg1 == "Foo" and carg2 == arg
1747
1748       self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1749
1750   def testPid(self):
1751     parent_pid = os.getpid()
1752
1753     def _check():
1754       return os.getpid() == parent_pid
1755
1756     self.failIf(utils.RunInSeparateProcess(_check))
1757
1758   def testSignal(self):
1759     def _kill():
1760       os.kill(os.getpid(), signal.SIGTERM)
1761
1762     self.assertRaises(errors.GenericError,
1763                       utils.RunInSeparateProcess, _kill)
1764
1765   def testException(self):
1766     def _exc():
1767       raise errors.GenericError("This is a test")
1768
1769     self.assertRaises(errors.GenericError,
1770                       utils.RunInSeparateProcess, _exc)
1771
1772
1773 class TestFingerprintFiles(unittest.TestCase):
1774   def setUp(self):
1775     self.tmpfile = tempfile.NamedTemporaryFile()
1776     self.tmpfile2 = tempfile.NamedTemporaryFile()
1777     utils.WriteFile(self.tmpfile2.name, data="Hello World\n")
1778     self.results = {
1779       self.tmpfile.name: "da39a3ee5e6b4b0d3255bfef95601890afd80709",
1780       self.tmpfile2.name: "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a",
1781       }
1782
1783   def testSingleFile(self):
1784     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1785                      self.results[self.tmpfile.name])
1786
1787     self.assertEqual(utils._FingerprintFile("/no/such/file"), None)
1788
1789   def testBigFile(self):
1790     self.tmpfile.write("A" * 8192)
1791     self.tmpfile.flush()
1792     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1793                      "35b6795ca20d6dc0aff8c7c110c96cd1070b8c38")
1794
1795   def testMultiple(self):
1796     all_files = self.results.keys()
1797     all_files.append("/no/such/file")
1798     self.assertEqual(utils.FingerprintFiles(self.results.keys()), self.results)
1799
1800
1801 class TestUnescapeAndSplit(unittest.TestCase):
1802   """Testing case for UnescapeAndSplit"""
1803
1804   def setUp(self):
1805     # testing more that one separator for regexp safety
1806     self._seps = [",", "+", "."]
1807
1808   def testSimple(self):
1809     a = ["a", "b", "c", "d"]
1810     for sep in self._seps:
1811       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1812
1813   def testEscape(self):
1814     for sep in self._seps:
1815       a = ["a", "b\\" + sep + "c", "d"]
1816       b = ["a", "b" + sep + "c", "d"]
1817       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1818
1819   def testDoubleEscape(self):
1820     for sep in self._seps:
1821       a = ["a", "b\\\\", "c", "d"]
1822       b = ["a", "b\\", "c", "d"]
1823       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1824
1825   def testThreeEscape(self):
1826     for sep in self._seps:
1827       a = ["a", "b\\\\\\" + sep + "c", "d"]
1828       b = ["a", "b\\" + sep + "c", "d"]
1829       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1830
1831
1832 class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1833   def setUp(self):
1834     self.tmpdir = tempfile.mkdtemp()
1835
1836   def tearDown(self):
1837     shutil.rmtree(self.tmpdir)
1838
1839   def _checkRsaPrivateKey(self, key):
1840     lines = key.splitlines()
1841     return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1842             "-----END RSA PRIVATE KEY-----" in lines)
1843
1844   def _checkCertificate(self, cert):
1845     lines = cert.splitlines()
1846     return ("-----BEGIN CERTIFICATE-----" in lines and
1847             "-----END CERTIFICATE-----" in lines)
1848
1849   def test(self):
1850     for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1851       (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1852       self._checkRsaPrivateKey(key_pem)
1853       self._checkCertificate(cert_pem)
1854
1855       key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1856                                            key_pem)
1857       self.assert_(key.bits() >= 1024)
1858       self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1859       self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1860
1861       x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1862                                              cert_pem)
1863       self.failIf(x509.has_expired())
1864       self.assertEqual(x509.get_issuer().CN, common_name)
1865       self.assertEqual(x509.get_subject().CN, common_name)
1866       self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1867
1868   def testLegacy(self):
1869     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1870
1871     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1872
1873     cert1 = utils.ReadFile(cert1_filename)
1874
1875     self.assert_(self._checkRsaPrivateKey(cert1))
1876     self.assert_(self._checkCertificate(cert1))
1877
1878
1879 class TestPathJoin(unittest.TestCase):
1880   """Testing case for PathJoin"""
1881
1882   def testBasicItems(self):
1883     mlist = ["/a", "b", "c"]
1884     self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1885
1886   def testNonAbsPrefix(self):
1887     self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1888
1889   def testBackTrack(self):
1890     self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1891
1892   def testMultiAbs(self):
1893     self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1894
1895
1896 class TestValidateServiceName(unittest.TestCase):
1897   def testValid(self):
1898     testnames = [
1899       0, 1, 2, 3, 1024, 65000, 65534, 65535,
1900       "ganeti",
1901       "gnt-masterd",
1902       "HELLO_WORLD_SVC",
1903       "hello.world.1",
1904       "0", "80", "1111", "65535",
1905       ]
1906
1907     for name in testnames:
1908       self.assertEqual(utils.ValidateServiceName(name), name)
1909
1910   def testInvalid(self):
1911     testnames = [
1912       -15756, -1, 65536, 133428083,
1913       "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1914       "-8546", "-1", "65536",
1915       (129 * "A"),
1916       ]
1917
1918     for name in testnames:
1919       self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1920
1921
1922 class TestParseAsn1Generalizedtime(unittest.TestCase):
1923   def test(self):
1924     # UTC
1925     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1926     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1927                      1266860512)
1928     self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1929                      (2**31) - 1)
1930
1931     # With offset
1932     self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1933                      1266860512)
1934     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1935                      1266931012)
1936     self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1937                      1266931088)
1938     self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1939                      1266931295)
1940     self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1941                      3600)
1942
1943     # Leap seconds are not supported by datetime.datetime
1944     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1945                       "19841231235960+0000")
1946     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1947                       "19920630235960+0000")
1948
1949     # Errors
1950     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1951     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1952     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1953                       "20100222174152")
1954     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1955                       "Mon Feb 22 17:47:02 UTC 2010")
1956     self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1957                       "2010-02-22 17:42:02")
1958
1959
1960 class TestGetX509CertValidity(testutils.GanetiTestCase):
1961   def setUp(self):
1962     testutils.GanetiTestCase.setUp(self)
1963
1964     pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1965
1966     # Test whether we have pyOpenSSL 0.7 or above
1967     self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1968
1969     if not self.pyopenssl0_7:
1970       warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1971                     " function correctly")
1972
1973   def _LoadCert(self, name):
1974     return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1975                                            self._ReadTestData(name))
1976
1977   def test(self):
1978     validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1979     if self.pyopenssl0_7:
1980       self.assertEqual(validity, (1266919967, 1267524767))
1981     else:
1982       self.assertEqual(validity, (None, None))
1983
1984
1985 class TestSignX509Certificate(unittest.TestCase):
1986   KEY = "My private key!"
1987   KEY_OTHER = "Another key"
1988
1989   def test(self):
1990     # Generate certificate valid for 5 minutes
1991     (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1992
1993     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1994                                            cert_pem)
1995
1996     # No signature at all
1997     self.assertRaises(errors.GenericError,
1998                       utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1999
2000     # Invalid input
2001     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2002                       "", self.KEY)
2003     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2004                       "X-Ganeti-Signature: \n", self.KEY)
2005     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2006                       "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
2007     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2008                       "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
2009     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2010                       "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
2011
2012     # Invalid salt
2013     for salt in list("-_@$,:;/\\ \t\n"):
2014       self.assertRaises(errors.GenericError, utils.SignX509Certificate,
2015                         cert_pem, self.KEY, "foo%sbar" % salt)
2016
2017     for salt in ["HelloWorld", "salt", string.letters, string.digits,
2018                  utils.GenerateSecret(numbytes=4),
2019                  utils.GenerateSecret(numbytes=16),
2020                  "{123:456}".encode("hex")]:
2021       signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2022
2023       self._Check(cert, salt, signed_pem)
2024
2025       self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2026       self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2027       self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2028                                "lines----\n------ at\nthe end!"))
2029
2030   def _Check(self, cert, salt, pem):
2031     (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2032     self.assertEqual(salt, salt2)
2033     self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2034
2035     # Other key
2036     self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2037                       pem, self.KEY_OTHER)
2038
2039
2040 class TestMakedirs(unittest.TestCase):
2041   def setUp(self):
2042     self.tmpdir = tempfile.mkdtemp()
2043
2044   def tearDown(self):
2045     shutil.rmtree(self.tmpdir)
2046
2047   def testNonExisting(self):
2048     path = PathJoin(self.tmpdir, "foo")
2049     utils.Makedirs(path)
2050     self.assert_(os.path.isdir(path))
2051
2052   def testExisting(self):
2053     path = PathJoin(self.tmpdir, "foo")
2054     os.mkdir(path)
2055     utils.Makedirs(path)
2056     self.assert_(os.path.isdir(path))
2057
2058   def testRecursiveNonExisting(self):
2059     path = PathJoin(self.tmpdir, "foo/bar/baz")
2060     utils.Makedirs(path)
2061     self.assert_(os.path.isdir(path))
2062
2063   def testRecursiveExisting(self):
2064     path = PathJoin(self.tmpdir, "B/moo/xyz")
2065     self.assertFalse(os.path.exists(path))
2066     os.mkdir(PathJoin(self.tmpdir, "B"))
2067     utils.Makedirs(path)
2068     self.assert_(os.path.isdir(path))
2069
2070
2071 class TestRetry(testutils.GanetiTestCase):
2072   def setUp(self):
2073     testutils.GanetiTestCase.setUp(self)
2074     self.retries = 0
2075
2076   @staticmethod
2077   def _RaiseRetryAgain():
2078     raise utils.RetryAgain()
2079
2080   @staticmethod
2081   def _RaiseRetryAgainWithArg(args):
2082     raise utils.RetryAgain(*args)
2083
2084   def _WrongNestedLoop(self):
2085     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2086
2087   def _RetryAndSucceed(self, retries):
2088     if self.retries < retries:
2089       self.retries += 1
2090       raise utils.RetryAgain()
2091     else:
2092       return True
2093
2094   def testRaiseTimeout(self):
2095     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2096                           self._RaiseRetryAgain, 0.01, 0.02)
2097     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2098                           self._RetryAndSucceed, 0.01, 0, args=[1])
2099     self.failUnlessEqual(self.retries, 1)
2100
2101   def testComplete(self):
2102     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2103     self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2104                          True)
2105     self.failUnlessEqual(self.retries, 2)
2106
2107   def testNestedLoop(self):
2108     try:
2109       self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2110                             self._WrongNestedLoop, 0, 1)
2111     except utils.RetryTimeout:
2112       self.fail("Didn't detect inner loop's exception")
2113
2114   def testTimeoutArgument(self):
2115     retry_arg="my_important_debugging_message"
2116     try:
2117       utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2118     except utils.RetryTimeout, err:
2119       self.failUnlessEqual(err.args, (retry_arg, ))
2120     else:
2121       self.fail("Expected timeout didn't happen")
2122
2123   def testRaiseInnerWithExc(self):
2124     retry_arg="my_important_debugging_message"
2125     try:
2126       try:
2127         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2128                     args=[[errors.GenericError(retry_arg, retry_arg)]])
2129       except utils.RetryTimeout, err:
2130         err.RaiseInner()
2131       else:
2132         self.fail("Expected timeout didn't happen")
2133     except errors.GenericError, err:
2134       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2135     else:
2136       self.fail("Expected GenericError didn't happen")
2137
2138   def testRaiseInnerWithMsg(self):
2139     retry_arg="my_important_debugging_message"
2140     try:
2141       try:
2142         utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2143                     args=[[retry_arg, retry_arg]])
2144       except utils.RetryTimeout, err:
2145         err.RaiseInner()
2146       else:
2147         self.fail("Expected timeout didn't happen")
2148     except utils.RetryTimeout, err:
2149       self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2150     else:
2151       self.fail("Expected RetryTimeout didn't happen")
2152
2153
2154 class TestLineSplitter(unittest.TestCase):
2155   def test(self):
2156     lines = []
2157     ls = utils.LineSplitter(lines.append)
2158     ls.write("Hello World\n")
2159     self.assertEqual(lines, [])
2160     ls.write("Foo\n Bar\r\n ")
2161     ls.write("Baz")
2162     ls.write("Moo")
2163     self.assertEqual(lines, [])
2164     ls.flush()
2165     self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2166     ls.close()
2167     self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2168
2169   def _testExtra(self, line, all_lines, p1, p2):
2170     self.assertEqual(p1, 999)
2171     self.assertEqual(p2, "extra")
2172     all_lines.append(line)
2173
2174   def testExtraArgsNoFlush(self):
2175     lines = []
2176     ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2177     ls.write("\n\nHello World\n")
2178     ls.write("Foo\n Bar\r\n ")
2179     ls.write("")
2180     ls.write("Baz")
2181     ls.write("Moo\n\nx\n")
2182     self.assertEqual(lines, [])
2183     ls.close()
2184     self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2185                              "", "x"])
2186
2187
2188 class TestReadLockedPidFile(unittest.TestCase):
2189   def setUp(self):
2190     self.tmpdir = tempfile.mkdtemp()
2191
2192   def tearDown(self):
2193     shutil.rmtree(self.tmpdir)
2194
2195   def testNonExistent(self):
2196     path = PathJoin(self.tmpdir, "nonexist")
2197     self.assert_(utils.ReadLockedPidFile(path) is None)
2198
2199   def testUnlocked(self):
2200     path = PathJoin(self.tmpdir, "pid")
2201     utils.WriteFile(path, data="123")
2202     self.assert_(utils.ReadLockedPidFile(path) is None)
2203
2204   def testLocked(self):
2205     path = PathJoin(self.tmpdir, "pid")
2206     utils.WriteFile(path, data="123")
2207
2208     fl = utils.FileLock.Open(path)
2209     try:
2210       fl.Exclusive(blocking=True)
2211
2212       self.assertEqual(utils.ReadLockedPidFile(path), 123)
2213     finally:
2214       fl.Close()
2215
2216     self.assert_(utils.ReadLockedPidFile(path) is None)
2217
2218   def testError(self):
2219     path = PathJoin(self.tmpdir, "foobar", "pid")
2220     utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2221     # open(2) should return ENOTDIR
2222     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2223
2224
2225 class TestCertVerification(testutils.GanetiTestCase):
2226   def setUp(self):
2227     testutils.GanetiTestCase.setUp(self)
2228
2229     self.tmpdir = tempfile.mkdtemp()
2230
2231   def tearDown(self):
2232     shutil.rmtree(self.tmpdir)
2233
2234   def testVerifyCertificate(self):
2235     cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2236     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2237                                            cert_pem)
2238
2239     # Not checking return value as this certificate is expired
2240     utils.VerifyX509Certificate(cert, 30, 7)
2241
2242
2243 class TestVerifyCertificateInner(unittest.TestCase):
2244   def test(self):
2245     vci = utils._VerifyCertificateInner
2246
2247     # Valid
2248     self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2249                      (None, None))
2250
2251     # Not yet valid
2252     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2253     self.assertEqual(errcode, utils.CERT_WARNING)
2254
2255     # Expiring soon
2256     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2257     self.assertEqual(errcode, utils.CERT_ERROR)
2258
2259     (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2260     self.assertEqual(errcode, utils.CERT_WARNING)
2261
2262     (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2263     self.assertEqual(errcode, None)
2264
2265     # Expired
2266     (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2267     self.assertEqual(errcode, utils.CERT_ERROR)
2268
2269     (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2270     self.assertEqual(errcode, utils.CERT_ERROR)
2271
2272     (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2273     self.assertEqual(errcode, utils.CERT_ERROR)
2274
2275     (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2276     self.assertEqual(errcode, utils.CERT_ERROR)
2277
2278
2279 class TestHmacFunctions(unittest.TestCase):
2280   # Digests can be checked with "openssl sha1 -hmac $key"
2281   def testSha1Hmac(self):
2282     self.assertEqual(utils.Sha1Hmac("", ""),
2283                      "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2284     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2285                      "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2286     self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2287                      "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2288
2289     longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2290     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2291                      "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2292
2293   def testSha1HmacSalt(self):
2294     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2295                      "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2296     self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2297                      "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2298     self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2299                      "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2300
2301   def testVerifySha1Hmac(self):
2302     self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2303                                                "7d64b71fb76370690e1d")))
2304     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2305                                       ("f904c2476527c6d3e660"
2306                                        "9ab683c66fa0652cb1dc")))
2307
2308     digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2309     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2310     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2311                                       digest.lower()))
2312     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2313                                       digest.upper()))
2314     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2315                                       digest.title()))
2316
2317   def testVerifySha1HmacSalt(self):
2318     self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2319                                       ("17a4adc34d69c0d367d4"
2320                                        "ffbef96fd41d4df7a6e8"),
2321                                       salt="abc9"))
2322     self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2323                                       ("7f264f8114c9066afc9b"
2324                                        "b7636e1786d996d3cc0d"),
2325                                       salt="xyz0"))
2326
2327
2328 class TestIgnoreSignals(unittest.TestCase):
2329   """Test the IgnoreSignals decorator"""
2330
2331   @staticmethod
2332   def _Raise(exception):
2333     raise exception
2334
2335   @staticmethod
2336   def _Return(rval):
2337     return rval
2338
2339   def testIgnoreSignals(self):
2340     sock_err_intr = socket.error(errno.EINTR, "Message")
2341     sock_err_inval = socket.error(errno.EINVAL, "Message")
2342
2343     env_err_intr = EnvironmentError(errno.EINTR, "Message")
2344     env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2345
2346     self.assertRaises(socket.error, self._Raise, sock_err_intr)
2347     self.assertRaises(socket.error, self._Raise, sock_err_inval)
2348     self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2349     self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2350
2351     self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2352     self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2353     self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2354                       sock_err_inval)
2355     self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2356                       env_err_inval)
2357
2358     self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2359     self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2360
2361
2362 class TestEnsureDirs(unittest.TestCase):
2363   """Tests for EnsureDirs"""
2364
2365   def setUp(self):
2366     self.dir = tempfile.mkdtemp()
2367     self.old_umask = os.umask(0777)
2368
2369   def testEnsureDirs(self):
2370     utils.EnsureDirs([
2371         (PathJoin(self.dir, "foo"), 0777),
2372         (PathJoin(self.dir, "bar"), 0000),
2373         ])
2374     self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2375     self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2376
2377   def tearDown(self):
2378     os.rmdir(PathJoin(self.dir, "foo"))
2379     os.rmdir(PathJoin(self.dir, "bar"))
2380     os.rmdir(self.dir)
2381     os.umask(self.old_umask)
2382
2383
2384 class TestFormatSeconds(unittest.TestCase):
2385   def test(self):
2386     self.assertEqual(utils.FormatSeconds(1), "1s")
2387     self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2388     self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2389     self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2390     self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2391     self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2392     self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2393     self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2394     self.assertEqual(utils.FormatSeconds(-1), "-1s")
2395     self.assertEqual(utils.FormatSeconds(-282), "-282s")
2396     self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2397
2398   def testFloat(self):
2399     self.assertEqual(utils.FormatSeconds(1.3), "1s")
2400     self.assertEqual(utils.FormatSeconds(1.9), "2s")
2401     self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2402     self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2403
2404
2405 class TestIgnoreProcessNotFound(unittest.TestCase):
2406   @staticmethod
2407   def _WritePid(fd):
2408     os.write(fd, str(os.getpid()))
2409     os.close(fd)
2410     return True
2411
2412   def test(self):
2413     (pid_read_fd, pid_write_fd) = os.pipe()
2414
2415     # Start short-lived process which writes its PID to pipe
2416     self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2417     os.close(pid_write_fd)
2418
2419     # Read PID from pipe
2420     pid = int(os.read(pid_read_fd, 1024))
2421     os.close(pid_read_fd)
2422
2423     # Try to send signal to process which exited recently
2424     self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2425
2426
2427 class TestShellWriter(unittest.TestCase):
2428   def test(self):
2429     buf = StringIO()
2430     sw = utils.ShellWriter(buf)
2431     sw.Write("#!/bin/bash")
2432     sw.Write("if true; then")
2433     sw.IncIndent()
2434     try:
2435       sw.Write("echo true")
2436
2437       sw.Write("for i in 1 2 3")
2438       sw.Write("do")
2439       sw.IncIndent()
2440       try:
2441         self.assertEqual(sw._indent, 2)
2442         sw.Write("date")
2443       finally:
2444         sw.DecIndent()
2445       sw.Write("done")
2446     finally:
2447       sw.DecIndent()
2448     sw.Write("echo %s", utils.ShellQuote("Hello World"))
2449     sw.Write("exit 0")
2450
2451     self.assertEqual(sw._indent, 0)
2452
2453     output = buf.getvalue()
2454
2455     self.assert_(output.endswith("\n"))
2456
2457     lines = output.splitlines()
2458     self.assertEqual(len(lines), 9)
2459     self.assertEqual(lines[0], "#!/bin/bash")
2460     self.assert_(re.match(r"^\s+date$", lines[5]))
2461     self.assertEqual(lines[7], "echo 'Hello World'")
2462
2463   def testEmpty(self):
2464     buf = StringIO()
2465     sw = utils.ShellWriter(buf)
2466     sw = None
2467     self.assertEqual(buf.getvalue(), "")
2468
2469
2470 class TestCommaJoin(unittest.TestCase):
2471   def test(self):
2472     self.assertEqual(utils.CommaJoin([]), "")
2473     self.assertEqual(utils.CommaJoin([1, 2, 3]), "1, 2, 3")
2474     self.assertEqual(utils.CommaJoin(["Hello"]), "Hello")
2475     self.assertEqual(utils.CommaJoin(["Hello", "World"]), "Hello, World")
2476     self.assertEqual(utils.CommaJoin(["Hello", "World", 99]),
2477                      "Hello, World, 99")
2478
2479
2480 class TestFindMatch(unittest.TestCase):
2481   def test(self):
2482     data = {
2483       "aaaa": "Four A",
2484       "bb": {"Two B": True},
2485       re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
2486       }
2487
2488     self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
2489     self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
2490
2491     for i in ["foo", "bar", "bazX"]:
2492       for j in range(1, 100, 7):
2493         self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
2494                          ((1, 2, 3), [i, str(j)]))
2495
2496   def testNoMatch(self):
2497     self.assert_(utils.FindMatch({}, "") is None)
2498     self.assert_(utils.FindMatch({}, "foo") is None)
2499     self.assert_(utils.FindMatch({}, 1234) is None)
2500
2501     data = {
2502       "X": "Hello World",
2503       re.compile("^(something)$"): "Hello World",
2504       }
2505
2506     self.assert_(utils.FindMatch(data, "") is None)
2507     self.assert_(utils.FindMatch(data, "Hello World") is None)
2508
2509
2510 class TestFileID(testutils.GanetiTestCase):
2511   def testEquality(self):
2512     name = self._CreateTempFile()
2513     oldi = utils.GetFileID(path=name)
2514     self.failUnless(utils.VerifyFileID(oldi, oldi))
2515
2516   def testUpdate(self):
2517     name = self._CreateTempFile()
2518     oldi = utils.GetFileID(path=name)
2519     os.utime(name, None)
2520     fd = os.open(name, os.O_RDWR)
2521     try:
2522       newi = utils.GetFileID(fd=fd)
2523       self.failUnless(utils.VerifyFileID(oldi, newi))
2524       self.failUnless(utils.VerifyFileID(newi, oldi))
2525     finally:
2526       os.close(fd)
2527
2528   def testWriteFile(self):
2529     name = self._CreateTempFile()
2530     oldi = utils.GetFileID(path=name)
2531     mtime = oldi[2]
2532     os.utime(name, (mtime + 10, mtime + 10))
2533     self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
2534                       oldi, data="")
2535     os.utime(name, (mtime - 10, mtime - 10))
2536     utils.SafeWriteFile(name, oldi, data="")
2537     oldi = utils.GetFileID(path=name)
2538     mtime = oldi[2]
2539     os.utime(name, (mtime + 10, mtime + 10))
2540     # this doesn't raise, since we passed None
2541     utils.SafeWriteFile(name, None, data="")
2542
2543   def testError(self):
2544     t = tempfile.NamedTemporaryFile()
2545     self.assertRaises(errors.ProgrammerError, utils.GetFileID,
2546                       path=t.name, fd=t.fileno())
2547
2548
2549 class TimeMock:
2550   def __init__(self, values):
2551     self.values = values
2552
2553   def __call__(self):
2554     return self.values.pop(0)
2555
2556
2557 class TestRunningTimeout(unittest.TestCase):
2558   def setUp(self):
2559     self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
2560
2561   def testRemainingFloat(self):
2562     timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
2563     self.assertAlmostEqual(timeout.Remaining(), 4.7)
2564     self.assertAlmostEqual(timeout.Remaining(), 0.4)
2565     self.assertAlmostEqual(timeout.Remaining(), -1.5)
2566
2567   def testRemaining(self):
2568     self.time_fn = TimeMock([0, 2, 4, 5, 6])
2569     timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
2570     self.assertEqual(timeout.Remaining(), 3)
2571     self.assertEqual(timeout.Remaining(), 1)
2572     self.assertEqual(timeout.Remaining(), 0)
2573     self.assertEqual(timeout.Remaining(), -1)
2574
2575   def testRemainingNonNegative(self):
2576     timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
2577     self.assertAlmostEqual(timeout.Remaining(), 4.7)
2578     self.assertAlmostEqual(timeout.Remaining(), 0.4)
2579     self.assertEqual(timeout.Remaining(), 0.0)
2580
2581   def testNegativeTimeout(self):
2582     self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
2583
2584
2585 class TestTryConvert(unittest.TestCase):
2586   def test(self):
2587     for src, fn, result in [
2588       ("1", int, 1),
2589       ("a", int, "a"),
2590       ("", bool, False),
2591       ("a", bool, True),
2592       ]:
2593       self.assertEqual(utils.TryConvert(fn, src), result)
2594
2595
2596 class TestIsValidShellParam(unittest.TestCase):
2597   def test(self):
2598     for val, result in [
2599       ("abc", True),
2600       ("ab;cd", False),
2601       ]:
2602       self.assertEqual(utils.IsValidShellParam(val), result)
2603
2604
2605 class TestBuildShellCmd(unittest.TestCase):
2606   def test(self):
2607     self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
2608                       "ls %s", "ab;cd")
2609     self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
2610
2611
2612 class TestWriteFile(unittest.TestCase):
2613   def setUp(self):
2614     self.tfile = tempfile.NamedTemporaryFile()
2615     self.did_pre = False
2616     self.did_post = False
2617     self.did_write = False
2618
2619   def markPre(self, fd):
2620     self.did_pre = True
2621
2622   def markPost(self, fd):
2623     self.did_post = True
2624
2625   def markWrite(self, fd):
2626     self.did_write = True
2627
2628   def testWrite(self):
2629     data = "abc"
2630     utils.WriteFile(self.tfile.name, data=data)
2631     self.assertEqual(utils.ReadFile(self.tfile.name), data)
2632
2633   def testErrors(self):
2634     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2635                       self.tfile.name, data="test", fn=lambda fd: None)
2636     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
2637     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2638                       self.tfile.name, data="test", atime=0)
2639
2640   def testCalls(self):
2641     utils.WriteFile(self.tfile.name, fn=self.markWrite,
2642                     prewrite=self.markPre, postwrite=self.markPost)
2643     self.assertTrue(self.did_pre)
2644     self.assertTrue(self.did_post)
2645     self.assertTrue(self.did_write)
2646
2647   def testDryRun(self):
2648     orig = "abc"
2649     self.tfile.write(orig)
2650     self.tfile.flush()
2651     utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
2652     self.assertEqual(utils.ReadFile(self.tfile.name), orig)
2653
2654   def testTimes(self):
2655     f = self.tfile.name
2656     for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
2657                    (int(time.time()), 5000)]:
2658       utils.WriteFile(f, data="hello", atime=at, mtime=mt)
2659       st = os.stat(f)
2660       self.assertEqual(st.st_atime, at)
2661       self.assertEqual(st.st_mtime, mt)
2662
2663
2664   def testNoClose(self):
2665     data = "hello"
2666     self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
2667     fd = utils.WriteFile(self.tfile.name, data=data, close=False)
2668     try:
2669       os.lseek(fd, 0, 0)
2670       self.assertEqual(os.read(fd, 4096), data)
2671     finally:
2672       os.close(fd)
2673
2674
2675 class TestNormalizeAndValidateMac(unittest.TestCase):
2676   def testInvalid(self):
2677     self.assertRaises(errors.OpPrereqError,
2678                       utils.NormalizeAndValidateMac, "xxx")
2679
2680   def testNormalization(self):
2681     for mac in ["aa:bb:cc:dd:ee:ff", "00:AA:11:bB:22:cc"]:
2682       self.assertEqual(utils.NormalizeAndValidateMac(mac), mac.lower())
2683
2684
2685 class TestNiceSort(unittest.TestCase):
2686   def test(self):
2687     self.assertEqual(utils.NiceSort([]), [])
2688     self.assertEqual(utils.NiceSort(["foo"]), ["foo"])
2689     self.assertEqual(utils.NiceSort(["bar", ""]), ["", "bar"])
2690     self.assertEqual(utils.NiceSort([",", "."]), [",", "."])
2691     self.assertEqual(utils.NiceSort(["0.1", "0.2"]), ["0.1", "0.2"])
2692     self.assertEqual(utils.NiceSort(["0;099", "0,099", "0.1", "0.2"]),
2693                      ["0,099", "0.1", "0.2", "0;099"])
2694
2695     data = ["a0", "a1", "a99", "a20", "a2", "b10", "b70", "b00", "0000"]
2696     self.assertEqual(utils.NiceSort(data),
2697                      ["0000", "a0", "a1", "a2", "a20", "a99",
2698                       "b00", "b10", "b70"])
2699
2700     data = ["a0-0", "a1-0", "a99-10", "a20-3", "a0-4", "a99-3", "a09-2",
2701             "Z", "a9-1", "A", "b"]
2702     self.assertEqual(utils.NiceSort(data),
2703                      ["A", "Z", "a0-0", "a0-4", "a1-0", "a9-1", "a09-2",
2704                       "a20-3", "a99-3", "a99-10", "b"])
2705     self.assertEqual(utils.NiceSort(data, key=str.lower),
2706                      ["A", "a0-0", "a0-4", "a1-0", "a9-1", "a09-2",
2707                       "a20-3", "a99-3", "a99-10", "b", "Z"])
2708     self.assertEqual(utils.NiceSort(data, key=str.upper),
2709                      ["A", "a0-0", "a0-4", "a1-0", "a9-1", "a09-2",
2710                       "a20-3", "a99-3", "a99-10", "b", "Z"])
2711
2712   def testLargeA(self):
2713     data = [
2714       "Eegah9ei", "xij88brTulHYAv8IEOyU", "3jTwJPtrXOY22bwL2YoW",
2715       "Z8Ljf1Pf5eBfNg171wJR", "WvNJd91OoXvLzdEiEXa6", "uHXAyYYftCSG1o7qcCqe",
2716       "xpIUJeVT1Rp", "KOt7vn1dWXi", "a07h8feON165N67PIE", "bH4Q7aCu3PUPjK3JtH",
2717       "cPRi0lM7HLnSuWA2G9", "KVQqLPDjcPjf8T3oyzjcOsfkb",
2718       "guKJkXnkULealVC8CyF1xefym", "pqF8dkU5B1cMnyZuREaSOADYx",
2719       ]
2720     self.assertEqual(utils.NiceSort(data), [
2721       "3jTwJPtrXOY22bwL2YoW", "Eegah9ei", "KOt7vn1dWXi",
2722       "KVQqLPDjcPjf8T3oyzjcOsfkb", "WvNJd91OoXvLzdEiEXa6",
2723       "Z8Ljf1Pf5eBfNg171wJR", "a07h8feON165N67PIE", "bH4Q7aCu3PUPjK3JtH",
2724       "cPRi0lM7HLnSuWA2G9", "guKJkXnkULealVC8CyF1xefym",
2725       "pqF8dkU5B1cMnyZuREaSOADYx", "uHXAyYYftCSG1o7qcCqe",
2726       "xij88brTulHYAv8IEOyU", "xpIUJeVT1Rp"
2727       ])
2728
2729   def testLargeB(self):
2730     data = [
2731       "inst-0.0.0.0-0.0.0.0",
2732       "inst-0.1.0.0-0.0.0.0",
2733       "inst-0.2.0.0-0.0.0.0",
2734       "inst-0.2.1.0-0.0.0.0",
2735       "inst-0.2.2.0-0.0.0.0",
2736       "inst-0.2.2.0-0.0.0.9",
2737       "inst-0.2.2.0-0.0.3.9",
2738       "inst-0.2.2.0-0.2.0.9",
2739       "inst-0.2.2.0-0.9.0.9",
2740       "inst-0.20.2.0-0.0.0.0",
2741       "inst-0.20.2.0-0.9.0.9",
2742       "inst-10.020.2.0-0.9.0.10",
2743       "inst-15.020.2.0-0.9.1.00",
2744       "inst-100.020.2.0-0.9.0.9",
2745
2746       # Only the last group, not converted to a number anymore, differs
2747       "inst-100.020.2.0a999",
2748       "inst-100.020.2.0b000",
2749       "inst-100.020.2.0c10",
2750       "inst-100.020.2.0c101",
2751       "inst-100.020.2.0c2",
2752       "inst-100.020.2.0c20",
2753       "inst-100.020.2.0c3",
2754       "inst-100.020.2.0c39123",
2755       ]
2756
2757     rnd = random.Random(16205)
2758     for _ in range(10):
2759       testdata = data[:]
2760       rnd.shuffle(testdata)
2761       assert testdata != data
2762       self.assertEqual(utils.NiceSort(testdata), data)
2763
2764   class _CallCount:
2765     def __init__(self, fn):
2766       self.count = 0
2767       self.fn = fn
2768
2769     def __call__(self, *args):
2770       self.count += 1
2771       return self.fn(*args)
2772
2773   def testKeyfuncA(self):
2774     # Generate some random numbers
2775     rnd = random.Random(21131)
2776     numbers = [rnd.randint(0, 10000) for _ in range(999)]
2777     assert numbers != sorted(numbers)
2778
2779     # Convert to hex
2780     data = [hex(i) for i in numbers]
2781     datacopy = data[:]
2782
2783     keyfn = self._CallCount(lambda value: str(int(value, 16)))
2784
2785     # Sort with key function converting hex to decimal
2786     result = utils.NiceSort(data, key=keyfn)
2787
2788     self.assertEqual([hex(i) for i in sorted(numbers)], result)
2789     self.assertEqual(data, datacopy, msg="Input data was modified in NiceSort")
2790     self.assertEqual(keyfn.count, len(numbers),
2791                      msg="Key function was not called once per value")
2792
2793   class _TestData:
2794     def __init__(self, name, value):
2795       self.name = name
2796       self.value = value
2797
2798   def testKeyfuncB(self):
2799     rnd = random.Random(27396)
2800     data = []
2801     for i in range(123):
2802       v1 = rnd.randint(0, 5)
2803       v2 = rnd.randint(0, 5)
2804       data.append(self._TestData("inst-%s-%s-%s" % (v1, v2, i),
2805                                  (v1, v2, i)))
2806     rnd.shuffle(data)
2807     assert data != sorted(data, key=operator.attrgetter("name"))
2808
2809     keyfn = self._CallCount(operator.attrgetter("name"))
2810
2811     # Sort by name
2812     result = utils.NiceSort(data, key=keyfn)
2813
2814     self.assertEqual(result, sorted(data, key=operator.attrgetter("value")))
2815     self.assertEqual(keyfn.count, len(data),
2816                      msg="Key function was not called once per value")
2817
2818
2819 if __name__ == '__main__':
2820   testutils.GanetiTestProgram()