Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 1d466a4f

History | View | Annotate | Download (50.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import stat
31
import md5
32
import signal
33
import socket
34
import shutil
35
import re
36
import select
37
import string
38
import OpenSSL
39
import warnings
40
import distutils.version
41
import glob
42

    
43
import ganeti
44
import testutils
45
from ganeti import constants
46
from ganeti import utils
47
from ganeti import errors
48
from ganeti.utils import IsProcessAlive, RunCmd, \
49
     RemoveFile, MatchNameComponent, FormatUnit, \
50
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
51
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
52
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
53
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
54
     UnescapeAndSplit, RunParts, PathJoin, HostInfo
55

    
56
from ganeti.errors import LockError, UnitParseError, GenericError, \
57
     ProgrammerError, OpPrereqError
58

    
59

    
60
class TestIsProcessAlive(unittest.TestCase):
61
  """Testing case for IsProcessAlive"""
62

    
63
  def testExists(self):
64
    mypid = os.getpid()
65
    self.assert_(IsProcessAlive(mypid),
66
                 "can't find myself running")
67

    
68
  def testNotExisting(self):
69
    pid_non_existing = os.fork()
70
    if pid_non_existing == 0:
71
      os._exit(0)
72
    elif pid_non_existing < 0:
73
      raise SystemError("can't fork")
74
    os.waitpid(pid_non_existing, 0)
75
    self.assert_(not IsProcessAlive(pid_non_existing),
76
                 "nonexisting process detected")
77

    
78

    
79
class TestPidFileFunctions(unittest.TestCase):
80
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
81

    
82
  def setUp(self):
83
    self.dir = tempfile.mkdtemp()
84
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
85
    utils.DaemonPidFileName = self.f_dpn
86

    
87
  def testPidFileFunctions(self):
88
    pid_file = self.f_dpn('test')
89
    utils.WritePidFile('test')
90
    self.failUnless(os.path.exists(pid_file),
91
                    "PID file should have been created")
92
    read_pid = utils.ReadPidFile(pid_file)
93
    self.failUnlessEqual(read_pid, os.getpid())
94
    self.failUnless(utils.IsProcessAlive(read_pid))
95
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
96
    utils.RemovePidFile('test')
97
    self.failIf(os.path.exists(pid_file),
98
                "PID file should not exist anymore")
99
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
100
                         "ReadPidFile should return 0 for missing pid file")
101
    fh = open(pid_file, "w")
102
    fh.write("blah\n")
103
    fh.close()
104
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
105
                         "ReadPidFile should return 0 for invalid pid file")
106
    utils.RemovePidFile('test')
107
    self.failIf(os.path.exists(pid_file),
108
                "PID file should not exist anymore")
109

    
110
  def testKill(self):
111
    pid_file = self.f_dpn('child')
112
    r_fd, w_fd = os.pipe()
113
    new_pid = os.fork()
114
    if new_pid == 0: #child
115
      utils.WritePidFile('child')
116
      os.write(w_fd, 'a')
117
      signal.pause()
118
      os._exit(0)
119
      return
120
    # else we are in the parent
121
    # wait until the child has written the pid file
122
    os.read(r_fd, 1)
123
    read_pid = utils.ReadPidFile(pid_file)
124
    self.failUnlessEqual(read_pid, new_pid)
125
    self.failUnless(utils.IsProcessAlive(new_pid))
126
    utils.KillProcess(new_pid, waitpid=True)
127
    self.failIf(utils.IsProcessAlive(new_pid))
128
    utils.RemovePidFile('child')
129
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
130

    
131
  def tearDown(self):
132
    for name in os.listdir(self.dir):
133
      os.unlink(os.path.join(self.dir, name))
134
    os.rmdir(self.dir)
135

    
136

    
137
class TestRunCmd(testutils.GanetiTestCase):
138
  """Testing case for the RunCmd function"""
139

    
140
  def setUp(self):
141
    testutils.GanetiTestCase.setUp(self)
142
    self.magic = time.ctime() + " ganeti test"
143
    self.fname = self._CreateTempFile()
144

    
145
  def testOk(self):
146
    """Test successful exit code"""
147
    result = RunCmd("/bin/sh -c 'exit 0'")
148
    self.assertEqual(result.exit_code, 0)
149
    self.assertEqual(result.output, "")
150

    
151
  def testFail(self):
152
    """Test fail exit code"""
153
    result = RunCmd("/bin/sh -c 'exit 1'")
154
    self.assertEqual(result.exit_code, 1)
155
    self.assertEqual(result.output, "")
156

    
157
  def testStdout(self):
158
    """Test standard output"""
159
    cmd = 'echo -n "%s"' % self.magic
160
    result = RunCmd("/bin/sh -c '%s'" % cmd)
161
    self.assertEqual(result.stdout, self.magic)
162
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
163
    self.assertEqual(result.output, "")
164
    self.assertFileContent(self.fname, self.magic)
165

    
166
  def testStderr(self):
167
    """Test standard error"""
168
    cmd = 'echo -n "%s"' % self.magic
169
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
170
    self.assertEqual(result.stderr, self.magic)
171
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
172
    self.assertEqual(result.output, "")
173
    self.assertFileContent(self.fname, self.magic)
174

    
175
  def testCombined(self):
176
    """Test combined output"""
177
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
178
    expected = "A" + self.magic + "B" + self.magic
179
    result = RunCmd("/bin/sh -c '%s'" % cmd)
180
    self.assertEqual(result.output, expected)
181
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
182
    self.assertEqual(result.output, "")
183
    self.assertFileContent(self.fname, expected)
184

    
185
  def testSignal(self):
186
    """Test signal"""
187
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
188
    self.assertEqual(result.signal, 15)
189
    self.assertEqual(result.output, "")
190

    
191
  def testListRun(self):
192
    """Test list runs"""
193
    result = RunCmd(["true"])
194
    self.assertEqual(result.signal, None)
195
    self.assertEqual(result.exit_code, 0)
196
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
197
    self.assertEqual(result.signal, None)
198
    self.assertEqual(result.exit_code, 1)
199
    result = RunCmd(["echo", "-n", self.magic])
200
    self.assertEqual(result.signal, None)
201
    self.assertEqual(result.exit_code, 0)
202
    self.assertEqual(result.stdout, self.magic)
203

    
204
  def testFileEmptyOutput(self):
205
    """Test file output"""
206
    result = RunCmd(["true"], output=self.fname)
207
    self.assertEqual(result.signal, None)
208
    self.assertEqual(result.exit_code, 0)
209
    self.assertFileContent(self.fname, "")
210

    
211
  def testLang(self):
212
    """Test locale environment"""
213
    old_env = os.environ.copy()
214
    try:
215
      os.environ["LANG"] = "en_US.UTF-8"
216
      os.environ["LC_ALL"] = "en_US.UTF-8"
217
      result = RunCmd(["locale"])
218
      for line in result.output.splitlines():
219
        key, value = line.split("=", 1)
220
        # Ignore these variables, they're overridden by LC_ALL
221
        if key == "LANG" or key == "LANGUAGE":
222
          continue
223
        self.failIf(value and value != "C" and value != '"C"',
224
            "Variable %s is set to the invalid value '%s'" % (key, value))
225
    finally:
226
      os.environ = old_env
227

    
228
  def testDefaultCwd(self):
229
    """Test default working directory"""
230
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
231

    
232
  def testCwd(self):
233
    """Test default working directory"""
234
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
235
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
236
    cwd = os.getcwd()
237
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
238

    
239
  def testResetEnv(self):
240
    """Test environment reset functionality"""
241
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
242

    
243

    
244
class TestRunParts(unittest.TestCase):
245
  """Testing case for the RunParts function"""
246

    
247
  def setUp(self):
248
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
249

    
250
  def tearDown(self):
251
    shutil.rmtree(self.rundir)
252

    
253
  def testEmpty(self):
254
    """Test on an empty dir"""
255
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
256

    
257
  def testSkipWrongName(self):
258
    """Test that wrong files are skipped"""
259
    fname = os.path.join(self.rundir, "00test.dot")
260
    utils.WriteFile(fname, data="")
261
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
262
    relname = os.path.basename(fname)
263
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
264
                         [(relname, constants.RUNPARTS_SKIP, None)])
265

    
266
  def testSkipNonExec(self):
267
    """Test that non executable files are skipped"""
268
    fname = os.path.join(self.rundir, "00test")
269
    utils.WriteFile(fname, data="")
270
    relname = os.path.basename(fname)
271
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
272
                         [(relname, constants.RUNPARTS_SKIP, None)])
273

    
274
  def testError(self):
275
    """Test error on a broken executable"""
276
    fname = os.path.join(self.rundir, "00test")
277
    utils.WriteFile(fname, data="")
278
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
279
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
280
    self.failUnlessEqual(relname, os.path.basename(fname))
281
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
282
    self.failUnless(error)
283

    
284
  def testSorted(self):
285
    """Test executions are sorted"""
286
    files = []
287
    files.append(os.path.join(self.rundir, "64test"))
288
    files.append(os.path.join(self.rundir, "00test"))
289
    files.append(os.path.join(self.rundir, "42test"))
290

    
291
    for fname in files:
292
      utils.WriteFile(fname, data="")
293

    
294
    results = RunParts(self.rundir, reset_env=True)
295

    
296
    for fname in sorted(files):
297
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
298

    
299
  def testOk(self):
300
    """Test correct execution"""
301
    fname = os.path.join(self.rundir, "00test")
302
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
303
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
304
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
305
    self.failUnlessEqual(relname, os.path.basename(fname))
306
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
307
    self.failUnlessEqual(runresult.stdout, "ciao")
308

    
309
  def testRunFail(self):
310
    """Test correct execution, with run failure"""
311
    fname = os.path.join(self.rundir, "00test")
312
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
313
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
314
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
315
    self.failUnlessEqual(relname, os.path.basename(fname))
316
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
317
    self.failUnlessEqual(runresult.exit_code, 1)
318
    self.failUnless(runresult.failed)
319

    
320
  def testRunMix(self):
321
    files = []
322
    files.append(os.path.join(self.rundir, "00test"))
323
    files.append(os.path.join(self.rundir, "42test"))
324
    files.append(os.path.join(self.rundir, "64test"))
325
    files.append(os.path.join(self.rundir, "99test"))
326

    
327
    files.sort()
328

    
329
    # 1st has errors in execution
330
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
331
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
332

    
333
    # 2nd is skipped
334
    utils.WriteFile(files[1], data="")
335

    
336
    # 3rd cannot execute properly
337
    utils.WriteFile(files[2], data="")
338
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
339

    
340
    # 4th execs
341
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
342
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
343

    
344
    results = RunParts(self.rundir, reset_env=True)
345

    
346
    (relname, status, runresult) = results[0]
347
    self.failUnlessEqual(relname, os.path.basename(files[0]))
348
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
349
    self.failUnlessEqual(runresult.exit_code, 1)
350
    self.failUnless(runresult.failed)
351

    
352
    (relname, status, runresult) = results[1]
353
    self.failUnlessEqual(relname, os.path.basename(files[1]))
354
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
355
    self.failUnlessEqual(runresult, None)
356

    
357
    (relname, status, runresult) = results[2]
358
    self.failUnlessEqual(relname, os.path.basename(files[2]))
359
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
360
    self.failUnless(runresult)
361

    
362
    (relname, status, runresult) = results[3]
363
    self.failUnlessEqual(relname, os.path.basename(files[3]))
364
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
365
    self.failUnlessEqual(runresult.output, "ciao")
366
    self.failUnlessEqual(runresult.exit_code, 0)
367
    self.failUnless(not runresult.failed)
368

    
369

    
370
class TestRemoveFile(unittest.TestCase):
371
  """Test case for the RemoveFile function"""
372

    
373
  def setUp(self):
374
    """Create a temp dir and file for each case"""
375
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
376
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
377
    os.close(fd)
378

    
379
  def tearDown(self):
380
    if os.path.exists(self.tmpfile):
381
      os.unlink(self.tmpfile)
382
    os.rmdir(self.tmpdir)
383

    
384

    
385
  def testIgnoreDirs(self):
386
    """Test that RemoveFile() ignores directories"""
387
    self.assertEqual(None, RemoveFile(self.tmpdir))
388

    
389

    
390
  def testIgnoreNotExisting(self):
391
    """Test that RemoveFile() ignores non-existing files"""
392
    RemoveFile(self.tmpfile)
393
    RemoveFile(self.tmpfile)
394

    
395

    
396
  def testRemoveFile(self):
397
    """Test that RemoveFile does remove a file"""
398
    RemoveFile(self.tmpfile)
399
    if os.path.exists(self.tmpfile):
400
      self.fail("File '%s' not removed" % self.tmpfile)
401

    
402

    
403
  def testRemoveSymlink(self):
404
    """Test that RemoveFile does remove symlinks"""
405
    symlink = self.tmpdir + "/symlink"
406
    os.symlink("no-such-file", symlink)
407
    RemoveFile(symlink)
408
    if os.path.exists(symlink):
409
      self.fail("File '%s' not removed" % symlink)
410
    os.symlink(self.tmpfile, symlink)
411
    RemoveFile(symlink)
412
    if os.path.exists(symlink):
413
      self.fail("File '%s' not removed" % symlink)
414

    
415

    
416
class TestRename(unittest.TestCase):
417
  """Test case for RenameFile"""
418

    
419
  def setUp(self):
420
    """Create a temporary directory"""
421
    self.tmpdir = tempfile.mkdtemp()
422
    self.tmpfile = os.path.join(self.tmpdir, "test1")
423

    
424
    # Touch the file
425
    open(self.tmpfile, "w").close()
426

    
427
  def tearDown(self):
428
    """Remove temporary directory"""
429
    shutil.rmtree(self.tmpdir)
430

    
431
  def testSimpleRename1(self):
432
    """Simple rename 1"""
433
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
434
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
435

    
436
  def testSimpleRename2(self):
437
    """Simple rename 2"""
438
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
439
                     mkdir=True)
440
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
441

    
442
  def testRenameMkdir(self):
443
    """Rename with mkdir"""
444
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
445
                     mkdir=True)
446
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
447
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
448

    
449
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
450
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
451
                     mkdir=True)
452
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
453
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
454
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
455

    
456

    
457
class TestMatchNameComponent(unittest.TestCase):
458
  """Test case for the MatchNameComponent function"""
459

    
460
  def testEmptyList(self):
461
    """Test that there is no match against an empty list"""
462

    
463
    self.failUnlessEqual(MatchNameComponent("", []), None)
464
    self.failUnlessEqual(MatchNameComponent("test", []), None)
465

    
466
  def testSingleMatch(self):
467
    """Test that a single match is performed correctly"""
468
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
469
    for key in "test2", "test2.example", "test2.example.com":
470
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
471

    
472
  def testMultipleMatches(self):
473
    """Test that a multiple match is returned as None"""
474
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
475
    for key in "test1", "test1.example":
476
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
477

    
478
  def testFullMatch(self):
479
    """Test that a full match is returned correctly"""
480
    key1 = "test1"
481
    key2 = "test1.example"
482
    mlist = [key2, key2 + ".com"]
483
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
484
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
485

    
486
  def testCaseInsensitivePartialMatch(self):
487
    """Test for the case_insensitive keyword"""
488
    mlist = ["test1.example.com", "test2.example.net"]
489
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
490
                     "test2.example.net")
491
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
492
                     "test2.example.net")
493
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
494
                     "test2.example.net")
495
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
496
                     "test2.example.net")
497

    
498

    
499
  def testCaseInsensitiveFullMatch(self):
500
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
501
    # Between the two ts1 a full string match non-case insensitive should work
502
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
503
                     None)
504
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
505
                     "ts1.ex")
506
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
507
                     "ts1.ex")
508
    # Between the two ts2 only case differs, so only case-match works
509
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
510
                     "ts2.ex")
511
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
512
                     "Ts2.ex")
513
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
514
                     None)
515

    
516

    
517
class TestTimestampForFilename(unittest.TestCase):
518
  def test(self):
519
    self.assert_("." not in utils.TimestampForFilename())
520
    self.assert_(":" not in utils.TimestampForFilename())
521

    
522

    
523
class TestCreateBackup(testutils.GanetiTestCase):
524
  def setUp(self):
525
    testutils.GanetiTestCase.setUp(self)
526

    
527
    self.tmpdir = tempfile.mkdtemp()
528

    
529
  def tearDown(self):
530
    testutils.GanetiTestCase.tearDown(self)
531

    
532
    shutil.rmtree(self.tmpdir)
533

    
534
  def testEmpty(self):
535
    filename = utils.PathJoin(self.tmpdir, "config.data")
536
    utils.WriteFile(filename, data="")
537
    bname = utils.CreateBackup(filename)
538
    self.assertFileContent(bname, "")
539
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
540
    utils.CreateBackup(filename)
541
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
542
    utils.CreateBackup(filename)
543
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
544

    
545
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
546
    os.mkfifo(fifoname)
547
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
548

    
549
  def testContent(self):
550
    bkpcount = 0
551
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
552
      for rep in [1, 2, 10, 127]:
553
        testdata = data * rep
554

    
555
        filename = utils.PathJoin(self.tmpdir, "test.data_")
556
        utils.WriteFile(filename, data=testdata)
557
        self.assertFileContent(filename, testdata)
558

    
559
        for _ in range(3):
560
          bname = utils.CreateBackup(filename)
561
          bkpcount += 1
562
          self.assertFileContent(bname, testdata)
563
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
564

    
565

    
566
class TestFormatUnit(unittest.TestCase):
567
  """Test case for the FormatUnit function"""
568

    
569
  def testMiB(self):
570
    self.assertEqual(FormatUnit(1, 'h'), '1M')
571
    self.assertEqual(FormatUnit(100, 'h'), '100M')
572
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
573

    
574
    self.assertEqual(FormatUnit(1, 'm'), '1')
575
    self.assertEqual(FormatUnit(100, 'm'), '100')
576
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
577

    
578
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
579
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
580
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
581
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
582

    
583
  def testGiB(self):
584
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
585
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
586
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
587
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
588

    
589
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
590
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
591
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
592
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
593

    
594
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
595
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
596
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
597

    
598
  def testTiB(self):
599
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
600
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
601
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
602

    
603
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
604
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
605
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
606

    
607
class TestParseUnit(unittest.TestCase):
608
  """Test case for the ParseUnit function"""
609

    
610
  SCALES = (('', 1),
611
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
612
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
613
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
614

    
615
  def testRounding(self):
616
    self.assertEqual(ParseUnit('0'), 0)
617
    self.assertEqual(ParseUnit('1'), 4)
618
    self.assertEqual(ParseUnit('2'), 4)
619
    self.assertEqual(ParseUnit('3'), 4)
620

    
621
    self.assertEqual(ParseUnit('124'), 124)
622
    self.assertEqual(ParseUnit('125'), 128)
623
    self.assertEqual(ParseUnit('126'), 128)
624
    self.assertEqual(ParseUnit('127'), 128)
625
    self.assertEqual(ParseUnit('128'), 128)
626
    self.assertEqual(ParseUnit('129'), 132)
627
    self.assertEqual(ParseUnit('130'), 132)
628

    
629
  def testFloating(self):
630
    self.assertEqual(ParseUnit('0'), 0)
631
    self.assertEqual(ParseUnit('0.5'), 4)
632
    self.assertEqual(ParseUnit('1.75'), 4)
633
    self.assertEqual(ParseUnit('1.99'), 4)
634
    self.assertEqual(ParseUnit('2.00'), 4)
635
    self.assertEqual(ParseUnit('2.01'), 4)
636
    self.assertEqual(ParseUnit('3.99'), 4)
637
    self.assertEqual(ParseUnit('4.00'), 4)
638
    self.assertEqual(ParseUnit('4.01'), 8)
639
    self.assertEqual(ParseUnit('1.5G'), 1536)
640
    self.assertEqual(ParseUnit('1.8G'), 1844)
641
    self.assertEqual(ParseUnit('8.28T'), 8682212)
642

    
643
  def testSuffixes(self):
644
    for sep in ('', ' ', '   ', "\t", "\t "):
645
      for suffix, scale in TestParseUnit.SCALES:
646
        for func in (lambda x: x, str.lower, str.upper):
647
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
648
                           1024 * scale)
649

    
650
  def testInvalidInput(self):
651
    for sep in ('-', '_', ',', 'a'):
652
      for suffix, _ in TestParseUnit.SCALES:
653
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
654

    
655
    for suffix, _ in TestParseUnit.SCALES:
656
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
657

    
658

    
659
class TestSshKeys(testutils.GanetiTestCase):
660
  """Test case for the AddAuthorizedKey function"""
661

    
662
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
663
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
664
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
665

    
666
  def setUp(self):
667
    testutils.GanetiTestCase.setUp(self)
668
    self.tmpname = self._CreateTempFile()
669
    handle = open(self.tmpname, 'w')
670
    try:
671
      handle.write("%s\n" % TestSshKeys.KEY_A)
672
      handle.write("%s\n" % TestSshKeys.KEY_B)
673
    finally:
674
      handle.close()
675

    
676
  def testAddingNewKey(self):
677
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
678

    
679
    self.assertFileContent(self.tmpname,
680
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
681
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
682
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
683
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
684

    
685
  def testAddingAlmostButNotCompletelyTheSameKey(self):
686
    AddAuthorizedKey(self.tmpname,
687
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
688

    
689
    self.assertFileContent(self.tmpname,
690
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
691
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
692
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
693
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
694

    
695
  def testAddingExistingKeyWithSomeMoreSpaces(self):
696
    AddAuthorizedKey(self.tmpname,
697
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
698

    
699
    self.assertFileContent(self.tmpname,
700
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
701
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
702
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
703

    
704
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
705
    RemoveAuthorizedKey(self.tmpname,
706
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
707

    
708
    self.assertFileContent(self.tmpname,
709
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
710
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
711

    
712
  def testRemovingNonExistingKey(self):
713
    RemoveAuthorizedKey(self.tmpname,
714
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
715

    
716
    self.assertFileContent(self.tmpname,
717
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
718
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
719
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
720

    
721

    
722
class TestEtcHosts(testutils.GanetiTestCase):
723
  """Test functions modifying /etc/hosts"""
724

    
725
  def setUp(self):
726
    testutils.GanetiTestCase.setUp(self)
727
    self.tmpname = self._CreateTempFile()
728
    handle = open(self.tmpname, 'w')
729
    try:
730
      handle.write('# This is a test file for /etc/hosts\n')
731
      handle.write('127.0.0.1\tlocalhost\n')
732
      handle.write('192.168.1.1 router gw\n')
733
    finally:
734
      handle.close()
735

    
736
  def testSettingNewIp(self):
737
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
738

    
739
    self.assertFileContent(self.tmpname,
740
      "# This is a test file for /etc/hosts\n"
741
      "127.0.0.1\tlocalhost\n"
742
      "192.168.1.1 router gw\n"
743
      "1.2.3.4\tmyhost.domain.tld myhost\n")
744
    self.assertFileMode(self.tmpname, 0644)
745

    
746
  def testSettingExistingIp(self):
747
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
748
                     ['myhost'])
749

    
750
    self.assertFileContent(self.tmpname,
751
      "# This is a test file for /etc/hosts\n"
752
      "127.0.0.1\tlocalhost\n"
753
      "192.168.1.1\tmyhost.domain.tld myhost\n")
754
    self.assertFileMode(self.tmpname, 0644)
755

    
756
  def testSettingDuplicateName(self):
757
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
758

    
759
    self.assertFileContent(self.tmpname,
760
      "# This is a test file for /etc/hosts\n"
761
      "127.0.0.1\tlocalhost\n"
762
      "192.168.1.1 router gw\n"
763
      "1.2.3.4\tmyhost\n")
764
    self.assertFileMode(self.tmpname, 0644)
765

    
766
  def testRemovingExistingHost(self):
767
    RemoveEtcHostsEntry(self.tmpname, 'router')
768

    
769
    self.assertFileContent(self.tmpname,
770
      "# This is a test file for /etc/hosts\n"
771
      "127.0.0.1\tlocalhost\n"
772
      "192.168.1.1 gw\n")
773
    self.assertFileMode(self.tmpname, 0644)
774

    
775
  def testRemovingSingleExistingHost(self):
776
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
777

    
778
    self.assertFileContent(self.tmpname,
779
      "# This is a test file for /etc/hosts\n"
780
      "192.168.1.1 router gw\n")
781
    self.assertFileMode(self.tmpname, 0644)
782

    
783
  def testRemovingNonExistingHost(self):
784
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
785

    
786
    self.assertFileContent(self.tmpname,
787
      "# This is a test file for /etc/hosts\n"
788
      "127.0.0.1\tlocalhost\n"
789
      "192.168.1.1 router gw\n")
790
    self.assertFileMode(self.tmpname, 0644)
791

    
792
  def testRemovingAlias(self):
793
    RemoveEtcHostsEntry(self.tmpname, 'gw')
794

    
795
    self.assertFileContent(self.tmpname,
796
      "# This is a test file for /etc/hosts\n"
797
      "127.0.0.1\tlocalhost\n"
798
      "192.168.1.1 router\n")
799
    self.assertFileMode(self.tmpname, 0644)
800

    
801

    
802
class TestShellQuoting(unittest.TestCase):
803
  """Test case for shell quoting functions"""
804

    
805
  def testShellQuote(self):
806
    self.assertEqual(ShellQuote('abc'), "abc")
807
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
808
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
809
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
810
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
811

    
812
  def testShellQuoteArgs(self):
813
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
814
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
815
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
816

    
817

    
818
class TestTcpPing(unittest.TestCase):
819
  """Testcase for TCP version of ping - against listen(2)ing port"""
820

    
821
  def setUp(self):
822
    self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
823
    self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
824
    self.listenerport = self.listener.getsockname()[1]
825
    self.listener.listen(1)
826

    
827
  def tearDown(self):
828
    self.listener.shutdown(socket.SHUT_RDWR)
829
    del self.listener
830
    del self.listenerport
831

    
832
  def testTcpPingToLocalHostAccept(self):
833
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
834
                         self.listenerport,
835
                         timeout=10,
836
                         live_port_needed=True,
837
                         source=constants.LOCALHOST_IP_ADDRESS,
838
                         ),
839
                 "failed to connect to test listener")
840

    
841
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
842
                         self.listenerport,
843
                         timeout=10,
844
                         live_port_needed=True,
845
                         ),
846
                 "failed to connect to test listener (no source)")
847

    
848

    
849
class TestTcpPingDeaf(unittest.TestCase):
850
  """Testcase for TCP version of ping - against non listen(2)ing port"""
851

    
852
  def setUp(self):
853
    self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
854
    self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
855
    self.deaflistenerport = self.deaflistener.getsockname()[1]
856

    
857
  def tearDown(self):
858
    del self.deaflistener
859
    del self.deaflistenerport
860

    
861
  def testTcpPingToLocalHostAcceptDeaf(self):
862
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
863
                        self.deaflistenerport,
864
                        timeout=constants.TCP_PING_TIMEOUT,
865
                        live_port_needed=True,
866
                        source=constants.LOCALHOST_IP_ADDRESS,
867
                        ), # need successful connect(2)
868
                "successfully connected to deaf listener")
869

    
870
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
871
                        self.deaflistenerport,
872
                        timeout=constants.TCP_PING_TIMEOUT,
873
                        live_port_needed=True,
874
                        ), # need successful connect(2)
875
                "successfully connected to deaf listener (no source addr)")
876

    
877
  def testTcpPingToLocalHostNoAccept(self):
878
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
879
                         self.deaflistenerport,
880
                         timeout=constants.TCP_PING_TIMEOUT,
881
                         live_port_needed=False,
882
                         source=constants.LOCALHOST_IP_ADDRESS,
883
                         ), # ECONNREFUSED is OK
884
                 "failed to ping alive host on deaf port")
885

    
886
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
887
                         self.deaflistenerport,
888
                         timeout=constants.TCP_PING_TIMEOUT,
889
                         live_port_needed=False,
890
                         ), # ECONNREFUSED is OK
891
                 "failed to ping alive host on deaf port (no source addr)")
892

    
893

    
894
class TestOwnIpAddress(unittest.TestCase):
895
  """Testcase for OwnIpAddress"""
896

    
897
  def testOwnLoopback(self):
898
    """check having the loopback ip"""
899
    self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
900
                    "Should own the loopback address")
901

    
902
  def testNowOwnAddress(self):
903
    """check that I don't own an address"""
904

    
905
    # network 192.0.2.0/24 is reserved for test/documentation as per
906
    # rfc 3330, so we *should* not have an address of this range... if
907
    # this fails, we should extend the test to multiple addresses
908
    DST_IP = "192.0.2.1"
909
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
910

    
911

    
912
class TestListVisibleFiles(unittest.TestCase):
913
  """Test case for ListVisibleFiles"""
914

    
915
  def setUp(self):
916
    self.path = tempfile.mkdtemp()
917

    
918
  def tearDown(self):
919
    shutil.rmtree(self.path)
920

    
921
  def _test(self, files, expected):
922
    # Sort a copy
923
    expected = expected[:]
924
    expected.sort()
925

    
926
    for name in files:
927
      f = open(os.path.join(self.path, name), 'w')
928
      try:
929
        f.write("Test\n")
930
      finally:
931
        f.close()
932

    
933
    found = ListVisibleFiles(self.path)
934
    found.sort()
935

    
936
    self.assertEqual(found, expected)
937

    
938
  def testAllVisible(self):
939
    files = ["a", "b", "c"]
940
    expected = files
941
    self._test(files, expected)
942

    
943
  def testNoneVisible(self):
944
    files = [".a", ".b", ".c"]
945
    expected = []
946
    self._test(files, expected)
947

    
948
  def testSomeVisible(self):
949
    files = ["a", "b", ".c"]
950
    expected = ["a", "b"]
951
    self._test(files, expected)
952

    
953
  def testNonAbsolutePath(self):
954
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
955

    
956
  def testNonNormalizedPath(self):
957
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
958
                          "/bin/../tmp")
959

    
960

    
961
class TestNewUUID(unittest.TestCase):
962
  """Test case for NewUUID"""
963

    
964
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
965
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
966

    
967
  def runTest(self):
968
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
969

    
970

    
971
class TestUniqueSequence(unittest.TestCase):
972
  """Test case for UniqueSequence"""
973

    
974
  def _test(self, input, expected):
975
    self.assertEqual(utils.UniqueSequence(input), expected)
976

    
977
  def runTest(self):
978
    # Ordered input
979
    self._test([1, 2, 3], [1, 2, 3])
980
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
981
    self._test([1, 2, 2, 3], [1, 2, 3])
982
    self._test([1, 2, 3, 3], [1, 2, 3])
983

    
984
    # Unordered input
985
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
986
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
987

    
988
    # Strings
989
    self._test(["a", "a"], ["a"])
990
    self._test(["a", "b"], ["a", "b"])
991
    self._test(["a", "b", "a"], ["a", "b"])
992

    
993

    
994
class TestFirstFree(unittest.TestCase):
995
  """Test case for the FirstFree function"""
996

    
997
  def test(self):
998
    """Test FirstFree"""
999
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1000
    self.failUnlessEqual(FirstFree([]), None)
1001
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1002
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1003
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1004

    
1005

    
1006
class TestTailFile(testutils.GanetiTestCase):
1007
  """Test case for the TailFile function"""
1008

    
1009
  def testEmpty(self):
1010
    fname = self._CreateTempFile()
1011
    self.failUnlessEqual(TailFile(fname), [])
1012
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1013

    
1014
  def testAllLines(self):
1015
    data = ["test %d" % i for i in range(30)]
1016
    for i in range(30):
1017
      fname = self._CreateTempFile()
1018
      fd = open(fname, "w")
1019
      fd.write("\n".join(data[:i]))
1020
      if i > 0:
1021
        fd.write("\n")
1022
      fd.close()
1023
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1024

    
1025
  def testPartialLines(self):
1026
    data = ["test %d" % i for i in range(30)]
1027
    fname = self._CreateTempFile()
1028
    fd = open(fname, "w")
1029
    fd.write("\n".join(data))
1030
    fd.write("\n")
1031
    fd.close()
1032
    for i in range(1, 30):
1033
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1034

    
1035
  def testBigFile(self):
1036
    data = ["test %d" % i for i in range(30)]
1037
    fname = self._CreateTempFile()
1038
    fd = open(fname, "w")
1039
    fd.write("X" * 1048576)
1040
    fd.write("\n")
1041
    fd.write("\n".join(data))
1042
    fd.write("\n")
1043
    fd.close()
1044
    for i in range(1, 30):
1045
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1046

    
1047

    
1048
class _BaseFileLockTest:
1049
  """Test case for the FileLock class"""
1050

    
1051
  def testSharedNonblocking(self):
1052
    self.lock.Shared(blocking=False)
1053
    self.lock.Close()
1054

    
1055
  def testExclusiveNonblocking(self):
1056
    self.lock.Exclusive(blocking=False)
1057
    self.lock.Close()
1058

    
1059
  def testUnlockNonblocking(self):
1060
    self.lock.Unlock(blocking=False)
1061
    self.lock.Close()
1062

    
1063
  def testSharedBlocking(self):
1064
    self.lock.Shared(blocking=True)
1065
    self.lock.Close()
1066

    
1067
  def testExclusiveBlocking(self):
1068
    self.lock.Exclusive(blocking=True)
1069
    self.lock.Close()
1070

    
1071
  def testUnlockBlocking(self):
1072
    self.lock.Unlock(blocking=True)
1073
    self.lock.Close()
1074

    
1075
  def testSharedExclusiveUnlock(self):
1076
    self.lock.Shared(blocking=False)
1077
    self.lock.Exclusive(blocking=False)
1078
    self.lock.Unlock(blocking=False)
1079
    self.lock.Close()
1080

    
1081
  def testExclusiveSharedUnlock(self):
1082
    self.lock.Exclusive(blocking=False)
1083
    self.lock.Shared(blocking=False)
1084
    self.lock.Unlock(blocking=False)
1085
    self.lock.Close()
1086

    
1087
  def testSimpleTimeout(self):
1088
    # These will succeed on the first attempt, hence a short timeout
1089
    self.lock.Shared(blocking=True, timeout=10.0)
1090
    self.lock.Exclusive(blocking=False, timeout=10.0)
1091
    self.lock.Unlock(blocking=True, timeout=10.0)
1092
    self.lock.Close()
1093

    
1094
  @staticmethod
1095
  def _TryLockInner(filename, shared, blocking):
1096
    lock = utils.FileLock.Open(filename)
1097

    
1098
    if shared:
1099
      fn = lock.Shared
1100
    else:
1101
      fn = lock.Exclusive
1102

    
1103
    try:
1104
      # The timeout doesn't really matter as the parent process waits for us to
1105
      # finish anyway.
1106
      fn(blocking=blocking, timeout=0.01)
1107
    except errors.LockError, err:
1108
      return False
1109

    
1110
    return True
1111

    
1112
  def _TryLock(self, *args):
1113
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1114
                                      *args)
1115

    
1116
  def testTimeout(self):
1117
    for blocking in [True, False]:
1118
      self.lock.Exclusive(blocking=True)
1119
      self.failIf(self._TryLock(False, blocking))
1120
      self.failIf(self._TryLock(True, blocking))
1121

    
1122
      self.lock.Shared(blocking=True)
1123
      self.assert_(self._TryLock(True, blocking))
1124
      self.failIf(self._TryLock(False, blocking))
1125

    
1126
  def testCloseShared(self):
1127
    self.lock.Close()
1128
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1129

    
1130
  def testCloseExclusive(self):
1131
    self.lock.Close()
1132
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1133

    
1134
  def testCloseUnlock(self):
1135
    self.lock.Close()
1136
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1137

    
1138

    
1139
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1140
  TESTDATA = "Hello World\n" * 10
1141

    
1142
  def setUp(self):
1143
    testutils.GanetiTestCase.setUp(self)
1144

    
1145
    self.tmpfile = tempfile.NamedTemporaryFile()
1146
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1147
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1148

    
1149
    # Ensure "Open" didn't truncate file
1150
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1151

    
1152
  def tearDown(self):
1153
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1154

    
1155
    testutils.GanetiTestCase.tearDown(self)
1156

    
1157

    
1158
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1159
  def setUp(self):
1160
    self.tmpfile = tempfile.NamedTemporaryFile()
1161
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1162

    
1163

    
1164
class TestTimeFunctions(unittest.TestCase):
1165
  """Test case for time functions"""
1166

    
1167
  def runTest(self):
1168
    self.assertEqual(utils.SplitTime(1), (1, 0))
1169
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1170
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1171
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1172
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1173
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1174
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1175
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1176

    
1177
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1178

    
1179
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1180
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1181
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1182

    
1183
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1184
                     1218448917.481)
1185
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1186

    
1187
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1188
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1189
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1190
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1191
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1192

    
1193

    
1194
class FieldSetTestCase(unittest.TestCase):
1195
  """Test case for FieldSets"""
1196

    
1197
  def testSimpleMatch(self):
1198
    f = utils.FieldSet("a", "b", "c", "def")
1199
    self.failUnless(f.Matches("a"))
1200
    self.failIf(f.Matches("d"), "Substring matched")
1201
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1202
    self.failIf(f.NonMatching(["b", "c"]))
1203
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1204
    self.failUnless(f.NonMatching(["a", "d"]))
1205

    
1206
  def testRegexMatch(self):
1207
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1208
    self.failUnless(f.Matches("b1"))
1209
    self.failUnless(f.Matches("b99"))
1210
    self.failIf(f.Matches("b/1"))
1211
    self.failIf(f.NonMatching(["b12", "c"]))
1212
    self.failUnless(f.NonMatching(["a", "1"]))
1213

    
1214
class TestForceDictType(unittest.TestCase):
1215
  """Test case for ForceDictType"""
1216

    
1217
  def setUp(self):
1218
    self.key_types = {
1219
      'a': constants.VTYPE_INT,
1220
      'b': constants.VTYPE_BOOL,
1221
      'c': constants.VTYPE_STRING,
1222
      'd': constants.VTYPE_SIZE,
1223
      }
1224

    
1225
  def _fdt(self, dict, allowed_values=None):
1226
    if allowed_values is None:
1227
      ForceDictType(dict, self.key_types)
1228
    else:
1229
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1230

    
1231
    return dict
1232

    
1233
  def testSimpleDict(self):
1234
    self.assertEqual(self._fdt({}), {})
1235
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1236
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1237
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1238
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1239
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1240
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1241
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1242
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1243
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1244
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1245
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1246

    
1247
  def testErrors(self):
1248
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1249
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1250
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1251
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1252

    
1253

    
1254
class TestIsAbsNormPath(unittest.TestCase):
1255
  """Testing case for IsProcessAlive"""
1256

    
1257
  def _pathTestHelper(self, path, result):
1258
    if result:
1259
      self.assert_(IsNormAbsPath(path),
1260
          "Path %s should result absolute and normalized" % path)
1261
    else:
1262
      self.assert_(not IsNormAbsPath(path),
1263
          "Path %s should not result absolute and normalized" % path)
1264

    
1265
  def testBase(self):
1266
    self._pathTestHelper('/etc', True)
1267
    self._pathTestHelper('/srv', True)
1268
    self._pathTestHelper('etc', False)
1269
    self._pathTestHelper('/etc/../root', False)
1270
    self._pathTestHelper('/etc/', False)
1271

    
1272

    
1273
class TestSafeEncode(unittest.TestCase):
1274
  """Test case for SafeEncode"""
1275

    
1276
  def testAscii(self):
1277
    for txt in [string.digits, string.letters, string.punctuation]:
1278
      self.failUnlessEqual(txt, SafeEncode(txt))
1279

    
1280
  def testDoubleEncode(self):
1281
    for i in range(255):
1282
      txt = SafeEncode(chr(i))
1283
      self.failUnlessEqual(txt, SafeEncode(txt))
1284

    
1285
  def testUnicode(self):
1286
    # 1024 is high enough to catch non-direct ASCII mappings
1287
    for i in range(1024):
1288
      txt = SafeEncode(unichr(i))
1289
      self.failUnlessEqual(txt, SafeEncode(txt))
1290

    
1291

    
1292
class TestFormatTime(unittest.TestCase):
1293
  """Testing case for FormatTime"""
1294

    
1295
  def testNone(self):
1296
    self.failUnlessEqual(FormatTime(None), "N/A")
1297

    
1298
  def testInvalid(self):
1299
    self.failUnlessEqual(FormatTime(()), "N/A")
1300

    
1301
  def testNow(self):
1302
    # tests that we accept time.time input
1303
    FormatTime(time.time())
1304
    # tests that we accept int input
1305
    FormatTime(int(time.time()))
1306

    
1307

    
1308
class RunInSeparateProcess(unittest.TestCase):
1309
  def test(self):
1310
    for exp in [True, False]:
1311
      def _child():
1312
        return exp
1313

    
1314
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1315

    
1316
  def testArgs(self):
1317
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1318
      def _child(carg1, carg2):
1319
        return carg1 == "Foo" and carg2 == arg
1320

    
1321
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1322

    
1323
  def testPid(self):
1324
    parent_pid = os.getpid()
1325

    
1326
    def _check():
1327
      return os.getpid() == parent_pid
1328

    
1329
    self.failIf(utils.RunInSeparateProcess(_check))
1330

    
1331
  def testSignal(self):
1332
    def _kill():
1333
      os.kill(os.getpid(), signal.SIGTERM)
1334

    
1335
    self.assertRaises(errors.GenericError,
1336
                      utils.RunInSeparateProcess, _kill)
1337

    
1338
  def testException(self):
1339
    def _exc():
1340
      raise errors.GenericError("This is a test")
1341

    
1342
    self.assertRaises(errors.GenericError,
1343
                      utils.RunInSeparateProcess, _exc)
1344

    
1345

    
1346
class TestFingerprintFile(unittest.TestCase):
1347
  def setUp(self):
1348
    self.tmpfile = tempfile.NamedTemporaryFile()
1349

    
1350
  def test(self):
1351
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1352
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1353

    
1354
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1355
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1356
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1357

    
1358

    
1359
class TestUnescapeAndSplit(unittest.TestCase):
1360
  """Testing case for UnescapeAndSplit"""
1361

    
1362
  def setUp(self):
1363
    # testing more that one separator for regexp safety
1364
    self._seps = [",", "+", "."]
1365

    
1366
  def testSimple(self):
1367
    a = ["a", "b", "c", "d"]
1368
    for sep in self._seps:
1369
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1370

    
1371
  def testEscape(self):
1372
    for sep in self._seps:
1373
      a = ["a", "b\\" + sep + "c", "d"]
1374
      b = ["a", "b" + sep + "c", "d"]
1375
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1376

    
1377
  def testDoubleEscape(self):
1378
    for sep in self._seps:
1379
      a = ["a", "b\\\\", "c", "d"]
1380
      b = ["a", "b\\", "c", "d"]
1381
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1382

    
1383
  def testThreeEscape(self):
1384
    for sep in self._seps:
1385
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1386
      b = ["a", "b\\" + sep + "c", "d"]
1387
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1388

    
1389

    
1390
class TestPathJoin(unittest.TestCase):
1391
  """Testing case for PathJoin"""
1392

    
1393
  def testBasicItems(self):
1394
    mlist = ["/a", "b", "c"]
1395
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1396

    
1397
  def testNonAbsPrefix(self):
1398
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1399

    
1400
  def testBackTrack(self):
1401
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1402

    
1403
  def testMultiAbs(self):
1404
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1405

    
1406

    
1407
class TestHostInfo(unittest.TestCase):
1408
  """Testing case for HostInfo"""
1409

    
1410
  def testUppercase(self):
1411
    data = "AbC.example.com"
1412
    self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1413

    
1414
  def testTooLongName(self):
1415
    data = "a.b." + "c" * 255
1416
    self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1417

    
1418
  def testTrailingDot(self):
1419
    data = "a.b.c"
1420
    self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1421

    
1422
  def testInvalidName(self):
1423
    data = [
1424
      "a b",
1425
      "a/b",
1426
      ".a.b",
1427
      "a..b",
1428
      ]
1429
    for value in data:
1430
      self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1431

    
1432
  def testValidName(self):
1433
    data = [
1434
      "a.b",
1435
      "a-b",
1436
      "a_b",
1437
      "a.b.c",
1438
      ]
1439
    for value in data:
1440
      HostInfo.NormalizeName(value)
1441

    
1442

    
1443
class TestParseAsn1Generalizedtime(unittest.TestCase):
1444
  def test(self):
1445
    # UTC
1446
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1447
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1448
                     1266860512)
1449
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1450
                     (2**31) - 1)
1451

    
1452
    # With offset
1453
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1454
                     1266860512)
1455
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1456
                     1266931012)
1457
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1458
                     1266931088)
1459
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1460
                     1266931295)
1461
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1462
                     3600)
1463

    
1464
    # Leap seconds are not supported by datetime.datetime
1465
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1466
                      "19841231235960+0000")
1467
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1468
                      "19920630235960+0000")
1469

    
1470
    # Errors
1471
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1472
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1473
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1474
                      "20100222174152")
1475
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1476
                      "Mon Feb 22 17:47:02 UTC 2010")
1477
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1478
                      "2010-02-22 17:42:02")
1479

    
1480

    
1481
class TestGetX509CertValidity(testutils.GanetiTestCase):
1482
  def setUp(self):
1483
    testutils.GanetiTestCase.setUp(self)
1484

    
1485
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1486

    
1487
    # Test whether we have pyOpenSSL 0.7 or above
1488
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1489

    
1490
    if not self.pyopenssl0_7:
1491
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1492
                    " function correctly")
1493

    
1494
  def _LoadCert(self, name):
1495
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1496
                                           self._ReadTestData(name))
1497

    
1498
  def test(self):
1499
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1500
    if self.pyopenssl0_7:
1501
      self.assertEqual(validity, (1266919967, 1267524767))
1502
    else:
1503
      self.assertEqual(validity, (None, None))
1504

    
1505

    
1506
if __name__ == '__main__':
1507
  testutils.GanetiTestProgram()