Merge branch 'devel-2.1'
[ganeti-local] / test / ganeti.utils_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the utils module"""
23
24 import unittest
25 import os
26 import time
27 import tempfile
28 import os.path
29 import os
30 import md5
31 import signal
32 import socket
33 import shutil
34 import re
35 import select
36 import string
37
38 import ganeti
39 import testutils
40 from ganeti import constants
41 from ganeti import utils
42 from ganeti import errors
43 from ganeti.utils import IsProcessAlive, RunCmd, \
44      RemoveFile, MatchNameComponent, FormatUnit, \
45      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
46      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
47      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
48      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
49      UnescapeAndSplit
50
51 from ganeti.errors import LockError, UnitParseError, GenericError, \
52      ProgrammerError
53
54
55 class TestIsProcessAlive(unittest.TestCase):
56   """Testing case for IsProcessAlive"""
57
58   def testExists(self):
59     mypid = os.getpid()
60     self.assert_(IsProcessAlive(mypid),
61                  "can't find myself running")
62
63   def testNotExisting(self):
64     pid_non_existing = os.fork()
65     if pid_non_existing == 0:
66       os._exit(0)
67     elif pid_non_existing < 0:
68       raise SystemError("can't fork")
69     os.waitpid(pid_non_existing, 0)
70     self.assert_(not IsProcessAlive(pid_non_existing),
71                  "nonexisting process detected")
72
73
74 class TestPidFileFunctions(unittest.TestCase):
75   """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
76
77   def setUp(self):
78     self.dir = tempfile.mkdtemp()
79     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
80     utils.DaemonPidFileName = self.f_dpn
81
82   def testPidFileFunctions(self):
83     pid_file = self.f_dpn('test')
84     utils.WritePidFile('test')
85     self.failUnless(os.path.exists(pid_file),
86                     "PID file should have been created")
87     read_pid = utils.ReadPidFile(pid_file)
88     self.failUnlessEqual(read_pid, os.getpid())
89     self.failUnless(utils.IsProcessAlive(read_pid))
90     self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
91     utils.RemovePidFile('test')
92     self.failIf(os.path.exists(pid_file),
93                 "PID file should not exist anymore")
94     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
95                          "ReadPidFile should return 0 for missing pid file")
96     fh = open(pid_file, "w")
97     fh.write("blah\n")
98     fh.close()
99     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
100                          "ReadPidFile should return 0 for invalid pid file")
101     utils.RemovePidFile('test')
102     self.failIf(os.path.exists(pid_file),
103                 "PID file should not exist anymore")
104
105   def testKill(self):
106     pid_file = self.f_dpn('child')
107     r_fd, w_fd = os.pipe()
108     new_pid = os.fork()
109     if new_pid == 0: #child
110       utils.WritePidFile('child')
111       os.write(w_fd, 'a')
112       signal.pause()
113       os._exit(0)
114       return
115     # else we are in the parent
116     # wait until the child has written the pid file
117     os.read(r_fd, 1)
118     read_pid = utils.ReadPidFile(pid_file)
119     self.failUnlessEqual(read_pid, new_pid)
120     self.failUnless(utils.IsProcessAlive(new_pid))
121     utils.KillProcess(new_pid, waitpid=True)
122     self.failIf(utils.IsProcessAlive(new_pid))
123     utils.RemovePidFile('child')
124     self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
125
126   def tearDown(self):
127     for name in os.listdir(self.dir):
128       os.unlink(os.path.join(self.dir, name))
129     os.rmdir(self.dir)
130
131
132 class TestRunCmd(testutils.GanetiTestCase):
133   """Testing case for the RunCmd function"""
134
135   def setUp(self):
136     testutils.GanetiTestCase.setUp(self)
137     self.magic = time.ctime() + " ganeti test"
138     self.fname = self._CreateTempFile()
139
140   def testOk(self):
141     """Test successful exit code"""
142     result = RunCmd("/bin/sh -c 'exit 0'")
143     self.assertEqual(result.exit_code, 0)
144     self.assertEqual(result.output, "")
145
146   def testFail(self):
147     """Test fail exit code"""
148     result = RunCmd("/bin/sh -c 'exit 1'")
149     self.assertEqual(result.exit_code, 1)
150     self.assertEqual(result.output, "")
151
152   def testStdout(self):
153     """Test standard output"""
154     cmd = 'echo -n "%s"' % self.magic
155     result = RunCmd("/bin/sh -c '%s'" % cmd)
156     self.assertEqual(result.stdout, self.magic)
157     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
158     self.assertEqual(result.output, "")
159     self.assertFileContent(self.fname, self.magic)
160
161   def testStderr(self):
162     """Test standard error"""
163     cmd = 'echo -n "%s"' % self.magic
164     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
165     self.assertEqual(result.stderr, self.magic)
166     result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
167     self.assertEqual(result.output, "")
168     self.assertFileContent(self.fname, self.magic)
169
170   def testCombined(self):
171     """Test combined output"""
172     cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
173     expected = "A" + self.magic + "B" + self.magic
174     result = RunCmd("/bin/sh -c '%s'" % cmd)
175     self.assertEqual(result.output, expected)
176     result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
177     self.assertEqual(result.output, "")
178     self.assertFileContent(self.fname, expected)
179
180   def testSignal(self):
181     """Test signal"""
182     result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
183     self.assertEqual(result.signal, 15)
184     self.assertEqual(result.output, "")
185
186   def testListRun(self):
187     """Test list runs"""
188     result = RunCmd(["true"])
189     self.assertEqual(result.signal, None)
190     self.assertEqual(result.exit_code, 0)
191     result = RunCmd(["/bin/sh", "-c", "exit 1"])
192     self.assertEqual(result.signal, None)
193     self.assertEqual(result.exit_code, 1)
194     result = RunCmd(["echo", "-n", self.magic])
195     self.assertEqual(result.signal, None)
196     self.assertEqual(result.exit_code, 0)
197     self.assertEqual(result.stdout, self.magic)
198
199   def testFileEmptyOutput(self):
200     """Test file output"""
201     result = RunCmd(["true"], output=self.fname)
202     self.assertEqual(result.signal, None)
203     self.assertEqual(result.exit_code, 0)
204     self.assertFileContent(self.fname, "")
205
206   def testLang(self):
207     """Test locale environment"""
208     old_env = os.environ.copy()
209     try:
210       os.environ["LANG"] = "en_US.UTF-8"
211       os.environ["LC_ALL"] = "en_US.UTF-8"
212       result = RunCmd(["locale"])
213       for line in result.output.splitlines():
214         key, value = line.split("=", 1)
215         # Ignore these variables, they're overridden by LC_ALL
216         if key == "LANG" or key == "LANGUAGE":
217           continue
218         self.failIf(value and value != "C" and value != '"C"',
219             "Variable %s is set to the invalid value '%s'" % (key, value))
220     finally:
221       os.environ = old_env
222
223   def testDefaultCwd(self):
224     """Test default working directory"""
225     self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
226
227   def testCwd(self):
228     """Test default working directory"""
229     self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
230     self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
231     cwd = os.getcwd()
232     self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
233
234
235 class TestRemoveFile(unittest.TestCase):
236   """Test case for the RemoveFile function"""
237
238   def setUp(self):
239     """Create a temp dir and file for each case"""
240     self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
241     fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
242     os.close(fd)
243
244   def tearDown(self):
245     if os.path.exists(self.tmpfile):
246       os.unlink(self.tmpfile)
247     os.rmdir(self.tmpdir)
248
249
250   def testIgnoreDirs(self):
251     """Test that RemoveFile() ignores directories"""
252     self.assertEqual(None, RemoveFile(self.tmpdir))
253
254
255   def testIgnoreNotExisting(self):
256     """Test that RemoveFile() ignores non-existing files"""
257     RemoveFile(self.tmpfile)
258     RemoveFile(self.tmpfile)
259
260
261   def testRemoveFile(self):
262     """Test that RemoveFile does remove a file"""
263     RemoveFile(self.tmpfile)
264     if os.path.exists(self.tmpfile):
265       self.fail("File '%s' not removed" % self.tmpfile)
266
267
268   def testRemoveSymlink(self):
269     """Test that RemoveFile does remove symlinks"""
270     symlink = self.tmpdir + "/symlink"
271     os.symlink("no-such-file", symlink)
272     RemoveFile(symlink)
273     if os.path.exists(symlink):
274       self.fail("File '%s' not removed" % symlink)
275     os.symlink(self.tmpfile, symlink)
276     RemoveFile(symlink)
277     if os.path.exists(symlink):
278       self.fail("File '%s' not removed" % symlink)
279
280
281 class TestRename(unittest.TestCase):
282   """Test case for RenameFile"""
283
284   def setUp(self):
285     """Create a temporary directory"""
286     self.tmpdir = tempfile.mkdtemp()
287     self.tmpfile = os.path.join(self.tmpdir, "test1")
288
289     # Touch the file
290     open(self.tmpfile, "w").close()
291
292   def tearDown(self):
293     """Remove temporary directory"""
294     shutil.rmtree(self.tmpdir)
295
296   def testSimpleRename1(self):
297     """Simple rename 1"""
298     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
299     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
300
301   def testSimpleRename2(self):
302     """Simple rename 2"""
303     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
304                      mkdir=True)
305     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
306
307   def testRenameMkdir(self):
308     """Rename with mkdir"""
309     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
310                      mkdir=True)
311     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
312     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
313
314     utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
315                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
316                      mkdir=True)
317     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
318     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
319     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
320
321
322 class TestMatchNameComponent(unittest.TestCase):
323   """Test case for the MatchNameComponent function"""
324
325   def testEmptyList(self):
326     """Test that there is no match against an empty list"""
327
328     self.failUnlessEqual(MatchNameComponent("", []), None)
329     self.failUnlessEqual(MatchNameComponent("test", []), None)
330
331   def testSingleMatch(self):
332     """Test that a single match is performed correctly"""
333     mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
334     for key in "test2", "test2.example", "test2.example.com":
335       self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
336
337   def testMultipleMatches(self):
338     """Test that a multiple match is returned as None"""
339     mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
340     for key in "test1", "test1.example":
341       self.failUnlessEqual(MatchNameComponent(key, mlist), None)
342
343   def testFullMatch(self):
344     """Test that a full match is returned correctly"""
345     key1 = "test1"
346     key2 = "test1.example"
347     mlist = [key2, key2 + ".com"]
348     self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
349     self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
350
351   def testCaseInsensitivePartialMatch(self):
352     """Test for the case_insensitive keyword"""
353     mlist = ["test1.example.com", "test2.example.net"]
354     self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
355                      "test2.example.net")
356     self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
357                      "test2.example.net")
358     self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
359                      "test2.example.net")
360     self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
361                      "test2.example.net")
362
363
364   def testCaseInsensitiveFullMatch(self):
365     mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
366     # Between the two ts1 a full string match non-case insensitive should work
367     self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
368                      None)
369     self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
370                      "ts1.ex")
371     self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
372                      "ts1.ex")
373     # Between the two ts2 only case differs, so only case-match works
374     self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
375                      "ts2.ex")
376     self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
377                      "Ts2.ex")
378     self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
379                      None)
380
381
382 class TestFormatUnit(unittest.TestCase):
383   """Test case for the FormatUnit function"""
384
385   def testMiB(self):
386     self.assertEqual(FormatUnit(1, 'h'), '1M')
387     self.assertEqual(FormatUnit(100, 'h'), '100M')
388     self.assertEqual(FormatUnit(1023, 'h'), '1023M')
389
390     self.assertEqual(FormatUnit(1, 'm'), '1')
391     self.assertEqual(FormatUnit(100, 'm'), '100')
392     self.assertEqual(FormatUnit(1023, 'm'), '1023')
393
394     self.assertEqual(FormatUnit(1024, 'm'), '1024')
395     self.assertEqual(FormatUnit(1536, 'm'), '1536')
396     self.assertEqual(FormatUnit(17133, 'm'), '17133')
397     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
398
399   def testGiB(self):
400     self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
401     self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
402     self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
403     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
404
405     self.assertEqual(FormatUnit(1024, 'g'), '1.0')
406     self.assertEqual(FormatUnit(1536, 'g'), '1.5')
407     self.assertEqual(FormatUnit(17133, 'g'), '16.7')
408     self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
409
410     self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
411     self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
412     self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
413
414   def testTiB(self):
415     self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
416     self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
417     self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
418
419     self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
420     self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
421     self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
422
423 class TestParseUnit(unittest.TestCase):
424   """Test case for the ParseUnit function"""
425
426   SCALES = (('', 1),
427             ('M', 1), ('G', 1024), ('T', 1024 * 1024),
428             ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
429             ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
430
431   def testRounding(self):
432     self.assertEqual(ParseUnit('0'), 0)
433     self.assertEqual(ParseUnit('1'), 4)
434     self.assertEqual(ParseUnit('2'), 4)
435     self.assertEqual(ParseUnit('3'), 4)
436
437     self.assertEqual(ParseUnit('124'), 124)
438     self.assertEqual(ParseUnit('125'), 128)
439     self.assertEqual(ParseUnit('126'), 128)
440     self.assertEqual(ParseUnit('127'), 128)
441     self.assertEqual(ParseUnit('128'), 128)
442     self.assertEqual(ParseUnit('129'), 132)
443     self.assertEqual(ParseUnit('130'), 132)
444
445   def testFloating(self):
446     self.assertEqual(ParseUnit('0'), 0)
447     self.assertEqual(ParseUnit('0.5'), 4)
448     self.assertEqual(ParseUnit('1.75'), 4)
449     self.assertEqual(ParseUnit('1.99'), 4)
450     self.assertEqual(ParseUnit('2.00'), 4)
451     self.assertEqual(ParseUnit('2.01'), 4)
452     self.assertEqual(ParseUnit('3.99'), 4)
453     self.assertEqual(ParseUnit('4.00'), 4)
454     self.assertEqual(ParseUnit('4.01'), 8)
455     self.assertEqual(ParseUnit('1.5G'), 1536)
456     self.assertEqual(ParseUnit('1.8G'), 1844)
457     self.assertEqual(ParseUnit('8.28T'), 8682212)
458
459   def testSuffixes(self):
460     for sep in ('', ' ', '   ', "\t", "\t "):
461       for suffix, scale in TestParseUnit.SCALES:
462         for func in (lambda x: x, str.lower, str.upper):
463           self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
464                            1024 * scale)
465
466   def testInvalidInput(self):
467     for sep in ('-', '_', ',', 'a'):
468       for suffix, _ in TestParseUnit.SCALES:
469         self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
470
471     for suffix, _ in TestParseUnit.SCALES:
472       self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
473
474
475 class TestSshKeys(testutils.GanetiTestCase):
476   """Test case for the AddAuthorizedKey function"""
477
478   KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
479   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
480            'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
481
482   def setUp(self):
483     testutils.GanetiTestCase.setUp(self)
484     self.tmpname = self._CreateTempFile()
485     handle = open(self.tmpname, 'w')
486     try:
487       handle.write("%s\n" % TestSshKeys.KEY_A)
488       handle.write("%s\n" % TestSshKeys.KEY_B)
489     finally:
490       handle.close()
491
492   def testAddingNewKey(self):
493     AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
494
495     self.assertFileContent(self.tmpname,
496       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
497       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
498       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
499       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
500
501   def testAddingAlmostButNotCompletelyTheSameKey(self):
502     AddAuthorizedKey(self.tmpname,
503         'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
504
505     self.assertFileContent(self.tmpname,
506       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
507       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
508       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
509       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
510
511   def testAddingExistingKeyWithSomeMoreSpaces(self):
512     AddAuthorizedKey(self.tmpname,
513         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
514
515     self.assertFileContent(self.tmpname,
516       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
517       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
518       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
519
520   def testRemovingExistingKeyWithSomeMoreSpaces(self):
521     RemoveAuthorizedKey(self.tmpname,
522         'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
523
524     self.assertFileContent(self.tmpname,
525       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
526       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
527
528   def testRemovingNonExistingKey(self):
529     RemoveAuthorizedKey(self.tmpname,
530         'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
531
532     self.assertFileContent(self.tmpname,
533       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
534       'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
535       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
536
537
538 class TestEtcHosts(testutils.GanetiTestCase):
539   """Test functions modifying /etc/hosts"""
540
541   def setUp(self):
542     testutils.GanetiTestCase.setUp(self)
543     self.tmpname = self._CreateTempFile()
544     handle = open(self.tmpname, 'w')
545     try:
546       handle.write('# This is a test file for /etc/hosts\n')
547       handle.write('127.0.0.1\tlocalhost\n')
548       handle.write('192.168.1.1 router gw\n')
549     finally:
550       handle.close()
551
552   def testSettingNewIp(self):
553     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
554
555     self.assertFileContent(self.tmpname,
556       "# This is a test file for /etc/hosts\n"
557       "127.0.0.1\tlocalhost\n"
558       "192.168.1.1 router gw\n"
559       "1.2.3.4\tmyhost.domain.tld myhost\n")
560     self.assertFileMode(self.tmpname, 0644)
561
562   def testSettingExistingIp(self):
563     SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
564                      ['myhost'])
565
566     self.assertFileContent(self.tmpname,
567       "# This is a test file for /etc/hosts\n"
568       "127.0.0.1\tlocalhost\n"
569       "192.168.1.1\tmyhost.domain.tld myhost\n")
570     self.assertFileMode(self.tmpname, 0644)
571
572   def testSettingDuplicateName(self):
573     SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
574
575     self.assertFileContent(self.tmpname,
576       "# This is a test file for /etc/hosts\n"
577       "127.0.0.1\tlocalhost\n"
578       "192.168.1.1 router gw\n"
579       "1.2.3.4\tmyhost\n")
580     self.assertFileMode(self.tmpname, 0644)
581
582   def testRemovingExistingHost(self):
583     RemoveEtcHostsEntry(self.tmpname, 'router')
584
585     self.assertFileContent(self.tmpname,
586       "# This is a test file for /etc/hosts\n"
587       "127.0.0.1\tlocalhost\n"
588       "192.168.1.1 gw\n")
589     self.assertFileMode(self.tmpname, 0644)
590
591   def testRemovingSingleExistingHost(self):
592     RemoveEtcHostsEntry(self.tmpname, 'localhost')
593
594     self.assertFileContent(self.tmpname,
595       "# This is a test file for /etc/hosts\n"
596       "192.168.1.1 router gw\n")
597     self.assertFileMode(self.tmpname, 0644)
598
599   def testRemovingNonExistingHost(self):
600     RemoveEtcHostsEntry(self.tmpname, 'myhost')
601
602     self.assertFileContent(self.tmpname,
603       "# This is a test file for /etc/hosts\n"
604       "127.0.0.1\tlocalhost\n"
605       "192.168.1.1 router gw\n")
606     self.assertFileMode(self.tmpname, 0644)
607
608   def testRemovingAlias(self):
609     RemoveEtcHostsEntry(self.tmpname, 'gw')
610
611     self.assertFileContent(self.tmpname,
612       "# This is a test file for /etc/hosts\n"
613       "127.0.0.1\tlocalhost\n"
614       "192.168.1.1 router\n")
615     self.assertFileMode(self.tmpname, 0644)
616
617
618 class TestShellQuoting(unittest.TestCase):
619   """Test case for shell quoting functions"""
620
621   def testShellQuote(self):
622     self.assertEqual(ShellQuote('abc'), "abc")
623     self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
624     self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
625     self.assertEqual(ShellQuote("a b c"), "'a b c'")
626     self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
627
628   def testShellQuoteArgs(self):
629     self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
630     self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
631     self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
632
633
634 class TestTcpPing(unittest.TestCase):
635   """Testcase for TCP version of ping - against listen(2)ing port"""
636
637   def setUp(self):
638     self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
639     self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
640     self.listenerport = self.listener.getsockname()[1]
641     self.listener.listen(1)
642
643   def tearDown(self):
644     self.listener.shutdown(socket.SHUT_RDWR)
645     del self.listener
646     del self.listenerport
647
648   def testTcpPingToLocalHostAccept(self):
649     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
650                          self.listenerport,
651                          timeout=10,
652                          live_port_needed=True,
653                          source=constants.LOCALHOST_IP_ADDRESS,
654                          ),
655                  "failed to connect to test listener")
656
657     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
658                          self.listenerport,
659                          timeout=10,
660                          live_port_needed=True,
661                          ),
662                  "failed to connect to test listener (no source)")
663
664
665 class TestTcpPingDeaf(unittest.TestCase):
666   """Testcase for TCP version of ping - against non listen(2)ing port"""
667
668   def setUp(self):
669     self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
670     self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
671     self.deaflistenerport = self.deaflistener.getsockname()[1]
672
673   def tearDown(self):
674     del self.deaflistener
675     del self.deaflistenerport
676
677   def testTcpPingToLocalHostAcceptDeaf(self):
678     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
679                         self.deaflistenerport,
680                         timeout=constants.TCP_PING_TIMEOUT,
681                         live_port_needed=True,
682                         source=constants.LOCALHOST_IP_ADDRESS,
683                         ), # need successful connect(2)
684                 "successfully connected to deaf listener")
685
686     self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
687                         self.deaflistenerport,
688                         timeout=constants.TCP_PING_TIMEOUT,
689                         live_port_needed=True,
690                         ), # need successful connect(2)
691                 "successfully connected to deaf listener (no source addr)")
692
693   def testTcpPingToLocalHostNoAccept(self):
694     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
695                          self.deaflistenerport,
696                          timeout=constants.TCP_PING_TIMEOUT,
697                          live_port_needed=False,
698                          source=constants.LOCALHOST_IP_ADDRESS,
699                          ), # ECONNREFUSED is OK
700                  "failed to ping alive host on deaf port")
701
702     self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
703                          self.deaflistenerport,
704                          timeout=constants.TCP_PING_TIMEOUT,
705                          live_port_needed=False,
706                          ), # ECONNREFUSED is OK
707                  "failed to ping alive host on deaf port (no source addr)")
708
709
710 class TestOwnIpAddress(unittest.TestCase):
711   """Testcase for OwnIpAddress"""
712
713   def testOwnLoopback(self):
714     """check having the loopback ip"""
715     self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
716                     "Should own the loopback address")
717
718   def testNowOwnAddress(self):
719     """check that I don't own an address"""
720
721     # network 192.0.2.0/24 is reserved for test/documentation as per
722     # rfc 3330, so we *should* not have an address of this range... if
723     # this fails, we should extend the test to multiple addresses
724     DST_IP = "192.0.2.1"
725     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
726
727
728 class TestListVisibleFiles(unittest.TestCase):
729   """Test case for ListVisibleFiles"""
730
731   def setUp(self):
732     self.path = tempfile.mkdtemp()
733
734   def tearDown(self):
735     shutil.rmtree(self.path)
736
737   def _test(self, files, expected):
738     # Sort a copy
739     expected = expected[:]
740     expected.sort()
741
742     for name in files:
743       f = open(os.path.join(self.path, name), 'w')
744       try:
745         f.write("Test\n")
746       finally:
747         f.close()
748
749     found = ListVisibleFiles(self.path)
750     found.sort()
751
752     self.assertEqual(found, expected)
753
754   def testAllVisible(self):
755     files = ["a", "b", "c"]
756     expected = files
757     self._test(files, expected)
758
759   def testNoneVisible(self):
760     files = [".a", ".b", ".c"]
761     expected = []
762     self._test(files, expected)
763
764   def testSomeVisible(self):
765     files = ["a", "b", ".c"]
766     expected = ["a", "b"]
767     self._test(files, expected)
768
769
770 class TestNewUUID(unittest.TestCase):
771   """Test case for NewUUID"""
772
773   _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
774                         '[a-f0-9]{4}-[a-f0-9]{12}$')
775
776   def runTest(self):
777     self.failUnless(self._re_uuid.match(utils.NewUUID()))
778
779
780 class TestUniqueSequence(unittest.TestCase):
781   """Test case for UniqueSequence"""
782
783   def _test(self, input, expected):
784     self.assertEqual(utils.UniqueSequence(input), expected)
785
786   def runTest(self):
787     # Ordered input
788     self._test([1, 2, 3], [1, 2, 3])
789     self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
790     self._test([1, 2, 2, 3], [1, 2, 3])
791     self._test([1, 2, 3, 3], [1, 2, 3])
792
793     # Unordered input
794     self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
795     self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
796
797     # Strings
798     self._test(["a", "a"], ["a"])
799     self._test(["a", "b"], ["a", "b"])
800     self._test(["a", "b", "a"], ["a", "b"])
801
802
803 class TestFirstFree(unittest.TestCase):
804   """Test case for the FirstFree function"""
805
806   def test(self):
807     """Test FirstFree"""
808     self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
809     self.failUnlessEqual(FirstFree([]), None)
810     self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
811     self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
812     self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
813
814
815 class TestTailFile(testutils.GanetiTestCase):
816   """Test case for the TailFile function"""
817
818   def testEmpty(self):
819     fname = self._CreateTempFile()
820     self.failUnlessEqual(TailFile(fname), [])
821     self.failUnlessEqual(TailFile(fname, lines=25), [])
822
823   def testAllLines(self):
824     data = ["test %d" % i for i in range(30)]
825     for i in range(30):
826       fname = self._CreateTempFile()
827       fd = open(fname, "w")
828       fd.write("\n".join(data[:i]))
829       if i > 0:
830         fd.write("\n")
831       fd.close()
832       self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
833
834   def testPartialLines(self):
835     data = ["test %d" % i for i in range(30)]
836     fname = self._CreateTempFile()
837     fd = open(fname, "w")
838     fd.write("\n".join(data))
839     fd.write("\n")
840     fd.close()
841     for i in range(1, 30):
842       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
843
844   def testBigFile(self):
845     data = ["test %d" % i for i in range(30)]
846     fname = self._CreateTempFile()
847     fd = open(fname, "w")
848     fd.write("X" * 1048576)
849     fd.write("\n")
850     fd.write("\n".join(data))
851     fd.write("\n")
852     fd.close()
853     for i in range(1, 30):
854       self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
855
856
857 class TestFileLock(unittest.TestCase):
858   """Test case for the FileLock class"""
859
860   def setUp(self):
861     self.tmpfile = tempfile.NamedTemporaryFile()
862     self.lock = utils.FileLock(self.tmpfile.name)
863
864   def testSharedNonblocking(self):
865     self.lock.Shared(blocking=False)
866     self.lock.Close()
867
868   def testExclusiveNonblocking(self):
869     self.lock.Exclusive(blocking=False)
870     self.lock.Close()
871
872   def testUnlockNonblocking(self):
873     self.lock.Unlock(blocking=False)
874     self.lock.Close()
875
876   def testSharedBlocking(self):
877     self.lock.Shared(blocking=True)
878     self.lock.Close()
879
880   def testExclusiveBlocking(self):
881     self.lock.Exclusive(blocking=True)
882     self.lock.Close()
883
884   def testUnlockBlocking(self):
885     self.lock.Unlock(blocking=True)
886     self.lock.Close()
887
888   def testSharedExclusiveUnlock(self):
889     self.lock.Shared(blocking=False)
890     self.lock.Exclusive(blocking=False)
891     self.lock.Unlock(blocking=False)
892     self.lock.Close()
893
894   def testExclusiveSharedUnlock(self):
895     self.lock.Exclusive(blocking=False)
896     self.lock.Shared(blocking=False)
897     self.lock.Unlock(blocking=False)
898     self.lock.Close()
899
900   def testCloseShared(self):
901     self.lock.Close()
902     self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
903
904   def testCloseExclusive(self):
905     self.lock.Close()
906     self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
907
908   def testCloseUnlock(self):
909     self.lock.Close()
910     self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
911
912
913 class TestTimeFunctions(unittest.TestCase):
914   """Test case for time functions"""
915
916   def runTest(self):
917     self.assertEqual(utils.SplitTime(1), (1, 0))
918     self.assertEqual(utils.SplitTime(1.5), (1, 500000))
919     self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
920     self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
921     self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
922     self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
923     self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
924     self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
925
926     self.assertRaises(AssertionError, utils.SplitTime, -1)
927
928     self.assertEqual(utils.MergeTime((1, 0)), 1.0)
929     self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
930     self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
931
932     self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
933                      1218448917.481)
934     self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
935
936     self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
937     self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
938     self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
939     self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
940     self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
941
942
943 class FieldSetTestCase(unittest.TestCase):
944   """Test case for FieldSets"""
945
946   def testSimpleMatch(self):
947     f = utils.FieldSet("a", "b", "c", "def")
948     self.failUnless(f.Matches("a"))
949     self.failIf(f.Matches("d"), "Substring matched")
950     self.failIf(f.Matches("defghi"), "Prefix string matched")
951     self.failIf(f.NonMatching(["b", "c"]))
952     self.failIf(f.NonMatching(["a", "b", "c", "def"]))
953     self.failUnless(f.NonMatching(["a", "d"]))
954
955   def testRegexMatch(self):
956     f = utils.FieldSet("a", "b([0-9]+)", "c")
957     self.failUnless(f.Matches("b1"))
958     self.failUnless(f.Matches("b99"))
959     self.failIf(f.Matches("b/1"))
960     self.failIf(f.NonMatching(["b12", "c"]))
961     self.failUnless(f.NonMatching(["a", "1"]))
962
963 class TestForceDictType(unittest.TestCase):
964   """Test case for ForceDictType"""
965
966   def setUp(self):
967     self.key_types = {
968       'a': constants.VTYPE_INT,
969       'b': constants.VTYPE_BOOL,
970       'c': constants.VTYPE_STRING,
971       'd': constants.VTYPE_SIZE,
972       }
973
974   def _fdt(self, dict, allowed_values=None):
975     if allowed_values is None:
976       ForceDictType(dict, self.key_types)
977     else:
978       ForceDictType(dict, self.key_types, allowed_values=allowed_values)
979
980     return dict
981
982   def testSimpleDict(self):
983     self.assertEqual(self._fdt({}), {})
984     self.assertEqual(self._fdt({'a': 1}), {'a': 1})
985     self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
986     self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
987     self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
988     self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
989     self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
990     self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
991     self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
992     self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
993     self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
994     self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
995
996   def testErrors(self):
997     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
998     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
999     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1000     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1001
1002
1003 class TestIsAbsNormPath(unittest.TestCase):
1004   """Testing case for IsProcessAlive"""
1005
1006   def _pathTestHelper(self, path, result):
1007     if result:
1008       self.assert_(IsNormAbsPath(path),
1009           "Path %s should result absolute and normalized" % path)
1010     else:
1011       self.assert_(not IsNormAbsPath(path),
1012           "Path %s should not result absolute and normalized" % path)
1013
1014   def testBase(self):
1015     self._pathTestHelper('/etc', True)
1016     self._pathTestHelper('/srv', True)
1017     self._pathTestHelper('etc', False)
1018     self._pathTestHelper('/etc/../root', False)
1019     self._pathTestHelper('/etc/', False)
1020
1021
1022 class TestSafeEncode(unittest.TestCase):
1023   """Test case for SafeEncode"""
1024
1025   def testAscii(self):
1026     for txt in [string.digits, string.letters, string.punctuation]:
1027       self.failUnlessEqual(txt, SafeEncode(txt))
1028
1029   def testDoubleEncode(self):
1030     for i in range(255):
1031       txt = SafeEncode(chr(i))
1032       self.failUnlessEqual(txt, SafeEncode(txt))
1033
1034   def testUnicode(self):
1035     # 1024 is high enough to catch non-direct ASCII mappings
1036     for i in range(1024):
1037       txt = SafeEncode(unichr(i))
1038       self.failUnlessEqual(txt, SafeEncode(txt))
1039
1040
1041 class TestFormatTime(unittest.TestCase):
1042   """Testing case for FormatTime"""
1043
1044   def testNone(self):
1045     self.failUnlessEqual(FormatTime(None), "N/A")
1046
1047   def testInvalid(self):
1048     self.failUnlessEqual(FormatTime(()), "N/A")
1049
1050   def testNow(self):
1051     # tests that we accept time.time input
1052     FormatTime(time.time())
1053     # tests that we accept int input
1054     FormatTime(int(time.time()))
1055
1056
1057 class TestFingerprintFile(unittest.TestCase):
1058   def setUp(self):
1059     self.tmpfile = tempfile.NamedTemporaryFile()
1060
1061   def test(self):
1062     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1063                      "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1064
1065     utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1066     self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1067                      "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1068
1069
1070 class TestUnescapeAndSplit(unittest.TestCase):
1071   """Testing case for UnescapeAndSplit"""
1072
1073   def setUp(self):
1074     # testing more that one separator for regexp safety
1075     self._seps = [",", "+", "."]
1076
1077   def testSimple(self):
1078     a = ["a", "b", "c", "d"]
1079     for sep in self._seps:
1080       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1081
1082   def testEscape(self):
1083     for sep in self._seps:
1084       a = ["a", "b\\" + sep + "c", "d"]
1085       b = ["a", "b" + sep + "c", "d"]
1086       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1087
1088   def testDoubleEscape(self):
1089     for sep in self._seps:
1090       a = ["a", "b\\\\", "c", "d"]
1091       b = ["a", "b\\", "c", "d"]
1092       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1093
1094   def testThreeEscape(self):
1095     for sep in self._seps:
1096       a = ["a", "b\\\\\\" + sep + "c", "d"]
1097       b = ["a", "b\\" + sep + "c", "d"]
1098       self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1099
1100
1101 class TestGenerateSelfSignedSslCert(unittest.TestCase):
1102   def setUp(self):
1103     self.tmpdir = tempfile.mkdtemp()
1104
1105   def tearDown(self):
1106     shutil.rmtree(self.tmpdir)
1107
1108   def _checkPrivateRsaKey(self, key):
1109     lines = key.splitlines()
1110     self.assert_("-----BEGIN RSA PRIVATE KEY-----" in lines)
1111     self.assert_("-----END RSA PRIVATE KEY-----" in lines)
1112
1113   def _checkRsaCertificate(self, cert):
1114     lines = cert.splitlines()
1115     self.assert_("-----BEGIN CERTIFICATE-----" in lines)
1116     self.assert_("-----END CERTIFICATE-----" in lines)
1117
1118   def testSingleFile(self):
1119     cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1120
1121     utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1122
1123     cert1 = utils.ReadFile(cert1_filename)
1124
1125     self._checkPrivateRsaKey(cert1)
1126     self._checkRsaCertificate(cert1)
1127
1128
1129 if __name__ == '__main__':
1130   testutils.GanetiTestProgram()