Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 560cbec1

History | View | Annotate | Download (81.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import stat
31
import signal
32
import socket
33
import shutil
34
import re
35
import select
36
import string
37
import fcntl
38
import OpenSSL
39
import warnings
40
import distutils.version
41
import glob
42
import errno
43

    
44
import ganeti
45
import testutils
46
from ganeti import constants
47
from ganeti import compat
48
from ganeti import utils
49
from ganeti import errors
50
from ganeti import serializer
51
from ganeti.utils import IsProcessAlive, RunCmd, \
52
     RemoveFile, MatchNameComponent, FormatUnit, \
53
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
54
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
55
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
56
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
57
     UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
58

    
59
from ganeti.errors import LockError, UnitParseError, GenericError, \
60
     ProgrammerError, OpPrereqError
61

    
62

    
63
class TestIsProcessAlive(unittest.TestCase):
64
  """Testing case for IsProcessAlive"""
65

    
66
  def testExists(self):
67
    mypid = os.getpid()
68
    self.assert_(IsProcessAlive(mypid),
69
                 "can't find myself running")
70

    
71
  def testNotExisting(self):
72
    pid_non_existing = os.fork()
73
    if pid_non_existing == 0:
74
      os._exit(0)
75
    elif pid_non_existing < 0:
76
      raise SystemError("can't fork")
77
    os.waitpid(pid_non_existing, 0)
78
    self.assert_(not IsProcessAlive(pid_non_existing),
79
                 "nonexisting process detected")
80

    
81

    
82
class TestGetProcStatusPath(unittest.TestCase):
83
  def test(self):
84
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
85
    self.assertNotEqual(utils._GetProcStatusPath(1),
86
                        utils._GetProcStatusPath(2))
87

    
88

    
89
class TestIsProcessHandlingSignal(unittest.TestCase):
90
  def setUp(self):
91
    self.tmpdir = tempfile.mkdtemp()
92

    
93
  def tearDown(self):
94
    shutil.rmtree(self.tmpdir)
95

    
96
  def testParseSigsetT(self):
97
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
98
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
99
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
100
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
101
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
102
                     set([2, 10, 32, 33]))
103
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
104
                     set([2, 32, 33]))
105
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
106
                     set([2, 28, 32, 33]))
107
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
108
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
109
                          24, 25, 26, 28, 31]))
110
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
111

    
112
  def testGetProcStatusField(self):
113
    for field in ["SigCgt", "Name", "FDSize"]:
114
      for value in ["", "0", "cat", "  1234 KB"]:
115
        pstatus = "\n".join([
116
          "VmPeak: 999 kB",
117
          "%s: %s" % (field, value),
118
          "TracerPid: 0",
119
          ])
120
        result = utils._GetProcStatusField(pstatus, field)
121
        self.assertEqual(result, value.strip())
122

    
123
  def test(self):
124
    sp = utils.PathJoin(self.tmpdir, "status")
125

    
126
    utils.WriteFile(sp, data="\n".join([
127
      "Name:   bash",
128
      "State:  S (sleeping)",
129
      "SleepAVG:       98%",
130
      "Pid:    22250",
131
      "PPid:   10858",
132
      "TracerPid:      0",
133
      "SigBlk: 0000000000010000",
134
      "SigIgn: 0000000000384004",
135
      "SigCgt: 000000004b813efb",
136
      "CapEff: 0000000000000000",
137
      ]))
138

    
139
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
140

    
141
  def testNoSigCgt(self):
142
    sp = utils.PathJoin(self.tmpdir, "status")
143

    
144
    utils.WriteFile(sp, data="\n".join([
145
      "Name:   bash",
146
      ]))
147

    
148
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
149
                      1234, 10, status_path=sp)
150

    
151
  def testNoSuchFile(self):
152
    sp = utils.PathJoin(self.tmpdir, "notexist")
153

    
154
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
155

    
156
  @staticmethod
157
  def _TestRealProcess():
158
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
159
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160
      raise Exception("SIGUSR1 is handled when it should not be")
161

    
162
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
163
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164
      raise Exception("SIGUSR1 is not handled when it should be")
165

    
166
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
167
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
168
      raise Exception("SIGUSR1 is not handled when it should be")
169

    
170
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
171
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
172
      raise Exception("SIGUSR1 is handled when it should not be")
173

    
174
    return True
175

    
176
  def testRealProcess(self):
177
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
178

    
179

    
180
class TestPidFileFunctions(unittest.TestCase):
181
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
182

    
183
  def setUp(self):
184
    self.dir = tempfile.mkdtemp()
185
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
186
    utils.DaemonPidFileName = self.f_dpn
187

    
188
  def testPidFileFunctions(self):
189
    pid_file = self.f_dpn('test')
190
    utils.WritePidFile('test')
191
    self.failUnless(os.path.exists(pid_file),
192
                    "PID file should have been created")
193
    read_pid = utils.ReadPidFile(pid_file)
194
    self.failUnlessEqual(read_pid, os.getpid())
195
    self.failUnless(utils.IsProcessAlive(read_pid))
196
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
197
    utils.RemovePidFile('test')
198
    self.failIf(os.path.exists(pid_file),
199
                "PID file should not exist anymore")
200
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
201
                         "ReadPidFile should return 0 for missing pid file")
202
    fh = open(pid_file, "w")
203
    fh.write("blah\n")
204
    fh.close()
205
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
206
                         "ReadPidFile should return 0 for invalid pid file")
207
    utils.RemovePidFile('test')
208
    self.failIf(os.path.exists(pid_file),
209
                "PID file should not exist anymore")
210

    
211
  def testKill(self):
212
    pid_file = self.f_dpn('child')
213
    r_fd, w_fd = os.pipe()
214
    new_pid = os.fork()
215
    if new_pid == 0: #child
216
      utils.WritePidFile('child')
217
      os.write(w_fd, 'a')
218
      signal.pause()
219
      os._exit(0)
220
      return
221
    # else we are in the parent
222
    # wait until the child has written the pid file
223
    os.read(r_fd, 1)
224
    read_pid = utils.ReadPidFile(pid_file)
225
    self.failUnlessEqual(read_pid, new_pid)
226
    self.failUnless(utils.IsProcessAlive(new_pid))
227
    utils.KillProcess(new_pid, waitpid=True)
228
    self.failIf(utils.IsProcessAlive(new_pid))
229
    utils.RemovePidFile('child')
230
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
231

    
232
  def tearDown(self):
233
    for name in os.listdir(self.dir):
234
      os.unlink(os.path.join(self.dir, name))
235
    os.rmdir(self.dir)
236

    
237

    
238
class TestRunCmd(testutils.GanetiTestCase):
239
  """Testing case for the RunCmd function"""
240

    
241
  def setUp(self):
242
    testutils.GanetiTestCase.setUp(self)
243
    self.magic = time.ctime() + " ganeti test"
244
    self.fname = self._CreateTempFile()
245

    
246
  def testOk(self):
247
    """Test successful exit code"""
248
    result = RunCmd("/bin/sh -c 'exit 0'")
249
    self.assertEqual(result.exit_code, 0)
250
    self.assertEqual(result.output, "")
251

    
252
  def testFail(self):
253
    """Test fail exit code"""
254
    result = RunCmd("/bin/sh -c 'exit 1'")
255
    self.assertEqual(result.exit_code, 1)
256
    self.assertEqual(result.output, "")
257

    
258
  def testStdout(self):
259
    """Test standard output"""
260
    cmd = 'echo -n "%s"' % self.magic
261
    result = RunCmd("/bin/sh -c '%s'" % cmd)
262
    self.assertEqual(result.stdout, self.magic)
263
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
264
    self.assertEqual(result.output, "")
265
    self.assertFileContent(self.fname, self.magic)
266

    
267
  def testStderr(self):
268
    """Test standard error"""
269
    cmd = 'echo -n "%s"' % self.magic
270
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
271
    self.assertEqual(result.stderr, self.magic)
272
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
273
    self.assertEqual(result.output, "")
274
    self.assertFileContent(self.fname, self.magic)
275

    
276
  def testCombined(self):
277
    """Test combined output"""
278
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
279
    expected = "A" + self.magic + "B" + self.magic
280
    result = RunCmd("/bin/sh -c '%s'" % cmd)
281
    self.assertEqual(result.output, expected)
282
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
283
    self.assertEqual(result.output, "")
284
    self.assertFileContent(self.fname, expected)
285

    
286
  def testSignal(self):
287
    """Test signal"""
288
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
289
    self.assertEqual(result.signal, 15)
290
    self.assertEqual(result.output, "")
291

    
292
  def testListRun(self):
293
    """Test list runs"""
294
    result = RunCmd(["true"])
295
    self.assertEqual(result.signal, None)
296
    self.assertEqual(result.exit_code, 0)
297
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
298
    self.assertEqual(result.signal, None)
299
    self.assertEqual(result.exit_code, 1)
300
    result = RunCmd(["echo", "-n", self.magic])
301
    self.assertEqual(result.signal, None)
302
    self.assertEqual(result.exit_code, 0)
303
    self.assertEqual(result.stdout, self.magic)
304

    
305
  def testFileEmptyOutput(self):
306
    """Test file output"""
307
    result = RunCmd(["true"], output=self.fname)
308
    self.assertEqual(result.signal, None)
309
    self.assertEqual(result.exit_code, 0)
310
    self.assertFileContent(self.fname, "")
311

    
312
  def testLang(self):
313
    """Test locale environment"""
314
    old_env = os.environ.copy()
315
    try:
316
      os.environ["LANG"] = "en_US.UTF-8"
317
      os.environ["LC_ALL"] = "en_US.UTF-8"
318
      result = RunCmd(["locale"])
319
      for line in result.output.splitlines():
320
        key, value = line.split("=", 1)
321
        # Ignore these variables, they're overridden by LC_ALL
322
        if key == "LANG" or key == "LANGUAGE":
323
          continue
324
        self.failIf(value and value != "C" and value != '"C"',
325
            "Variable %s is set to the invalid value '%s'" % (key, value))
326
    finally:
327
      os.environ = old_env
328

    
329
  def testDefaultCwd(self):
330
    """Test default working directory"""
331
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
332

    
333
  def testCwd(self):
334
    """Test default working directory"""
335
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
336
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
337
    cwd = os.getcwd()
338
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
339

    
340
  def testResetEnv(self):
341
    """Test environment reset functionality"""
342
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
343
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
344
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
345

    
346

    
347
class TestRunParts(unittest.TestCase):
348
  """Testing case for the RunParts function"""
349

    
350
  def setUp(self):
351
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
352

    
353
  def tearDown(self):
354
    shutil.rmtree(self.rundir)
355

    
356
  def testEmpty(self):
357
    """Test on an empty dir"""
358
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
359

    
360
  def testSkipWrongName(self):
361
    """Test that wrong files are skipped"""
362
    fname = os.path.join(self.rundir, "00test.dot")
363
    utils.WriteFile(fname, data="")
364
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
365
    relname = os.path.basename(fname)
366
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
367
                         [(relname, constants.RUNPARTS_SKIP, None)])
368

    
369
  def testSkipNonExec(self):
370
    """Test that non executable files are skipped"""
371
    fname = os.path.join(self.rundir, "00test")
372
    utils.WriteFile(fname, data="")
373
    relname = os.path.basename(fname)
374
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
375
                         [(relname, constants.RUNPARTS_SKIP, None)])
376

    
377
  def testError(self):
378
    """Test error on a broken executable"""
379
    fname = os.path.join(self.rundir, "00test")
380
    utils.WriteFile(fname, data="")
381
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
382
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
383
    self.failUnlessEqual(relname, os.path.basename(fname))
384
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
385
    self.failUnless(error)
386

    
387
  def testSorted(self):
388
    """Test executions are sorted"""
389
    files = []
390
    files.append(os.path.join(self.rundir, "64test"))
391
    files.append(os.path.join(self.rundir, "00test"))
392
    files.append(os.path.join(self.rundir, "42test"))
393

    
394
    for fname in files:
395
      utils.WriteFile(fname, data="")
396

    
397
    results = RunParts(self.rundir, reset_env=True)
398

    
399
    for fname in sorted(files):
400
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
401

    
402
  def testOk(self):
403
    """Test correct execution"""
404
    fname = os.path.join(self.rundir, "00test")
405
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
406
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408
    self.failUnlessEqual(relname, os.path.basename(fname))
409
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410
    self.failUnlessEqual(runresult.stdout, "ciao")
411

    
412
  def testRunFail(self):
413
    """Test correct execution, with run failure"""
414
    fname = os.path.join(self.rundir, "00test")
415
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
416
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
417
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
418
    self.failUnlessEqual(relname, os.path.basename(fname))
419
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
420
    self.failUnlessEqual(runresult.exit_code, 1)
421
    self.failUnless(runresult.failed)
422

    
423
  def testRunMix(self):
424
    files = []
425
    files.append(os.path.join(self.rundir, "00test"))
426
    files.append(os.path.join(self.rundir, "42test"))
427
    files.append(os.path.join(self.rundir, "64test"))
428
    files.append(os.path.join(self.rundir, "99test"))
429

    
430
    files.sort()
431

    
432
    # 1st has errors in execution
433
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
434
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
435

    
436
    # 2nd is skipped
437
    utils.WriteFile(files[1], data="")
438

    
439
    # 3rd cannot execute properly
440
    utils.WriteFile(files[2], data="")
441
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
442

    
443
    # 4th execs
444
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
445
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
446

    
447
    results = RunParts(self.rundir, reset_env=True)
448

    
449
    (relname, status, runresult) = results[0]
450
    self.failUnlessEqual(relname, os.path.basename(files[0]))
451
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
452
    self.failUnlessEqual(runresult.exit_code, 1)
453
    self.failUnless(runresult.failed)
454

    
455
    (relname, status, runresult) = results[1]
456
    self.failUnlessEqual(relname, os.path.basename(files[1]))
457
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
458
    self.failUnlessEqual(runresult, None)
459

    
460
    (relname, status, runresult) = results[2]
461
    self.failUnlessEqual(relname, os.path.basename(files[2]))
462
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
463
    self.failUnless(runresult)
464

    
465
    (relname, status, runresult) = results[3]
466
    self.failUnlessEqual(relname, os.path.basename(files[3]))
467
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
468
    self.failUnlessEqual(runresult.output, "ciao")
469
    self.failUnlessEqual(runresult.exit_code, 0)
470
    self.failUnless(not runresult.failed)
471

    
472

    
473
class TestStartDaemon(testutils.GanetiTestCase):
474
  def setUp(self):
475
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
476
    self.tmpfile = os.path.join(self.tmpdir, "test")
477

    
478
  def tearDown(self):
479
    shutil.rmtree(self.tmpdir)
480

    
481
  def testShell(self):
482
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
483
    self._wait(self.tmpfile, 60.0, "Hello World")
484

    
485
  def testShellOutput(self):
486
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
487
    self._wait(self.tmpfile, 60.0, "Hello World")
488

    
489
  def testNoShellNoOutput(self):
490
    utils.StartDaemon(["pwd"])
491

    
492
  def testNoShellNoOutputTouch(self):
493
    testfile = os.path.join(self.tmpdir, "check")
494
    self.failIf(os.path.exists(testfile))
495
    utils.StartDaemon(["touch", testfile])
496
    self._wait(testfile, 60.0, "")
497

    
498
  def testNoShellOutput(self):
499
    utils.StartDaemon(["pwd"], output=self.tmpfile)
500
    self._wait(self.tmpfile, 60.0, "/")
501

    
502
  def testNoShellOutputCwd(self):
503
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
504
    self._wait(self.tmpfile, 60.0, os.getcwd())
505

    
506
  def testShellEnv(self):
507
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
508
                      env={ "GNT_TEST_VAR": "Hello World", })
509
    self._wait(self.tmpfile, 60.0, "Hello World")
510

    
511
  def testNoShellEnv(self):
512
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
513
                      env={ "GNT_TEST_VAR": "Hello World", })
514
    self._wait(self.tmpfile, 60.0, "Hello World")
515

    
516
  def testOutputFd(self):
517
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
518
    try:
519
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
520
    finally:
521
      os.close(fd)
522
    self._wait(self.tmpfile, 60.0, os.getcwd())
523

    
524
  def testPid(self):
525
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
526
    self._wait(self.tmpfile, 60.0, str(pid))
527

    
528
  def testPidFile(self):
529
    pidfile = os.path.join(self.tmpdir, "pid")
530
    checkfile = os.path.join(self.tmpdir, "abort")
531

    
532
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
533
                            output=self.tmpfile)
534
    try:
535
      fd = os.open(pidfile, os.O_RDONLY)
536
      try:
537
        # Check file is locked
538
        self.assertRaises(errors.LockError, utils.LockFile, fd)
539

    
540
        pidtext = os.read(fd, 100)
541
      finally:
542
        os.close(fd)
543

    
544
      self.assertEqual(int(pidtext.strip()), pid)
545

    
546
      self.assert_(utils.IsProcessAlive(pid))
547
    finally:
548
      # No matter what happens, kill daemon
549
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
550
      self.failIf(utils.IsProcessAlive(pid))
551

    
552
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
553

    
554
  def _wait(self, path, timeout, expected):
555
    # Due to the asynchronous nature of daemon processes, polling is necessary.
556
    # A timeout makes sure the test doesn't hang forever.
557
    def _CheckFile():
558
      if not (os.path.isfile(path) and
559
              utils.ReadFile(path).strip() == expected):
560
        raise utils.RetryAgain()
561

    
562
    try:
563
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
564
    except utils.RetryTimeout:
565
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
566
                " didn't write the correct output" % timeout)
567

    
568
  def testError(self):
569
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
570
                      ["./does-NOT-EXIST/here/0123456789"])
571
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
572
                      ["./does-NOT-EXIST/here/0123456789"],
573
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
574
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
575
                      ["./does-NOT-EXIST/here/0123456789"],
576
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
577
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
578
                      ["./does-NOT-EXIST/here/0123456789"],
579
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
580

    
581
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
582
    try:
583
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
584
                        ["./does-NOT-EXIST/here/0123456789"],
585
                        output=self.tmpfile, output_fd=fd)
586
    finally:
587
      os.close(fd)
588

    
589

    
590
class TestSetCloseOnExecFlag(unittest.TestCase):
591
  """Tests for SetCloseOnExecFlag"""
592

    
593
  def setUp(self):
594
    self.tmpfile = tempfile.TemporaryFile()
595

    
596
  def testEnable(self):
597
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
598
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
599
                    fcntl.FD_CLOEXEC)
600

    
601
  def testDisable(self):
602
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
603
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
604
                fcntl.FD_CLOEXEC)
605

    
606

    
607
class TestSetNonblockFlag(unittest.TestCase):
608
  def setUp(self):
609
    self.tmpfile = tempfile.TemporaryFile()
610

    
611
  def testEnable(self):
612
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
613
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
614
                    os.O_NONBLOCK)
615

    
616
  def testDisable(self):
617
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
618
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
619
                os.O_NONBLOCK)
620

    
621

    
622
class TestRemoveFile(unittest.TestCase):
623
  """Test case for the RemoveFile function"""
624

    
625
  def setUp(self):
626
    """Create a temp dir and file for each case"""
627
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
628
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
629
    os.close(fd)
630

    
631
  def tearDown(self):
632
    if os.path.exists(self.tmpfile):
633
      os.unlink(self.tmpfile)
634
    os.rmdir(self.tmpdir)
635

    
636
  def testIgnoreDirs(self):
637
    """Test that RemoveFile() ignores directories"""
638
    self.assertEqual(None, RemoveFile(self.tmpdir))
639

    
640
  def testIgnoreNotExisting(self):
641
    """Test that RemoveFile() ignores non-existing files"""
642
    RemoveFile(self.tmpfile)
643
    RemoveFile(self.tmpfile)
644

    
645
  def testRemoveFile(self):
646
    """Test that RemoveFile does remove a file"""
647
    RemoveFile(self.tmpfile)
648
    if os.path.exists(self.tmpfile):
649
      self.fail("File '%s' not removed" % self.tmpfile)
650

    
651
  def testRemoveSymlink(self):
652
    """Test that RemoveFile does remove symlinks"""
653
    symlink = self.tmpdir + "/symlink"
654
    os.symlink("no-such-file", symlink)
655
    RemoveFile(symlink)
656
    if os.path.exists(symlink):
657
      self.fail("File '%s' not removed" % symlink)
658
    os.symlink(self.tmpfile, symlink)
659
    RemoveFile(symlink)
660
    if os.path.exists(symlink):
661
      self.fail("File '%s' not removed" % symlink)
662

    
663

    
664
class TestRename(unittest.TestCase):
665
  """Test case for RenameFile"""
666

    
667
  def setUp(self):
668
    """Create a temporary directory"""
669
    self.tmpdir = tempfile.mkdtemp()
670
    self.tmpfile = os.path.join(self.tmpdir, "test1")
671

    
672
    # Touch the file
673
    open(self.tmpfile, "w").close()
674

    
675
  def tearDown(self):
676
    """Remove temporary directory"""
677
    shutil.rmtree(self.tmpdir)
678

    
679
  def testSimpleRename1(self):
680
    """Simple rename 1"""
681
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
682
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
683

    
684
  def testSimpleRename2(self):
685
    """Simple rename 2"""
686
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
687
                     mkdir=True)
688
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
689

    
690
  def testRenameMkdir(self):
691
    """Rename with mkdir"""
692
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
693
                     mkdir=True)
694
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
695
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
696

    
697
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
698
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
699
                     mkdir=True)
700
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
701
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
702
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
703

    
704

    
705
class TestMatchNameComponent(unittest.TestCase):
706
  """Test case for the MatchNameComponent function"""
707

    
708
  def testEmptyList(self):
709
    """Test that there is no match against an empty list"""
710

    
711
    self.failUnlessEqual(MatchNameComponent("", []), None)
712
    self.failUnlessEqual(MatchNameComponent("test", []), None)
713

    
714
  def testSingleMatch(self):
715
    """Test that a single match is performed correctly"""
716
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
717
    for key in "test2", "test2.example", "test2.example.com":
718
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
719

    
720
  def testMultipleMatches(self):
721
    """Test that a multiple match is returned as None"""
722
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
723
    for key in "test1", "test1.example":
724
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
725

    
726
  def testFullMatch(self):
727
    """Test that a full match is returned correctly"""
728
    key1 = "test1"
729
    key2 = "test1.example"
730
    mlist = [key2, key2 + ".com"]
731
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
732
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
733

    
734
  def testCaseInsensitivePartialMatch(self):
735
    """Test for the case_insensitive keyword"""
736
    mlist = ["test1.example.com", "test2.example.net"]
737
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
738
                     "test2.example.net")
739
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
740
                     "test2.example.net")
741
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
742
                     "test2.example.net")
743
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
744
                     "test2.example.net")
745

    
746

    
747
  def testCaseInsensitiveFullMatch(self):
748
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
749
    # Between the two ts1 a full string match non-case insensitive should work
750
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
751
                     None)
752
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
753
                     "ts1.ex")
754
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
755
                     "ts1.ex")
756
    # Between the two ts2 only case differs, so only case-match works
757
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
758
                     "ts2.ex")
759
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
760
                     "Ts2.ex")
761
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
762
                     None)
763

    
764

    
765
class TestReadFile(testutils.GanetiTestCase):
766

    
767
  def testReadAll(self):
768
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
769
    self.assertEqual(len(data), 814)
770

    
771
    h = compat.md5_hash()
772
    h.update(data)
773
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
774

    
775
  def testReadSize(self):
776
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
777
                          size=100)
778
    self.assertEqual(len(data), 100)
779

    
780
    h = compat.md5_hash()
781
    h.update(data)
782
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
783

    
784
  def testError(self):
785
    self.assertRaises(EnvironmentError, utils.ReadFile,
786
                      "/dev/null/does-not-exist")
787

    
788

    
789
class TestReadOneLineFile(testutils.GanetiTestCase):
790

    
791
  def setUp(self):
792
    testutils.GanetiTestCase.setUp(self)
793

    
794
  def testDefault(self):
795
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
796
    self.assertEqual(len(data), 27)
797
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
798

    
799
  def testNotStrict(self):
800
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
801
    self.assertEqual(len(data), 27)
802
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
803

    
804
  def testStrictFailure(self):
805
    self.assertRaises(errors.GenericError, ReadOneLineFile,
806
                      self._TestDataFilename("cert1.pem"), strict=True)
807

    
808
  def testLongLine(self):
809
    dummydata = (1024 * "Hello World! ")
810
    myfile = self._CreateTempFile()
811
    utils.WriteFile(myfile, data=dummydata)
812
    datastrict = ReadOneLineFile(myfile, strict=True)
813
    datalax = ReadOneLineFile(myfile, strict=False)
814
    self.assertEqual(dummydata, datastrict)
815
    self.assertEqual(dummydata, datalax)
816

    
817
  def testNewline(self):
818
    myfile = self._CreateTempFile()
819
    myline = "myline"
820
    for nl in ["", "\n", "\r\n"]:
821
      dummydata = "%s%s" % (myline, nl)
822
      utils.WriteFile(myfile, data=dummydata)
823
      datalax = ReadOneLineFile(myfile, strict=False)
824
      self.assertEqual(myline, datalax)
825
      datastrict = ReadOneLineFile(myfile, strict=True)
826
      self.assertEqual(myline, datastrict)
827

    
828
  def testWhitespaceAndMultipleLines(self):
829
    myfile = self._CreateTempFile()
830
    for nl in ["", "\n", "\r\n"]:
831
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
832
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
833
        utils.WriteFile(myfile, data=dummydata)
834
        datalax = ReadOneLineFile(myfile, strict=False)
835
        if nl:
836
          self.assert_(set("\r\n") & set(dummydata))
837
          self.assertRaises(errors.GenericError, ReadOneLineFile,
838
                            myfile, strict=True)
839
          explen = len("Foo bar baz ") + len(ws)
840
          self.assertEqual(len(datalax), explen)
841
          self.assertEqual(datalax, dummydata[:explen])
842
          self.assertFalse(set("\r\n") & set(datalax))
843
        else:
844
          datastrict = ReadOneLineFile(myfile, strict=True)
845
          self.assertEqual(dummydata, datastrict)
846
          self.assertEqual(dummydata, datalax)
847

    
848
  def testEmptylines(self):
849
    myfile = self._CreateTempFile()
850
    myline = "myline"
851
    for nl in ["\n", "\r\n"]:
852
      for ol in ["", "otherline"]:
853
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
854
        utils.WriteFile(myfile, data=dummydata)
855
        self.assert_(set("\r\n") & set(dummydata))
856
        datalax = ReadOneLineFile(myfile, strict=False)
857
        self.assertEqual(myline, datalax)
858
        if ol:
859
          self.assertRaises(errors.GenericError, ReadOneLineFile,
860
                            myfile, strict=True)
861
        else:
862
          datastrict = ReadOneLineFile(myfile, strict=True)
863
          self.assertEqual(myline, datastrict)
864

    
865

    
866
class TestTimestampForFilename(unittest.TestCase):
867
  def test(self):
868
    self.assert_("." not in utils.TimestampForFilename())
869
    self.assert_(":" not in utils.TimestampForFilename())
870

    
871

    
872
class TestCreateBackup(testutils.GanetiTestCase):
873
  def setUp(self):
874
    testutils.GanetiTestCase.setUp(self)
875

    
876
    self.tmpdir = tempfile.mkdtemp()
877

    
878
  def tearDown(self):
879
    testutils.GanetiTestCase.tearDown(self)
880

    
881
    shutil.rmtree(self.tmpdir)
882

    
883
  def testEmpty(self):
884
    filename = utils.PathJoin(self.tmpdir, "config.data")
885
    utils.WriteFile(filename, data="")
886
    bname = utils.CreateBackup(filename)
887
    self.assertFileContent(bname, "")
888
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
889
    utils.CreateBackup(filename)
890
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
891
    utils.CreateBackup(filename)
892
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
893

    
894
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
895
    os.mkfifo(fifoname)
896
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
897

    
898
  def testContent(self):
899
    bkpcount = 0
900
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
901
      for rep in [1, 2, 10, 127]:
902
        testdata = data * rep
903

    
904
        filename = utils.PathJoin(self.tmpdir, "test.data_")
905
        utils.WriteFile(filename, data=testdata)
906
        self.assertFileContent(filename, testdata)
907

    
908
        for _ in range(3):
909
          bname = utils.CreateBackup(filename)
910
          bkpcount += 1
911
          self.assertFileContent(bname, testdata)
912
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
913

    
914

    
915
class TestFormatUnit(unittest.TestCase):
916
  """Test case for the FormatUnit function"""
917

    
918
  def testMiB(self):
919
    self.assertEqual(FormatUnit(1, 'h'), '1M')
920
    self.assertEqual(FormatUnit(100, 'h'), '100M')
921
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
922

    
923
    self.assertEqual(FormatUnit(1, 'm'), '1')
924
    self.assertEqual(FormatUnit(100, 'm'), '100')
925
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
926

    
927
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
928
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
929
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
930
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
931

    
932
  def testGiB(self):
933
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
934
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
935
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
936
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
937

    
938
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
939
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
940
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
941
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
942

    
943
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
944
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
945
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
946

    
947
  def testTiB(self):
948
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
949
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
950
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
951

    
952
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
953
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
954
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
955

    
956
class TestParseUnit(unittest.TestCase):
957
  """Test case for the ParseUnit function"""
958

    
959
  SCALES = (('', 1),
960
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
961
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
962
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
963

    
964
  def testRounding(self):
965
    self.assertEqual(ParseUnit('0'), 0)
966
    self.assertEqual(ParseUnit('1'), 4)
967
    self.assertEqual(ParseUnit('2'), 4)
968
    self.assertEqual(ParseUnit('3'), 4)
969

    
970
    self.assertEqual(ParseUnit('124'), 124)
971
    self.assertEqual(ParseUnit('125'), 128)
972
    self.assertEqual(ParseUnit('126'), 128)
973
    self.assertEqual(ParseUnit('127'), 128)
974
    self.assertEqual(ParseUnit('128'), 128)
975
    self.assertEqual(ParseUnit('129'), 132)
976
    self.assertEqual(ParseUnit('130'), 132)
977

    
978
  def testFloating(self):
979
    self.assertEqual(ParseUnit('0'), 0)
980
    self.assertEqual(ParseUnit('0.5'), 4)
981
    self.assertEqual(ParseUnit('1.75'), 4)
982
    self.assertEqual(ParseUnit('1.99'), 4)
983
    self.assertEqual(ParseUnit('2.00'), 4)
984
    self.assertEqual(ParseUnit('2.01'), 4)
985
    self.assertEqual(ParseUnit('3.99'), 4)
986
    self.assertEqual(ParseUnit('4.00'), 4)
987
    self.assertEqual(ParseUnit('4.01'), 8)
988
    self.assertEqual(ParseUnit('1.5G'), 1536)
989
    self.assertEqual(ParseUnit('1.8G'), 1844)
990
    self.assertEqual(ParseUnit('8.28T'), 8682212)
991

    
992
  def testSuffixes(self):
993
    for sep in ('', ' ', '   ', "\t", "\t "):
994
      for suffix, scale in TestParseUnit.SCALES:
995
        for func in (lambda x: x, str.lower, str.upper):
996
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
997
                           1024 * scale)
998

    
999
  def testInvalidInput(self):
1000
    for sep in ('-', '_', ',', 'a'):
1001
      for suffix, _ in TestParseUnit.SCALES:
1002
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
1003

    
1004
    for suffix, _ in TestParseUnit.SCALES:
1005
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
1006

    
1007

    
1008
class TestSshKeys(testutils.GanetiTestCase):
1009
  """Test case for the AddAuthorizedKey function"""
1010

    
1011
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1012
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
1013
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1014

    
1015
  def setUp(self):
1016
    testutils.GanetiTestCase.setUp(self)
1017
    self.tmpname = self._CreateTempFile()
1018
    handle = open(self.tmpname, 'w')
1019
    try:
1020
      handle.write("%s\n" % TestSshKeys.KEY_A)
1021
      handle.write("%s\n" % TestSshKeys.KEY_B)
1022
    finally:
1023
      handle.close()
1024

    
1025
  def testAddingNewKey(self):
1026
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1027

    
1028
    self.assertFileContent(self.tmpname,
1029
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1030
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1031
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1032
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1033

    
1034
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1035
    AddAuthorizedKey(self.tmpname,
1036
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1037

    
1038
    self.assertFileContent(self.tmpname,
1039
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1040
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1041
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1042
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1043

    
1044
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1045
    AddAuthorizedKey(self.tmpname,
1046
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1047

    
1048
    self.assertFileContent(self.tmpname,
1049
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1050
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1051
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1052

    
1053
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1054
    RemoveAuthorizedKey(self.tmpname,
1055
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1056

    
1057
    self.assertFileContent(self.tmpname,
1058
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1059
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1060

    
1061
  def testRemovingNonExistingKey(self):
1062
    RemoveAuthorizedKey(self.tmpname,
1063
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1064

    
1065
    self.assertFileContent(self.tmpname,
1066
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1067
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1068
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1069

    
1070

    
1071
class TestEtcHosts(testutils.GanetiTestCase):
1072
  """Test functions modifying /etc/hosts"""
1073

    
1074
  def setUp(self):
1075
    testutils.GanetiTestCase.setUp(self)
1076
    self.tmpname = self._CreateTempFile()
1077
    handle = open(self.tmpname, 'w')
1078
    try:
1079
      handle.write('# This is a test file for /etc/hosts\n')
1080
      handle.write('127.0.0.1\tlocalhost\n')
1081
      handle.write('192.168.1.1 router gw\n')
1082
    finally:
1083
      handle.close()
1084

    
1085
  def testSettingNewIp(self):
1086
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
1087

    
1088
    self.assertFileContent(self.tmpname,
1089
      "# This is a test file for /etc/hosts\n"
1090
      "127.0.0.1\tlocalhost\n"
1091
      "192.168.1.1 router gw\n"
1092
      "1.2.3.4\tmyhost.domain.tld myhost\n")
1093
    self.assertFileMode(self.tmpname, 0644)
1094

    
1095
  def testSettingExistingIp(self):
1096
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
1097
                     ['myhost'])
1098

    
1099
    self.assertFileContent(self.tmpname,
1100
      "# This is a test file for /etc/hosts\n"
1101
      "127.0.0.1\tlocalhost\n"
1102
      "192.168.1.1\tmyhost.domain.tld myhost\n")
1103
    self.assertFileMode(self.tmpname, 0644)
1104

    
1105
  def testSettingDuplicateName(self):
1106
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
1107

    
1108
    self.assertFileContent(self.tmpname,
1109
      "# This is a test file for /etc/hosts\n"
1110
      "127.0.0.1\tlocalhost\n"
1111
      "192.168.1.1 router gw\n"
1112
      "1.2.3.4\tmyhost\n")
1113
    self.assertFileMode(self.tmpname, 0644)
1114

    
1115
  def testRemovingExistingHost(self):
1116
    RemoveEtcHostsEntry(self.tmpname, 'router')
1117

    
1118
    self.assertFileContent(self.tmpname,
1119
      "# This is a test file for /etc/hosts\n"
1120
      "127.0.0.1\tlocalhost\n"
1121
      "192.168.1.1 gw\n")
1122
    self.assertFileMode(self.tmpname, 0644)
1123

    
1124
  def testRemovingSingleExistingHost(self):
1125
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1126

    
1127
    self.assertFileContent(self.tmpname,
1128
      "# This is a test file for /etc/hosts\n"
1129
      "192.168.1.1 router gw\n")
1130
    self.assertFileMode(self.tmpname, 0644)
1131

    
1132
  def testRemovingNonExistingHost(self):
1133
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1134

    
1135
    self.assertFileContent(self.tmpname,
1136
      "# This is a test file for /etc/hosts\n"
1137
      "127.0.0.1\tlocalhost\n"
1138
      "192.168.1.1 router gw\n")
1139
    self.assertFileMode(self.tmpname, 0644)
1140

    
1141
  def testRemovingAlias(self):
1142
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1143

    
1144
    self.assertFileContent(self.tmpname,
1145
      "# This is a test file for /etc/hosts\n"
1146
      "127.0.0.1\tlocalhost\n"
1147
      "192.168.1.1 router\n")
1148
    self.assertFileMode(self.tmpname, 0644)
1149

    
1150

    
1151
class TestShellQuoting(unittest.TestCase):
1152
  """Test case for shell quoting functions"""
1153

    
1154
  def testShellQuote(self):
1155
    self.assertEqual(ShellQuote('abc'), "abc")
1156
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1157
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1158
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1159
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1160

    
1161
  def testShellQuoteArgs(self):
1162
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1163
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1164
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1165

    
1166

    
1167
class TestTcpPing(unittest.TestCase):
1168
  """Testcase for TCP version of ping - against listen(2)ing port"""
1169

    
1170
  def setUp(self):
1171
    self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1172
    self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1173
    self.listenerport = self.listener.getsockname()[1]
1174
    self.listener.listen(1)
1175

    
1176
  def tearDown(self):
1177
    self.listener.shutdown(socket.SHUT_RDWR)
1178
    del self.listener
1179
    del self.listenerport
1180

    
1181
  def testTcpPingToLocalHostAccept(self):
1182
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1183
                         self.listenerport,
1184
                         timeout=10,
1185
                         live_port_needed=True,
1186
                         source=constants.LOCALHOST_IP_ADDRESS,
1187
                         ),
1188
                 "failed to connect to test listener")
1189

    
1190
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1191
                         self.listenerport,
1192
                         timeout=10,
1193
                         live_port_needed=True,
1194
                         ),
1195
                 "failed to connect to test listener (no source)")
1196

    
1197

    
1198
class TestTcpPingDeaf(unittest.TestCase):
1199
  """Testcase for TCP version of ping - against non listen(2)ing port"""
1200

    
1201
  def setUp(self):
1202
    self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1203
    self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1204
    self.deaflistenerport = self.deaflistener.getsockname()[1]
1205

    
1206
  def tearDown(self):
1207
    del self.deaflistener
1208
    del self.deaflistenerport
1209

    
1210
  def testTcpPingToLocalHostAcceptDeaf(self):
1211
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1212
                        self.deaflistenerport,
1213
                        timeout=constants.TCP_PING_TIMEOUT,
1214
                        live_port_needed=True,
1215
                        source=constants.LOCALHOST_IP_ADDRESS,
1216
                        ), # need successful connect(2)
1217
                "successfully connected to deaf listener")
1218

    
1219
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1220
                        self.deaflistenerport,
1221
                        timeout=constants.TCP_PING_TIMEOUT,
1222
                        live_port_needed=True,
1223
                        ), # need successful connect(2)
1224
                "successfully connected to deaf listener (no source addr)")
1225

    
1226
  def testTcpPingToLocalHostNoAccept(self):
1227
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1228
                         self.deaflistenerport,
1229
                         timeout=constants.TCP_PING_TIMEOUT,
1230
                         live_port_needed=False,
1231
                         source=constants.LOCALHOST_IP_ADDRESS,
1232
                         ), # ECONNREFUSED is OK
1233
                 "failed to ping alive host on deaf port")
1234

    
1235
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1236
                         self.deaflistenerport,
1237
                         timeout=constants.TCP_PING_TIMEOUT,
1238
                         live_port_needed=False,
1239
                         ), # ECONNREFUSED is OK
1240
                 "failed to ping alive host on deaf port (no source addr)")
1241

    
1242

    
1243
class TestOwnIpAddress(unittest.TestCase):
1244
  """Testcase for OwnIpAddress"""
1245

    
1246
  def testOwnLoopback(self):
1247
    """check having the loopback ip"""
1248
    self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1249
                    "Should own the loopback address")
1250

    
1251
  def testNowOwnAddress(self):
1252
    """check that I don't own an address"""
1253

    
1254
    # Network 192.0.2.0/24 is reserved for test/documentation as per
1255
    # RFC 5735, so we *should* not have an address of this range... if
1256
    # this fails, we should extend the test to multiple addresses
1257
    DST_IP = "192.0.2.1"
1258
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1259

    
1260

    
1261
def _GetSocketCredentials(path):
1262
  """Connect to a Unix socket and return remote credentials.
1263

1264
  """
1265
  sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1266
  try:
1267
    sock.settimeout(10)
1268
    sock.connect(path)
1269
    return utils.GetSocketCredentials(sock)
1270
  finally:
1271
    sock.close()
1272

    
1273

    
1274
class TestGetSocketCredentials(unittest.TestCase):
1275
  def setUp(self):
1276
    self.tmpdir = tempfile.mkdtemp()
1277
    self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1278

    
1279
    self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1280
    self.listener.settimeout(10)
1281
    self.listener.bind(self.sockpath)
1282
    self.listener.listen(1)
1283

    
1284
  def tearDown(self):
1285
    self.listener.shutdown(socket.SHUT_RDWR)
1286
    self.listener.close()
1287
    shutil.rmtree(self.tmpdir)
1288

    
1289
  def test(self):
1290
    (c2pr, c2pw) = os.pipe()
1291

    
1292
    # Start child process
1293
    child = os.fork()
1294
    if child == 0:
1295
      try:
1296
        data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1297

    
1298
        os.write(c2pw, data)
1299
        os.close(c2pw)
1300

    
1301
        os._exit(0)
1302
      finally:
1303
        os._exit(1)
1304

    
1305
    os.close(c2pw)
1306

    
1307
    # Wait for one connection
1308
    (conn, _) = self.listener.accept()
1309
    conn.recv(1)
1310
    conn.close()
1311

    
1312
    # Wait for result
1313
    result = os.read(c2pr, 4096)
1314
    os.close(c2pr)
1315

    
1316
    # Check child's exit code
1317
    (_, status) = os.waitpid(child, 0)
1318
    self.assertFalse(os.WIFSIGNALED(status))
1319
    self.assertEqual(os.WEXITSTATUS(status), 0)
1320

    
1321
    # Check result
1322
    (pid, uid, gid) = serializer.LoadJson(result)
1323
    self.assertEqual(pid, os.getpid())
1324
    self.assertEqual(uid, os.getuid())
1325
    self.assertEqual(gid, os.getgid())
1326

    
1327

    
1328
class TestListVisibleFiles(unittest.TestCase):
1329
  """Test case for ListVisibleFiles"""
1330

    
1331
  def setUp(self):
1332
    self.path = tempfile.mkdtemp()
1333

    
1334
  def tearDown(self):
1335
    shutil.rmtree(self.path)
1336

    
1337
  def _CreateFiles(self, files):
1338
    for name in files:
1339
      utils.WriteFile(os.path.join(self.path, name), data="test")
1340

    
1341
  def _test(self, files, expected):
1342
    self._CreateFiles(files)
1343
    found = ListVisibleFiles(self.path)
1344
    # by default ListVisibleFiles sorts its output
1345
    self.assertEqual(found, sorted(expected))
1346

    
1347
  def testAllVisible(self):
1348
    files = ["a", "b", "c"]
1349
    expected = files
1350
    self._test(files, expected)
1351

    
1352
  def testNoneVisible(self):
1353
    files = [".a", ".b", ".c"]
1354
    expected = []
1355
    self._test(files, expected)
1356

    
1357
  def testSomeVisible(self):
1358
    files = ["a", "b", ".c"]
1359
    expected = ["a", "b"]
1360
    self._test(files, expected)
1361

    
1362
  def testForceSort(self):
1363
    files = ["c", "b", "a"]
1364
    self._CreateFiles(files)
1365
    found = ListVisibleFiles(self.path, sort=True)
1366
    self.assertEqual(found, sorted(files))
1367

    
1368
  def testForceNonSort(self):
1369
    files = ["c", "b", "a"]
1370
    self._CreateFiles(files)
1371
    found = ListVisibleFiles(self.path, sort=False)
1372
    # We can't actually check that they weren't sorted, because they might come
1373
    # out sorted by chance
1374
    self.assertEqual(set(found), set(files))
1375

    
1376
  def testNonAbsolutePath(self):
1377
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1378

    
1379
  def testNonNormalizedPath(self):
1380
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1381
                          "/bin/../tmp")
1382

    
1383

    
1384
class TestNewUUID(unittest.TestCase):
1385
  """Test case for NewUUID"""
1386

    
1387
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1388
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1389

    
1390
  def runTest(self):
1391
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1392

    
1393

    
1394
class TestUniqueSequence(unittest.TestCase):
1395
  """Test case for UniqueSequence"""
1396

    
1397
  def _test(self, input, expected):
1398
    self.assertEqual(utils.UniqueSequence(input), expected)
1399

    
1400
  def runTest(self):
1401
    # Ordered input
1402
    self._test([1, 2, 3], [1, 2, 3])
1403
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1404
    self._test([1, 2, 2, 3], [1, 2, 3])
1405
    self._test([1, 2, 3, 3], [1, 2, 3])
1406

    
1407
    # Unordered input
1408
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1409
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1410

    
1411
    # Strings
1412
    self._test(["a", "a"], ["a"])
1413
    self._test(["a", "b"], ["a", "b"])
1414
    self._test(["a", "b", "a"], ["a", "b"])
1415

    
1416

    
1417
class TestFirstFree(unittest.TestCase):
1418
  """Test case for the FirstFree function"""
1419

    
1420
  def test(self):
1421
    """Test FirstFree"""
1422
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1423
    self.failUnlessEqual(FirstFree([]), None)
1424
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1425
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1426
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1427

    
1428

    
1429
class TestTailFile(testutils.GanetiTestCase):
1430
  """Test case for the TailFile function"""
1431

    
1432
  def testEmpty(self):
1433
    fname = self._CreateTempFile()
1434
    self.failUnlessEqual(TailFile(fname), [])
1435
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1436

    
1437
  def testAllLines(self):
1438
    data = ["test %d" % i for i in range(30)]
1439
    for i in range(30):
1440
      fname = self._CreateTempFile()
1441
      fd = open(fname, "w")
1442
      fd.write("\n".join(data[:i]))
1443
      if i > 0:
1444
        fd.write("\n")
1445
      fd.close()
1446
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1447

    
1448
  def testPartialLines(self):
1449
    data = ["test %d" % i for i in range(30)]
1450
    fname = self._CreateTempFile()
1451
    fd = open(fname, "w")
1452
    fd.write("\n".join(data))
1453
    fd.write("\n")
1454
    fd.close()
1455
    for i in range(1, 30):
1456
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1457

    
1458
  def testBigFile(self):
1459
    data = ["test %d" % i for i in range(30)]
1460
    fname = self._CreateTempFile()
1461
    fd = open(fname, "w")
1462
    fd.write("X" * 1048576)
1463
    fd.write("\n")
1464
    fd.write("\n".join(data))
1465
    fd.write("\n")
1466
    fd.close()
1467
    for i in range(1, 30):
1468
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1469

    
1470

    
1471
class _BaseFileLockTest:
1472
  """Test case for the FileLock class"""
1473

    
1474
  def testSharedNonblocking(self):
1475
    self.lock.Shared(blocking=False)
1476
    self.lock.Close()
1477

    
1478
  def testExclusiveNonblocking(self):
1479
    self.lock.Exclusive(blocking=False)
1480
    self.lock.Close()
1481

    
1482
  def testUnlockNonblocking(self):
1483
    self.lock.Unlock(blocking=False)
1484
    self.lock.Close()
1485

    
1486
  def testSharedBlocking(self):
1487
    self.lock.Shared(blocking=True)
1488
    self.lock.Close()
1489

    
1490
  def testExclusiveBlocking(self):
1491
    self.lock.Exclusive(blocking=True)
1492
    self.lock.Close()
1493

    
1494
  def testUnlockBlocking(self):
1495
    self.lock.Unlock(blocking=True)
1496
    self.lock.Close()
1497

    
1498
  def testSharedExclusiveUnlock(self):
1499
    self.lock.Shared(blocking=False)
1500
    self.lock.Exclusive(blocking=False)
1501
    self.lock.Unlock(blocking=False)
1502
    self.lock.Close()
1503

    
1504
  def testExclusiveSharedUnlock(self):
1505
    self.lock.Exclusive(blocking=False)
1506
    self.lock.Shared(blocking=False)
1507
    self.lock.Unlock(blocking=False)
1508
    self.lock.Close()
1509

    
1510
  def testSimpleTimeout(self):
1511
    # These will succeed on the first attempt, hence a short timeout
1512
    self.lock.Shared(blocking=True, timeout=10.0)
1513
    self.lock.Exclusive(blocking=False, timeout=10.0)
1514
    self.lock.Unlock(blocking=True, timeout=10.0)
1515
    self.lock.Close()
1516

    
1517
  @staticmethod
1518
  def _TryLockInner(filename, shared, blocking):
1519
    lock = utils.FileLock.Open(filename)
1520

    
1521
    if shared:
1522
      fn = lock.Shared
1523
    else:
1524
      fn = lock.Exclusive
1525

    
1526
    try:
1527
      # The timeout doesn't really matter as the parent process waits for us to
1528
      # finish anyway.
1529
      fn(blocking=blocking, timeout=0.01)
1530
    except errors.LockError, err:
1531
      return False
1532

    
1533
    return True
1534

    
1535
  def _TryLock(self, *args):
1536
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1537
                                      *args)
1538

    
1539
  def testTimeout(self):
1540
    for blocking in [True, False]:
1541
      self.lock.Exclusive(blocking=True)
1542
      self.failIf(self._TryLock(False, blocking))
1543
      self.failIf(self._TryLock(True, blocking))
1544

    
1545
      self.lock.Shared(blocking=True)
1546
      self.assert_(self._TryLock(True, blocking))
1547
      self.failIf(self._TryLock(False, blocking))
1548

    
1549
  def testCloseShared(self):
1550
    self.lock.Close()
1551
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1552

    
1553
  def testCloseExclusive(self):
1554
    self.lock.Close()
1555
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1556

    
1557
  def testCloseUnlock(self):
1558
    self.lock.Close()
1559
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1560

    
1561

    
1562
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1563
  TESTDATA = "Hello World\n" * 10
1564

    
1565
  def setUp(self):
1566
    testutils.GanetiTestCase.setUp(self)
1567

    
1568
    self.tmpfile = tempfile.NamedTemporaryFile()
1569
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1570
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1571

    
1572
    # Ensure "Open" didn't truncate file
1573
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1574

    
1575
  def tearDown(self):
1576
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1577

    
1578
    testutils.GanetiTestCase.tearDown(self)
1579

    
1580

    
1581
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1582
  def setUp(self):
1583
    self.tmpfile = tempfile.NamedTemporaryFile()
1584
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1585

    
1586

    
1587
class TestTimeFunctions(unittest.TestCase):
1588
  """Test case for time functions"""
1589

    
1590
  def runTest(self):
1591
    self.assertEqual(utils.SplitTime(1), (1, 0))
1592
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1593
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1594
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1595
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1596
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1597
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1598
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1599

    
1600
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1601

    
1602
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1603
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1604
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1605

    
1606
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1607
                     1218448917.481)
1608
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1609

    
1610
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1611
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1612
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1613
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1614
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1615

    
1616

    
1617
class FieldSetTestCase(unittest.TestCase):
1618
  """Test case for FieldSets"""
1619

    
1620
  def testSimpleMatch(self):
1621
    f = utils.FieldSet("a", "b", "c", "def")
1622
    self.failUnless(f.Matches("a"))
1623
    self.failIf(f.Matches("d"), "Substring matched")
1624
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1625
    self.failIf(f.NonMatching(["b", "c"]))
1626
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1627
    self.failUnless(f.NonMatching(["a", "d"]))
1628

    
1629
  def testRegexMatch(self):
1630
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1631
    self.failUnless(f.Matches("b1"))
1632
    self.failUnless(f.Matches("b99"))
1633
    self.failIf(f.Matches("b/1"))
1634
    self.failIf(f.NonMatching(["b12", "c"]))
1635
    self.failUnless(f.NonMatching(["a", "1"]))
1636

    
1637
class TestForceDictType(unittest.TestCase):
1638
  """Test case for ForceDictType"""
1639

    
1640
  def setUp(self):
1641
    self.key_types = {
1642
      'a': constants.VTYPE_INT,
1643
      'b': constants.VTYPE_BOOL,
1644
      'c': constants.VTYPE_STRING,
1645
      'd': constants.VTYPE_SIZE,
1646
      }
1647

    
1648
  def _fdt(self, dict, allowed_values=None):
1649
    if allowed_values is None:
1650
      ForceDictType(dict, self.key_types)
1651
    else:
1652
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1653

    
1654
    return dict
1655

    
1656
  def testSimpleDict(self):
1657
    self.assertEqual(self._fdt({}), {})
1658
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1659
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1660
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1661
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1662
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1663
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1664
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1665
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1666
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1667
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1668
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1669

    
1670
  def testErrors(self):
1671
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1672
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1673
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1674
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1675

    
1676

    
1677
class TestIsNormAbsPath(unittest.TestCase):
1678
  """Testing case for IsNormAbsPath"""
1679

    
1680
  def _pathTestHelper(self, path, result):
1681
    if result:
1682
      self.assert_(IsNormAbsPath(path),
1683
          "Path %s should result absolute and normalized" % path)
1684
    else:
1685
      self.assert_(not IsNormAbsPath(path),
1686
          "Path %s should not result absolute and normalized" % path)
1687

    
1688
  def testBase(self):
1689
    self._pathTestHelper('/etc', True)
1690
    self._pathTestHelper('/srv', True)
1691
    self._pathTestHelper('etc', False)
1692
    self._pathTestHelper('/etc/../root', False)
1693
    self._pathTestHelper('/etc/', False)
1694

    
1695

    
1696
class TestSafeEncode(unittest.TestCase):
1697
  """Test case for SafeEncode"""
1698

    
1699
  def testAscii(self):
1700
    for txt in [string.digits, string.letters, string.punctuation]:
1701
      self.failUnlessEqual(txt, SafeEncode(txt))
1702

    
1703
  def testDoubleEncode(self):
1704
    for i in range(255):
1705
      txt = SafeEncode(chr(i))
1706
      self.failUnlessEqual(txt, SafeEncode(txt))
1707

    
1708
  def testUnicode(self):
1709
    # 1024 is high enough to catch non-direct ASCII mappings
1710
    for i in range(1024):
1711
      txt = SafeEncode(unichr(i))
1712
      self.failUnlessEqual(txt, SafeEncode(txt))
1713

    
1714

    
1715
class TestFormatTime(unittest.TestCase):
1716
  """Testing case for FormatTime"""
1717

    
1718
  def testNone(self):
1719
    self.failUnlessEqual(FormatTime(None), "N/A")
1720

    
1721
  def testInvalid(self):
1722
    self.failUnlessEqual(FormatTime(()), "N/A")
1723

    
1724
  def testNow(self):
1725
    # tests that we accept time.time input
1726
    FormatTime(time.time())
1727
    # tests that we accept int input
1728
    FormatTime(int(time.time()))
1729

    
1730

    
1731
class RunInSeparateProcess(unittest.TestCase):
1732
  def test(self):
1733
    for exp in [True, False]:
1734
      def _child():
1735
        return exp
1736

    
1737
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1738

    
1739
  def testArgs(self):
1740
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1741
      def _child(carg1, carg2):
1742
        return carg1 == "Foo" and carg2 == arg
1743

    
1744
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1745

    
1746
  def testPid(self):
1747
    parent_pid = os.getpid()
1748

    
1749
    def _check():
1750
      return os.getpid() == parent_pid
1751

    
1752
    self.failIf(utils.RunInSeparateProcess(_check))
1753

    
1754
  def testSignal(self):
1755
    def _kill():
1756
      os.kill(os.getpid(), signal.SIGTERM)
1757

    
1758
    self.assertRaises(errors.GenericError,
1759
                      utils.RunInSeparateProcess, _kill)
1760

    
1761
  def testException(self):
1762
    def _exc():
1763
      raise errors.GenericError("This is a test")
1764

    
1765
    self.assertRaises(errors.GenericError,
1766
                      utils.RunInSeparateProcess, _exc)
1767

    
1768

    
1769
class TestFingerprintFile(unittest.TestCase):
1770
  def setUp(self):
1771
    self.tmpfile = tempfile.NamedTemporaryFile()
1772

    
1773
  def test(self):
1774
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1775
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1776

    
1777
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1778
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1779
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1780

    
1781

    
1782
class TestUnescapeAndSplit(unittest.TestCase):
1783
  """Testing case for UnescapeAndSplit"""
1784

    
1785
  def setUp(self):
1786
    # testing more that one separator for regexp safety
1787
    self._seps = [",", "+", "."]
1788

    
1789
  def testSimple(self):
1790
    a = ["a", "b", "c", "d"]
1791
    for sep in self._seps:
1792
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1793

    
1794
  def testEscape(self):
1795
    for sep in self._seps:
1796
      a = ["a", "b\\" + sep + "c", "d"]
1797
      b = ["a", "b" + sep + "c", "d"]
1798
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1799

    
1800
  def testDoubleEscape(self):
1801
    for sep in self._seps:
1802
      a = ["a", "b\\\\", "c", "d"]
1803
      b = ["a", "b\\", "c", "d"]
1804
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1805

    
1806
  def testThreeEscape(self):
1807
    for sep in self._seps:
1808
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1809
      b = ["a", "b\\" + sep + "c", "d"]
1810
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1811

    
1812

    
1813
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1814
  def setUp(self):
1815
    self.tmpdir = tempfile.mkdtemp()
1816

    
1817
  def tearDown(self):
1818
    shutil.rmtree(self.tmpdir)
1819

    
1820
  def _checkRsaPrivateKey(self, key):
1821
    lines = key.splitlines()
1822
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1823
            "-----END RSA PRIVATE KEY-----" in lines)
1824

    
1825
  def _checkCertificate(self, cert):
1826
    lines = cert.splitlines()
1827
    return ("-----BEGIN CERTIFICATE-----" in lines and
1828
            "-----END CERTIFICATE-----" in lines)
1829

    
1830
  def test(self):
1831
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1832
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1833
      self._checkRsaPrivateKey(key_pem)
1834
      self._checkCertificate(cert_pem)
1835

    
1836
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1837
                                           key_pem)
1838
      self.assert_(key.bits() >= 1024)
1839
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1840
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1841

    
1842
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1843
                                             cert_pem)
1844
      self.failIf(x509.has_expired())
1845
      self.assertEqual(x509.get_issuer().CN, common_name)
1846
      self.assertEqual(x509.get_subject().CN, common_name)
1847
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1848

    
1849
  def testLegacy(self):
1850
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1851

    
1852
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1853

    
1854
    cert1 = utils.ReadFile(cert1_filename)
1855

    
1856
    self.assert_(self._checkRsaPrivateKey(cert1))
1857
    self.assert_(self._checkCertificate(cert1))
1858

    
1859

    
1860
class TestPathJoin(unittest.TestCase):
1861
  """Testing case for PathJoin"""
1862

    
1863
  def testBasicItems(self):
1864
    mlist = ["/a", "b", "c"]
1865
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1866

    
1867
  def testNonAbsPrefix(self):
1868
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1869

    
1870
  def testBackTrack(self):
1871
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1872

    
1873
  def testMultiAbs(self):
1874
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1875

    
1876

    
1877
class TestHostInfo(unittest.TestCase):
1878
  """Testing case for HostInfo"""
1879

    
1880
  def testUppercase(self):
1881
    data = "AbC.example.com"
1882
    self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1883

    
1884
  def testTooLongName(self):
1885
    data = "a.b." + "c" * 255
1886
    self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1887

    
1888
  def testTrailingDot(self):
1889
    data = "a.b.c"
1890
    self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1891

    
1892
  def testInvalidName(self):
1893
    data = [
1894
      "a b",
1895
      "a/b",
1896
      ".a.b",
1897
      "a..b",
1898
      ]
1899
    for value in data:
1900
      self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1901

    
1902
  def testValidName(self):
1903
    data = [
1904
      "a.b",
1905
      "a-b",
1906
      "a_b",
1907
      "a.b.c",
1908
      ]
1909
    for value in data:
1910
      HostInfo.NormalizeName(value)
1911

    
1912

    
1913
class TestParseAsn1Generalizedtime(unittest.TestCase):
1914
  def test(self):
1915
    # UTC
1916
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1917
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1918
                     1266860512)
1919
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1920
                     (2**31) - 1)
1921

    
1922
    # With offset
1923
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1924
                     1266860512)
1925
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1926
                     1266931012)
1927
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1928
                     1266931088)
1929
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1930
                     1266931295)
1931
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1932
                     3600)
1933

    
1934
    # Leap seconds are not supported by datetime.datetime
1935
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1936
                      "19841231235960+0000")
1937
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1938
                      "19920630235960+0000")
1939

    
1940
    # Errors
1941
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1942
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1943
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1944
                      "20100222174152")
1945
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1946
                      "Mon Feb 22 17:47:02 UTC 2010")
1947
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1948
                      "2010-02-22 17:42:02")
1949

    
1950

    
1951
class TestGetX509CertValidity(testutils.GanetiTestCase):
1952
  def setUp(self):
1953
    testutils.GanetiTestCase.setUp(self)
1954

    
1955
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1956

    
1957
    # Test whether we have pyOpenSSL 0.7 or above
1958
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1959

    
1960
    if not self.pyopenssl0_7:
1961
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1962
                    " function correctly")
1963

    
1964
  def _LoadCert(self, name):
1965
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1966
                                           self._ReadTestData(name))
1967

    
1968
  def test(self):
1969
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1970
    if self.pyopenssl0_7:
1971
      self.assertEqual(validity, (1266919967, 1267524767))
1972
    else:
1973
      self.assertEqual(validity, (None, None))
1974

    
1975

    
1976
class TestSignX509Certificate(unittest.TestCase):
1977
  KEY = "My private key!"
1978
  KEY_OTHER = "Another key"
1979

    
1980
  def test(self):
1981
    # Generate certificate valid for 5 minutes
1982
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1983

    
1984
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1985
                                           cert_pem)
1986

    
1987
    # No signature at all
1988
    self.assertRaises(errors.GenericError,
1989
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1990

    
1991
    # Invalid input
1992
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1993
                      "", self.KEY)
1994
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1995
                      "X-Ganeti-Signature: \n", self.KEY)
1996
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1997
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1998
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1999
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
2000
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2001
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
2002

    
2003
    # Invalid salt
2004
    for salt in list("-_@$,:;/\\ \t\n"):
2005
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
2006
                        cert_pem, self.KEY, "foo%sbar" % salt)
2007

    
2008
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
2009
                 utils.GenerateSecret(numbytes=4),
2010
                 utils.GenerateSecret(numbytes=16),
2011
                 "{123:456}".encode("hex")]:
2012
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2013

    
2014
      self._Check(cert, salt, signed_pem)
2015

    
2016
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2017
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2018
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2019
                               "lines----\n------ at\nthe end!"))
2020

    
2021
  def _Check(self, cert, salt, pem):
2022
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2023
    self.assertEqual(salt, salt2)
2024
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2025

    
2026
    # Other key
2027
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2028
                      pem, self.KEY_OTHER)
2029

    
2030

    
2031
class TestMakedirs(unittest.TestCase):
2032
  def setUp(self):
2033
    self.tmpdir = tempfile.mkdtemp()
2034

    
2035
  def tearDown(self):
2036
    shutil.rmtree(self.tmpdir)
2037

    
2038
  def testNonExisting(self):
2039
    path = utils.PathJoin(self.tmpdir, "foo")
2040
    utils.Makedirs(path)
2041
    self.assert_(os.path.isdir(path))
2042

    
2043
  def testExisting(self):
2044
    path = utils.PathJoin(self.tmpdir, "foo")
2045
    os.mkdir(path)
2046
    utils.Makedirs(path)
2047
    self.assert_(os.path.isdir(path))
2048

    
2049
  def testRecursiveNonExisting(self):
2050
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
2051
    utils.Makedirs(path)
2052
    self.assert_(os.path.isdir(path))
2053

    
2054
  def testRecursiveExisting(self):
2055
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
2056
    self.assert_(not os.path.exists(path))
2057
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
2058
    utils.Makedirs(path)
2059
    self.assert_(os.path.isdir(path))
2060

    
2061

    
2062
class TestRetry(testutils.GanetiTestCase):
2063
  def setUp(self):
2064
    testutils.GanetiTestCase.setUp(self)
2065
    self.retries = 0
2066

    
2067
  @staticmethod
2068
  def _RaiseRetryAgain():
2069
    raise utils.RetryAgain()
2070

    
2071
  @staticmethod
2072
  def _RaiseRetryAgainWithArg(args):
2073
    raise utils.RetryAgain(*args)
2074

    
2075
  def _WrongNestedLoop(self):
2076
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2077

    
2078
  def _RetryAndSucceed(self, retries):
2079
    if self.retries < retries:
2080
      self.retries += 1
2081
      raise utils.RetryAgain()
2082
    else:
2083
      return True
2084

    
2085
  def testRaiseTimeout(self):
2086
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2087
                          self._RaiseRetryAgain, 0.01, 0.02)
2088
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2089
                          self._RetryAndSucceed, 0.01, 0, args=[1])
2090
    self.failUnlessEqual(self.retries, 1)
2091

    
2092
  def testComplete(self):
2093
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2094
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2095
                         True)
2096
    self.failUnlessEqual(self.retries, 2)
2097

    
2098
  def testNestedLoop(self):
2099
    try:
2100
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2101
                            self._WrongNestedLoop, 0, 1)
2102
    except utils.RetryTimeout:
2103
      self.fail("Didn't detect inner loop's exception")
2104

    
2105
  def testTimeoutArgument(self):
2106
    retry_arg="my_important_debugging_message"
2107
    try:
2108
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2109
    except utils.RetryTimeout, err:
2110
      self.failUnlessEqual(err.args, (retry_arg, ))
2111
    else:
2112
      self.fail("Expected timeout didn't happen")
2113

    
2114
  def testRaiseInnerWithExc(self):
2115
    retry_arg="my_important_debugging_message"
2116
    try:
2117
      try:
2118
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2119
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
2120
      except utils.RetryTimeout, err:
2121
        err.RaiseInner()
2122
      else:
2123
        self.fail("Expected timeout didn't happen")
2124
    except errors.GenericError, err:
2125
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2126
    else:
2127
      self.fail("Expected GenericError didn't happen")
2128

    
2129
  def testRaiseInnerWithMsg(self):
2130
    retry_arg="my_important_debugging_message"
2131
    try:
2132
      try:
2133
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2134
                    args=[[retry_arg, retry_arg]])
2135
      except utils.RetryTimeout, err:
2136
        err.RaiseInner()
2137
      else:
2138
        self.fail("Expected timeout didn't happen")
2139
    except utils.RetryTimeout, err:
2140
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2141
    else:
2142
      self.fail("Expected RetryTimeout didn't happen")
2143

    
2144

    
2145
class TestLineSplitter(unittest.TestCase):
2146
  def test(self):
2147
    lines = []
2148
    ls = utils.LineSplitter(lines.append)
2149
    ls.write("Hello World\n")
2150
    self.assertEqual(lines, [])
2151
    ls.write("Foo\n Bar\r\n ")
2152
    ls.write("Baz")
2153
    ls.write("Moo")
2154
    self.assertEqual(lines, [])
2155
    ls.flush()
2156
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2157
    ls.close()
2158
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2159

    
2160
  def _testExtra(self, line, all_lines, p1, p2):
2161
    self.assertEqual(p1, 999)
2162
    self.assertEqual(p2, "extra")
2163
    all_lines.append(line)
2164

    
2165
  def testExtraArgsNoFlush(self):
2166
    lines = []
2167
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2168
    ls.write("\n\nHello World\n")
2169
    ls.write("Foo\n Bar\r\n ")
2170
    ls.write("")
2171
    ls.write("Baz")
2172
    ls.write("Moo\n\nx\n")
2173
    self.assertEqual(lines, [])
2174
    ls.close()
2175
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2176
                             "", "x"])
2177

    
2178

    
2179
class TestReadLockedPidFile(unittest.TestCase):
2180
  def setUp(self):
2181
    self.tmpdir = tempfile.mkdtemp()
2182

    
2183
  def tearDown(self):
2184
    shutil.rmtree(self.tmpdir)
2185

    
2186
  def testNonExistent(self):
2187
    path = utils.PathJoin(self.tmpdir, "nonexist")
2188
    self.assert_(utils.ReadLockedPidFile(path) is None)
2189

    
2190
  def testUnlocked(self):
2191
    path = utils.PathJoin(self.tmpdir, "pid")
2192
    utils.WriteFile(path, data="123")
2193
    self.assert_(utils.ReadLockedPidFile(path) is None)
2194

    
2195
  def testLocked(self):
2196
    path = utils.PathJoin(self.tmpdir, "pid")
2197
    utils.WriteFile(path, data="123")
2198

    
2199
    fl = utils.FileLock.Open(path)
2200
    try:
2201
      fl.Exclusive(blocking=True)
2202

    
2203
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2204
    finally:
2205
      fl.Close()
2206

    
2207
    self.assert_(utils.ReadLockedPidFile(path) is None)
2208

    
2209
  def testError(self):
2210
    path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2211
    utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2212
    # open(2) should return ENOTDIR
2213
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2214

    
2215

    
2216
class TestCertVerification(testutils.GanetiTestCase):
2217
  def setUp(self):
2218
    testutils.GanetiTestCase.setUp(self)
2219

    
2220
    self.tmpdir = tempfile.mkdtemp()
2221

    
2222
  def tearDown(self):
2223
    shutil.rmtree(self.tmpdir)
2224

    
2225
  def testVerifyCertificate(self):
2226
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2227
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2228
                                           cert_pem)
2229

    
2230
    # Not checking return value as this certificate is expired
2231
    utils.VerifyX509Certificate(cert, 30, 7)
2232

    
2233

    
2234
class TestVerifyCertificateInner(unittest.TestCase):
2235
  def test(self):
2236
    vci = utils._VerifyCertificateInner
2237

    
2238
    # Valid
2239
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2240
                     (None, None))
2241

    
2242
    # Not yet valid
2243
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2244
    self.assertEqual(errcode, utils.CERT_WARNING)
2245

    
2246
    # Expiring soon
2247
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2248
    self.assertEqual(errcode, utils.CERT_ERROR)
2249

    
2250
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2251
    self.assertEqual(errcode, utils.CERT_WARNING)
2252

    
2253
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2254
    self.assertEqual(errcode, None)
2255

    
2256
    # Expired
2257
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2258
    self.assertEqual(errcode, utils.CERT_ERROR)
2259

    
2260
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2261
    self.assertEqual(errcode, utils.CERT_ERROR)
2262

    
2263
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2264
    self.assertEqual(errcode, utils.CERT_ERROR)
2265

    
2266
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2267
    self.assertEqual(errcode, utils.CERT_ERROR)
2268

    
2269

    
2270
class TestHmacFunctions(unittest.TestCase):
2271
  # Digests can be checked with "openssl sha1 -hmac $key"
2272
  def testSha1Hmac(self):
2273
    self.assertEqual(utils.Sha1Hmac("", ""),
2274
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2275
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2276
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2277
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2278
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2279

    
2280
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2281
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2282
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2283

    
2284
  def testSha1HmacSalt(self):
2285
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2286
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2287
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2288
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2289
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2290
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2291

    
2292
  def testVerifySha1Hmac(self):
2293
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2294
                                               "7d64b71fb76370690e1d")))
2295
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2296
                                      ("f904c2476527c6d3e660"
2297
                                       "9ab683c66fa0652cb1dc")))
2298

    
2299
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2300
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2301
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2302
                                      digest.lower()))
2303
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2304
                                      digest.upper()))
2305
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2306
                                      digest.title()))
2307

    
2308
  def testVerifySha1HmacSalt(self):
2309
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2310
                                      ("17a4adc34d69c0d367d4"
2311
                                       "ffbef96fd41d4df7a6e8"),
2312
                                      salt="abc9"))
2313
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2314
                                      ("7f264f8114c9066afc9b"
2315
                                       "b7636e1786d996d3cc0d"),
2316
                                      salt="xyz0"))
2317

    
2318

    
2319
class TestIgnoreSignals(unittest.TestCase):
2320
  """Test the IgnoreSignals decorator"""
2321

    
2322
  @staticmethod
2323
  def _Raise(exception):
2324
    raise exception
2325

    
2326
  @staticmethod
2327
  def _Return(rval):
2328
    return rval
2329

    
2330
  def testIgnoreSignals(self):
2331
    sock_err_intr = socket.error(errno.EINTR, "Message")
2332
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2333

    
2334
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2335
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2336

    
2337
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2338
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2339
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2340
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2341

    
2342
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2343
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2344
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2345
                      sock_err_inval)
2346
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2347
                      env_err_inval)
2348

    
2349
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2350
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2351

    
2352

    
2353
class TestEnsureDirs(unittest.TestCase):
2354
  """Tests for EnsureDirs"""
2355

    
2356
  def setUp(self):
2357
    self.dir = tempfile.mkdtemp()
2358
    self.old_umask = os.umask(0777)
2359

    
2360
  def testEnsureDirs(self):
2361
    utils.EnsureDirs([
2362
        (utils.PathJoin(self.dir, "foo"), 0777),
2363
        (utils.PathJoin(self.dir, "bar"), 0000),
2364
        ])
2365
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2366
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2367

    
2368
  def tearDown(self):
2369
    os.rmdir(utils.PathJoin(self.dir, "foo"))
2370
    os.rmdir(utils.PathJoin(self.dir, "bar"))
2371
    os.rmdir(self.dir)
2372
    os.umask(self.old_umask)
2373

    
2374

    
2375
class TestFormatSeconds(unittest.TestCase):
2376
  def test(self):
2377
    self.assertEqual(utils.FormatSeconds(1), "1s")
2378
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2379
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2380
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2381
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2382
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2383
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2384
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2385
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2386
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2387
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2388

    
2389
  def testFloat(self):
2390
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2391
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2392
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2393
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2394

    
2395

    
2396
class RunIgnoreProcessNotFound(unittest.TestCase):
2397
  @staticmethod
2398
  def _WritePid(fd):
2399
    os.write(fd, str(os.getpid()))
2400
    os.close(fd)
2401
    return True
2402

    
2403
  def test(self):
2404
    (pid_read_fd, pid_write_fd) = os.pipe()
2405

    
2406
    # Start short-lived process which writes its PID to pipe
2407
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2408
    os.close(pid_write_fd)
2409

    
2410
    # Read PID from pipe
2411
    pid = int(os.read(pid_read_fd, 1024))
2412
    os.close(pid_read_fd)
2413

    
2414
    # Try to send signal to process which exited recently
2415
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2416

    
2417

    
2418
if __name__ == '__main__':
2419
  testutils.GanetiTestProgram()