Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 615aaaba

History | View | Annotate | Download (67.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import stat
31
import signal
32
import socket
33
import shutil
34
import re
35
import select
36
import string
37
import fcntl
38
import OpenSSL
39
import warnings
40
import distutils.version
41
import glob
42

    
43
import ganeti
44
import testutils
45
from ganeti import constants
46
from ganeti import utils
47
from ganeti import errors
48
from ganeti import serializer
49
from ganeti.utils import IsProcessAlive, RunCmd, \
50
     RemoveFile, MatchNameComponent, FormatUnit, \
51
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
52
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
53
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
54
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
55
     UnescapeAndSplit, RunParts, PathJoin, HostInfo
56

    
57
from ganeti.errors import LockError, UnitParseError, GenericError, \
58
     ProgrammerError, OpPrereqError
59

    
60

    
61
class TestIsProcessAlive(unittest.TestCase):
62
  """Testing case for IsProcessAlive"""
63

    
64
  def testExists(self):
65
    mypid = os.getpid()
66
    self.assert_(IsProcessAlive(mypid),
67
                 "can't find myself running")
68

    
69
  def testNotExisting(self):
70
    pid_non_existing = os.fork()
71
    if pid_non_existing == 0:
72
      os._exit(0)
73
    elif pid_non_existing < 0:
74
      raise SystemError("can't fork")
75
    os.waitpid(pid_non_existing, 0)
76
    self.assert_(not IsProcessAlive(pid_non_existing),
77
                 "nonexisting process detected")
78

    
79

    
80
class TestPidFileFunctions(unittest.TestCase):
81
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
82

    
83
  def setUp(self):
84
    self.dir = tempfile.mkdtemp()
85
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
86
    utils.DaemonPidFileName = self.f_dpn
87

    
88
  def testPidFileFunctions(self):
89
    pid_file = self.f_dpn('test')
90
    utils.WritePidFile('test')
91
    self.failUnless(os.path.exists(pid_file),
92
                    "PID file should have been created")
93
    read_pid = utils.ReadPidFile(pid_file)
94
    self.failUnlessEqual(read_pid, os.getpid())
95
    self.failUnless(utils.IsProcessAlive(read_pid))
96
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
97
    utils.RemovePidFile('test')
98
    self.failIf(os.path.exists(pid_file),
99
                "PID file should not exist anymore")
100
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
101
                         "ReadPidFile should return 0 for missing pid file")
102
    fh = open(pid_file, "w")
103
    fh.write("blah\n")
104
    fh.close()
105
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
106
                         "ReadPidFile should return 0 for invalid pid file")
107
    utils.RemovePidFile('test')
108
    self.failIf(os.path.exists(pid_file),
109
                "PID file should not exist anymore")
110

    
111
  def testKill(self):
112
    pid_file = self.f_dpn('child')
113
    r_fd, w_fd = os.pipe()
114
    new_pid = os.fork()
115
    if new_pid == 0: #child
116
      utils.WritePidFile('child')
117
      os.write(w_fd, 'a')
118
      signal.pause()
119
      os._exit(0)
120
      return
121
    # else we are in the parent
122
    # wait until the child has written the pid file
123
    os.read(r_fd, 1)
124
    read_pid = utils.ReadPidFile(pid_file)
125
    self.failUnlessEqual(read_pid, new_pid)
126
    self.failUnless(utils.IsProcessAlive(new_pid))
127
    utils.KillProcess(new_pid, waitpid=True)
128
    self.failIf(utils.IsProcessAlive(new_pid))
129
    utils.RemovePidFile('child')
130
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
131

    
132
  def tearDown(self):
133
    for name in os.listdir(self.dir):
134
      os.unlink(os.path.join(self.dir, name))
135
    os.rmdir(self.dir)
136

    
137

    
138
class TestRunCmd(testutils.GanetiTestCase):
139
  """Testing case for the RunCmd function"""
140

    
141
  def setUp(self):
142
    testutils.GanetiTestCase.setUp(self)
143
    self.magic = time.ctime() + " ganeti test"
144
    self.fname = self._CreateTempFile()
145

    
146
  def testOk(self):
147
    """Test successful exit code"""
148
    result = RunCmd("/bin/sh -c 'exit 0'")
149
    self.assertEqual(result.exit_code, 0)
150
    self.assertEqual(result.output, "")
151

    
152
  def testFail(self):
153
    """Test fail exit code"""
154
    result = RunCmd("/bin/sh -c 'exit 1'")
155
    self.assertEqual(result.exit_code, 1)
156
    self.assertEqual(result.output, "")
157

    
158
  def testStdout(self):
159
    """Test standard output"""
160
    cmd = 'echo -n "%s"' % self.magic
161
    result = RunCmd("/bin/sh -c '%s'" % cmd)
162
    self.assertEqual(result.stdout, self.magic)
163
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
164
    self.assertEqual(result.output, "")
165
    self.assertFileContent(self.fname, self.magic)
166

    
167
  def testStderr(self):
168
    """Test standard error"""
169
    cmd = 'echo -n "%s"' % self.magic
170
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
171
    self.assertEqual(result.stderr, self.magic)
172
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
173
    self.assertEqual(result.output, "")
174
    self.assertFileContent(self.fname, self.magic)
175

    
176
  def testCombined(self):
177
    """Test combined output"""
178
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
179
    expected = "A" + self.magic + "B" + self.magic
180
    result = RunCmd("/bin/sh -c '%s'" % cmd)
181
    self.assertEqual(result.output, expected)
182
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
183
    self.assertEqual(result.output, "")
184
    self.assertFileContent(self.fname, expected)
185

    
186
  def testSignal(self):
187
    """Test signal"""
188
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
189
    self.assertEqual(result.signal, 15)
190
    self.assertEqual(result.output, "")
191

    
192
  def testListRun(self):
193
    """Test list runs"""
194
    result = RunCmd(["true"])
195
    self.assertEqual(result.signal, None)
196
    self.assertEqual(result.exit_code, 0)
197
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
198
    self.assertEqual(result.signal, None)
199
    self.assertEqual(result.exit_code, 1)
200
    result = RunCmd(["echo", "-n", self.magic])
201
    self.assertEqual(result.signal, None)
202
    self.assertEqual(result.exit_code, 0)
203
    self.assertEqual(result.stdout, self.magic)
204

    
205
  def testFileEmptyOutput(self):
206
    """Test file output"""
207
    result = RunCmd(["true"], output=self.fname)
208
    self.assertEqual(result.signal, None)
209
    self.assertEqual(result.exit_code, 0)
210
    self.assertFileContent(self.fname, "")
211

    
212
  def testLang(self):
213
    """Test locale environment"""
214
    old_env = os.environ.copy()
215
    try:
216
      os.environ["LANG"] = "en_US.UTF-8"
217
      os.environ["LC_ALL"] = "en_US.UTF-8"
218
      result = RunCmd(["locale"])
219
      for line in result.output.splitlines():
220
        key, value = line.split("=", 1)
221
        # Ignore these variables, they're overridden by LC_ALL
222
        if key == "LANG" or key == "LANGUAGE":
223
          continue
224
        self.failIf(value and value != "C" and value != '"C"',
225
            "Variable %s is set to the invalid value '%s'" % (key, value))
226
    finally:
227
      os.environ = old_env
228

    
229
  def testDefaultCwd(self):
230
    """Test default working directory"""
231
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
232

    
233
  def testCwd(self):
234
    """Test default working directory"""
235
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
236
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
237
    cwd = os.getcwd()
238
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
239

    
240
  def testResetEnv(self):
241
    """Test environment reset functionality"""
242
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
243
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
244
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
245

    
246

    
247
class TestRunParts(unittest.TestCase):
248
  """Testing case for the RunParts function"""
249

    
250
  def setUp(self):
251
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
252

    
253
  def tearDown(self):
254
    shutil.rmtree(self.rundir)
255

    
256
  def testEmpty(self):
257
    """Test on an empty dir"""
258
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
259

    
260
  def testSkipWrongName(self):
261
    """Test that wrong files are skipped"""
262
    fname = os.path.join(self.rundir, "00test.dot")
263
    utils.WriteFile(fname, data="")
264
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
265
    relname = os.path.basename(fname)
266
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
267
                         [(relname, constants.RUNPARTS_SKIP, None)])
268

    
269
  def testSkipNonExec(self):
270
    """Test that non executable files are skipped"""
271
    fname = os.path.join(self.rundir, "00test")
272
    utils.WriteFile(fname, data="")
273
    relname = os.path.basename(fname)
274
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
275
                         [(relname, constants.RUNPARTS_SKIP, None)])
276

    
277
  def testError(self):
278
    """Test error on a broken executable"""
279
    fname = os.path.join(self.rundir, "00test")
280
    utils.WriteFile(fname, data="")
281
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
282
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
283
    self.failUnlessEqual(relname, os.path.basename(fname))
284
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
285
    self.failUnless(error)
286

    
287
  def testSorted(self):
288
    """Test executions are sorted"""
289
    files = []
290
    files.append(os.path.join(self.rundir, "64test"))
291
    files.append(os.path.join(self.rundir, "00test"))
292
    files.append(os.path.join(self.rundir, "42test"))
293

    
294
    for fname in files:
295
      utils.WriteFile(fname, data="")
296

    
297
    results = RunParts(self.rundir, reset_env=True)
298

    
299
    for fname in sorted(files):
300
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
301

    
302
  def testOk(self):
303
    """Test correct execution"""
304
    fname = os.path.join(self.rundir, "00test")
305
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
306
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
307
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
308
    self.failUnlessEqual(relname, os.path.basename(fname))
309
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
310
    self.failUnlessEqual(runresult.stdout, "ciao")
311

    
312
  def testRunFail(self):
313
    """Test correct execution, with run failure"""
314
    fname = os.path.join(self.rundir, "00test")
315
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
316
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
317
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
318
    self.failUnlessEqual(relname, os.path.basename(fname))
319
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
320
    self.failUnlessEqual(runresult.exit_code, 1)
321
    self.failUnless(runresult.failed)
322

    
323
  def testRunMix(self):
324
    files = []
325
    files.append(os.path.join(self.rundir, "00test"))
326
    files.append(os.path.join(self.rundir, "42test"))
327
    files.append(os.path.join(self.rundir, "64test"))
328
    files.append(os.path.join(self.rundir, "99test"))
329

    
330
    files.sort()
331

    
332
    # 1st has errors in execution
333
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
334
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
335

    
336
    # 2nd is skipped
337
    utils.WriteFile(files[1], data="")
338

    
339
    # 3rd cannot execute properly
340
    utils.WriteFile(files[2], data="")
341
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
342

    
343
    # 4th execs
344
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
345
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
346

    
347
    results = RunParts(self.rundir, reset_env=True)
348

    
349
    (relname, status, runresult) = results[0]
350
    self.failUnlessEqual(relname, os.path.basename(files[0]))
351
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
352
    self.failUnlessEqual(runresult.exit_code, 1)
353
    self.failUnless(runresult.failed)
354

    
355
    (relname, status, runresult) = results[1]
356
    self.failUnlessEqual(relname, os.path.basename(files[1]))
357
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
358
    self.failUnlessEqual(runresult, None)
359

    
360
    (relname, status, runresult) = results[2]
361
    self.failUnlessEqual(relname, os.path.basename(files[2]))
362
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
363
    self.failUnless(runresult)
364

    
365
    (relname, status, runresult) = results[3]
366
    self.failUnlessEqual(relname, os.path.basename(files[3]))
367
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
368
    self.failUnlessEqual(runresult.output, "ciao")
369
    self.failUnlessEqual(runresult.exit_code, 0)
370
    self.failUnless(not runresult.failed)
371

    
372

    
373
class TestStartDaemon(testutils.GanetiTestCase):
374
  def setUp(self):
375
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
376
    self.tmpfile = os.path.join(self.tmpdir, "test")
377

    
378
  def tearDown(self):
379
    shutil.rmtree(self.tmpdir)
380

    
381
  def testShell(self):
382
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
383
    self._wait(self.tmpfile, 60.0, "Hello World")
384

    
385
  def testShellOutput(self):
386
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
387
    self._wait(self.tmpfile, 60.0, "Hello World")
388

    
389
  def testNoShellNoOutput(self):
390
    utils.StartDaemon(["pwd"])
391

    
392
  def testNoShellNoOutputTouch(self):
393
    testfile = os.path.join(self.tmpdir, "check")
394
    self.failIf(os.path.exists(testfile))
395
    utils.StartDaemon(["touch", testfile])
396
    self._wait(testfile, 60.0, "")
397

    
398
  def testNoShellOutput(self):
399
    utils.StartDaemon(["pwd"], output=self.tmpfile)
400
    self._wait(self.tmpfile, 60.0, "/")
401

    
402
  def testNoShellOutputCwd(self):
403
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
404
    self._wait(self.tmpfile, 60.0, os.getcwd())
405

    
406
  def testShellEnv(self):
407
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
408
                      env={ "GNT_TEST_VAR": "Hello World", })
409
    self._wait(self.tmpfile, 60.0, "Hello World")
410

    
411
  def testNoShellEnv(self):
412
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
413
                      env={ "GNT_TEST_VAR": "Hello World", })
414
    self._wait(self.tmpfile, 60.0, "Hello World")
415

    
416
  def testOutputFd(self):
417
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
418
    try:
419
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
420
    finally:
421
      os.close(fd)
422
    self._wait(self.tmpfile, 60.0, os.getcwd())
423

    
424
  def testPid(self):
425
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
426
    self._wait(self.tmpfile, 60.0, str(pid))
427

    
428
  def testPidFile(self):
429
    pidfile = os.path.join(self.tmpdir, "pid")
430
    checkfile = os.path.join(self.tmpdir, "abort")
431

    
432
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
433
                            output=self.tmpfile)
434
    try:
435
      fd = os.open(pidfile, os.O_RDONLY)
436
      try:
437
        # Check file is locked
438
        self.assertRaises(errors.LockError, utils.LockFile, fd)
439

    
440
        pidtext = os.read(fd, 100)
441
      finally:
442
        os.close(fd)
443

    
444
      self.assertEqual(int(pidtext.strip()), pid)
445

    
446
      self.assert_(utils.IsProcessAlive(pid))
447
    finally:
448
      # No matter what happens, kill daemon
449
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
450
      self.failIf(utils.IsProcessAlive(pid))
451

    
452
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
453

    
454
  def _wait(self, path, timeout, expected):
455
    # Due to the asynchronous nature of daemon processes, polling is necessary.
456
    # A timeout makes sure the test doesn't hang forever.
457
    def _CheckFile():
458
      if not (os.path.isfile(path) and
459
              utils.ReadFile(path).strip() == expected):
460
        raise utils.RetryAgain()
461

    
462
    try:
463
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
464
    except utils.RetryTimeout:
465
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
466
                " didn't write the correct output" % timeout)
467

    
468
  def testError(self):
469
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
470
                      ["./does-NOT-EXIST/here/0123456789"])
471
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
472
                      ["./does-NOT-EXIST/here/0123456789"],
473
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
474
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
475
                      ["./does-NOT-EXIST/here/0123456789"],
476
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
477
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
478
                      ["./does-NOT-EXIST/here/0123456789"],
479
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
480

    
481
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
482
    try:
483
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
484
                        ["./does-NOT-EXIST/here/0123456789"],
485
                        output=self.tmpfile, output_fd=fd)
486
    finally:
487
      os.close(fd)
488

    
489

    
490
class TestSetCloseOnExecFlag(unittest.TestCase):
491
  """Tests for SetCloseOnExecFlag"""
492

    
493
  def setUp(self):
494
    self.tmpfile = tempfile.TemporaryFile()
495

    
496
  def testEnable(self):
497
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
498
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
499
                    fcntl.FD_CLOEXEC)
500

    
501
  def testDisable(self):
502
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
503
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
504
                fcntl.FD_CLOEXEC)
505

    
506

    
507
class TestSetNonblockFlag(unittest.TestCase):
508
  def setUp(self):
509
    self.tmpfile = tempfile.TemporaryFile()
510

    
511
  def testEnable(self):
512
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
513
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
514
                    os.O_NONBLOCK)
515

    
516
  def testDisable(self):
517
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
518
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
519
                os.O_NONBLOCK)
520

    
521

    
522
class TestRemoveFile(unittest.TestCase):
523
  """Test case for the RemoveFile function"""
524

    
525
  def setUp(self):
526
    """Create a temp dir and file for each case"""
527
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
528
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
529
    os.close(fd)
530

    
531
  def tearDown(self):
532
    if os.path.exists(self.tmpfile):
533
      os.unlink(self.tmpfile)
534
    os.rmdir(self.tmpdir)
535

    
536
  def testIgnoreDirs(self):
537
    """Test that RemoveFile() ignores directories"""
538
    self.assertEqual(None, RemoveFile(self.tmpdir))
539

    
540
  def testIgnoreNotExisting(self):
541
    """Test that RemoveFile() ignores non-existing files"""
542
    RemoveFile(self.tmpfile)
543
    RemoveFile(self.tmpfile)
544

    
545
  def testRemoveFile(self):
546
    """Test that RemoveFile does remove a file"""
547
    RemoveFile(self.tmpfile)
548
    if os.path.exists(self.tmpfile):
549
      self.fail("File '%s' not removed" % self.tmpfile)
550

    
551
  def testRemoveSymlink(self):
552
    """Test that RemoveFile does remove symlinks"""
553
    symlink = self.tmpdir + "/symlink"
554
    os.symlink("no-such-file", symlink)
555
    RemoveFile(symlink)
556
    if os.path.exists(symlink):
557
      self.fail("File '%s' not removed" % symlink)
558
    os.symlink(self.tmpfile, symlink)
559
    RemoveFile(symlink)
560
    if os.path.exists(symlink):
561
      self.fail("File '%s' not removed" % symlink)
562

    
563

    
564
class TestRename(unittest.TestCase):
565
  """Test case for RenameFile"""
566

    
567
  def setUp(self):
568
    """Create a temporary directory"""
569
    self.tmpdir = tempfile.mkdtemp()
570
    self.tmpfile = os.path.join(self.tmpdir, "test1")
571

    
572
    # Touch the file
573
    open(self.tmpfile, "w").close()
574

    
575
  def tearDown(self):
576
    """Remove temporary directory"""
577
    shutil.rmtree(self.tmpdir)
578

    
579
  def testSimpleRename1(self):
580
    """Simple rename 1"""
581
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
582
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
583

    
584
  def testSimpleRename2(self):
585
    """Simple rename 2"""
586
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
587
                     mkdir=True)
588
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
589

    
590
  def testRenameMkdir(self):
591
    """Rename with mkdir"""
592
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
593
                     mkdir=True)
594
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
595
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
596

    
597
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
598
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
599
                     mkdir=True)
600
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
601
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
602
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
603

    
604

    
605
class TestMatchNameComponent(unittest.TestCase):
606
  """Test case for the MatchNameComponent function"""
607

    
608
  def testEmptyList(self):
609
    """Test that there is no match against an empty list"""
610

    
611
    self.failUnlessEqual(MatchNameComponent("", []), None)
612
    self.failUnlessEqual(MatchNameComponent("test", []), None)
613

    
614
  def testSingleMatch(self):
615
    """Test that a single match is performed correctly"""
616
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
617
    for key in "test2", "test2.example", "test2.example.com":
618
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
619

    
620
  def testMultipleMatches(self):
621
    """Test that a multiple match is returned as None"""
622
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
623
    for key in "test1", "test1.example":
624
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
625

    
626
  def testFullMatch(self):
627
    """Test that a full match is returned correctly"""
628
    key1 = "test1"
629
    key2 = "test1.example"
630
    mlist = [key2, key2 + ".com"]
631
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
632
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
633

    
634
  def testCaseInsensitivePartialMatch(self):
635
    """Test for the case_insensitive keyword"""
636
    mlist = ["test1.example.com", "test2.example.net"]
637
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
638
                     "test2.example.net")
639
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
640
                     "test2.example.net")
641
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
642
                     "test2.example.net")
643
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
644
                     "test2.example.net")
645

    
646

    
647
  def testCaseInsensitiveFullMatch(self):
648
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
649
    # Between the two ts1 a full string match non-case insensitive should work
650
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
651
                     None)
652
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
653
                     "ts1.ex")
654
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
655
                     "ts1.ex")
656
    # Between the two ts2 only case differs, so only case-match works
657
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
658
                     "ts2.ex")
659
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
660
                     "Ts2.ex")
661
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
662
                     None)
663

    
664

    
665
class TestTimestampForFilename(unittest.TestCase):
666
  def test(self):
667
    self.assert_("." not in utils.TimestampForFilename())
668
    self.assert_(":" not in utils.TimestampForFilename())
669

    
670

    
671
class TestCreateBackup(testutils.GanetiTestCase):
672
  def setUp(self):
673
    testutils.GanetiTestCase.setUp(self)
674

    
675
    self.tmpdir = tempfile.mkdtemp()
676

    
677
  def tearDown(self):
678
    testutils.GanetiTestCase.tearDown(self)
679

    
680
    shutil.rmtree(self.tmpdir)
681

    
682
  def testEmpty(self):
683
    filename = utils.PathJoin(self.tmpdir, "config.data")
684
    utils.WriteFile(filename, data="")
685
    bname = utils.CreateBackup(filename)
686
    self.assertFileContent(bname, "")
687
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
688
    utils.CreateBackup(filename)
689
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
690
    utils.CreateBackup(filename)
691
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
692

    
693
    fifoname = utils.PathJoin(self.tmpdir, "fifo")
694
    os.mkfifo(fifoname)
695
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
696

    
697
  def testContent(self):
698
    bkpcount = 0
699
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
700
      for rep in [1, 2, 10, 127]:
701
        testdata = data * rep
702

    
703
        filename = utils.PathJoin(self.tmpdir, "test.data_")
704
        utils.WriteFile(filename, data=testdata)
705
        self.assertFileContent(filename, testdata)
706

    
707
        for _ in range(3):
708
          bname = utils.CreateBackup(filename)
709
          bkpcount += 1
710
          self.assertFileContent(bname, testdata)
711
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
712

    
713

    
714
class TestFormatUnit(unittest.TestCase):
715
  """Test case for the FormatUnit function"""
716

    
717
  def testMiB(self):
718
    self.assertEqual(FormatUnit(1, 'h'), '1M')
719
    self.assertEqual(FormatUnit(100, 'h'), '100M')
720
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
721

    
722
    self.assertEqual(FormatUnit(1, 'm'), '1')
723
    self.assertEqual(FormatUnit(100, 'm'), '100')
724
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
725

    
726
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
727
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
728
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
729
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
730

    
731
  def testGiB(self):
732
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
733
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
734
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
735
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
736

    
737
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
738
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
739
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
740
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
741

    
742
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
743
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
744
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
745

    
746
  def testTiB(self):
747
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
748
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
749
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
750

    
751
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
752
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
753
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
754

    
755
class TestParseUnit(unittest.TestCase):
756
  """Test case for the ParseUnit function"""
757

    
758
  SCALES = (('', 1),
759
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
760
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
761
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
762

    
763
  def testRounding(self):
764
    self.assertEqual(ParseUnit('0'), 0)
765
    self.assertEqual(ParseUnit('1'), 4)
766
    self.assertEqual(ParseUnit('2'), 4)
767
    self.assertEqual(ParseUnit('3'), 4)
768

    
769
    self.assertEqual(ParseUnit('124'), 124)
770
    self.assertEqual(ParseUnit('125'), 128)
771
    self.assertEqual(ParseUnit('126'), 128)
772
    self.assertEqual(ParseUnit('127'), 128)
773
    self.assertEqual(ParseUnit('128'), 128)
774
    self.assertEqual(ParseUnit('129'), 132)
775
    self.assertEqual(ParseUnit('130'), 132)
776

    
777
  def testFloating(self):
778
    self.assertEqual(ParseUnit('0'), 0)
779
    self.assertEqual(ParseUnit('0.5'), 4)
780
    self.assertEqual(ParseUnit('1.75'), 4)
781
    self.assertEqual(ParseUnit('1.99'), 4)
782
    self.assertEqual(ParseUnit('2.00'), 4)
783
    self.assertEqual(ParseUnit('2.01'), 4)
784
    self.assertEqual(ParseUnit('3.99'), 4)
785
    self.assertEqual(ParseUnit('4.00'), 4)
786
    self.assertEqual(ParseUnit('4.01'), 8)
787
    self.assertEqual(ParseUnit('1.5G'), 1536)
788
    self.assertEqual(ParseUnit('1.8G'), 1844)
789
    self.assertEqual(ParseUnit('8.28T'), 8682212)
790

    
791
  def testSuffixes(self):
792
    for sep in ('', ' ', '   ', "\t", "\t "):
793
      for suffix, scale in TestParseUnit.SCALES:
794
        for func in (lambda x: x, str.lower, str.upper):
795
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
796
                           1024 * scale)
797

    
798
  def testInvalidInput(self):
799
    for sep in ('-', '_', ',', 'a'):
800
      for suffix, _ in TestParseUnit.SCALES:
801
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
802

    
803
    for suffix, _ in TestParseUnit.SCALES:
804
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
805

    
806

    
807
class TestSshKeys(testutils.GanetiTestCase):
808
  """Test case for the AddAuthorizedKey function"""
809

    
810
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
811
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
812
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
813

    
814
  def setUp(self):
815
    testutils.GanetiTestCase.setUp(self)
816
    self.tmpname = self._CreateTempFile()
817
    handle = open(self.tmpname, 'w')
818
    try:
819
      handle.write("%s\n" % TestSshKeys.KEY_A)
820
      handle.write("%s\n" % TestSshKeys.KEY_B)
821
    finally:
822
      handle.close()
823

    
824
  def testAddingNewKey(self):
825
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
826

    
827
    self.assertFileContent(self.tmpname,
828
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
829
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
830
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
831
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
832

    
833
  def testAddingAlmostButNotCompletelyTheSameKey(self):
834
    AddAuthorizedKey(self.tmpname,
835
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
836

    
837
    self.assertFileContent(self.tmpname,
838
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
839
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
840
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
841
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
842

    
843
  def testAddingExistingKeyWithSomeMoreSpaces(self):
844
    AddAuthorizedKey(self.tmpname,
845
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
846

    
847
    self.assertFileContent(self.tmpname,
848
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
849
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
850
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
851

    
852
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
853
    RemoveAuthorizedKey(self.tmpname,
854
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
855

    
856
    self.assertFileContent(self.tmpname,
857
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
858
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
859

    
860
  def testRemovingNonExistingKey(self):
861
    RemoveAuthorizedKey(self.tmpname,
862
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
863

    
864
    self.assertFileContent(self.tmpname,
865
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
866
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
867
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
868

    
869

    
870
class TestEtcHosts(testutils.GanetiTestCase):
871
  """Test functions modifying /etc/hosts"""
872

    
873
  def setUp(self):
874
    testutils.GanetiTestCase.setUp(self)
875
    self.tmpname = self._CreateTempFile()
876
    handle = open(self.tmpname, 'w')
877
    try:
878
      handle.write('# This is a test file for /etc/hosts\n')
879
      handle.write('127.0.0.1\tlocalhost\n')
880
      handle.write('192.168.1.1 router gw\n')
881
    finally:
882
      handle.close()
883

    
884
  def testSettingNewIp(self):
885
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
886

    
887
    self.assertFileContent(self.tmpname,
888
      "# This is a test file for /etc/hosts\n"
889
      "127.0.0.1\tlocalhost\n"
890
      "192.168.1.1 router gw\n"
891
      "1.2.3.4\tmyhost.domain.tld myhost\n")
892
    self.assertFileMode(self.tmpname, 0644)
893

    
894
  def testSettingExistingIp(self):
895
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
896
                     ['myhost'])
897

    
898
    self.assertFileContent(self.tmpname,
899
      "# This is a test file for /etc/hosts\n"
900
      "127.0.0.1\tlocalhost\n"
901
      "192.168.1.1\tmyhost.domain.tld myhost\n")
902
    self.assertFileMode(self.tmpname, 0644)
903

    
904
  def testSettingDuplicateName(self):
905
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
906

    
907
    self.assertFileContent(self.tmpname,
908
      "# This is a test file for /etc/hosts\n"
909
      "127.0.0.1\tlocalhost\n"
910
      "192.168.1.1 router gw\n"
911
      "1.2.3.4\tmyhost\n")
912
    self.assertFileMode(self.tmpname, 0644)
913

    
914
  def testRemovingExistingHost(self):
915
    RemoveEtcHostsEntry(self.tmpname, 'router')
916

    
917
    self.assertFileContent(self.tmpname,
918
      "# This is a test file for /etc/hosts\n"
919
      "127.0.0.1\tlocalhost\n"
920
      "192.168.1.1 gw\n")
921
    self.assertFileMode(self.tmpname, 0644)
922

    
923
  def testRemovingSingleExistingHost(self):
924
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
925

    
926
    self.assertFileContent(self.tmpname,
927
      "# This is a test file for /etc/hosts\n"
928
      "192.168.1.1 router gw\n")
929
    self.assertFileMode(self.tmpname, 0644)
930

    
931
  def testRemovingNonExistingHost(self):
932
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
933

    
934
    self.assertFileContent(self.tmpname,
935
      "# This is a test file for /etc/hosts\n"
936
      "127.0.0.1\tlocalhost\n"
937
      "192.168.1.1 router gw\n")
938
    self.assertFileMode(self.tmpname, 0644)
939

    
940
  def testRemovingAlias(self):
941
    RemoveEtcHostsEntry(self.tmpname, 'gw')
942

    
943
    self.assertFileContent(self.tmpname,
944
      "# This is a test file for /etc/hosts\n"
945
      "127.0.0.1\tlocalhost\n"
946
      "192.168.1.1 router\n")
947
    self.assertFileMode(self.tmpname, 0644)
948

    
949

    
950
class TestShellQuoting(unittest.TestCase):
951
  """Test case for shell quoting functions"""
952

    
953
  def testShellQuote(self):
954
    self.assertEqual(ShellQuote('abc'), "abc")
955
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
956
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
957
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
958
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
959

    
960
  def testShellQuoteArgs(self):
961
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
962
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
963
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
964

    
965

    
966
class TestTcpPing(unittest.TestCase):
967
  """Testcase for TCP version of ping - against listen(2)ing port"""
968

    
969
  def setUp(self):
970
    self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
971
    self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
972
    self.listenerport = self.listener.getsockname()[1]
973
    self.listener.listen(1)
974

    
975
  def tearDown(self):
976
    self.listener.shutdown(socket.SHUT_RDWR)
977
    del self.listener
978
    del self.listenerport
979

    
980
  def testTcpPingToLocalHostAccept(self):
981
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
982
                         self.listenerport,
983
                         timeout=10,
984
                         live_port_needed=True,
985
                         source=constants.LOCALHOST_IP_ADDRESS,
986
                         ),
987
                 "failed to connect to test listener")
988

    
989
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
990
                         self.listenerport,
991
                         timeout=10,
992
                         live_port_needed=True,
993
                         ),
994
                 "failed to connect to test listener (no source)")
995

    
996

    
997
class TestTcpPingDeaf(unittest.TestCase):
998
  """Testcase for TCP version of ping - against non listen(2)ing port"""
999

    
1000
  def setUp(self):
1001
    self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1002
    self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
1003
    self.deaflistenerport = self.deaflistener.getsockname()[1]
1004

    
1005
  def tearDown(self):
1006
    del self.deaflistener
1007
    del self.deaflistenerport
1008

    
1009
  def testTcpPingToLocalHostAcceptDeaf(self):
1010
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1011
                        self.deaflistenerport,
1012
                        timeout=constants.TCP_PING_TIMEOUT,
1013
                        live_port_needed=True,
1014
                        source=constants.LOCALHOST_IP_ADDRESS,
1015
                        ), # need successful connect(2)
1016
                "successfully connected to deaf listener")
1017

    
1018
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1019
                        self.deaflistenerport,
1020
                        timeout=constants.TCP_PING_TIMEOUT,
1021
                        live_port_needed=True,
1022
                        ), # need successful connect(2)
1023
                "successfully connected to deaf listener (no source addr)")
1024

    
1025
  def testTcpPingToLocalHostNoAccept(self):
1026
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1027
                         self.deaflistenerport,
1028
                         timeout=constants.TCP_PING_TIMEOUT,
1029
                         live_port_needed=False,
1030
                         source=constants.LOCALHOST_IP_ADDRESS,
1031
                         ), # ECONNREFUSED is OK
1032
                 "failed to ping alive host on deaf port")
1033

    
1034
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
1035
                         self.deaflistenerport,
1036
                         timeout=constants.TCP_PING_TIMEOUT,
1037
                         live_port_needed=False,
1038
                         ), # ECONNREFUSED is OK
1039
                 "failed to ping alive host on deaf port (no source addr)")
1040

    
1041

    
1042
class TestOwnIpAddress(unittest.TestCase):
1043
  """Testcase for OwnIpAddress"""
1044

    
1045
  def testOwnLoopback(self):
1046
    """check having the loopback ip"""
1047
    self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
1048
                    "Should own the loopback address")
1049

    
1050
  def testNowOwnAddress(self):
1051
    """check that I don't own an address"""
1052

    
1053
    # network 192.0.2.0/24 is reserved for test/documentation as per
1054
    # rfc 3330, so we *should* not have an address of this range... if
1055
    # this fails, we should extend the test to multiple addresses
1056
    DST_IP = "192.0.2.1"
1057
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
1058

    
1059

    
1060
def _GetSocketCredentials(path):
1061
  """Connect to a Unix socket and return remote credentials.
1062

1063
  """
1064
  sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1065
  try:
1066
    sock.settimeout(10)
1067
    sock.connect(path)
1068
    return utils.GetSocketCredentials(sock)
1069
  finally:
1070
    sock.close()
1071

    
1072

    
1073
class TestGetSocketCredentials(unittest.TestCase):
1074
  def setUp(self):
1075
    self.tmpdir = tempfile.mkdtemp()
1076
    self.sockpath = utils.PathJoin(self.tmpdir, "sock")
1077

    
1078
    self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1079
    self.listener.settimeout(10)
1080
    self.listener.bind(self.sockpath)
1081
    self.listener.listen(1)
1082

    
1083
  def tearDown(self):
1084
    self.listener.shutdown(socket.SHUT_RDWR)
1085
    self.listener.close()
1086
    shutil.rmtree(self.tmpdir)
1087

    
1088
  def test(self):
1089
    (c2pr, c2pw) = os.pipe()
1090

    
1091
    # Start child process
1092
    child = os.fork()
1093
    if child == 0:
1094
      try:
1095
        data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
1096

    
1097
        os.write(c2pw, data)
1098
        os.close(c2pw)
1099

    
1100
        os._exit(0)
1101
      finally:
1102
        os._exit(1)
1103

    
1104
    os.close(c2pw)
1105

    
1106
    # Wait for one connection
1107
    (conn, _) = self.listener.accept()
1108
    conn.recv(1)
1109
    conn.close()
1110

    
1111
    # Wait for result
1112
    result = os.read(c2pr, 4096)
1113
    os.close(c2pr)
1114

    
1115
    # Check child's exit code
1116
    (_, status) = os.waitpid(child, 0)
1117
    self.assertFalse(os.WIFSIGNALED(status))
1118
    self.assertEqual(os.WEXITSTATUS(status), 0)
1119

    
1120
    # Check result
1121
    (pid, uid, gid) = serializer.LoadJson(result)
1122
    self.assertEqual(pid, os.getpid())
1123
    self.assertEqual(uid, os.getuid())
1124
    self.assertEqual(gid, os.getgid())
1125

    
1126

    
1127
class TestListVisibleFiles(unittest.TestCase):
1128
  """Test case for ListVisibleFiles"""
1129

    
1130
  def setUp(self):
1131
    self.path = tempfile.mkdtemp()
1132

    
1133
  def tearDown(self):
1134
    shutil.rmtree(self.path)
1135

    
1136
  def _test(self, files, expected):
1137
    # Sort a copy
1138
    expected = expected[:]
1139
    expected.sort()
1140

    
1141
    for name in files:
1142
      f = open(os.path.join(self.path, name), 'w')
1143
      try:
1144
        f.write("Test\n")
1145
      finally:
1146
        f.close()
1147

    
1148
    found = ListVisibleFiles(self.path)
1149
    found.sort()
1150

    
1151
    self.assertEqual(found, expected)
1152

    
1153
  def testAllVisible(self):
1154
    files = ["a", "b", "c"]
1155
    expected = files
1156
    self._test(files, expected)
1157

    
1158
  def testNoneVisible(self):
1159
    files = [".a", ".b", ".c"]
1160
    expected = []
1161
    self._test(files, expected)
1162

    
1163
  def testSomeVisible(self):
1164
    files = ["a", "b", ".c"]
1165
    expected = ["a", "b"]
1166
    self._test(files, expected)
1167

    
1168
  def testNonAbsolutePath(self):
1169
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1170

    
1171
  def testNonNormalizedPath(self):
1172
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1173
                          "/bin/../tmp")
1174

    
1175

    
1176
class TestNewUUID(unittest.TestCase):
1177
  """Test case for NewUUID"""
1178

    
1179
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
1180
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
1181

    
1182
  def runTest(self):
1183
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
1184

    
1185

    
1186
class TestUniqueSequence(unittest.TestCase):
1187
  """Test case for UniqueSequence"""
1188

    
1189
  def _test(self, input, expected):
1190
    self.assertEqual(utils.UniqueSequence(input), expected)
1191

    
1192
  def runTest(self):
1193
    # Ordered input
1194
    self._test([1, 2, 3], [1, 2, 3])
1195
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1196
    self._test([1, 2, 2, 3], [1, 2, 3])
1197
    self._test([1, 2, 3, 3], [1, 2, 3])
1198

    
1199
    # Unordered input
1200
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1201
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1202

    
1203
    # Strings
1204
    self._test(["a", "a"], ["a"])
1205
    self._test(["a", "b"], ["a", "b"])
1206
    self._test(["a", "b", "a"], ["a", "b"])
1207

    
1208

    
1209
class TestFirstFree(unittest.TestCase):
1210
  """Test case for the FirstFree function"""
1211

    
1212
  def test(self):
1213
    """Test FirstFree"""
1214
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1215
    self.failUnlessEqual(FirstFree([]), None)
1216
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1217
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1218
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1219

    
1220

    
1221
class TestTailFile(testutils.GanetiTestCase):
1222
  """Test case for the TailFile function"""
1223

    
1224
  def testEmpty(self):
1225
    fname = self._CreateTempFile()
1226
    self.failUnlessEqual(TailFile(fname), [])
1227
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1228

    
1229
  def testAllLines(self):
1230
    data = ["test %d" % i for i in range(30)]
1231
    for i in range(30):
1232
      fname = self._CreateTempFile()
1233
      fd = open(fname, "w")
1234
      fd.write("\n".join(data[:i]))
1235
      if i > 0:
1236
        fd.write("\n")
1237
      fd.close()
1238
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1239

    
1240
  def testPartialLines(self):
1241
    data = ["test %d" % i for i in range(30)]
1242
    fname = self._CreateTempFile()
1243
    fd = open(fname, "w")
1244
    fd.write("\n".join(data))
1245
    fd.write("\n")
1246
    fd.close()
1247
    for i in range(1, 30):
1248
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1249

    
1250
  def testBigFile(self):
1251
    data = ["test %d" % i for i in range(30)]
1252
    fname = self._CreateTempFile()
1253
    fd = open(fname, "w")
1254
    fd.write("X" * 1048576)
1255
    fd.write("\n")
1256
    fd.write("\n".join(data))
1257
    fd.write("\n")
1258
    fd.close()
1259
    for i in range(1, 30):
1260
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1261

    
1262

    
1263
class _BaseFileLockTest:
1264
  """Test case for the FileLock class"""
1265

    
1266
  def testSharedNonblocking(self):
1267
    self.lock.Shared(blocking=False)
1268
    self.lock.Close()
1269

    
1270
  def testExclusiveNonblocking(self):
1271
    self.lock.Exclusive(blocking=False)
1272
    self.lock.Close()
1273

    
1274
  def testUnlockNonblocking(self):
1275
    self.lock.Unlock(blocking=False)
1276
    self.lock.Close()
1277

    
1278
  def testSharedBlocking(self):
1279
    self.lock.Shared(blocking=True)
1280
    self.lock.Close()
1281

    
1282
  def testExclusiveBlocking(self):
1283
    self.lock.Exclusive(blocking=True)
1284
    self.lock.Close()
1285

    
1286
  def testUnlockBlocking(self):
1287
    self.lock.Unlock(blocking=True)
1288
    self.lock.Close()
1289

    
1290
  def testSharedExclusiveUnlock(self):
1291
    self.lock.Shared(blocking=False)
1292
    self.lock.Exclusive(blocking=False)
1293
    self.lock.Unlock(blocking=False)
1294
    self.lock.Close()
1295

    
1296
  def testExclusiveSharedUnlock(self):
1297
    self.lock.Exclusive(blocking=False)
1298
    self.lock.Shared(blocking=False)
1299
    self.lock.Unlock(blocking=False)
1300
    self.lock.Close()
1301

    
1302
  def testSimpleTimeout(self):
1303
    # These will succeed on the first attempt, hence a short timeout
1304
    self.lock.Shared(blocking=True, timeout=10.0)
1305
    self.lock.Exclusive(blocking=False, timeout=10.0)
1306
    self.lock.Unlock(blocking=True, timeout=10.0)
1307
    self.lock.Close()
1308

    
1309
  @staticmethod
1310
  def _TryLockInner(filename, shared, blocking):
1311
    lock = utils.FileLock.Open(filename)
1312

    
1313
    if shared:
1314
      fn = lock.Shared
1315
    else:
1316
      fn = lock.Exclusive
1317

    
1318
    try:
1319
      # The timeout doesn't really matter as the parent process waits for us to
1320
      # finish anyway.
1321
      fn(blocking=blocking, timeout=0.01)
1322
    except errors.LockError, err:
1323
      return False
1324

    
1325
    return True
1326

    
1327
  def _TryLock(self, *args):
1328
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1329
                                      *args)
1330

    
1331
  def testTimeout(self):
1332
    for blocking in [True, False]:
1333
      self.lock.Exclusive(blocking=True)
1334
      self.failIf(self._TryLock(False, blocking))
1335
      self.failIf(self._TryLock(True, blocking))
1336

    
1337
      self.lock.Shared(blocking=True)
1338
      self.assert_(self._TryLock(True, blocking))
1339
      self.failIf(self._TryLock(False, blocking))
1340

    
1341
  def testCloseShared(self):
1342
    self.lock.Close()
1343
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1344

    
1345
  def testCloseExclusive(self):
1346
    self.lock.Close()
1347
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1348

    
1349
  def testCloseUnlock(self):
1350
    self.lock.Close()
1351
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1352

    
1353

    
1354
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1355
  TESTDATA = "Hello World\n" * 10
1356

    
1357
  def setUp(self):
1358
    testutils.GanetiTestCase.setUp(self)
1359

    
1360
    self.tmpfile = tempfile.NamedTemporaryFile()
1361
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1362
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1363

    
1364
    # Ensure "Open" didn't truncate file
1365
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1366

    
1367
  def tearDown(self):
1368
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1369

    
1370
    testutils.GanetiTestCase.tearDown(self)
1371

    
1372

    
1373
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1374
  def setUp(self):
1375
    self.tmpfile = tempfile.NamedTemporaryFile()
1376
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1377

    
1378

    
1379
class TestTimeFunctions(unittest.TestCase):
1380
  """Test case for time functions"""
1381

    
1382
  def runTest(self):
1383
    self.assertEqual(utils.SplitTime(1), (1, 0))
1384
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1385
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1386
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1387
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1388
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1389
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1390
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1391

    
1392
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1393

    
1394
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1395
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1396
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1397

    
1398
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1399
                     1218448917.481)
1400
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1401

    
1402
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1403
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1404
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1405
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1406
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1407

    
1408

    
1409
class FieldSetTestCase(unittest.TestCase):
1410
  """Test case for FieldSets"""
1411

    
1412
  def testSimpleMatch(self):
1413
    f = utils.FieldSet("a", "b", "c", "def")
1414
    self.failUnless(f.Matches("a"))
1415
    self.failIf(f.Matches("d"), "Substring matched")
1416
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1417
    self.failIf(f.NonMatching(["b", "c"]))
1418
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1419
    self.failUnless(f.NonMatching(["a", "d"]))
1420

    
1421
  def testRegexMatch(self):
1422
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1423
    self.failUnless(f.Matches("b1"))
1424
    self.failUnless(f.Matches("b99"))
1425
    self.failIf(f.Matches("b/1"))
1426
    self.failIf(f.NonMatching(["b12", "c"]))
1427
    self.failUnless(f.NonMatching(["a", "1"]))
1428

    
1429
class TestForceDictType(unittest.TestCase):
1430
  """Test case for ForceDictType"""
1431

    
1432
  def setUp(self):
1433
    self.key_types = {
1434
      'a': constants.VTYPE_INT,
1435
      'b': constants.VTYPE_BOOL,
1436
      'c': constants.VTYPE_STRING,
1437
      'd': constants.VTYPE_SIZE,
1438
      }
1439

    
1440
  def _fdt(self, dict, allowed_values=None):
1441
    if allowed_values is None:
1442
      ForceDictType(dict, self.key_types)
1443
    else:
1444
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
1445

    
1446
    return dict
1447

    
1448
  def testSimpleDict(self):
1449
    self.assertEqual(self._fdt({}), {})
1450
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1451
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1452
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1453
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1454
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1455
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1456
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1457
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1458
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1459
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1460
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1461

    
1462
  def testErrors(self):
1463
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1464
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1465
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1466
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1467

    
1468

    
1469
class TestIsAbsNormPath(unittest.TestCase):
1470
  """Testing case for IsNormAbsPath"""
1471

    
1472
  def _pathTestHelper(self, path, result):
1473
    if result:
1474
      self.assert_(IsNormAbsPath(path),
1475
          "Path %s should result absolute and normalized" % path)
1476
    else:
1477
      self.assert_(not IsNormAbsPath(path),
1478
          "Path %s should not result absolute and normalized" % path)
1479

    
1480
  def testBase(self):
1481
    self._pathTestHelper('/etc', True)
1482
    self._pathTestHelper('/srv', True)
1483
    self._pathTestHelper('etc', False)
1484
    self._pathTestHelper('/etc/../root', False)
1485
    self._pathTestHelper('/etc/', False)
1486

    
1487

    
1488
class TestSafeEncode(unittest.TestCase):
1489
  """Test case for SafeEncode"""
1490

    
1491
  def testAscii(self):
1492
    for txt in [string.digits, string.letters, string.punctuation]:
1493
      self.failUnlessEqual(txt, SafeEncode(txt))
1494

    
1495
  def testDoubleEncode(self):
1496
    for i in range(255):
1497
      txt = SafeEncode(chr(i))
1498
      self.failUnlessEqual(txt, SafeEncode(txt))
1499

    
1500
  def testUnicode(self):
1501
    # 1024 is high enough to catch non-direct ASCII mappings
1502
    for i in range(1024):
1503
      txt = SafeEncode(unichr(i))
1504
      self.failUnlessEqual(txt, SafeEncode(txt))
1505

    
1506

    
1507
class TestFormatTime(unittest.TestCase):
1508
  """Testing case for FormatTime"""
1509

    
1510
  def testNone(self):
1511
    self.failUnlessEqual(FormatTime(None), "N/A")
1512

    
1513
  def testInvalid(self):
1514
    self.failUnlessEqual(FormatTime(()), "N/A")
1515

    
1516
  def testNow(self):
1517
    # tests that we accept time.time input
1518
    FormatTime(time.time())
1519
    # tests that we accept int input
1520
    FormatTime(int(time.time()))
1521

    
1522

    
1523
class RunInSeparateProcess(unittest.TestCase):
1524
  def test(self):
1525
    for exp in [True, False]:
1526
      def _child():
1527
        return exp
1528

    
1529
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1530

    
1531
  def testArgs(self):
1532
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1533
      def _child(carg1, carg2):
1534
        return carg1 == "Foo" and carg2 == arg
1535

    
1536
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1537

    
1538
  def testPid(self):
1539
    parent_pid = os.getpid()
1540

    
1541
    def _check():
1542
      return os.getpid() == parent_pid
1543

    
1544
    self.failIf(utils.RunInSeparateProcess(_check))
1545

    
1546
  def testSignal(self):
1547
    def _kill():
1548
      os.kill(os.getpid(), signal.SIGTERM)
1549

    
1550
    self.assertRaises(errors.GenericError,
1551
                      utils.RunInSeparateProcess, _kill)
1552

    
1553
  def testException(self):
1554
    def _exc():
1555
      raise errors.GenericError("This is a test")
1556

    
1557
    self.assertRaises(errors.GenericError,
1558
                      utils.RunInSeparateProcess, _exc)
1559

    
1560

    
1561
class TestFingerprintFile(unittest.TestCase):
1562
  def setUp(self):
1563
    self.tmpfile = tempfile.NamedTemporaryFile()
1564

    
1565
  def test(self):
1566
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1567
                     "da39a3ee5e6b4b0d3255bfef95601890afd80709")
1568

    
1569
    utils.WriteFile(self.tmpfile.name, data="Hello World\n")
1570
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1571
                     "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a")
1572

    
1573

    
1574
class TestUnescapeAndSplit(unittest.TestCase):
1575
  """Testing case for UnescapeAndSplit"""
1576

    
1577
  def setUp(self):
1578
    # testing more that one separator for regexp safety
1579
    self._seps = [",", "+", "."]
1580

    
1581
  def testSimple(self):
1582
    a = ["a", "b", "c", "d"]
1583
    for sep in self._seps:
1584
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1585

    
1586
  def testEscape(self):
1587
    for sep in self._seps:
1588
      a = ["a", "b\\" + sep + "c", "d"]
1589
      b = ["a", "b" + sep + "c", "d"]
1590
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1591

    
1592
  def testDoubleEscape(self):
1593
    for sep in self._seps:
1594
      a = ["a", "b\\\\", "c", "d"]
1595
      b = ["a", "b\\", "c", "d"]
1596
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1597

    
1598
  def testThreeEscape(self):
1599
    for sep in self._seps:
1600
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1601
      b = ["a", "b\\" + sep + "c", "d"]
1602
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1603

    
1604

    
1605
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1606
  def setUp(self):
1607
    self.tmpdir = tempfile.mkdtemp()
1608

    
1609
  def tearDown(self):
1610
    shutil.rmtree(self.tmpdir)
1611

    
1612
  def _checkRsaPrivateKey(self, key):
1613
    lines = key.splitlines()
1614
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1615
            "-----END RSA PRIVATE KEY-----" in lines)
1616

    
1617
  def _checkCertificate(self, cert):
1618
    lines = cert.splitlines()
1619
    return ("-----BEGIN CERTIFICATE-----" in lines and
1620
            "-----END CERTIFICATE-----" in lines)
1621

    
1622
  def test(self):
1623
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1624
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1625
      self._checkRsaPrivateKey(key_pem)
1626
      self._checkCertificate(cert_pem)
1627

    
1628
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1629
                                           key_pem)
1630
      self.assert_(key.bits() >= 1024)
1631
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1632
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1633

    
1634
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1635
                                             cert_pem)
1636
      self.failIf(x509.has_expired())
1637
      self.assertEqual(x509.get_issuer().CN, common_name)
1638
      self.assertEqual(x509.get_subject().CN, common_name)
1639
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1640

    
1641
  def testLegacy(self):
1642
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1643

    
1644
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1645

    
1646
    cert1 = utils.ReadFile(cert1_filename)
1647

    
1648
    self.assert_(self._checkRsaPrivateKey(cert1))
1649
    self.assert_(self._checkCertificate(cert1))
1650

    
1651

    
1652
class TestPathJoin(unittest.TestCase):
1653
  """Testing case for PathJoin"""
1654

    
1655
  def testBasicItems(self):
1656
    mlist = ["/a", "b", "c"]
1657
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1658

    
1659
  def testNonAbsPrefix(self):
1660
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1661

    
1662
  def testBackTrack(self):
1663
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1664

    
1665
  def testMultiAbs(self):
1666
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1667

    
1668

    
1669
class TestHostInfo(unittest.TestCase):
1670
  """Testing case for HostInfo"""
1671

    
1672
  def testUppercase(self):
1673
    data = "AbC.example.com"
1674
    self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower())
1675

    
1676
  def testTooLongName(self):
1677
    data = "a.b." + "c" * 255
1678
    self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data)
1679

    
1680
  def testTrailingDot(self):
1681
    data = "a.b.c"
1682
    self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data)
1683

    
1684
  def testInvalidName(self):
1685
    data = [
1686
      "a b",
1687
      "a/b",
1688
      ".a.b",
1689
      "a..b",
1690
      ]
1691
    for value in data:
1692
      self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value)
1693

    
1694
  def testValidName(self):
1695
    data = [
1696
      "a.b",
1697
      "a-b",
1698
      "a_b",
1699
      "a.b.c",
1700
      ]
1701
    for value in data:
1702
      HostInfo.NormalizeName(value)
1703

    
1704

    
1705
class TestParseAsn1Generalizedtime(unittest.TestCase):
1706
  def test(self):
1707
    # UTC
1708
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1709
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1710
                     1266860512)
1711
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1712
                     (2**31) - 1)
1713

    
1714
    # With offset
1715
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1716
                     1266860512)
1717
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1718
                     1266931012)
1719
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1720
                     1266931088)
1721
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1722
                     1266931295)
1723
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1724
                     3600)
1725

    
1726
    # Leap seconds are not supported by datetime.datetime
1727
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1728
                      "19841231235960+0000")
1729
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1730
                      "19920630235960+0000")
1731

    
1732
    # Errors
1733
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1734
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1735
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1736
                      "20100222174152")
1737
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1738
                      "Mon Feb 22 17:47:02 UTC 2010")
1739
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1740
                      "2010-02-22 17:42:02")
1741

    
1742

    
1743
class TestGetX509CertValidity(testutils.GanetiTestCase):
1744
  def setUp(self):
1745
    testutils.GanetiTestCase.setUp(self)
1746

    
1747
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1748

    
1749
    # Test whether we have pyOpenSSL 0.7 or above
1750
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1751

    
1752
    if not self.pyopenssl0_7:
1753
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1754
                    " function correctly")
1755

    
1756
  def _LoadCert(self, name):
1757
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1758
                                           self._ReadTestData(name))
1759

    
1760
  def test(self):
1761
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1762
    if self.pyopenssl0_7:
1763
      self.assertEqual(validity, (1266919967, 1267524767))
1764
    else:
1765
      self.assertEqual(validity, (None, None))
1766

    
1767

    
1768
class TestSignX509Certificate(unittest.TestCase):
1769
  KEY = "My private key!"
1770
  KEY_OTHER = "Another key"
1771

    
1772
  def test(self):
1773
    # Generate certificate valid for 5 minutes
1774
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1775

    
1776
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1777
                                           cert_pem)
1778

    
1779
    # No signature at all
1780
    self.assertRaises(errors.GenericError,
1781
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1782

    
1783
    # Invalid input
1784
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1785
                      "", self.KEY)
1786
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1787
                      "X-Ganeti-Signature: \n", self.KEY)
1788
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1789
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1790
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1791
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1792
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1793
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1794

    
1795
    # Invalid salt
1796
    for salt in list("-_@$,:;/\\ \t\n"):
1797
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1798
                        cert_pem, self.KEY, "foo%sbar" % salt)
1799

    
1800
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1801
                 utils.GenerateSecret(numbytes=4),
1802
                 utils.GenerateSecret(numbytes=16),
1803
                 "{123:456}".encode("hex")]:
1804
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1805

    
1806
      self._Check(cert, salt, signed_pem)
1807

    
1808
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1809
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1810
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1811
                               "lines----\n------ at\nthe end!"))
1812

    
1813
  def _Check(self, cert, salt, pem):
1814
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1815
    self.assertEqual(salt, salt2)
1816
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1817

    
1818
    # Other key
1819
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1820
                      pem, self.KEY_OTHER)
1821

    
1822

    
1823
class TestMakedirs(unittest.TestCase):
1824
  def setUp(self):
1825
    self.tmpdir = tempfile.mkdtemp()
1826

    
1827
  def tearDown(self):
1828
    shutil.rmtree(self.tmpdir)
1829

    
1830
  def testNonExisting(self):
1831
    path = utils.PathJoin(self.tmpdir, "foo")
1832
    utils.Makedirs(path)
1833
    self.assert_(os.path.isdir(path))
1834

    
1835
  def testExisting(self):
1836
    path = utils.PathJoin(self.tmpdir, "foo")
1837
    os.mkdir(path)
1838
    utils.Makedirs(path)
1839
    self.assert_(os.path.isdir(path))
1840

    
1841
  def testRecursiveNonExisting(self):
1842
    path = utils.PathJoin(self.tmpdir, "foo/bar/baz")
1843
    utils.Makedirs(path)
1844
    self.assert_(os.path.isdir(path))
1845

    
1846
  def testRecursiveExisting(self):
1847
    path = utils.PathJoin(self.tmpdir, "B/moo/xyz")
1848
    self.assert_(not os.path.exists(path))
1849
    os.mkdir(utils.PathJoin(self.tmpdir, "B"))
1850
    utils.Makedirs(path)
1851
    self.assert_(os.path.isdir(path))
1852

    
1853

    
1854
class TestRetry(testutils.GanetiTestCase):
1855
  @staticmethod
1856
  def _RaiseRetryAgain():
1857
    raise utils.RetryAgain()
1858

    
1859
  def _WrongNestedLoop(self):
1860
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
1861

    
1862
  def testRaiseTimeout(self):
1863
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
1864
                          self._RaiseRetryAgain, 0.01, 0.02)
1865

    
1866
  def testComplete(self):
1867
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
1868

    
1869
  def testNestedLoop(self):
1870
    try:
1871
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
1872
                            self._WrongNestedLoop, 0, 1)
1873
    except utils.RetryTimeout:
1874
      self.fail("Didn't detect inner loop's exception")
1875

    
1876

    
1877
class TestLineSplitter(unittest.TestCase):
1878
  def test(self):
1879
    lines = []
1880
    ls = utils.LineSplitter(lines.append)
1881
    ls.write("Hello World\n")
1882
    self.assertEqual(lines, [])
1883
    ls.write("Foo\n Bar\r\n ")
1884
    ls.write("Baz")
1885
    ls.write("Moo")
1886
    self.assertEqual(lines, [])
1887
    ls.flush()
1888
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
1889
    ls.close()
1890
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
1891

    
1892
  def _testExtra(self, line, all_lines, p1, p2):
1893
    self.assertEqual(p1, 999)
1894
    self.assertEqual(p2, "extra")
1895
    all_lines.append(line)
1896

    
1897
  def testExtraArgsNoFlush(self):
1898
    lines = []
1899
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
1900
    ls.write("\n\nHello World\n")
1901
    ls.write("Foo\n Bar\r\n ")
1902
    ls.write("")
1903
    ls.write("Baz")
1904
    ls.write("Moo\n\nx\n")
1905
    self.assertEqual(lines, [])
1906
    ls.close()
1907
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
1908
                             "", "x"])
1909

    
1910

    
1911
class TestReadLockedPidFile(unittest.TestCase):
1912
  def setUp(self):
1913
    self.tmpdir = tempfile.mkdtemp()
1914

    
1915
  def tearDown(self):
1916
    shutil.rmtree(self.tmpdir)
1917

    
1918
  def testNonExistent(self):
1919
    path = utils.PathJoin(self.tmpdir, "nonexist")
1920
    self.assert_(utils.ReadLockedPidFile(path) is None)
1921

    
1922
  def testUnlocked(self):
1923
    path = utils.PathJoin(self.tmpdir, "pid")
1924
    utils.WriteFile(path, data="123")
1925
    self.assert_(utils.ReadLockedPidFile(path) is None)
1926

    
1927
  def testLocked(self):
1928
    path = utils.PathJoin(self.tmpdir, "pid")
1929
    utils.WriteFile(path, data="123")
1930

    
1931
    fl = utils.FileLock.Open(path)
1932
    try:
1933
      fl.Exclusive(blocking=True)
1934

    
1935
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
1936
    finally:
1937
      fl.Close()
1938

    
1939
    self.assert_(utils.ReadLockedPidFile(path) is None)
1940

    
1941
  def testError(self):
1942
    path = utils.PathJoin(self.tmpdir, "foobar", "pid")
1943
    utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="")
1944
    # open(2) should return ENOTDIR
1945
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
1946

    
1947

    
1948
class TestCertVerification(testutils.GanetiTestCase):
1949
  def setUp(self):
1950
    testutils.GanetiTestCase.setUp(self)
1951

    
1952
    self.tmpdir = tempfile.mkdtemp()
1953

    
1954
  def tearDown(self):
1955
    shutil.rmtree(self.tmpdir)
1956

    
1957
  def testVerifyCertificate(self):
1958
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
1959
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1960
                                           cert_pem)
1961

    
1962
    # Not checking return value as this certificate is expired
1963
    utils.VerifyX509Certificate(cert, 30, 7)
1964

    
1965

    
1966
class TestVerifyCertificateInner(unittest.TestCase):
1967
  def test(self):
1968
    vci = utils._VerifyCertificateInner
1969

    
1970
    # Valid
1971
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
1972
                     (None, None))
1973

    
1974
    # Not yet valid
1975
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
1976
    self.assertEqual(errcode, utils.CERT_WARNING)
1977

    
1978
    # Expiring soon
1979
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
1980
    self.assertEqual(errcode, utils.CERT_ERROR)
1981

    
1982
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
1983
    self.assertEqual(errcode, utils.CERT_WARNING)
1984

    
1985
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
1986
    self.assertEqual(errcode, None)
1987

    
1988
    # Expired
1989
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
1990
    self.assertEqual(errcode, utils.CERT_ERROR)
1991

    
1992
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
1993
    self.assertEqual(errcode, utils.CERT_ERROR)
1994

    
1995
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
1996
    self.assertEqual(errcode, utils.CERT_ERROR)
1997

    
1998
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
1999
    self.assertEqual(errcode, utils.CERT_ERROR)
2000

    
2001

    
2002
class TestHmacFunctions(unittest.TestCase):
2003
  # Digests can be checked with "openssl sha1 -hmac $key"
2004
  def testSha1Hmac(self):
2005
    self.assertEqual(utils.Sha1Hmac("", ""),
2006
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2007
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2008
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2009
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2010
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2011

    
2012
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2013
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2014
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2015

    
2016
  def testVerifySha1Hmac(self):
2017
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2018
                                               "7d64b71fb76370690e1d")))
2019
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2020
                                      ("f904c2476527c6d3e660"
2021
                                       "9ab683c66fa0652cb1dc")))
2022

    
2023
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2024
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2025
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2026
                                      digest.lower()))
2027
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2028
                                      digest.upper()))
2029
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2030
                                      digest.title()))
2031

    
2032

    
2033
if __name__ == '__main__':
2034
  testutils.GanetiTestProgram()