Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils.io_unittest.py @ 3865ca48

History | View | Annotate | Download (22.3 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010, 2011 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.utils.io"""
23

    
24
import os
25
import tempfile
26
import unittest
27
import shutil
28
import glob
29
import time
30

    
31
from ganeti import constants
32
from ganeti import utils
33
from ganeti import compat
34
from ganeti import errors
35

    
36
import testutils
37

    
38

    
39
class TestReadFile(testutils.GanetiTestCase):
40
  def testReadAll(self):
41
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
42
    self.assertEqual(len(data), 814)
43

    
44
    h = compat.md5_hash()
45
    h.update(data)
46
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
47

    
48
  def testReadSize(self):
49
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
50
                          size=100)
51
    self.assertEqual(len(data), 100)
52

    
53
    h = compat.md5_hash()
54
    h.update(data)
55
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
56

    
57
  def testError(self):
58
    self.assertRaises(EnvironmentError, utils.ReadFile,
59
                      "/dev/null/does-not-exist")
60

    
61

    
62
class TestReadOneLineFile(testutils.GanetiTestCase):
63
  def setUp(self):
64
    testutils.GanetiTestCase.setUp(self)
65

    
66
  def testDefault(self):
67
    data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"))
68
    self.assertEqual(len(data), 27)
69
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
70

    
71
  def testNotStrict(self):
72
    data = utils.ReadOneLineFile(self._TestDataFilename("cert1.pem"),
73
                                 strict=False)
74
    self.assertEqual(len(data), 27)
75
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
76

    
77
  def testStrictFailure(self):
78
    self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
79
                      self._TestDataFilename("cert1.pem"), strict=True)
80

    
81
  def testLongLine(self):
82
    dummydata = (1024 * "Hello World! ")
83
    myfile = self._CreateTempFile()
84
    utils.WriteFile(myfile, data=dummydata)
85
    datastrict = utils.ReadOneLineFile(myfile, strict=True)
86
    datalax = utils.ReadOneLineFile(myfile, strict=False)
87
    self.assertEqual(dummydata, datastrict)
88
    self.assertEqual(dummydata, datalax)
89

    
90
  def testNewline(self):
91
    myfile = self._CreateTempFile()
92
    myline = "myline"
93
    for nl in ["", "\n", "\r\n"]:
94
      dummydata = "%s%s" % (myline, nl)
95
      utils.WriteFile(myfile, data=dummydata)
96
      datalax = utils.ReadOneLineFile(myfile, strict=False)
97
      self.assertEqual(myline, datalax)
98
      datastrict = utils.ReadOneLineFile(myfile, strict=True)
99
      self.assertEqual(myline, datastrict)
100

    
101
  def testWhitespaceAndMultipleLines(self):
102
    myfile = self._CreateTempFile()
103
    for nl in ["", "\n", "\r\n"]:
104
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
105
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
106
        utils.WriteFile(myfile, data=dummydata)
107
        datalax = utils.ReadOneLineFile(myfile, strict=False)
108
        if nl:
109
          self.assert_(set("\r\n") & set(dummydata))
110
          self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
111
                            myfile, strict=True)
112
          explen = len("Foo bar baz ") + len(ws)
113
          self.assertEqual(len(datalax), explen)
114
          self.assertEqual(datalax, dummydata[:explen])
115
          self.assertFalse(set("\r\n") & set(datalax))
116
        else:
117
          datastrict = utils.ReadOneLineFile(myfile, strict=True)
118
          self.assertEqual(dummydata, datastrict)
119
          self.assertEqual(dummydata, datalax)
120

    
121
  def testEmptylines(self):
122
    myfile = self._CreateTempFile()
123
    myline = "myline"
124
    for nl in ["\n", "\r\n"]:
125
      for ol in ["", "otherline"]:
126
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
127
        utils.WriteFile(myfile, data=dummydata)
128
        self.assert_(set("\r\n") & set(dummydata))
129
        datalax = utils.ReadOneLineFile(myfile, strict=False)
130
        self.assertEqual(myline, datalax)
131
        if ol:
132
          self.assertRaises(errors.GenericError, utils.ReadOneLineFile,
133
                            myfile, strict=True)
134
        else:
135
          datastrict = utils.ReadOneLineFile(myfile, strict=True)
136
          self.assertEqual(myline, datastrict)
137

    
138
  def testEmptyfile(self):
139
    myfile = self._CreateTempFile()
140
    self.assertRaises(errors.GenericError, utils.ReadOneLineFile, myfile)
141

    
142

    
143
class TestTimestampForFilename(unittest.TestCase):
144
  def test(self):
145
    self.assert_("." not in utils.TimestampForFilename())
146
    self.assert_(":" not in utils.TimestampForFilename())
147

    
148

    
149
class TestCreateBackup(testutils.GanetiTestCase):
150
  def setUp(self):
151
    testutils.GanetiTestCase.setUp(self)
152

    
153
    self.tmpdir = tempfile.mkdtemp()
154

    
155
  def tearDown(self):
156
    testutils.GanetiTestCase.tearDown(self)
157

    
158
    shutil.rmtree(self.tmpdir)
159

    
160
  def testEmpty(self):
161
    filename = utils.PathJoin(self.tmpdir, "config.data")
162
    utils.WriteFile(filename, data="")
163
    bname = utils.CreateBackup(filename)
164
    self.assertFileContent(bname, "")
165
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
166
    utils.CreateBackup(filename)
167
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
168
    utils.CreateBackup(filename)
169
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
170

    
171
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
172
    os.mkfifo(fifoname)
173
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
174

    
175
  def testContent(self):
176
    bkpcount = 0
177
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
178
      for rep in [1, 2, 10, 127]:
179
        testdata = data * rep
180

    
181
        filename = utils.PathJoin(self.tmpdir, "test.data_")
182
        utils.WriteFile(filename, data=testdata)
183
        self.assertFileContent(filename, testdata)
184

    
185
        for _ in range(3):
186
          bname = utils.CreateBackup(filename)
187
          bkpcount += 1
188
          self.assertFileContent(bname, testdata)
189
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
190

    
191

    
192
class TestListVisibleFiles(unittest.TestCase):
193
  """Test case for ListVisibleFiles"""
194

    
195
  def setUp(self):
196
    self.path = tempfile.mkdtemp()
197

    
198
  def tearDown(self):
199
    shutil.rmtree(self.path)
200

    
201
  def _CreateFiles(self, files):
202
    for name in files:
203
      utils.WriteFile(os.path.join(self.path, name), data="test")
204

    
205
  def _test(self, files, expected):
206
    self._CreateFiles(files)
207
    found = utils.ListVisibleFiles(self.path)
208
    self.assertEqual(set(found), set(expected))
209

    
210
  def testAllVisible(self):
211
    files = ["a", "b", "c"]
212
    expected = files
213
    self._test(files, expected)
214

    
215
  def testNoneVisible(self):
216
    files = [".a", ".b", ".c"]
217
    expected = []
218
    self._test(files, expected)
219

    
220
  def testSomeVisible(self):
221
    files = ["a", "b", ".c"]
222
    expected = ["a", "b"]
223
    self._test(files, expected)
224

    
225
  def testNonAbsolutePath(self):
226
    self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
227
                          "abc")
228

    
229
  def testNonNormalizedPath(self):
230
    self.failUnlessRaises(errors.ProgrammerError, utils.ListVisibleFiles,
231
                          "/bin/../tmp")
232

    
233

    
234
class TestWriteFile(unittest.TestCase):
235
  def setUp(self):
236
    self.tfile = tempfile.NamedTemporaryFile()
237
    self.did_pre = False
238
    self.did_post = False
239
    self.did_write = False
240

    
241
  def markPre(self, fd):
242
    self.did_pre = True
243

    
244
  def markPost(self, fd):
245
    self.did_post = True
246

    
247
  def markWrite(self, fd):
248
    self.did_write = True
249

    
250
  def testWrite(self):
251
    data = "abc"
252
    utils.WriteFile(self.tfile.name, data=data)
253
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
254

    
255
  def testErrors(self):
256
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
257
                      self.tfile.name, data="test", fn=lambda fd: None)
258
    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
259
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
260
                      self.tfile.name, data="test", atime=0)
261

    
262
  def testCalls(self):
263
    utils.WriteFile(self.tfile.name, fn=self.markWrite,
264
                    prewrite=self.markPre, postwrite=self.markPost)
265
    self.assertTrue(self.did_pre)
266
    self.assertTrue(self.did_post)
267
    self.assertTrue(self.did_write)
268

    
269
  def testDryRun(self):
270
    orig = "abc"
271
    self.tfile.write(orig)
272
    self.tfile.flush()
273
    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
274
    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
275

    
276
  def testTimes(self):
277
    f = self.tfile.name
278
    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
279
                   (int(time.time()), 5000)]:
280
      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
281
      st = os.stat(f)
282
      self.assertEqual(st.st_atime, at)
283
      self.assertEqual(st.st_mtime, mt)
284

    
285
  def testNoClose(self):
286
    data = "hello"
287
    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
288
    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
289
    try:
290
      os.lseek(fd, 0, 0)
291
      self.assertEqual(os.read(fd, 4096), data)
292
    finally:
293
      os.close(fd)
294

    
295

    
296
class TestFileID(testutils.GanetiTestCase):
297
  def testEquality(self):
298
    name = self._CreateTempFile()
299
    oldi = utils.GetFileID(path=name)
300
    self.failUnless(utils.VerifyFileID(oldi, oldi))
301

    
302
  def testUpdate(self):
303
    name = self._CreateTempFile()
304
    oldi = utils.GetFileID(path=name)
305
    os.utime(name, None)
306
    fd = os.open(name, os.O_RDWR)
307
    try:
308
      newi = utils.GetFileID(fd=fd)
309
      self.failUnless(utils.VerifyFileID(oldi, newi))
310
      self.failUnless(utils.VerifyFileID(newi, oldi))
311
    finally:
312
      os.close(fd)
313

    
314
  def testWriteFile(self):
315
    name = self._CreateTempFile()
316
    oldi = utils.GetFileID(path=name)
317
    mtime = oldi[2]
318
    os.utime(name, (mtime + 10, mtime + 10))
319
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
320
                      oldi, data="")
321
    os.utime(name, (mtime - 10, mtime - 10))
322
    utils.SafeWriteFile(name, oldi, data="")
323
    oldi = utils.GetFileID(path=name)
324
    mtime = oldi[2]
325
    os.utime(name, (mtime + 10, mtime + 10))
326
    # this doesn't raise, since we passed None
327
    utils.SafeWriteFile(name, None, data="")
328

    
329
  def testError(self):
330
    t = tempfile.NamedTemporaryFile()
331
    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
332
                      path=t.name, fd=t.fileno())
333

    
334

    
335
class TestRemoveFile(unittest.TestCase):
336
  """Test case for the RemoveFile function"""
337

    
338
  def setUp(self):
339
    """Create a temp dir and file for each case"""
340
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
341
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
342
    os.close(fd)
343

    
344
  def tearDown(self):
345
    if os.path.exists(self.tmpfile):
346
      os.unlink(self.tmpfile)
347
    os.rmdir(self.tmpdir)
348

    
349
  def testIgnoreDirs(self):
350
    """Test that RemoveFile() ignores directories"""
351
    self.assertEqual(None, utils.RemoveFile(self.tmpdir))
352

    
353
  def testIgnoreNotExisting(self):
354
    """Test that RemoveFile() ignores non-existing files"""
355
    utils.RemoveFile(self.tmpfile)
356
    utils.RemoveFile(self.tmpfile)
357

    
358
  def testRemoveFile(self):
359
    """Test that RemoveFile does remove a file"""
360
    utils.RemoveFile(self.tmpfile)
361
    if os.path.exists(self.tmpfile):
362
      self.fail("File '%s' not removed" % self.tmpfile)
363

    
364
  def testRemoveSymlink(self):
365
    """Test that RemoveFile does remove symlinks"""
366
    symlink = self.tmpdir + "/symlink"
367
    os.symlink("no-such-file", symlink)
368
    utils.RemoveFile(symlink)
369
    if os.path.exists(symlink):
370
      self.fail("File '%s' not removed" % symlink)
371
    os.symlink(self.tmpfile, symlink)
372
    utils.RemoveFile(symlink)
373
    if os.path.exists(symlink):
374
      self.fail("File '%s' not removed" % symlink)
375

    
376

    
377
class TestRemoveDir(unittest.TestCase):
378
  def setUp(self):
379
    self.tmpdir = tempfile.mkdtemp()
380

    
381
  def tearDown(self):
382
    try:
383
      shutil.rmtree(self.tmpdir)
384
    except EnvironmentError:
385
      pass
386

    
387
  def testEmptyDir(self):
388
    utils.RemoveDir(self.tmpdir)
389
    self.assertFalse(os.path.isdir(self.tmpdir))
390

    
391
  def testNonEmptyDir(self):
392
    self.tmpfile = os.path.join(self.tmpdir, "test1")
393
    open(self.tmpfile, "w").close()
394
    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
395

    
396

    
397
class TestRename(unittest.TestCase):
398
  """Test case for RenameFile"""
399

    
400
  def setUp(self):
401
    """Create a temporary directory"""
402
    self.tmpdir = tempfile.mkdtemp()
403
    self.tmpfile = os.path.join(self.tmpdir, "test1")
404

    
405
    # Touch the file
406
    open(self.tmpfile, "w").close()
407

    
408
  def tearDown(self):
409
    """Remove temporary directory"""
410
    shutil.rmtree(self.tmpdir)
411

    
412
  def testSimpleRename1(self):
413
    """Simple rename 1"""
414
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
415
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
416

    
417
  def testSimpleRename2(self):
418
    """Simple rename 2"""
419
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
420
                     mkdir=True)
421
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
422

    
423
  def testRenameMkdir(self):
424
    """Rename with mkdir"""
425
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
426
                     mkdir=True)
427
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
428
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
429

    
430
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
431
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
432
                     mkdir=True)
433
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
434
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
435
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
436

    
437

    
438
class TestMakedirs(unittest.TestCase):
439
  def setUp(self):
440
    self.tmpdir = tempfile.mkdtemp()
441

    
442
  def tearDown(self):
443
    shutil.rmtree(self.tmpdir)
444

    
445
  def testNonExisting(self):
446
    path = utils.PathJoin(self.tmpdir, "foo")
447
    utils.Makedirs(path)
448
    self.assert_(os.path.isdir(path))
449

    
450
  def testExisting(self):
451
    path = utils.PathJoin(self.tmpdir, "foo")
452
    os.mkdir(path)
453
    utils.Makedirs(path)
454
    self.assert_(os.path.isdir(path))
455

    
456
  def testRecursiveNonExisting(self):
457
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
458
    utils.Makedirs(path)
459
    self.assert_(os.path.isdir(path))
460

    
461
  def testRecursiveExisting(self):
462
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
463
    self.assertFalse(os.path.exists(path))
464
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
465
    utils.Makedirs(path)
466
    self.assert_(os.path.isdir(path))
467

    
468

    
469
class TestEnsureDirs(unittest.TestCase):
470
  """Tests for EnsureDirs"""
471

    
472
  def setUp(self):
473
    self.dir = tempfile.mkdtemp()
474
    self.old_umask = os.umask(0777)
475

    
476
  def testEnsureDirs(self):
477
    utils.EnsureDirs([
478
        (utils.PathJoin(self.dir, "foo"), 0777),
479
        (utils.PathJoin(self.dir, "bar"), 0000),
480
        ])
481
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
482
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
483

    
484
  def tearDown(self):
485
    os.rmdir(utils.PathJoin(self.dir, "foo"))
486
    os.rmdir(utils.PathJoin(self.dir, "bar"))
487
    os.rmdir(self.dir)
488
    os.umask(self.old_umask)
489

    
490

    
491
class TestIsNormAbsPath(unittest.TestCase):
492
  """Testing case for IsNormAbsPath"""
493

    
494
  def _pathTestHelper(self, path, result):
495
    if result:
496
      self.assert_(utils.IsNormAbsPath(path),
497
          "Path %s should result absolute and normalized" % path)
498
    else:
499
      self.assertFalse(utils.IsNormAbsPath(path),
500
          "Path %s should not result absolute and normalized" % path)
501

    
502
  def testBase(self):
503
    self._pathTestHelper("/etc", True)
504
    self._pathTestHelper("/srv", True)
505
    self._pathTestHelper("etc", False)
506
    self._pathTestHelper("/etc/../root", False)
507
    self._pathTestHelper("/etc/", False)
508

    
509

    
510
class TestPathJoin(unittest.TestCase):
511
  """Testing case for PathJoin"""
512

    
513
  def testBasicItems(self):
514
    mlist = ["/a", "b", "c"]
515
    self.failUnlessEqual(utils.PathJoin(*mlist), "/".join(mlist))
516

    
517
  def testNonAbsPrefix(self):
518
    self.failUnlessRaises(ValueError, utils.PathJoin, "a", "b")
519

    
520
  def testBackTrack(self):
521
    self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "b/../c")
522

    
523
  def testMultiAbs(self):
524
    self.failUnlessRaises(ValueError, utils.PathJoin, "/a", "/b")
525

    
526

    
527
class TestTailFile(testutils.GanetiTestCase):
528
  """Test case for the TailFile function"""
529

    
530
  def testEmpty(self):
531
    fname = self._CreateTempFile()
532
    self.failUnlessEqual(utils.TailFile(fname), [])
533
    self.failUnlessEqual(utils.TailFile(fname, lines=25), [])
534

    
535
  def testAllLines(self):
536
    data = ["test %d" % i for i in range(30)]
537
    for i in range(30):
538
      fname = self._CreateTempFile()
539
      fd = open(fname, "w")
540
      fd.write("\n".join(data[:i]))
541
      if i > 0:
542
        fd.write("\n")
543
      fd.close()
544
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[:i])
545

    
546
  def testPartialLines(self):
547
    data = ["test %d" % i for i in range(30)]
548
    fname = self._CreateTempFile()
549
    fd = open(fname, "w")
550
    fd.write("\n".join(data))
551
    fd.write("\n")
552
    fd.close()
553
    for i in range(1, 30):
554
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
555

    
556
  def testBigFile(self):
557
    data = ["test %d" % i for i in range(30)]
558
    fname = self._CreateTempFile()
559
    fd = open(fname, "w")
560
    fd.write("X" * 1048576)
561
    fd.write("\n")
562
    fd.write("\n".join(data))
563
    fd.write("\n")
564
    fd.close()
565
    for i in range(1, 30):
566
      self.failUnlessEqual(utils.TailFile(fname, lines=i), data[-i:])
567

    
568

    
569
class TestPidFileFunctions(unittest.TestCase):
570
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
571

    
572
  def setUp(self):
573
    self.dir = tempfile.mkdtemp()
574
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
575

    
576
  def testPidFileFunctions(self):
577
    pid_file = self.f_dpn('test')
578
    fd = utils.WritePidFile(self.f_dpn('test'))
579
    self.failUnless(os.path.exists(pid_file),
580
                    "PID file should have been created")
581
    read_pid = utils.ReadPidFile(pid_file)
582
    self.failUnlessEqual(read_pid, os.getpid())
583
    self.failUnless(utils.IsProcessAlive(read_pid))
584
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
585
                          self.f_dpn('test'))
586
    os.close(fd)
587
    utils.RemovePidFile(self.f_dpn("test"))
588
    self.failIf(os.path.exists(pid_file),
589
                "PID file should not exist anymore")
590
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
591
                         "ReadPidFile should return 0 for missing pid file")
592
    fh = open(pid_file, "w")
593
    fh.write("blah\n")
594
    fh.close()
595
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
596
                         "ReadPidFile should return 0 for invalid pid file")
597
    # but now, even with the file existing, we should be able to lock it
598
    fd = utils.WritePidFile(self.f_dpn('test'))
599
    os.close(fd)
600
    utils.RemovePidFile(self.f_dpn("test"))
601
    self.failIf(os.path.exists(pid_file),
602
                "PID file should not exist anymore")
603

    
604
  def testKill(self):
605
    pid_file = self.f_dpn('child')
606
    r_fd, w_fd = os.pipe()
607
    new_pid = os.fork()
608
    if new_pid == 0: #child
609
      utils.WritePidFile(self.f_dpn('child'))
610
      os.write(w_fd, 'a')
611
      signal.pause()
612
      os._exit(0)
613
      return
614
    # else we are in the parent
615
    # wait until the child has written the pid file
616
    os.read(r_fd, 1)
617
    read_pid = utils.ReadPidFile(pid_file)
618
    self.failUnlessEqual(read_pid, new_pid)
619
    self.failUnless(utils.IsProcessAlive(new_pid))
620
    utils.KillProcess(new_pid, waitpid=True)
621
    self.failIf(utils.IsProcessAlive(new_pid))
622
    utils.RemovePidFile(self.f_dpn('child'))
623
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
624

    
625
  def tearDown(self):
626
    shutil.rmtree(self.dir)
627

    
628

    
629
class TestSshKeys(testutils.GanetiTestCase):
630
  """Test case for the AddAuthorizedKey function"""
631

    
632
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
633
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
634
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
635

    
636
  def setUp(self):
637
    testutils.GanetiTestCase.setUp(self)
638
    self.tmpname = self._CreateTempFile()
639
    handle = open(self.tmpname, 'w')
640
    try:
641
      handle.write("%s\n" % TestSshKeys.KEY_A)
642
      handle.write("%s\n" % TestSshKeys.KEY_B)
643
    finally:
644
      handle.close()
645

    
646
  def testAddingNewKey(self):
647
    utils.AddAuthorizedKey(self.tmpname,
648
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
649

    
650
    self.assertFileContent(self.tmpname,
651
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
652
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
653
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
654
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
655

    
656
  def testAddingAlmostButNotCompletelyTheSameKey(self):
657
    utils.AddAuthorizedKey(self.tmpname,
658
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
659

    
660
    self.assertFileContent(self.tmpname,
661
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
662
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
663
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
664
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
665

    
666
  def testAddingExistingKeyWithSomeMoreSpaces(self):
667
    utils.AddAuthorizedKey(self.tmpname,
668
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
669

    
670
    self.assertFileContent(self.tmpname,
671
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
672
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
673
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
674

    
675
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
676
    utils.RemoveAuthorizedKey(self.tmpname,
677
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
678

    
679
    self.assertFileContent(self.tmpname,
680
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
681
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
682

    
683
  def testRemovingNonExistingKey(self):
684
    utils.RemoveAuthorizedKey(self.tmpname,
685
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
686

    
687
    self.assertFileContent(self.tmpname,
688
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
689
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
690
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
691

    
692

    
693
if __name__ == "__main__":
694
  testutils.GanetiTestProgram()