Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 31155d60

History | View | Annotate | Download (77.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
from cStringIO import StringIO
42

    
43
import testutils
44
from ganeti import constants
45
from ganeti import compat
46
from ganeti import utils
47
from ganeti import errors
48
from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
49
     ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
50
     TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
51
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
52

    
53

    
54
class TestIsProcessAlive(unittest.TestCase):
55
  """Testing case for IsProcessAlive"""
56

    
57
  def testExists(self):
58
    mypid = os.getpid()
59
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
60

    
61
  def testNotExisting(self):
62
    pid_non_existing = os.fork()
63
    if pid_non_existing == 0:
64
      os._exit(0)
65
    elif pid_non_existing < 0:
66
      raise SystemError("can't fork")
67
    os.waitpid(pid_non_existing, 0)
68
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
69
                     "nonexisting process detected")
70

    
71

    
72
class TestGetProcStatusPath(unittest.TestCase):
73
  def test(self):
74
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
75
    self.assertNotEqual(utils._GetProcStatusPath(1),
76
                        utils._GetProcStatusPath(2))
77

    
78

    
79
class TestIsProcessHandlingSignal(unittest.TestCase):
80
  def setUp(self):
81
    self.tmpdir = tempfile.mkdtemp()
82

    
83
  def tearDown(self):
84
    shutil.rmtree(self.tmpdir)
85

    
86
  def testParseSigsetT(self):
87
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
88
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
89
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
90
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
91
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
92
                     set([2, 10, 32, 33]))
93
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
94
                     set([2, 32, 33]))
95
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
96
                     set([2, 28, 32, 33]))
97
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
98
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
99
                          24, 25, 26, 28, 31]))
100
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
101

    
102
  def testGetProcStatusField(self):
103
    for field in ["SigCgt", "Name", "FDSize"]:
104
      for value in ["", "0", "cat", "  1234 KB"]:
105
        pstatus = "\n".join([
106
          "VmPeak: 999 kB",
107
          "%s: %s" % (field, value),
108
          "TracerPid: 0",
109
          ])
110
        result = utils._GetProcStatusField(pstatus, field)
111
        self.assertEqual(result, value.strip())
112

    
113
  def test(self):
114
    sp = PathJoin(self.tmpdir, "status")
115

    
116
    utils.WriteFile(sp, data="\n".join([
117
      "Name:   bash",
118
      "State:  S (sleeping)",
119
      "SleepAVG:       98%",
120
      "Pid:    22250",
121
      "PPid:   10858",
122
      "TracerPid:      0",
123
      "SigBlk: 0000000000010000",
124
      "SigIgn: 0000000000384004",
125
      "SigCgt: 000000004b813efb",
126
      "CapEff: 0000000000000000",
127
      ]))
128

    
129
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
130

    
131
  def testNoSigCgt(self):
132
    sp = PathJoin(self.tmpdir, "status")
133

    
134
    utils.WriteFile(sp, data="\n".join([
135
      "Name:   bash",
136
      ]))
137

    
138
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
139
                      1234, 10, status_path=sp)
140

    
141
  def testNoSuchFile(self):
142
    sp = PathJoin(self.tmpdir, "notexist")
143

    
144
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
145

    
146
  @staticmethod
147
  def _TestRealProcess():
148
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
149
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
150
      raise Exception("SIGUSR1 is handled when it should not be")
151

    
152
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
153
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
154
      raise Exception("SIGUSR1 is not handled when it should be")
155

    
156
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
157
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
158
      raise Exception("SIGUSR1 is not handled when it should be")
159

    
160
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
161
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
162
      raise Exception("SIGUSR1 is handled when it should not be")
163

    
164
    return True
165

    
166
  def testRealProcess(self):
167
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
168

    
169

    
170
class TestPidFileFunctions(unittest.TestCase):
171
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
172

    
173
  def setUp(self):
174
    self.dir = tempfile.mkdtemp()
175
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
176
    utils.DaemonPidFileName = self.f_dpn
177

    
178
  def testPidFileFunctions(self):
179
    pid_file = self.f_dpn('test')
180
    utils.WritePidFile('test')
181
    self.failUnless(os.path.exists(pid_file),
182
                    "PID file should have been created")
183
    read_pid = utils.ReadPidFile(pid_file)
184
    self.failUnlessEqual(read_pid, os.getpid())
185
    self.failUnless(utils.IsProcessAlive(read_pid))
186
    self.failUnlessRaises(errors.GenericError, utils.WritePidFile, 'test')
187
    utils.RemovePidFile('test')
188
    self.failIf(os.path.exists(pid_file),
189
                "PID file should not exist anymore")
190
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
191
                         "ReadPidFile should return 0 for missing pid file")
192
    fh = open(pid_file, "w")
193
    fh.write("blah\n")
194
    fh.close()
195
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
196
                         "ReadPidFile should return 0 for invalid pid file")
197
    utils.RemovePidFile('test')
198
    self.failIf(os.path.exists(pid_file),
199
                "PID file should not exist anymore")
200

    
201
  def testKill(self):
202
    pid_file = self.f_dpn('child')
203
    r_fd, w_fd = os.pipe()
204
    new_pid = os.fork()
205
    if new_pid == 0: #child
206
      utils.WritePidFile('child')
207
      os.write(w_fd, 'a')
208
      signal.pause()
209
      os._exit(0)
210
      return
211
    # else we are in the parent
212
    # wait until the child has written the pid file
213
    os.read(r_fd, 1)
214
    read_pid = utils.ReadPidFile(pid_file)
215
    self.failUnlessEqual(read_pid, new_pid)
216
    self.failUnless(utils.IsProcessAlive(new_pid))
217
    utils.KillProcess(new_pid, waitpid=True)
218
    self.failIf(utils.IsProcessAlive(new_pid))
219
    utils.RemovePidFile('child')
220
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
221

    
222
  def tearDown(self):
223
    for name in os.listdir(self.dir):
224
      os.unlink(os.path.join(self.dir, name))
225
    os.rmdir(self.dir)
226

    
227

    
228
class TestRunCmd(testutils.GanetiTestCase):
229
  """Testing case for the RunCmd function"""
230

    
231
  def setUp(self):
232
    testutils.GanetiTestCase.setUp(self)
233
    self.magic = time.ctime() + " ganeti test"
234
    self.fname = self._CreateTempFile()
235

    
236
  def testOk(self):
237
    """Test successful exit code"""
238
    result = RunCmd("/bin/sh -c 'exit 0'")
239
    self.assertEqual(result.exit_code, 0)
240
    self.assertEqual(result.output, "")
241

    
242
  def testFail(self):
243
    """Test fail exit code"""
244
    result = RunCmd("/bin/sh -c 'exit 1'")
245
    self.assertEqual(result.exit_code, 1)
246
    self.assertEqual(result.output, "")
247

    
248
  def testStdout(self):
249
    """Test standard output"""
250
    cmd = 'echo -n "%s"' % self.magic
251
    result = RunCmd("/bin/sh -c '%s'" % cmd)
252
    self.assertEqual(result.stdout, self.magic)
253
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
254
    self.assertEqual(result.output, "")
255
    self.assertFileContent(self.fname, self.magic)
256

    
257
  def testStderr(self):
258
    """Test standard error"""
259
    cmd = 'echo -n "%s"' % self.magic
260
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
261
    self.assertEqual(result.stderr, self.magic)
262
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
263
    self.assertEqual(result.output, "")
264
    self.assertFileContent(self.fname, self.magic)
265

    
266
  def testCombined(self):
267
    """Test combined output"""
268
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
269
    expected = "A" + self.magic + "B" + self.magic
270
    result = RunCmd("/bin/sh -c '%s'" % cmd)
271
    self.assertEqual(result.output, expected)
272
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
273
    self.assertEqual(result.output, "")
274
    self.assertFileContent(self.fname, expected)
275

    
276
  def testSignal(self):
277
    """Test signal"""
278
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
279
    self.assertEqual(result.signal, 15)
280
    self.assertEqual(result.output, "")
281

    
282
  def testListRun(self):
283
    """Test list runs"""
284
    result = RunCmd(["true"])
285
    self.assertEqual(result.signal, None)
286
    self.assertEqual(result.exit_code, 0)
287
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
288
    self.assertEqual(result.signal, None)
289
    self.assertEqual(result.exit_code, 1)
290
    result = RunCmd(["echo", "-n", self.magic])
291
    self.assertEqual(result.signal, None)
292
    self.assertEqual(result.exit_code, 0)
293
    self.assertEqual(result.stdout, self.magic)
294

    
295
  def testFileEmptyOutput(self):
296
    """Test file output"""
297
    result = RunCmd(["true"], output=self.fname)
298
    self.assertEqual(result.signal, None)
299
    self.assertEqual(result.exit_code, 0)
300
    self.assertFileContent(self.fname, "")
301

    
302
  def testLang(self):
303
    """Test locale environment"""
304
    old_env = os.environ.copy()
305
    try:
306
      os.environ["LANG"] = "en_US.UTF-8"
307
      os.environ["LC_ALL"] = "en_US.UTF-8"
308
      result = RunCmd(["locale"])
309
      for line in result.output.splitlines():
310
        key, value = line.split("=", 1)
311
        # Ignore these variables, they're overridden by LC_ALL
312
        if key == "LANG" or key == "LANGUAGE":
313
          continue
314
        self.failIf(value and value != "C" and value != '"C"',
315
            "Variable %s is set to the invalid value '%s'" % (key, value))
316
    finally:
317
      os.environ = old_env
318

    
319
  def testDefaultCwd(self):
320
    """Test default working directory"""
321
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
322

    
323
  def testCwd(self):
324
    """Test default working directory"""
325
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
326
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
327
    cwd = os.getcwd()
328
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
329

    
330
  def testResetEnv(self):
331
    """Test environment reset functionality"""
332
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
333
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
334
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
335

    
336

    
337
class TestRunParts(unittest.TestCase):
338
  """Testing case for the RunParts function"""
339

    
340
  def setUp(self):
341
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
342

    
343
  def tearDown(self):
344
    shutil.rmtree(self.rundir)
345

    
346
  def testEmpty(self):
347
    """Test on an empty dir"""
348
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
349

    
350
  def testSkipWrongName(self):
351
    """Test that wrong files are skipped"""
352
    fname = os.path.join(self.rundir, "00test.dot")
353
    utils.WriteFile(fname, data="")
354
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
355
    relname = os.path.basename(fname)
356
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
357
                         [(relname, constants.RUNPARTS_SKIP, None)])
358

    
359
  def testSkipNonExec(self):
360
    """Test that non executable files are skipped"""
361
    fname = os.path.join(self.rundir, "00test")
362
    utils.WriteFile(fname, data="")
363
    relname = os.path.basename(fname)
364
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
365
                         [(relname, constants.RUNPARTS_SKIP, None)])
366

    
367
  def testError(self):
368
    """Test error on a broken executable"""
369
    fname = os.path.join(self.rundir, "00test")
370
    utils.WriteFile(fname, data="")
371
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
372
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
373
    self.failUnlessEqual(relname, os.path.basename(fname))
374
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
375
    self.failUnless(error)
376

    
377
  def testSorted(self):
378
    """Test executions are sorted"""
379
    files = []
380
    files.append(os.path.join(self.rundir, "64test"))
381
    files.append(os.path.join(self.rundir, "00test"))
382
    files.append(os.path.join(self.rundir, "42test"))
383

    
384
    for fname in files:
385
      utils.WriteFile(fname, data="")
386

    
387
    results = RunParts(self.rundir, reset_env=True)
388

    
389
    for fname in sorted(files):
390
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
391

    
392
  def testOk(self):
393
    """Test correct execution"""
394
    fname = os.path.join(self.rundir, "00test")
395
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
396
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
397
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
398
    self.failUnlessEqual(relname, os.path.basename(fname))
399
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
400
    self.failUnlessEqual(runresult.stdout, "ciao")
401

    
402
  def testRunFail(self):
403
    """Test correct execution, with run failure"""
404
    fname = os.path.join(self.rundir, "00test")
405
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
406
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408
    self.failUnlessEqual(relname, os.path.basename(fname))
409
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410
    self.failUnlessEqual(runresult.exit_code, 1)
411
    self.failUnless(runresult.failed)
412

    
413
  def testRunMix(self):
414
    files = []
415
    files.append(os.path.join(self.rundir, "00test"))
416
    files.append(os.path.join(self.rundir, "42test"))
417
    files.append(os.path.join(self.rundir, "64test"))
418
    files.append(os.path.join(self.rundir, "99test"))
419

    
420
    files.sort()
421

    
422
    # 1st has errors in execution
423
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
424
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
425

    
426
    # 2nd is skipped
427
    utils.WriteFile(files[1], data="")
428

    
429
    # 3rd cannot execute properly
430
    utils.WriteFile(files[2], data="")
431
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
432

    
433
    # 4th execs
434
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
435
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
436

    
437
    results = RunParts(self.rundir, reset_env=True)
438

    
439
    (relname, status, runresult) = results[0]
440
    self.failUnlessEqual(relname, os.path.basename(files[0]))
441
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
442
    self.failUnlessEqual(runresult.exit_code, 1)
443
    self.failUnless(runresult.failed)
444

    
445
    (relname, status, runresult) = results[1]
446
    self.failUnlessEqual(relname, os.path.basename(files[1]))
447
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
448
    self.failUnlessEqual(runresult, None)
449

    
450
    (relname, status, runresult) = results[2]
451
    self.failUnlessEqual(relname, os.path.basename(files[2]))
452
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
453
    self.failUnless(runresult)
454

    
455
    (relname, status, runresult) = results[3]
456
    self.failUnlessEqual(relname, os.path.basename(files[3]))
457
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
458
    self.failUnlessEqual(runresult.output, "ciao")
459
    self.failUnlessEqual(runresult.exit_code, 0)
460
    self.failUnless(not runresult.failed)
461

    
462

    
463
class TestStartDaemon(testutils.GanetiTestCase):
464
  def setUp(self):
465
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
466
    self.tmpfile = os.path.join(self.tmpdir, "test")
467

    
468
  def tearDown(self):
469
    shutil.rmtree(self.tmpdir)
470

    
471
  def testShell(self):
472
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
473
    self._wait(self.tmpfile, 60.0, "Hello World")
474

    
475
  def testShellOutput(self):
476
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
477
    self._wait(self.tmpfile, 60.0, "Hello World")
478

    
479
  def testNoShellNoOutput(self):
480
    utils.StartDaemon(["pwd"])
481

    
482
  def testNoShellNoOutputTouch(self):
483
    testfile = os.path.join(self.tmpdir, "check")
484
    self.failIf(os.path.exists(testfile))
485
    utils.StartDaemon(["touch", testfile])
486
    self._wait(testfile, 60.0, "")
487

    
488
  def testNoShellOutput(self):
489
    utils.StartDaemon(["pwd"], output=self.tmpfile)
490
    self._wait(self.tmpfile, 60.0, "/")
491

    
492
  def testNoShellOutputCwd(self):
493
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
494
    self._wait(self.tmpfile, 60.0, os.getcwd())
495

    
496
  def testShellEnv(self):
497
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
498
                      env={ "GNT_TEST_VAR": "Hello World", })
499
    self._wait(self.tmpfile, 60.0, "Hello World")
500

    
501
  def testNoShellEnv(self):
502
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
503
                      env={ "GNT_TEST_VAR": "Hello World", })
504
    self._wait(self.tmpfile, 60.0, "Hello World")
505

    
506
  def testOutputFd(self):
507
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
508
    try:
509
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
510
    finally:
511
      os.close(fd)
512
    self._wait(self.tmpfile, 60.0, os.getcwd())
513

    
514
  def testPid(self):
515
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
516
    self._wait(self.tmpfile, 60.0, str(pid))
517

    
518
  def testPidFile(self):
519
    pidfile = os.path.join(self.tmpdir, "pid")
520
    checkfile = os.path.join(self.tmpdir, "abort")
521

    
522
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
523
                            output=self.tmpfile)
524
    try:
525
      fd = os.open(pidfile, os.O_RDONLY)
526
      try:
527
        # Check file is locked
528
        self.assertRaises(errors.LockError, utils.LockFile, fd)
529

    
530
        pidtext = os.read(fd, 100)
531
      finally:
532
        os.close(fd)
533

    
534
      self.assertEqual(int(pidtext.strip()), pid)
535

    
536
      self.assert_(utils.IsProcessAlive(pid))
537
    finally:
538
      # No matter what happens, kill daemon
539
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
540
      self.failIf(utils.IsProcessAlive(pid))
541

    
542
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
543

    
544
  def _wait(self, path, timeout, expected):
545
    # Due to the asynchronous nature of daemon processes, polling is necessary.
546
    # A timeout makes sure the test doesn't hang forever.
547
    def _CheckFile():
548
      if not (os.path.isfile(path) and
549
              utils.ReadFile(path).strip() == expected):
550
        raise utils.RetryAgain()
551

    
552
    try:
553
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
554
    except utils.RetryTimeout:
555
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
556
                " didn't write the correct output" % timeout)
557

    
558
  def testError(self):
559
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
560
                      ["./does-NOT-EXIST/here/0123456789"])
561
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
562
                      ["./does-NOT-EXIST/here/0123456789"],
563
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
564
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
565
                      ["./does-NOT-EXIST/here/0123456789"],
566
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
567
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
568
                      ["./does-NOT-EXIST/here/0123456789"],
569
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
570

    
571
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
572
    try:
573
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
574
                        ["./does-NOT-EXIST/here/0123456789"],
575
                        output=self.tmpfile, output_fd=fd)
576
    finally:
577
      os.close(fd)
578

    
579

    
580
class TestSetCloseOnExecFlag(unittest.TestCase):
581
  """Tests for SetCloseOnExecFlag"""
582

    
583
  def setUp(self):
584
    self.tmpfile = tempfile.TemporaryFile()
585

    
586
  def testEnable(self):
587
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
588
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
589
                    fcntl.FD_CLOEXEC)
590

    
591
  def testDisable(self):
592
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
593
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
594
                fcntl.FD_CLOEXEC)
595

    
596

    
597
class TestSetNonblockFlag(unittest.TestCase):
598
  def setUp(self):
599
    self.tmpfile = tempfile.TemporaryFile()
600

    
601
  def testEnable(self):
602
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
603
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
604
                    os.O_NONBLOCK)
605

    
606
  def testDisable(self):
607
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
608
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
609
                os.O_NONBLOCK)
610

    
611

    
612
class TestRemoveFile(unittest.TestCase):
613
  """Test case for the RemoveFile function"""
614

    
615
  def setUp(self):
616
    """Create a temp dir and file for each case"""
617
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
618
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
619
    os.close(fd)
620

    
621
  def tearDown(self):
622
    if os.path.exists(self.tmpfile):
623
      os.unlink(self.tmpfile)
624
    os.rmdir(self.tmpdir)
625

    
626
  def testIgnoreDirs(self):
627
    """Test that RemoveFile() ignores directories"""
628
    self.assertEqual(None, RemoveFile(self.tmpdir))
629

    
630
  def testIgnoreNotExisting(self):
631
    """Test that RemoveFile() ignores non-existing files"""
632
    RemoveFile(self.tmpfile)
633
    RemoveFile(self.tmpfile)
634

    
635
  def testRemoveFile(self):
636
    """Test that RemoveFile does remove a file"""
637
    RemoveFile(self.tmpfile)
638
    if os.path.exists(self.tmpfile):
639
      self.fail("File '%s' not removed" % self.tmpfile)
640

    
641
  def testRemoveSymlink(self):
642
    """Test that RemoveFile does remove symlinks"""
643
    symlink = self.tmpdir + "/symlink"
644
    os.symlink("no-such-file", symlink)
645
    RemoveFile(symlink)
646
    if os.path.exists(symlink):
647
      self.fail("File '%s' not removed" % symlink)
648
    os.symlink(self.tmpfile, symlink)
649
    RemoveFile(symlink)
650
    if os.path.exists(symlink):
651
      self.fail("File '%s' not removed" % symlink)
652

    
653

    
654
class TestRename(unittest.TestCase):
655
  """Test case for RenameFile"""
656

    
657
  def setUp(self):
658
    """Create a temporary directory"""
659
    self.tmpdir = tempfile.mkdtemp()
660
    self.tmpfile = os.path.join(self.tmpdir, "test1")
661

    
662
    # Touch the file
663
    open(self.tmpfile, "w").close()
664

    
665
  def tearDown(self):
666
    """Remove temporary directory"""
667
    shutil.rmtree(self.tmpdir)
668

    
669
  def testSimpleRename1(self):
670
    """Simple rename 1"""
671
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
672
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
673

    
674
  def testSimpleRename2(self):
675
    """Simple rename 2"""
676
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
677
                     mkdir=True)
678
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
679

    
680
  def testRenameMkdir(self):
681
    """Rename with mkdir"""
682
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
683
                     mkdir=True)
684
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
685
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
686

    
687
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
688
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
689
                     mkdir=True)
690
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
691
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
692
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
693

    
694

    
695
class TestMatchNameComponent(unittest.TestCase):
696
  """Test case for the MatchNameComponent function"""
697

    
698
  def testEmptyList(self):
699
    """Test that there is no match against an empty list"""
700

    
701
    self.failUnlessEqual(MatchNameComponent("", []), None)
702
    self.failUnlessEqual(MatchNameComponent("test", []), None)
703

    
704
  def testSingleMatch(self):
705
    """Test that a single match is performed correctly"""
706
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
707
    for key in "test2", "test2.example", "test2.example.com":
708
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
709

    
710
  def testMultipleMatches(self):
711
    """Test that a multiple match is returned as None"""
712
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
713
    for key in "test1", "test1.example":
714
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
715

    
716
  def testFullMatch(self):
717
    """Test that a full match is returned correctly"""
718
    key1 = "test1"
719
    key2 = "test1.example"
720
    mlist = [key2, key2 + ".com"]
721
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
722
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
723

    
724
  def testCaseInsensitivePartialMatch(self):
725
    """Test for the case_insensitive keyword"""
726
    mlist = ["test1.example.com", "test2.example.net"]
727
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
728
                     "test2.example.net")
729
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
730
                     "test2.example.net")
731
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
732
                     "test2.example.net")
733
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
734
                     "test2.example.net")
735

    
736

    
737
  def testCaseInsensitiveFullMatch(self):
738
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
739
    # Between the two ts1 a full string match non-case insensitive should work
740
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
741
                     None)
742
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
743
                     "ts1.ex")
744
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
745
                     "ts1.ex")
746
    # Between the two ts2 only case differs, so only case-match works
747
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
748
                     "ts2.ex")
749
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
750
                     "Ts2.ex")
751
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
752
                     None)
753

    
754

    
755
class TestReadFile(testutils.GanetiTestCase):
756

    
757
  def testReadAll(self):
758
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
759
    self.assertEqual(len(data), 814)
760

    
761
    h = compat.md5_hash()
762
    h.update(data)
763
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
764

    
765
  def testReadSize(self):
766
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
767
                          size=100)
768
    self.assertEqual(len(data), 100)
769

    
770
    h = compat.md5_hash()
771
    h.update(data)
772
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
773

    
774
  def testError(self):
775
    self.assertRaises(EnvironmentError, utils.ReadFile,
776
                      "/dev/null/does-not-exist")
777

    
778

    
779
class TestReadOneLineFile(testutils.GanetiTestCase):
780

    
781
  def setUp(self):
782
    testutils.GanetiTestCase.setUp(self)
783

    
784
  def testDefault(self):
785
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
786
    self.assertEqual(len(data), 27)
787
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
788

    
789
  def testNotStrict(self):
790
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
791
    self.assertEqual(len(data), 27)
792
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
793

    
794
  def testStrictFailure(self):
795
    self.assertRaises(errors.GenericError, ReadOneLineFile,
796
                      self._TestDataFilename("cert1.pem"), strict=True)
797

    
798
  def testLongLine(self):
799
    dummydata = (1024 * "Hello World! ")
800
    myfile = self._CreateTempFile()
801
    utils.WriteFile(myfile, data=dummydata)
802
    datastrict = ReadOneLineFile(myfile, strict=True)
803
    datalax = ReadOneLineFile(myfile, strict=False)
804
    self.assertEqual(dummydata, datastrict)
805
    self.assertEqual(dummydata, datalax)
806

    
807
  def testNewline(self):
808
    myfile = self._CreateTempFile()
809
    myline = "myline"
810
    for nl in ["", "\n", "\r\n"]:
811
      dummydata = "%s%s" % (myline, nl)
812
      utils.WriteFile(myfile, data=dummydata)
813
      datalax = ReadOneLineFile(myfile, strict=False)
814
      self.assertEqual(myline, datalax)
815
      datastrict = ReadOneLineFile(myfile, strict=True)
816
      self.assertEqual(myline, datastrict)
817

    
818
  def testWhitespaceAndMultipleLines(self):
819
    myfile = self._CreateTempFile()
820
    for nl in ["", "\n", "\r\n"]:
821
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
822
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
823
        utils.WriteFile(myfile, data=dummydata)
824
        datalax = ReadOneLineFile(myfile, strict=False)
825
        if nl:
826
          self.assert_(set("\r\n") & set(dummydata))
827
          self.assertRaises(errors.GenericError, ReadOneLineFile,
828
                            myfile, strict=True)
829
          explen = len("Foo bar baz ") + len(ws)
830
          self.assertEqual(len(datalax), explen)
831
          self.assertEqual(datalax, dummydata[:explen])
832
          self.assertFalse(set("\r\n") & set(datalax))
833
        else:
834
          datastrict = ReadOneLineFile(myfile, strict=True)
835
          self.assertEqual(dummydata, datastrict)
836
          self.assertEqual(dummydata, datalax)
837

    
838
  def testEmptylines(self):
839
    myfile = self._CreateTempFile()
840
    myline = "myline"
841
    for nl in ["\n", "\r\n"]:
842
      for ol in ["", "otherline"]:
843
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
844
        utils.WriteFile(myfile, data=dummydata)
845
        self.assert_(set("\r\n") & set(dummydata))
846
        datalax = ReadOneLineFile(myfile, strict=False)
847
        self.assertEqual(myline, datalax)
848
        if ol:
849
          self.assertRaises(errors.GenericError, ReadOneLineFile,
850
                            myfile, strict=True)
851
        else:
852
          datastrict = ReadOneLineFile(myfile, strict=True)
853
          self.assertEqual(myline, datastrict)
854

    
855

    
856
class TestTimestampForFilename(unittest.TestCase):
857
  def test(self):
858
    self.assert_("." not in utils.TimestampForFilename())
859
    self.assert_(":" not in utils.TimestampForFilename())
860

    
861

    
862
class TestCreateBackup(testutils.GanetiTestCase):
863
  def setUp(self):
864
    testutils.GanetiTestCase.setUp(self)
865

    
866
    self.tmpdir = tempfile.mkdtemp()
867

    
868
  def tearDown(self):
869
    testutils.GanetiTestCase.tearDown(self)
870

    
871
    shutil.rmtree(self.tmpdir)
872

    
873
  def testEmpty(self):
874
    filename = PathJoin(self.tmpdir, "config.data")
875
    utils.WriteFile(filename, data="")
876
    bname = utils.CreateBackup(filename)
877
    self.assertFileContent(bname, "")
878
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
879
    utils.CreateBackup(filename)
880
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
881
    utils.CreateBackup(filename)
882
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
883

    
884
    fifoname = PathJoin(self.tmpdir, "fifo")
885
    os.mkfifo(fifoname)
886
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
887

    
888
  def testContent(self):
889
    bkpcount = 0
890
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
891
      for rep in [1, 2, 10, 127]:
892
        testdata = data * rep
893

    
894
        filename = PathJoin(self.tmpdir, "test.data_")
895
        utils.WriteFile(filename, data=testdata)
896
        self.assertFileContent(filename, testdata)
897

    
898
        for _ in range(3):
899
          bname = utils.CreateBackup(filename)
900
          bkpcount += 1
901
          self.assertFileContent(bname, testdata)
902
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
903

    
904

    
905
class TestFormatUnit(unittest.TestCase):
906
  """Test case for the FormatUnit function"""
907

    
908
  def testMiB(self):
909
    self.assertEqual(FormatUnit(1, 'h'), '1M')
910
    self.assertEqual(FormatUnit(100, 'h'), '100M')
911
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
912

    
913
    self.assertEqual(FormatUnit(1, 'm'), '1')
914
    self.assertEqual(FormatUnit(100, 'm'), '100')
915
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
916

    
917
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
918
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
919
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
920
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
921

    
922
  def testGiB(self):
923
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
924
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
925
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
926
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
927

    
928
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
929
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
930
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
931
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
932

    
933
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
934
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
935
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
936

    
937
  def testTiB(self):
938
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
939
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
940
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
941

    
942
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
943
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
944
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
945

    
946

    
947
class TestParseUnit(unittest.TestCase):
948
  """Test case for the ParseUnit function"""
949

    
950
  SCALES = (('', 1),
951
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
952
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
953
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
954

    
955
  def testRounding(self):
956
    self.assertEqual(ParseUnit('0'), 0)
957
    self.assertEqual(ParseUnit('1'), 4)
958
    self.assertEqual(ParseUnit('2'), 4)
959
    self.assertEqual(ParseUnit('3'), 4)
960

    
961
    self.assertEqual(ParseUnit('124'), 124)
962
    self.assertEqual(ParseUnit('125'), 128)
963
    self.assertEqual(ParseUnit('126'), 128)
964
    self.assertEqual(ParseUnit('127'), 128)
965
    self.assertEqual(ParseUnit('128'), 128)
966
    self.assertEqual(ParseUnit('129'), 132)
967
    self.assertEqual(ParseUnit('130'), 132)
968

    
969
  def testFloating(self):
970
    self.assertEqual(ParseUnit('0'), 0)
971
    self.assertEqual(ParseUnit('0.5'), 4)
972
    self.assertEqual(ParseUnit('1.75'), 4)
973
    self.assertEqual(ParseUnit('1.99'), 4)
974
    self.assertEqual(ParseUnit('2.00'), 4)
975
    self.assertEqual(ParseUnit('2.01'), 4)
976
    self.assertEqual(ParseUnit('3.99'), 4)
977
    self.assertEqual(ParseUnit('4.00'), 4)
978
    self.assertEqual(ParseUnit('4.01'), 8)
979
    self.assertEqual(ParseUnit('1.5G'), 1536)
980
    self.assertEqual(ParseUnit('1.8G'), 1844)
981
    self.assertEqual(ParseUnit('8.28T'), 8682212)
982

    
983
  def testSuffixes(self):
984
    for sep in ('', ' ', '   ', "\t", "\t "):
985
      for suffix, scale in TestParseUnit.SCALES:
986
        for func in (lambda x: x, str.lower, str.upper):
987
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
988
                           1024 * scale)
989

    
990
  def testInvalidInput(self):
991
    for sep in ('-', '_', ',', 'a'):
992
      for suffix, _ in TestParseUnit.SCALES:
993
        self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
994

    
995
    for suffix, _ in TestParseUnit.SCALES:
996
      self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
997

    
998

    
999
class TestParseCpuMask(unittest.TestCase):
1000
  """Test case for the ParseCpuMask function."""
1001

    
1002
  def testWellFormed(self):
1003
    self.assertEqual(utils.ParseCpuMask(""), [])
1004
    self.assertEqual(utils.ParseCpuMask("1"), [1])
1005
    self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
1006

    
1007
  def testInvalidInput(self):
1008
    self.assertRaises(errors.ParseError,
1009
                      utils.ParseCpuMask,
1010
                      "garbage")
1011
    self.assertRaises(errors.ParseError,
1012
                      utils.ParseCpuMask,
1013
                      "0,")
1014
    self.assertRaises(errors.ParseError,
1015
                      utils.ParseCpuMask,
1016
                      "0-1-2")
1017
    self.assertRaises(errors.ParseError,
1018
                      utils.ParseCpuMask,
1019
                      "2-1")
1020

    
1021
class TestSshKeys(testutils.GanetiTestCase):
1022
  """Test case for the AddAuthorizedKey function"""
1023

    
1024
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1025
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1026
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1027

    
1028
  def setUp(self):
1029
    testutils.GanetiTestCase.setUp(self)
1030
    self.tmpname = self._CreateTempFile()
1031
    handle = open(self.tmpname, 'w')
1032
    try:
1033
      handle.write("%s\n" % TestSshKeys.KEY_A)
1034
      handle.write("%s\n" % TestSshKeys.KEY_B)
1035
    finally:
1036
      handle.close()
1037

    
1038
  def testAddingNewKey(self):
1039
    utils.AddAuthorizedKey(self.tmpname,
1040
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1041

    
1042
    self.assertFileContent(self.tmpname,
1043
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1044
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1045
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1046
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1047

    
1048
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1049
    utils.AddAuthorizedKey(self.tmpname,
1050
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1051

    
1052
    self.assertFileContent(self.tmpname,
1053
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1054
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1055
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1056
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1057

    
1058
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1059
    utils.AddAuthorizedKey(self.tmpname,
1060
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1061

    
1062
    self.assertFileContent(self.tmpname,
1063
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1064
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1065
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1066

    
1067
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1068
    utils.RemoveAuthorizedKey(self.tmpname,
1069
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1070

    
1071
    self.assertFileContent(self.tmpname,
1072
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1073
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1074

    
1075
  def testRemovingNonExistingKey(self):
1076
    utils.RemoveAuthorizedKey(self.tmpname,
1077
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1078

    
1079
    self.assertFileContent(self.tmpname,
1080
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1081
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1082
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1083

    
1084

    
1085
class TestEtcHosts(testutils.GanetiTestCase):
1086
  """Test functions modifying /etc/hosts"""
1087

    
1088
  def setUp(self):
1089
    testutils.GanetiTestCase.setUp(self)
1090
    self.tmpname = self._CreateTempFile()
1091
    handle = open(self.tmpname, 'w')
1092
    try:
1093
      handle.write('# This is a test file for /etc/hosts\n')
1094
      handle.write('127.0.0.1\tlocalhost\n')
1095
      handle.write('192.0.2.1 router gw\n')
1096
    finally:
1097
      handle.close()
1098

    
1099
  def testSettingNewIp(self):
1100
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1101
                     ['myhost'])
1102

    
1103
    self.assertFileContent(self.tmpname,
1104
      "# This is a test file for /etc/hosts\n"
1105
      "127.0.0.1\tlocalhost\n"
1106
      "192.0.2.1 router gw\n"
1107
      "198.51.100.4\tmyhost.example.com myhost\n")
1108
    self.assertFileMode(self.tmpname, 0644)
1109

    
1110
  def testSettingExistingIp(self):
1111
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1112
                     ['myhost'])
1113

    
1114
    self.assertFileContent(self.tmpname,
1115
      "# This is a test file for /etc/hosts\n"
1116
      "127.0.0.1\tlocalhost\n"
1117
      "192.0.2.1\tmyhost.example.com myhost\n")
1118
    self.assertFileMode(self.tmpname, 0644)
1119

    
1120
  def testSettingDuplicateName(self):
1121
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1122

    
1123
    self.assertFileContent(self.tmpname,
1124
      "# This is a test file for /etc/hosts\n"
1125
      "127.0.0.1\tlocalhost\n"
1126
      "192.0.2.1 router gw\n"
1127
      "198.51.100.4\tmyhost\n")
1128
    self.assertFileMode(self.tmpname, 0644)
1129

    
1130
  def testRemovingExistingHost(self):
1131
    RemoveEtcHostsEntry(self.tmpname, 'router')
1132

    
1133
    self.assertFileContent(self.tmpname,
1134
      "# This is a test file for /etc/hosts\n"
1135
      "127.0.0.1\tlocalhost\n"
1136
      "192.0.2.1 gw\n")
1137
    self.assertFileMode(self.tmpname, 0644)
1138

    
1139
  def testRemovingSingleExistingHost(self):
1140
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1141

    
1142
    self.assertFileContent(self.tmpname,
1143
      "# This is a test file for /etc/hosts\n"
1144
      "192.0.2.1 router gw\n")
1145
    self.assertFileMode(self.tmpname, 0644)
1146

    
1147
  def testRemovingNonExistingHost(self):
1148
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1149

    
1150
    self.assertFileContent(self.tmpname,
1151
      "# This is a test file for /etc/hosts\n"
1152
      "127.0.0.1\tlocalhost\n"
1153
      "192.0.2.1 router gw\n")
1154
    self.assertFileMode(self.tmpname, 0644)
1155

    
1156
  def testRemovingAlias(self):
1157
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1158

    
1159
    self.assertFileContent(self.tmpname,
1160
      "# This is a test file for /etc/hosts\n"
1161
      "127.0.0.1\tlocalhost\n"
1162
      "192.0.2.1 router\n")
1163
    self.assertFileMode(self.tmpname, 0644)
1164

    
1165

    
1166
class TestGetMounts(unittest.TestCase):
1167
  """Test case for GetMounts()."""
1168

    
1169
  TESTDATA = (
1170
    "rootfs /     rootfs rw 0 0\n"
1171
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1172
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1173

    
1174
  def setUp(self):
1175
    self.tmpfile = tempfile.NamedTemporaryFile()
1176
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1177

    
1178
  def testGetMounts(self):
1179
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1180
      [
1181
        ("rootfs", "/", "rootfs", "rw"),
1182
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1183
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1184
      ])
1185

    
1186

    
1187
class TestShellQuoting(unittest.TestCase):
1188
  """Test case for shell quoting functions"""
1189

    
1190
  def testShellQuote(self):
1191
    self.assertEqual(ShellQuote('abc'), "abc")
1192
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1193
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1194
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1195
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1196

    
1197
  def testShellQuoteArgs(self):
1198
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1199
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1200
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1201

    
1202

    
1203
class TestListVisibleFiles(unittest.TestCase):
1204
  """Test case for ListVisibleFiles"""
1205

    
1206
  def setUp(self):
1207
    self.path = tempfile.mkdtemp()
1208

    
1209
  def tearDown(self):
1210
    shutil.rmtree(self.path)
1211

    
1212
  def _CreateFiles(self, files):
1213
    for name in files:
1214
      utils.WriteFile(os.path.join(self.path, name), data="test")
1215

    
1216
  def _test(self, files, expected):
1217
    self._CreateFiles(files)
1218
    found = ListVisibleFiles(self.path)
1219
    self.assertEqual(set(found), set(expected))
1220

    
1221
  def testAllVisible(self):
1222
    files = ["a", "b", "c"]
1223
    expected = files
1224
    self._test(files, expected)
1225

    
1226
  def testNoneVisible(self):
1227
    files = [".a", ".b", ".c"]
1228
    expected = []
1229
    self._test(files, expected)
1230

    
1231
  def testSomeVisible(self):
1232
    files = ["a", "b", ".c"]
1233
    expected = ["a", "b"]
1234
    self._test(files, expected)
1235

    
1236
  def testNonAbsolutePath(self):
1237
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1238

    
1239
  def testNonNormalizedPath(self):
1240
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1241
                          "/bin/../tmp")
1242

    
1243

    
1244
class TestNewUUID(unittest.TestCase):
1245
  """Test case for NewUUID"""
1246

    
1247
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1248
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1249

    
1250
  def runTest(self):
1251
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1252

    
1253

    
1254
class TestUniqueSequence(unittest.TestCase):
1255
  """Test case for UniqueSequence"""
1256

    
1257
  def _test(self, input, expected):
1258
    self.assertEqual(utils.UniqueSequence(input), expected)
1259

    
1260
  def runTest(self):
1261
    # Ordered input
1262
    self._test([1, 2, 3], [1, 2, 3])
1263
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1264
    self._test([1, 2, 2, 3], [1, 2, 3])
1265
    self._test([1, 2, 3, 3], [1, 2, 3])
1266

    
1267
    # Unordered input
1268
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1269
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1270

    
1271
    # Strings
1272
    self._test(["a", "a"], ["a"])
1273
    self._test(["a", "b"], ["a", "b"])
1274
    self._test(["a", "b", "a"], ["a", "b"])
1275

    
1276

    
1277
class TestFirstFree(unittest.TestCase):
1278
  """Test case for the FirstFree function"""
1279

    
1280
  def test(self):
1281
    """Test FirstFree"""
1282
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1283
    self.failUnlessEqual(FirstFree([]), None)
1284
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1285
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1286
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1287

    
1288

    
1289
class TestTailFile(testutils.GanetiTestCase):
1290
  """Test case for the TailFile function"""
1291

    
1292
  def testEmpty(self):
1293
    fname = self._CreateTempFile()
1294
    self.failUnlessEqual(TailFile(fname), [])
1295
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1296

    
1297
  def testAllLines(self):
1298
    data = ["test %d" % i for i in range(30)]
1299
    for i in range(30):
1300
      fname = self._CreateTempFile()
1301
      fd = open(fname, "w")
1302
      fd.write("\n".join(data[:i]))
1303
      if i > 0:
1304
        fd.write("\n")
1305
      fd.close()
1306
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1307

    
1308
  def testPartialLines(self):
1309
    data = ["test %d" % i for i in range(30)]
1310
    fname = self._CreateTempFile()
1311
    fd = open(fname, "w")
1312
    fd.write("\n".join(data))
1313
    fd.write("\n")
1314
    fd.close()
1315
    for i in range(1, 30):
1316
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1317

    
1318
  def testBigFile(self):
1319
    data = ["test %d" % i for i in range(30)]
1320
    fname = self._CreateTempFile()
1321
    fd = open(fname, "w")
1322
    fd.write("X" * 1048576)
1323
    fd.write("\n")
1324
    fd.write("\n".join(data))
1325
    fd.write("\n")
1326
    fd.close()
1327
    for i in range(1, 30):
1328
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1329

    
1330

    
1331
class _BaseFileLockTest:
1332
  """Test case for the FileLock class"""
1333

    
1334
  def testSharedNonblocking(self):
1335
    self.lock.Shared(blocking=False)
1336
    self.lock.Close()
1337

    
1338
  def testExclusiveNonblocking(self):
1339
    self.lock.Exclusive(blocking=False)
1340
    self.lock.Close()
1341

    
1342
  def testUnlockNonblocking(self):
1343
    self.lock.Unlock(blocking=False)
1344
    self.lock.Close()
1345

    
1346
  def testSharedBlocking(self):
1347
    self.lock.Shared(blocking=True)
1348
    self.lock.Close()
1349

    
1350
  def testExclusiveBlocking(self):
1351
    self.lock.Exclusive(blocking=True)
1352
    self.lock.Close()
1353

    
1354
  def testUnlockBlocking(self):
1355
    self.lock.Unlock(blocking=True)
1356
    self.lock.Close()
1357

    
1358
  def testSharedExclusiveUnlock(self):
1359
    self.lock.Shared(blocking=False)
1360
    self.lock.Exclusive(blocking=False)
1361
    self.lock.Unlock(blocking=False)
1362
    self.lock.Close()
1363

    
1364
  def testExclusiveSharedUnlock(self):
1365
    self.lock.Exclusive(blocking=False)
1366
    self.lock.Shared(blocking=False)
1367
    self.lock.Unlock(blocking=False)
1368
    self.lock.Close()
1369

    
1370
  def testSimpleTimeout(self):
1371
    # These will succeed on the first attempt, hence a short timeout
1372
    self.lock.Shared(blocking=True, timeout=10.0)
1373
    self.lock.Exclusive(blocking=False, timeout=10.0)
1374
    self.lock.Unlock(blocking=True, timeout=10.0)
1375
    self.lock.Close()
1376

    
1377
  @staticmethod
1378
  def _TryLockInner(filename, shared, blocking):
1379
    lock = utils.FileLock.Open(filename)
1380

    
1381
    if shared:
1382
      fn = lock.Shared
1383
    else:
1384
      fn = lock.Exclusive
1385

    
1386
    try:
1387
      # The timeout doesn't really matter as the parent process waits for us to
1388
      # finish anyway.
1389
      fn(blocking=blocking, timeout=0.01)
1390
    except errors.LockError, err:
1391
      return False
1392

    
1393
    return True
1394

    
1395
  def _TryLock(self, *args):
1396
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1397
                                      *args)
1398

    
1399
  def testTimeout(self):
1400
    for blocking in [True, False]:
1401
      self.lock.Exclusive(blocking=True)
1402
      self.failIf(self._TryLock(False, blocking))
1403
      self.failIf(self._TryLock(True, blocking))
1404

    
1405
      self.lock.Shared(blocking=True)
1406
      self.assert_(self._TryLock(True, blocking))
1407
      self.failIf(self._TryLock(False, blocking))
1408

    
1409
  def testCloseShared(self):
1410
    self.lock.Close()
1411
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1412

    
1413
  def testCloseExclusive(self):
1414
    self.lock.Close()
1415
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1416

    
1417
  def testCloseUnlock(self):
1418
    self.lock.Close()
1419
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1420

    
1421

    
1422
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1423
  TESTDATA = "Hello World\n" * 10
1424

    
1425
  def setUp(self):
1426
    testutils.GanetiTestCase.setUp(self)
1427

    
1428
    self.tmpfile = tempfile.NamedTemporaryFile()
1429
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1430
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1431

    
1432
    # Ensure "Open" didn't truncate file
1433
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1434

    
1435
  def tearDown(self):
1436
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1437

    
1438
    testutils.GanetiTestCase.tearDown(self)
1439

    
1440

    
1441
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1442
  def setUp(self):
1443
    self.tmpfile = tempfile.NamedTemporaryFile()
1444
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1445

    
1446

    
1447
class TestTimeFunctions(unittest.TestCase):
1448
  """Test case for time functions"""
1449

    
1450
  def runTest(self):
1451
    self.assertEqual(utils.SplitTime(1), (1, 0))
1452
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1453
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1454
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1455
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1456
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1457
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1458
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1459

    
1460
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1461

    
1462
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1463
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1464
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1465

    
1466
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1467
                     1218448917.481)
1468
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1469

    
1470
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1471
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1472
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1473
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1474
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1475

    
1476

    
1477
class FieldSetTestCase(unittest.TestCase):
1478
  """Test case for FieldSets"""
1479

    
1480
  def testSimpleMatch(self):
1481
    f = utils.FieldSet("a", "b", "c", "def")
1482
    self.failUnless(f.Matches("a"))
1483
    self.failIf(f.Matches("d"), "Substring matched")
1484
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1485
    self.failIf(f.NonMatching(["b", "c"]))
1486
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1487
    self.failUnless(f.NonMatching(["a", "d"]))
1488

    
1489
  def testRegexMatch(self):
1490
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1491
    self.failUnless(f.Matches("b1"))
1492
    self.failUnless(f.Matches("b99"))
1493
    self.failIf(f.Matches("b/1"))
1494
    self.failIf(f.NonMatching(["b12", "c"]))
1495
    self.failUnless(f.NonMatching(["a", "1"]))
1496

    
1497
class TestForceDictType(unittest.TestCase):
1498
  """Test case for ForceDictType"""
1499

    
1500
  def setUp(self):
1501
    self.key_types = {
1502
      'a': constants.VTYPE_INT,
1503
      'b': constants.VTYPE_BOOL,
1504
      'c': constants.VTYPE_STRING,
1505
      'd': constants.VTYPE_SIZE,
1506
      }
1507

    
1508
  def _fdt(self, dict, allowed_values=None):
1509
    if allowed_values is None:
1510
      utils.ForceDictType(dict, self.key_types)
1511
    else:
1512
      utils.ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1513

    
1514
    return dict
1515

    
1516
  def testSimpleDict(self):
1517
    self.assertEqual(self._fdt({}), {})
1518
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1519
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1520
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1521
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1522
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1523
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1524
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1525
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1526
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1527
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1528
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1529

    
1530
  def testErrors(self):
1531
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1532
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1533
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1534
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1535

    
1536

    
1537
class TestIsNormAbsPath(unittest.TestCase):
1538
  """Testing case for IsNormAbsPath"""
1539

    
1540
  def _pathTestHelper(self, path, result):
1541
    if result:
1542
      self.assert_(utils.IsNormAbsPath(path),
1543
          "Path %s should result absolute and normalized" % path)
1544
    else:
1545
      self.assertFalse(utils.IsNormAbsPath(path),
1546
          "Path %s should not result absolute and normalized" % path)
1547

    
1548
  def testBase(self):
1549
    self._pathTestHelper('/etc', True)
1550
    self._pathTestHelper('/srv', True)
1551
    self._pathTestHelper('etc', False)
1552
    self._pathTestHelper('/etc/../root', False)
1553
    self._pathTestHelper('/etc/', False)
1554

    
1555

    
1556
class TestSafeEncode(unittest.TestCase):
1557
  """Test case for SafeEncode"""
1558

    
1559
  def testAscii(self):
1560
    for txt in [string.digits, string.letters, string.punctuation]:
1561
      self.failUnlessEqual(txt, SafeEncode(txt))
1562

    
1563
  def testDoubleEncode(self):
1564
    for i in range(255):
1565
      txt = SafeEncode(chr(i))
1566
      self.failUnlessEqual(txt, SafeEncode(txt))
1567

    
1568
  def testUnicode(self):
1569
    # 1024 is high enough to catch non-direct ASCII mappings
1570
    for i in range(1024):
1571
      txt = SafeEncode(unichr(i))
1572
      self.failUnlessEqual(txt, SafeEncode(txt))
1573

    
1574

    
1575
class TestFormatTime(unittest.TestCase):
1576
  """Testing case for FormatTime"""
1577

    
1578
  def testNone(self):
1579
    self.failUnlessEqual(FormatTime(None), "N/A")
1580

    
1581
  def testInvalid(self):
1582
    self.failUnlessEqual(FormatTime(()), "N/A")
1583

    
1584
  def testNow(self):
1585
    # tests that we accept time.time input
1586
    FormatTime(time.time())
1587
    # tests that we accept int input
1588
    FormatTime(int(time.time()))
1589

    
1590

    
1591
class RunInSeparateProcess(unittest.TestCase):
1592
  def test(self):
1593
    for exp in [True, False]:
1594
      def _child():
1595
        return exp
1596

    
1597
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1598

    
1599
  def testArgs(self):
1600
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1601
      def _child(carg1, carg2):
1602
        return carg1 == "Foo" and carg2 == arg
1603

    
1604
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1605

    
1606
  def testPid(self):
1607
    parent_pid = os.getpid()
1608

    
1609
    def _check():
1610
      return os.getpid() == parent_pid
1611

    
1612
    self.failIf(utils.RunInSeparateProcess(_check))
1613

    
1614
  def testSignal(self):
1615
    def _kill():
1616
      os.kill(os.getpid(), signal.SIGTERM)
1617

    
1618
    self.assertRaises(errors.GenericError,
1619
                      utils.RunInSeparateProcess, _kill)
1620

    
1621
  def testException(self):
1622
    def _exc():
1623
      raise errors.GenericError("This is a test")
1624

    
1625
    self.assertRaises(errors.GenericError,
1626
                      utils.RunInSeparateProcess, _exc)
1627

    
1628

    
1629
class TestFingerprintFile(unittest.TestCase):
1630
  def setUp(self):
1631
    self.tmpfile = tempfile.NamedTemporaryFile()
1632

    
1633
  def test(self):
1634
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1635
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1636

    
1637
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1638
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1639
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1640

    
1641

    
1642
class TestUnescapeAndSplit(unittest.TestCase):
1643
  """Testing case for UnescapeAndSplit"""
1644

    
1645
  def setUp(self):
1646
    # testing more that one separator for regexp safety
1647
    self._seps = [",", "+", "."]
1648

    
1649
  def testSimple(self):
1650
    a = ["a", "b", "c", "d"]
1651
    for sep in self._seps:
1652
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1653

    
1654
  def testEscape(self):
1655
    for sep in self._seps:
1656
      a = ["a", "b\\" + sep + "c", "d"]
1657
      b = ["a", "b" + sep + "c", "d"]
1658
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1659

    
1660
  def testDoubleEscape(self):
1661
    for sep in self._seps:
1662
      a = ["a", "b\\\\", "c", "d"]
1663
      b = ["a", "b\\", "c", "d"]
1664
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1665

    
1666
  def testThreeEscape(self):
1667
    for sep in self._seps:
1668
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1669
      b = ["a", "b\\" + sep + "c", "d"]
1670
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1671

    
1672

    
1673
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1674
  def setUp(self):
1675
    self.tmpdir = tempfile.mkdtemp()
1676

    
1677
  def tearDown(self):
1678
    shutil.rmtree(self.tmpdir)
1679

    
1680
  def _checkRsaPrivateKey(self, key):
1681
    lines = key.splitlines()
1682
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1683
            "-----END RSA PRIVATE KEY-----" in lines)
1684

    
1685
  def _checkCertificate(self, cert):
1686
    lines = cert.splitlines()
1687
    return ("-----BEGIN CERTIFICATE-----" in lines and
1688
            "-----END CERTIFICATE-----" in lines)
1689

    
1690
  def test(self):
1691
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1692
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1693
      self._checkRsaPrivateKey(key_pem)
1694
      self._checkCertificate(cert_pem)
1695

    
1696
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1697
                                           key_pem)
1698
      self.assert_(key.bits() >= 1024)
1699
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1700
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1701

    
1702
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1703
                                             cert_pem)
1704
      self.failIf(x509.has_expired())
1705
      self.assertEqual(x509.get_issuer().CN, common_name)
1706
      self.assertEqual(x509.get_subject().CN, common_name)
1707
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1708

    
1709
  def testLegacy(self):
1710
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1711

    
1712
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1713

    
1714
    cert1 = utils.ReadFile(cert1_filename)
1715

    
1716
    self.assert_(self._checkRsaPrivateKey(cert1))
1717
    self.assert_(self._checkCertificate(cert1))
1718

    
1719

    
1720
class TestPathJoin(unittest.TestCase):
1721
  """Testing case for PathJoin"""
1722

    
1723
  def testBasicItems(self):
1724
    mlist = ["/a", "b", "c"]
1725
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1726

    
1727
  def testNonAbsPrefix(self):
1728
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1729

    
1730
  def testBackTrack(self):
1731
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1732

    
1733
  def testMultiAbs(self):
1734
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1735

    
1736

    
1737
class TestValidateServiceName(unittest.TestCase):
1738
  def testValid(self):
1739
    testnames = [
1740
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1741
      "ganeti",
1742
      "gnt-masterd",
1743
      "HELLO_WORLD_SVC",
1744
      "hello.world.1",
1745
      "0", "80", "1111", "65535",
1746
      ]
1747

    
1748
    for name in testnames:
1749
      self.assertEqual(utils.ValidateServiceName(name), name)
1750

    
1751
  def testInvalid(self):
1752
    testnames = [
1753
      -15756, -1, 65536, 133428083,
1754
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1755
      "-8546", "-1", "65536",
1756
      (129 * "A"),
1757
      ]
1758

    
1759
    for name in testnames:
1760
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1761

    
1762

    
1763
class TestParseAsn1Generalizedtime(unittest.TestCase):
1764
  def test(self):
1765
    # UTC
1766
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1767
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1768
                     1266860512)
1769
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1770
                     (2**31) - 1)
1771

    
1772
    # With offset
1773
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1774
                     1266860512)
1775
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1776
                     1266931012)
1777
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1778
                     1266931088)
1779
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1780
                     1266931295)
1781
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1782
                     3600)
1783

    
1784
    # Leap seconds are not supported by datetime.datetime
1785
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1786
                      "19841231235960+0000")
1787
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1788
                      "19920630235960+0000")
1789

    
1790
    # Errors
1791
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1792
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1793
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1794
                      "20100222174152")
1795
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1796
                      "Mon Feb 22 17:47:02 UTC 2010")
1797
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1798
                      "2010-02-22 17:42:02")
1799

    
1800

    
1801
class TestGetX509CertValidity(testutils.GanetiTestCase):
1802
  def setUp(self):
1803
    testutils.GanetiTestCase.setUp(self)
1804

    
1805
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1806

    
1807
    # Test whether we have pyOpenSSL 0.7 or above
1808
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1809

    
1810
    if not self.pyopenssl0_7:
1811
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1812
                    " function correctly")
1813

    
1814
  def _LoadCert(self, name):
1815
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1816
                                           self._ReadTestData(name))
1817

    
1818
  def test(self):
1819
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1820
    if self.pyopenssl0_7:
1821
      self.assertEqual(validity, (1266919967, 1267524767))
1822
    else:
1823
      self.assertEqual(validity, (None, None))
1824

    
1825

    
1826
class TestSignX509Certificate(unittest.TestCase):
1827
  KEY = "My private key!"
1828
  KEY_OTHER = "Another key"
1829

    
1830
  def test(self):
1831
    # Generate certificate valid for 5 minutes
1832
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1833

    
1834
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1835
                                           cert_pem)
1836

    
1837
    # No signature at all
1838
    self.assertRaises(errors.GenericError,
1839
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1840

    
1841
    # Invalid input
1842
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1843
                      "", self.KEY)
1844
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1845
                      "X-Ganeti-Signature: \n", self.KEY)
1846
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1847
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1848
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1849
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1850
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1851
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1852

    
1853
    # Invalid salt
1854
    for salt in list("-_@$,:;/\\ \t\n"):
1855
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1856
                        cert_pem, self.KEY, "foo%sbar" % salt)
1857

    
1858
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1859
                 utils.GenerateSecret(numbytes=4),
1860
                 utils.GenerateSecret(numbytes=16),
1861
                 "{123:456}".encode("hex")]:
1862
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1863

    
1864
      self._Check(cert, salt, signed_pem)
1865

    
1866
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1867
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1868
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1869
                               "lines----\n------ at\nthe end!"))
1870

    
1871
  def _Check(self, cert, salt, pem):
1872
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1873
    self.assertEqual(salt, salt2)
1874
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1875

    
1876
    # Other key
1877
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1878
                      pem, self.KEY_OTHER)
1879

    
1880

    
1881
class TestMakedirs(unittest.TestCase):
1882
  def setUp(self):
1883
    self.tmpdir = tempfile.mkdtemp()
1884

    
1885
  def tearDown(self):
1886
    shutil.rmtree(self.tmpdir)
1887

    
1888
  def testNonExisting(self):
1889
    path = PathJoin(self.tmpdir, "foo")
1890
    utils.Makedirs(path)
1891
    self.assert_(os.path.isdir(path))
1892

    
1893
  def testExisting(self):
1894
    path = PathJoin(self.tmpdir, "foo")
1895
    os.mkdir(path)
1896
    utils.Makedirs(path)
1897
    self.assert_(os.path.isdir(path))
1898

    
1899
  def testRecursiveNonExisting(self):
1900
    path = PathJoin(self.tmpdir, "foo/bar/baz")
1901
    utils.Makedirs(path)
1902
    self.assert_(os.path.isdir(path))
1903

    
1904
  def testRecursiveExisting(self):
1905
    path = PathJoin(self.tmpdir, "B/moo/xyz")
1906
    self.assertFalse(os.path.exists(path))
1907
    os.mkdir(PathJoin(self.tmpdir, "B"))
1908
    utils.Makedirs(path)
1909
    self.assert_(os.path.isdir(path))
1910

    
1911

    
1912
class TestRetry(testutils.GanetiTestCase):
1913
  def setUp(self):
1914
    testutils.GanetiTestCase.setUp(self)
1915
    self.retries = 0
1916

    
1917
  @staticmethod
1918
  def _RaiseRetryAgain():
1919
    raise utils.RetryAgain()
1920

    
1921
  @staticmethod
1922
  def _RaiseRetryAgainWithArg(args):
1923
    raise utils.RetryAgain(*args)
1924

    
1925
  def _WrongNestedLoop(self):
1926
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1927

    
1928
  def _RetryAndSucceed(self, retries):
1929
    if self.retries < retries:
1930
      self.retries += 1
1931
      raise utils.RetryAgain()
1932
    else:
1933
      return True
1934

    
1935
  def testRaiseTimeout(self):
1936
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1937
                          self._RaiseRetryAgain, 0.01, 0.02)
1938
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1939
                          self._RetryAndSucceed, 0.01, 0, args=[1])
1940
    self.failUnlessEqual(self.retries, 1)
1941

    
1942
  def testComplete(self):
1943
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1944
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
1945
                         True)
1946
    self.failUnlessEqual(self.retries, 2)
1947

    
1948
  def testNestedLoop(self):
1949
    try:
1950
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1951
                            self._WrongNestedLoop, 0, 1)
1952
    except utils.RetryTimeout:
1953
      self.fail("Didn't detect inner loop's exception")
1954

    
1955
  def testTimeoutArgument(self):
1956
    retry_arg="my_important_debugging_message"
1957
    try:
1958
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
1959
    except utils.RetryTimeout, err:
1960
      self.failUnlessEqual(err.args, (retry_arg, ))
1961
    else:
1962
      self.fail("Expected timeout didn't happen")
1963

    
1964
  def testRaiseInnerWithExc(self):
1965
    retry_arg="my_important_debugging_message"
1966
    try:
1967
      try:
1968
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1969
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
1970
      except utils.RetryTimeout, err:
1971
        err.RaiseInner()
1972
      else:
1973
        self.fail("Expected timeout didn't happen")
1974
    except errors.GenericError, err:
1975
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1976
    else:
1977
      self.fail("Expected GenericError didn't happen")
1978

    
1979
  def testRaiseInnerWithMsg(self):
1980
    retry_arg="my_important_debugging_message"
1981
    try:
1982
      try:
1983
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
1984
                    args=[[retry_arg, retry_arg]])
1985
      except utils.RetryTimeout, err:
1986
        err.RaiseInner()
1987
      else:
1988
        self.fail("Expected timeout didn't happen")
1989
    except utils.RetryTimeout, err:
1990
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
1991
    else:
1992
      self.fail("Expected RetryTimeout didn't happen")
1993

    
1994

    
1995
class TestLineSplitter(unittest.TestCase):
1996
  def test(self):
1997
    lines = []
1998
    ls = utils.LineSplitter(lines.append)
1999
    ls.write("Hello World\n")
2000
    self.assertEqual(lines, [])
2001
    ls.write("Foo\n Bar\r\n ")
2002
    ls.write("Baz")
2003
    ls.write("Moo")
2004
    self.assertEqual(lines, [])
2005
    ls.flush()
2006
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2007
    ls.close()
2008
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2009

    
2010
  def _testExtra(self, line, all_lines, p1, p2):
2011
    self.assertEqual(p1, 999)
2012
    self.assertEqual(p2, "extra")
2013
    all_lines.append(line)
2014

    
2015
  def testExtraArgsNoFlush(self):
2016
    lines = []
2017
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2018
    ls.write("\n\nHello World\n")
2019
    ls.write("Foo\n Bar\r\n ")
2020
    ls.write("")
2021
    ls.write("Baz")
2022
    ls.write("Moo\n\nx\n")
2023
    self.assertEqual(lines, [])
2024
    ls.close()
2025
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2026
                             "", "x"])
2027

    
2028

    
2029
class TestReadLockedPidFile(unittest.TestCase):
2030
  def setUp(self):
2031
    self.tmpdir = tempfile.mkdtemp()
2032

    
2033
  def tearDown(self):
2034
    shutil.rmtree(self.tmpdir)
2035

    
2036
  def testNonExistent(self):
2037
    path = PathJoin(self.tmpdir, "nonexist")
2038
    self.assert_(utils.ReadLockedPidFile(path) is None)
2039

    
2040
  def testUnlocked(self):
2041
    path = PathJoin(self.tmpdir, "pid")
2042
    utils.WriteFile(path, data="123")
2043
    self.assert_(utils.ReadLockedPidFile(path) is None)
2044

    
2045
  def testLocked(self):
2046
    path = PathJoin(self.tmpdir, "pid")
2047
    utils.WriteFile(path, data="123")
2048

    
2049
    fl = utils.FileLock.Open(path)
2050
    try:
2051
      fl.Exclusive(blocking=True)
2052

    
2053
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2054
    finally:
2055
      fl.Close()
2056

    
2057
    self.assert_(utils.ReadLockedPidFile(path) is None)
2058

    
2059
  def testError(self):
2060
    path = PathJoin(self.tmpdir, "foobar", "pid")
2061
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2062
    # open(2) should return ENOTDIR
2063
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2064

    
2065

    
2066
class TestCertVerification(testutils.GanetiTestCase):
2067
  def setUp(self):
2068
    testutils.GanetiTestCase.setUp(self)
2069

    
2070
    self.tmpdir = tempfile.mkdtemp()
2071

    
2072
  def tearDown(self):
2073
    shutil.rmtree(self.tmpdir)
2074

    
2075
  def testVerifyCertificate(self):
2076
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2077
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2078
                                           cert_pem)
2079

    
2080
    # Not checking return value as this certificate is expired
2081
    utils.VerifyX509Certificate(cert, 30, 7)
2082

    
2083

    
2084
class TestVerifyCertificateInner(unittest.TestCase):
2085
  def test(self):
2086
    vci = utils._VerifyCertificateInner
2087

    
2088
    # Valid
2089
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2090
                     (None, None))
2091

    
2092
    # Not yet valid
2093
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2094
    self.assertEqual(errcode, utils.CERT_WARNING)
2095

    
2096
    # Expiring soon
2097
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2098
    self.assertEqual(errcode, utils.CERT_ERROR)
2099

    
2100
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2101
    self.assertEqual(errcode, utils.CERT_WARNING)
2102

    
2103
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2104
    self.assertEqual(errcode, None)
2105

    
2106
    # Expired
2107
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2108
    self.assertEqual(errcode, utils.CERT_ERROR)
2109

    
2110
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2111
    self.assertEqual(errcode, utils.CERT_ERROR)
2112

    
2113
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2114
    self.assertEqual(errcode, utils.CERT_ERROR)
2115

    
2116
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2117
    self.assertEqual(errcode, utils.CERT_ERROR)
2118

    
2119

    
2120
class TestHmacFunctions(unittest.TestCase):
2121
  # Digests can be checked with "openssl sha1 -hmac $key"
2122
  def testSha1Hmac(self):
2123
    self.assertEqual(utils.Sha1Hmac("", ""),
2124
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2125
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2126
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2127
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2128
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2129

    
2130
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2131
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2132
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2133

    
2134
  def testSha1HmacSalt(self):
2135
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2136
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2137
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2138
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2139
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2140
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2141

    
2142
  def testVerifySha1Hmac(self):
2143
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2144
                                               "7d64b71fb76370690e1d")))
2145
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2146
                                      ("f904c2476527c6d3e660"
2147
                                       "9ab683c66fa0652cb1dc")))
2148

    
2149
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2150
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2151
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2152
                                      digest.lower()))
2153
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2154
                                      digest.upper()))
2155
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2156
                                      digest.title()))
2157

    
2158
  def testVerifySha1HmacSalt(self):
2159
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2160
                                      ("17a4adc34d69c0d367d4"
2161
                                       "ffbef96fd41d4df7a6e8"),
2162
                                      salt="abc9"))
2163
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2164
                                      ("7f264f8114c9066afc9b"
2165
                                       "b7636e1786d996d3cc0d"),
2166
                                      salt="xyz0"))
2167

    
2168

    
2169
class TestIgnoreSignals(unittest.TestCase):
2170
  """Test the IgnoreSignals decorator"""
2171

    
2172
  @staticmethod
2173
  def _Raise(exception):
2174
    raise exception
2175

    
2176
  @staticmethod
2177
  def _Return(rval):
2178
    return rval
2179

    
2180
  def testIgnoreSignals(self):
2181
    sock_err_intr = socket.error(errno.EINTR, "Message")
2182
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2183

    
2184
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2185
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2186

    
2187
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2188
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2189
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2190
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2191

    
2192
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2193
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2194
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2195
                      sock_err_inval)
2196
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2197
                      env_err_inval)
2198

    
2199
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2200
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2201

    
2202

    
2203
class TestEnsureDirs(unittest.TestCase):
2204
  """Tests for EnsureDirs"""
2205

    
2206
  def setUp(self):
2207
    self.dir = tempfile.mkdtemp()
2208
    self.old_umask = os.umask(0777)
2209

    
2210
  def testEnsureDirs(self):
2211
    utils.EnsureDirs([
2212
        (PathJoin(self.dir, "foo"), 0777),
2213
        (PathJoin(self.dir, "bar"), 0000),
2214
        ])
2215
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2216
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2217

    
2218
  def tearDown(self):
2219
    os.rmdir(PathJoin(self.dir, "foo"))
2220
    os.rmdir(PathJoin(self.dir, "bar"))
2221
    os.rmdir(self.dir)
2222
    os.umask(self.old_umask)
2223

    
2224

    
2225
class TestFormatSeconds(unittest.TestCase):
2226
  def test(self):
2227
    self.assertEqual(utils.FormatSeconds(1), "1s")
2228
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2229
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2230
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2231
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2232
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2233
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2234
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2235
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2236
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2237
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2238

    
2239
  def testFloat(self):
2240
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2241
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2242
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2243
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2244

    
2245

    
2246
class TestIgnoreProcessNotFound(unittest.TestCase):
2247
  @staticmethod
2248
  def _WritePid(fd):
2249
    os.write(fd, str(os.getpid()))
2250
    os.close(fd)
2251
    return True
2252

    
2253
  def test(self):
2254
    (pid_read_fd, pid_write_fd) = os.pipe()
2255

    
2256
    # Start short-lived process which writes its PID to pipe
2257
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2258
    os.close(pid_write_fd)
2259

    
2260
    # Read PID from pipe
2261
    pid = int(os.read(pid_read_fd, 1024))
2262
    os.close(pid_read_fd)
2263

    
2264
    # Try to send signal to process which exited recently
2265
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2266

    
2267

    
2268
class TestShellWriter(unittest.TestCase):
2269
  def test(self):
2270
    buf = StringIO()
2271
    sw = utils.ShellWriter(buf)
2272
    sw.Write("#!/bin/bash")
2273
    sw.Write("if true; then")
2274
    sw.IncIndent()
2275
    try:
2276
      sw.Write("echo true")
2277

    
2278
      sw.Write("for i in 1 2 3")
2279
      sw.Write("do")
2280
      sw.IncIndent()
2281
      try:
2282
        self.assertEqual(sw._indent, 2)
2283
        sw.Write("date")
2284
      finally:
2285
        sw.DecIndent()
2286
      sw.Write("done")
2287
    finally:
2288
      sw.DecIndent()
2289
    sw.Write("echo %s", utils.ShellQuote("Hello World"))
2290
    sw.Write("exit 0")
2291

    
2292
    self.assertEqual(sw._indent, 0)
2293

    
2294
    output = buf.getvalue()
2295

    
2296
    self.assert_(output.endswith("\n"))
2297

    
2298
    lines = output.splitlines()
2299
    self.assertEqual(len(lines), 9)
2300
    self.assertEqual(lines[0], "#!/bin/bash")
2301
    self.assert_(re.match(r"^\s+date$", lines[5]))
2302
    self.assertEqual(lines[7], "echo 'Hello World'")
2303

    
2304
  def testEmpty(self):
2305
    buf = StringIO()
2306
    sw = utils.ShellWriter(buf)
2307
    sw = None
2308
    self.assertEqual(buf.getvalue(), "")
2309

    
2310

    
2311
if __name__ == '__main__':
2312
  testutils.GanetiTestProgram()