Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ c74cda62

History | View | Annotate | Download (83.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
from cStringIO import StringIO
42

    
43
import testutils
44
from ganeti import constants
45
from ganeti import compat
46
from ganeti import utils
47
from ganeti import errors
48
from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
49
     ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
50
     TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
51
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
52

    
53

    
54
class TestIsProcessAlive(unittest.TestCase):
55
  """Testing case for IsProcessAlive"""
56

    
57
  def testExists(self):
58
    mypid = os.getpid()
59
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
60

    
61
  def testNotExisting(self):
62
    pid_non_existing = os.fork()
63
    if pid_non_existing == 0:
64
      os._exit(0)
65
    elif pid_non_existing < 0:
66
      raise SystemError("can't fork")
67
    os.waitpid(pid_non_existing, 0)
68
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
69
                     "nonexisting process detected")
70

    
71

    
72
class TestGetProcStatusPath(unittest.TestCase):
73
  def test(self):
74
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
75
    self.assertNotEqual(utils._GetProcStatusPath(1),
76
                        utils._GetProcStatusPath(2))
77

    
78

    
79
class TestIsProcessHandlingSignal(unittest.TestCase):
80
  def setUp(self):
81
    self.tmpdir = tempfile.mkdtemp()
82

    
83
  def tearDown(self):
84
    shutil.rmtree(self.tmpdir)
85

    
86
  def testParseSigsetT(self):
87
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
88
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
89
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
90
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
91
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
92
                     set([2, 10, 32, 33]))
93
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
94
                     set([2, 32, 33]))
95
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
96
                     set([2, 28, 32, 33]))
97
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
98
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
99
                          24, 25, 26, 28, 31]))
100
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
101

    
102
  def testGetProcStatusField(self):
103
    for field in ["SigCgt", "Name", "FDSize"]:
104
      for value in ["", "0", "cat", "  1234 KB"]:
105
        pstatus = "\n".join([
106
          "VmPeak: 999 kB",
107
          "%s: %s" % (field, value),
108
          "TracerPid: 0",
109
          ])
110
        result = utils._GetProcStatusField(pstatus, field)
111
        self.assertEqual(result, value.strip())
112

    
113
  def test(self):
114
    sp = PathJoin(self.tmpdir, "status")
115

    
116
    utils.WriteFile(sp, data="\n".join([
117
      "Name:   bash",
118
      "State:  S (sleeping)",
119
      "SleepAVG:       98%",
120
      "Pid:    22250",
121
      "PPid:   10858",
122
      "TracerPid:      0",
123
      "SigBlk: 0000000000010000",
124
      "SigIgn: 0000000000384004",
125
      "SigCgt: 000000004b813efb",
126
      "CapEff: 0000000000000000",
127
      ]))
128

    
129
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
130

    
131
  def testNoSigCgt(self):
132
    sp = PathJoin(self.tmpdir, "status")
133

    
134
    utils.WriteFile(sp, data="\n".join([
135
      "Name:   bash",
136
      ]))
137

    
138
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
139
                      1234, 10, status_path=sp)
140

    
141
  def testNoSuchFile(self):
142
    sp = PathJoin(self.tmpdir, "notexist")
143

    
144
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
145

    
146
  @staticmethod
147
  def _TestRealProcess():
148
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
149
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
150
      raise Exception("SIGUSR1 is handled when it should not be")
151

    
152
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
153
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
154
      raise Exception("SIGUSR1 is not handled when it should be")
155

    
156
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
157
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
158
      raise Exception("SIGUSR1 is not handled when it should be")
159

    
160
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
161
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
162
      raise Exception("SIGUSR1 is handled when it should not be")
163

    
164
    return True
165

    
166
  def testRealProcess(self):
167
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
168

    
169

    
170
class TestPidFileFunctions(unittest.TestCase):
171
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
172

    
173
  def setUp(self):
174
    self.dir = tempfile.mkdtemp()
175
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
176
    utils.DaemonPidFileName = self.f_dpn
177

    
178
  def testPidFileFunctions(self):
179
    pid_file = self.f_dpn('test')
180
    fd = utils.WritePidFile(self.f_dpn('test'))
181
    self.failUnless(os.path.exists(pid_file),
182
                    "PID file should have been created")
183
    read_pid = utils.ReadPidFile(pid_file)
184
    self.failUnlessEqual(read_pid, os.getpid())
185
    self.failUnless(utils.IsProcessAlive(read_pid))
186
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
187
                          self.f_dpn('test'))
188
    os.close(fd)
189
    utils.RemovePidFile('test')
190
    self.failIf(os.path.exists(pid_file),
191
                "PID file should not exist anymore")
192
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
193
                         "ReadPidFile should return 0 for missing pid file")
194
    fh = open(pid_file, "w")
195
    fh.write("blah\n")
196
    fh.close()
197
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
198
                         "ReadPidFile should return 0 for invalid pid file")
199
    # but now, even with the file existing, we should be able to lock it
200
    fd = utils.WritePidFile(self.f_dpn('test'))
201
    os.close(fd)
202
    utils.RemovePidFile('test')
203
    self.failIf(os.path.exists(pid_file),
204
                "PID file should not exist anymore")
205

    
206
  def testKill(self):
207
    pid_file = self.f_dpn('child')
208
    r_fd, w_fd = os.pipe()
209
    new_pid = os.fork()
210
    if new_pid == 0: #child
211
      utils.WritePidFile(self.f_dpn('child'))
212
      os.write(w_fd, 'a')
213
      signal.pause()
214
      os._exit(0)
215
      return
216
    # else we are in the parent
217
    # wait until the child has written the pid file
218
    os.read(r_fd, 1)
219
    read_pid = utils.ReadPidFile(pid_file)
220
    self.failUnlessEqual(read_pid, new_pid)
221
    self.failUnless(utils.IsProcessAlive(new_pid))
222
    utils.KillProcess(new_pid, waitpid=True)
223
    self.failIf(utils.IsProcessAlive(new_pid))
224
    utils.RemovePidFile('child')
225
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
226

    
227
  def tearDown(self):
228
    for name in os.listdir(self.dir):
229
      os.unlink(os.path.join(self.dir, name))
230
    os.rmdir(self.dir)
231

    
232

    
233
class TestRunCmd(testutils.GanetiTestCase):
234
  """Testing case for the RunCmd function"""
235

    
236
  def setUp(self):
237
    testutils.GanetiTestCase.setUp(self)
238
    self.magic = time.ctime() + " ganeti test"
239
    self.fname = self._CreateTempFile()
240
    self.fifo_tmpdir = tempfile.mkdtemp()
241
    self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
242
    os.mkfifo(self.fifo_file)
243

    
244
  def tearDown(self):
245
    shutil.rmtree(self.fifo_tmpdir)
246

    
247
  def testOk(self):
248
    """Test successful exit code"""
249
    result = RunCmd("/bin/sh -c 'exit 0'")
250
    self.assertEqual(result.exit_code, 0)
251
    self.assertEqual(result.output, "")
252

    
253
  def testFail(self):
254
    """Test fail exit code"""
255
    result = RunCmd("/bin/sh -c 'exit 1'")
256
    self.assertEqual(result.exit_code, 1)
257
    self.assertEqual(result.output, "")
258

    
259
  def testStdout(self):
260
    """Test standard output"""
261
    cmd = 'echo -n "%s"' % self.magic
262
    result = RunCmd("/bin/sh -c '%s'" % cmd)
263
    self.assertEqual(result.stdout, self.magic)
264
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
265
    self.assertEqual(result.output, "")
266
    self.assertFileContent(self.fname, self.magic)
267

    
268
  def testStderr(self):
269
    """Test standard error"""
270
    cmd = 'echo -n "%s"' % self.magic
271
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
272
    self.assertEqual(result.stderr, self.magic)
273
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
274
    self.assertEqual(result.output, "")
275
    self.assertFileContent(self.fname, self.magic)
276

    
277
  def testCombined(self):
278
    """Test combined output"""
279
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
280
    expected = "A" + self.magic + "B" + self.magic
281
    result = RunCmd("/bin/sh -c '%s'" % cmd)
282
    self.assertEqual(result.output, expected)
283
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
284
    self.assertEqual(result.output, "")
285
    self.assertFileContent(self.fname, expected)
286

    
287
  def testSignal(self):
288
    """Test signal"""
289
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
290
    self.assertEqual(result.signal, 15)
291
    self.assertEqual(result.output, "")
292

    
293
  def testTimeoutClean(self):
294
    cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
295
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
296
    self.assertEqual(result.exit_code, 0)
297

    
298
  def testTimeoutCleanInteractive(self):
299
    cmd = "trap 'exit 0' TERM; read"
300
    result = RunCmd(["/bin/sh", "-c", cmd], interactive=True, timeout=0.2)
301
    self.assertEqual(result.exit_code, 0)
302

    
303
  def testTimeoutNonClean(self):
304
    for exit_code in (1, 10, 17, 29):
305
      cmd = "trap 'exit %i' TERM; read" % exit_code
306
      result = RunCmd(["/bin/sh", "-c", cmd], interactive=True, timeout=0.2)
307
      self.assert_(result.failed)
308
      self.assertEqual(result.exit_code, exit_code)
309

    
310
  def testTimeoutKill(self):
311
    cmd = "trap '' TERM; read < %s" % self.fifo_file
312
    timeout = 0.2
313
    strcmd = utils.ShellQuoteArgs(["/bin/sh", "-c", cmd])
314
    out, err, status, ta = utils._RunCmdPipe(strcmd, {}, True, "/", False,
315
                                             timeout, _linger_timeout=0.2)
316
    self.assert_(status < 0)
317
    self.assertEqual(-status, signal.SIGKILL)
318

    
319
  def testTimeoutOutputAfterTerm(self):
320
    cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
321
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
322
    self.assert_(result.failed)
323
    self.assertEqual(result.stdout, "sigtermed\n")
324

    
325
  def testListRun(self):
326
    """Test list runs"""
327
    result = RunCmd(["true"])
328
    self.assertEqual(result.signal, None)
329
    self.assertEqual(result.exit_code, 0)
330
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
331
    self.assertEqual(result.signal, None)
332
    self.assertEqual(result.exit_code, 1)
333
    result = RunCmd(["echo", "-n", self.magic])
334
    self.assertEqual(result.signal, None)
335
    self.assertEqual(result.exit_code, 0)
336
    self.assertEqual(result.stdout, self.magic)
337

    
338
  def testFileEmptyOutput(self):
339
    """Test file output"""
340
    result = RunCmd(["true"], output=self.fname)
341
    self.assertEqual(result.signal, None)
342
    self.assertEqual(result.exit_code, 0)
343
    self.assertFileContent(self.fname, "")
344

    
345
  def testLang(self):
346
    """Test locale environment"""
347
    old_env = os.environ.copy()
348
    try:
349
      os.environ["LANG"] = "en_US.UTF-8"
350
      os.environ["LC_ALL"] = "en_US.UTF-8"
351
      result = RunCmd(["locale"])
352
      for line in result.output.splitlines():
353
        key, value = line.split("=", 1)
354
        # Ignore these variables, they're overridden by LC_ALL
355
        if key == "LANG" or key == "LANGUAGE":
356
          continue
357
        self.failIf(value and value != "C" and value != '"C"',
358
            "Variable %s is set to the invalid value '%s'" % (key, value))
359
    finally:
360
      os.environ = old_env
361

    
362
  def testDefaultCwd(self):
363
    """Test default working directory"""
364
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
365

    
366
  def testCwd(self):
367
    """Test default working directory"""
368
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
369
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
370
    cwd = os.getcwd()
371
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
372

    
373
  def testResetEnv(self):
374
    """Test environment reset functionality"""
375
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
376
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
377
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
378

    
379

    
380
class TestRunParts(unittest.TestCase):
381
  """Testing case for the RunParts function"""
382

    
383
  def setUp(self):
384
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
385

    
386
  def tearDown(self):
387
    shutil.rmtree(self.rundir)
388

    
389
  def testEmpty(self):
390
    """Test on an empty dir"""
391
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
392

    
393
  def testSkipWrongName(self):
394
    """Test that wrong files are skipped"""
395
    fname = os.path.join(self.rundir, "00test.dot")
396
    utils.WriteFile(fname, data="")
397
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
398
    relname = os.path.basename(fname)
399
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
400
                         [(relname, constants.RUNPARTS_SKIP, None)])
401

    
402
  def testSkipNonExec(self):
403
    """Test that non executable files are skipped"""
404
    fname = os.path.join(self.rundir, "00test")
405
    utils.WriteFile(fname, data="")
406
    relname = os.path.basename(fname)
407
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
408
                         [(relname, constants.RUNPARTS_SKIP, None)])
409

    
410
  def testError(self):
411
    """Test error on a broken executable"""
412
    fname = os.path.join(self.rundir, "00test")
413
    utils.WriteFile(fname, data="")
414
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
415
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
416
    self.failUnlessEqual(relname, os.path.basename(fname))
417
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
418
    self.failUnless(error)
419

    
420
  def testSorted(self):
421
    """Test executions are sorted"""
422
    files = []
423
    files.append(os.path.join(self.rundir, "64test"))
424
    files.append(os.path.join(self.rundir, "00test"))
425
    files.append(os.path.join(self.rundir, "42test"))
426

    
427
    for fname in files:
428
      utils.WriteFile(fname, data="")
429

    
430
    results = RunParts(self.rundir, reset_env=True)
431

    
432
    for fname in sorted(files):
433
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
434

    
435
  def testOk(self):
436
    """Test correct execution"""
437
    fname = os.path.join(self.rundir, "00test")
438
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
439
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
440
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
441
    self.failUnlessEqual(relname, os.path.basename(fname))
442
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
443
    self.failUnlessEqual(runresult.stdout, "ciao")
444

    
445
  def testRunFail(self):
446
    """Test correct execution, with run failure"""
447
    fname = os.path.join(self.rundir, "00test")
448
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
449
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
450
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
451
    self.failUnlessEqual(relname, os.path.basename(fname))
452
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
453
    self.failUnlessEqual(runresult.exit_code, 1)
454
    self.failUnless(runresult.failed)
455

    
456
  def testRunMix(self):
457
    files = []
458
    files.append(os.path.join(self.rundir, "00test"))
459
    files.append(os.path.join(self.rundir, "42test"))
460
    files.append(os.path.join(self.rundir, "64test"))
461
    files.append(os.path.join(self.rundir, "99test"))
462

    
463
    files.sort()
464

    
465
    # 1st has errors in execution
466
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
467
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
468

    
469
    # 2nd is skipped
470
    utils.WriteFile(files[1], data="")
471

    
472
    # 3rd cannot execute properly
473
    utils.WriteFile(files[2], data="")
474
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
475

    
476
    # 4th execs
477
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
478
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
479

    
480
    results = RunParts(self.rundir, reset_env=True)
481

    
482
    (relname, status, runresult) = results[0]
483
    self.failUnlessEqual(relname, os.path.basename(files[0]))
484
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
485
    self.failUnlessEqual(runresult.exit_code, 1)
486
    self.failUnless(runresult.failed)
487

    
488
    (relname, status, runresult) = results[1]
489
    self.failUnlessEqual(relname, os.path.basename(files[1]))
490
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
491
    self.failUnlessEqual(runresult, None)
492

    
493
    (relname, status, runresult) = results[2]
494
    self.failUnlessEqual(relname, os.path.basename(files[2]))
495
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
496
    self.failUnless(runresult)
497

    
498
    (relname, status, runresult) = results[3]
499
    self.failUnlessEqual(relname, os.path.basename(files[3]))
500
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
501
    self.failUnlessEqual(runresult.output, "ciao")
502
    self.failUnlessEqual(runresult.exit_code, 0)
503
    self.failUnless(not runresult.failed)
504

    
505

    
506
class TestStartDaemon(testutils.GanetiTestCase):
507
  def setUp(self):
508
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
509
    self.tmpfile = os.path.join(self.tmpdir, "test")
510

    
511
  def tearDown(self):
512
    shutil.rmtree(self.tmpdir)
513

    
514
  def testShell(self):
515
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
516
    self._wait(self.tmpfile, 60.0, "Hello World")
517

    
518
  def testShellOutput(self):
519
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
520
    self._wait(self.tmpfile, 60.0, "Hello World")
521

    
522
  def testNoShellNoOutput(self):
523
    utils.StartDaemon(["pwd"])
524

    
525
  def testNoShellNoOutputTouch(self):
526
    testfile = os.path.join(self.tmpdir, "check")
527
    self.failIf(os.path.exists(testfile))
528
    utils.StartDaemon(["touch", testfile])
529
    self._wait(testfile, 60.0, "")
530

    
531
  def testNoShellOutput(self):
532
    utils.StartDaemon(["pwd"], output=self.tmpfile)
533
    self._wait(self.tmpfile, 60.0, "/")
534

    
535
  def testNoShellOutputCwd(self):
536
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
537
    self._wait(self.tmpfile, 60.0, os.getcwd())
538

    
539
  def testShellEnv(self):
540
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
541
                      env={ "GNT_TEST_VAR": "Hello World", })
542
    self._wait(self.tmpfile, 60.0, "Hello World")
543

    
544
  def testNoShellEnv(self):
545
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
546
                      env={ "GNT_TEST_VAR": "Hello World", })
547
    self._wait(self.tmpfile, 60.0, "Hello World")
548

    
549
  def testOutputFd(self):
550
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
551
    try:
552
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
553
    finally:
554
      os.close(fd)
555
    self._wait(self.tmpfile, 60.0, os.getcwd())
556

    
557
  def testPid(self):
558
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
559
    self._wait(self.tmpfile, 60.0, str(pid))
560

    
561
  def testPidFile(self):
562
    pidfile = os.path.join(self.tmpdir, "pid")
563
    checkfile = os.path.join(self.tmpdir, "abort")
564

    
565
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
566
                            output=self.tmpfile)
567
    try:
568
      fd = os.open(pidfile, os.O_RDONLY)
569
      try:
570
        # Check file is locked
571
        self.assertRaises(errors.LockError, utils.LockFile, fd)
572

    
573
        pidtext = os.read(fd, 100)
574
      finally:
575
        os.close(fd)
576

    
577
      self.assertEqual(int(pidtext.strip()), pid)
578

    
579
      self.assert_(utils.IsProcessAlive(pid))
580
    finally:
581
      # No matter what happens, kill daemon
582
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
583
      self.failIf(utils.IsProcessAlive(pid))
584

    
585
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
586

    
587
  def _wait(self, path, timeout, expected):
588
    # Due to the asynchronous nature of daemon processes, polling is necessary.
589
    # A timeout makes sure the test doesn't hang forever.
590
    def _CheckFile():
591
      if not (os.path.isfile(path) and
592
              utils.ReadFile(path).strip() == expected):
593
        raise utils.RetryAgain()
594

    
595
    try:
596
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
597
    except utils.RetryTimeout:
598
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
599
                " didn't write the correct output" % timeout)
600

    
601
  def testError(self):
602
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
603
                      ["./does-NOT-EXIST/here/0123456789"])
604
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
605
                      ["./does-NOT-EXIST/here/0123456789"],
606
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
607
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
608
                      ["./does-NOT-EXIST/here/0123456789"],
609
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
610
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
611
                      ["./does-NOT-EXIST/here/0123456789"],
612
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
613

    
614
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
615
    try:
616
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
617
                        ["./does-NOT-EXIST/here/0123456789"],
618
                        output=self.tmpfile, output_fd=fd)
619
    finally:
620
      os.close(fd)
621

    
622

    
623
class TestSetCloseOnExecFlag(unittest.TestCase):
624
  """Tests for SetCloseOnExecFlag"""
625

    
626
  def setUp(self):
627
    self.tmpfile = tempfile.TemporaryFile()
628

    
629
  def testEnable(self):
630
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
631
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
632
                    fcntl.FD_CLOEXEC)
633

    
634
  def testDisable(self):
635
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
636
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
637
                fcntl.FD_CLOEXEC)
638

    
639

    
640
class TestSetNonblockFlag(unittest.TestCase):
641
  def setUp(self):
642
    self.tmpfile = tempfile.TemporaryFile()
643

    
644
  def testEnable(self):
645
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
646
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
647
                    os.O_NONBLOCK)
648

    
649
  def testDisable(self):
650
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
651
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
652
                os.O_NONBLOCK)
653

    
654

    
655
class TestRemoveFile(unittest.TestCase):
656
  """Test case for the RemoveFile function"""
657

    
658
  def setUp(self):
659
    """Create a temp dir and file for each case"""
660
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
661
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
662
    os.close(fd)
663

    
664
  def tearDown(self):
665
    if os.path.exists(self.tmpfile):
666
      os.unlink(self.tmpfile)
667
    os.rmdir(self.tmpdir)
668

    
669
  def testIgnoreDirs(self):
670
    """Test that RemoveFile() ignores directories"""
671
    self.assertEqual(None, RemoveFile(self.tmpdir))
672

    
673
  def testIgnoreNotExisting(self):
674
    """Test that RemoveFile() ignores non-existing files"""
675
    RemoveFile(self.tmpfile)
676
    RemoveFile(self.tmpfile)
677

    
678
  def testRemoveFile(self):
679
    """Test that RemoveFile does remove a file"""
680
    RemoveFile(self.tmpfile)
681
    if os.path.exists(self.tmpfile):
682
      self.fail("File '%s' not removed" % self.tmpfile)
683

    
684
  def testRemoveSymlink(self):
685
    """Test that RemoveFile does remove symlinks"""
686
    symlink = self.tmpdir + "/symlink"
687
    os.symlink("no-such-file", symlink)
688
    RemoveFile(symlink)
689
    if os.path.exists(symlink):
690
      self.fail("File '%s' not removed" % symlink)
691
    os.symlink(self.tmpfile, symlink)
692
    RemoveFile(symlink)
693
    if os.path.exists(symlink):
694
      self.fail("File '%s' not removed" % symlink)
695

    
696

    
697
class TestRename(unittest.TestCase):
698
  """Test case for RenameFile"""
699

    
700
  def setUp(self):
701
    """Create a temporary directory"""
702
    self.tmpdir = tempfile.mkdtemp()
703
    self.tmpfile = os.path.join(self.tmpdir, "test1")
704

    
705
    # Touch the file
706
    open(self.tmpfile, "w").close()
707

    
708
  def tearDown(self):
709
    """Remove temporary directory"""
710
    shutil.rmtree(self.tmpdir)
711

    
712
  def testSimpleRename1(self):
713
    """Simple rename 1"""
714
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
715
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
716

    
717
  def testSimpleRename2(self):
718
    """Simple rename 2"""
719
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
720
                     mkdir=True)
721
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
722

    
723
  def testRenameMkdir(self):
724
    """Rename with mkdir"""
725
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
726
                     mkdir=True)
727
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
728
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
729

    
730
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
731
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
732
                     mkdir=True)
733
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
734
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
735
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
736

    
737

    
738
class TestMatchNameComponent(unittest.TestCase):
739
  """Test case for the MatchNameComponent function"""
740

    
741
  def testEmptyList(self):
742
    """Test that there is no match against an empty list"""
743

    
744
    self.failUnlessEqual(MatchNameComponent("", []), None)
745
    self.failUnlessEqual(MatchNameComponent("test", []), None)
746

    
747
  def testSingleMatch(self):
748
    """Test that a single match is performed correctly"""
749
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
750
    for key in "test2", "test2.example", "test2.example.com":
751
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
752

    
753
  def testMultipleMatches(self):
754
    """Test that a multiple match is returned as None"""
755
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
756
    for key in "test1", "test1.example":
757
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
758

    
759
  def testFullMatch(self):
760
    """Test that a full match is returned correctly"""
761
    key1 = "test1"
762
    key2 = "test1.example"
763
    mlist = [key2, key2 + ".com"]
764
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
765
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
766

    
767
  def testCaseInsensitivePartialMatch(self):
768
    """Test for the case_insensitive keyword"""
769
    mlist = ["test1.example.com", "test2.example.net"]
770
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
771
                     "test2.example.net")
772
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
773
                     "test2.example.net")
774
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
775
                     "test2.example.net")
776
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
777
                     "test2.example.net")
778

    
779

    
780
  def testCaseInsensitiveFullMatch(self):
781
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
782
    # Between the two ts1 a full string match non-case insensitive should work
783
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
784
                     None)
785
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
786
                     "ts1.ex")
787
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
788
                     "ts1.ex")
789
    # Between the two ts2 only case differs, so only case-match works
790
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
791
                     "ts2.ex")
792
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
793
                     "Ts2.ex")
794
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
795
                     None)
796

    
797

    
798
class TestReadFile(testutils.GanetiTestCase):
799

    
800
  def testReadAll(self):
801
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
802
    self.assertEqual(len(data), 814)
803

    
804
    h = compat.md5_hash()
805
    h.update(data)
806
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
807

    
808
  def testReadSize(self):
809
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
810
                          size=100)
811
    self.assertEqual(len(data), 100)
812

    
813
    h = compat.md5_hash()
814
    h.update(data)
815
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
816

    
817
  def testError(self):
818
    self.assertRaises(EnvironmentError, utils.ReadFile,
819
                      "/dev/null/does-not-exist")
820

    
821

    
822
class TestReadOneLineFile(testutils.GanetiTestCase):
823

    
824
  def setUp(self):
825
    testutils.GanetiTestCase.setUp(self)
826

    
827
  def testDefault(self):
828
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
829
    self.assertEqual(len(data), 27)
830
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
831

    
832
  def testNotStrict(self):
833
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
834
    self.assertEqual(len(data), 27)
835
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
836

    
837
  def testStrictFailure(self):
838
    self.assertRaises(errors.GenericError, ReadOneLineFile,
839
                      self._TestDataFilename("cert1.pem"), strict=True)
840

    
841
  def testLongLine(self):
842
    dummydata = (1024 * "Hello World! ")
843
    myfile = self._CreateTempFile()
844
    utils.WriteFile(myfile, data=dummydata)
845
    datastrict = ReadOneLineFile(myfile, strict=True)
846
    datalax = ReadOneLineFile(myfile, strict=False)
847
    self.assertEqual(dummydata, datastrict)
848
    self.assertEqual(dummydata, datalax)
849

    
850
  def testNewline(self):
851
    myfile = self._CreateTempFile()
852
    myline = "myline"
853
    for nl in ["", "\n", "\r\n"]:
854
      dummydata = "%s%s" % (myline, nl)
855
      utils.WriteFile(myfile, data=dummydata)
856
      datalax = ReadOneLineFile(myfile, strict=False)
857
      self.assertEqual(myline, datalax)
858
      datastrict = ReadOneLineFile(myfile, strict=True)
859
      self.assertEqual(myline, datastrict)
860

    
861
  def testWhitespaceAndMultipleLines(self):
862
    myfile = self._CreateTempFile()
863
    for nl in ["", "\n", "\r\n"]:
864
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
865
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
866
        utils.WriteFile(myfile, data=dummydata)
867
        datalax = ReadOneLineFile(myfile, strict=False)
868
        if nl:
869
          self.assert_(set("\r\n") & set(dummydata))
870
          self.assertRaises(errors.GenericError, ReadOneLineFile,
871
                            myfile, strict=True)
872
          explen = len("Foo bar baz ") + len(ws)
873
          self.assertEqual(len(datalax), explen)
874
          self.assertEqual(datalax, dummydata[:explen])
875
          self.assertFalse(set("\r\n") & set(datalax))
876
        else:
877
          datastrict = ReadOneLineFile(myfile, strict=True)
878
          self.assertEqual(dummydata, datastrict)
879
          self.assertEqual(dummydata, datalax)
880

    
881
  def testEmptylines(self):
882
    myfile = self._CreateTempFile()
883
    myline = "myline"
884
    for nl in ["\n", "\r\n"]:
885
      for ol in ["", "otherline"]:
886
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
887
        utils.WriteFile(myfile, data=dummydata)
888
        self.assert_(set("\r\n") & set(dummydata))
889
        datalax = ReadOneLineFile(myfile, strict=False)
890
        self.assertEqual(myline, datalax)
891
        if ol:
892
          self.assertRaises(errors.GenericError, ReadOneLineFile,
893
                            myfile, strict=True)
894
        else:
895
          datastrict = ReadOneLineFile(myfile, strict=True)
896
          self.assertEqual(myline, datastrict)
897

    
898

    
899
class TestTimestampForFilename(unittest.TestCase):
900
  def test(self):
901
    self.assert_("." not in utils.TimestampForFilename())
902
    self.assert_(":" not in utils.TimestampForFilename())
903

    
904

    
905
class TestCreateBackup(testutils.GanetiTestCase):
906
  def setUp(self):
907
    testutils.GanetiTestCase.setUp(self)
908

    
909
    self.tmpdir = tempfile.mkdtemp()
910

    
911
  def tearDown(self):
912
    testutils.GanetiTestCase.tearDown(self)
913

    
914
    shutil.rmtree(self.tmpdir)
915

    
916
  def testEmpty(self):
917
    filename = PathJoin(self.tmpdir, "config.data")
918
    utils.WriteFile(filename, data="")
919
    bname = utils.CreateBackup(filename)
920
    self.assertFileContent(bname, "")
921
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
922
    utils.CreateBackup(filename)
923
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
924
    utils.CreateBackup(filename)
925
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
926

    
927
    fifoname = PathJoin(self.tmpdir, "fifo")
928
    os.mkfifo(fifoname)
929
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
930

    
931
  def testContent(self):
932
    bkpcount = 0
933
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
934
      for rep in [1, 2, 10, 127]:
935
        testdata = data * rep
936

    
937
        filename = PathJoin(self.tmpdir, "test.data_")
938
        utils.WriteFile(filename, data=testdata)
939
        self.assertFileContent(filename, testdata)
940

    
941
        for _ in range(3):
942
          bname = utils.CreateBackup(filename)
943
          bkpcount += 1
944
          self.assertFileContent(bname, testdata)
945
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
946

    
947

    
948
class TestFormatUnit(unittest.TestCase):
949
  """Test case for the FormatUnit function"""
950

    
951
  def testMiB(self):
952
    self.assertEqual(FormatUnit(1, 'h'), '1M')
953
    self.assertEqual(FormatUnit(100, 'h'), '100M')
954
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
955

    
956
    self.assertEqual(FormatUnit(1, 'm'), '1')
957
    self.assertEqual(FormatUnit(100, 'm'), '100')
958
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
959

    
960
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
961
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
962
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
963
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
964

    
965
  def testGiB(self):
966
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
967
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
968
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
969
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
970

    
971
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
972
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
973
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
974
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
975

    
976
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
977
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
978
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
979

    
980
  def testTiB(self):
981
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
982
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
983
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
984

    
985
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
986
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
987
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
988

    
989

    
990
class TestParseUnit(unittest.TestCase):
991
  """Test case for the ParseUnit function"""
992

    
993
  SCALES = (('', 1),
994
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
995
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
996
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
997

    
998
  def testRounding(self):
999
    self.assertEqual(ParseUnit('0'), 0)
1000
    self.assertEqual(ParseUnit('1'), 4)
1001
    self.assertEqual(ParseUnit('2'), 4)
1002
    self.assertEqual(ParseUnit('3'), 4)
1003

    
1004
    self.assertEqual(ParseUnit('124'), 124)
1005
    self.assertEqual(ParseUnit('125'), 128)
1006
    self.assertEqual(ParseUnit('126'), 128)
1007
    self.assertEqual(ParseUnit('127'), 128)
1008
    self.assertEqual(ParseUnit('128'), 128)
1009
    self.assertEqual(ParseUnit('129'), 132)
1010
    self.assertEqual(ParseUnit('130'), 132)
1011

    
1012
  def testFloating(self):
1013
    self.assertEqual(ParseUnit('0'), 0)
1014
    self.assertEqual(ParseUnit('0.5'), 4)
1015
    self.assertEqual(ParseUnit('1.75'), 4)
1016
    self.assertEqual(ParseUnit('1.99'), 4)
1017
    self.assertEqual(ParseUnit('2.00'), 4)
1018
    self.assertEqual(ParseUnit('2.01'), 4)
1019
    self.assertEqual(ParseUnit('3.99'), 4)
1020
    self.assertEqual(ParseUnit('4.00'), 4)
1021
    self.assertEqual(ParseUnit('4.01'), 8)
1022
    self.assertEqual(ParseUnit('1.5G'), 1536)
1023
    self.assertEqual(ParseUnit('1.8G'), 1844)
1024
    self.assertEqual(ParseUnit('8.28T'), 8682212)
1025

    
1026
  def testSuffixes(self):
1027
    for sep in ('', ' ', '   ', "\t", "\t "):
1028
      for suffix, scale in TestParseUnit.SCALES:
1029
        for func in (lambda x: x, str.lower, str.upper):
1030
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
1031
                           1024 * scale)
1032

    
1033
  def testInvalidInput(self):
1034
    for sep in ('-', '_', ',', 'a'):
1035
      for suffix, _ in TestParseUnit.SCALES:
1036
        self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
1037

    
1038
    for suffix, _ in TestParseUnit.SCALES:
1039
      self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
1040

    
1041

    
1042
class TestParseCpuMask(unittest.TestCase):
1043
  """Test case for the ParseCpuMask function."""
1044

    
1045
  def testWellFormed(self):
1046
    self.assertEqual(utils.ParseCpuMask(""), [])
1047
    self.assertEqual(utils.ParseCpuMask("1"), [1])
1048
    self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
1049

    
1050
  def testInvalidInput(self):
1051
    self.assertRaises(errors.ParseError,
1052
                      utils.ParseCpuMask,
1053
                      "garbage")
1054
    self.assertRaises(errors.ParseError,
1055
                      utils.ParseCpuMask,
1056
                      "0,")
1057
    self.assertRaises(errors.ParseError,
1058
                      utils.ParseCpuMask,
1059
                      "0-1-2")
1060
    self.assertRaises(errors.ParseError,
1061
                      utils.ParseCpuMask,
1062
                      "2-1")
1063

    
1064
class TestSshKeys(testutils.GanetiTestCase):
1065
  """Test case for the AddAuthorizedKey function"""
1066

    
1067
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1068
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1069
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1070

    
1071
  def setUp(self):
1072
    testutils.GanetiTestCase.setUp(self)
1073
    self.tmpname = self._CreateTempFile()
1074
    handle = open(self.tmpname, 'w')
1075
    try:
1076
      handle.write("%s\n" % TestSshKeys.KEY_A)
1077
      handle.write("%s\n" % TestSshKeys.KEY_B)
1078
    finally:
1079
      handle.close()
1080

    
1081
  def testAddingNewKey(self):
1082
    utils.AddAuthorizedKey(self.tmpname,
1083
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1084

    
1085
    self.assertFileContent(self.tmpname,
1086
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1087
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1088
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1089
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1090

    
1091
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1092
    utils.AddAuthorizedKey(self.tmpname,
1093
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1094

    
1095
    self.assertFileContent(self.tmpname,
1096
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1097
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1098
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1099
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1100

    
1101
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1102
    utils.AddAuthorizedKey(self.tmpname,
1103
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1104

    
1105
    self.assertFileContent(self.tmpname,
1106
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1107
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1108
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1109

    
1110
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1111
    utils.RemoveAuthorizedKey(self.tmpname,
1112
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1113

    
1114
    self.assertFileContent(self.tmpname,
1115
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1116
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1117

    
1118
  def testRemovingNonExistingKey(self):
1119
    utils.RemoveAuthorizedKey(self.tmpname,
1120
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1121

    
1122
    self.assertFileContent(self.tmpname,
1123
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1124
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1125
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1126

    
1127

    
1128
class TestEtcHosts(testutils.GanetiTestCase):
1129
  """Test functions modifying /etc/hosts"""
1130

    
1131
  def setUp(self):
1132
    testutils.GanetiTestCase.setUp(self)
1133
    self.tmpname = self._CreateTempFile()
1134
    handle = open(self.tmpname, 'w')
1135
    try:
1136
      handle.write('# This is a test file for /etc/hosts\n')
1137
      handle.write('127.0.0.1\tlocalhost\n')
1138
      handle.write('192.0.2.1 router gw\n')
1139
    finally:
1140
      handle.close()
1141

    
1142
  def testSettingNewIp(self):
1143
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1144
                     ['myhost'])
1145

    
1146
    self.assertFileContent(self.tmpname,
1147
      "# This is a test file for /etc/hosts\n"
1148
      "127.0.0.1\tlocalhost\n"
1149
      "192.0.2.1 router gw\n"
1150
      "198.51.100.4\tmyhost.example.com myhost\n")
1151
    self.assertFileMode(self.tmpname, 0644)
1152

    
1153
  def testSettingExistingIp(self):
1154
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1155
                     ['myhost'])
1156

    
1157
    self.assertFileContent(self.tmpname,
1158
      "# This is a test file for /etc/hosts\n"
1159
      "127.0.0.1\tlocalhost\n"
1160
      "192.0.2.1\tmyhost.example.com myhost\n")
1161
    self.assertFileMode(self.tmpname, 0644)
1162

    
1163
  def testSettingDuplicateName(self):
1164
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1165

    
1166
    self.assertFileContent(self.tmpname,
1167
      "# This is a test file for /etc/hosts\n"
1168
      "127.0.0.1\tlocalhost\n"
1169
      "192.0.2.1 router gw\n"
1170
      "198.51.100.4\tmyhost\n")
1171
    self.assertFileMode(self.tmpname, 0644)
1172

    
1173
  def testRemovingExistingHost(self):
1174
    RemoveEtcHostsEntry(self.tmpname, 'router')
1175

    
1176
    self.assertFileContent(self.tmpname,
1177
      "# This is a test file for /etc/hosts\n"
1178
      "127.0.0.1\tlocalhost\n"
1179
      "192.0.2.1 gw\n")
1180
    self.assertFileMode(self.tmpname, 0644)
1181

    
1182
  def testRemovingSingleExistingHost(self):
1183
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1184

    
1185
    self.assertFileContent(self.tmpname,
1186
      "# This is a test file for /etc/hosts\n"
1187
      "192.0.2.1 router gw\n")
1188
    self.assertFileMode(self.tmpname, 0644)
1189

    
1190
  def testRemovingNonExistingHost(self):
1191
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1192

    
1193
    self.assertFileContent(self.tmpname,
1194
      "# This is a test file for /etc/hosts\n"
1195
      "127.0.0.1\tlocalhost\n"
1196
      "192.0.2.1 router gw\n")
1197
    self.assertFileMode(self.tmpname, 0644)
1198

    
1199
  def testRemovingAlias(self):
1200
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1201

    
1202
    self.assertFileContent(self.tmpname,
1203
      "# This is a test file for /etc/hosts\n"
1204
      "127.0.0.1\tlocalhost\n"
1205
      "192.0.2.1 router\n")
1206
    self.assertFileMode(self.tmpname, 0644)
1207

    
1208

    
1209
class TestGetMounts(unittest.TestCase):
1210
  """Test case for GetMounts()."""
1211

    
1212
  TESTDATA = (
1213
    "rootfs /     rootfs rw 0 0\n"
1214
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1215
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1216

    
1217
  def setUp(self):
1218
    self.tmpfile = tempfile.NamedTemporaryFile()
1219
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1220

    
1221
  def testGetMounts(self):
1222
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1223
      [
1224
        ("rootfs", "/", "rootfs", "rw"),
1225
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1226
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1227
      ])
1228

    
1229

    
1230
class TestShellQuoting(unittest.TestCase):
1231
  """Test case for shell quoting functions"""
1232

    
1233
  def testShellQuote(self):
1234
    self.assertEqual(ShellQuote('abc'), "abc")
1235
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1236
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1237
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1238
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1239

    
1240
  def testShellQuoteArgs(self):
1241
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1242
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1243
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1244

    
1245

    
1246
class TestListVisibleFiles(unittest.TestCase):
1247
  """Test case for ListVisibleFiles"""
1248

    
1249
  def setUp(self):
1250
    self.path = tempfile.mkdtemp()
1251

    
1252
  def tearDown(self):
1253
    shutil.rmtree(self.path)
1254

    
1255
  def _CreateFiles(self, files):
1256
    for name in files:
1257
      utils.WriteFile(os.path.join(self.path, name), data="test")
1258

    
1259
  def _test(self, files, expected):
1260
    self._CreateFiles(files)
1261
    found = ListVisibleFiles(self.path)
1262
    self.assertEqual(set(found), set(expected))
1263

    
1264
  def testAllVisible(self):
1265
    files = ["a", "b", "c"]
1266
    expected = files
1267
    self._test(files, expected)
1268

    
1269
  def testNoneVisible(self):
1270
    files = [".a", ".b", ".c"]
1271
    expected = []
1272
    self._test(files, expected)
1273

    
1274
  def testSomeVisible(self):
1275
    files = ["a", "b", ".c"]
1276
    expected = ["a", "b"]
1277
    self._test(files, expected)
1278

    
1279
  def testNonAbsolutePath(self):
1280
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1281

    
1282
  def testNonNormalizedPath(self):
1283
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1284
                          "/bin/../tmp")
1285

    
1286

    
1287
class TestNewUUID(unittest.TestCase):
1288
  """Test case for NewUUID"""
1289

    
1290
  def runTest(self):
1291
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1292

    
1293

    
1294
class TestUniqueSequence(unittest.TestCase):
1295
  """Test case for UniqueSequence"""
1296

    
1297
  def _test(self, input, expected):
1298
    self.assertEqual(utils.UniqueSequence(input), expected)
1299

    
1300
  def runTest(self):
1301
    # Ordered input
1302
    self._test([1, 2, 3], [1, 2, 3])
1303
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1304
    self._test([1, 2, 2, 3], [1, 2, 3])
1305
    self._test([1, 2, 3, 3], [1, 2, 3])
1306

    
1307
    # Unordered input
1308
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1309
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1310

    
1311
    # Strings
1312
    self._test(["a", "a"], ["a"])
1313
    self._test(["a", "b"], ["a", "b"])
1314
    self._test(["a", "b", "a"], ["a", "b"])
1315

    
1316

    
1317
class TestFirstFree(unittest.TestCase):
1318
  """Test case for the FirstFree function"""
1319

    
1320
  def test(self):
1321
    """Test FirstFree"""
1322
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1323
    self.failUnlessEqual(FirstFree([]), None)
1324
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1325
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1326
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1327

    
1328

    
1329
class TestTailFile(testutils.GanetiTestCase):
1330
  """Test case for the TailFile function"""
1331

    
1332
  def testEmpty(self):
1333
    fname = self._CreateTempFile()
1334
    self.failUnlessEqual(TailFile(fname), [])
1335
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1336

    
1337
  def testAllLines(self):
1338
    data = ["test %d" % i for i in range(30)]
1339
    for i in range(30):
1340
      fname = self._CreateTempFile()
1341
      fd = open(fname, "w")
1342
      fd.write("\n".join(data[:i]))
1343
      if i > 0:
1344
        fd.write("\n")
1345
      fd.close()
1346
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1347

    
1348
  def testPartialLines(self):
1349
    data = ["test %d" % i for i in range(30)]
1350
    fname = self._CreateTempFile()
1351
    fd = open(fname, "w")
1352
    fd.write("\n".join(data))
1353
    fd.write("\n")
1354
    fd.close()
1355
    for i in range(1, 30):
1356
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1357

    
1358
  def testBigFile(self):
1359
    data = ["test %d" % i for i in range(30)]
1360
    fname = self._CreateTempFile()
1361
    fd = open(fname, "w")
1362
    fd.write("X" * 1048576)
1363
    fd.write("\n")
1364
    fd.write("\n".join(data))
1365
    fd.write("\n")
1366
    fd.close()
1367
    for i in range(1, 30):
1368
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1369

    
1370

    
1371
class _BaseFileLockTest:
1372
  """Test case for the FileLock class"""
1373

    
1374
  def testSharedNonblocking(self):
1375
    self.lock.Shared(blocking=False)
1376
    self.lock.Close()
1377

    
1378
  def testExclusiveNonblocking(self):
1379
    self.lock.Exclusive(blocking=False)
1380
    self.lock.Close()
1381

    
1382
  def testUnlockNonblocking(self):
1383
    self.lock.Unlock(blocking=False)
1384
    self.lock.Close()
1385

    
1386
  def testSharedBlocking(self):
1387
    self.lock.Shared(blocking=True)
1388
    self.lock.Close()
1389

    
1390
  def testExclusiveBlocking(self):
1391
    self.lock.Exclusive(blocking=True)
1392
    self.lock.Close()
1393

    
1394
  def testUnlockBlocking(self):
1395
    self.lock.Unlock(blocking=True)
1396
    self.lock.Close()
1397

    
1398
  def testSharedExclusiveUnlock(self):
1399
    self.lock.Shared(blocking=False)
1400
    self.lock.Exclusive(blocking=False)
1401
    self.lock.Unlock(blocking=False)
1402
    self.lock.Close()
1403

    
1404
  def testExclusiveSharedUnlock(self):
1405
    self.lock.Exclusive(blocking=False)
1406
    self.lock.Shared(blocking=False)
1407
    self.lock.Unlock(blocking=False)
1408
    self.lock.Close()
1409

    
1410
  def testSimpleTimeout(self):
1411
    # These will succeed on the first attempt, hence a short timeout
1412
    self.lock.Shared(blocking=True, timeout=10.0)
1413
    self.lock.Exclusive(blocking=False, timeout=10.0)
1414
    self.lock.Unlock(blocking=True, timeout=10.0)
1415
    self.lock.Close()
1416

    
1417
  @staticmethod
1418
  def _TryLockInner(filename, shared, blocking):
1419
    lock = utils.FileLock.Open(filename)
1420

    
1421
    if shared:
1422
      fn = lock.Shared
1423
    else:
1424
      fn = lock.Exclusive
1425

    
1426
    try:
1427
      # The timeout doesn't really matter as the parent process waits for us to
1428
      # finish anyway.
1429
      fn(blocking=blocking, timeout=0.01)
1430
    except errors.LockError, err:
1431
      return False
1432

    
1433
    return True
1434

    
1435
  def _TryLock(self, *args):
1436
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1437
                                      *args)
1438

    
1439
  def testTimeout(self):
1440
    for blocking in [True, False]:
1441
      self.lock.Exclusive(blocking=True)
1442
      self.failIf(self._TryLock(False, blocking))
1443
      self.failIf(self._TryLock(True, blocking))
1444

    
1445
      self.lock.Shared(blocking=True)
1446
      self.assert_(self._TryLock(True, blocking))
1447
      self.failIf(self._TryLock(False, blocking))
1448

    
1449
  def testCloseShared(self):
1450
    self.lock.Close()
1451
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1452

    
1453
  def testCloseExclusive(self):
1454
    self.lock.Close()
1455
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1456

    
1457
  def testCloseUnlock(self):
1458
    self.lock.Close()
1459
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1460

    
1461

    
1462
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1463
  TESTDATA = "Hello World\n" * 10
1464

    
1465
  def setUp(self):
1466
    testutils.GanetiTestCase.setUp(self)
1467

    
1468
    self.tmpfile = tempfile.NamedTemporaryFile()
1469
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1470
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1471

    
1472
    # Ensure "Open" didn't truncate file
1473
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1474

    
1475
  def tearDown(self):
1476
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1477

    
1478
    testutils.GanetiTestCase.tearDown(self)
1479

    
1480

    
1481
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1482
  def setUp(self):
1483
    self.tmpfile = tempfile.NamedTemporaryFile()
1484
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1485

    
1486

    
1487
class TestTimeFunctions(unittest.TestCase):
1488
  """Test case for time functions"""
1489

    
1490
  def runTest(self):
1491
    self.assertEqual(utils.SplitTime(1), (1, 0))
1492
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1493
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1494
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1495
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1496
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1497
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1498
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1499

    
1500
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1501

    
1502
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1503
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1504
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1505

    
1506
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1507
                     1218448917.481)
1508
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1509

    
1510
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1511
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1512
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1513
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1514
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1515

    
1516

    
1517
class FieldSetTestCase(unittest.TestCase):
1518
  """Test case for FieldSets"""
1519

    
1520
  def testSimpleMatch(self):
1521
    f = utils.FieldSet("a", "b", "c", "def")
1522
    self.failUnless(f.Matches("a"))
1523
    self.failIf(f.Matches("d"), "Substring matched")
1524
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1525
    self.failIf(f.NonMatching(["b", "c"]))
1526
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1527
    self.failUnless(f.NonMatching(["a", "d"]))
1528

    
1529
  def testRegexMatch(self):
1530
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1531
    self.failUnless(f.Matches("b1"))
1532
    self.failUnless(f.Matches("b99"))
1533
    self.failIf(f.Matches("b/1"))
1534
    self.failIf(f.NonMatching(["b12", "c"]))
1535
    self.failUnless(f.NonMatching(["a", "1"]))
1536

    
1537
class TestForceDictType(unittest.TestCase):
1538
  """Test case for ForceDictType"""
1539

    
1540
  def setUp(self):
1541
    self.key_types = {
1542
      'a': constants.VTYPE_INT,
1543
      'b': constants.VTYPE_BOOL,
1544
      'c': constants.VTYPE_STRING,
1545
      'd': constants.VTYPE_SIZE,
1546
      "e": constants.VTYPE_MAYBE_STRING,
1547
      }
1548

    
1549
  def _fdt(self, dict, allowed_values=None):
1550
    if allowed_values is None:
1551
      utils.ForceDictType(dict, self.key_types)
1552
    else:
1553
      utils.ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1554

    
1555
    return dict
1556

    
1557
  def testSimpleDict(self):
1558
    self.assertEqual(self._fdt({}), {})
1559
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1560
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1561
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1562
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1563
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1564
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1565
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1566
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1567
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1568
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1569
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1570
    self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1571
    self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1572
    self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1573

    
1574
  def testErrors(self):
1575
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1576
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1577
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1578
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1579
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1580
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1581

    
1582

    
1583
class TestIsNormAbsPath(unittest.TestCase):
1584
  """Testing case for IsNormAbsPath"""
1585

    
1586
  def _pathTestHelper(self, path, result):
1587
    if result:
1588
      self.assert_(utils.IsNormAbsPath(path),
1589
          "Path %s should result absolute and normalized" % path)
1590
    else:
1591
      self.assertFalse(utils.IsNormAbsPath(path),
1592
          "Path %s should not result absolute and normalized" % path)
1593

    
1594
  def testBase(self):
1595
    self._pathTestHelper('/etc', True)
1596
    self._pathTestHelper('/srv', True)
1597
    self._pathTestHelper('etc', False)
1598
    self._pathTestHelper('/etc/../root', False)
1599
    self._pathTestHelper('/etc/', False)
1600

    
1601

    
1602
class TestSafeEncode(unittest.TestCase):
1603
  """Test case for SafeEncode"""
1604

    
1605
  def testAscii(self):
1606
    for txt in [string.digits, string.letters, string.punctuation]:
1607
      self.failUnlessEqual(txt, SafeEncode(txt))
1608

    
1609
  def testDoubleEncode(self):
1610
    for i in range(255):
1611
      txt = SafeEncode(chr(i))
1612
      self.failUnlessEqual(txt, SafeEncode(txt))
1613

    
1614
  def testUnicode(self):
1615
    # 1024 is high enough to catch non-direct ASCII mappings
1616
    for i in range(1024):
1617
      txt = SafeEncode(unichr(i))
1618
      self.failUnlessEqual(txt, SafeEncode(txt))
1619

    
1620

    
1621
class TestFormatTime(unittest.TestCase):
1622
  """Testing case for FormatTime"""
1623

    
1624
  def testNone(self):
1625
    self.failUnlessEqual(FormatTime(None), "N/A")
1626

    
1627
  def testInvalid(self):
1628
    self.failUnlessEqual(FormatTime(()), "N/A")
1629

    
1630
  def testNow(self):
1631
    # tests that we accept time.time input
1632
    FormatTime(time.time())
1633
    # tests that we accept int input
1634
    FormatTime(int(time.time()))
1635

    
1636

    
1637
class RunInSeparateProcess(unittest.TestCase):
1638
  def test(self):
1639
    for exp in [True, False]:
1640
      def _child():
1641
        return exp
1642

    
1643
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1644

    
1645
  def testArgs(self):
1646
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1647
      def _child(carg1, carg2):
1648
        return carg1 == "Foo" and carg2 == arg
1649

    
1650
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1651

    
1652
  def testPid(self):
1653
    parent_pid = os.getpid()
1654

    
1655
    def _check():
1656
      return os.getpid() == parent_pid
1657

    
1658
    self.failIf(utils.RunInSeparateProcess(_check))
1659

    
1660
  def testSignal(self):
1661
    def _kill():
1662
      os.kill(os.getpid(), signal.SIGTERM)
1663

    
1664
    self.assertRaises(errors.GenericError,
1665
                      utils.RunInSeparateProcess, _kill)
1666

    
1667
  def testException(self):
1668
    def _exc():
1669
      raise errors.GenericError("This is a test")
1670

    
1671
    self.assertRaises(errors.GenericError,
1672
                      utils.RunInSeparateProcess, _exc)
1673

    
1674

    
1675
class TestFingerprintFile(unittest.TestCase):
1676
  def setUp(self):
1677
    self.tmpfile = tempfile.NamedTemporaryFile()
1678

    
1679
  def test(self):
1680
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1681
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1682

    
1683
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1684
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1685
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1686

    
1687

    
1688
class TestUnescapeAndSplit(unittest.TestCase):
1689
  """Testing case for UnescapeAndSplit"""
1690

    
1691
  def setUp(self):
1692
    # testing more that one separator for regexp safety
1693
    self._seps = [",", "+", "."]
1694

    
1695
  def testSimple(self):
1696
    a = ["a", "b", "c", "d"]
1697
    for sep in self._seps:
1698
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1699

    
1700
  def testEscape(self):
1701
    for sep in self._seps:
1702
      a = ["a", "b\\" + sep + "c", "d"]
1703
      b = ["a", "b" + sep + "c", "d"]
1704
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1705

    
1706
  def testDoubleEscape(self):
1707
    for sep in self._seps:
1708
      a = ["a", "b\\\\", "c", "d"]
1709
      b = ["a", "b\\", "c", "d"]
1710
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1711

    
1712
  def testThreeEscape(self):
1713
    for sep in self._seps:
1714
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1715
      b = ["a", "b\\" + sep + "c", "d"]
1716
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1717

    
1718

    
1719
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1720
  def setUp(self):
1721
    self.tmpdir = tempfile.mkdtemp()
1722

    
1723
  def tearDown(self):
1724
    shutil.rmtree(self.tmpdir)
1725

    
1726
  def _checkRsaPrivateKey(self, key):
1727
    lines = key.splitlines()
1728
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1729
            "-----END RSA PRIVATE KEY-----" in lines)
1730

    
1731
  def _checkCertificate(self, cert):
1732
    lines = cert.splitlines()
1733
    return ("-----BEGIN CERTIFICATE-----" in lines and
1734
            "-----END CERTIFICATE-----" in lines)
1735

    
1736
  def test(self):
1737
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1738
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1739
      self._checkRsaPrivateKey(key_pem)
1740
      self._checkCertificate(cert_pem)
1741

    
1742
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1743
                                           key_pem)
1744
      self.assert_(key.bits() >= 1024)
1745
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1746
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1747

    
1748
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1749
                                             cert_pem)
1750
      self.failIf(x509.has_expired())
1751
      self.assertEqual(x509.get_issuer().CN, common_name)
1752
      self.assertEqual(x509.get_subject().CN, common_name)
1753
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1754

    
1755
  def testLegacy(self):
1756
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1757

    
1758
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1759

    
1760
    cert1 = utils.ReadFile(cert1_filename)
1761

    
1762
    self.assert_(self._checkRsaPrivateKey(cert1))
1763
    self.assert_(self._checkCertificate(cert1))
1764

    
1765

    
1766
class TestPathJoin(unittest.TestCase):
1767
  """Testing case for PathJoin"""
1768

    
1769
  def testBasicItems(self):
1770
    mlist = ["/a", "b", "c"]
1771
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1772

    
1773
  def testNonAbsPrefix(self):
1774
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1775

    
1776
  def testBackTrack(self):
1777
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1778

    
1779
  def testMultiAbs(self):
1780
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1781

    
1782

    
1783
class TestValidateServiceName(unittest.TestCase):
1784
  def testValid(self):
1785
    testnames = [
1786
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1787
      "ganeti",
1788
      "gnt-masterd",
1789
      "HELLO_WORLD_SVC",
1790
      "hello.world.1",
1791
      "0", "80", "1111", "65535",
1792
      ]
1793

    
1794
    for name in testnames:
1795
      self.assertEqual(utils.ValidateServiceName(name), name)
1796

    
1797
  def testInvalid(self):
1798
    testnames = [
1799
      -15756, -1, 65536, 133428083,
1800
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1801
      "-8546", "-1", "65536",
1802
      (129 * "A"),
1803
      ]
1804

    
1805
    for name in testnames:
1806
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1807

    
1808

    
1809
class TestParseAsn1Generalizedtime(unittest.TestCase):
1810
  def test(self):
1811
    # UTC
1812
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1813
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1814
                     1266860512)
1815
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1816
                     (2**31) - 1)
1817

    
1818
    # With offset
1819
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1820
                     1266860512)
1821
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1822
                     1266931012)
1823
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1824
                     1266931088)
1825
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1826
                     1266931295)
1827
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1828
                     3600)
1829

    
1830
    # Leap seconds are not supported by datetime.datetime
1831
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1832
                      "19841231235960+0000")
1833
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1834
                      "19920630235960+0000")
1835

    
1836
    # Errors
1837
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1838
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1839
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1840
                      "20100222174152")
1841
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1842
                      "Mon Feb 22 17:47:02 UTC 2010")
1843
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1844
                      "2010-02-22 17:42:02")
1845

    
1846

    
1847
class TestGetX509CertValidity(testutils.GanetiTestCase):
1848
  def setUp(self):
1849
    testutils.GanetiTestCase.setUp(self)
1850

    
1851
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1852

    
1853
    # Test whether we have pyOpenSSL 0.7 or above
1854
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1855

    
1856
    if not self.pyopenssl0_7:
1857
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1858
                    " function correctly")
1859

    
1860
  def _LoadCert(self, name):
1861
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1862
                                           self._ReadTestData(name))
1863

    
1864
  def test(self):
1865
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1866
    if self.pyopenssl0_7:
1867
      self.assertEqual(validity, (1266919967, 1267524767))
1868
    else:
1869
      self.assertEqual(validity, (None, None))
1870

    
1871

    
1872
class TestSignX509Certificate(unittest.TestCase):
1873
  KEY = "My private key!"
1874
  KEY_OTHER = "Another key"
1875

    
1876
  def test(self):
1877
    # Generate certificate valid for 5 minutes
1878
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1879

    
1880
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1881
                                           cert_pem)
1882

    
1883
    # No signature at all
1884
    self.assertRaises(errors.GenericError,
1885
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1886

    
1887
    # Invalid input
1888
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1889
                      "", self.KEY)
1890
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1891
                      "X-Ganeti-Signature: \n", self.KEY)
1892
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1893
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1894
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1895
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1896
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1897
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1898

    
1899
    # Invalid salt
1900
    for salt in list("-_@$,:;/\\ \t\n"):
1901
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1902
                        cert_pem, self.KEY, "foo%sbar" % salt)
1903

    
1904
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1905
                 utils.GenerateSecret(numbytes=4),
1906
                 utils.GenerateSecret(numbytes=16),
1907
                 "{123:456}".encode("hex")]:
1908
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1909

    
1910
      self._Check(cert, salt, signed_pem)
1911

    
1912
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1913
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1914
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1915
                               "lines----\n------ at\nthe end!"))
1916

    
1917
  def _Check(self, cert, salt, pem):
1918
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1919
    self.assertEqual(salt, salt2)
1920
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1921

    
1922
    # Other key
1923
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1924
                      pem, self.KEY_OTHER)
1925

    
1926

    
1927
class TestMakedirs(unittest.TestCase):
1928
  def setUp(self):
1929
    self.tmpdir = tempfile.mkdtemp()
1930

    
1931
  def tearDown(self):
1932
    shutil.rmtree(self.tmpdir)
1933

    
1934
  def testNonExisting(self):
1935
    path = PathJoin(self.tmpdir, "foo")
1936
    utils.Makedirs(path)
1937
    self.assert_(os.path.isdir(path))
1938

    
1939
  def testExisting(self):
1940
    path = PathJoin(self.tmpdir, "foo")
1941
    os.mkdir(path)
1942
    utils.Makedirs(path)
1943
    self.assert_(os.path.isdir(path))
1944

    
1945
  def testRecursiveNonExisting(self):
1946
    path = PathJoin(self.tmpdir, "foo/bar/baz")
1947
    utils.Makedirs(path)
1948
    self.assert_(os.path.isdir(path))
1949

    
1950
  def testRecursiveExisting(self):
1951
    path = PathJoin(self.tmpdir, "B/moo/xyz")
1952
    self.assertFalse(os.path.exists(path))
1953
    os.mkdir(PathJoin(self.tmpdir, "B"))
1954
    utils.Makedirs(path)
1955
    self.assert_(os.path.isdir(path))
1956

    
1957

    
1958
class TestRetry(testutils.GanetiTestCase):
1959
  def setUp(self):
1960
    testutils.GanetiTestCase.setUp(self)
1961
    self.retries = 0
1962

    
1963
  @staticmethod
1964
  def _RaiseRetryAgain():
1965
    raise utils.RetryAgain()
1966

    
1967
  @staticmethod
1968
  def _RaiseRetryAgainWithArg(args):
1969
    raise utils.RetryAgain(*args)
1970

    
1971
  def _WrongNestedLoop(self):
1972
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1973

    
1974
  def _RetryAndSucceed(self, retries):
1975
    if self.retries < retries:
1976
      self.retries += 1
1977
      raise utils.RetryAgain()
1978
    else:
1979
      return True
1980

    
1981
  def testRaiseTimeout(self):
1982
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1983
                          self._RaiseRetryAgain, 0.01, 0.02)
1984
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1985
                          self._RetryAndSucceed, 0.01, 0, args=[1])
1986
    self.failUnlessEqual(self.retries, 1)
1987

    
1988
  def testComplete(self):
1989
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1990
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1991
                         True)
1992
    self.failUnlessEqual(self.retries, 2)
1993

    
1994
  def testNestedLoop(self):
1995
    try:
1996
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1997
                            self._WrongNestedLoop, 0, 1)
1998
    except utils.RetryTimeout:
1999
      self.fail("Didn't detect inner loop's exception")
2000

    
2001
  def testTimeoutArgument(self):
2002
    retry_arg="my_important_debugging_message"
2003
    try:
2004
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2005
    except utils.RetryTimeout, err:
2006
      self.failUnlessEqual(err.args, (retry_arg, ))
2007
    else:
2008
      self.fail("Expected timeout didn't happen")
2009

    
2010
  def testRaiseInnerWithExc(self):
2011
    retry_arg="my_important_debugging_message"
2012
    try:
2013
      try:
2014
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2015
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
2016
      except utils.RetryTimeout, err:
2017
        err.RaiseInner()
2018
      else:
2019
        self.fail("Expected timeout didn't happen")
2020
    except errors.GenericError, err:
2021
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2022
    else:
2023
      self.fail("Expected GenericError didn't happen")
2024

    
2025
  def testRaiseInnerWithMsg(self):
2026
    retry_arg="my_important_debugging_message"
2027
    try:
2028
      try:
2029
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2030
                    args=[[retry_arg, retry_arg]])
2031
      except utils.RetryTimeout, err:
2032
        err.RaiseInner()
2033
      else:
2034
        self.fail("Expected timeout didn't happen")
2035
    except utils.RetryTimeout, err:
2036
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2037
    else:
2038
      self.fail("Expected RetryTimeout didn't happen")
2039

    
2040

    
2041
class TestLineSplitter(unittest.TestCase):
2042
  def test(self):
2043
    lines = []
2044
    ls = utils.LineSplitter(lines.append)
2045
    ls.write("Hello World\n")
2046
    self.assertEqual(lines, [])
2047
    ls.write("Foo\n Bar\r\n ")
2048
    ls.write("Baz")
2049
    ls.write("Moo")
2050
    self.assertEqual(lines, [])
2051
    ls.flush()
2052
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2053
    ls.close()
2054
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2055

    
2056
  def _testExtra(self, line, all_lines, p1, p2):
2057
    self.assertEqual(p1, 999)
2058
    self.assertEqual(p2, "extra")
2059
    all_lines.append(line)
2060

    
2061
  def testExtraArgsNoFlush(self):
2062
    lines = []
2063
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2064
    ls.write("\n\nHello World\n")
2065
    ls.write("Foo\n Bar\r\n ")
2066
    ls.write("")
2067
    ls.write("Baz")
2068
    ls.write("Moo\n\nx\n")
2069
    self.assertEqual(lines, [])
2070
    ls.close()
2071
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2072
                             "", "x"])
2073

    
2074

    
2075
class TestReadLockedPidFile(unittest.TestCase):
2076
  def setUp(self):
2077
    self.tmpdir = tempfile.mkdtemp()
2078

    
2079
  def tearDown(self):
2080
    shutil.rmtree(self.tmpdir)
2081

    
2082
  def testNonExistent(self):
2083
    path = PathJoin(self.tmpdir, "nonexist")
2084
    self.assert_(utils.ReadLockedPidFile(path) is None)
2085

    
2086
  def testUnlocked(self):
2087
    path = PathJoin(self.tmpdir, "pid")
2088
    utils.WriteFile(path, data="123")
2089
    self.assert_(utils.ReadLockedPidFile(path) is None)
2090

    
2091
  def testLocked(self):
2092
    path = PathJoin(self.tmpdir, "pid")
2093
    utils.WriteFile(path, data="123")
2094

    
2095
    fl = utils.FileLock.Open(path)
2096
    try:
2097
      fl.Exclusive(blocking=True)
2098

    
2099
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2100
    finally:
2101
      fl.Close()
2102

    
2103
    self.assert_(utils.ReadLockedPidFile(path) is None)
2104

    
2105
  def testError(self):
2106
    path = PathJoin(self.tmpdir, "foobar", "pid")
2107
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2108
    # open(2) should return ENOTDIR
2109
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2110

    
2111

    
2112
class TestCertVerification(testutils.GanetiTestCase):
2113
  def setUp(self):
2114
    testutils.GanetiTestCase.setUp(self)
2115

    
2116
    self.tmpdir = tempfile.mkdtemp()
2117

    
2118
  def tearDown(self):
2119
    shutil.rmtree(self.tmpdir)
2120

    
2121
  def testVerifyCertificate(self):
2122
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2123
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2124
                                           cert_pem)
2125

    
2126
    # Not checking return value as this certificate is expired
2127
    utils.VerifyX509Certificate(cert, 30, 7)
2128

    
2129

    
2130
class TestVerifyCertificateInner(unittest.TestCase):
2131
  def test(self):
2132
    vci = utils._VerifyCertificateInner
2133

    
2134
    # Valid
2135
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2136
                     (None, None))
2137

    
2138
    # Not yet valid
2139
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2140
    self.assertEqual(errcode, utils.CERT_WARNING)
2141

    
2142
    # Expiring soon
2143
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2144
    self.assertEqual(errcode, utils.CERT_ERROR)
2145

    
2146
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2147
    self.assertEqual(errcode, utils.CERT_WARNING)
2148

    
2149
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2150
    self.assertEqual(errcode, None)
2151

    
2152
    # Expired
2153
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2154
    self.assertEqual(errcode, utils.CERT_ERROR)
2155

    
2156
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2157
    self.assertEqual(errcode, utils.CERT_ERROR)
2158

    
2159
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2160
    self.assertEqual(errcode, utils.CERT_ERROR)
2161

    
2162
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2163
    self.assertEqual(errcode, utils.CERT_ERROR)
2164

    
2165

    
2166
class TestHmacFunctions(unittest.TestCase):
2167
  # Digests can be checked with "openssl sha1 -hmac $key"
2168
  def testSha1Hmac(self):
2169
    self.assertEqual(utils.Sha1Hmac("", ""),
2170
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2171
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2172
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2173
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2174
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2175

    
2176
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2177
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2178
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2179

    
2180
  def testSha1HmacSalt(self):
2181
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2182
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2183
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2184
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2185
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2186
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2187

    
2188
  def testVerifySha1Hmac(self):
2189
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2190
                                               "7d64b71fb76370690e1d")))
2191
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2192
                                      ("f904c2476527c6d3e660"
2193
                                       "9ab683c66fa0652cb1dc")))
2194

    
2195
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2196
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2197
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2198
                                      digest.lower()))
2199
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2200
                                      digest.upper()))
2201
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2202
                                      digest.title()))
2203

    
2204
  def testVerifySha1HmacSalt(self):
2205
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2206
                                      ("17a4adc34d69c0d367d4"
2207
                                       "ffbef96fd41d4df7a6e8"),
2208
                                      salt="abc9"))
2209
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2210
                                      ("7f264f8114c9066afc9b"
2211
                                       "b7636e1786d996d3cc0d"),
2212
                                      salt="xyz0"))
2213

    
2214

    
2215
class TestIgnoreSignals(unittest.TestCase):
2216
  """Test the IgnoreSignals decorator"""
2217

    
2218
  @staticmethod
2219
  def _Raise(exception):
2220
    raise exception
2221

    
2222
  @staticmethod
2223
  def _Return(rval):
2224
    return rval
2225

    
2226
  def testIgnoreSignals(self):
2227
    sock_err_intr = socket.error(errno.EINTR, "Message")
2228
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2229

    
2230
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2231
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2232

    
2233
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2234
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2235
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2236
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2237

    
2238
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2239
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2240
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2241
                      sock_err_inval)
2242
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2243
                      env_err_inval)
2244

    
2245
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2246
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2247

    
2248

    
2249
class TestEnsureDirs(unittest.TestCase):
2250
  """Tests for EnsureDirs"""
2251

    
2252
  def setUp(self):
2253
    self.dir = tempfile.mkdtemp()
2254
    self.old_umask = os.umask(0777)
2255

    
2256
  def testEnsureDirs(self):
2257
    utils.EnsureDirs([
2258
        (PathJoin(self.dir, "foo"), 0777),
2259
        (PathJoin(self.dir, "bar"), 0000),
2260
        ])
2261
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2262
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2263

    
2264
  def tearDown(self):
2265
    os.rmdir(PathJoin(self.dir, "foo"))
2266
    os.rmdir(PathJoin(self.dir, "bar"))
2267
    os.rmdir(self.dir)
2268
    os.umask(self.old_umask)
2269

    
2270

    
2271
class TestFormatSeconds(unittest.TestCase):
2272
  def test(self):
2273
    self.assertEqual(utils.FormatSeconds(1), "1s")
2274
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2275
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2276
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2277
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2278
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2279
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2280
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2281
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2282
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2283
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2284

    
2285
  def testFloat(self):
2286
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2287
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2288
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2289
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2290

    
2291

    
2292
class TestIgnoreProcessNotFound(unittest.TestCase):
2293
  @staticmethod
2294
  def _WritePid(fd):
2295
    os.write(fd, str(os.getpid()))
2296
    os.close(fd)
2297
    return True
2298

    
2299
  def test(self):
2300
    (pid_read_fd, pid_write_fd) = os.pipe()
2301

    
2302
    # Start short-lived process which writes its PID to pipe
2303
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2304
    os.close(pid_write_fd)
2305

    
2306
    # Read PID from pipe
2307
    pid = int(os.read(pid_read_fd, 1024))
2308
    os.close(pid_read_fd)
2309

    
2310
    # Try to send signal to process which exited recently
2311
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2312

    
2313

    
2314
class TestShellWriter(unittest.TestCase):
2315
  def test(self):
2316
    buf = StringIO()
2317
    sw = utils.ShellWriter(buf)
2318
    sw.Write("#!/bin/bash")
2319
    sw.Write("if true; then")
2320
    sw.IncIndent()
2321
    try:
2322
      sw.Write("echo true")
2323

    
2324
      sw.Write("for i in 1 2 3")
2325
      sw.Write("do")
2326
      sw.IncIndent()
2327
      try:
2328
        self.assertEqual(sw._indent, 2)
2329
        sw.Write("date")
2330
      finally:
2331
        sw.DecIndent()
2332
      sw.Write("done")
2333
    finally:
2334
      sw.DecIndent()
2335
    sw.Write("echo %s", utils.ShellQuote("Hello World"))
2336
    sw.Write("exit 0")
2337

    
2338
    self.assertEqual(sw._indent, 0)
2339

    
2340
    output = buf.getvalue()
2341

    
2342
    self.assert_(output.endswith("\n"))
2343

    
2344
    lines = output.splitlines()
2345
    self.assertEqual(len(lines), 9)
2346
    self.assertEqual(lines[0], "#!/bin/bash")
2347
    self.assert_(re.match(r"^\s+date$", lines[5]))
2348
    self.assertEqual(lines[7], "echo 'Hello World'")
2349

    
2350
  def testEmpty(self):
2351
    buf = StringIO()
2352
    sw = utils.ShellWriter(buf)
2353
    sw = None
2354
    self.assertEqual(buf.getvalue(), "")
2355

    
2356

    
2357
class TestCommaJoin(unittest.TestCase):
2358
  def test(self):
2359
    self.assertEqual(utils.CommaJoin([]), "")
2360
    self.assertEqual(utils.CommaJoin([1, 2, 3]), "1, 2, 3")
2361
    self.assertEqual(utils.CommaJoin(["Hello"]), "Hello")
2362
    self.assertEqual(utils.CommaJoin(["Hello", "World"]), "Hello, World")
2363
    self.assertEqual(utils.CommaJoin(["Hello", "World", 99]),
2364
                     "Hello, World, 99")
2365

    
2366

    
2367
class TestFindMatch(unittest.TestCase):
2368
  def test(self):
2369
    data = {
2370
      "aaaa": "Four A",
2371
      "bb": {"Two B": True},
2372
      re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
2373
      }
2374

    
2375
    self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
2376
    self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
2377

    
2378
    for i in ["foo", "bar", "bazX"]:
2379
      for j in range(1, 100, 7):
2380
        self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
2381
                         ((1, 2, 3), [i, str(j)]))
2382

    
2383
  def testNoMatch(self):
2384
    self.assert_(utils.FindMatch({}, "") is None)
2385
    self.assert_(utils.FindMatch({}, "foo") is None)
2386
    self.assert_(utils.FindMatch({}, 1234) is None)
2387

    
2388
    data = {
2389
      "X": "Hello World",
2390
      re.compile("^(something)$"): "Hello World",
2391
      }
2392

    
2393
    self.assert_(utils.FindMatch(data, "") is None)
2394
    self.assert_(utils.FindMatch(data, "Hello World") is None)
2395

    
2396

    
2397
class TestFileID(testutils.GanetiTestCase):
2398
  def testEquality(self):
2399
    name = self._CreateTempFile()
2400
    oldi = utils.GetFileID(path=name)
2401
    self.failUnless(utils.VerifyFileID(oldi, oldi))
2402

    
2403
  def testUpdate(self):
2404
    name = self._CreateTempFile()
2405
    oldi = utils.GetFileID(path=name)
2406
    os.utime(name, None)
2407
    fd = os.open(name, os.O_RDWR)
2408
    try:
2409
      newi = utils.GetFileID(fd=fd)
2410
      self.failUnless(utils.VerifyFileID(oldi, newi))
2411
      self.failUnless(utils.VerifyFileID(newi, oldi))
2412
    finally:
2413
      os.close(fd)
2414

    
2415
  def testWriteFile(self):
2416
    name = self._CreateTempFile()
2417
    oldi = utils.GetFileID(path=name)
2418
    mtime = oldi[2]
2419
    os.utime(name, (mtime + 10, mtime + 10))
2420
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
2421
                      oldi, data="")
2422
    os.utime(name, (mtime - 10, mtime - 10))
2423
    utils.SafeWriteFile(name, oldi, data="")
2424
    oldi = utils.GetFileID(path=name)
2425
    mtime = oldi[2]
2426
    os.utime(name, (mtime + 10, mtime + 10))
2427
    # this doesn't raise, since we passed None
2428
    utils.SafeWriteFile(name, None, data="")
2429

    
2430

    
2431
class TimeMock:
2432
  def __init__(self, values):
2433
    self.values = values
2434

    
2435
  def __call__(self):
2436
    return self.values.pop(0)
2437

    
2438

    
2439
class TestRunningTimeout(unittest.TestCase):
2440
  def setUp(self):
2441
    self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
2442

    
2443
  def testRemainingFloat(self):
2444
    timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
2445
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2446
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2447
    self.assertAlmostEqual(timeout.Remaining(), -1.5)
2448

    
2449
  def testRemaining(self):
2450
    self.time_fn = TimeMock([0, 2, 4, 5, 6])
2451
    timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
2452
    self.assertEqual(timeout.Remaining(), 3)
2453
    self.assertEqual(timeout.Remaining(), 1)
2454
    self.assertEqual(timeout.Remaining(), 0)
2455
    self.assertEqual(timeout.Remaining(), -1)
2456

    
2457
  def testRemainingNonNegative(self):
2458
    timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
2459
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2460
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2461
    self.assertEqual(timeout.Remaining(), 0.0)
2462

    
2463
  def testNegativeTimeout(self):
2464
    self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
2465

    
2466

    
2467
if __name__ == '__main__':
2468
  testutils.GanetiTestProgram()