Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 7f81e2b9

History | View | Annotate | Download (82.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
from cStringIO import StringIO
42

    
43
import testutils
44
from ganeti import constants
45
from ganeti import compat
46
from ganeti import utils
47
from ganeti import errors
48
from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
49
     ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
50
     TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
51
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
52

    
53

    
54
class TestIsProcessAlive(unittest.TestCase):
55
  """Testing case for IsProcessAlive"""
56

    
57
  def testExists(self):
58
    mypid = os.getpid()
59
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
60

    
61
  def testNotExisting(self):
62
    pid_non_existing = os.fork()
63
    if pid_non_existing == 0:
64
      os._exit(0)
65
    elif pid_non_existing < 0:
66
      raise SystemError("can't fork")
67
    os.waitpid(pid_non_existing, 0)
68
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
69
                     "nonexisting process detected")
70

    
71

    
72
class TestGetProcStatusPath(unittest.TestCase):
73
  def test(self):
74
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
75
    self.assertNotEqual(utils._GetProcStatusPath(1),
76
                        utils._GetProcStatusPath(2))
77

    
78

    
79
class TestIsProcessHandlingSignal(unittest.TestCase):
80
  def setUp(self):
81
    self.tmpdir = tempfile.mkdtemp()
82

    
83
  def tearDown(self):
84
    shutil.rmtree(self.tmpdir)
85

    
86
  def testParseSigsetT(self):
87
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
88
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
89
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
90
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
91
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
92
                     set([2, 10, 32, 33]))
93
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
94
                     set([2, 32, 33]))
95
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
96
                     set([2, 28, 32, 33]))
97
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
98
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
99
                          24, 25, 26, 28, 31]))
100
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
101

    
102
  def testGetProcStatusField(self):
103
    for field in ["SigCgt", "Name", "FDSize"]:
104
      for value in ["", "0", "cat", "  1234 KB"]:
105
        pstatus = "\n".join([
106
          "VmPeak: 999 kB",
107
          "%s: %s" % (field, value),
108
          "TracerPid: 0",
109
          ])
110
        result = utils._GetProcStatusField(pstatus, field)
111
        self.assertEqual(result, value.strip())
112

    
113
  def test(self):
114
    sp = PathJoin(self.tmpdir, "status")
115

    
116
    utils.WriteFile(sp, data="\n".join([
117
      "Name:   bash",
118
      "State:  S (sleeping)",
119
      "SleepAVG:       98%",
120
      "Pid:    22250",
121
      "PPid:   10858",
122
      "TracerPid:      0",
123
      "SigBlk: 0000000000010000",
124
      "SigIgn: 0000000000384004",
125
      "SigCgt: 000000004b813efb",
126
      "CapEff: 0000000000000000",
127
      ]))
128

    
129
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
130

    
131
  def testNoSigCgt(self):
132
    sp = PathJoin(self.tmpdir, "status")
133

    
134
    utils.WriteFile(sp, data="\n".join([
135
      "Name:   bash",
136
      ]))
137

    
138
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
139
                      1234, 10, status_path=sp)
140

    
141
  def testNoSuchFile(self):
142
    sp = PathJoin(self.tmpdir, "notexist")
143

    
144
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
145

    
146
  @staticmethod
147
  def _TestRealProcess():
148
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
149
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
150
      raise Exception("SIGUSR1 is handled when it should not be")
151

    
152
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
153
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
154
      raise Exception("SIGUSR1 is not handled when it should be")
155

    
156
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
157
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
158
      raise Exception("SIGUSR1 is not handled when it should be")
159

    
160
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
161
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
162
      raise Exception("SIGUSR1 is handled when it should not be")
163

    
164
    return True
165

    
166
  def testRealProcess(self):
167
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
168

    
169

    
170
class TestPidFileFunctions(unittest.TestCase):
171
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
172

    
173
  def setUp(self):
174
    self.dir = tempfile.mkdtemp()
175
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
176
    utils.DaemonPidFileName = self.f_dpn
177

    
178
  def testPidFileFunctions(self):
179
    pid_file = self.f_dpn('test')
180
    fd = utils.WritePidFile(self.f_dpn('test'))
181
    self.failUnless(os.path.exists(pid_file),
182
                    "PID file should have been created")
183
    read_pid = utils.ReadPidFile(pid_file)
184
    self.failUnlessEqual(read_pid, os.getpid())
185
    self.failUnless(utils.IsProcessAlive(read_pid))
186
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
187
                          self.f_dpn('test'))
188
    os.close(fd)
189
    utils.RemovePidFile('test')
190
    self.failIf(os.path.exists(pid_file),
191
                "PID file should not exist anymore")
192
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
193
                         "ReadPidFile should return 0 for missing pid file")
194
    fh = open(pid_file, "w")
195
    fh.write("blah\n")
196
    fh.close()
197
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
198
                         "ReadPidFile should return 0 for invalid pid file")
199
    # but now, even with the file existing, we should be able to lock it
200
    fd = utils.WritePidFile(self.f_dpn('test'))
201
    os.close(fd)
202
    utils.RemovePidFile('test')
203
    self.failIf(os.path.exists(pid_file),
204
                "PID file should not exist anymore")
205

    
206
  def testKill(self):
207
    pid_file = self.f_dpn('child')
208
    r_fd, w_fd = os.pipe()
209
    new_pid = os.fork()
210
    if new_pid == 0: #child
211
      utils.WritePidFile(self.f_dpn('child'))
212
      os.write(w_fd, 'a')
213
      signal.pause()
214
      os._exit(0)
215
      return
216
    # else we are in the parent
217
    # wait until the child has written the pid file
218
    os.read(r_fd, 1)
219
    read_pid = utils.ReadPidFile(pid_file)
220
    self.failUnlessEqual(read_pid, new_pid)
221
    self.failUnless(utils.IsProcessAlive(new_pid))
222
    utils.KillProcess(new_pid, waitpid=True)
223
    self.failIf(utils.IsProcessAlive(new_pid))
224
    utils.RemovePidFile('child')
225
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
226

    
227
  def tearDown(self):
228
    for name in os.listdir(self.dir):
229
      os.unlink(os.path.join(self.dir, name))
230
    os.rmdir(self.dir)
231

    
232

    
233
class TestRunCmd(testutils.GanetiTestCase):
234
  """Testing case for the RunCmd function"""
235

    
236
  def setUp(self):
237
    testutils.GanetiTestCase.setUp(self)
238
    self.magic = time.ctime() + " ganeti test"
239
    self.fname = self._CreateTempFile()
240
    self.fifo_tmpdir = tempfile.mkdtemp()
241
    self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
242
    os.mkfifo(self.fifo_file)
243

    
244
  def tearDown(self):
245
    shutil.rmtree(self.fifo_tmpdir)
246

    
247
  def testOk(self):
248
    """Test successful exit code"""
249
    result = RunCmd("/bin/sh -c 'exit 0'")
250
    self.assertEqual(result.exit_code, 0)
251
    self.assertEqual(result.output, "")
252

    
253
  def testFail(self):
254
    """Test fail exit code"""
255
    result = RunCmd("/bin/sh -c 'exit 1'")
256
    self.assertEqual(result.exit_code, 1)
257
    self.assertEqual(result.output, "")
258

    
259
  def testStdout(self):
260
    """Test standard output"""
261
    cmd = 'echo -n "%s"' % self.magic
262
    result = RunCmd("/bin/sh -c '%s'" % cmd)
263
    self.assertEqual(result.stdout, self.magic)
264
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
265
    self.assertEqual(result.output, "")
266
    self.assertFileContent(self.fname, self.magic)
267

    
268
  def testStderr(self):
269
    """Test standard error"""
270
    cmd = 'echo -n "%s"' % self.magic
271
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
272
    self.assertEqual(result.stderr, self.magic)
273
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
274
    self.assertEqual(result.output, "")
275
    self.assertFileContent(self.fname, self.magic)
276

    
277
  def testCombined(self):
278
    """Test combined output"""
279
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
280
    expected = "A" + self.magic + "B" + self.magic
281
    result = RunCmd("/bin/sh -c '%s'" % cmd)
282
    self.assertEqual(result.output, expected)
283
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
284
    self.assertEqual(result.output, "")
285
    self.assertFileContent(self.fname, expected)
286

    
287
  def testSignal(self):
288
    """Test signal"""
289
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
290
    self.assertEqual(result.signal, 15)
291
    self.assertEqual(result.output, "")
292

    
293
  def testTimeoutClean(self):
294
    cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
295
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
296
    self.assertEqual(result.exit_code, 0)
297

    
298
  def testTimeoutKill(self):
299
    cmd = "trap '' TERM; read < %s" % self.fifo_file
300
    timeout = 0.2
301
    strcmd = utils.ShellQuoteArgs(["/bin/sh", "-c", cmd])
302
    out, err, status, ta = utils._RunCmdPipe(strcmd, {}, True, "/", False,
303
                                             timeout, _linger_timeout=0.2)
304
    self.assert_(status < 0)
305
    self.assertEqual(-status, signal.SIGKILL)
306

    
307
  def testTimeoutOutputAfterTerm(self):
308
    cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
309
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
310
    self.assert_(result.failed)
311
    self.assertEqual(result.stdout, "sigtermed\n")
312

    
313
  def testListRun(self):
314
    """Test list runs"""
315
    result = RunCmd(["true"])
316
    self.assertEqual(result.signal, None)
317
    self.assertEqual(result.exit_code, 0)
318
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
319
    self.assertEqual(result.signal, None)
320
    self.assertEqual(result.exit_code, 1)
321
    result = RunCmd(["echo", "-n", self.magic])
322
    self.assertEqual(result.signal, None)
323
    self.assertEqual(result.exit_code, 0)
324
    self.assertEqual(result.stdout, self.magic)
325

    
326
  def testFileEmptyOutput(self):
327
    """Test file output"""
328
    result = RunCmd(["true"], output=self.fname)
329
    self.assertEqual(result.signal, None)
330
    self.assertEqual(result.exit_code, 0)
331
    self.assertFileContent(self.fname, "")
332

    
333
  def testLang(self):
334
    """Test locale environment"""
335
    old_env = os.environ.copy()
336
    try:
337
      os.environ["LANG"] = "en_US.UTF-8"
338
      os.environ["LC_ALL"] = "en_US.UTF-8"
339
      result = RunCmd(["locale"])
340
      for line in result.output.splitlines():
341
        key, value = line.split("=", 1)
342
        # Ignore these variables, they're overridden by LC_ALL
343
        if key == "LANG" or key == "LANGUAGE":
344
          continue
345
        self.failIf(value and value != "C" and value != '"C"',
346
            "Variable %s is set to the invalid value '%s'" % (key, value))
347
    finally:
348
      os.environ = old_env
349

    
350
  def testDefaultCwd(self):
351
    """Test default working directory"""
352
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
353

    
354
  def testCwd(self):
355
    """Test default working directory"""
356
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
357
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
358
    cwd = os.getcwd()
359
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
360

    
361
  def testResetEnv(self):
362
    """Test environment reset functionality"""
363
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
364
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
365
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
366

    
367

    
368
class TestRunParts(unittest.TestCase):
369
  """Testing case for the RunParts function"""
370

    
371
  def setUp(self):
372
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
373

    
374
  def tearDown(self):
375
    shutil.rmtree(self.rundir)
376

    
377
  def testEmpty(self):
378
    """Test on an empty dir"""
379
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
380

    
381
  def testSkipWrongName(self):
382
    """Test that wrong files are skipped"""
383
    fname = os.path.join(self.rundir, "00test.dot")
384
    utils.WriteFile(fname, data="")
385
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
386
    relname = os.path.basename(fname)
387
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
388
                         [(relname, constants.RUNPARTS_SKIP, None)])
389

    
390
  def testSkipNonExec(self):
391
    """Test that non executable files are skipped"""
392
    fname = os.path.join(self.rundir, "00test")
393
    utils.WriteFile(fname, data="")
394
    relname = os.path.basename(fname)
395
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
396
                         [(relname, constants.RUNPARTS_SKIP, None)])
397

    
398
  def testError(self):
399
    """Test error on a broken executable"""
400
    fname = os.path.join(self.rundir, "00test")
401
    utils.WriteFile(fname, data="")
402
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
403
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
404
    self.failUnlessEqual(relname, os.path.basename(fname))
405
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
406
    self.failUnless(error)
407

    
408
  def testSorted(self):
409
    """Test executions are sorted"""
410
    files = []
411
    files.append(os.path.join(self.rundir, "64test"))
412
    files.append(os.path.join(self.rundir, "00test"))
413
    files.append(os.path.join(self.rundir, "42test"))
414

    
415
    for fname in files:
416
      utils.WriteFile(fname, data="")
417

    
418
    results = RunParts(self.rundir, reset_env=True)
419

    
420
    for fname in sorted(files):
421
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
422

    
423
  def testOk(self):
424
    """Test correct execution"""
425
    fname = os.path.join(self.rundir, "00test")
426
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
427
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
428
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
429
    self.failUnlessEqual(relname, os.path.basename(fname))
430
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
431
    self.failUnlessEqual(runresult.stdout, "ciao")
432

    
433
  def testRunFail(self):
434
    """Test correct execution, with run failure"""
435
    fname = os.path.join(self.rundir, "00test")
436
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
437
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
438
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
439
    self.failUnlessEqual(relname, os.path.basename(fname))
440
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
441
    self.failUnlessEqual(runresult.exit_code, 1)
442
    self.failUnless(runresult.failed)
443

    
444
  def testRunMix(self):
445
    files = []
446
    files.append(os.path.join(self.rundir, "00test"))
447
    files.append(os.path.join(self.rundir, "42test"))
448
    files.append(os.path.join(self.rundir, "64test"))
449
    files.append(os.path.join(self.rundir, "99test"))
450

    
451
    files.sort()
452

    
453
    # 1st has errors in execution
454
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
455
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
456

    
457
    # 2nd is skipped
458
    utils.WriteFile(files[1], data="")
459

    
460
    # 3rd cannot execute properly
461
    utils.WriteFile(files[2], data="")
462
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
463

    
464
    # 4th execs
465
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
466
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
467

    
468
    results = RunParts(self.rundir, reset_env=True)
469

    
470
    (relname, status, runresult) = results[0]
471
    self.failUnlessEqual(relname, os.path.basename(files[0]))
472
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
473
    self.failUnlessEqual(runresult.exit_code, 1)
474
    self.failUnless(runresult.failed)
475

    
476
    (relname, status, runresult) = results[1]
477
    self.failUnlessEqual(relname, os.path.basename(files[1]))
478
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
479
    self.failUnlessEqual(runresult, None)
480

    
481
    (relname, status, runresult) = results[2]
482
    self.failUnlessEqual(relname, os.path.basename(files[2]))
483
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
484
    self.failUnless(runresult)
485

    
486
    (relname, status, runresult) = results[3]
487
    self.failUnlessEqual(relname, os.path.basename(files[3]))
488
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
489
    self.failUnlessEqual(runresult.output, "ciao")
490
    self.failUnlessEqual(runresult.exit_code, 0)
491
    self.failUnless(not runresult.failed)
492

    
493

    
494
class TestStartDaemon(testutils.GanetiTestCase):
495
  def setUp(self):
496
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
497
    self.tmpfile = os.path.join(self.tmpdir, "test")
498

    
499
  def tearDown(self):
500
    shutil.rmtree(self.tmpdir)
501

    
502
  def testShell(self):
503
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
504
    self._wait(self.tmpfile, 60.0, "Hello World")
505

    
506
  def testShellOutput(self):
507
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
508
    self._wait(self.tmpfile, 60.0, "Hello World")
509

    
510
  def testNoShellNoOutput(self):
511
    utils.StartDaemon(["pwd"])
512

    
513
  def testNoShellNoOutputTouch(self):
514
    testfile = os.path.join(self.tmpdir, "check")
515
    self.failIf(os.path.exists(testfile))
516
    utils.StartDaemon(["touch", testfile])
517
    self._wait(testfile, 60.0, "")
518

    
519
  def testNoShellOutput(self):
520
    utils.StartDaemon(["pwd"], output=self.tmpfile)
521
    self._wait(self.tmpfile, 60.0, "/")
522

    
523
  def testNoShellOutputCwd(self):
524
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
525
    self._wait(self.tmpfile, 60.0, os.getcwd())
526

    
527
  def testShellEnv(self):
528
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
529
                      env={ "GNT_TEST_VAR": "Hello World", })
530
    self._wait(self.tmpfile, 60.0, "Hello World")
531

    
532
  def testNoShellEnv(self):
533
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
534
                      env={ "GNT_TEST_VAR": "Hello World", })
535
    self._wait(self.tmpfile, 60.0, "Hello World")
536

    
537
  def testOutputFd(self):
538
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
539
    try:
540
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
541
    finally:
542
      os.close(fd)
543
    self._wait(self.tmpfile, 60.0, os.getcwd())
544

    
545
  def testPid(self):
546
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
547
    self._wait(self.tmpfile, 60.0, str(pid))
548

    
549
  def testPidFile(self):
550
    pidfile = os.path.join(self.tmpdir, "pid")
551
    checkfile = os.path.join(self.tmpdir, "abort")
552

    
553
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
554
                            output=self.tmpfile)
555
    try:
556
      fd = os.open(pidfile, os.O_RDONLY)
557
      try:
558
        # Check file is locked
559
        self.assertRaises(errors.LockError, utils.LockFile, fd)
560

    
561
        pidtext = os.read(fd, 100)
562
      finally:
563
        os.close(fd)
564

    
565
      self.assertEqual(int(pidtext.strip()), pid)
566

    
567
      self.assert_(utils.IsProcessAlive(pid))
568
    finally:
569
      # No matter what happens, kill daemon
570
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
571
      self.failIf(utils.IsProcessAlive(pid))
572

    
573
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
574

    
575
  def _wait(self, path, timeout, expected):
576
    # Due to the asynchronous nature of daemon processes, polling is necessary.
577
    # A timeout makes sure the test doesn't hang forever.
578
    def _CheckFile():
579
      if not (os.path.isfile(path) and
580
              utils.ReadFile(path).strip() == expected):
581
        raise utils.RetryAgain()
582

    
583
    try:
584
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
585
    except utils.RetryTimeout:
586
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
587
                " didn't write the correct output" % timeout)
588

    
589
  def testError(self):
590
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
591
                      ["./does-NOT-EXIST/here/0123456789"])
592
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
593
                      ["./does-NOT-EXIST/here/0123456789"],
594
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
595
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
596
                      ["./does-NOT-EXIST/here/0123456789"],
597
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
598
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
599
                      ["./does-NOT-EXIST/here/0123456789"],
600
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
601

    
602
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
603
    try:
604
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
605
                        ["./does-NOT-EXIST/here/0123456789"],
606
                        output=self.tmpfile, output_fd=fd)
607
    finally:
608
      os.close(fd)
609

    
610

    
611
class TestSetCloseOnExecFlag(unittest.TestCase):
612
  """Tests for SetCloseOnExecFlag"""
613

    
614
  def setUp(self):
615
    self.tmpfile = tempfile.TemporaryFile()
616

    
617
  def testEnable(self):
618
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
619
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
620
                    fcntl.FD_CLOEXEC)
621

    
622
  def testDisable(self):
623
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
624
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
625
                fcntl.FD_CLOEXEC)
626

    
627

    
628
class TestSetNonblockFlag(unittest.TestCase):
629
  def setUp(self):
630
    self.tmpfile = tempfile.TemporaryFile()
631

    
632
  def testEnable(self):
633
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
634
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
635
                    os.O_NONBLOCK)
636

    
637
  def testDisable(self):
638
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
639
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
640
                os.O_NONBLOCK)
641

    
642

    
643
class TestRemoveFile(unittest.TestCase):
644
  """Test case for the RemoveFile function"""
645

    
646
  def setUp(self):
647
    """Create a temp dir and file for each case"""
648
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
649
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
650
    os.close(fd)
651

    
652
  def tearDown(self):
653
    if os.path.exists(self.tmpfile):
654
      os.unlink(self.tmpfile)
655
    os.rmdir(self.tmpdir)
656

    
657
  def testIgnoreDirs(self):
658
    """Test that RemoveFile() ignores directories"""
659
    self.assertEqual(None, RemoveFile(self.tmpdir))
660

    
661
  def testIgnoreNotExisting(self):
662
    """Test that RemoveFile() ignores non-existing files"""
663
    RemoveFile(self.tmpfile)
664
    RemoveFile(self.tmpfile)
665

    
666
  def testRemoveFile(self):
667
    """Test that RemoveFile does remove a file"""
668
    RemoveFile(self.tmpfile)
669
    if os.path.exists(self.tmpfile):
670
      self.fail("File '%s' not removed" % self.tmpfile)
671

    
672
  def testRemoveSymlink(self):
673
    """Test that RemoveFile does remove symlinks"""
674
    symlink = self.tmpdir + "/symlink"
675
    os.symlink("no-such-file", symlink)
676
    RemoveFile(symlink)
677
    if os.path.exists(symlink):
678
      self.fail("File '%s' not removed" % symlink)
679
    os.symlink(self.tmpfile, symlink)
680
    RemoveFile(symlink)
681
    if os.path.exists(symlink):
682
      self.fail("File '%s' not removed" % symlink)
683

    
684

    
685
class TestRename(unittest.TestCase):
686
  """Test case for RenameFile"""
687

    
688
  def setUp(self):
689
    """Create a temporary directory"""
690
    self.tmpdir = tempfile.mkdtemp()
691
    self.tmpfile = os.path.join(self.tmpdir, "test1")
692

    
693
    # Touch the file
694
    open(self.tmpfile, "w").close()
695

    
696
  def tearDown(self):
697
    """Remove temporary directory"""
698
    shutil.rmtree(self.tmpdir)
699

    
700
  def testSimpleRename1(self):
701
    """Simple rename 1"""
702
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
703
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
704

    
705
  def testSimpleRename2(self):
706
    """Simple rename 2"""
707
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
708
                     mkdir=True)
709
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
710

    
711
  def testRenameMkdir(self):
712
    """Rename with mkdir"""
713
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
714
                     mkdir=True)
715
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
716
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
717

    
718
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
719
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
720
                     mkdir=True)
721
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
722
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
723
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
724

    
725

    
726
class TestMatchNameComponent(unittest.TestCase):
727
  """Test case for the MatchNameComponent function"""
728

    
729
  def testEmptyList(self):
730
    """Test that there is no match against an empty list"""
731

    
732
    self.failUnlessEqual(MatchNameComponent("", []), None)
733
    self.failUnlessEqual(MatchNameComponent("test", []), None)
734

    
735
  def testSingleMatch(self):
736
    """Test that a single match is performed correctly"""
737
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
738
    for key in "test2", "test2.example", "test2.example.com":
739
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
740

    
741
  def testMultipleMatches(self):
742
    """Test that a multiple match is returned as None"""
743
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
744
    for key in "test1", "test1.example":
745
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
746

    
747
  def testFullMatch(self):
748
    """Test that a full match is returned correctly"""
749
    key1 = "test1"
750
    key2 = "test1.example"
751
    mlist = [key2, key2 + ".com"]
752
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
753
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
754

    
755
  def testCaseInsensitivePartialMatch(self):
756
    """Test for the case_insensitive keyword"""
757
    mlist = ["test1.example.com", "test2.example.net"]
758
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
759
                     "test2.example.net")
760
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
761
                     "test2.example.net")
762
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
763
                     "test2.example.net")
764
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
765
                     "test2.example.net")
766

    
767

    
768
  def testCaseInsensitiveFullMatch(self):
769
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
770
    # Between the two ts1 a full string match non-case insensitive should work
771
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
772
                     None)
773
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
774
                     "ts1.ex")
775
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
776
                     "ts1.ex")
777
    # Between the two ts2 only case differs, so only case-match works
778
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
779
                     "ts2.ex")
780
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
781
                     "Ts2.ex")
782
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
783
                     None)
784

    
785

    
786
class TestReadFile(testutils.GanetiTestCase):
787

    
788
  def testReadAll(self):
789
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
790
    self.assertEqual(len(data), 814)
791

    
792
    h = compat.md5_hash()
793
    h.update(data)
794
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
795

    
796
  def testReadSize(self):
797
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
798
                          size=100)
799
    self.assertEqual(len(data), 100)
800

    
801
    h = compat.md5_hash()
802
    h.update(data)
803
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
804

    
805
  def testError(self):
806
    self.assertRaises(EnvironmentError, utils.ReadFile,
807
                      "/dev/null/does-not-exist")
808

    
809

    
810
class TestReadOneLineFile(testutils.GanetiTestCase):
811

    
812
  def setUp(self):
813
    testutils.GanetiTestCase.setUp(self)
814

    
815
  def testDefault(self):
816
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
817
    self.assertEqual(len(data), 27)
818
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
819

    
820
  def testNotStrict(self):
821
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
822
    self.assertEqual(len(data), 27)
823
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
824

    
825
  def testStrictFailure(self):
826
    self.assertRaises(errors.GenericError, ReadOneLineFile,
827
                      self._TestDataFilename("cert1.pem"), strict=True)
828

    
829
  def testLongLine(self):
830
    dummydata = (1024 * "Hello World! ")
831
    myfile = self._CreateTempFile()
832
    utils.WriteFile(myfile, data=dummydata)
833
    datastrict = ReadOneLineFile(myfile, strict=True)
834
    datalax = ReadOneLineFile(myfile, strict=False)
835
    self.assertEqual(dummydata, datastrict)
836
    self.assertEqual(dummydata, datalax)
837

    
838
  def testNewline(self):
839
    myfile = self._CreateTempFile()
840
    myline = "myline"
841
    for nl in ["", "\n", "\r\n"]:
842
      dummydata = "%s%s" % (myline, nl)
843
      utils.WriteFile(myfile, data=dummydata)
844
      datalax = ReadOneLineFile(myfile, strict=False)
845
      self.assertEqual(myline, datalax)
846
      datastrict = ReadOneLineFile(myfile, strict=True)
847
      self.assertEqual(myline, datastrict)
848

    
849
  def testWhitespaceAndMultipleLines(self):
850
    myfile = self._CreateTempFile()
851
    for nl in ["", "\n", "\r\n"]:
852
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
853
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
854
        utils.WriteFile(myfile, data=dummydata)
855
        datalax = ReadOneLineFile(myfile, strict=False)
856
        if nl:
857
          self.assert_(set("\r\n") & set(dummydata))
858
          self.assertRaises(errors.GenericError, ReadOneLineFile,
859
                            myfile, strict=True)
860
          explen = len("Foo bar baz ") + len(ws)
861
          self.assertEqual(len(datalax), explen)
862
          self.assertEqual(datalax, dummydata[:explen])
863
          self.assertFalse(set("\r\n") & set(datalax))
864
        else:
865
          datastrict = ReadOneLineFile(myfile, strict=True)
866
          self.assertEqual(dummydata, datastrict)
867
          self.assertEqual(dummydata, datalax)
868

    
869
  def testEmptylines(self):
870
    myfile = self._CreateTempFile()
871
    myline = "myline"
872
    for nl in ["\n", "\r\n"]:
873
      for ol in ["", "otherline"]:
874
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
875
        utils.WriteFile(myfile, data=dummydata)
876
        self.assert_(set("\r\n") & set(dummydata))
877
        datalax = ReadOneLineFile(myfile, strict=False)
878
        self.assertEqual(myline, datalax)
879
        if ol:
880
          self.assertRaises(errors.GenericError, ReadOneLineFile,
881
                            myfile, strict=True)
882
        else:
883
          datastrict = ReadOneLineFile(myfile, strict=True)
884
          self.assertEqual(myline, datastrict)
885

    
886

    
887
class TestTimestampForFilename(unittest.TestCase):
888
  def test(self):
889
    self.assert_("." not in utils.TimestampForFilename())
890
    self.assert_(":" not in utils.TimestampForFilename())
891

    
892

    
893
class TestCreateBackup(testutils.GanetiTestCase):
894
  def setUp(self):
895
    testutils.GanetiTestCase.setUp(self)
896

    
897
    self.tmpdir = tempfile.mkdtemp()
898

    
899
  def tearDown(self):
900
    testutils.GanetiTestCase.tearDown(self)
901

    
902
    shutil.rmtree(self.tmpdir)
903

    
904
  def testEmpty(self):
905
    filename = PathJoin(self.tmpdir, "config.data")
906
    utils.WriteFile(filename, data="")
907
    bname = utils.CreateBackup(filename)
908
    self.assertFileContent(bname, "")
909
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
910
    utils.CreateBackup(filename)
911
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
912
    utils.CreateBackup(filename)
913
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
914

    
915
    fifoname = PathJoin(self.tmpdir, "fifo")
916
    os.mkfifo(fifoname)
917
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
918

    
919
  def testContent(self):
920
    bkpcount = 0
921
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
922
      for rep in [1, 2, 10, 127]:
923
        testdata = data * rep
924

    
925
        filename = PathJoin(self.tmpdir, "test.data_")
926
        utils.WriteFile(filename, data=testdata)
927
        self.assertFileContent(filename, testdata)
928

    
929
        for _ in range(3):
930
          bname = utils.CreateBackup(filename)
931
          bkpcount += 1
932
          self.assertFileContent(bname, testdata)
933
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
934

    
935

    
936
class TestFormatUnit(unittest.TestCase):
937
  """Test case for the FormatUnit function"""
938

    
939
  def testMiB(self):
940
    self.assertEqual(FormatUnit(1, 'h'), '1M')
941
    self.assertEqual(FormatUnit(100, 'h'), '100M')
942
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
943

    
944
    self.assertEqual(FormatUnit(1, 'm'), '1')
945
    self.assertEqual(FormatUnit(100, 'm'), '100')
946
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
947

    
948
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
949
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
950
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
951
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
952

    
953
  def testGiB(self):
954
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
955
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
956
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
957
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
958

    
959
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
960
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
961
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
962
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
963

    
964
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
965
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
966
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
967

    
968
  def testTiB(self):
969
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
970
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
971
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
972

    
973
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
974
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
975
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
976

    
977

    
978
class TestParseUnit(unittest.TestCase):
979
  """Test case for the ParseUnit function"""
980

    
981
  SCALES = (('', 1),
982
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
983
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
984
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
985

    
986
  def testRounding(self):
987
    self.assertEqual(ParseUnit('0'), 0)
988
    self.assertEqual(ParseUnit('1'), 4)
989
    self.assertEqual(ParseUnit('2'), 4)
990
    self.assertEqual(ParseUnit('3'), 4)
991

    
992
    self.assertEqual(ParseUnit('124'), 124)
993
    self.assertEqual(ParseUnit('125'), 128)
994
    self.assertEqual(ParseUnit('126'), 128)
995
    self.assertEqual(ParseUnit('127'), 128)
996
    self.assertEqual(ParseUnit('128'), 128)
997
    self.assertEqual(ParseUnit('129'), 132)
998
    self.assertEqual(ParseUnit('130'), 132)
999

    
1000
  def testFloating(self):
1001
    self.assertEqual(ParseUnit('0'), 0)
1002
    self.assertEqual(ParseUnit('0.5'), 4)
1003
    self.assertEqual(ParseUnit('1.75'), 4)
1004
    self.assertEqual(ParseUnit('1.99'), 4)
1005
    self.assertEqual(ParseUnit('2.00'), 4)
1006
    self.assertEqual(ParseUnit('2.01'), 4)
1007
    self.assertEqual(ParseUnit('3.99'), 4)
1008
    self.assertEqual(ParseUnit('4.00'), 4)
1009
    self.assertEqual(ParseUnit('4.01'), 8)
1010
    self.assertEqual(ParseUnit('1.5G'), 1536)
1011
    self.assertEqual(ParseUnit('1.8G'), 1844)
1012
    self.assertEqual(ParseUnit('8.28T'), 8682212)
1013

    
1014
  def testSuffixes(self):
1015
    for sep in ('', ' ', '   ', "\t", "\t "):
1016
      for suffix, scale in TestParseUnit.SCALES:
1017
        for func in (lambda x: x, str.lower, str.upper):
1018
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
1019
                           1024 * scale)
1020

    
1021
  def testInvalidInput(self):
1022
    for sep in ('-', '_', ',', 'a'):
1023
      for suffix, _ in TestParseUnit.SCALES:
1024
        self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
1025

    
1026
    for suffix, _ in TestParseUnit.SCALES:
1027
      self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
1028

    
1029

    
1030
class TestParseCpuMask(unittest.TestCase):
1031
  """Test case for the ParseCpuMask function."""
1032

    
1033
  def testWellFormed(self):
1034
    self.assertEqual(utils.ParseCpuMask(""), [])
1035
    self.assertEqual(utils.ParseCpuMask("1"), [1])
1036
    self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
1037

    
1038
  def testInvalidInput(self):
1039
    self.assertRaises(errors.ParseError,
1040
                      utils.ParseCpuMask,
1041
                      "garbage")
1042
    self.assertRaises(errors.ParseError,
1043
                      utils.ParseCpuMask,
1044
                      "0,")
1045
    self.assertRaises(errors.ParseError,
1046
                      utils.ParseCpuMask,
1047
                      "0-1-2")
1048
    self.assertRaises(errors.ParseError,
1049
                      utils.ParseCpuMask,
1050
                      "2-1")
1051

    
1052
class TestSshKeys(testutils.GanetiTestCase):
1053
  """Test case for the AddAuthorizedKey function"""
1054

    
1055
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1056
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1057
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1058

    
1059
  def setUp(self):
1060
    testutils.GanetiTestCase.setUp(self)
1061
    self.tmpname = self._CreateTempFile()
1062
    handle = open(self.tmpname, 'w')
1063
    try:
1064
      handle.write("%s\n" % TestSshKeys.KEY_A)
1065
      handle.write("%s\n" % TestSshKeys.KEY_B)
1066
    finally:
1067
      handle.close()
1068

    
1069
  def testAddingNewKey(self):
1070
    utils.AddAuthorizedKey(self.tmpname,
1071
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1072

    
1073
    self.assertFileContent(self.tmpname,
1074
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1075
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1076
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1077
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1078

    
1079
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1080
    utils.AddAuthorizedKey(self.tmpname,
1081
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1082

    
1083
    self.assertFileContent(self.tmpname,
1084
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1085
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1086
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1087
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1088

    
1089
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1090
    utils.AddAuthorizedKey(self.tmpname,
1091
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1092

    
1093
    self.assertFileContent(self.tmpname,
1094
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1095
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1096
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1097

    
1098
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1099
    utils.RemoveAuthorizedKey(self.tmpname,
1100
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1101

    
1102
    self.assertFileContent(self.tmpname,
1103
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1104
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1105

    
1106
  def testRemovingNonExistingKey(self):
1107
    utils.RemoveAuthorizedKey(self.tmpname,
1108
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1109

    
1110
    self.assertFileContent(self.tmpname,
1111
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1112
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1113
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1114

    
1115

    
1116
class TestEtcHosts(testutils.GanetiTestCase):
1117
  """Test functions modifying /etc/hosts"""
1118

    
1119
  def setUp(self):
1120
    testutils.GanetiTestCase.setUp(self)
1121
    self.tmpname = self._CreateTempFile()
1122
    handle = open(self.tmpname, 'w')
1123
    try:
1124
      handle.write('# This is a test file for /etc/hosts\n')
1125
      handle.write('127.0.0.1\tlocalhost\n')
1126
      handle.write('192.0.2.1 router gw\n')
1127
    finally:
1128
      handle.close()
1129

    
1130
  def testSettingNewIp(self):
1131
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1132
                     ['myhost'])
1133

    
1134
    self.assertFileContent(self.tmpname,
1135
      "# This is a test file for /etc/hosts\n"
1136
      "127.0.0.1\tlocalhost\n"
1137
      "192.0.2.1 router gw\n"
1138
      "198.51.100.4\tmyhost.example.com myhost\n")
1139
    self.assertFileMode(self.tmpname, 0644)
1140

    
1141
  def testSettingExistingIp(self):
1142
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1143
                     ['myhost'])
1144

    
1145
    self.assertFileContent(self.tmpname,
1146
      "# This is a test file for /etc/hosts\n"
1147
      "127.0.0.1\tlocalhost\n"
1148
      "192.0.2.1\tmyhost.example.com myhost\n")
1149
    self.assertFileMode(self.tmpname, 0644)
1150

    
1151
  def testSettingDuplicateName(self):
1152
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1153

    
1154
    self.assertFileContent(self.tmpname,
1155
      "# This is a test file for /etc/hosts\n"
1156
      "127.0.0.1\tlocalhost\n"
1157
      "192.0.2.1 router gw\n"
1158
      "198.51.100.4\tmyhost\n")
1159
    self.assertFileMode(self.tmpname, 0644)
1160

    
1161
  def testRemovingExistingHost(self):
1162
    RemoveEtcHostsEntry(self.tmpname, 'router')
1163

    
1164
    self.assertFileContent(self.tmpname,
1165
      "# This is a test file for /etc/hosts\n"
1166
      "127.0.0.1\tlocalhost\n"
1167
      "192.0.2.1 gw\n")
1168
    self.assertFileMode(self.tmpname, 0644)
1169

    
1170
  def testRemovingSingleExistingHost(self):
1171
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1172

    
1173
    self.assertFileContent(self.tmpname,
1174
      "# This is a test file for /etc/hosts\n"
1175
      "192.0.2.1 router gw\n")
1176
    self.assertFileMode(self.tmpname, 0644)
1177

    
1178
  def testRemovingNonExistingHost(self):
1179
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1180

    
1181
    self.assertFileContent(self.tmpname,
1182
      "# This is a test file for /etc/hosts\n"
1183
      "127.0.0.1\tlocalhost\n"
1184
      "192.0.2.1 router gw\n")
1185
    self.assertFileMode(self.tmpname, 0644)
1186

    
1187
  def testRemovingAlias(self):
1188
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1189

    
1190
    self.assertFileContent(self.tmpname,
1191
      "# This is a test file for /etc/hosts\n"
1192
      "127.0.0.1\tlocalhost\n"
1193
      "192.0.2.1 router\n")
1194
    self.assertFileMode(self.tmpname, 0644)
1195

    
1196

    
1197
class TestGetMounts(unittest.TestCase):
1198
  """Test case for GetMounts()."""
1199

    
1200
  TESTDATA = (
1201
    "rootfs /     rootfs rw 0 0\n"
1202
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1203
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1204

    
1205
  def setUp(self):
1206
    self.tmpfile = tempfile.NamedTemporaryFile()
1207
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1208

    
1209
  def testGetMounts(self):
1210
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1211
      [
1212
        ("rootfs", "/", "rootfs", "rw"),
1213
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1214
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1215
      ])
1216

    
1217

    
1218
class TestShellQuoting(unittest.TestCase):
1219
  """Test case for shell quoting functions"""
1220

    
1221
  def testShellQuote(self):
1222
    self.assertEqual(ShellQuote('abc'), "abc")
1223
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1224
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1225
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1226
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1227

    
1228
  def testShellQuoteArgs(self):
1229
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1230
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1231
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1232

    
1233

    
1234
class TestListVisibleFiles(unittest.TestCase):
1235
  """Test case for ListVisibleFiles"""
1236

    
1237
  def setUp(self):
1238
    self.path = tempfile.mkdtemp()
1239

    
1240
  def tearDown(self):
1241
    shutil.rmtree(self.path)
1242

    
1243
  def _CreateFiles(self, files):
1244
    for name in files:
1245
      utils.WriteFile(os.path.join(self.path, name), data="test")
1246

    
1247
  def _test(self, files, expected):
1248
    self._CreateFiles(files)
1249
    found = ListVisibleFiles(self.path)
1250
    self.assertEqual(set(found), set(expected))
1251

    
1252
  def testAllVisible(self):
1253
    files = ["a", "b", "c"]
1254
    expected = files
1255
    self._test(files, expected)
1256

    
1257
  def testNoneVisible(self):
1258
    files = [".a", ".b", ".c"]
1259
    expected = []
1260
    self._test(files, expected)
1261

    
1262
  def testSomeVisible(self):
1263
    files = ["a", "b", ".c"]
1264
    expected = ["a", "b"]
1265
    self._test(files, expected)
1266

    
1267
  def testNonAbsolutePath(self):
1268
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1269

    
1270
  def testNonNormalizedPath(self):
1271
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1272
                          "/bin/../tmp")
1273

    
1274

    
1275
class TestNewUUID(unittest.TestCase):
1276
  """Test case for NewUUID"""
1277

    
1278
  def runTest(self):
1279
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1280

    
1281

    
1282
class TestUniqueSequence(unittest.TestCase):
1283
  """Test case for UniqueSequence"""
1284

    
1285
  def _test(self, input, expected):
1286
    self.assertEqual(utils.UniqueSequence(input), expected)
1287

    
1288
  def runTest(self):
1289
    # Ordered input
1290
    self._test([1, 2, 3], [1, 2, 3])
1291
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1292
    self._test([1, 2, 2, 3], [1, 2, 3])
1293
    self._test([1, 2, 3, 3], [1, 2, 3])
1294

    
1295
    # Unordered input
1296
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1297
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1298

    
1299
    # Strings
1300
    self._test(["a", "a"], ["a"])
1301
    self._test(["a", "b"], ["a", "b"])
1302
    self._test(["a", "b", "a"], ["a", "b"])
1303

    
1304

    
1305
class TestFirstFree(unittest.TestCase):
1306
  """Test case for the FirstFree function"""
1307

    
1308
  def test(self):
1309
    """Test FirstFree"""
1310
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1311
    self.failUnlessEqual(FirstFree([]), None)
1312
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1313
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1314
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1315

    
1316

    
1317
class TestTailFile(testutils.GanetiTestCase):
1318
  """Test case for the TailFile function"""
1319

    
1320
  def testEmpty(self):
1321
    fname = self._CreateTempFile()
1322
    self.failUnlessEqual(TailFile(fname), [])
1323
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1324

    
1325
  def testAllLines(self):
1326
    data = ["test %d" % i for i in range(30)]
1327
    for i in range(30):
1328
      fname = self._CreateTempFile()
1329
      fd = open(fname, "w")
1330
      fd.write("\n".join(data[:i]))
1331
      if i > 0:
1332
        fd.write("\n")
1333
      fd.close()
1334
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1335

    
1336
  def testPartialLines(self):
1337
    data = ["test %d" % i for i in range(30)]
1338
    fname = self._CreateTempFile()
1339
    fd = open(fname, "w")
1340
    fd.write("\n".join(data))
1341
    fd.write("\n")
1342
    fd.close()
1343
    for i in range(1, 30):
1344
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1345

    
1346
  def testBigFile(self):
1347
    data = ["test %d" % i for i in range(30)]
1348
    fname = self._CreateTempFile()
1349
    fd = open(fname, "w")
1350
    fd.write("X" * 1048576)
1351
    fd.write("\n")
1352
    fd.write("\n".join(data))
1353
    fd.write("\n")
1354
    fd.close()
1355
    for i in range(1, 30):
1356
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1357

    
1358

    
1359
class _BaseFileLockTest:
1360
  """Test case for the FileLock class"""
1361

    
1362
  def testSharedNonblocking(self):
1363
    self.lock.Shared(blocking=False)
1364
    self.lock.Close()
1365

    
1366
  def testExclusiveNonblocking(self):
1367
    self.lock.Exclusive(blocking=False)
1368
    self.lock.Close()
1369

    
1370
  def testUnlockNonblocking(self):
1371
    self.lock.Unlock(blocking=False)
1372
    self.lock.Close()
1373

    
1374
  def testSharedBlocking(self):
1375
    self.lock.Shared(blocking=True)
1376
    self.lock.Close()
1377

    
1378
  def testExclusiveBlocking(self):
1379
    self.lock.Exclusive(blocking=True)
1380
    self.lock.Close()
1381

    
1382
  def testUnlockBlocking(self):
1383
    self.lock.Unlock(blocking=True)
1384
    self.lock.Close()
1385

    
1386
  def testSharedExclusiveUnlock(self):
1387
    self.lock.Shared(blocking=False)
1388
    self.lock.Exclusive(blocking=False)
1389
    self.lock.Unlock(blocking=False)
1390
    self.lock.Close()
1391

    
1392
  def testExclusiveSharedUnlock(self):
1393
    self.lock.Exclusive(blocking=False)
1394
    self.lock.Shared(blocking=False)
1395
    self.lock.Unlock(blocking=False)
1396
    self.lock.Close()
1397

    
1398
  def testSimpleTimeout(self):
1399
    # These will succeed on the first attempt, hence a short timeout
1400
    self.lock.Shared(blocking=True, timeout=10.0)
1401
    self.lock.Exclusive(blocking=False, timeout=10.0)
1402
    self.lock.Unlock(blocking=True, timeout=10.0)
1403
    self.lock.Close()
1404

    
1405
  @staticmethod
1406
  def _TryLockInner(filename, shared, blocking):
1407
    lock = utils.FileLock.Open(filename)
1408

    
1409
    if shared:
1410
      fn = lock.Shared
1411
    else:
1412
      fn = lock.Exclusive
1413

    
1414
    try:
1415
      # The timeout doesn't really matter as the parent process waits for us to
1416
      # finish anyway.
1417
      fn(blocking=blocking, timeout=0.01)
1418
    except errors.LockError, err:
1419
      return False
1420

    
1421
    return True
1422

    
1423
  def _TryLock(self, *args):
1424
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1425
                                      *args)
1426

    
1427
  def testTimeout(self):
1428
    for blocking in [True, False]:
1429
      self.lock.Exclusive(blocking=True)
1430
      self.failIf(self._TryLock(False, blocking))
1431
      self.failIf(self._TryLock(True, blocking))
1432

    
1433
      self.lock.Shared(blocking=True)
1434
      self.assert_(self._TryLock(True, blocking))
1435
      self.failIf(self._TryLock(False, blocking))
1436

    
1437
  def testCloseShared(self):
1438
    self.lock.Close()
1439
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1440

    
1441
  def testCloseExclusive(self):
1442
    self.lock.Close()
1443
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1444

    
1445
  def testCloseUnlock(self):
1446
    self.lock.Close()
1447
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1448

    
1449

    
1450
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1451
  TESTDATA = "Hello World\n" * 10
1452

    
1453
  def setUp(self):
1454
    testutils.GanetiTestCase.setUp(self)
1455

    
1456
    self.tmpfile = tempfile.NamedTemporaryFile()
1457
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1458
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1459

    
1460
    # Ensure "Open" didn't truncate file
1461
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1462

    
1463
  def tearDown(self):
1464
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1465

    
1466
    testutils.GanetiTestCase.tearDown(self)
1467

    
1468

    
1469
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1470
  def setUp(self):
1471
    self.tmpfile = tempfile.NamedTemporaryFile()
1472
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1473

    
1474

    
1475
class TestTimeFunctions(unittest.TestCase):
1476
  """Test case for time functions"""
1477

    
1478
  def runTest(self):
1479
    self.assertEqual(utils.SplitTime(1), (1, 0))
1480
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1481
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1482
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1483
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1484
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1485
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1486
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1487

    
1488
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1489

    
1490
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1491
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1492
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1493

    
1494
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1495
                     1218448917.481)
1496
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1497

    
1498
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1499
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1500
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1501
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1502
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1503

    
1504

    
1505
class FieldSetTestCase(unittest.TestCase):
1506
  """Test case for FieldSets"""
1507

    
1508
  def testSimpleMatch(self):
1509
    f = utils.FieldSet("a", "b", "c", "def")
1510
    self.failUnless(f.Matches("a"))
1511
    self.failIf(f.Matches("d"), "Substring matched")
1512
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1513
    self.failIf(f.NonMatching(["b", "c"]))
1514
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1515
    self.failUnless(f.NonMatching(["a", "d"]))
1516

    
1517
  def testRegexMatch(self):
1518
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1519
    self.failUnless(f.Matches("b1"))
1520
    self.failUnless(f.Matches("b99"))
1521
    self.failIf(f.Matches("b/1"))
1522
    self.failIf(f.NonMatching(["b12", "c"]))
1523
    self.failUnless(f.NonMatching(["a", "1"]))
1524

    
1525
class TestForceDictType(unittest.TestCase):
1526
  """Test case for ForceDictType"""
1527

    
1528
  def setUp(self):
1529
    self.key_types = {
1530
      'a': constants.VTYPE_INT,
1531
      'b': constants.VTYPE_BOOL,
1532
      'c': constants.VTYPE_STRING,
1533
      'd': constants.VTYPE_SIZE,
1534
      "e": constants.VTYPE_MAYBE_STRING,
1535
      }
1536

    
1537
  def _fdt(self, dict, allowed_values=None):
1538
    if allowed_values is None:
1539
      utils.ForceDictType(dict, self.key_types)
1540
    else:
1541
      utils.ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1542

    
1543
    return dict
1544

    
1545
  def testSimpleDict(self):
1546
    self.assertEqual(self._fdt({}), {})
1547
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1548
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1549
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1550
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1551
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1552
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1553
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1554
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1555
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1556
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1557
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1558
    self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1559
    self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1560
    self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1561

    
1562
  def testErrors(self):
1563
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1564
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1565
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1566
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1567
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1568
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1569

    
1570

    
1571
class TestIsNormAbsPath(unittest.TestCase):
1572
  """Testing case for IsNormAbsPath"""
1573

    
1574
  def _pathTestHelper(self, path, result):
1575
    if result:
1576
      self.assert_(utils.IsNormAbsPath(path),
1577
          "Path %s should result absolute and normalized" % path)
1578
    else:
1579
      self.assertFalse(utils.IsNormAbsPath(path),
1580
          "Path %s should not result absolute and normalized" % path)
1581

    
1582
  def testBase(self):
1583
    self._pathTestHelper('/etc', True)
1584
    self._pathTestHelper('/srv', True)
1585
    self._pathTestHelper('etc', False)
1586
    self._pathTestHelper('/etc/../root', False)
1587
    self._pathTestHelper('/etc/', False)
1588

    
1589

    
1590
class TestSafeEncode(unittest.TestCase):
1591
  """Test case for SafeEncode"""
1592

    
1593
  def testAscii(self):
1594
    for txt in [string.digits, string.letters, string.punctuation]:
1595
      self.failUnlessEqual(txt, SafeEncode(txt))
1596

    
1597
  def testDoubleEncode(self):
1598
    for i in range(255):
1599
      txt = SafeEncode(chr(i))
1600
      self.failUnlessEqual(txt, SafeEncode(txt))
1601

    
1602
  def testUnicode(self):
1603
    # 1024 is high enough to catch non-direct ASCII mappings
1604
    for i in range(1024):
1605
      txt = SafeEncode(unichr(i))
1606
      self.failUnlessEqual(txt, SafeEncode(txt))
1607

    
1608

    
1609
class TestFormatTime(unittest.TestCase):
1610
  """Testing case for FormatTime"""
1611

    
1612
  def testNone(self):
1613
    self.failUnlessEqual(FormatTime(None), "N/A")
1614

    
1615
  def testInvalid(self):
1616
    self.failUnlessEqual(FormatTime(()), "N/A")
1617

    
1618
  def testNow(self):
1619
    # tests that we accept time.time input
1620
    FormatTime(time.time())
1621
    # tests that we accept int input
1622
    FormatTime(int(time.time()))
1623

    
1624

    
1625
class RunInSeparateProcess(unittest.TestCase):
1626
  def test(self):
1627
    for exp in [True, False]:
1628
      def _child():
1629
        return exp
1630

    
1631
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1632

    
1633
  def testArgs(self):
1634
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1635
      def _child(carg1, carg2):
1636
        return carg1 == "Foo" and carg2 == arg
1637

    
1638
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1639

    
1640
  def testPid(self):
1641
    parent_pid = os.getpid()
1642

    
1643
    def _check():
1644
      return os.getpid() == parent_pid
1645

    
1646
    self.failIf(utils.RunInSeparateProcess(_check))
1647

    
1648
  def testSignal(self):
1649
    def _kill():
1650
      os.kill(os.getpid(), signal.SIGTERM)
1651

    
1652
    self.assertRaises(errors.GenericError,
1653
                      utils.RunInSeparateProcess, _kill)
1654

    
1655
  def testException(self):
1656
    def _exc():
1657
      raise errors.GenericError("This is a test")
1658

    
1659
    self.assertRaises(errors.GenericError,
1660
                      utils.RunInSeparateProcess, _exc)
1661

    
1662

    
1663
class TestFingerprintFile(unittest.TestCase):
1664
  def setUp(self):
1665
    self.tmpfile = tempfile.NamedTemporaryFile()
1666

    
1667
  def test(self):
1668
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1669
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1670

    
1671
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1672
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1673
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1674

    
1675

    
1676
class TestUnescapeAndSplit(unittest.TestCase):
1677
  """Testing case for UnescapeAndSplit"""
1678

    
1679
  def setUp(self):
1680
    # testing more that one separator for regexp safety
1681
    self._seps = [",", "+", "."]
1682

    
1683
  def testSimple(self):
1684
    a = ["a", "b", "c", "d"]
1685
    for sep in self._seps:
1686
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1687

    
1688
  def testEscape(self):
1689
    for sep in self._seps:
1690
      a = ["a", "b\\" + sep + "c", "d"]
1691
      b = ["a", "b" + sep + "c", "d"]
1692
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1693

    
1694
  def testDoubleEscape(self):
1695
    for sep in self._seps:
1696
      a = ["a", "b\\\\", "c", "d"]
1697
      b = ["a", "b\\", "c", "d"]
1698
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1699

    
1700
  def testThreeEscape(self):
1701
    for sep in self._seps:
1702
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1703
      b = ["a", "b\\" + sep + "c", "d"]
1704
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1705

    
1706

    
1707
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1708
  def setUp(self):
1709
    self.tmpdir = tempfile.mkdtemp()
1710

    
1711
  def tearDown(self):
1712
    shutil.rmtree(self.tmpdir)
1713

    
1714
  def _checkRsaPrivateKey(self, key):
1715
    lines = key.splitlines()
1716
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1717
            "-----END RSA PRIVATE KEY-----" in lines)
1718

    
1719
  def _checkCertificate(self, cert):
1720
    lines = cert.splitlines()
1721
    return ("-----BEGIN CERTIFICATE-----" in lines and
1722
            "-----END CERTIFICATE-----" in lines)
1723

    
1724
  def test(self):
1725
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1726
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1727
      self._checkRsaPrivateKey(key_pem)
1728
      self._checkCertificate(cert_pem)
1729

    
1730
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1731
                                           key_pem)
1732
      self.assert_(key.bits() >= 1024)
1733
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1734
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1735

    
1736
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1737
                                             cert_pem)
1738
      self.failIf(x509.has_expired())
1739
      self.assertEqual(x509.get_issuer().CN, common_name)
1740
      self.assertEqual(x509.get_subject().CN, common_name)
1741
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1742

    
1743
  def testLegacy(self):
1744
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1745

    
1746
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1747

    
1748
    cert1 = utils.ReadFile(cert1_filename)
1749

    
1750
    self.assert_(self._checkRsaPrivateKey(cert1))
1751
    self.assert_(self._checkCertificate(cert1))
1752

    
1753

    
1754
class TestPathJoin(unittest.TestCase):
1755
  """Testing case for PathJoin"""
1756

    
1757
  def testBasicItems(self):
1758
    mlist = ["/a", "b", "c"]
1759
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1760

    
1761
  def testNonAbsPrefix(self):
1762
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1763

    
1764
  def testBackTrack(self):
1765
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1766

    
1767
  def testMultiAbs(self):
1768
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1769

    
1770

    
1771
class TestValidateServiceName(unittest.TestCase):
1772
  def testValid(self):
1773
    testnames = [
1774
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1775
      "ganeti",
1776
      "gnt-masterd",
1777
      "HELLO_WORLD_SVC",
1778
      "hello.world.1",
1779
      "0", "80", "1111", "65535",
1780
      ]
1781

    
1782
    for name in testnames:
1783
      self.assertEqual(utils.ValidateServiceName(name), name)
1784

    
1785
  def testInvalid(self):
1786
    testnames = [
1787
      -15756, -1, 65536, 133428083,
1788
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1789
      "-8546", "-1", "65536",
1790
      (129 * "A"),
1791
      ]
1792

    
1793
    for name in testnames:
1794
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1795

    
1796

    
1797
class TestParseAsn1Generalizedtime(unittest.TestCase):
1798
  def test(self):
1799
    # UTC
1800
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1801
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1802
                     1266860512)
1803
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1804
                     (2**31) - 1)
1805

    
1806
    # With offset
1807
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1808
                     1266860512)
1809
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1810
                     1266931012)
1811
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1812
                     1266931088)
1813
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1814
                     1266931295)
1815
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1816
                     3600)
1817

    
1818
    # Leap seconds are not supported by datetime.datetime
1819
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1820
                      "19841231235960+0000")
1821
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1822
                      "19920630235960+0000")
1823

    
1824
    # Errors
1825
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1826
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1827
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1828
                      "20100222174152")
1829
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1830
                      "Mon Feb 22 17:47:02 UTC 2010")
1831
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1832
                      "2010-02-22 17:42:02")
1833

    
1834

    
1835
class TestGetX509CertValidity(testutils.GanetiTestCase):
1836
  def setUp(self):
1837
    testutils.GanetiTestCase.setUp(self)
1838

    
1839
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1840

    
1841
    # Test whether we have pyOpenSSL 0.7 or above
1842
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1843

    
1844
    if not self.pyopenssl0_7:
1845
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1846
                    " function correctly")
1847

    
1848
  def _LoadCert(self, name):
1849
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1850
                                           self._ReadTestData(name))
1851

    
1852
  def test(self):
1853
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1854
    if self.pyopenssl0_7:
1855
      self.assertEqual(validity, (1266919967, 1267524767))
1856
    else:
1857
      self.assertEqual(validity, (None, None))
1858

    
1859

    
1860
class TestSignX509Certificate(unittest.TestCase):
1861
  KEY = "My private key!"
1862
  KEY_OTHER = "Another key"
1863

    
1864
  def test(self):
1865
    # Generate certificate valid for 5 minutes
1866
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1867

    
1868
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1869
                                           cert_pem)
1870

    
1871
    # No signature at all
1872
    self.assertRaises(errors.GenericError,
1873
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1874

    
1875
    # Invalid input
1876
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1877
                      "", self.KEY)
1878
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1879
                      "X-Ganeti-Signature: \n", self.KEY)
1880
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1881
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1882
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1883
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1884
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1885
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1886

    
1887
    # Invalid salt
1888
    for salt in list("-_@$,:;/\\ \t\n"):
1889
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1890
                        cert_pem, self.KEY, "foo%sbar" % salt)
1891

    
1892
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1893
                 utils.GenerateSecret(numbytes=4),
1894
                 utils.GenerateSecret(numbytes=16),
1895
                 "{123:456}".encode("hex")]:
1896
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1897

    
1898
      self._Check(cert, salt, signed_pem)
1899

    
1900
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1901
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1902
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1903
                               "lines----\n------ at\nthe end!"))
1904

    
1905
  def _Check(self, cert, salt, pem):
1906
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1907
    self.assertEqual(salt, salt2)
1908
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1909

    
1910
    # Other key
1911
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1912
                      pem, self.KEY_OTHER)
1913

    
1914

    
1915
class TestMakedirs(unittest.TestCase):
1916
  def setUp(self):
1917
    self.tmpdir = tempfile.mkdtemp()
1918

    
1919
  def tearDown(self):
1920
    shutil.rmtree(self.tmpdir)
1921

    
1922
  def testNonExisting(self):
1923
    path = PathJoin(self.tmpdir, "foo")
1924
    utils.Makedirs(path)
1925
    self.assert_(os.path.isdir(path))
1926

    
1927
  def testExisting(self):
1928
    path = PathJoin(self.tmpdir, "foo")
1929
    os.mkdir(path)
1930
    utils.Makedirs(path)
1931
    self.assert_(os.path.isdir(path))
1932

    
1933
  def testRecursiveNonExisting(self):
1934
    path = PathJoin(self.tmpdir, "foo/bar/baz")
1935
    utils.Makedirs(path)
1936
    self.assert_(os.path.isdir(path))
1937

    
1938
  def testRecursiveExisting(self):
1939
    path = PathJoin(self.tmpdir, "B/moo/xyz")
1940
    self.assertFalse(os.path.exists(path))
1941
    os.mkdir(PathJoin(self.tmpdir, "B"))
1942
    utils.Makedirs(path)
1943
    self.assert_(os.path.isdir(path))
1944

    
1945

    
1946
class TestRetry(testutils.GanetiTestCase):
1947
  def setUp(self):
1948
    testutils.GanetiTestCase.setUp(self)
1949
    self.retries = 0
1950

    
1951
  @staticmethod
1952
  def _RaiseRetryAgain():
1953
    raise utils.RetryAgain()
1954

    
1955
  @staticmethod
1956
  def _RaiseRetryAgainWithArg(args):
1957
    raise utils.RetryAgain(*args)
1958

    
1959
  def _WrongNestedLoop(self):
1960
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1961

    
1962
  def _RetryAndSucceed(self, retries):
1963
    if self.retries < retries:
1964
      self.retries += 1
1965
      raise utils.RetryAgain()
1966
    else:
1967
      return True
1968

    
1969
  def testRaiseTimeout(self):
1970
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1971
                          self._RaiseRetryAgain, 0.01, 0.02)
1972
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1973
                          self._RetryAndSucceed, 0.01, 0, args=[1])
1974
    self.failUnlessEqual(self.retries, 1)
1975

    
1976
  def testComplete(self):
1977
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1978
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1979
                         True)
1980
    self.failUnlessEqual(self.retries, 2)
1981

    
1982
  def testNestedLoop(self):
1983
    try:
1984
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1985
                            self._WrongNestedLoop, 0, 1)
1986
    except utils.RetryTimeout:
1987
      self.fail("Didn't detect inner loop's exception")
1988

    
1989
  def testTimeoutArgument(self):
1990
    retry_arg="my_important_debugging_message"
1991
    try:
1992
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
1993
    except utils.RetryTimeout, err:
1994
      self.failUnlessEqual(err.args, (retry_arg, ))
1995
    else:
1996
      self.fail("Expected timeout didn't happen")
1997

    
1998
  def testRaiseInnerWithExc(self):
1999
    retry_arg="my_important_debugging_message"
2000
    try:
2001
      try:
2002
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2003
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
2004
      except utils.RetryTimeout, err:
2005
        err.RaiseInner()
2006
      else:
2007
        self.fail("Expected timeout didn't happen")
2008
    except errors.GenericError, err:
2009
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2010
    else:
2011
      self.fail("Expected GenericError didn't happen")
2012

    
2013
  def testRaiseInnerWithMsg(self):
2014
    retry_arg="my_important_debugging_message"
2015
    try:
2016
      try:
2017
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2018
                    args=[[retry_arg, retry_arg]])
2019
      except utils.RetryTimeout, err:
2020
        err.RaiseInner()
2021
      else:
2022
        self.fail("Expected timeout didn't happen")
2023
    except utils.RetryTimeout, err:
2024
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2025
    else:
2026
      self.fail("Expected RetryTimeout didn't happen")
2027

    
2028

    
2029
class TestLineSplitter(unittest.TestCase):
2030
  def test(self):
2031
    lines = []
2032
    ls = utils.LineSplitter(lines.append)
2033
    ls.write("Hello World\n")
2034
    self.assertEqual(lines, [])
2035
    ls.write("Foo\n Bar\r\n ")
2036
    ls.write("Baz")
2037
    ls.write("Moo")
2038
    self.assertEqual(lines, [])
2039
    ls.flush()
2040
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2041
    ls.close()
2042
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2043

    
2044
  def _testExtra(self, line, all_lines, p1, p2):
2045
    self.assertEqual(p1, 999)
2046
    self.assertEqual(p2, "extra")
2047
    all_lines.append(line)
2048

    
2049
  def testExtraArgsNoFlush(self):
2050
    lines = []
2051
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2052
    ls.write("\n\nHello World\n")
2053
    ls.write("Foo\n Bar\r\n ")
2054
    ls.write("")
2055
    ls.write("Baz")
2056
    ls.write("Moo\n\nx\n")
2057
    self.assertEqual(lines, [])
2058
    ls.close()
2059
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2060
                             "", "x"])
2061

    
2062

    
2063
class TestReadLockedPidFile(unittest.TestCase):
2064
  def setUp(self):
2065
    self.tmpdir = tempfile.mkdtemp()
2066

    
2067
  def tearDown(self):
2068
    shutil.rmtree(self.tmpdir)
2069

    
2070
  def testNonExistent(self):
2071
    path = PathJoin(self.tmpdir, "nonexist")
2072
    self.assert_(utils.ReadLockedPidFile(path) is None)
2073

    
2074
  def testUnlocked(self):
2075
    path = PathJoin(self.tmpdir, "pid")
2076
    utils.WriteFile(path, data="123")
2077
    self.assert_(utils.ReadLockedPidFile(path) is None)
2078

    
2079
  def testLocked(self):
2080
    path = PathJoin(self.tmpdir, "pid")
2081
    utils.WriteFile(path, data="123")
2082

    
2083
    fl = utils.FileLock.Open(path)
2084
    try:
2085
      fl.Exclusive(blocking=True)
2086

    
2087
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2088
    finally:
2089
      fl.Close()
2090

    
2091
    self.assert_(utils.ReadLockedPidFile(path) is None)
2092

    
2093
  def testError(self):
2094
    path = PathJoin(self.tmpdir, "foobar", "pid")
2095
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2096
    # open(2) should return ENOTDIR
2097
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2098

    
2099

    
2100
class TestCertVerification(testutils.GanetiTestCase):
2101
  def setUp(self):
2102
    testutils.GanetiTestCase.setUp(self)
2103

    
2104
    self.tmpdir = tempfile.mkdtemp()
2105

    
2106
  def tearDown(self):
2107
    shutil.rmtree(self.tmpdir)
2108

    
2109
  def testVerifyCertificate(self):
2110
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2111
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2112
                                           cert_pem)
2113

    
2114
    # Not checking return value as this certificate is expired
2115
    utils.VerifyX509Certificate(cert, 30, 7)
2116

    
2117

    
2118
class TestVerifyCertificateInner(unittest.TestCase):
2119
  def test(self):
2120
    vci = utils._VerifyCertificateInner
2121

    
2122
    # Valid
2123
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2124
                     (None, None))
2125

    
2126
    # Not yet valid
2127
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2128
    self.assertEqual(errcode, utils.CERT_WARNING)
2129

    
2130
    # Expiring soon
2131
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2132
    self.assertEqual(errcode, utils.CERT_ERROR)
2133

    
2134
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2135
    self.assertEqual(errcode, utils.CERT_WARNING)
2136

    
2137
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2138
    self.assertEqual(errcode, None)
2139

    
2140
    # Expired
2141
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2142
    self.assertEqual(errcode, utils.CERT_ERROR)
2143

    
2144
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2145
    self.assertEqual(errcode, utils.CERT_ERROR)
2146

    
2147
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2148
    self.assertEqual(errcode, utils.CERT_ERROR)
2149

    
2150
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2151
    self.assertEqual(errcode, utils.CERT_ERROR)
2152

    
2153

    
2154
class TestHmacFunctions(unittest.TestCase):
2155
  # Digests can be checked with "openssl sha1 -hmac $key"
2156
  def testSha1Hmac(self):
2157
    self.assertEqual(utils.Sha1Hmac("", ""),
2158
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2159
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2160
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2161
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2162
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2163

    
2164
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2165
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2166
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2167

    
2168
  def testSha1HmacSalt(self):
2169
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2170
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2171
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2172
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2173
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2174
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2175

    
2176
  def testVerifySha1Hmac(self):
2177
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2178
                                               "7d64b71fb76370690e1d")))
2179
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2180
                                      ("f904c2476527c6d3e660"
2181
                                       "9ab683c66fa0652cb1dc")))
2182

    
2183
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2184
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2185
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2186
                                      digest.lower()))
2187
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2188
                                      digest.upper()))
2189
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2190
                                      digest.title()))
2191

    
2192
  def testVerifySha1HmacSalt(self):
2193
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2194
                                      ("17a4adc34d69c0d367d4"
2195
                                       "ffbef96fd41d4df7a6e8"),
2196
                                      salt="abc9"))
2197
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2198
                                      ("7f264f8114c9066afc9b"
2199
                                       "b7636e1786d996d3cc0d"),
2200
                                      salt="xyz0"))
2201

    
2202

    
2203
class TestIgnoreSignals(unittest.TestCase):
2204
  """Test the IgnoreSignals decorator"""
2205

    
2206
  @staticmethod
2207
  def _Raise(exception):
2208
    raise exception
2209

    
2210
  @staticmethod
2211
  def _Return(rval):
2212
    return rval
2213

    
2214
  def testIgnoreSignals(self):
2215
    sock_err_intr = socket.error(errno.EINTR, "Message")
2216
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2217

    
2218
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2219
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2220

    
2221
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2222
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2223
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2224
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2225

    
2226
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2227
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2228
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2229
                      sock_err_inval)
2230
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2231
                      env_err_inval)
2232

    
2233
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2234
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2235

    
2236

    
2237
class TestEnsureDirs(unittest.TestCase):
2238
  """Tests for EnsureDirs"""
2239

    
2240
  def setUp(self):
2241
    self.dir = tempfile.mkdtemp()
2242
    self.old_umask = os.umask(0777)
2243

    
2244
  def testEnsureDirs(self):
2245
    utils.EnsureDirs([
2246
        (PathJoin(self.dir, "foo"), 0777),
2247
        (PathJoin(self.dir, "bar"), 0000),
2248
        ])
2249
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2250
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2251

    
2252
  def tearDown(self):
2253
    os.rmdir(PathJoin(self.dir, "foo"))
2254
    os.rmdir(PathJoin(self.dir, "bar"))
2255
    os.rmdir(self.dir)
2256
    os.umask(self.old_umask)
2257

    
2258

    
2259
class TestFormatSeconds(unittest.TestCase):
2260
  def test(self):
2261
    self.assertEqual(utils.FormatSeconds(1), "1s")
2262
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2263
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2264
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2265
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2266
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2267
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2268
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2269
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2270
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2271
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2272

    
2273
  def testFloat(self):
2274
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2275
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2276
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2277
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2278

    
2279

    
2280
class TestIgnoreProcessNotFound(unittest.TestCase):
2281
  @staticmethod
2282
  def _WritePid(fd):
2283
    os.write(fd, str(os.getpid()))
2284
    os.close(fd)
2285
    return True
2286

    
2287
  def test(self):
2288
    (pid_read_fd, pid_write_fd) = os.pipe()
2289

    
2290
    # Start short-lived process which writes its PID to pipe
2291
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2292
    os.close(pid_write_fd)
2293

    
2294
    # Read PID from pipe
2295
    pid = int(os.read(pid_read_fd, 1024))
2296
    os.close(pid_read_fd)
2297

    
2298
    # Try to send signal to process which exited recently
2299
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2300

    
2301

    
2302
class TestShellWriter(unittest.TestCase):
2303
  def test(self):
2304
    buf = StringIO()
2305
    sw = utils.ShellWriter(buf)
2306
    sw.Write("#!/bin/bash")
2307
    sw.Write("if true; then")
2308
    sw.IncIndent()
2309
    try:
2310
      sw.Write("echo true")
2311

    
2312
      sw.Write("for i in 1 2 3")
2313
      sw.Write("do")
2314
      sw.IncIndent()
2315
      try:
2316
        self.assertEqual(sw._indent, 2)
2317
        sw.Write("date")
2318
      finally:
2319
        sw.DecIndent()
2320
      sw.Write("done")
2321
    finally:
2322
      sw.DecIndent()
2323
    sw.Write("echo %s", utils.ShellQuote("Hello World"))
2324
    sw.Write("exit 0")
2325

    
2326
    self.assertEqual(sw._indent, 0)
2327

    
2328
    output = buf.getvalue()
2329

    
2330
    self.assert_(output.endswith("\n"))
2331

    
2332
    lines = output.splitlines()
2333
    self.assertEqual(len(lines), 9)
2334
    self.assertEqual(lines[0], "#!/bin/bash")
2335
    self.assert_(re.match(r"^\s+date$", lines[5]))
2336
    self.assertEqual(lines[7], "echo 'Hello World'")
2337

    
2338
  def testEmpty(self):
2339
    buf = StringIO()
2340
    sw = utils.ShellWriter(buf)
2341
    sw = None
2342
    self.assertEqual(buf.getvalue(), "")
2343

    
2344

    
2345
class TestCommaJoin(unittest.TestCase):
2346
  def test(self):
2347
    self.assertEqual(utils.CommaJoin([]), "")
2348
    self.assertEqual(utils.CommaJoin([1, 2, 3]), "1, 2, 3")
2349
    self.assertEqual(utils.CommaJoin(["Hello"]), "Hello")
2350
    self.assertEqual(utils.CommaJoin(["Hello", "World"]), "Hello, World")
2351
    self.assertEqual(utils.CommaJoin(["Hello", "World", 99]),
2352
                     "Hello, World, 99")
2353

    
2354

    
2355
class TestFindMatch(unittest.TestCase):
2356
  def test(self):
2357
    data = {
2358
      "aaaa": "Four A",
2359
      "bb": {"Two B": True},
2360
      re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
2361
      }
2362

    
2363
    self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
2364
    self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
2365

    
2366
    for i in ["foo", "bar", "bazX"]:
2367
      for j in range(1, 100, 7):
2368
        self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
2369
                         ((1, 2, 3), [i, str(j)]))
2370

    
2371
  def testNoMatch(self):
2372
    self.assert_(utils.FindMatch({}, "") is None)
2373
    self.assert_(utils.FindMatch({}, "foo") is None)
2374
    self.assert_(utils.FindMatch({}, 1234) is None)
2375

    
2376
    data = {
2377
      "X": "Hello World",
2378
      re.compile("^(something)$"): "Hello World",
2379
      }
2380

    
2381
    self.assert_(utils.FindMatch(data, "") is None)
2382
    self.assert_(utils.FindMatch(data, "Hello World") is None)
2383

    
2384

    
2385
class TestFileID(testutils.GanetiTestCase):
2386
  def testEquality(self):
2387
    name = self._CreateTempFile()
2388
    oldi = utils.GetFileID(path=name)
2389
    self.failUnless(utils.VerifyFileID(oldi, oldi))
2390

    
2391
  def testUpdate(self):
2392
    name = self._CreateTempFile()
2393
    oldi = utils.GetFileID(path=name)
2394
    os.utime(name, None)
2395
    fd = os.open(name, os.O_RDWR)
2396
    try:
2397
      newi = utils.GetFileID(fd=fd)
2398
      self.failUnless(utils.VerifyFileID(oldi, newi))
2399
      self.failUnless(utils.VerifyFileID(newi, oldi))
2400
    finally:
2401
      os.close(fd)
2402

    
2403
  def testWriteFile(self):
2404
    name = self._CreateTempFile()
2405
    oldi = utils.GetFileID(path=name)
2406
    mtime = oldi[2]
2407
    os.utime(name, (mtime + 10, mtime + 10))
2408
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
2409
                      oldi, data="")
2410
    os.utime(name, (mtime - 10, mtime - 10))
2411
    utils.SafeWriteFile(name, oldi, data="")
2412
    oldi = utils.GetFileID(path=name)
2413
    mtime = oldi[2]
2414
    os.utime(name, (mtime + 10, mtime + 10))
2415
    # this doesn't raise, since we passed None
2416
    utils.SafeWriteFile(name, None, data="")
2417

    
2418

    
2419
class TimeMock:
2420
  def __init__(self, values):
2421
    self.values = values
2422

    
2423
  def __call__(self):
2424
    return self.values.pop(0)
2425

    
2426

    
2427
class TestRunningTimeout(unittest.TestCase):
2428
  def setUp(self):
2429
    self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
2430

    
2431
  def testRemainingFloat(self):
2432
    timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
2433
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2434
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2435
    self.assertAlmostEqual(timeout.Remaining(), -1.5)
2436

    
2437
  def testRemaining(self):
2438
    self.time_fn = TimeMock([0, 2, 4, 5, 6])
2439
    timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
2440
    self.assertEqual(timeout.Remaining(), 3)
2441
    self.assertEqual(timeout.Remaining(), 1)
2442
    self.assertEqual(timeout.Remaining(), 0)
2443
    self.assertEqual(timeout.Remaining(), -1)
2444

    
2445
  def testRemainingNonNegative(self):
2446
    timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
2447
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2448
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2449
    self.assertEqual(timeout.Remaining(), 0.0)
2450

    
2451
  def testNegativeTimeout(self):
2452
    self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
2453

    
2454

    
2455
if __name__ == '__main__':
2456
  testutils.GanetiTestProgram()