Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ c0c3fa27

History | View | Annotate | Download (83 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import stat
31
import signal
32
import socket
33
import shutil
34
import re
35
import select
36
import string
37
import fcntl
38
import OpenSSL
39
import warnings
40
import distutils.version
41
import glob
42
import errno
43

    
44
import ganeti
45
import testutils
46
from ganeti import constants
47
from ganeti import compat
48
from ganeti import utils
49
from ganeti import errors
50
from ganeti import serializer
51
from ganeti.utils import IsProcessAlive, RunCmd, \
52
     RemoveFile, MatchNameComponent, FormatUnit, \
53
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
54
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
55
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
56
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
57
     UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
58

    
59
from ganeti.errors import LockError, UnitParseError, GenericError, \
60
     ProgrammerError, OpPrereqError
61

    
62

    
63
class TestIsProcessAlive(unittest.TestCase):
64
  """Testing case for IsProcessAlive"""
65

    
66
  def testExists(self):
67
    mypid = os.getpid()
68
    self.assert_(IsProcessAlive(mypid),
69
                 "can't find myself running")
70

    
71
  def testNotExisting(self):
72
    pid_non_existing = os.fork()
73
    if pid_non_existing == 0:
74
      os._exit(0)
75
    elif pid_non_existing < 0:
76
      raise SystemError("can't fork")
77
    os.waitpid(pid_non_existing, 0)
78
    self.assertFalse(IsProcessAlive(pid_non_existing),
79
                     "nonexisting process detected")
80

    
81

    
82
class TestGetProcStatusPath(unittest.TestCase):
83
  def test(self):
84
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
85
    self.assertNotEqual(utils._GetProcStatusPath(1),
86
                        utils._GetProcStatusPath(2))
87

    
88

    
89
class TestIsProcessHandlingSignal(unittest.TestCase):
90
  def setUp(self):
91
    self.tmpdir = tempfile.mkdtemp()
92

    
93
  def tearDown(self):
94
    shutil.rmtree(self.tmpdir)
95

    
96
  def testParseSigsetT(self):
97
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
98
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
99
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
100
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
101
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
102
                     set([2, 10, 32, 33]))
103
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
104
                     set([2, 32, 33]))
105
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
106
                     set([2, 28, 32, 33]))
107
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
108
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
109
                          24, 25, 26, 28, 31]))
110
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
111

    
112
  def testGetProcStatusField(self):
113
    for field in ["SigCgt", "Name", "FDSize"]:
114
      for value in ["", "0", "cat", "  1234 KB"]:
115
        pstatus = "\n".join([
116
          "VmPeak: 999 kB",
117
          "%s: %s" % (field, value),
118
          "TracerPid: 0",
119
          ])
120
        result = utils._GetProcStatusField(pstatus, field)
121
        self.assertEqual(result, value.strip())
122

    
123
  def test(self):
124
    sp = utils.PathJoin(self.tmpdir, "status")
125

    
126
    utils.WriteFile(sp, data="\n".join([
127
      "Name:   bash",
128
      "State:  S (sleeping)",
129
      "SleepAVG:       98%",
130
      "Pid:    22250",
131
      "PPid:   10858",
132
      "TracerPid:      0",
133
      "SigBlk: 0000000000010000",
134
      "SigIgn: 0000000000384004",
135
      "SigCgt: 000000004b813efb",
136
      "CapEff: 0000000000000000",
137
      ]))
138

    
139
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
140

    
141
  def testNoSigCgt(self):
142
    sp = utils.PathJoin(self.tmpdir, "status")
143

    
144
    utils.WriteFile(sp, data="\n".join([
145
      "Name:   bash",
146
      ]))
147

    
148
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
149
                      1234, 10, status_path=sp)
150

    
151
  def testNoSuchFile(self):
152
    sp = utils.PathJoin(self.tmpdir, "notexist")
153

    
154
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
155

    
156
  @staticmethod
157
  def _TestRealProcess():
158
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
159
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160
      raise Exception("SIGUSR1 is handled when it should not be")
161

    
162
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
163
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164
      raise Exception("SIGUSR1 is not handled when it should be")
165

    
166
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
167
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
168
      raise Exception("SIGUSR1 is not handled when it should be")
169

    
170
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
171
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
172
      raise Exception("SIGUSR1 is handled when it should not be")
173

    
174
    return True
175

    
176
  def testRealProcess(self):
177
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
178

    
179

    
180
class TestPidFileFunctions(unittest.TestCase):
181
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
182

    
183
  def setUp(self):
184
    self.dir = tempfile.mkdtemp()
185
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
186
    utils.DaemonPidFileName = self.f_dpn
187

    
188
  def testPidFileFunctions(self):
189
    pid_file = self.f_dpn('test')
190
    utils.WritePidFile('test')
191
    self.failUnless(os.path.exists(pid_file),
192
                    "PID file should have been created")
193
    read_pid = utils.ReadPidFile(pid_file)
194
    self.failUnlessEqual(read_pid, os.getpid())
195
    self.failUnless(utils.IsProcessAlive(read_pid))
196
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
197
    utils.RemovePidFile('test')
198
    self.failIf(os.path.exists(pid_file),
199
                "PID file should not exist anymore")
200
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
201
                         "ReadPidFile should return 0 for missing pid file")
202
    fh = open(pid_file, "w")
203
    fh.write("blah\n")
204
    fh.close()
205
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
206
                         "ReadPidFile should return 0 for invalid pid file")
207
    utils.RemovePidFile('test')
208
    self.failIf(os.path.exists(pid_file),
209
                "PID file should not exist anymore")
210

    
211
  def testKill(self):
212
    pid_file = self.f_dpn('child')
213
    r_fd, w_fd = os.pipe()
214
    new_pid = os.fork()
215
    if new_pid == 0: #child
216
      utils.WritePidFile('child')
217
      os.write(w_fd, 'a')
218
      signal.pause()
219
      os._exit(0)
220
      return
221
    # else we are in the parent
222
    # wait until the child has written the pid file
223
    os.read(r_fd, 1)
224
    read_pid = utils.ReadPidFile(pid_file)
225
    self.failUnlessEqual(read_pid, new_pid)
226
    self.failUnless(utils.IsProcessAlive(new_pid))
227
    utils.KillProcess(new_pid, waitpid=True)
228
    self.failIf(utils.IsProcessAlive(new_pid))
229
    utils.RemovePidFile('child')
230
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
231

    
232
  def tearDown(self):
233
    for name in os.listdir(self.dir):
234
      os.unlink(os.path.join(self.dir, name))
235
    os.rmdir(self.dir)
236

    
237

    
238
class TestRunCmd(testutils.GanetiTestCase):
239
  """Testing case for the RunCmd function"""
240

    
241
  def setUp(self):
242
    testutils.GanetiTestCase.setUp(self)
243
    self.magic = time.ctime() + " ganeti test"
244
    self.fname = self._CreateTempFile()
245

    
246
  def testOk(self):
247
    """Test successful exit code"""
248
    result = RunCmd("/bin/sh -c 'exit 0'")
249
    self.assertEqual(result.exit_code, 0)
250
    self.assertEqual(result.output, "")
251

    
252
  def testFail(self):
253
    """Test fail exit code"""
254
    result = RunCmd("/bin/sh -c 'exit 1'")
255
    self.assertEqual(result.exit_code, 1)
256
    self.assertEqual(result.output, "")
257

    
258
  def testStdout(self):
259
    """Test standard output"""
260
    cmd = 'echo -n "%s"' % self.magic
261
    result = RunCmd("/bin/sh -c '%s'" % cmd)
262
    self.assertEqual(result.stdout, self.magic)
263
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
264
    self.assertEqual(result.output, "")
265
    self.assertFileContent(self.fname, self.magic)
266

    
267
  def testStderr(self):
268
    """Test standard error"""
269
    cmd = 'echo -n "%s"' % self.magic
270
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
271
    self.assertEqual(result.stderr, self.magic)
272
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
273
    self.assertEqual(result.output, "")
274
    self.assertFileContent(self.fname, self.magic)
275

    
276
  def testCombined(self):
277
    """Test combined output"""
278
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
279
    expected = "A" + self.magic + "B" + self.magic
280
    result = RunCmd("/bin/sh -c '%s'" % cmd)
281
    self.assertEqual(result.output, expected)
282
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
283
    self.assertEqual(result.output, "")
284
    self.assertFileContent(self.fname, expected)
285

    
286
  def testSignal(self):
287
    """Test signal"""
288
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
289
    self.assertEqual(result.signal, 15)
290
    self.assertEqual(result.output, "")
291

    
292
  def testListRun(self):
293
    """Test list runs"""
294
    result = RunCmd(["true"])
295
    self.assertEqual(result.signal, None)
296
    self.assertEqual(result.exit_code, 0)
297
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
298
    self.assertEqual(result.signal, None)
299
    self.assertEqual(result.exit_code, 1)
300
    result = RunCmd(["echo", "-n", self.magic])
301
    self.assertEqual(result.signal, None)
302
    self.assertEqual(result.exit_code, 0)
303
    self.assertEqual(result.stdout, self.magic)
304

    
305
  def testFileEmptyOutput(self):
306
    """Test file output"""
307
    result = RunCmd(["true"], output=self.fname)
308
    self.assertEqual(result.signal, None)
309
    self.assertEqual(result.exit_code, 0)
310
    self.assertFileContent(self.fname, "")
311

    
312
  def testLang(self):
313
    """Test locale environment"""
314
    old_env = os.environ.copy()
315
    try:
316
      os.environ["LANG"] = "en_US.UTF-8"
317
      os.environ["LC_ALL"] = "en_US.UTF-8"
318
      result = RunCmd(["locale"])
319
      for line in result.output.splitlines():
320
        key, value = line.split("=", 1)
321
        # Ignore these variables, they're overridden by LC_ALL
322
        if key == "LANG" or key == "LANGUAGE":
323
          continue
324
        self.failIf(value and value != "C" and value != '"C"',
325
            "Variable %s is set to the invalid value '%s'" % (key, value))
326
    finally:
327
      os.environ = old_env
328

    
329
  def testDefaultCwd(self):
330
    """Test default working directory"""
331
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
332

    
333
  def testCwd(self):
334
    """Test default working directory"""
335
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
336
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
337
    cwd = os.getcwd()
338
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
339

    
340
  def testResetEnv(self):
341
    """Test environment reset functionality"""
342
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
343
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
344
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
345

    
346

    
347
class TestRunParts(unittest.TestCase):
348
  """Testing case for the RunParts function"""
349

    
350
  def setUp(self):
351
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
352

    
353
  def tearDown(self):
354
    shutil.rmtree(self.rundir)
355

    
356
  def testEmpty(self):
357
    """Test on an empty dir"""
358
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
359

    
360
  def testSkipWrongName(self):
361
    """Test that wrong files are skipped"""
362
    fname = os.path.join(self.rundir, "00test.dot")
363
    utils.WriteFile(fname, data="")
364
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
365
    relname = os.path.basename(fname)
366
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
367
                         [(relname, constants.RUNPARTS_SKIP, None)])
368

    
369
  def testSkipNonExec(self):
370
    """Test that non executable files are skipped"""
371
    fname = os.path.join(self.rundir, "00test")
372
    utils.WriteFile(fname, data="")
373
    relname = os.path.basename(fname)
374
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
375
                         [(relname, constants.RUNPARTS_SKIP, None)])
376

    
377
  def testError(self):
378
    """Test error on a broken executable"""
379
    fname = os.path.join(self.rundir, "00test")
380
    utils.WriteFile(fname, data="")
381
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
382
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
383
    self.failUnlessEqual(relname, os.path.basename(fname))
384
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
385
    self.failUnless(error)
386

    
387
  def testSorted(self):
388
    """Test executions are sorted"""
389
    files = []
390
    files.append(os.path.join(self.rundir, "64test"))
391
    files.append(os.path.join(self.rundir, "00test"))
392
    files.append(os.path.join(self.rundir, "42test"))
393

    
394
    for fname in files:
395
      utils.WriteFile(fname, data="")
396

    
397
    results = RunParts(self.rundir, reset_env=True)
398

    
399
    for fname in sorted(files):
400
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
401

    
402
  def testOk(self):
403
    """Test correct execution"""
404
    fname = os.path.join(self.rundir, "00test")
405
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
406
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408
    self.failUnlessEqual(relname, os.path.basename(fname))
409
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410
    self.failUnlessEqual(runresult.stdout, "ciao")
411

    
412
  def testRunFail(self):
413
    """Test correct execution, with run failure"""
414
    fname = os.path.join(self.rundir, "00test")
415
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
416
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
417
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
418
    self.failUnlessEqual(relname, os.path.basename(fname))
419
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
420
    self.failUnlessEqual(runresult.exit_code, 1)
421
    self.failUnless(runresult.failed)
422

    
423
  def testRunMix(self):
424
    files = []
425
    files.append(os.path.join(self.rundir, "00test"))
426
    files.append(os.path.join(self.rundir, "42test"))
427
    files.append(os.path.join(self.rundir, "64test"))
428
    files.append(os.path.join(self.rundir, "99test"))
429

    
430
    files.sort()
431

    
432
    # 1st has errors in execution
433
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
434
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
435

    
436
    # 2nd is skipped
437
    utils.WriteFile(files[1], data="")
438

    
439
    # 3rd cannot execute properly
440
    utils.WriteFile(files[2], data="")
441
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
442

    
443
    # 4th execs
444
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
445
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
446

    
447
    results = RunParts(self.rundir, reset_env=True)
448

    
449
    (relname, status, runresult) = results[0]
450
    self.failUnlessEqual(relname, os.path.basename(files[0]))
451
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
452
    self.failUnlessEqual(runresult.exit_code, 1)
453
    self.failUnless(runresult.failed)
454

    
455
    (relname, status, runresult) = results[1]
456
    self.failUnlessEqual(relname, os.path.basename(files[1]))
457
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
458
    self.failUnlessEqual(runresult, None)
459

    
460
    (relname, status, runresult) = results[2]
461
    self.failUnlessEqual(relname, os.path.basename(files[2]))
462
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
463
    self.failUnless(runresult)
464

    
465
    (relname, status, runresult) = results[3]
466
    self.failUnlessEqual(relname, os.path.basename(files[3]))
467
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
468
    self.failUnlessEqual(runresult.output, "ciao")
469
    self.failUnlessEqual(runresult.exit_code, 0)
470
    self.failUnless(not runresult.failed)
471

    
472

    
473
class TestStartDaemon(testutils.GanetiTestCase):
474
  def setUp(self):
475
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
476
    self.tmpfile = os.path.join(self.tmpdir, "test")
477

    
478
  def tearDown(self):
479
    shutil.rmtree(self.tmpdir)
480

    
481
  def testShell(self):
482
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
483
    self._wait(self.tmpfile, 60.0, "Hello World")
484

    
485
  def testShellOutput(self):
486
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
487
    self._wait(self.tmpfile, 60.0, "Hello World")
488

    
489
  def testNoShellNoOutput(self):
490
    utils.StartDaemon(["pwd"])
491

    
492
  def testNoShellNoOutputTouch(self):
493
    testfile = os.path.join(self.tmpdir, "check")
494
    self.failIf(os.path.exists(testfile))
495
    utils.StartDaemon(["touch", testfile])
496
    self._wait(testfile, 60.0, "")
497

    
498
  def testNoShellOutput(self):
499
    utils.StartDaemon(["pwd"], output=self.tmpfile)
500
    self._wait(self.tmpfile, 60.0, "/")
501

    
502
  def testNoShellOutputCwd(self):
503
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
504
    self._wait(self.tmpfile, 60.0, os.getcwd())
505

    
506
  def testShellEnv(self):
507
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
508
                      env={ "GNT_TEST_VAR": "Hello World", })
509
    self._wait(self.tmpfile, 60.0, "Hello World")
510

    
511
  def testNoShellEnv(self):
512
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
513
                      env={ "GNT_TEST_VAR": "Hello World", })
514
    self._wait(self.tmpfile, 60.0, "Hello World")
515

    
516
  def testOutputFd(self):
517
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
518
    try:
519
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
520
    finally:
521
      os.close(fd)
522
    self._wait(self.tmpfile, 60.0, os.getcwd())
523

    
524
  def testPid(self):
525
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
526
    self._wait(self.tmpfile, 60.0, str(pid))
527

    
528
  def testPidFile(self):
529
    pidfile = os.path.join(self.tmpdir, "pid")
530
    checkfile = os.path.join(self.tmpdir, "abort")
531

    
532
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
533
                            output=self.tmpfile)
534
    try:
535
      fd = os.open(pidfile, os.O_RDONLY)
536
      try:
537
        # Check file is locked
538
        self.assertRaises(errors.LockError, utils.LockFile, fd)
539

    
540
        pidtext = os.read(fd, 100)
541
      finally:
542
        os.close(fd)
543

    
544
      self.assertEqual(int(pidtext.strip()), pid)
545

    
546
      self.assert_(utils.IsProcessAlive(pid))
547
    finally:
548
      # No matter what happens, kill daemon
549
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
550
      self.failIf(utils.IsProcessAlive(pid))
551

    
552
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
553

    
554
  def _wait(self, path, timeout, expected):
555
    # Due to the asynchronous nature of daemon processes, polling is necessary.
556
    # A timeout makes sure the test doesn't hang forever.
557
    def _CheckFile():
558
      if not (os.path.isfile(path) and
559
              utils.ReadFile(path).strip() == expected):
560
        raise utils.RetryAgain()
561

    
562
    try:
563
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
564
    except utils.RetryTimeout:
565
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
566
                " didn't write the correct output" % timeout)
567

    
568
  def testError(self):
569
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
570
                      ["./does-NOT-EXIST/here/0123456789"])
571
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
572
                      ["./does-NOT-EXIST/here/0123456789"],
573
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
574
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
575
                      ["./does-NOT-EXIST/here/0123456789"],
576
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
577
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
578
                      ["./does-NOT-EXIST/here/0123456789"],
579
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
580

    
581
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
582
    try:
583
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
584
                        ["./does-NOT-EXIST/here/0123456789"],
585
                        output=self.tmpfile, output_fd=fd)
586
    finally:
587
      os.close(fd)
588

    
589

    
590
class TestSetCloseOnExecFlag(unittest.TestCase):
591
  """Tests for SetCloseOnExecFlag"""
592

    
593
  def setUp(self):
594
    self.tmpfile = tempfile.TemporaryFile()
595

    
596
  def testEnable(self):
597
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
598
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
599
                    fcntl.FD_CLOEXEC)
600

    
601
  def testDisable(self):
602
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
603
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
604
                fcntl.FD_CLOEXEC)
605

    
606

    
607
class TestSetNonblockFlag(unittest.TestCase):
608
  def setUp(self):
609
    self.tmpfile = tempfile.TemporaryFile()
610

    
611
  def testEnable(self):
612
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
613
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
614
                    os.O_NONBLOCK)
615

    
616
  def testDisable(self):
617
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
618
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
619
                os.O_NONBLOCK)
620

    
621

    
622
class TestRemoveFile(unittest.TestCase):
623
  """Test case for the RemoveFile function"""
624

    
625
  def setUp(self):
626
    """Create a temp dir and file for each case"""
627
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
628
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
629
    os.close(fd)
630

    
631
  def tearDown(self):
632
    if os.path.exists(self.tmpfile):
633
      os.unlink(self.tmpfile)
634
    os.rmdir(self.tmpdir)
635

    
636
  def testIgnoreDirs(self):
637
    """Test that RemoveFile() ignores directories"""
638
    self.assertEqual(None, RemoveFile(self.tmpdir))
639

    
640
  def testIgnoreNotExisting(self):
641
    """Test that RemoveFile() ignores non-existing files"""
642
    RemoveFile(self.tmpfile)
643
    RemoveFile(self.tmpfile)
644

    
645
  def testRemoveFile(self):
646
    """Test that RemoveFile does remove a file"""
647
    RemoveFile(self.tmpfile)
648
    if os.path.exists(self.tmpfile):
649
      self.fail("File '%s' not removed" % self.tmpfile)
650

    
651
  def testRemoveSymlink(self):
652
    """Test that RemoveFile does remove symlinks"""
653
    symlink = self.tmpdir + "/symlink"
654
    os.symlink("no-such-file", symlink)
655
    RemoveFile(symlink)
656
    if os.path.exists(symlink):
657
      self.fail("File '%s' not removed" % symlink)
658
    os.symlink(self.tmpfile, symlink)
659
    RemoveFile(symlink)
660
    if os.path.exists(symlink):
661
      self.fail("File '%s' not removed" % symlink)
662

    
663

    
664
class TestRename(unittest.TestCase):
665
  """Test case for RenameFile"""
666

    
667
  def setUp(self):
668
    """Create a temporary directory"""
669
    self.tmpdir = tempfile.mkdtemp()
670
    self.tmpfile = os.path.join(self.tmpdir, "test1")
671

    
672
    # Touch the file
673
    open(self.tmpfile, "w").close()
674

    
675
  def tearDown(self):
676
    """Remove temporary directory"""
677
    shutil.rmtree(self.tmpdir)
678

    
679
  def testSimpleRename1(self):
680
    """Simple rename 1"""
681
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
682
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
683

    
684
  def testSimpleRename2(self):
685
    """Simple rename 2"""
686
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
687
                     mkdir=True)
688
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
689

    
690
  def testRenameMkdir(self):
691
    """Rename with mkdir"""
692
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
693
                     mkdir=True)
694
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
695
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
696

    
697
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
698
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
699
                     mkdir=True)
700
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
701
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
702
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
703

    
704

    
705
class TestMatchNameComponent(unittest.TestCase):
706
  """Test case for the MatchNameComponent function"""
707

    
708
  def testEmptyList(self):
709
    """Test that there is no match against an empty list"""
710

    
711
    self.failUnlessEqual(MatchNameComponent("", []), None)
712
    self.failUnlessEqual(MatchNameComponent("test", []), None)
713

    
714
  def testSingleMatch(self):
715
    """Test that a single match is performed correctly"""
716
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
717
    for key in "test2", "test2.example", "test2.example.com":
718
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
719

    
720
  def testMultipleMatches(self):
721
    """Test that a multiple match is returned as None"""
722
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
723
    for key in "test1", "test1.example":
724
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
725

    
726
  def testFullMatch(self):
727
    """Test that a full match is returned correctly"""
728
    key1 = "test1"
729
    key2 = "test1.example"
730
    mlist = [key2, key2 + ".com"]
731
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
732
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
733

    
734
  def testCaseInsensitivePartialMatch(self):
735
    """Test for the case_insensitive keyword"""
736
    mlist = ["test1.example.com", "test2.example.net"]
737
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
738
                     "test2.example.net")
739
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
740
                     "test2.example.net")
741
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
742
                     "test2.example.net")
743
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
744
                     "test2.example.net")
745

    
746

    
747
  def testCaseInsensitiveFullMatch(self):
748
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
749
    # Between the two ts1 a full string match non-case insensitive should work
750
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
751
                     None)
752
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
753
                     "ts1.ex")
754
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
755
                     "ts1.ex")
756
    # Between the two ts2 only case differs, so only case-match works
757
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
758
                     "ts2.ex")
759
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
760
                     "Ts2.ex")
761
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
762
                     None)
763

    
764

    
765
class TestReadFile(testutils.GanetiTestCase):
766

    
767
  def testReadAll(self):
768
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
769
    self.assertEqual(len(data), 814)
770

    
771
    h = compat.md5_hash()
772
    h.update(data)
773
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
774

    
775
  def testReadSize(self):
776
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
777
                          size=100)
778
    self.assertEqual(len(data), 100)
779

    
780
    h = compat.md5_hash()
781
    h.update(data)
782
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
783

    
784
  def testError(self):
785
    self.assertRaises(EnvironmentError, utils.ReadFile,
786
                      "/dev/null/does-not-exist")
787

    
788

    
789
class TestReadOneLineFile(testutils.GanetiTestCase):
790

    
791
  def setUp(self):
792
    testutils.GanetiTestCase.setUp(self)
793

    
794
  def testDefault(self):
795
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
796
    self.assertEqual(len(data), 27)
797
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
798

    
799
  def testNotStrict(self):
800
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
801
    self.assertEqual(len(data), 27)
802
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
803

    
804
  def testStrictFailure(self):
805
    self.assertRaises(errors.GenericError, ReadOneLineFile,
806
                      self._TestDataFilename("cert1.pem"), strict=True)
807

    
808
  def testLongLine(self):
809
    dummydata = (1024 * "Hello World! ")
810
    myfile = self._CreateTempFile()
811
    utils.WriteFile(myfile, data=dummydata)
812
    datastrict = ReadOneLineFile(myfile, strict=True)
813
    datalax = ReadOneLineFile(myfile, strict=False)
814
    self.assertEqual(dummydata, datastrict)
815
    self.assertEqual(dummydata, datalax)
816

    
817
  def testNewline(self):
818
    myfile = self._CreateTempFile()
819
    myline = "myline"
820
    for nl in ["", "\n", "\r\n"]:
821
      dummydata = "%s%s" % (myline, nl)
822
      utils.WriteFile(myfile, data=dummydata)
823
      datalax = ReadOneLineFile(myfile, strict=False)
824
      self.assertEqual(myline, datalax)
825
      datastrict = ReadOneLineFile(myfile, strict=True)
826
      self.assertEqual(myline, datastrict)
827

    
828
  def testWhitespaceAndMultipleLines(self):
829
    myfile = self._CreateTempFile()
830
    for nl in ["", "\n", "\r\n"]:
831
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
832
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
833
        utils.WriteFile(myfile, data=dummydata)
834
        datalax = ReadOneLineFile(myfile, strict=False)
835
        if nl:
836
          self.assert_(set("\r\n") & set(dummydata))
837
          self.assertRaises(errors.GenericError, ReadOneLineFile,
838
                            myfile, strict=True)
839
          explen = len("Foo bar baz ") + len(ws)
840
          self.assertEqual(len(datalax), explen)
841
          self.assertEqual(datalax, dummydata[:explen])
842
          self.assertFalse(set("\r\n") & set(datalax))
843
        else:
844
          datastrict = ReadOneLineFile(myfile, strict=True)
845
          self.assertEqual(dummydata, datastrict)
846
          self.assertEqual(dummydata, datalax)
847

    
848
  def testEmptylines(self):
849
    myfile = self._CreateTempFile()
850
    myline = "myline"
851
    for nl in ["\n", "\r\n"]:
852
      for ol in ["", "otherline"]:
853
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
854
        utils.WriteFile(myfile, data=dummydata)
855
        self.assert_(set("\r\n") & set(dummydata))
856
        datalax = ReadOneLineFile(myfile, strict=False)
857
        self.assertEqual(myline, datalax)
858
        if ol:
859
          self.assertRaises(errors.GenericError, ReadOneLineFile,
860
                            myfile, strict=True)
861
        else:
862
          datastrict = ReadOneLineFile(myfile, strict=True)
863
          self.assertEqual(myline, datastrict)
864

    
865

    
866
class TestTimestampForFilename(unittest.TestCase):
867
  def test(self):
868
    self.assert_("." not in utils.TimestampForFilename())
869
    self.assert_(":" not in utils.TimestampForFilename())
870

    
871

    
872
class TestCreateBackup(testutils.GanetiTestCase):
873
  def setUp(self):
874
    testutils.GanetiTestCase.setUp(self)
875

    
876
    self.tmpdir = tempfile.mkdtemp()
877

    
878
  def tearDown(self):
879
    testutils.GanetiTestCase.tearDown(self)
880

    
881
    shutil.rmtree(self.tmpdir)
882

    
883
  def testEmpty(self):
884
    filename = utils.PathJoin(self.tmpdir, "config.data")
885
    utils.WriteFile(filename, data="")
886
    bname = utils.CreateBackup(filename)
887
    self.assertFileContent(bname, "")
888
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
889
    utils.CreateBackup(filename)
890
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
891
    utils.CreateBackup(filename)
892
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
893

    
894
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
895
    os.mkfifo(fifoname)
896
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
897

    
898
  def testContent(self):
899
    bkpcount = 0
900
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
901
      for rep in [1, 2, 10, 127]:
902
        testdata = data * rep
903

    
904
        filename = utils.PathJoin(self.tmpdir, "test.data_")
905
        utils.WriteFile(filename, data=testdata)
906
        self.assertFileContent(filename, testdata)
907

    
908
        for _ in range(3):
909
          bname = utils.CreateBackup(filename)
910
          bkpcount += 1
911
          self.assertFileContent(bname, testdata)
912
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
913

    
914

    
915
class TestFormatUnit(unittest.TestCase):
916
  """Test case for the FormatUnit function"""
917

    
918
  def testMiB(self):
919
    self.assertEqual(FormatUnit(1, 'h'), '1M')
920
    self.assertEqual(FormatUnit(100, 'h'), '100M')
921
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
922

    
923
    self.assertEqual(FormatUnit(1, 'm'), '1')
924
    self.assertEqual(FormatUnit(100, 'm'), '100')
925
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
926

    
927
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
928
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
929
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
930
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
931

    
932
  def testGiB(self):
933
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
934
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
935
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
936
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
937

    
938
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
939
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
940
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
941
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
942

    
943
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
944
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
945
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
946

    
947
  def testTiB(self):
948
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
949
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
950
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
951

    
952
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
953
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
954
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
955

    
956
class TestParseUnit(unittest.TestCase):
957
  """Test case for the ParseUnit function"""
958

    
959
  SCALES = (('', 1),
960
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
961
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
962
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
963

    
964
  def testRounding(self):
965
    self.assertEqual(ParseUnit('0'), 0)
966
    self.assertEqual(ParseUnit('1'), 4)
967
    self.assertEqual(ParseUnit('2'), 4)
968
    self.assertEqual(ParseUnit('3'), 4)
969

    
970
    self.assertEqual(ParseUnit('124'), 124)
971
    self.assertEqual(ParseUnit('125'), 128)
972
    self.assertEqual(ParseUnit('126'), 128)
973
    self.assertEqual(ParseUnit('127'), 128)
974
    self.assertEqual(ParseUnit('128'), 128)
975
    self.assertEqual(ParseUnit('129'), 132)
976
    self.assertEqual(ParseUnit('130'), 132)
977

    
978
  def testFloating(self):
979
    self.assertEqual(ParseUnit('0'), 0)
980
    self.assertEqual(ParseUnit('0.5'), 4)
981
    self.assertEqual(ParseUnit('1.75'), 4)
982
    self.assertEqual(ParseUnit('1.99'), 4)
983
    self.assertEqual(ParseUnit('2.00'), 4)
984
    self.assertEqual(ParseUnit('2.01'), 4)
985
    self.assertEqual(ParseUnit('3.99'), 4)
986
    self.assertEqual(ParseUnit('4.00'), 4)
987
    self.assertEqual(ParseUnit('4.01'), 8)
988
    self.assertEqual(ParseUnit('1.5G'), 1536)
989
    self.assertEqual(ParseUnit('1.8G'), 1844)
990
    self.assertEqual(ParseUnit('8.28T'), 8682212)
991

    
992
  def testSuffixes(self):
993
    for sep in ('', ' ', '   ', "\t", "\t "):
994
      for suffix, scale in TestParseUnit.SCALES:
995
        for func in (lambda x: x, str.lower, str.upper):
996
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
997
                           1024 * scale)
998

    
999
  def testInvalidInput(self):
1000
    for sep in ('-', '_', ',', 'a'):
1001
      for suffix, _ in TestParseUnit.SCALES:
1002
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
1003

    
1004
    for suffix, _ in TestParseUnit.SCALES:
1005
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
1006

    
1007

    
1008
class TestSshKeys(testutils.GanetiTestCase):
1009
  """Test case for the AddAuthorizedKey function"""
1010

    
1011
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1012
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
1013
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1014

    
1015
  def setUp(self):
1016
    testutils.GanetiTestCase.setUp(self)
1017
    self.tmpname = self._CreateTempFile()
1018
    handle = open(self.tmpname, 'w')
1019
    try:
1020
      handle.write("%s\n" % TestSshKeys.KEY_A)
1021
      handle.write("%s\n" % TestSshKeys.KEY_B)
1022
    finally:
1023
      handle.close()
1024

    
1025
  def testAddingNewKey(self):
1026
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1027

    
1028
    self.assertFileContent(self.tmpname,
1029
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1030
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1031
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1032
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1033

    
1034
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1035
    AddAuthorizedKey(self.tmpname,
1036
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1037

    
1038
    self.assertFileContent(self.tmpname,
1039
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1040
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1041
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1042
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1043

    
1044
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1045
    AddAuthorizedKey(self.tmpname,
1046
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1047

    
1048
    self.assertFileContent(self.tmpname,
1049
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1050
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1051
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1052

    
1053
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1054
    RemoveAuthorizedKey(self.tmpname,
1055
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1056

    
1057
    self.assertFileContent(self.tmpname,
1058
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1059
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1060

    
1061
  def testRemovingNonExistingKey(self):
1062
    RemoveAuthorizedKey(self.tmpname,
1063
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1064

    
1065
    self.assertFileContent(self.tmpname,
1066
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1067
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1068
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1069

    
1070

    
1071
class TestEtcHosts(testutils.GanetiTestCase):
1072
  """Test functions modifying /etc/hosts"""
1073

    
1074
  def setUp(self):
1075
    testutils.GanetiTestCase.setUp(self)
1076
    self.tmpname = self._CreateTempFile()
1077
    handle = open(self.tmpname, 'w')
1078
    try:
1079
      handle.write('# This is a test file for /etc/hosts\n')
1080
      handle.write('127.0.0.1\tlocalhost\n')
1081
      handle.write('192.168.1.1 router gw\n')
1082
    finally:
1083
      handle.close()
1084

    
1085
  def testSettingNewIp(self):
1086
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
1087

    
1088
    self.assertFileContent(self.tmpname,
1089
      "# This is a test file for /etc/hosts\n"
1090
      "127.0.0.1\tlocalhost\n"
1091
      "192.168.1.1 router gw\n"
1092
      "1.2.3.4\tmyhost.domain.tld myhost\n")
1093
    self.assertFileMode(self.tmpname, 0644)
1094

    
1095
  def testSettingExistingIp(self):
1096
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
1097
                     ['myhost'])
1098

    
1099
    self.assertFileContent(self.tmpname,
1100
      "# This is a test file for /etc/hosts\n"
1101
      "127.0.0.1\tlocalhost\n"
1102
      "192.168.1.1\tmyhost.domain.tld myhost\n")
1103
    self.assertFileMode(self.tmpname, 0644)
1104

    
1105
  def testSettingDuplicateName(self):
1106
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
1107

    
1108
    self.assertFileContent(self.tmpname,
1109
      "# This is a test file for /etc/hosts\n"
1110
      "127.0.0.1\tlocalhost\n"
1111
      "192.168.1.1 router gw\n"
1112
      "1.2.3.4\tmyhost\n")
1113
    self.assertFileMode(self.tmpname, 0644)
1114

    
1115
  def testRemovingExistingHost(self):
1116
    RemoveEtcHostsEntry(self.tmpname, 'router')
1117

    
1118
    self.assertFileContent(self.tmpname,
1119
      "# This is a test file for /etc/hosts\n"
1120
      "127.0.0.1\tlocalhost\n"
1121
      "192.168.1.1 gw\n")
1122
    self.assertFileMode(self.tmpname, 0644)
1123

    
1124
  def testRemovingSingleExistingHost(self):
1125
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1126

    
1127
    self.assertFileContent(self.tmpname,
1128
      "# This is a test file for /etc/hosts\n"
1129
      "192.168.1.1 router gw\n")
1130
    self.assertFileMode(self.tmpname, 0644)
1131

    
1132
  def testRemovingNonExistingHost(self):
1133
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1134

    
1135
    self.assertFileContent(self.tmpname,
1136
      "# This is a test file for /etc/hosts\n"
1137
      "127.0.0.1\tlocalhost\n"
1138
      "192.168.1.1 router gw\n")
1139
    self.assertFileMode(self.tmpname, 0644)
1140

    
1141
  def testRemovingAlias(self):
1142
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1143

    
1144
    self.assertFileContent(self.tmpname,
1145
      "# This is a test file for /etc/hosts\n"
1146
      "127.0.0.1\tlocalhost\n"
1147
      "192.168.1.1 router\n")
1148
    self.assertFileMode(self.tmpname, 0644)
1149

    
1150

    
1151
class TestShellQuoting(unittest.TestCase):
1152
  """Test case for shell quoting functions"""
1153

    
1154
  def testShellQuote(self):
1155
    self.assertEqual(ShellQuote('abc'), "abc")
1156
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1157
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1158
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1159
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1160

    
1161
  def testShellQuoteArgs(self):
1162
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1163
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1164
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1165

    
1166

    
1167
class TestTcpPing(unittest.TestCase):
1168
  """Testcase for TCP version of ping - against listen(2)ing port"""
1169

    
1170
  def setUp(self):
1171
    self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1172
    self.listener.bind((constants.IP4_ADDRESS_LOCALHOST, 0))
1173
    self.listenerport = self.listener.getsockname()[1]
1174
    self.listener.listen(1)
1175

    
1176
  def tearDown(self):
1177
    self.listener.shutdown(socket.SHUT_RDWR)
1178
    del self.listener
1179
    del self.listenerport
1180

    
1181
  def testTcpPingToLocalHostAccept(self):
1182
    self.assert_(TcpPing(constants.IP4_ADDRESS_LOCALHOST,
1183
                         self.listenerport,
1184
                         timeout=10,
1185
                         live_port_needed=True,
1186
                         source=constants.IP4_ADDRESS_LOCALHOST,
1187
                         ),
1188
                 "failed to connect to test listener")
1189

    
1190
    self.assert_(TcpPing(constants.IP4_ADDRESS_LOCALHOST,
1191
                         self.listenerport,
1192
                         timeout=10,
1193
                         live_port_needed=True,
1194
                         ),
1195
                 "failed to connect to test listener (no source)")
1196

    
1197

    
1198
class TestTcpPingDeaf(unittest.TestCase):
1199
  """Testcase for TCP version of ping - against non listen(2)ing port"""
1200

    
1201
  def setUp(self):
1202
    self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1203
    self.deaflistener.bind((constants.IP4_ADDRESS_LOCALHOST, 0))
1204
    self.deaflistenerport = self.deaflistener.getsockname()[1]
1205

    
1206
  def tearDown(self):
1207
    del self.deaflistener
1208
    del self.deaflistenerport
1209

    
1210
  def testTcpPingToLocalHostAcceptDeaf(self):
1211
    self.failIf(TcpPing(constants.IP4_ADDRESS_LOCALHOST,
1212
                        self.deaflistenerport,
1213
                        timeout=constants.TCP_PING_TIMEOUT,
1214
                        live_port_needed=True,
1215
                        source=constants.IP4_ADDRESS_LOCALHOST,
1216
                        ), # need successful connect(2)
1217
                "successfully connected to deaf listener")
1218

    
1219
    self.failIf(TcpPing(constants.IP4_ADDRESS_LOCALHOST,
1220
                        self.deaflistenerport,
1221
                        timeout=constants.TCP_PING_TIMEOUT,
1222
                        live_port_needed=True,
1223
                        ), # need successful connect(2)
1224
                "successfully connected to deaf listener (no source addr)")
1225

    
1226
  def testTcpPingToLocalHostNoAccept(self):
1227
    self.assert_(TcpPing(constants.IP4_ADDRESS_LOCALHOST,
1228
                         self.deaflistenerport,
1229
                         timeout=constants.TCP_PING_TIMEOUT,
1230
                         live_port_needed=False,
1231
                         source=constants.IP4_ADDRESS_LOCALHOST,
1232
                         ), # ECONNREFUSED is OK
1233
                 "failed to ping alive host on deaf port")
1234

    
1235
    self.assert_(TcpPing(constants.IP4_ADDRESS_LOCALHOST,
1236
                         self.deaflistenerport,
1237
                         timeout=constants.TCP_PING_TIMEOUT,
1238
                         live_port_needed=False,
1239
                         ), # ECONNREFUSED is OK
1240
                 "failed to ping alive host on deaf port (no source addr)")
1241

    
1242

    
1243
class TestOwnIpAddress(unittest.TestCase):
1244
  """Testcase for OwnIpAddress"""
1245

    
1246
  def testOwnLoopback(self):
1247
    """check having the loopback ip"""
1248
    self.failUnless(OwnIpAddress(constants.IP4_ADDRESS_LOCALHOST),
1249
                    "Should own the loopback address")
1250

    
1251
  def testNowOwnAddress(self):
1252
    """check that I don't own an address"""
1253

    
1254
    # Network 192.0.2.0/24 is reserved for test/documentation as per
1255
    # RFC 5735, so we *should* not have an address of this range... if
1256
    # this fails, we should extend the test to multiple addresses
1257
    DST_IP = "192.0.2.1"
1258
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1259

    
1260

    
1261
def _GetSocketCredentials(path):
1262
  """Connect to a Unix socket and return remote credentials.
1263

1264
  """
1265
  sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1266
  try:
1267
    sock.settimeout(10)
1268
    sock.connect(path)
1269
    return utils.GetSocketCredentials(sock)
1270
  finally:
1271
    sock.close()
1272

    
1273

    
1274
class TestGetSocketCredentials(unittest.TestCase):
1275
  def setUp(self):
1276
    self.tmpdir = tempfile.mkdtemp()
1277
    self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1278

    
1279
    self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1280
    self.listener.settimeout(10)
1281
    self.listener.bind(self.sockpath)
1282
    self.listener.listen(1)
1283

    
1284
  def tearDown(self):
1285
    self.listener.shutdown(socket.SHUT_RDWR)
1286
    self.listener.close()
1287
    shutil.rmtree(self.tmpdir)
1288

    
1289
  def test(self):
1290
    (c2pr, c2pw) = os.pipe()
1291

    
1292
    # Start child process
1293
    child = os.fork()
1294
    if child == 0:
1295
      try:
1296
        data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1297

    
1298
        os.write(c2pw, data)
1299
        os.close(c2pw)
1300

    
1301
        os._exit(0)
1302
      finally:
1303
        os._exit(1)
1304

    
1305
    os.close(c2pw)
1306

    
1307
    # Wait for one connection
1308
    (conn, _) = self.listener.accept()
1309
    conn.recv(1)
1310
    conn.close()
1311

    
1312
    # Wait for result
1313
    result = os.read(c2pr, 4096)
1314
    os.close(c2pr)
1315

    
1316
    # Check child's exit code
1317
    (_, status) = os.waitpid(child, 0)
1318
    self.assertFalse(os.WIFSIGNALED(status))
1319
    self.assertEqual(os.WEXITSTATUS(status), 0)
1320

    
1321
    # Check result
1322
    (pid, uid, gid) = serializer.LoadJson(result)
1323
    self.assertEqual(pid, os.getpid())
1324
    self.assertEqual(uid, os.getuid())
1325
    self.assertEqual(gid, os.getgid())
1326

    
1327

    
1328
class TestListVisibleFiles(unittest.TestCase):
1329
  """Test case for ListVisibleFiles"""
1330

    
1331
  def setUp(self):
1332
    self.path = tempfile.mkdtemp()
1333

    
1334
  def tearDown(self):
1335
    shutil.rmtree(self.path)
1336

    
1337
  def _CreateFiles(self, files):
1338
    for name in files:
1339
      utils.WriteFile(os.path.join(self.path, name), data="test")
1340

    
1341
  def _test(self, files, expected):
1342
    self._CreateFiles(files)
1343
    found = ListVisibleFiles(self.path)
1344
    self.assertEqual(set(found), set(expected))
1345

    
1346
  def testAllVisible(self):
1347
    files = ["a", "b", "c"]
1348
    expected = files
1349
    self._test(files, expected)
1350

    
1351
  def testNoneVisible(self):
1352
    files = [".a", ".b", ".c"]
1353
    expected = []
1354
    self._test(files, expected)
1355

    
1356
  def testSomeVisible(self):
1357
    files = ["a", "b", ".c"]
1358
    expected = ["a", "b"]
1359
    self._test(files, expected)
1360

    
1361
  def testNonAbsolutePath(self):
1362
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1363

    
1364
  def testNonNormalizedPath(self):
1365
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1366
                          "/bin/../tmp")
1367

    
1368

    
1369
class TestNewUUID(unittest.TestCase):
1370
  """Test case for NewUUID"""
1371

    
1372
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1373
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1374

    
1375
  def runTest(self):
1376
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1377

    
1378

    
1379
class TestUniqueSequence(unittest.TestCase):
1380
  """Test case for UniqueSequence"""
1381

    
1382
  def _test(self, input, expected):
1383
    self.assertEqual(utils.UniqueSequence(input), expected)
1384

    
1385
  def runTest(self):
1386
    # Ordered input
1387
    self._test([1, 2, 3], [1, 2, 3])
1388
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1389
    self._test([1, 2, 2, 3], [1, 2, 3])
1390
    self._test([1, 2, 3, 3], [1, 2, 3])
1391

    
1392
    # Unordered input
1393
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1394
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1395

    
1396
    # Strings
1397
    self._test(["a", "a"], ["a"])
1398
    self._test(["a", "b"], ["a", "b"])
1399
    self._test(["a", "b", "a"], ["a", "b"])
1400

    
1401

    
1402
class TestFirstFree(unittest.TestCase):
1403
  """Test case for the FirstFree function"""
1404

    
1405
  def test(self):
1406
    """Test FirstFree"""
1407
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1408
    self.failUnlessEqual(FirstFree([]), None)
1409
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1410
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1411
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1412

    
1413

    
1414
class TestTailFile(testutils.GanetiTestCase):
1415
  """Test case for the TailFile function"""
1416

    
1417
  def testEmpty(self):
1418
    fname = self._CreateTempFile()
1419
    self.failUnlessEqual(TailFile(fname), [])
1420
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1421

    
1422
  def testAllLines(self):
1423
    data = ["test %d" % i for i in range(30)]
1424
    for i in range(30):
1425
      fname = self._CreateTempFile()
1426
      fd = open(fname, "w")
1427
      fd.write("\n".join(data[:i]))
1428
      if i > 0:
1429
        fd.write("\n")
1430
      fd.close()
1431
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1432

    
1433
  def testPartialLines(self):
1434
    data = ["test %d" % i for i in range(30)]
1435
    fname = self._CreateTempFile()
1436
    fd = open(fname, "w")
1437
    fd.write("\n".join(data))
1438
    fd.write("\n")
1439
    fd.close()
1440
    for i in range(1, 30):
1441
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1442

    
1443
  def testBigFile(self):
1444
    data = ["test %d" % i for i in range(30)]
1445
    fname = self._CreateTempFile()
1446
    fd = open(fname, "w")
1447
    fd.write("X" * 1048576)
1448
    fd.write("\n")
1449
    fd.write("\n".join(data))
1450
    fd.write("\n")
1451
    fd.close()
1452
    for i in range(1, 30):
1453
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1454

    
1455

    
1456
class _BaseFileLockTest:
1457
  """Test case for the FileLock class"""
1458

    
1459
  def testSharedNonblocking(self):
1460
    self.lock.Shared(blocking=False)
1461
    self.lock.Close()
1462

    
1463
  def testExclusiveNonblocking(self):
1464
    self.lock.Exclusive(blocking=False)
1465
    self.lock.Close()
1466

    
1467
  def testUnlockNonblocking(self):
1468
    self.lock.Unlock(blocking=False)
1469
    self.lock.Close()
1470

    
1471
  def testSharedBlocking(self):
1472
    self.lock.Shared(blocking=True)
1473
    self.lock.Close()
1474

    
1475
  def testExclusiveBlocking(self):
1476
    self.lock.Exclusive(blocking=True)
1477
    self.lock.Close()
1478

    
1479
  def testUnlockBlocking(self):
1480
    self.lock.Unlock(blocking=True)
1481
    self.lock.Close()
1482

    
1483
  def testSharedExclusiveUnlock(self):
1484
    self.lock.Shared(blocking=False)
1485
    self.lock.Exclusive(blocking=False)
1486
    self.lock.Unlock(blocking=False)
1487
    self.lock.Close()
1488

    
1489
  def testExclusiveSharedUnlock(self):
1490
    self.lock.Exclusive(blocking=False)
1491
    self.lock.Shared(blocking=False)
1492
    self.lock.Unlock(blocking=False)
1493
    self.lock.Close()
1494

    
1495
  def testSimpleTimeout(self):
1496
    # These will succeed on the first attempt, hence a short timeout
1497
    self.lock.Shared(blocking=True, timeout=10.0)
1498
    self.lock.Exclusive(blocking=False, timeout=10.0)
1499
    self.lock.Unlock(blocking=True, timeout=10.0)
1500
    self.lock.Close()
1501

    
1502
  @staticmethod
1503
  def _TryLockInner(filename, shared, blocking):
1504
    lock = utils.FileLock.Open(filename)
1505

    
1506
    if shared:
1507
      fn = lock.Shared
1508
    else:
1509
      fn = lock.Exclusive
1510

    
1511
    try:
1512
      # The timeout doesn't really matter as the parent process waits for us to
1513
      # finish anyway.
1514
      fn(blocking=blocking, timeout=0.01)
1515
    except errors.LockError, err:
1516
      return False
1517

    
1518
    return True
1519

    
1520
  def _TryLock(self, *args):
1521
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1522
                                      *args)
1523

    
1524
  def testTimeout(self):
1525
    for blocking in [True, False]:
1526
      self.lock.Exclusive(blocking=True)
1527
      self.failIf(self._TryLock(False, blocking))
1528
      self.failIf(self._TryLock(True, blocking))
1529

    
1530
      self.lock.Shared(blocking=True)
1531
      self.assert_(self._TryLock(True, blocking))
1532
      self.failIf(self._TryLock(False, blocking))
1533

    
1534
  def testCloseShared(self):
1535
    self.lock.Close()
1536
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1537

    
1538
  def testCloseExclusive(self):
1539
    self.lock.Close()
1540
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1541

    
1542
  def testCloseUnlock(self):
1543
    self.lock.Close()
1544
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1545

    
1546

    
1547
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1548
  TESTDATA = "Hello World\n" * 10
1549

    
1550
  def setUp(self):
1551
    testutils.GanetiTestCase.setUp(self)
1552

    
1553
    self.tmpfile = tempfile.NamedTemporaryFile()
1554
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1555
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1556

    
1557
    # Ensure "Open" didn't truncate file
1558
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1559

    
1560
  def tearDown(self):
1561
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1562

    
1563
    testutils.GanetiTestCase.tearDown(self)
1564

    
1565

    
1566
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1567
  def setUp(self):
1568
    self.tmpfile = tempfile.NamedTemporaryFile()
1569
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1570

    
1571

    
1572
class TestTimeFunctions(unittest.TestCase):
1573
  """Test case for time functions"""
1574

    
1575
  def runTest(self):
1576
    self.assertEqual(utils.SplitTime(1), (1, 0))
1577
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1578
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1579
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1580
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1581
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1582
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1583
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1584

    
1585
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1586

    
1587
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1588
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1589
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1590

    
1591
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1592
                     1218448917.481)
1593
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1594

    
1595
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1596
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1597
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1598
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1599
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1600

    
1601

    
1602
class FieldSetTestCase(unittest.TestCase):
1603
  """Test case for FieldSets"""
1604

    
1605
  def testSimpleMatch(self):
1606
    f = utils.FieldSet("a", "b", "c", "def")
1607
    self.failUnless(f.Matches("a"))
1608
    self.failIf(f.Matches("d"), "Substring matched")
1609
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1610
    self.failIf(f.NonMatching(["b", "c"]))
1611
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1612
    self.failUnless(f.NonMatching(["a", "d"]))
1613

    
1614
  def testRegexMatch(self):
1615
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1616
    self.failUnless(f.Matches("b1"))
1617
    self.failUnless(f.Matches("b99"))
1618
    self.failIf(f.Matches("b/1"))
1619
    self.failIf(f.NonMatching(["b12", "c"]))
1620
    self.failUnless(f.NonMatching(["a", "1"]))
1621

    
1622
class TestForceDictType(unittest.TestCase):
1623
  """Test case for ForceDictType"""
1624

    
1625
  def setUp(self):
1626
    self.key_types = {
1627
      'a': constants.VTYPE_INT,
1628
      'b': constants.VTYPE_BOOL,
1629
      'c': constants.VTYPE_STRING,
1630
      'd': constants.VTYPE_SIZE,
1631
      }
1632

    
1633
  def _fdt(self, dict, allowed_values=None):
1634
    if allowed_values is None:
1635
      ForceDictType(dict, self.key_types)
1636
    else:
1637
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1638

    
1639
    return dict
1640

    
1641
  def testSimpleDict(self):
1642
    self.assertEqual(self._fdt({}), {})
1643
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1644
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1645
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1646
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1647
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1648
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1649
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1650
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1651
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1652
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1653
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1654

    
1655
  def testErrors(self):
1656
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1657
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1658
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1659
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1660

    
1661

    
1662
class TestIsNormAbsPath(unittest.TestCase):
1663
  """Testing case for IsNormAbsPath"""
1664

    
1665
  def _pathTestHelper(self, path, result):
1666
    if result:
1667
      self.assert_(IsNormAbsPath(path),
1668
          "Path %s should result absolute and normalized" % path)
1669
    else:
1670
      self.assertFalse(IsNormAbsPath(path),
1671
          "Path %s should not result absolute and normalized" % path)
1672

    
1673
  def testBase(self):
1674
    self._pathTestHelper('/etc', True)
1675
    self._pathTestHelper('/srv', True)
1676
    self._pathTestHelper('etc', False)
1677
    self._pathTestHelper('/etc/../root', False)
1678
    self._pathTestHelper('/etc/', False)
1679

    
1680

    
1681
class TestSafeEncode(unittest.TestCase):
1682
  """Test case for SafeEncode"""
1683

    
1684
  def testAscii(self):
1685
    for txt in [string.digits, string.letters, string.punctuation]:
1686
      self.failUnlessEqual(txt, SafeEncode(txt))
1687

    
1688
  def testDoubleEncode(self):
1689
    for i in range(255):
1690
      txt = SafeEncode(chr(i))
1691
      self.failUnlessEqual(txt, SafeEncode(txt))
1692

    
1693
  def testUnicode(self):
1694
    # 1024 is high enough to catch non-direct ASCII mappings
1695
    for i in range(1024):
1696
      txt = SafeEncode(unichr(i))
1697
      self.failUnlessEqual(txt, SafeEncode(txt))
1698

    
1699

    
1700
class TestFormatTime(unittest.TestCase):
1701
  """Testing case for FormatTime"""
1702

    
1703
  def testNone(self):
1704
    self.failUnlessEqual(FormatTime(None), "N/A")
1705

    
1706
  def testInvalid(self):
1707
    self.failUnlessEqual(FormatTime(()), "N/A")
1708

    
1709
  def testNow(self):
1710
    # tests that we accept time.time input
1711
    FormatTime(time.time())
1712
    # tests that we accept int input
1713
    FormatTime(int(time.time()))
1714

    
1715

    
1716
class RunInSeparateProcess(unittest.TestCase):
1717
  def test(self):
1718
    for exp in [True, False]:
1719
      def _child():
1720
        return exp
1721

    
1722
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1723

    
1724
  def testArgs(self):
1725
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1726
      def _child(carg1, carg2):
1727
        return carg1 == "Foo" and carg2 == arg
1728

    
1729
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1730

    
1731
  def testPid(self):
1732
    parent_pid = os.getpid()
1733

    
1734
    def _check():
1735
      return os.getpid() == parent_pid
1736

    
1737
    self.failIf(utils.RunInSeparateProcess(_check))
1738

    
1739
  def testSignal(self):
1740
    def _kill():
1741
      os.kill(os.getpid(), signal.SIGTERM)
1742

    
1743
    self.assertRaises(errors.GenericError,
1744
                      utils.RunInSeparateProcess, _kill)
1745

    
1746
  def testException(self):
1747
    def _exc():
1748
      raise errors.GenericError("This is a test")
1749

    
1750
    self.assertRaises(errors.GenericError,
1751
                      utils.RunInSeparateProcess, _exc)
1752

    
1753

    
1754
class TestFingerprintFile(unittest.TestCase):
1755
  def setUp(self):
1756
    self.tmpfile = tempfile.NamedTemporaryFile()
1757

    
1758
  def test(self):
1759
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1760
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1761

    
1762
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1763
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1764
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1765

    
1766

    
1767
class TestUnescapeAndSplit(unittest.TestCase):
1768
  """Testing case for UnescapeAndSplit"""
1769

    
1770
  def setUp(self):
1771
    # testing more that one separator for regexp safety
1772
    self._seps = [",", "+", "."]
1773

    
1774
  def testSimple(self):
1775
    a = ["a", "b", "c", "d"]
1776
    for sep in self._seps:
1777
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1778

    
1779
  def testEscape(self):
1780
    for sep in self._seps:
1781
      a = ["a", "b\\" + sep + "c", "d"]
1782
      b = ["a", "b" + sep + "c", "d"]
1783
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1784

    
1785
  def testDoubleEscape(self):
1786
    for sep in self._seps:
1787
      a = ["a", "b\\\\", "c", "d"]
1788
      b = ["a", "b\\", "c", "d"]
1789
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1790

    
1791
  def testThreeEscape(self):
1792
    for sep in self._seps:
1793
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1794
      b = ["a", "b\\" + sep + "c", "d"]
1795
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1796

    
1797

    
1798
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1799
  def setUp(self):
1800
    self.tmpdir = tempfile.mkdtemp()
1801

    
1802
  def tearDown(self):
1803
    shutil.rmtree(self.tmpdir)
1804

    
1805
  def _checkRsaPrivateKey(self, key):
1806
    lines = key.splitlines()
1807
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1808
            "-----END RSA PRIVATE KEY-----" in lines)
1809

    
1810
  def _checkCertificate(self, cert):
1811
    lines = cert.splitlines()
1812
    return ("-----BEGIN CERTIFICATE-----" in lines and
1813
            "-----END CERTIFICATE-----" in lines)
1814

    
1815
  def test(self):
1816
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1817
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1818
      self._checkRsaPrivateKey(key_pem)
1819
      self._checkCertificate(cert_pem)
1820

    
1821
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1822
                                           key_pem)
1823
      self.assert_(key.bits() >= 1024)
1824
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1825
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1826

    
1827
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1828
                                             cert_pem)
1829
      self.failIf(x509.has_expired())
1830
      self.assertEqual(x509.get_issuer().CN, common_name)
1831
      self.assertEqual(x509.get_subject().CN, common_name)
1832
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1833

    
1834
  def testLegacy(self):
1835
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1836

    
1837
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1838

    
1839
    cert1 = utils.ReadFile(cert1_filename)
1840

    
1841
    self.assert_(self._checkRsaPrivateKey(cert1))
1842
    self.assert_(self._checkCertificate(cert1))
1843

    
1844

    
1845
class TestPathJoin(unittest.TestCase):
1846
  """Testing case for PathJoin"""
1847

    
1848
  def testBasicItems(self):
1849
    mlist = ["/a", "b", "c"]
1850
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1851

    
1852
  def testNonAbsPrefix(self):
1853
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1854

    
1855
  def testBackTrack(self):
1856
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1857

    
1858
  def testMultiAbs(self):
1859
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1860

    
1861

    
1862
class TestHostInfo(unittest.TestCase):
1863
  """Testing case for HostInfo"""
1864

    
1865
  def testUppercase(self):
1866
    data = "AbC.example.com"
1867
    self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1868

    
1869
  def testTooLongName(self):
1870
    data = "a.b." + "c" * 255
1871
    self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1872

    
1873
  def testTrailingDot(self):
1874
    data = "a.b.c"
1875
    self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1876

    
1877
  def testInvalidName(self):
1878
    data = [
1879
      "a b",
1880
      "a/b",
1881
      ".a.b",
1882
      "a..b",
1883
      ]
1884
    for value in data:
1885
      self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1886

    
1887
  def testValidName(self):
1888
    data = [
1889
      "a.b",
1890
      "a-b",
1891
      "a_b",
1892
      "a.b.c",
1893
      ]
1894
    for value in data:
1895
      HostInfo.NormalizeName(value)
1896

    
1897

    
1898
class TestValidateServiceName(unittest.TestCase):
1899
  def testValid(self):
1900
    testnames = [
1901
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1902
      "ganeti",
1903
      "gnt-masterd",
1904
      "HELLO_WORLD_SVC",
1905
      "hello.world.1",
1906
      "0", "80", "1111", "65535",
1907
      ]
1908

    
1909
    for name in testnames:
1910
      self.assertEqual(utils.ValidateServiceName(name), name)
1911

    
1912
  def testInvalid(self):
1913
    testnames = [
1914
      -15756, -1, 65536, 133428083,
1915
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1916
      "-8546", "-1", "65536",
1917
      (129 * "A"),
1918
      ]
1919

    
1920
    for name in testnames:
1921
      self.assertRaises(OpPrereqError, utils.ValidateServiceName, name)
1922

    
1923

    
1924
class TestParseAsn1Generalizedtime(unittest.TestCase):
1925
  def test(self):
1926
    # UTC
1927
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1928
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1929
                     1266860512)
1930
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1931
                     (2**31) - 1)
1932

    
1933
    # With offset
1934
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1935
                     1266860512)
1936
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1937
                     1266931012)
1938
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1939
                     1266931088)
1940
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1941
                     1266931295)
1942
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1943
                     3600)
1944

    
1945
    # Leap seconds are not supported by datetime.datetime
1946
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1947
                      "19841231235960+0000")
1948
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1949
                      "19920630235960+0000")
1950

    
1951
    # Errors
1952
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1953
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1954
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1955
                      "20100222174152")
1956
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1957
                      "Mon Feb 22 17:47:02 UTC 2010")
1958
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1959
                      "2010-02-22 17:42:02")
1960

    
1961

    
1962
class TestGetX509CertValidity(testutils.GanetiTestCase):
1963
  def setUp(self):
1964
    testutils.GanetiTestCase.setUp(self)
1965

    
1966
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1967

    
1968
    # Test whether we have pyOpenSSL 0.7 or above
1969
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1970

    
1971
    if not self.pyopenssl0_7:
1972
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1973
                    " function correctly")
1974

    
1975
  def _LoadCert(self, name):
1976
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1977
                                           self._ReadTestData(name))
1978

    
1979
  def test(self):
1980
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1981
    if self.pyopenssl0_7:
1982
      self.assertEqual(validity, (1266919967, 1267524767))
1983
    else:
1984
      self.assertEqual(validity, (None, None))
1985

    
1986

    
1987
class TestSignX509Certificate(unittest.TestCase):
1988
  KEY = "My private key!"
1989
  KEY_OTHER = "Another key"
1990

    
1991
  def test(self):
1992
    # Generate certificate valid for 5 minutes
1993
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1994

    
1995
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1996
                                           cert_pem)
1997

    
1998
    # No signature at all
1999
    self.assertRaises(errors.GenericError,
2000
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
2001

    
2002
    # Invalid input
2003
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2004
                      "", self.KEY)
2005
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2006
                      "X-Ganeti-Signature: \n", self.KEY)
2007
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2008
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
2009
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2010
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
2011
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2012
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
2013

    
2014
    # Invalid salt
2015
    for salt in list("-_@$,:;/\\ \t\n"):
2016
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
2017
                        cert_pem, self.KEY, "foo%sbar" % salt)
2018

    
2019
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
2020
                 utils.GenerateSecret(numbytes=4),
2021
                 utils.GenerateSecret(numbytes=16),
2022
                 "{123:456}".encode("hex")]:
2023
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2024

    
2025
      self._Check(cert, salt, signed_pem)
2026

    
2027
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2028
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2029
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2030
                               "lines----\n------ at\nthe end!"))
2031

    
2032
  def _Check(self, cert, salt, pem):
2033
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2034
    self.assertEqual(salt, salt2)
2035
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2036

    
2037
    # Other key
2038
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2039
                      pem, self.KEY_OTHER)
2040

    
2041

    
2042
class TestMakedirs(unittest.TestCase):
2043
  def setUp(self):
2044
    self.tmpdir = tempfile.mkdtemp()
2045

    
2046
  def tearDown(self):
2047
    shutil.rmtree(self.tmpdir)
2048

    
2049
  def testNonExisting(self):
2050
    path = utils.PathJoin(self.tmpdir, "foo")
2051
    utils.Makedirs(path)
2052
    self.assert_(os.path.isdir(path))
2053

    
2054
  def testExisting(self):
2055
    path = utils.PathJoin(self.tmpdir, "foo")
2056
    os.mkdir(path)
2057
    utils.Makedirs(path)
2058
    self.assert_(os.path.isdir(path))
2059

    
2060
  def testRecursiveNonExisting(self):
2061
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
2062
    utils.Makedirs(path)
2063
    self.assert_(os.path.isdir(path))
2064

    
2065
  def testRecursiveExisting(self):
2066
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
2067
    self.assertFalse(os.path.exists(path))
2068
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
2069
    utils.Makedirs(path)
2070
    self.assert_(os.path.isdir(path))
2071

    
2072

    
2073
class TestRetry(testutils.GanetiTestCase):
2074
  def setUp(self):
2075
    testutils.GanetiTestCase.setUp(self)
2076
    self.retries = 0
2077

    
2078
  @staticmethod
2079
  def _RaiseRetryAgain():
2080
    raise utils.RetryAgain()
2081

    
2082
  @staticmethod
2083
  def _RaiseRetryAgainWithArg(args):
2084
    raise utils.RetryAgain(*args)
2085

    
2086
  def _WrongNestedLoop(self):
2087
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2088

    
2089
  def _RetryAndSucceed(self, retries):
2090
    if self.retries < retries:
2091
      self.retries += 1
2092
      raise utils.RetryAgain()
2093
    else:
2094
      return True
2095

    
2096
  def testRaiseTimeout(self):
2097
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2098
                          self._RaiseRetryAgain, 0.01, 0.02)
2099
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2100
                          self._RetryAndSucceed, 0.01, 0, args=[1])
2101
    self.failUnlessEqual(self.retries, 1)
2102

    
2103
  def testComplete(self):
2104
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2105
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2106
                         True)
2107
    self.failUnlessEqual(self.retries, 2)
2108

    
2109
  def testNestedLoop(self):
2110
    try:
2111
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2112
                            self._WrongNestedLoop, 0, 1)
2113
    except utils.RetryTimeout:
2114
      self.fail("Didn't detect inner loop's exception")
2115

    
2116
  def testTimeoutArgument(self):
2117
    retry_arg="my_important_debugging_message"
2118
    try:
2119
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2120
    except utils.RetryTimeout, err:
2121
      self.failUnlessEqual(err.args, (retry_arg, ))
2122
    else:
2123
      self.fail("Expected timeout didn't happen")
2124

    
2125
  def testRaiseInnerWithExc(self):
2126
    retry_arg="my_important_debugging_message"
2127
    try:
2128
      try:
2129
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2130
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
2131
      except utils.RetryTimeout, err:
2132
        err.RaiseInner()
2133
      else:
2134
        self.fail("Expected timeout didn't happen")
2135
    except errors.GenericError, err:
2136
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2137
    else:
2138
      self.fail("Expected GenericError didn't happen")
2139

    
2140
  def testRaiseInnerWithMsg(self):
2141
    retry_arg="my_important_debugging_message"
2142
    try:
2143
      try:
2144
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2145
                    args=[[retry_arg, retry_arg]])
2146
      except utils.RetryTimeout, err:
2147
        err.RaiseInner()
2148
      else:
2149
        self.fail("Expected timeout didn't happen")
2150
    except utils.RetryTimeout, err:
2151
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2152
    else:
2153
      self.fail("Expected RetryTimeout didn't happen")
2154

    
2155

    
2156
class TestLineSplitter(unittest.TestCase):
2157
  def test(self):
2158
    lines = []
2159
    ls = utils.LineSplitter(lines.append)
2160
    ls.write("Hello World\n")
2161
    self.assertEqual(lines, [])
2162
    ls.write("Foo\n Bar\r\n ")
2163
    ls.write("Baz")
2164
    ls.write("Moo")
2165
    self.assertEqual(lines, [])
2166
    ls.flush()
2167
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2168
    ls.close()
2169
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2170

    
2171
  def _testExtra(self, line, all_lines, p1, p2):
2172
    self.assertEqual(p1, 999)
2173
    self.assertEqual(p2, "extra")
2174
    all_lines.append(line)
2175

    
2176
  def testExtraArgsNoFlush(self):
2177
    lines = []
2178
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2179
    ls.write("\n\nHello World\n")
2180
    ls.write("Foo\n Bar\r\n ")
2181
    ls.write("")
2182
    ls.write("Baz")
2183
    ls.write("Moo\n\nx\n")
2184
    self.assertEqual(lines, [])
2185
    ls.close()
2186
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2187
                             "", "x"])
2188

    
2189

    
2190
class TestReadLockedPidFile(unittest.TestCase):
2191
  def setUp(self):
2192
    self.tmpdir = tempfile.mkdtemp()
2193

    
2194
  def tearDown(self):
2195
    shutil.rmtree(self.tmpdir)
2196

    
2197
  def testNonExistent(self):
2198
    path = utils.PathJoin(self.tmpdir, "nonexist")
2199
    self.assert_(utils.ReadLockedPidFile(path) is None)
2200

    
2201
  def testUnlocked(self):
2202
    path = utils.PathJoin(self.tmpdir, "pid")
2203
    utils.WriteFile(path, data="123")
2204
    self.assert_(utils.ReadLockedPidFile(path) is None)
2205

    
2206
  def testLocked(self):
2207
    path = utils.PathJoin(self.tmpdir, "pid")
2208
    utils.WriteFile(path, data="123")
2209

    
2210
    fl = utils.FileLock.Open(path)
2211
    try:
2212
      fl.Exclusive(blocking=True)
2213

    
2214
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2215
    finally:
2216
      fl.Close()
2217

    
2218
    self.assert_(utils.ReadLockedPidFile(path) is None)
2219

    
2220
  def testError(self):
2221
    path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2222
    utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2223
    # open(2) should return ENOTDIR
2224
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2225

    
2226

    
2227
class TestCertVerification(testutils.GanetiTestCase):
2228
  def setUp(self):
2229
    testutils.GanetiTestCase.setUp(self)
2230

    
2231
    self.tmpdir = tempfile.mkdtemp()
2232

    
2233
  def tearDown(self):
2234
    shutil.rmtree(self.tmpdir)
2235

    
2236
  def testVerifyCertificate(self):
2237
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2238
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2239
                                           cert_pem)
2240

    
2241
    # Not checking return value as this certificate is expired
2242
    utils.VerifyX509Certificate(cert, 30, 7)
2243

    
2244

    
2245
class TestVerifyCertificateInner(unittest.TestCase):
2246
  def test(self):
2247
    vci = utils._VerifyCertificateInner
2248

    
2249
    # Valid
2250
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2251
                     (None, None))
2252

    
2253
    # Not yet valid
2254
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2255
    self.assertEqual(errcode, utils.CERT_WARNING)
2256

    
2257
    # Expiring soon
2258
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2259
    self.assertEqual(errcode, utils.CERT_ERROR)
2260

    
2261
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2262
    self.assertEqual(errcode, utils.CERT_WARNING)
2263

    
2264
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2265
    self.assertEqual(errcode, None)
2266

    
2267
    # Expired
2268
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2269
    self.assertEqual(errcode, utils.CERT_ERROR)
2270

    
2271
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2272
    self.assertEqual(errcode, utils.CERT_ERROR)
2273

    
2274
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2275
    self.assertEqual(errcode, utils.CERT_ERROR)
2276

    
2277
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2278
    self.assertEqual(errcode, utils.CERT_ERROR)
2279

    
2280

    
2281
class TestHmacFunctions(unittest.TestCase):
2282
  # Digests can be checked with "openssl sha1 -hmac $key"
2283
  def testSha1Hmac(self):
2284
    self.assertEqual(utils.Sha1Hmac("", ""),
2285
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2286
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2287
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2288
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2289
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2290

    
2291
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2292
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2293
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2294

    
2295
  def testSha1HmacSalt(self):
2296
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2297
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2298
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2299
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2300
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2301
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2302

    
2303
  def testVerifySha1Hmac(self):
2304
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2305
                                               "7d64b71fb76370690e1d")))
2306
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2307
                                      ("f904c2476527c6d3e660"
2308
                                       "9ab683c66fa0652cb1dc")))
2309

    
2310
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2311
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2312
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2313
                                      digest.lower()))
2314
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2315
                                      digest.upper()))
2316
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2317
                                      digest.title()))
2318

    
2319
  def testVerifySha1HmacSalt(self):
2320
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2321
                                      ("17a4adc34d69c0d367d4"
2322
                                       "ffbef96fd41d4df7a6e8"),
2323
                                      salt="abc9"))
2324
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2325
                                      ("7f264f8114c9066afc9b"
2326
                                       "b7636e1786d996d3cc0d"),
2327
                                      salt="xyz0"))
2328

    
2329

    
2330
class TestIgnoreSignals(unittest.TestCase):
2331
  """Test the IgnoreSignals decorator"""
2332

    
2333
  @staticmethod
2334
  def _Raise(exception):
2335
    raise exception
2336

    
2337
  @staticmethod
2338
  def _Return(rval):
2339
    return rval
2340

    
2341
  def testIgnoreSignals(self):
2342
    sock_err_intr = socket.error(errno.EINTR, "Message")
2343
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2344

    
2345
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2346
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2347

    
2348
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2349
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2350
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2351
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2352

    
2353
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2354
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2355
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2356
                      sock_err_inval)
2357
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2358
                      env_err_inval)
2359

    
2360
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2361
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2362

    
2363

    
2364
class TestEnsureDirs(unittest.TestCase):
2365
  """Tests for EnsureDirs"""
2366

    
2367
  def setUp(self):
2368
    self.dir = tempfile.mkdtemp()
2369
    self.old_umask = os.umask(0777)
2370

    
2371
  def testEnsureDirs(self):
2372
    utils.EnsureDirs([
2373
        (utils.PathJoin(self.dir, "foo"), 0777),
2374
        (utils.PathJoin(self.dir, "bar"), 0000),
2375
        ])
2376
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2377
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2378

    
2379
  def tearDown(self):
2380
    os.rmdir(utils.PathJoin(self.dir, "foo"))
2381
    os.rmdir(utils.PathJoin(self.dir, "bar"))
2382
    os.rmdir(self.dir)
2383
    os.umask(self.old_umask)
2384

    
2385

    
2386
class TestFormatSeconds(unittest.TestCase):
2387
  def test(self):
2388
    self.assertEqual(utils.FormatSeconds(1), "1s")
2389
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2390
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2391
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2392
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2393
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2394
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2395
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2396
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2397
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2398
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2399

    
2400
  def testFloat(self):
2401
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2402
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2403
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2404
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2405

    
2406

    
2407
class RunIgnoreProcessNotFound(unittest.TestCase):
2408
  @staticmethod
2409
  def _WritePid(fd):
2410
    os.write(fd, str(os.getpid()))
2411
    os.close(fd)
2412
    return True
2413

    
2414
  def test(self):
2415
    (pid_read_fd, pid_write_fd) = os.pipe()
2416

    
2417
    # Start short-lived process which writes its PID to pipe
2418
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2419
    os.close(pid_write_fd)
2420

    
2421
    # Read PID from pipe
2422
    pid = int(os.read(pid_read_fd, 1024))
2423
    os.close(pid_read_fd)
2424

    
2425
    # Try to send signal to process which exited recently
2426
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2427

    
2428

    
2429
class TestIsValidIP4(unittest.TestCase):
2430
  def test(self):
2431
    self.assert_(utils.IsValidIP4("127.0.0.1"))
2432
    self.assert_(utils.IsValidIP4("0.0.0.0"))
2433
    self.assert_(utils.IsValidIP4("255.255.255.255"))
2434
    self.assertFalse(utils.IsValidIP4("0"))
2435
    self.assertFalse(utils.IsValidIP4("1"))
2436
    self.assertFalse(utils.IsValidIP4("1.1.1"))
2437
    self.assertFalse(utils.IsValidIP4("255.255.255.256"))
2438
    self.assertFalse(utils.IsValidIP4("::1"))
2439

    
2440

    
2441
class TestIsValidIP6(unittest.TestCase):
2442
  def test(self):
2443
    self.assert_(utils.IsValidIP6("::"))
2444
    self.assert_(utils.IsValidIP6("::1"))
2445
    self.assert_(utils.IsValidIP6("1" + (":1" * 7)))
2446
    self.assert_(utils.IsValidIP6("ffff" + (":ffff" * 7)))
2447
    self.assertFalse(utils.IsValidIP6("0"))
2448
    self.assertFalse(utils.IsValidIP6(":1"))
2449
    self.assertFalse(utils.IsValidIP6("f" + (":f" * 6)))
2450
    self.assertFalse(utils.IsValidIP6("fffg" + (":ffff" * 7)))
2451
    self.assertFalse(utils.IsValidIP6("fffff" + (":ffff" * 7)))
2452
    self.assertFalse(utils.IsValidIP6("1" + (":1" * 8)))
2453
    self.assertFalse(utils.IsValidIP6("127.0.0.1"))
2454

    
2455

    
2456
class TestIsValidIP(unittest.TestCase):
2457
  def test(self):
2458
    self.assert_(utils.IsValidIP("0.0.0.0"))
2459
    self.assert_(utils.IsValidIP("127.0.0.1"))
2460
    self.assert_(utils.IsValidIP("::"))
2461
    self.assert_(utils.IsValidIP("::1"))
2462
    self.assertFalse(utils.IsValidIP("0"))
2463
    self.assertFalse(utils.IsValidIP("1.1.1.256"))
2464
    self.assertFalse(utils.IsValidIP("a:g::1"))
2465

    
2466

    
2467
if __name__ == '__main__':
2468
  testutils.GanetiTestProgram()