Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 8342c325

History | View | Annotate | Download (68.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
import random
42
import operator
43

    
44
import testutils
45
from ganeti import constants
46
from ganeti import compat
47
from ganeti import utils
48
from ganeti import errors
49
from ganeti.utils import RunCmd, RemoveFile, \
50
     ListVisibleFiles, FirstFree, \
51
     TailFile, RunParts, PathJoin, \
52
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
53

    
54

    
55
class TestIsProcessAlive(unittest.TestCase):
56
  """Testing case for IsProcessAlive"""
57

    
58
  def testExists(self):
59
    mypid = os.getpid()
60
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
61

    
62
  def testNotExisting(self):
63
    pid_non_existing = os.fork()
64
    if pid_non_existing == 0:
65
      os._exit(0)
66
    elif pid_non_existing < 0:
67
      raise SystemError("can't fork")
68
    os.waitpid(pid_non_existing, 0)
69
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
70
                     "nonexisting process detected")
71

    
72

    
73
class TestGetProcStatusPath(unittest.TestCase):
74
  def test(self):
75
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
76
    self.assertNotEqual(utils._GetProcStatusPath(1),
77
                        utils._GetProcStatusPath(2))
78

    
79

    
80
class TestIsProcessHandlingSignal(unittest.TestCase):
81
  def setUp(self):
82
    self.tmpdir = tempfile.mkdtemp()
83

    
84
  def tearDown(self):
85
    shutil.rmtree(self.tmpdir)
86

    
87
  def testParseSigsetT(self):
88
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
89
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
90
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
91
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
92
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
93
                     set([2, 10, 32, 33]))
94
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
95
                     set([2, 32, 33]))
96
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
97
                     set([2, 28, 32, 33]))
98
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
99
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
100
                          24, 25, 26, 28, 31]))
101
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
102

    
103
  def testGetProcStatusField(self):
104
    for field in ["SigCgt", "Name", "FDSize"]:
105
      for value in ["", "0", "cat", "  1234 KB"]:
106
        pstatus = "\n".join([
107
          "VmPeak: 999 kB",
108
          "%s: %s" % (field, value),
109
          "TracerPid: 0",
110
          ])
111
        result = utils._GetProcStatusField(pstatus, field)
112
        self.assertEqual(result, value.strip())
113

    
114
  def test(self):
115
    sp = PathJoin(self.tmpdir, "status")
116

    
117
    utils.WriteFile(sp, data="\n".join([
118
      "Name:   bash",
119
      "State:  S (sleeping)",
120
      "SleepAVG:       98%",
121
      "Pid:    22250",
122
      "PPid:   10858",
123
      "TracerPid:      0",
124
      "SigBlk: 0000000000010000",
125
      "SigIgn: 0000000000384004",
126
      "SigCgt: 000000004b813efb",
127
      "CapEff: 0000000000000000",
128
      ]))
129

    
130
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
131

    
132
  def testNoSigCgt(self):
133
    sp = PathJoin(self.tmpdir, "status")
134

    
135
    utils.WriteFile(sp, data="\n".join([
136
      "Name:   bash",
137
      ]))
138

    
139
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
140
                      1234, 10, status_path=sp)
141

    
142
  def testNoSuchFile(self):
143
    sp = PathJoin(self.tmpdir, "notexist")
144

    
145
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
146

    
147
  @staticmethod
148
  def _TestRealProcess():
149
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
150
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
151
      raise Exception("SIGUSR1 is handled when it should not be")
152

    
153
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
154
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
155
      raise Exception("SIGUSR1 is not handled when it should be")
156

    
157
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
158
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
159
      raise Exception("SIGUSR1 is not handled when it should be")
160

    
161
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
162
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
163
      raise Exception("SIGUSR1 is handled when it should not be")
164

    
165
    return True
166

    
167
  def testRealProcess(self):
168
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
169

    
170

    
171
class TestPidFileFunctions(unittest.TestCase):
172
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
173

    
174
  def setUp(self):
175
    self.dir = tempfile.mkdtemp()
176
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
177

    
178
  def testPidFileFunctions(self):
179
    pid_file = self.f_dpn('test')
180
    fd = utils.WritePidFile(self.f_dpn('test'))
181
    self.failUnless(os.path.exists(pid_file),
182
                    "PID file should have been created")
183
    read_pid = utils.ReadPidFile(pid_file)
184
    self.failUnlessEqual(read_pid, os.getpid())
185
    self.failUnless(utils.IsProcessAlive(read_pid))
186
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
187
                          self.f_dpn('test'))
188
    os.close(fd)
189
    utils.RemovePidFile(self.f_dpn("test"))
190
    self.failIf(os.path.exists(pid_file),
191
                "PID file should not exist anymore")
192
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
193
                         "ReadPidFile should return 0 for missing pid file")
194
    fh = open(pid_file, "w")
195
    fh.write("blah\n")
196
    fh.close()
197
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
198
                         "ReadPidFile should return 0 for invalid pid file")
199
    # but now, even with the file existing, we should be able to lock it
200
    fd = utils.WritePidFile(self.f_dpn('test'))
201
    os.close(fd)
202
    utils.RemovePidFile(self.f_dpn("test"))
203
    self.failIf(os.path.exists(pid_file),
204
                "PID file should not exist anymore")
205

    
206
  def testKill(self):
207
    pid_file = self.f_dpn('child')
208
    r_fd, w_fd = os.pipe()
209
    new_pid = os.fork()
210
    if new_pid == 0: #child
211
      utils.WritePidFile(self.f_dpn('child'))
212
      os.write(w_fd, 'a')
213
      signal.pause()
214
      os._exit(0)
215
      return
216
    # else we are in the parent
217
    # wait until the child has written the pid file
218
    os.read(r_fd, 1)
219
    read_pid = utils.ReadPidFile(pid_file)
220
    self.failUnlessEqual(read_pid, new_pid)
221
    self.failUnless(utils.IsProcessAlive(new_pid))
222
    utils.KillProcess(new_pid, waitpid=True)
223
    self.failIf(utils.IsProcessAlive(new_pid))
224
    utils.RemovePidFile(self.f_dpn('child'))
225
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
226

    
227
  def tearDown(self):
228
    for name in os.listdir(self.dir):
229
      os.unlink(os.path.join(self.dir, name))
230
    os.rmdir(self.dir)
231

    
232

    
233
class TestRunCmd(testutils.GanetiTestCase):
234
  """Testing case for the RunCmd function"""
235

    
236
  def setUp(self):
237
    testutils.GanetiTestCase.setUp(self)
238
    self.magic = time.ctime() + " ganeti test"
239
    self.fname = self._CreateTempFile()
240
    self.fifo_tmpdir = tempfile.mkdtemp()
241
    self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
242
    os.mkfifo(self.fifo_file)
243

    
244
  def tearDown(self):
245
    shutil.rmtree(self.fifo_tmpdir)
246
    testutils.GanetiTestCase.tearDown(self)
247

    
248
  def testOk(self):
249
    """Test successful exit code"""
250
    result = RunCmd("/bin/sh -c 'exit 0'")
251
    self.assertEqual(result.exit_code, 0)
252
    self.assertEqual(result.output, "")
253

    
254
  def testFail(self):
255
    """Test fail exit code"""
256
    result = RunCmd("/bin/sh -c 'exit 1'")
257
    self.assertEqual(result.exit_code, 1)
258
    self.assertEqual(result.output, "")
259

    
260
  def testStdout(self):
261
    """Test standard output"""
262
    cmd = 'echo -n "%s"' % self.magic
263
    result = RunCmd("/bin/sh -c '%s'" % cmd)
264
    self.assertEqual(result.stdout, self.magic)
265
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
266
    self.assertEqual(result.output, "")
267
    self.assertFileContent(self.fname, self.magic)
268

    
269
  def testStderr(self):
270
    """Test standard error"""
271
    cmd = 'echo -n "%s"' % self.magic
272
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
273
    self.assertEqual(result.stderr, self.magic)
274
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
275
    self.assertEqual(result.output, "")
276
    self.assertFileContent(self.fname, self.magic)
277

    
278
  def testCombined(self):
279
    """Test combined output"""
280
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
281
    expected = "A" + self.magic + "B" + self.magic
282
    result = RunCmd("/bin/sh -c '%s'" % cmd)
283
    self.assertEqual(result.output, expected)
284
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
285
    self.assertEqual(result.output, "")
286
    self.assertFileContent(self.fname, expected)
287

    
288
  def testSignal(self):
289
    """Test signal"""
290
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
291
    self.assertEqual(result.signal, 15)
292
    self.assertEqual(result.output, "")
293

    
294
  def testTimeoutClean(self):
295
    cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
296
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
297
    self.assertEqual(result.exit_code, 0)
298

    
299
  def testTimeoutKill(self):
300
    cmd = ["/bin/sh", "-c", "trap '' TERM; read < %s" % self.fifo_file]
301
    timeout = 0.2
302
    out, err, status, ta = utils._RunCmdPipe(cmd, {}, False, "/", False,
303
                                             timeout, _linger_timeout=0.2)
304
    self.assert_(status < 0)
305
    self.assertEqual(-status, signal.SIGKILL)
306

    
307
  def testTimeoutOutputAfterTerm(self):
308
    cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
309
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
310
    self.assert_(result.failed)
311
    self.assertEqual(result.stdout, "sigtermed\n")
312

    
313
  def testListRun(self):
314
    """Test list runs"""
315
    result = RunCmd(["true"])
316
    self.assertEqual(result.signal, None)
317
    self.assertEqual(result.exit_code, 0)
318
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
319
    self.assertEqual(result.signal, None)
320
    self.assertEqual(result.exit_code, 1)
321
    result = RunCmd(["echo", "-n", self.magic])
322
    self.assertEqual(result.signal, None)
323
    self.assertEqual(result.exit_code, 0)
324
    self.assertEqual(result.stdout, self.magic)
325

    
326
  def testFileEmptyOutput(self):
327
    """Test file output"""
328
    result = RunCmd(["true"], output=self.fname)
329
    self.assertEqual(result.signal, None)
330
    self.assertEqual(result.exit_code, 0)
331
    self.assertFileContent(self.fname, "")
332

    
333
  def testLang(self):
334
    """Test locale environment"""
335
    old_env = os.environ.copy()
336
    try:
337
      os.environ["LANG"] = "en_US.UTF-8"
338
      os.environ["LC_ALL"] = "en_US.UTF-8"
339
      result = RunCmd(["locale"])
340
      for line in result.output.splitlines():
341
        key, value = line.split("=", 1)
342
        # Ignore these variables, they're overridden by LC_ALL
343
        if key == "LANG" or key == "LANGUAGE":
344
          continue
345
        self.failIf(value and value != "C" and value != '"C"',
346
            "Variable %s is set to the invalid value '%s'" % (key, value))
347
    finally:
348
      os.environ = old_env
349

    
350
  def testDefaultCwd(self):
351
    """Test default working directory"""
352
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
353

    
354
  def testCwd(self):
355
    """Test default working directory"""
356
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
357
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
358
    cwd = os.getcwd()
359
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
360

    
361
  def testResetEnv(self):
362
    """Test environment reset functionality"""
363
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
364
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
365
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
366

    
367
  def testNoFork(self):
368
    """Test that nofork raise an error"""
369
    assert not utils.no_fork
370
    utils.no_fork = True
371
    try:
372
      self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
373
    finally:
374
      utils.no_fork = False
375

    
376
  def testWrongParams(self):
377
    """Test wrong parameters"""
378
    self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
379
                      output="/dev/null", interactive=True)
380

    
381

    
382
class TestRunParts(testutils.GanetiTestCase):
383
  """Testing case for the RunParts function"""
384

    
385
  def setUp(self):
386
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
387

    
388
  def tearDown(self):
389
    shutil.rmtree(self.rundir)
390

    
391
  def testEmpty(self):
392
    """Test on an empty dir"""
393
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
394

    
395
  def testSkipWrongName(self):
396
    """Test that wrong files are skipped"""
397
    fname = os.path.join(self.rundir, "00test.dot")
398
    utils.WriteFile(fname, data="")
399
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
400
    relname = os.path.basename(fname)
401
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
402
                         [(relname, constants.RUNPARTS_SKIP, None)])
403

    
404
  def testSkipNonExec(self):
405
    """Test that non executable files are skipped"""
406
    fname = os.path.join(self.rundir, "00test")
407
    utils.WriteFile(fname, data="")
408
    relname = os.path.basename(fname)
409
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
410
                         [(relname, constants.RUNPARTS_SKIP, None)])
411

    
412
  def testError(self):
413
    """Test error on a broken executable"""
414
    fname = os.path.join(self.rundir, "00test")
415
    utils.WriteFile(fname, data="")
416
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
417
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
418
    self.failUnlessEqual(relname, os.path.basename(fname))
419
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
420
    self.failUnless(error)
421

    
422
  def testSorted(self):
423
    """Test executions are sorted"""
424
    files = []
425
    files.append(os.path.join(self.rundir, "64test"))
426
    files.append(os.path.join(self.rundir, "00test"))
427
    files.append(os.path.join(self.rundir, "42test"))
428

    
429
    for fname in files:
430
      utils.WriteFile(fname, data="")
431

    
432
    results = RunParts(self.rundir, reset_env=True)
433

    
434
    for fname in sorted(files):
435
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
436

    
437
  def testOk(self):
438
    """Test correct execution"""
439
    fname = os.path.join(self.rundir, "00test")
440
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
441
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
442
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
443
    self.failUnlessEqual(relname, os.path.basename(fname))
444
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
445
    self.failUnlessEqual(runresult.stdout, "ciao")
446

    
447
  def testRunFail(self):
448
    """Test correct execution, with run failure"""
449
    fname = os.path.join(self.rundir, "00test")
450
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
451
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
452
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
453
    self.failUnlessEqual(relname, os.path.basename(fname))
454
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
455
    self.failUnlessEqual(runresult.exit_code, 1)
456
    self.failUnless(runresult.failed)
457

    
458
  def testRunMix(self):
459
    files = []
460
    files.append(os.path.join(self.rundir, "00test"))
461
    files.append(os.path.join(self.rundir, "42test"))
462
    files.append(os.path.join(self.rundir, "64test"))
463
    files.append(os.path.join(self.rundir, "99test"))
464

    
465
    files.sort()
466

    
467
    # 1st has errors in execution
468
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
469
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
470

    
471
    # 2nd is skipped
472
    utils.WriteFile(files[1], data="")
473

    
474
    # 3rd cannot execute properly
475
    utils.WriteFile(files[2], data="")
476
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
477

    
478
    # 4th execs
479
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
480
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
481

    
482
    results = RunParts(self.rundir, reset_env=True)
483

    
484
    (relname, status, runresult) = results[0]
485
    self.failUnlessEqual(relname, os.path.basename(files[0]))
486
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
487
    self.failUnlessEqual(runresult.exit_code, 1)
488
    self.failUnless(runresult.failed)
489

    
490
    (relname, status, runresult) = results[1]
491
    self.failUnlessEqual(relname, os.path.basename(files[1]))
492
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
493
    self.failUnlessEqual(runresult, None)
494

    
495
    (relname, status, runresult) = results[2]
496
    self.failUnlessEqual(relname, os.path.basename(files[2]))
497
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
498
    self.failUnless(runresult)
499

    
500
    (relname, status, runresult) = results[3]
501
    self.failUnlessEqual(relname, os.path.basename(files[3]))
502
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
503
    self.failUnlessEqual(runresult.output, "ciao")
504
    self.failUnlessEqual(runresult.exit_code, 0)
505
    self.failUnless(not runresult.failed)
506

    
507
  def testMissingDirectory(self):
508
    nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
509
    self.assertEqual(RunParts(nosuchdir), [])
510

    
511

    
512
class TestStartDaemon(testutils.GanetiTestCase):
513
  def setUp(self):
514
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
515
    self.tmpfile = os.path.join(self.tmpdir, "test")
516

    
517
  def tearDown(self):
518
    shutil.rmtree(self.tmpdir)
519

    
520
  def testShell(self):
521
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
522
    self._wait(self.tmpfile, 60.0, "Hello World")
523

    
524
  def testShellOutput(self):
525
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
526
    self._wait(self.tmpfile, 60.0, "Hello World")
527

    
528
  def testNoShellNoOutput(self):
529
    utils.StartDaemon(["pwd"])
530

    
531
  def testNoShellNoOutputTouch(self):
532
    testfile = os.path.join(self.tmpdir, "check")
533
    self.failIf(os.path.exists(testfile))
534
    utils.StartDaemon(["touch", testfile])
535
    self._wait(testfile, 60.0, "")
536

    
537
  def testNoShellOutput(self):
538
    utils.StartDaemon(["pwd"], output=self.tmpfile)
539
    self._wait(self.tmpfile, 60.0, "/")
540

    
541
  def testNoShellOutputCwd(self):
542
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
543
    self._wait(self.tmpfile, 60.0, os.getcwd())
544

    
545
  def testShellEnv(self):
546
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
547
                      env={ "GNT_TEST_VAR": "Hello World", })
548
    self._wait(self.tmpfile, 60.0, "Hello World")
549

    
550
  def testNoShellEnv(self):
551
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
552
                      env={ "GNT_TEST_VAR": "Hello World", })
553
    self._wait(self.tmpfile, 60.0, "Hello World")
554

    
555
  def testOutputFd(self):
556
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
557
    try:
558
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
559
    finally:
560
      os.close(fd)
561
    self._wait(self.tmpfile, 60.0, os.getcwd())
562

    
563
  def testPid(self):
564
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
565
    self._wait(self.tmpfile, 60.0, str(pid))
566

    
567
  def testPidFile(self):
568
    pidfile = os.path.join(self.tmpdir, "pid")
569
    checkfile = os.path.join(self.tmpdir, "abort")
570

    
571
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
572
                            output=self.tmpfile)
573
    try:
574
      fd = os.open(pidfile, os.O_RDONLY)
575
      try:
576
        # Check file is locked
577
        self.assertRaises(errors.LockError, utils.LockFile, fd)
578

    
579
        pidtext = os.read(fd, 100)
580
      finally:
581
        os.close(fd)
582

    
583
      self.assertEqual(int(pidtext.strip()), pid)
584

    
585
      self.assert_(utils.IsProcessAlive(pid))
586
    finally:
587
      # No matter what happens, kill daemon
588
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
589
      self.failIf(utils.IsProcessAlive(pid))
590

    
591
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
592

    
593
  def _wait(self, path, timeout, expected):
594
    # Due to the asynchronous nature of daemon processes, polling is necessary.
595
    # A timeout makes sure the test doesn't hang forever.
596
    def _CheckFile():
597
      if not (os.path.isfile(path) and
598
              utils.ReadFile(path).strip() == expected):
599
        raise utils.RetryAgain()
600

    
601
    try:
602
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
603
    except utils.RetryTimeout:
604
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
605
                " didn't write the correct output" % timeout)
606

    
607
  def testError(self):
608
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
609
                      ["./does-NOT-EXIST/here/0123456789"])
610
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
611
                      ["./does-NOT-EXIST/here/0123456789"],
612
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
613
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
614
                      ["./does-NOT-EXIST/here/0123456789"],
615
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
616
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
617
                      ["./does-NOT-EXIST/here/0123456789"],
618
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
619

    
620
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
621
    try:
622
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
623
                        ["./does-NOT-EXIST/here/0123456789"],
624
                        output=self.tmpfile, output_fd=fd)
625
    finally:
626
      os.close(fd)
627

    
628

    
629
class TestSetCloseOnExecFlag(unittest.TestCase):
630
  """Tests for SetCloseOnExecFlag"""
631

    
632
  def setUp(self):
633
    self.tmpfile = tempfile.TemporaryFile()
634

    
635
  def testEnable(self):
636
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
637
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
638
                    fcntl.FD_CLOEXEC)
639

    
640
  def testDisable(self):
641
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
642
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
643
                fcntl.FD_CLOEXEC)
644

    
645

    
646
class TestSetNonblockFlag(unittest.TestCase):
647
  def setUp(self):
648
    self.tmpfile = tempfile.TemporaryFile()
649

    
650
  def testEnable(self):
651
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
652
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
653
                    os.O_NONBLOCK)
654

    
655
  def testDisable(self):
656
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
657
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
658
                os.O_NONBLOCK)
659

    
660

    
661
class TestRemoveFile(unittest.TestCase):
662
  """Test case for the RemoveFile function"""
663

    
664
  def setUp(self):
665
    """Create a temp dir and file for each case"""
666
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
667
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
668
    os.close(fd)
669

    
670
  def tearDown(self):
671
    if os.path.exists(self.tmpfile):
672
      os.unlink(self.tmpfile)
673
    os.rmdir(self.tmpdir)
674

    
675
  def testIgnoreDirs(self):
676
    """Test that RemoveFile() ignores directories"""
677
    self.assertEqual(None, RemoveFile(self.tmpdir))
678

    
679
  def testIgnoreNotExisting(self):
680
    """Test that RemoveFile() ignores non-existing files"""
681
    RemoveFile(self.tmpfile)
682
    RemoveFile(self.tmpfile)
683

    
684
  def testRemoveFile(self):
685
    """Test that RemoveFile does remove a file"""
686
    RemoveFile(self.tmpfile)
687
    if os.path.exists(self.tmpfile):
688
      self.fail("File '%s' not removed" % self.tmpfile)
689

    
690
  def testRemoveSymlink(self):
691
    """Test that RemoveFile does remove symlinks"""
692
    symlink = self.tmpdir + "/symlink"
693
    os.symlink("no-such-file", symlink)
694
    RemoveFile(symlink)
695
    if os.path.exists(symlink):
696
      self.fail("File '%s' not removed" % symlink)
697
    os.symlink(self.tmpfile, symlink)
698
    RemoveFile(symlink)
699
    if os.path.exists(symlink):
700
      self.fail("File '%s' not removed" % symlink)
701

    
702

    
703
class TestRemoveDir(unittest.TestCase):
704
  def setUp(self):
705
    self.tmpdir = tempfile.mkdtemp()
706

    
707
  def tearDown(self):
708
    try:
709
      shutil.rmtree(self.tmpdir)
710
    except EnvironmentError:
711
      pass
712

    
713
  def testEmptyDir(self):
714
    utils.RemoveDir(self.tmpdir)
715
    self.assertFalse(os.path.isdir(self.tmpdir))
716

    
717
  def testNonEmptyDir(self):
718
    self.tmpfile = os.path.join(self.tmpdir, "test1")
719
    open(self.tmpfile, "w").close()
720
    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
721

    
722

    
723
class TestRename(unittest.TestCase):
724
  """Test case for RenameFile"""
725

    
726
  def setUp(self):
727
    """Create a temporary directory"""
728
    self.tmpdir = tempfile.mkdtemp()
729
    self.tmpfile = os.path.join(self.tmpdir, "test1")
730

    
731
    # Touch the file
732
    open(self.tmpfile, "w").close()
733

    
734
  def tearDown(self):
735
    """Remove temporary directory"""
736
    shutil.rmtree(self.tmpdir)
737

    
738
  def testSimpleRename1(self):
739
    """Simple rename 1"""
740
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
741
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
742

    
743
  def testSimpleRename2(self):
744
    """Simple rename 2"""
745
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
746
                     mkdir=True)
747
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
748

    
749
  def testRenameMkdir(self):
750
    """Rename with mkdir"""
751
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
752
                     mkdir=True)
753
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
754
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
755

    
756
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
757
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
758
                     mkdir=True)
759
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
760
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
761
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
762

    
763

    
764
class TestReadFile(testutils.GanetiTestCase):
765

    
766
  def testReadAll(self):
767
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
768
    self.assertEqual(len(data), 814)
769

    
770
    h = compat.md5_hash()
771
    h.update(data)
772
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
773

    
774
  def testReadSize(self):
775
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
776
                          size=100)
777
    self.assertEqual(len(data), 100)
778

    
779
    h = compat.md5_hash()
780
    h.update(data)
781
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
782

    
783
  def testError(self):
784
    self.assertRaises(EnvironmentError, utils.ReadFile,
785
                      "/dev/null/does-not-exist")
786

    
787

    
788
class TestReadOneLineFile(testutils.GanetiTestCase):
789

    
790
  def setUp(self):
791
    testutils.GanetiTestCase.setUp(self)
792

    
793
  def testDefault(self):
794
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
795
    self.assertEqual(len(data), 27)
796
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
797

    
798
  def testNotStrict(self):
799
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
800
    self.assertEqual(len(data), 27)
801
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
802

    
803
  def testStrictFailure(self):
804
    self.assertRaises(errors.GenericError, ReadOneLineFile,
805
                      self._TestDataFilename("cert1.pem"), strict=True)
806

    
807
  def testLongLine(self):
808
    dummydata = (1024 * "Hello World! ")
809
    myfile = self._CreateTempFile()
810
    utils.WriteFile(myfile, data=dummydata)
811
    datastrict = ReadOneLineFile(myfile, strict=True)
812
    datalax = ReadOneLineFile(myfile, strict=False)
813
    self.assertEqual(dummydata, datastrict)
814
    self.assertEqual(dummydata, datalax)
815

    
816
  def testNewline(self):
817
    myfile = self._CreateTempFile()
818
    myline = "myline"
819
    for nl in ["", "\n", "\r\n"]:
820
      dummydata = "%s%s" % (myline, nl)
821
      utils.WriteFile(myfile, data=dummydata)
822
      datalax = ReadOneLineFile(myfile, strict=False)
823
      self.assertEqual(myline, datalax)
824
      datastrict = ReadOneLineFile(myfile, strict=True)
825
      self.assertEqual(myline, datastrict)
826

    
827
  def testWhitespaceAndMultipleLines(self):
828
    myfile = self._CreateTempFile()
829
    for nl in ["", "\n", "\r\n"]:
830
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
831
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
832
        utils.WriteFile(myfile, data=dummydata)
833
        datalax = ReadOneLineFile(myfile, strict=False)
834
        if nl:
835
          self.assert_(set("\r\n") & set(dummydata))
836
          self.assertRaises(errors.GenericError, ReadOneLineFile,
837
                            myfile, strict=True)
838
          explen = len("Foo bar baz ") + len(ws)
839
          self.assertEqual(len(datalax), explen)
840
          self.assertEqual(datalax, dummydata[:explen])
841
          self.assertFalse(set("\r\n") & set(datalax))
842
        else:
843
          datastrict = ReadOneLineFile(myfile, strict=True)
844
          self.assertEqual(dummydata, datastrict)
845
          self.assertEqual(dummydata, datalax)
846

    
847
  def testEmptylines(self):
848
    myfile = self._CreateTempFile()
849
    myline = "myline"
850
    for nl in ["\n", "\r\n"]:
851
      for ol in ["", "otherline"]:
852
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
853
        utils.WriteFile(myfile, data=dummydata)
854
        self.assert_(set("\r\n") & set(dummydata))
855
        datalax = ReadOneLineFile(myfile, strict=False)
856
        self.assertEqual(myline, datalax)
857
        if ol:
858
          self.assertRaises(errors.GenericError, ReadOneLineFile,
859
                            myfile, strict=True)
860
        else:
861
          datastrict = ReadOneLineFile(myfile, strict=True)
862
          self.assertEqual(myline, datastrict)
863

    
864
  def testEmptyfile(self):
865
    myfile = self._CreateTempFile()
866
    self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
867

    
868

    
869
class TestTimestampForFilename(unittest.TestCase):
870
  def test(self):
871
    self.assert_("." not in utils.TimestampForFilename())
872
    self.assert_(":" not in utils.TimestampForFilename())
873

    
874

    
875
class TestCreateBackup(testutils.GanetiTestCase):
876
  def setUp(self):
877
    testutils.GanetiTestCase.setUp(self)
878

    
879
    self.tmpdir = tempfile.mkdtemp()
880

    
881
  def tearDown(self):
882
    testutils.GanetiTestCase.tearDown(self)
883

    
884
    shutil.rmtree(self.tmpdir)
885

    
886
  def testEmpty(self):
887
    filename = PathJoin(self.tmpdir, "config.data")
888
    utils.WriteFile(filename, data="")
889
    bname = utils.CreateBackup(filename)
890
    self.assertFileContent(bname, "")
891
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
892
    utils.CreateBackup(filename)
893
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
894
    utils.CreateBackup(filename)
895
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
896

    
897
    fifoname = PathJoin(self.tmpdir, "fifo")
898
    os.mkfifo(fifoname)
899
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
900

    
901
  def testContent(self):
902
    bkpcount = 0
903
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
904
      for rep in [1, 2, 10, 127]:
905
        testdata = data * rep
906

    
907
        filename = PathJoin(self.tmpdir, "test.data_")
908
        utils.WriteFile(filename, data=testdata)
909
        self.assertFileContent(filename, testdata)
910

    
911
        for _ in range(3):
912
          bname = utils.CreateBackup(filename)
913
          bkpcount += 1
914
          self.assertFileContent(bname, testdata)
915
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
916

    
917

    
918
class TestParseCpuMask(unittest.TestCase):
919
  """Test case for the ParseCpuMask function."""
920

    
921
  def testWellFormed(self):
922
    self.assertEqual(utils.ParseCpuMask(""), [])
923
    self.assertEqual(utils.ParseCpuMask("1"), [1])
924
    self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
925

    
926
  def testInvalidInput(self):
927
    for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
928
      self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
929

    
930

    
931
class TestSshKeys(testutils.GanetiTestCase):
932
  """Test case for the AddAuthorizedKey function"""
933

    
934
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
935
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
936
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
937

    
938
  def setUp(self):
939
    testutils.GanetiTestCase.setUp(self)
940
    self.tmpname = self._CreateTempFile()
941
    handle = open(self.tmpname, 'w')
942
    try:
943
      handle.write("%s\n" % TestSshKeys.KEY_A)
944
      handle.write("%s\n" % TestSshKeys.KEY_B)
945
    finally:
946
      handle.close()
947

    
948
  def testAddingNewKey(self):
949
    utils.AddAuthorizedKey(self.tmpname,
950
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
951

    
952
    self.assertFileContent(self.tmpname,
953
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
954
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
955
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
956
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
957

    
958
  def testAddingAlmostButNotCompletelyTheSameKey(self):
959
    utils.AddAuthorizedKey(self.tmpname,
960
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
961

    
962
    self.assertFileContent(self.tmpname,
963
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
964
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
965
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
966
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
967

    
968
  def testAddingExistingKeyWithSomeMoreSpaces(self):
969
    utils.AddAuthorizedKey(self.tmpname,
970
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
971

    
972
    self.assertFileContent(self.tmpname,
973
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
974
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
975
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
976

    
977
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
978
    utils.RemoveAuthorizedKey(self.tmpname,
979
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
980

    
981
    self.assertFileContent(self.tmpname,
982
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
983
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
984

    
985
  def testRemovingNonExistingKey(self):
986
    utils.RemoveAuthorizedKey(self.tmpname,
987
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
988

    
989
    self.assertFileContent(self.tmpname,
990
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
991
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
992
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
993

    
994

    
995
class TestEtcHosts(testutils.GanetiTestCase):
996
  """Test functions modifying /etc/hosts"""
997

    
998
  def setUp(self):
999
    testutils.GanetiTestCase.setUp(self)
1000
    self.tmpname = self._CreateTempFile()
1001
    handle = open(self.tmpname, 'w')
1002
    try:
1003
      handle.write('# This is a test file for /etc/hosts\n')
1004
      handle.write('127.0.0.1\tlocalhost\n')
1005
      handle.write('192.0.2.1 router gw\n')
1006
    finally:
1007
      handle.close()
1008

    
1009
  def testSettingNewIp(self):
1010
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1011
                     ['myhost'])
1012

    
1013
    self.assertFileContent(self.tmpname,
1014
      "# This is a test file for /etc/hosts\n"
1015
      "127.0.0.1\tlocalhost\n"
1016
      "192.0.2.1 router gw\n"
1017
      "198.51.100.4\tmyhost.example.com myhost\n")
1018
    self.assertFileMode(self.tmpname, 0644)
1019

    
1020
  def testSettingExistingIp(self):
1021
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1022
                     ['myhost'])
1023

    
1024
    self.assertFileContent(self.tmpname,
1025
      "# This is a test file for /etc/hosts\n"
1026
      "127.0.0.1\tlocalhost\n"
1027
      "192.0.2.1\tmyhost.example.com myhost\n")
1028
    self.assertFileMode(self.tmpname, 0644)
1029

    
1030
  def testSettingDuplicateName(self):
1031
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1032

    
1033
    self.assertFileContent(self.tmpname,
1034
      "# This is a test file for /etc/hosts\n"
1035
      "127.0.0.1\tlocalhost\n"
1036
      "192.0.2.1 router gw\n"
1037
      "198.51.100.4\tmyhost\n")
1038
    self.assertFileMode(self.tmpname, 0644)
1039

    
1040
  def testRemovingExistingHost(self):
1041
    RemoveEtcHostsEntry(self.tmpname, 'router')
1042

    
1043
    self.assertFileContent(self.tmpname,
1044
      "# This is a test file for /etc/hosts\n"
1045
      "127.0.0.1\tlocalhost\n"
1046
      "192.0.2.1 gw\n")
1047
    self.assertFileMode(self.tmpname, 0644)
1048

    
1049
  def testRemovingSingleExistingHost(self):
1050
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1051

    
1052
    self.assertFileContent(self.tmpname,
1053
      "# This is a test file for /etc/hosts\n"
1054
      "192.0.2.1 router gw\n")
1055
    self.assertFileMode(self.tmpname, 0644)
1056

    
1057
  def testRemovingNonExistingHost(self):
1058
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1059

    
1060
    self.assertFileContent(self.tmpname,
1061
      "# This is a test file for /etc/hosts\n"
1062
      "127.0.0.1\tlocalhost\n"
1063
      "192.0.2.1 router gw\n")
1064
    self.assertFileMode(self.tmpname, 0644)
1065

    
1066
  def testRemovingAlias(self):
1067
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1068

    
1069
    self.assertFileContent(self.tmpname,
1070
      "# This is a test file for /etc/hosts\n"
1071
      "127.0.0.1\tlocalhost\n"
1072
      "192.0.2.1 router\n")
1073
    self.assertFileMode(self.tmpname, 0644)
1074

    
1075

    
1076
class TestGetMounts(unittest.TestCase):
1077
  """Test case for GetMounts()."""
1078

    
1079
  TESTDATA = (
1080
    "rootfs /     rootfs rw 0 0\n"
1081
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1082
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1083

    
1084
  def setUp(self):
1085
    self.tmpfile = tempfile.NamedTemporaryFile()
1086
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1087

    
1088
  def testGetMounts(self):
1089
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1090
      [
1091
        ("rootfs", "/", "rootfs", "rw"),
1092
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1093
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1094
      ])
1095

    
1096

    
1097
class TestListVisibleFiles(unittest.TestCase):
1098
  """Test case for ListVisibleFiles"""
1099

    
1100
  def setUp(self):
1101
    self.path = tempfile.mkdtemp()
1102

    
1103
  def tearDown(self):
1104
    shutil.rmtree(self.path)
1105

    
1106
  def _CreateFiles(self, files):
1107
    for name in files:
1108
      utils.WriteFile(os.path.join(self.path, name), data="test")
1109

    
1110
  def _test(self, files, expected):
1111
    self._CreateFiles(files)
1112
    found = ListVisibleFiles(self.path)
1113
    self.assertEqual(set(found), set(expected))
1114

    
1115
  def testAllVisible(self):
1116
    files = ["a", "b", "c"]
1117
    expected = files
1118
    self._test(files, expected)
1119

    
1120
  def testNoneVisible(self):
1121
    files = [".a", ".b", ".c"]
1122
    expected = []
1123
    self._test(files, expected)
1124

    
1125
  def testSomeVisible(self):
1126
    files = ["a", "b", ".c"]
1127
    expected = ["a", "b"]
1128
    self._test(files, expected)
1129

    
1130
  def testNonAbsolutePath(self):
1131
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1132

    
1133
  def testNonNormalizedPath(self):
1134
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1135
                          "/bin/../tmp")
1136

    
1137

    
1138
class TestNewUUID(unittest.TestCase):
1139
  """Test case for NewUUID"""
1140

    
1141
  def runTest(self):
1142
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1143

    
1144

    
1145
class TestFirstFree(unittest.TestCase):
1146
  """Test case for the FirstFree function"""
1147

    
1148
  def test(self):
1149
    """Test FirstFree"""
1150
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1151
    self.failUnlessEqual(FirstFree([]), None)
1152
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1153
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1154
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1155

    
1156

    
1157
class TestTailFile(testutils.GanetiTestCase):
1158
  """Test case for the TailFile function"""
1159

    
1160
  def testEmpty(self):
1161
    fname = self._CreateTempFile()
1162
    self.failUnlessEqual(TailFile(fname), [])
1163
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1164

    
1165
  def testAllLines(self):
1166
    data = ["test %d" % i for i in range(30)]
1167
    for i in range(30):
1168
      fname = self._CreateTempFile()
1169
      fd = open(fname, "w")
1170
      fd.write("\n".join(data[:i]))
1171
      if i > 0:
1172
        fd.write("\n")
1173
      fd.close()
1174
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1175

    
1176
  def testPartialLines(self):
1177
    data = ["test %d" % i for i in range(30)]
1178
    fname = self._CreateTempFile()
1179
    fd = open(fname, "w")
1180
    fd.write("\n".join(data))
1181
    fd.write("\n")
1182
    fd.close()
1183
    for i in range(1, 30):
1184
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1185

    
1186
  def testBigFile(self):
1187
    data = ["test %d" % i for i in range(30)]
1188
    fname = self._CreateTempFile()
1189
    fd = open(fname, "w")
1190
    fd.write("X" * 1048576)
1191
    fd.write("\n")
1192
    fd.write("\n".join(data))
1193
    fd.write("\n")
1194
    fd.close()
1195
    for i in range(1, 30):
1196
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1197

    
1198

    
1199
class _BaseFileLockTest:
1200
  """Test case for the FileLock class"""
1201

    
1202
  def testSharedNonblocking(self):
1203
    self.lock.Shared(blocking=False)
1204
    self.lock.Close()
1205

    
1206
  def testExclusiveNonblocking(self):
1207
    self.lock.Exclusive(blocking=False)
1208
    self.lock.Close()
1209

    
1210
  def testUnlockNonblocking(self):
1211
    self.lock.Unlock(blocking=False)
1212
    self.lock.Close()
1213

    
1214
  def testSharedBlocking(self):
1215
    self.lock.Shared(blocking=True)
1216
    self.lock.Close()
1217

    
1218
  def testExclusiveBlocking(self):
1219
    self.lock.Exclusive(blocking=True)
1220
    self.lock.Close()
1221

    
1222
  def testUnlockBlocking(self):
1223
    self.lock.Unlock(blocking=True)
1224
    self.lock.Close()
1225

    
1226
  def testSharedExclusiveUnlock(self):
1227
    self.lock.Shared(blocking=False)
1228
    self.lock.Exclusive(blocking=False)
1229
    self.lock.Unlock(blocking=False)
1230
    self.lock.Close()
1231

    
1232
  def testExclusiveSharedUnlock(self):
1233
    self.lock.Exclusive(blocking=False)
1234
    self.lock.Shared(blocking=False)
1235
    self.lock.Unlock(blocking=False)
1236
    self.lock.Close()
1237

    
1238
  def testSimpleTimeout(self):
1239
    # These will succeed on the first attempt, hence a short timeout
1240
    self.lock.Shared(blocking=True, timeout=10.0)
1241
    self.lock.Exclusive(blocking=False, timeout=10.0)
1242
    self.lock.Unlock(blocking=True, timeout=10.0)
1243
    self.lock.Close()
1244

    
1245
  @staticmethod
1246
  def _TryLockInner(filename, shared, blocking):
1247
    lock = utils.FileLock.Open(filename)
1248

    
1249
    if shared:
1250
      fn = lock.Shared
1251
    else:
1252
      fn = lock.Exclusive
1253

    
1254
    try:
1255
      # The timeout doesn't really matter as the parent process waits for us to
1256
      # finish anyway.
1257
      fn(blocking=blocking, timeout=0.01)
1258
    except errors.LockError, err:
1259
      return False
1260

    
1261
    return True
1262

    
1263
  def _TryLock(self, *args):
1264
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1265
                                      *args)
1266

    
1267
  def testTimeout(self):
1268
    for blocking in [True, False]:
1269
      self.lock.Exclusive(blocking=True)
1270
      self.failIf(self._TryLock(False, blocking))
1271
      self.failIf(self._TryLock(True, blocking))
1272

    
1273
      self.lock.Shared(blocking=True)
1274
      self.assert_(self._TryLock(True, blocking))
1275
      self.failIf(self._TryLock(False, blocking))
1276

    
1277
  def testCloseShared(self):
1278
    self.lock.Close()
1279
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1280

    
1281
  def testCloseExclusive(self):
1282
    self.lock.Close()
1283
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1284

    
1285
  def testCloseUnlock(self):
1286
    self.lock.Close()
1287
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1288

    
1289

    
1290
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1291
  TESTDATA = "Hello World\n" * 10
1292

    
1293
  def setUp(self):
1294
    testutils.GanetiTestCase.setUp(self)
1295

    
1296
    self.tmpfile = tempfile.NamedTemporaryFile()
1297
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1298
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1299

    
1300
    # Ensure "Open" didn't truncate file
1301
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1302

    
1303
  def tearDown(self):
1304
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1305

    
1306
    testutils.GanetiTestCase.tearDown(self)
1307

    
1308

    
1309
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1310
  def setUp(self):
1311
    self.tmpfile = tempfile.NamedTemporaryFile()
1312
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1313

    
1314

    
1315
class TestTimeFunctions(unittest.TestCase):
1316
  """Test case for time functions"""
1317

    
1318
  def runTest(self):
1319
    self.assertEqual(utils.SplitTime(1), (1, 0))
1320
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1321
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1322
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1323
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1324
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1325
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1326
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1327

    
1328
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1329

    
1330
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1331
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1332
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1333

    
1334
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1335
                     1218448917.481)
1336
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1337

    
1338
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1339
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1340
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1341
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1342
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1343

    
1344

    
1345
class FieldSetTestCase(unittest.TestCase):
1346
  """Test case for FieldSets"""
1347

    
1348
  def testSimpleMatch(self):
1349
    f = utils.FieldSet("a", "b", "c", "def")
1350
    self.failUnless(f.Matches("a"))
1351
    self.failIf(f.Matches("d"), "Substring matched")
1352
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1353
    self.failIf(f.NonMatching(["b", "c"]))
1354
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1355
    self.failUnless(f.NonMatching(["a", "d"]))
1356

    
1357
  def testRegexMatch(self):
1358
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1359
    self.failUnless(f.Matches("b1"))
1360
    self.failUnless(f.Matches("b99"))
1361
    self.failIf(f.Matches("b/1"))
1362
    self.failIf(f.NonMatching(["b12", "c"]))
1363
    self.failUnless(f.NonMatching(["a", "1"]))
1364

    
1365
class TestForceDictType(unittest.TestCase):
1366
  """Test case for ForceDictType"""
1367
  KEY_TYPES = {
1368
    "a": constants.VTYPE_INT,
1369
    "b": constants.VTYPE_BOOL,
1370
    "c": constants.VTYPE_STRING,
1371
    "d": constants.VTYPE_SIZE,
1372
    "e": constants.VTYPE_MAYBE_STRING,
1373
    }
1374

    
1375
  def _fdt(self, dict, allowed_values=None):
1376
    if allowed_values is None:
1377
      utils.ForceDictType(dict, self.KEY_TYPES)
1378
    else:
1379
      utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
1380

    
1381
    return dict
1382

    
1383
  def testSimpleDict(self):
1384
    self.assertEqual(self._fdt({}), {})
1385
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1386
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1387
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1388
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1389
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1390
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1391
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1392
    self.assertEqual(self._fdt({'b': False}), {'b': False})
1393
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1394
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1395
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1396
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1397
    self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1398
    self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1399
    self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1400
    self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
1401

    
1402
  def testErrors(self):
1403
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1404
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
1405
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1406
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1407
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1408
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1409
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1410
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
1411
    self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
1412
    self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
1413
                      {"b": "hello"}, {"b": "no-such-type"})
1414

    
1415

    
1416
class TestIsNormAbsPath(unittest.TestCase):
1417
  """Testing case for IsNormAbsPath"""
1418

    
1419
  def _pathTestHelper(self, path, result):
1420
    if result:
1421
      self.assert_(utils.IsNormAbsPath(path),
1422
          "Path %s should result absolute and normalized" % path)
1423
    else:
1424
      self.assertFalse(utils.IsNormAbsPath(path),
1425
          "Path %s should not result absolute and normalized" % path)
1426

    
1427
  def testBase(self):
1428
    self._pathTestHelper('/etc', True)
1429
    self._pathTestHelper('/srv', True)
1430
    self._pathTestHelper('etc', False)
1431
    self._pathTestHelper('/etc/../root', False)
1432
    self._pathTestHelper('/etc/', False)
1433

    
1434

    
1435
class RunInSeparateProcess(unittest.TestCase):
1436
  def test(self):
1437
    for exp in [True, False]:
1438
      def _child():
1439
        return exp
1440

    
1441
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1442

    
1443
  def testArgs(self):
1444
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1445
      def _child(carg1, carg2):
1446
        return carg1 == "Foo" and carg2 == arg
1447

    
1448
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1449

    
1450
  def testPid(self):
1451
    parent_pid = os.getpid()
1452

    
1453
    def _check():
1454
      return os.getpid() == parent_pid
1455

    
1456
    self.failIf(utils.RunInSeparateProcess(_check))
1457

    
1458
  def testSignal(self):
1459
    def _kill():
1460
      os.kill(os.getpid(), signal.SIGTERM)
1461

    
1462
    self.assertRaises(errors.GenericError,
1463
                      utils.RunInSeparateProcess, _kill)
1464

    
1465
  def testException(self):
1466
    def _exc():
1467
      raise errors.GenericError("This is a test")
1468

    
1469
    self.assertRaises(errors.GenericError,
1470
                      utils.RunInSeparateProcess, _exc)
1471

    
1472

    
1473
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1474
  def setUp(self):
1475
    self.tmpdir = tempfile.mkdtemp()
1476

    
1477
  def tearDown(self):
1478
    shutil.rmtree(self.tmpdir)
1479

    
1480
  def _checkRsaPrivateKey(self, key):
1481
    lines = key.splitlines()
1482
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1483
            "-----END RSA PRIVATE KEY-----" in lines)
1484

    
1485
  def _checkCertificate(self, cert):
1486
    lines = cert.splitlines()
1487
    return ("-----BEGIN CERTIFICATE-----" in lines and
1488
            "-----END CERTIFICATE-----" in lines)
1489

    
1490
  def test(self):
1491
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1492
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1493
      self._checkRsaPrivateKey(key_pem)
1494
      self._checkCertificate(cert_pem)
1495

    
1496
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1497
                                           key_pem)
1498
      self.assert_(key.bits() >= 1024)
1499
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1500
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1501

    
1502
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1503
                                             cert_pem)
1504
      self.failIf(x509.has_expired())
1505
      self.assertEqual(x509.get_issuer().CN, common_name)
1506
      self.assertEqual(x509.get_subject().CN, common_name)
1507
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1508

    
1509
  def testLegacy(self):
1510
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1511

    
1512
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1513

    
1514
    cert1 = utils.ReadFile(cert1_filename)
1515

    
1516
    self.assert_(self._checkRsaPrivateKey(cert1))
1517
    self.assert_(self._checkCertificate(cert1))
1518

    
1519

    
1520
class TestPathJoin(unittest.TestCase):
1521
  """Testing case for PathJoin"""
1522

    
1523
  def testBasicItems(self):
1524
    mlist = ["/a", "b", "c"]
1525
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1526

    
1527
  def testNonAbsPrefix(self):
1528
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1529

    
1530
  def testBackTrack(self):
1531
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1532

    
1533
  def testMultiAbs(self):
1534
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1535

    
1536

    
1537
class TestValidateServiceName(unittest.TestCase):
1538
  def testValid(self):
1539
    testnames = [
1540
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1541
      "ganeti",
1542
      "gnt-masterd",
1543
      "HELLO_WORLD_SVC",
1544
      "hello.world.1",
1545
      "0", "80", "1111", "65535",
1546
      ]
1547

    
1548
    for name in testnames:
1549
      self.assertEqual(utils.ValidateServiceName(name), name)
1550

    
1551
  def testInvalid(self):
1552
    testnames = [
1553
      -15756, -1, 65536, 133428083,
1554
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1555
      "-8546", "-1", "65536",
1556
      (129 * "A"),
1557
      ]
1558

    
1559
    for name in testnames:
1560
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1561

    
1562

    
1563
class TestParseAsn1Generalizedtime(unittest.TestCase):
1564
  def test(self):
1565
    # UTC
1566
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1567
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1568
                     1266860512)
1569
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1570
                     (2**31) - 1)
1571

    
1572
    # With offset
1573
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1574
                     1266860512)
1575
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1576
                     1266931012)
1577
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1578
                     1266931088)
1579
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1580
                     1266931295)
1581
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1582
                     3600)
1583

    
1584
    # Leap seconds are not supported by datetime.datetime
1585
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1586
                      "19841231235960+0000")
1587
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1588
                      "19920630235960+0000")
1589

    
1590
    # Errors
1591
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1592
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1593
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1594
                      "20100222174152")
1595
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1596
                      "Mon Feb 22 17:47:02 UTC 2010")
1597
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1598
                      "2010-02-22 17:42:02")
1599

    
1600

    
1601
class TestGetX509CertValidity(testutils.GanetiTestCase):
1602
  def setUp(self):
1603
    testutils.GanetiTestCase.setUp(self)
1604

    
1605
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1606

    
1607
    # Test whether we have pyOpenSSL 0.7 or above
1608
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1609

    
1610
    if not self.pyopenssl0_7:
1611
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1612
                    " function correctly")
1613

    
1614
  def _LoadCert(self, name):
1615
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1616
                                           self._ReadTestData(name))
1617

    
1618
  def test(self):
1619
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1620
    if self.pyopenssl0_7:
1621
      self.assertEqual(validity, (1266919967, 1267524767))
1622
    else:
1623
      self.assertEqual(validity, (None, None))
1624

    
1625

    
1626
class TestSignX509Certificate(unittest.TestCase):
1627
  KEY = "My private key!"
1628
  KEY_OTHER = "Another key"
1629

    
1630
  def test(self):
1631
    # Generate certificate valid for 5 minutes
1632
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1633

    
1634
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1635
                                           cert_pem)
1636

    
1637
    # No signature at all
1638
    self.assertRaises(errors.GenericError,
1639
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1640

    
1641
    # Invalid input
1642
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1643
                      "", self.KEY)
1644
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1645
                      "X-Ganeti-Signature: \n", self.KEY)
1646
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1647
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
1648
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1649
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
1650
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1651
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
1652

    
1653
    # Invalid salt
1654
    for salt in list("-_@$,:;/\\ \t\n"):
1655
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
1656
                        cert_pem, self.KEY, "foo%sbar" % salt)
1657

    
1658
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
1659
                 utils.GenerateSecret(numbytes=4),
1660
                 utils.GenerateSecret(numbytes=16),
1661
                 "{123:456}".encode("hex")]:
1662
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
1663

    
1664
      self._Check(cert, salt, signed_pem)
1665

    
1666
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
1667
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
1668
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
1669
                               "lines----\n------ at\nthe end!"))
1670

    
1671
  def _Check(self, cert, salt, pem):
1672
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
1673
    self.assertEqual(salt, salt2)
1674
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
1675

    
1676
    # Other key
1677
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
1678
                      pem, self.KEY_OTHER)
1679

    
1680

    
1681
class TestMakedirs(unittest.TestCase):
1682
  def setUp(self):
1683
    self.tmpdir = tempfile.mkdtemp()
1684

    
1685
  def tearDown(self):
1686
    shutil.rmtree(self.tmpdir)
1687

    
1688
  def testNonExisting(self):
1689
    path = PathJoin(self.tmpdir, "foo")
1690
    utils.Makedirs(path)
1691
    self.assert_(os.path.isdir(path))
1692

    
1693
  def testExisting(self):
1694
    path = PathJoin(self.tmpdir, "foo")
1695
    os.mkdir(path)
1696
    utils.Makedirs(path)
1697
    self.assert_(os.path.isdir(path))
1698

    
1699
  def testRecursiveNonExisting(self):
1700
    path = PathJoin(self.tmpdir, "foo/bar/baz")
1701
    utils.Makedirs(path)
1702
    self.assert_(os.path.isdir(path))
1703

    
1704
  def testRecursiveExisting(self):
1705
    path = PathJoin(self.tmpdir, "B/moo/xyz")
1706
    self.assertFalse(os.path.exists(path))
1707
    os.mkdir(PathJoin(self.tmpdir, "B"))
1708
    utils.Makedirs(path)
1709
    self.assert_(os.path.isdir(path))
1710

    
1711

    
1712
class TestReadLockedPidFile(unittest.TestCase):
1713
  def setUp(self):
1714
    self.tmpdir = tempfile.mkdtemp()
1715

    
1716
  def tearDown(self):
1717
    shutil.rmtree(self.tmpdir)
1718

    
1719
  def testNonExistent(self):
1720
    path = PathJoin(self.tmpdir, "nonexist")
1721
    self.assert_(utils.ReadLockedPidFile(path) is None)
1722

    
1723
  def testUnlocked(self):
1724
    path = PathJoin(self.tmpdir, "pid")
1725
    utils.WriteFile(path, data="123")
1726
    self.assert_(utils.ReadLockedPidFile(path) is None)
1727

    
1728
  def testLocked(self):
1729
    path = PathJoin(self.tmpdir, "pid")
1730
    utils.WriteFile(path, data="123")
1731

    
1732
    fl = utils.FileLock.Open(path)
1733
    try:
1734
      fl.Exclusive(blocking=True)
1735

    
1736
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
1737
    finally:
1738
      fl.Close()
1739

    
1740
    self.assert_(utils.ReadLockedPidFile(path) is None)
1741

    
1742
  def testError(self):
1743
    path = PathJoin(self.tmpdir, "foobar", "pid")
1744
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
1745
    # open(2) should return ENOTDIR
1746
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
1747

    
1748

    
1749
class TestCertVerification(testutils.GanetiTestCase):
1750
  def setUp(self):
1751
    testutils.GanetiTestCase.setUp(self)
1752

    
1753
    self.tmpdir = tempfile.mkdtemp()
1754

    
1755
  def tearDown(self):
1756
    shutil.rmtree(self.tmpdir)
1757

    
1758
  def testVerifyCertificate(self):
1759
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
1760
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1761
                                           cert_pem)
1762

    
1763
    # Not checking return value as this certificate is expired
1764
    utils.VerifyX509Certificate(cert, 30, 7)
1765

    
1766

    
1767
class TestVerifyCertificateInner(unittest.TestCase):
1768
  def test(self):
1769
    vci = utils._VerifyCertificateInner
1770

    
1771
    # Valid
1772
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
1773
                     (None, None))
1774

    
1775
    # Not yet valid
1776
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
1777
    self.assertEqual(errcode, utils.CERT_WARNING)
1778

    
1779
    # Expiring soon
1780
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
1781
    self.assertEqual(errcode, utils.CERT_ERROR)
1782

    
1783
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
1784
    self.assertEqual(errcode, utils.CERT_WARNING)
1785

    
1786
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
1787
    self.assertEqual(errcode, None)
1788

    
1789
    # Expired
1790
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
1791
    self.assertEqual(errcode, utils.CERT_ERROR)
1792

    
1793
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
1794
    self.assertEqual(errcode, utils.CERT_ERROR)
1795

    
1796
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
1797
    self.assertEqual(errcode, utils.CERT_ERROR)
1798

    
1799
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
1800
    self.assertEqual(errcode, utils.CERT_ERROR)
1801

    
1802

    
1803
class TestIgnoreSignals(unittest.TestCase):
1804
  """Test the IgnoreSignals decorator"""
1805

    
1806
  @staticmethod
1807
  def _Raise(exception):
1808
    raise exception
1809

    
1810
  @staticmethod
1811
  def _Return(rval):
1812
    return rval
1813

    
1814
  def testIgnoreSignals(self):
1815
    sock_err_intr = socket.error(errno.EINTR, "Message")
1816
    sock_err_inval = socket.error(errno.EINVAL, "Message")
1817

    
1818
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
1819
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
1820

    
1821
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
1822
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
1823
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
1824
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
1825

    
1826
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
1827
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
1828
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
1829
                      sock_err_inval)
1830
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
1831
                      env_err_inval)
1832

    
1833
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
1834
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
1835

    
1836

    
1837
class TestEnsureDirs(unittest.TestCase):
1838
  """Tests for EnsureDirs"""
1839

    
1840
  def setUp(self):
1841
    self.dir = tempfile.mkdtemp()
1842
    self.old_umask = os.umask(0777)
1843

    
1844
  def testEnsureDirs(self):
1845
    utils.EnsureDirs([
1846
        (PathJoin(self.dir, "foo"), 0777),
1847
        (PathJoin(self.dir, "bar"), 0000),
1848
        ])
1849
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
1850
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
1851

    
1852
  def tearDown(self):
1853
    os.rmdir(PathJoin(self.dir, "foo"))
1854
    os.rmdir(PathJoin(self.dir, "bar"))
1855
    os.rmdir(self.dir)
1856
    os.umask(self.old_umask)
1857

    
1858

    
1859
class TestIgnoreProcessNotFound(unittest.TestCase):
1860
  @staticmethod
1861
  def _WritePid(fd):
1862
    os.write(fd, str(os.getpid()))
1863
    os.close(fd)
1864
    return True
1865

    
1866
  def test(self):
1867
    (pid_read_fd, pid_write_fd) = os.pipe()
1868

    
1869
    # Start short-lived process which writes its PID to pipe
1870
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
1871
    os.close(pid_write_fd)
1872

    
1873
    # Read PID from pipe
1874
    pid = int(os.read(pid_read_fd, 1024))
1875
    os.close(pid_read_fd)
1876

    
1877
    # Try to send signal to process which exited recently
1878
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
1879

    
1880

    
1881
class TestFindMatch(unittest.TestCase):
1882
  def test(self):
1883
    data = {
1884
      "aaaa": "Four A",
1885
      "bb": {"Two B": True},
1886
      re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
1887
      }
1888

    
1889
    self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
1890
    self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
1891

    
1892
    for i in ["foo", "bar", "bazX"]:
1893
      for j in range(1, 100, 7):
1894
        self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
1895
                         ((1, 2, 3), [i, str(j)]))
1896

    
1897
  def testNoMatch(self):
1898
    self.assert_(utils.FindMatch({}, "") is None)
1899
    self.assert_(utils.FindMatch({}, "foo") is None)
1900
    self.assert_(utils.FindMatch({}, 1234) is None)
1901

    
1902
    data = {
1903
      "X": "Hello World",
1904
      re.compile("^(something)$"): "Hello World",
1905
      }
1906

    
1907
    self.assert_(utils.FindMatch(data, "") is None)
1908
    self.assert_(utils.FindMatch(data, "Hello World") is None)
1909

    
1910

    
1911
class TestFileID(testutils.GanetiTestCase):
1912
  def testEquality(self):
1913
    name = self._CreateTempFile()
1914
    oldi = utils.GetFileID(path=name)
1915
    self.failUnless(utils.VerifyFileID(oldi, oldi))
1916

    
1917
  def testUpdate(self):
1918
    name = self._CreateTempFile()
1919
    oldi = utils.GetFileID(path=name)
1920
    os.utime(name, None)
1921
    fd = os.open(name, os.O_RDWR)
1922
    try:
1923
      newi = utils.GetFileID(fd=fd)
1924
      self.failUnless(utils.VerifyFileID(oldi, newi))
1925
      self.failUnless(utils.VerifyFileID(newi, oldi))
1926
    finally:
1927
      os.close(fd)
1928

    
1929
  def testWriteFile(self):
1930
    name = self._CreateTempFile()
1931
    oldi = utils.GetFileID(path=name)
1932
    mtime = oldi[2]
1933
    os.utime(name, (mtime + 10, mtime + 10))
1934
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
1935
                      oldi, data="")
1936
    os.utime(name, (mtime - 10, mtime - 10))
1937
    utils.SafeWriteFile(name, oldi, data="")
1938
    oldi = utils.GetFileID(path=name)
1939
    mtime = oldi[2]
1940
    os.utime(name, (mtime + 10, mtime + 10))
1941
    # this doesn't raise, since we passed None
1942
    utils.SafeWriteFile(name, None, data="")
1943

    
1944
  def testError(self):
1945
    t = tempfile.NamedTemporaryFile()
1946
    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
1947
                      path=t.name, fd=t.fileno())
1948

    
1949

    
1950
class TimeMock:
1951
  def __init__(self, values):
1952
    self.values = values
1953

    
1954
  def __call__(self):
1955
    return self.values.pop(0)
1956

    
1957

    
1958
class TestRunningTimeout(unittest.TestCase):
1959
  def setUp(self):
1960
    self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
1961

    
1962
  def testRemainingFloat(self):
1963
    timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
1964
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
1965
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
1966
    self.assertAlmostEqual(timeout.Remaining(), -1.5)
1967

    
1968
  def testRemaining(self):
1969
    self.time_fn = TimeMock([0, 2, 4, 5, 6])
1970
    timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
1971
    self.assertEqual(timeout.Remaining(), 3)
1972
    self.assertEqual(timeout.Remaining(), 1)
1973
    self.assertEqual(timeout.Remaining(), 0)
1974
    self.assertEqual(timeout.Remaining(), -1)
1975

    
1976
  def testRemainingNonNegative(self):
1977
    timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
1978
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
1979
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
1980
    self.assertEqual(timeout.Remaining(), 0.0)
1981

    
1982
  def testNegativeTimeout(self):
1983
    self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
1984

    
1985

    
1986
class TestTryConvert(unittest.TestCase):
1987
  def test(self):
1988
    for src, fn, result in [
1989
      ("1", int, 1),
1990
      ("a", int, "a"),
1991
      ("", bool, False),
1992
      ("a", bool, True),
1993
      ]:
1994
      self.assertEqual(utils.TryConvert(fn, src), result)
1995

    
1996

    
1997
class TestIsValidShellParam(unittest.TestCase):
1998
  def test(self):
1999
    for val, result in [
2000
      ("abc", True),
2001
      ("ab;cd", False),
2002
      ]:
2003
      self.assertEqual(utils.IsValidShellParam(val), result)
2004

    
2005

    
2006
class TestBuildShellCmd(unittest.TestCase):
2007
  def test(self):
2008
    self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
2009
                      "ls %s", "ab;cd")
2010
    self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
2011

    
2012

    
2013
class TestWriteFile(unittest.TestCase):
2014
  def setUp(self):
2015
    self.tfile = tempfile.NamedTemporaryFile()
2016
    self.did_pre = False
2017
    self.did_post = False
2018
    self.did_write = False
2019

    
2020
  def markPre(self, fd):
2021
    self.did_pre = True
2022

    
2023
  def markPost(self, fd):
2024
    self.did_post = True
2025

    
2026
  def markWrite(self, fd):
2027
    self.did_write = True
2028

    
2029
  def testWrite(self):
2030
    data = "abc"
2031
    utils.WriteFile(self.tfile.name, data=data)
2032
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
2033

    
2034
  def testErrors(self):
2035
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2036
                      self.tfile.name, data="test", fn=lambda fd: None)
2037
    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
2038
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2039
                      self.tfile.name, data="test", atime=0)
2040

    
2041
  def testCalls(self):
2042
    utils.WriteFile(self.tfile.name, fn=self.markWrite,
2043
                    prewrite=self.markPre, postwrite=self.markPost)
2044
    self.assertTrue(self.did_pre)
2045
    self.assertTrue(self.did_post)
2046
    self.assertTrue(self.did_write)
2047

    
2048
  def testDryRun(self):
2049
    orig = "abc"
2050
    self.tfile.write(orig)
2051
    self.tfile.flush()
2052
    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
2053
    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
2054

    
2055
  def testTimes(self):
2056
    f = self.tfile.name
2057
    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
2058
                   (int(time.time()), 5000)]:
2059
      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
2060
      st = os.stat(f)
2061
      self.assertEqual(st.st_atime, at)
2062
      self.assertEqual(st.st_mtime, mt)
2063

    
2064

    
2065
  def testNoClose(self):
2066
    data = "hello"
2067
    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
2068
    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
2069
    try:
2070
      os.lseek(fd, 0, 0)
2071
      self.assertEqual(os.read(fd, 4096), data)
2072
    finally:
2073
      os.close(fd)
2074

    
2075

    
2076
if __name__ == '__main__':
2077
  testutils.GanetiTestProgram()