Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ cf0e6df7

History | View | Annotate | Download (85 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import stat
31
import signal
32
import socket
33
import shutil
34
import re
35
import select
36
import string
37
import fcntl
38
import OpenSSL
39
import warnings
40
import distutils.version
41
import glob
42
import errno
43

    
44
import ganeti
45
import testutils
46
from ganeti import constants
47
from ganeti import compat
48
from ganeti import utils
49
from ganeti import errors
50
from ganeti import serializer
51
from ganeti.utils import IsProcessAlive, RunCmd, \
52
     RemoveFile, MatchNameComponent, FormatUnit, \
53
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
54
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
55
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
56
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
57
     UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
58

    
59
from ganeti.errors import LockError, UnitParseError, GenericError, \
60
     ProgrammerError, OpPrereqError
61

    
62

    
63
class TestIsProcessAlive(unittest.TestCase):
64
  """Testing case for IsProcessAlive"""
65

    
66
  def testExists(self):
67
    mypid = os.getpid()
68
    self.assert_(IsProcessAlive(mypid),
69
                 "can't find myself running")
70

    
71
  def testNotExisting(self):
72
    pid_non_existing = os.fork()
73
    if pid_non_existing == 0:
74
      os._exit(0)
75
    elif pid_non_existing < 0:
76
      raise SystemError("can't fork")
77
    os.waitpid(pid_non_existing, 0)
78
    self.assertFalse(IsProcessAlive(pid_non_existing),
79
                     "nonexisting process detected")
80

    
81

    
82
class TestGetProcStatusPath(unittest.TestCase):
83
  def test(self):
84
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
85
    self.assertNotEqual(utils._GetProcStatusPath(1),
86
                        utils._GetProcStatusPath(2))
87

    
88

    
89
class TestIsProcessHandlingSignal(unittest.TestCase):
90
  def setUp(self):
91
    self.tmpdir = tempfile.mkdtemp()
92

    
93
  def tearDown(self):
94
    shutil.rmtree(self.tmpdir)
95

    
96
  def testParseSigsetT(self):
97
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
98
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
99
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
100
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
101
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
102
                     set([2, 10, 32, 33]))
103
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
104
                     set([2, 32, 33]))
105
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
106
                     set([2, 28, 32, 33]))
107
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
108
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
109
                          24, 25, 26, 28, 31]))
110
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
111

    
112
  def testGetProcStatusField(self):
113
    for field in ["SigCgt", "Name", "FDSize"]:
114
      for value in ["", "0", "cat", "  1234 KB"]:
115
        pstatus = "\n".join([
116
          "VmPeak: 999 kB",
117
          "%s: %s" % (field, value),
118
          "TracerPid: 0",
119
          ])
120
        result = utils._GetProcStatusField(pstatus, field)
121
        self.assertEqual(result, value.strip())
122

    
123
  def test(self):
124
    sp = utils.PathJoin(self.tmpdir, "status")
125

    
126
    utils.WriteFile(sp, data="\n".join([
127
      "Name:   bash",
128
      "State:  S (sleeping)",
129
      "SleepAVG:       98%",
130
      "Pid:    22250",
131
      "PPid:   10858",
132
      "TracerPid:      0",
133
      "SigBlk: 0000000000010000",
134
      "SigIgn: 0000000000384004",
135
      "SigCgt: 000000004b813efb",
136
      "CapEff: 0000000000000000",
137
      ]))
138

    
139
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
140

    
141
  def testNoSigCgt(self):
142
    sp = utils.PathJoin(self.tmpdir, "status")
143

    
144
    utils.WriteFile(sp, data="\n".join([
145
      "Name:   bash",
146
      ]))
147

    
148
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
149
                      1234, 10, status_path=sp)
150

    
151
  def testNoSuchFile(self):
152
    sp = utils.PathJoin(self.tmpdir, "notexist")
153

    
154
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
155

    
156
  @staticmethod
157
  def _TestRealProcess():
158
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
159
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160
      raise Exception("SIGUSR1 is handled when it should not be")
161

    
162
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
163
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164
      raise Exception("SIGUSR1 is not handled when it should be")
165

    
166
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
167
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
168
      raise Exception("SIGUSR1 is not handled when it should be")
169

    
170
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
171
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
172
      raise Exception("SIGUSR1 is handled when it should not be")
173

    
174
    return True
175

    
176
  def testRealProcess(self):
177
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
178

    
179

    
180
class TestPidFileFunctions(unittest.TestCase):
181
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
182

    
183
  def setUp(self):
184
    self.dir = tempfile.mkdtemp()
185
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
186
    utils.DaemonPidFileName = self.f_dpn
187

    
188
  def testPidFileFunctions(self):
189
    pid_file = self.f_dpn('test')
190
    utils.WritePidFile('test')
191
    self.failUnless(os.path.exists(pid_file),
192
                    "PID file should have been created")
193
    read_pid = utils.ReadPidFile(pid_file)
194
    self.failUnlessEqual(read_pid, os.getpid())
195
    self.failUnless(utils.IsProcessAlive(read_pid))
196
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
197
    utils.RemovePidFile('test')
198
    self.failIf(os.path.exists(pid_file),
199
                "PID file should not exist anymore")
200
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
201
                         "ReadPidFile should return 0 for missing pid file")
202
    fh = open(pid_file, "w")
203
    fh.write("blah\n")
204
    fh.close()
205
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
206
                         "ReadPidFile should return 0 for invalid pid file")
207
    utils.RemovePidFile('test')
208
    self.failIf(os.path.exists(pid_file),
209
                "PID file should not exist anymore")
210

    
211
  def testKill(self):
212
    pid_file = self.f_dpn('child')
213
    r_fd, w_fd = os.pipe()
214
    new_pid = os.fork()
215
    if new_pid == 0: #child
216
      utils.WritePidFile('child')
217
      os.write(w_fd, 'a')
218
      signal.pause()
219
      os._exit(0)
220
      return
221
    # else we are in the parent
222
    # wait until the child has written the pid file
223
    os.read(r_fd, 1)
224
    read_pid = utils.ReadPidFile(pid_file)
225
    self.failUnlessEqual(read_pid, new_pid)
226
    self.failUnless(utils.IsProcessAlive(new_pid))
227
    utils.KillProcess(new_pid, waitpid=True)
228
    self.failIf(utils.IsProcessAlive(new_pid))
229
    utils.RemovePidFile('child')
230
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
231

    
232
  def tearDown(self):
233
    for name in os.listdir(self.dir):
234
      os.unlink(os.path.join(self.dir, name))
235
    os.rmdir(self.dir)
236

    
237

    
238
class TestRunCmd(testutils.GanetiTestCase):
239
  """Testing case for the RunCmd function"""
240

    
241
  def setUp(self):
242
    testutils.GanetiTestCase.setUp(self)
243
    self.magic = time.ctime() + " ganeti test"
244
    self.fname = self._CreateTempFile()
245

    
246
  def testOk(self):
247
    """Test successful exit code"""
248
    result = RunCmd("/bin/sh -c 'exit 0'")
249
    self.assertEqual(result.exit_code, 0)
250
    self.assertEqual(result.output, "")
251

    
252
  def testFail(self):
253
    """Test fail exit code"""
254
    result = RunCmd("/bin/sh -c 'exit 1'")
255
    self.assertEqual(result.exit_code, 1)
256
    self.assertEqual(result.output, "")
257

    
258
  def testStdout(self):
259
    """Test standard output"""
260
    cmd = 'echo -n "%s"' % self.magic
261
    result = RunCmd("/bin/sh -c '%s'" % cmd)
262
    self.assertEqual(result.stdout, self.magic)
263
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
264
    self.assertEqual(result.output, "")
265
    self.assertFileContent(self.fname, self.magic)
266

    
267
  def testStderr(self):
268
    """Test standard error"""
269
    cmd = 'echo -n "%s"' % self.magic
270
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
271
    self.assertEqual(result.stderr, self.magic)
272
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
273
    self.assertEqual(result.output, "")
274
    self.assertFileContent(self.fname, self.magic)
275

    
276
  def testCombined(self):
277
    """Test combined output"""
278
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
279
    expected = "A" + self.magic + "B" + self.magic
280
    result = RunCmd("/bin/sh -c '%s'" % cmd)
281
    self.assertEqual(result.output, expected)
282
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
283
    self.assertEqual(result.output, "")
284
    self.assertFileContent(self.fname, expected)
285

    
286
  def testSignal(self):
287
    """Test signal"""
288
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
289
    self.assertEqual(result.signal, 15)
290
    self.assertEqual(result.output, "")
291

    
292
  def testListRun(self):
293
    """Test list runs"""
294
    result = RunCmd(["true"])
295
    self.assertEqual(result.signal, None)
296
    self.assertEqual(result.exit_code, 0)
297
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
298
    self.assertEqual(result.signal, None)
299
    self.assertEqual(result.exit_code, 1)
300
    result = RunCmd(["echo", "-n", self.magic])
301
    self.assertEqual(result.signal, None)
302
    self.assertEqual(result.exit_code, 0)
303
    self.assertEqual(result.stdout, self.magic)
304

    
305
  def testFileEmptyOutput(self):
306
    """Test file output"""
307
    result = RunCmd(["true"], output=self.fname)
308
    self.assertEqual(result.signal, None)
309
    self.assertEqual(result.exit_code, 0)
310
    self.assertFileContent(self.fname, "")
311

    
312
  def testLang(self):
313
    """Test locale environment"""
314
    old_env = os.environ.copy()
315
    try:
316
      os.environ["LANG"] = "en_US.UTF-8"
317
      os.environ["LC_ALL"] = "en_US.UTF-8"
318
      result = RunCmd(["locale"])
319
      for line in result.output.splitlines():
320
        key, value = line.split("=", 1)
321
        # Ignore these variables, they're overridden by LC_ALL
322
        if key == "LANG" or key == "LANGUAGE":
323
          continue
324
        self.failIf(value and value != "C" and value != '"C"',
325
            "Variable %s is set to the invalid value '%s'" % (key, value))
326
    finally:
327
      os.environ = old_env
328

    
329
  def testDefaultCwd(self):
330
    """Test default working directory"""
331
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
332

    
333
  def testCwd(self):
334
    """Test default working directory"""
335
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
336
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
337
    cwd = os.getcwd()
338
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
339

    
340
  def testResetEnv(self):
341
    """Test environment reset functionality"""
342
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
343
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
344
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
345

    
346

    
347
class TestRunParts(unittest.TestCase):
348
  """Testing case for the RunParts function"""
349

    
350
  def setUp(self):
351
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
352

    
353
  def tearDown(self):
354
    shutil.rmtree(self.rundir)
355

    
356
  def testEmpty(self):
357
    """Test on an empty dir"""
358
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
359

    
360
  def testSkipWrongName(self):
361
    """Test that wrong files are skipped"""
362
    fname = os.path.join(self.rundir, "00test.dot")
363
    utils.WriteFile(fname, data="")
364
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
365
    relname = os.path.basename(fname)
366
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
367
                         [(relname, constants.RUNPARTS_SKIP, None)])
368

    
369
  def testSkipNonExec(self):
370
    """Test that non executable files are skipped"""
371
    fname = os.path.join(self.rundir, "00test")
372
    utils.WriteFile(fname, data="")
373
    relname = os.path.basename(fname)
374
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
375
                         [(relname, constants.RUNPARTS_SKIP, None)])
376

    
377
  def testError(self):
378
    """Test error on a broken executable"""
379
    fname = os.path.join(self.rundir, "00test")
380
    utils.WriteFile(fname, data="")
381
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
382
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
383
    self.failUnlessEqual(relname, os.path.basename(fname))
384
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
385
    self.failUnless(error)
386

    
387
  def testSorted(self):
388
    """Test executions are sorted"""
389
    files = []
390
    files.append(os.path.join(self.rundir, "64test"))
391
    files.append(os.path.join(self.rundir, "00test"))
392
    files.append(os.path.join(self.rundir, "42test"))
393

    
394
    for fname in files:
395
      utils.WriteFile(fname, data="")
396

    
397
    results = RunParts(self.rundir, reset_env=True)
398

    
399
    for fname in sorted(files):
400
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
401

    
402
  def testOk(self):
403
    """Test correct execution"""
404
    fname = os.path.join(self.rundir, "00test")
405
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
406
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
407
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
408
    self.failUnlessEqual(relname, os.path.basename(fname))
409
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
410
    self.failUnlessEqual(runresult.stdout, "ciao")
411

    
412
  def testRunFail(self):
413
    """Test correct execution, with run failure"""
414
    fname = os.path.join(self.rundir, "00test")
415
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
416
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
417
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
418
    self.failUnlessEqual(relname, os.path.basename(fname))
419
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
420
    self.failUnlessEqual(runresult.exit_code, 1)
421
    self.failUnless(runresult.failed)
422

    
423
  def testRunMix(self):
424
    files = []
425
    files.append(os.path.join(self.rundir, "00test"))
426
    files.append(os.path.join(self.rundir, "42test"))
427
    files.append(os.path.join(self.rundir, "64test"))
428
    files.append(os.path.join(self.rundir, "99test"))
429

    
430
    files.sort()
431

    
432
    # 1st has errors in execution
433
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
434
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
435

    
436
    # 2nd is skipped
437
    utils.WriteFile(files[1], data="")
438

    
439
    # 3rd cannot execute properly
440
    utils.WriteFile(files[2], data="")
441
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
442

    
443
    # 4th execs
444
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
445
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
446

    
447
    results = RunParts(self.rundir, reset_env=True)
448

    
449
    (relname, status, runresult) = results[0]
450
    self.failUnlessEqual(relname, os.path.basename(files[0]))
451
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
452
    self.failUnlessEqual(runresult.exit_code, 1)
453
    self.failUnless(runresult.failed)
454

    
455
    (relname, status, runresult) = results[1]
456
    self.failUnlessEqual(relname, os.path.basename(files[1]))
457
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
458
    self.failUnlessEqual(runresult, None)
459

    
460
    (relname, status, runresult) = results[2]
461
    self.failUnlessEqual(relname, os.path.basename(files[2]))
462
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
463
    self.failUnless(runresult)
464

    
465
    (relname, status, runresult) = results[3]
466
    self.failUnlessEqual(relname, os.path.basename(files[3]))
467
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
468
    self.failUnlessEqual(runresult.output, "ciao")
469
    self.failUnlessEqual(runresult.exit_code, 0)
470
    self.failUnless(not runresult.failed)
471

    
472

    
473
class TestStartDaemon(testutils.GanetiTestCase):
474
  def setUp(self):
475
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
476
    self.tmpfile = os.path.join(self.tmpdir, "test")
477

    
478
  def tearDown(self):
479
    shutil.rmtree(self.tmpdir)
480

    
481
  def testShell(self):
482
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
483
    self._wait(self.tmpfile, 60.0, "Hello World")
484

    
485
  def testShellOutput(self):
486
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
487
    self._wait(self.tmpfile, 60.0, "Hello World")
488

    
489
  def testNoShellNoOutput(self):
490
    utils.StartDaemon(["pwd"])
491

    
492
  def testNoShellNoOutputTouch(self):
493
    testfile = os.path.join(self.tmpdir, "check")
494
    self.failIf(os.path.exists(testfile))
495
    utils.StartDaemon(["touch", testfile])
496
    self._wait(testfile, 60.0, "")
497

    
498
  def testNoShellOutput(self):
499
    utils.StartDaemon(["pwd"], output=self.tmpfile)
500
    self._wait(self.tmpfile, 60.0, "/")
501

    
502
  def testNoShellOutputCwd(self):
503
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
504
    self._wait(self.tmpfile, 60.0, os.getcwd())
505

    
506
  def testShellEnv(self):
507
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
508
                      env={ "GNT_TEST_VAR": "Hello World", })
509
    self._wait(self.tmpfile, 60.0, "Hello World")
510

    
511
  def testNoShellEnv(self):
512
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
513
                      env={ "GNT_TEST_VAR": "Hello World", })
514
    self._wait(self.tmpfile, 60.0, "Hello World")
515

    
516
  def testOutputFd(self):
517
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
518
    try:
519
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
520
    finally:
521
      os.close(fd)
522
    self._wait(self.tmpfile, 60.0, os.getcwd())
523

    
524
  def testPid(self):
525
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
526
    self._wait(self.tmpfile, 60.0, str(pid))
527

    
528
  def testPidFile(self):
529
    pidfile = os.path.join(self.tmpdir, "pid")
530
    checkfile = os.path.join(self.tmpdir, "abort")
531

    
532
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
533
                            output=self.tmpfile)
534
    try:
535
      fd = os.open(pidfile, os.O_RDONLY)
536
      try:
537
        # Check file is locked
538
        self.assertRaises(errors.LockError, utils.LockFile, fd)
539

    
540
        pidtext = os.read(fd, 100)
541
      finally:
542
        os.close(fd)
543

    
544
      self.assertEqual(int(pidtext.strip()), pid)
545

    
546
      self.assert_(utils.IsProcessAlive(pid))
547
    finally:
548
      # No matter what happens, kill daemon
549
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
550
      self.failIf(utils.IsProcessAlive(pid))
551

    
552
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
553

    
554
  def _wait(self, path, timeout, expected):
555
    # Due to the asynchronous nature of daemon processes, polling is necessary.
556
    # A timeout makes sure the test doesn't hang forever.
557
    def _CheckFile():
558
      if not (os.path.isfile(path) and
559
              utils.ReadFile(path).strip() == expected):
560
        raise utils.RetryAgain()
561

    
562
    try:
563
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
564
    except utils.RetryTimeout:
565
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
566
                " didn't write the correct output" % timeout)
567

    
568
  def testError(self):
569
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
570
                      ["./does-NOT-EXIST/here/0123456789"])
571
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
572
                      ["./does-NOT-EXIST/here/0123456789"],
573
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
574
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
575
                      ["./does-NOT-EXIST/here/0123456789"],
576
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
577
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
578
                      ["./does-NOT-EXIST/here/0123456789"],
579
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
580

    
581
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
582
    try:
583
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
584
                        ["./does-NOT-EXIST/here/0123456789"],
585
                        output=self.tmpfile, output_fd=fd)
586
    finally:
587
      os.close(fd)
588

    
589

    
590
class TestSetCloseOnExecFlag(unittest.TestCase):
591
  """Tests for SetCloseOnExecFlag"""
592

    
593
  def setUp(self):
594
    self.tmpfile = tempfile.TemporaryFile()
595

    
596
  def testEnable(self):
597
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
598
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
599
                    fcntl.FD_CLOEXEC)
600

    
601
  def testDisable(self):
602
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
603
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
604
                fcntl.FD_CLOEXEC)
605

    
606

    
607
class TestSetNonblockFlag(unittest.TestCase):
608
  def setUp(self):
609
    self.tmpfile = tempfile.TemporaryFile()
610

    
611
  def testEnable(self):
612
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
613
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
614
                    os.O_NONBLOCK)
615

    
616
  def testDisable(self):
617
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
618
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
619
                os.O_NONBLOCK)
620

    
621

    
622
class TestRemoveFile(unittest.TestCase):
623
  """Test case for the RemoveFile function"""
624

    
625
  def setUp(self):
626
    """Create a temp dir and file for each case"""
627
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
628
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
629
    os.close(fd)
630

    
631
  def tearDown(self):
632
    if os.path.exists(self.tmpfile):
633
      os.unlink(self.tmpfile)
634
    os.rmdir(self.tmpdir)
635

    
636
  def testIgnoreDirs(self):
637
    """Test that RemoveFile() ignores directories"""
638
    self.assertEqual(None, RemoveFile(self.tmpdir))
639

    
640
  def testIgnoreNotExisting(self):
641
    """Test that RemoveFile() ignores non-existing files"""
642
    RemoveFile(self.tmpfile)
643
    RemoveFile(self.tmpfile)
644

    
645
  def testRemoveFile(self):
646
    """Test that RemoveFile does remove a file"""
647
    RemoveFile(self.tmpfile)
648
    if os.path.exists(self.tmpfile):
649
      self.fail("File '%s' not removed" % self.tmpfile)
650

    
651
  def testRemoveSymlink(self):
652
    """Test that RemoveFile does remove symlinks"""
653
    symlink = self.tmpdir + "/symlink"
654
    os.symlink("no-such-file", symlink)
655
    RemoveFile(symlink)
656
    if os.path.exists(symlink):
657
      self.fail("File '%s' not removed" % symlink)
658
    os.symlink(self.tmpfile, symlink)
659
    RemoveFile(symlink)
660
    if os.path.exists(symlink):
661
      self.fail("File '%s' not removed" % symlink)
662

    
663

    
664
class TestRename(unittest.TestCase):
665
  """Test case for RenameFile"""
666

    
667
  def setUp(self):
668
    """Create a temporary directory"""
669
    self.tmpdir = tempfile.mkdtemp()
670
    self.tmpfile = os.path.join(self.tmpdir, "test1")
671

    
672
    # Touch the file
673
    open(self.tmpfile, "w").close()
674

    
675
  def tearDown(self):
676
    """Remove temporary directory"""
677
    shutil.rmtree(self.tmpdir)
678

    
679
  def testSimpleRename1(self):
680
    """Simple rename 1"""
681
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
682
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
683

    
684
  def testSimpleRename2(self):
685
    """Simple rename 2"""
686
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
687
                     mkdir=True)
688
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
689

    
690
  def testRenameMkdir(self):
691
    """Rename with mkdir"""
692
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
693
                     mkdir=True)
694
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
695
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
696

    
697
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
698
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
699
                     mkdir=True)
700
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
701
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
702
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
703

    
704

    
705
class TestMatchNameComponent(unittest.TestCase):
706
  """Test case for the MatchNameComponent function"""
707

    
708
  def testEmptyList(self):
709
    """Test that there is no match against an empty list"""
710

    
711
    self.failUnlessEqual(MatchNameComponent("", []), None)
712
    self.failUnlessEqual(MatchNameComponent("test", []), None)
713

    
714
  def testSingleMatch(self):
715
    """Test that a single match is performed correctly"""
716
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
717
    for key in "test2", "test2.example", "test2.example.com":
718
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
719

    
720
  def testMultipleMatches(self):
721
    """Test that a multiple match is returned as None"""
722
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
723
    for key in "test1", "test1.example":
724
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
725

    
726
  def testFullMatch(self):
727
    """Test that a full match is returned correctly"""
728
    key1 = "test1"
729
    key2 = "test1.example"
730
    mlist = [key2, key2 + ".com"]
731
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
732
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
733

    
734
  def testCaseInsensitivePartialMatch(self):
735
    """Test for the case_insensitive keyword"""
736
    mlist = ["test1.example.com", "test2.example.net"]
737
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
738
                     "test2.example.net")
739
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
740
                     "test2.example.net")
741
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
742
                     "test2.example.net")
743
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
744
                     "test2.example.net")
745

    
746

    
747
  def testCaseInsensitiveFullMatch(self):
748
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
749
    # Between the two ts1 a full string match non-case insensitive should work
750
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
751
                     None)
752
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
753
                     "ts1.ex")
754
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
755
                     "ts1.ex")
756
    # Between the two ts2 only case differs, so only case-match works
757
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
758
                     "ts2.ex")
759
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
760
                     "Ts2.ex")
761
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
762
                     None)
763

    
764

    
765
class TestReadFile(testutils.GanetiTestCase):
766

    
767
  def testReadAll(self):
768
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
769
    self.assertEqual(len(data), 814)
770

    
771
    h = compat.md5_hash()
772
    h.update(data)
773
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
774

    
775
  def testReadSize(self):
776
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
777
                          size=100)
778
    self.assertEqual(len(data), 100)
779

    
780
    h = compat.md5_hash()
781
    h.update(data)
782
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
783

    
784
  def testError(self):
785
    self.assertRaises(EnvironmentError, utils.ReadFile,
786
                      "/dev/null/does-not-exist")
787

    
788

    
789
class TestReadOneLineFile(testutils.GanetiTestCase):
790

    
791
  def setUp(self):
792
    testutils.GanetiTestCase.setUp(self)
793

    
794
  def testDefault(self):
795
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
796
    self.assertEqual(len(data), 27)
797
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
798

    
799
  def testNotStrict(self):
800
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
801
    self.assertEqual(len(data), 27)
802
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
803

    
804
  def testStrictFailure(self):
805
    self.assertRaises(errors.GenericError, ReadOneLineFile,
806
                      self._TestDataFilename("cert1.pem"), strict=True)
807

    
808
  def testLongLine(self):
809
    dummydata = (1024 * "Hello World! ")
810
    myfile = self._CreateTempFile()
811
    utils.WriteFile(myfile, data=dummydata)
812
    datastrict = ReadOneLineFile(myfile, strict=True)
813
    datalax = ReadOneLineFile(myfile, strict=False)
814
    self.assertEqual(dummydata, datastrict)
815
    self.assertEqual(dummydata, datalax)
816

    
817
  def testNewline(self):
818
    myfile = self._CreateTempFile()
819
    myline = "myline"
820
    for nl in ["", "\n", "\r\n"]:
821
      dummydata = "%s%s" % (myline, nl)
822
      utils.WriteFile(myfile, data=dummydata)
823
      datalax = ReadOneLineFile(myfile, strict=False)
824
      self.assertEqual(myline, datalax)
825
      datastrict = ReadOneLineFile(myfile, strict=True)
826
      self.assertEqual(myline, datastrict)
827

    
828
  def testWhitespaceAndMultipleLines(self):
829
    myfile = self._CreateTempFile()
830
    for nl in ["", "\n", "\r\n"]:
831
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
832
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
833
        utils.WriteFile(myfile, data=dummydata)
834
        datalax = ReadOneLineFile(myfile, strict=False)
835
        if nl:
836
          self.assert_(set("\r\n") & set(dummydata))
837
          self.assertRaises(errors.GenericError, ReadOneLineFile,
838
                            myfile, strict=True)
839
          explen = len("Foo bar baz ") + len(ws)
840
          self.assertEqual(len(datalax), explen)
841
          self.assertEqual(datalax, dummydata[:explen])
842
          self.assertFalse(set("\r\n") & set(datalax))
843
        else:
844
          datastrict = ReadOneLineFile(myfile, strict=True)
845
          self.assertEqual(dummydata, datastrict)
846
          self.assertEqual(dummydata, datalax)
847

    
848
  def testEmptylines(self):
849
    myfile = self._CreateTempFile()
850
    myline = "myline"
851
    for nl in ["\n", "\r\n"]:
852
      for ol in ["", "otherline"]:
853
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
854
        utils.WriteFile(myfile, data=dummydata)
855
        self.assert_(set("\r\n") & set(dummydata))
856
        datalax = ReadOneLineFile(myfile, strict=False)
857
        self.assertEqual(myline, datalax)
858
        if ol:
859
          self.assertRaises(errors.GenericError, ReadOneLineFile,
860
                            myfile, strict=True)
861
        else:
862
          datastrict = ReadOneLineFile(myfile, strict=True)
863
          self.assertEqual(myline, datastrict)
864

    
865

    
866
class TestTimestampForFilename(unittest.TestCase):
867
  def test(self):
868
    self.assert_("." not in utils.TimestampForFilename())
869
    self.assert_(":" not in utils.TimestampForFilename())
870

    
871

    
872
class TestCreateBackup(testutils.GanetiTestCase):
873
  def setUp(self):
874
    testutils.GanetiTestCase.setUp(self)
875

    
876
    self.tmpdir = tempfile.mkdtemp()
877

    
878
  def tearDown(self):
879
    testutils.GanetiTestCase.tearDown(self)
880

    
881
    shutil.rmtree(self.tmpdir)
882

    
883
  def testEmpty(self):
884
    filename = utils.PathJoin(self.tmpdir, "config.data")
885
    utils.WriteFile(filename, data="")
886
    bname = utils.CreateBackup(filename)
887
    self.assertFileContent(bname, "")
888
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
889
    utils.CreateBackup(filename)
890
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
891
    utils.CreateBackup(filename)
892
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
893

    
894
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
895
    os.mkfifo(fifoname)
896
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
897

    
898
  def testContent(self):
899
    bkpcount = 0
900
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
901
      for rep in [1, 2, 10, 127]:
902
        testdata = data * rep
903

    
904
        filename = utils.PathJoin(self.tmpdir, "test.data_")
905
        utils.WriteFile(filename, data=testdata)
906
        self.assertFileContent(filename, testdata)
907

    
908
        for _ in range(3):
909
          bname = utils.CreateBackup(filename)
910
          bkpcount += 1
911
          self.assertFileContent(bname, testdata)
912
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
913

    
914

    
915
class TestFormatUnit(unittest.TestCase):
916
  """Test case for the FormatUnit function"""
917

    
918
  def testMiB(self):
919
    self.assertEqual(FormatUnit(1, 'h'), '1M')
920
    self.assertEqual(FormatUnit(100, 'h'), '100M')
921
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
922

    
923
    self.assertEqual(FormatUnit(1, 'm'), '1')
924
    self.assertEqual(FormatUnit(100, 'm'), '100')
925
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
926

    
927
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
928
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
929
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
930
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
931

    
932
  def testGiB(self):
933
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
934
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
935
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
936
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
937

    
938
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
939
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
940
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
941
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
942

    
943
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
944
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
945
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
946

    
947
  def testTiB(self):
948
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
949
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
950
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
951

    
952
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
953
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
954
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
955

    
956
class TestParseUnit(unittest.TestCase):
957
  """Test case for the ParseUnit function"""
958

    
959
  SCALES = (('', 1),
960
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
961
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
962
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
963

    
964
  def testRounding(self):
965
    self.assertEqual(ParseUnit('0'), 0)
966
    self.assertEqual(ParseUnit('1'), 4)
967
    self.assertEqual(ParseUnit('2'), 4)
968
    self.assertEqual(ParseUnit('3'), 4)
969

    
970
    self.assertEqual(ParseUnit('124'), 124)
971
    self.assertEqual(ParseUnit('125'), 128)
972
    self.assertEqual(ParseUnit('126'), 128)
973
    self.assertEqual(ParseUnit('127'), 128)
974
    self.assertEqual(ParseUnit('128'), 128)
975
    self.assertEqual(ParseUnit('129'), 132)
976
    self.assertEqual(ParseUnit('130'), 132)
977

    
978
  def testFloating(self):
979
    self.assertEqual(ParseUnit('0'), 0)
980
    self.assertEqual(ParseUnit('0.5'), 4)
981
    self.assertEqual(ParseUnit('1.75'), 4)
982
    self.assertEqual(ParseUnit('1.99'), 4)
983
    self.assertEqual(ParseUnit('2.00'), 4)
984
    self.assertEqual(ParseUnit('2.01'), 4)
985
    self.assertEqual(ParseUnit('3.99'), 4)
986
    self.assertEqual(ParseUnit('4.00'), 4)
987
    self.assertEqual(ParseUnit('4.01'), 8)
988
    self.assertEqual(ParseUnit('1.5G'), 1536)
989
    self.assertEqual(ParseUnit('1.8G'), 1844)
990
    self.assertEqual(ParseUnit('8.28T'), 8682212)
991

    
992
  def testSuffixes(self):
993
    for sep in ('', ' ', '   ', "\t", "\t "):
994
      for suffix, scale in TestParseUnit.SCALES:
995
        for func in (lambda x: x, str.lower, str.upper):
996
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
997
                           1024 * scale)
998

    
999
  def testInvalidInput(self):
1000
    for sep in ('-', '_', ',', 'a'):
1001
      for suffix, _ in TestParseUnit.SCALES:
1002
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
1003

    
1004
    for suffix, _ in TestParseUnit.SCALES:
1005
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
1006

    
1007

    
1008
class TestSshKeys(testutils.GanetiTestCase):
1009
  """Test case for the AddAuthorizedKey function"""
1010

    
1011
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1012
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
1013
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1014

    
1015
  def setUp(self):
1016
    testutils.GanetiTestCase.setUp(self)
1017
    self.tmpname = self._CreateTempFile()
1018
    handle = open(self.tmpname, 'w')
1019
    try:
1020
      handle.write("%s\n" % TestSshKeys.KEY_A)
1021
      handle.write("%s\n" % TestSshKeys.KEY_B)
1022
    finally:
1023
      handle.close()
1024

    
1025
  def testAddingNewKey(self):
1026
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1027

    
1028
    self.assertFileContent(self.tmpname,
1029
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1030
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1031
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1032
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1033

    
1034
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1035
    AddAuthorizedKey(self.tmpname,
1036
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1037

    
1038
    self.assertFileContent(self.tmpname,
1039
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1040
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1041
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1042
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1043

    
1044
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1045
    AddAuthorizedKey(self.tmpname,
1046
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1047

    
1048
    self.assertFileContent(self.tmpname,
1049
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1050
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1051
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1052

    
1053
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1054
    RemoveAuthorizedKey(self.tmpname,
1055
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1056

    
1057
    self.assertFileContent(self.tmpname,
1058
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1059
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1060

    
1061
  def testRemovingNonExistingKey(self):
1062
    RemoveAuthorizedKey(self.tmpname,
1063
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1064

    
1065
    self.assertFileContent(self.tmpname,
1066
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1067
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
1068
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1069

    
1070

    
1071
class TestEtcHosts(testutils.GanetiTestCase):
1072
  """Test functions modifying /etc/hosts"""
1073

    
1074
  def setUp(self):
1075
    testutils.GanetiTestCase.setUp(self)
1076
    self.tmpname = self._CreateTempFile()
1077
    handle = open(self.tmpname, 'w')
1078
    try:
1079
      handle.write('# This is a test file for /etc/hosts\n')
1080
      handle.write('127.0.0.1\tlocalhost\n')
1081
      handle.write('192.168.1.1 router gw\n')
1082
    finally:
1083
      handle.close()
1084

    
1085
  def testSettingNewIp(self):
1086
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
1087

    
1088
    self.assertFileContent(self.tmpname,
1089
      "# This is a test file for /etc/hosts\n"
1090
      "127.0.0.1\tlocalhost\n"
1091
      "192.168.1.1 router gw\n"
1092
      "1.2.3.4\tmyhost.domain.tld myhost\n")
1093
    self.assertFileMode(self.tmpname, 0644)
1094

    
1095
  def testSettingExistingIp(self):
1096
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
1097
                     ['myhost'])
1098

    
1099
    self.assertFileContent(self.tmpname,
1100
      "# This is a test file for /etc/hosts\n"
1101
      "127.0.0.1\tlocalhost\n"
1102
      "192.168.1.1\tmyhost.domain.tld myhost\n")
1103
    self.assertFileMode(self.tmpname, 0644)
1104

    
1105
  def testSettingDuplicateName(self):
1106
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
1107

    
1108
    self.assertFileContent(self.tmpname,
1109
      "# This is a test file for /etc/hosts\n"
1110
      "127.0.0.1\tlocalhost\n"
1111
      "192.168.1.1 router gw\n"
1112
      "1.2.3.4\tmyhost\n")
1113
    self.assertFileMode(self.tmpname, 0644)
1114

    
1115
  def testRemovingExistingHost(self):
1116
    RemoveEtcHostsEntry(self.tmpname, 'router')
1117

    
1118
    self.assertFileContent(self.tmpname,
1119
      "# This is a test file for /etc/hosts\n"
1120
      "127.0.0.1\tlocalhost\n"
1121
      "192.168.1.1 gw\n")
1122
    self.assertFileMode(self.tmpname, 0644)
1123

    
1124
  def testRemovingSingleExistingHost(self):
1125
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1126

    
1127
    self.assertFileContent(self.tmpname,
1128
      "# This is a test file for /etc/hosts\n"
1129
      "192.168.1.1 router gw\n")
1130
    self.assertFileMode(self.tmpname, 0644)
1131

    
1132
  def testRemovingNonExistingHost(self):
1133
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1134

    
1135
    self.assertFileContent(self.tmpname,
1136
      "# This is a test file for /etc/hosts\n"
1137
      "127.0.0.1\tlocalhost\n"
1138
      "192.168.1.1 router gw\n")
1139
    self.assertFileMode(self.tmpname, 0644)
1140

    
1141
  def testRemovingAlias(self):
1142
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1143

    
1144
    self.assertFileContent(self.tmpname,
1145
      "# This is a test file for /etc/hosts\n"
1146
      "127.0.0.1\tlocalhost\n"
1147
      "192.168.1.1 router\n")
1148
    self.assertFileMode(self.tmpname, 0644)
1149

    
1150

    
1151
class TestShellQuoting(unittest.TestCase):
1152
  """Test case for shell quoting functions"""
1153

    
1154
  def testShellQuote(self):
1155
    self.assertEqual(ShellQuote('abc'), "abc")
1156
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1157
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1158
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1159
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1160

    
1161
  def testShellQuoteArgs(self):
1162
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1163
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1164
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1165

    
1166

    
1167
class _BaseTcpPingTest:
1168
  """Base class for TcpPing tests against listen(2)ing port"""
1169
  family = None
1170
  address = None
1171

    
1172
  def setUp(self):
1173
    self.listener = socket.socket(self.family, socket.SOCK_STREAM)
1174
    self.listener.bind((self.address, 0))
1175
    self.listenerport = self.listener.getsockname()[1]
1176
    self.listener.listen(1)
1177

    
1178
  def tearDown(self):
1179
    self.listener.shutdown(socket.SHUT_RDWR)
1180
    del self.listener
1181
    del self.listenerport
1182

    
1183
  def testTcpPingToLocalHostAccept(self):
1184
    self.assert_(TcpPing(self.address,
1185
                         self.listenerport,
1186
                         timeout=constants.TCP_PING_TIMEOUT,
1187
                         live_port_needed=True,
1188
                         source=self.address,
1189
                         ),
1190
                 "failed to connect to test listener")
1191

    
1192
    self.assert_(TcpPing(self.address, self.listenerport,
1193
                         timeout=constants.TCP_PING_TIMEOUT,
1194
                         live_port_needed=True),
1195
                 "failed to connect to test listener (no source)")
1196

    
1197

    
1198
class TestIP4TcpPing(unittest.TestCase, _BaseTcpPingTest):
1199
  """Testcase for IPv4 TCP version of ping - against listen(2)ing port"""
1200
  family = socket.AF_INET
1201
  address = constants.IP4_ADDRESS_LOCALHOST
1202

    
1203
  def setUp(self):
1204
    unittest.TestCase.setUp(self)
1205
    _BaseTcpPingTest.setUp(self)
1206

    
1207
  def tearDown(self):
1208
    unittest.TestCase.tearDown(self)
1209
    _BaseTcpPingTest.tearDown(self)
1210

    
1211

    
1212
class TestIP6TcpPing(unittest.TestCase, _BaseTcpPingTest):
1213
  """Testcase for IPv6 TCP version of ping - against listen(2)ing port"""
1214
  family = socket.AF_INET6
1215
  address = constants.IP6_ADDRESS_LOCALHOST
1216

    
1217
  def setUp(self):
1218
    unittest.TestCase.setUp(self)
1219
    _BaseTcpPingTest.setUp(self)
1220

    
1221
  def tearDown(self):
1222
    unittest.TestCase.tearDown(self)
1223
    _BaseTcpPingTest.tearDown(self)
1224

    
1225

    
1226
class _BaseTcpPingDeafTest:
1227
  """Base class for TcpPing tests against non listen(2)ing port"""
1228
  family = None
1229
  address = None
1230

    
1231
  def setUp(self):
1232
    self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
1233
    self.deaflistener.bind((self.address, 0))
1234
    self.deaflistenerport = self.deaflistener.getsockname()[1]
1235

    
1236
  def tearDown(self):
1237
    del self.deaflistener
1238
    del self.deaflistenerport
1239

    
1240
  def testTcpPingToLocalHostAcceptDeaf(self):
1241
    self.assertFalse(TcpPing(self.address,
1242
                             self.deaflistenerport,
1243
                             timeout=constants.TCP_PING_TIMEOUT,
1244
                             live_port_needed=True,
1245
                             source=self.address,
1246
                             ), # need successful connect(2)
1247
                     "successfully connected to deaf listener")
1248

    
1249
    self.assertFalse(TcpPing(self.address,
1250
                             self.deaflistenerport,
1251
                             timeout=constants.TCP_PING_TIMEOUT,
1252
                             live_port_needed=True,
1253
                             ), # need successful connect(2)
1254
                     "successfully connected to deaf listener (no source)")
1255

    
1256
  def testTcpPingToLocalHostNoAccept(self):
1257
    self.assert_(TcpPing(self.address,
1258
                         self.deaflistenerport,
1259
                         timeout=constants.TCP_PING_TIMEOUT,
1260
                         live_port_needed=False,
1261
                         source=self.address,
1262
                         ), # ECONNREFUSED is OK
1263
                 "failed to ping alive host on deaf port")
1264

    
1265
    self.assert_(TcpPing(self.address,
1266
                         self.deaflistenerport,
1267
                         timeout=constants.TCP_PING_TIMEOUT,
1268
                         live_port_needed=False,
1269
                         ), # ECONNREFUSED is OK
1270
                 "failed to ping alive host on deaf port (no source)")
1271

    
1272

    
1273
class TestIP4TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
1274
  """Testcase for IPv4 TCP version of ping - against non listen(2)ing port"""
1275
  family = socket.AF_INET
1276
  address = constants.IP4_ADDRESS_LOCALHOST
1277

    
1278
  def setUp(self):
1279
    self.deaflistener = socket.socket(self.family, socket.SOCK_STREAM)
1280
    self.deaflistener.bind((self.address, 0))
1281
    self.deaflistenerport = self.deaflistener.getsockname()[1]
1282

    
1283
  def tearDown(self):
1284
    del self.deaflistener
1285
    del self.deaflistenerport
1286

    
1287

    
1288
class TestIP6TcpPingDeaf(unittest.TestCase, _BaseTcpPingDeafTest):
1289
  """Testcase for IPv6 TCP version of ping - against non listen(2)ing port"""
1290
  family = socket.AF_INET6
1291
  address = constants.IP6_ADDRESS_LOCALHOST
1292

    
1293
  def setUp(self):
1294
    unittest.TestCase.setUp(self)
1295
    _BaseTcpPingDeafTest.setUp(self)
1296

    
1297
  def tearDown(self):
1298
    unittest.TestCase.tearDown(self)
1299
    _BaseTcpPingDeafTest.tearDown(self)
1300

    
1301

    
1302
class TestOwnIpAddress(unittest.TestCase):
1303
  """Testcase for OwnIpAddress"""
1304

    
1305
  def testOwnLoopback(self):
1306
    """check having the loopback ip"""
1307
    self.failUnless(OwnIpAddress(constants.IP4_ADDRESS_LOCALHOST),
1308
                    "Should own the loopback address")
1309

    
1310
  def testNowOwnAddress(self):
1311
    """check that I don't own an address"""
1312

    
1313
    # Network 192.0.2.0/24 is reserved for test/documentation as per
1314
    # RFC 5735, so we *should* not have an address of this range... if
1315
    # this fails, we should extend the test to multiple addresses
1316
    DST_IP = "192.0.2.1"
1317
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1318

    
1319

    
1320
def _GetSocketCredentials(path):
1321
  """Connect to a Unix socket and return remote credentials.
1322

1323
  """
1324
  sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1325
  try:
1326
    sock.settimeout(10)
1327
    sock.connect(path)
1328
    return utils.GetSocketCredentials(sock)
1329
  finally:
1330
    sock.close()
1331

    
1332

    
1333
class TestGetSocketCredentials(unittest.TestCase):
1334
  def setUp(self):
1335
    self.tmpdir = tempfile.mkdtemp()
1336
    self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1337

    
1338
    self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1339
    self.listener.settimeout(10)
1340
    self.listener.bind(self.sockpath)
1341
    self.listener.listen(1)
1342

    
1343
  def tearDown(self):
1344
    self.listener.shutdown(socket.SHUT_RDWR)
1345
    self.listener.close()
1346
    shutil.rmtree(self.tmpdir)
1347

    
1348
  def test(self):
1349
    (c2pr, c2pw) = os.pipe()
1350

    
1351
    # Start child process
1352
    child = os.fork()
1353
    if child == 0:
1354
      try:
1355
        data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1356

    
1357
        os.write(c2pw, data)
1358
        os.close(c2pw)
1359

    
1360
        os._exit(0)
1361
      finally:
1362
        os._exit(1)
1363

    
1364
    os.close(c2pw)
1365

    
1366
    # Wait for one connection
1367
    (conn, _) = self.listener.accept()
1368
    conn.recv(1)
1369
    conn.close()
1370

    
1371
    # Wait for result
1372
    result = os.read(c2pr, 4096)
1373
    os.close(c2pr)
1374

    
1375
    # Check child's exit code
1376
    (_, status) = os.waitpid(child, 0)
1377
    self.assertFalse(os.WIFSIGNALED(status))
1378
    self.assertEqual(os.WEXITSTATUS(status), 0)
1379

    
1380
    # Check result
1381
    (pid, uid, gid) = serializer.LoadJson(result)
1382
    self.assertEqual(pid, os.getpid())
1383
    self.assertEqual(uid, os.getuid())
1384
    self.assertEqual(gid, os.getgid())
1385

    
1386

    
1387
class TestListVisibleFiles(unittest.TestCase):
1388
  """Test case for ListVisibleFiles"""
1389

    
1390
  def setUp(self):
1391
    self.path = tempfile.mkdtemp()
1392

    
1393
  def tearDown(self):
1394
    shutil.rmtree(self.path)
1395

    
1396
  def _CreateFiles(self, files):
1397
    for name in files:
1398
      utils.WriteFile(os.path.join(self.path, name), data="test")
1399

    
1400
  def _test(self, files, expected):
1401
    self._CreateFiles(files)
1402
    found = ListVisibleFiles(self.path)
1403
    self.assertEqual(set(found), set(expected))
1404

    
1405
  def testAllVisible(self):
1406
    files = ["a", "b", "c"]
1407
    expected = files
1408
    self._test(files, expected)
1409

    
1410
  def testNoneVisible(self):
1411
    files = [".a", ".b", ".c"]
1412
    expected = []
1413
    self._test(files, expected)
1414

    
1415
  def testSomeVisible(self):
1416
    files = ["a", "b", ".c"]
1417
    expected = ["a", "b"]
1418
    self._test(files, expected)
1419

    
1420
  def testNonAbsolutePath(self):
1421
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1422

    
1423
  def testNonNormalizedPath(self):
1424
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1425
                          "/bin/../tmp")
1426

    
1427

    
1428
class TestNewUUID(unittest.TestCase):
1429
  """Test case for NewUUID"""
1430

    
1431
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1432
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1433

    
1434
  def runTest(self):
1435
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1436

    
1437

    
1438
class TestUniqueSequence(unittest.TestCase):
1439
  """Test case for UniqueSequence"""
1440

    
1441
  def _test(self, input, expected):
1442
    self.assertEqual(utils.UniqueSequence(input), expected)
1443

    
1444
  def runTest(self):
1445
    # Ordered input
1446
    self._test([1, 2, 3], [1, 2, 3])
1447
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1448
    self._test([1, 2, 2, 3], [1, 2, 3])
1449
    self._test([1, 2, 3, 3], [1, 2, 3])
1450

    
1451
    # Unordered input
1452
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1453
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1454

    
1455
    # Strings
1456
    self._test(["a", "a"], ["a"])
1457
    self._test(["a", "b"], ["a", "b"])
1458
    self._test(["a", "b", "a"], ["a", "b"])
1459

    
1460

    
1461
class TestFirstFree(unittest.TestCase):
1462
  """Test case for the FirstFree function"""
1463

    
1464
  def test(self):
1465
    """Test FirstFree"""
1466
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1467
    self.failUnlessEqual(FirstFree([]), None)
1468
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1469
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1470
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1471

    
1472

    
1473
class TestTailFile(testutils.GanetiTestCase):
1474
  """Test case for the TailFile function"""
1475

    
1476
  def testEmpty(self):
1477
    fname = self._CreateTempFile()
1478
    self.failUnlessEqual(TailFile(fname), [])
1479
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1480

    
1481
  def testAllLines(self):
1482
    data = ["test %d" % i for i in range(30)]
1483
    for i in range(30):
1484
      fname = self._CreateTempFile()
1485
      fd = open(fname, "w")
1486
      fd.write("\n".join(data[:i]))
1487
      if i > 0:
1488
        fd.write("\n")
1489
      fd.close()
1490
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1491

    
1492
  def testPartialLines(self):
1493
    data = ["test %d" % i for i in range(30)]
1494
    fname = self._CreateTempFile()
1495
    fd = open(fname, "w")
1496
    fd.write("\n".join(data))
1497
    fd.write("\n")
1498
    fd.close()
1499
    for i in range(1, 30):
1500
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1501

    
1502
  def testBigFile(self):
1503
    data = ["test %d" % i for i in range(30)]
1504
    fname = self._CreateTempFile()
1505
    fd = open(fname, "w")
1506
    fd.write("X" * 1048576)
1507
    fd.write("\n")
1508
    fd.write("\n".join(data))
1509
    fd.write("\n")
1510
    fd.close()
1511
    for i in range(1, 30):
1512
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1513

    
1514

    
1515
class _BaseFileLockTest:
1516
  """Test case for the FileLock class"""
1517

    
1518
  def testSharedNonblocking(self):
1519
    self.lock.Shared(blocking=False)
1520
    self.lock.Close()
1521

    
1522
  def testExclusiveNonblocking(self):
1523
    self.lock.Exclusive(blocking=False)
1524
    self.lock.Close()
1525

    
1526
  def testUnlockNonblocking(self):
1527
    self.lock.Unlock(blocking=False)
1528
    self.lock.Close()
1529

    
1530
  def testSharedBlocking(self):
1531
    self.lock.Shared(blocking=True)
1532
    self.lock.Close()
1533

    
1534
  def testExclusiveBlocking(self):
1535
    self.lock.Exclusive(blocking=True)
1536
    self.lock.Close()
1537

    
1538
  def testUnlockBlocking(self):
1539
    self.lock.Unlock(blocking=True)
1540
    self.lock.Close()
1541

    
1542
  def testSharedExclusiveUnlock(self):
1543
    self.lock.Shared(blocking=False)
1544
    self.lock.Exclusive(blocking=False)
1545
    self.lock.Unlock(blocking=False)
1546
    self.lock.Close()
1547

    
1548
  def testExclusiveSharedUnlock(self):
1549
    self.lock.Exclusive(blocking=False)
1550
    self.lock.Shared(blocking=False)
1551
    self.lock.Unlock(blocking=False)
1552
    self.lock.Close()
1553

    
1554
  def testSimpleTimeout(self):
1555
    # These will succeed on the first attempt, hence a short timeout
1556
    self.lock.Shared(blocking=True, timeout=10.0)
1557
    self.lock.Exclusive(blocking=False, timeout=10.0)
1558
    self.lock.Unlock(blocking=True, timeout=10.0)
1559
    self.lock.Close()
1560

    
1561
  @staticmethod
1562
  def _TryLockInner(filename, shared, blocking):
1563
    lock = utils.FileLock.Open(filename)
1564

    
1565
    if shared:
1566
      fn = lock.Shared
1567
    else:
1568
      fn = lock.Exclusive
1569

    
1570
    try:
1571
      # The timeout doesn't really matter as the parent process waits for us to
1572
      # finish anyway.
1573
      fn(blocking=blocking, timeout=0.01)
1574
    except errors.LockError, err:
1575
      return False
1576

    
1577
    return True
1578

    
1579
  def _TryLock(self, *args):
1580
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1581
                                      *args)
1582

    
1583
  def testTimeout(self):
1584
    for blocking in [True, False]:
1585
      self.lock.Exclusive(blocking=True)
1586
      self.failIf(self._TryLock(False, blocking))
1587
      self.failIf(self._TryLock(True, blocking))
1588

    
1589
      self.lock.Shared(blocking=True)
1590
      self.assert_(self._TryLock(True, blocking))
1591
      self.failIf(self._TryLock(False, blocking))
1592

    
1593
  def testCloseShared(self):
1594
    self.lock.Close()
1595
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1596

    
1597
  def testCloseExclusive(self):
1598
    self.lock.Close()
1599
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1600

    
1601
  def testCloseUnlock(self):
1602
    self.lock.Close()
1603
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1604

    
1605

    
1606
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1607
  TESTDATA = "Hello World\n" * 10
1608

    
1609
  def setUp(self):
1610
    testutils.GanetiTestCase.setUp(self)
1611

    
1612
    self.tmpfile = tempfile.NamedTemporaryFile()
1613
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1614
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1615

    
1616
    # Ensure "Open" didn't truncate file
1617
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1618

    
1619
  def tearDown(self):
1620
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1621

    
1622
    testutils.GanetiTestCase.tearDown(self)
1623

    
1624

    
1625
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1626
  def setUp(self):
1627
    self.tmpfile = tempfile.NamedTemporaryFile()
1628
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1629

    
1630

    
1631
class TestTimeFunctions(unittest.TestCase):
1632
  """Test case for time functions"""
1633

    
1634
  def runTest(self):
1635
    self.assertEqual(utils.SplitTime(1), (1, 0))
1636
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1637
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1638
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1639
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1640
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1641
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1642
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1643

    
1644
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1645

    
1646
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1647
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1648
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1649

    
1650
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1651
                     1218448917.481)
1652
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1653

    
1654
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1655
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1656
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1657
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1658
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1659

    
1660

    
1661
class FieldSetTestCase(unittest.TestCase):
1662
  """Test case for FieldSets"""
1663

    
1664
  def testSimpleMatch(self):
1665
    f = utils.FieldSet("a", "b", "c", "def")
1666
    self.failUnless(f.Matches("a"))
1667
    self.failIf(f.Matches("d"), "Substring matched")
1668
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1669
    self.failIf(f.NonMatching(["b", "c"]))
1670
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1671
    self.failUnless(f.NonMatching(["a", "d"]))
1672

    
1673
  def testRegexMatch(self):
1674
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1675
    self.failUnless(f.Matches("b1"))
1676
    self.failUnless(f.Matches("b99"))
1677
    self.failIf(f.Matches("b/1"))
1678
    self.failIf(f.NonMatching(["b12", "c"]))
1679
    self.failUnless(f.NonMatching(["a", "1"]))
1680

    
1681
class TestForceDictType(unittest.TestCase):
1682
  """Test case for ForceDictType"""
1683

    
1684
  def setUp(self):
1685
    self.key_types = {
1686
      'a': constants.VTYPE_INT,
1687
      'b': constants.VTYPE_BOOL,
1688
      'c': constants.VTYPE_STRING,
1689
      'd': constants.VTYPE_SIZE,
1690
      }
1691

    
1692
  def _fdt(self, dict, allowed_values=None):
1693
    if allowed_values is None:
1694
      ForceDictType(dict, self.key_types)
1695
    else:
1696
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1697

    
1698
    return dict
1699

    
1700
  def testSimpleDict(self):
1701
    self.assertEqual(self._fdt({}), {})
1702
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1703
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1704
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1705
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1706
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1707
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1708
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1709
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1710
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1711
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1712
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1713

    
1714
  def testErrors(self):
1715
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1716
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1717
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1718
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1719

    
1720

    
1721
class TestIsNormAbsPath(unittest.TestCase):
1722
  """Testing case for IsNormAbsPath"""
1723

    
1724
  def _pathTestHelper(self, path, result):
1725
    if result:
1726
      self.assert_(IsNormAbsPath(path),
1727
          "Path %s should result absolute and normalized" % path)
1728
    else:
1729
      self.assertFalse(IsNormAbsPath(path),
1730
          "Path %s should not result absolute and normalized" % path)
1731

    
1732
  def testBase(self):
1733
    self._pathTestHelper('/etc', True)
1734
    self._pathTestHelper('/srv', True)
1735
    self._pathTestHelper('etc', False)
1736
    self._pathTestHelper('/etc/../root', False)
1737
    self._pathTestHelper('/etc/', False)
1738

    
1739

    
1740
class TestSafeEncode(unittest.TestCase):
1741
  """Test case for SafeEncode"""
1742

    
1743
  def testAscii(self):
1744
    for txt in [string.digits, string.letters, string.punctuation]:
1745
      self.failUnlessEqual(txt, SafeEncode(txt))
1746

    
1747
  def testDoubleEncode(self):
1748
    for i in range(255):
1749
      txt = SafeEncode(chr(i))
1750
      self.failUnlessEqual(txt, SafeEncode(txt))
1751

    
1752
  def testUnicode(self):
1753
    # 1024 is high enough to catch non-direct ASCII mappings
1754
    for i in range(1024):
1755
      txt = SafeEncode(unichr(i))
1756
      self.failUnlessEqual(txt, SafeEncode(txt))
1757

    
1758

    
1759
class TestFormatTime(unittest.TestCase):
1760
  """Testing case for FormatTime"""
1761

    
1762
  def testNone(self):
1763
    self.failUnlessEqual(FormatTime(None), "N/A")
1764

    
1765
  def testInvalid(self):
1766
    self.failUnlessEqual(FormatTime(()), "N/A")
1767

    
1768
  def testNow(self):
1769
    # tests that we accept time.time input
1770
    FormatTime(time.time())
1771
    # tests that we accept int input
1772
    FormatTime(int(time.time()))
1773

    
1774

    
1775
class RunInSeparateProcess(unittest.TestCase):
1776
  def test(self):
1777
    for exp in [True, False]:
1778
      def _child():
1779
        return exp
1780

    
1781
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1782

    
1783
  def testArgs(self):
1784
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1785
      def _child(carg1, carg2):
1786
        return carg1 == "Foo" and carg2 == arg
1787

    
1788
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1789

    
1790
  def testPid(self):
1791
    parent_pid = os.getpid()
1792

    
1793
    def _check():
1794
      return os.getpid() == parent_pid
1795

    
1796
    self.failIf(utils.RunInSeparateProcess(_check))
1797

    
1798
  def testSignal(self):
1799
    def _kill():
1800
      os.kill(os.getpid(), signal.SIGTERM)
1801

    
1802
    self.assertRaises(errors.GenericError,
1803
                      utils.RunInSeparateProcess, _kill)
1804

    
1805
  def testException(self):
1806
    def _exc():
1807
      raise errors.GenericError("This is a test")
1808

    
1809
    self.assertRaises(errors.GenericError,
1810
                      utils.RunInSeparateProcess, _exc)
1811

    
1812

    
1813
class TestFingerprintFile(unittest.TestCase):
1814
  def setUp(self):
1815
    self.tmpfile = tempfile.NamedTemporaryFile()
1816

    
1817
  def test(self):
1818
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1819
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1820

    
1821
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1822
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1823
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1824

    
1825

    
1826
class TestUnescapeAndSplit(unittest.TestCase):
1827
  """Testing case for UnescapeAndSplit"""
1828

    
1829
  def setUp(self):
1830
    # testing more that one separator for regexp safety
1831
    self._seps = [",", "+", "."]
1832

    
1833
  def testSimple(self):
1834
    a = ["a", "b", "c", "d"]
1835
    for sep in self._seps:
1836
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1837

    
1838
  def testEscape(self):
1839
    for sep in self._seps:
1840
      a = ["a", "b\\" + sep + "c", "d"]
1841
      b = ["a", "b" + sep + "c", "d"]
1842
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1843

    
1844
  def testDoubleEscape(self):
1845
    for sep in self._seps:
1846
      a = ["a", "b\\\\", "c", "d"]
1847
      b = ["a", "b\\", "c", "d"]
1848
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1849

    
1850
  def testThreeEscape(self):
1851
    for sep in self._seps:
1852
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1853
      b = ["a", "b\\" + sep + "c", "d"]
1854
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1855

    
1856

    
1857
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1858
  def setUp(self):
1859
    self.tmpdir = tempfile.mkdtemp()
1860

    
1861
  def tearDown(self):
1862
    shutil.rmtree(self.tmpdir)
1863

    
1864
  def _checkRsaPrivateKey(self, key):
1865
    lines = key.splitlines()
1866
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1867
            "-----END RSA PRIVATE KEY-----" in lines)
1868

    
1869
  def _checkCertificate(self, cert):
1870
    lines = cert.splitlines()
1871
    return ("-----BEGIN CERTIFICATE-----" in lines and
1872
            "-----END CERTIFICATE-----" in lines)
1873

    
1874
  def test(self):
1875
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1876
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1877
      self._checkRsaPrivateKey(key_pem)
1878
      self._checkCertificate(cert_pem)
1879

    
1880
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1881
                                           key_pem)
1882
      self.assert_(key.bits() >= 1024)
1883
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1884
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1885

    
1886
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1887
                                             cert_pem)
1888
      self.failIf(x509.has_expired())
1889
      self.assertEqual(x509.get_issuer().CN, common_name)
1890
      self.assertEqual(x509.get_subject().CN, common_name)
1891
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1892

    
1893
  def testLegacy(self):
1894
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1895

    
1896
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1897

    
1898
    cert1 = utils.ReadFile(cert1_filename)
1899

    
1900
    self.assert_(self._checkRsaPrivateKey(cert1))
1901
    self.assert_(self._checkCertificate(cert1))
1902

    
1903

    
1904
class TestPathJoin(unittest.TestCase):
1905
  """Testing case for PathJoin"""
1906

    
1907
  def testBasicItems(self):
1908
    mlist = ["/a", "b", "c"]
1909
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1910

    
1911
  def testNonAbsPrefix(self):
1912
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1913

    
1914
  def testBackTrack(self):
1915
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1916

    
1917
  def testMultiAbs(self):
1918
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1919

    
1920

    
1921
class TestHostInfo(unittest.TestCase):
1922
  """Testing case for HostInfo"""
1923

    
1924
  def testUppercase(self):
1925
    data = "AbC.example.com"
1926
    self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1927

    
1928
  def testTooLongName(self):
1929
    data = "a.b." + "c" * 255
1930
    self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1931

    
1932
  def testTrailingDot(self):
1933
    data = "a.b.c"
1934
    self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1935

    
1936
  def testInvalidName(self):
1937
    data = [
1938
      "a b",
1939
      "a/b",
1940
      ".a.b",
1941
      "a..b",
1942
      ]
1943
    for value in data:
1944
      self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1945

    
1946
  def testValidName(self):
1947
    data = [
1948
      "a.b",
1949
      "a-b",
1950
      "a_b",
1951
      "a.b.c",
1952
      ]
1953
    for value in data:
1954
      HostInfo.NormalizeName(value)
1955

    
1956

    
1957
class TestValidateServiceName(unittest.TestCase):
1958
  def testValid(self):
1959
    testnames = [
1960
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1961
      "ganeti",
1962
      "gnt-masterd",
1963
      "HELLO_WORLD_SVC",
1964
      "hello.world.1",
1965
      "0", "80", "1111", "65535",
1966
      ]
1967

    
1968
    for name in testnames:
1969
      self.assertEqual(utils.ValidateServiceName(name), name)
1970

    
1971
  def testInvalid(self):
1972
    testnames = [
1973
      -15756, -1, 65536, 133428083,
1974
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1975
      "-8546", "-1", "65536",
1976
      (129 * "A"),
1977
      ]
1978

    
1979
    for name in testnames:
1980
      self.assertRaises(OpPrereqError, utils.ValidateServiceName, name)
1981

    
1982

    
1983
class TestParseAsn1Generalizedtime(unittest.TestCase):
1984
  def test(self):
1985
    # UTC
1986
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1987
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1988
                     1266860512)
1989
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1990
                     (2**31) - 1)
1991

    
1992
    # With offset
1993
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1994
                     1266860512)
1995
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1996
                     1266931012)
1997
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1998
                     1266931088)
1999
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
2000
                     1266931295)
2001
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
2002
                     3600)
2003

    
2004
    # Leap seconds are not supported by datetime.datetime
2005
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2006
                      "19841231235960+0000")
2007
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2008
                      "19920630235960+0000")
2009

    
2010
    # Errors
2011
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
2012
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
2013
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2014
                      "20100222174152")
2015
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2016
                      "Mon Feb 22 17:47:02 UTC 2010")
2017
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
2018
                      "2010-02-22 17:42:02")
2019

    
2020

    
2021
class TestGetX509CertValidity(testutils.GanetiTestCase):
2022
  def setUp(self):
2023
    testutils.GanetiTestCase.setUp(self)
2024

    
2025
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
2026

    
2027
    # Test whether we have pyOpenSSL 0.7 or above
2028
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
2029

    
2030
    if not self.pyopenssl0_7:
2031
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
2032
                    " function correctly")
2033

    
2034
  def _LoadCert(self, name):
2035
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2036
                                           self._ReadTestData(name))
2037

    
2038
  def test(self):
2039
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
2040
    if self.pyopenssl0_7:
2041
      self.assertEqual(validity, (1266919967, 1267524767))
2042
    else:
2043
      self.assertEqual(validity, (None, None))
2044

    
2045

    
2046
class TestSignX509Certificate(unittest.TestCase):
2047
  KEY = "My private key!"
2048
  KEY_OTHER = "Another key"
2049

    
2050
  def test(self):
2051
    # Generate certificate valid for 5 minutes
2052
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
2053

    
2054
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2055
                                           cert_pem)
2056

    
2057
    # No signature at all
2058
    self.assertRaises(errors.GenericError,
2059
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
2060

    
2061
    # Invalid input
2062
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2063
                      "", self.KEY)
2064
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2065
                      "X-Ganeti-Signature: \n", self.KEY)
2066
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2067
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
2068
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2069
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
2070
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2071
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
2072

    
2073
    # Invalid salt
2074
    for salt in list("-_@$,:;/\\ \t\n"):
2075
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
2076
                        cert_pem, self.KEY, "foo%sbar" % salt)
2077

    
2078
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
2079
                 utils.GenerateSecret(numbytes=4),
2080
                 utils.GenerateSecret(numbytes=16),
2081
                 "{123:456}".encode("hex")]:
2082
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2083

    
2084
      self._Check(cert, salt, signed_pem)
2085

    
2086
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2087
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2088
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2089
                               "lines----\n------ at\nthe end!"))
2090

    
2091
  def _Check(self, cert, salt, pem):
2092
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2093
    self.assertEqual(salt, salt2)
2094
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2095

    
2096
    # Other key
2097
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2098
                      pem, self.KEY_OTHER)
2099

    
2100

    
2101
class TestMakedirs(unittest.TestCase):
2102
  def setUp(self):
2103
    self.tmpdir = tempfile.mkdtemp()
2104

    
2105
  def tearDown(self):
2106
    shutil.rmtree(self.tmpdir)
2107

    
2108
  def testNonExisting(self):
2109
    path = utils.PathJoin(self.tmpdir, "foo")
2110
    utils.Makedirs(path)
2111
    self.assert_(os.path.isdir(path))
2112

    
2113
  def testExisting(self):
2114
    path = utils.PathJoin(self.tmpdir, "foo")
2115
    os.mkdir(path)
2116
    utils.Makedirs(path)
2117
    self.assert_(os.path.isdir(path))
2118

    
2119
  def testRecursiveNonExisting(self):
2120
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
2121
    utils.Makedirs(path)
2122
    self.assert_(os.path.isdir(path))
2123

    
2124
  def testRecursiveExisting(self):
2125
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
2126
    self.assertFalse(os.path.exists(path))
2127
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
2128
    utils.Makedirs(path)
2129
    self.assert_(os.path.isdir(path))
2130

    
2131

    
2132
class TestRetry(testutils.GanetiTestCase):
2133
  def setUp(self):
2134
    testutils.GanetiTestCase.setUp(self)
2135
    self.retries = 0
2136

    
2137
  @staticmethod
2138
  def _RaiseRetryAgain():
2139
    raise utils.RetryAgain()
2140

    
2141
  @staticmethod
2142
  def _RaiseRetryAgainWithArg(args):
2143
    raise utils.RetryAgain(*args)
2144

    
2145
  def _WrongNestedLoop(self):
2146
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2147

    
2148
  def _RetryAndSucceed(self, retries):
2149
    if self.retries < retries:
2150
      self.retries += 1
2151
      raise utils.RetryAgain()
2152
    else:
2153
      return True
2154

    
2155
  def testRaiseTimeout(self):
2156
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2157
                          self._RaiseRetryAgain, 0.01, 0.02)
2158
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2159
                          self._RetryAndSucceed, 0.01, 0, args=[1])
2160
    self.failUnlessEqual(self.retries, 1)
2161

    
2162
  def testComplete(self):
2163
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2164
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2165
                         True)
2166
    self.failUnlessEqual(self.retries, 2)
2167

    
2168
  def testNestedLoop(self):
2169
    try:
2170
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2171
                            self._WrongNestedLoop, 0, 1)
2172
    except utils.RetryTimeout:
2173
      self.fail("Didn't detect inner loop's exception")
2174

    
2175
  def testTimeoutArgument(self):
2176
    retry_arg="my_important_debugging_message"
2177
    try:
2178
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2179
    except utils.RetryTimeout, err:
2180
      self.failUnlessEqual(err.args, (retry_arg, ))
2181
    else:
2182
      self.fail("Expected timeout didn't happen")
2183

    
2184
  def testRaiseInnerWithExc(self):
2185
    retry_arg="my_important_debugging_message"
2186
    try:
2187
      try:
2188
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2189
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
2190
      except utils.RetryTimeout, err:
2191
        err.RaiseInner()
2192
      else:
2193
        self.fail("Expected timeout didn't happen")
2194
    except errors.GenericError, err:
2195
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2196
    else:
2197
      self.fail("Expected GenericError didn't happen")
2198

    
2199
  def testRaiseInnerWithMsg(self):
2200
    retry_arg="my_important_debugging_message"
2201
    try:
2202
      try:
2203
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2204
                    args=[[retry_arg, retry_arg]])
2205
      except utils.RetryTimeout, err:
2206
        err.RaiseInner()
2207
      else:
2208
        self.fail("Expected timeout didn't happen")
2209
    except utils.RetryTimeout, err:
2210
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2211
    else:
2212
      self.fail("Expected RetryTimeout didn't happen")
2213

    
2214

    
2215
class TestLineSplitter(unittest.TestCase):
2216
  def test(self):
2217
    lines = []
2218
    ls = utils.LineSplitter(lines.append)
2219
    ls.write("Hello World\n")
2220
    self.assertEqual(lines, [])
2221
    ls.write("Foo\n Bar\r\n ")
2222
    ls.write("Baz")
2223
    ls.write("Moo")
2224
    self.assertEqual(lines, [])
2225
    ls.flush()
2226
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2227
    ls.close()
2228
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2229

    
2230
  def _testExtra(self, line, all_lines, p1, p2):
2231
    self.assertEqual(p1, 999)
2232
    self.assertEqual(p2, "extra")
2233
    all_lines.append(line)
2234

    
2235
  def testExtraArgsNoFlush(self):
2236
    lines = []
2237
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2238
    ls.write("\n\nHello World\n")
2239
    ls.write("Foo\n Bar\r\n ")
2240
    ls.write("")
2241
    ls.write("Baz")
2242
    ls.write("Moo\n\nx\n")
2243
    self.assertEqual(lines, [])
2244
    ls.close()
2245
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2246
                             "", "x"])
2247

    
2248

    
2249
class TestReadLockedPidFile(unittest.TestCase):
2250
  def setUp(self):
2251
    self.tmpdir = tempfile.mkdtemp()
2252

    
2253
  def tearDown(self):
2254
    shutil.rmtree(self.tmpdir)
2255

    
2256
  def testNonExistent(self):
2257
    path = utils.PathJoin(self.tmpdir, "nonexist")
2258
    self.assert_(utils.ReadLockedPidFile(path) is None)
2259

    
2260
  def testUnlocked(self):
2261
    path = utils.PathJoin(self.tmpdir, "pid")
2262
    utils.WriteFile(path, data="123")
2263
    self.assert_(utils.ReadLockedPidFile(path) is None)
2264

    
2265
  def testLocked(self):
2266
    path = utils.PathJoin(self.tmpdir, "pid")
2267
    utils.WriteFile(path, data="123")
2268

    
2269
    fl = utils.FileLock.Open(path)
2270
    try:
2271
      fl.Exclusive(blocking=True)
2272

    
2273
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2274
    finally:
2275
      fl.Close()
2276

    
2277
    self.assert_(utils.ReadLockedPidFile(path) is None)
2278

    
2279
  def testError(self):
2280
    path = utils.PathJoin(self.tmpdir, "foobar", "pid")
2281
    utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
2282
    # open(2) should return ENOTDIR
2283
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2284

    
2285

    
2286
class TestCertVerification(testutils.GanetiTestCase):
2287
  def setUp(self):
2288
    testutils.GanetiTestCase.setUp(self)
2289

    
2290
    self.tmpdir = tempfile.mkdtemp()
2291

    
2292
  def tearDown(self):
2293
    shutil.rmtree(self.tmpdir)
2294

    
2295
  def testVerifyCertificate(self):
2296
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2297
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2298
                                           cert_pem)
2299

    
2300
    # Not checking return value as this certificate is expired
2301
    utils.VerifyX509Certificate(cert, 30, 7)
2302

    
2303

    
2304
class TestVerifyCertificateInner(unittest.TestCase):
2305
  def test(self):
2306
    vci = utils._VerifyCertificateInner
2307

    
2308
    # Valid
2309
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2310
                     (None, None))
2311

    
2312
    # Not yet valid
2313
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2314
    self.assertEqual(errcode, utils.CERT_WARNING)
2315

    
2316
    # Expiring soon
2317
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2318
    self.assertEqual(errcode, utils.CERT_ERROR)
2319

    
2320
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2321
    self.assertEqual(errcode, utils.CERT_WARNING)
2322

    
2323
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2324
    self.assertEqual(errcode, None)
2325

    
2326
    # Expired
2327
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2328
    self.assertEqual(errcode, utils.CERT_ERROR)
2329

    
2330
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2331
    self.assertEqual(errcode, utils.CERT_ERROR)
2332

    
2333
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2334
    self.assertEqual(errcode, utils.CERT_ERROR)
2335

    
2336
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2337
    self.assertEqual(errcode, utils.CERT_ERROR)
2338

    
2339

    
2340
class TestHmacFunctions(unittest.TestCase):
2341
  # Digests can be checked with "openssl sha1 -hmac $key"
2342
  def testSha1Hmac(self):
2343
    self.assertEqual(utils.Sha1Hmac("", ""),
2344
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2345
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2346
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2347
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2348
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2349

    
2350
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2351
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2352
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2353

    
2354
  def testSha1HmacSalt(self):
2355
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2356
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2357
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2358
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2359
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2360
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2361

    
2362
  def testVerifySha1Hmac(self):
2363
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2364
                                               "7d64b71fb76370690e1d")))
2365
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2366
                                      ("f904c2476527c6d3e660"
2367
                                       "9ab683c66fa0652cb1dc")))
2368

    
2369
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2370
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2371
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2372
                                      digest.lower()))
2373
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2374
                                      digest.upper()))
2375
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2376
                                      digest.title()))
2377

    
2378
  def testVerifySha1HmacSalt(self):
2379
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2380
                                      ("17a4adc34d69c0d367d4"
2381
                                       "ffbef96fd41d4df7a6e8"),
2382
                                      salt="abc9"))
2383
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2384
                                      ("7f264f8114c9066afc9b"
2385
                                       "b7636e1786d996d3cc0d"),
2386
                                      salt="xyz0"))
2387

    
2388

    
2389
class TestIgnoreSignals(unittest.TestCase):
2390
  """Test the IgnoreSignals decorator"""
2391

    
2392
  @staticmethod
2393
  def _Raise(exception):
2394
    raise exception
2395

    
2396
  @staticmethod
2397
  def _Return(rval):
2398
    return rval
2399

    
2400
  def testIgnoreSignals(self):
2401
    sock_err_intr = socket.error(errno.EINTR, "Message")
2402
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2403

    
2404
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2405
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2406

    
2407
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2408
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2409
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2410
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2411

    
2412
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2413
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2414
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2415
                      sock_err_inval)
2416
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2417
                      env_err_inval)
2418

    
2419
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2420
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2421

    
2422

    
2423
class TestEnsureDirs(unittest.TestCase):
2424
  """Tests for EnsureDirs"""
2425

    
2426
  def setUp(self):
2427
    self.dir = tempfile.mkdtemp()
2428
    self.old_umask = os.umask(0777)
2429

    
2430
  def testEnsureDirs(self):
2431
    utils.EnsureDirs([
2432
        (utils.PathJoin(self.dir, "foo"), 0777),
2433
        (utils.PathJoin(self.dir, "bar"), 0000),
2434
        ])
2435
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2436
    self.assertEquals(os.stat(utils.PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2437

    
2438
  def tearDown(self):
2439
    os.rmdir(utils.PathJoin(self.dir, "foo"))
2440
    os.rmdir(utils.PathJoin(self.dir, "bar"))
2441
    os.rmdir(self.dir)
2442
    os.umask(self.old_umask)
2443

    
2444

    
2445
class TestFormatSeconds(unittest.TestCase):
2446
  def test(self):
2447
    self.assertEqual(utils.FormatSeconds(1), "1s")
2448
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2449
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2450
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2451
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2452
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2453
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2454
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2455
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2456
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2457
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2458

    
2459
  def testFloat(self):
2460
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2461
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2462
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2463
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2464

    
2465

    
2466
class RunIgnoreProcessNotFound(unittest.TestCase):
2467
  @staticmethod
2468
  def _WritePid(fd):
2469
    os.write(fd, str(os.getpid()))
2470
    os.close(fd)
2471
    return True
2472

    
2473
  def test(self):
2474
    (pid_read_fd, pid_write_fd) = os.pipe()
2475

    
2476
    # Start short-lived process which writes its PID to pipe
2477
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2478
    os.close(pid_write_fd)
2479

    
2480
    # Read PID from pipe
2481
    pid = int(os.read(pid_read_fd, 1024))
2482
    os.close(pid_read_fd)
2483

    
2484
    # Try to send signal to process which exited recently
2485
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2486

    
2487

    
2488
class TestIsValidIP4(unittest.TestCase):
2489
  def test(self):
2490
    self.assert_(utils.IsValidIP4("127.0.0.1"))
2491
    self.assert_(utils.IsValidIP4("0.0.0.0"))
2492
    self.assert_(utils.IsValidIP4("255.255.255.255"))
2493
    self.assertFalse(utils.IsValidIP4("0"))
2494
    self.assertFalse(utils.IsValidIP4("1"))
2495
    self.assertFalse(utils.IsValidIP4("1.1.1"))
2496
    self.assertFalse(utils.IsValidIP4("255.255.255.256"))
2497
    self.assertFalse(utils.IsValidIP4("::1"))
2498

    
2499

    
2500
class TestIsValidIP6(unittest.TestCase):
2501
  def test(self):
2502
    self.assert_(utils.IsValidIP6("::"))
2503
    self.assert_(utils.IsValidIP6("::1"))
2504
    self.assert_(utils.IsValidIP6("1" + (":1" * 7)))
2505
    self.assert_(utils.IsValidIP6("ffff" + (":ffff" * 7)))
2506
    self.assertFalse(utils.IsValidIP6("0"))
2507
    self.assertFalse(utils.IsValidIP6(":1"))
2508
    self.assertFalse(utils.IsValidIP6("f" + (":f" * 6)))
2509
    self.assertFalse(utils.IsValidIP6("fffg" + (":ffff" * 7)))
2510
    self.assertFalse(utils.IsValidIP6("fffff" + (":ffff" * 7)))
2511
    self.assertFalse(utils.IsValidIP6("1" + (":1" * 8)))
2512
    self.assertFalse(utils.IsValidIP6("127.0.0.1"))
2513

    
2514

    
2515
class TestIsValidIP(unittest.TestCase):
2516
  def test(self):
2517
    self.assert_(utils.IsValidIP("0.0.0.0"))
2518
    self.assert_(utils.IsValidIP("127.0.0.1"))
2519
    self.assert_(utils.IsValidIP("::"))
2520
    self.assert_(utils.IsValidIP("::1"))
2521
    self.assertFalse(utils.IsValidIP("0"))
2522
    self.assertFalse(utils.IsValidIP("1.1.1.256"))
2523
    self.assertFalse(utils.IsValidIP("a:g::1"))
2524

    
2525

    
2526
class TestGetAddressFamily(unittest.TestCase):
2527
  def test(self):
2528
    self.assertEqual(utils.GetAddressFamily("127.0.0.1"), socket.AF_INET)
2529
    self.assertEqual(utils.GetAddressFamily("10.2.0.127"), socket.AF_INET)
2530
    self.assertEqual(utils.GetAddressFamily("::1"), socket.AF_INET6)
2531
    self.assertEqual(utils.GetAddressFamily("fe80::a00:27ff:fe08:5048"),
2532
                     socket.AF_INET6)
2533
    self.assertRaises(errors.GenericError, utils.GetAddressFamily, "0")
2534

    
2535

    
2536
if __name__ == '__main__':
2537
  testutils.GanetiTestProgram()