Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 79d22269

History | View | Annotate | Download (85.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
import random
42
import operator
43
from cStringIO import StringIO
44

    
45
import testutils
46
from ganeti import constants
47
from ganeti import compat
48
from ganeti import utils
49
from ganeti import errors
50
from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
51
     ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
52
     TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
53
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
54

    
55

    
56
class TestIsProcessAlive(unittest.TestCase):
57
  """Testing case for IsProcessAlive"""
58

    
59
  def testExists(self):
60
    mypid = os.getpid()
61
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
62

    
63
  def testNotExisting(self):
64
    pid_non_existing = os.fork()
65
    if pid_non_existing == 0:
66
      os._exit(0)
67
    elif pid_non_existing < 0:
68
      raise SystemError("can't fork")
69
    os.waitpid(pid_non_existing, 0)
70
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
71
                     "nonexisting process detected")
72

    
73

    
74
class TestGetProcStatusPath(unittest.TestCase):
75
  def test(self):
76
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
77
    self.assertNotEqual(utils._GetProcStatusPath(1),
78
                        utils._GetProcStatusPath(2))
79

    
80

    
81
class TestIsProcessHandlingSignal(unittest.TestCase):
82
  def setUp(self):
83
    self.tmpdir = tempfile.mkdtemp()
84

    
85
  def tearDown(self):
86
    shutil.rmtree(self.tmpdir)
87

    
88
  def testParseSigsetT(self):
89
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
90
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
91
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
92
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
93
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
94
                     set([2, 10, 32, 33]))
95
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
96
                     set([2, 32, 33]))
97
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
98
                     set([2, 28, 32, 33]))
99
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
100
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
101
                          24, 25, 26, 28, 31]))
102
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
103

    
104
  def testGetProcStatusField(self):
105
    for field in ["SigCgt", "Name", "FDSize"]:
106
      for value in ["", "0", "cat", "  1234 KB"]:
107
        pstatus = "\n".join([
108
          "VmPeak: 999 kB",
109
          "%s: %s" % (field, value),
110
          "TracerPid: 0",
111
          ])
112
        result = utils._GetProcStatusField(pstatus, field)
113
        self.assertEqual(result, value.strip())
114

    
115
  def test(self):
116
    sp = PathJoin(self.tmpdir, "status")
117

    
118
    utils.WriteFile(sp, data="\n".join([
119
      "Name:   bash",
120
      "State:  S (sleeping)",
121
      "SleepAVG:       98%",
122
      "Pid:    22250",
123
      "PPid:   10858",
124
      "TracerPid:      0",
125
      "SigBlk: 0000000000010000",
126
      "SigIgn: 0000000000384004",
127
      "SigCgt: 000000004b813efb",
128
      "CapEff: 0000000000000000",
129
      ]))
130

    
131
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
132

    
133
  def testNoSigCgt(self):
134
    sp = PathJoin(self.tmpdir, "status")
135

    
136
    utils.WriteFile(sp, data="\n".join([
137
      "Name:   bash",
138
      ]))
139

    
140
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
141
                      1234, 10, status_path=sp)
142

    
143
  def testNoSuchFile(self):
144
    sp = PathJoin(self.tmpdir, "notexist")
145

    
146
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
147

    
148
  @staticmethod
149
  def _TestRealProcess():
150
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
151
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
152
      raise Exception("SIGUSR1 is handled when it should not be")
153

    
154
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
155
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
156
      raise Exception("SIGUSR1 is not handled when it should be")
157

    
158
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
159
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160
      raise Exception("SIGUSR1 is not handled when it should be")
161

    
162
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
163
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164
      raise Exception("SIGUSR1 is handled when it should not be")
165

    
166
    return True
167

    
168
  def testRealProcess(self):
169
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
170

    
171

    
172
class TestPidFileFunctions(unittest.TestCase):
173
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
174

    
175
  def setUp(self):
176
    self.dir = tempfile.mkdtemp()
177
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
178
    utils.DaemonPidFileName = self.f_dpn
179

    
180
  def testPidFileFunctions(self):
181
    pid_file = self.f_dpn('test')
182
    fd = utils.WritePidFile(self.f_dpn('test'))
183
    self.failUnless(os.path.exists(pid_file),
184
                    "PID file should have been created")
185
    read_pid = utils.ReadPidFile(pid_file)
186
    self.failUnlessEqual(read_pid, os.getpid())
187
    self.failUnless(utils.IsProcessAlive(read_pid))
188
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
189
                          self.f_dpn('test'))
190
    os.close(fd)
191
    utils.RemovePidFile('test')
192
    self.failIf(os.path.exists(pid_file),
193
                "PID file should not exist anymore")
194
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
195
                         "ReadPidFile should return 0 for missing pid file")
196
    fh = open(pid_file, "w")
197
    fh.write("blah\n")
198
    fh.close()
199
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
200
                         "ReadPidFile should return 0 for invalid pid file")
201
    # but now, even with the file existing, we should be able to lock it
202
    fd = utils.WritePidFile(self.f_dpn('test'))
203
    os.close(fd)
204
    utils.RemovePidFile('test')
205
    self.failIf(os.path.exists(pid_file),
206
                "PID file should not exist anymore")
207

    
208
  def testKill(self):
209
    pid_file = self.f_dpn('child')
210
    r_fd, w_fd = os.pipe()
211
    new_pid = os.fork()
212
    if new_pid == 0: #child
213
      utils.WritePidFile(self.f_dpn('child'))
214
      os.write(w_fd, 'a')
215
      signal.pause()
216
      os._exit(0)
217
      return
218
    # else we are in the parent
219
    # wait until the child has written the pid file
220
    os.read(r_fd, 1)
221
    read_pid = utils.ReadPidFile(pid_file)
222
    self.failUnlessEqual(read_pid, new_pid)
223
    self.failUnless(utils.IsProcessAlive(new_pid))
224
    utils.KillProcess(new_pid, waitpid=True)
225
    self.failIf(utils.IsProcessAlive(new_pid))
226
    utils.RemovePidFile('child')
227
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
228

    
229
  def tearDown(self):
230
    for name in os.listdir(self.dir):
231
      os.unlink(os.path.join(self.dir, name))
232
    os.rmdir(self.dir)
233

    
234

    
235
class TestRunCmd(testutils.GanetiTestCase):
236
  """Testing case for the RunCmd function"""
237

    
238
  def setUp(self):
239
    testutils.GanetiTestCase.setUp(self)
240
    self.magic = time.ctime() + " ganeti test"
241
    self.fname = self._CreateTempFile()
242
    self.fifo_tmpdir = tempfile.mkdtemp()
243
    self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
244
    os.mkfifo(self.fifo_file)
245

    
246
  def tearDown(self):
247
    shutil.rmtree(self.fifo_tmpdir)
248
    testutils.GanetiTestCase.tearDown(self)
249

    
250
  def testOk(self):
251
    """Test successful exit code"""
252
    result = RunCmd("/bin/sh -c 'exit 0'")
253
    self.assertEqual(result.exit_code, 0)
254
    self.assertEqual(result.output, "")
255

    
256
  def testFail(self):
257
    """Test fail exit code"""
258
    result = RunCmd("/bin/sh -c 'exit 1'")
259
    self.assertEqual(result.exit_code, 1)
260
    self.assertEqual(result.output, "")
261

    
262
  def testStdout(self):
263
    """Test standard output"""
264
    cmd = 'echo -n "%s"' % self.magic
265
    result = RunCmd("/bin/sh -c '%s'" % cmd)
266
    self.assertEqual(result.stdout, self.magic)
267
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
268
    self.assertEqual(result.output, "")
269
    self.assertFileContent(self.fname, self.magic)
270

    
271
  def testStderr(self):
272
    """Test standard error"""
273
    cmd = 'echo -n "%s"' % self.magic
274
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
275
    self.assertEqual(result.stderr, self.magic)
276
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
277
    self.assertEqual(result.output, "")
278
    self.assertFileContent(self.fname, self.magic)
279

    
280
  def testCombined(self):
281
    """Test combined output"""
282
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
283
    expected = "A" + self.magic + "B" + self.magic
284
    result = RunCmd("/bin/sh -c '%s'" % cmd)
285
    self.assertEqual(result.output, expected)
286
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
287
    self.assertEqual(result.output, "")
288
    self.assertFileContent(self.fname, expected)
289

    
290
  def testSignal(self):
291
    """Test signal"""
292
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
293
    self.assertEqual(result.signal, 15)
294
    self.assertEqual(result.output, "")
295

    
296
  def testTimeoutClean(self):
297
    cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
298
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
299
    self.assertEqual(result.exit_code, 0)
300

    
301
  def testTimeoutKill(self):
302
    cmd = ["/bin/sh", "-c", "trap '' TERM; read < %s" % self.fifo_file]
303
    timeout = 0.2
304
    out, err, status, ta = utils._RunCmdPipe(cmd, {}, False, "/", False,
305
                                             timeout, _linger_timeout=0.2)
306
    self.assert_(status < 0)
307
    self.assertEqual(-status, signal.SIGKILL)
308

    
309
  def testTimeoutOutputAfterTerm(self):
310
    cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
311
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
312
    self.assert_(result.failed)
313
    self.assertEqual(result.stdout, "sigtermed\n")
314

    
315
  def testListRun(self):
316
    """Test list runs"""
317
    result = RunCmd(["true"])
318
    self.assertEqual(result.signal, None)
319
    self.assertEqual(result.exit_code, 0)
320
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
321
    self.assertEqual(result.signal, None)
322
    self.assertEqual(result.exit_code, 1)
323
    result = RunCmd(["echo", "-n", self.magic])
324
    self.assertEqual(result.signal, None)
325
    self.assertEqual(result.exit_code, 0)
326
    self.assertEqual(result.stdout, self.magic)
327

    
328
  def testFileEmptyOutput(self):
329
    """Test file output"""
330
    result = RunCmd(["true"], output=self.fname)
331
    self.assertEqual(result.signal, None)
332
    self.assertEqual(result.exit_code, 0)
333
    self.assertFileContent(self.fname, "")
334

    
335
  def testLang(self):
336
    """Test locale environment"""
337
    old_env = os.environ.copy()
338
    try:
339
      os.environ["LANG"] = "en_US.UTF-8"
340
      os.environ["LC_ALL"] = "en_US.UTF-8"
341
      result = RunCmd(["locale"])
342
      for line in result.output.splitlines():
343
        key, value = line.split("=", 1)
344
        # Ignore these variables, they're overridden by LC_ALL
345
        if key == "LANG" or key == "LANGUAGE":
346
          continue
347
        self.failIf(value and value != "C" and value != '"C"',
348
            "Variable %s is set to the invalid value '%s'" % (key, value))
349
    finally:
350
      os.environ = old_env
351

    
352
  def testDefaultCwd(self):
353
    """Test default working directory"""
354
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
355

    
356
  def testCwd(self):
357
    """Test default working directory"""
358
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
359
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
360
    cwd = os.getcwd()
361
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
362

    
363
  def testResetEnv(self):
364
    """Test environment reset functionality"""
365
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
366
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
367
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
368

    
369
  def testNoFork(self):
370
    """Test that nofork raise an error"""
371
    assert not utils.no_fork
372
    utils.no_fork = True
373
    try:
374
      self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
375
    finally:
376
      utils.no_fork = False
377

    
378
  def testWrongParams(self):
379
    """Test wrong parameters"""
380
    self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
381
                      output="/dev/null", interactive=True)
382

    
383

    
384
class TestRunParts(testutils.GanetiTestCase):
385
  """Testing case for the RunParts function"""
386

    
387
  def setUp(self):
388
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
389

    
390
  def tearDown(self):
391
    shutil.rmtree(self.rundir)
392

    
393
  def testEmpty(self):
394
    """Test on an empty dir"""
395
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
396

    
397
  def testSkipWrongName(self):
398
    """Test that wrong files are skipped"""
399
    fname = os.path.join(self.rundir, "00test.dot")
400
    utils.WriteFile(fname, data="")
401
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
402
    relname = os.path.basename(fname)
403
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
404
                         [(relname, constants.RUNPARTS_SKIP, None)])
405

    
406
  def testSkipNonExec(self):
407
    """Test that non executable files are skipped"""
408
    fname = os.path.join(self.rundir, "00test")
409
    utils.WriteFile(fname, data="")
410
    relname = os.path.basename(fname)
411
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
412
                         [(relname, constants.RUNPARTS_SKIP, None)])
413

    
414
  def testError(self):
415
    """Test error on a broken executable"""
416
    fname = os.path.join(self.rundir, "00test")
417
    utils.WriteFile(fname, data="")
418
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
419
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
420
    self.failUnlessEqual(relname, os.path.basename(fname))
421
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
422
    self.failUnless(error)
423

    
424
  def testSorted(self):
425
    """Test executions are sorted"""
426
    files = []
427
    files.append(os.path.join(self.rundir, "64test"))
428
    files.append(os.path.join(self.rundir, "00test"))
429
    files.append(os.path.join(self.rundir, "42test"))
430

    
431
    for fname in files:
432
      utils.WriteFile(fname, data="")
433

    
434
    results = RunParts(self.rundir, reset_env=True)
435

    
436
    for fname in sorted(files):
437
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
438

    
439
  def testOk(self):
440
    """Test correct execution"""
441
    fname = os.path.join(self.rundir, "00test")
442
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
443
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
444
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
445
    self.failUnlessEqual(relname, os.path.basename(fname))
446
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
447
    self.failUnlessEqual(runresult.stdout, "ciao")
448

    
449
  def testRunFail(self):
450
    """Test correct execution, with run failure"""
451
    fname = os.path.join(self.rundir, "00test")
452
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
453
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
454
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
455
    self.failUnlessEqual(relname, os.path.basename(fname))
456
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
457
    self.failUnlessEqual(runresult.exit_code, 1)
458
    self.failUnless(runresult.failed)
459

    
460
  def testRunMix(self):
461
    files = []
462
    files.append(os.path.join(self.rundir, "00test"))
463
    files.append(os.path.join(self.rundir, "42test"))
464
    files.append(os.path.join(self.rundir, "64test"))
465
    files.append(os.path.join(self.rundir, "99test"))
466

    
467
    files.sort()
468

    
469
    # 1st has errors in execution
470
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
471
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
472

    
473
    # 2nd is skipped
474
    utils.WriteFile(files[1], data="")
475

    
476
    # 3rd cannot execute properly
477
    utils.WriteFile(files[2], data="")
478
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
479

    
480
    # 4th execs
481
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
482
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
483

    
484
    results = RunParts(self.rundir, reset_env=True)
485

    
486
    (relname, status, runresult) = results[0]
487
    self.failUnlessEqual(relname, os.path.basename(files[0]))
488
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
489
    self.failUnlessEqual(runresult.exit_code, 1)
490
    self.failUnless(runresult.failed)
491

    
492
    (relname, status, runresult) = results[1]
493
    self.failUnlessEqual(relname, os.path.basename(files[1]))
494
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
495
    self.failUnlessEqual(runresult, None)
496

    
497
    (relname, status, runresult) = results[2]
498
    self.failUnlessEqual(relname, os.path.basename(files[2]))
499
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
500
    self.failUnless(runresult)
501

    
502
    (relname, status, runresult) = results[3]
503
    self.failUnlessEqual(relname, os.path.basename(files[3]))
504
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
505
    self.failUnlessEqual(runresult.output, "ciao")
506
    self.failUnlessEqual(runresult.exit_code, 0)
507
    self.failUnless(not runresult.failed)
508

    
509
  def testMissingDirectory(self):
510
    nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
511
    self.assertEqual(RunParts(nosuchdir), [])
512

    
513

    
514
class TestStartDaemon(testutils.GanetiTestCase):
515
  def setUp(self):
516
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
517
    self.tmpfile = os.path.join(self.tmpdir, "test")
518

    
519
  def tearDown(self):
520
    shutil.rmtree(self.tmpdir)
521

    
522
  def testShell(self):
523
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
524
    self._wait(self.tmpfile, 60.0, "Hello World")
525

    
526
  def testShellOutput(self):
527
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
528
    self._wait(self.tmpfile, 60.0, "Hello World")
529

    
530
  def testNoShellNoOutput(self):
531
    utils.StartDaemon(["pwd"])
532

    
533
  def testNoShellNoOutputTouch(self):
534
    testfile = os.path.join(self.tmpdir, "check")
535
    self.failIf(os.path.exists(testfile))
536
    utils.StartDaemon(["touch", testfile])
537
    self._wait(testfile, 60.0, "")
538

    
539
  def testNoShellOutput(self):
540
    utils.StartDaemon(["pwd"], output=self.tmpfile)
541
    self._wait(self.tmpfile, 60.0, "/")
542

    
543
  def testNoShellOutputCwd(self):
544
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
545
    self._wait(self.tmpfile, 60.0, os.getcwd())
546

    
547
  def testShellEnv(self):
548
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
549
                      env={ "GNT_TEST_VAR": "Hello World", })
550
    self._wait(self.tmpfile, 60.0, "Hello World")
551

    
552
  def testNoShellEnv(self):
553
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
554
                      env={ "GNT_TEST_VAR": "Hello World", })
555
    self._wait(self.tmpfile, 60.0, "Hello World")
556

    
557
  def testOutputFd(self):
558
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
559
    try:
560
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
561
    finally:
562
      os.close(fd)
563
    self._wait(self.tmpfile, 60.0, os.getcwd())
564

    
565
  def testPid(self):
566
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
567
    self._wait(self.tmpfile, 60.0, str(pid))
568

    
569
  def testPidFile(self):
570
    pidfile = os.path.join(self.tmpdir, "pid")
571
    checkfile = os.path.join(self.tmpdir, "abort")
572

    
573
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
574
                            output=self.tmpfile)
575
    try:
576
      fd = os.open(pidfile, os.O_RDONLY)
577
      try:
578
        # Check file is locked
579
        self.assertRaises(errors.LockError, utils.LockFile, fd)
580

    
581
        pidtext = os.read(fd, 100)
582
      finally:
583
        os.close(fd)
584

    
585
      self.assertEqual(int(pidtext.strip()), pid)
586

    
587
      self.assert_(utils.IsProcessAlive(pid))
588
    finally:
589
      # No matter what happens, kill daemon
590
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
591
      self.failIf(utils.IsProcessAlive(pid))
592

    
593
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
594

    
595
  def _wait(self, path, timeout, expected):
596
    # Due to the asynchronous nature of daemon processes, polling is necessary.
597
    # A timeout makes sure the test doesn't hang forever.
598
    def _CheckFile():
599
      if not (os.path.isfile(path) and
600
              utils.ReadFile(path).strip() == expected):
601
        raise utils.RetryAgain()
602

    
603
    try:
604
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
605
    except utils.RetryTimeout:
606
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
607
                " didn't write the correct output" % timeout)
608

    
609
  def testError(self):
610
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
611
                      ["./does-NOT-EXIST/here/0123456789"])
612
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
613
                      ["./does-NOT-EXIST/here/0123456789"],
614
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
615
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
616
                      ["./does-NOT-EXIST/here/0123456789"],
617
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
618
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
619
                      ["./does-NOT-EXIST/here/0123456789"],
620
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
621

    
622
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
623
    try:
624
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
625
                        ["./does-NOT-EXIST/here/0123456789"],
626
                        output=self.tmpfile, output_fd=fd)
627
    finally:
628
      os.close(fd)
629

    
630

    
631
class TestSetCloseOnExecFlag(unittest.TestCase):
632
  """Tests for SetCloseOnExecFlag"""
633

    
634
  def setUp(self):
635
    self.tmpfile = tempfile.TemporaryFile()
636

    
637
  def testEnable(self):
638
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
639
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
640
                    fcntl.FD_CLOEXEC)
641

    
642
  def testDisable(self):
643
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
644
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
645
                fcntl.FD_CLOEXEC)
646

    
647

    
648
class TestSetNonblockFlag(unittest.TestCase):
649
  def setUp(self):
650
    self.tmpfile = tempfile.TemporaryFile()
651

    
652
  def testEnable(self):
653
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
654
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
655
                    os.O_NONBLOCK)
656

    
657
  def testDisable(self):
658
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
659
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
660
                os.O_NONBLOCK)
661

    
662

    
663
class TestRemoveFile(unittest.TestCase):
664
  """Test case for the RemoveFile function"""
665

    
666
  def setUp(self):
667
    """Create a temp dir and file for each case"""
668
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
669
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
670
    os.close(fd)
671

    
672
  def tearDown(self):
673
    if os.path.exists(self.tmpfile):
674
      os.unlink(self.tmpfile)
675
    os.rmdir(self.tmpdir)
676

    
677
  def testIgnoreDirs(self):
678
    """Test that RemoveFile() ignores directories"""
679
    self.assertEqual(None, RemoveFile(self.tmpdir))
680

    
681
  def testIgnoreNotExisting(self):
682
    """Test that RemoveFile() ignores non-existing files"""
683
    RemoveFile(self.tmpfile)
684
    RemoveFile(self.tmpfile)
685

    
686
  def testRemoveFile(self):
687
    """Test that RemoveFile does remove a file"""
688
    RemoveFile(self.tmpfile)
689
    if os.path.exists(self.tmpfile):
690
      self.fail("File '%s' not removed" % self.tmpfile)
691

    
692
  def testRemoveSymlink(self):
693
    """Test that RemoveFile does remove symlinks"""
694
    symlink = self.tmpdir + "/symlink"
695
    os.symlink("no-such-file", symlink)
696
    RemoveFile(symlink)
697
    if os.path.exists(symlink):
698
      self.fail("File '%s' not removed" % symlink)
699
    os.symlink(self.tmpfile, symlink)
700
    RemoveFile(symlink)
701
    if os.path.exists(symlink):
702
      self.fail("File '%s' not removed" % symlink)
703

    
704

    
705
class TestRemoveDir(unittest.TestCase):
706
  def setUp(self):
707
    self.tmpdir = tempfile.mkdtemp()
708

    
709
  def tearDown(self):
710
    try:
711
      shutil.rmtree(self.tmpdir)
712
    except EnvironmentError:
713
      pass
714

    
715
  def testEmptyDir(self):
716
    utils.RemoveDir(self.tmpdir)
717
    self.assertFalse(os.path.isdir(self.tmpdir))
718

    
719
  def testNonEmptyDir(self):
720
    self.tmpfile = os.path.join(self.tmpdir, "test1")
721
    open(self.tmpfile, "w").close()
722
    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
723

    
724

    
725
class TestRename(unittest.TestCase):
726
  """Test case for RenameFile"""
727

    
728
  def setUp(self):
729
    """Create a temporary directory"""
730
    self.tmpdir = tempfile.mkdtemp()
731
    self.tmpfile = os.path.join(self.tmpdir, "test1")
732

    
733
    # Touch the file
734
    open(self.tmpfile, "w").close()
735

    
736
  def tearDown(self):
737
    """Remove temporary directory"""
738
    shutil.rmtree(self.tmpdir)
739

    
740
  def testSimpleRename1(self):
741
    """Simple rename 1"""
742
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
743
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
744

    
745
  def testSimpleRename2(self):
746
    """Simple rename 2"""
747
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
748
                     mkdir=True)
749
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
750

    
751
  def testRenameMkdir(self):
752
    """Rename with mkdir"""
753
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
754
                     mkdir=True)
755
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
756
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
757

    
758
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
759
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
760
                     mkdir=True)
761
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
762
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
763
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
764

    
765

    
766
class TestMatchNameComponent(unittest.TestCase):
767
  """Test case for the MatchNameComponent function"""
768

    
769
  def testEmptyList(self):
770
    """Test that there is no match against an empty list"""
771

    
772
    self.failUnlessEqual(MatchNameComponent("", []), None)
773
    self.failUnlessEqual(MatchNameComponent("test", []), None)
774

    
775
  def testSingleMatch(self):
776
    """Test that a single match is performed correctly"""
777
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
778
    for key in "test2", "test2.example", "test2.example.com":
779
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
780

    
781
  def testMultipleMatches(self):
782
    """Test that a multiple match is returned as None"""
783
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
784
    for key in "test1", "test1.example":
785
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
786

    
787
  def testFullMatch(self):
788
    """Test that a full match is returned correctly"""
789
    key1 = "test1"
790
    key2 = "test1.example"
791
    mlist = [key2, key2 + ".com"]
792
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
793
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
794

    
795
  def testCaseInsensitivePartialMatch(self):
796
    """Test for the case_insensitive keyword"""
797
    mlist = ["test1.example.com", "test2.example.net"]
798
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
799
                     "test2.example.net")
800
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
801
                     "test2.example.net")
802
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
803
                     "test2.example.net")
804
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
805
                     "test2.example.net")
806

    
807

    
808
  def testCaseInsensitiveFullMatch(self):
809
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
810
    # Between the two ts1 a full string match non-case insensitive should work
811
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
812
                     None)
813
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
814
                     "ts1.ex")
815
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
816
                     "ts1.ex")
817
    # Between the two ts2 only case differs, so only case-match works
818
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
819
                     "ts2.ex")
820
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
821
                     "Ts2.ex")
822
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
823
                     None)
824

    
825

    
826
class TestReadFile(testutils.GanetiTestCase):
827

    
828
  def testReadAll(self):
829
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
830
    self.assertEqual(len(data), 814)
831

    
832
    h = compat.md5_hash()
833
    h.update(data)
834
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
835

    
836
  def testReadSize(self):
837
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
838
                          size=100)
839
    self.assertEqual(len(data), 100)
840

    
841
    h = compat.md5_hash()
842
    h.update(data)
843
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
844

    
845
  def testError(self):
846
    self.assertRaises(EnvironmentError, utils.ReadFile,
847
                      "/dev/null/does-not-exist")
848

    
849

    
850
class TestReadOneLineFile(testutils.GanetiTestCase):
851

    
852
  def setUp(self):
853
    testutils.GanetiTestCase.setUp(self)
854

    
855
  def testDefault(self):
856
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
857
    self.assertEqual(len(data), 27)
858
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
859

    
860
  def testNotStrict(self):
861
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
862
    self.assertEqual(len(data), 27)
863
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
864

    
865
  def testStrictFailure(self):
866
    self.assertRaises(errors.GenericError, ReadOneLineFile,
867
                      self._TestDataFilename("cert1.pem"), strict=True)
868

    
869
  def testLongLine(self):
870
    dummydata = (1024 * "Hello World! ")
871
    myfile = self._CreateTempFile()
872
    utils.WriteFile(myfile, data=dummydata)
873
    datastrict = ReadOneLineFile(myfile, strict=True)
874
    datalax = ReadOneLineFile(myfile, strict=False)
875
    self.assertEqual(dummydata, datastrict)
876
    self.assertEqual(dummydata, datalax)
877

    
878
  def testNewline(self):
879
    myfile = self._CreateTempFile()
880
    myline = "myline"
881
    for nl in ["", "\n", "\r\n"]:
882
      dummydata = "%s%s" % (myline, nl)
883
      utils.WriteFile(myfile, data=dummydata)
884
      datalax = ReadOneLineFile(myfile, strict=False)
885
      self.assertEqual(myline, datalax)
886
      datastrict = ReadOneLineFile(myfile, strict=True)
887
      self.assertEqual(myline, datastrict)
888

    
889
  def testWhitespaceAndMultipleLines(self):
890
    myfile = self._CreateTempFile()
891
    for nl in ["", "\n", "\r\n"]:
892
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
893
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
894
        utils.WriteFile(myfile, data=dummydata)
895
        datalax = ReadOneLineFile(myfile, strict=False)
896
        if nl:
897
          self.assert_(set("\r\n") & set(dummydata))
898
          self.assertRaises(errors.GenericError, ReadOneLineFile,
899
                            myfile, strict=True)
900
          explen = len("Foo bar baz ") + len(ws)
901
          self.assertEqual(len(datalax), explen)
902
          self.assertEqual(datalax, dummydata[:explen])
903
          self.assertFalse(set("\r\n") & set(datalax))
904
        else:
905
          datastrict = ReadOneLineFile(myfile, strict=True)
906
          self.assertEqual(dummydata, datastrict)
907
          self.assertEqual(dummydata, datalax)
908

    
909
  def testEmptylines(self):
910
    myfile = self._CreateTempFile()
911
    myline = "myline"
912
    for nl in ["\n", "\r\n"]:
913
      for ol in ["", "otherline"]:
914
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
915
        utils.WriteFile(myfile, data=dummydata)
916
        self.assert_(set("\r\n") & set(dummydata))
917
        datalax = ReadOneLineFile(myfile, strict=False)
918
        self.assertEqual(myline, datalax)
919
        if ol:
920
          self.assertRaises(errors.GenericError, ReadOneLineFile,
921
                            myfile, strict=True)
922
        else:
923
          datastrict = ReadOneLineFile(myfile, strict=True)
924
          self.assertEqual(myline, datastrict)
925

    
926
  def testEmptyfile(self):
927
    myfile = self._CreateTempFile()
928
    self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
929

    
930

    
931
class TestTimestampForFilename(unittest.TestCase):
932
  def test(self):
933
    self.assert_("." not in utils.TimestampForFilename())
934
    self.assert_(":" not in utils.TimestampForFilename())
935

    
936

    
937
class TestCreateBackup(testutils.GanetiTestCase):
938
  def setUp(self):
939
    testutils.GanetiTestCase.setUp(self)
940

    
941
    self.tmpdir = tempfile.mkdtemp()
942

    
943
  def tearDown(self):
944
    testutils.GanetiTestCase.tearDown(self)
945

    
946
    shutil.rmtree(self.tmpdir)
947

    
948
  def testEmpty(self):
949
    filename = PathJoin(self.tmpdir, "config.data")
950
    utils.WriteFile(filename, data="")
951
    bname = utils.CreateBackup(filename)
952
    self.assertFileContent(bname, "")
953
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
954
    utils.CreateBackup(filename)
955
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
956
    utils.CreateBackup(filename)
957
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
958

    
959
    fifoname = PathJoin(self.tmpdir, "fifo")
960
    os.mkfifo(fifoname)
961
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
962

    
963
  def testContent(self):
964
    bkpcount = 0
965
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
966
      for rep in [1, 2, 10, 127]:
967
        testdata = data * rep
968

    
969
        filename = PathJoin(self.tmpdir, "test.data_")
970
        utils.WriteFile(filename, data=testdata)
971
        self.assertFileContent(filename, testdata)
972

    
973
        for _ in range(3):
974
          bname = utils.CreateBackup(filename)
975
          bkpcount += 1
976
          self.assertFileContent(bname, testdata)
977
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
978

    
979

    
980
class TestFormatUnit(unittest.TestCase):
981
  """Test case for the FormatUnit function"""
982

    
983
  def testMiB(self):
984
    self.assertEqual(FormatUnit(1, 'h'), '1M')
985
    self.assertEqual(FormatUnit(100, 'h'), '100M')
986
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
987

    
988
    self.assertEqual(FormatUnit(1, 'm'), '1')
989
    self.assertEqual(FormatUnit(100, 'm'), '100')
990
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
991

    
992
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
993
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
994
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
995
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
996

    
997
  def testGiB(self):
998
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
999
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
1000
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
1001
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
1002

    
1003
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
1004
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
1005
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
1006
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
1007

    
1008
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
1009
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
1010
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
1011

    
1012
  def testTiB(self):
1013
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
1014
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
1015
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
1016

    
1017
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
1018
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
1019
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
1020

    
1021
  def testErrors(self):
1022
    self.assertRaises(errors.ProgrammerError, FormatUnit, 1, "a")
1023

    
1024

    
1025
class TestParseUnit(unittest.TestCase):
1026
  """Test case for the ParseUnit function"""
1027

    
1028
  SCALES = (('', 1),
1029
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
1030
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
1031
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
1032

    
1033
  def testRounding(self):
1034
    self.assertEqual(ParseUnit('0'), 0)
1035
    self.assertEqual(ParseUnit('1'), 4)
1036
    self.assertEqual(ParseUnit('2'), 4)
1037
    self.assertEqual(ParseUnit('3'), 4)
1038

    
1039
    self.assertEqual(ParseUnit('124'), 124)
1040
    self.assertEqual(ParseUnit('125'), 128)
1041
    self.assertEqual(ParseUnit('126'), 128)
1042
    self.assertEqual(ParseUnit('127'), 128)
1043
    self.assertEqual(ParseUnit('128'), 128)
1044
    self.assertEqual(ParseUnit('129'), 132)
1045
    self.assertEqual(ParseUnit('130'), 132)
1046

    
1047
  def testFloating(self):
1048
    self.assertEqual(ParseUnit('0'), 0)
1049
    self.assertEqual(ParseUnit('0.5'), 4)
1050
    self.assertEqual(ParseUnit('1.75'), 4)
1051
    self.assertEqual(ParseUnit('1.99'), 4)
1052
    self.assertEqual(ParseUnit('2.00'), 4)
1053
    self.assertEqual(ParseUnit('2.01'), 4)
1054
    self.assertEqual(ParseUnit('3.99'), 4)
1055
    self.assertEqual(ParseUnit('4.00'), 4)
1056
    self.assertEqual(ParseUnit('4.01'), 8)
1057
    self.assertEqual(ParseUnit('1.5G'), 1536)
1058
    self.assertEqual(ParseUnit('1.8G'), 1844)
1059
    self.assertEqual(ParseUnit('8.28T'), 8682212)
1060

    
1061
  def testSuffixes(self):
1062
    for sep in ('', ' ', '   ', "\t", "\t "):
1063
      for suffix, scale in TestParseUnit.SCALES:
1064
        for func in (lambda x: x, str.lower, str.upper):
1065
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
1066
                           1024 * scale)
1067

    
1068
  def testInvalidInput(self):
1069
    for sep in ('-', '_', ',', 'a'):
1070
      for suffix, _ in TestParseUnit.SCALES:
1071
        self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
1072

    
1073
    for suffix, _ in TestParseUnit.SCALES:
1074
      self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
1075

    
1076

    
1077
class TestParseCpuMask(unittest.TestCase):
1078
  """Test case for the ParseCpuMask function."""
1079

    
1080
  def testWellFormed(self):
1081
    self.assertEqual(utils.ParseCpuMask(""), [])
1082
    self.assertEqual(utils.ParseCpuMask("1"), [1])
1083
    self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
1084

    
1085
  def testInvalidInput(self):
1086
    for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
1087
      self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
1088

    
1089

    
1090
class TestSshKeys(testutils.GanetiTestCase):
1091
  """Test case for the AddAuthorizedKey function"""
1092

    
1093
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1094
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1095
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1096

    
1097
  def setUp(self):
1098
    testutils.GanetiTestCase.setUp(self)
1099
    self.tmpname = self._CreateTempFile()
1100
    handle = open(self.tmpname, 'w')
1101
    try:
1102
      handle.write("%s\n" % TestSshKeys.KEY_A)
1103
      handle.write("%s\n" % TestSshKeys.KEY_B)
1104
    finally:
1105
      handle.close()
1106

    
1107
  def testAddingNewKey(self):
1108
    utils.AddAuthorizedKey(self.tmpname,
1109
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1110

    
1111
    self.assertFileContent(self.tmpname,
1112
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1113
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1114
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1115
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1116

    
1117
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1118
    utils.AddAuthorizedKey(self.tmpname,
1119
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1120

    
1121
    self.assertFileContent(self.tmpname,
1122
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1123
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1124
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1125
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1126

    
1127
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1128
    utils.AddAuthorizedKey(self.tmpname,
1129
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1130

    
1131
    self.assertFileContent(self.tmpname,
1132
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1133
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1134
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1135

    
1136
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1137
    utils.RemoveAuthorizedKey(self.tmpname,
1138
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1139

    
1140
    self.assertFileContent(self.tmpname,
1141
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1142
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1143

    
1144
  def testRemovingNonExistingKey(self):
1145
    utils.RemoveAuthorizedKey(self.tmpname,
1146
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1147

    
1148
    self.assertFileContent(self.tmpname,
1149
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1150
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1151
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1152

    
1153

    
1154
class TestEtcHosts(testutils.GanetiTestCase):
1155
  """Test functions modifying /etc/hosts"""
1156

    
1157
  def setUp(self):
1158
    testutils.GanetiTestCase.setUp(self)
1159
    self.tmpname = self._CreateTempFile()
1160
    handle = open(self.tmpname, 'w')
1161
    try:
1162
      handle.write('# This is a test file for /etc/hosts\n')
1163
      handle.write('127.0.0.1\tlocalhost\n')
1164
      handle.write('192.0.2.1 router gw\n')
1165
    finally:
1166
      handle.close()
1167

    
1168
  def testSettingNewIp(self):
1169
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1170
                     ['myhost'])
1171

    
1172
    self.assertFileContent(self.tmpname,
1173
      "# This is a test file for /etc/hosts\n"
1174
      "127.0.0.1\tlocalhost\n"
1175
      "192.0.2.1 router gw\n"
1176
      "198.51.100.4\tmyhost.example.com myhost\n")
1177
    self.assertFileMode(self.tmpname, 0644)
1178

    
1179
  def testSettingExistingIp(self):
1180
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1181
                     ['myhost'])
1182

    
1183
    self.assertFileContent(self.tmpname,
1184
      "# This is a test file for /etc/hosts\n"
1185
      "127.0.0.1\tlocalhost\n"
1186
      "192.0.2.1\tmyhost.example.com myhost\n")
1187
    self.assertFileMode(self.tmpname, 0644)
1188

    
1189
  def testSettingDuplicateName(self):
1190
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1191

    
1192
    self.assertFileContent(self.tmpname,
1193
      "# This is a test file for /etc/hosts\n"
1194
      "127.0.0.1\tlocalhost\n"
1195
      "192.0.2.1 router gw\n"
1196
      "198.51.100.4\tmyhost\n")
1197
    self.assertFileMode(self.tmpname, 0644)
1198

    
1199
  def testRemovingExistingHost(self):
1200
    RemoveEtcHostsEntry(self.tmpname, 'router')
1201

    
1202
    self.assertFileContent(self.tmpname,
1203
      "# This is a test file for /etc/hosts\n"
1204
      "127.0.0.1\tlocalhost\n"
1205
      "192.0.2.1 gw\n")
1206
    self.assertFileMode(self.tmpname, 0644)
1207

    
1208
  def testRemovingSingleExistingHost(self):
1209
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1210

    
1211
    self.assertFileContent(self.tmpname,
1212
      "# This is a test file for /etc/hosts\n"
1213
      "192.0.2.1 router gw\n")
1214
    self.assertFileMode(self.tmpname, 0644)
1215

    
1216
  def testRemovingNonExistingHost(self):
1217
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1218

    
1219
    self.assertFileContent(self.tmpname,
1220
      "# This is a test file for /etc/hosts\n"
1221
      "127.0.0.1\tlocalhost\n"
1222
      "192.0.2.1 router gw\n")
1223
    self.assertFileMode(self.tmpname, 0644)
1224

    
1225
  def testRemovingAlias(self):
1226
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1227

    
1228
    self.assertFileContent(self.tmpname,
1229
      "# This is a test file for /etc/hosts\n"
1230
      "127.0.0.1\tlocalhost\n"
1231
      "192.0.2.1 router\n")
1232
    self.assertFileMode(self.tmpname, 0644)
1233

    
1234

    
1235
class TestGetMounts(unittest.TestCase):
1236
  """Test case for GetMounts()."""
1237

    
1238
  TESTDATA = (
1239
    "rootfs /     rootfs rw 0 0\n"
1240
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1241
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1242

    
1243
  def setUp(self):
1244
    self.tmpfile = tempfile.NamedTemporaryFile()
1245
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1246

    
1247
  def testGetMounts(self):
1248
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1249
      [
1250
        ("rootfs", "/", "rootfs", "rw"),
1251
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1252
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1253
      ])
1254

    
1255

    
1256
class TestShellQuoting(unittest.TestCase):
1257
  """Test case for shell quoting functions"""
1258

    
1259
  def testShellQuote(self):
1260
    self.assertEqual(ShellQuote('abc'), "abc")
1261
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1262
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1263
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1264
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1265

    
1266
  def testShellQuoteArgs(self):
1267
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1268
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1269
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1270

    
1271

    
1272
class TestListVisibleFiles(unittest.TestCase):
1273
  """Test case for ListVisibleFiles"""
1274

    
1275
  def setUp(self):
1276
    self.path = tempfile.mkdtemp()
1277

    
1278
  def tearDown(self):
1279
    shutil.rmtree(self.path)
1280

    
1281
  def _CreateFiles(self, files):
1282
    for name in files:
1283
      utils.WriteFile(os.path.join(self.path, name), data="test")
1284

    
1285
  def _test(self, files, expected):
1286
    self._CreateFiles(files)
1287
    found = ListVisibleFiles(self.path)
1288
    self.assertEqual(set(found), set(expected))
1289

    
1290
  def testAllVisible(self):
1291
    files = ["a", "b", "c"]
1292
    expected = files
1293
    self._test(files, expected)
1294

    
1295
  def testNoneVisible(self):
1296
    files = [".a", ".b", ".c"]
1297
    expected = []
1298
    self._test(files, expected)
1299

    
1300
  def testSomeVisible(self):
1301
    files = ["a", "b", ".c"]
1302
    expected = ["a", "b"]
1303
    self._test(files, expected)
1304

    
1305
  def testNonAbsolutePath(self):
1306
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1307

    
1308
  def testNonNormalizedPath(self):
1309
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1310
                          "/bin/../tmp")
1311

    
1312

    
1313
class TestNewUUID(unittest.TestCase):
1314
  """Test case for NewUUID"""
1315

    
1316
  def runTest(self):
1317
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1318

    
1319

    
1320
class TestFirstFree(unittest.TestCase):
1321
  """Test case for the FirstFree function"""
1322

    
1323
  def test(self):
1324
    """Test FirstFree"""
1325
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1326
    self.failUnlessEqual(FirstFree([]), None)
1327
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1328
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1329
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1330

    
1331

    
1332
class TestTailFile(testutils.GanetiTestCase):
1333
  """Test case for the TailFile function"""
1334

    
1335
  def testEmpty(self):
1336
    fname = self._CreateTempFile()
1337
    self.failUnlessEqual(TailFile(fname), [])
1338
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1339

    
1340
  def testAllLines(self):
1341
    data = ["test %d" % i for i in range(30)]
1342
    for i in range(30):
1343
      fname = self._CreateTempFile()
1344
      fd = open(fname, "w")
1345
      fd.write("\n".join(data[:i]))
1346
      if i > 0:
1347
        fd.write("\n")
1348
      fd.close()
1349
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1350

    
1351
  def testPartialLines(self):
1352
    data = ["test %d" % i for i in range(30)]
1353
    fname = self._CreateTempFile()
1354
    fd = open(fname, "w")
1355
    fd.write("\n".join(data))
1356
    fd.write("\n")
1357
    fd.close()
1358
    for i in range(1, 30):
1359
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1360

    
1361
  def testBigFile(self):
1362
    data = ["test %d" % i for i in range(30)]
1363
    fname = self._CreateTempFile()
1364
    fd = open(fname, "w")
1365
    fd.write("X" * 1048576)
1366
    fd.write("\n")
1367
    fd.write("\n".join(data))
1368
    fd.write("\n")
1369
    fd.close()
1370
    for i in range(1, 30):
1371
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1372

    
1373

    
1374
class _BaseFileLockTest:
1375
  """Test case for the FileLock class"""
1376

    
1377
  def testSharedNonblocking(self):
1378
    self.lock.Shared(blocking=False)
1379
    self.lock.Close()
1380

    
1381
  def testExclusiveNonblocking(self):
1382
    self.lock.Exclusive(blocking=False)
1383
    self.lock.Close()
1384

    
1385
  def testUnlockNonblocking(self):
1386
    self.lock.Unlock(blocking=False)
1387
    self.lock.Close()
1388

    
1389
  def testSharedBlocking(self):
1390
    self.lock.Shared(blocking=True)
1391
    self.lock.Close()
1392

    
1393
  def testExclusiveBlocking(self):
1394
    self.lock.Exclusive(blocking=True)
1395
    self.lock.Close()
1396

    
1397
  def testUnlockBlocking(self):
1398
    self.lock.Unlock(blocking=True)
1399
    self.lock.Close()
1400

    
1401
  def testSharedExclusiveUnlock(self):
1402
    self.lock.Shared(blocking=False)
1403
    self.lock.Exclusive(blocking=False)
1404
    self.lock.Unlock(blocking=False)
1405
    self.lock.Close()
1406

    
1407
  def testExclusiveSharedUnlock(self):
1408
    self.lock.Exclusive(blocking=False)
1409
    self.lock.Shared(blocking=False)
1410
    self.lock.Unlock(blocking=False)
1411
    self.lock.Close()
1412

    
1413
  def testSimpleTimeout(self):
1414
    # These will succeed on the first attempt, hence a short timeout
1415
    self.lock.Shared(blocking=True, timeout=10.0)
1416
    self.lock.Exclusive(blocking=False, timeout=10.0)
1417
    self.lock.Unlock(blocking=True, timeout=10.0)
1418
    self.lock.Close()
1419

    
1420
  @staticmethod
1421
  def _TryLockInner(filename, shared, blocking):
1422
    lock = utils.FileLock.Open(filename)
1423

    
1424
    if shared:
1425
      fn = lock.Shared
1426
    else:
1427
      fn = lock.Exclusive
1428

    
1429
    try:
1430
      # The timeout doesn't really matter as the parent process waits for us to
1431
      # finish anyway.
1432
      fn(blocking=blocking, timeout=0.01)
1433
    except errors.LockError, err:
1434
      return False
1435

    
1436
    return True
1437

    
1438
  def _TryLock(self, *args):
1439
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1440
                                      *args)
1441

    
1442
  def testTimeout(self):
1443
    for blocking in [True, False]:
1444
      self.lock.Exclusive(blocking=True)
1445
      self.failIf(self._TryLock(False, blocking))
1446
      self.failIf(self._TryLock(True, blocking))
1447

    
1448
      self.lock.Shared(blocking=True)
1449
      self.assert_(self._TryLock(True, blocking))
1450
      self.failIf(self._TryLock(False, blocking))
1451

    
1452
  def testCloseShared(self):
1453
    self.lock.Close()
1454
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1455

    
1456
  def testCloseExclusive(self):
1457
    self.lock.Close()
1458
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1459

    
1460
  def testCloseUnlock(self):
1461
    self.lock.Close()
1462
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1463

    
1464

    
1465
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1466
  TESTDATA = "Hello World\n" * 10
1467

    
1468
  def setUp(self):
1469
    testutils.GanetiTestCase.setUp(self)
1470

    
1471
    self.tmpfile = tempfile.NamedTemporaryFile()
1472
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1473
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1474

    
1475
    # Ensure "Open" didn't truncate file
1476
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1477

    
1478
  def tearDown(self):
1479
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1480

    
1481
    testutils.GanetiTestCase.tearDown(self)
1482

    
1483

    
1484
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1485
  def setUp(self):
1486
    self.tmpfile = tempfile.NamedTemporaryFile()
1487
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1488

    
1489

    
1490
class TestTimeFunctions(unittest.TestCase):
1491
  """Test case for time functions"""
1492

    
1493
  def runTest(self):
1494
    self.assertEqual(utils.SplitTime(1), (1, 0))
1495
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1496
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1497
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1498
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1499
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1500
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1501
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1502

    
1503
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1504

    
1505
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1506
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1507
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1508

    
1509
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1510
                     1218448917.481)
1511
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1512

    
1513
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1514
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1515
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1516
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1517
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1518

    
1519

    
1520
class FieldSetTestCase(unittest.TestCase):
1521
  """Test case for FieldSets"""
1522

    
1523
  def testSimpleMatch(self):
1524
    f = utils.FieldSet("a", "b", "c", "def")
1525
    self.failUnless(f.Matches("a"))
1526
    self.failIf(f.Matches("d"), "Substring matched")
1527
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1528
    self.failIf(f.NonMatching(["b", "c"]))
1529
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1530
    self.failUnless(f.NonMatching(["a", "d"]))
1531

    
1532
  def testRegexMatch(self):
1533
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1534
    self.failUnless(f.Matches("b1"))
1535
    self.failUnless(f.Matches("b99"))
1536
    self.failIf(f.Matches("b/1"))
1537
    self.failIf(f.NonMatching(["b12", "c"]))
1538
    self.failUnless(f.NonMatching(["a", "1"]))
1539

    
1540
class TestForceDictType(unittest.TestCase):
1541
  """Test case for ForceDictType"""
1542
  KEY_TYPES = {
1543
    "a": constants.VTYPE_INT,
1544
    "b": constants.VTYPE_BOOL,
1545
    "c": constants.VTYPE_STRING,
1546
    "d": constants.VTYPE_SIZE,
1547
    "e": constants.VTYPE_MAYBE_STRING,
1548
    }
1549

    
1550
  def _fdt(self, dict, allowed_values=None):
1551
    if allowed_values is None:
1552
      utils.ForceDictType(dict, self.KEY_TYPES)
1553
    else:
1554
      utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
1555

    
1556
    return dict
1557

    
1558
  def testSimpleDict(self):
1559
    self.assertEqual(self._fdt({}), {})
1560
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1561
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1562
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1563
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1564
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1565
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1566
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1567
    self.assertEqual(self._fdt({'b': False}), {'b': False})
1568
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1569
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1570
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1571
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1572
    self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1573
    self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1574
    self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1575
    self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
1576

    
1577
  def testErrors(self):
1578
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1579
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
1580
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1581
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1582
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1583
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1584
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1585
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
1586
    self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
1587
    self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
1588
                      {"b": "hello"}, {"b": "no-such-type"})
1589

    
1590

    
1591
class TestIsNormAbsPath(unittest.TestCase):
1592
  """Testing case for IsNormAbsPath"""
1593

    
1594
  def _pathTestHelper(self, path, result):
1595
    if result:
1596
      self.assert_(utils.IsNormAbsPath(path),
1597
          "Path %s should result absolute and normalized" % path)
1598
    else:
1599
      self.assertFalse(utils.IsNormAbsPath(path),
1600
          "Path %s should not result absolute and normalized" % path)
1601

    
1602
  def testBase(self):
1603
    self._pathTestHelper('/etc', True)
1604
    self._pathTestHelper('/srv', True)
1605
    self._pathTestHelper('etc', False)
1606
    self._pathTestHelper('/etc/../root', False)
1607
    self._pathTestHelper('/etc/', False)
1608

    
1609

    
1610
class TestSafeEncode(unittest.TestCase):
1611
  """Test case for SafeEncode"""
1612

    
1613
  def testAscii(self):
1614
    for txt in [string.digits, string.letters, string.punctuation]:
1615
      self.failUnlessEqual(txt, SafeEncode(txt))
1616

    
1617
  def testDoubleEncode(self):
1618
    for i in range(255):
1619
      txt = SafeEncode(chr(i))
1620
      self.failUnlessEqual(txt, SafeEncode(txt))
1621

    
1622
  def testUnicode(self):
1623
    # 1024 is high enough to catch non-direct ASCII mappings
1624
    for i in range(1024):
1625
      txt = SafeEncode(unichr(i))
1626
      self.failUnlessEqual(txt, SafeEncode(txt))
1627

    
1628

    
1629
class TestFormatTime(unittest.TestCase):
1630
  """Testing case for FormatTime"""
1631

    
1632
  @staticmethod
1633
  def _TestInProcess(tz, timestamp, expected):
1634
    os.environ["TZ"] = tz
1635
    time.tzset()
1636
    return utils.FormatTime(timestamp) == expected
1637

    
1638
  def _Test(self, *args):
1639
    # Need to use separate process as we want to change TZ
1640
    self.assert_(utils.RunInSeparateProcess(self._TestInProcess, *args))
1641

    
1642
  def test(self):
1643
    self._Test("UTC", 0, "1970-01-01 00:00:00")
1644
    self._Test("America/Sao_Paulo", 1292606926, "2010-12-17 15:28:46")
1645
    self._Test("Europe/London", 1292606926, "2010-12-17 17:28:46")
1646
    self._Test("Europe/Zurich", 1292606926, "2010-12-17 18:28:46")
1647
    self._Test("Australia/Sydney", 1292606926, "2010-12-18 04:28:46")
1648

    
1649
  def testNone(self):
1650
    self.failUnlessEqual(FormatTime(None), "N/A")
1651

    
1652
  def testInvalid(self):
1653
    self.failUnlessEqual(FormatTime(()), "N/A")
1654

    
1655
  def testNow(self):
1656
    # tests that we accept time.time input
1657
    FormatTime(time.time())
1658
    # tests that we accept int input
1659
    FormatTime(int(time.time()))
1660

    
1661

    
1662
class RunInSeparateProcess(unittest.TestCase):
1663
  def test(self):
1664
    for exp in [True, False]:
1665
      def _child():
1666
        return exp
1667

    
1668
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1669

    
1670
  def testArgs(self):
1671
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1672
      def _child(carg1, carg2):
1673
        return carg1 == "Foo" and carg2 == arg
1674

    
1675
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1676

    
1677
  def testPid(self):
1678
    parent_pid = os.getpid()
1679

    
1680
    def _check():
1681
      return os.getpid() == parent_pid
1682

    
1683
    self.failIf(utils.RunInSeparateProcess(_check))
1684

    
1685
  def testSignal(self):
1686
    def _kill():
1687
      os.kill(os.getpid(), signal.SIGTERM)
1688

    
1689
    self.assertRaises(errors.GenericError,
1690
                      utils.RunInSeparateProcess, _kill)
1691

    
1692
  def testException(self):
1693
    def _exc():
1694
      raise errors.GenericError("This is a test")
1695

    
1696
    self.assertRaises(errors.GenericError,
1697
                      utils.RunInSeparateProcess, _exc)
1698

    
1699

    
1700
class TestFingerprintFiles(unittest.TestCase):
1701
  def setUp(self):
1702
    self.tmpfile = tempfile.NamedTemporaryFile()
1703
    self.tmpfile2 = tempfile.NamedTemporaryFile()
1704
    utils.WriteFile(self.tmpfile2.name, data="Hello World\n")
1705
    self.results = {
1706
      self.tmpfile.name: "da39a3ee5e6b4b0d3255bfef95601890afd80709",
1707
      self.tmpfile2.name: "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a",
1708
      }
1709

    
1710
  def testSingleFile(self):
1711
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1712
                     self.results[self.tmpfile.name])
1713

    
1714
    self.assertEqual(utils._FingerprintFile("/no/such/file"), None)
1715

    
1716
  def testBigFile(self):
1717
    self.tmpfile.write("A" * 8192)
1718
    self.tmpfile.flush()
1719
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1720
                     "35b6795ca20d6dc0aff8c7c110c96cd1070b8c38")
1721

    
1722
  def testMultiple(self):
1723
    all_files = self.results.keys()
1724
    all_files.append("/no/such/file")
1725
    self.assertEqual(utils.FingerprintFiles(self.results.keys()), self.results)
1726

    
1727

    
1728
class TestUnescapeAndSplit(unittest.TestCase):
1729
  """Testing case for UnescapeAndSplit"""
1730

    
1731
  def setUp(self):
1732
    # testing more that one separator for regexp safety
1733
    self._seps = [",", "+", "."]
1734

    
1735
  def testSimple(self):
1736
    a = ["a", "b", "c", "d"]
1737
    for sep in self._seps:
1738
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1739

    
1740
  def testEscape(self):
1741
    for sep in self._seps:
1742
      a = ["a", "b\\" + sep + "c", "d"]
1743
      b = ["a", "b" + sep + "c", "d"]
1744
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1745

    
1746
  def testDoubleEscape(self):
1747
    for sep in self._seps:
1748
      a = ["a", "b\\\\", "c", "d"]
1749
      b = ["a", "b\\", "c", "d"]
1750
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1751

    
1752
  def testThreeEscape(self):
1753
    for sep in self._seps:
1754
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1755
      b = ["a", "b\\" + sep + "c", "d"]
1756
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1757

    
1758

    
1759
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1760
  def setUp(self):
1761
    self.tmpdir = tempfile.mkdtemp()
1762

    
1763
  def tearDown(self):
1764
    shutil.rmtree(self.tmpdir)
1765

    
1766
  def _checkRsaPrivateKey(self, key):
1767
    lines = key.splitlines()
1768
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1769
            "-----END RSA PRIVATE KEY-----" in lines)
1770

    
1771
  def _checkCertificate(self, cert):
1772
    lines = cert.splitlines()
1773
    return ("-----BEGIN CERTIFICATE-----" in lines and
1774
            "-----END CERTIFICATE-----" in lines)
1775

    
1776
  def test(self):
1777
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1778
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1779
      self._checkRsaPrivateKey(key_pem)
1780
      self._checkCertificate(cert_pem)
1781

    
1782
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1783
                                           key_pem)
1784
      self.assert_(key.bits() >= 1024)
1785
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1786
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1787

    
1788
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1789
                                             cert_pem)
1790
      self.failIf(x509.has_expired())
1791
      self.assertEqual(x509.get_issuer().CN, common_name)
1792
      self.assertEqual(x509.get_subject().CN, common_name)
1793
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1794

    
1795
  def testLegacy(self):
1796
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1797

    
1798
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1799

    
1800
    cert1 = utils.ReadFile(cert1_filename)
1801

    
1802
    self.assert_(self._checkRsaPrivateKey(cert1))
1803
    self.assert_(self._checkCertificate(cert1))
1804

    
1805

    
1806
class TestPathJoin(unittest.TestCase):
1807
  """Testing case for PathJoin"""
1808

    
1809
  def testBasicItems(self):
1810
    mlist = ["/a", "b", "c"]
1811
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1812

    
1813
  def testNonAbsPrefix(self):
1814
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1815

    
1816
  def testBackTrack(self):
1817
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1818

    
1819
  def testMultiAbs(self):
1820
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1821

    
1822

    
1823
class TestValidateServiceName(unittest.TestCase):
1824
  def testValid(self):
1825
    testnames = [
1826
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1827
      "ganeti",
1828
      "gnt-masterd",
1829
      "HELLO_WORLD_SVC",
1830
      "hello.world.1",
1831
      "0", "80", "1111", "65535",
1832
      ]
1833

    
1834
    for name in testnames:
1835
      self.assertEqual(utils.ValidateServiceName(name), name)
1836

    
1837
  def testInvalid(self):
1838
    testnames = [
1839
      -15756, -1, 65536, 133428083,
1840
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1841
      "-8546", "-1", "65536",
1842
      (129 * "A"),
1843
      ]
1844

    
1845
    for name in testnames:
1846
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1847

    
1848

    
1849
class TestParseAsn1Generalizedtime(unittest.TestCase):
1850
  def test(self):
1851
    # UTC
1852
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1853
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1854
                     1266860512)
1855
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1856
                     (2**31) - 1)
1857

    
1858
    # With offset
1859
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1860
                     1266860512)
1861
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1862
                     1266931012)
1863
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1864
                     1266931088)
1865
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1866
                     1266931295)
1867
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1868
                     3600)
1869

    
1870
    # Leap seconds are not supported by datetime.datetime
1871
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1872
                      "19841231235960+0000")
1873
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1874
                      "19920630235960+0000")
1875

    
1876
    # Errors
1877
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1878
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1879
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1880
                      "20100222174152")
1881
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1882
                      "Mon Feb 22 17:47:02 UTC 2010")
1883
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1884
                      "2010-02-22 17:42:02")
1885

    
1886

    
1887
class TestGetX509CertValidity(testutils.GanetiTestCase):
1888
  def setUp(self):
1889
    testutils.GanetiTestCase.setUp(self)
1890

    
1891
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1892

    
1893
    # Test whether we have pyOpenSSL 0.7 or above
1894
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1895

    
1896
    if not self.pyopenssl0_7:
1897
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1898
                    " function correctly")
1899

    
1900
  def _LoadCert(self, name):
1901
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1902
                                           self._ReadTestData(name))
1903

    
1904
  def test(self):
1905
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1906
    if self.pyopenssl0_7:
1907
      self.assertEqual(validity, (1266919967, 1267524767))
1908
    else:
1909
      self.assertEqual(validity, (None, None))
1910

    
1911

    
1912
class TestSignX509Certificate(unittest.TestCase):
1913
  KEY = "My private key!"
1914
  KEY_OTHER = "Another key"
1915

    
1916
  def test(self):
1917
    # Generate certificate valid for 5 minutes
1918
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1919

    
1920
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1921
                                           cert_pem)
1922

    
1923
    # No signature at all
1924
    self.assertRaises(errors.GenericError,
1925
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1926

    
1927
    # Invalid input
1928
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1929
                      "", self.KEY)
1930
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1931
                      "X-Ganeti-Signature: \n", self.KEY)
1932
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1933
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1934
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1935
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1936
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1937
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1938

    
1939
    # Invalid salt
1940
    for salt in list("-_@$,:;/\\ \t\n"):
1941
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1942
                        cert_pem, self.KEY, "foo%sbar" % salt)
1943

    
1944
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1945
                 utils.GenerateSecret(numbytes=4),
1946
                 utils.GenerateSecret(numbytes=16),
1947
                 "{123:456}".encode("hex")]:
1948
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1949

    
1950
      self._Check(cert, salt, signed_pem)
1951

    
1952
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1953
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1954
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1955
                               "lines----\n------ at\nthe end!"))
1956

    
1957
  def _Check(self, cert, salt, pem):
1958
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1959
    self.assertEqual(salt, salt2)
1960
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1961

    
1962
    # Other key
1963
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1964
                      pem, self.KEY_OTHER)
1965

    
1966

    
1967
class TestMakedirs(unittest.TestCase):
1968
  def setUp(self):
1969
    self.tmpdir = tempfile.mkdtemp()
1970

    
1971
  def tearDown(self):
1972
    shutil.rmtree(self.tmpdir)
1973

    
1974
  def testNonExisting(self):
1975
    path = PathJoin(self.tmpdir, "foo")
1976
    utils.Makedirs(path)
1977
    self.assert_(os.path.isdir(path))
1978

    
1979
  def testExisting(self):
1980
    path = PathJoin(self.tmpdir, "foo")
1981
    os.mkdir(path)
1982
    utils.Makedirs(path)
1983
    self.assert_(os.path.isdir(path))
1984

    
1985
  def testRecursiveNonExisting(self):
1986
    path = PathJoin(self.tmpdir, "foo/bar/baz")
1987
    utils.Makedirs(path)
1988
    self.assert_(os.path.isdir(path))
1989

    
1990
  def testRecursiveExisting(self):
1991
    path = PathJoin(self.tmpdir, "B/moo/xyz")
1992
    self.assertFalse(os.path.exists(path))
1993
    os.mkdir(PathJoin(self.tmpdir, "B"))
1994
    utils.Makedirs(path)
1995
    self.assert_(os.path.isdir(path))
1996

    
1997

    
1998
class TestLineSplitter(unittest.TestCase):
1999
  def test(self):
2000
    lines = []
2001
    ls = utils.LineSplitter(lines.append)
2002
    ls.write("Hello World\n")
2003
    self.assertEqual(lines, [])
2004
    ls.write("Foo\n Bar\r\n ")
2005
    ls.write("Baz")
2006
    ls.write("Moo")
2007
    self.assertEqual(lines, [])
2008
    ls.flush()
2009
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2010
    ls.close()
2011
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2012

    
2013
  def _testExtra(self, line, all_lines, p1, p2):
2014
    self.assertEqual(p1, 999)
2015
    self.assertEqual(p2, "extra")
2016
    all_lines.append(line)
2017

    
2018
  def testExtraArgsNoFlush(self):
2019
    lines = []
2020
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2021
    ls.write("\n\nHello World\n")
2022
    ls.write("Foo\n Bar\r\n ")
2023
    ls.write("")
2024
    ls.write("Baz")
2025
    ls.write("Moo\n\nx\n")
2026
    self.assertEqual(lines, [])
2027
    ls.close()
2028
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2029
                             "", "x"])
2030

    
2031

    
2032
class TestReadLockedPidFile(unittest.TestCase):
2033
  def setUp(self):
2034
    self.tmpdir = tempfile.mkdtemp()
2035

    
2036
  def tearDown(self):
2037
    shutil.rmtree(self.tmpdir)
2038

    
2039
  def testNonExistent(self):
2040
    path = PathJoin(self.tmpdir, "nonexist")
2041
    self.assert_(utils.ReadLockedPidFile(path) is None)
2042

    
2043
  def testUnlocked(self):
2044
    path = PathJoin(self.tmpdir, "pid")
2045
    utils.WriteFile(path, data="123")
2046
    self.assert_(utils.ReadLockedPidFile(path) is None)
2047

    
2048
  def testLocked(self):
2049
    path = PathJoin(self.tmpdir, "pid")
2050
    utils.WriteFile(path, data="123")
2051

    
2052
    fl = utils.FileLock.Open(path)
2053
    try:
2054
      fl.Exclusive(blocking=True)
2055

    
2056
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2057
    finally:
2058
      fl.Close()
2059

    
2060
    self.assert_(utils.ReadLockedPidFile(path) is None)
2061

    
2062
  def testError(self):
2063
    path = PathJoin(self.tmpdir, "foobar", "pid")
2064
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2065
    # open(2) should return ENOTDIR
2066
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2067

    
2068

    
2069
class TestCertVerification(testutils.GanetiTestCase):
2070
  def setUp(self):
2071
    testutils.GanetiTestCase.setUp(self)
2072

    
2073
    self.tmpdir = tempfile.mkdtemp()
2074

    
2075
  def tearDown(self):
2076
    shutil.rmtree(self.tmpdir)
2077

    
2078
  def testVerifyCertificate(self):
2079
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2080
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2081
                                           cert_pem)
2082

    
2083
    # Not checking return value as this certificate is expired
2084
    utils.VerifyX509Certificate(cert, 30, 7)
2085

    
2086

    
2087
class TestVerifyCertificateInner(unittest.TestCase):
2088
  def test(self):
2089
    vci = utils._VerifyCertificateInner
2090

    
2091
    # Valid
2092
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2093
                     (None, None))
2094

    
2095
    # Not yet valid
2096
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2097
    self.assertEqual(errcode, utils.CERT_WARNING)
2098

    
2099
    # Expiring soon
2100
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2101
    self.assertEqual(errcode, utils.CERT_ERROR)
2102

    
2103
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2104
    self.assertEqual(errcode, utils.CERT_WARNING)
2105

    
2106
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2107
    self.assertEqual(errcode, None)
2108

    
2109
    # Expired
2110
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2111
    self.assertEqual(errcode, utils.CERT_ERROR)
2112

    
2113
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2114
    self.assertEqual(errcode, utils.CERT_ERROR)
2115

    
2116
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2117
    self.assertEqual(errcode, utils.CERT_ERROR)
2118

    
2119
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2120
    self.assertEqual(errcode, utils.CERT_ERROR)
2121

    
2122

    
2123
class TestHmacFunctions(unittest.TestCase):
2124
  # Digests can be checked with "openssl sha1 -hmac $key"
2125
  def testSha1Hmac(self):
2126
    self.assertEqual(utils.Sha1Hmac("", ""),
2127
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2128
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2129
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2130
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2131
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2132

    
2133
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2134
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2135
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2136

    
2137
  def testSha1HmacSalt(self):
2138
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2139
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2140
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2141
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2142
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2143
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2144

    
2145
  def testVerifySha1Hmac(self):
2146
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2147
                                               "7d64b71fb76370690e1d")))
2148
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2149
                                      ("f904c2476527c6d3e660"
2150
                                       "9ab683c66fa0652cb1dc")))
2151

    
2152
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2153
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2154
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2155
                                      digest.lower()))
2156
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2157
                                      digest.upper()))
2158
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2159
                                      digest.title()))
2160

    
2161
  def testVerifySha1HmacSalt(self):
2162
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2163
                                      ("17a4adc34d69c0d367d4"
2164
                                       "ffbef96fd41d4df7a6e8"),
2165
                                      salt="abc9"))
2166
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2167
                                      ("7f264f8114c9066afc9b"
2168
                                       "b7636e1786d996d3cc0d"),
2169
                                      salt="xyz0"))
2170

    
2171

    
2172
class TestIgnoreSignals(unittest.TestCase):
2173
  """Test the IgnoreSignals decorator"""
2174

    
2175
  @staticmethod
2176
  def _Raise(exception):
2177
    raise exception
2178

    
2179
  @staticmethod
2180
  def _Return(rval):
2181
    return rval
2182

    
2183
  def testIgnoreSignals(self):
2184
    sock_err_intr = socket.error(errno.EINTR, "Message")
2185
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2186

    
2187
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2188
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2189

    
2190
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2191
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2192
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2193
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2194

    
2195
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2196
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2197
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2198
                      sock_err_inval)
2199
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2200
                      env_err_inval)
2201

    
2202
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2203
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2204

    
2205

    
2206
class TestEnsureDirs(unittest.TestCase):
2207
  """Tests for EnsureDirs"""
2208

    
2209
  def setUp(self):
2210
    self.dir = tempfile.mkdtemp()
2211
    self.old_umask = os.umask(0777)
2212

    
2213
  def testEnsureDirs(self):
2214
    utils.EnsureDirs([
2215
        (PathJoin(self.dir, "foo"), 0777),
2216
        (PathJoin(self.dir, "bar"), 0000),
2217
        ])
2218
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2219
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2220

    
2221
  def tearDown(self):
2222
    os.rmdir(PathJoin(self.dir, "foo"))
2223
    os.rmdir(PathJoin(self.dir, "bar"))
2224
    os.rmdir(self.dir)
2225
    os.umask(self.old_umask)
2226

    
2227

    
2228
class TestFormatSeconds(unittest.TestCase):
2229
  def test(self):
2230
    self.assertEqual(utils.FormatSeconds(1), "1s")
2231
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2232
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2233
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2234
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2235
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2236
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2237
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2238
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2239
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2240
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2241

    
2242
  def testFloat(self):
2243
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2244
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2245
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2246
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2247

    
2248

    
2249
class TestIgnoreProcessNotFound(unittest.TestCase):
2250
  @staticmethod
2251
  def _WritePid(fd):
2252
    os.write(fd, str(os.getpid()))
2253
    os.close(fd)
2254
    return True
2255

    
2256
  def test(self):
2257
    (pid_read_fd, pid_write_fd) = os.pipe()
2258

    
2259
    # Start short-lived process which writes its PID to pipe
2260
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2261
    os.close(pid_write_fd)
2262

    
2263
    # Read PID from pipe
2264
    pid = int(os.read(pid_read_fd, 1024))
2265
    os.close(pid_read_fd)
2266

    
2267
    # Try to send signal to process which exited recently
2268
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2269

    
2270

    
2271
class TestShellWriter(unittest.TestCase):
2272
  def test(self):
2273
    buf = StringIO()
2274
    sw = utils.ShellWriter(buf)
2275
    sw.Write("#!/bin/bash")
2276
    sw.Write("if true; then")
2277
    sw.IncIndent()
2278
    try:
2279
      sw.Write("echo true")
2280

    
2281
      sw.Write("for i in 1 2 3")
2282
      sw.Write("do")
2283
      sw.IncIndent()
2284
      try:
2285
        self.assertEqual(sw._indent, 2)
2286
        sw.Write("date")
2287
      finally:
2288
        sw.DecIndent()
2289
      sw.Write("done")
2290
    finally:
2291
      sw.DecIndent()
2292
    sw.Write("echo %s", utils.ShellQuote("Hello World"))
2293
    sw.Write("exit 0")
2294

    
2295
    self.assertEqual(sw._indent, 0)
2296

    
2297
    output = buf.getvalue()
2298

    
2299
    self.assert_(output.endswith("\n"))
2300

    
2301
    lines = output.splitlines()
2302
    self.assertEqual(len(lines), 9)
2303
    self.assertEqual(lines[0], "#!/bin/bash")
2304
    self.assert_(re.match(r"^\s+date$", lines[5]))
2305
    self.assertEqual(lines[7], "echo 'Hello World'")
2306

    
2307
  def testEmpty(self):
2308
    buf = StringIO()
2309
    sw = utils.ShellWriter(buf)
2310
    sw = None
2311
    self.assertEqual(buf.getvalue(), "")
2312

    
2313

    
2314
class TestCommaJoin(unittest.TestCase):
2315
  def test(self):
2316
    self.assertEqual(utils.CommaJoin([]), "")
2317
    self.assertEqual(utils.CommaJoin([1, 2, 3]), "1, 2, 3")
2318
    self.assertEqual(utils.CommaJoin(["Hello"]), "Hello")
2319
    self.assertEqual(utils.CommaJoin(["Hello", "World"]), "Hello, World")
2320
    self.assertEqual(utils.CommaJoin(["Hello", "World", 99]),
2321
                     "Hello, World, 99")
2322

    
2323

    
2324
class TestFindMatch(unittest.TestCase):
2325
  def test(self):
2326
    data = {
2327
      "aaaa": "Four A",
2328
      "bb": {"Two B": True},
2329
      re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
2330
      }
2331

    
2332
    self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
2333
    self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
2334

    
2335
    for i in ["foo", "bar", "bazX"]:
2336
      for j in range(1, 100, 7):
2337
        self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
2338
                         ((1, 2, 3), [i, str(j)]))
2339

    
2340
  def testNoMatch(self):
2341
    self.assert_(utils.FindMatch({}, "") is None)
2342
    self.assert_(utils.FindMatch({}, "foo") is None)
2343
    self.assert_(utils.FindMatch({}, 1234) is None)
2344

    
2345
    data = {
2346
      "X": "Hello World",
2347
      re.compile("^(something)$"): "Hello World",
2348
      }
2349

    
2350
    self.assert_(utils.FindMatch(data, "") is None)
2351
    self.assert_(utils.FindMatch(data, "Hello World") is None)
2352

    
2353

    
2354
class TestFileID(testutils.GanetiTestCase):
2355
  def testEquality(self):
2356
    name = self._CreateTempFile()
2357
    oldi = utils.GetFileID(path=name)
2358
    self.failUnless(utils.VerifyFileID(oldi, oldi))
2359

    
2360
  def testUpdate(self):
2361
    name = self._CreateTempFile()
2362
    oldi = utils.GetFileID(path=name)
2363
    os.utime(name, None)
2364
    fd = os.open(name, os.O_RDWR)
2365
    try:
2366
      newi = utils.GetFileID(fd=fd)
2367
      self.failUnless(utils.VerifyFileID(oldi, newi))
2368
      self.failUnless(utils.VerifyFileID(newi, oldi))
2369
    finally:
2370
      os.close(fd)
2371

    
2372
  def testWriteFile(self):
2373
    name = self._CreateTempFile()
2374
    oldi = utils.GetFileID(path=name)
2375
    mtime = oldi[2]
2376
    os.utime(name, (mtime + 10, mtime + 10))
2377
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
2378
                      oldi, data="")
2379
    os.utime(name, (mtime - 10, mtime - 10))
2380
    utils.SafeWriteFile(name, oldi, data="")
2381
    oldi = utils.GetFileID(path=name)
2382
    mtime = oldi[2]
2383
    os.utime(name, (mtime + 10, mtime + 10))
2384
    # this doesn't raise, since we passed None
2385
    utils.SafeWriteFile(name, None, data="")
2386

    
2387
  def testError(self):
2388
    t = tempfile.NamedTemporaryFile()
2389
    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
2390
                      path=t.name, fd=t.fileno())
2391

    
2392

    
2393
class TimeMock:
2394
  def __init__(self, values):
2395
    self.values = values
2396

    
2397
  def __call__(self):
2398
    return self.values.pop(0)
2399

    
2400

    
2401
class TestRunningTimeout(unittest.TestCase):
2402
  def setUp(self):
2403
    self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
2404

    
2405
  def testRemainingFloat(self):
2406
    timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
2407
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2408
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2409
    self.assertAlmostEqual(timeout.Remaining(), -1.5)
2410

    
2411
  def testRemaining(self):
2412
    self.time_fn = TimeMock([0, 2, 4, 5, 6])
2413
    timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
2414
    self.assertEqual(timeout.Remaining(), 3)
2415
    self.assertEqual(timeout.Remaining(), 1)
2416
    self.assertEqual(timeout.Remaining(), 0)
2417
    self.assertEqual(timeout.Remaining(), -1)
2418

    
2419
  def testRemainingNonNegative(self):
2420
    timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
2421
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2422
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2423
    self.assertEqual(timeout.Remaining(), 0.0)
2424

    
2425
  def testNegativeTimeout(self):
2426
    self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
2427

    
2428

    
2429
class TestTryConvert(unittest.TestCase):
2430
  def test(self):
2431
    for src, fn, result in [
2432
      ("1", int, 1),
2433
      ("a", int, "a"),
2434
      ("", bool, False),
2435
      ("a", bool, True),
2436
      ]:
2437
      self.assertEqual(utils.TryConvert(fn, src), result)
2438

    
2439

    
2440
class TestIsValidShellParam(unittest.TestCase):
2441
  def test(self):
2442
    for val, result in [
2443
      ("abc", True),
2444
      ("ab;cd", False),
2445
      ]:
2446
      self.assertEqual(utils.IsValidShellParam(val), result)
2447

    
2448

    
2449
class TestBuildShellCmd(unittest.TestCase):
2450
  def test(self):
2451
    self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
2452
                      "ls %s", "ab;cd")
2453
    self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
2454

    
2455

    
2456
class TestWriteFile(unittest.TestCase):
2457
  def setUp(self):
2458
    self.tfile = tempfile.NamedTemporaryFile()
2459
    self.did_pre = False
2460
    self.did_post = False
2461
    self.did_write = False
2462

    
2463
  def markPre(self, fd):
2464
    self.did_pre = True
2465

    
2466
  def markPost(self, fd):
2467
    self.did_post = True
2468

    
2469
  def markWrite(self, fd):
2470
    self.did_write = True
2471

    
2472
  def testWrite(self):
2473
    data = "abc"
2474
    utils.WriteFile(self.tfile.name, data=data)
2475
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
2476

    
2477
  def testErrors(self):
2478
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2479
                      self.tfile.name, data="test", fn=lambda fd: None)
2480
    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
2481
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2482
                      self.tfile.name, data="test", atime=0)
2483

    
2484
  def testCalls(self):
2485
    utils.WriteFile(self.tfile.name, fn=self.markWrite,
2486
                    prewrite=self.markPre, postwrite=self.markPost)
2487
    self.assertTrue(self.did_pre)
2488
    self.assertTrue(self.did_post)
2489
    self.assertTrue(self.did_write)
2490

    
2491
  def testDryRun(self):
2492
    orig = "abc"
2493
    self.tfile.write(orig)
2494
    self.tfile.flush()
2495
    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
2496
    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
2497

    
2498
  def testTimes(self):
2499
    f = self.tfile.name
2500
    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
2501
                   (int(time.time()), 5000)]:
2502
      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
2503
      st = os.stat(f)
2504
      self.assertEqual(st.st_atime, at)
2505
      self.assertEqual(st.st_mtime, mt)
2506

    
2507

    
2508
  def testNoClose(self):
2509
    data = "hello"
2510
    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
2511
    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
2512
    try:
2513
      os.lseek(fd, 0, 0)
2514
      self.assertEqual(os.read(fd, 4096), data)
2515
    finally:
2516
      os.close(fd)
2517

    
2518

    
2519
class TestNormalizeAndValidateMac(unittest.TestCase):
2520
  def testInvalid(self):
2521
    self.assertRaises(errors.OpPrereqError,
2522
                      utils.NormalizeAndValidateMac, "xxx")
2523

    
2524
  def testNormalization(self):
2525
    for mac in ["aa:bb:cc:dd:ee:ff", "00:AA:11:bB:22:cc"]:
2526
      self.assertEqual(utils.NormalizeAndValidateMac(mac), mac.lower())
2527

    
2528

    
2529
if __name__ == '__main__':
2530
  testutils.GanetiTestProgram()