Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.utils_unittest.py @ a426508d

History | View | Annotate | Download (35.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the utils module"""
23

    
24
import unittest
25
import os
26
import time
27
import tempfile
28
import os.path
29
import os
30
import md5
31
import signal
32
import socket
33
import shutil
34
import re
35
import select
36
import string
37

    
38
import ganeti
39
import testutils
40
from ganeti import constants
41
from ganeti import utils
42
from ganeti import errors
43
from ganeti.utils import IsProcessAlive, RunCmd, \
44
     RemoveFile, MatchNameComponent, FormatUnit, \
45
     ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
46
     ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
47
     SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
48
     TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime
49

    
50
from ganeti.errors import LockError, UnitParseError, GenericError, \
51
     ProgrammerError
52

    
53

    
54
class TestIsProcessAlive(unittest.TestCase):
55
  """Testing case for IsProcessAlive"""
56

    
57
  def testExists(self):
58
    mypid = os.getpid()
59
    self.assert_(IsProcessAlive(mypid),
60
                 "can't find myself running")
61

    
62
  def testNotExisting(self):
63
    pid_non_existing = os.fork()
64
    if pid_non_existing == 0:
65
      os._exit(0)
66
    elif pid_non_existing < 0:
67
      raise SystemError("can't fork")
68
    os.waitpid(pid_non_existing, 0)
69
    self.assert_(not IsProcessAlive(pid_non_existing),
70
                 "nonexisting process detected")
71

    
72

    
73
class TestPidFileFunctions(unittest.TestCase):
74
  """Tests for WritePidFile, RemovePidFile and ReadPidFile"""
75

    
76
  def setUp(self):
77
    self.dir = tempfile.mkdtemp()
78
    self.f_dpn = lambda name: os.path.join(self.dir, "%s.pid" % name)
79
    utils.DaemonPidFileName = self.f_dpn
80

    
81
  def testPidFileFunctions(self):
82
    pid_file = self.f_dpn('test')
83
    utils.WritePidFile('test')
84
    self.failUnless(os.path.exists(pid_file),
85
                    "PID file should have been created")
86
    read_pid = utils.ReadPidFile(pid_file)
87
    self.failUnlessEqual(read_pid, os.getpid())
88
    self.failUnless(utils.IsProcessAlive(read_pid))
89
    self.failUnlessRaises(GenericError, utils.WritePidFile, 'test')
90
    utils.RemovePidFile('test')
91
    self.failIf(os.path.exists(pid_file),
92
                "PID file should not exist anymore")
93
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
94
                         "ReadPidFile should return 0 for missing pid file")
95
    fh = open(pid_file, "w")
96
    fh.write("blah\n")
97
    fh.close()
98
    self.failUnlessEqual(utils.ReadPidFile(pid_file), 0,
99
                         "ReadPidFile should return 0 for invalid pid file")
100
    utils.RemovePidFile('test')
101
    self.failIf(os.path.exists(pid_file),
102
                "PID file should not exist anymore")
103

    
104
  def testKill(self):
105
    pid_file = self.f_dpn('child')
106
    r_fd, w_fd = os.pipe()
107
    new_pid = os.fork()
108
    if new_pid == 0: #child
109
      utils.WritePidFile('child')
110
      os.write(w_fd, 'a')
111
      signal.pause()
112
      os._exit(0)
113
      return
114
    # else we are in the parent
115
    # wait until the child has written the pid file
116
    os.read(r_fd, 1)
117
    read_pid = utils.ReadPidFile(pid_file)
118
    self.failUnlessEqual(read_pid, new_pid)
119
    self.failUnless(utils.IsProcessAlive(new_pid))
120
    utils.KillProcess(new_pid, waitpid=True)
121
    self.failIf(utils.IsProcessAlive(new_pid))
122
    utils.RemovePidFile('child')
123
    self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0)
124

    
125
  def tearDown(self):
126
    for name in os.listdir(self.dir):
127
      os.unlink(os.path.join(self.dir, name))
128
    os.rmdir(self.dir)
129

    
130

    
131
class TestRunCmd(testutils.GanetiTestCase):
132
  """Testing case for the RunCmd function"""
133

    
134
  def setUp(self):
135
    testutils.GanetiTestCase.setUp(self)
136
    self.magic = time.ctime() + " ganeti test"
137
    self.fname = self._CreateTempFile()
138

    
139
  def testOk(self):
140
    """Test successful exit code"""
141
    result = RunCmd("/bin/sh -c 'exit 0'")
142
    self.assertEqual(result.exit_code, 0)
143
    self.assertEqual(result.output, "")
144

    
145
  def testFail(self):
146
    """Test fail exit code"""
147
    result = RunCmd("/bin/sh -c 'exit 1'")
148
    self.assertEqual(result.exit_code, 1)
149
    self.assertEqual(result.output, "")
150

    
151
  def testStdout(self):
152
    """Test standard output"""
153
    cmd = 'echo -n "%s"' % self.magic
154
    result = RunCmd("/bin/sh -c '%s'" % cmd)
155
    self.assertEqual(result.stdout, self.magic)
156
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
157
    self.assertEqual(result.output, "")
158
    self.assertFileContent(self.fname, self.magic)
159

    
160
  def testStderr(self):
161
    """Test standard error"""
162
    cmd = 'echo -n "%s"' % self.magic
163
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd)
164
    self.assertEqual(result.stderr, self.magic)
165
    result = RunCmd("/bin/sh -c '%s' 1>&2" % cmd, output=self.fname)
166
    self.assertEqual(result.output, "")
167
    self.assertFileContent(self.fname, self.magic)
168

    
169
  def testCombined(self):
170
    """Test combined output"""
171
    cmd = 'echo -n "A%s"; echo -n "B%s" 1>&2' % (self.magic, self.magic)
172
    expected = "A" + self.magic + "B" + self.magic
173
    result = RunCmd("/bin/sh -c '%s'" % cmd)
174
    self.assertEqual(result.output, expected)
175
    result = RunCmd("/bin/sh -c '%s'" % cmd, output=self.fname)
176
    self.assertEqual(result.output, "")
177
    self.assertFileContent(self.fname, expected)
178

    
179
  def testSignal(self):
180
    """Test signal"""
181
    result = RunCmd(["python", "-c", "import os; os.kill(os.getpid(), 15)"])
182
    self.assertEqual(result.signal, 15)
183
    self.assertEqual(result.output, "")
184

    
185
  def testListRun(self):
186
    """Test list runs"""
187
    result = RunCmd(["true"])
188
    self.assertEqual(result.signal, None)
189
    self.assertEqual(result.exit_code, 0)
190
    result = RunCmd(["/bin/sh", "-c", "exit 1"])
191
    self.assertEqual(result.signal, None)
192
    self.assertEqual(result.exit_code, 1)
193
    result = RunCmd(["echo", "-n", self.magic])
194
    self.assertEqual(result.signal, None)
195
    self.assertEqual(result.exit_code, 0)
196
    self.assertEqual(result.stdout, self.magic)
197

    
198
  def testFileEmptyOutput(self):
199
    """Test file output"""
200
    result = RunCmd(["true"], output=self.fname)
201
    self.assertEqual(result.signal, None)
202
    self.assertEqual(result.exit_code, 0)
203
    self.assertFileContent(self.fname, "")
204

    
205
  def testLang(self):
206
    """Test locale environment"""
207
    old_env = os.environ.copy()
208
    try:
209
      os.environ["LANG"] = "en_US.UTF-8"
210
      os.environ["LC_ALL"] = "en_US.UTF-8"
211
      result = RunCmd(["locale"])
212
      for line in result.output.splitlines():
213
        key, value = line.split("=", 1)
214
        # Ignore these variables, they're overridden by LC_ALL
215
        if key == "LANG" or key == "LANGUAGE":
216
          continue
217
        self.failIf(value and value != "C" and value != '"C"',
218
            "Variable %s is set to the invalid value '%s'" % (key, value))
219
    finally:
220
      os.environ = old_env
221

    
222
  def testDefaultCwd(self):
223
    """Test default working directory"""
224
    self.failUnlessEqual(RunCmd(["pwd"]).stdout.strip(), "/")
225

    
226
  def testCwd(self):
227
    """Test default working directory"""
228
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/").stdout.strip(), "/")
229
    self.failUnlessEqual(RunCmd(["pwd"], cwd="/tmp").stdout.strip(), "/tmp")
230
    cwd = os.getcwd()
231
    self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd)
232

    
233

    
234
class TestRemoveFile(unittest.TestCase):
235
  """Test case for the RemoveFile function"""
236

    
237
  def setUp(self):
238
    """Create a temp dir and file for each case"""
239
    self.tmpdir = tempfile.mkdtemp('', 'ganeti-unittest-')
240
    fd, self.tmpfile = tempfile.mkstemp('', '', self.tmpdir)
241
    os.close(fd)
242

    
243
  def tearDown(self):
244
    if os.path.exists(self.tmpfile):
245
      os.unlink(self.tmpfile)
246
    os.rmdir(self.tmpdir)
247

    
248

    
249
  def testIgnoreDirs(self):
250
    """Test that RemoveFile() ignores directories"""
251
    self.assertEqual(None, RemoveFile(self.tmpdir))
252

    
253

    
254
  def testIgnoreNotExisting(self):
255
    """Test that RemoveFile() ignores non-existing files"""
256
    RemoveFile(self.tmpfile)
257
    RemoveFile(self.tmpfile)
258

    
259

    
260
  def testRemoveFile(self):
261
    """Test that RemoveFile does remove a file"""
262
    RemoveFile(self.tmpfile)
263
    if os.path.exists(self.tmpfile):
264
      self.fail("File '%s' not removed" % self.tmpfile)
265

    
266

    
267
  def testRemoveSymlink(self):
268
    """Test that RemoveFile does remove symlinks"""
269
    symlink = self.tmpdir + "/symlink"
270
    os.symlink("no-such-file", symlink)
271
    RemoveFile(symlink)
272
    if os.path.exists(symlink):
273
      self.fail("File '%s' not removed" % symlink)
274
    os.symlink(self.tmpfile, symlink)
275
    RemoveFile(symlink)
276
    if os.path.exists(symlink):
277
      self.fail("File '%s' not removed" % symlink)
278

    
279

    
280
class TestRename(unittest.TestCase):
281
  """Test case for RenameFile"""
282

    
283
  def setUp(self):
284
    """Create a temporary directory"""
285
    self.tmpdir = tempfile.mkdtemp()
286
    self.tmpfile = os.path.join(self.tmpdir, "test1")
287

    
288
    # Touch the file
289
    open(self.tmpfile, "w").close()
290

    
291
  def tearDown(self):
292
    """Remove temporary directory"""
293
    shutil.rmtree(self.tmpdir)
294

    
295
  def testSimpleRename1(self):
296
    """Simple rename 1"""
297
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"))
298
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
299

    
300
  def testSimpleRename2(self):
301
    """Simple rename 2"""
302
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"),
303
                     mkdir=True)
304
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz")))
305

    
306
  def testRenameMkdir(self):
307
    """Rename with mkdir"""
308
    utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"),
309
                     mkdir=True)
310
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
311
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz")))
312

    
313
    utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"),
314
                     os.path.join(self.tmpdir, "test/foo/bar/baz"),
315
                     mkdir=True)
316
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test")))
317
    self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar")))
318
    self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz")))
319

    
320

    
321
class TestMatchNameComponent(unittest.TestCase):
322
  """Test case for the MatchNameComponent function"""
323

    
324
  def testEmptyList(self):
325
    """Test that there is no match against an empty list"""
326

    
327
    self.failUnlessEqual(MatchNameComponent("", []), None)
328
    self.failUnlessEqual(MatchNameComponent("test", []), None)
329

    
330
  def testSingleMatch(self):
331
    """Test that a single match is performed correctly"""
332
    mlist = ["test1.example.com", "test2.example.com", "test3.example.com"]
333
    for key in "test2", "test2.example", "test2.example.com":
334
      self.failUnlessEqual(MatchNameComponent(key, mlist), mlist[1])
335

    
336
  def testMultipleMatches(self):
337
    """Test that a multiple match is returned as None"""
338
    mlist = ["test1.example.com", "test1.example.org", "test1.example.net"]
339
    for key in "test1", "test1.example":
340
      self.failUnlessEqual(MatchNameComponent(key, mlist), None)
341

    
342
  def testFullMatch(self):
343
    """Test that a full match is returned correctly"""
344
    key1 = "test1"
345
    key2 = "test1.example"
346
    mlist = [key2, key2 + ".com"]
347
    self.failUnlessEqual(MatchNameComponent(key1, mlist), None)
348
    self.failUnlessEqual(MatchNameComponent(key2, mlist), key2)
349

    
350
  def testCaseInsensitivePartialMatch(self):
351
    """Test for the case_insensitive keyword"""
352
    mlist = ["test1.example.com", "test2.example.net"]
353
    self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False),
354
                     "test2.example.net")
355
    self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False),
356
                     "test2.example.net")
357
    self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False),
358
                     "test2.example.net")
359
    self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False),
360
                     "test2.example.net")
361

    
362

    
363
  def testCaseInsensitiveFullMatch(self):
364
    mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"]
365
    # Between the two ts1 a full string match non-case insensitive should work
366
    self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False),
367
                     None)
368
    self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False),
369
                     "ts1.ex")
370
    self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False),
371
                     "ts1.ex")
372
    # Between the two ts2 only case differs, so only case-match works
373
    self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False),
374
                     "ts2.ex")
375
    self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False),
376
                     "Ts2.ex")
377
    self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False),
378
                     None)
379

    
380

    
381
class TestFormatUnit(unittest.TestCase):
382
  """Test case for the FormatUnit function"""
383

    
384
  def testMiB(self):
385
    self.assertEqual(FormatUnit(1, 'h'), '1M')
386
    self.assertEqual(FormatUnit(100, 'h'), '100M')
387
    self.assertEqual(FormatUnit(1023, 'h'), '1023M')
388

    
389
    self.assertEqual(FormatUnit(1, 'm'), '1')
390
    self.assertEqual(FormatUnit(100, 'm'), '100')
391
    self.assertEqual(FormatUnit(1023, 'm'), '1023')
392

    
393
    self.assertEqual(FormatUnit(1024, 'm'), '1024')
394
    self.assertEqual(FormatUnit(1536, 'm'), '1536')
395
    self.assertEqual(FormatUnit(17133, 'm'), '17133')
396
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'm'), '1048575')
397

    
398
  def testGiB(self):
399
    self.assertEqual(FormatUnit(1024, 'h'), '1.0G')
400
    self.assertEqual(FormatUnit(1536, 'h'), '1.5G')
401
    self.assertEqual(FormatUnit(17133, 'h'), '16.7G')
402
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'h'), '1024.0G')
403

    
404
    self.assertEqual(FormatUnit(1024, 'g'), '1.0')
405
    self.assertEqual(FormatUnit(1536, 'g'), '1.5')
406
    self.assertEqual(FormatUnit(17133, 'g'), '16.7')
407
    self.assertEqual(FormatUnit(1024 * 1024 - 1, 'g'), '1024.0')
408

    
409
    self.assertEqual(FormatUnit(1024 * 1024, 'g'), '1024.0')
410
    self.assertEqual(FormatUnit(5120 * 1024, 'g'), '5120.0')
411
    self.assertEqual(FormatUnit(29829 * 1024, 'g'), '29829.0')
412

    
413
  def testTiB(self):
414
    self.assertEqual(FormatUnit(1024 * 1024, 'h'), '1.0T')
415
    self.assertEqual(FormatUnit(5120 * 1024, 'h'), '5.0T')
416
    self.assertEqual(FormatUnit(29829 * 1024, 'h'), '29.1T')
417

    
418
    self.assertEqual(FormatUnit(1024 * 1024, 't'), '1.0')
419
    self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0')
420
    self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1')
421

    
422
class TestParseUnit(unittest.TestCase):
423
  """Test case for the ParseUnit function"""
424

    
425
  SCALES = (('', 1),
426
            ('M', 1), ('G', 1024), ('T', 1024 * 1024),
427
            ('MB', 1), ('GB', 1024), ('TB', 1024 * 1024),
428
            ('MiB', 1), ('GiB', 1024), ('TiB', 1024 * 1024))
429

    
430
  def testRounding(self):
431
    self.assertEqual(ParseUnit('0'), 0)
432
    self.assertEqual(ParseUnit('1'), 4)
433
    self.assertEqual(ParseUnit('2'), 4)
434
    self.assertEqual(ParseUnit('3'), 4)
435

    
436
    self.assertEqual(ParseUnit('124'), 124)
437
    self.assertEqual(ParseUnit('125'), 128)
438
    self.assertEqual(ParseUnit('126'), 128)
439
    self.assertEqual(ParseUnit('127'), 128)
440
    self.assertEqual(ParseUnit('128'), 128)
441
    self.assertEqual(ParseUnit('129'), 132)
442
    self.assertEqual(ParseUnit('130'), 132)
443

    
444
  def testFloating(self):
445
    self.assertEqual(ParseUnit('0'), 0)
446
    self.assertEqual(ParseUnit('0.5'), 4)
447
    self.assertEqual(ParseUnit('1.75'), 4)
448
    self.assertEqual(ParseUnit('1.99'), 4)
449
    self.assertEqual(ParseUnit('2.00'), 4)
450
    self.assertEqual(ParseUnit('2.01'), 4)
451
    self.assertEqual(ParseUnit('3.99'), 4)
452
    self.assertEqual(ParseUnit('4.00'), 4)
453
    self.assertEqual(ParseUnit('4.01'), 8)
454
    self.assertEqual(ParseUnit('1.5G'), 1536)
455
    self.assertEqual(ParseUnit('1.8G'), 1844)
456
    self.assertEqual(ParseUnit('8.28T'), 8682212)
457

    
458
  def testSuffixes(self):
459
    for sep in ('', ' ', '   ', "\t", "\t "):
460
      for suffix, scale in TestParseUnit.SCALES:
461
        for func in (lambda x: x, str.lower, str.upper):
462
          self.assertEqual(ParseUnit('1024' + sep + func(suffix)),
463
                           1024 * scale)
464

    
465
  def testInvalidInput(self):
466
    for sep in ('-', '_', ',', 'a'):
467
      for suffix, _ in TestParseUnit.SCALES:
468
        self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix)
469

    
470
    for suffix, _ in TestParseUnit.SCALES:
471
      self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix)
472

    
473

    
474
class TestSshKeys(testutils.GanetiTestCase):
475
  """Test case for the AddAuthorizedKey function"""
476

    
477
  KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a'
478
  KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" '
479
           'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b')
480

    
481
  def setUp(self):
482
    testutils.GanetiTestCase.setUp(self)
483
    self.tmpname = self._CreateTempFile()
484
    handle = open(self.tmpname, 'w')
485
    try:
486
      handle.write("%s\n" % TestSshKeys.KEY_A)
487
      handle.write("%s\n" % TestSshKeys.KEY_B)
488
    finally:
489
      handle.close()
490

    
491
  def testAddingNewKey(self):
492
    AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test')
493

    
494
    self.assertFileContent(self.tmpname,
495
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
496
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
497
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
498
      "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n")
499

    
500
  def testAddingAlmostButNotCompletelyTheSameKey(self):
501
    AddAuthorizedKey(self.tmpname,
502
        'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test')
503

    
504
    self.assertFileContent(self.tmpname,
505
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
506
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
507
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n"
508
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n")
509

    
510
  def testAddingExistingKeyWithSomeMoreSpaces(self):
511
    AddAuthorizedKey(self.tmpname,
512
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
513

    
514
    self.assertFileContent(self.tmpname,
515
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
516
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
517
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
518

    
519
  def testRemovingExistingKeyWithSomeMoreSpaces(self):
520
    RemoveAuthorizedKey(self.tmpname,
521
        'ssh-dss  AAAAB3NzaC1w5256closdj32mZaQU   root@key-a')
522

    
523
    self.assertFileContent(self.tmpname,
524
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
525
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
526

    
527
  def testRemovingNonExistingKey(self):
528
    RemoveAuthorizedKey(self.tmpname,
529
        'ssh-dss  AAAAB3Nsdfj230xxjxJjsjwjsjdjU   root@test')
530

    
531
    self.assertFileContent(self.tmpname,
532
      "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n"
533
      'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"'
534
      " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n")
535

    
536

    
537
class TestEtcHosts(testutils.GanetiTestCase):
538
  """Test functions modifying /etc/hosts"""
539

    
540
  def setUp(self):
541
    testutils.GanetiTestCase.setUp(self)
542
    self.tmpname = self._CreateTempFile()
543
    handle = open(self.tmpname, 'w')
544
    try:
545
      handle.write('# This is a test file for /etc/hosts\n')
546
      handle.write('127.0.0.1\tlocalhost\n')
547
      handle.write('192.168.1.1 router gw\n')
548
    finally:
549
      handle.close()
550

    
551
  def testSettingNewIp(self):
552
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost'])
553

    
554
    self.assertFileContent(self.tmpname,
555
      "# This is a test file for /etc/hosts\n"
556
      "127.0.0.1\tlocalhost\n"
557
      "192.168.1.1 router gw\n"
558
      "1.2.3.4\tmyhost.domain.tld myhost\n")
559
    self.assertFileMode(self.tmpname, 0644)
560

    
561
  def testSettingExistingIp(self):
562
    SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld',
563
                     ['myhost'])
564

    
565
    self.assertFileContent(self.tmpname,
566
      "# This is a test file for /etc/hosts\n"
567
      "127.0.0.1\tlocalhost\n"
568
      "192.168.1.1\tmyhost.domain.tld myhost\n")
569
    self.assertFileMode(self.tmpname, 0644)
570

    
571
  def testSettingDuplicateName(self):
572
    SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost'])
573

    
574
    self.assertFileContent(self.tmpname,
575
      "# This is a test file for /etc/hosts\n"
576
      "127.0.0.1\tlocalhost\n"
577
      "192.168.1.1 router gw\n"
578
      "1.2.3.4\tmyhost\n")
579
    self.assertFileMode(self.tmpname, 0644)
580

    
581
  def testRemovingExistingHost(self):
582
    RemoveEtcHostsEntry(self.tmpname, 'router')
583

    
584
    self.assertFileContent(self.tmpname,
585
      "# This is a test file for /etc/hosts\n"
586
      "127.0.0.1\tlocalhost\n"
587
      "192.168.1.1 gw\n")
588
    self.assertFileMode(self.tmpname, 0644)
589

    
590
  def testRemovingSingleExistingHost(self):
591
    RemoveEtcHostsEntry(self.tmpname, 'localhost')
592

    
593
    self.assertFileContent(self.tmpname,
594
      "# This is a test file for /etc/hosts\n"
595
      "192.168.1.1 router gw\n")
596
    self.assertFileMode(self.tmpname, 0644)
597

    
598
  def testRemovingNonExistingHost(self):
599
    RemoveEtcHostsEntry(self.tmpname, 'myhost')
600

    
601
    self.assertFileContent(self.tmpname,
602
      "# This is a test file for /etc/hosts\n"
603
      "127.0.0.1\tlocalhost\n"
604
      "192.168.1.1 router gw\n")
605
    self.assertFileMode(self.tmpname, 0644)
606

    
607
  def testRemovingAlias(self):
608
    RemoveEtcHostsEntry(self.tmpname, 'gw')
609

    
610
    self.assertFileContent(self.tmpname,
611
      "# This is a test file for /etc/hosts\n"
612
      "127.0.0.1\tlocalhost\n"
613
      "192.168.1.1 router\n")
614
    self.assertFileMode(self.tmpname, 0644)
615

    
616

    
617
class TestShellQuoting(unittest.TestCase):
618
  """Test case for shell quoting functions"""
619

    
620
  def testShellQuote(self):
621
    self.assertEqual(ShellQuote('abc'), "abc")
622
    self.assertEqual(ShellQuote('ab"c'), "'ab\"c'")
623
    self.assertEqual(ShellQuote("a'bc"), "'a'\\''bc'")
624
    self.assertEqual(ShellQuote("a b c"), "'a b c'")
625
    self.assertEqual(ShellQuote("a b\\ c"), "'a b\\ c'")
626

    
627
  def testShellQuoteArgs(self):
628
    self.assertEqual(ShellQuoteArgs(['a', 'b', 'c']), "a b c")
629
    self.assertEqual(ShellQuoteArgs(['a', 'b"', 'c']), "a 'b\"' c")
630
    self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c")
631

    
632

    
633
class TestTcpPing(unittest.TestCase):
634
  """Testcase for TCP version of ping - against listen(2)ing port"""
635

    
636
  def setUp(self):
637
    self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
638
    self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
639
    self.listenerport = self.listener.getsockname()[1]
640
    self.listener.listen(1)
641

    
642
  def tearDown(self):
643
    self.listener.shutdown(socket.SHUT_RDWR)
644
    del self.listener
645
    del self.listenerport
646

    
647
  def testTcpPingToLocalHostAccept(self):
648
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
649
                         self.listenerport,
650
                         timeout=10,
651
                         live_port_needed=True,
652
                         source=constants.LOCALHOST_IP_ADDRESS,
653
                         ),
654
                 "failed to connect to test listener")
655

    
656
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
657
                         self.listenerport,
658
                         timeout=10,
659
                         live_port_needed=True,
660
                         ),
661
                 "failed to connect to test listener (no source)")
662

    
663

    
664
class TestTcpPingDeaf(unittest.TestCase):
665
  """Testcase for TCP version of ping - against non listen(2)ing port"""
666

    
667
  def setUp(self):
668
    self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
669
    self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0))
670
    self.deaflistenerport = self.deaflistener.getsockname()[1]
671

    
672
  def tearDown(self):
673
    del self.deaflistener
674
    del self.deaflistenerport
675

    
676
  def testTcpPingToLocalHostAcceptDeaf(self):
677
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
678
                        self.deaflistenerport,
679
                        timeout=constants.TCP_PING_TIMEOUT,
680
                        live_port_needed=True,
681
                        source=constants.LOCALHOST_IP_ADDRESS,
682
                        ), # need successful connect(2)
683
                "successfully connected to deaf listener")
684

    
685
    self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS,
686
                        self.deaflistenerport,
687
                        timeout=constants.TCP_PING_TIMEOUT,
688
                        live_port_needed=True,
689
                        ), # need successful connect(2)
690
                "successfully connected to deaf listener (no source addr)")
691

    
692
  def testTcpPingToLocalHostNoAccept(self):
693
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
694
                         self.deaflistenerport,
695
                         timeout=constants.TCP_PING_TIMEOUT,
696
                         live_port_needed=False,
697
                         source=constants.LOCALHOST_IP_ADDRESS,
698
                         ), # ECONNREFUSED is OK
699
                 "failed to ping alive host on deaf port")
700

    
701
    self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS,
702
                         self.deaflistenerport,
703
                         timeout=constants.TCP_PING_TIMEOUT,
704
                         live_port_needed=False,
705
                         ), # ECONNREFUSED is OK
706
                 "failed to ping alive host on deaf port (no source addr)")
707

    
708

    
709
class TestOwnIpAddress(unittest.TestCase):
710
  """Testcase for OwnIpAddress"""
711

    
712
  def testOwnLoopback(self):
713
    """check having the loopback ip"""
714
    self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS),
715
                    "Should own the loopback address")
716

    
717
  def testNowOwnAddress(self):
718
    """check that I don't own an address"""
719

    
720
    # network 192.0.2.0/24 is reserved for test/documentation as per
721
    # rfc 3330, so we *should* not have an address of this range... if
722
    # this fails, we should extend the test to multiple addresses
723
    DST_IP = "192.0.2.1"
724
    self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
725

    
726

    
727
class TestListVisibleFiles(unittest.TestCase):
728
  """Test case for ListVisibleFiles"""
729

    
730
  def setUp(self):
731
    self.path = tempfile.mkdtemp()
732

    
733
  def tearDown(self):
734
    shutil.rmtree(self.path)
735

    
736
  def _test(self, files, expected):
737
    # Sort a copy
738
    expected = expected[:]
739
    expected.sort()
740

    
741
    for name in files:
742
      f = open(os.path.join(self.path, name), 'w')
743
      try:
744
        f.write("Test\n")
745
      finally:
746
        f.close()
747

    
748
    found = ListVisibleFiles(self.path)
749
    found.sort()
750

    
751
    self.assertEqual(found, expected)
752

    
753
  def testAllVisible(self):
754
    files = ["a", "b", "c"]
755
    expected = files
756
    self._test(files, expected)
757

    
758
  def testNoneVisible(self):
759
    files = [".a", ".b", ".c"]
760
    expected = []
761
    self._test(files, expected)
762

    
763
  def testSomeVisible(self):
764
    files = ["a", "b", ".c"]
765
    expected = ["a", "b"]
766
    self._test(files, expected)
767

    
768

    
769
class TestNewUUID(unittest.TestCase):
770
  """Test case for NewUUID"""
771

    
772
  _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-'
773
                        '[a-f0-9]{4}-[a-f0-9]{12}$')
774

    
775
  def runTest(self):
776
    self.failUnless(self._re_uuid.match(utils.NewUUID()))
777

    
778

    
779
class TestUniqueSequence(unittest.TestCase):
780
  """Test case for UniqueSequence"""
781

    
782
  def _test(self, input, expected):
783
    self.assertEqual(utils.UniqueSequence(input), expected)
784

    
785
  def runTest(self):
786
    # Ordered input
787
    self._test([1, 2, 3], [1, 2, 3])
788
    self._test([1, 1, 2, 2, 3, 3], [1, 2, 3])
789
    self._test([1, 2, 2, 3], [1, 2, 3])
790
    self._test([1, 2, 3, 3], [1, 2, 3])
791

    
792
    # Unordered input
793
    self._test([1, 2, 3, 1, 2, 3], [1, 2, 3])
794
    self._test([1, 1, 2, 3, 3, 1, 2], [1, 2, 3])
795

    
796
    # Strings
797
    self._test(["a", "a"], ["a"])
798
    self._test(["a", "b"], ["a", "b"])
799
    self._test(["a", "b", "a"], ["a", "b"])
800

    
801

    
802
class TestFirstFree(unittest.TestCase):
803
  """Test case for the FirstFree function"""
804

    
805
  def test(self):
806
    """Test FirstFree"""
807
    self.failUnlessEqual(FirstFree([0, 1, 3]), 2)
808
    self.failUnlessEqual(FirstFree([]), None)
809
    self.failUnlessEqual(FirstFree([3, 4, 6]), 0)
810
    self.failUnlessEqual(FirstFree([3, 4, 6], base=3), 5)
811
    self.failUnlessRaises(AssertionError, FirstFree, [0, 3, 4, 6], base=3)
812

    
813

    
814
class TestTailFile(testutils.GanetiTestCase):
815
  """Test case for the TailFile function"""
816

    
817
  def testEmpty(self):
818
    fname = self._CreateTempFile()
819
    self.failUnlessEqual(TailFile(fname), [])
820
    self.failUnlessEqual(TailFile(fname, lines=25), [])
821

    
822
  def testAllLines(self):
823
    data = ["test %d" % i for i in range(30)]
824
    for i in range(30):
825
      fname = self._CreateTempFile()
826
      fd = open(fname, "w")
827
      fd.write("\n".join(data[:i]))
828
      if i > 0:
829
        fd.write("\n")
830
      fd.close()
831
      self.failUnlessEqual(TailFile(fname, lines=i), data[:i])
832

    
833
  def testPartialLines(self):
834
    data = ["test %d" % i for i in range(30)]
835
    fname = self._CreateTempFile()
836
    fd = open(fname, "w")
837
    fd.write("\n".join(data))
838
    fd.write("\n")
839
    fd.close()
840
    for i in range(1, 30):
841
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
842

    
843
  def testBigFile(self):
844
    data = ["test %d" % i for i in range(30)]
845
    fname = self._CreateTempFile()
846
    fd = open(fname, "w")
847
    fd.write("X" * 1048576)
848
    fd.write("\n")
849
    fd.write("\n".join(data))
850
    fd.write("\n")
851
    fd.close()
852
    for i in range(1, 30):
853
      self.failUnlessEqual(TailFile(fname, lines=i), data[-i:])
854

    
855

    
856
class TestFileLock(unittest.TestCase):
857
  """Test case for the FileLock class"""
858

    
859
  def setUp(self):
860
    self.tmpfile = tempfile.NamedTemporaryFile()
861
    self.lock = utils.FileLock(self.tmpfile.name)
862

    
863
  def testSharedNonblocking(self):
864
    self.lock.Shared(blocking=False)
865
    self.lock.Close()
866

    
867
  def testExclusiveNonblocking(self):
868
    self.lock.Exclusive(blocking=False)
869
    self.lock.Close()
870

    
871
  def testUnlockNonblocking(self):
872
    self.lock.Unlock(blocking=False)
873
    self.lock.Close()
874

    
875
  def testSharedBlocking(self):
876
    self.lock.Shared(blocking=True)
877
    self.lock.Close()
878

    
879
  def testExclusiveBlocking(self):
880
    self.lock.Exclusive(blocking=True)
881
    self.lock.Close()
882

    
883
  def testUnlockBlocking(self):
884
    self.lock.Unlock(blocking=True)
885
    self.lock.Close()
886

    
887
  def testSharedExclusiveUnlock(self):
888
    self.lock.Shared(blocking=False)
889
    self.lock.Exclusive(blocking=False)
890
    self.lock.Unlock(blocking=False)
891
    self.lock.Close()
892

    
893
  def testExclusiveSharedUnlock(self):
894
    self.lock.Exclusive(blocking=False)
895
    self.lock.Shared(blocking=False)
896
    self.lock.Unlock(blocking=False)
897
    self.lock.Close()
898

    
899
  def testCloseShared(self):
900
    self.lock.Close()
901
    self.assertRaises(AssertionError, self.lock.Shared, blocking=False)
902

    
903
  def testCloseExclusive(self):
904
    self.lock.Close()
905
    self.assertRaises(AssertionError, self.lock.Exclusive, blocking=False)
906

    
907
  def testCloseUnlock(self):
908
    self.lock.Close()
909
    self.assertRaises(AssertionError, self.lock.Unlock, blocking=False)
910

    
911

    
912
class TestTimeFunctions(unittest.TestCase):
913
  """Test case for time functions"""
914

    
915
  def runTest(self):
916
    self.assertEqual(utils.SplitTime(1), (1, 0))
917
    self.assertEqual(utils.SplitTime(1.5), (1, 500000))
918
    self.assertEqual(utils.SplitTime(1218448917.4809151), (1218448917, 480915))
919
    self.assertEqual(utils.SplitTime(123.48012), (123, 480120))
920
    self.assertEqual(utils.SplitTime(123.9996), (123, 999600))
921
    self.assertEqual(utils.SplitTime(123.9995), (123, 999500))
922
    self.assertEqual(utils.SplitTime(123.9994), (123, 999400))
923
    self.assertEqual(utils.SplitTime(123.999999999), (123, 999999))
924

    
925
    self.assertRaises(AssertionError, utils.SplitTime, -1)
926

    
927
    self.assertEqual(utils.MergeTime((1, 0)), 1.0)
928
    self.assertEqual(utils.MergeTime((1, 500000)), 1.5)
929
    self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5)
930

    
931
    self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3),
932
                     1218448917.481)
933
    self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801)
934

    
935
    self.assertRaises(AssertionError, utils.MergeTime, (0, -1))
936
    self.assertRaises(AssertionError, utils.MergeTime, (0, 1000000))
937
    self.assertRaises(AssertionError, utils.MergeTime, (0, 9999999))
938
    self.assertRaises(AssertionError, utils.MergeTime, (-1, 0))
939
    self.assertRaises(AssertionError, utils.MergeTime, (-9999, 0))
940

    
941

    
942
class FieldSetTestCase(unittest.TestCase):
943
  """Test case for FieldSets"""
944

    
945
  def testSimpleMatch(self):
946
    f = utils.FieldSet("a", "b", "c", "def")
947
    self.failUnless(f.Matches("a"))
948
    self.failIf(f.Matches("d"), "Substring matched")
949
    self.failIf(f.Matches("defghi"), "Prefix string matched")
950
    self.failIf(f.NonMatching(["b", "c"]))
951
    self.failIf(f.NonMatching(["a", "b", "c", "def"]))
952
    self.failUnless(f.NonMatching(["a", "d"]))
953

    
954
  def testRegexMatch(self):
955
    f = utils.FieldSet("a", "b([0-9]+)", "c")
956
    self.failUnless(f.Matches("b1"))
957
    self.failUnless(f.Matches("b99"))
958
    self.failIf(f.Matches("b/1"))
959
    self.failIf(f.NonMatching(["b12", "c"]))
960
    self.failUnless(f.NonMatching(["a", "1"]))
961

    
962
class TestForceDictType(unittest.TestCase):
963
  """Test case for ForceDictType"""
964

    
965
  def setUp(self):
966
    self.key_types = {
967
      'a': constants.VTYPE_INT,
968
      'b': constants.VTYPE_BOOL,
969
      'c': constants.VTYPE_STRING,
970
      'd': constants.VTYPE_SIZE,
971
      }
972

    
973
  def _fdt(self, dict, allowed_values=None):
974
    if allowed_values is None:
975
      ForceDictType(dict, self.key_types)
976
    else:
977
      ForceDictType(dict, self.key_types, allowed_values=allowed_values)
978

    
979
    return dict
980

    
981
  def testSimpleDict(self):
982
    self.assertEqual(self._fdt({}), {})
983
    self.assertEqual(self._fdt({'a': 1}), {'a': 1})
984
    self.assertEqual(self._fdt({'a': '1'}), {'a': 1})
985
    self.assertEqual(self._fdt({'a': 1, 'b': 1}), {'a':1, 'b': True})
986
    self.assertEqual(self._fdt({'b': 1, 'c': 'foo'}), {'b': True, 'c': 'foo'})
987
    self.assertEqual(self._fdt({'b': 1, 'c': False}), {'b': True, 'c': ''})
988
    self.assertEqual(self._fdt({'b': 'false'}), {'b': False})
989
    self.assertEqual(self._fdt({'b': 'False'}), {'b': False})
990
    self.assertEqual(self._fdt({'b': 'true'}), {'b': True})
991
    self.assertEqual(self._fdt({'b': 'True'}), {'b': True})
992
    self.assertEqual(self._fdt({'d': '4'}), {'d': 4})
993
    self.assertEqual(self._fdt({'d': '4M'}), {'d': 4})
994

    
995
  def testErrors(self):
996
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'})
997
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True})
998
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'})
999
    self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
1000

    
1001

    
1002
class TestIsAbsNormPath(unittest.TestCase):
1003
  """Testing case for IsProcessAlive"""
1004

    
1005
  def _pathTestHelper(self, path, result):
1006
    if result:
1007
      self.assert_(IsNormAbsPath(path),
1008
          "Path %s should result absolute and normalized" % path)
1009
    else:
1010
      self.assert_(not IsNormAbsPath(path),
1011
          "Path %s should not result absolute and normalized" % path)
1012

    
1013
  def testBase(self):
1014
    self._pathTestHelper('/etc', True)
1015
    self._pathTestHelper('/srv', True)
1016
    self._pathTestHelper('etc', False)
1017
    self._pathTestHelper('/etc/../root', False)
1018
    self._pathTestHelper('/etc/', False)
1019

    
1020

    
1021
class TestSafeEncode(unittest.TestCase):
1022
  """Test case for SafeEncode"""
1023

    
1024
  def testAscii(self):
1025
    for txt in [string.digits, string.letters, string.punctuation]:
1026
      self.failUnlessEqual(txt, SafeEncode(txt))
1027

    
1028
  def testDoubleEncode(self):
1029
    for i in range(255):
1030
      txt = SafeEncode(chr(i))
1031
      self.failUnlessEqual(txt, SafeEncode(txt))
1032

    
1033
  def testUnicode(self):
1034
    # 1024 is high enough to catch non-direct ASCII mappings
1035
    for i in range(1024):
1036
      txt = SafeEncode(unichr(i))
1037
      self.failUnlessEqual(txt, SafeEncode(txt))
1038

    
1039

    
1040
class TestFormatTime(unittest.TestCase):
1041
  """Testing case for FormatTime"""
1042

    
1043
  def testNone(self):
1044
    self.failUnlessEqual(FormatTime(None), "N/A")
1045

    
1046
  def testInvalid(self):
1047
    self.failUnlessEqual(FormatTime(()), "N/A")
1048

    
1049
  def testNow(self):
1050
    # tests that we accept time.time input
1051
    FormatTime(time.time())
1052
    # tests that we accept int input
1053
    FormatTime(int(time.time()))
1054

    
1055

    
1056
if __name__ == '__main__':
1057
  testutils.GanetiTestProgram()