Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ 9f37f689

History | View | Annotate | Download (94.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import distutils.version
25
import errno
26
import fcntl
27
import glob
28
import os
29
import os.path
30
import re
31
import shutil
32
import signal
33
import socket
34
import stat
35
import string
36
import tempfile
37
import time
38
import unittest
39
import warnings
40
import OpenSSL
41
import random
42
import operator
43
from cStringIO import StringIO
44

    
45
import testutils
46
from ganeti import constants
47
from ganeti import compat
48
from ganeti import utils
49
from ganeti import errors
50
from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \
51
     ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \
52
     TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \
53
     ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry
54

    
55

    
56
class TestIsProcessAlive(unittest.TestCase):
57
  """Testing case for IsProcessAlive"""
58

    
59
  def testExists(self):
60
    mypid = os.getpid()
61
    self.assert_(utils.IsProcessAlive(mypid), "can't find myself running")
62

    
63
  def testNotExisting(self):
64
    pid_non_existing = os.fork()
65
    if pid_non_existing == 0:
66
      os._exit(0)
67
    elif pid_non_existing < 0:
68
      raise SystemError("can't fork")
69
    os.waitpid(pid_non_existing, 0)
70
    self.assertFalse(utils.IsProcessAlive(pid_non_existing),
71
                     "nonexisting process detected")
72

    
73

    
74
class TestGetProcStatusPath(unittest.TestCase):
75
  def test(self):
76
    self.assert_("/1234/" in utils._GetProcStatusPath(1234))
77
    self.assertNotEqual(utils._GetProcStatusPath(1),
78
                        utils._GetProcStatusPath(2))
79

    
80

    
81
class TestIsProcessHandlingSignal(unittest.TestCase):
82
  def setUp(self):
83
    self.tmpdir = tempfile.mkdtemp()
84

    
85
  def tearDown(self):
86
    shutil.rmtree(self.tmpdir)
87

    
88
  def testParseSigsetT(self):
89
    self.assertEqual(len(utils._ParseSigsetT("0")), 0)
90
    self.assertEqual(utils._ParseSigsetT("1"), set([1]))
91
    self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17]))
92
    self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ]))
93
    self.assertEqual(utils._ParseSigsetT("0000000180000202"),
94
                     set([2, 10, 32, 33]))
95
    self.assertEqual(utils._ParseSigsetT("0000000180000002"),
96
                     set([2, 32, 33]))
97
    self.assertEqual(utils._ParseSigsetT("0000000188000002"),
98
                     set([2, 28, 32, 33]))
99
    self.assertEqual(utils._ParseSigsetT("000000004b813efb"),
100
                     set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17,
101
                          24, 25, 26, 28, 31]))
102
    self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25)))
103

    
104
  def testGetProcStatusField(self):
105
    for field in ["SigCgt", "Name", "FDSize"]:
106
      for value in ["", "0", "cat", "  1234 KB"]:
107
        pstatus = "\n".join([
108
          "VmPeak: 999 kB",
109
          "%s: %s" % (field, value),
110
          "TracerPid: 0",
111
          ])
112
        result = utils._GetProcStatusField(pstatus, field)
113
        self.assertEqual(result, value.strip())
114

    
115
  def test(self):
116
    sp = PathJoin(self.tmpdir, "status")
117

    
118
    utils.WriteFile(sp, data="\n".join([
119
      "Name:   bash",
120
      "State:  S (sleeping)",
121
      "SleepAVG:       98%",
122
      "Pid:    22250",
123
      "PPid:   10858",
124
      "TracerPid:      0",
125
      "SigBlk: 0000000000010000",
126
      "SigIgn: 0000000000384004",
127
      "SigCgt: 000000004b813efb",
128
      "CapEff: 0000000000000000",
129
      ]))
130

    
131
    self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
132

    
133
  def testNoSigCgt(self):
134
    sp = PathJoin(self.tmpdir, "status")
135

    
136
    utils.WriteFile(sp, data="\n".join([
137
      "Name:   bash",
138
      ]))
139

    
140
    self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal,
141
                      1234, 10, status_path=sp)
142

    
143
  def testNoSuchFile(self):
144
    sp = PathJoin(self.tmpdir, "notexist")
145

    
146
    self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp))
147

    
148
  @staticmethod
149
  def _TestRealProcess():
150
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
151
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
152
      raise Exception("SIGUSR1 is handled when it should not be")
153

    
154
    signal.signal(signal.SIGUSR1, lambda signum, frame: None)
155
    if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
156
      raise Exception("SIGUSR1 is not handled when it should be")
157

    
158
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
159
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
160
      raise Exception("SIGUSR1 is not handled when it should be")
161

    
162
    signal.signal(signal.SIGUSR1, signal.SIG_DFL)
163
    if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1):
164
      raise Exception("SIGUSR1 is handled when it should not be")
165

    
166
    return True
167

    
168
  def testRealProcess(self):
169
    self.assert_(utils.RunInSeparateProcess(self._TestRealProcess))
170

    
171

    
172
class TestPidFileFunctions(unittest.TestCase):
173
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
174

    
175
  def setUp(self):
176
    self.dir = tempfile.mkdtemp()
177
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
178
    utils.DaemonPidFileName = self.f_dpn
179

    
180
  def testPidFileFunctions(self):
181
    pid_file = self.f_dpn('test')
182
    fd = utils.WritePidFile(self.f_dpn('test'))
183
    self.failUnless(os.path.exists(pid_file),
184
                    "PID file should have been created")
185
    read_pid = utils.ReadPidFile(pid_file)
186
    self.failUnlessEqual(read_pid, os.getpid())
187
    self.failUnless(utils.IsProcessAlive(read_pid))
188
    self.failUnlessRaises(errors.LockError, utils.WritePidFile,
189
                          self.f_dpn('test'))
190
    os.close(fd)
191
    utils.RemovePidFile('test')
192
    self.failIf(os.path.exists(pid_file),
193
                "PID file should not exist anymore")
194
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
195
                         "ReadPidFile should return 0 for missing pid file")
196
    fh = open(pid_file, "w")
197
    fh.write("blah\n")
198
    fh.close()
199
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
200
                         "ReadPidFile should return 0 for invalid pid file")
201
    # but now, even with the file existing, we should be able to lock it
202
    fd = utils.WritePidFile(self.f_dpn('test'))
203
    os.close(fd)
204
    utils.RemovePidFile('test')
205
    self.failIf(os.path.exists(pid_file),
206
                "PID file should not exist anymore")
207

    
208
  def testKill(self):
209
    pid_file = self.f_dpn('child')
210
    r_fd, w_fd = os.pipe()
211
    new_pid = os.fork()
212
    if new_pid == 0: #child
213
      utils.WritePidFile(self.f_dpn('child'))
214
      os.write(w_fd, 'a')
215
      signal.pause()
216
      os._exit(0)
217
      return
218
    # else we are in the parent
219
    # wait until the child has written the pid file
220
    os.read(r_fd, 1)
221
    read_pid = utils.ReadPidFile(pid_file)
222
    self.failUnlessEqual(read_pid, new_pid)
223
    self.failUnless(utils.IsProcessAlive(new_pid))
224
    utils.KillProcess(new_pid, waitpid=True)
225
    self.failIf(utils.IsProcessAlive(new_pid))
226
    utils.RemovePidFile('child')
227
    self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0)
228

    
229
  def tearDown(self):
230
    for name in os.listdir(self.dir):
231
      os.unlink(os.path.join(self.dir, name))
232
    os.rmdir(self.dir)
233

    
234

    
235
class TestRunCmd(testutils.GanetiTestCase):
236
  """Testing case for the RunCmd function"""
237

    
238
  def setUp(self):
239
    testutils.GanetiTestCase.setUp(self)
240
    self.magic = time.ctime() + " ganeti test"
241
    self.fname = self._CreateTempFile()
242
    self.fifo_tmpdir = tempfile.mkdtemp()
243
    self.fifo_file = os.path.join(self.fifo_tmpdir, "ganeti_test_fifo")
244
    os.mkfifo(self.fifo_file)
245

    
246
  def tearDown(self):
247
    shutil.rmtree(self.fifo_tmpdir)
248

    
249
  def testOk(self):
250
    """Test successful exit code"""
251
    result = RunCmd("/bin/sh -c 'exit 0'")
252
    self.assertEqual(result.exit_code, 0)
253
    self.assertEqual(result.output, "")
254

    
255
  def testFail(self):
256
    """Test fail exit code"""
257
    result = RunCmd("/bin/sh -c 'exit 1'")
258
    self.assertEqual(result.exit_code, 1)
259
    self.assertEqual(result.output, "")
260

    
261
  def testStdout(self):
262
    """Test standard output"""
263
    cmd = 'echo -n "%s"' % self.magic
264
    result = RunCmd("/bin/sh -c '%s'" % cmd)
265
    self.assertEqual(result.stdout, self.magic)
266
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
267
    self.assertEqual(result.output, "")
268
    self.assertFileContent(self.fname, self.magic)
269

    
270
  def testStderr(self):
271
    """Test standard error"""
272
    cmd = 'echo -n "%s"' % self.magic
273
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
274
    self.assertEqual(result.stderr, self.magic)
275
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
276
    self.assertEqual(result.output, "")
277
    self.assertFileContent(self.fname, self.magic)
278

    
279
  def testCombined(self):
280
    """Test combined output"""
281
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
282
    expected = "A" + self.magic + "B" + self.magic
283
    result = RunCmd("/bin/sh -c '%s'" % cmd)
284
    self.assertEqual(result.output, expected)
285
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
286
    self.assertEqual(result.output, "")
287
    self.assertFileContent(self.fname, expected)
288

    
289
  def testSignal(self):
290
    """Test signal"""
291
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
292
    self.assertEqual(result.signal, 15)
293
    self.assertEqual(result.output, "")
294

    
295
  def testTimeoutClean(self):
296
    cmd = "trap 'exit 0' TERM; read < %s" % self.fifo_file
297
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
298
    self.assertEqual(result.exit_code, 0)
299

    
300
  def testTimeoutKill(self):
301
    cmd = ["/bin/sh", "-c", "trap '' TERM; read < %s" % self.fifo_file]
302
    timeout = 0.2
303
    out, err, status, ta = utils._RunCmdPipe(cmd, {}, False, "/", False,
304
                                             timeout, _linger_timeout=0.2)
305
    self.assert_(status < 0)
306
    self.assertEqual(-status, signal.SIGKILL)
307

    
308
  def testTimeoutOutputAfterTerm(self):
309
    cmd = "trap 'echo sigtermed; exit 1' TERM; read < %s" % self.fifo_file
310
    result = RunCmd(["/bin/sh", "-c", cmd], timeout=0.2)
311
    self.assert_(result.failed)
312
    self.assertEqual(result.stdout, "sigtermed\n")
313

    
314
  def testListRun(self):
315
    """Test list runs"""
316
    result = RunCmd(["true"])
317
    self.assertEqual(result.signal, None)
318
    self.assertEqual(result.exit_code, 0)
319
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
320
    self.assertEqual(result.signal, None)
321
    self.assertEqual(result.exit_code, 1)
322
    result = RunCmd(["echo", "-n", self.magic])
323
    self.assertEqual(result.signal, None)
324
    self.assertEqual(result.exit_code, 0)
325
    self.assertEqual(result.stdout, self.magic)
326

    
327
  def testFileEmptyOutput(self):
328
    """Test file output"""
329
    result = RunCmd(["true"], output=self.fname)
330
    self.assertEqual(result.signal, None)
331
    self.assertEqual(result.exit_code, 0)
332
    self.assertFileContent(self.fname, "")
333

    
334
  def testLang(self):
335
    """Test locale environment"""
336
    old_env = os.environ.copy()
337
    try:
338
      os.environ["LANG"] = "en_US.UTF-8"
339
      os.environ["LC_ALL"] = "en_US.UTF-8"
340
      result = RunCmd(["locale"])
341
      for line in result.output.splitlines():
342
        key, value = line.split("=", 1)
343
        # Ignore these variables, they're overridden by LC_ALL
344
        if key == "LANG" or key == "LANGUAGE":
345
          continue
346
        self.failIf(value and value != "C" and value != '"C"',
347
            "Variable %s is set to the invalid value '%s'" % (key, value))
348
    finally:
349
      os.environ = old_env
350

    
351
  def testDefaultCwd(self):
352
    """Test default working directory"""
353
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
354

    
355
  def testCwd(self):
356
    """Test default working directory"""
357
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
358
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
359
    cwd = os.getcwd()
360
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
361

    
362
  def testResetEnv(self):
363
    """Test environment reset functionality"""
364
    self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "")
365
    self.failUnlessEqual(RunCmd(["env"], reset_env=True,
366
                                env={"FOO": "bar",}).stdout.strip(), "FOO=bar")
367

    
368
  def testNoFork(self):
369
    """Test that nofork raise an error"""
370
    assert not utils.no_fork
371
    utils.no_fork = True
372
    try:
373
      self.assertRaises(errors.ProgrammerError, RunCmd, ["true"])
374
    finally:
375
      utils.no_fork = False
376

    
377
  def testWrongParams(self):
378
    """Test wrong parameters"""
379
    self.assertRaises(errors.ProgrammerError, RunCmd, ["true"],
380
                      output="/dev/null", interactive=True)
381

    
382

    
383
class TestRunParts(testutils.GanetiTestCase):
384
  """Testing case for the RunParts function"""
385

    
386
  def setUp(self):
387
    self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp")
388

    
389
  def tearDown(self):
390
    shutil.rmtree(self.rundir)
391

    
392
  def testEmpty(self):
393
    """Test on an empty dir"""
394
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True), [])
395

    
396
  def testSkipWrongName(self):
397
    """Test that wrong files are skipped"""
398
    fname = os.path.join(self.rundir, "00test.dot")
399
    utils.WriteFile(fname, data="")
400
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
401
    relname = os.path.basename(fname)
402
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
403
                         [(relname, constants.RUNPARTS_SKIP, None)])
404

    
405
  def testSkipNonExec(self):
406
    """Test that non executable files are skipped"""
407
    fname = os.path.join(self.rundir, "00test")
408
    utils.WriteFile(fname, data="")
409
    relname = os.path.basename(fname)
410
    self.failUnlessEqual(RunParts(self.rundir, reset_env=True),
411
                         [(relname, constants.RUNPARTS_SKIP, None)])
412

    
413
  def testError(self):
414
    """Test error on a broken executable"""
415
    fname = os.path.join(self.rundir, "00test")
416
    utils.WriteFile(fname, data="")
417
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
418
    (relname, status, error) = RunParts(self.rundir, reset_env=True)[0]
419
    self.failUnlessEqual(relname, os.path.basename(fname))
420
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
421
    self.failUnless(error)
422

    
423
  def testSorted(self):
424
    """Test executions are sorted"""
425
    files = []
426
    files.append(os.path.join(self.rundir, "64test"))
427
    files.append(os.path.join(self.rundir, "00test"))
428
    files.append(os.path.join(self.rundir, "42test"))
429

    
430
    for fname in files:
431
      utils.WriteFile(fname, data="")
432

    
433
    results = RunParts(self.rundir, reset_env=True)
434

    
435
    for fname in sorted(files):
436
      self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0])
437

    
438
  def testOk(self):
439
    """Test correct execution"""
440
    fname = os.path.join(self.rundir, "00test")
441
    utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao")
442
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
443
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
444
    self.failUnlessEqual(relname, os.path.basename(fname))
445
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
446
    self.failUnlessEqual(runresult.stdout, "ciao")
447

    
448
  def testRunFail(self):
449
    """Test correct execution, with run failure"""
450
    fname = os.path.join(self.rundir, "00test")
451
    utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1")
452
    os.chmod(fname, stat.S_IREAD | stat.S_IEXEC)
453
    (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0]
454
    self.failUnlessEqual(relname, os.path.basename(fname))
455
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
456
    self.failUnlessEqual(runresult.exit_code, 1)
457
    self.failUnless(runresult.failed)
458

    
459
  def testRunMix(self):
460
    files = []
461
    files.append(os.path.join(self.rundir, "00test"))
462
    files.append(os.path.join(self.rundir, "42test"))
463
    files.append(os.path.join(self.rundir, "64test"))
464
    files.append(os.path.join(self.rundir, "99test"))
465

    
466
    files.sort()
467

    
468
    # 1st has errors in execution
469
    utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1")
470
    os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC)
471

    
472
    # 2nd is skipped
473
    utils.WriteFile(files[1], data="")
474

    
475
    # 3rd cannot execute properly
476
    utils.WriteFile(files[2], data="")
477
    os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC)
478

    
479
    # 4th execs
480
    utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao")
481
    os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC)
482

    
483
    results = RunParts(self.rundir, reset_env=True)
484

    
485
    (relname, status, runresult) = results[0]
486
    self.failUnlessEqual(relname, os.path.basename(files[0]))
487
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
488
    self.failUnlessEqual(runresult.exit_code, 1)
489
    self.failUnless(runresult.failed)
490

    
491
    (relname, status, runresult) = results[1]
492
    self.failUnlessEqual(relname, os.path.basename(files[1]))
493
    self.failUnlessEqual(status, constants.RUNPARTS_SKIP)
494
    self.failUnlessEqual(runresult, None)
495

    
496
    (relname, status, runresult) = results[2]
497
    self.failUnlessEqual(relname, os.path.basename(files[2]))
498
    self.failUnlessEqual(status, constants.RUNPARTS_ERR)
499
    self.failUnless(runresult)
500

    
501
    (relname, status, runresult) = results[3]
502
    self.failUnlessEqual(relname, os.path.basename(files[3]))
503
    self.failUnlessEqual(status, constants.RUNPARTS_RUN)
504
    self.failUnlessEqual(runresult.output, "ciao")
505
    self.failUnlessEqual(runresult.exit_code, 0)
506
    self.failUnless(not runresult.failed)
507

    
508
  def testMissingDirectory(self):
509
    nosuchdir = utils.PathJoin(self.rundir, "no/such/directory")
510
    self.assertEqual(RunParts(nosuchdir), [])
511

    
512

    
513
class TestStartDaemon(testutils.GanetiTestCase):
514
  def setUp(self):
515
    self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test")
516
    self.tmpfile = os.path.join(self.tmpdir, "test")
517

    
518
  def tearDown(self):
519
    shutil.rmtree(self.tmpdir)
520

    
521
  def testShell(self):
522
    utils.StartDaemon("echo Hello World > %s" % self.tmpfile)
523
    self._wait(self.tmpfile, 60.0, "Hello World")
524

    
525
  def testShellOutput(self):
526
    utils.StartDaemon("echo Hello World", output=self.tmpfile)
527
    self._wait(self.tmpfile, 60.0, "Hello World")
528

    
529
  def testNoShellNoOutput(self):
530
    utils.StartDaemon(["pwd"])
531

    
532
  def testNoShellNoOutputTouch(self):
533
    testfile = os.path.join(self.tmpdir, "check")
534
    self.failIf(os.path.exists(testfile))
535
    utils.StartDaemon(["touch", testfile])
536
    self._wait(testfile, 60.0, "")
537

    
538
  def testNoShellOutput(self):
539
    utils.StartDaemon(["pwd"], output=self.tmpfile)
540
    self._wait(self.tmpfile, 60.0, "/")
541

    
542
  def testNoShellOutputCwd(self):
543
    utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd())
544
    self._wait(self.tmpfile, 60.0, os.getcwd())
545

    
546
  def testShellEnv(self):
547
    utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile,
548
                      env={ "GNT_TEST_VAR": "Hello World", })
549
    self._wait(self.tmpfile, 60.0, "Hello World")
550

    
551
  def testNoShellEnv(self):
552
    utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile,
553
                      env={ "GNT_TEST_VAR": "Hello World", })
554
    self._wait(self.tmpfile, 60.0, "Hello World")
555

    
556
  def testOutputFd(self):
557
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
558
    try:
559
      utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd())
560
    finally:
561
      os.close(fd)
562
    self._wait(self.tmpfile, 60.0, os.getcwd())
563

    
564
  def testPid(self):
565
    pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile)
566
    self._wait(self.tmpfile, 60.0, str(pid))
567

    
568
  def testPidFile(self):
569
    pidfile = os.path.join(self.tmpdir, "pid")
570
    checkfile = os.path.join(self.tmpdir, "abort")
571

    
572
    pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile,
573
                            output=self.tmpfile)
574
    try:
575
      fd = os.open(pidfile, os.O_RDONLY)
576
      try:
577
        # Check file is locked
578
        self.assertRaises(errors.LockError, utils.LockFile, fd)
579

    
580
        pidtext = os.read(fd, 100)
581
      finally:
582
        os.close(fd)
583

    
584
      self.assertEqual(int(pidtext.strip()), pid)
585

    
586
      self.assert_(utils.IsProcessAlive(pid))
587
    finally:
588
      # No matter what happens, kill daemon
589
      utils.KillProcess(pid, timeout=5.0, waitpid=False)
590
      self.failIf(utils.IsProcessAlive(pid))
591

    
592
    self.assertEqual(utils.ReadFile(self.tmpfile), "")
593

    
594
  def _wait(self, path, timeout, expected):
595
    # Due to the asynchronous nature of daemon processes, polling is necessary.
596
    # A timeout makes sure the test doesn't hang forever.
597
    def _CheckFile():
598
      if not (os.path.isfile(path) and
599
              utils.ReadFile(path).strip() == expected):
600
        raise utils.RetryAgain()
601

    
602
    try:
603
      utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout)
604
    except utils.RetryTimeout:
605
      self.fail("Apparently the daemon didn't run in %s seconds and/or"
606
                " didn't write the correct output" % timeout)
607

    
608
  def testError(self):
609
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
610
                      ["./does-NOT-EXIST/here/0123456789"])
611
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
612
                      ["./does-NOT-EXIST/here/0123456789"],
613
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
614
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
615
                      ["./does-NOT-EXIST/here/0123456789"],
616
                      cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
617
    self.assertRaises(errors.OpExecError, utils.StartDaemon,
618
                      ["./does-NOT-EXIST/here/0123456789"],
619
                      output=os.path.join(self.tmpdir, "DIR/NOT/EXIST"))
620

    
621
    fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT)
622
    try:
623
      self.assertRaises(errors.ProgrammerError, utils.StartDaemon,
624
                        ["./does-NOT-EXIST/here/0123456789"],
625
                        output=self.tmpfile, output_fd=fd)
626
    finally:
627
      os.close(fd)
628

    
629

    
630
class TestSetCloseOnExecFlag(unittest.TestCase):
631
  """Tests for SetCloseOnExecFlag"""
632

    
633
  def setUp(self):
634
    self.tmpfile = tempfile.TemporaryFile()
635

    
636
  def testEnable(self):
637
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True)
638
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
639
                    fcntl.FD_CLOEXEC)
640

    
641
  def testDisable(self):
642
    utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False)
643
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) &
644
                fcntl.FD_CLOEXEC)
645

    
646

    
647
class TestSetNonblockFlag(unittest.TestCase):
648
  def setUp(self):
649
    self.tmpfile = tempfile.TemporaryFile()
650

    
651
  def testEnable(self):
652
    utils.SetNonblockFlag(self.tmpfile.fileno(), True)
653
    self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
654
                    os.O_NONBLOCK)
655

    
656
  def testDisable(self):
657
    utils.SetNonblockFlag(self.tmpfile.fileno(), False)
658
    self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) &
659
                os.O_NONBLOCK)
660

    
661

    
662
class TestRemoveFile(unittest.TestCase):
663
  """Test case for the RemoveFile function"""
664

    
665
  def setUp(self):
666
    """Create a temp dir and file for each case"""
667
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
668
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
669
    os.close(fd)
670

    
671
  def tearDown(self):
672
    if os.path.exists(self.tmpfile):
673
      os.unlink(self.tmpfile)
674
    os.rmdir(self.tmpdir)
675

    
676
  def testIgnoreDirs(self):
677
    """Test that RemoveFile() ignores directories"""
678
    self.assertEqual(None, RemoveFile(self.tmpdir))
679

    
680
  def testIgnoreNotExisting(self):
681
    """Test that RemoveFile() ignores non-existing files"""
682
    RemoveFile(self.tmpfile)
683
    RemoveFile(self.tmpfile)
684

    
685
  def testRemoveFile(self):
686
    """Test that RemoveFile does remove a file"""
687
    RemoveFile(self.tmpfile)
688
    if os.path.exists(self.tmpfile):
689
      self.fail("File '%s' not removed" % self.tmpfile)
690

    
691
  def testRemoveSymlink(self):
692
    """Test that RemoveFile does remove symlinks"""
693
    symlink = self.tmpdir + "/symlink"
694
    os.symlink("no-such-file", symlink)
695
    RemoveFile(symlink)
696
    if os.path.exists(symlink):
697
      self.fail("File '%s' not removed" % symlink)
698
    os.symlink(self.tmpfile, symlink)
699
    RemoveFile(symlink)
700
    if os.path.exists(symlink):
701
      self.fail("File '%s' not removed" % symlink)
702

    
703

    
704
class TestRemoveDir(unittest.TestCase):
705
  def setUp(self):
706
    self.tmpdir = tempfile.mkdtemp()
707

    
708
  def tearDown(self):
709
    try:
710
      shutil.rmtree(self.tmpdir)
711
    except EnvironmentError:
712
      pass
713

    
714
  def testEmptyDir(self):
715
    utils.RemoveDir(self.tmpdir)
716
    self.assertFalse(os.path.isdir(self.tmpdir))
717

    
718
  def testNonEmptyDir(self):
719
    self.tmpfile = os.path.join(self.tmpdir, "test1")
720
    open(self.tmpfile, "w").close()
721
    self.assertRaises(EnvironmentError, utils.RemoveDir, self.tmpdir)
722

    
723

    
724
class TestRename(unittest.TestCase):
725
  """Test case for RenameFile"""
726

    
727
  def setUp(self):
728
    """Create a temporary directory"""
729
    self.tmpdir = tempfile.mkdtemp()
730
    self.tmpfile = os.path.join(self.tmpdir, "test1")
731

    
732
    # Touch the file
733
    open(self.tmpfile, "w").close()
734

    
735
  def tearDown(self):
736
    """Remove temporary directory"""
737
    shutil.rmtree(self.tmpdir)
738

    
739
  def testSimpleRename1(self):
740
    """Simple rename 1"""
741
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
742
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
743

    
744
  def testSimpleRename2(self):
745
    """Simple rename 2"""
746
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
747
                     mkdir=True)
748
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
749

    
750
  def testRenameMkdir(self):
751
    """Rename with mkdir"""
752
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
753
                     mkdir=True)
754
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
755
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
756

    
757
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
758
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
759
                     mkdir=True)
760
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
761
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
762
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
763

    
764

    
765
class TestMatchNameComponent(unittest.TestCase):
766
  """Test case for the MatchNameComponent function"""
767

    
768
  def testEmptyList(self):
769
    """Test that there is no match against an empty list"""
770

    
771
    self.failUnlessEqual(MatchNameComponent("", []), None)
772
    self.failUnlessEqual(MatchNameComponent("test", []), None)
773

    
774
  def testSingleMatch(self):
775
    """Test that a single match is performed correctly"""
776
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
777
    for key in "test2", "test2.example", "test2.example.com":
778
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
779

    
780
  def testMultipleMatches(self):
781
    """Test that a multiple match is returned as None"""
782
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
783
    for key in "test1", "test1.example":
784
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
785

    
786
  def testFullMatch(self):
787
    """Test that a full match is returned correctly"""
788
    key1 = "test1"
789
    key2 = "test1.example"
790
    mlist = [key2, key2 + ".com"]
791
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
792
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
793

    
794
  def testCaseInsensitivePartialMatch(self):
795
    """Test for the case_insensitive keyword"""
796
    mlist = ["test1.example.com", "test2.example.net"]
797
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
798
                     "test2.example.net")
799
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
800
                     "test2.example.net")
801
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
802
                     "test2.example.net")
803
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
804
                     "test2.example.net")
805

    
806

    
807
  def testCaseInsensitiveFullMatch(self):
808
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
809
    # Between the two ts1 a full string match non-case insensitive should work
810
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
811
                     None)
812
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
813
                     "ts1.ex")
814
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
815
                     "ts1.ex")
816
    # Between the two ts2 only case differs, so only case-match works
817
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
818
                     "ts2.ex")
819
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
820
                     "Ts2.ex")
821
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
822
                     None)
823

    
824

    
825
class TestReadFile(testutils.GanetiTestCase):
826

    
827
  def testReadAll(self):
828
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
829
    self.assertEqual(len(data), 814)
830

    
831
    h = compat.md5_hash()
832
    h.update(data)
833
    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
834

    
835
  def testReadSize(self):
836
    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
837
                          size=100)
838
    self.assertEqual(len(data), 100)
839

    
840
    h = compat.md5_hash()
841
    h.update(data)
842
    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
843

    
844
  def testError(self):
845
    self.assertRaises(EnvironmentError, utils.ReadFile,
846
                      "/dev/null/does-not-exist")
847

    
848

    
849
class TestReadOneLineFile(testutils.GanetiTestCase):
850

    
851
  def setUp(self):
852
    testutils.GanetiTestCase.setUp(self)
853

    
854
  def testDefault(self):
855
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
856
    self.assertEqual(len(data), 27)
857
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
858

    
859
  def testNotStrict(self):
860
    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
861
    self.assertEqual(len(data), 27)
862
    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
863

    
864
  def testStrictFailure(self):
865
    self.assertRaises(errors.GenericError, ReadOneLineFile,
866
                      self._TestDataFilename("cert1.pem"), strict=True)
867

    
868
  def testLongLine(self):
869
    dummydata = (1024 * "Hello World! ")
870
    myfile = self._CreateTempFile()
871
    utils.WriteFile(myfile, data=dummydata)
872
    datastrict = ReadOneLineFile(myfile, strict=True)
873
    datalax = ReadOneLineFile(myfile, strict=False)
874
    self.assertEqual(dummydata, datastrict)
875
    self.assertEqual(dummydata, datalax)
876

    
877
  def testNewline(self):
878
    myfile = self._CreateTempFile()
879
    myline = "myline"
880
    for nl in ["", "\n", "\r\n"]:
881
      dummydata = "%s%s" % (myline, nl)
882
      utils.WriteFile(myfile, data=dummydata)
883
      datalax = ReadOneLineFile(myfile, strict=False)
884
      self.assertEqual(myline, datalax)
885
      datastrict = ReadOneLineFile(myfile, strict=True)
886
      self.assertEqual(myline, datastrict)
887

    
888
  def testWhitespaceAndMultipleLines(self):
889
    myfile = self._CreateTempFile()
890
    for nl in ["", "\n", "\r\n"]:
891
      for ws in [" ", "\t", "\t\t  \t", "\t "]:
892
        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
893
        utils.WriteFile(myfile, data=dummydata)
894
        datalax = ReadOneLineFile(myfile, strict=False)
895
        if nl:
896
          self.assert_(set("\r\n") & set(dummydata))
897
          self.assertRaises(errors.GenericError, ReadOneLineFile,
898
                            myfile, strict=True)
899
          explen = len("Foo bar baz ") + len(ws)
900
          self.assertEqual(len(datalax), explen)
901
          self.assertEqual(datalax, dummydata[:explen])
902
          self.assertFalse(set("\r\n") & set(datalax))
903
        else:
904
          datastrict = ReadOneLineFile(myfile, strict=True)
905
          self.assertEqual(dummydata, datastrict)
906
          self.assertEqual(dummydata, datalax)
907

    
908
  def testEmptylines(self):
909
    myfile = self._CreateTempFile()
910
    myline = "myline"
911
    for nl in ["\n", "\r\n"]:
912
      for ol in ["", "otherline"]:
913
        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
914
        utils.WriteFile(myfile, data=dummydata)
915
        self.assert_(set("\r\n") & set(dummydata))
916
        datalax = ReadOneLineFile(myfile, strict=False)
917
        self.assertEqual(myline, datalax)
918
        if ol:
919
          self.assertRaises(errors.GenericError, ReadOneLineFile,
920
                            myfile, strict=True)
921
        else:
922
          datastrict = ReadOneLineFile(myfile, strict=True)
923
          self.assertEqual(myline, datastrict)
924

    
925
  def testEmptyfile(self):
926
    myfile = self._CreateTempFile()
927
    self.assertRaises(errors.GenericError, ReadOneLineFile, myfile)
928

    
929

    
930
class TestTimestampForFilename(unittest.TestCase):
931
  def test(self):
932
    self.assert_("." not in utils.TimestampForFilename())
933
    self.assert_(":" not in utils.TimestampForFilename())
934

    
935

    
936
class TestCreateBackup(testutils.GanetiTestCase):
937
  def setUp(self):
938
    testutils.GanetiTestCase.setUp(self)
939

    
940
    self.tmpdir = tempfile.mkdtemp()
941

    
942
  def tearDown(self):
943
    testutils.GanetiTestCase.tearDown(self)
944

    
945
    shutil.rmtree(self.tmpdir)
946

    
947
  def testEmpty(self):
948
    filename = PathJoin(self.tmpdir, "config.data")
949
    utils.WriteFile(filename, data="")
950
    bname = utils.CreateBackup(filename)
951
    self.assertFileContent(bname, "")
952
    self.assertEqual(len(glob.glob("%s*" % filename)), 2)
953
    utils.CreateBackup(filename)
954
    self.assertEqual(len(glob.glob("%s*" % filename)), 3)
955
    utils.CreateBackup(filename)
956
    self.assertEqual(len(glob.glob("%s*" % filename)), 4)
957

    
958
    fifoname = PathJoin(self.tmpdir, "fifo")
959
    os.mkfifo(fifoname)
960
    self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname)
961

    
962
  def testContent(self):
963
    bkpcount = 0
964
    for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]:
965
      for rep in [1, 2, 10, 127]:
966
        testdata = data * rep
967

    
968
        filename = PathJoin(self.tmpdir, "test.data_")
969
        utils.WriteFile(filename, data=testdata)
970
        self.assertFileContent(filename, testdata)
971

    
972
        for _ in range(3):
973
          bname = utils.CreateBackup(filename)
974
          bkpcount += 1
975
          self.assertFileContent(bname, testdata)
976
          self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount)
977

    
978

    
979
class TestFormatUnit(unittest.TestCase):
980
  """Test case for the FormatUnit function"""
981

    
982
  def testMiB(self):
983
    self.assertEqual(FormatUnit(1, 'h'), '1M')
984
    self.assertEqual(FormatUnit(100, 'h'), '100M')
985
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
986

    
987
    self.assertEqual(FormatUnit(1, 'm'), '1')
988
    self.assertEqual(FormatUnit(100, 'm'), '100')
989
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
990

    
991
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
992
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
993
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
994
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
995

    
996
  def testGiB(self):
997
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
998
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
999
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
1000
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
1001

    
1002
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
1003
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
1004
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
1005
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
1006

    
1007
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
1008
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
1009
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
1010

    
1011
  def testTiB(self):
1012
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
1013
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
1014
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
1015

    
1016
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
1017
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
1018
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
1019

    
1020
  def testErrors(self):
1021
    self.assertRaises(errors.ProgrammerError, FormatUnit, 1, "a")
1022

    
1023

    
1024
class TestParseUnit(unittest.TestCase):
1025
  """Test case for the ParseUnit function"""
1026

    
1027
  SCALES = (('', 1),
1028
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
1029
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
1030
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
1031

    
1032
  def testRounding(self):
1033
    self.assertEqual(ParseUnit('0'), 0)
1034
    self.assertEqual(ParseUnit('1'), 4)
1035
    self.assertEqual(ParseUnit('2'), 4)
1036
    self.assertEqual(ParseUnit('3'), 4)
1037

    
1038
    self.assertEqual(ParseUnit('124'), 124)
1039
    self.assertEqual(ParseUnit('125'), 128)
1040
    self.assertEqual(ParseUnit('126'), 128)
1041
    self.assertEqual(ParseUnit('127'), 128)
1042
    self.assertEqual(ParseUnit('128'), 128)
1043
    self.assertEqual(ParseUnit('129'), 132)
1044
    self.assertEqual(ParseUnit('130'), 132)
1045

    
1046
  def testFloating(self):
1047
    self.assertEqual(ParseUnit('0'), 0)
1048
    self.assertEqual(ParseUnit('0.5'), 4)
1049
    self.assertEqual(ParseUnit('1.75'), 4)
1050
    self.assertEqual(ParseUnit('1.99'), 4)
1051
    self.assertEqual(ParseUnit('2.00'), 4)
1052
    self.assertEqual(ParseUnit('2.01'), 4)
1053
    self.assertEqual(ParseUnit('3.99'), 4)
1054
    self.assertEqual(ParseUnit('4.00'), 4)
1055
    self.assertEqual(ParseUnit('4.01'), 8)
1056
    self.assertEqual(ParseUnit('1.5G'), 1536)
1057
    self.assertEqual(ParseUnit('1.8G'), 1844)
1058
    self.assertEqual(ParseUnit('8.28T'), 8682212)
1059

    
1060
  def testSuffixes(self):
1061
    for sep in ('', ' ', '   ', "\t", "\t "):
1062
      for suffix, scale in TestParseUnit.SCALES:
1063
        for func in (lambda x: x, str.lower, str.upper):
1064
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
1065
                           1024 * scale)
1066

    
1067
  def testInvalidInput(self):
1068
    for sep in ('-', '_', ',', 'a'):
1069
      for suffix, _ in TestParseUnit.SCALES:
1070
        self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix)
1071

    
1072
    for suffix, _ in TestParseUnit.SCALES:
1073
      self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix)
1074

    
1075

    
1076
class TestParseCpuMask(unittest.TestCase):
1077
  """Test case for the ParseCpuMask function."""
1078

    
1079
  def testWellFormed(self):
1080
    self.assertEqual(utils.ParseCpuMask(""), [])
1081
    self.assertEqual(utils.ParseCpuMask("1"), [1])
1082
    self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5])
1083

    
1084
  def testInvalidInput(self):
1085
    for data in ["garbage", "0,", "0-1-2", "2-1", "1-a"]:
1086
      self.assertRaises(errors.ParseError, utils.ParseCpuMask, data)
1087

    
1088

    
1089
class TestSshKeys(testutils.GanetiTestCase):
1090
  """Test case for the AddAuthorizedKey function"""
1091

    
1092
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
1093
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" '
1094
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
1095

    
1096
  def setUp(self):
1097
    testutils.GanetiTestCase.setUp(self)
1098
    self.tmpname = self._CreateTempFile()
1099
    handle = open(self.tmpname, 'w')
1100
    try:
1101
      handle.write("%s\n" % TestSshKeys.KEY_A)
1102
      handle.write("%s\n" % TestSshKeys.KEY_B)
1103
    finally:
1104
      handle.close()
1105

    
1106
  def testAddingNewKey(self):
1107
    utils.AddAuthorizedKey(self.tmpname,
1108
                           'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
1109

    
1110
    self.assertFileContent(self.tmpname,
1111
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1112
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1113
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1114
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
1115

    
1116
  def testAddingAlmostButNotCompletelyTheSameKey(self):
1117
    utils.AddAuthorizedKey(self.tmpname,
1118
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
1119

    
1120
    self.assertFileContent(self.tmpname,
1121
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1122
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1123
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
1124
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
1125

    
1126
  def testAddingExistingKeyWithSomeMoreSpaces(self):
1127
    utils.AddAuthorizedKey(self.tmpname,
1128
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1129

    
1130
    self.assertFileContent(self.tmpname,
1131
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1132
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1133
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1134

    
1135
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
1136
    utils.RemoveAuthorizedKey(self.tmpname,
1137
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
1138

    
1139
    self.assertFileContent(self.tmpname,
1140
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1141
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1142

    
1143
  def testRemovingNonExistingKey(self):
1144
    utils.RemoveAuthorizedKey(self.tmpname,
1145
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
1146

    
1147
    self.assertFileContent(self.tmpname,
1148
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
1149
      'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"'
1150
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
1151

    
1152

    
1153
class TestEtcHosts(testutils.GanetiTestCase):
1154
  """Test functions modifying /etc/hosts"""
1155

    
1156
  def setUp(self):
1157
    testutils.GanetiTestCase.setUp(self)
1158
    self.tmpname = self._CreateTempFile()
1159
    handle = open(self.tmpname, 'w')
1160
    try:
1161
      handle.write('# This is a test file for /etc/hosts\n')
1162
      handle.write('127.0.0.1\tlocalhost\n')
1163
      handle.write('192.0.2.1 router gw\n')
1164
    finally:
1165
      handle.close()
1166

    
1167
  def testSettingNewIp(self):
1168
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com',
1169
                     ['myhost'])
1170

    
1171
    self.assertFileContent(self.tmpname,
1172
      "# This is a test file for /etc/hosts\n"
1173
      "127.0.0.1\tlocalhost\n"
1174
      "192.0.2.1 router gw\n"
1175
      "198.51.100.4\tmyhost.example.com myhost\n")
1176
    self.assertFileMode(self.tmpname, 0644)
1177

    
1178
  def testSettingExistingIp(self):
1179
    SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com',
1180
                     ['myhost'])
1181

    
1182
    self.assertFileContent(self.tmpname,
1183
      "# This is a test file for /etc/hosts\n"
1184
      "127.0.0.1\tlocalhost\n"
1185
      "192.0.2.1\tmyhost.example.com myhost\n")
1186
    self.assertFileMode(self.tmpname, 0644)
1187

    
1188
  def testSettingDuplicateName(self):
1189
    SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost'])
1190

    
1191
    self.assertFileContent(self.tmpname,
1192
      "# This is a test file for /etc/hosts\n"
1193
      "127.0.0.1\tlocalhost\n"
1194
      "192.0.2.1 router gw\n"
1195
      "198.51.100.4\tmyhost\n")
1196
    self.assertFileMode(self.tmpname, 0644)
1197

    
1198
  def testRemovingExistingHost(self):
1199
    RemoveEtcHostsEntry(self.tmpname, 'router')
1200

    
1201
    self.assertFileContent(self.tmpname,
1202
      "# This is a test file for /etc/hosts\n"
1203
      "127.0.0.1\tlocalhost\n"
1204
      "192.0.2.1 gw\n")
1205
    self.assertFileMode(self.tmpname, 0644)
1206

    
1207
  def testRemovingSingleExistingHost(self):
1208
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
1209

    
1210
    self.assertFileContent(self.tmpname,
1211
      "# This is a test file for /etc/hosts\n"
1212
      "192.0.2.1 router gw\n")
1213
    self.assertFileMode(self.tmpname, 0644)
1214

    
1215
  def testRemovingNonExistingHost(self):
1216
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
1217

    
1218
    self.assertFileContent(self.tmpname,
1219
      "# This is a test file for /etc/hosts\n"
1220
      "127.0.0.1\tlocalhost\n"
1221
      "192.0.2.1 router gw\n")
1222
    self.assertFileMode(self.tmpname, 0644)
1223

    
1224
  def testRemovingAlias(self):
1225
    RemoveEtcHostsEntry(self.tmpname, 'gw')
1226

    
1227
    self.assertFileContent(self.tmpname,
1228
      "# This is a test file for /etc/hosts\n"
1229
      "127.0.0.1\tlocalhost\n"
1230
      "192.0.2.1 router\n")
1231
    self.assertFileMode(self.tmpname, 0644)
1232

    
1233

    
1234
class TestGetMounts(unittest.TestCase):
1235
  """Test case for GetMounts()."""
1236

    
1237
  TESTDATA = (
1238
    "rootfs /     rootfs rw 0 0\n"
1239
    "none   /sys  sysfs  rw,nosuid,nodev,noexec,relatime 0 0\n"
1240
    "none   /proc proc   rw,nosuid,nodev,noexec,relatime 0 0\n")
1241

    
1242
  def setUp(self):
1243
    self.tmpfile = tempfile.NamedTemporaryFile()
1244
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1245

    
1246
  def testGetMounts(self):
1247
    self.assertEqual(utils.GetMounts(filename=self.tmpfile.name),
1248
      [
1249
        ("rootfs", "/", "rootfs", "rw"),
1250
        ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"),
1251
        ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"),
1252
      ])
1253

    
1254

    
1255
class TestShellQuoting(unittest.TestCase):
1256
  """Test case for shell quoting functions"""
1257

    
1258
  def testShellQuote(self):
1259
    self.assertEqual(ShellQuote('abc'), "abc")
1260
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
1261
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
1262
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
1263
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
1264

    
1265
  def testShellQuoteArgs(self):
1266
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
1267
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
1268
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
1269

    
1270

    
1271
class TestListVisibleFiles(unittest.TestCase):
1272
  """Test case for ListVisibleFiles"""
1273

    
1274
  def setUp(self):
1275
    self.path = tempfile.mkdtemp()
1276

    
1277
  def tearDown(self):
1278
    shutil.rmtree(self.path)
1279

    
1280
  def _CreateFiles(self, files):
1281
    for name in files:
1282
      utils.WriteFile(os.path.join(self.path, name), data="test")
1283

    
1284
  def _test(self, files, expected):
1285
    self._CreateFiles(files)
1286
    found = ListVisibleFiles(self.path)
1287
    self.assertEqual(set(found), set(expected))
1288

    
1289
  def testAllVisible(self):
1290
    files = ["a", "b", "c"]
1291
    expected = files
1292
    self._test(files, expected)
1293

    
1294
  def testNoneVisible(self):
1295
    files = [".a", ".b", ".c"]
1296
    expected = []
1297
    self._test(files, expected)
1298

    
1299
  def testSomeVisible(self):
1300
    files = ["a", "b", ".c"]
1301
    expected = ["a", "b"]
1302
    self._test(files, expected)
1303

    
1304
  def testNonAbsolutePath(self):
1305
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc")
1306

    
1307
  def testNonNormalizedPath(self):
1308
    self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles,
1309
                          "/bin/../tmp")
1310

    
1311

    
1312
class TestNewUUID(unittest.TestCase):
1313
  """Test case for NewUUID"""
1314

    
1315
  def runTest(self):
1316
    self.failUnless(utils.UUID_RE.match(utils.NewUUID()))
1317

    
1318

    
1319
class TestUniqueSequence(unittest.TestCase):
1320
  """Test case for UniqueSequence"""
1321

    
1322
  def _test(self, input, expected):
1323
    self.assertEqual(utils.UniqueSequence(input), expected)
1324

    
1325
  def runTest(self):
1326
    # Ordered input
1327
    self._test([1, 2, 3], [1, 2, 3])
1328
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
1329
    self._test([1, 2, 2, 3], [1, 2, 3])
1330
    self._test([1, 2, 3, 3], [1, 2, 3])
1331

    
1332
    # Unordered input
1333
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
1334
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
1335

    
1336
    # Strings
1337
    self._test(["a", "a"], ["a"])
1338
    self._test(["a", "b"], ["a", "b"])
1339
    self._test(["a", "b", "a"], ["a", "b"])
1340

    
1341

    
1342
class TestFindDuplicates(unittest.TestCase):
1343
  """Test case for FindDuplicates"""
1344

    
1345
  def _Test(self, seq, expected):
1346
    result = utils.FindDuplicates(seq)
1347
    self.assertEqual(result, utils.UniqueSequence(result))
1348
    self.assertEqual(set(result), set(expected))
1349

    
1350
  def test(self):
1351
    self._Test([], [])
1352
    self._Test([1, 2, 3], [])
1353
    self._Test([9, 8, 8, 0, 5, 1, 7, 0, 6, 7], [8, 0, 7])
1354
    for exp in [[1, 2, 3], [3, 2, 1]]:
1355
      self._Test([1, 1, 2, 2, 3, 3], exp)
1356

    
1357
    self._Test(["A", "a", "B"], [])
1358
    self._Test(["a", "A", "a", "B"], ["a"])
1359
    self._Test("Hello World out there!", ["e", " ", "o", "r", "t", "l"])
1360

    
1361
    self._Test(self._Gen(False), [])
1362
    self._Test(self._Gen(True), range(1, 10))
1363

    
1364
  @staticmethod
1365
  def _Gen(dup):
1366
    for i in range(10):
1367
      yield i
1368
      if dup:
1369
        for _ in range(i):
1370
          yield i
1371

    
1372

    
1373
class TestFirstFree(unittest.TestCase):
1374
  """Test case for the FirstFree function"""
1375

    
1376
  def test(self):
1377
    """Test FirstFree"""
1378
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
1379
    self.failUnlessEqual(FirstFree([]), None)
1380
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
1381
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
1382
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
1383

    
1384

    
1385
class TestTailFile(testutils.GanetiTestCase):
1386
  """Test case for the TailFile function"""
1387

    
1388
  def testEmpty(self):
1389
    fname = self._CreateTempFile()
1390
    self.failUnlessEqual(TailFile(fname), [])
1391
    self.failUnlessEqual(TailFile(fname, lines=25), [])
1392

    
1393
  def testAllLines(self):
1394
    data = ["test %d" % i for i in range(30)]
1395
    for i in range(30):
1396
      fname = self._CreateTempFile()
1397
      fd = open(fname, "w")
1398
      fd.write("\n".join(data[:i]))
1399
      if i > 0:
1400
        fd.write("\n")
1401
      fd.close()
1402
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
1403

    
1404
  def testPartialLines(self):
1405
    data = ["test %d" % i for i in range(30)]
1406
    fname = self._CreateTempFile()
1407
    fd = open(fname, "w")
1408
    fd.write("\n".join(data))
1409
    fd.write("\n")
1410
    fd.close()
1411
    for i in range(1, 30):
1412
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1413

    
1414
  def testBigFile(self):
1415
    data = ["test %d" % i for i in range(30)]
1416
    fname = self._CreateTempFile()
1417
    fd = open(fname, "w")
1418
    fd.write("X" * 1048576)
1419
    fd.write("\n")
1420
    fd.write("\n".join(data))
1421
    fd.write("\n")
1422
    fd.close()
1423
    for i in range(1, 30):
1424
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
1425

    
1426

    
1427
class _BaseFileLockTest:
1428
  """Test case for the FileLock class"""
1429

    
1430
  def testSharedNonblocking(self):
1431
    self.lock.Shared(blocking=False)
1432
    self.lock.Close()
1433

    
1434
  def testExclusiveNonblocking(self):
1435
    self.lock.Exclusive(blocking=False)
1436
    self.lock.Close()
1437

    
1438
  def testUnlockNonblocking(self):
1439
    self.lock.Unlock(blocking=False)
1440
    self.lock.Close()
1441

    
1442
  def testSharedBlocking(self):
1443
    self.lock.Shared(blocking=True)
1444
    self.lock.Close()
1445

    
1446
  def testExclusiveBlocking(self):
1447
    self.lock.Exclusive(blocking=True)
1448
    self.lock.Close()
1449

    
1450
  def testUnlockBlocking(self):
1451
    self.lock.Unlock(blocking=True)
1452
    self.lock.Close()
1453

    
1454
  def testSharedExclusiveUnlock(self):
1455
    self.lock.Shared(blocking=False)
1456
    self.lock.Exclusive(blocking=False)
1457
    self.lock.Unlock(blocking=False)
1458
    self.lock.Close()
1459

    
1460
  def testExclusiveSharedUnlock(self):
1461
    self.lock.Exclusive(blocking=False)
1462
    self.lock.Shared(blocking=False)
1463
    self.lock.Unlock(blocking=False)
1464
    self.lock.Close()
1465

    
1466
  def testSimpleTimeout(self):
1467
    # These will succeed on the first attempt, hence a short timeout
1468
    self.lock.Shared(blocking=True, timeout=10.0)
1469
    self.lock.Exclusive(blocking=False, timeout=10.0)
1470
    self.lock.Unlock(blocking=True, timeout=10.0)
1471
    self.lock.Close()
1472

    
1473
  @staticmethod
1474
  def _TryLockInner(filename, shared, blocking):
1475
    lock = utils.FileLock.Open(filename)
1476

    
1477
    if shared:
1478
      fn = lock.Shared
1479
    else:
1480
      fn = lock.Exclusive
1481

    
1482
    try:
1483
      # The timeout doesn't really matter as the parent process waits for us to
1484
      # finish anyway.
1485
      fn(blocking=blocking, timeout=0.01)
1486
    except errors.LockError, err:
1487
      return False
1488

    
1489
    return True
1490

    
1491
  def _TryLock(self, *args):
1492
    return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name,
1493
                                      *args)
1494

    
1495
  def testTimeout(self):
1496
    for blocking in [True, False]:
1497
      self.lock.Exclusive(blocking=True)
1498
      self.failIf(self._TryLock(False, blocking))
1499
      self.failIf(self._TryLock(True, blocking))
1500

    
1501
      self.lock.Shared(blocking=True)
1502
      self.assert_(self._TryLock(True, blocking))
1503
      self.failIf(self._TryLock(False, blocking))
1504

    
1505
  def testCloseShared(self):
1506
    self.lock.Close()
1507
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
1508

    
1509
  def testCloseExclusive(self):
1510
    self.lock.Close()
1511
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
1512

    
1513
  def testCloseUnlock(self):
1514
    self.lock.Close()
1515
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
1516

    
1517

    
1518
class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest):
1519
  TESTDATA = "Hello World\n" * 10
1520

    
1521
  def setUp(self):
1522
    testutils.GanetiTestCase.setUp(self)
1523

    
1524
    self.tmpfile = tempfile.NamedTemporaryFile()
1525
    utils.WriteFile(self.tmpfile.name, data=self.TESTDATA)
1526
    self.lock = utils.FileLock.Open(self.tmpfile.name)
1527

    
1528
    # Ensure "Open" didn't truncate file
1529
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1530

    
1531
  def tearDown(self):
1532
    self.assertFileContent(self.tmpfile.name, self.TESTDATA)
1533

    
1534
    testutils.GanetiTestCase.tearDown(self)
1535

    
1536

    
1537
class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest):
1538
  def setUp(self):
1539
    self.tmpfile = tempfile.NamedTemporaryFile()
1540
    self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name)
1541

    
1542

    
1543
class TestTimeFunctions(unittest.TestCase):
1544
  """Test case for time functions"""
1545

    
1546
  def runTest(self):
1547
    self.assertEqual(utils.SplitTime(1), (1, 0))
1548
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
1549
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
1550
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
1551
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
1552
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
1553
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
1554
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
1555

    
1556
    self.assertRaises(AssertionError, utils.SplitTime, -1)
1557

    
1558
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
1559
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
1560
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
1561

    
1562
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
1563
                     1218448917.481)
1564
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
1565

    
1566
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
1567
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
1568
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
1569
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
1570
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
1571

    
1572

    
1573
class FieldSetTestCase(unittest.TestCase):
1574
  """Test case for FieldSets"""
1575

    
1576
  def testSimpleMatch(self):
1577
    f = utils.FieldSet("a", "b", "c", "def")
1578
    self.failUnless(f.Matches("a"))
1579
    self.failIf(f.Matches("d"), "Substring matched")
1580
    self.failIf(f.Matches("defghi"), "Prefix string matched")
1581
    self.failIf(f.NonMatching(["b", "c"]))
1582
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
1583
    self.failUnless(f.NonMatching(["a", "d"]))
1584

    
1585
  def testRegexMatch(self):
1586
    f = utils.FieldSet("a", "b([0-9]+)", "c")
1587
    self.failUnless(f.Matches("b1"))
1588
    self.failUnless(f.Matches("b99"))
1589
    self.failIf(f.Matches("b/1"))
1590
    self.failIf(f.NonMatching(["b12", "c"]))
1591
    self.failUnless(f.NonMatching(["a", "1"]))
1592

    
1593
class TestForceDictType(unittest.TestCase):
1594
  """Test case for ForceDictType"""
1595
  KEY_TYPES = {
1596
    "a": constants.VTYPE_INT,
1597
    "b": constants.VTYPE_BOOL,
1598
    "c": constants.VTYPE_STRING,
1599
    "d": constants.VTYPE_SIZE,
1600
    "e": constants.VTYPE_MAYBE_STRING,
1601
    }
1602

    
1603
  def _fdt(self, dict, allowed_values=None):
1604
    if allowed_values is None:
1605
      utils.ForceDictType(dict, self.KEY_TYPES)
1606
    else:
1607
      utils.ForceDictType(dict, self.KEY_TYPES, allowed_values=allowed_values)
1608

    
1609
    return dict
1610

    
1611
  def testSimpleDict(self):
1612
    self.assertEqual(self._fdt({}), {})
1613
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
1614
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
1615
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
1616
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
1617
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
1618
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
1619
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
1620
    self.assertEqual(self._fdt({'b': False}), {'b': False})
1621
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
1622
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
1623
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
1624
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
1625
    self.assertEqual(self._fdt({"e": None, }), {"e": None, })
1626
    self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", })
1627
    self.assertEqual(self._fdt({"e": False, }), {"e": '', })
1628
    self.assertEqual(self._fdt({"b": "hello", }, ["hello"]), {"b": "hello"})
1629

    
1630
  def testErrors(self):
1631
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
1632
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"b": "hello"})
1633
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
1634
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
1635
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1636
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), })
1637
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], })
1638
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {"x": None, })
1639
    self.assertRaises(errors.TypeEnforcementError, self._fdt, [])
1640
    self.assertRaises(errors.ProgrammerError, utils.ForceDictType,
1641
                      {"b": "hello"}, {"b": "no-such-type"})
1642

    
1643

    
1644
class TestIsNormAbsPath(unittest.TestCase):
1645
  """Testing case for IsNormAbsPath"""
1646

    
1647
  def _pathTestHelper(self, path, result):
1648
    if result:
1649
      self.assert_(utils.IsNormAbsPath(path),
1650
          "Path %s should result absolute and normalized" % path)
1651
    else:
1652
      self.assertFalse(utils.IsNormAbsPath(path),
1653
          "Path %s should not result absolute and normalized" % path)
1654

    
1655
  def testBase(self):
1656
    self._pathTestHelper('/etc', True)
1657
    self._pathTestHelper('/srv', True)
1658
    self._pathTestHelper('etc', False)
1659
    self._pathTestHelper('/etc/../root', False)
1660
    self._pathTestHelper('/etc/', False)
1661

    
1662

    
1663
class TestSafeEncode(unittest.TestCase):
1664
  """Test case for SafeEncode"""
1665

    
1666
  def testAscii(self):
1667
    for txt in [string.digits, string.letters, string.punctuation]:
1668
      self.failUnlessEqual(txt, SafeEncode(txt))
1669

    
1670
  def testDoubleEncode(self):
1671
    for i in range(255):
1672
      txt = SafeEncode(chr(i))
1673
      self.failUnlessEqual(txt, SafeEncode(txt))
1674

    
1675
  def testUnicode(self):
1676
    # 1024 is high enough to catch non-direct ASCII mappings
1677
    for i in range(1024):
1678
      txt = SafeEncode(unichr(i))
1679
      self.failUnlessEqual(txt, SafeEncode(txt))
1680

    
1681

    
1682
class TestFormatTime(unittest.TestCase):
1683
  """Testing case for FormatTime"""
1684

    
1685
  @staticmethod
1686
  def _TestInProcess(tz, timestamp, expected):
1687
    os.environ["TZ"] = tz
1688
    time.tzset()
1689
    return utils.FormatTime(timestamp) == expected
1690

    
1691
  def _Test(self, *args):
1692
    # Need to use separate process as we want to change TZ
1693
    self.assert_(utils.RunInSeparateProcess(self._TestInProcess, *args))
1694

    
1695
  def test(self):
1696
    self._Test("UTC", 0, "1970-01-01 00:00:00")
1697
    self._Test("America/Sao_Paulo", 1292606926, "2010-12-17 15:28:46")
1698
    self._Test("Europe/London", 1292606926, "2010-12-17 17:28:46")
1699
    self._Test("Europe/Zurich", 1292606926, "2010-12-17 18:28:46")
1700
    self._Test("Australia/Sydney", 1292606926, "2010-12-18 04:28:46")
1701

    
1702
  def testNone(self):
1703
    self.failUnlessEqual(FormatTime(None), "N/A")
1704

    
1705
  def testInvalid(self):
1706
    self.failUnlessEqual(FormatTime(()), "N/A")
1707

    
1708
  def testNow(self):
1709
    # tests that we accept time.time input
1710
    FormatTime(time.time())
1711
    # tests that we accept int input
1712
    FormatTime(int(time.time()))
1713

    
1714

    
1715
class TestFormatTimestampWithTZ(unittest.TestCase):
1716
  @staticmethod
1717
  def _TestInProcess(tz, timestamp, expected):
1718
    os.environ["TZ"] = tz
1719
    time.tzset()
1720
    return utils.FormatTimestampWithTZ(timestamp) == expected
1721

    
1722
  def _Test(self, *args):
1723
    # Need to use separate process as we want to change TZ
1724
    self.assert_(utils.RunInSeparateProcess(self._TestInProcess, *args))
1725

    
1726
  def test(self):
1727
    self._Test("UTC", 0, "1970-01-01 00:00:00 UTC")
1728
    self._Test("America/Sao_Paulo", 1292606926, "2010-12-17 15:28:46 BRST")
1729
    self._Test("Europe/London", 1292606926, "2010-12-17 17:28:46 GMT")
1730
    self._Test("Europe/Zurich", 1292606926, "2010-12-17 18:28:46 CET")
1731
    self._Test("Australia/Sydney", 1292606926, "2010-12-18 04:28:46 EST")
1732

    
1733

    
1734
class RunInSeparateProcess(unittest.TestCase):
1735
  def test(self):
1736
    for exp in [True, False]:
1737
      def _child():
1738
        return exp
1739

    
1740
      self.assertEqual(exp, utils.RunInSeparateProcess(_child))
1741

    
1742
  def testArgs(self):
1743
    for arg in [0, 1, 999, "Hello World", (1, 2, 3)]:
1744
      def _child(carg1, carg2):
1745
        return carg1 == "Foo" and carg2 == arg
1746

    
1747
      self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg))
1748

    
1749
  def testPid(self):
1750
    parent_pid = os.getpid()
1751

    
1752
    def _check():
1753
      return os.getpid() == parent_pid
1754

    
1755
    self.failIf(utils.RunInSeparateProcess(_check))
1756

    
1757
  def testSignal(self):
1758
    def _kill():
1759
      os.kill(os.getpid(), signal.SIGTERM)
1760

    
1761
    self.assertRaises(errors.GenericError,
1762
                      utils.RunInSeparateProcess, _kill)
1763

    
1764
  def testException(self):
1765
    def _exc():
1766
      raise errors.GenericError("This is a test")
1767

    
1768
    self.assertRaises(errors.GenericError,
1769
                      utils.RunInSeparateProcess, _exc)
1770

    
1771

    
1772
class TestFingerprintFiles(unittest.TestCase):
1773
  def setUp(self):
1774
    self.tmpfile = tempfile.NamedTemporaryFile()
1775
    self.tmpfile2 = tempfile.NamedTemporaryFile()
1776
    utils.WriteFile(self.tmpfile2.name, data="Hello World\n")
1777
    self.results = {
1778
      self.tmpfile.name: "da39a3ee5e6b4b0d3255bfef95601890afd80709",
1779
      self.tmpfile2.name: "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a",
1780
      }
1781

    
1782
  def testSingleFile(self):
1783
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1784
                     self.results[self.tmpfile.name])
1785

    
1786
    self.assertEqual(utils._FingerprintFile("/no/such/file"), None)
1787

    
1788
  def testBigFile(self):
1789
    self.tmpfile.write("A" * 8192)
1790
    self.tmpfile.flush()
1791
    self.assertEqual(utils._FingerprintFile(self.tmpfile.name),
1792
                     "35b6795ca20d6dc0aff8c7c110c96cd1070b8c38")
1793

    
1794
  def testMultiple(self):
1795
    all_files = self.results.keys()
1796
    all_files.append("/no/such/file")
1797
    self.assertEqual(utils.FingerprintFiles(self.results.keys()), self.results)
1798

    
1799

    
1800
class TestUnescapeAndSplit(unittest.TestCase):
1801
  """Testing case for UnescapeAndSplit"""
1802

    
1803
  def setUp(self):
1804
    # testing more that one separator for regexp safety
1805
    self._seps = [",", "+", "."]
1806

    
1807
  def testSimple(self):
1808
    a = ["a", "b", "c", "d"]
1809
    for sep in self._seps:
1810
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a)
1811

    
1812
  def testEscape(self):
1813
    for sep in self._seps:
1814
      a = ["a", "b\\" + sep + "c", "d"]
1815
      b = ["a", "b" + sep + "c", "d"]
1816
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1817

    
1818
  def testDoubleEscape(self):
1819
    for sep in self._seps:
1820
      a = ["a", "b\\\\", "c", "d"]
1821
      b = ["a", "b\\", "c", "d"]
1822
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1823

    
1824
  def testThreeEscape(self):
1825
    for sep in self._seps:
1826
      a = ["a", "b\\\\\\" + sep + "c", "d"]
1827
      b = ["a", "b\\" + sep + "c", "d"]
1828
      self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b)
1829

    
1830

    
1831
class TestGenerateSelfSignedX509Cert(unittest.TestCase):
1832
  def setUp(self):
1833
    self.tmpdir = tempfile.mkdtemp()
1834

    
1835
  def tearDown(self):
1836
    shutil.rmtree(self.tmpdir)
1837

    
1838
  def _checkRsaPrivateKey(self, key):
1839
    lines = key.splitlines()
1840
    return ("-----BEGIN RSA PRIVATE KEY-----" in lines and
1841
            "-----END RSA PRIVATE KEY-----" in lines)
1842

    
1843
  def _checkCertificate(self, cert):
1844
    lines = cert.splitlines()
1845
    return ("-----BEGIN CERTIFICATE-----" in lines and
1846
            "-----END CERTIFICATE-----" in lines)
1847

    
1848
  def test(self):
1849
    for common_name in [None, ".", "Ganeti", "node1.example.com"]:
1850
      (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300)
1851
      self._checkRsaPrivateKey(key_pem)
1852
      self._checkCertificate(cert_pem)
1853

    
1854
      key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM,
1855
                                           key_pem)
1856
      self.assert_(key.bits() >= 1024)
1857
      self.assertEqual(key.bits(), constants.RSA_KEY_BITS)
1858
      self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA)
1859

    
1860
      x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1861
                                             cert_pem)
1862
      self.failIf(x509.has_expired())
1863
      self.assertEqual(x509.get_issuer().CN, common_name)
1864
      self.assertEqual(x509.get_subject().CN, common_name)
1865
      self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS)
1866

    
1867
  def testLegacy(self):
1868
    cert1_filename = os.path.join(self.tmpdir, "cert1.pem")
1869

    
1870
    utils.GenerateSelfSignedSslCert(cert1_filename, validity=1)
1871

    
1872
    cert1 = utils.ReadFile(cert1_filename)
1873

    
1874
    self.assert_(self._checkRsaPrivateKey(cert1))
1875
    self.assert_(self._checkCertificate(cert1))
1876

    
1877

    
1878
class TestPathJoin(unittest.TestCase):
1879
  """Testing case for PathJoin"""
1880

    
1881
  def testBasicItems(self):
1882
    mlist = ["/a", "b", "c"]
1883
    self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist))
1884

    
1885
  def testNonAbsPrefix(self):
1886
    self.failUnlessRaises(ValueError, PathJoin, "a", "b")
1887

    
1888
  def testBackTrack(self):
1889
    self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c")
1890

    
1891
  def testMultiAbs(self):
1892
    self.failUnlessRaises(ValueError, PathJoin, "/a", "/b")
1893

    
1894

    
1895
class TestValidateServiceName(unittest.TestCase):
1896
  def testValid(self):
1897
    testnames = [
1898
      0, 1, 2, 3, 1024, 65000, 65534, 65535,
1899
      "ganeti",
1900
      "gnt-masterd",
1901
      "HELLO_WORLD_SVC",
1902
      "hello.world.1",
1903
      "0", "80", "1111", "65535",
1904
      ]
1905

    
1906
    for name in testnames:
1907
      self.assertEqual(utils.ValidateServiceName(name), name)
1908

    
1909
  def testInvalid(self):
1910
    testnames = [
1911
      -15756, -1, 65536, 133428083,
1912
      "", "Hello World!", "!", "'", "\"", "\t", "\n", "`",
1913
      "-8546", "-1", "65536",
1914
      (129 * "A"),
1915
      ]
1916

    
1917
    for name in testnames:
1918
      self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name)
1919

    
1920

    
1921
class TestParseAsn1Generalizedtime(unittest.TestCase):
1922
  def test(self):
1923
    # UTC
1924
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0)
1925
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"),
1926
                     1266860512)
1927
    self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"),
1928
                     (2**31) - 1)
1929

    
1930
    # With offset
1931
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"),
1932
                     1266860512)
1933
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"),
1934
                     1266931012)
1935
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"),
1936
                     1266931088)
1937
    self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"),
1938
                     1266931295)
1939
    self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"),
1940
                     3600)
1941

    
1942
    # Leap seconds are not supported by datetime.datetime
1943
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1944
                      "19841231235960+0000")
1945
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1946
                      "19920630235960+0000")
1947

    
1948
    # Errors
1949
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "")
1950
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid")
1951
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1952
                      "20100222174152")
1953
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1954
                      "Mon Feb 22 17:47:02 UTC 2010")
1955
    self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime,
1956
                      "2010-02-22 17:42:02")
1957

    
1958

    
1959
class TestGetX509CertValidity(testutils.GanetiTestCase):
1960
  def setUp(self):
1961
    testutils.GanetiTestCase.setUp(self)
1962

    
1963
    pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__)
1964

    
1965
    # Test whether we have pyOpenSSL 0.7 or above
1966
    self.pyopenssl0_7 = (pyopenssl_version >= "0.7")
1967

    
1968
    if not self.pyopenssl0_7:
1969
      warnings.warn("This test requires pyOpenSSL 0.7 or above to"
1970
                    " function correctly")
1971

    
1972
  def _LoadCert(self, name):
1973
    return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1974
                                           self._ReadTestData(name))
1975

    
1976
  def test(self):
1977
    validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem"))
1978
    if self.pyopenssl0_7:
1979
      self.assertEqual(validity, (1266919967, 1267524767))
1980
    else:
1981
      self.assertEqual(validity, (None, None))
1982

    
1983

    
1984
class TestSignX509Certificate(unittest.TestCase):
1985
  KEY = "My private key!"
1986
  KEY_OTHER = "Another key"
1987

    
1988
  def test(self):
1989
    # Generate certificate valid for 5 minutes
1990
    (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300)
1991

    
1992
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
1993
                                           cert_pem)
1994

    
1995
    # No signature at all
1996
    self.assertRaises(errors.GenericError,
1997
                      utils.LoadSignedX509Certificate, cert_pem, self.KEY)
1998

    
1999
    # Invalid input
2000
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2001
                      "", self.KEY)
2002
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2003
                      "X-Ganeti-Signature: \n", self.KEY)
2004
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2005
                      "X-Ganeti-Sign: $1234$abcdef\n", self.KEY)
2006
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2007
                      "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY)
2008
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2009
                      "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY)
2010

    
2011
    # Invalid salt
2012
    for salt in list("-_@$,:;/\\ \t\n"):
2013
      self.assertRaises(errors.GenericError, utils.SignX509Certificate,
2014
                        cert_pem, self.KEY, "foo%sbar" % salt)
2015

    
2016
    for salt in ["HelloWorld", "salt", string.letters, string.digits,
2017
                 utils.GenerateSecret(numbytes=4),
2018
                 utils.GenerateSecret(numbytes=16),
2019
                 "{123:456}".encode("hex")]:
2020
      signed_pem = utils.SignX509Certificate(cert, self.KEY, salt)
2021

    
2022
      self._Check(cert, salt, signed_pem)
2023

    
2024
      self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem)
2025
      self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem)
2026
      self._Check(cert, salt, (signed_pem + "\n\na few more\n"
2027
                               "lines----\n------ at\nthe end!"))
2028

    
2029
  def _Check(self, cert, salt, pem):
2030
    (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY)
2031
    self.assertEqual(salt, salt2)
2032
    self.assertEqual(cert.digest("sha1"), cert2.digest("sha1"))
2033

    
2034
    # Other key
2035
    self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate,
2036
                      pem, self.KEY_OTHER)
2037

    
2038

    
2039
class TestMakedirs(unittest.TestCase):
2040
  def setUp(self):
2041
    self.tmpdir = tempfile.mkdtemp()
2042

    
2043
  def tearDown(self):
2044
    shutil.rmtree(self.tmpdir)
2045

    
2046
  def testNonExisting(self):
2047
    path = PathJoin(self.tmpdir, "foo")
2048
    utils.Makedirs(path)
2049
    self.assert_(os.path.isdir(path))
2050

    
2051
  def testExisting(self):
2052
    path = PathJoin(self.tmpdir, "foo")
2053
    os.mkdir(path)
2054
    utils.Makedirs(path)
2055
    self.assert_(os.path.isdir(path))
2056

    
2057
  def testRecursiveNonExisting(self):
2058
    path = PathJoin(self.tmpdir, "foo/bar/baz")
2059
    utils.Makedirs(path)
2060
    self.assert_(os.path.isdir(path))
2061

    
2062
  def testRecursiveExisting(self):
2063
    path = PathJoin(self.tmpdir, "B/moo/xyz")
2064
    self.assertFalse(os.path.exists(path))
2065
    os.mkdir(PathJoin(self.tmpdir, "B"))
2066
    utils.Makedirs(path)
2067
    self.assert_(os.path.isdir(path))
2068

    
2069

    
2070
class TestRetry(testutils.GanetiTestCase):
2071
  def setUp(self):
2072
    testutils.GanetiTestCase.setUp(self)
2073
    self.retries = 0
2074

    
2075
  @staticmethod
2076
  def _RaiseRetryAgain():
2077
    raise utils.RetryAgain()
2078

    
2079
  @staticmethod
2080
  def _RaiseRetryAgainWithArg(args):
2081
    raise utils.RetryAgain(*args)
2082

    
2083
  def _WrongNestedLoop(self):
2084
    return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
2085

    
2086
  def _RetryAndSucceed(self, retries):
2087
    if self.retries < retries:
2088
      self.retries += 1
2089
      raise utils.RetryAgain()
2090
    else:
2091
      return True
2092

    
2093
  def testRaiseTimeout(self):
2094
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2095
                          self._RaiseRetryAgain, 0.01, 0.02)
2096
    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
2097
                          self._RetryAndSucceed, 0.01, 0, args=[1])
2098
    self.failUnlessEqual(self.retries, 1)
2099

    
2100
  def testComplete(self):
2101
    self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
2102
    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
2103
                         True)
2104
    self.failUnlessEqual(self.retries, 2)
2105

    
2106
  def testNestedLoop(self):
2107
    try:
2108
      self.failUnlessRaises(errors.ProgrammerError, utils.Retry,
2109
                            self._WrongNestedLoop, 0, 1)
2110
    except utils.RetryTimeout:
2111
      self.fail("Didn't detect inner loop's exception")
2112

    
2113
  def testTimeoutArgument(self):
2114
    retry_arg="my_important_debugging_message"
2115
    try:
2116
      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
2117
    except utils.RetryTimeout, err:
2118
      self.failUnlessEqual(err.args, (retry_arg, ))
2119
    else:
2120
      self.fail("Expected timeout didn't happen")
2121

    
2122
  def testRaiseInnerWithExc(self):
2123
    retry_arg="my_important_debugging_message"
2124
    try:
2125
      try:
2126
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2127
                    args=[[errors.GenericError(retry_arg, retry_arg)]])
2128
      except utils.RetryTimeout, err:
2129
        err.RaiseInner()
2130
      else:
2131
        self.fail("Expected timeout didn't happen")
2132
    except errors.GenericError, err:
2133
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2134
    else:
2135
      self.fail("Expected GenericError didn't happen")
2136

    
2137
  def testRaiseInnerWithMsg(self):
2138
    retry_arg="my_important_debugging_message"
2139
    try:
2140
      try:
2141
        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
2142
                    args=[[retry_arg, retry_arg]])
2143
      except utils.RetryTimeout, err:
2144
        err.RaiseInner()
2145
      else:
2146
        self.fail("Expected timeout didn't happen")
2147
    except utils.RetryTimeout, err:
2148
      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
2149
    else:
2150
      self.fail("Expected RetryTimeout didn't happen")
2151

    
2152

    
2153
class TestLineSplitter(unittest.TestCase):
2154
  def test(self):
2155
    lines = []
2156
    ls = utils.LineSplitter(lines.append)
2157
    ls.write("Hello World\n")
2158
    self.assertEqual(lines, [])
2159
    ls.write("Foo\n Bar\r\n ")
2160
    ls.write("Baz")
2161
    ls.write("Moo")
2162
    self.assertEqual(lines, [])
2163
    ls.flush()
2164
    self.assertEqual(lines, ["Hello World", "Foo", " Bar"])
2165
    ls.close()
2166
    self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"])
2167

    
2168
  def _testExtra(self, line, all_lines, p1, p2):
2169
    self.assertEqual(p1, 999)
2170
    self.assertEqual(p2, "extra")
2171
    all_lines.append(line)
2172

    
2173
  def testExtraArgsNoFlush(self):
2174
    lines = []
2175
    ls = utils.LineSplitter(self._testExtra, lines, 999, "extra")
2176
    ls.write("\n\nHello World\n")
2177
    ls.write("Foo\n Bar\r\n ")
2178
    ls.write("")
2179
    ls.write("Baz")
2180
    ls.write("Moo\n\nx\n")
2181
    self.assertEqual(lines, [])
2182
    ls.close()
2183
    self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo",
2184
                             "", "x"])
2185

    
2186

    
2187
class TestReadLockedPidFile(unittest.TestCase):
2188
  def setUp(self):
2189
    self.tmpdir = tempfile.mkdtemp()
2190

    
2191
  def tearDown(self):
2192
    shutil.rmtree(self.tmpdir)
2193

    
2194
  def testNonExistent(self):
2195
    path = PathJoin(self.tmpdir, "nonexist")
2196
    self.assert_(utils.ReadLockedPidFile(path) is None)
2197

    
2198
  def testUnlocked(self):
2199
    path = PathJoin(self.tmpdir, "pid")
2200
    utils.WriteFile(path, data="123")
2201
    self.assert_(utils.ReadLockedPidFile(path) is None)
2202

    
2203
  def testLocked(self):
2204
    path = PathJoin(self.tmpdir, "pid")
2205
    utils.WriteFile(path, data="123")
2206

    
2207
    fl = utils.FileLock.Open(path)
2208
    try:
2209
      fl.Exclusive(blocking=True)
2210

    
2211
      self.assertEqual(utils.ReadLockedPidFile(path), 123)
2212
    finally:
2213
      fl.Close()
2214

    
2215
    self.assert_(utils.ReadLockedPidFile(path) is None)
2216

    
2217
  def testError(self):
2218
    path = PathJoin(self.tmpdir, "foobar", "pid")
2219
    utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="")
2220
    # open(2) should return ENOTDIR
2221
    self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
2222

    
2223

    
2224
class TestCertVerification(testutils.GanetiTestCase):
2225
  def setUp(self):
2226
    testutils.GanetiTestCase.setUp(self)
2227

    
2228
    self.tmpdir = tempfile.mkdtemp()
2229

    
2230
  def tearDown(self):
2231
    shutil.rmtree(self.tmpdir)
2232

    
2233
  def testVerifyCertificate(self):
2234
    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
2235
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
2236
                                           cert_pem)
2237

    
2238
    # Not checking return value as this certificate is expired
2239
    utils.VerifyX509Certificate(cert, 30, 7)
2240

    
2241

    
2242
class TestVerifyCertificateInner(unittest.TestCase):
2243
  def test(self):
2244
    vci = utils._VerifyCertificateInner
2245

    
2246
    # Valid
2247
    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
2248
                     (None, None))
2249

    
2250
    # Not yet valid
2251
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
2252
    self.assertEqual(errcode, utils.CERT_WARNING)
2253

    
2254
    # Expiring soon
2255
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
2256
    self.assertEqual(errcode, utils.CERT_ERROR)
2257

    
2258
    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
2259
    self.assertEqual(errcode, utils.CERT_WARNING)
2260

    
2261
    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
2262
    self.assertEqual(errcode, None)
2263

    
2264
    # Expired
2265
    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
2266
    self.assertEqual(errcode, utils.CERT_ERROR)
2267

    
2268
    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
2269
    self.assertEqual(errcode, utils.CERT_ERROR)
2270

    
2271
    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
2272
    self.assertEqual(errcode, utils.CERT_ERROR)
2273

    
2274
    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
2275
    self.assertEqual(errcode, utils.CERT_ERROR)
2276

    
2277

    
2278
class TestHmacFunctions(unittest.TestCase):
2279
  # Digests can be checked with "openssl sha1 -hmac $key"
2280
  def testSha1Hmac(self):
2281
    self.assertEqual(utils.Sha1Hmac("", ""),
2282
                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
2283
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
2284
                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
2285
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
2286
                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
2287

    
2288
    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
2289
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
2290
                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
2291

    
2292
  def testSha1HmacSalt(self):
2293
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
2294
                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
2295
    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
2296
                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
2297
    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
2298
                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
2299

    
2300
  def testVerifySha1Hmac(self):
2301
    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
2302
                                               "7d64b71fb76370690e1d")))
2303
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2304
                                      ("f904c2476527c6d3e660"
2305
                                       "9ab683c66fa0652cb1dc")))
2306

    
2307
    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
2308
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
2309
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2310
                                      digest.lower()))
2311
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2312
                                      digest.upper()))
2313
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2314
                                      digest.title()))
2315

    
2316
  def testVerifySha1HmacSalt(self):
2317
    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
2318
                                      ("17a4adc34d69c0d367d4"
2319
                                       "ffbef96fd41d4df7a6e8"),
2320
                                      salt="abc9"))
2321
    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
2322
                                      ("7f264f8114c9066afc9b"
2323
                                       "b7636e1786d996d3cc0d"),
2324
                                      salt="xyz0"))
2325

    
2326

    
2327
class TestIgnoreSignals(unittest.TestCase):
2328
  """Test the IgnoreSignals decorator"""
2329

    
2330
  @staticmethod
2331
  def _Raise(exception):
2332
    raise exception
2333

    
2334
  @staticmethod
2335
  def _Return(rval):
2336
    return rval
2337

    
2338
  def testIgnoreSignals(self):
2339
    sock_err_intr = socket.error(errno.EINTR, "Message")
2340
    sock_err_inval = socket.error(errno.EINVAL, "Message")
2341

    
2342
    env_err_intr = EnvironmentError(errno.EINTR, "Message")
2343
    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
2344

    
2345
    self.assertRaises(socket.error, self._Raise, sock_err_intr)
2346
    self.assertRaises(socket.error, self._Raise, sock_err_inval)
2347
    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
2348
    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
2349

    
2350
    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
2351
    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
2352
    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
2353
                      sock_err_inval)
2354
    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
2355
                      env_err_inval)
2356

    
2357
    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
2358
    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
2359

    
2360

    
2361
class TestEnsureDirs(unittest.TestCase):
2362
  """Tests for EnsureDirs"""
2363

    
2364
  def setUp(self):
2365
    self.dir = tempfile.mkdtemp()
2366
    self.old_umask = os.umask(0777)
2367

    
2368
  def testEnsureDirs(self):
2369
    utils.EnsureDirs([
2370
        (PathJoin(self.dir, "foo"), 0777),
2371
        (PathJoin(self.dir, "bar"), 0000),
2372
        ])
2373
    self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777)
2374
    self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000)
2375

    
2376
  def tearDown(self):
2377
    os.rmdir(PathJoin(self.dir, "foo"))
2378
    os.rmdir(PathJoin(self.dir, "bar"))
2379
    os.rmdir(self.dir)
2380
    os.umask(self.old_umask)
2381

    
2382

    
2383
class TestFormatSeconds(unittest.TestCase):
2384
  def test(self):
2385
    self.assertEqual(utils.FormatSeconds(1), "1s")
2386
    self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s")
2387
    self.assertEqual(utils.FormatSeconds(3599), "59m 59s")
2388
    self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s")
2389
    self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s")
2390
    self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s")
2391
    self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s")
2392
    self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s")
2393
    self.assertEqual(utils.FormatSeconds(-1), "-1s")
2394
    self.assertEqual(utils.FormatSeconds(-282), "-282s")
2395
    self.assertEqual(utils.FormatSeconds(-29119), "-29119s")
2396

    
2397
  def testFloat(self):
2398
    self.assertEqual(utils.FormatSeconds(1.3), "1s")
2399
    self.assertEqual(utils.FormatSeconds(1.9), "2s")
2400
    self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s")
2401
    self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s")
2402

    
2403

    
2404
class TestIgnoreProcessNotFound(unittest.TestCase):
2405
  @staticmethod
2406
  def _WritePid(fd):
2407
    os.write(fd, str(os.getpid()))
2408
    os.close(fd)
2409
    return True
2410

    
2411
  def test(self):
2412
    (pid_read_fd, pid_write_fd) = os.pipe()
2413

    
2414
    # Start short-lived process which writes its PID to pipe
2415
    self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd))
2416
    os.close(pid_write_fd)
2417

    
2418
    # Read PID from pipe
2419
    pid = int(os.read(pid_read_fd, 1024))
2420
    os.close(pid_read_fd)
2421

    
2422
    # Try to send signal to process which exited recently
2423
    self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0))
2424

    
2425

    
2426
class TestShellWriter(unittest.TestCase):
2427
  def test(self):
2428
    buf = StringIO()
2429
    sw = utils.ShellWriter(buf)
2430
    sw.Write("#!/bin/bash")
2431
    sw.Write("if true; then")
2432
    sw.IncIndent()
2433
    try:
2434
      sw.Write("echo true")
2435

    
2436
      sw.Write("for i in 1 2 3")
2437
      sw.Write("do")
2438
      sw.IncIndent()
2439
      try:
2440
        self.assertEqual(sw._indent, 2)
2441
        sw.Write("date")
2442
      finally:
2443
        sw.DecIndent()
2444
      sw.Write("done")
2445
    finally:
2446
      sw.DecIndent()
2447
    sw.Write("echo %s", utils.ShellQuote("Hello World"))
2448
    sw.Write("exit 0")
2449

    
2450
    self.assertEqual(sw._indent, 0)
2451

    
2452
    output = buf.getvalue()
2453

    
2454
    self.assert_(output.endswith("\n"))
2455

    
2456
    lines = output.splitlines()
2457
    self.assertEqual(len(lines), 9)
2458
    self.assertEqual(lines[0], "#!/bin/bash")
2459
    self.assert_(re.match(r"^\s+date$", lines[5]))
2460
    self.assertEqual(lines[7], "echo 'Hello World'")
2461

    
2462
  def testEmpty(self):
2463
    buf = StringIO()
2464
    sw = utils.ShellWriter(buf)
2465
    sw = None
2466
    self.assertEqual(buf.getvalue(), "")
2467

    
2468

    
2469
class TestCommaJoin(unittest.TestCase):
2470
  def test(self):
2471
    self.assertEqual(utils.CommaJoin([]), "")
2472
    self.assertEqual(utils.CommaJoin([1, 2, 3]), "1, 2, 3")
2473
    self.assertEqual(utils.CommaJoin(["Hello"]), "Hello")
2474
    self.assertEqual(utils.CommaJoin(["Hello", "World"]), "Hello, World")
2475
    self.assertEqual(utils.CommaJoin(["Hello", "World", 99]),
2476
                     "Hello, World, 99")
2477

    
2478

    
2479
class TestFindMatch(unittest.TestCase):
2480
  def test(self):
2481
    data = {
2482
      "aaaa": "Four A",
2483
      "bb": {"Two B": True},
2484
      re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3),
2485
      }
2486

    
2487
    self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", []))
2488
    self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, []))
2489

    
2490
    for i in ["foo", "bar", "bazX"]:
2491
      for j in range(1, 100, 7):
2492
        self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)),
2493
                         ((1, 2, 3), [i, str(j)]))
2494

    
2495
  def testNoMatch(self):
2496
    self.assert_(utils.FindMatch({}, "") is None)
2497
    self.assert_(utils.FindMatch({}, "foo") is None)
2498
    self.assert_(utils.FindMatch({}, 1234) is None)
2499

    
2500
    data = {
2501
      "X": "Hello World",
2502
      re.compile("^(something)$"): "Hello World",
2503
      }
2504

    
2505
    self.assert_(utils.FindMatch(data, "") is None)
2506
    self.assert_(utils.FindMatch(data, "Hello World") is None)
2507

    
2508

    
2509
class TestFileID(testutils.GanetiTestCase):
2510
  def testEquality(self):
2511
    name = self._CreateTempFile()
2512
    oldi = utils.GetFileID(path=name)
2513
    self.failUnless(utils.VerifyFileID(oldi, oldi))
2514

    
2515
  def testUpdate(self):
2516
    name = self._CreateTempFile()
2517
    oldi = utils.GetFileID(path=name)
2518
    os.utime(name, None)
2519
    fd = os.open(name, os.O_RDWR)
2520
    try:
2521
      newi = utils.GetFileID(fd=fd)
2522
      self.failUnless(utils.VerifyFileID(oldi, newi))
2523
      self.failUnless(utils.VerifyFileID(newi, oldi))
2524
    finally:
2525
      os.close(fd)
2526

    
2527
  def testWriteFile(self):
2528
    name = self._CreateTempFile()
2529
    oldi = utils.GetFileID(path=name)
2530
    mtime = oldi[2]
2531
    os.utime(name, (mtime + 10, mtime + 10))
2532
    self.assertRaises(errors.LockError, utils.SafeWriteFile, name,
2533
                      oldi, data="")
2534
    os.utime(name, (mtime - 10, mtime - 10))
2535
    utils.SafeWriteFile(name, oldi, data="")
2536
    oldi = utils.GetFileID(path=name)
2537
    mtime = oldi[2]
2538
    os.utime(name, (mtime + 10, mtime + 10))
2539
    # this doesn't raise, since we passed None
2540
    utils.SafeWriteFile(name, None, data="")
2541

    
2542
  def testError(self):
2543
    t = tempfile.NamedTemporaryFile()
2544
    self.assertRaises(errors.ProgrammerError, utils.GetFileID,
2545
                      path=t.name, fd=t.fileno())
2546

    
2547

    
2548
class TimeMock:
2549
  def __init__(self, values):
2550
    self.values = values
2551

    
2552
  def __call__(self):
2553
    return self.values.pop(0)
2554

    
2555

    
2556
class TestRunningTimeout(unittest.TestCase):
2557
  def setUp(self):
2558
    self.time_fn = TimeMock([0.0, 0.3, 4.6, 6.5])
2559

    
2560
  def testRemainingFloat(self):
2561
    timeout = utils.RunningTimeout(5.0, True, _time_fn=self.time_fn)
2562
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2563
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2564
    self.assertAlmostEqual(timeout.Remaining(), -1.5)
2565

    
2566
  def testRemaining(self):
2567
    self.time_fn = TimeMock([0, 2, 4, 5, 6])
2568
    timeout = utils.RunningTimeout(5, True, _time_fn=self.time_fn)
2569
    self.assertEqual(timeout.Remaining(), 3)
2570
    self.assertEqual(timeout.Remaining(), 1)
2571
    self.assertEqual(timeout.Remaining(), 0)
2572
    self.assertEqual(timeout.Remaining(), -1)
2573

    
2574
  def testRemainingNonNegative(self):
2575
    timeout = utils.RunningTimeout(5.0, False, _time_fn=self.time_fn)
2576
    self.assertAlmostEqual(timeout.Remaining(), 4.7)
2577
    self.assertAlmostEqual(timeout.Remaining(), 0.4)
2578
    self.assertEqual(timeout.Remaining(), 0.0)
2579

    
2580
  def testNegativeTimeout(self):
2581
    self.assertRaises(ValueError, utils.RunningTimeout, -1.0, True)
2582

    
2583

    
2584
class TestTryConvert(unittest.TestCase):
2585
  def test(self):
2586
    for src, fn, result in [
2587
      ("1", int, 1),
2588
      ("a", int, "a"),
2589
      ("", bool, False),
2590
      ("a", bool, True),
2591
      ]:
2592
      self.assertEqual(utils.TryConvert(fn, src), result)
2593

    
2594

    
2595
class TestIsValidShellParam(unittest.TestCase):
2596
  def test(self):
2597
    for val, result in [
2598
      ("abc", True),
2599
      ("ab;cd", False),
2600
      ]:
2601
      self.assertEqual(utils.IsValidShellParam(val), result)
2602

    
2603

    
2604
class TestBuildShellCmd(unittest.TestCase):
2605
  def test(self):
2606
    self.assertRaises(errors.ProgrammerError, utils.BuildShellCmd,
2607
                      "ls %s", "ab;cd")
2608
    self.assertEqual(utils.BuildShellCmd("ls %s", "ab"), "ls ab")
2609

    
2610

    
2611
class TestWriteFile(unittest.TestCase):
2612
  def setUp(self):
2613
    self.tfile = tempfile.NamedTemporaryFile()
2614
    self.did_pre = False
2615
    self.did_post = False
2616
    self.did_write = False
2617

    
2618
  def markPre(self, fd):
2619
    self.did_pre = True
2620

    
2621
  def markPost(self, fd):
2622
    self.did_post = True
2623

    
2624
  def markWrite(self, fd):
2625
    self.did_write = True
2626

    
2627
  def testWrite(self):
2628
    data = "abc"
2629
    utils.WriteFile(self.tfile.name, data=data)
2630
    self.assertEqual(utils.ReadFile(self.tfile.name), data)
2631

    
2632
  def testErrors(self):
2633
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2634
                      self.tfile.name, data="test", fn=lambda fd: None)
2635
    self.assertRaises(errors.ProgrammerError, utils.WriteFile, self.tfile.name)
2636
    self.assertRaises(errors.ProgrammerError, utils.WriteFile,
2637
                      self.tfile.name, data="test", atime=0)
2638

    
2639
  def testCalls(self):
2640
    utils.WriteFile(self.tfile.name, fn=self.markWrite,
2641
                    prewrite=self.markPre, postwrite=self.markPost)
2642
    self.assertTrue(self.did_pre)
2643
    self.assertTrue(self.did_post)
2644
    self.assertTrue(self.did_write)
2645

    
2646
  def testDryRun(self):
2647
    orig = "abc"
2648
    self.tfile.write(orig)
2649
    self.tfile.flush()
2650
    utils.WriteFile(self.tfile.name, data="hello", dry_run=True)
2651
    self.assertEqual(utils.ReadFile(self.tfile.name), orig)
2652

    
2653
  def testTimes(self):
2654
    f = self.tfile.name
2655
    for at, mt in [(0, 0), (1000, 1000), (2000, 3000),
2656
                   (int(time.time()), 5000)]:
2657
      utils.WriteFile(f, data="hello", atime=at, mtime=mt)
2658
      st = os.stat(f)
2659
      self.assertEqual(st.st_atime, at)
2660
      self.assertEqual(st.st_mtime, mt)
2661

    
2662

    
2663
  def testNoClose(self):
2664
    data = "hello"
2665
    self.assertEqual(utils.WriteFile(self.tfile.name, data="abc"), None)
2666
    fd = utils.WriteFile(self.tfile.name, data=data, close=False)
2667
    try:
2668
      os.lseek(fd, 0, 0)
2669
      self.assertEqual(os.read(fd, 4096), data)
2670
    finally:
2671
      os.close(fd)
2672

    
2673

    
2674
class TestNormalizeAndValidateMac(unittest.TestCase):
2675
  def testInvalid(self):
2676
    self.assertRaises(errors.OpPrereqError,
2677
                      utils.NormalizeAndValidateMac, "xxx")
2678

    
2679
  def testNormalization(self):
2680
    for mac in ["aa:bb:cc:dd:ee:ff", "00:AA:11:bB:22:cc"]:
2681
      self.assertEqual(utils.NormalizeAndValidateMac(mac), mac.lower())
2682

    
2683

    
2684
class TestNiceSort(unittest.TestCase):
2685
  def test(self):
2686
    self.assertEqual(utils.NiceSort([]), [])
2687
    self.assertEqual(utils.NiceSort(["foo"]), ["foo"])
2688
    self.assertEqual(utils.NiceSort(["bar", ""]), ["", "bar"])
2689
    self.assertEqual(utils.NiceSort([",", "."]), [",", "."])
2690
    self.assertEqual(utils.NiceSort(["0.1", "0.2"]), ["0.1", "0.2"])
2691
    self.assertEqual(utils.NiceSort(["0;099", "0,099", "0.1", "0.2"]),
2692
                     ["0,099", "0.1", "0.2", "0;099"])
2693

    
2694
    data = ["a0", "a1", "a99", "a20", "a2", "b10", "b70", "b00", "0000"]
2695
    self.assertEqual(utils.NiceSort(data),
2696
                     ["0000", "a0", "a1", "a2", "a20", "a99",
2697
                      "b00", "b10", "b70"])
2698

    
2699
    data = ["a0-0", "a1-0", "a99-10", "a20-3", "a0-4", "a99-3", "a09-2",
2700
            "Z", "a9-1", "A", "b"]
2701
    self.assertEqual(utils.NiceSort(data),
2702
                     ["A", "Z", "a0-0", "a0-4", "a1-0", "a9-1", "a09-2",
2703
                      "a20-3", "a99-3", "a99-10", "b"])
2704
    self.assertEqual(utils.NiceSort(data, key=str.lower),
2705
                     ["A", "a0-0", "a0-4", "a1-0", "a9-1", "a09-2",
2706
                      "a20-3", "a99-3", "a99-10", "b", "Z"])
2707
    self.assertEqual(utils.NiceSort(data, key=str.upper),
2708
                     ["A", "a0-0", "a0-4", "a1-0", "a9-1", "a09-2",
2709
                      "a20-3", "a99-3", "a99-10", "b", "Z"])
2710

    
2711
  def testLargeA(self):
2712
    data = [
2713
      "Eegah9ei", "xij88brTulHYAv8IEOyU", "3jTwJPtrXOY22bwL2YoW",
2714
      "Z8Ljf1Pf5eBfNg171wJR", "WvNJd91OoXvLzdEiEXa6", "uHXAyYYftCSG1o7qcCqe",
2715
      "xpIUJeVT1Rp", "KOt7vn1dWXi", "a07h8feON165N67PIE", "bH4Q7aCu3PUPjK3JtH",
2716
      "cPRi0lM7HLnSuWA2G9", "KVQqLPDjcPjf8T3oyzjcOsfkb",
2717
      "guKJkXnkULealVC8CyF1xefym", "pqF8dkU5B1cMnyZuREaSOADYx",
2718
      ]
2719
    self.assertEqual(utils.NiceSort(data), [
2720
      "3jTwJPtrXOY22bwL2YoW", "Eegah9ei", "KOt7vn1dWXi",
2721
      "KVQqLPDjcPjf8T3oyzjcOsfkb", "WvNJd91OoXvLzdEiEXa6",
2722
      "Z8Ljf1Pf5eBfNg171wJR", "a07h8feON165N67PIE", "bH4Q7aCu3PUPjK3JtH",
2723
      "cPRi0lM7HLnSuWA2G9", "guKJkXnkULealVC8CyF1xefym",
2724
      "pqF8dkU5B1cMnyZuREaSOADYx", "uHXAyYYftCSG1o7qcCqe",
2725
      "xij88brTulHYAv8IEOyU", "xpIUJeVT1Rp"
2726
      ])
2727

    
2728
  def testLargeB(self):
2729
    data = [
2730
      "inst-0.0.0.0-0.0.0.0",
2731
      "inst-0.1.0.0-0.0.0.0",
2732
      "inst-0.2.0.0-0.0.0.0",
2733
      "inst-0.2.1.0-0.0.0.0",
2734
      "inst-0.2.2.0-0.0.0.0",
2735
      "inst-0.2.2.0-0.0.0.9",
2736
      "inst-0.2.2.0-0.0.3.9",
2737
      "inst-0.2.2.0-0.2.0.9",
2738
      "inst-0.2.2.0-0.9.0.9",
2739
      "inst-0.20.2.0-0.0.0.0",
2740
      "inst-0.20.2.0-0.9.0.9",
2741
      "inst-10.020.2.0-0.9.0.10",
2742
      "inst-15.020.2.0-0.9.1.00",
2743
      "inst-100.020.2.0-0.9.0.9",
2744

    
2745
      # Only the last group, not converted to a number anymore, differs
2746
      "inst-100.020.2.0a999",
2747
      "inst-100.020.2.0b000",
2748
      "inst-100.020.2.0c10",
2749
      "inst-100.020.2.0c101",
2750
      "inst-100.020.2.0c2",
2751
      "inst-100.020.2.0c20",
2752
      "inst-100.020.2.0c3",
2753
      "inst-100.020.2.0c39123",
2754
      ]
2755

    
2756
    rnd = random.Random(16205)
2757
    for _ in range(10):
2758
      testdata = data[:]
2759
      rnd.shuffle(testdata)
2760
      assert testdata != data
2761
      self.assertEqual(utils.NiceSort(testdata), data)
2762

    
2763
  class _CallCount:
2764
    def __init__(self, fn):
2765
      self.count = 0
2766
      self.fn = fn
2767

    
2768
    def __call__(self, *args):
2769
      self.count += 1
2770
      return self.fn(*args)
2771

    
2772
  def testKeyfuncA(self):
2773
    # Generate some random numbers
2774
    rnd = random.Random(21131)
2775
    numbers = [rnd.randint(0, 10000) for _ in range(999)]
2776
    assert numbers != sorted(numbers)
2777

    
2778
    # Convert to hex
2779
    data = [hex(i) for i in numbers]
2780
    datacopy = data[:]
2781

    
2782
    keyfn = self._CallCount(lambda value: str(int(value, 16)))
2783

    
2784
    # Sort with key function converting hex to decimal
2785
    result = utils.NiceSort(data, key=keyfn)
2786

    
2787
    self.assertEqual([hex(i) for i in sorted(numbers)], result)
2788
    self.assertEqual(data, datacopy, msg="Input data was modified in NiceSort")
2789
    self.assertEqual(keyfn.count, len(numbers),
2790
                     msg="Key function was not called once per value")
2791

    
2792
  class _TestData:
2793
    def __init__(self, name, value):
2794
      self.name = name
2795
      self.value = value
2796

    
2797
  def testKeyfuncB(self):
2798
    rnd = random.Random(27396)
2799
    data = []
2800
    for i in range(123):
2801
      v1 = rnd.randint(0, 5)
2802
      v2 = rnd.randint(0, 5)
2803
      data.append(self._TestData("inst-%s-%s-%s" % (v1, v2, i),
2804
                                 (v1, v2, i)))
2805
    rnd.shuffle(data)
2806
    assert data != sorted(data, key=operator.attrgetter("name"))
2807

    
2808
    keyfn = self._CallCount(operator.attrgetter("name"))
2809

    
2810
    # Sort by name
2811
    result = utils.NiceSort(data, key=keyfn)
2812

    
2813
    self.assertEqual(result, sorted(data, key=operator.attrgetter("value")))
2814
    self.assertEqual(keyfn.count, len(data),
2815
                     msg="Key function was not called once per value")
2816

    
2817

    
2818
if __name__ == '__main__':
2819
  testutils.GanetiTestProgram()