Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 858905fb

History | View | Annotate | Download (77 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
from cStringIO import StringIO
42

    
43
import testutils
44
from ganeti import constants
45
from ganeti import compat
46
from ganeti import utils
47
from ganeti import errors
48
from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
49
     ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
50
     TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
51
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
52

    
53

    
54
class TestIsProcessAlive(unittest.TestCase):
55
  """Testing case for IsProcessAlive"""
56

    
57
  def testExists(self):
58
    mypid = os.getpid()
59
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
60

    
61
  def testNotExisting(self):
62
    pid_non_existing = os.fork()
63
    if pid_non_existing == 0:
64
      os._exit(0)
65
    elif pid_non_existing < 0:
66
      raise SystemError("can't fork")
67
    os.waitpid(pid_non_existing, 0)
68
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
69
                     "nonexisting process detected")
70

    
71

    
72
class TestGetProcStatusPath(unittest.TestCase):
73
  def test(self):
74
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
75
    self.assertNotEqual(utils._GetProcStatusPath(1),
76
                        utils._GetProcStatusPath(2))
77

    
78

    
79
class TestIsProcessHandlingSignal(unittest.TestCase):
80
  def setUp(self):
81
    self.tmpdir = tempfile.mkdtemp()
82

    
83
  def tearDown(self):
84
    shutil.rmtree(self.tmpdir)
85

    
86
  def testParseSigsetT(self):
87
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
88
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
89
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
90
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
91
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
92
                     set([2, 10, 32, 33]))
93
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
94
                     set([2, 32, 33]))
95
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
96
                     set([2, 28, 32, 33]))
97
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
98
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
99
                          24, 25, 26, 28, 31]))
100
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
101

    
102
  def testGetProcStatusField(self):
103
    for field in ["SigCgt", "Name", "FDSize"]:
104
      for value in ["", "0", "cat", "  1234 KB"]:
105
        pstatus = "\n".join([
106
          "VmPeak: 999 kB",
107
          "%s: %s" % (field, value),
108
          "TracerPid: 0",
109
          ])
110
        result = utils._GetProcStatusField(pstatus, field)
111
        self.assertEqual(result, value.strip())
112

    
113
  def test(self):
114
    sp = PathJoin(self.tmpdir, "status")
115

    
116
    utils.WriteFile(sp, data="\n".join([
117
      "Name:   bash",
118
      "State:  S (sleeping)",
119
      "SleepAVG:       98%",
120
      "Pid:    22250",
121
      "PPid:   10858",
122
      "TracerPid:      0",
123
      "SigBlk: 0000000000010000",
124
      "SigIgn: 0000000000384004",
125
      "SigCgt: 000000004b813efb",
126
      "CapEff: 0000000000000000",
127
      ]))
128

    
129
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
130

    
131
  def testNoSigCgt(self):
132
    sp = PathJoin(self.tmpdir, "status")
133

    
134
    utils.WriteFile(sp, data="\n".join([
135
      "Name:   bash",
136
      ]))
137

    
138
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
139
                      1234, 10, status_path=sp)
140

    
141
  def testNoSuchFile(self):
142
    sp = PathJoin(self.tmpdir, "notexist")
143

    
144
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
145

    
146
  @staticmethod
147
  def _TestRealProcess():
148
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
149
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
150
      raise Exception("SIGUSR1 is handled when it should not be")
151

    
152
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
153
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
154
      raise Exception("SIGUSR1 is not handled when it should be")
155

    
156
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
157
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
158
      raise Exception("SIGUSR1 is not handled when it should be")
159

    
160
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
161
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
162
      raise Exception("SIGUSR1 is handled when it should not be")
163

    
164
    return True
165

    
166
  def testRealProcess(self):
167
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
168

    
169

    
170
class TestPidFileFunctions(unittest.TestCase):
171
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
172

    
173
  def setUp(self):
174
    self.dir = tempfile.mkdtemp()
175
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
176
    utils.DaemonPidFileName = self.f_dpn
177

    
178
  def testPidFileFunctions(self):
179
    pid_file = self.f_dpn('test')
180
    utils.WritePidFile('test')
181
    self.failUnless(os.path.exists(pid_file),
182
                    "PID file should have been created")
183
    read_pid = utils.ReadPidFile(pid_file)
184
    self.failUnlessEqual(read_pid, os.getpid())
185
    self.failUnless(utils.IsProcessAlive(read_pid))
186
    self.failUnlessRaises(errors.GenericError, utils.WritePidFile, 'test')
187
    utils.RemovePidFile('test')
188
    self.failIf(os.path.exists(pid_file),
189
                "PID file should not exist anymore")
190
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
191
                         "ReadPidFile should return 0 for missing pid file")
192
    fh = open(pid_file, "w")
193
    fh.write("blah\n")
194
    fh.close()
195
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
196
                         "ReadPidFile should return 0 for invalid pid file")
197
    utils.RemovePidFile('test')
198
    self.failIf(os.path.exists(pid_file),
199
                "PID file should not exist anymore")
200

    
201
  def testKill(self):
202
    pid_file = self.f_dpn('child')
203
    r_fd, w_fd = os.pipe()
204
    new_pid = os.fork()
205
    if new_pid == 0: #child
206
      utils.WritePidFile('child')
207
      os.write(w_fd, 'a')
208
      signal.pause()
209
      os._exit(0)
210
      return
211
    # else we are in the parent
212
    # wait until the child has written the pid file
213
    os.read(r_fd, 1)
214
    read_pid = utils.ReadPidFile(pid_file)
215
    self.failUnlessEqual(read_pid, new_pid)
216
    self.failUnless(utils.IsProcessAlive(new_pid))
217
    utils.KillProcess(new_pid, waitpid=True)
218
    self.failIf(utils.IsProcessAlive(new_pid))
219
    utils.RemovePidFile('child')
220
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
221

    
222
  def tearDown(self):
223
    for name in os.listdir(self.dir):
224
      os.unlink(os.path.join(self.dir, name))
225
    os.rmdir(self.dir)
226

    
227

    
228
class TestRunCmd(testutils.GanetiTestCase):
229
  """Testing case for the RunCmd function"""
230

    
231
  def setUp(self):
232
    testutils.GanetiTestCase.setUp(self)
233
    self.magic = time.ctime() + " ganeti test"
234
    self.fname = self._CreateTempFile()
235

    
236
  def testOk(self):
237
    """Test successful exit code"""
238
    result = RunCmd("/bin/sh -c 'exit 0'")
239
    self.assertEqual(result.exit_code, 0)
240
    self.assertEqual(result.output, "")
241

    
242
  def testFail(self):
243
    """Test fail exit code"""
244
    result = RunCmd("/bin/sh -c 'exit 1'")
245
    self.assertEqual(result.exit_code, 1)
246
    self.assertEqual(result.output, "")
247

    
248
  def testStdout(self):
249
    """Test standard output"""
250
    cmd = 'echo -n "%s"' % self.magic
251
    result = RunCmd("/bin/sh -c '%s'" % cmd)
252
    self.assertEqual(result.stdout, self.magic)
253
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
254
    self.assertEqual(result.output, "")
255
    self.assertFileContent(self.fname, self.magic)
256

    
257
  def testStderr(self):
258
    """Test standard error"""
259
    cmd = 'echo -n "%s"' % self.magic
260
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
261
    self.assertEqual(result.stderr, self.magic)
262
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
263
    self.assertEqual(result.output, "")
264
    self.assertFileContent(self.fname, self.magic)
265

    
266
  def testCombined(self):
267
    """Test combined output"""
268
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
269
    expected = "A" + self.magic + "B" + self.magic
270
    result = RunCmd("/bin/sh -c '%s'" % cmd)
271
    self.assertEqual(result.output, expected)
272
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
273
    self.assertEqual(result.output, "")
274
    self.assertFileContent(self.fname, expected)
275

    
276
  def testSignal(self):
277
    """Test signal"""
278
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
279
    self.assertEqual(result.signal, 15)
280
    self.assertEqual(result.output, "")
281

    
282
  def testListRun(self):
283
    """Test list runs"""
284
    result = RunCmd(["true"])
285
    self.assertEqual(result.signal, None)
286
    self.assertEqual(result.exit_code, 0)
287
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
288
    self.assertEqual(result.signal, None)
289
    self.assertEqual(result.exit_code, 1)
290
    result = RunCmd(["echo", "-n", self.magic])
291
    self.assertEqual(result.signal, None)
292
    self.assertEqual(result.exit_code, 0)
293
    self.assertEqual(result.stdout, self.magic)
294

    
295
  def testFileEmptyOutput(self):
296
    """Test file output"""
297
    result = RunCmd(["true"], output=self.fname)
298
    self.assertEqual(result.signal, None)
299
    self.assertEqual(result.exit_code, 0)
300
    self.assertFileContent(self.fname, "")
301

    
302
  def testLang(self):
303
    """Test locale environment"""
304
    old_env = os.environ.copy()
305
    try:
306
      os.environ["LANG"] = "en_US.UTF-8"
307
      os.environ["LC_ALL"] = "en_US.UTF-8"
308
      result = RunCmd(["locale"])
309
      for line in result.output.splitlines():
310
        key, value = line.split("=", 1)
311
        # Ignore these variables, they're overridden by LC_ALL
312
        if key == "LANG" or key == "LANGUAGE":
313
          continue
314
        self.failIf(value and value != "C" and value != '"C"',
315
            "Variable %s is set to the invalid value '%s'" % (key, value))
316
    finally:
317
      os.environ = old_env
318

    
319
  def testDefaultCwd(self):
320
    """Test default working directory"""
321
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
322

    
323
  def testCwd(self):
324
    """Test default working directory"""
325
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
326
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
327
    cwd = os.getcwd()
328
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
329

    
330
  def testResetEnv(self):
331
    """Test environment reset functionality"""
332
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
333
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
334
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
335

    
336

    
337
class TestRunParts(unittest.TestCase):
338
  """Testing case for the RunParts function"""
339

    
340
  def setUp(self):
341
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
342

    
343
  def tearDown(self):
344
    shutil.rmtree(self.rundir)
345

    
346
  def testEmpty(self):
347
    """Test on an empty dir"""
348
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
349

    
350
  def testSkipWrongName(self):
351
    """Test that wrong files are skipped"""
352
    fname = os.path.join(self.rundir, "00test.dot")
353
    utils.WriteFile(fname, data="")
354
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
355
    relname = os.path.basename(fname)
356
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
357
                         [(relname, constants.RUNPARTS_SKIP, None)])
358

    
359
  def testSkipNonExec(self):
360
    """Test that non executable files are skipped"""
361
    fname = os.path.join(self.rundir, "00test")
362
    utils.WriteFile(fname, data="")
363
    relname = os.path.basename(fname)
364
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
365
                         [(relname, constants.RUNPARTS_SKIP, None)])
366

    
367
  def testError(self):
368
    """Test error on a broken executable"""
369
    fname = os.path.join(self.rundir, "00test")
370
    utils.WriteFile(fname, data="")
371
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
372
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
373
    self.failUnlessEqual(relname, os.path.basename(fname))
374
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
375
    self.failUnless(error)
376

    
377
  def testSorted(self):
378
    """Test executions are sorted"""
379
    files = []
380
    files.append(os.path.join(self.rundir, "64test"))
381
    files.append(os.path.join(self.rundir, "00test"))
382
    files.append(os.path.join(self.rundir, "42test"))
383

    
384
    for fname in files:
385
      utils.WriteFile(fname, data="")
386

    
387
    results = RunParts(self.rundir, reset_env=True)
388

    
389
    for fname in sorted(files):
390
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
391

    
392
  def testOk(self):
393
    """Test correct execution"""
394
    fname = os.path.join(self.rundir, "00test")
395
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
396
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
397
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
398
    self.failUnlessEqual(relname, os.path.basename(fname))
399
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
400
    self.failUnlessEqual(runresult.stdout, "ciao")
401

    
402
  def testRunFail(self):
403
    """Test correct execution, with run failure"""
404
    fname = os.path.join(self.rundir, "00test")
405
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
406
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408
    self.failUnlessEqual(relname, os.path.basename(fname))
409
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410
    self.failUnlessEqual(runresult.exit_code, 1)
411
    self.failUnless(runresult.failed)
412

    
413
  def testRunMix(self):
414
    files = []
415
    files.append(os.path.join(self.rundir, "00test"))
416
    files.append(os.path.join(self.rundir, "42test"))
417
    files.append(os.path.join(self.rundir, "64test"))
418
    files.append(os.path.join(self.rundir, "99test"))
419

    
420
    files.sort()
421

    
422
    # 1st has errors in execution
423
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
424
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
425

    
426
    # 2nd is skipped
427
    utils.WriteFile(files[1], data="")
428

    
429
    # 3rd cannot execute properly
430
    utils.WriteFile(files[2], data="")
431
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
432

    
433
    # 4th execs
434
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
435
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
436

    
437
    results = RunParts(self.rundir, reset_env=True)
438

    
439
    (relname, status, runresult) = results[0]
440
    self.failUnlessEqual(relname, os.path.basename(files[0]))
441
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
442
    self.failUnlessEqual(runresult.exit_code, 1)
443
    self.failUnless(runresult.failed)
444

    
445
    (relname, status, runresult) = results[1]
446
    self.failUnlessEqual(relname, os.path.basename(files[1]))
447
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
448
    self.failUnlessEqual(runresult, None)
449

    
450
    (relname, status, runresult) = results[2]
451
    self.failUnlessEqual(relname, os.path.basename(files[2]))
452
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
453
    self.failUnless(runresult)
454

    
455
    (relname, status, runresult) = results[3]
456
    self.failUnlessEqual(relname, os.path.basename(files[3]))
457
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
458
    self.failUnlessEqual(runresult.output, "ciao")
459
    self.failUnlessEqual(runresult.exit_code, 0)
460
    self.failUnless(not runresult.failed)
461

    
462

    
463
class TestStartDaemon(testutils.GanetiTestCase):
464
  def setUp(self):
465
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
466
    self.tmpfile = os.path.join(self.tmpdir, "test")
467

    
468
  def tearDown(self):
469
    shutil.rmtree(self.tmpdir)
470

    
471
  def testShell(self):
472
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
473
    self._wait(self.tmpfile, 60.0, "Hello World")
474

    
475
  def testShellOutput(self):
476
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
477
    self._wait(self.tmpfile, 60.0, "Hello World")
478

    
479
  def testNoShellNoOutput(self):
480
    utils.StartDaemon(["pwd"])
481

    
482
  def testNoShellNoOutputTouch(self):
483
    testfile = os.path.join(self.tmpdir, "check")
484
    self.failIf(os.path.exists(testfile))
485
    utils.StartDaemon(["touch", testfile])
486
    self._wait(testfile, 60.0, "")
487

    
488
  def testNoShellOutput(self):
489
    utils.StartDaemon(["pwd"], output=self.tmpfile)
490
    self._wait(self.tmpfile, 60.0, "/")
491

    
492
  def testNoShellOutputCwd(self):
493
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
494
    self._wait(self.tmpfile, 60.0, os.getcwd())
495

    
496
  def testShellEnv(self):
497
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
498
                      env={ "GNT_TEST_VAR": "Hello World", })
499
    self._wait(self.tmpfile, 60.0, "Hello World")
500

    
501
  def testNoShellEnv(self):
502
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
503
                      env={ "GNT_TEST_VAR": "Hello World", })
504
    self._wait(self.tmpfile, 60.0, "Hello World")
505

    
506
  def testOutputFd(self):
507
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
508
    try:
509
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
510
    finally:
511
      os.close(fd)
512
    self._wait(self.tmpfile, 60.0, os.getcwd())
513

    
514
  def testPid(self):
515
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
516
    self._wait(self.tmpfile, 60.0, str(pid))
517

    
518
  def testPidFile(self):
519
    pidfile = os.path.join(self.tmpdir, "pid")
520
    checkfile = os.path.join(self.tmpdir, "abort")
521

    
522
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
523
                            output=self.tmpfile)
524
    try:
525
      fd = os.open(pidfile, os.O_RDONLY)
526
      try:
527
        # Check file is locked
528
        self.assertRaises(errors.LockError, utils.LockFile, fd)
529

    
530
        pidtext = os.read(fd, 100)
531
      finally:
532
        os.close(fd)
533

    
534
      self.assertEqual(int(pidtext.strip()), pid)
535

    
536
      self.assert_(utils.IsProcessAlive(pid))
537
    finally:
538
      # No matter what happens, kill daemon
539
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
540
      self.failIf(utils.IsProcessAlive(pid))
541

    
542
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
543

    
544
  def _wait(self, path, timeout, expected):
545
    # Due to the asynchronous nature of daemon processes, polling is necessary.
546
    # A timeout makes sure the test doesn't hang forever.
547
    def _CheckFile():
548
      if not (os.path.isfile(path) and
549
              utils.ReadFile(path).strip() == expected):
550
        raise utils.RetryAgain()
551

    
552
    try:
553
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
554
    except utils.RetryTimeout:
555
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
556
                " didn't write the correct output" % timeout)
557

    
558
  def testError(self):
559
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
560
                      ["./does-NOT-EXIST/here/0123456789"])
561
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
562
                      ["./does-NOT-EXIST/here/0123456789"],
563
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
564
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
565
                      ["./does-NOT-EXIST/here/0123456789"],
566
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
567
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
568
                      ["./does-NOT-EXIST/here/0123456789"],
569
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
570

    
571
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
572
    try:
573
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
574
                        ["./does-NOT-EXIST/here/0123456789"],
575
                        output=self.tmpfile, output_fd=fd)
576
    finally:
577
      os.close(fd)
578

    
579

    
580
class TestSetCloseOnExecFlag(unittest.TestCase):
581
  """Tests for SetCloseOnExecFlag"""
582

    
583
  def setUp(self):
584
    self.tmpfile = tempfile.TemporaryFile()
585

    
586
  def testEnable(self):
587
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
588
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
589
                    fcntl.FD_CLOEXEC)
590

    
591
  def testDisable(self):
592
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
593
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
594
                fcntl.FD_CLOEXEC)
595

    
596

    
597
class TestSetNonblockFlag(unittest.TestCase):
598
  def setUp(self):
599
    self.tmpfile = tempfile.TemporaryFile()
600

    
601
  def testEnable(self):
602
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
603
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
604
                    os.O_NONBLOCK)
605

    
606
  def testDisable(self):
607
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
608
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
609
                os.O_NONBLOCK)
610

    
611

    
612
class TestRemoveFile(unittest.TestCase):
613
  """Test case for the RemoveFile function"""
614

    
615
  def setUp(self):
616
    """Create a temp dir and file for each case"""
617
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
618
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
619
    os.close(fd)
620

    
621
  def tearDown(self):
622
    if os.path.exists(self.tmpfile):
623
      os.unlink(self.tmpfile)
624
    os.rmdir(self.tmpdir)
625

    
626
  def testIgnoreDirs(self):
627
    """Test that RemoveFile() ignores directories"""
628
    self.assertEqual(None, RemoveFile(self.tmpdir))
629

    
630
  def testIgnoreNotExisting(self):
631
    """Test that RemoveFile() ignores non-existing files"""
632
    RemoveFile(self.tmpfile)
633
    RemoveFile(self.tmpfile)
634

    
635
  def testRemoveFile(self):
636
    """Test that RemoveFile does remove a file"""
637
    RemoveFile(self.tmpfile)
638
    if os.path.exists(self.tmpfile):
639
      self.fail("File '%s' not removed" % self.tmpfile)
640

    
641
  def testRemoveSymlink(self):
642
    """Test that RemoveFile does remove symlinks"""
643
    symlink = self.tmpdir + "/symlink"
644
    os.symlink("no-such-file", symlink)
645
    RemoveFile(symlink)
646
    if os.path.exists(symlink):
647
      self.fail("File '%s' not removed" % symlink)
648
    os.symlink(self.tmpfile, symlink)
649
    RemoveFile(symlink)
650
    if os.path.exists(symlink):
651
      self.fail("File '%s' not removed" % symlink)
652

    
653

    
654
class TestRename(unittest.TestCase):
655
  """Test case for RenameFile"""
656

    
657
  def setUp(self):
658
    """Create a temporary directory"""
659
    self.tmpdir = tempfile.mkdtemp()
660
    self.tmpfile = os.path.join(self.tmpdir, "test1")
661

    
662
    # Touch the file
663
    open(self.tmpfile, "w").close()
664

    
665
  def tearDown(self):
666
    """Remove temporary directory"""
667
    shutil.rmtree(self.tmpdir)
668

    
669
  def testSimpleRename1(self):
670
    """Simple rename 1"""
671
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
672
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
673

    
674
  def testSimpleRename2(self):
675
    """Simple rename 2"""
676
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
677
                     mkdir=True)
678
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
679

    
680
  def testRenameMkdir(self):
681
    """Rename with mkdir"""
682
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
683
                     mkdir=True)
684
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
685
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
686

    
687
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
688
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
689
                     mkdir=True)
690
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
691
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
692
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
693

    
694

    
695
class TestMatchNameComponent(unittest.TestCase):
696
  """Test case for the MatchNameComponent function"""
697

    
698
  def testEmptyList(self):
699
    """Test that there is no match against an empty list"""
700

    
701
    self.failUnlessEqual(MatchNameComponent("", []), None)
702
    self.failUnlessEqual(MatchNameComponent("test", []), None)
703

    
704
  def testSingleMatch(self):
705
    """Test that a single match is performed correctly"""
706
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
707
    for key in "test2", "test2.example", "test2.example.com":
708
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
709

    
710
  def testMultipleMatches(self):
711
    """Test that a multiple match is returned as None"""
712
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
713
    for key in "test1", "test1.example":
714
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
715

    
716
  def testFullMatch(self):
717
    """Test that a full match is returned correctly"""
718
    key1 = "test1"
719
    key2 = "test1.example"
720
    mlist = [key2, key2 + ".com"]
721
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
722
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
723

    
724
  def testCaseInsensitivePartialMatch(self):
725
    """Test for the case_insensitive keyword"""
726
    mlist = ["test1.example.com", "test2.example.net"]
727
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
728
                     "test2.example.net")
729
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
730
                     "test2.example.net")
731
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
732
                     "test2.example.net")
733
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
734
                     "test2.example.net")
735

    
736

    
737
  def testCaseInsensitiveFullMatch(self):
738
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
739
    # Between the two ts1 a full string match non-case insensitive should work
740
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
741
                     None)
742
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
743
                     "ts1.ex")
744
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
745
                     "ts1.ex")
746
    # Between the two ts2 only case differs, so only case-match works
747
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
748
                     "ts2.ex")
749
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
750
                     "Ts2.ex")
751
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
752
                     None)
753

    
754

    
755
class TestReadFile(testutils.GanetiTestCase):
756

    
757
  def testReadAll(self):
758
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
759
    self.assertEqual(len(data), 814)
760

    
761
    h = compat.md5_hash()
762
    h.update(data)
763
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
764

    
765
  def testReadSize(self):
766
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
767
                          size=100)
768
    self.assertEqual(len(data), 100)
769

    
770
    h = compat.md5_hash()
771
    h.update(data)
772
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
773

    
774
  def testError(self):
775
    self.assertRaises(EnvironmentError, utils.ReadFile,
776
                      "/dev/null/does-not-exist")
777

    
778

    
779
class TestReadOneLineFile(testutils.GanetiTestCase):
780

    
781
  def setUp(self):
782
    testutils.GanetiTestCase.setUp(self)
783

    
784
  def testDefault(self):
785
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
786
    self.assertEqual(len(data), 27)
787
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
788

    
789
  def testNotStrict(self):
790
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
791
    self.assertEqual(len(data), 27)
792
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
793

    
794
  def testStrictFailure(self):
795
    self.assertRaises(errors.GenericError, ReadOneLineFile,
796
                      self._TestDataFilename("cert1.pem"), strict=True)
797

    
798
  def testLongLine(self):
799
    dummydata = (1024 * "Hello World! ")
800
    myfile = self._CreateTempFile()
801
    utils.WriteFile(myfile, data=dummydata)
802
    datastrict = ReadOneLineFile(myfile, strict=True)
803
    datalax = ReadOneLineFile(myfile, strict=False)
804
    self.assertEqual(dummydata, datastrict)
805
    self.assertEqual(dummydata, datalax)
806

    
807
  def testNewline(self):
808
    myfile = self._CreateTempFile()
809
    myline = "myline"
810
    for nl in ["", "\n", "\r\n"]:
811
      dummydata = "%s%s" % (myline, nl)
812
      utils.WriteFile(myfile, data=dummydata)
813
      datalax = ReadOneLineFile(myfile, strict=False)
814
      self.assertEqual(myline, datalax)
815
      datastrict = ReadOneLineFile(myfile, strict=True)
816
      self.assertEqual(myline, datastrict)
817

    
818
  def testWhitespaceAndMultipleLines(self):
819
    myfile = self._CreateTempFile()
820
    for nl in ["", "\n", "\r\n"]:
821
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
822
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
823
        utils.WriteFile(myfile, data=dummydata)
824
        datalax = ReadOneLineFile(myfile, strict=False)
825
        if nl:
826
          self.assert_(set("\r\n") & set(dummydata))
827
          self.assertRaises(errors.GenericError, ReadOneLineFile,
828
                            myfile, strict=True)
829
          explen = len("Foo bar baz ") + len(ws)
830
          self.assertEqual(len(datalax), explen)
831
          self.assertEqual(datalax, dummydata[:explen])
832
          self.assertFalse(set("\r\n") & set(datalax))
833
        else:
834
          datastrict = ReadOneLineFile(myfile, strict=True)
835
          self.assertEqual(dummydata, datastrict)
836
          self.assertEqual(dummydata, datalax)
837

    
838
  def testEmptylines(self):
839
    myfile = self._CreateTempFile()
840
    myline = "myline"
841
    for nl in ["\n", "\r\n"]:
842
      for ol in ["", "otherline"]:
843
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
844
        utils.WriteFile(myfile, data=dummydata)
845
        self.assert_(set("\r\n") & set(dummydata))
846
        datalax = ReadOneLineFile(myfile, strict=False)
847
        self.assertEqual(myline, datalax)
848
        if ol:
849
          self.assertRaises(errors.GenericError, ReadOneLineFile,
850
                            myfile, strict=True)
851
        else:
852
          datastrict = ReadOneLineFile(myfile, strict=True)
853
          self.assertEqual(myline, datastrict)
854

    
855

    
856
class TestTimestampForFilename(unittest.TestCase):
857
  def test(self):
858
    self.assert_("." not in utils.TimestampForFilename())
859
    self.assert_(":" not in utils.TimestampForFilename())
860

    
861

    
862
class TestCreateBackup(testutils.GanetiTestCase):
863
  def setUp(self):
864
    testutils.GanetiTestCase.setUp(self)
865

    
866
    self.tmpdir = tempfile.mkdtemp()
867

    
868
  def tearDown(self):
869
    testutils.GanetiTestCase.tearDown(self)
870

    
871
    shutil.rmtree(self.tmpdir)
872

    
873
  def testEmpty(self):
874
    filename = PathJoin(self.tmpdir, "config.data")
875
    utils.WriteFile(filename, data="")
876
    bname = utils.CreateBackup(filename)
877
    self.assertFileContent(bname, "")
878
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
879
    utils.CreateBackup(filename)
880
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
881
    utils.CreateBackup(filename)
882
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
883

    
884
    fifoname = PathJoin(self.tmpdir, "fifo")
885
    os.mkfifo(fifoname)
886
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
887

    
888
  def testContent(self):
889
    bkpcount = 0
890
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
891
      for rep in [1, 2, 10, 127]:
892
        testdata = data * rep
893

    
894
        filename = PathJoin(self.tmpdir, "test.data_")
895
        utils.WriteFile(filename, data=testdata)
896
        self.assertFileContent(filename, testdata)
897

    
898
        for _ in range(3):
899
          bname = utils.CreateBackup(filename)
900
          bkpcount += 1
901
          self.assertFileContent(bname, testdata)
902
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
903

    
904

    
905
class TestFormatUnit(unittest.TestCase):
906
  """Test case for the FormatUnit function"""
907

    
908
  def testMiB(self):
909
    self.assertEqual(FormatUnit(1, 'h'), '1M')
910
    self.assertEqual(FormatUnit(100, 'h'), '100M')
911
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
912

    
913
    self.assertEqual(FormatUnit(1, 'm'), '1')
914
    self.assertEqual(FormatUnit(100, 'm'), '100')
915
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
916

    
917
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
918
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
919
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
920
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
921

    
922
  def testGiB(self):
923
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
924
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
925
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
926
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
927

    
928
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
929
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
930
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
931
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
932

    
933
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
934
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
935
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
936

    
937
  def testTiB(self):
938
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
939
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
940
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
941

    
942
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
943
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
944
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
945

    
946

    
947
class TestParseUnit(unittest.TestCase):
948
  """Test case for the ParseUnit function"""
949

    
950
  SCALES = (('', 1),
951
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
952
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
953
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
954

    
955
  def testRounding(self):
956
    self.assertEqual(ParseUnit('0'), 0)
957
    self.assertEqual(ParseUnit('1'), 4)
958
    self.assertEqual(ParseUnit('2'), 4)
959
    self.assertEqual(ParseUnit('3'), 4)
960

    
961
    self.assertEqual(ParseUnit('124'), 124)
962
    self.assertEqual(ParseUnit('125'), 128)
963
    self.assertEqual(ParseUnit('126'), 128)
964
    self.assertEqual(ParseUnit('127'), 128)
965
    self.assertEqual(ParseUnit('128'), 128)
966
    self.assertEqual(ParseUnit('129'), 132)
967
    self.assertEqual(ParseUnit('130'), 132)
968

    
969
  def testFloating(self):
970
    self.assertEqual(ParseUnit('0'), 0)
971
    self.assertEqual(ParseUnit('0.5'), 4)
972
    self.assertEqual(ParseUnit('1.75'), 4)
973
    self.assertEqual(ParseUnit('1.99'), 4)
974
    self.assertEqual(ParseUnit('2.00'), 4)
975
    self.assertEqual(ParseUnit('2.01'), 4)
976
    self.assertEqual(ParseUnit('3.99'), 4)
977
    self.assertEqual(ParseUnit('4.00'), 4)
978
    self.assertEqual(ParseUnit('4.01'), 8)
979
    self.assertEqual(ParseUnit('1.5G'), 1536)
980
    self.assertEqual(ParseUnit('1.8G'), 1844)
981
    self.assertEqual(ParseUnit('8.28T'), 8682212)
982

    
983
  def testSuffixes(self):
984
    for sep in ('', ' ', '   ', "\t", "\t "):
985
      for suffix, scale in TestParseUnit.SCALES:
986
        for func in (lambda x: x, str.lower, str.upper):
987
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
988
                           1024 * scale)
989

    
990
  def testInvalidInput(self):
991
    for sep in ('-', '_', ',', 'a'):
992
      for suffix, _ in TestParseUnit.SCALES:
993
        self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
994

    
995
    for suffix, _ in TestParseUnit.SCALES:
996
      self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
997

    
998

    
999
class TestSshKeys(testutils.GanetiTestCase):
1000
  """Test case for the AddAuthorizedKey function"""
1001

    
1002
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1003
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1004
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1005

    
1006
  def setUp(self):
1007
    testutils.GanetiTestCase.setUp(self)
1008
    self.tmpname = self._CreateTempFile()
1009
    handle = open(self.tmpname, 'w')
1010
    try:
1011
      handle.write("%s\n" % TestSshKeys.KEY_A)
1012
      handle.write("%s\n" % TestSshKeys.KEY_B)
1013
    finally:
1014
      handle.close()
1015

    
1016
  def testAddingNewKey(self):
1017
    utils.AddAuthorizedKey(self.tmpname,
1018
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1019

    
1020
    self.assertFileContent(self.tmpname,
1021
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1022
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1023
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1024
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1025

    
1026
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1027
    utils.AddAuthorizedKey(self.tmpname,
1028
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1029

    
1030
    self.assertFileContent(self.tmpname,
1031
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1032
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1033
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1034
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1035

    
1036
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1037
    utils.AddAuthorizedKey(self.tmpname,
1038
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1039

    
1040
    self.assertFileContent(self.tmpname,
1041
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1042
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1043
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1044

    
1045
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1046
    utils.RemoveAuthorizedKey(self.tmpname,
1047
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1048

    
1049
    self.assertFileContent(self.tmpname,
1050
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1051
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1052

    
1053
  def testRemovingNonExistingKey(self):
1054
    utils.RemoveAuthorizedKey(self.tmpname,
1055
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1056

    
1057
    self.assertFileContent(self.tmpname,
1058
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1059
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1060
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1061

    
1062

    
1063
class TestEtcHosts(testutils.GanetiTestCase):
1064
  """Test functions modifying /etc/hosts"""
1065

    
1066
  def setUp(self):
1067
    testutils.GanetiTestCase.setUp(self)
1068
    self.tmpname = self._CreateTempFile()
1069
    handle = open(self.tmpname, 'w')
1070
    try:
1071
      handle.write('# This is a test file for /etc/hosts\n')
1072
      handle.write('127.0.0.1\tlocalhost\n')
1073
      handle.write('192.0.2.1 router gw\n')
1074
    finally:
1075
      handle.close()
1076

    
1077
  def testSettingNewIp(self):
1078
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1079
                     ['myhost'])
1080

    
1081
    self.assertFileContent(self.tmpname,
1082
      "# This is a test file for /etc/hosts\n"
1083
      "127.0.0.1\tlocalhost\n"
1084
      "192.0.2.1 router gw\n"
1085
      "198.51.100.4\tmyhost.example.com myhost\n")
1086
    self.assertFileMode(self.tmpname, 0644)
1087

    
1088
  def testSettingExistingIp(self):
1089
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1090
                     ['myhost'])
1091

    
1092
    self.assertFileContent(self.tmpname,
1093
      "# This is a test file for /etc/hosts\n"
1094
      "127.0.0.1\tlocalhost\n"
1095
      "192.0.2.1\tmyhost.example.com myhost\n")
1096
    self.assertFileMode(self.tmpname, 0644)
1097

    
1098
  def testSettingDuplicateName(self):
1099
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1100

    
1101
    self.assertFileContent(self.tmpname,
1102
      "# This is a test file for /etc/hosts\n"
1103
      "127.0.0.1\tlocalhost\n"
1104
      "192.0.2.1 router gw\n"
1105
      "198.51.100.4\tmyhost\n")
1106
    self.assertFileMode(self.tmpname, 0644)
1107

    
1108
  def testRemovingExistingHost(self):
1109
    RemoveEtcHostsEntry(self.tmpname, 'router')
1110

    
1111
    self.assertFileContent(self.tmpname,
1112
      "# This is a test file for /etc/hosts\n"
1113
      "127.0.0.1\tlocalhost\n"
1114
      "192.0.2.1 gw\n")
1115
    self.assertFileMode(self.tmpname, 0644)
1116

    
1117
  def testRemovingSingleExistingHost(self):
1118
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1119

    
1120
    self.assertFileContent(self.tmpname,
1121
      "# This is a test file for /etc/hosts\n"
1122
      "192.0.2.1 router gw\n")
1123
    self.assertFileMode(self.tmpname, 0644)
1124

    
1125
  def testRemovingNonExistingHost(self):
1126
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1127

    
1128
    self.assertFileContent(self.tmpname,
1129
      "# This is a test file for /etc/hosts\n"
1130
      "127.0.0.1\tlocalhost\n"
1131
      "192.0.2.1 router gw\n")
1132
    self.assertFileMode(self.tmpname, 0644)
1133

    
1134
  def testRemovingAlias(self):
1135
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1136

    
1137
    self.assertFileContent(self.tmpname,
1138
      "# This is a test file for /etc/hosts\n"
1139
      "127.0.0.1\tlocalhost\n"
1140
      "192.0.2.1 router\n")
1141
    self.assertFileMode(self.tmpname, 0644)
1142

    
1143

    
1144
class TestGetMounts(unittest.TestCase):
1145
  """Test case for GetMounts()."""
1146

    
1147
  TESTDATA = (
1148
    "rootfs /     rootfs rw 0 0\n"
1149
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1150
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1151

    
1152
  def setUp(self):
1153
    self.tmpfile = tempfile.NamedTemporaryFile()
1154
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1155

    
1156
  def testGetMounts(self):
1157
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1158
      [
1159
        ("rootfs", "/", "rootfs", "rw"),
1160
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1161
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1162
      ])
1163

    
1164

    
1165
class TestShellQuoting(unittest.TestCase):
1166
  """Test case for shell quoting functions"""
1167

    
1168
  def testShellQuote(self):
1169
    self.assertEqual(ShellQuote('abc'), "abc")
1170
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1171
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1172
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1173
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1174

    
1175
  def testShellQuoteArgs(self):
1176
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1177
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1178
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1179

    
1180

    
1181
class TestListVisibleFiles(unittest.TestCase):
1182
  """Test case for ListVisibleFiles"""
1183

    
1184
  def setUp(self):
1185
    self.path = tempfile.mkdtemp()
1186

    
1187
  def tearDown(self):
1188
    shutil.rmtree(self.path)
1189

    
1190
  def _CreateFiles(self, files):
1191
    for name in files:
1192
      utils.WriteFile(os.path.join(self.path, name), data="test")
1193

    
1194
  def _test(self, files, expected):
1195
    self._CreateFiles(files)
1196
    found = ListVisibleFiles(self.path)
1197
    self.assertEqual(set(found), set(expected))
1198

    
1199
  def testAllVisible(self):
1200
    files = ["a", "b", "c"]
1201
    expected = files
1202
    self._test(files, expected)
1203

    
1204
  def testNoneVisible(self):
1205
    files = [".a", ".b", ".c"]
1206
    expected = []
1207
    self._test(files, expected)
1208

    
1209
  def testSomeVisible(self):
1210
    files = ["a", "b", ".c"]
1211
    expected = ["a", "b"]
1212
    self._test(files, expected)
1213

    
1214
  def testNonAbsolutePath(self):
1215
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1216

    
1217
  def testNonNormalizedPath(self):
1218
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1219
                          "/bin/../tmp")
1220

    
1221

    
1222
class TestNewUUID(unittest.TestCase):
1223
  """Test case for NewUUID"""
1224

    
1225
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1226
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1227

    
1228
  def runTest(self):
1229
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1230

    
1231

    
1232
class TestUniqueSequence(unittest.TestCase):
1233
  """Test case for UniqueSequence"""
1234

    
1235
  def _test(self, input, expected):
1236
    self.assertEqual(utils.UniqueSequence(input), expected)
1237

    
1238
  def runTest(self):
1239
    # Ordered input
1240
    self._test([1, 2, 3], [1, 2, 3])
1241
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1242
    self._test([1, 2, 2, 3], [1, 2, 3])
1243
    self._test([1, 2, 3, 3], [1, 2, 3])
1244

    
1245
    # Unordered input
1246
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1247
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1248

    
1249
    # Strings
1250
    self._test(["a", "a"], ["a"])
1251
    self._test(["a", "b"], ["a", "b"])
1252
    self._test(["a", "b", "a"], ["a", "b"])
1253

    
1254

    
1255
class TestFirstFree(unittest.TestCase):
1256
  """Test case for the FirstFree function"""
1257

    
1258
  def test(self):
1259
    """Test FirstFree"""
1260
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1261
    self.failUnlessEqual(FirstFree([]), None)
1262
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1263
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1264
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1265

    
1266

    
1267
class TestTailFile(testutils.GanetiTestCase):
1268
  """Test case for the TailFile function"""
1269

    
1270
  def testEmpty(self):
1271
    fname = self._CreateTempFile()
1272
    self.failUnlessEqual(TailFile(fname), [])
1273
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1274

    
1275
  def testAllLines(self):
1276
    data = ["test %d" % i for i in range(30)]
1277
    for i in range(30):
1278
      fname = self._CreateTempFile()
1279
      fd = open(fname, "w")
1280
      fd.write("\n".join(data[:i]))
1281
      if i > 0:
1282
        fd.write("\n")
1283
      fd.close()
1284
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1285

    
1286
  def testPartialLines(self):
1287
    data = ["test %d" % i for i in range(30)]
1288
    fname = self._CreateTempFile()
1289
    fd = open(fname, "w")
1290
    fd.write("\n".join(data))
1291
    fd.write("\n")
1292
    fd.close()
1293
    for i in range(1, 30):
1294
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1295

    
1296
  def testBigFile(self):
1297
    data = ["test %d" % i for i in range(30)]
1298
    fname = self._CreateTempFile()
1299
    fd = open(fname, "w")
1300
    fd.write("X" * 1048576)
1301
    fd.write("\n")
1302
    fd.write("\n".join(data))
1303
    fd.write("\n")
1304
    fd.close()
1305
    for i in range(1, 30):
1306
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1307

    
1308

    
1309
class _BaseFileLockTest:
1310
  """Test case for the FileLock class"""
1311

    
1312
  def testSharedNonblocking(self):
1313
    self.lock.Shared(blocking=False)
1314
    self.lock.Close()
1315

    
1316
  def testExclusiveNonblocking(self):
1317
    self.lock.Exclusive(blocking=False)
1318
    self.lock.Close()
1319

    
1320
  def testUnlockNonblocking(self):
1321
    self.lock.Unlock(blocking=False)
1322
    self.lock.Close()
1323

    
1324
  def testSharedBlocking(self):
1325
    self.lock.Shared(blocking=True)
1326
    self.lock.Close()
1327

    
1328
  def testExclusiveBlocking(self):
1329
    self.lock.Exclusive(blocking=True)
1330
    self.lock.Close()
1331

    
1332
  def testUnlockBlocking(self):
1333
    self.lock.Unlock(blocking=True)
1334
    self.lock.Close()
1335

    
1336
  def testSharedExclusiveUnlock(self):
1337
    self.lock.Shared(blocking=False)
1338
    self.lock.Exclusive(blocking=False)
1339
    self.lock.Unlock(blocking=False)
1340
    self.lock.Close()
1341

    
1342
  def testExclusiveSharedUnlock(self):
1343
    self.lock.Exclusive(blocking=False)
1344
    self.lock.Shared(blocking=False)
1345
    self.lock.Unlock(blocking=False)
1346
    self.lock.Close()
1347

    
1348
  def testSimpleTimeout(self):
1349
    # These will succeed on the first attempt, hence a short timeout
1350
    self.lock.Shared(blocking=True, timeout=10.0)
1351
    self.lock.Exclusive(blocking=False, timeout=10.0)
1352
    self.lock.Unlock(blocking=True, timeout=10.0)
1353
    self.lock.Close()
1354

    
1355
  @staticmethod
1356
  def _TryLockInner(filename, shared, blocking):
1357
    lock = utils.FileLock.Open(filename)
1358

    
1359
    if shared:
1360
      fn = lock.Shared
1361
    else:
1362
      fn = lock.Exclusive
1363

    
1364
    try:
1365
      # The timeout doesn't really matter as the parent process waits for us to
1366
      # finish anyway.
1367
      fn(blocking=blocking, timeout=0.01)
1368
    except errors.LockError, err:
1369
      return False
1370

    
1371
    return True
1372

    
1373
  def _TryLock(self, *args):
1374
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1375
                                      *args)
1376

    
1377
  def testTimeout(self):
1378
    for blocking in [True, False]:
1379
      self.lock.Exclusive(blocking=True)
1380
      self.failIf(self._TryLock(False, blocking))
1381
      self.failIf(self._TryLock(True, blocking))
1382

    
1383
      self.lock.Shared(blocking=True)
1384
      self.assert_(self._TryLock(True, blocking))
1385
      self.failIf(self._TryLock(False, blocking))
1386

    
1387
  def testCloseShared(self):
1388
    self.lock.Close()
1389
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1390

    
1391
  def testCloseExclusive(self):
1392
    self.lock.Close()
1393
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1394

    
1395
  def testCloseUnlock(self):
1396
    self.lock.Close()
1397
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1398

    
1399

    
1400
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1401
  TESTDATA = "Hello World\n" * 10
1402

    
1403
  def setUp(self):
1404
    testutils.GanetiTestCase.setUp(self)
1405

    
1406
    self.tmpfile = tempfile.NamedTemporaryFile()
1407
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1408
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1409

    
1410
    # Ensure "Open" didn't truncate file
1411
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1412

    
1413
  def tearDown(self):
1414
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1415

    
1416
    testutils.GanetiTestCase.tearDown(self)
1417

    
1418

    
1419
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1420
  def setUp(self):
1421
    self.tmpfile = tempfile.NamedTemporaryFile()
1422
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1423

    
1424

    
1425
class TestTimeFunctions(unittest.TestCase):
1426
  """Test case for time functions"""
1427

    
1428
  def runTest(self):
1429
    self.assertEqual(utils.SplitTime(1), (1, 0))
1430
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1431
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1432
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1433
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1434
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1435
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1436
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1437

    
1438
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1439

    
1440
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1441
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1442
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1443

    
1444
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1445
                     1218448917.481)
1446
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1447

    
1448
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1449
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1450
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1451
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1452
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1453

    
1454

    
1455
class FieldSetTestCase(unittest.TestCase):
1456
  """Test case for FieldSets"""
1457

    
1458
  def testSimpleMatch(self):
1459
    f = utils.FieldSet("a", "b", "c", "def")
1460
    self.failUnless(f.Matches("a"))
1461
    self.failIf(f.Matches("d"), "Substring matched")
1462
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1463
    self.failIf(f.NonMatching(["b", "c"]))
1464
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1465
    self.failUnless(f.NonMatching(["a", "d"]))
1466

    
1467
  def testRegexMatch(self):
1468
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1469
    self.failUnless(f.Matches("b1"))
1470
    self.failUnless(f.Matches("b99"))
1471
    self.failIf(f.Matches("b/1"))
1472
    self.failIf(f.NonMatching(["b12", "c"]))
1473
    self.failUnless(f.NonMatching(["a", "1"]))
1474

    
1475
class TestForceDictType(unittest.TestCase):
1476
  """Test case for ForceDictType"""
1477

    
1478
  def setUp(self):
1479
    self.key_types = {
1480
      'a': constants.VTYPE_INT,
1481
      'b': constants.VTYPE_BOOL,
1482
      'c': constants.VTYPE_STRING,
1483
      'd': constants.VTYPE_SIZE,
1484
      }
1485

    
1486
  def _fdt(self, dict, allowed_values=None):
1487
    if allowed_values is None:
1488
      utils.ForceDictType(dict, self.key_types)
1489
    else:
1490
      utils.ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1491

    
1492
    return dict
1493

    
1494
  def testSimpleDict(self):
1495
    self.assertEqual(self._fdt({}), {})
1496
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1497
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1498
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1499
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1500
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1501
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1502
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1503
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1504
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1505
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1506
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1507

    
1508
  def testErrors(self):
1509
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1510
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1511
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1512
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1513

    
1514

    
1515
class TestIsNormAbsPath(unittest.TestCase):
1516
  """Testing case for IsNormAbsPath"""
1517

    
1518
  def _pathTestHelper(self, path, result):
1519
    if result:
1520
      self.assert_(utils.IsNormAbsPath(path),
1521
          "Path %s should result absolute and normalized" % path)
1522
    else:
1523
      self.assertFalse(utils.IsNormAbsPath(path),
1524
          "Path %s should not result absolute and normalized" % path)
1525

    
1526
  def testBase(self):
1527
    self._pathTestHelper('/etc', True)
1528
    self._pathTestHelper('/srv', True)
1529
    self._pathTestHelper('etc', False)
1530
    self._pathTestHelper('/etc/../root', False)
1531
    self._pathTestHelper('/etc/', False)
1532

    
1533

    
1534
class TestSafeEncode(unittest.TestCase):
1535
  """Test case for SafeEncode"""
1536

    
1537
  def testAscii(self):
1538
    for txt in [string.digits, string.letters, string.punctuation]:
1539
      self.failUnlessEqual(txt, SafeEncode(txt))
1540

    
1541
  def testDoubleEncode(self):
1542
    for i in range(255):
1543
      txt = SafeEncode(chr(i))
1544
      self.failUnlessEqual(txt, SafeEncode(txt))
1545

    
1546
  def testUnicode(self):
1547
    # 1024 is high enough to catch non-direct ASCII mappings
1548
    for i in range(1024):
1549
      txt = SafeEncode(unichr(i))
1550
      self.failUnlessEqual(txt, SafeEncode(txt))
1551

    
1552

    
1553
class TestFormatTime(unittest.TestCase):
1554
  """Testing case for FormatTime"""
1555

    
1556
  def testNone(self):
1557
    self.failUnlessEqual(FormatTime(None), "N/A")
1558

    
1559
  def testInvalid(self):
1560
    self.failUnlessEqual(FormatTime(()), "N/A")
1561

    
1562
  def testNow(self):
1563
    # tests that we accept time.time input
1564
    FormatTime(time.time())
1565
    # tests that we accept int input
1566
    FormatTime(int(time.time()))
1567

    
1568

    
1569
class RunInSeparateProcess(unittest.TestCase):
1570
  def test(self):
1571
    for exp in [True, False]:
1572
      def _child():
1573
        return exp
1574

    
1575
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1576

    
1577
  def testArgs(self):
1578
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1579
      def _child(carg1, carg2):
1580
        return carg1 == "Foo" and carg2 == arg
1581

    
1582
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1583

    
1584
  def testPid(self):
1585
    parent_pid = os.getpid()
1586

    
1587
    def _check():
1588
      return os.getpid() == parent_pid
1589

    
1590
    self.failIf(utils.RunInSeparateProcess(_check))
1591

    
1592
  def testSignal(self):
1593
    def _kill():
1594
      os.kill(os.getpid(), signal.SIGTERM)
1595

    
1596
    self.assertRaises(errors.GenericError,
1597
                      utils.RunInSeparateProcess, _kill)
1598

    
1599
  def testException(self):
1600
    def _exc():
1601
      raise errors.GenericError("This is a test")
1602

    
1603
    self.assertRaises(errors.GenericError,
1604
                      utils.RunInSeparateProcess, _exc)
1605

    
1606

    
1607
class TestFingerprintFile(unittest.TestCase):
1608
  def setUp(self):
1609
    self.tmpfile = tempfile.NamedTemporaryFile()
1610

    
1611
  def test(self):
1612
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1613
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1614

    
1615
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1616
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1617
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1618

    
1619

    
1620
class TestUnescapeAndSplit(unittest.TestCase):
1621
  """Testing case for UnescapeAndSplit"""
1622

    
1623
  def setUp(self):
1624
    # testing more that one separator for regexp safety
1625
    self._seps = [",", "+", "."]
1626

    
1627
  def testSimple(self):
1628
    a = ["a", "b", "c", "d"]
1629
    for sep in self._seps:
1630
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1631

    
1632
  def testEscape(self):
1633
    for sep in self._seps:
1634
      a = ["a", "b\\" + sep + "c", "d"]
1635
      b = ["a", "b" + sep + "c", "d"]
1636
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1637

    
1638
  def testDoubleEscape(self):
1639
    for sep in self._seps:
1640
      a = ["a", "b\\\\", "c", "d"]
1641
      b = ["a", "b\\", "c", "d"]
1642
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1643

    
1644
  def testThreeEscape(self):
1645
    for sep in self._seps:
1646
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1647
      b = ["a", "b\\" + sep + "c", "d"]
1648
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1649

    
1650

    
1651
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1652
  def setUp(self):
1653
    self.tmpdir = tempfile.mkdtemp()
1654

    
1655
  def tearDown(self):
1656
    shutil.rmtree(self.tmpdir)
1657

    
1658
  def _checkRsaPrivateKey(self, key):
1659
    lines = key.splitlines()
1660
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1661
            "-----END RSA PRIVATE KEY-----" in lines)
1662

    
1663
  def _checkCertificate(self, cert):
1664
    lines = cert.splitlines()
1665
    return ("-----BEGIN CERTIFICATE-----" in lines and
1666
            "-----END CERTIFICATE-----" in lines)
1667

    
1668
  def test(self):
1669
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1670
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1671
      self._checkRsaPrivateKey(key_pem)
1672
      self._checkCertificate(cert_pem)
1673

    
1674
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1675
                                           key_pem)
1676
      self.assert_(key.bits() >= 1024)
1677
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1678
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1679

    
1680
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1681
                                             cert_pem)
1682
      self.failIf(x509.has_expired())
1683
      self.assertEqual(x509.get_issuer().CN, common_name)
1684
      self.assertEqual(x509.get_subject().CN, common_name)
1685
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1686

    
1687
  def testLegacy(self):
1688
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1689

    
1690
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1691

    
1692
    cert1 = utils.ReadFile(cert1_filename)
1693

    
1694
    self.assert_(self._checkRsaPrivateKey(cert1))
1695
    self.assert_(self._checkCertificate(cert1))
1696

    
1697

    
1698
class TestPathJoin(unittest.TestCase):
1699
  """Testing case for PathJoin"""
1700

    
1701
  def testBasicItems(self):
1702
    mlist = ["/a", "b", "c"]
1703
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1704

    
1705
  def testNonAbsPrefix(self):
1706
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1707

    
1708
  def testBackTrack(self):
1709
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1710

    
1711
  def testMultiAbs(self):
1712
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1713

    
1714

    
1715
class TestValidateServiceName(unittest.TestCase):
1716
  def testValid(self):
1717
    testnames = [
1718
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1719
      "ganeti",
1720
      "gnt-masterd",
1721
      "HELLO_WORLD_SVC",
1722
      "hello.world.1",
1723
      "0", "80", "1111", "65535",
1724
      ]
1725

    
1726
    for name in testnames:
1727
      self.assertEqual(utils.ValidateServiceName(name), name)
1728

    
1729
  def testInvalid(self):
1730
    testnames = [
1731
      -15756, -1, 65536, 133428083,
1732
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1733
      "-8546", "-1", "65536",
1734
      (129 * "A"),
1735
      ]
1736

    
1737
    for name in testnames:
1738
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1739

    
1740

    
1741
class TestParseAsn1Generalizedtime(unittest.TestCase):
1742
  def test(self):
1743
    # UTC
1744
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1745
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1746
                     1266860512)
1747
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1748
                     (2**31) - 1)
1749

    
1750
    # With offset
1751
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1752
                     1266860512)
1753
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1754
                     1266931012)
1755
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1756
                     1266931088)
1757
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1758
                     1266931295)
1759
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1760
                     3600)
1761

    
1762
    # Leap seconds are not supported by datetime.datetime
1763
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1764
                      "19841231235960+0000")
1765
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1766
                      "19920630235960+0000")
1767

    
1768
    # Errors
1769
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1770
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1771
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1772
                      "20100222174152")
1773
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1774
                      "Mon Feb 22 17:47:02 UTC 2010")
1775
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1776
                      "2010-02-22 17:42:02")
1777

    
1778

    
1779
class TestGetX509CertValidity(testutils.GanetiTestCase):
1780
  def setUp(self):
1781
    testutils.GanetiTestCase.setUp(self)
1782

    
1783
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1784

    
1785
    # Test whether we have pyOpenSSL 0.7 or above
1786
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1787

    
1788
    if not self.pyopenssl0_7:
1789
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1790
                    " function correctly")
1791

    
1792
  def _LoadCert(self, name):
1793
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1794
                                           self._ReadTestData(name))
1795

    
1796
  def test(self):
1797
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1798
    if self.pyopenssl0_7:
1799
      self.assertEqual(validity, (1266919967, 1267524767))
1800
    else:
1801
      self.assertEqual(validity, (None, None))
1802

    
1803

    
1804
class TestSignX509Certificate(unittest.TestCase):
1805
  KEY = "My private key!"
1806
  KEY_OTHER = "Another key"
1807

    
1808
  def test(self):
1809
    # Generate certificate valid for 5 minutes
1810
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1811

    
1812
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1813
                                           cert_pem)
1814

    
1815
    # No signature at all
1816
    self.assertRaises(errors.GenericError,
1817
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1818

    
1819
    # Invalid input
1820
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1821
                      "", self.KEY)
1822
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1823
                      "X-Ganeti-Signature: \n", self.KEY)
1824
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1825
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1826
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1827
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1828
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1829
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1830

    
1831
    # Invalid salt
1832
    for salt in list("-_@$,:;/\\ \t\n"):
1833
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1834
                        cert_pem, self.KEY, "foo%sbar" % salt)
1835

    
1836
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1837
                 utils.GenerateSecret(numbytes=4),
1838
                 utils.GenerateSecret(numbytes=16),
1839
                 "{123:456}".encode("hex")]:
1840
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1841

    
1842
      self._Check(cert, salt, signed_pem)
1843

    
1844
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1845
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1846
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1847
                               "lines----\n------ at\nthe end!"))
1848

    
1849
  def _Check(self, cert, salt, pem):
1850
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1851
    self.assertEqual(salt, salt2)
1852
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1853

    
1854
    # Other key
1855
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1856
                      pem, self.KEY_OTHER)
1857

    
1858

    
1859
class TestMakedirs(unittest.TestCase):
1860
  def setUp(self):
1861
    self.tmpdir = tempfile.mkdtemp()
1862

    
1863
  def tearDown(self):
1864
    shutil.rmtree(self.tmpdir)
1865

    
1866
  def testNonExisting(self):
1867
    path = PathJoin(self.tmpdir, "foo")
1868
    utils.Makedirs(path)
1869
    self.assert_(os.path.isdir(path))
1870

    
1871
  def testExisting(self):
1872
    path = PathJoin(self.tmpdir, "foo")
1873
    os.mkdir(path)
1874
    utils.Makedirs(path)
1875
    self.assert_(os.path.isdir(path))
1876

    
1877
  def testRecursiveNonExisting(self):
1878
    path = PathJoin(self.tmpdir, "foo/bar/baz")
1879
    utils.Makedirs(path)
1880
    self.assert_(os.path.isdir(path))
1881

    
1882
  def testRecursiveExisting(self):
1883
    path = PathJoin(self.tmpdir, "B/moo/xyz")
1884
    self.assertFalse(os.path.exists(path))
1885
    os.mkdir(PathJoin(self.tmpdir, "B"))
1886
    utils.Makedirs(path)
1887
    self.assert_(os.path.isdir(path))
1888

    
1889

    
1890
class TestRetry(testutils.GanetiTestCase):
1891
  def setUp(self):
1892
    testutils.GanetiTestCase.setUp(self)
1893
    self.retries = 0
1894

    
1895
  @staticmethod
1896
  def _RaiseRetryAgain():
1897
    raise utils.RetryAgain()
1898

    
1899
  @staticmethod
1900
  def _RaiseRetryAgainWithArg(args):
1901
    raise utils.RetryAgain(*args)
1902

    
1903
  def _WrongNestedLoop(self):
1904
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1905

    
1906
  def _RetryAndSucceed(self, retries):
1907
    if self.retries < retries:
1908
      self.retries += 1
1909
      raise utils.RetryAgain()
1910
    else:
1911
      return True
1912

    
1913
  def testRaiseTimeout(self):
1914
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1915
                          self._RaiseRetryAgain, 0.01, 0.02)
1916
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1917
                          self._RetryAndSucceed, 0.01, 0, args=[1])
1918
    self.failUnlessEqual(self.retries, 1)
1919

    
1920
  def testComplete(self):
1921
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1922
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1923
                         True)
1924
    self.failUnlessEqual(self.retries, 2)
1925

    
1926
  def testNestedLoop(self):
1927
    try:
1928
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1929
                            self._WrongNestedLoop, 0, 1)
1930
    except utils.RetryTimeout:
1931
      self.fail("Didn't detect inner loop's exception")
1932

    
1933
  def testTimeoutArgument(self):
1934
    retry_arg="my_important_debugging_message"
1935
    try:
1936
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
1937
    except utils.RetryTimeout, err:
1938
      self.failUnlessEqual(err.args, (retry_arg, ))
1939
    else:
1940
      self.fail("Expected timeout didn't happen")
1941

    
1942
  def testRaiseInnerWithExc(self):
1943
    retry_arg="my_important_debugging_message"
1944
    try:
1945
      try:
1946
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1947
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
1948
      except utils.RetryTimeout, err:
1949
        err.RaiseInner()
1950
      else:
1951
        self.fail("Expected timeout didn't happen")
1952
    except errors.GenericError, err:
1953
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1954
    else:
1955
      self.fail("Expected GenericError didn't happen")
1956

    
1957
  def testRaiseInnerWithMsg(self):
1958
    retry_arg="my_important_debugging_message"
1959
    try:
1960
      try:
1961
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1962
                    args=[[retry_arg, retry_arg]])
1963
      except utils.RetryTimeout, err:
1964
        err.RaiseInner()
1965
      else:
1966
        self.fail("Expected timeout didn't happen")
1967
    except utils.RetryTimeout, err:
1968
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1969
    else:
1970
      self.fail("Expected RetryTimeout didn't happen")
1971

    
1972

    
1973
class TestLineSplitter(unittest.TestCase):
1974
  def test(self):
1975
    lines = []
1976
    ls = utils.LineSplitter(lines.append)
1977
    ls.write("Hello World\n")
1978
    self.assertEqual(lines, [])
1979
    ls.write("Foo\n Bar\r\n ")
1980
    ls.write("Baz")
1981
    ls.write("Moo")
1982
    self.assertEqual(lines, [])
1983
    ls.flush()
1984
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
1985
    ls.close()
1986
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
1987

    
1988
  def _testExtra(self, line, all_lines, p1, p2):
1989
    self.assertEqual(p1, 999)
1990
    self.assertEqual(p2, "extra")
1991
    all_lines.append(line)
1992

    
1993
  def testExtraArgsNoFlush(self):
1994
    lines = []
1995
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
1996
    ls.write("\n\nHello World\n")
1997
    ls.write("Foo\n Bar\r\n ")
1998
    ls.write("")
1999
    ls.write("Baz")
2000
    ls.write("Moo\n\nx\n")
2001
    self.assertEqual(lines, [])
2002
    ls.close()
2003
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2004
                             "", "x"])
2005

    
2006

    
2007
class TestReadLockedPidFile(unittest.TestCase):
2008
  def setUp(self):
2009
    self.tmpdir = tempfile.mkdtemp()
2010

    
2011
  def tearDown(self):
2012
    shutil.rmtree(self.tmpdir)
2013

    
2014
  def testNonExistent(self):
2015
    path = PathJoin(self.tmpdir, "nonexist")
2016
    self.assert_(utils.ReadLockedPidFile(path) is None)
2017

    
2018
  def testUnlocked(self):
2019
    path = PathJoin(self.tmpdir, "pid")
2020
    utils.WriteFile(path, data="123")
2021
    self.assert_(utils.ReadLockedPidFile(path) is None)
2022

    
2023
  def testLocked(self):
2024
    path = PathJoin(self.tmpdir, "pid")
2025
    utils.WriteFile(path, data="123")
2026

    
2027
    fl = utils.FileLock.Open(path)
2028
    try:
2029
      fl.Exclusive(blocking=True)
2030

    
2031
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2032
    finally:
2033
      fl.Close()
2034

    
2035
    self.assert_(utils.ReadLockedPidFile(path) is None)
2036

    
2037
  def testError(self):
2038
    path = PathJoin(self.tmpdir, "foobar", "pid")
2039
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2040
    # open(2) should return ENOTDIR
2041
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2042

    
2043

    
2044
class TestCertVerification(testutils.GanetiTestCase):
2045
  def setUp(self):
2046
    testutils.GanetiTestCase.setUp(self)
2047

    
2048
    self.tmpdir = tempfile.mkdtemp()
2049

    
2050
  def tearDown(self):
2051
    shutil.rmtree(self.tmpdir)
2052

    
2053
  def testVerifyCertificate(self):
2054
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2055
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2056
                                           cert_pem)
2057

    
2058
    # Not checking return value as this certificate is expired
2059
    utils.VerifyX509Certificate(cert, 30, 7)
2060

    
2061

    
2062
class TestVerifyCertificateInner(unittest.TestCase):
2063
  def test(self):
2064
    vci = utils._VerifyCertificateInner
2065

    
2066
    # Valid
2067
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2068
                     (None, None))
2069

    
2070
    # Not yet valid
2071
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2072
    self.assertEqual(errcode, utils.CERT_WARNING)
2073

    
2074
    # Expiring soon
2075
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2076
    self.assertEqual(errcode, utils.CERT_ERROR)
2077

    
2078
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2079
    self.assertEqual(errcode, utils.CERT_WARNING)
2080

    
2081
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2082
    self.assertEqual(errcode, None)
2083

    
2084
    # Expired
2085
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2086
    self.assertEqual(errcode, utils.CERT_ERROR)
2087

    
2088
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2089
    self.assertEqual(errcode, utils.CERT_ERROR)
2090

    
2091
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2092
    self.assertEqual(errcode, utils.CERT_ERROR)
2093

    
2094
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2095
    self.assertEqual(errcode, utils.CERT_ERROR)
2096

    
2097

    
2098
class TestHmacFunctions(unittest.TestCase):
2099
  # Digests can be checked with "openssl sha1 -hmac $key"
2100
  def testSha1Hmac(self):
2101
    self.assertEqual(utils.Sha1Hmac("", ""),
2102
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2103
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2104
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2105
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2106
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2107

    
2108
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2109
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2110
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2111

    
2112
  def testSha1HmacSalt(self):
2113
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2114
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2115
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2116
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2117
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2118
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2119

    
2120
  def testVerifySha1Hmac(self):
2121
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2122
                                               "7d64b71fb76370690e1d")))
2123
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2124
                                      ("f904c2476527c6d3e660"
2125
                                       "9ab683c66fa0652cb1dc")))
2126

    
2127
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2128
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2129
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2130
                                      digest.lower()))
2131
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2132
                                      digest.upper()))
2133
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2134
                                      digest.title()))
2135

    
2136
  def testVerifySha1HmacSalt(self):
2137
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2138
                                      ("17a4adc34d69c0d367d4"
2139
                                       "ffbef96fd41d4df7a6e8"),
2140
                                      salt="abc9"))
2141
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2142
                                      ("7f264f8114c9066afc9b"
2143
                                       "b7636e1786d996d3cc0d"),
2144
                                      salt="xyz0"))
2145

    
2146

    
2147
class TestIgnoreSignals(unittest.TestCase):
2148
  """Test the IgnoreSignals decorator"""
2149

    
2150
  @staticmethod
2151
  def _Raise(exception):
2152
    raise exception
2153

    
2154
  @staticmethod
2155
  def _Return(rval):
2156
    return rval
2157

    
2158
  def testIgnoreSignals(self):
2159
    sock_err_intr = socket.error(errno.EINTR, "Message")
2160
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2161

    
2162
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2163
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2164

    
2165
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2166
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2167
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2168
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2169

    
2170
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2171
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2172
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2173
                      sock_err_inval)
2174
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2175
                      env_err_inval)
2176

    
2177
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2178
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2179

    
2180

    
2181
class TestEnsureDirs(unittest.TestCase):
2182
  """Tests for EnsureDirs"""
2183

    
2184
  def setUp(self):
2185
    self.dir = tempfile.mkdtemp()
2186
    self.old_umask = os.umask(0777)
2187

    
2188
  def testEnsureDirs(self):
2189
    utils.EnsureDirs([
2190
        (PathJoin(self.dir, "foo"), 0777),
2191
        (PathJoin(self.dir, "bar"), 0000),
2192
        ])
2193
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2194
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2195

    
2196
  def tearDown(self):
2197
    os.rmdir(PathJoin(self.dir, "foo"))
2198
    os.rmdir(PathJoin(self.dir, "bar"))
2199
    os.rmdir(self.dir)
2200
    os.umask(self.old_umask)
2201

    
2202

    
2203
class TestFormatSeconds(unittest.TestCase):
2204
  def test(self):
2205
    self.assertEqual(utils.FormatSeconds(1), "1s")
2206
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2207
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2208
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2209
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2210
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2211
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2212
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2213
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2214
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2215
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2216

    
2217
  def testFloat(self):
2218
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2219
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2220
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2221
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2222

    
2223

    
2224
class TestIgnoreProcessNotFound(unittest.TestCase):
2225
  @staticmethod
2226
  def _WritePid(fd):
2227
    os.write(fd, str(os.getpid()))
2228
    os.close(fd)
2229
    return True
2230

    
2231
  def test(self):
2232
    (pid_read_fd, pid_write_fd) = os.pipe()
2233

    
2234
    # Start short-lived process which writes its PID to pipe
2235
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2236
    os.close(pid_write_fd)
2237

    
2238
    # Read PID from pipe
2239
    pid = int(os.read(pid_read_fd, 1024))
2240
    os.close(pid_read_fd)
2241

    
2242
    # Try to send signal to process which exited recently
2243
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2244

    
2245

    
2246
class TestShellWriter(unittest.TestCase):
2247
  def test(self):
2248
    buf = StringIO()
2249
    sw = utils.ShellWriter(buf)
2250
    sw.Write("#!/bin/bash")
2251
    sw.Write("if true; then")
2252
    sw.IncIndent()
2253
    try:
2254
      sw.Write("echo true")
2255

    
2256
      sw.Write("for i in 1 2 3")
2257
      sw.Write("do")
2258
      sw.IncIndent()
2259
      try:
2260
        self.assertEqual(sw._indent, 2)
2261
        sw.Write("date")
2262
      finally:
2263
        sw.DecIndent()
2264
      sw.Write("done")
2265
    finally:
2266
      sw.DecIndent()
2267
    sw.Write("echo %s", utils.ShellQuote("Hello World"))
2268
    sw.Write("exit 0")
2269

    
2270
    self.assertEqual(sw._indent, 0)
2271

    
2272
    output = buf.getvalue()
2273

    
2274
    self.assert_(output.endswith("\n"))
2275

    
2276
    lines = output.splitlines()
2277
    self.assertEqual(len(lines), 9)
2278
    self.assertEqual(lines[0], "#!/bin/bash")
2279
    self.assert_(re.match(r"^\s+date$", lines[5]))
2280
    self.assertEqual(lines[7], "echo 'Hello World'")
2281

    
2282
  def testEmpty(self):
2283
    buf = StringIO()
2284
    sw = utils.ShellWriter(buf)
2285
    sw = None
2286
    self.assertEqual(buf.getvalue(), "")
2287

    
2288

    
2289
if __name__ == '__main__':
2290
  testutils.GanetiTestProgram()