Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils.io_unittest.py @ 2dbc6857

History | View | Annotate | Download (31.8 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.utils.io"""
23

    
24
import os
25
import tempfile
26
import unittest
27
import shutil
28
import glob
29
import time
30
import signal
31
import stat
32
import errno
33

    
34
from ganeti import constants
35
from ganeti import utils
36
from ganeti import compat
37
from ganeti import errors
38

    
39
import testutils
40

    
41

    
42
class TestReadFile(testutils.GanetiTestCase):
43
  def testReadAll(self):
44
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
45
    self.assertEqual(len(data), 814)
46

    
47
    h = compat.md5_hash()
48
    h.update(data)
49
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
50

    
51
  def testReadSize(self):
52
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
53
                          size=100)
54
    self.assertEqual(len(data), 100)
55

    
56
    h = compat.md5_hash()
57
    h.update(data)
58
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
59

    
60
  def testCallback(self):
61
    def _Cb(fh):
62
      self.assertEqual(fh.tell(), 0)
63
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"), preread=_Cb)
64
    self.assertEqual(len(data), 814)
65

    
66
  def testError(self):
67
    self.assertRaises(EnvironmentError, utils.ReadFile,
68
                      "/dev/null/does-not-exist")
69

    
70

    
71
class TestReadOneLineFile(testutils.GanetiTestCase):
72
  def setUp(self):
73
    testutils.GanetiTestCase.setUp(self)
74

    
75
  def testDefault(self):
76
    data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"))
77
    self.assertEqual(len(data), 27)
78
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
79

    
80
  def testNotStrict(self):
81
    data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"),
82
                                 strict=False)
83
    self.assertEqual(len(data), 27)
84
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
85

    
86
  def testStrictFailure(self):
87
    self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
88
                      self._TestDataFilename("cert1.pem"), strict=True)
89

    
90
  def testLongLine(self):
91
    dummydata = (1024 * "Hello World! ")
92
    myfile = self._CreateTempFile()
93
    utils.WriteFile(myfile, data=dummydata)
94
    datastrict = utils.ReadOneLineFile(myfile, strict=True)
95
    datalax = utils.ReadOneLineFile(myfile, strict=False)
96
    self.assertEqual(dummydata, datastrict)
97
    self.assertEqual(dummydata, datalax)
98

    
99
  def testNewline(self):
100
    myfile = self._CreateTempFile()
101
    myline = "myline"
102
    for nl in ["", "\n", "\r\n"]:
103
      dummydata = "%s%s" % (myline, nl)
104
      utils.WriteFile(myfile, data=dummydata)
105
      datalax = utils.ReadOneLineFile(myfile, strict=False)
106
      self.assertEqual(myline, datalax)
107
      datastrict = utils.ReadOneLineFile(myfile, strict=True)
108
      self.assertEqual(myline, datastrict)
109

    
110
  def testWhitespaceAndMultipleLines(self):
111
    myfile = self._CreateTempFile()
112
    for nl in ["", "\n", "\r\n"]:
113
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
114
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
115
        utils.WriteFile(myfile, data=dummydata)
116
        datalax = utils.ReadOneLineFile(myfile, strict=False)
117
        if nl:
118
          self.assert_(set("\r\n") & set(dummydata))
119
          self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
120
                            myfile, strict=True)
121
          explen = len("Foo bar baz ") + len(ws)
122
          self.assertEqual(len(datalax), explen)
123
          self.assertEqual(datalax, dummydata[:explen])
124
          self.assertFalse(set("\r\n") & set(datalax))
125
        else:
126
          datastrict = utils.ReadOneLineFile(myfile, strict=True)
127
          self.assertEqual(dummydata, datastrict)
128
          self.assertEqual(dummydata, datalax)
129

    
130
  def testEmptylines(self):
131
    myfile = self._CreateTempFile()
132
    myline = "myline"
133
    for nl in ["\n", "\r\n"]:
134
      for ol in ["", "otherline"]:
135
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
136
        utils.WriteFile(myfile, data=dummydata)
137
        self.assert_(set("\r\n") & set(dummydata))
138
        datalax = utils.ReadOneLineFile(myfile, strict=False)
139
        self.assertEqual(myline, datalax)
140
        if ol:
141
          self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
142
                            myfile, strict=True)
143
        else:
144
          datastrict = utils.ReadOneLineFile(myfile, strict=True)
145
          self.assertEqual(myline, datastrict)
146

    
147
  def testEmptyfile(self):
148
    myfile = self._CreateTempFile()
149
    self.assertRaises(errors.GenericError, utils.ReadOneLineFile, myfile)
150

    
151

    
152
class TestTimestampForFilename(unittest.TestCase):
153
  def test(self):
154
    self.assert_("." not in utils.TimestampForFilename())
155
    self.assert_(":" not in utils.TimestampForFilename())
156

    
157

    
158
class TestCreateBackup(testutils.GanetiTestCase):
159
  def setUp(self):
160
    testutils.GanetiTestCase.setUp(self)
161

    
162
    self.tmpdir = tempfile.mkdtemp()
163

    
164
  def tearDown(self):
165
    testutils.GanetiTestCase.tearDown(self)
166

    
167
    shutil.rmtree(self.tmpdir)
168

    
169
  def testEmpty(self):
170
    filename = utils.PathJoin(self.tmpdir, "config.data")
171
    utils.WriteFile(filename, data="")
172
    bname = utils.CreateBackup(filename)
173
    self.assertFileContent(bname, "")
174
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
175
    utils.CreateBackup(filename)
176
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
177
    utils.CreateBackup(filename)
178
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
179

    
180
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
181
    os.mkfifo(fifoname)
182
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
183

    
184
  def testContent(self):
185
    bkpcount = 0
186
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
187
      for rep in [1, 2, 10, 127]:
188
        testdata = data * rep
189

    
190
        filename = utils.PathJoin(self.tmpdir, "test.data_")
191
        utils.WriteFile(filename, data=testdata)
192
        self.assertFileContent(filename, testdata)
193

    
194
        for _ in range(3):
195
          bname = utils.CreateBackup(filename)
196
          bkpcount += 1
197
          self.assertFileContent(bname, testdata)
198
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
199

    
200

    
201
class TestListVisibleFiles(unittest.TestCase):
202
  """Test case for ListVisibleFiles"""
203

    
204
  def setUp(self):
205
    self.path = tempfile.mkdtemp()
206

    
207
  def tearDown(self):
208
    shutil.rmtree(self.path)
209

    
210
  def _CreateFiles(self, files):
211
    for name in files:
212
      utils.WriteFile(os.path.join(self.path, name), data="test")
213

    
214
  def _test(self, files, expected):
215
    self._CreateFiles(files)
216
    found = utils.ListVisibleFiles(self.path)
217
    self.assertEqual(set(found), set(expected))
218

    
219
  def testAllVisible(self):
220
    files = ["a", "b", "c"]
221
    expected = files
222
    self._test(files, expected)
223

    
224
  def testNoneVisible(self):
225
    files = [".a", ".b", ".c"]
226
    expected = []
227
    self._test(files, expected)
228

    
229
  def testSomeVisible(self):
230
    files = ["a", "b", ".c"]
231
    expected = ["a", "b"]
232
    self._test(files, expected)
233

    
234
  def testNonAbsolutePath(self):
235
    self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
236
                          "abc")
237

    
238
  def testNonNormalizedPath(self):
239
    self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
240
                          "/bin/../tmp")
241

    
242
  def testMountpoint(self):
243
    lvfmp_fn = compat.partial(utils.ListVisibleFiles,
244
                              _is_mountpoint=lambda _: True)
245
    self.assertEqual(lvfmp_fn(self.path), [])
246

    
247
    # Create "lost+found" as a regular file
248
    self._CreateFiles(["foo", "bar", ".baz", "lost+found"])
249
    self.assertEqual(set(lvfmp_fn(self.path)),
250
                     set(["foo", "bar", "lost+found"]))
251

    
252
    # Replace "lost+found" with a directory
253
    laf_path = utils.PathJoin(self.path, "lost+found")
254
    utils.RemoveFile(laf_path)
255
    os.mkdir(laf_path)
256
    self.assertEqual(set(lvfmp_fn(self.path)), set(["foo", "bar"]))
257

    
258
  def testLostAndFoundNoMountpoint(self):
259
    files = ["foo", "bar", ".Hello World", "lost+found"]
260
    expected = ["foo", "bar", "lost+found"]
261
    self._test(files, expected)
262

    
263

    
264
class TestWriteFile(unittest.TestCase):
265
  def setUp(self):
266
    self.tmpdir = None
267
    self.tfile = tempfile.NamedTemporaryFile()
268
    self.did_pre = False
269
    self.did_post = False
270
    self.did_write = False
271

    
272
  def tearDown(self):
273
    if self.tmpdir:
274
      shutil.rmtree(self.tmpdir)
275

    
276
  def markPre(self, fd):
277
    self.did_pre = True
278

    
279
  def markPost(self, fd):
280
    self.did_post = True
281

    
282
  def markWrite(self, fd):
283
    self.did_write = True
284

    
285
  def testWrite(self):
286
    data = "abc"
287
    utils.WriteFile(self.tfile.name, data=data)
288
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
289

    
290
  def testWriteSimpleUnicode(self):
291
    data = u"abc"
292
    utils.WriteFile(self.tfile.name, data=data)
293
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
294

    
295
  def testErrors(self):
296
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
297
                      self.tfile.name, data="test", fn=lambda fd: None)
298
    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
299
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
300
                      self.tfile.name, data="test", atime=0)
301

    
302
  def testPreWrite(self):
303
    utils.WriteFile(self.tfile.name, data="", prewrite=self.markPre)
304
    self.assertTrue(self.did_pre)
305
    self.assertFalse(self.did_post)
306
    self.assertFalse(self.did_write)
307

    
308
  def testPostWrite(self):
309
    utils.WriteFile(self.tfile.name, data="", postwrite=self.markPost)
310
    self.assertFalse(self.did_pre)
311
    self.assertTrue(self.did_post)
312
    self.assertFalse(self.did_write)
313

    
314
  def testWriteFunction(self):
315
    utils.WriteFile(self.tfile.name, fn=self.markWrite)
316
    self.assertFalse(self.did_pre)
317
    self.assertFalse(self.did_post)
318
    self.assertTrue(self.did_write)
319

    
320
  def testDryRun(self):
321
    orig = "abc"
322
    self.tfile.write(orig)
323
    self.tfile.flush()
324
    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
325
    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
326

    
327
  def testTimes(self):
328
    f = self.tfile.name
329
    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
330
                   (int(time.time()), 5000)]:
331
      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
332
      st = os.stat(f)
333
      self.assertEqual(st.st_atime, at)
334
      self.assertEqual(st.st_mtime, mt)
335

    
336
  def testNoClose(self):
337
    data = "hello"
338
    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
339
    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
340
    try:
341
      os.lseek(fd, 0, 0)
342
      self.assertEqual(os.read(fd, 4096), data)
343
    finally:
344
      os.close(fd)
345

    
346
  def testNoLeftovers(self):
347
    self.tmpdir = tempfile.mkdtemp()
348
    self.assertEqual(utils.WriteFile(utils.PathJoin(self.tmpdir, "test"),
349
                                     data="abc"),
350
                     None)
351
    self.assertEqual(os.listdir(self.tmpdir), ["test"])
352

    
353
  def testFailRename(self):
354
    self.tmpdir = tempfile.mkdtemp()
355
    target = utils.PathJoin(self.tmpdir, "target")
356
    os.mkdir(target)
357
    self.assertRaises(OSError, utils.WriteFile, target, data="abc")
358
    self.assertTrue(os.path.isdir(target))
359
    self.assertEqual(os.listdir(self.tmpdir), ["target"])
360
    self.assertFalse(os.listdir(target))
361

    
362
  def testFailRenameDryRun(self):
363
    self.tmpdir = tempfile.mkdtemp()
364
    target = utils.PathJoin(self.tmpdir, "target")
365
    os.mkdir(target)
366
    self.assertEqual(utils.WriteFile(target, data="abc", dry_run=True), None)
367
    self.assertTrue(os.path.isdir(target))
368
    self.assertEqual(os.listdir(self.tmpdir), ["target"])
369
    self.assertFalse(os.listdir(target))
370

    
371
  def testBackup(self):
372
    self.tmpdir = tempfile.mkdtemp()
373
    testfile = utils.PathJoin(self.tmpdir, "test")
374

    
375
    self.assertEqual(utils.WriteFile(testfile, data="foo", backup=True), None)
376
    self.assertEqual(utils.ReadFile(testfile), "foo")
377
    self.assertEqual(os.listdir(self.tmpdir), ["test"])
378

    
379
    # Write again
380
    assert os.path.isfile(testfile)
381
    self.assertEqual(utils.WriteFile(testfile, data="bar", backup=True), None)
382
    self.assertEqual(utils.ReadFile(testfile), "bar")
383
    self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
384
    self.assertTrue("test" in os.listdir(self.tmpdir))
385
    self.assertEqual(len(os.listdir(self.tmpdir)), 2)
386

    
387
    # Write again as dry-run
388
    assert os.path.isfile(testfile)
389
    self.assertEqual(utils.WriteFile(testfile, data="000", backup=True,
390
                                     dry_run=True),
391
                     None)
392
    self.assertEqual(utils.ReadFile(testfile), "bar")
393
    self.assertEqual(len(glob.glob("%s.backup*" % testfile)), 1)
394
    self.assertTrue("test" in os.listdir(self.tmpdir))
395
    self.assertEqual(len(os.listdir(self.tmpdir)), 2)
396

    
397

    
398
class TestFileID(testutils.GanetiTestCase):
399
  def testEquality(self):
400
    name = self._CreateTempFile()
401
    oldi = utils.GetFileID(path=name)
402
    self.failUnless(utils.VerifyFileID(oldi, oldi))
403

    
404
  def testUpdate(self):
405
    name = self._CreateTempFile()
406
    oldi = utils.GetFileID(path=name)
407
    fd = os.open(name, os.O_RDWR)
408
    try:
409
      newi = utils.GetFileID(fd=fd)
410
      self.failUnless(utils.VerifyFileID(oldi, newi))
411
      self.failUnless(utils.VerifyFileID(newi, oldi))
412
    finally:
413
      os.close(fd)
414

    
415
  def testWriteFile(self):
416
    name = self._CreateTempFile()
417
    oldi = utils.GetFileID(path=name)
418
    mtime = oldi[2]
419
    os.utime(name, (mtime + 10, mtime + 10))
420
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
421
                      oldi, data="")
422
    os.utime(name, (mtime - 10, mtime - 10))
423
    utils.SafeWriteFile(name, oldi, data="")
424
    oldi = utils.GetFileID(path=name)
425
    mtime = oldi[2]
426
    os.utime(name, (mtime + 10, mtime + 10))
427
    # this doesn't raise, since we passed None
428
    utils.SafeWriteFile(name, None, data="")
429

    
430
  def testError(self):
431
    t = tempfile.NamedTemporaryFile()
432
    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
433
                      path=t.name, fd=t.fileno())
434

    
435

    
436
class TestRemoveFile(unittest.TestCase):
437
  """Test case for the RemoveFile function"""
438

    
439
  def setUp(self):
440
    """Create a temp dir and file for each case"""
441
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
442
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
443
    os.close(fd)
444

    
445
  def tearDown(self):
446
    if os.path.exists(self.tmpfile):
447
      os.unlink(self.tmpfile)
448
    os.rmdir(self.tmpdir)
449

    
450
  def testIgnoreDirs(self):
451
    """Test that RemoveFile() ignores directories"""
452
    self.assertEqual(None, utils.RemoveFile(self.tmpdir))
453

    
454
  def testIgnoreNotExisting(self):
455
    """Test that RemoveFile() ignores non-existing files"""
456
    utils.RemoveFile(self.tmpfile)
457
    utils.RemoveFile(self.tmpfile)
458

    
459
  def testRemoveFile(self):
460
    """Test that RemoveFile does remove a file"""
461
    utils.RemoveFile(self.tmpfile)
462
    if os.path.exists(self.tmpfile):
463
      self.fail("File '%s' not removed" % self.tmpfile)
464

    
465
  def testRemoveSymlink(self):
466
    """Test that RemoveFile does remove symlinks"""
467
    symlink = self.tmpdir + "/symlink"
468
    os.symlink("no-such-file", symlink)
469
    utils.RemoveFile(symlink)
470
    if os.path.exists(symlink):
471
      self.fail("File '%s' not removed" % symlink)
472
    os.symlink(self.tmpfile, symlink)
473
    utils.RemoveFile(symlink)
474
    if os.path.exists(symlink):
475
      self.fail("File '%s' not removed" % symlink)
476

    
477

    
478
class TestRemoveDir(unittest.TestCase):
479
  def setUp(self):
480
    self.tmpdir = tempfile.mkdtemp()
481

    
482
  def tearDown(self):
483
    try:
484
      shutil.rmtree(self.tmpdir)
485
    except EnvironmentError:
486
      pass
487

    
488
  def testEmptyDir(self):
489
    utils.RemoveDir(self.tmpdir)
490
    self.assertFalse(os.path.isdir(self.tmpdir))
491

    
492
  def testNonEmptyDir(self):
493
    self.tmpfile = os.path.join(self.tmpdir, "test1")
494
    open(self.tmpfile, "w").close()
495
    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
496

    
497

    
498
class TestRename(unittest.TestCase):
499
  """Test case for RenameFile"""
500

    
501
  def setUp(self):
502
    """Create a temporary directory"""
503
    self.tmpdir = tempfile.mkdtemp()
504
    self.tmpfile = os.path.join(self.tmpdir, "test1")
505

    
506
    # Touch the file
507
    open(self.tmpfile, "w").close()
508

    
509
  def tearDown(self):
510
    """Remove temporary directory"""
511
    shutil.rmtree(self.tmpdir)
512

    
513
  def testSimpleRename1(self):
514
    """Simple rename 1"""
515
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
516
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
517

    
518
  def testSimpleRename2(self):
519
    """Simple rename 2"""
520
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
521
                     mkdir=True)
522
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
523

    
524
  def testRenameMkdir(self):
525
    """Rename with mkdir"""
526
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
527
                     mkdir=True)
528
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
529
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
530

    
531
    self.assertRaises(EnvironmentError, utils.RenameFile,
532
                      os.path.join(self.tmpdir, "test/xyz"),
533
                      os.path.join(self.tmpdir, "test/foo/bar/baz"),
534
                      mkdir=True)
535

    
536
    self.assertTrue(os.path.exists(os.path.join(self.tmpdir, "test/xyz")))
537
    self.assertFalse(os.path.exists(os.path.join(self.tmpdir, "test/foo/bar")))
538
    self.assertFalse(os.path.exists(os.path.join(self.tmpdir,
539
                                                 "test/foo/bar/baz")))
540

    
541

    
542
class TestMakedirs(unittest.TestCase):
543
  def setUp(self):
544
    self.tmpdir = tempfile.mkdtemp()
545

    
546
  def tearDown(self):
547
    shutil.rmtree(self.tmpdir)
548

    
549
  def testNonExisting(self):
550
    path = utils.PathJoin(self.tmpdir, "foo")
551
    utils.Makedirs(path)
552
    self.assert_(os.path.isdir(path))
553

    
554
  def testExisting(self):
555
    path = utils.PathJoin(self.tmpdir, "foo")
556
    os.mkdir(path)
557
    utils.Makedirs(path)
558
    self.assert_(os.path.isdir(path))
559

    
560
  def testRecursiveNonExisting(self):
561
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
562
    utils.Makedirs(path)
563
    self.assert_(os.path.isdir(path))
564

    
565
  def testRecursiveExisting(self):
566
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
567
    self.assertFalse(os.path.exists(path))
568
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
569
    utils.Makedirs(path)
570
    self.assert_(os.path.isdir(path))
571

    
572

    
573
class TestEnsureDirs(unittest.TestCase):
574
  """Tests for EnsureDirs"""
575

    
576
  def setUp(self):
577
    self.dir = tempfile.mkdtemp()
578
    self.old_umask = os.umask(0777)
579

    
580
  def testEnsureDirs(self):
581
    utils.EnsureDirs([
582
        (utils.PathJoin(self.dir, "foo"), 0777),
583
        (utils.PathJoin(self.dir, "bar"), 0000),
584
        ])
585
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
586
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
587

    
588
  def tearDown(self):
589
    os.rmdir(utils.PathJoin(self.dir, "foo"))
590
    os.rmdir(utils.PathJoin(self.dir, "bar"))
591
    os.rmdir(self.dir)
592
    os.umask(self.old_umask)
593

    
594

    
595
class TestIsNormAbsPath(unittest.TestCase):
596
  """Testing case for IsNormAbsPath"""
597

    
598
  def _pathTestHelper(self, path, result):
599
    if result:
600
      self.assert_(utils.IsNormAbsPath(path),
601
          "Path %s should result absolute and normalized" % path)
602
    else:
603
      self.assertFalse(utils.IsNormAbsPath(path),
604
          "Path %s should not result absolute and normalized" % path)
605

    
606
  def testBase(self):
607
    self._pathTestHelper("/etc", True)
608
    self._pathTestHelper("/srv", True)
609
    self._pathTestHelper("etc", False)
610
    self._pathTestHelper("/etc/../root", False)
611
    self._pathTestHelper("/etc/", False)
612

    
613

    
614
class TestIsBelowDir(unittest.TestCase):
615
  """Testing case for IsBelowDir"""
616

    
617
  def testSamePrefix(self):
618
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c"))
619
    self.assertTrue(utils.IsBelowDir("/a/b/", "/a/b/e"))
620

    
621
  def testSamePrefixButDifferentDir(self):
622
    self.assertFalse(utils.IsBelowDir("/a/b", "/a/bc/d"))
623
    self.assertFalse(utils.IsBelowDir("/a/b/", "/a/bc/e"))
624

    
625
  def testSamePrefixButDirTraversal(self):
626
    self.assertFalse(utils.IsBelowDir("/a/b", "/a/b/../c"))
627
    self.assertFalse(utils.IsBelowDir("/a/b/", "/a/b/../d"))
628

    
629
  def testSamePrefixAndTraversal(self):
630
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c/../d"))
631
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/c/./e"))
632
    self.assertTrue(utils.IsBelowDir("/a/b", "/a/b/../b/./e"))
633

    
634
  def testBothAbsPath(self):
635
    self.assertRaises(ValueError, utils.IsBelowDir, "/a/b/c", "d")
636
    self.assertRaises(ValueError, utils.IsBelowDir, "a/b/c", "/d")
637
    self.assertRaises(ValueError, utils.IsBelowDir, "a/b/c", "d")
638

    
639

    
640
class TestPathJoin(unittest.TestCase):
641
  """Testing case for PathJoin"""
642

    
643
  def testBasicItems(self):
644
    mlist = ["/a", "b", "c"]
645
    self.failUnlessEqual(utils.PathJoin(*mlist), "/".join(mlist))
646

    
647
  def testNonAbsPrefix(self):
648
    self.failUnlessRaises(ValueError, utils.PathJoin, "a", "b")
649

    
650
  def testBackTrack(self):
651
    self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "b/../c")
652

    
653
  def testMultiAbs(self):
654
    self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "/b")
655

    
656

    
657
class TestTailFile(testutils.GanetiTestCase):
658
  """Test case for the TailFile function"""
659

    
660
  def testEmpty(self):
661
    fname = self._CreateTempFile()
662
    self.failUnlessEqual(utils.TailFile(fname), [])
663
    self.failUnlessEqual(utils.TailFile(fname, lines=25), [])
664

    
665
  def testAllLines(self):
666
    data = ["test %d" % i for i in range(30)]
667
    for i in range(30):
668
      fname = self._CreateTempFile()
669
      fd = open(fname, "w")
670
      fd.write("\n".join(data[:i]))
671
      if i > 0:
672
        fd.write("\n")
673
      fd.close()
674
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[:i])
675

    
676
  def testPartialLines(self):
677
    data = ["test %d" % i for i in range(30)]
678
    fname = self._CreateTempFile()
679
    fd = open(fname, "w")
680
    fd.write("\n".join(data))
681
    fd.write("\n")
682
    fd.close()
683
    for i in range(1, 30):
684
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
685

    
686
  def testBigFile(self):
687
    data = ["test %d" % i for i in range(30)]
688
    fname = self._CreateTempFile()
689
    fd = open(fname, "w")
690
    fd.write("X" * 1048576)
691
    fd.write("\n")
692
    fd.write("\n".join(data))
693
    fd.write("\n")
694
    fd.close()
695
    for i in range(1, 30):
696
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
697

    
698

    
699
class TestPidFileFunctions(unittest.TestCase):
700
  """Tests for WritePidFile and ReadPidFile"""
701

    
702
  def setUp(self):
703
    self.dir = tempfile.mkdtemp()
704
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
705

    
706
  def testPidFileFunctions(self):
707
    pid_file = self.f_dpn('test')
708
    fd = utils.WritePidFile(self.f_dpn('test'))
709
    self.failUnless(os.path.exists(pid_file),
710
                    "PID file should have been created")
711
    read_pid = utils.ReadPidFile(pid_file)
712
    self.failUnlessEqual(read_pid, os.getpid())
713
    self.failUnless(utils.IsProcessAlive(read_pid))
714
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
715
                          self.f_dpn('test'))
716
    os.close(fd)
717
    utils.RemoveFile(self.f_dpn("test"))
718
    self.failIf(os.path.exists(pid_file),
719
                "PID file should not exist anymore")
720
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
721
                         "ReadPidFile should return 0 for missing pid file")
722
    fh = open(pid_file, "w")
723
    fh.write("blah\n")
724
    fh.close()
725
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
726
                         "ReadPidFile should return 0 for invalid pid file")
727
    # but now, even with the file existing, we should be able to lock it
728
    fd = utils.WritePidFile(self.f_dpn('test'))
729
    os.close(fd)
730
    utils.RemoveFile(self.f_dpn("test"))
731
    self.failIf(os.path.exists(pid_file),
732
                "PID file should not exist anymore")
733

    
734
  def testKill(self):
735
    pid_file = self.f_dpn('child')
736
    r_fd, w_fd = os.pipe()
737
    new_pid = os.fork()
738
    if new_pid == 0: #child
739
      utils.WritePidFile(self.f_dpn('child'))
740
      os.write(w_fd, 'a')
741
      signal.pause()
742
      os._exit(0)
743
      return
744
    # else we are in the parent
745
    # wait until the child has written the pid file
746
    os.read(r_fd, 1)
747
    read_pid = utils.ReadPidFile(pid_file)
748
    self.failUnlessEqual(read_pid, new_pid)
749
    self.failUnless(utils.IsProcessAlive(new_pid))
750
    utils.KillProcess(new_pid, waitpid=True)
751
    self.failIf(utils.IsProcessAlive(new_pid))
752
    utils.RemoveFile(self.f_dpn('child'))
753
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
754

    
755
  def tearDown(self):
756
    shutil.rmtree(self.dir)
757

    
758

    
759
class TestSshKeys(testutils.GanetiTestCase):
760
  """Test case for the AddAuthorizedKey function"""
761

    
762
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
763
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
764
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
765

    
766
  def setUp(self):
767
    testutils.GanetiTestCase.setUp(self)
768
    self.tmpname = self._CreateTempFile()
769
    handle = open(self.tmpname, 'w')
770
    try:
771
      handle.write("%s\n" % TestSshKeys.KEY_A)
772
      handle.write("%s\n" % TestSshKeys.KEY_B)
773
    finally:
774
      handle.close()
775

    
776
  def testAddingNewKey(self):
777
    utils.AddAuthorizedKey(self.tmpname,
778
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
779

    
780
    self.assertFileContent(self.tmpname,
781
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
782
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
783
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
784
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
785

    
786
  def testAddingAlmostButNotCompletelyTheSameKey(self):
787
    utils.AddAuthorizedKey(self.tmpname,
788
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
789

    
790
    self.assertFileContent(self.tmpname,
791
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
792
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
793
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
794
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
795

    
796
  def testAddingExistingKeyWithSomeMoreSpaces(self):
797
    utils.AddAuthorizedKey(self.tmpname,
798
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
799

    
800
    self.assertFileContent(self.tmpname,
801
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
802
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
803
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
804

    
805
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
806
    utils.RemoveAuthorizedKey(self.tmpname,
807
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
808

    
809
    self.assertFileContent(self.tmpname,
810
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
811
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
812

    
813
  def testRemovingNonExistingKey(self):
814
    utils.RemoveAuthorizedKey(self.tmpname,
815
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
816

    
817
    self.assertFileContent(self.tmpname,
818
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
819
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
820
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
821

    
822

    
823
class TestNewUUID(unittest.TestCase):
824
  """Test case for NewUUID"""
825

    
826
  def runTest(self):
827
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
828

    
829

    
830
def _MockStatResult(cb, mode, uid, gid):
831
  def _fn(path):
832
    if cb:
833
      cb()
834
    return {
835
      stat.ST_MODE: mode,
836
      stat.ST_UID: uid,
837
      stat.ST_GID: gid,
838
      }
839
  return _fn
840

    
841

    
842
def _RaiseNoEntError():
843
  raise EnvironmentError(errno.ENOENT, "not found")
844

    
845

    
846
def _OtherStatRaise():
847
  raise EnvironmentError()
848

    
849

    
850
class TestPermissionEnforcements(unittest.TestCase):
851
  UID_A = 16024
852
  UID_B = 25850
853
  GID_A = 14028
854
  GID_B = 29801
855

    
856
  def setUp(self):
857
    self._chown_calls = []
858
    self._chmod_calls = []
859
    self._mkdir_calls = []
860

    
861
  def tearDown(self):
862
    self.assertRaises(IndexError, self._mkdir_calls.pop)
863
    self.assertRaises(IndexError, self._chmod_calls.pop)
864
    self.assertRaises(IndexError, self._chown_calls.pop)
865

    
866
  def _FakeMkdir(self, path):
867
    self._mkdir_calls.append(path)
868

    
869
  def _FakeChown(self, path, uid, gid):
870
    self._chown_calls.append((path, uid, gid))
871

    
872
  def _ChmodWrapper(self, cb):
873
    def _fn(path, mode):
874
      self._chmod_calls.append((path, mode))
875
      if cb:
876
        cb()
877
    return _fn
878

    
879
  def _VerifyPerm(self, path, mode, uid=-1, gid=-1):
880
    self.assertEqual(path, "/ganeti-qa-non-test")
881
    self.assertEqual(mode, 0700)
882
    self.assertEqual(uid, self.UID_A)
883
    self.assertEqual(gid, self.GID_A)
884

    
885
  def testMakeDirWithPerm(self):
886
    is_dir_stat = _MockStatResult(None, stat.S_IFDIR, 0, 0)
887
    utils.MakeDirWithPerm("/ganeti-qa-non-test", 0700, self.UID_A, self.GID_A,
888
                          _lstat_fn=is_dir_stat, _perm_fn=self._VerifyPerm)
889

    
890
  def testDirErrors(self):
891
    self.assertRaises(errors.GenericError, utils.MakeDirWithPerm,
892
                      "/ganeti-qa-non-test", 0700, 0, 0,
893
                      _lstat_fn=_MockStatResult(None, 0, 0, 0))
894
    self.assertRaises(IndexError, self._mkdir_calls.pop)
895

    
896
    other_stat_raise = _MockStatResult(_OtherStatRaise, stat.S_IFDIR, 0, 0)
897
    self.assertRaises(errors.GenericError, utils.MakeDirWithPerm,
898
                      "/ganeti-qa-non-test", 0700, 0, 0,
899
                      _lstat_fn=other_stat_raise)
900
    self.assertRaises(IndexError, self._mkdir_calls.pop)
901

    
902
    non_exist_stat = _MockStatResult(_RaiseNoEntError, stat.S_IFDIR, 0, 0)
903
    utils.MakeDirWithPerm("/ganeti-qa-non-test", 0700, self.UID_A, self.GID_A,
904
                          _lstat_fn=non_exist_stat, _mkdir_fn=self._FakeMkdir,
905
                          _perm_fn=self._VerifyPerm)
906
    self.assertEqual(self._mkdir_calls.pop(0), "/ganeti-qa-non-test")
907

    
908
  def testEnforcePermissionNoEnt(self):
909
    self.assertRaises(errors.GenericError, utils.EnforcePermission,
910
                      "/ganeti-qa-non-test", 0600,
911
                      _chmod_fn=NotImplemented, _chown_fn=NotImplemented,
912
                      _stat_fn=_MockStatResult(_RaiseNoEntError, 0, 0, 0))
913

    
914
  def testEnforcePermissionNoEntMustNotExist(self):
915
    utils.EnforcePermission("/ganeti-qa-non-test", 0600, must_exist=False,
916
                            _chmod_fn=NotImplemented,
917
                            _chown_fn=NotImplemented,
918
                            _stat_fn=_MockStatResult(_RaiseNoEntError,
919
                                                          0, 0, 0))
920

    
921
  def testEnforcePermissionOtherErrorMustNotExist(self):
922
    self.assertRaises(errors.GenericError, utils.EnforcePermission,
923
                      "/ganeti-qa-non-test", 0600, must_exist=False,
924
                      _chmod_fn=NotImplemented, _chown_fn=NotImplemented,
925
                      _stat_fn=_MockStatResult(_OtherStatRaise, 0, 0, 0))
926

    
927
  def testEnforcePermissionNoChanges(self):
928
    utils.EnforcePermission("/ganeti-qa-non-test", 0600,
929
                            _stat_fn=_MockStatResult(None, 0600, 0, 0),
930
                            _chmod_fn=self._ChmodWrapper(None),
931
                            _chown_fn=self._FakeChown)
932

    
933
  def testEnforcePermissionChangeMode(self):
934
    utils.EnforcePermission("/ganeti-qa-non-test", 0444,
935
                            _stat_fn=_MockStatResult(None, 0600, 0, 0),
936
                            _chmod_fn=self._ChmodWrapper(None),
937
                            _chown_fn=self._FakeChown)
938
    self.assertEqual(self._chmod_calls.pop(0), ("/ganeti-qa-non-test", 0444))
939

    
940
  def testEnforcePermissionSetUidGid(self):
941
    utils.EnforcePermission("/ganeti-qa-non-test", 0600,
942
                            uid=self.UID_B, gid=self.GID_B,
943
                            _stat_fn=_MockStatResult(None, 0600,
944
                                                     self.UID_A,
945
                                                     self.GID_A),
946
                            _chmod_fn=self._ChmodWrapper(None),
947
                            _chown_fn=self._FakeChown)
948
    self.assertEqual(self._chown_calls.pop(0),
949
                     ("/ganeti-qa-non-test", self.UID_B, self.GID_B))
950

    
951

    
952
if __name__ == "__main__":
953
  testutils.GanetiTestProgram()