Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils.io_unittest.py @ b6522276

History | View | Annotate | Download (32.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.utils.io"""
23

    
24
import os
25
import tempfile
26
import unittest
27
import shutil
28
import glob
29
import time
30
import signal
31
import stat
32
import errno
33

    
34
from ganeti import constants
35
from ganeti import utils
36
from ganeti import compat
37
from ganeti import errors
38

    
39
import testutils
40

    
41

    
42
class TestReadFile(testutils.GanetiTestCase):
43
  def testReadAll(self):
44
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
45
    self.assertEqual(len(data), 814)
46

    
47
    h = compat.md5_hash()
48
    h.update(data)
49
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
50

    
51
  def testReadSize(self):
52
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
53
                          size=100)
54
    self.assertEqual(len(data), 100)
55

    
56
    h = compat.md5_hash()
57
    h.update(data)
58
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
59

    
60
  def testCallback(self):
61
    def _Cb(fh):
62
      self.assertEqual(fh.tell(), 0)
63
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"), preread=_Cb)
64
    self.assertEqual(len(data), 814)
65

    
66
  def testError(self):
67
    self.assertRaises(EnvironmentError, utils.ReadFile,
68
                      "/dev/null/does-not-exist")
69

    
70

    
71
class TestReadOneLineFile(testutils.GanetiTestCase):
72
  def setUp(self):
73
    testutils.GanetiTestCase.setUp(self)
74

    
75
  def testDefault(self):
76
    data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"))
77
    self.assertEqual(len(data), 27)
78
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
79

    
80
  def testNotStrict(self):
81
    data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"),
82
                                 strict=False)
83
    self.assertEqual(len(data), 27)
84
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
85

    
86
  def testStrictFailure(self):
87
    self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
88
                      self._TestDataFilename("cert1.pem"), strict=True)
89

    
90
  def testLongLine(self):
91
    dummydata = (1024 * "Hello World! ")
92
    myfile = self._CreateTempFile()
93
    utils.WriteFile(myfile, data=dummydata)
94
    datastrict = utils.ReadOneLineFile(myfile, strict=True)
95
    datalax = utils.ReadOneLineFile(myfile, strict=False)
96
    self.assertEqual(dummydata, datastrict)
97
    self.assertEqual(dummydata, datalax)
98

    
99
  def testNewline(self):
100
    myfile = self._CreateTempFile()
101
    myline = "myline"
102
    for nl in ["", "\n", "\r\n"]:
103
      dummydata = "%s%s" % (myline, nl)
104
      utils.WriteFile(myfile, data=dummydata)
105
      datalax = utils.ReadOneLineFile(myfile, strict=False)
106
      self.assertEqual(myline, datalax)
107
      datastrict = utils.ReadOneLineFile(myfile, strict=True)
108
      self.assertEqual(myline, datastrict)
109

    
110
  def testWhitespaceAndMultipleLines(self):
111
    myfile = self._CreateTempFile()
112
    for nl in ["", "\n", "\r\n"]:
113
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
114
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
115
        utils.WriteFile(myfile, data=dummydata)
116
        datalax = utils.ReadOneLineFile(myfile, strict=False)
117
        if nl:
118
          self.assert_(set("\r\n") & set(dummydata))
119
          self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
120
                            myfile, strict=True)
121
          explen = len("Foo bar baz ") + len(ws)
122
          self.assertEqual(len(datalax), explen)
123
          self.assertEqual(datalax, dummydata[:explen])
124
          self.assertFalse(set("\r\n") & set(datalax))
125
        else:
126
          datastrict = utils.ReadOneLineFile(myfile, strict=True)
127
          self.assertEqual(dummydata, datastrict)
128
          self.assertEqual(dummydata, datalax)
129

    
130
  def testEmptylines(self):
131
    myfile = self._CreateTempFile()
132
    myline = "myline"
133
    for nl in ["\n", "\r\n"]:
134
      for ol in ["", "otherline"]:
135
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
136
        utils.WriteFile(myfile, data=dummydata)
137
        self.assert_(set("\r\n") & set(dummydata))
138
        datalax = utils.ReadOneLineFile(myfile, strict=False)
139
        self.assertEqual(myline, datalax)
140
        if ol:
141
          self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
142
                            myfile, strict=True)
143
        else:
144
          datastrict = utils.ReadOneLineFile(myfile, strict=True)
145
          self.assertEqual(myline, datastrict)
146

    
147
  def testEmptyfile(self):
148
    myfile = self._CreateTempFile()
149
    self.assertRaises(errors.GenericError, utils.ReadOneLineFile, myfile)
150

    
151

    
152
class TestTimestampForFilename(unittest.TestCase):
153
  def test(self):
154
    self.assert_("." not in utils.TimestampForFilename())
155
    self.assert_(":" not in utils.TimestampForFilename())
156

    
157

    
158
class TestCreateBackup(testutils.GanetiTestCase):
159
  def setUp(self):
160
    testutils.GanetiTestCase.setUp(self)
161

    
162
    self.tmpdir = tempfile.mkdtemp()
163

    
164
  def tearDown(self):
165
    testutils.GanetiTestCase.tearDown(self)
166

    
167
    shutil.rmtree(self.tmpdir)
168

    
169
  def testEmpty(self):
170
    filename = utils.PathJoin(self.tmpdir, "config.data")
171
    utils.WriteFile(filename, data="")
172
    bname = utils.CreateBackup(filename)
173
    self.assertFileContent(bname, "")
174
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
175
    utils.CreateBackup(filename)
176
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
177
    utils.CreateBackup(filename)
178
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
179

    
180
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
181
    os.mkfifo(fifoname)
182
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
183

    
184
  def testContent(self):
185
    bkpcount = 0
186
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
187
      for rep in [1, 2, 10, 127]:
188
        testdata = data * rep
189

    
190
        filename = utils.PathJoin(self.tmpdir, "test.data_")
191
        utils.WriteFile(filename, data=testdata)
192
        self.assertFileContent(filename, testdata)
193

    
194
        for _ in range(3):
195
          bname = utils.CreateBackup(filename)
196
          bkpcount += 1
197
          self.assertFileContent(bname, testdata)
198
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
199

    
200

    
201
class TestListVisibleFiles(unittest.TestCase):
202
  """Test case for ListVisibleFiles"""
203

    
204
  def setUp(self):
205
    self.path = tempfile.mkdtemp()
206

    
207
  def tearDown(self):
208
    shutil.rmtree(self.path)
209

    
210
  def _CreateFiles(self, files):
211
    for name in files:
212
      utils.WriteFile(os.path.join(self.path, name), data="test")
213

    
214
  def _test(self, files, expected):
215
    self._CreateFiles(files)
216
    found = utils.ListVisibleFiles(self.path)
217
    self.assertEqual(set(found), set(expected))
218

    
219
  def testAllVisible(self):
220
    files = ["a", "b", "c"]
221
    expected = files
222
    self._test(files, expected)
223

    
224
  def testNoneVisible(self):
225
    files = [".a", ".b", ".c"]
226
    expected = []
227
    self._test(files, expected)
228

    
229
  def testSomeVisible(self):
230
    files = ["a", "b", ".c"]
231
    expected = ["a", "b"]
232
    self._test(files, expected)
233

    
234
  def testNonAbsolutePath(self):
235
    self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
236
                          "abc")
237

    
238
  def testNonNormalizedPath(self):
239
    self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
240
                          "/bin/../tmp")
241

    
242
  def testMountpoint(self):
243
    lvfmp_fn = compat.partial(utils.ListVisibleFiles,
244
                              _is_mountpoint=lambda _: True)
245
    self.assertEqual(lvfmp_fn(self.path), [])
246

    
247
    # Create "lost+found" as a regular file
248
    self._CreateFiles(["foo", "bar", ".baz", "lost+found"])
249
    self.assertEqual(set(lvfmp_fn(self.path)),
250
                     set(["foo", "bar", "lost+found"]))
251

    
252
    # Replace "lost+found" with a directory
253
    laf_path = utils.PathJoin(self.path, "lost+found")
254
    utils.RemoveFile(laf_path)
255
    os.mkdir(laf_path)
256
    self.assertEqual(set(lvfmp_fn(self.path)), set(["foo", "bar"]))
257

    
258
  def testLostAndFoundNoMountpoint(self):
259
    files = ["foo", "bar", ".Hello World", "lost+found"]
260
    expected = ["foo", "bar", "lost+found"]
261
    self._test(files, expected)
262

    
263

    
264
class TestWriteFile(unittest.TestCase):
265
  def setUp(self):
266
    self.tmpdir = None
267
    self.tfile = tempfile.NamedTemporaryFile()
268
    self.did_pre = False
269
    self.did_post = False
270
    self.did_write = False
271

    
272
  def tearDown(self):
273
    if self.tmpdir:
274
      shutil.rmtree(self.tmpdir)
275

    
276
  def markPre(self, fd):
277
    self.did_pre = True
278

    
279
  def markPost(self, fd):
280
    self.did_post = True
281

    
282
  def markWrite(self, fd):
283
    self.did_write = True
284

    
285
  def testWrite(self):
286
    data = "abc"
287
    utils.WriteFile(self.tfile.name, data=data)
288
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
289

    
290
  def testWriteSimpleUnicode(self):
291
    data = u"abc"
292
    utils.WriteFile(self.tfile.name, data=data)
293
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
294

    
295
  def testErrors(self):
296
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
297
                      self.tfile.name, data="test", fn=lambda fd: None)
298
    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
299
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
300
                      self.tfile.name, data="test", atime=0)
301

    
302
  def testPreWrite(self):
303
    utils.WriteFile(self.tfile.name, data="", prewrite=self.markPre)
304
    self.assertTrue(self.did_pre)
305
    self.assertFalse(self.did_post)
306
    self.assertFalse(self.did_write)
307

    
308
  def testPostWrite(self):
309
    utils.WriteFile(self.tfile.name, data="", postwrite=self.markPost)
310
    self.assertFalse(self.did_pre)
311
    self.assertTrue(self.did_post)
312
    self.assertFalse(self.did_write)
313

    
314
  def testWriteFunction(self):
315
    utils.WriteFile(self.tfile.name, fn=self.markWrite)
316
    self.assertFalse(self.did_pre)
317
    self.assertFalse(self.did_post)
318
    self.assertTrue(self.did_write)
319

    
320
  def testDryRun(self):
321
    orig = "abc"
322
    self.tfile.write(orig)
323
    self.tfile.flush()
324
    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
325
    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
326

    
327
  def testTimes(self):
328
    f = self.tfile.name
329
    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
330
                   (int(time.time()), 5000)]:
331
      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
332
      st = os.stat(f)
333
      self.assertEqual(st.st_atime, at)
334
      self.assertEqual(st.st_mtime, mt)
335

    
336
  def testNoClose(self):
337
    data = "hello"
338
    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
339
    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
340
    try:
341
      os.lseek(fd, 0, 0)
342
      self.assertEqual(os.read(fd, 4096), data)
343
    finally:
344
      os.close(fd)
345

    
346
  def testNoLeftovers(self):
347
    self.tmpdir = tempfile.mkdtemp()
348
    self.assertEqual(utils.WriteFile(utils.PathJoin(self.tmpdir, "test"),
349
                                     data="abc"),
350
                     None)
351
    self.assertEqual(os.listdir(self.tmpdir), ["test"])
352

    
353
  def testFailRename(self):
354
    self.tmpdir = tempfile.mkdtemp()
355
    target = utils.PathJoin(self.tmpdir, "target")
356
    os.mkdir(target)
357
    self.assertRaises(OSError, utils.WriteFile, target, data="abc")
358
    self.assertTrue(os.path.isdir(target))
359
    self.assertEqual(os.listdir(self.tmpdir), ["target"])
360
    self.assertFalse(os.listdir(target))
361

    
362
  def testFailRenameDryRun(self):
363
    self.tmpdir = tempfile.mkdtemp()
364
    target = utils.PathJoin(self.tmpdir, "target")
365
    os.mkdir(target)
366
    self.assertEqual(utils.WriteFile(target, data="abc", dry_run=True), None)
367
    self.assertTrue(os.path.isdir(target))
368
    self.assertEqual(os.listdir(self.tmpdir), ["target"])
369
    self.assertFalse(os.listdir(target))
370

    
371
  def testBackup(self):
372
    self.tmpdir = tempfile.mkdtemp()
373
    testfile = utils.PathJoin(self.tmpdir, "test")
374

    
375
    self.assertEqual(utils.WriteFile(testfile, data="foo", backup=True), None)
376
    self.assertEqual(utils.ReadFile(testfile), "foo")
377
    self.assertEqual(os.listdir(self.tmpdir), ["test"])
378

    
379
    # Write again
380
    assert os.path.isfile(testfile)
381
    self.assertEqual(utils.WriteFile(testfile, data="bar", backup=True), None)
382
    self.assertEqual(utils.ReadFile(testfile), "bar")
383
    self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
384
    self.assertTrue("test" in os.listdir(self.tmpdir))
385
    self.assertEqual(len(os.listdir(self.tmpdir)), 2)
386

    
387
    # Write again as dry-run
388
    assert os.path.isfile(testfile)
389
    self.assertEqual(utils.WriteFile(testfile, data="000", backup=True,
390
                                     dry_run=True),
391
                     None)
392
    self.assertEqual(utils.ReadFile(testfile), "bar")
393
    self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
394
    self.assertTrue("test" in os.listdir(self.tmpdir))
395
    self.assertEqual(len(os.listdir(self.tmpdir)), 2)
396

    
397

    
398
class TestFileID(testutils.GanetiTestCase):
399
  def testEquality(self):
400
    name = self._CreateTempFile()
401
    oldi = utils.GetFileID(path=name)
402
    self.failUnless(utils.VerifyFileID(oldi, oldi))
403

    
404
  def testUpdate(self):
405
    name = self._CreateTempFile()
406
    oldi = utils.GetFileID(path=name)
407
    fd = os.open(name, os.O_RDWR)
408
    try:
409
      newi = utils.GetFileID(fd=fd)
410
      self.failUnless(utils.VerifyFileID(oldi, newi))
411
      self.failUnless(utils.VerifyFileID(newi, oldi))
412
    finally:
413
      os.close(fd)
414

    
415
  def testWriteFile(self):
416
    name = self._CreateTempFile()
417
    oldi = utils.GetFileID(path=name)
418
    mtime = oldi[2]
419
    os.utime(name, (mtime + 10, mtime + 10))
420
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
421
                      oldi, data="")
422
    os.utime(name, (mtime - 10, mtime - 10))
423
    utils.SafeWriteFile(name, oldi, data="")
424
    oldi = utils.GetFileID(path=name)
425
    mtime = oldi[2]
426
    os.utime(name, (mtime + 10, mtime + 10))
427
    # this doesn't raise, since we passed None
428
    utils.SafeWriteFile(name, None, data="")
429

    
430
  def testError(self):
431
    t = tempfile.NamedTemporaryFile()
432
    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
433
                      path=t.name, fd=t.fileno())
434

    
435

    
436
class TestRemoveFile(unittest.TestCase):
437
  """Test case for the RemoveFile function"""
438

    
439
  def setUp(self):
440
    """Create a temp dir and file for each case"""
441
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
442
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
443
    os.close(fd)
444

    
445
  def tearDown(self):
446
    if os.path.exists(self.tmpfile):
447
      os.unlink(self.tmpfile)
448
    os.rmdir(self.tmpdir)
449

    
450
  def testIgnoreDirs(self):
451
    """Test that RemoveFile() ignores directories"""
452
    self.assertEqual(None, utils.RemoveFile(self.tmpdir))
453

    
454
  def testIgnoreNotExisting(self):
455
    """Test that RemoveFile() ignores non-existing files"""
456
    utils.RemoveFile(self.tmpfile)
457
    utils.RemoveFile(self.tmpfile)
458

    
459
  def testRemoveFile(self):
460
    """Test that RemoveFile does remove a file"""
461
    utils.RemoveFile(self.tmpfile)
462
    if os.path.exists(self.tmpfile):
463
      self.fail("File '%s' not removed" % self.tmpfile)
464

    
465
  def testRemoveSymlink(self):
466
    """Test that RemoveFile does remove symlinks"""
467
    symlink = self.tmpdir + "/symlink"
468
    os.symlink("no-such-file", symlink)
469
    utils.RemoveFile(symlink)
470
    if os.path.exists(symlink):
471
      self.fail("File '%s' not removed" % symlink)
472
    os.symlink(self.tmpfile, symlink)
473
    utils.RemoveFile(symlink)
474
    if os.path.exists(symlink):
475
      self.fail("File '%s' not removed" % symlink)
476

    
477

    
478
class TestRemoveDir(unittest.TestCase):
479
  def setUp(self):
480
    self.tmpdir = tempfile.mkdtemp()
481

    
482
  def tearDown(self):
483
    try:
484
      shutil.rmtree(self.tmpdir)
485
    except EnvironmentError:
486
      pass
487

    
488
  def testEmptyDir(self):
489
    utils.RemoveDir(self.tmpdir)
490
    self.assertFalse(os.path.isdir(self.tmpdir))
491

    
492
  def testNonEmptyDir(self):
493
    self.tmpfile = os.path.join(self.tmpdir, "test1")
494
    open(self.tmpfile, "w").close()
495
    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
496

    
497

    
498
class TestRename(unittest.TestCase):
499
  """Test case for RenameFile"""
500

    
501
  def setUp(self):
502
    """Create a temporary directory"""
503
    self.tmpdir = tempfile.mkdtemp()
504
    self.tmpfile = os.path.join(self.tmpdir, "test1")
505

    
506
    # Touch the file
507
    open(self.tmpfile, "w").close()
508

    
509
  def tearDown(self):
510
    """Remove temporary directory"""
511
    shutil.rmtree(self.tmpdir)
512

    
513
  def testSimpleRename1(self):
514
    """Simple rename 1"""
515
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
516
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
517

    
518
  def testSimpleRename2(self):
519
    """Simple rename 2"""
520
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
521
                     mkdir=True)
522
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
523

    
524
  def testRenameMkdir(self):
525
    """Rename with mkdir"""
526
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
527
                     mkdir=True)
528
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
529
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
530

    
531
    self.assertRaises(EnvironmentError, utils.RenameFile,
532
                      os.path.join(self.tmpdir, "test/xyz"),
533
                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
534
                      mkdir=True)
535

    
536
    self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "test/xyz")))
537
    self.assertFalse(os.path.exists(os.path.join(self.tmpdir, "test/foo/bar")))
538
    self.assertFalse(os.path.exists(os.path.join(self.tmpdir,
539
                                                 "test/foo/bar/baz")))
540

    
541

    
542
class TestMakedirs(unittest.TestCase):
543
  def setUp(self):
544
    self.tmpdir = tempfile.mkdtemp()
545

    
546
  def tearDown(self):
547
    shutil.rmtree(self.tmpdir)
548

    
549
  def testNonExisting(self):
550
    path = utils.PathJoin(self.tmpdir, "foo")
551
    utils.Makedirs(path)
552
    self.assert_(os.path.isdir(path))
553

    
554
  def testExisting(self):
555
    path = utils.PathJoin(self.tmpdir, "foo")
556
    os.mkdir(path)
557
    utils.Makedirs(path)
558
    self.assert_(os.path.isdir(path))
559

    
560
  def testRecursiveNonExisting(self):
561
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
562
    utils.Makedirs(path)
563
    self.assert_(os.path.isdir(path))
564

    
565
  def testRecursiveExisting(self):
566
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
567
    self.assertFalse(os.path.exists(path))
568
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
569
    utils.Makedirs(path)
570
    self.assert_(os.path.isdir(path))
571

    
572

    
573
class TestEnsureDirs(unittest.TestCase):
574
  """Tests for EnsureDirs"""
575

    
576
  def setUp(self):
577
    self.dir = tempfile.mkdtemp()
578
    self.old_umask = os.umask(0777)
579

    
580
  def testEnsureDirs(self):
581
    utils.EnsureDirs([
582
        (utils.PathJoin(self.dir, "foo"), 0777),
583
        (utils.PathJoin(self.dir, "bar"), 0000),
584
        ])
585
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
586
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
587

    
588
  def tearDown(self):
589
    os.rmdir(utils.PathJoin(self.dir, "foo"))
590
    os.rmdir(utils.PathJoin(self.dir, "bar"))
591
    os.rmdir(self.dir)
592
    os.umask(self.old_umask)
593

    
594

    
595
class TestIsNormAbsPath(unittest.TestCase):
596
  """Testing case for IsNormAbsPath"""
597

    
598
  def _pathTestHelper(self, path, result):
599
    if result:
600
      self.assert_(utils.IsNormAbsPath(path),
601
          "Path %s should result absolute and normalized" % path)
602
    else:
603
      self.assertFalse(utils.IsNormAbsPath(path),
604
          "Path %s should not result absolute and normalized" % path)
605

    
606
  def testBase(self):
607
    self._pathTestHelper("/etc", True)
608
    self._pathTestHelper("/srv", True)
609
    self._pathTestHelper("etc", False)
610
    self._pathTestHelper("/etc/../root", False)
611
    self._pathTestHelper("/etc/", False)
612

    
613

    
614
class TestIsBelowDir(unittest.TestCase):
615
  """Testing case for IsBelowDir"""
616

    
617
  def testSamePrefix(self):
618
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c"))
619
    self.assertTrue(utils.IsBelowDir("/a/b/", "/a/b/e"))
620

    
621
  def testSamePrefixButDifferentDir(self):
622
    self.assertFalse(utils.IsBelowDir("/a/b", "/a/bc/d"))
623
    self.assertFalse(utils.IsBelowDir("/a/b/", "/a/bc/e"))
624

    
625
  def testSamePrefixButDirTraversal(self):
626
    self.assertFalse(utils.IsBelowDir("/a/b", "/a/b/../c"))
627
    self.assertFalse(utils.IsBelowDir("/a/b/", "/a/b/../d"))
628

    
629
  def testSamePrefixAndTraversal(self):
630
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c/../d"))
631
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c/./e"))
632
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/../b/./e"))
633

    
634
  def testBothAbsPath(self):
635
    self.assertRaises(ValueError, utils.IsBelowDir, "/a/b/c", "d")
636
    self.assertRaises(ValueError, utils.IsBelowDir, "a/b/c", "/d")
637
    self.assertRaises(ValueError, utils.IsBelowDir, "a/b/c", "d")
638

    
639

    
640
class TestPathJoin(unittest.TestCase):
641
  """Testing case for PathJoin"""
642

    
643
  def testBasicItems(self):
644
    mlist = ["/a", "b", "c"]
645
    self.failUnlessEqual(utils.PathJoin(*mlist), "/".join(mlist))
646

    
647
  def testNonAbsPrefix(self):
648
    self.failUnlessRaises(ValueError, utils.PathJoin, "a", "b")
649

    
650
  def testBackTrack(self):
651
    self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "b/../c")
652

    
653
  def testMultiAbs(self):
654
    self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "/b")
655

    
656

    
657
class TestTailFile(testutils.GanetiTestCase):
658
  """Test case for the TailFile function"""
659

    
660
  def testEmpty(self):
661
    fname = self._CreateTempFile()
662
    self.failUnlessEqual(utils.TailFile(fname), [])
663
    self.failUnlessEqual(utils.TailFile(fname, lines=25), [])
664

    
665
  def testAllLines(self):
666
    data = ["test %d" % i for i in range(30)]
667
    for i in range(30):
668
      fname = self._CreateTempFile()
669
      fd = open(fname, "w")
670
      fd.write("\n".join(data[:i]))
671
      if i > 0:
672
        fd.write("\n")
673
      fd.close()
674
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[:i])
675

    
676
  def testPartialLines(self):
677
    data = ["test %d" % i for i in range(30)]
678
    fname = self._CreateTempFile()
679
    fd = open(fname, "w")
680
    fd.write("\n".join(data))
681
    fd.write("\n")
682
    fd.close()
683
    for i in range(1, 30):
684
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
685

    
686
  def testBigFile(self):
687
    data = ["test %d" % i for i in range(30)]
688
    fname = self._CreateTempFile()
689
    fd = open(fname, "w")
690
    fd.write("X" * 1048576)
691
    fd.write("\n")
692
    fd.write("\n".join(data))
693
    fd.write("\n")
694
    fd.close()
695
    for i in range(1, 30):
696
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
697

    
698

    
699
class TestPidFileFunctions(unittest.TestCase):
700
  """Tests for WritePidFile and ReadPidFile"""
701

    
702
  def setUp(self):
703
    self.dir = tempfile.mkdtemp()
704
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
705

    
706
  def testPidFileFunctions(self):
707
    pid_file = self.f_dpn('test')
708
    fd = utils.WritePidFile(self.f_dpn('test'))
709
    self.failUnless(os.path.exists(pid_file),
710
                    "PID file should have been created")
711
    read_pid = utils.ReadPidFile(pid_file)
712
    self.failUnlessEqual(read_pid, os.getpid())
713
    self.failUnless(utils.IsProcessAlive(read_pid))
714
    self.failUnlessRaises(errors.PidFileLockError, utils.WritePidFile,
715
                          self.f_dpn('test'))
716
    os.close(fd)
717
    utils.RemoveFile(self.f_dpn("test"))
718
    self.failIf(os.path.exists(pid_file),
719
                "PID file should not exist anymore")
720
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
721
                         "ReadPidFile should return 0 for missing pid file")
722
    fh = open(pid_file, "w")
723
    fh.write("blah\n")
724
    fh.close()
725
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
726
                         "ReadPidFile should return 0 for invalid pid file")
727
    # but now, even with the file existing, we should be able to lock it
728
    fd = utils.WritePidFile(self.f_dpn('test'))
729
    os.close(fd)
730
    utils.RemoveFile(self.f_dpn("test"))
731
    self.failIf(os.path.exists(pid_file),
732
                "PID file should not exist anymore")
733

    
734
  def testKill(self):
735
    pid_file = self.f_dpn('child')
736
    r_fd, w_fd = os.pipe()
737
    new_pid = os.fork()
738
    if new_pid == 0: #child
739
      utils.WritePidFile(self.f_dpn('child'))
740
      os.write(w_fd, 'a')
741
      signal.pause()
742
      os._exit(0)
743
      return
744
    # else we are in the parent
745
    # wait until the child has written the pid file
746
    os.read(r_fd, 1)
747
    read_pid = utils.ReadPidFile(pid_file)
748
    self.failUnlessEqual(read_pid, new_pid)
749
    self.failUnless(utils.IsProcessAlive(new_pid))
750

    
751
    # Try writing to locked file
752
    try:
753
      utils.WritePidFile(pid_file)
754
    except errors.PidFileLockError, err:
755
      errmsg = str(err)
756
      self.assertTrue(errmsg.endswith(" %s" % new_pid),
757
                      msg=("Error message ('%s') didn't contain correct"
758
                           " PID (%s)" % (errmsg, new_pid)))
759
    else:
760
      self.fail("Writing to locked file didn't fail")
761

    
762
    utils.KillProcess(new_pid, waitpid=True)
763
    self.failIf(utils.IsProcessAlive(new_pid))
764
    utils.RemoveFile(self.f_dpn('child'))
765
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
766

    
767
  def testExceptionType(self):
768
    # Make sure the PID lock error is a subclass of LockError in case some code
769
    # depends on it
770
    self.assertTrue(issubclass(errors.PidFileLockError, errors.LockError))
771

    
772
  def tearDown(self):
773
    shutil.rmtree(self.dir)
774

    
775

    
776
class TestSshKeys(testutils.GanetiTestCase):
777
  """Test case for the AddAuthorizedKey function"""
778

    
779
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
780
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
781
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
782

    
783
  def setUp(self):
784
    testutils.GanetiTestCase.setUp(self)
785
    self.tmpname = self._CreateTempFile()
786
    handle = open(self.tmpname, 'w')
787
    try:
788
      handle.write("%s\n" % TestSshKeys.KEY_A)
789
      handle.write("%s\n" % TestSshKeys.KEY_B)
790
    finally:
791
      handle.close()
792

    
793
  def testAddingNewKey(self):
794
    utils.AddAuthorizedKey(self.tmpname,
795
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
796

    
797
    self.assertFileContent(self.tmpname,
798
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
799
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
800
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
801
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
802

    
803
  def testAddingAlmostButNotCompletelyTheSameKey(self):
804
    utils.AddAuthorizedKey(self.tmpname,
805
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
806

    
807
    self.assertFileContent(self.tmpname,
808
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
809
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
810
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
811
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
812

    
813
  def testAddingExistingKeyWithSomeMoreSpaces(self):
814
    utils.AddAuthorizedKey(self.tmpname,
815
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
816

    
817
    self.assertFileContent(self.tmpname,
818
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
819
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
820
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
821

    
822
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
823
    utils.RemoveAuthorizedKey(self.tmpname,
824
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
825

    
826
    self.assertFileContent(self.tmpname,
827
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
828
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
829

    
830
  def testRemovingNonExistingKey(self):
831
    utils.RemoveAuthorizedKey(self.tmpname,
832
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
833

    
834
    self.assertFileContent(self.tmpname,
835
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
836
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
837
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
838

    
839

    
840
class TestNewUUID(unittest.TestCase):
841
  """Test case for NewUUID"""
842

    
843
  def runTest(self):
844
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
845

    
846

    
847
def _MockStatResult(cb, mode, uid, gid):
848
  def _fn(path):
849
    if cb:
850
      cb()
851
    return {
852
      stat.ST_MODE: mode,
853
      stat.ST_UID: uid,
854
      stat.ST_GID: gid,
855
      }
856
  return _fn
857

    
858

    
859
def _RaiseNoEntError():
860
  raise EnvironmentError(errno.ENOENT, "not found")
861

    
862

    
863
def _OtherStatRaise():
864
  raise EnvironmentError()
865

    
866

    
867
class TestPermissionEnforcements(unittest.TestCase):
868
  UID_A = 16024
869
  UID_B = 25850
870
  GID_A = 14028
871
  GID_B = 29801
872

    
873
  def setUp(self):
874
    self._chown_calls = []
875
    self._chmod_calls = []
876
    self._mkdir_calls = []
877

    
878
  def tearDown(self):
879
    self.assertRaises(IndexError, self._mkdir_calls.pop)
880
    self.assertRaises(IndexError, self._chmod_calls.pop)
881
    self.assertRaises(IndexError, self._chown_calls.pop)
882

    
883
  def _FakeMkdir(self, path):
884
    self._mkdir_calls.append(path)
885

    
886
  def _FakeChown(self, path, uid, gid):
887
    self._chown_calls.append((path, uid, gid))
888

    
889
  def _ChmodWrapper(self, cb):
890
    def _fn(path, mode):
891
      self._chmod_calls.append((path, mode))
892
      if cb:
893
        cb()
894
    return _fn
895

    
896
  def _VerifyPerm(self, path, mode, uid=-1, gid=-1):
897
    self.assertEqual(path, "/ganeti-qa-non-test")
898
    self.assertEqual(mode, 0700)
899
    self.assertEqual(uid, self.UID_A)
900
    self.assertEqual(gid, self.GID_A)
901

    
902
  def testMakeDirWithPerm(self):
903
    is_dir_stat = _MockStatResult(None, stat.S_IFDIR, 0, 0)
904
    utils.MakeDirWithPerm("/ganeti-qa-non-test", 0700, self.UID_A, self.GID_A,
905
                          _lstat_fn=is_dir_stat, _perm_fn=self._VerifyPerm)
906

    
907
  def testDirErrors(self):
908
    self.assertRaises(errors.GenericError, utils.MakeDirWithPerm,
909
                      "/ganeti-qa-non-test", 0700, 0, 0,
910
                      _lstat_fn=_MockStatResult(None, 0, 0, 0))
911
    self.assertRaises(IndexError, self._mkdir_calls.pop)
912

    
913
    other_stat_raise = _MockStatResult(_OtherStatRaise, stat.S_IFDIR, 0, 0)
914
    self.assertRaises(errors.GenericError, utils.MakeDirWithPerm,
915
                      "/ganeti-qa-non-test", 0700, 0, 0,
916
                      _lstat_fn=other_stat_raise)
917
    self.assertRaises(IndexError, self._mkdir_calls.pop)
918

    
919
    non_exist_stat = _MockStatResult(_RaiseNoEntError, stat.S_IFDIR, 0, 0)
920
    utils.MakeDirWithPerm("/ganeti-qa-non-test", 0700, self.UID_A, self.GID_A,
921
                          _lstat_fn=non_exist_stat, _mkdir_fn=self._FakeMkdir,
922
                          _perm_fn=self._VerifyPerm)
923
    self.assertEqual(self._mkdir_calls.pop(0), "/ganeti-qa-non-test")
924

    
925
  def testEnforcePermissionNoEnt(self):
926
    self.assertRaises(errors.GenericError, utils.EnforcePermission,
927
                      "/ganeti-qa-non-test", 0600,
928
                      _chmod_fn=NotImplemented, _chown_fn=NotImplemented,
929
                      _stat_fn=_MockStatResult(_RaiseNoEntError, 0, 0, 0))
930

    
931
  def testEnforcePermissionNoEntMustNotExist(self):
932
    utils.EnforcePermission("/ganeti-qa-non-test", 0600, must_exist=False,
933
                            _chmod_fn=NotImplemented,
934
                            _chown_fn=NotImplemented,
935
                            _stat_fn=_MockStatResult(_RaiseNoEntError,
936
                                                          0, 0, 0))
937

    
938
  def testEnforcePermissionOtherErrorMustNotExist(self):
939
    self.assertRaises(errors.GenericError, utils.EnforcePermission,
940
                      "/ganeti-qa-non-test", 0600, must_exist=False,
941
                      _chmod_fn=NotImplemented, _chown_fn=NotImplemented,
942
                      _stat_fn=_MockStatResult(_OtherStatRaise, 0, 0, 0))
943

    
944
  def testEnforcePermissionNoChanges(self):
945
    utils.EnforcePermission("/ganeti-qa-non-test", 0600,
946
                            _stat_fn=_MockStatResult(None, 0600, 0, 0),
947
                            _chmod_fn=self._ChmodWrapper(None),
948
                            _chown_fn=self._FakeChown)
949

    
950
  def testEnforcePermissionChangeMode(self):
951
    utils.EnforcePermission("/ganeti-qa-non-test", 0444,
952
                            _stat_fn=_MockStatResult(None, 0600, 0, 0),
953
                            _chmod_fn=self._ChmodWrapper(None),
954
                            _chown_fn=self._FakeChown)
955
    self.assertEqual(self._chmod_calls.pop(0), ("/ganeti-qa-non-test", 0444))
956

    
957
  def testEnforcePermissionSetUidGid(self):
958
    utils.EnforcePermission("/ganeti-qa-non-test", 0600,
959
                            uid=self.UID_B, gid=self.GID_B,
960
                            _stat_fn=_MockStatResult(None, 0600,
961
                                                     self.UID_A,
962
                                                     self.GID_A),
963
                            _chmod_fn=self._ChmodWrapper(None),
964
                            _chown_fn=self._FakeChown)
965
    self.assertEqual(self._chown_calls.pop(0),
966
                     ("/ganeti-qa-non-test", self.UID_B, self.GID_B))
967

    
968

    
969
if __name__ == "__main__":
970
  testutils.GanetiTestProgram()