Locking related fixes for networks
[ganeti-local] / test / ganeti.utils.io_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.utils.io"""
23
24 import os
25 import tempfile
26 import unittest
27 import shutil
28 import glob
29 import time
30 import signal
31 import stat
32 import errno
33
34 from ganeti import constants
35 from ganeti import utils
36 from ganeti import compat
37 from ganeti import errors
38
39 import testutils
40
41
42 class TestReadFile(testutils.GanetiTestCase):
43   def testReadAll(self):
44     data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
45     self.assertEqual(len(data), 814)
46
47     h = compat.md5_hash()
48     h.update(data)
49     self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
50
51   def testReadSize(self):
52     data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
53                           size=100)
54     self.assertEqual(len(data), 100)
55
56     h = compat.md5_hash()
57     h.update(data)
58     self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
59
60   def testCallback(self):
61     def _Cb(fh):
62       self.assertEqual(fh.tell(), 0)
63     data = utils.ReadFile(self._TestDataFilename("cert1.pem"), preread=_Cb)
64     self.assertEqual(len(data), 814)
65
66   def testError(self):
67     self.assertRaises(EnvironmentError, utils.ReadFile,
68                       "/dev/null/does-not-exist")
69
70
71 class TestReadOneLineFile(testutils.GanetiTestCase):
72   def setUp(self):
73     testutils.GanetiTestCase.setUp(self)
74
75   def testDefault(self):
76     data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"))
77     self.assertEqual(len(data), 27)
78     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
79
80   def testNotStrict(self):
81     data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"),
82                                  strict=False)
83     self.assertEqual(len(data), 27)
84     self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
85
86   def testStrictFailure(self):
87     self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
88                       self._TestDataFilename("cert1.pem"), strict=True)
89
90   def testLongLine(self):
91     dummydata = (1024 * "Hello World! ")
92     myfile = self._CreateTempFile()
93     utils.WriteFile(myfile, data=dummydata)
94     datastrict = utils.ReadOneLineFile(myfile, strict=True)
95     datalax = utils.ReadOneLineFile(myfile, strict=False)
96     self.assertEqual(dummydata, datastrict)
97     self.assertEqual(dummydata, datalax)
98
99   def testNewline(self):
100     myfile = self._CreateTempFile()
101     myline = "myline"
102     for nl in ["", "\n", "\r\n"]:
103       dummydata = "%s%s" % (myline, nl)
104       utils.WriteFile(myfile, data=dummydata)
105       datalax = utils.ReadOneLineFile(myfile, strict=False)
106       self.assertEqual(myline, datalax)
107       datastrict = utils.ReadOneLineFile(myfile, strict=True)
108       self.assertEqual(myline, datastrict)
109
110   def testWhitespaceAndMultipleLines(self):
111     myfile = self._CreateTempFile()
112     for nl in ["", "\n", "\r\n"]:
113       for ws in [" ", "\t", "\t\t  \t", "\t "]:
114         dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
115         utils.WriteFile(myfile, data=dummydata)
116         datalax = utils.ReadOneLineFile(myfile, strict=False)
117         if nl:
118           self.assert_(set("\r\n") & set(dummydata))
119           self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
120                             myfile, strict=True)
121           explen = len("Foo bar baz ") + len(ws)
122           self.assertEqual(len(datalax), explen)
123           self.assertEqual(datalax, dummydata[:explen])
124           self.assertFalse(set("\r\n") & set(datalax))
125         else:
126           datastrict = utils.ReadOneLineFile(myfile, strict=True)
127           self.assertEqual(dummydata, datastrict)
128           self.assertEqual(dummydata, datalax)
129
130   def testEmptylines(self):
131     myfile = self._CreateTempFile()
132     myline = "myline"
133     for nl in ["\n", "\r\n"]:
134       for ol in ["", "otherline"]:
135         dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
136         utils.WriteFile(myfile, data=dummydata)
137         self.assert_(set("\r\n") & set(dummydata))
138         datalax = utils.ReadOneLineFile(myfile, strict=False)
139         self.assertEqual(myline, datalax)
140         if ol:
141           self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
142                             myfile, strict=True)
143         else:
144           datastrict = utils.ReadOneLineFile(myfile, strict=True)
145           self.assertEqual(myline, datastrict)
146
147   def testEmptyfile(self):
148     myfile = self._CreateTempFile()
149     self.assertRaises(errors.GenericError, utils.ReadOneLineFile, myfile)
150
151
152 class TestTimestampForFilename(unittest.TestCase):
153   def test(self):
154     self.assert_("." not in utils.TimestampForFilename())
155     self.assert_(":" not in utils.TimestampForFilename())
156
157
158 class TestCreateBackup(testutils.GanetiTestCase):
159   def setUp(self):
160     testutils.GanetiTestCase.setUp(self)
161
162     self.tmpdir = tempfile.mkdtemp()
163
164   def tearDown(self):
165     testutils.GanetiTestCase.tearDown(self)
166
167     shutil.rmtree(self.tmpdir)
168
169   def testEmpty(self):
170     filename = utils.PathJoin(self.tmpdir, "config.data")
171     utils.WriteFile(filename, data="")
172     bname = utils.CreateBackup(filename)
173     self.assertFileContent(bname, "")
174     self.assertEqual(len(glob.glob("%s*" % filename)), 2)
175     utils.CreateBackup(filename)
176     self.assertEqual(len(glob.glob("%s*" % filename)), 3)
177     utils.CreateBackup(filename)
178     self.assertEqual(len(glob.glob("%s*" % filename)), 4)
179
180     fifoname = utils.PathJoin(self.tmpdir, "fifo")
181     os.mkfifo(fifoname)
182     self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
183
184   def testContent(self):
185     bkpcount = 0
186     for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
187       for rep in [1, 2, 10, 127]:
188         testdata = data * rep
189
190         filename = utils.PathJoin(self.tmpdir, "test.data_")
191         utils.WriteFile(filename, data=testdata)
192         self.assertFileContent(filename, testdata)
193
194         for _ in range(3):
195           bname = utils.CreateBackup(filename)
196           bkpcount += 1
197           self.assertFileContent(bname, testdata)
198           self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
199
200
201 class TestListVisibleFiles(unittest.TestCase):
202   """Test case for ListVisibleFiles"""
203
204   def setUp(self):
205     self.path = tempfile.mkdtemp()
206
207   def tearDown(self):
208     shutil.rmtree(self.path)
209
210   def _CreateFiles(self, files):
211     for name in files:
212       utils.WriteFile(os.path.join(self.path, name), data="test")
213
214   def _test(self, files, expected):
215     self._CreateFiles(files)
216     found = utils.ListVisibleFiles(self.path)
217     self.assertEqual(set(found), set(expected))
218
219   def testAllVisible(self):
220     files = ["a", "b", "c"]
221     expected = files
222     self._test(files, expected)
223
224   def testNoneVisible(self):
225     files = [".a", ".b", ".c"]
226     expected = []
227     self._test(files, expected)
228
229   def testSomeVisible(self):
230     files = ["a", "b", ".c"]
231     expected = ["a", "b"]
232     self._test(files, expected)
233
234   def testNonAbsolutePath(self):
235     self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
236                           "abc")
237
238   def testNonNormalizedPath(self):
239     self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
240                           "/bin/../tmp")
241
242   def testMountpoint(self):
243     lvfmp_fn = compat.partial(utils.ListVisibleFiles,
244                               _is_mountpoint=lambda _: True)
245     self.assertEqual(lvfmp_fn(self.path), [])
246
247     # Create "lost+found" as a regular file
248     self._CreateFiles(["foo", "bar", ".baz", "lost+found"])
249     self.assertEqual(set(lvfmp_fn(self.path)),
250                      set(["foo", "bar", "lost+found"]))
251
252     # Replace "lost+found" with a directory
253     laf_path = utils.PathJoin(self.path, "lost+found")
254     utils.RemoveFile(laf_path)
255     os.mkdir(laf_path)
256     self.assertEqual(set(lvfmp_fn(self.path)), set(["foo", "bar"]))
257
258   def testLostAndFoundNoMountpoint(self):
259     files = ["foo", "bar", ".Hello World", "lost+found"]
260     expected = ["foo", "bar", "lost+found"]
261     self._test(files, expected)
262
263
264 class TestWriteFile(testutils.GanetiTestCase):
265   def setUp(self):
266     testutils.GanetiTestCase.setUp(self)
267     self.tmpdir = None
268     self.tfile = tempfile.NamedTemporaryFile()
269     self.did_pre = False
270     self.did_post = False
271     self.did_write = False
272
273   def tearDown(self):
274     testutils.GanetiTestCase.tearDown(self)
275     if self.tmpdir:
276       shutil.rmtree(self.tmpdir)
277
278   def markPre(self, fd):
279     self.did_pre = True
280
281   def markPost(self, fd):
282     self.did_post = True
283
284   def markWrite(self, fd):
285     self.did_write = True
286
287   def testWrite(self):
288     data = "abc"
289     utils.WriteFile(self.tfile.name, data=data)
290     self.assertEqual(utils.ReadFile(self.tfile.name), data)
291
292   def testWriteSimpleUnicode(self):
293     data = u"abc"
294     utils.WriteFile(self.tfile.name, data=data)
295     self.assertEqual(utils.ReadFile(self.tfile.name), data)
296
297   def testErrors(self):
298     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
299                       self.tfile.name, data="test", fn=lambda fd: None)
300     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
301     self.assertRaises(errors.ProgrammerError, utils.WriteFile,
302                       self.tfile.name, data="test", atime=0)
303     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name,
304                       mode=0400, keep_perms=utils.KP_ALWAYS)
305     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name,
306                       uid=0, keep_perms=utils.KP_ALWAYS)
307     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name,
308                       gid=0, keep_perms=utils.KP_ALWAYS)
309     self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name,
310                       mode=0400, uid=0, keep_perms=utils.KP_ALWAYS)
311
312   def testPreWrite(self):
313     utils.WriteFile(self.tfile.name, data="", prewrite=self.markPre)
314     self.assertTrue(self.did_pre)
315     self.assertFalse(self.did_post)
316     self.assertFalse(self.did_write)
317
318   def testPostWrite(self):
319     utils.WriteFile(self.tfile.name, data="", postwrite=self.markPost)
320     self.assertFalse(self.did_pre)
321     self.assertTrue(self.did_post)
322     self.assertFalse(self.did_write)
323
324   def testWriteFunction(self):
325     utils.WriteFile(self.tfile.name, fn=self.markWrite)
326     self.assertFalse(self.did_pre)
327     self.assertFalse(self.did_post)
328     self.assertTrue(self.did_write)
329
330   def testDryRun(self):
331     orig = "abc"
332     self.tfile.write(orig)
333     self.tfile.flush()
334     utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
335     self.assertEqual(utils.ReadFile(self.tfile.name), orig)
336
337   def testTimes(self):
338     f = self.tfile.name
339     for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
340                    (int(time.time()), 5000)]:
341       utils.WriteFile(f, data="hello", atime=at, mtime=mt)
342       st = os.stat(f)
343       self.assertEqual(st.st_atime, at)
344       self.assertEqual(st.st_mtime, mt)
345
346   def testNoClose(self):
347     data = "hello"
348     self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
349     fd = utils.WriteFile(self.tfile.name, data=data, close=False)
350     try:
351       os.lseek(fd, 0, 0)
352       self.assertEqual(os.read(fd, 4096), data)
353     finally:
354       os.close(fd)
355
356   def testNoLeftovers(self):
357     self.tmpdir = tempfile.mkdtemp()
358     self.assertEqual(utils.WriteFile(utils.PathJoin(self.tmpdir, "test"),
359                                      data="abc"),
360                      None)
361     self.assertEqual(os.listdir(self.tmpdir), ["test"])
362
363   def testFailRename(self):
364     self.tmpdir = tempfile.mkdtemp()
365     target = utils.PathJoin(self.tmpdir, "target")
366     os.mkdir(target)
367     self.assertRaises(OSError, utils.WriteFile, target, data="abc")
368     self.assertTrue(os.path.isdir(target))
369     self.assertEqual(os.listdir(self.tmpdir), ["target"])
370     self.assertFalse(os.listdir(target))
371
372   def testFailRenameDryRun(self):
373     self.tmpdir = tempfile.mkdtemp()
374     target = utils.PathJoin(self.tmpdir, "target")
375     os.mkdir(target)
376     self.assertEqual(utils.WriteFile(target, data="abc", dry_run=True), None)
377     self.assertTrue(os.path.isdir(target))
378     self.assertEqual(os.listdir(self.tmpdir), ["target"])
379     self.assertFalse(os.listdir(target))
380
381   def testBackup(self):
382     self.tmpdir = tempfile.mkdtemp()
383     testfile = utils.PathJoin(self.tmpdir, "test")
384
385     self.assertEqual(utils.WriteFile(testfile, data="foo", backup=True), None)
386     self.assertEqual(utils.ReadFile(testfile), "foo")
387     self.assertEqual(os.listdir(self.tmpdir), ["test"])
388
389     # Write again
390     assert os.path.isfile(testfile)
391     self.assertEqual(utils.WriteFile(testfile, data="bar", backup=True), None)
392     self.assertEqual(utils.ReadFile(testfile), "bar")
393     self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
394     self.assertTrue("test" in os.listdir(self.tmpdir))
395     self.assertEqual(len(os.listdir(self.tmpdir)), 2)
396
397     # Write again as dry-run
398     assert os.path.isfile(testfile)
399     self.assertEqual(utils.WriteFile(testfile, data="000", backup=True,
400                                      dry_run=True),
401                      None)
402     self.assertEqual(utils.ReadFile(testfile), "bar")
403     self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
404     self.assertTrue("test" in os.listdir(self.tmpdir))
405     self.assertEqual(len(os.listdir(self.tmpdir)), 2)
406
407   def testFileMode(self):
408     self.tmpdir = tempfile.mkdtemp()
409     target = utils.PathJoin(self.tmpdir, "target")
410     self.assertRaises(OSError, utils.WriteFile, target, data="data",
411                       keep_perms=utils.KP_ALWAYS)
412     # All masks have only user bits set, to avoid interactions with umask
413     utils.WriteFile(target, data="data", mode=0200)
414     self.assertFileMode(target, 0200)
415     utils.WriteFile(target, data="data", mode=0400,
416                     keep_perms=utils.KP_IF_EXISTS)
417     self.assertFileMode(target, 0200)
418     utils.WriteFile(target, data="data", keep_perms=utils.KP_ALWAYS)
419     self.assertFileMode(target, 0200)
420     utils.WriteFile(target, data="data", mode=0700)
421     self.assertFileMode(target, 0700)
422
423   def testNewFileMode(self):
424     self.tmpdir = tempfile.mkdtemp()
425     target = utils.PathJoin(self.tmpdir, "target")
426     utils.WriteFile(target, data="data", mode=0400,
427                     keep_perms=utils.KP_IF_EXISTS)
428     self.assertFileMode(target, 0400)
429
430 class TestFileID(testutils.GanetiTestCase):
431   def testEquality(self):
432     name = self._CreateTempFile()
433     oldi = utils.GetFileID(path=name)
434     self.failUnless(utils.VerifyFileID(oldi, oldi))
435
436   def testUpdate(self):
437     name = self._CreateTempFile()
438     oldi = utils.GetFileID(path=name)
439     fd = os.open(name, os.O_RDWR)
440     try:
441       newi = utils.GetFileID(fd=fd)
442       self.failUnless(utils.VerifyFileID(oldi, newi))
443       self.failUnless(utils.VerifyFileID(newi, oldi))
444     finally:
445       os.close(fd)
446
447   def testWriteFile(self):
448     name = self._CreateTempFile()
449     oldi = utils.GetFileID(path=name)
450     mtime = oldi[2]
451     os.utime(name, (mtime + 10, mtime + 10))
452     self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
453                       oldi, data="")
454     os.utime(name, (mtime - 10, mtime - 10))
455     utils.SafeWriteFile(name, oldi, data="")
456     oldi = utils.GetFileID(path=name)
457     mtime = oldi[2]
458     os.utime(name, (mtime + 10, mtime + 10))
459     # this doesn't raise, since we passed None
460     utils.SafeWriteFile(name, None, data="")
461
462   def testError(self):
463     t = tempfile.NamedTemporaryFile()
464     self.assertRaises(errors.ProgrammerError, utils.GetFileID,
465                       path=t.name, fd=t.fileno())
466
467
468 class TestRemoveFile(unittest.TestCase):
469   """Test case for the RemoveFile function"""
470
471   def setUp(self):
472     """Create a temp dir and file for each case"""
473     self.tmpdir = tempfile.mkdtemp("", "ganeti-unittest-")
474     fd, self.tmpfile = tempfile.mkstemp("", "", self.tmpdir)
475     os.close(fd)
476
477   def tearDown(self):
478     if os.path.exists(self.tmpfile):
479       os.unlink(self.tmpfile)
480     os.rmdir(self.tmpdir)
481
482   def testIgnoreDirs(self):
483     """Test that RemoveFile() ignores directories"""
484     self.assertEqual(None, utils.RemoveFile(self.tmpdir))
485
486   def testIgnoreNotExisting(self):
487     """Test that RemoveFile() ignores non-existing files"""
488     utils.RemoveFile(self.tmpfile)
489     utils.RemoveFile(self.tmpfile)
490
491   def testRemoveFile(self):
492     """Test that RemoveFile does remove a file"""
493     utils.RemoveFile(self.tmpfile)
494     if os.path.exists(self.tmpfile):
495       self.fail("File '%s' not removed" % self.tmpfile)
496
497   def testRemoveSymlink(self):
498     """Test that RemoveFile does remove symlinks"""
499     symlink = self.tmpdir + "/symlink"
500     os.symlink("no-such-file", symlink)
501     utils.RemoveFile(symlink)
502     if os.path.exists(symlink):
503       self.fail("File '%s' not removed" % symlink)
504     os.symlink(self.tmpfile, symlink)
505     utils.RemoveFile(symlink)
506     if os.path.exists(symlink):
507       self.fail("File '%s' not removed" % symlink)
508
509
510 class TestRemoveDir(unittest.TestCase):
511   def setUp(self):
512     self.tmpdir = tempfile.mkdtemp()
513
514   def tearDown(self):
515     try:
516       shutil.rmtree(self.tmpdir)
517     except EnvironmentError:
518       pass
519
520   def testEmptyDir(self):
521     utils.RemoveDir(self.tmpdir)
522     self.assertFalse(os.path.isdir(self.tmpdir))
523
524   def testNonEmptyDir(self):
525     self.tmpfile = os.path.join(self.tmpdir, "test1")
526     open(self.tmpfile, "w").close()
527     self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
528
529
530 class TestRename(unittest.TestCase):
531   """Test case for RenameFile"""
532
533   def setUp(self):
534     """Create a temporary directory"""
535     self.tmpdir = tempfile.mkdtemp()
536     self.tmpfile = os.path.join(self.tmpdir, "test1")
537
538     # Touch the file
539     open(self.tmpfile, "w").close()
540
541   def tearDown(self):
542     """Remove temporary directory"""
543     shutil.rmtree(self.tmpdir)
544
545   def testSimpleRename1(self):
546     """Simple rename 1"""
547     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
548     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
549
550   def testSimpleRename2(self):
551     """Simple rename 2"""
552     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
553                      mkdir=True)
554     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
555
556   def testRenameMkdir(self):
557     """Rename with mkdir"""
558     utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
559                      mkdir=True)
560     self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
561     self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
562
563     self.assertRaises(EnvironmentError, utils.RenameFile,
564                       os.path.join(self.tmpdir, "test/xyz"),
565                       os.path.join(self.tmpdir, "test/foo/bar/baz"),
566                       mkdir=True)
567
568     self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "test/xyz")))
569     self.assertFalse(os.path.exists(os.path.join(self.tmpdir, "test/foo/bar")))
570     self.assertFalse(os.path.exists(os.path.join(self.tmpdir,
571                                                  "test/foo/bar/baz")))
572
573
574 class TestMakedirs(unittest.TestCase):
575   def setUp(self):
576     self.tmpdir = tempfile.mkdtemp()
577
578   def tearDown(self):
579     shutil.rmtree(self.tmpdir)
580
581   def testNonExisting(self):
582     path = utils.PathJoin(self.tmpdir, "foo")
583     utils.Makedirs(path)
584     self.assert_(os.path.isdir(path))
585
586   def testExisting(self):
587     path = utils.PathJoin(self.tmpdir, "foo")
588     os.mkdir(path)
589     utils.Makedirs(path)
590     self.assert_(os.path.isdir(path))
591
592   def testRecursiveNonExisting(self):
593     path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
594     utils.Makedirs(path)
595     self.assert_(os.path.isdir(path))
596
597   def testRecursiveExisting(self):
598     path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
599     self.assertFalse(os.path.exists(path))
600     os.mkdir(utils.PathJoin(self.tmpdir, "B"))
601     utils.Makedirs(path)
602     self.assert_(os.path.isdir(path))
603
604
605 class TestEnsureDirs(unittest.TestCase):
606   """Tests for EnsureDirs"""
607
608   def setUp(self):
609     self.dir = tempfile.mkdtemp()
610     self.old_umask = os.umask(0777)
611
612   def testEnsureDirs(self):
613     utils.EnsureDirs([
614         (utils.PathJoin(self.dir, "foo"), 0777),
615         (utils.PathJoin(self.dir, "bar"), 0000),
616         ])
617     self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
618     self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
619
620   def tearDown(self):
621     os.rmdir(utils.PathJoin(self.dir, "foo"))
622     os.rmdir(utils.PathJoin(self.dir, "bar"))
623     os.rmdir(self.dir)
624     os.umask(self.old_umask)
625
626
627 class TestIsNormAbsPath(unittest.TestCase):
628   """Testing case for IsNormAbsPath"""
629
630   def _pathTestHelper(self, path, result):
631     if result:
632       self.assert_(utils.IsNormAbsPath(path),
633           msg="Path %s should result absolute and normalized" % path)
634     else:
635       self.assertFalse(utils.IsNormAbsPath(path),
636           msg="Path %s should not result absolute and normalized" % path)
637
638   def testBase(self):
639     self._pathTestHelper("/etc", True)
640     self._pathTestHelper("/srv", True)
641     self._pathTestHelper("etc", False)
642     self._pathTestHelper("/etc/../root", False)
643     self._pathTestHelper("/etc/", False)
644
645   def testSlashes(self):
646     # Root directory
647     self._pathTestHelper("/", True)
648
649     # POSIX' "implementation-defined" double slashes
650     self._pathTestHelper("//", True)
651
652     # Three and more slashes count as one, so the path is not normalized
653     for i in range(3, 10):
654       self._pathTestHelper("/" * i, False)
655
656
657 class TestIsBelowDir(unittest.TestCase):
658   """Testing case for IsBelowDir"""
659
660   def testExactlyTheSame(self):
661     self.assertFalse(utils.IsBelowDir("/a/b", "/a/b"))
662     self.assertFalse(utils.IsBelowDir("/a/b", "/a/b/"))
663     self.assertFalse(utils.IsBelowDir("/a/b/", "/a/b"))
664     self.assertFalse(utils.IsBelowDir("/a/b/", "/a/b/"))
665
666   def testSamePrefix(self):
667     self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c"))
668     self.assertTrue(utils.IsBelowDir("/a/b/", "/a/b/e"))
669
670   def testSamePrefixButDifferentDir(self):
671     self.assertFalse(utils.IsBelowDir("/a/b", "/a/bc/d"))
672     self.assertFalse(utils.IsBelowDir("/a/b/", "/a/bc/e"))
673
674   def testSamePrefixButDirTraversal(self):
675     self.assertFalse(utils.IsBelowDir("/a/b", "/a/b/../c"))
676     self.assertFalse(utils.IsBelowDir("/a/b/", "/a/b/../d"))
677
678   def testSamePrefixAndTraversal(self):
679     self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c/../d"))
680     self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c/./e"))
681     self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/../b/./e"))
682
683   def testBothAbsPath(self):
684     self.assertRaises(ValueError, utils.IsBelowDir, "/a/b/c", "d")
685     self.assertRaises(ValueError, utils.IsBelowDir, "a/b/c", "/d")
686     self.assertRaises(ValueError, utils.IsBelowDir, "a/b/c", "d")
687     self.assertRaises(ValueError, utils.IsBelowDir, "", "/")
688     self.assertRaises(ValueError, utils.IsBelowDir, "/", "")
689
690   def testRoot(self):
691     self.assertFalse(utils.IsBelowDir("/", "/"))
692
693     for i in ["/a", "/tmp", "/tmp/foo/bar", "/tmp/"]:
694       self.assertTrue(utils.IsBelowDir("/", i))
695
696   def testSlashes(self):
697     # In POSIX a double slash is "implementation-defined".
698     self.assertFalse(utils.IsBelowDir("//", "//"))
699     self.assertFalse(utils.IsBelowDir("//", "/tmp"))
700     self.assertTrue(utils.IsBelowDir("//tmp", "//tmp/x"))
701
702     # Three (or more) slashes count as one
703     self.assertFalse(utils.IsBelowDir("/", "///"))
704     self.assertTrue(utils.IsBelowDir("/", "///tmp"))
705     self.assertTrue(utils.IsBelowDir("/tmp", "///tmp/a/b"))
706
707
708 class TestPathJoin(unittest.TestCase):
709   """Testing case for PathJoin"""
710
711   def testBasicItems(self):
712     mlist = ["/a", "b", "c"]
713     self.failUnlessEqual(utils.PathJoin(*mlist), "/".join(mlist))
714
715   def testNonAbsPrefix(self):
716     self.failUnlessRaises(ValueError, utils.PathJoin, "a", "b")
717
718   def testBackTrack(self):
719     self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "b/../c")
720
721   def testMultiAbs(self):
722     self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "/b")
723
724
725 class TestTailFile(testutils.GanetiTestCase):
726   """Test case for the TailFile function"""
727
728   def testEmpty(self):
729     fname = self._CreateTempFile()
730     self.failUnlessEqual(utils.TailFile(fname), [])
731     self.failUnlessEqual(utils.TailFile(fname, lines=25), [])
732
733   def testAllLines(self):
734     data = ["test %d" % i for i in range(30)]
735     for i in range(30):
736       fname = self._CreateTempFile()
737       fd = open(fname, "w")
738       fd.write("\n".join(data[:i]))
739       if i > 0:
740         fd.write("\n")
741       fd.close()
742       self.failUnlessEqual(utils.TailFile(fname, lines=i), data[:i])
743
744   def testPartialLines(self):
745     data = ["test %d" % i for i in range(30)]
746     fname = self._CreateTempFile()
747     fd = open(fname, "w")
748     fd.write("\n".join(data))
749     fd.write("\n")
750     fd.close()
751     for i in range(1, 30):
752       self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
753
754   def testBigFile(self):
755     data = ["test %d" % i for i in range(30)]
756     fname = self._CreateTempFile()
757     fd = open(fname, "w")
758     fd.write("X" * 1048576)
759     fd.write("\n")
760     fd.write("\n".join(data))
761     fd.write("\n")
762     fd.close()
763     for i in range(1, 30):
764       self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
765
766
767 class TestPidFileFunctions(unittest.TestCase):
768   """Tests for WritePidFile and ReadPidFile"""
769
770   def setUp(self):
771     self.dir = tempfile.mkdtemp()
772     self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
773
774   def testPidFileFunctions(self):
775     pid_file = self.f_dpn("test")
776     fd = utils.WritePidFile(self.f_dpn("test"))
777     self.failUnless(os.path.exists(pid_file),
778                     "PID file should have been created")
779     read_pid = utils.ReadPidFile(pid_file)
780     self.failUnlessEqual(read_pid, os.getpid())
781     self.failUnless(utils.IsProcessAlive(read_pid))
782     self.failUnlessRaises(errors.PidFileLockError, utils.WritePidFile,
783                           self.f_dpn("test"))
784     os.close(fd)
785     utils.RemoveFile(self.f_dpn("test"))
786     self.failIf(os.path.exists(pid_file),
787                 "PID file should not exist anymore")
788     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
789                          "ReadPidFile should return 0 for missing pid file")
790     fh = open(pid_file, "w")
791     fh.write("blah\n")
792     fh.close()
793     self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
794                          "ReadPidFile should return 0 for invalid pid file")
795     # but now, even with the file existing, we should be able to lock it
796     fd = utils.WritePidFile(self.f_dpn("test"))
797     os.close(fd)
798     utils.RemoveFile(self.f_dpn("test"))
799     self.failIf(os.path.exists(pid_file),
800                 "PID file should not exist anymore")
801
802   def testKill(self):
803     pid_file = self.f_dpn("child")
804     r_fd, w_fd = os.pipe()
805     new_pid = os.fork()
806     if new_pid == 0: #child
807       utils.WritePidFile(self.f_dpn("child"))
808       os.write(w_fd, "a")
809       signal.pause()
810       os._exit(0)
811       return
812     # else we are in the parent
813     # wait until the child has written the pid file
814     os.read(r_fd, 1)
815     read_pid = utils.ReadPidFile(pid_file)
816     self.failUnlessEqual(read_pid, new_pid)
817     self.failUnless(utils.IsProcessAlive(new_pid))
818
819     # Try writing to locked file
820     try:
821       utils.WritePidFile(pid_file)
822     except errors.PidFileLockError, err:
823       errmsg = str(err)
824       self.assertTrue(errmsg.endswith(" %s" % new_pid),
825                       msg=("Error message ('%s') didn't contain correct"
826                            " PID (%s)" % (errmsg, new_pid)))
827     else:
828       self.fail("Writing to locked file didn't fail")
829
830     utils.KillProcess(new_pid, waitpid=True)
831     self.failIf(utils.IsProcessAlive(new_pid))
832     utils.RemoveFile(self.f_dpn("child"))
833     self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
834
835   def testExceptionType(self):
836     # Make sure the PID lock error is a subclass of LockError in case some code
837     # depends on it
838     self.assertTrue(issubclass(errors.PidFileLockError, errors.LockError))
839
840   def tearDown(self):
841     shutil.rmtree(self.dir)
842
843
844 class TestSshKeys(testutils.GanetiTestCase):
845   """Test case for the AddAuthorizedKey function"""
846
847   KEY_A = "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a"
848   KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
849            "ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b")
850
851   def setUp(self):
852     testutils.GanetiTestCase.setUp(self)
853     self.tmpname = self._CreateTempFile()
854     handle = open(self.tmpname, "w")
855     try:
856       handle.write("%s\n" % TestSshKeys.KEY_A)
857       handle.write("%s\n" % TestSshKeys.KEY_B)
858     finally:
859       handle.close()
860
861   def testAddingNewKey(self):
862     utils.AddAuthorizedKey(self.tmpname,
863                            "ssh-dss AAAAB3NzaC1kc3MAAACB root@test")
864
865     self.assertFileContent(self.tmpname,
866       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
867       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
868       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
869       "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
870
871   def testAddingAlmostButNotCompletelyTheSameKey(self):
872     utils.AddAuthorizedKey(self.tmpname,
873         "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test")
874
875     # Only significant fields are compared, therefore the key won't be
876     # updated/added
877     self.assertFileContent(self.tmpname,
878       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
879       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
880       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
881
882   def testAddingExistingKeyWithSomeMoreSpaces(self):
883     utils.AddAuthorizedKey(self.tmpname,
884       "ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a")
885     utils.AddAuthorizedKey(self.tmpname,
886       "ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22")
887
888     self.assertFileContent(self.tmpname,
889       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
890       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
891       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
892       "ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22\n")
893
894   def testRemovingExistingKeyWithSomeMoreSpaces(self):
895     utils.RemoveAuthorizedKey(self.tmpname,
896         "ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a")
897
898     self.assertFileContent(self.tmpname,
899       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
900       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
901
902   def testRemovingNonExistingKey(self):
903     utils.RemoveAuthorizedKey(self.tmpname,
904         "ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test")
905
906     self.assertFileContent(self.tmpname,
907       "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
908       'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
909       " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
910
911
912 class TestNewUUID(unittest.TestCase):
913   """Test case for NewUUID"""
914
915   def runTest(self):
916     self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
917
918
919 def _MockStatResult(cb, mode, uid, gid):
920   def _fn(path):
921     if cb:
922       cb()
923     return {
924       stat.ST_MODE: mode,
925       stat.ST_UID: uid,
926       stat.ST_GID: gid,
927       }
928   return _fn
929
930
931 def _RaiseNoEntError():
932   raise EnvironmentError(errno.ENOENT, "not found")
933
934
935 def _OtherStatRaise():
936   raise EnvironmentError()
937
938
939 class TestPermissionEnforcements(unittest.TestCase):
940   UID_A = 16024
941   UID_B = 25850
942   GID_A = 14028
943   GID_B = 29801
944
945   def setUp(self):
946     self._chown_calls = []
947     self._chmod_calls = []
948     self._mkdir_calls = []
949
950   def tearDown(self):
951     self.assertRaises(IndexError, self._mkdir_calls.pop)
952     self.assertRaises(IndexError, self._chmod_calls.pop)
953     self.assertRaises(IndexError, self._chown_calls.pop)
954
955   def _FakeMkdir(self, path):
956     self._mkdir_calls.append(path)
957
958   def _FakeChown(self, path, uid, gid):
959     self._chown_calls.append((path, uid, gid))
960
961   def _ChmodWrapper(self, cb):
962     def _fn(path, mode):
963       self._chmod_calls.append((path, mode))
964       if cb:
965         cb()
966     return _fn
967
968   def _VerifyPerm(self, path, mode, uid=-1, gid=-1):
969     self.assertEqual(path, "/ganeti-qa-non-test")
970     self.assertEqual(mode, 0700)
971     self.assertEqual(uid, self.UID_A)
972     self.assertEqual(gid, self.GID_A)
973
974   def testMakeDirWithPerm(self):
975     is_dir_stat = _MockStatResult(None, stat.S_IFDIR, 0, 0)
976     utils.MakeDirWithPerm("/ganeti-qa-non-test", 0700, self.UID_A, self.GID_A,
977                           _lstat_fn=is_dir_stat, _perm_fn=self._VerifyPerm)
978
979   def testDirErrors(self):
980     self.assertRaises(errors.GenericError, utils.MakeDirWithPerm,
981                       "/ganeti-qa-non-test", 0700, 0, 0,
982                       _lstat_fn=_MockStatResult(None, 0, 0, 0))
983     self.assertRaises(IndexError, self._mkdir_calls.pop)
984
985     other_stat_raise = _MockStatResult(_OtherStatRaise, stat.S_IFDIR, 0, 0)
986     self.assertRaises(errors.GenericError, utils.MakeDirWithPerm,
987                       "/ganeti-qa-non-test", 0700, 0, 0,
988                       _lstat_fn=other_stat_raise)
989     self.assertRaises(IndexError, self._mkdir_calls.pop)
990
991     non_exist_stat = _MockStatResult(_RaiseNoEntError, stat.S_IFDIR, 0, 0)
992     utils.MakeDirWithPerm("/ganeti-qa-non-test", 0700, self.UID_A, self.GID_A,
993                           _lstat_fn=non_exist_stat, _mkdir_fn=self._FakeMkdir,
994                           _perm_fn=self._VerifyPerm)
995     self.assertEqual(self._mkdir_calls.pop(0), "/ganeti-qa-non-test")
996
997   def testEnforcePermissionNoEnt(self):
998     self.assertRaises(errors.GenericError, utils.EnforcePermission,
999                       "/ganeti-qa-non-test", 0600,
1000                       _chmod_fn=NotImplemented, _chown_fn=NotImplemented,
1001                       _stat_fn=_MockStatResult(_RaiseNoEntError, 0, 0, 0))
1002
1003   def testEnforcePermissionNoEntMustNotExist(self):
1004     utils.EnforcePermission("/ganeti-qa-non-test", 0600, must_exist=False,
1005                             _chmod_fn=NotImplemented,
1006                             _chown_fn=NotImplemented,
1007                             _stat_fn=_MockStatResult(_RaiseNoEntError,
1008                                                           0, 0, 0))
1009
1010   def testEnforcePermissionOtherErrorMustNotExist(self):
1011     self.assertRaises(errors.GenericError, utils.EnforcePermission,
1012                       "/ganeti-qa-non-test", 0600, must_exist=False,
1013                       _chmod_fn=NotImplemented, _chown_fn=NotImplemented,
1014                       _stat_fn=_MockStatResult(_OtherStatRaise, 0, 0, 0))
1015
1016   def testEnforcePermissionNoChanges(self):
1017     utils.EnforcePermission("/ganeti-qa-non-test", 0600,
1018                             _stat_fn=_MockStatResult(None, 0600, 0, 0),
1019                             _chmod_fn=self._ChmodWrapper(None),
1020                             _chown_fn=self._FakeChown)
1021
1022   def testEnforcePermissionChangeMode(self):
1023     utils.EnforcePermission("/ganeti-qa-non-test", 0444,
1024                             _stat_fn=_MockStatResult(None, 0600, 0, 0),
1025                             _chmod_fn=self._ChmodWrapper(None),
1026                             _chown_fn=self._FakeChown)
1027     self.assertEqual(self._chmod_calls.pop(0), ("/ganeti-qa-non-test", 0444))
1028
1029   def testEnforcePermissionSetUidGid(self):
1030     utils.EnforcePermission("/ganeti-qa-non-test", 0600,
1031                             uid=self.UID_B, gid=self.GID_B,
1032                             _stat_fn=_MockStatResult(None, 0600,
1033                                                      self.UID_A,
1034                                                      self.GID_A),
1035                             _chmod_fn=self._ChmodWrapper(None),
1036                             _chown_fn=self._FakeChown)
1037     self.assertEqual(self._chown_calls.pop(0),
1038                      ("/ganeti-qa-non-test", self.UID_B, self.GID_B))
1039
1040
1041 if __name__ == "__main__":
1042   testutils.GanetiTestProgram()