Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 68857643

History | View | Annotate | Download (59.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import stat
31
import md5
32
import signal
33
import socket
34
import shutil
35
import re
36
import select
37
import string
38
import fcntl
39
import OpenSSL
40
import warnings
41
import distutils.version
42
import glob
43

    
44
import ganeti
45
import testutils
46
from ganeti import constants
47
from ganeti import utils
48
from ganeti import errors
49
from ganeti.utils import IsProcessAlive, RunCmd, \
50
     RemoveFile, MatchNameComponent, FormatUnit, \
51
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
52
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
53
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
54
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
55
     UnescapeAndSplit, RunParts, PathJoin, HostInfo
56

    
57
from ganeti.errors import LockError, UnitParseError, GenericError, \
58
     ProgrammerError, OpPrereqError
59

    
60

    
61
class TestIsProcessAlive(unittest.TestCase):
62
  """Testing case for IsProcessAlive"""
63

    
64
  def testExists(self):
65
    mypid = os.getpid()
66
    self.assert_(IsProcessAlive(mypid),
67
                 "can't find myself running")
68

    
69
  def testNotExisting(self):
70
    pid_non_existing = os.fork()
71
    if pid_non_existing == 0:
72
      os._exit(0)
73
    elif pid_non_existing < 0:
74
      raise SystemError("can't fork")
75
    os.waitpid(pid_non_existing, 0)
76
    self.assert_(not IsProcessAlive(pid_non_existing),
77
                 "nonexisting process detected")
78

    
79

    
80
class TestPidFileFunctions(unittest.TestCase):
81
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
82

    
83
  def setUp(self):
84
    self.dir = tempfile.mkdtemp()
85
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
86
    utils.DaemonPidFileName = self.f_dpn
87

    
88
  def testPidFileFunctions(self):
89
    pid_file = self.f_dpn('test')
90
    utils.WritePidFile('test')
91
    self.failUnless(os.path.exists(pid_file),
92
                    "PID file should have been created")
93
    read_pid = utils.ReadPidFile(pid_file)
94
    self.failUnlessEqual(read_pid, os.getpid())
95
    self.failUnless(utils.IsProcessAlive(read_pid))
96
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
97
    utils.RemovePidFile('test')
98
    self.failIf(os.path.exists(pid_file),
99
                "PID file should not exist anymore")
100
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
101
                         "ReadPidFile should return 0 for missing pid file")
102
    fh = open(pid_file, "w")
103
    fh.write("blah\n")
104
    fh.close()
105
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
106
                         "ReadPidFile should return 0 for invalid pid file")
107
    utils.RemovePidFile('test')
108
    self.failIf(os.path.exists(pid_file),
109
                "PID file should not exist anymore")
110

    
111
  def testKill(self):
112
    pid_file = self.f_dpn('child')
113
    r_fd, w_fd = os.pipe()
114
    new_pid = os.fork()
115
    if new_pid == 0: #child
116
      utils.WritePidFile('child')
117
      os.write(w_fd, 'a')
118
      signal.pause()
119
      os._exit(0)
120
      return
121
    # else we are in the parent
122
    # wait until the child has written the pid file
123
    os.read(r_fd, 1)
124
    read_pid = utils.ReadPidFile(pid_file)
125
    self.failUnlessEqual(read_pid, new_pid)
126
    self.failUnless(utils.IsProcessAlive(new_pid))
127
    utils.KillProcess(new_pid, waitpid=True)
128
    self.failIf(utils.IsProcessAlive(new_pid))
129
    utils.RemovePidFile('child')
130
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
131

    
132
  def tearDown(self):
133
    for name in os.listdir(self.dir):
134
      os.unlink(os.path.join(self.dir, name))
135
    os.rmdir(self.dir)
136

    
137

    
138
class TestRunCmd(testutils.GanetiTestCase):
139
  """Testing case for the RunCmd function"""
140

    
141
  def setUp(self):
142
    testutils.GanetiTestCase.setUp(self)
143
    self.magic = time.ctime() + " ganeti test"
144
    self.fname = self._CreateTempFile()
145

    
146
  def testOk(self):
147
    """Test successful exit code"""
148
    result = RunCmd("/bin/sh -c 'exit 0'")
149
    self.assertEqual(result.exit_code, 0)
150
    self.assertEqual(result.output, "")
151

    
152
  def testFail(self):
153
    """Test fail exit code"""
154
    result = RunCmd("/bin/sh -c 'exit 1'")
155
    self.assertEqual(result.exit_code, 1)
156
    self.assertEqual(result.output, "")
157

    
158
  def testStdout(self):
159
    """Test standard output"""
160
    cmd = 'echo -n "%s"' % self.magic
161
    result = RunCmd("/bin/sh -c '%s'" % cmd)
162
    self.assertEqual(result.stdout, self.magic)
163
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
164
    self.assertEqual(result.output, "")
165
    self.assertFileContent(self.fname, self.magic)
166

    
167
  def testStderr(self):
168
    """Test standard error"""
169
    cmd = 'echo -n "%s"' % self.magic
170
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
171
    self.assertEqual(result.stderr, self.magic)
172
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
173
    self.assertEqual(result.output, "")
174
    self.assertFileContent(self.fname, self.magic)
175

    
176
  def testCombined(self):
177
    """Test combined output"""
178
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
179
    expected = "A" + self.magic + "B" + self.magic
180
    result = RunCmd("/bin/sh -c '%s'" % cmd)
181
    self.assertEqual(result.output, expected)
182
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
183
    self.assertEqual(result.output, "")
184
    self.assertFileContent(self.fname, expected)
185

    
186
  def testSignal(self):
187
    """Test signal"""
188
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
189
    self.assertEqual(result.signal, 15)
190
    self.assertEqual(result.output, "")
191

    
192
  def testListRun(self):
193
    """Test list runs"""
194
    result = RunCmd(["true"])
195
    self.assertEqual(result.signal, None)
196
    self.assertEqual(result.exit_code, 0)
197
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
198
    self.assertEqual(result.signal, None)
199
    self.assertEqual(result.exit_code, 1)
200
    result = RunCmd(["echo", "-n", self.magic])
201
    self.assertEqual(result.signal, None)
202
    self.assertEqual(result.exit_code, 0)
203
    self.assertEqual(result.stdout, self.magic)
204

    
205
  def testFileEmptyOutput(self):
206
    """Test file output"""
207
    result = RunCmd(["true"], output=self.fname)
208
    self.assertEqual(result.signal, None)
209
    self.assertEqual(result.exit_code, 0)
210
    self.assertFileContent(self.fname, "")
211

    
212
  def testLang(self):
213
    """Test locale environment"""
214
    old_env = os.environ.copy()
215
    try:
216
      os.environ["LANG"] = "en_US.UTF-8"
217
      os.environ["LC_ALL"] = "en_US.UTF-8"
218
      result = RunCmd(["locale"])
219
      for line in result.output.splitlines():
220
        key, value = line.split("=", 1)
221
        # Ignore these variables, they're overridden by LC_ALL
222
        if key == "LANG" or key == "LANGUAGE":
223
          continue
224
        self.failIf(value and value != "C" and value != '"C"',
225
            "Variable %s is set to the invalid value '%s'" % (key, value))
226
    finally:
227
      os.environ = old_env
228

    
229
  def testDefaultCwd(self):
230
    """Test default working directory"""
231
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
232

    
233
  def testCwd(self):
234
    """Test default working directory"""
235
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
236
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
237
    cwd = os.getcwd()
238
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
239

    
240
  def testResetEnv(self):
241
    """Test environment reset functionality"""
242
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
243
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
244
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
245

    
246

    
247
class TestRunParts(unittest.TestCase):
248
  """Testing case for the RunParts function"""
249

    
250
  def setUp(self):
251
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
252

    
253
  def tearDown(self):
254
    shutil.rmtree(self.rundir)
255

    
256
  def testEmpty(self):
257
    """Test on an empty dir"""
258
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
259

    
260
  def testSkipWrongName(self):
261
    """Test that wrong files are skipped"""
262
    fname = os.path.join(self.rundir, "00test.dot")
263
    utils.WriteFile(fname, data="")
264
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
265
    relname = os.path.basename(fname)
266
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
267
                         [(relname, constants.RUNPARTS_SKIP, None)])
268

    
269
  def testSkipNonExec(self):
270
    """Test that non executable files are skipped"""
271
    fname = os.path.join(self.rundir, "00test")
272
    utils.WriteFile(fname, data="")
273
    relname = os.path.basename(fname)
274
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
275
                         [(relname, constants.RUNPARTS_SKIP, None)])
276

    
277
  def testError(self):
278
    """Test error on a broken executable"""
279
    fname = os.path.join(self.rundir, "00test")
280
    utils.WriteFile(fname, data="")
281
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
282
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
283
    self.failUnlessEqual(relname, os.path.basename(fname))
284
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
285
    self.failUnless(error)
286

    
287
  def testSorted(self):
288
    """Test executions are sorted"""
289
    files = []
290
    files.append(os.path.join(self.rundir, "64test"))
291
    files.append(os.path.join(self.rundir, "00test"))
292
    files.append(os.path.join(self.rundir, "42test"))
293

    
294
    for fname in files:
295
      utils.WriteFile(fname, data="")
296

    
297
    results = RunParts(self.rundir, reset_env=True)
298

    
299
    for fname in sorted(files):
300
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
301

    
302
  def testOk(self):
303
    """Test correct execution"""
304
    fname = os.path.join(self.rundir, "00test")
305
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
306
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
307
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
308
    self.failUnlessEqual(relname, os.path.basename(fname))
309
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
310
    self.failUnlessEqual(runresult.stdout, "ciao")
311

    
312
  def testRunFail(self):
313
    """Test correct execution, with run failure"""
314
    fname = os.path.join(self.rundir, "00test")
315
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
316
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
317
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
318
    self.failUnlessEqual(relname, os.path.basename(fname))
319
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
320
    self.failUnlessEqual(runresult.exit_code, 1)
321
    self.failUnless(runresult.failed)
322

    
323
  def testRunMix(self):
324
    files = []
325
    files.append(os.path.join(self.rundir, "00test"))
326
    files.append(os.path.join(self.rundir, "42test"))
327
    files.append(os.path.join(self.rundir, "64test"))
328
    files.append(os.path.join(self.rundir, "99test"))
329

    
330
    files.sort()
331

    
332
    # 1st has errors in execution
333
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
334
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
335

    
336
    # 2nd is skipped
337
    utils.WriteFile(files[1], data="")
338

    
339
    # 3rd cannot execute properly
340
    utils.WriteFile(files[2], data="")
341
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
342

    
343
    # 4th execs
344
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
345
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
346

    
347
    results = RunParts(self.rundir, reset_env=True)
348

    
349
    (relname, status, runresult) = results[0]
350
    self.failUnlessEqual(relname, os.path.basename(files[0]))
351
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
352
    self.failUnlessEqual(runresult.exit_code, 1)
353
    self.failUnless(runresult.failed)
354

    
355
    (relname, status, runresult) = results[1]
356
    self.failUnlessEqual(relname, os.path.basename(files[1]))
357
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
358
    self.failUnlessEqual(runresult, None)
359

    
360
    (relname, status, runresult) = results[2]
361
    self.failUnlessEqual(relname, os.path.basename(files[2]))
362
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
363
    self.failUnless(runresult)
364

    
365
    (relname, status, runresult) = results[3]
366
    self.failUnlessEqual(relname, os.path.basename(files[3]))
367
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
368
    self.failUnlessEqual(runresult.output, "ciao")
369
    self.failUnlessEqual(runresult.exit_code, 0)
370
    self.failUnless(not runresult.failed)
371

    
372

    
373
class TestStartDaemon(testutils.GanetiTestCase):
374
  def setUp(self):
375
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
376
    self.tmpfile = os.path.join(self.tmpdir, "test")
377

    
378
  def tearDown(self):
379
    shutil.rmtree(self.tmpdir)
380

    
381
  def testShell(self):
382
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
383
    self._wait(self.tmpfile, 60.0, "Hello World")
384

    
385
  def testShellOutput(self):
386
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
387
    self._wait(self.tmpfile, 60.0, "Hello World")
388

    
389
  def testNoShellNoOutput(self):
390
    utils.StartDaemon(["pwd"])
391

    
392
  def testNoShellNoOutputTouch(self):
393
    testfile = os.path.join(self.tmpdir, "check")
394
    self.failIf(os.path.exists(testfile))
395
    utils.StartDaemon(["touch", testfile])
396
    self._wait(testfile, 60.0, "")
397

    
398
  def testNoShellOutput(self):
399
    utils.StartDaemon(["pwd"], output=self.tmpfile)
400
    self._wait(self.tmpfile, 60.0, "/")
401

    
402
  def testNoShellOutputCwd(self):
403
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
404
    self._wait(self.tmpfile, 60.0, os.getcwd())
405

    
406
  def testShellEnv(self):
407
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
408
                      env={ "GNT_TEST_VAR": "Hello World", })
409
    self._wait(self.tmpfile, 60.0, "Hello World")
410

    
411
  def testNoShellEnv(self):
412
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
413
                      env={ "GNT_TEST_VAR": "Hello World", })
414
    self._wait(self.tmpfile, 60.0, "Hello World")
415

    
416
  def testOutputFd(self):
417
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
418
    try:
419
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
420
    finally:
421
      os.close(fd)
422
    self._wait(self.tmpfile, 60.0, os.getcwd())
423

    
424
  def testPid(self):
425
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
426
    self._wait(self.tmpfile, 60.0, str(pid))
427

    
428
  def testPidFile(self):
429
    pidfile = os.path.join(self.tmpdir, "pid")
430
    checkfile = os.path.join(self.tmpdir, "abort")
431

    
432
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
433
                            output=self.tmpfile)
434
    try:
435
      fd = os.open(pidfile, os.O_RDONLY)
436
      try:
437
        # Check file is locked
438
        self.assertRaises(errors.LockError, utils.LockFile, fd)
439

    
440
        pidtext = os.read(fd, 100)
441
      finally:
442
        os.close(fd)
443

    
444
      self.assertEqual(int(pidtext.strip()), pid)
445

    
446
      self.assert_(utils.IsProcessAlive(pid))
447
    finally:
448
      # No matter what happens, kill daemon
449
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
450
      self.failIf(utils.IsProcessAlive(pid))
451

    
452
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
453

    
454
  def _wait(self, path, timeout, expected):
455
    # Due to the asynchronous nature of daemon processes, polling is necessary.
456
    # A timeout makes sure the test doesn't hang forever.
457
    def _CheckFile():
458
      if not (os.path.isfile(path) and
459
              utils.ReadFile(path).strip() == expected):
460
        raise utils.RetryAgain()
461

    
462
    try:
463
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
464
    except utils.RetryTimeout:
465
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
466
                " didn't write the correct output" % timeout)
467

    
468
  def testError(self):
469
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
470
                      ["./does-NOT-EXIST/here/0123456789"])
471
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
472
                      ["./does-NOT-EXIST/here/0123456789"],
473
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
474
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
475
                      ["./does-NOT-EXIST/here/0123456789"],
476
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
477
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
478
                      ["./does-NOT-EXIST/here/0123456789"],
479
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
480

    
481
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
482
    try:
483
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
484
                        ["./does-NOT-EXIST/here/0123456789"],
485
                        output=self.tmpfile, output_fd=fd)
486
    finally:
487
      os.close(fd)
488

    
489

    
490
class TestSetCloseOnExecFlag(unittest.TestCase):
491
  """Tests for SetCloseOnExecFlag"""
492

    
493
  def setUp(self):
494
    self.tmpfile = tempfile.TemporaryFile()
495

    
496
  def testEnable(self):
497
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
498
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
499
                    fcntl.FD_CLOEXEC)
500

    
501
  def testDisable(self):
502
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
503
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
504
                fcntl.FD_CLOEXEC)
505

    
506

    
507
class TestSetNonblockFlag(unittest.TestCase):
508
  def setUp(self):
509
    self.tmpfile = tempfile.TemporaryFile()
510

    
511
  def testEnable(self):
512
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
513
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
514
                    os.O_NONBLOCK)
515

    
516
  def testDisable(self):
517
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
518
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
519
                os.O_NONBLOCK)
520

    
521

    
522
class TestRemoveFile(unittest.TestCase):
523
  """Test case for the RemoveFile function"""
524

    
525
  def setUp(self):
526
    """Create a temp dir and file for each case"""
527
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
528
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
529
    os.close(fd)
530

    
531
  def tearDown(self):
532
    if os.path.exists(self.tmpfile):
533
      os.unlink(self.tmpfile)
534
    os.rmdir(self.tmpdir)
535

    
536
  def testIgnoreDirs(self):
537
    """Test that RemoveFile() ignores directories"""
538
    self.assertEqual(None, RemoveFile(self.tmpdir))
539

    
540
  def testIgnoreNotExisting(self):
541
    """Test that RemoveFile() ignores non-existing files"""
542
    RemoveFile(self.tmpfile)
543
    RemoveFile(self.tmpfile)
544

    
545
  def testRemoveFile(self):
546
    """Test that RemoveFile does remove a file"""
547
    RemoveFile(self.tmpfile)
548
    if os.path.exists(self.tmpfile):
549
      self.fail("File '%s' not removed" % self.tmpfile)
550

    
551
  def testRemoveSymlink(self):
552
    """Test that RemoveFile does remove symlinks"""
553
    symlink = self.tmpdir + "/symlink"
554
    os.symlink("no-such-file", symlink)
555
    RemoveFile(symlink)
556
    if os.path.exists(symlink):
557
      self.fail("File '%s' not removed" % symlink)
558
    os.symlink(self.tmpfile, symlink)
559
    RemoveFile(symlink)
560
    if os.path.exists(symlink):
561
      self.fail("File '%s' not removed" % symlink)
562

    
563

    
564
class TestRename(unittest.TestCase):
565
  """Test case for RenameFile"""
566

    
567
  def setUp(self):
568
    """Create a temporary directory"""
569
    self.tmpdir = tempfile.mkdtemp()
570
    self.tmpfile = os.path.join(self.tmpdir, "test1")
571

    
572
    # Touch the file
573
    open(self.tmpfile, "w").close()
574

    
575
  def tearDown(self):
576
    """Remove temporary directory"""
577
    shutil.rmtree(self.tmpdir)
578

    
579
  def testSimpleRename1(self):
580
    """Simple rename 1"""
581
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
582
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
583

    
584
  def testSimpleRename2(self):
585
    """Simple rename 2"""
586
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
587
                     mkdir=True)
588
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
589

    
590
  def testRenameMkdir(self):
591
    """Rename with mkdir"""
592
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
593
                     mkdir=True)
594
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
595
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
596

    
597
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
598
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
599
                     mkdir=True)
600
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
601
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
602
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
603

    
604

    
605
class TestMatchNameComponent(unittest.TestCase):
606
  """Test case for the MatchNameComponent function"""
607

    
608
  def testEmptyList(self):
609
    """Test that there is no match against an empty list"""
610

    
611
    self.failUnlessEqual(MatchNameComponent("", []), None)
612
    self.failUnlessEqual(MatchNameComponent("test", []), None)
613

    
614
  def testSingleMatch(self):
615
    """Test that a single match is performed correctly"""
616
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
617
    for key in "test2", "test2.example", "test2.example.com":
618
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
619

    
620
  def testMultipleMatches(self):
621
    """Test that a multiple match is returned as None"""
622
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
623
    for key in "test1", "test1.example":
624
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
625

    
626
  def testFullMatch(self):
627
    """Test that a full match is returned correctly"""
628
    key1 = "test1"
629
    key2 = "test1.example"
630
    mlist = [key2, key2 + ".com"]
631
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
632
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
633

    
634
  def testCaseInsensitivePartialMatch(self):
635
    """Test for the case_insensitive keyword"""
636
    mlist = ["test1.example.com", "test2.example.net"]
637
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
638
                     "test2.example.net")
639
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
640
                     "test2.example.net")
641
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
642
                     "test2.example.net")
643
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
644
                     "test2.example.net")
645

    
646

    
647
  def testCaseInsensitiveFullMatch(self):
648
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
649
    # Between the two ts1 a full string match non-case insensitive should work
650
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
651
                     None)
652
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
653
                     "ts1.ex")
654
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
655
                     "ts1.ex")
656
    # Between the two ts2 only case differs, so only case-match works
657
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
658
                     "ts2.ex")
659
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
660
                     "Ts2.ex")
661
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
662
                     None)
663

    
664

    
665
class TestTimestampForFilename(unittest.TestCase):
666
  def test(self):
667
    self.assert_("." not in utils.TimestampForFilename())
668
    self.assert_(":" not in utils.TimestampForFilename())
669

    
670

    
671
class TestCreateBackup(testutils.GanetiTestCase):
672
  def setUp(self):
673
    testutils.GanetiTestCase.setUp(self)
674

    
675
    self.tmpdir = tempfile.mkdtemp()
676

    
677
  def tearDown(self):
678
    testutils.GanetiTestCase.tearDown(self)
679

    
680
    shutil.rmtree(self.tmpdir)
681

    
682
  def testEmpty(self):
683
    filename = utils.PathJoin(self.tmpdir, "config.data")
684
    utils.WriteFile(filename, data="")
685
    bname = utils.CreateBackup(filename)
686
    self.assertFileContent(bname, "")
687
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
688
    utils.CreateBackup(filename)
689
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
690
    utils.CreateBackup(filename)
691
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
692

    
693
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
694
    os.mkfifo(fifoname)
695
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
696

    
697
  def testContent(self):
698
    bkpcount = 0
699
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
700
      for rep in [1, 2, 10, 127]:
701
        testdata = data * rep
702

    
703
        filename = utils.PathJoin(self.tmpdir, "test.data_")
704
        utils.WriteFile(filename, data=testdata)
705
        self.assertFileContent(filename, testdata)
706

    
707
        for _ in range(3):
708
          bname = utils.CreateBackup(filename)
709
          bkpcount += 1
710
          self.assertFileContent(bname, testdata)
711
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
712

    
713

    
714
class TestFormatUnit(unittest.TestCase):
715
  """Test case for the FormatUnit function"""
716

    
717
  def testMiB(self):
718
    self.assertEqual(FormatUnit(1, 'h'), '1M')
719
    self.assertEqual(FormatUnit(100, 'h'), '100M')
720
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
721

    
722
    self.assertEqual(FormatUnit(1, 'm'), '1')
723
    self.assertEqual(FormatUnit(100, 'm'), '100')
724
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
725

    
726
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
727
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
728
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
729
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
730

    
731
  def testGiB(self):
732
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
733
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
734
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
735
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
736

    
737
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
738
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
739
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
740
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
741

    
742
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
743
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
744
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
745

    
746
  def testTiB(self):
747
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
748
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
749
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
750

    
751
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
752
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
753
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
754

    
755
class TestParseUnit(unittest.TestCase):
756
  """Test case for the ParseUnit function"""
757

    
758
  SCALES = (('', 1),
759
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
760
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
761
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
762

    
763
  def testRounding(self):
764
    self.assertEqual(ParseUnit('0'), 0)
765
    self.assertEqual(ParseUnit('1'), 4)
766
    self.assertEqual(ParseUnit('2'), 4)
767
    self.assertEqual(ParseUnit('3'), 4)
768

    
769
    self.assertEqual(ParseUnit('124'), 124)
770
    self.assertEqual(ParseUnit('125'), 128)
771
    self.assertEqual(ParseUnit('126'), 128)
772
    self.assertEqual(ParseUnit('127'), 128)
773
    self.assertEqual(ParseUnit('128'), 128)
774
    self.assertEqual(ParseUnit('129'), 132)
775
    self.assertEqual(ParseUnit('130'), 132)
776

    
777
  def testFloating(self):
778
    self.assertEqual(ParseUnit('0'), 0)
779
    self.assertEqual(ParseUnit('0.5'), 4)
780
    self.assertEqual(ParseUnit('1.75'), 4)
781
    self.assertEqual(ParseUnit('1.99'), 4)
782
    self.assertEqual(ParseUnit('2.00'), 4)
783
    self.assertEqual(ParseUnit('2.01'), 4)
784
    self.assertEqual(ParseUnit('3.99'), 4)
785
    self.assertEqual(ParseUnit('4.00'), 4)
786
    self.assertEqual(ParseUnit('4.01'), 8)
787
    self.assertEqual(ParseUnit('1.5G'), 1536)
788
    self.assertEqual(ParseUnit('1.8G'), 1844)
789
    self.assertEqual(ParseUnit('8.28T'), 8682212)
790

    
791
  def testSuffixes(self):
792
    for sep in ('', ' ', '   ', "\t", "\t "):
793
      for suffix, scale in TestParseUnit.SCALES:
794
        for func in (lambda x: x, str.lower, str.upper):
795
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
796
                           1024 * scale)
797

    
798
  def testInvalidInput(self):
799
    for sep in ('-', '_', ',', 'a'):
800
      for suffix, _ in TestParseUnit.SCALES:
801
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
802

    
803
    for suffix, _ in TestParseUnit.SCALES:
804
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
805

    
806

    
807
class TestSshKeys(testutils.GanetiTestCase):
808
  """Test case for the AddAuthorizedKey function"""
809

    
810
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
811
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
812
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
813

    
814
  def setUp(self):
815
    testutils.GanetiTestCase.setUp(self)
816
    self.tmpname = self._CreateTempFile()
817
    handle = open(self.tmpname, 'w')
818
    try:
819
      handle.write("%s\n" % TestSshKeys.KEY_A)
820
      handle.write("%s\n" % TestSshKeys.KEY_B)
821
    finally:
822
      handle.close()
823

    
824
  def testAddingNewKey(self):
825
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
826

    
827
    self.assertFileContent(self.tmpname,
828
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
829
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
830
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
831
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
832

    
833
  def testAddingAlmostButNotCompletelyTheSameKey(self):
834
    AddAuthorizedKey(self.tmpname,
835
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
836

    
837
    self.assertFileContent(self.tmpname,
838
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
839
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
840
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
841
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
842

    
843
  def testAddingExistingKeyWithSomeMoreSpaces(self):
844
    AddAuthorizedKey(self.tmpname,
845
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
846

    
847
    self.assertFileContent(self.tmpname,
848
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
849
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
850
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
851

    
852
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
853
    RemoveAuthorizedKey(self.tmpname,
854
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
855

    
856
    self.assertFileContent(self.tmpname,
857
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
858
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
859

    
860
  def testRemovingNonExistingKey(self):
861
    RemoveAuthorizedKey(self.tmpname,
862
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
863

    
864
    self.assertFileContent(self.tmpname,
865
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
866
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
867
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
868

    
869

    
870
class TestEtcHosts(testutils.GanetiTestCase):
871
  """Test functions modifying /etc/hosts"""
872

    
873
  def setUp(self):
874
    testutils.GanetiTestCase.setUp(self)
875
    self.tmpname = self._CreateTempFile()
876
    handle = open(self.tmpname, 'w')
877
    try:
878
      handle.write('# This is a test file for /etc/hosts\n')
879
      handle.write('127.0.0.1\tlocalhost\n')
880
      handle.write('192.168.1.1 router gw\n')
881
    finally:
882
      handle.close()
883

    
884
  def testSettingNewIp(self):
885
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
886

    
887
    self.assertFileContent(self.tmpname,
888
      "# This is a test file for /etc/hosts\n"
889
      "127.0.0.1\tlocalhost\n"
890
      "192.168.1.1 router gw\n"
891
      "1.2.3.4\tmyhost.domain.tld myhost\n")
892
    self.assertFileMode(self.tmpname, 0644)
893

    
894
  def testSettingExistingIp(self):
895
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
896
                     ['myhost'])
897

    
898
    self.assertFileContent(self.tmpname,
899
      "# This is a test file for /etc/hosts\n"
900
      "127.0.0.1\tlocalhost\n"
901
      "192.168.1.1\tmyhost.domain.tld myhost\n")
902
    self.assertFileMode(self.tmpname, 0644)
903

    
904
  def testSettingDuplicateName(self):
905
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
906

    
907
    self.assertFileContent(self.tmpname,
908
      "# This is a test file for /etc/hosts\n"
909
      "127.0.0.1\tlocalhost\n"
910
      "192.168.1.1 router gw\n"
911
      "1.2.3.4\tmyhost\n")
912
    self.assertFileMode(self.tmpname, 0644)
913

    
914
  def testRemovingExistingHost(self):
915
    RemoveEtcHostsEntry(self.tmpname, 'router')
916

    
917
    self.assertFileContent(self.tmpname,
918
      "# This is a test file for /etc/hosts\n"
919
      "127.0.0.1\tlocalhost\n"
920
      "192.168.1.1 gw\n")
921
    self.assertFileMode(self.tmpname, 0644)
922

    
923
  def testRemovingSingleExistingHost(self):
924
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
925

    
926
    self.assertFileContent(self.tmpname,
927
      "# This is a test file for /etc/hosts\n"
928
      "192.168.1.1 router gw\n")
929
    self.assertFileMode(self.tmpname, 0644)
930

    
931
  def testRemovingNonExistingHost(self):
932
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
933

    
934
    self.assertFileContent(self.tmpname,
935
      "# This is a test file for /etc/hosts\n"
936
      "127.0.0.1\tlocalhost\n"
937
      "192.168.1.1 router gw\n")
938
    self.assertFileMode(self.tmpname, 0644)
939

    
940
  def testRemovingAlias(self):
941
    RemoveEtcHostsEntry(self.tmpname, 'gw')
942

    
943
    self.assertFileContent(self.tmpname,
944
      "# This is a test file for /etc/hosts\n"
945
      "127.0.0.1\tlocalhost\n"
946
      "192.168.1.1 router\n")
947
    self.assertFileMode(self.tmpname, 0644)
948

    
949

    
950
class TestShellQuoting(unittest.TestCase):
951
  """Test case for shell quoting functions"""
952

    
953
  def testShellQuote(self):
954
    self.assertEqual(ShellQuote('abc'), "abc")
955
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
956
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
957
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
958
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
959

    
960
  def testShellQuoteArgs(self):
961
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
962
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
963
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
964

    
965

    
966
class TestTcpPing(unittest.TestCase):
967
  """Testcase for TCP version of ping - against listen(2)ing port"""
968

    
969
  def setUp(self):
970
    self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
971
    self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
972
    self.listenerport = self.listener.getsockname()[1]
973
    self.listener.listen(1)
974

    
975
  def tearDown(self):
976
    self.listener.shutdown(socket.SHUT_RDWR)
977
    del self.listener
978
    del self.listenerport
979

    
980
  def testTcpPingToLocalHostAccept(self):
981
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
982
                         self.listenerport,
983
                         timeout=10,
984
                         live_port_needed=True,
985
                         source=constants.LOCALHOST_IP_ADDRESS,
986
                         ),
987
                 "failed to connect to test listener")
988

    
989
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
990
                         self.listenerport,
991
                         timeout=10,
992
                         live_port_needed=True,
993
                         ),
994
                 "failed to connect to test listener (no source)")
995

    
996

    
997
class TestTcpPingDeaf(unittest.TestCase):
998
  """Testcase for TCP version of ping - against non listen(2)ing port"""
999

    
1000
  def setUp(self):
1001
    self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1002
    self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1003
    self.deaflistenerport = self.deaflistener.getsockname()[1]
1004

    
1005
  def tearDown(self):
1006
    del self.deaflistener
1007
    del self.deaflistenerport
1008

    
1009
  def testTcpPingToLocalHostAcceptDeaf(self):
1010
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1011
                        self.deaflistenerport,
1012
                        timeout=constants.TCP_PING_TIMEOUT,
1013
                        live_port_needed=True,
1014
                        source=constants.LOCALHOST_IP_ADDRESS,
1015
                        ), # need successful connect(2)
1016
                "successfully connected to deaf listener")
1017

    
1018
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1019
                        self.deaflistenerport,
1020
                        timeout=constants.TCP_PING_TIMEOUT,
1021
                        live_port_needed=True,
1022
                        ), # need successful connect(2)
1023
                "successfully connected to deaf listener (no source addr)")
1024

    
1025
  def testTcpPingToLocalHostNoAccept(self):
1026
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1027
                         self.deaflistenerport,
1028
                         timeout=constants.TCP_PING_TIMEOUT,
1029
                         live_port_needed=False,
1030
                         source=constants.LOCALHOST_IP_ADDRESS,
1031
                         ), # ECONNREFUSED is OK
1032
                 "failed to ping alive host on deaf port")
1033

    
1034
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1035
                         self.deaflistenerport,
1036
                         timeout=constants.TCP_PING_TIMEOUT,
1037
                         live_port_needed=False,
1038
                         ), # ECONNREFUSED is OK
1039
                 "failed to ping alive host on deaf port (no source addr)")
1040

    
1041

    
1042
class TestOwnIpAddress(unittest.TestCase):
1043
  """Testcase for OwnIpAddress"""
1044

    
1045
  def testOwnLoopback(self):
1046
    """check having the loopback ip"""
1047
    self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1048
                    "Should own the loopback address")
1049

    
1050
  def testNowOwnAddress(self):
1051
    """check that I don't own an address"""
1052

    
1053
    # network 192.0.2.0/24 is reserved for test/documentation as per
1054
    # rfc 3330, so we *should* not have an address of this range... if
1055
    # this fails, we should extend the test to multiple addresses
1056
    DST_IP = "192.0.2.1"
1057
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1058

    
1059

    
1060
class TestListVisibleFiles(unittest.TestCase):
1061
  """Test case for ListVisibleFiles"""
1062

    
1063
  def setUp(self):
1064
    self.path = tempfile.mkdtemp()
1065

    
1066
  def tearDown(self):
1067
    shutil.rmtree(self.path)
1068

    
1069
  def _test(self, files, expected):
1070
    # Sort a copy
1071
    expected = expected[:]
1072
    expected.sort()
1073

    
1074
    for name in files:
1075
      f = open(os.path.join(self.path, name), 'w')
1076
      try:
1077
        f.write("Test\n")
1078
      finally:
1079
        f.close()
1080

    
1081
    found = ListVisibleFiles(self.path)
1082
    found.sort()
1083

    
1084
    self.assertEqual(found, expected)
1085

    
1086
  def testAllVisible(self):
1087
    files = ["a", "b", "c"]
1088
    expected = files
1089
    self._test(files, expected)
1090

    
1091
  def testNoneVisible(self):
1092
    files = [".a", ".b", ".c"]
1093
    expected = []
1094
    self._test(files, expected)
1095

    
1096
  def testSomeVisible(self):
1097
    files = ["a", "b", ".c"]
1098
    expected = ["a", "b"]
1099
    self._test(files, expected)
1100

    
1101
  def testNonAbsolutePath(self):
1102
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1103

    
1104
  def testNonNormalizedPath(self):
1105
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1106
                          "/bin/../tmp")
1107

    
1108

    
1109
class TestNewUUID(unittest.TestCase):
1110
  """Test case for NewUUID"""
1111

    
1112
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1113
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1114

    
1115
  def runTest(self):
1116
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1117

    
1118

    
1119
class TestUniqueSequence(unittest.TestCase):
1120
  """Test case for UniqueSequence"""
1121

    
1122
  def _test(self, input, expected):
1123
    self.assertEqual(utils.UniqueSequence(input), expected)
1124

    
1125
  def runTest(self):
1126
    # Ordered input
1127
    self._test([1, 2, 3], [1, 2, 3])
1128
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1129
    self._test([1, 2, 2, 3], [1, 2, 3])
1130
    self._test([1, 2, 3, 3], [1, 2, 3])
1131

    
1132
    # Unordered input
1133
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1134
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1135

    
1136
    # Strings
1137
    self._test(["a", "a"], ["a"])
1138
    self._test(["a", "b"], ["a", "b"])
1139
    self._test(["a", "b", "a"], ["a", "b"])
1140

    
1141

    
1142
class TestFirstFree(unittest.TestCase):
1143
  """Test case for the FirstFree function"""
1144

    
1145
  def test(self):
1146
    """Test FirstFree"""
1147
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1148
    self.failUnlessEqual(FirstFree([]), None)
1149
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1150
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1151
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1152

    
1153

    
1154
class TestTailFile(testutils.GanetiTestCase):
1155
  """Test case for the TailFile function"""
1156

    
1157
  def testEmpty(self):
1158
    fname = self._CreateTempFile()
1159
    self.failUnlessEqual(TailFile(fname), [])
1160
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1161

    
1162
  def testAllLines(self):
1163
    data = ["test %d" % i for i in range(30)]
1164
    for i in range(30):
1165
      fname = self._CreateTempFile()
1166
      fd = open(fname, "w")
1167
      fd.write("\n".join(data[:i]))
1168
      if i > 0:
1169
        fd.write("\n")
1170
      fd.close()
1171
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1172

    
1173
  def testPartialLines(self):
1174
    data = ["test %d" % i for i in range(30)]
1175
    fname = self._CreateTempFile()
1176
    fd = open(fname, "w")
1177
    fd.write("\n".join(data))
1178
    fd.write("\n")
1179
    fd.close()
1180
    for i in range(1, 30):
1181
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1182

    
1183
  def testBigFile(self):
1184
    data = ["test %d" % i for i in range(30)]
1185
    fname = self._CreateTempFile()
1186
    fd = open(fname, "w")
1187
    fd.write("X" * 1048576)
1188
    fd.write("\n")
1189
    fd.write("\n".join(data))
1190
    fd.write("\n")
1191
    fd.close()
1192
    for i in range(1, 30):
1193
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1194

    
1195

    
1196
class _BaseFileLockTest:
1197
  """Test case for the FileLock class"""
1198

    
1199
  def testSharedNonblocking(self):
1200
    self.lock.Shared(blocking=False)
1201
    self.lock.Close()
1202

    
1203
  def testExclusiveNonblocking(self):
1204
    self.lock.Exclusive(blocking=False)
1205
    self.lock.Close()
1206

    
1207
  def testUnlockNonblocking(self):
1208
    self.lock.Unlock(blocking=False)
1209
    self.lock.Close()
1210

    
1211
  def testSharedBlocking(self):
1212
    self.lock.Shared(blocking=True)
1213
    self.lock.Close()
1214

    
1215
  def testExclusiveBlocking(self):
1216
    self.lock.Exclusive(blocking=True)
1217
    self.lock.Close()
1218

    
1219
  def testUnlockBlocking(self):
1220
    self.lock.Unlock(blocking=True)
1221
    self.lock.Close()
1222

    
1223
  def testSharedExclusiveUnlock(self):
1224
    self.lock.Shared(blocking=False)
1225
    self.lock.Exclusive(blocking=False)
1226
    self.lock.Unlock(blocking=False)
1227
    self.lock.Close()
1228

    
1229
  def testExclusiveSharedUnlock(self):
1230
    self.lock.Exclusive(blocking=False)
1231
    self.lock.Shared(blocking=False)
1232
    self.lock.Unlock(blocking=False)
1233
    self.lock.Close()
1234

    
1235
  def testSimpleTimeout(self):
1236
    # These will succeed on the first attempt, hence a short timeout
1237
    self.lock.Shared(blocking=True, timeout=10.0)
1238
    self.lock.Exclusive(blocking=False, timeout=10.0)
1239
    self.lock.Unlock(blocking=True, timeout=10.0)
1240
    self.lock.Close()
1241

    
1242
  @staticmethod
1243
  def _TryLockInner(filename, shared, blocking):
1244
    lock = utils.FileLock.Open(filename)
1245

    
1246
    if shared:
1247
      fn = lock.Shared
1248
    else:
1249
      fn = lock.Exclusive
1250

    
1251
    try:
1252
      # The timeout doesn't really matter as the parent process waits for us to
1253
      # finish anyway.
1254
      fn(blocking=blocking, timeout=0.01)
1255
    except errors.LockError, err:
1256
      return False
1257

    
1258
    return True
1259

    
1260
  def _TryLock(self, *args):
1261
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1262
                                      *args)
1263

    
1264
  def testTimeout(self):
1265
    for blocking in [True, False]:
1266
      self.lock.Exclusive(blocking=True)
1267
      self.failIf(self._TryLock(False, blocking))
1268
      self.failIf(self._TryLock(True, blocking))
1269

    
1270
      self.lock.Shared(blocking=True)
1271
      self.assert_(self._TryLock(True, blocking))
1272
      self.failIf(self._TryLock(False, blocking))
1273

    
1274
  def testCloseShared(self):
1275
    self.lock.Close()
1276
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1277

    
1278
  def testCloseExclusive(self):
1279
    self.lock.Close()
1280
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1281

    
1282
  def testCloseUnlock(self):
1283
    self.lock.Close()
1284
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1285

    
1286

    
1287
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1288
  TESTDATA = "Hello World\n" * 10
1289

    
1290
  def setUp(self):
1291
    testutils.GanetiTestCase.setUp(self)
1292

    
1293
    self.tmpfile = tempfile.NamedTemporaryFile()
1294
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1295
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1296

    
1297
    # Ensure "Open" didn't truncate file
1298
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1299

    
1300
  def tearDown(self):
1301
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1302

    
1303
    testutils.GanetiTestCase.tearDown(self)
1304

    
1305

    
1306
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1307
  def setUp(self):
1308
    self.tmpfile = tempfile.NamedTemporaryFile()
1309
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1310

    
1311

    
1312
class TestTimeFunctions(unittest.TestCase):
1313
  """Test case for time functions"""
1314

    
1315
  def runTest(self):
1316
    self.assertEqual(utils.SplitTime(1), (1, 0))
1317
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1318
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1319
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1320
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1321
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1322
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1323
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1324

    
1325
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1326

    
1327
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1328
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1329
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1330

    
1331
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1332
                     1218448917.481)
1333
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1334

    
1335
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1336
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1337
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1338
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1339
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1340

    
1341

    
1342
class FieldSetTestCase(unittest.TestCase):
1343
  """Test case for FieldSets"""
1344

    
1345
  def testSimpleMatch(self):
1346
    f = utils.FieldSet("a", "b", "c", "def")
1347
    self.failUnless(f.Matches("a"))
1348
    self.failIf(f.Matches("d"), "Substring matched")
1349
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1350
    self.failIf(f.NonMatching(["b", "c"]))
1351
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1352
    self.failUnless(f.NonMatching(["a", "d"]))
1353

    
1354
  def testRegexMatch(self):
1355
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1356
    self.failUnless(f.Matches("b1"))
1357
    self.failUnless(f.Matches("b99"))
1358
    self.failIf(f.Matches("b/1"))
1359
    self.failIf(f.NonMatching(["b12", "c"]))
1360
    self.failUnless(f.NonMatching(["a", "1"]))
1361

    
1362
class TestForceDictType(unittest.TestCase):
1363
  """Test case for ForceDictType"""
1364

    
1365
  def setUp(self):
1366
    self.key_types = {
1367
      'a': constants.VTYPE_INT,
1368
      'b': constants.VTYPE_BOOL,
1369
      'c': constants.VTYPE_STRING,
1370
      'd': constants.VTYPE_SIZE,
1371
      }
1372

    
1373
  def _fdt(self, dict, allowed_values=None):
1374
    if allowed_values is None:
1375
      ForceDictType(dict, self.key_types)
1376
    else:
1377
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1378

    
1379
    return dict
1380

    
1381
  def testSimpleDict(self):
1382
    self.assertEqual(self._fdt({}), {})
1383
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1384
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1385
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1386
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1387
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1388
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1389
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1390
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1391
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1392
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1393
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1394

    
1395
  def testErrors(self):
1396
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1397
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1398
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1399
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1400

    
1401

    
1402
class TestIsAbsNormPath(unittest.TestCase):
1403
  """Testing case for IsNormAbsPath"""
1404

    
1405
  def _pathTestHelper(self, path, result):
1406
    if result:
1407
      self.assert_(IsNormAbsPath(path),
1408
          "Path %s should result absolute and normalized" % path)
1409
    else:
1410
      self.assert_(not IsNormAbsPath(path),
1411
          "Path %s should not result absolute and normalized" % path)
1412

    
1413
  def testBase(self):
1414
    self._pathTestHelper('/etc', True)
1415
    self._pathTestHelper('/srv', True)
1416
    self._pathTestHelper('etc', False)
1417
    self._pathTestHelper('/etc/../root', False)
1418
    self._pathTestHelper('/etc/', False)
1419

    
1420

    
1421
class TestSafeEncode(unittest.TestCase):
1422
  """Test case for SafeEncode"""
1423

    
1424
  def testAscii(self):
1425
    for txt in [string.digits, string.letters, string.punctuation]:
1426
      self.failUnlessEqual(txt, SafeEncode(txt))
1427

    
1428
  def testDoubleEncode(self):
1429
    for i in range(255):
1430
      txt = SafeEncode(chr(i))
1431
      self.failUnlessEqual(txt, SafeEncode(txt))
1432

    
1433
  def testUnicode(self):
1434
    # 1024 is high enough to catch non-direct ASCII mappings
1435
    for i in range(1024):
1436
      txt = SafeEncode(unichr(i))
1437
      self.failUnlessEqual(txt, SafeEncode(txt))
1438

    
1439

    
1440
class TestFormatTime(unittest.TestCase):
1441
  """Testing case for FormatTime"""
1442

    
1443
  def testNone(self):
1444
    self.failUnlessEqual(FormatTime(None), "N/A")
1445

    
1446
  def testInvalid(self):
1447
    self.failUnlessEqual(FormatTime(()), "N/A")
1448

    
1449
  def testNow(self):
1450
    # tests that we accept time.time input
1451
    FormatTime(time.time())
1452
    # tests that we accept int input
1453
    FormatTime(int(time.time()))
1454

    
1455

    
1456
class RunInSeparateProcess(unittest.TestCase):
1457
  def test(self):
1458
    for exp in [True, False]:
1459
      def _child():
1460
        return exp
1461

    
1462
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1463

    
1464
  def testArgs(self):
1465
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1466
      def _child(carg1, carg2):
1467
        return carg1 == "Foo" and carg2 == arg
1468

    
1469
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1470

    
1471
  def testPid(self):
1472
    parent_pid = os.getpid()
1473

    
1474
    def _check():
1475
      return os.getpid() == parent_pid
1476

    
1477
    self.failIf(utils.RunInSeparateProcess(_check))
1478

    
1479
  def testSignal(self):
1480
    def _kill():
1481
      os.kill(os.getpid(), signal.SIGTERM)
1482

    
1483
    self.assertRaises(errors.GenericError,
1484
                      utils.RunInSeparateProcess, _kill)
1485

    
1486
  def testException(self):
1487
    def _exc():
1488
      raise errors.GenericError("This is a test")
1489

    
1490
    self.assertRaises(errors.GenericError,
1491
                      utils.RunInSeparateProcess, _exc)
1492

    
1493

    
1494
class TestFingerprintFile(unittest.TestCase):
1495
  def setUp(self):
1496
    self.tmpfile = tempfile.NamedTemporaryFile()
1497

    
1498
  def test(self):
1499
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1500
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1501

    
1502
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1503
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1504
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1505

    
1506

    
1507
class TestUnescapeAndSplit(unittest.TestCase):
1508
  """Testing case for UnescapeAndSplit"""
1509

    
1510
  def setUp(self):
1511
    # testing more that one separator for regexp safety
1512
    self._seps = [",", "+", "."]
1513

    
1514
  def testSimple(self):
1515
    a = ["a", "b", "c", "d"]
1516
    for sep in self._seps:
1517
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1518

    
1519
  def testEscape(self):
1520
    for sep in self._seps:
1521
      a = ["a", "b\\" + sep + "c", "d"]
1522
      b = ["a", "b" + sep + "c", "d"]
1523
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1524

    
1525
  def testDoubleEscape(self):
1526
    for sep in self._seps:
1527
      a = ["a", "b\\\\", "c", "d"]
1528
      b = ["a", "b\\", "c", "d"]
1529
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1530

    
1531
  def testThreeEscape(self):
1532
    for sep in self._seps:
1533
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1534
      b = ["a", "b\\" + sep + "c", "d"]
1535
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1536

    
1537

    
1538
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1539
  def setUp(self):
1540
    self.tmpdir = tempfile.mkdtemp()
1541

    
1542
  def tearDown(self):
1543
    shutil.rmtree(self.tmpdir)
1544

    
1545
  def _checkRsaPrivateKey(self, key):
1546
    lines = key.splitlines()
1547
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1548
            "-----END RSA PRIVATE KEY-----" in lines)
1549

    
1550
  def _checkCertificate(self, cert):
1551
    lines = cert.splitlines()
1552
    return ("-----BEGIN CERTIFICATE-----" in lines and
1553
            "-----END CERTIFICATE-----" in lines)
1554

    
1555
  def test(self):
1556
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1557
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1558
      self._checkRsaPrivateKey(key_pem)
1559
      self._checkCertificate(cert_pem)
1560

    
1561
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1562
                                           key_pem)
1563
      self.assert_(key.bits() >= 1024)
1564
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1565
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1566

    
1567
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1568
                                             cert_pem)
1569
      self.failIf(x509.has_expired())
1570
      self.assertEqual(x509.get_issuer().CN, common_name)
1571
      self.assertEqual(x509.get_subject().CN, common_name)
1572
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1573

    
1574
  def testLegacy(self):
1575
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1576

    
1577
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1578

    
1579
    cert1 = utils.ReadFile(cert1_filename)
1580

    
1581
    self.assert_(self._checkRsaPrivateKey(cert1))
1582
    self.assert_(self._checkCertificate(cert1))
1583

    
1584

    
1585
class TestPathJoin(unittest.TestCase):
1586
  """Testing case for PathJoin"""
1587

    
1588
  def testBasicItems(self):
1589
    mlist = ["/a", "b", "c"]
1590
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1591

    
1592
  def testNonAbsPrefix(self):
1593
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1594

    
1595
  def testBackTrack(self):
1596
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1597

    
1598
  def testMultiAbs(self):
1599
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1600

    
1601

    
1602
class TestHostInfo(unittest.TestCase):
1603
  """Testing case for HostInfo"""
1604

    
1605
  def testUppercase(self):
1606
    data = "AbC.example.com"
1607
    self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1608

    
1609
  def testTooLongName(self):
1610
    data = "a.b." + "c" * 255
1611
    self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1612

    
1613
  def testTrailingDot(self):
1614
    data = "a.b.c"
1615
    self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1616

    
1617
  def testInvalidName(self):
1618
    data = [
1619
      "a b",
1620
      "a/b",
1621
      ".a.b",
1622
      "a..b",
1623
      ]
1624
    for value in data:
1625
      self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1626

    
1627
  def testValidName(self):
1628
    data = [
1629
      "a.b",
1630
      "a-b",
1631
      "a_b",
1632
      "a.b.c",
1633
      ]
1634
    for value in data:
1635
      HostInfo.NormalizeName(value)
1636

    
1637

    
1638
class TestParseAsn1Generalizedtime(unittest.TestCase):
1639
  def test(self):
1640
    # UTC
1641
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1642
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1643
                     1266860512)
1644
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1645
                     (2**31) - 1)
1646

    
1647
    # With offset
1648
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1649
                     1266860512)
1650
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1651
                     1266931012)
1652
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1653
                     1266931088)
1654
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1655
                     1266931295)
1656
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1657
                     3600)
1658

    
1659
    # Leap seconds are not supported by datetime.datetime
1660
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1661
                      "19841231235960+0000")
1662
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1663
                      "19920630235960+0000")
1664

    
1665
    # Errors
1666
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1667
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1668
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1669
                      "20100222174152")
1670
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1671
                      "Mon Feb 22 17:47:02 UTC 2010")
1672
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1673
                      "2010-02-22 17:42:02")
1674

    
1675

    
1676
class TestGetX509CertValidity(testutils.GanetiTestCase):
1677
  def setUp(self):
1678
    testutils.GanetiTestCase.setUp(self)
1679

    
1680
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1681

    
1682
    # Test whether we have pyOpenSSL 0.7 or above
1683
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1684

    
1685
    if not self.pyopenssl0_7:
1686
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1687
                    " function correctly")
1688

    
1689
  def _LoadCert(self, name):
1690
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1691
                                           self._ReadTestData(name))
1692

    
1693
  def test(self):
1694
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1695
    if self.pyopenssl0_7:
1696
      self.assertEqual(validity, (1266919967, 1267524767))
1697
    else:
1698
      self.assertEqual(validity, (None, None))
1699

    
1700

    
1701
class TestSignX509Certificate(unittest.TestCase):
1702
  KEY = "My private key!"
1703
  KEY_OTHER = "Another key"
1704

    
1705
  def test(self):
1706
    # Generate certificate valid for 5 minutes
1707
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1708

    
1709
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1710
                                           cert_pem)
1711

    
1712
    # No signature at all
1713
    self.assertRaises(errors.GenericError,
1714
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1715

    
1716
    # Invalid input
1717
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1718
                      "", self.KEY)
1719
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1720
                      "X-Ganeti-Signature: \n", self.KEY)
1721
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1722
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1723
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1724
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1725
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1726
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1727

    
1728
    # Invalid salt
1729
    for salt in list("-_@$,:;/\\ \t\n"):
1730
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1731
                        cert_pem, self.KEY, "foo%sbar" % salt)
1732

    
1733
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1734
                 utils.GenerateSecret(numbytes=4),
1735
                 utils.GenerateSecret(numbytes=16),
1736
                 "{123:456}".encode("hex")]:
1737
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1738

    
1739
      self._Check(cert, salt, signed_pem)
1740

    
1741
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1742
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1743
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1744
                               "lines----\n------ at\nthe end!"))
1745

    
1746
  def _Check(self, cert, salt, pem):
1747
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1748
    self.assertEqual(salt, salt2)
1749
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1750

    
1751
    # Other key
1752
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1753
                      pem, self.KEY_OTHER)
1754

    
1755

    
1756
if __name__ == '__main__':
1757
  testutils.GanetiTestProgram()