Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ f8ea4ada

History | View | Annotate | Download (80.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import stat
31
import signal
32
import socket
33
import shutil
34
import re
35
import select
36
import string
37
import fcntl
38
import OpenSSL
39
import warnings
40
import distutils.version
41
import glob
42
import errno
43

    
44
import ganeti
45
import testutils
46
from ganeti import constants
47
from ganeti import compat
48
from ganeti import utils
49
from ganeti import errors
50
from ganeti import serializer
51
from ganeti.utils import IsProcessAlive, RunCmd, \
52
     RemoveFile, MatchNameComponent, FormatUnit, \
53
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
54
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
55
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
56
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
57
     UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
58

    
59
from ganeti.errors import LockError, UnitParseError, GenericError, \
60
     ProgrammerError, OpPrereqError
61

    
62

    
63
class TestIsProcessAlive(unittest.TestCase):
64
  """Testing case for IsProcessAlive"""
65

    
66
  def testExists(self):
67
    mypid = os.getpid()
68
    self.assert_(IsProcessAlive(mypid),
69
                 "can't find myself running")
70

    
71
  def testNotExisting(self):
72
    pid_non_existing = os.fork()
73
    if pid_non_existing == 0:
74
      os._exit(0)
75
    elif pid_non_existing < 0:
76
      raise SystemError("can't fork")
77
    os.waitpid(pid_non_existing, 0)
78
    self.assert_(not IsProcessAlive(pid_non_existing),
79
                 "nonexisting process detected")
80

    
81

    
82
class TestGetProcStatusPath(unittest.TestCase):
83
  def test(self):
84
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
85
    self.assertNotEqual(utils._GetProcStatusPath(1),
86
                        utils._GetProcStatusPath(2))
87

    
88

    
89
class TestIsProcessHandlingSignal(unittest.TestCase):
90
  def setUp(self):
91
    self.tmpdir = tempfile.mkdtemp()
92

    
93
  def tearDown(self):
94
    shutil.rmtree(self.tmpdir)
95

    
96
  def testParseSigsetT(self):
97
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
98
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
99
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
100
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
101
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
102
                     set([2, 10, 32, 33]))
103
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
104
                     set([2, 32, 33]))
105
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
106
                     set([2, 28, 32, 33]))
107
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
108
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
109
                          24, 25, 26, 28, 31]))
110
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
111

    
112
  def testGetProcStatusField(self):
113
    for field in ["SigCgt", "Name", "FDSize"]:
114
      for value in ["", "0", "cat", "  1234 KB"]:
115
        pstatus = "\n".join([
116
          "VmPeak: 999 kB",
117
          "%s: %s" % (field, value),
118
          "TracerPid: 0",
119
          ])
120
        result = utils._GetProcStatusField(pstatus, field)
121
        self.assertEqual(result, value.strip())
122

    
123
  def test(self):
124
    sp = utils.PathJoin(self.tmpdir, "status")
125

    
126
    utils.WriteFile(sp, data="\n".join([
127
      "Name:   bash",
128
      "State:  S (sleeping)",
129
      "SleepAVG:       98%",
130
      "Pid:    22250",
131
      "PPid:   10858",
132
      "TracerPid:      0",
133
      "SigBlk: 0000000000010000",
134
      "SigIgn: 0000000000384004",
135
      "SigCgt: 000000004b813efb",
136
      "CapEff: 0000000000000000",
137
      ]))
138

    
139
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
140

    
141
  def testNoSigCgt(self):
142
    sp = utils.PathJoin(self.tmpdir, "status")
143

    
144
    utils.WriteFile(sp, data="\n".join([
145
      "Name:   bash",
146
      ]))
147

    
148
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
149
                      1234, 10, status_path=sp)
150

    
151
  def testNoSuchFile(self):
152
    sp = utils.PathJoin(self.tmpdir, "notexist")
153

    
154
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
155

    
156
  @staticmethod
157
  def _TestRealProcess():
158
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
159
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160
      raise Exception("SIGUSR1 is handled when it should not be")
161

    
162
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
163
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164
      raise Exception("SIGUSR1 is not handled when it should be")
165

    
166
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
167
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
168
      raise Exception("SIGUSR1 is not handled when it should be")
169

    
170
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
171
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
172
      raise Exception("SIGUSR1 is handled when it should not be")
173

    
174
    return True
175

    
176
  def testRealProcess(self):
177
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
178

    
179

    
180
class TestPidFileFunctions(unittest.TestCase):
181
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
182

    
183
  def setUp(self):
184
    self.dir = tempfile.mkdtemp()
185
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
186
    utils.DaemonPidFileName = self.f_dpn
187

    
188
  def testPidFileFunctions(self):
189
    pid_file = self.f_dpn('test')
190
    utils.WritePidFile('test')
191
    self.failUnless(os.path.exists(pid_file),
192
                    "PID file should have been created")
193
    read_pid = utils.ReadPidFile(pid_file)
194
    self.failUnlessEqual(read_pid, os.getpid())
195
    self.failUnless(utils.IsProcessAlive(read_pid))
196
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
197
    utils.RemovePidFile('test')
198
    self.failIf(os.path.exists(pid_file),
199
                "PID file should not exist anymore")
200
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
201
                         "ReadPidFile should return 0 for missing pid file")
202
    fh = open(pid_file, "w")
203
    fh.write("blah\n")
204
    fh.close()
205
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
206
                         "ReadPidFile should return 0 for invalid pid file")
207
    utils.RemovePidFile('test')
208
    self.failIf(os.path.exists(pid_file),
209
                "PID file should not exist anymore")
210

    
211
  def testKill(self):
212
    pid_file = self.f_dpn('child')
213
    r_fd, w_fd = os.pipe()
214
    new_pid = os.fork()
215
    if new_pid == 0: #child
216
      utils.WritePidFile('child')
217
      os.write(w_fd, 'a')
218
      signal.pause()
219
      os._exit(0)
220
      return
221
    # else we are in the parent
222
    # wait until the child has written the pid file
223
    os.read(r_fd, 1)
224
    read_pid = utils.ReadPidFile(pid_file)
225
    self.failUnlessEqual(read_pid, new_pid)
226
    self.failUnless(utils.IsProcessAlive(new_pid))
227
    utils.KillProcess(new_pid, waitpid=True)
228
    self.failIf(utils.IsProcessAlive(new_pid))
229
    utils.RemovePidFile('child')
230
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
231

    
232
  def tearDown(self):
233
    for name in os.listdir(self.dir):
234
      os.unlink(os.path.join(self.dir, name))
235
    os.rmdir(self.dir)
236

    
237

    
238
class TestRunCmd(testutils.GanetiTestCase):
239
  """Testing case for the RunCmd function"""
240

    
241
  def setUp(self):
242
    testutils.GanetiTestCase.setUp(self)
243
    self.magic = time.ctime() + " ganeti test"
244
    self.fname = self._CreateTempFile()
245

    
246
  def testOk(self):
247
    """Test successful exit code"""
248
    result = RunCmd("/bin/sh -c 'exit 0'")
249
    self.assertEqual(result.exit_code, 0)
250
    self.assertEqual(result.output, "")
251

    
252
  def testFail(self):
253
    """Test fail exit code"""
254
    result = RunCmd("/bin/sh -c 'exit 1'")
255
    self.assertEqual(result.exit_code, 1)
256
    self.assertEqual(result.output, "")
257

    
258
  def testStdout(self):
259
    """Test standard output"""
260
    cmd = 'echo -n "%s"' % self.magic
261
    result = RunCmd("/bin/sh -c '%s'" % cmd)
262
    self.assertEqual(result.stdout, self.magic)
263
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
264
    self.assertEqual(result.output, "")
265
    self.assertFileContent(self.fname, self.magic)
266

    
267
  def testStderr(self):
268
    """Test standard error"""
269
    cmd = 'echo -n "%s"' % self.magic
270
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
271
    self.assertEqual(result.stderr, self.magic)
272
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
273
    self.assertEqual(result.output, "")
274
    self.assertFileContent(self.fname, self.magic)
275

    
276
  def testCombined(self):
277
    """Test combined output"""
278
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
279
    expected = "A" + self.magic + "B" + self.magic
280
    result = RunCmd("/bin/sh -c '%s'" % cmd)
281
    self.assertEqual(result.output, expected)
282
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
283
    self.assertEqual(result.output, "")
284
    self.assertFileContent(self.fname, expected)
285

    
286
  def testSignal(self):
287
    """Test signal"""
288
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
289
    self.assertEqual(result.signal, 15)
290
    self.assertEqual(result.output, "")
291

    
292
  def testListRun(self):
293
    """Test list runs"""
294
    result = RunCmd(["true"])
295
    self.assertEqual(result.signal, None)
296
    self.assertEqual(result.exit_code, 0)
297
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
298
    self.assertEqual(result.signal, None)
299
    self.assertEqual(result.exit_code, 1)
300
    result = RunCmd(["echo", "-n", self.magic])
301
    self.assertEqual(result.signal, None)
302
    self.assertEqual(result.exit_code, 0)
303
    self.assertEqual(result.stdout, self.magic)
304

    
305
  def testFileEmptyOutput(self):
306
    """Test file output"""
307
    result = RunCmd(["true"], output=self.fname)
308
    self.assertEqual(result.signal, None)
309
    self.assertEqual(result.exit_code, 0)
310
    self.assertFileContent(self.fname, "")
311

    
312
  def testLang(self):
313
    """Test locale environment"""
314
    old_env = os.environ.copy()
315
    try:
316
      os.environ["LANG"] = "en_US.UTF-8"
317
      os.environ["LC_ALL"] = "en_US.UTF-8"
318
      result = RunCmd(["locale"])
319
      for line in result.output.splitlines():
320
        key, value = line.split("=", 1)
321
        # Ignore these variables, they're overridden by LC_ALL
322
        if key == "LANG" or key == "LANGUAGE":
323
          continue
324
        self.failIf(value and value != "C" and value != '"C"',
325
            "Variable %s is set to the invalid value '%s'" % (key, value))
326
    finally:
327
      os.environ = old_env
328

    
329
  def testDefaultCwd(self):
330
    """Test default working directory"""
331
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
332

    
333
  def testCwd(self):
334
    """Test default working directory"""
335
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
336
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
337
    cwd = os.getcwd()
338
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
339

    
340
  def testResetEnv(self):
341
    """Test environment reset functionality"""
342
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
343
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
344
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
345

    
346

    
347
class TestRunParts(unittest.TestCase):
348
  """Testing case for the RunParts function"""
349

    
350
  def setUp(self):
351
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
352

    
353
  def tearDown(self):
354
    shutil.rmtree(self.rundir)
355

    
356
  def testEmpty(self):
357
    """Test on an empty dir"""
358
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
359

    
360
  def testSkipWrongName(self):
361
    """Test that wrong files are skipped"""
362
    fname = os.path.join(self.rundir, "00test.dot")
363
    utils.WriteFile(fname, data="")
364
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
365
    relname = os.path.basename(fname)
366
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
367
                         [(relname, constants.RUNPARTS_SKIP, None)])
368

    
369
  def testSkipNonExec(self):
370
    """Test that non executable files are skipped"""
371
    fname = os.path.join(self.rundir, "00test")
372
    utils.WriteFile(fname, data="")
373
    relname = os.path.basename(fname)
374
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
375
                         [(relname, constants.RUNPARTS_SKIP, None)])
376

    
377
  def testError(self):
378
    """Test error on a broken executable"""
379
    fname = os.path.join(self.rundir, "00test")
380
    utils.WriteFile(fname, data="")
381
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
382
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
383
    self.failUnlessEqual(relname, os.path.basename(fname))
384
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
385
    self.failUnless(error)
386

    
387
  def testSorted(self):
388
    """Test executions are sorted"""
389
    files = []
390
    files.append(os.path.join(self.rundir, "64test"))
391
    files.append(os.path.join(self.rundir, "00test"))
392
    files.append(os.path.join(self.rundir, "42test"))
393

    
394
    for fname in files:
395
      utils.WriteFile(fname, data="")
396

    
397
    results = RunParts(self.rundir, reset_env=True)
398

    
399
    for fname in sorted(files):
400
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
401

    
402
  def testOk(self):
403
    """Test correct execution"""
404
    fname = os.path.join(self.rundir, "00test")
405
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
406
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408
    self.failUnlessEqual(relname, os.path.basename(fname))
409
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410
    self.failUnlessEqual(runresult.stdout, "ciao")
411

    
412
  def testRunFail(self):
413
    """Test correct execution, with run failure"""
414
    fname = os.path.join(self.rundir, "00test")
415
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
416
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
417
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
418
    self.failUnlessEqual(relname, os.path.basename(fname))
419
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
420
    self.failUnlessEqual(runresult.exit_code, 1)
421
    self.failUnless(runresult.failed)
422

    
423
  def testRunMix(self):
424
    files = []
425
    files.append(os.path.join(self.rundir, "00test"))
426
    files.append(os.path.join(self.rundir, "42test"))
427
    files.append(os.path.join(self.rundir, "64test"))
428
    files.append(os.path.join(self.rundir, "99test"))
429

    
430
    files.sort()
431

    
432
    # 1st has errors in execution
433
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
434
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
435

    
436
    # 2nd is skipped
437
    utils.WriteFile(files[1], data="")
438

    
439
    # 3rd cannot execute properly
440
    utils.WriteFile(files[2], data="")
441
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
442

    
443
    # 4th execs
444
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
445
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
446

    
447
    results = RunParts(self.rundir, reset_env=True)
448

    
449
    (relname, status, runresult) = results[0]
450
    self.failUnlessEqual(relname, os.path.basename(files[0]))
451
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
452
    self.failUnlessEqual(runresult.exit_code, 1)
453
    self.failUnless(runresult.failed)
454

    
455
    (relname, status, runresult) = results[1]
456
    self.failUnlessEqual(relname, os.path.basename(files[1]))
457
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
458
    self.failUnlessEqual(runresult, None)
459

    
460
    (relname, status, runresult) = results[2]
461
    self.failUnlessEqual(relname, os.path.basename(files[2]))
462
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
463
    self.failUnless(runresult)
464

    
465
    (relname, status, runresult) = results[3]
466
    self.failUnlessEqual(relname, os.path.basename(files[3]))
467
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
468
    self.failUnlessEqual(runresult.output, "ciao")
469
    self.failUnlessEqual(runresult.exit_code, 0)
470
    self.failUnless(not runresult.failed)
471

    
472

    
473
class TestStartDaemon(testutils.GanetiTestCase):
474
  def setUp(self):
475
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
476
    self.tmpfile = os.path.join(self.tmpdir, "test")
477

    
478
  def tearDown(self):
479
    shutil.rmtree(self.tmpdir)
480

    
481
  def testShell(self):
482
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
483
    self._wait(self.tmpfile, 60.0, "Hello World")
484

    
485
  def testShellOutput(self):
486
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
487
    self._wait(self.tmpfile, 60.0, "Hello World")
488

    
489
  def testNoShellNoOutput(self):
490
    utils.StartDaemon(["pwd"])
491

    
492
  def testNoShellNoOutputTouch(self):
493
    testfile = os.path.join(self.tmpdir, "check")
494
    self.failIf(os.path.exists(testfile))
495
    utils.StartDaemon(["touch", testfile])
496
    self._wait(testfile, 60.0, "")
497

    
498
  def testNoShellOutput(self):
499
    utils.StartDaemon(["pwd"], output=self.tmpfile)
500
    self._wait(self.tmpfile, 60.0, "/")
501

    
502
  def testNoShellOutputCwd(self):
503
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
504
    self._wait(self.tmpfile, 60.0, os.getcwd())
505

    
506
  def testShellEnv(self):
507
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
508
                      env={ "GNT_TEST_VAR": "Hello World", })
509
    self._wait(self.tmpfile, 60.0, "Hello World")
510

    
511
  def testNoShellEnv(self):
512
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
513
                      env={ "GNT_TEST_VAR": "Hello World", })
514
    self._wait(self.tmpfile, 60.0, "Hello World")
515

    
516
  def testOutputFd(self):
517
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
518
    try:
519
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
520
    finally:
521
      os.close(fd)
522
    self._wait(self.tmpfile, 60.0, os.getcwd())
523

    
524
  def testPid(self):
525
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
526
    self._wait(self.tmpfile, 60.0, str(pid))
527

    
528
  def testPidFile(self):
529
    pidfile = os.path.join(self.tmpdir, "pid")
530
    checkfile = os.path.join(self.tmpdir, "abort")
531

    
532
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
533
                            output=self.tmpfile)
534
    try:
535
      fd = os.open(pidfile, os.O_RDONLY)
536
      try:
537
        # Check file is locked
538
        self.assertRaises(errors.LockError, utils.LockFile, fd)
539

    
540
        pidtext = os.read(fd, 100)
541
      finally:
542
        os.close(fd)
543

    
544
      self.assertEqual(int(pidtext.strip()), pid)
545

    
546
      self.assert_(utils.IsProcessAlive(pid))
547
    finally:
548
      # No matter what happens, kill daemon
549
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
550
      self.failIf(utils.IsProcessAlive(pid))
551

    
552
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
553

    
554
  def _wait(self, path, timeout, expected):
555
    # Due to the asynchronous nature of daemon processes, polling is necessary.
556
    # A timeout makes sure the test doesn't hang forever.
557
    def _CheckFile():
558
      if not (os.path.isfile(path) and
559
              utils.ReadFile(path).strip() == expected):
560
        raise utils.RetryAgain()
561

    
562
    try:
563
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
564
    except utils.RetryTimeout:
565
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
566
                " didn't write the correct output" % timeout)
567

    
568
  def testError(self):
569
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
570
                      ["./does-NOT-EXIST/here/0123456789"])
571
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
572
                      ["./does-NOT-EXIST/here/0123456789"],
573
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
574
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
575
                      ["./does-NOT-EXIST/here/0123456789"],
576
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
577
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
578
                      ["./does-NOT-EXIST/here/0123456789"],
579
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
580

    
581
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
582
    try:
583
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
584
                        ["./does-NOT-EXIST/here/0123456789"],
585
                        output=self.tmpfile, output_fd=fd)
586
    finally:
587
      os.close(fd)
588

    
589

    
590
class TestSetCloseOnExecFlag(unittest.TestCase):
591
  """Tests for SetCloseOnExecFlag"""
592

    
593
  def setUp(self):
594
    self.tmpfile = tempfile.TemporaryFile()
595

    
596
  def testEnable(self):
597
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
598
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
599
                    fcntl.FD_CLOEXEC)
600

    
601
  def testDisable(self):
602
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
603
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
604
                fcntl.FD_CLOEXEC)
605

    
606

    
607
class TestSetNonblockFlag(unittest.TestCase):
608
  def setUp(self):
609
    self.tmpfile = tempfile.TemporaryFile()
610

    
611
  def testEnable(self):
612
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
613
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
614
                    os.O_NONBLOCK)
615

    
616
  def testDisable(self):
617
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
618
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
619
                os.O_NONBLOCK)
620

    
621

    
622
class TestRemoveFile(unittest.TestCase):
623
  """Test case for the RemoveFile function"""
624

    
625
  def setUp(self):
626
    """Create a temp dir and file for each case"""
627
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
628
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
629
    os.close(fd)
630

    
631
  def tearDown(self):
632
    if os.path.exists(self.tmpfile):
633
      os.unlink(self.tmpfile)
634
    os.rmdir(self.tmpdir)
635

    
636
  def testIgnoreDirs(self):
637
    """Test that RemoveFile() ignores directories"""
638
    self.assertEqual(None, RemoveFile(self.tmpdir))
639

    
640
  def testIgnoreNotExisting(self):
641
    """Test that RemoveFile() ignores non-existing files"""
642
    RemoveFile(self.tmpfile)
643
    RemoveFile(self.tmpfile)
644

    
645
  def testRemoveFile(self):
646
    """Test that RemoveFile does remove a file"""
647
    RemoveFile(self.tmpfile)
648
    if os.path.exists(self.tmpfile):
649
      self.fail("File '%s' not removed" % self.tmpfile)
650

    
651
  def testRemoveSymlink(self):
652
    """Test that RemoveFile does remove symlinks"""
653
    symlink = self.tmpdir + "/symlink"
654
    os.symlink("no-such-file", symlink)
655
    RemoveFile(symlink)
656
    if os.path.exists(symlink):
657
      self.fail("File '%s' not removed" % symlink)
658
    os.symlink(self.tmpfile, symlink)
659
    RemoveFile(symlink)
660
    if os.path.exists(symlink):
661
      self.fail("File '%s' not removed" % symlink)
662

    
663

    
664
class TestRename(unittest.TestCase):
665
  """Test case for RenameFile"""
666

    
667
  def setUp(self):
668
    """Create a temporary directory"""
669
    self.tmpdir = tempfile.mkdtemp()
670
    self.tmpfile = os.path.join(self.tmpdir, "test1")
671

    
672
    # Touch the file
673
    open(self.tmpfile, "w").close()
674

    
675
  def tearDown(self):
676
    """Remove temporary directory"""
677
    shutil.rmtree(self.tmpdir)
678

    
679
  def testSimpleRename1(self):
680
    """Simple rename 1"""
681
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
682
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
683

    
684
  def testSimpleRename2(self):
685
    """Simple rename 2"""
686
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
687
                     mkdir=True)
688
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
689

    
690
  def testRenameMkdir(self):
691
    """Rename with mkdir"""
692
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
693
                     mkdir=True)
694
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
695
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
696

    
697
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
698
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
699
                     mkdir=True)
700
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
701
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
702
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
703

    
704

    
705
class TestMatchNameComponent(unittest.TestCase):
706
  """Test case for the MatchNameComponent function"""
707

    
708
  def testEmptyList(self):
709
    """Test that there is no match against an empty list"""
710

    
711
    self.failUnlessEqual(MatchNameComponent("", []), None)
712
    self.failUnlessEqual(MatchNameComponent("test", []), None)
713

    
714
  def testSingleMatch(self):
715
    """Test that a single match is performed correctly"""
716
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
717
    for key in "test2", "test2.example", "test2.example.com":
718
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
719

    
720
  def testMultipleMatches(self):
721
    """Test that a multiple match is returned as None"""
722
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
723
    for key in "test1", "test1.example":
724
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
725

    
726
  def testFullMatch(self):
727
    """Test that a full match is returned correctly"""
728
    key1 = "test1"
729
    key2 = "test1.example"
730
    mlist = [key2, key2 + ".com"]
731
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
732
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
733

    
734
  def testCaseInsensitivePartialMatch(self):
735
    """Test for the case_insensitive keyword"""
736
    mlist = ["test1.example.com", "test2.example.net"]
737
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
738
                     "test2.example.net")
739
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
740
                     "test2.example.net")
741
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
742
                     "test2.example.net")
743
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
744
                     "test2.example.net")
745

    
746

    
747
  def testCaseInsensitiveFullMatch(self):
748
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
749
    # Between the two ts1 a full string match non-case insensitive should work
750
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
751
                     None)
752
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
753
                     "ts1.ex")
754
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
755
                     "ts1.ex")
756
    # Between the two ts2 only case differs, so only case-match works
757
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
758
                     "ts2.ex")
759
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
760
                     "Ts2.ex")
761
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
762
                     None)
763

    
764

    
765
class TestReadFile(testutils.GanetiTestCase):
766

    
767
  def testReadAll(self):
768
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
769
    self.assertEqual(len(data), 814)
770

    
771
    h = compat.md5_hash()
772
    h.update(data)
773
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
774

    
775
  def testReadSize(self):
776
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
777
                          size=100)
778
    self.assertEqual(len(data), 100)
779

    
780
    h = compat.md5_hash()
781
    h.update(data)
782
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
783

    
784
  def testError(self):
785
    self.assertRaises(EnvironmentError, utils.ReadFile,
786
                      "/dev/null/does-not-exist")
787

    
788

    
789
class TestReadOneLineFile(testutils.GanetiTestCase):
790

    
791
  def setUp(self):
792
    testutils.GanetiTestCase.setUp(self)
793

    
794
  def testDefault(self):
795
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
796
    self.assertEqual(len(data), 27)
797
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
798

    
799
  def testNotStrict(self):
800
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
801
    self.assertEqual(len(data), 27)
802
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
803

    
804
  def testStrictFailure(self):
805
    self.assertRaises(errors.GenericError, ReadOneLineFile,
806
                      self._TestDataFilename("cert1.pem"), strict=True)
807

    
808
  def testLongLine(self):
809
    dummydata = (1024 * "Hello World! ")
810
    myfile = self._CreateTempFile()
811
    utils.WriteFile(myfile, data=dummydata)
812
    datastrict = ReadOneLineFile(myfile, strict=True)
813
    datalax = ReadOneLineFile(myfile, strict=False)
814
    self.assertEqual(dummydata, datastrict)
815
    self.assertEqual(dummydata, datalax)
816

    
817
  def testNewline(self):
818
    myfile = self._CreateTempFile()
819
    myline = "myline"
820
    for nl in ["", "\n", "\r\n"]:
821
      dummydata = "%s%s" % (myline, nl)
822
      utils.WriteFile(myfile, data=dummydata)
823
      datalax = ReadOneLineFile(myfile, strict=False)
824
      self.assertEqual(myline, datalax)
825
      datastrict = ReadOneLineFile(myfile, strict=True)
826
      self.assertEqual(myline, datastrict)
827

    
828
  def testWhitespaceAndMultipleLines(self):
829
    myfile = self._CreateTempFile()
830
    for nl in ["", "\n", "\r\n"]:
831
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
832
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
833
        utils.WriteFile(myfile, data=dummydata)
834
        datalax = ReadOneLineFile(myfile, strict=False)
835
        if nl:
836
          self.assert_(set("\r\n") & set(dummydata))
837
          self.assertRaises(errors.GenericError, ReadOneLineFile,
838
                            myfile, strict=True)
839
          explen = len("Foo bar baz ") + len(ws)
840
          self.assertEqual(len(datalax), explen)
841
          self.assertEqual(datalax, dummydata[:explen])
842
          self.assertFalse(set("\r\n") & set(datalax))
843
        else:
844
          datastrict = ReadOneLineFile(myfile, strict=True)
845
          self.assertEqual(dummydata, datastrict)
846
          self.assertEqual(dummydata, datalax)
847

    
848
  def testEmptylines(self):
849
    myfile = self._CreateTempFile()
850
    myline = "myline"
851
    for nl in ["\n", "\r\n"]:
852
      for ol in ["", "otherline"]:
853
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
854
        utils.WriteFile(myfile, data=dummydata)
855
        self.assert_(set("\r\n") & set(dummydata))
856
        datalax = ReadOneLineFile(myfile, strict=False)
857
        self.assertEqual(myline, datalax)
858
        if ol:
859
          self.assertRaises(errors.GenericError, ReadOneLineFile,
860
                            myfile, strict=True)
861
        else:
862
          datastrict = ReadOneLineFile(myfile, strict=True)
863
          self.assertEqual(myline, datastrict)
864

    
865

    
866
class TestTimestampForFilename(unittest.TestCase):
867
  def test(self):
868
    self.assert_("." not in utils.TimestampForFilename())
869
    self.assert_(":" not in utils.TimestampForFilename())
870

    
871

    
872
class TestCreateBackup(testutils.GanetiTestCase):
873
  def setUp(self):
874
    testutils.GanetiTestCase.setUp(self)
875

    
876
    self.tmpdir = tempfile.mkdtemp()
877

    
878
  def tearDown(self):
879
    testutils.GanetiTestCase.tearDown(self)
880

    
881
    shutil.rmtree(self.tmpdir)
882

    
883
  def testEmpty(self):
884
    filename = utils.PathJoin(self.tmpdir, "config.data")
885
    utils.WriteFile(filename, data="")
886
    bname = utils.CreateBackup(filename)
887
    self.assertFileContent(bname, "")
888
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
889
    utils.CreateBackup(filename)
890
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
891
    utils.CreateBackup(filename)
892
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
893

    
894
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
895
    os.mkfifo(fifoname)
896
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
897

    
898
  def testContent(self):
899
    bkpcount = 0
900
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
901
      for rep in [1, 2, 10, 127]:
902
        testdata = data * rep
903

    
904
        filename = utils.PathJoin(self.tmpdir, "test.data_")
905
        utils.WriteFile(filename, data=testdata)
906
        self.assertFileContent(filename, testdata)
907

    
908
        for _ in range(3):
909
          bname = utils.CreateBackup(filename)
910
          bkpcount += 1
911
          self.assertFileContent(bname, testdata)
912
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
913

    
914

    
915
class TestFormatUnit(unittest.TestCase):
916
  """Test case for the FormatUnit function"""
917

    
918
  def testMiB(self):
919
    self.assertEqual(FormatUnit(1, 'h'), '1M')
920
    self.assertEqual(FormatUnit(100, 'h'), '100M')
921
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
922

    
923
    self.assertEqual(FormatUnit(1, 'm'), '1')
924
    self.assertEqual(FormatUnit(100, 'm'), '100')
925
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
926

    
927
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
928
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
929
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
930
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
931

    
932
  def testGiB(self):
933
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
934
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
935
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
936
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
937

    
938
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
939
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
940
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
941
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
942

    
943
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
944
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
945
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
946

    
947
  def testTiB(self):
948
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
949
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
950
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
951

    
952
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
953
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
954
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
955

    
956
class TestParseUnit(unittest.TestCase):
957
  """Test case for the ParseUnit function"""
958

    
959
  SCALES = (('', 1),
960
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
961
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
962
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
963

    
964
  def testRounding(self):
965
    self.assertEqual(ParseUnit('0'), 0)
966
    self.assertEqual(ParseUnit('1'), 4)
967
    self.assertEqual(ParseUnit('2'), 4)
968
    self.assertEqual(ParseUnit('3'), 4)
969

    
970
    self.assertEqual(ParseUnit('124'), 124)
971
    self.assertEqual(ParseUnit('125'), 128)
972
    self.assertEqual(ParseUnit('126'), 128)
973
    self.assertEqual(ParseUnit('127'), 128)
974
    self.assertEqual(ParseUnit('128'), 128)
975
    self.assertEqual(ParseUnit('129'), 132)
976
    self.assertEqual(ParseUnit('130'), 132)
977

    
978
  def testFloating(self):
979
    self.assertEqual(ParseUnit('0'), 0)
980
    self.assertEqual(ParseUnit('0.5'), 4)
981
    self.assertEqual(ParseUnit('1.75'), 4)
982
    self.assertEqual(ParseUnit('1.99'), 4)
983
    self.assertEqual(ParseUnit('2.00'), 4)
984
    self.assertEqual(ParseUnit('2.01'), 4)
985
    self.assertEqual(ParseUnit('3.99'), 4)
986
    self.assertEqual(ParseUnit('4.00'), 4)
987
    self.assertEqual(ParseUnit('4.01'), 8)
988
    self.assertEqual(ParseUnit('1.5G'), 1536)
989
    self.assertEqual(ParseUnit('1.8G'), 1844)
990
    self.assertEqual(ParseUnit('8.28T'), 8682212)
991

    
992
  def testSuffixes(self):
993
    for sep in ('', ' ', '   ', "\t", "\t "):
994
      for suffix, scale in TestParseUnit.SCALES:
995
        for func in (lambda x: x, str.lower, str.upper):
996
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
997
                           1024 * scale)
998

    
999
  def testInvalidInput(self):
1000
    for sep in ('-', '_', ',', 'a'):
1001
      for suffix, _ in TestParseUnit.SCALES:
1002
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
1003

    
1004
    for suffix, _ in TestParseUnit.SCALES:
1005
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
1006

    
1007

    
1008
class TestSshKeys(testutils.GanetiTestCase):
1009
  """Test case for the AddAuthorizedKey function"""
1010

    
1011
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1012
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
1013
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1014

    
1015
  def setUp(self):
1016
    testutils.GanetiTestCase.setUp(self)
1017
    self.tmpname = self._CreateTempFile()
1018
    handle = open(self.tmpname, 'w')
1019
    try:
1020
      handle.write("%s\n" % TestSshKeys.KEY_A)
1021
      handle.write("%s\n" % TestSshKeys.KEY_B)
1022
    finally:
1023
      handle.close()
1024

    
1025
  def testAddingNewKey(self):
1026
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1027

    
1028
    self.assertFileContent(self.tmpname,
1029
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1030
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1031
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1032
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1033

    
1034
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1035
    AddAuthorizedKey(self.tmpname,
1036
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1037

    
1038
    self.assertFileContent(self.tmpname,
1039
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1040
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1041
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1042
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1043

    
1044
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1045
    AddAuthorizedKey(self.tmpname,
1046
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1047

    
1048
    self.assertFileContent(self.tmpname,
1049
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1050
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1051
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1052

    
1053
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1054
    RemoveAuthorizedKey(self.tmpname,
1055
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1056

    
1057
    self.assertFileContent(self.tmpname,
1058
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1059
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1060

    
1061
  def testRemovingNonExistingKey(self):
1062
    RemoveAuthorizedKey(self.tmpname,
1063
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1064

    
1065
    self.assertFileContent(self.tmpname,
1066
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1067
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1068
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1069

    
1070

    
1071
class TestEtcHosts(testutils.GanetiTestCase):
1072
  """Test functions modifying /etc/hosts"""
1073

    
1074
  def setUp(self):
1075
    testutils.GanetiTestCase.setUp(self)
1076
    self.tmpname = self._CreateTempFile()
1077
    handle = open(self.tmpname, 'w')
1078
    try:
1079
      handle.write('# This is a test file for /etc/hosts\n')
1080
      handle.write('127.0.0.1\tlocalhost\n')
1081
      handle.write('192.168.1.1 router gw\n')
1082
    finally:
1083
      handle.close()
1084

    
1085
  def testSettingNewIp(self):
1086
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
1087

    
1088
    self.assertFileContent(self.tmpname,
1089
      "# This is a test file for /etc/hosts\n"
1090
      "127.0.0.1\tlocalhost\n"
1091
      "192.168.1.1 router gw\n"
1092
      "1.2.3.4\tmyhost.domain.tld myhost\n")
1093
    self.assertFileMode(self.tmpname, 0644)
1094

    
1095
  def testSettingExistingIp(self):
1096
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
1097
                     ['myhost'])
1098

    
1099
    self.assertFileContent(self.tmpname,
1100
      "# This is a test file for /etc/hosts\n"
1101
      "127.0.0.1\tlocalhost\n"
1102
      "192.168.1.1\tmyhost.domain.tld myhost\n")
1103
    self.assertFileMode(self.tmpname, 0644)
1104

    
1105
  def testSettingDuplicateName(self):
1106
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
1107

    
1108
    self.assertFileContent(self.tmpname,
1109
      "# This is a test file for /etc/hosts\n"
1110
      "127.0.0.1\tlocalhost\n"
1111
      "192.168.1.1 router gw\n"
1112
      "1.2.3.4\tmyhost\n")
1113
    self.assertFileMode(self.tmpname, 0644)
1114

    
1115
  def testRemovingExistingHost(self):
1116
    RemoveEtcHostsEntry(self.tmpname, 'router')
1117

    
1118
    self.assertFileContent(self.tmpname,
1119
      "# This is a test file for /etc/hosts\n"
1120
      "127.0.0.1\tlocalhost\n"
1121
      "192.168.1.1 gw\n")
1122
    self.assertFileMode(self.tmpname, 0644)
1123

    
1124
  def testRemovingSingleExistingHost(self):
1125
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1126

    
1127
    self.assertFileContent(self.tmpname,
1128
      "# This is a test file for /etc/hosts\n"
1129
      "192.168.1.1 router gw\n")
1130
    self.assertFileMode(self.tmpname, 0644)
1131

    
1132
  def testRemovingNonExistingHost(self):
1133
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1134

    
1135
    self.assertFileContent(self.tmpname,
1136
      "# This is a test file for /etc/hosts\n"
1137
      "127.0.0.1\tlocalhost\n"
1138
      "192.168.1.1 router gw\n")
1139
    self.assertFileMode(self.tmpname, 0644)
1140

    
1141
  def testRemovingAlias(self):
1142
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1143

    
1144
    self.assertFileContent(self.tmpname,
1145
      "# This is a test file for /etc/hosts\n"
1146
      "127.0.0.1\tlocalhost\n"
1147
      "192.168.1.1 router\n")
1148
    self.assertFileMode(self.tmpname, 0644)
1149

    
1150

    
1151
class TestShellQuoting(unittest.TestCase):
1152
  """Test case for shell quoting functions"""
1153

    
1154
  def testShellQuote(self):
1155
    self.assertEqual(ShellQuote('abc'), "abc")
1156
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1157
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1158
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1159
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1160

    
1161
  def testShellQuoteArgs(self):
1162
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1163
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1164
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1165

    
1166

    
1167
class TestTcpPing(unittest.TestCase):
1168
  """Testcase for TCP version of ping - against listen(2)ing port"""
1169

    
1170
  def setUp(self):
1171
    self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1172
    self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1173
    self.listenerport = self.listener.getsockname()[1]
1174
    self.listener.listen(1)
1175

    
1176
  def tearDown(self):
1177
    self.listener.shutdown(socket.SHUT_RDWR)
1178
    del self.listener
1179
    del self.listenerport
1180

    
1181
  def testTcpPingToLocalHostAccept(self):
1182
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1183
                         self.listenerport,
1184
                         timeout=10,
1185
                         live_port_needed=True,
1186
                         source=constants.LOCALHOST_IP_ADDRESS,
1187
                         ),
1188
                 "failed to connect to test listener")
1189

    
1190
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1191
                         self.listenerport,
1192
                         timeout=10,
1193
                         live_port_needed=True,
1194
                         ),
1195
                 "failed to connect to test listener (no source)")
1196

    
1197

    
1198
class TestTcpPingDeaf(unittest.TestCase):
1199
  """Testcase for TCP version of ping - against non listen(2)ing port"""
1200

    
1201
  def setUp(self):
1202
    self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1203
    self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1204
    self.deaflistenerport = self.deaflistener.getsockname()[1]
1205

    
1206
  def tearDown(self):
1207
    del self.deaflistener
1208
    del self.deaflistenerport
1209

    
1210
  def testTcpPingToLocalHostAcceptDeaf(self):
1211
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1212
                        self.deaflistenerport,
1213
                        timeout=constants.TCP_PING_TIMEOUT,
1214
                        live_port_needed=True,
1215
                        source=constants.LOCALHOST_IP_ADDRESS,
1216
                        ), # need successful connect(2)
1217
                "successfully connected to deaf listener")
1218

    
1219
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1220
                        self.deaflistenerport,
1221
                        timeout=constants.TCP_PING_TIMEOUT,
1222
                        live_port_needed=True,
1223
                        ), # need successful connect(2)
1224
                "successfully connected to deaf listener (no source addr)")
1225

    
1226
  def testTcpPingToLocalHostNoAccept(self):
1227
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1228
                         self.deaflistenerport,
1229
                         timeout=constants.TCP_PING_TIMEOUT,
1230
                         live_port_needed=False,
1231
                         source=constants.LOCALHOST_IP_ADDRESS,
1232
                         ), # ECONNREFUSED is OK
1233
                 "failed to ping alive host on deaf port")
1234

    
1235
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1236
                         self.deaflistenerport,
1237
                         timeout=constants.TCP_PING_TIMEOUT,
1238
                         live_port_needed=False,
1239
                         ), # ECONNREFUSED is OK
1240
                 "failed to ping alive host on deaf port (no source addr)")
1241

    
1242

    
1243
class TestOwnIpAddress(unittest.TestCase):
1244
  """Testcase for OwnIpAddress"""
1245

    
1246
  def testOwnLoopback(self):
1247
    """check having the loopback ip"""
1248
    self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1249
                    "Should own the loopback address")
1250

    
1251
  def testNowOwnAddress(self):
1252
    """check that I don't own an address"""
1253

    
1254
    # Network 192.0.2.0/24 is reserved for test/documentation as per
1255
    # RFC 5735, so we *should* not have an address of this range... if
1256
    # this fails, we should extend the test to multiple addresses
1257
    DST_IP = "192.0.2.1"
1258
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1259

    
1260

    
1261
def _GetSocketCredentials(path):
1262
  """Connect to a Unix socket and return remote credentials.
1263

1264
  """
1265
  sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1266
  try:
1267
    sock.settimeout(10)
1268
    sock.connect(path)
1269
    return utils.GetSocketCredentials(sock)
1270
  finally:
1271
    sock.close()
1272

    
1273

    
1274
class TestGetSocketCredentials(unittest.TestCase):
1275
  def setUp(self):
1276
    self.tmpdir = tempfile.mkdtemp()
1277
    self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1278

    
1279
    self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1280
    self.listener.settimeout(10)
1281
    self.listener.bind(self.sockpath)
1282
    self.listener.listen(1)
1283

    
1284
  def tearDown(self):
1285
    self.listener.shutdown(socket.SHUT_RDWR)
1286
    self.listener.close()
1287
    shutil.rmtree(self.tmpdir)
1288

    
1289
  def test(self):
1290
    (c2pr, c2pw) = os.pipe()
1291

    
1292
    # Start child process
1293
    child = os.fork()
1294
    if child == 0:
1295
      try:
1296
        data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1297

    
1298
        os.write(c2pw, data)
1299
        os.close(c2pw)
1300

    
1301
        os._exit(0)
1302
      finally:
1303
        os._exit(1)
1304

    
1305
    os.close(c2pw)
1306

    
1307
    # Wait for one connection
1308
    (conn, _) = self.listener.accept()
1309
    conn.recv(1)
1310
    conn.close()
1311

    
1312
    # Wait for result
1313
    result = os.read(c2pr, 4096)
1314
    os.close(c2pr)
1315

    
1316
    # Check child's exit code
1317
    (_, status) = os.waitpid(child, 0)
1318
    self.assertFalse(os.WIFSIGNALED(status))
1319
    self.assertEqual(os.WEXITSTATUS(status), 0)
1320

    
1321
    # Check result
1322
    (pid, uid, gid) = serializer.LoadJson(result)
1323
    self.assertEqual(pid, os.getpid())
1324
    self.assertEqual(uid, os.getuid())
1325
    self.assertEqual(gid, os.getgid())
1326

    
1327

    
1328
class TestListVisibleFiles(unittest.TestCase):
1329
  """Test case for ListVisibleFiles"""
1330

    
1331
  def setUp(self):
1332
    self.path = tempfile.mkdtemp()
1333

    
1334
  def tearDown(self):
1335
    shutil.rmtree(self.path)
1336

    
1337
  def _test(self, files, expected):
1338
    # Sort a copy
1339
    expected = expected[:]
1340
    expected.sort()
1341

    
1342
    for name in files:
1343
      f = open(os.path.join(self.path, name), 'w')
1344
      try:
1345
        f.write("Test\n")
1346
      finally:
1347
        f.close()
1348

    
1349
    found = ListVisibleFiles(self.path)
1350
    found.sort()
1351

    
1352
    self.assertEqual(found, expected)
1353

    
1354
  def testAllVisible(self):
1355
    files = ["a", "b", "c"]
1356
    expected = files
1357
    self._test(files, expected)
1358

    
1359
  def testNoneVisible(self):
1360
    files = [".a", ".b", ".c"]
1361
    expected = []
1362
    self._test(files, expected)
1363

    
1364
  def testSomeVisible(self):
1365
    files = ["a", "b", ".c"]
1366
    expected = ["a", "b"]
1367
    self._test(files, expected)
1368

    
1369
  def testNonAbsolutePath(self):
1370
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1371

    
1372
  def testNonNormalizedPath(self):
1373
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1374
                          "/bin/../tmp")
1375

    
1376

    
1377
class TestNewUUID(unittest.TestCase):
1378
  """Test case for NewUUID"""
1379

    
1380
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1381
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1382

    
1383
  def runTest(self):
1384
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1385

    
1386

    
1387
class TestUniqueSequence(unittest.TestCase):
1388
  """Test case for UniqueSequence"""
1389

    
1390
  def _test(self, input, expected):
1391
    self.assertEqual(utils.UniqueSequence(input), expected)
1392

    
1393
  def runTest(self):
1394
    # Ordered input
1395
    self._test([1, 2, 3], [1, 2, 3])
1396
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1397
    self._test([1, 2, 2, 3], [1, 2, 3])
1398
    self._test([1, 2, 3, 3], [1, 2, 3])
1399

    
1400
    # Unordered input
1401
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1402
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1403

    
1404
    # Strings
1405
    self._test(["a", "a"], ["a"])
1406
    self._test(["a", "b"], ["a", "b"])
1407
    self._test(["a", "b", "a"], ["a", "b"])
1408

    
1409

    
1410
class TestFirstFree(unittest.TestCase):
1411
  """Test case for the FirstFree function"""
1412

    
1413
  def test(self):
1414
    """Test FirstFree"""
1415
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1416
    self.failUnlessEqual(FirstFree([]), None)
1417
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1418
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1419
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1420

    
1421

    
1422
class TestTailFile(testutils.GanetiTestCase):
1423
  """Test case for the TailFile function"""
1424

    
1425
  def testEmpty(self):
1426
    fname = self._CreateTempFile()
1427
    self.failUnlessEqual(TailFile(fname), [])
1428
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1429

    
1430
  def testAllLines(self):
1431
    data = ["test %d" % i for i in range(30)]
1432
    for i in range(30):
1433
      fname = self._CreateTempFile()
1434
      fd = open(fname, "w")
1435
      fd.write("\n".join(data[:i]))
1436
      if i > 0:
1437
        fd.write("\n")
1438
      fd.close()
1439
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1440

    
1441
  def testPartialLines(self):
1442
    data = ["test %d" % i for i in range(30)]
1443
    fname = self._CreateTempFile()
1444
    fd = open(fname, "w")
1445
    fd.write("\n".join(data))
1446
    fd.write("\n")
1447
    fd.close()
1448
    for i in range(1, 30):
1449
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1450

    
1451
  def testBigFile(self):
1452
    data = ["test %d" % i for i in range(30)]
1453
    fname = self._CreateTempFile()
1454
    fd = open(fname, "w")
1455
    fd.write("X" * 1048576)
1456
    fd.write("\n")
1457
    fd.write("\n".join(data))
1458
    fd.write("\n")
1459
    fd.close()
1460
    for i in range(1, 30):
1461
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1462

    
1463

    
1464
class _BaseFileLockTest:
1465
  """Test case for the FileLock class"""
1466

    
1467
  def testSharedNonblocking(self):
1468
    self.lock.Shared(blocking=False)
1469
    self.lock.Close()
1470

    
1471
  def testExclusiveNonblocking(self):
1472
    self.lock.Exclusive(blocking=False)
1473
    self.lock.Close()
1474

    
1475
  def testUnlockNonblocking(self):
1476
    self.lock.Unlock(blocking=False)
1477
    self.lock.Close()
1478

    
1479
  def testSharedBlocking(self):
1480
    self.lock.Shared(blocking=True)
1481
    self.lock.Close()
1482

    
1483
  def testExclusiveBlocking(self):
1484
    self.lock.Exclusive(blocking=True)
1485
    self.lock.Close()
1486

    
1487
  def testUnlockBlocking(self):
1488
    self.lock.Unlock(blocking=True)
1489
    self.lock.Close()
1490

    
1491
  def testSharedExclusiveUnlock(self):
1492
    self.lock.Shared(blocking=False)
1493
    self.lock.Exclusive(blocking=False)
1494
    self.lock.Unlock(blocking=False)
1495
    self.lock.Close()
1496

    
1497
  def testExclusiveSharedUnlock(self):
1498
    self.lock.Exclusive(blocking=False)
1499
    self.lock.Shared(blocking=False)
1500
    self.lock.Unlock(blocking=False)
1501
    self.lock.Close()
1502

    
1503
  def testSimpleTimeout(self):
1504
    # These will succeed on the first attempt, hence a short timeout
1505
    self.lock.Shared(blocking=True, timeout=10.0)
1506
    self.lock.Exclusive(blocking=False, timeout=10.0)
1507
    self.lock.Unlock(blocking=True, timeout=10.0)
1508
    self.lock.Close()
1509

    
1510
  @staticmethod
1511
  def _TryLockInner(filename, shared, blocking):
1512
    lock = utils.FileLock.Open(filename)
1513

    
1514
    if shared:
1515
      fn = lock.Shared
1516
    else:
1517
      fn = lock.Exclusive
1518

    
1519
    try:
1520
      # The timeout doesn't really matter as the parent process waits for us to
1521
      # finish anyway.
1522
      fn(blocking=blocking, timeout=0.01)
1523
    except errors.LockError, err:
1524
      return False
1525

    
1526
    return True
1527

    
1528
  def _TryLock(self, *args):
1529
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1530
                                      *args)
1531

    
1532
  def testTimeout(self):
1533
    for blocking in [True, False]:
1534
      self.lock.Exclusive(blocking=True)
1535
      self.failIf(self._TryLock(False, blocking))
1536
      self.failIf(self._TryLock(True, blocking))
1537

    
1538
      self.lock.Shared(blocking=True)
1539
      self.assert_(self._TryLock(True, blocking))
1540
      self.failIf(self._TryLock(False, blocking))
1541

    
1542
  def testCloseShared(self):
1543
    self.lock.Close()
1544
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1545

    
1546
  def testCloseExclusive(self):
1547
    self.lock.Close()
1548
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1549

    
1550
  def testCloseUnlock(self):
1551
    self.lock.Close()
1552
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1553

    
1554

    
1555
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1556
  TESTDATA = "Hello World\n" * 10
1557

    
1558
  def setUp(self):
1559
    testutils.GanetiTestCase.setUp(self)
1560

    
1561
    self.tmpfile = tempfile.NamedTemporaryFile()
1562
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1563
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1564

    
1565
    # Ensure "Open" didn't truncate file
1566
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1567

    
1568
  def tearDown(self):
1569
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1570

    
1571
    testutils.GanetiTestCase.tearDown(self)
1572

    
1573

    
1574
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1575
  def setUp(self):
1576
    self.tmpfile = tempfile.NamedTemporaryFile()
1577
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1578

    
1579

    
1580
class TestTimeFunctions(unittest.TestCase):
1581
  """Test case for time functions"""
1582

    
1583
  def runTest(self):
1584
    self.assertEqual(utils.SplitTime(1), (1, 0))
1585
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1586
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1587
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1588
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1589
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1590
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1591
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1592

    
1593
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1594

    
1595
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1596
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1597
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1598

    
1599
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1600
                     1218448917.481)
1601
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1602

    
1603
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1604
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1605
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1606
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1607
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1608

    
1609

    
1610
class FieldSetTestCase(unittest.TestCase):
1611
  """Test case for FieldSets"""
1612

    
1613
  def testSimpleMatch(self):
1614
    f = utils.FieldSet("a", "b", "c", "def")
1615
    self.failUnless(f.Matches("a"))
1616
    self.failIf(f.Matches("d"), "Substring matched")
1617
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1618
    self.failIf(f.NonMatching(["b", "c"]))
1619
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1620
    self.failUnless(f.NonMatching(["a", "d"]))
1621

    
1622
  def testRegexMatch(self):
1623
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1624
    self.failUnless(f.Matches("b1"))
1625
    self.failUnless(f.Matches("b99"))
1626
    self.failIf(f.Matches("b/1"))
1627
    self.failIf(f.NonMatching(["b12", "c"]))
1628
    self.failUnless(f.NonMatching(["a", "1"]))
1629

    
1630
class TestForceDictType(unittest.TestCase):
1631
  """Test case for ForceDictType"""
1632

    
1633
  def setUp(self):
1634
    self.key_types = {
1635
      'a': constants.VTYPE_INT,
1636
      'b': constants.VTYPE_BOOL,
1637
      'c': constants.VTYPE_STRING,
1638
      'd': constants.VTYPE_SIZE,
1639
      }
1640

    
1641
  def _fdt(self, dict, allowed_values=None):
1642
    if allowed_values is None:
1643
      ForceDictType(dict, self.key_types)
1644
    else:
1645
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1646

    
1647
    return dict
1648

    
1649
  def testSimpleDict(self):
1650
    self.assertEqual(self._fdt({}), {})
1651
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1652
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1653
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1654
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1655
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1656
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1657
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1658
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1659
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1660
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1661
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1662

    
1663
  def testErrors(self):
1664
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1665
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1666
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1667
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1668

    
1669

    
1670
class TestIsNormAbsPath(unittest.TestCase):
1671
  """Testing case for IsNormAbsPath"""
1672

    
1673
  def _pathTestHelper(self, path, result):
1674
    if result:
1675
      self.assert_(IsNormAbsPath(path),
1676
          "Path %s should result absolute and normalized" % path)
1677
    else:
1678
      self.assert_(not IsNormAbsPath(path),
1679
          "Path %s should not result absolute and normalized" % path)
1680

    
1681
  def testBase(self):
1682
    self._pathTestHelper('/etc', True)
1683
    self._pathTestHelper('/srv', True)
1684
    self._pathTestHelper('etc', False)
1685
    self._pathTestHelper('/etc/../root', False)
1686
    self._pathTestHelper('/etc/', False)
1687

    
1688

    
1689
class TestSafeEncode(unittest.TestCase):
1690
  """Test case for SafeEncode"""
1691

    
1692
  def testAscii(self):
1693
    for txt in [string.digits, string.letters, string.punctuation]:
1694
      self.failUnlessEqual(txt, SafeEncode(txt))
1695

    
1696
  def testDoubleEncode(self):
1697
    for i in range(255):
1698
      txt = SafeEncode(chr(i))
1699
      self.failUnlessEqual(txt, SafeEncode(txt))
1700

    
1701
  def testUnicode(self):
1702
    # 1024 is high enough to catch non-direct ASCII mappings
1703
    for i in range(1024):
1704
      txt = SafeEncode(unichr(i))
1705
      self.failUnlessEqual(txt, SafeEncode(txt))
1706

    
1707

    
1708
class TestFormatTime(unittest.TestCase):
1709
  """Testing case for FormatTime"""
1710

    
1711
  def testNone(self):
1712
    self.failUnlessEqual(FormatTime(None), "N/A")
1713

    
1714
  def testInvalid(self):
1715
    self.failUnlessEqual(FormatTime(()), "N/A")
1716

    
1717
  def testNow(self):
1718
    # tests that we accept time.time input
1719
    FormatTime(time.time())
1720
    # tests that we accept int input
1721
    FormatTime(int(time.time()))
1722

    
1723

    
1724
class RunInSeparateProcess(unittest.TestCase):
1725
  def test(self):
1726
    for exp in [True, False]:
1727
      def _child():
1728
        return exp
1729

    
1730
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1731

    
1732
  def testArgs(self):
1733
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1734
      def _child(carg1, carg2):
1735
        return carg1 == "Foo" and carg2 == arg
1736

    
1737
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1738

    
1739
  def testPid(self):
1740
    parent_pid = os.getpid()
1741

    
1742
    def _check():
1743
      return os.getpid() == parent_pid
1744

    
1745
    self.failIf(utils.RunInSeparateProcess(_check))
1746

    
1747
  def testSignal(self):
1748
    def _kill():
1749
      os.kill(os.getpid(), signal.SIGTERM)
1750

    
1751
    self.assertRaises(errors.GenericError,
1752
                      utils.RunInSeparateProcess, _kill)
1753

    
1754
  def testException(self):
1755
    def _exc():
1756
      raise errors.GenericError("This is a test")
1757

    
1758
    self.assertRaises(errors.GenericError,
1759
                      utils.RunInSeparateProcess, _exc)
1760

    
1761

    
1762
class TestFingerprintFile(unittest.TestCase):
1763
  def setUp(self):
1764
    self.tmpfile = tempfile.NamedTemporaryFile()
1765

    
1766
  def test(self):
1767
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1768
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1769

    
1770
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1771
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1772
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1773

    
1774

    
1775
class TestUnescapeAndSplit(unittest.TestCase):
1776
  """Testing case for UnescapeAndSplit"""
1777

    
1778
  def setUp(self):
1779
    # testing more that one separator for regexp safety
1780
    self._seps = [",", "+", "."]
1781

    
1782
  def testSimple(self):
1783
    a = ["a", "b", "c", "d"]
1784
    for sep in self._seps:
1785
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1786

    
1787
  def testEscape(self):
1788
    for sep in self._seps:
1789
      a = ["a", "b\\" + sep + "c", "d"]
1790
      b = ["a", "b" + sep + "c", "d"]
1791
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1792

    
1793
  def testDoubleEscape(self):
1794
    for sep in self._seps:
1795
      a = ["a", "b\\\\", "c", "d"]
1796
      b = ["a", "b\\", "c", "d"]
1797
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1798

    
1799
  def testThreeEscape(self):
1800
    for sep in self._seps:
1801
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1802
      b = ["a", "b\\" + sep + "c", "d"]
1803
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1804

    
1805

    
1806
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1807
  def setUp(self):
1808
    self.tmpdir = tempfile.mkdtemp()
1809

    
1810
  def tearDown(self):
1811
    shutil.rmtree(self.tmpdir)
1812

    
1813
  def _checkRsaPrivateKey(self, key):
1814
    lines = key.splitlines()
1815
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1816
            "-----END RSA PRIVATE KEY-----" in lines)
1817

    
1818
  def _checkCertificate(self, cert):
1819
    lines = cert.splitlines()
1820
    return ("-----BEGIN CERTIFICATE-----" in lines and
1821
            "-----END CERTIFICATE-----" in lines)
1822

    
1823
  def test(self):
1824
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1825
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1826
      self._checkRsaPrivateKey(key_pem)
1827
      self._checkCertificate(cert_pem)
1828

    
1829
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1830
                                           key_pem)
1831
      self.assert_(key.bits() >= 1024)
1832
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1833
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1834

    
1835
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1836
                                             cert_pem)
1837
      self.failIf(x509.has_expired())
1838
      self.assertEqual(x509.get_issuer().CN, common_name)
1839
      self.assertEqual(x509.get_subject().CN, common_name)
1840
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1841

    
1842
  def testLegacy(self):
1843
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1844

    
1845
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1846

    
1847
    cert1 = utils.ReadFile(cert1_filename)
1848

    
1849
    self.assert_(self._checkRsaPrivateKey(cert1))
1850
    self.assert_(self._checkCertificate(cert1))
1851

    
1852

    
1853
class TestPathJoin(unittest.TestCase):
1854
  """Testing case for PathJoin"""
1855

    
1856
  def testBasicItems(self):
1857
    mlist = ["/a", "b", "c"]
1858
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1859

    
1860
  def testNonAbsPrefix(self):
1861
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1862

    
1863
  def testBackTrack(self):
1864
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1865

    
1866
  def testMultiAbs(self):
1867
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1868

    
1869

    
1870
class TestHostInfo(unittest.TestCase):
1871
  """Testing case for HostInfo"""
1872

    
1873
  def testUppercase(self):
1874
    data = "AbC.example.com"
1875
    self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1876

    
1877
  def testTooLongName(self):
1878
    data = "a.b." + "c" * 255
1879
    self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1880

    
1881
  def testTrailingDot(self):
1882
    data = "a.b.c"
1883
    self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1884

    
1885
  def testInvalidName(self):
1886
    data = [
1887
      "a b",
1888
      "a/b",
1889
      ".a.b",
1890
      "a..b",
1891
      ]
1892
    for value in data:
1893
      self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1894

    
1895
  def testValidName(self):
1896
    data = [
1897
      "a.b",
1898
      "a-b",
1899
      "a_b",
1900
      "a.b.c",
1901
      ]
1902
    for value in data:
1903
      HostInfo.NormalizeName(value)
1904

    
1905

    
1906
class TestParseAsn1Generalizedtime(unittest.TestCase):
1907
  def test(self):
1908
    # UTC
1909
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1910
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1911
                     1266860512)
1912
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1913
                     (2**31) - 1)
1914

    
1915
    # With offset
1916
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1917
                     1266860512)
1918
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1919
                     1266931012)
1920
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1921
                     1266931088)
1922
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1923
                     1266931295)
1924
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1925
                     3600)
1926

    
1927
    # Leap seconds are not supported by datetime.datetime
1928
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1929
                      "19841231235960+0000")
1930
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1931
                      "19920630235960+0000")
1932

    
1933
    # Errors
1934
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1935
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1936
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1937
                      "20100222174152")
1938
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1939
                      "Mon Feb 22 17:47:02 UTC 2010")
1940
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1941
                      "2010-02-22 17:42:02")
1942

    
1943

    
1944
class TestGetX509CertValidity(testutils.GanetiTestCase):
1945
  def setUp(self):
1946
    testutils.GanetiTestCase.setUp(self)
1947

    
1948
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1949

    
1950
    # Test whether we have pyOpenSSL 0.7 or above
1951
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1952

    
1953
    if not self.pyopenssl0_7:
1954
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1955
                    " function correctly")
1956

    
1957
  def _LoadCert(self, name):
1958
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1959
                                           self._ReadTestData(name))
1960

    
1961
  def test(self):
1962
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1963
    if self.pyopenssl0_7:
1964
      self.assertEqual(validity, (1266919967, 1267524767))
1965
    else:
1966
      self.assertEqual(validity, (None, None))
1967

    
1968

    
1969
class TestSignX509Certificate(unittest.TestCase):
1970
  KEY = "My private key!"
1971
  KEY_OTHER = "Another key"
1972

    
1973
  def test(self):
1974
    # Generate certificate valid for 5 minutes
1975
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1976

    
1977
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1978
                                           cert_pem)
1979

    
1980
    # No signature at all
1981
    self.assertRaises(errors.GenericError,
1982
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1983

    
1984
    # Invalid input
1985
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1986
                      "", self.KEY)
1987
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1988
                      "X-Ganeti-Signature: \n", self.KEY)
1989
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1990
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1991
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1992
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1993
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1994
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1995

    
1996
    # Invalid salt
1997
    for salt in list("-_@$,:;/\\ \t\n"):
1998
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1999
                        cert_pem, self.KEY, "foo%sbar" % salt)
2000

    
2001
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
2002
                 utils.GenerateSecret(numbytes=4),
2003
                 utils.GenerateSecret(numbytes=16),
2004
                 "{123:456}".encode("hex")]:
2005
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2006

    
2007
      self._Check(cert, salt, signed_pem)
2008

    
2009
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2010
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2011
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2012
                               "lines----\n------ at\nthe end!"))
2013

    
2014
  def _Check(self, cert, salt, pem):
2015
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2016
    self.assertEqual(salt, salt2)
2017
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2018

    
2019
    # Other key
2020
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2021
                      pem, self.KEY_OTHER)
2022

    
2023

    
2024
class TestMakedirs(unittest.TestCase):
2025
  def setUp(self):
2026
    self.tmpdir = tempfile.mkdtemp()
2027

    
2028
  def tearDown(self):
2029
    shutil.rmtree(self.tmpdir)
2030

    
2031
  def testNonExisting(self):
2032
    path = utils.PathJoin(self.tmpdir, "foo")
2033
    utils.Makedirs(path)
2034
    self.assert_(os.path.isdir(path))
2035

    
2036
  def testExisting(self):
2037
    path = utils.PathJoin(self.tmpdir, "foo")
2038
    os.mkdir(path)
2039
    utils.Makedirs(path)
2040
    self.assert_(os.path.isdir(path))
2041

    
2042
  def testRecursiveNonExisting(self):
2043
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
2044
    utils.Makedirs(path)
2045
    self.assert_(os.path.isdir(path))
2046

    
2047
  def testRecursiveExisting(self):
2048
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
2049
    self.assert_(not os.path.exists(path))
2050
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
2051
    utils.Makedirs(path)
2052
    self.assert_(os.path.isdir(path))
2053

    
2054

    
2055
class TestRetry(testutils.GanetiTestCase):
2056
  def setUp(self):
2057
    testutils.GanetiTestCase.setUp(self)
2058
    self.retries = 0
2059

    
2060
  @staticmethod
2061
  def _RaiseRetryAgain():
2062
    raise utils.RetryAgain()
2063

    
2064
  @staticmethod
2065
  def _RaiseRetryAgainWithArg(args):
2066
    raise utils.RetryAgain(*args)
2067

    
2068
  def _WrongNestedLoop(self):
2069
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2070

    
2071
  def _RetryAndSucceed(self, retries):
2072
    if self.retries < retries:
2073
      self.retries += 1
2074
      raise utils.RetryAgain()
2075
    else:
2076
      return True
2077

    
2078
  def testRaiseTimeout(self):
2079
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2080
                          self._RaiseRetryAgain, 0.01, 0.02)
2081
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2082
                          self._RetryAndSucceed, 0.01, 0, args=[1])
2083
    self.failUnlessEqual(self.retries, 1)
2084

    
2085
  def testComplete(self):
2086
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2087
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2088
                         True)
2089
    self.failUnlessEqual(self.retries, 2)
2090

    
2091
  def testNestedLoop(self):
2092
    try:
2093
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2094
                            self._WrongNestedLoop, 0, 1)
2095
    except utils.RetryTimeout:
2096
      self.fail("Didn't detect inner loop's exception")
2097

    
2098
  def testTimeoutArgument(self):
2099
    retry_arg="my_important_debugging_message"
2100
    try:
2101
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2102
    except utils.RetryTimeout, err:
2103
      self.failUnlessEqual(err.args, (retry_arg, ))
2104
    else:
2105
      self.fail("Expected timeout didn't happen")
2106

    
2107
  def testRaiseInnerWithExc(self):
2108
    retry_arg="my_important_debugging_message"
2109
    try:
2110
      try:
2111
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2112
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
2113
      except utils.RetryTimeout, err:
2114
        err.RaiseInner()
2115
      else:
2116
        self.fail("Expected timeout didn't happen")
2117
    except errors.GenericError, err:
2118
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2119
    else:
2120
      self.fail("Expected GenericError didn't happen")
2121

    
2122
  def testRaiseInnerWithMsg(self):
2123
    retry_arg="my_important_debugging_message"
2124
    try:
2125
      try:
2126
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2127
                    args=[[retry_arg, retry_arg]])
2128
      except utils.RetryTimeout, err:
2129
        err.RaiseInner()
2130
      else:
2131
        self.fail("Expected timeout didn't happen")
2132
    except utils.RetryTimeout, err:
2133
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2134
    else:
2135
      self.fail("Expected RetryTimeout didn't happen")
2136

    
2137

    
2138
class TestLineSplitter(unittest.TestCase):
2139
  def test(self):
2140
    lines = []
2141
    ls = utils.LineSplitter(lines.append)
2142
    ls.write("Hello World\n")
2143
    self.assertEqual(lines, [])
2144
    ls.write("Foo\n Bar\r\n ")
2145
    ls.write("Baz")
2146
    ls.write("Moo")
2147
    self.assertEqual(lines, [])
2148
    ls.flush()
2149
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2150
    ls.close()
2151
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2152

    
2153
  def _testExtra(self, line, all_lines, p1, p2):
2154
    self.assertEqual(p1, 999)
2155
    self.assertEqual(p2, "extra")
2156
    all_lines.append(line)
2157

    
2158
  def testExtraArgsNoFlush(self):
2159
    lines = []
2160
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2161
    ls.write("\n\nHello World\n")
2162
    ls.write("Foo\n Bar\r\n ")
2163
    ls.write("")
2164
    ls.write("Baz")
2165
    ls.write("Moo\n\nx\n")
2166
    self.assertEqual(lines, [])
2167
    ls.close()
2168
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2169
                             "", "x"])
2170

    
2171

    
2172
class TestReadLockedPidFile(unittest.TestCase):
2173
  def setUp(self):
2174
    self.tmpdir = tempfile.mkdtemp()
2175

    
2176
  def tearDown(self):
2177
    shutil.rmtree(self.tmpdir)
2178

    
2179
  def testNonExistent(self):
2180
    path = utils.PathJoin(self.tmpdir, "nonexist")
2181
    self.assert_(utils.ReadLockedPidFile(path) is None)
2182

    
2183
  def testUnlocked(self):
2184
    path = utils.PathJoin(self.tmpdir, "pid")
2185
    utils.WriteFile(path, data="123")
2186
    self.assert_(utils.ReadLockedPidFile(path) is None)
2187

    
2188
  def testLocked(self):
2189
    path = utils.PathJoin(self.tmpdir, "pid")
2190
    utils.WriteFile(path, data="123")
2191

    
2192
    fl = utils.FileLock.Open(path)
2193
    try:
2194
      fl.Exclusive(blocking=True)
2195

    
2196
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2197
    finally:
2198
      fl.Close()
2199

    
2200
    self.assert_(utils.ReadLockedPidFile(path) is None)
2201

    
2202
  def testError(self):
2203
    path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2204
    utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2205
    # open(2) should return ENOTDIR
2206
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2207

    
2208

    
2209
class TestCertVerification(testutils.GanetiTestCase):
2210
  def setUp(self):
2211
    testutils.GanetiTestCase.setUp(self)
2212

    
2213
    self.tmpdir = tempfile.mkdtemp()
2214

    
2215
  def tearDown(self):
2216
    shutil.rmtree(self.tmpdir)
2217

    
2218
  def testVerifyCertificate(self):
2219
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2220
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2221
                                           cert_pem)
2222

    
2223
    # Not checking return value as this certificate is expired
2224
    utils.VerifyX509Certificate(cert, 30, 7)
2225

    
2226

    
2227
class TestVerifyCertificateInner(unittest.TestCase):
2228
  def test(self):
2229
    vci = utils._VerifyCertificateInner
2230

    
2231
    # Valid
2232
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2233
                     (None, None))
2234

    
2235
    # Not yet valid
2236
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2237
    self.assertEqual(errcode, utils.CERT_WARNING)
2238

    
2239
    # Expiring soon
2240
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2241
    self.assertEqual(errcode, utils.CERT_ERROR)
2242

    
2243
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2244
    self.assertEqual(errcode, utils.CERT_WARNING)
2245

    
2246
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2247
    self.assertEqual(errcode, None)
2248

    
2249
    # Expired
2250
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2251
    self.assertEqual(errcode, utils.CERT_ERROR)
2252

    
2253
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2254
    self.assertEqual(errcode, utils.CERT_ERROR)
2255

    
2256
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2257
    self.assertEqual(errcode, utils.CERT_ERROR)
2258

    
2259
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2260
    self.assertEqual(errcode, utils.CERT_ERROR)
2261

    
2262

    
2263
class TestHmacFunctions(unittest.TestCase):
2264
  # Digests can be checked with "openssl sha1 -hmac $key"
2265
  def testSha1Hmac(self):
2266
    self.assertEqual(utils.Sha1Hmac("", ""),
2267
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2268
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2269
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2270
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2271
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2272

    
2273
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2274
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2275
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2276

    
2277
  def testSha1HmacSalt(self):
2278
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2279
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2280
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2281
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2282
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2283
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2284

    
2285
  def testVerifySha1Hmac(self):
2286
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2287
                                               "7d64b71fb76370690e1d")))
2288
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2289
                                      ("f904c2476527c6d3e660"
2290
                                       "9ab683c66fa0652cb1dc")))
2291

    
2292
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2293
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2294
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2295
                                      digest.lower()))
2296
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2297
                                      digest.upper()))
2298
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2299
                                      digest.title()))
2300

    
2301
  def testVerifySha1HmacSalt(self):
2302
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2303
                                      ("17a4adc34d69c0d367d4"
2304
                                       "ffbef96fd41d4df7a6e8"),
2305
                                      salt="abc9"))
2306
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2307
                                      ("7f264f8114c9066afc9b"
2308
                                       "b7636e1786d996d3cc0d"),
2309
                                      salt="xyz0"))
2310

    
2311

    
2312
class TestIgnoreSignals(unittest.TestCase):
2313
  """Test the IgnoreSignals decorator"""
2314

    
2315
  @staticmethod
2316
  def _Raise(exception):
2317
    raise exception
2318

    
2319
  @staticmethod
2320
  def _Return(rval):
2321
    return rval
2322

    
2323
  def testIgnoreSignals(self):
2324
    sock_err_intr = socket.error(errno.EINTR, "Message")
2325
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2326

    
2327
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2328
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2329

    
2330
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2331
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2332
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2333
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2334

    
2335
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2336
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2337
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2338
                      sock_err_inval)
2339
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2340
                      env_err_inval)
2341

    
2342
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2343
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2344

    
2345

    
2346
class TestEnsureDirs(unittest.TestCase):
2347
  """Tests for EnsureDirs"""
2348

    
2349
  def setUp(self):
2350
    self.dir = tempfile.mkdtemp()
2351
    self.old_umask = os.umask(0777)
2352

    
2353
  def testEnsureDirs(self):
2354
    utils.EnsureDirs([
2355
        (utils.PathJoin(self.dir, "foo"), 0777),
2356
        (utils.PathJoin(self.dir, "bar"), 0000),
2357
        ])
2358
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2359
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2360

    
2361
  def tearDown(self):
2362
    os.rmdir(utils.PathJoin(self.dir, "foo"))
2363
    os.rmdir(utils.PathJoin(self.dir, "bar"))
2364
    os.rmdir(self.dir)
2365
    os.umask(self.old_umask)
2366

    
2367

    
2368
class TestFormatSeconds(unittest.TestCase):
2369
  def test(self):
2370
    self.assertEqual(utils.FormatSeconds(1), "1s")
2371
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2372
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2373
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2374
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2375
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2376
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2377
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2378
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2379
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2380
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2381

    
2382
  def testFloat(self):
2383
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2384
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2385
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2386
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2387

    
2388

    
2389
if __name__ == '__main__':
2390
  testutils.GanetiTestProgram()