X-Git-Url: https://code.grnet.gr/git/ganeti-local/blobdiff_plain/fde0203b0e640575f04249e0bbf5aab9abb0eca7..d1a0ab50cab99fd572ffd1100bada8412c23b221:/test/ganeti.utils_unittest.py diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index fa900c9..2c46afc 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -1,7 +1,7 @@ #!/usr/bin/python # -# Copyright (C) 2006, 2007 Google Inc. +# Copyright (C) 2006, 2007, 2010 Google Inc. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -21,39 +21,34 @@ """Script for unittesting the utils module""" -import unittest +import distutils.version +import errno +import fcntl +import glob import os -import time -import tempfile import os.path -import os -import stat +import re +import shutil import signal import socket -import shutil -import re -import select +import stat import string -import OpenSSL +import tempfile +import time +import unittest import warnings -import distutils.version -import glob +import OpenSSL +from cStringIO import StringIO -import ganeti import testutils from ganeti import constants +from ganeti import compat from ganeti import utils from ganeti import errors -from ganeti.utils import IsProcessAlive, RunCmd, \ - RemoveFile, MatchNameComponent, FormatUnit, \ - ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \ - ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \ - SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \ - TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \ - UnescapeAndSplit, RunParts, PathJoin, HostInfo - -from ganeti.errors import LockError, UnitParseError, GenericError, \ - ProgrammerError, OpPrereqError +from ganeti.utils import RunCmd, RemoveFile, MatchNameComponent, FormatUnit, \ + ParseUnit, ShellQuote, ShellQuoteArgs, ListVisibleFiles, FirstFree, \ + TailFile, SafeEncode, FormatTime, UnescapeAndSplit, RunParts, PathJoin, \ + ReadOneLineFile, SetEtcHostsEntry, RemoveEtcHostsEntry class TestIsProcessAlive(unittest.TestCase): @@ -61,8 +56,7 @@ class TestIsProcessAlive(unittest.TestCase): def testExists(self): mypid = os.getpid() - self.assert_(IsProcessAlive(mypid), - "can't find myself running") + self.assert_(utils.IsProcessAlive(mypid), "can't find myself running") def testNotExisting(self): pid_non_existing = os.fork() @@ -71,8 +65,106 @@ class TestIsProcessAlive(unittest.TestCase): elif pid_non_existing < 0: raise SystemError("can't fork") os.waitpid(pid_non_existing, 0) - self.assert_(not IsProcessAlive(pid_non_existing), - "nonexisting process detected") + self.assertFalse(utils.IsProcessAlive(pid_non_existing), + "nonexisting process detected") + + +class TestGetProcStatusPath(unittest.TestCase): + def test(self): + self.assert_("/1234/" in utils._GetProcStatusPath(1234)) + self.assertNotEqual(utils._GetProcStatusPath(1), + utils._GetProcStatusPath(2)) + + +class TestIsProcessHandlingSignal(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testParseSigsetT(self): + self.assertEqual(len(utils._ParseSigsetT("0")), 0) + self.assertEqual(utils._ParseSigsetT("1"), set([1])) + self.assertEqual(utils._ParseSigsetT("1000a"), set([2, 4, 17])) + self.assertEqual(utils._ParseSigsetT("810002"), set([2, 17, 24, ])) + self.assertEqual(utils._ParseSigsetT("0000000180000202"), + set([2, 10, 32, 33])) + self.assertEqual(utils._ParseSigsetT("0000000180000002"), + set([2, 32, 33])) + self.assertEqual(utils._ParseSigsetT("0000000188000002"), + set([2, 28, 32, 33])) + self.assertEqual(utils._ParseSigsetT("000000004b813efb"), + set([1, 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 17, + 24, 25, 26, 28, 31])) + self.assertEqual(utils._ParseSigsetT("ffffff"), set(range(1, 25))) + + def testGetProcStatusField(self): + for field in ["SigCgt", "Name", "FDSize"]: + for value in ["", "0", "cat", " 1234 KB"]: + pstatus = "\n".join([ + "VmPeak: 999 kB", + "%s: %s" % (field, value), + "TracerPid: 0", + ]) + result = utils._GetProcStatusField(pstatus, field) + self.assertEqual(result, value.strip()) + + def test(self): + sp = PathJoin(self.tmpdir, "status") + + utils.WriteFile(sp, data="\n".join([ + "Name: bash", + "State: S (sleeping)", + "SleepAVG: 98%", + "Pid: 22250", + "PPid: 10858", + "TracerPid: 0", + "SigBlk: 0000000000010000", + "SigIgn: 0000000000384004", + "SigCgt: 000000004b813efb", + "CapEff: 0000000000000000", + ])) + + self.assert_(utils.IsProcessHandlingSignal(1234, 10, status_path=sp)) + + def testNoSigCgt(self): + sp = PathJoin(self.tmpdir, "status") + + utils.WriteFile(sp, data="\n".join([ + "Name: bash", + ])) + + self.assertRaises(RuntimeError, utils.IsProcessHandlingSignal, + 1234, 10, status_path=sp) + + def testNoSuchFile(self): + sp = PathJoin(self.tmpdir, "notexist") + + self.assertFalse(utils.IsProcessHandlingSignal(1234, 10, status_path=sp)) + + @staticmethod + def _TestRealProcess(): + signal.signal(signal.SIGUSR1, signal.SIG_DFL) + if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1): + raise Exception("SIGUSR1 is handled when it should not be") + + signal.signal(signal.SIGUSR1, lambda signum, frame: None) + if not utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1): + raise Exception("SIGUSR1 is not handled when it should be") + + signal.signal(signal.SIGUSR1, signal.SIG_IGN) + if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1): + raise Exception("SIGUSR1 is not handled when it should be") + + signal.signal(signal.SIGUSR1, signal.SIG_DFL) + if utils.IsProcessHandlingSignal(os.getpid(), signal.SIGUSR1): + raise Exception("SIGUSR1 is handled when it should not be") + + return True + + def testRealProcess(self): + self.assert_(utils.RunInSeparateProcess(self._TestRealProcess)) class TestPidFileFunctions(unittest.TestCase): @@ -85,13 +177,15 @@ class TestPidFileFunctions(unittest.TestCase): def testPidFileFunctions(self): pid_file = self.f_dpn('test') - utils.WritePidFile('test') + fd = utils.WritePidFile(self.f_dpn('test')) self.failUnless(os.path.exists(pid_file), "PID file should have been created") read_pid = utils.ReadPidFile(pid_file) self.failUnlessEqual(read_pid, os.getpid()) self.failUnless(utils.IsProcessAlive(read_pid)) - self.failUnlessRaises(GenericError, utils.WritePidFile, 'test') + self.failUnlessRaises(errors.LockError, utils.WritePidFile, + self.f_dpn('test')) + os.close(fd) utils.RemovePidFile('test') self.failIf(os.path.exists(pid_file), "PID file should not exist anymore") @@ -102,6 +196,9 @@ class TestPidFileFunctions(unittest.TestCase): fh.close() self.failUnlessEqual(utils.ReadPidFile(pid_file), 0, "ReadPidFile should return 0 for invalid pid file") + # but now, even with the file existing, we should be able to lock it + fd = utils.WritePidFile(self.f_dpn('test')) + os.close(fd) utils.RemovePidFile('test') self.failIf(os.path.exists(pid_file), "PID file should not exist anymore") @@ -111,7 +208,7 @@ class TestPidFileFunctions(unittest.TestCase): r_fd, w_fd = os.pipe() new_pid = os.fork() if new_pid == 0: #child - utils.WritePidFile('child') + utils.WritePidFile(self.f_dpn('child')) os.write(w_fd, 'a') signal.pause() os._exit(0) @@ -125,7 +222,7 @@ class TestPidFileFunctions(unittest.TestCase): utils.KillProcess(new_pid, waitpid=True) self.failIf(utils.IsProcessAlive(new_pid)) utils.RemovePidFile('child') - self.failUnlessRaises(ProgrammerError, utils.KillProcess, 0) + self.failUnlessRaises(errors.ProgrammerError, utils.KillProcess, 0) def tearDown(self): for name in os.listdir(self.dir): @@ -368,6 +465,155 @@ class TestRunParts(unittest.TestCase): self.failUnless(not runresult.failed) +class TestStartDaemon(testutils.GanetiTestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test") + self.tmpfile = os.path.join(self.tmpdir, "test") + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testShell(self): + utils.StartDaemon("echo Hello World > %s" % self.tmpfile) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testShellOutput(self): + utils.StartDaemon("echo Hello World", output=self.tmpfile) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testNoShellNoOutput(self): + utils.StartDaemon(["pwd"]) + + def testNoShellNoOutputTouch(self): + testfile = os.path.join(self.tmpdir, "check") + self.failIf(os.path.exists(testfile)) + utils.StartDaemon(["touch", testfile]) + self._wait(testfile, 60.0, "") + + def testNoShellOutput(self): + utils.StartDaemon(["pwd"], output=self.tmpfile) + self._wait(self.tmpfile, 60.0, "/") + + def testNoShellOutputCwd(self): + utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd()) + self._wait(self.tmpfile, 60.0, os.getcwd()) + + def testShellEnv(self): + utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile, + env={ "GNT_TEST_VAR": "Hello World", }) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testNoShellEnv(self): + utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile, + env={ "GNT_TEST_VAR": "Hello World", }) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testOutputFd(self): + fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT) + try: + utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd()) + finally: + os.close(fd) + self._wait(self.tmpfile, 60.0, os.getcwd()) + + def testPid(self): + pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile) + self._wait(self.tmpfile, 60.0, str(pid)) + + def testPidFile(self): + pidfile = os.path.join(self.tmpdir, "pid") + checkfile = os.path.join(self.tmpdir, "abort") + + pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile, + output=self.tmpfile) + try: + fd = os.open(pidfile, os.O_RDONLY) + try: + # Check file is locked + self.assertRaises(errors.LockError, utils.LockFile, fd) + + pidtext = os.read(fd, 100) + finally: + os.close(fd) + + self.assertEqual(int(pidtext.strip()), pid) + + self.assert_(utils.IsProcessAlive(pid)) + finally: + # No matter what happens, kill daemon + utils.KillProcess(pid, timeout=5.0, waitpid=False) + self.failIf(utils.IsProcessAlive(pid)) + + self.assertEqual(utils.ReadFile(self.tmpfile), "") + + def _wait(self, path, timeout, expected): + # Due to the asynchronous nature of daemon processes, polling is necessary. + # A timeout makes sure the test doesn't hang forever. + def _CheckFile(): + if not (os.path.isfile(path) and + utils.ReadFile(path).strip() == expected): + raise utils.RetryAgain() + + try: + utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout) + except utils.RetryTimeout: + self.fail("Apparently the daemon didn't run in %s seconds and/or" + " didn't write the correct output" % timeout) + + def testError(self): + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"]) + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + output=os.path.join(self.tmpdir, "DIR/NOT/EXIST")) + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST")) + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + output=os.path.join(self.tmpdir, "DIR/NOT/EXIST")) + + fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT) + try: + self.assertRaises(errors.ProgrammerError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + output=self.tmpfile, output_fd=fd) + finally: + os.close(fd) + + +class TestSetCloseOnExecFlag(unittest.TestCase): + """Tests for SetCloseOnExecFlag""" + + def setUp(self): + self.tmpfile = tempfile.TemporaryFile() + + def testEnable(self): + utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True) + self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) & + fcntl.FD_CLOEXEC) + + def testDisable(self): + utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False) + self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) & + fcntl.FD_CLOEXEC) + + +class TestSetNonblockFlag(unittest.TestCase): + def setUp(self): + self.tmpfile = tempfile.TemporaryFile() + + def testEnable(self): + utils.SetNonblockFlag(self.tmpfile.fileno(), True) + self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) & + os.O_NONBLOCK) + + def testDisable(self): + utils.SetNonblockFlag(self.tmpfile.fileno(), False) + self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) & + os.O_NONBLOCK) + + class TestRemoveFile(unittest.TestCase): """Test case for the RemoveFile function""" @@ -382,25 +628,21 @@ class TestRemoveFile(unittest.TestCase): os.unlink(self.tmpfile) os.rmdir(self.tmpdir) - def testIgnoreDirs(self): """Test that RemoveFile() ignores directories""" self.assertEqual(None, RemoveFile(self.tmpdir)) - def testIgnoreNotExisting(self): """Test that RemoveFile() ignores non-existing files""" RemoveFile(self.tmpfile) RemoveFile(self.tmpfile) - def testRemoveFile(self): """Test that RemoveFile does remove a file""" RemoveFile(self.tmpfile) if os.path.exists(self.tmpfile): self.fail("File '%s' not removed" % self.tmpfile) - def testRemoveSymlink(self): """Test that RemoveFile does remove symlinks""" symlink = self.tmpdir + "/symlink" @@ -515,6 +757,107 @@ class TestMatchNameComponent(unittest.TestCase): None) +class TestReadFile(testutils.GanetiTestCase): + + def testReadAll(self): + data = utils.ReadFile(self._TestDataFilename("cert1.pem")) + self.assertEqual(len(data), 814) + + h = compat.md5_hash() + h.update(data) + self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4") + + def testReadSize(self): + data = utils.ReadFile(self._TestDataFilename("cert1.pem"), + size=100) + self.assertEqual(len(data), 100) + + h = compat.md5_hash() + h.update(data) + self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7") + + def testError(self): + self.assertRaises(EnvironmentError, utils.ReadFile, + "/dev/null/does-not-exist") + + +class TestReadOneLineFile(testutils.GanetiTestCase): + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + def testDefault(self): + data = ReadOneLineFile(self._TestDataFilename("cert1.pem")) + self.assertEqual(len(data), 27) + self.assertEqual(data, "-----BEGIN CERTIFICATE-----") + + def testNotStrict(self): + data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False) + self.assertEqual(len(data), 27) + self.assertEqual(data, "-----BEGIN CERTIFICATE-----") + + def testStrictFailure(self): + self.assertRaises(errors.GenericError, ReadOneLineFile, + self._TestDataFilename("cert1.pem"), strict=True) + + def testLongLine(self): + dummydata = (1024 * "Hello World! ") + myfile = self._CreateTempFile() + utils.WriteFile(myfile, data=dummydata) + datastrict = ReadOneLineFile(myfile, strict=True) + datalax = ReadOneLineFile(myfile, strict=False) + self.assertEqual(dummydata, datastrict) + self.assertEqual(dummydata, datalax) + + def testNewline(self): + myfile = self._CreateTempFile() + myline = "myline" + for nl in ["", "\n", "\r\n"]: + dummydata = "%s%s" % (myline, nl) + utils.WriteFile(myfile, data=dummydata) + datalax = ReadOneLineFile(myfile, strict=False) + self.assertEqual(myline, datalax) + datastrict = ReadOneLineFile(myfile, strict=True) + self.assertEqual(myline, datastrict) + + def testWhitespaceAndMultipleLines(self): + myfile = self._CreateTempFile() + for nl in ["", "\n", "\r\n"]: + for ws in [" ", "\t", "\t\t \t", "\t "]: + dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl))) + utils.WriteFile(myfile, data=dummydata) + datalax = ReadOneLineFile(myfile, strict=False) + if nl: + self.assert_(set("\r\n") & set(dummydata)) + self.assertRaises(errors.GenericError, ReadOneLineFile, + myfile, strict=True) + explen = len("Foo bar baz ") + len(ws) + self.assertEqual(len(datalax), explen) + self.assertEqual(datalax, dummydata[:explen]) + self.assertFalse(set("\r\n") & set(datalax)) + else: + datastrict = ReadOneLineFile(myfile, strict=True) + self.assertEqual(dummydata, datastrict) + self.assertEqual(dummydata, datalax) + + def testEmptylines(self): + myfile = self._CreateTempFile() + myline = "myline" + for nl in ["\n", "\r\n"]: + for ol in ["", "otherline"]: + dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl) + utils.WriteFile(myfile, data=dummydata) + self.assert_(set("\r\n") & set(dummydata)) + datalax = ReadOneLineFile(myfile, strict=False) + self.assertEqual(myline, datalax) + if ol: + self.assertRaises(errors.GenericError, ReadOneLineFile, + myfile, strict=True) + else: + datastrict = ReadOneLineFile(myfile, strict=True) + self.assertEqual(myline, datastrict) + + class TestTimestampForFilename(unittest.TestCase): def test(self): self.assert_("." not in utils.TimestampForFilename()) @@ -533,7 +876,7 @@ class TestCreateBackup(testutils.GanetiTestCase): shutil.rmtree(self.tmpdir) def testEmpty(self): - filename = utils.PathJoin(self.tmpdir, "config.data") + filename = PathJoin(self.tmpdir, "config.data") utils.WriteFile(filename, data="") bname = utils.CreateBackup(filename) self.assertFileContent(bname, "") @@ -543,7 +886,7 @@ class TestCreateBackup(testutils.GanetiTestCase): utils.CreateBackup(filename) self.assertEqual(len(glob.glob("%s*" % filename)), 4) - fifoname = utils.PathJoin(self.tmpdir, "fifo") + fifoname = PathJoin(self.tmpdir, "fifo") os.mkfifo(fifoname) self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname) @@ -553,7 +896,7 @@ class TestCreateBackup(testutils.GanetiTestCase): for rep in [1, 2, 10, 127]: testdata = data * rep - filename = utils.PathJoin(self.tmpdir, "test.data_") + filename = PathJoin(self.tmpdir, "test.data_") utils.WriteFile(filename, data=testdata) self.assertFileContent(filename, testdata) @@ -605,6 +948,7 @@ class TestFormatUnit(unittest.TestCase): self.assertEqual(FormatUnit(5120 * 1024, 't'), '5.0') self.assertEqual(FormatUnit(29829 * 1024, 't'), '29.1') + class TestParseUnit(unittest.TestCase): """Test case for the ParseUnit function""" @@ -651,17 +995,39 @@ class TestParseUnit(unittest.TestCase): def testInvalidInput(self): for sep in ('-', '_', ',', 'a'): for suffix, _ in TestParseUnit.SCALES: - self.assertRaises(UnitParseError, ParseUnit, '1' + sep + suffix) + self.assertRaises(errors.UnitParseError, ParseUnit, '1' + sep + suffix) for suffix, _ in TestParseUnit.SCALES: - self.assertRaises(UnitParseError, ParseUnit, '1,3' + suffix) + self.assertRaises(errors.UnitParseError, ParseUnit, '1,3' + suffix) + + +class TestParseCpuMask(unittest.TestCase): + """Test case for the ParseCpuMask function.""" + def testWellFormed(self): + self.assertEqual(utils.ParseCpuMask(""), []) + self.assertEqual(utils.ParseCpuMask("1"), [1]) + self.assertEqual(utils.ParseCpuMask("0-2,4,5-5"), [0,1,2,4,5]) + + def testInvalidInput(self): + self.assertRaises(errors.ParseError, + utils.ParseCpuMask, + "garbage") + self.assertRaises(errors.ParseError, + utils.ParseCpuMask, + "0,") + self.assertRaises(errors.ParseError, + utils.ParseCpuMask, + "0-1-2") + self.assertRaises(errors.ParseError, + utils.ParseCpuMask, + "2-1") class TestSshKeys(testutils.GanetiTestCase): """Test case for the AddAuthorizedKey function""" KEY_A = 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a' - KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="1.2.3.4" ' + KEY_B = ('command="/usr/bin/fooserver -t --verbose",from="198.51.100.4" ' 'ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b') def setUp(self): @@ -675,48 +1041,49 @@ class TestSshKeys(testutils.GanetiTestCase): handle.close() def testAddingNewKey(self): - AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test') + utils.AddAuthorizedKey(self.tmpname, + 'ssh-dss AAAAB3NzaC1kc3MAAACB root@test') self.assertFileContent(self.tmpname, "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n" - 'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"' + 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"' " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n" "ssh-dss AAAAB3NzaC1kc3MAAACB root@test\n") def testAddingAlmostButNotCompletelyTheSameKey(self): - AddAuthorizedKey(self.tmpname, + utils.AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test') self.assertFileContent(self.tmpname, "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n" - 'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"' + 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"' " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n" "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@test\n") def testAddingExistingKeyWithSomeMoreSpaces(self): - AddAuthorizedKey(self.tmpname, + utils.AddAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a') self.assertFileContent(self.tmpname, "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n" - 'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"' + 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"' " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n") def testRemovingExistingKeyWithSomeMoreSpaces(self): - RemoveAuthorizedKey(self.tmpname, + utils.RemoveAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a') self.assertFileContent(self.tmpname, - 'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"' + 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"' " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n") def testRemovingNonExistingKey(self): - RemoveAuthorizedKey(self.tmpname, + utils.RemoveAuthorizedKey(self.tmpname, 'ssh-dss AAAAB3Nsdfj230xxjxJjsjwjsjdjU root@test') self.assertFileContent(self.tmpname, "ssh-dss AAAAB3NzaC1w5256closdj32mZaQU root@key-a\n" - 'command="/usr/bin/fooserver -t --verbose",from="1.2.3.4"' + 'command="/usr/bin/fooserver -t --verbose",from="198.51.100.4"' " ssh-dss AAAAB3NzaC1w520smc01ms0jfJs22 root@key-b\n") @@ -730,38 +1097,39 @@ class TestEtcHosts(testutils.GanetiTestCase): try: handle.write('# This is a test file for /etc/hosts\n') handle.write('127.0.0.1\tlocalhost\n') - handle.write('192.168.1.1 router gw\n') + handle.write('192.0.2.1 router gw\n') finally: handle.close() def testSettingNewIp(self): - SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost.domain.tld', ['myhost']) + SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost.example.com', + ['myhost']) self.assertFileContent(self.tmpname, "# This is a test file for /etc/hosts\n" "127.0.0.1\tlocalhost\n" - "192.168.1.1 router gw\n" - "1.2.3.4\tmyhost.domain.tld myhost\n") + "192.0.2.1 router gw\n" + "198.51.100.4\tmyhost.example.com myhost\n") self.assertFileMode(self.tmpname, 0644) def testSettingExistingIp(self): - SetEtcHostsEntry(self.tmpname, '192.168.1.1', 'myhost.domain.tld', + SetEtcHostsEntry(self.tmpname, '192.0.2.1', 'myhost.example.com', ['myhost']) self.assertFileContent(self.tmpname, "# This is a test file for /etc/hosts\n" "127.0.0.1\tlocalhost\n" - "192.168.1.1\tmyhost.domain.tld myhost\n") + "192.0.2.1\tmyhost.example.com myhost\n") self.assertFileMode(self.tmpname, 0644) def testSettingDuplicateName(self): - SetEtcHostsEntry(self.tmpname, '1.2.3.4', 'myhost', ['myhost']) + SetEtcHostsEntry(self.tmpname, '198.51.100.4', 'myhost', ['myhost']) self.assertFileContent(self.tmpname, "# This is a test file for /etc/hosts\n" "127.0.0.1\tlocalhost\n" - "192.168.1.1 router gw\n" - "1.2.3.4\tmyhost\n") + "192.0.2.1 router gw\n" + "198.51.100.4\tmyhost\n") self.assertFileMode(self.tmpname, 0644) def testRemovingExistingHost(self): @@ -770,7 +1138,7 @@ class TestEtcHosts(testutils.GanetiTestCase): self.assertFileContent(self.tmpname, "# This is a test file for /etc/hosts\n" "127.0.0.1\tlocalhost\n" - "192.168.1.1 gw\n") + "192.0.2.1 gw\n") self.assertFileMode(self.tmpname, 0644) def testRemovingSingleExistingHost(self): @@ -778,7 +1146,7 @@ class TestEtcHosts(testutils.GanetiTestCase): self.assertFileContent(self.tmpname, "# This is a test file for /etc/hosts\n" - "192.168.1.1 router gw\n") + "192.0.2.1 router gw\n") self.assertFileMode(self.tmpname, 0644) def testRemovingNonExistingHost(self): @@ -787,7 +1155,7 @@ class TestEtcHosts(testutils.GanetiTestCase): self.assertFileContent(self.tmpname, "# This is a test file for /etc/hosts\n" "127.0.0.1\tlocalhost\n" - "192.168.1.1 router gw\n") + "192.0.2.1 router gw\n") self.assertFileMode(self.tmpname, 0644) def testRemovingAlias(self): @@ -796,10 +1164,31 @@ class TestEtcHosts(testutils.GanetiTestCase): self.assertFileContent(self.tmpname, "# This is a test file for /etc/hosts\n" "127.0.0.1\tlocalhost\n" - "192.168.1.1 router\n") + "192.0.2.1 router\n") self.assertFileMode(self.tmpname, 0644) +class TestGetMounts(unittest.TestCase): + """Test case for GetMounts().""" + + TESTDATA = ( + "rootfs / rootfs rw 0 0\n" + "none /sys sysfs rw,nosuid,nodev,noexec,relatime 0 0\n" + "none /proc proc rw,nosuid,nodev,noexec,relatime 0 0\n") + + def setUp(self): + self.tmpfile = tempfile.NamedTemporaryFile() + utils.WriteFile(self.tmpfile.name, data=self.TESTDATA) + + def testGetMounts(self): + self.assertEqual(utils.GetMounts(filename=self.tmpfile.name), + [ + ("rootfs", "/", "rootfs", "rw"), + ("none", "/sys", "sysfs", "rw,nosuid,nodev,noexec,relatime"), + ("none", "/proc", "proc", "rw,nosuid,nodev,noexec,relatime"), + ]) + + class TestShellQuoting(unittest.TestCase): """Test case for shell quoting functions""" @@ -816,100 +1205,6 @@ class TestShellQuoting(unittest.TestCase): self.assertEqual(ShellQuoteArgs(['a', 'b\'', 'c']), "a 'b'\\\''' c") -class TestTcpPing(unittest.TestCase): - """Testcase for TCP version of ping - against listen(2)ing port""" - - def setUp(self): - self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.listener.bind((constants.LOCALHOST_IP_ADDRESS, 0)) - self.listenerport = self.listener.getsockname()[1] - self.listener.listen(1) - - def tearDown(self): - self.listener.shutdown(socket.SHUT_RDWR) - del self.listener - del self.listenerport - - def testTcpPingToLocalHostAccept(self): - self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS, - self.listenerport, - timeout=10, - live_port_needed=True, - source=constants.LOCALHOST_IP_ADDRESS, - ), - "failed to connect to test listener") - - self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS, - self.listenerport, - timeout=10, - live_port_needed=True, - ), - "failed to connect to test listener (no source)") - - -class TestTcpPingDeaf(unittest.TestCase): - """Testcase for TCP version of ping - against non listen(2)ing port""" - - def setUp(self): - self.deaflistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.deaflistener.bind((constants.LOCALHOST_IP_ADDRESS, 0)) - self.deaflistenerport = self.deaflistener.getsockname()[1] - - def tearDown(self): - del self.deaflistener - del self.deaflistenerport - - def testTcpPingToLocalHostAcceptDeaf(self): - self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS, - self.deaflistenerport, - timeout=constants.TCP_PING_TIMEOUT, - live_port_needed=True, - source=constants.LOCALHOST_IP_ADDRESS, - ), # need successful connect(2) - "successfully connected to deaf listener") - - self.failIf(TcpPing(constants.LOCALHOST_IP_ADDRESS, - self.deaflistenerport, - timeout=constants.TCP_PING_TIMEOUT, - live_port_needed=True, - ), # need successful connect(2) - "successfully connected to deaf listener (no source addr)") - - def testTcpPingToLocalHostNoAccept(self): - self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS, - self.deaflistenerport, - timeout=constants.TCP_PING_TIMEOUT, - live_port_needed=False, - source=constants.LOCALHOST_IP_ADDRESS, - ), # ECONNREFUSED is OK - "failed to ping alive host on deaf port") - - self.assert_(TcpPing(constants.LOCALHOST_IP_ADDRESS, - self.deaflistenerport, - timeout=constants.TCP_PING_TIMEOUT, - live_port_needed=False, - ), # ECONNREFUSED is OK - "failed to ping alive host on deaf port (no source addr)") - - -class TestOwnIpAddress(unittest.TestCase): - """Testcase for OwnIpAddress""" - - def testOwnLoopback(self): - """check having the loopback ip""" - self.failUnless(OwnIpAddress(constants.LOCALHOST_IP_ADDRESS), - "Should own the loopback address") - - def testNowOwnAddress(self): - """check that I don't own an address""" - - # network 192.0.2.0/24 is reserved for test/documentation as per - # rfc 3330, so we *should* not have an address of this range... if - # this fails, we should extend the test to multiple addresses - DST_IP = "192.0.2.1" - self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP) - - class TestListVisibleFiles(unittest.TestCase): """Test case for ListVisibleFiles""" @@ -919,22 +1214,14 @@ class TestListVisibleFiles(unittest.TestCase): def tearDown(self): shutil.rmtree(self.path) - def _test(self, files, expected): - # Sort a copy - expected = expected[:] - expected.sort() - + def _CreateFiles(self, files): for name in files: - f = open(os.path.join(self.path, name), 'w') - try: - f.write("Test\n") - finally: - f.close() + utils.WriteFile(os.path.join(self.path, name), data="test") + def _test(self, files, expected): + self._CreateFiles(files) found = ListVisibleFiles(self.path) - found.sort() - - self.assertEqual(found, expected) + self.assertEqual(set(found), set(expected)) def testAllVisible(self): files = ["a", "b", "c"] @@ -962,11 +1249,8 @@ class TestListVisibleFiles(unittest.TestCase): class TestNewUUID(unittest.TestCase): """Test case for NewUUID""" - _re_uuid = re.compile('^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-' - '[a-f0-9]{4}-[a-f0-9]{12}$') - def runTest(self): - self.failUnless(self._re_uuid.match(utils.NewUUID())) + self.failUnless(utils.UUID_RE.match(utils.NewUUID())) class TestUniqueSequence(unittest.TestCase): @@ -1221,13 +1505,14 @@ class TestForceDictType(unittest.TestCase): 'b': constants.VTYPE_BOOL, 'c': constants.VTYPE_STRING, 'd': constants.VTYPE_SIZE, + "e": constants.VTYPE_MAYBE_STRING, } def _fdt(self, dict, allowed_values=None): if allowed_values is None: - ForceDictType(dict, self.key_types) + utils.ForceDictType(dict, self.key_types) else: - ForceDictType(dict, self.key_types, allowed_values=allowed_values) + utils.ForceDictType(dict, self.key_types, allowed_values=allowed_values) return dict @@ -1244,23 +1529,28 @@ class TestForceDictType(unittest.TestCase): self.assertEqual(self._fdt({'b': 'True'}), {'b': True}) self.assertEqual(self._fdt({'d': '4'}), {'d': 4}) self.assertEqual(self._fdt({'d': '4M'}), {'d': 4}) + self.assertEqual(self._fdt({"e": None, }), {"e": None, }) + self.assertEqual(self._fdt({"e": "Hello World", }), {"e": "Hello World", }) + self.assertEqual(self._fdt({"e": False, }), {"e": '', }) def testErrors(self): self.assertRaises(errors.TypeEnforcementError, self._fdt, {'a': 'astring'}) self.assertRaises(errors.TypeEnforcementError, self._fdt, {'c': True}) self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': 'astring'}) self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'}) + self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": object(), }) + self.assertRaises(errors.TypeEnforcementError, self._fdt, {"e": [], }) -class TestIsAbsNormPath(unittest.TestCase): - """Testing case for IsProcessAlive""" +class TestIsNormAbsPath(unittest.TestCase): + """Testing case for IsNormAbsPath""" def _pathTestHelper(self, path, result): if result: - self.assert_(IsNormAbsPath(path), + self.assert_(utils.IsNormAbsPath(path), "Path %s should result absolute and normalized" % path) else: - self.assert_(not IsNormAbsPath(path), + self.assertFalse(utils.IsNormAbsPath(path), "Path %s should not result absolute and normalized" % path) def testBase(self): @@ -1388,6 +1678,53 @@ class TestUnescapeAndSplit(unittest.TestCase): self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b) +class TestGenerateSelfSignedX509Cert(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def _checkRsaPrivateKey(self, key): + lines = key.splitlines() + return ("-----BEGIN RSA PRIVATE KEY-----" in lines and + "-----END RSA PRIVATE KEY-----" in lines) + + def _checkCertificate(self, cert): + lines = cert.splitlines() + return ("-----BEGIN CERTIFICATE-----" in lines and + "-----END CERTIFICATE-----" in lines) + + def test(self): + for common_name in [None, ".", "Ganeti", "node1.example.com"]: + (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300) + self._checkRsaPrivateKey(key_pem) + self._checkCertificate(cert_pem) + + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, + key_pem) + self.assert_(key.bits() >= 1024) + self.assertEqual(key.bits(), constants.RSA_KEY_BITS) + self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA) + + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert_pem) + self.failIf(x509.has_expired()) + self.assertEqual(x509.get_issuer().CN, common_name) + self.assertEqual(x509.get_subject().CN, common_name) + self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS) + + def testLegacy(self): + cert1_filename = os.path.join(self.tmpdir, "cert1.pem") + + utils.GenerateSelfSignedSslCert(cert1_filename, validity=1) + + cert1 = utils.ReadFile(cert1_filename) + + self.assert_(self._checkRsaPrivateKey(cert1)) + self.assert_(self._checkCertificate(cert1)) + + class TestPathJoin(unittest.TestCase): """Testing case for PathJoin""" @@ -1405,40 +1742,30 @@ class TestPathJoin(unittest.TestCase): self.failUnlessRaises(ValueError, PathJoin, "/a", "/b") -class TestHostInfo(unittest.TestCase): - """Testing case for HostInfo""" - - def testUppercase(self): - data = "AbC.example.com" - self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower()) - - def testTooLongName(self): - data = "a.b." + "c" * 255 - self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data) +class TestValidateServiceName(unittest.TestCase): + def testValid(self): + testnames = [ + 0, 1, 2, 3, 1024, 65000, 65534, 65535, + "ganeti", + "gnt-masterd", + "HELLO_WORLD_SVC", + "hello.world.1", + "0", "80", "1111", "65535", + ] - def testTrailingDot(self): - data = "a.b.c" - self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data) + for name in testnames: + self.assertEqual(utils.ValidateServiceName(name), name) - def testInvalidName(self): - data = [ - "a b", - "a/b", - ".a.b", - "a..b", - ] - for value in data: - self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value) - - def testValidName(self): - data = [ - "a.b", - "a-b", - "a_b", - "a.b.c", + def testInvalid(self): + testnames = [ + -15756, -1, 65536, 133428083, + "", "Hello World!", "!", "'", "\"", "\t", "\n", "`", + "-8546", "-1", "65536", + (129 * "A"), ] - for value in data: - HostInfo.NormalizeName(value) + + for name in testnames: + self.assertRaises(errors.OpPrereqError, utils.ValidateServiceName, name) class TestParseAsn1Generalizedtime(unittest.TestCase): @@ -1504,6 +1831,61 @@ class TestGetX509CertValidity(testutils.GanetiTestCase): self.assertEqual(validity, (None, None)) +class TestSignX509Certificate(unittest.TestCase): + KEY = "My private key!" + KEY_OTHER = "Another key" + + def test(self): + # Generate certificate valid for 5 minutes + (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300) + + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert_pem) + + # No signature at all + self.assertRaises(errors.GenericError, + utils.LoadSignedX509Certificate, cert_pem, self.KEY) + + # Invalid input + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Signature: \n", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Sign: $1234$abcdef\n", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY) + + # Invalid salt + for salt in list("-_@$,:;/\\ \t\n"): + self.assertRaises(errors.GenericError, utils.SignX509Certificate, + cert_pem, self.KEY, "foo%sbar" % salt) + + for salt in ["HelloWorld", "salt", string.letters, string.digits, + utils.GenerateSecret(numbytes=4), + utils.GenerateSecret(numbytes=16), + "{123:456}".encode("hex")]: + signed_pem = utils.SignX509Certificate(cert, self.KEY, salt) + + self._Check(cert, salt, signed_pem) + + self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem) + self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem) + self._Check(cert, salt, (signed_pem + "\n\na few more\n" + "lines----\n------ at\nthe end!")) + + def _Check(self, cert, salt, pem): + (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY) + self.assertEqual(salt, salt2) + self.assertEqual(cert.digest("sha1"), cert2.digest("sha1")) + + # Other key + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + pem, self.KEY_OTHER) + + class TestMakedirs(unittest.TestCase): def setUp(self): self.tmpdir = tempfile.mkdtemp() @@ -1512,43 +1894,64 @@ class TestMakedirs(unittest.TestCase): shutil.rmtree(self.tmpdir) def testNonExisting(self): - path = utils.PathJoin(self.tmpdir, "foo") + path = PathJoin(self.tmpdir, "foo") utils.Makedirs(path) self.assert_(os.path.isdir(path)) def testExisting(self): - path = utils.PathJoin(self.tmpdir, "foo") + path = PathJoin(self.tmpdir, "foo") os.mkdir(path) utils.Makedirs(path) self.assert_(os.path.isdir(path)) def testRecursiveNonExisting(self): - path = utils.PathJoin(self.tmpdir, "foo/bar/baz") + path = PathJoin(self.tmpdir, "foo/bar/baz") utils.Makedirs(path) self.assert_(os.path.isdir(path)) def testRecursiveExisting(self): - path = utils.PathJoin(self.tmpdir, "B/moo/xyz") - self.assert_(not os.path.exists(path)) - os.mkdir(utils.PathJoin(self.tmpdir, "B")) + path = PathJoin(self.tmpdir, "B/moo/xyz") + self.assertFalse(os.path.exists(path)) + os.mkdir(PathJoin(self.tmpdir, "B")) utils.Makedirs(path) self.assert_(os.path.isdir(path)) class TestRetry(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + self.retries = 0 + @staticmethod def _RaiseRetryAgain(): raise utils.RetryAgain() + @staticmethod + def _RaiseRetryAgainWithArg(args): + raise utils.RetryAgain(*args) + def _WrongNestedLoop(self): return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02) + def _RetryAndSucceed(self, retries): + if self.retries < retries: + self.retries += 1 + raise utils.RetryAgain() + else: + return True + def testRaiseTimeout(self): self.failUnlessRaises(utils.RetryTimeout, utils.Retry, self._RaiseRetryAgain, 0.01, 0.02) + self.failUnlessRaises(utils.RetryTimeout, utils.Retry, + self._RetryAndSucceed, 0.01, 0, args=[1]) + self.failUnlessEqual(self.retries, 1) def testComplete(self): self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True) + self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]), + True) + self.failUnlessEqual(self.retries, 2) def testNestedLoop(self): try: @@ -1557,6 +1960,45 @@ class TestRetry(testutils.GanetiTestCase): except utils.RetryTimeout: self.fail("Didn't detect inner loop's exception") + def testTimeoutArgument(self): + retry_arg="my_important_debugging_message" + try: + utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]]) + except utils.RetryTimeout, err: + self.failUnlessEqual(err.args, (retry_arg, )) + else: + self.fail("Expected timeout didn't happen") + + def testRaiseInnerWithExc(self): + retry_arg="my_important_debugging_message" + try: + try: + utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, + args=[[errors.GenericError(retry_arg, retry_arg)]]) + except utils.RetryTimeout, err: + err.RaiseInner() + else: + self.fail("Expected timeout didn't happen") + except errors.GenericError, err: + self.failUnlessEqual(err.args, (retry_arg, retry_arg)) + else: + self.fail("Expected GenericError didn't happen") + + def testRaiseInnerWithMsg(self): + retry_arg="my_important_debugging_message" + try: + try: + utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, + args=[[retry_arg, retry_arg]]) + except utils.RetryTimeout, err: + err.RaiseInner() + else: + self.fail("Expected timeout didn't happen") + except utils.RetryTimeout, err: + self.failUnlessEqual(err.args, (retry_arg, retry_arg)) + else: + self.fail("Expected RetryTimeout didn't happen") + class TestLineSplitter(unittest.TestCase): def test(self): @@ -1592,31 +2034,360 @@ class TestLineSplitter(unittest.TestCase): "", "x"]) -class TestPartial(testutils.GanetiTestCase): +class TestReadLockedPidFile(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testNonExistent(self): + path = PathJoin(self.tmpdir, "nonexist") + self.assert_(utils.ReadLockedPidFile(path) is None) + + def testUnlocked(self): + path = PathJoin(self.tmpdir, "pid") + utils.WriteFile(path, data="123") + self.assert_(utils.ReadLockedPidFile(path) is None) + + def testLocked(self): + path = PathJoin(self.tmpdir, "pid") + utils.WriteFile(path, data="123") + + fl = utils.FileLock.Open(path) + try: + fl.Exclusive(blocking=True) + + self.assertEqual(utils.ReadLockedPidFile(path), 123) + finally: + fl.Close() + + self.assert_(utils.ReadLockedPidFile(path) is None) + + def testError(self): + path = PathJoin(self.tmpdir, "foobar", "pid") + utils.WriteFile(PathJoin(self.tmpdir, "foobar"), data="") + # open(2) should return ENOTDIR + self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path) + + +class TestCertVerification(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testVerifyCertificate(self): + cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem")) + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert_pem) + + # Not checking return value as this certificate is expired + utils.VerifyX509Certificate(cert, 30, 7) + + +class TestVerifyCertificateInner(unittest.TestCase): + def test(self): + vci = utils._VerifyCertificateInner + + # Valid + self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7), + (None, None)) + + # Not yet valid + (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7) + self.assertEqual(errcode, utils.CERT_WARNING) + + # Expiring soon + (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1) + self.assertEqual(errcode, utils.CERT_WARNING) + + (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7) + self.assertEqual(errcode, None) + + # Expired + (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(True, None, None, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + +class TestHmacFunctions(unittest.TestCase): + # Digests can be checked with "openssl sha1 -hmac $key" + def testSha1Hmac(self): + self.assertEqual(utils.Sha1Hmac("", ""), + "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d") + self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"), + "ef4f3bda82212ecb2f7ce868888a19092481f1fd") + self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""), + "f904c2476527c6d3e6609ab683c66fa0652cb1dc") + + longtext = 1500 * "The quick brown fox jumps over the lazy dog\n" + self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext), + "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54") + + def testSha1HmacSalt(self): + self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"), + "4999bf342470eadb11dfcd24ca5680cf9fd7cdce") + self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"), + "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8") + self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"), + "7f264f8114c9066afc9bb7636e1786d996d3cc0d") + + def testVerifySha1Hmac(self): + self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b" + "7d64b71fb76370690e1d"))) + self.assert_(utils.VerifySha1Hmac("TguMTA2K", "", + ("f904c2476527c6d3e660" + "9ab683c66fa0652cb1dc"))) + + digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd" + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest)) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + digest.lower())) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + digest.upper())) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + digest.title())) + + def testVerifySha1HmacSalt(self): + self.assert_(utils.VerifySha1Hmac("TguMTA2K", "", + ("17a4adc34d69c0d367d4" + "ffbef96fd41d4df7a6e8"), + salt="abc9")) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + ("7f264f8114c9066afc9b" + "b7636e1786d996d3cc0d"), + salt="xyz0")) + + +class TestIgnoreSignals(unittest.TestCase): + """Test the IgnoreSignals decorator""" + + @staticmethod + def _Raise(exception): + raise exception + + @staticmethod + def _Return(rval): + return rval + + def testIgnoreSignals(self): + sock_err_intr = socket.error(errno.EINTR, "Message") + sock_err_inval = socket.error(errno.EINVAL, "Message") + + env_err_intr = EnvironmentError(errno.EINTR, "Message") + env_err_inval = EnvironmentError(errno.EINVAL, "Message") + + self.assertRaises(socket.error, self._Raise, sock_err_intr) + self.assertRaises(socket.error, self._Raise, sock_err_inval) + self.assertRaises(EnvironmentError, self._Raise, env_err_intr) + self.assertRaises(EnvironmentError, self._Raise, env_err_inval) + + self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None) + self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None) + self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise, + sock_err_inval) + self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise, + env_err_inval) + + self.assertEquals(utils.IgnoreSignals(self._Return, True), True) + self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33) + + +class TestEnsureDirs(unittest.TestCase): + """Tests for EnsureDirs""" + + def setUp(self): + self.dir = tempfile.mkdtemp() + self.old_umask = os.umask(0777) + + def testEnsureDirs(self): + utils.EnsureDirs([ + (PathJoin(self.dir, "foo"), 0777), + (PathJoin(self.dir, "bar"), 0000), + ]) + self.assertEquals(os.stat(PathJoin(self.dir, "foo"))[0] & 0777, 0777) + self.assertEquals(os.stat(PathJoin(self.dir, "bar"))[0] & 0777, 0000) + + def tearDown(self): + os.rmdir(PathJoin(self.dir, "foo")) + os.rmdir(PathJoin(self.dir, "bar")) + os.rmdir(self.dir) + os.umask(self.old_umask) + + +class TestFormatSeconds(unittest.TestCase): + def test(self): + self.assertEqual(utils.FormatSeconds(1), "1s") + self.assertEqual(utils.FormatSeconds(3600), "1h 0m 0s") + self.assertEqual(utils.FormatSeconds(3599), "59m 59s") + self.assertEqual(utils.FormatSeconds(7200), "2h 0m 0s") + self.assertEqual(utils.FormatSeconds(7201), "2h 0m 1s") + self.assertEqual(utils.FormatSeconds(7281), "2h 1m 21s") + self.assertEqual(utils.FormatSeconds(29119), "8h 5m 19s") + self.assertEqual(utils.FormatSeconds(19431228), "224d 21h 33m 48s") + self.assertEqual(utils.FormatSeconds(-1), "-1s") + self.assertEqual(utils.FormatSeconds(-282), "-282s") + self.assertEqual(utils.FormatSeconds(-29119), "-29119s") + + def testFloat(self): + self.assertEqual(utils.FormatSeconds(1.3), "1s") + self.assertEqual(utils.FormatSeconds(1.9), "2s") + self.assertEqual(utils.FormatSeconds(3912.12311), "1h 5m 12s") + self.assertEqual(utils.FormatSeconds(3912.8), "1h 5m 13s") + + +class TestIgnoreProcessNotFound(unittest.TestCase): + @staticmethod + def _WritePid(fd): + os.write(fd, str(os.getpid())) + os.close(fd) + return True + + def test(self): + (pid_read_fd, pid_write_fd) = os.pipe() + + # Start short-lived process which writes its PID to pipe + self.assert_(utils.RunInSeparateProcess(self._WritePid, pid_write_fd)) + os.close(pid_write_fd) + + # Read PID from pipe + pid = int(os.read(pid_read_fd, 1024)) + os.close(pid_read_fd) + + # Try to send signal to process which exited recently + self.assertFalse(utils.IgnoreProcessNotFound(os.kill, pid, 0)) + + +class TestShellWriter(unittest.TestCase): def test(self): - self._Test(utils.partial) - self._Test(utils._partial) + buf = StringIO() + sw = utils.ShellWriter(buf) + sw.Write("#!/bin/bash") + sw.Write("if true; then") + sw.IncIndent() + try: + sw.Write("echo true") - def _Test(self, fn): - def _TestFunc1(x, power=2): - return x ** power + sw.Write("for i in 1 2 3") + sw.Write("do") + sw.IncIndent() + try: + self.assertEqual(sw._indent, 2) + sw.Write("date") + finally: + sw.DecIndent() + sw.Write("done") + finally: + sw.DecIndent() + sw.Write("echo %s", utils.ShellQuote("Hello World")) + sw.Write("exit 0") - cubic = fn(_TestFunc1, power=3) - self.assertEqual(cubic(1), 1) - self.assertEqual(cubic(3), 27) - self.assertEqual(cubic(4), 64) + self.assertEqual(sw._indent, 0) - def _TestFunc2(*args, **kwargs): - return (args, kwargs) + output = buf.getvalue() - self.assertEqualValues(fn(_TestFunc2, "Hello", "World")("Foo"), - (("Hello", "World", "Foo"), {})) + self.assert_(output.endswith("\n")) - self.assertEqualValues(fn(_TestFunc2, "Hello", xyz=123)("Foo"), - (("Hello", "Foo"), {"xyz": 123})) + lines = output.splitlines() + self.assertEqual(len(lines), 9) + self.assertEqual(lines[0], "#!/bin/bash") + self.assert_(re.match(r"^\s+date$", lines[5])) + self.assertEqual(lines[7], "echo 'Hello World'") + + def testEmpty(self): + buf = StringIO() + sw = utils.ShellWriter(buf) + sw = None + self.assertEqual(buf.getvalue(), "") - self.assertEqualValues(fn(_TestFunc2, xyz=123)("Foo", xyz=999), - (("Foo", ), {"xyz": 999,})) + +class TestCommaJoin(unittest.TestCase): + def test(self): + self.assertEqual(utils.CommaJoin([]), "") + self.assertEqual(utils.CommaJoin([1, 2, 3]), "1, 2, 3") + self.assertEqual(utils.CommaJoin(["Hello"]), "Hello") + self.assertEqual(utils.CommaJoin(["Hello", "World"]), "Hello, World") + self.assertEqual(utils.CommaJoin(["Hello", "World", 99]), + "Hello, World, 99") + + +class TestFindMatch(unittest.TestCase): + def test(self): + data = { + "aaaa": "Four A", + "bb": {"Two B": True}, + re.compile(r"^x(foo|bar|bazX)([0-9]+)$"): (1, 2, 3), + } + + self.assertEqual(utils.FindMatch(data, "aaaa"), ("Four A", [])) + self.assertEqual(utils.FindMatch(data, "bb"), ({"Two B": True}, [])) + + for i in ["foo", "bar", "bazX"]: + for j in range(1, 100, 7): + self.assertEqual(utils.FindMatch(data, "x%s%s" % (i, j)), + ((1, 2, 3), [i, str(j)])) + + def testNoMatch(self): + self.assert_(utils.FindMatch({}, "") is None) + self.assert_(utils.FindMatch({}, "foo") is None) + self.assert_(utils.FindMatch({}, 1234) is None) + + data = { + "X": "Hello World", + re.compile("^(something)$"): "Hello World", + } + + self.assert_(utils.FindMatch(data, "") is None) + self.assert_(utils.FindMatch(data, "Hello World") is None) + + +class TestFileID(testutils.GanetiTestCase): + def testEquality(self): + name = self._CreateTempFile() + oldi = utils.GetFileID(path=name) + self.failUnless(utils.VerifyFileID(oldi, oldi)) + + def testUpdate(self): + name = self._CreateTempFile() + oldi = utils.GetFileID(path=name) + os.utime(name, None) + fd = os.open(name, os.O_RDWR) + try: + newi = utils.GetFileID(fd=fd) + self.failUnless(utils.VerifyFileID(oldi, newi)) + self.failUnless(utils.VerifyFileID(newi, oldi)) + finally: + os.close(fd) + + def testWriteFile(self): + name = self._CreateTempFile() + oldi = utils.GetFileID(path=name) + mtime = oldi[2] + os.utime(name, (mtime + 10, mtime + 10)) + self.assertRaises(errors.LockError, utils.SafeWriteFile, name, + oldi, data="") + os.utime(name, (mtime - 10, mtime - 10)) + utils.SafeWriteFile(name, oldi, data="") + oldi = utils.GetFileID(path=name) + mtime = oldi[2] + os.utime(name, (mtime + 10, mtime + 10)) + # this doesn't raise, since we passed None + utils.SafeWriteFile(name, None, data="") if __name__ == '__main__':