X-Git-Url: https://code.grnet.gr/git/ganeti-local/blobdiff_plain/da961187f97344fde390140ebb2f10d10d334d51..39019f912d61d301e593dcbc8e72d5be44f75d7a:/test/ganeti.utils_unittest.py diff --git a/test/ganeti.utils_unittest.py b/test/ganeti.utils_unittest.py index 433bfe3..511e48a 100755 --- a/test/ganeti.utils_unittest.py +++ b/test/ganeti.utils_unittest.py @@ -27,27 +27,37 @@ import time import tempfile import os.path import os -import md5 +import stat import signal import socket import shutil import re import select +import string +import fcntl +import OpenSSL +import warnings +import distutils.version +import glob +import md5 +import errno import ganeti import testutils from ganeti import constants from ganeti import utils from ganeti import errors +from ganeti import serializer from ganeti.utils import IsProcessAlive, RunCmd, \ - RemoveFile, CheckDict, MatchNameComponent, FormatUnit, \ + RemoveFile, MatchNameComponent, FormatUnit, \ ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \ ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \ SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \ - TailFile, ForceDictType, IsNormAbsPath + TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \ + UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile from ganeti.errors import LockError, UnitParseError, GenericError, \ - ProgrammerError + ProgrammerError, OpPrereqError class TestIsProcessAlive(unittest.TestCase): @@ -229,6 +239,287 @@ class TestRunCmd(testutils.GanetiTestCase): cwd = os.getcwd() self.failUnlessEqual(RunCmd(["pwd"], cwd=cwd).stdout.strip(), cwd) + def testResetEnv(self): + """Test environment reset functionality""" + self.failUnlessEqual(RunCmd(["env"], reset_env=True).stdout.strip(), "") + self.failUnlessEqual(RunCmd(["env"], reset_env=True, + env={"FOO": "bar",}).stdout.strip(), "FOO=bar") + + +class TestRunParts(unittest.TestCase): + """Testing case for the RunParts function""" + + def setUp(self): + self.rundir = tempfile.mkdtemp(prefix="ganeti-test", suffix=".tmp") + + def tearDown(self): + shutil.rmtree(self.rundir) + + def testEmpty(self): + """Test on an empty dir""" + self.failUnlessEqual(RunParts(self.rundir, reset_env=True), []) + + def testSkipWrongName(self): + """Test that wrong files are skipped""" + fname = os.path.join(self.rundir, "00test.dot") + utils.WriteFile(fname, data="") + os.chmod(fname, stat.S_IREAD | stat.S_IEXEC) + relname = os.path.basename(fname) + self.failUnlessEqual(RunParts(self.rundir, reset_env=True), + [(relname, constants.RUNPARTS_SKIP, None)]) + + def testSkipNonExec(self): + """Test that non executable files are skipped""" + fname = os.path.join(self.rundir, "00test") + utils.WriteFile(fname, data="") + relname = os.path.basename(fname) + self.failUnlessEqual(RunParts(self.rundir, reset_env=True), + [(relname, constants.RUNPARTS_SKIP, None)]) + + def testError(self): + """Test error on a broken executable""" + fname = os.path.join(self.rundir, "00test") + utils.WriteFile(fname, data="") + os.chmod(fname, stat.S_IREAD | stat.S_IEXEC) + (relname, status, error) = RunParts(self.rundir, reset_env=True)[0] + self.failUnlessEqual(relname, os.path.basename(fname)) + self.failUnlessEqual(status, constants.RUNPARTS_ERR) + self.failUnless(error) + + def testSorted(self): + """Test executions are sorted""" + files = [] + files.append(os.path.join(self.rundir, "64test")) + files.append(os.path.join(self.rundir, "00test")) + files.append(os.path.join(self.rundir, "42test")) + + for fname in files: + utils.WriteFile(fname, data="") + + results = RunParts(self.rundir, reset_env=True) + + for fname in sorted(files): + self.failUnlessEqual(os.path.basename(fname), results.pop(0)[0]) + + def testOk(self): + """Test correct execution""" + fname = os.path.join(self.rundir, "00test") + utils.WriteFile(fname, data="#!/bin/sh\n\necho -n ciao") + os.chmod(fname, stat.S_IREAD | stat.S_IEXEC) + (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0] + self.failUnlessEqual(relname, os.path.basename(fname)) + self.failUnlessEqual(status, constants.RUNPARTS_RUN) + self.failUnlessEqual(runresult.stdout, "ciao") + + def testRunFail(self): + """Test correct execution, with run failure""" + fname = os.path.join(self.rundir, "00test") + utils.WriteFile(fname, data="#!/bin/sh\n\nexit 1") + os.chmod(fname, stat.S_IREAD | stat.S_IEXEC) + (relname, status, runresult) = RunParts(self.rundir, reset_env=True)[0] + self.failUnlessEqual(relname, os.path.basename(fname)) + self.failUnlessEqual(status, constants.RUNPARTS_RUN) + self.failUnlessEqual(runresult.exit_code, 1) + self.failUnless(runresult.failed) + + def testRunMix(self): + files = [] + files.append(os.path.join(self.rundir, "00test")) + files.append(os.path.join(self.rundir, "42test")) + files.append(os.path.join(self.rundir, "64test")) + files.append(os.path.join(self.rundir, "99test")) + + files.sort() + + # 1st has errors in execution + utils.WriteFile(files[0], data="#!/bin/sh\n\nexit 1") + os.chmod(files[0], stat.S_IREAD | stat.S_IEXEC) + + # 2nd is skipped + utils.WriteFile(files[1], data="") + + # 3rd cannot execute properly + utils.WriteFile(files[2], data="") + os.chmod(files[2], stat.S_IREAD | stat.S_IEXEC) + + # 4th execs + utils.WriteFile(files[3], data="#!/bin/sh\n\necho -n ciao") + os.chmod(files[3], stat.S_IREAD | stat.S_IEXEC) + + results = RunParts(self.rundir, reset_env=True) + + (relname, status, runresult) = results[0] + self.failUnlessEqual(relname, os.path.basename(files[0])) + self.failUnlessEqual(status, constants.RUNPARTS_RUN) + self.failUnlessEqual(runresult.exit_code, 1) + self.failUnless(runresult.failed) + + (relname, status, runresult) = results[1] + self.failUnlessEqual(relname, os.path.basename(files[1])) + self.failUnlessEqual(status, constants.RUNPARTS_SKIP) + self.failUnlessEqual(runresult, None) + + (relname, status, runresult) = results[2] + self.failUnlessEqual(relname, os.path.basename(files[2])) + self.failUnlessEqual(status, constants.RUNPARTS_ERR) + self.failUnless(runresult) + + (relname, status, runresult) = results[3] + self.failUnlessEqual(relname, os.path.basename(files[3])) + self.failUnlessEqual(status, constants.RUNPARTS_RUN) + self.failUnlessEqual(runresult.output, "ciao") + self.failUnlessEqual(runresult.exit_code, 0) + self.failUnless(not runresult.failed) + + +class TestStartDaemon(testutils.GanetiTestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp(prefix="ganeti-test") + self.tmpfile = os.path.join(self.tmpdir, "test") + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testShell(self): + utils.StartDaemon("echo Hello World > %s" % self.tmpfile) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testShellOutput(self): + utils.StartDaemon("echo Hello World", output=self.tmpfile) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testNoShellNoOutput(self): + utils.StartDaemon(["pwd"]) + + def testNoShellNoOutputTouch(self): + testfile = os.path.join(self.tmpdir, "check") + self.failIf(os.path.exists(testfile)) + utils.StartDaemon(["touch", testfile]) + self._wait(testfile, 60.0, "") + + def testNoShellOutput(self): + utils.StartDaemon(["pwd"], output=self.tmpfile) + self._wait(self.tmpfile, 60.0, "/") + + def testNoShellOutputCwd(self): + utils.StartDaemon(["pwd"], output=self.tmpfile, cwd=os.getcwd()) + self._wait(self.tmpfile, 60.0, os.getcwd()) + + def testShellEnv(self): + utils.StartDaemon("echo \"$GNT_TEST_VAR\"", output=self.tmpfile, + env={ "GNT_TEST_VAR": "Hello World", }) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testNoShellEnv(self): + utils.StartDaemon(["printenv", "GNT_TEST_VAR"], output=self.tmpfile, + env={ "GNT_TEST_VAR": "Hello World", }) + self._wait(self.tmpfile, 60.0, "Hello World") + + def testOutputFd(self): + fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT) + try: + utils.StartDaemon(["pwd"], output_fd=fd, cwd=os.getcwd()) + finally: + os.close(fd) + self._wait(self.tmpfile, 60.0, os.getcwd()) + + def testPid(self): + pid = utils.StartDaemon("echo $$ > %s" % self.tmpfile) + self._wait(self.tmpfile, 60.0, str(pid)) + + def testPidFile(self): + pidfile = os.path.join(self.tmpdir, "pid") + checkfile = os.path.join(self.tmpdir, "abort") + + pid = utils.StartDaemon("while sleep 5; do :; done", pidfile=pidfile, + output=self.tmpfile) + try: + fd = os.open(pidfile, os.O_RDONLY) + try: + # Check file is locked + self.assertRaises(errors.LockError, utils.LockFile, fd) + + pidtext = os.read(fd, 100) + finally: + os.close(fd) + + self.assertEqual(int(pidtext.strip()), pid) + + self.assert_(utils.IsProcessAlive(pid)) + finally: + # No matter what happens, kill daemon + utils.KillProcess(pid, timeout=5.0, waitpid=False) + self.failIf(utils.IsProcessAlive(pid)) + + self.assertEqual(utils.ReadFile(self.tmpfile), "") + + def _wait(self, path, timeout, expected): + # Due to the asynchronous nature of daemon processes, polling is necessary. + # A timeout makes sure the test doesn't hang forever. + def _CheckFile(): + if not (os.path.isfile(path) and + utils.ReadFile(path).strip() == expected): + raise utils.RetryAgain() + + try: + utils.Retry(_CheckFile, (0.01, 1.5, 1.0), timeout) + except utils.RetryTimeout: + self.fail("Apparently the daemon didn't run in %s seconds and/or" + " didn't write the correct output" % timeout) + + def testError(self): + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"]) + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + output=os.path.join(self.tmpdir, "DIR/NOT/EXIST")) + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + cwd=os.path.join(self.tmpdir, "DIR/NOT/EXIST")) + self.assertRaises(errors.OpExecError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + output=os.path.join(self.tmpdir, "DIR/NOT/EXIST")) + + fd = os.open(self.tmpfile, os.O_WRONLY | os.O_CREAT) + try: + self.assertRaises(errors.ProgrammerError, utils.StartDaemon, + ["./does-NOT-EXIST/here/0123456789"], + output=self.tmpfile, output_fd=fd) + finally: + os.close(fd) + + +class TestSetCloseOnExecFlag(unittest.TestCase): + """Tests for SetCloseOnExecFlag""" + + def setUp(self): + self.tmpfile = tempfile.TemporaryFile() + + def testEnable(self): + utils.SetCloseOnExecFlag(self.tmpfile.fileno(), True) + self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) & + fcntl.FD_CLOEXEC) + + def testDisable(self): + utils.SetCloseOnExecFlag(self.tmpfile.fileno(), False) + self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFD) & + fcntl.FD_CLOEXEC) + + +class TestSetNonblockFlag(unittest.TestCase): + def setUp(self): + self.tmpfile = tempfile.TemporaryFile() + + def testEnable(self): + utils.SetNonblockFlag(self.tmpfile.fileno(), True) + self.failUnless(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) & + os.O_NONBLOCK) + + def testDisable(self): + utils.SetNonblockFlag(self.tmpfile.fileno(), False) + self.failIf(fcntl.fcntl(self.tmpfile.fileno(), fcntl.F_GETFL) & + os.O_NONBLOCK) + class TestRemoveFile(unittest.TestCase): """Test case for the RemoveFile function""" @@ -244,25 +535,21 @@ class TestRemoveFile(unittest.TestCase): os.unlink(self.tmpfile) os.rmdir(self.tmpdir) - def testIgnoreDirs(self): """Test that RemoveFile() ignores directories""" self.assertEqual(None, RemoveFile(self.tmpdir)) - def testIgnoreNotExisting(self): """Test that RemoveFile() ignores non-existing files""" RemoveFile(self.tmpfile) RemoveFile(self.tmpfile) - def testRemoveFile(self): """Test that RemoveFile does remove a file""" RemoveFile(self.tmpfile) if os.path.exists(self.tmpfile): self.fail("File '%s' not removed" % self.tmpfile) - def testRemoveSymlink(self): """Test that RemoveFile does remove symlinks""" symlink = self.tmpdir + "/symlink" @@ -294,37 +581,27 @@ class TestRename(unittest.TestCase): def testSimpleRename1(self): """Simple rename 1""" utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz")) + self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz"))) def testSimpleRename2(self): """Simple rename 2""" utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "xyz"), mkdir=True) + self.assert_(os.path.isfile(os.path.join(self.tmpdir, "xyz"))) def testRenameMkdir(self): """Rename with mkdir""" utils.RenameFile(self.tmpfile, os.path.join(self.tmpdir, "test/xyz"), mkdir=True) + self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test"))) + self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/xyz"))) - -class TestCheckdict(unittest.TestCase): - """Test case for the CheckDict function""" - - def testAdd(self): - """Test that CheckDict adds a missing key with the correct value""" - - tgt = {'a':1} - tmpl = {'b': 2} - CheckDict(tgt, tmpl) - if 'b' not in tgt or tgt['b'] != 2: - self.fail("Failed to update dict") - - - def testNoUpdate(self): - """Test that CheckDict does not overwrite an existing key""" - tgt = {'a':1, 'b': 3} - tmpl = {'b': 2} - CheckDict(tgt, tmpl) - self.failUnlessEqual(tgt['b'], 3) + utils.RenameFile(os.path.join(self.tmpdir, "test/xyz"), + os.path.join(self.tmpdir, "test/foo/bar/baz"), + mkdir=True) + self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test"))) + self.assert_(os.path.isdir(os.path.join(self.tmpdir, "test/foo/bar"))) + self.assert_(os.path.isfile(os.path.join(self.tmpdir, "test/foo/bar/baz"))) class TestMatchNameComponent(unittest.TestCase): @@ -348,6 +625,194 @@ class TestMatchNameComponent(unittest.TestCase): for key in "test1", "test1.example": self.failUnlessEqual(MatchNameComponent(key, mlist), None) + def testFullMatch(self): + """Test that a full match is returned correctly""" + key1 = "test1" + key2 = "test1.example" + mlist = [key2, key2 + ".com"] + self.failUnlessEqual(MatchNameComponent(key1, mlist), None) + self.failUnlessEqual(MatchNameComponent(key2, mlist), key2) + + def testCaseInsensitivePartialMatch(self): + """Test for the case_insensitive keyword""" + mlist = ["test1.example.com", "test2.example.net"] + self.assertEqual(MatchNameComponent("test2", mlist, case_sensitive=False), + "test2.example.net") + self.assertEqual(MatchNameComponent("Test2", mlist, case_sensitive=False), + "test2.example.net") + self.assertEqual(MatchNameComponent("teSt2", mlist, case_sensitive=False), + "test2.example.net") + self.assertEqual(MatchNameComponent("TeSt2", mlist, case_sensitive=False), + "test2.example.net") + + + def testCaseInsensitiveFullMatch(self): + mlist = ["ts1.ex", "ts1.ex.org", "ts2.ex", "Ts2.ex"] + # Between the two ts1 a full string match non-case insensitive should work + self.assertEqual(MatchNameComponent("Ts1", mlist, case_sensitive=False), + None) + self.assertEqual(MatchNameComponent("Ts1.ex", mlist, case_sensitive=False), + "ts1.ex") + self.assertEqual(MatchNameComponent("ts1.ex", mlist, case_sensitive=False), + "ts1.ex") + # Between the two ts2 only case differs, so only case-match works + self.assertEqual(MatchNameComponent("ts2.ex", mlist, case_sensitive=False), + "ts2.ex") + self.assertEqual(MatchNameComponent("Ts2.ex", mlist, case_sensitive=False), + "Ts2.ex") + self.assertEqual(MatchNameComponent("TS2.ex", mlist, case_sensitive=False), + None) + + +class TestReadFile(testutils.GanetiTestCase): + + def testReadAll(self): + data = utils.ReadFile(self._TestDataFilename("cert1.pem")) + self.assertEqual(len(data), 814) + + h = md5.new() + h.update(data) + self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4") + + def testReadSize(self): + data = utils.ReadFile(self._TestDataFilename("cert1.pem"), + size=100) + self.assertEqual(len(data), 100) + + h = md5.new() + h.update(data) + self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7") + + def testError(self): + self.assertRaises(EnvironmentError, utils.ReadFile, + "/dev/null/does-not-exist") + + +class TestReadOneLineFile(testutils.GanetiTestCase): + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + def testDefault(self): + data = ReadOneLineFile(self._TestDataFilename("cert1.pem")) + self.assertEqual(len(data), 27) + self.assertEqual(data, "-----BEGIN CERTIFICATE-----") + + def testNotStrict(self): + data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False) + self.assertEqual(len(data), 27) + self.assertEqual(data, "-----BEGIN CERTIFICATE-----") + + def testStrictFailure(self): + self.assertRaises(errors.GenericError, ReadOneLineFile, + self._TestDataFilename("cert1.pem"), strict=True) + + def testLongLine(self): + dummydata = (1024 * "Hello World! ") + myfile = self._CreateTempFile() + utils.WriteFile(myfile, data=dummydata) + datastrict = ReadOneLineFile(myfile, strict=True) + datalax = ReadOneLineFile(myfile, strict=False) + self.assertEqual(dummydata, datastrict) + self.assertEqual(dummydata, datalax) + + def testNewline(self): + myfile = self._CreateTempFile() + myline = "myline" + for nl in ["", "\n", "\r\n"]: + dummydata = "%s%s" % (myline, nl) + utils.WriteFile(myfile, data=dummydata) + datalax = ReadOneLineFile(myfile, strict=False) + self.assertEqual(myline, datalax) + datastrict = ReadOneLineFile(myfile, strict=True) + self.assertEqual(myline, datastrict) + + def testWhitespaceAndMultipleLines(self): + myfile = self._CreateTempFile() + for nl in ["", "\n", "\r\n"]: + for ws in [" ", "\t", "\t\t \t", "\t "]: + dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl))) + utils.WriteFile(myfile, data=dummydata) + datalax = ReadOneLineFile(myfile, strict=False) + if nl: + self.assert_(set("\r\n") & set(dummydata)) + self.assertRaises(errors.GenericError, ReadOneLineFile, + myfile, strict=True) + explen = len("Foo bar baz ") + len(ws) + self.assertEqual(len(datalax), explen) + self.assertEqual(datalax, dummydata[:explen]) + self.assertFalse(set("\r\n") & set(datalax)) + else: + datastrict = ReadOneLineFile(myfile, strict=True) + self.assertEqual(dummydata, datastrict) + self.assertEqual(dummydata, datalax) + + def testEmptylines(self): + myfile = self._CreateTempFile() + myline = "myline" + for nl in ["\n", "\r\n"]: + for ol in ["", "otherline"]: + dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl) + utils.WriteFile(myfile, data=dummydata) + self.assert_(set("\r\n") & set(dummydata)) + datalax = ReadOneLineFile(myfile, strict=False) + self.assertEqual(myline, datalax) + if ol: + self.assertRaises(errors.GenericError, ReadOneLineFile, + myfile, strict=True) + else: + datastrict = ReadOneLineFile(myfile, strict=True) + self.assertEqual(myline, datastrict) + + +class TestTimestampForFilename(unittest.TestCase): + def test(self): + self.assert_("." not in utils.TimestampForFilename()) + self.assert_(":" not in utils.TimestampForFilename()) + + +class TestCreateBackup(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + testutils.GanetiTestCase.tearDown(self) + + shutil.rmtree(self.tmpdir) + + def testEmpty(self): + filename = utils.PathJoin(self.tmpdir, "config.data") + utils.WriteFile(filename, data="") + bname = utils.CreateBackup(filename) + self.assertFileContent(bname, "") + self.assertEqual(len(glob.glob("%s*" % filename)), 2) + utils.CreateBackup(filename) + self.assertEqual(len(glob.glob("%s*" % filename)), 3) + utils.CreateBackup(filename) + self.assertEqual(len(glob.glob("%s*" % filename)), 4) + + fifoname = utils.PathJoin(self.tmpdir, "fifo") + os.mkfifo(fifoname) + self.assertRaises(errors.ProgrammerError, utils.CreateBackup, fifoname) + + def testContent(self): + bkpcount = 0 + for data in ["", "X", "Hello World!\n" * 100, "Binary data\0\x01\x02\n"]: + for rep in [1, 2, 10, 127]: + testdata = data * rep + + filename = utils.PathJoin(self.tmpdir, "test.data_") + utils.WriteFile(filename, data=testdata) + self.assertFileContent(filename, testdata) + + for _ in range(3): + bname = utils.CreateBackup(filename) + bkpcount += 1 + self.assertFileContent(bname, testdata) + self.assertEqual(len(glob.glob("%s*" % filename)), 1 + bkpcount) + class TestFormatUnit(unittest.TestCase): """Test case for the FormatUnit function""" @@ -688,13 +1153,80 @@ class TestOwnIpAddress(unittest.TestCase): def testNowOwnAddress(self): """check that I don't own an address""" - # network 192.0.2.0/24 is reserved for test/documentation as per - # rfc 3330, so we *should* not have an address of this range... if + # Network 192.0.2.0/24 is reserved for test/documentation as per + # RFC 5735, so we *should* not have an address of this range... if # this fails, we should extend the test to multiple addresses DST_IP = "192.0.2.1" self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP) +def _GetSocketCredentials(path): + """Connect to a Unix socket and return remote credentials. + + """ + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.settimeout(10) + sock.connect(path) + return utils.GetSocketCredentials(sock) + finally: + sock.close() + + +class TestGetSocketCredentials(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + self.sockpath = utils.PathJoin(self.tmpdir, "sock") + + self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.listener.settimeout(10) + self.listener.bind(self.sockpath) + self.listener.listen(1) + + def tearDown(self): + self.listener.shutdown(socket.SHUT_RDWR) + self.listener.close() + shutil.rmtree(self.tmpdir) + + def test(self): + (c2pr, c2pw) = os.pipe() + + # Start child process + child = os.fork() + if child == 0: + try: + data = serializer.DumpJson(_GetSocketCredentials(self.sockpath)) + + os.write(c2pw, data) + os.close(c2pw) + + os._exit(0) + finally: + os._exit(1) + + os.close(c2pw) + + # Wait for one connection + (conn, _) = self.listener.accept() + conn.recv(1) + conn.close() + + # Wait for result + result = os.read(c2pr, 4096) + os.close(c2pr) + + # Check child's exit code + (_, status) = os.waitpid(child, 0) + self.assertFalse(os.WIFSIGNALED(status)) + self.assertEqual(os.WEXITSTATUS(status), 0) + + # Check result + (pid, uid, gid) = serializer.LoadJson(result) + self.assertEqual(pid, os.getpid()) + self.assertEqual(uid, os.getuid()) + self.assertEqual(gid, os.getgid()) + + class TestListVisibleFiles(unittest.TestCase): """Test case for ListVisibleFiles""" @@ -736,6 +1268,13 @@ class TestListVisibleFiles(unittest.TestCase): expected = ["a", "b"] self._test(files, expected) + def testNonAbsolutePath(self): + self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, "abc") + + def testNonNormalizedPath(self): + self.failUnlessRaises(errors.ProgrammerError, ListVisibleFiles, + "/bin/../tmp") + class TestNewUUID(unittest.TestCase): """Test case for NewUUID""" @@ -824,13 +1363,9 @@ class TestTailFile(testutils.GanetiTestCase): self.failUnlessEqual(TailFile(fname, lines=i), data[-i:]) -class TestFileLock(unittest.TestCase): +class _BaseFileLockTest: """Test case for the FileLock class""" - def setUp(self): - self.tmpfile = tempfile.NamedTemporaryFile() - self.lock = utils.FileLock(self.tmpfile.name) - def testSharedNonblocking(self): self.lock.Shared(blocking=False) self.lock.Close() @@ -867,6 +1402,45 @@ class TestFileLock(unittest.TestCase): self.lock.Unlock(blocking=False) self.lock.Close() + def testSimpleTimeout(self): + # These will succeed on the first attempt, hence a short timeout + self.lock.Shared(blocking=True, timeout=10.0) + self.lock.Exclusive(blocking=False, timeout=10.0) + self.lock.Unlock(blocking=True, timeout=10.0) + self.lock.Close() + + @staticmethod + def _TryLockInner(filename, shared, blocking): + lock = utils.FileLock.Open(filename) + + if shared: + fn = lock.Shared + else: + fn = lock.Exclusive + + try: + # The timeout doesn't really matter as the parent process waits for us to + # finish anyway. + fn(blocking=blocking, timeout=0.01) + except errors.LockError, err: + return False + + return True + + def _TryLock(self, *args): + return utils.RunInSeparateProcess(self._TryLockInner, self.tmpfile.name, + *args) + + def testTimeout(self): + for blocking in [True, False]: + self.lock.Exclusive(blocking=True) + self.failIf(self._TryLock(False, blocking)) + self.failIf(self._TryLock(True, blocking)) + + self.lock.Shared(blocking=True) + self.assert_(self._TryLock(True, blocking)) + self.failIf(self._TryLock(False, blocking)) + def testCloseShared(self): self.lock.Close() self.assertRaises(AssertionError, self.lock.Shared, blocking=False) @@ -880,6 +1454,31 @@ class TestFileLock(unittest.TestCase): self.assertRaises(AssertionError, self.lock.Unlock, blocking=False) +class TestFileLockWithFilename(testutils.GanetiTestCase, _BaseFileLockTest): + TESTDATA = "Hello World\n" * 10 + + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + self.tmpfile = tempfile.NamedTemporaryFile() + utils.WriteFile(self.tmpfile.name, data=self.TESTDATA) + self.lock = utils.FileLock.Open(self.tmpfile.name) + + # Ensure "Open" didn't truncate file + self.assertFileContent(self.tmpfile.name, self.TESTDATA) + + def tearDown(self): + self.assertFileContent(self.tmpfile.name, self.TESTDATA) + + testutils.GanetiTestCase.tearDown(self) + + +class TestFileLockWithFileObject(unittest.TestCase, _BaseFileLockTest): + def setUp(self): + self.tmpfile = tempfile.NamedTemporaryFile() + self.lock = utils.FileLock(open(self.tmpfile.name, "w"), self.tmpfile.name) + + class TestTimeFunctions(unittest.TestCase): """Test case for time functions""" @@ -899,7 +1498,8 @@ class TestTimeFunctions(unittest.TestCase): self.assertEqual(utils.MergeTime((1, 500000)), 1.5) self.assertEqual(utils.MergeTime((1218448917, 500000)), 1218448917.5) - self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3), 1218448917.481) + self.assertEqual(round(utils.MergeTime((1218448917, 481000)), 3), + 1218448917.481) self.assertEqual(round(utils.MergeTime((1, 801000)), 3), 1.801) self.assertRaises(AssertionError, utils.MergeTime, (0, -1)) @@ -969,16 +1569,16 @@ class TestForceDictType(unittest.TestCase): self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'}) -class TestIsAbsNormPath(unittest.TestCase): - """Testing case for IsProcessAlive""" +class TestIsNormAbsPath(unittest.TestCase): + """Testing case for IsNormAbsPath""" def _pathTestHelper(self, path, result): if result: self.assert_(IsNormAbsPath(path), - "Path %s should be absolute and normal" % path) + "Path %s should result absolute and normalized" % path) else: self.assert_(not IsNormAbsPath(path), - "Path %s should not be absolute and normal" % path) + "Path %s should not result absolute and normalized" % path) def testBase(self): self._pathTestHelper('/etc', True) @@ -987,5 +1587,665 @@ class TestIsAbsNormPath(unittest.TestCase): self._pathTestHelper('/etc/../root', False) self._pathTestHelper('/etc/', False) + +class TestSafeEncode(unittest.TestCase): + """Test case for SafeEncode""" + + def testAscii(self): + for txt in [string.digits, string.letters, string.punctuation]: + self.failUnlessEqual(txt, SafeEncode(txt)) + + def testDoubleEncode(self): + for i in range(255): + txt = SafeEncode(chr(i)) + self.failUnlessEqual(txt, SafeEncode(txt)) + + def testUnicode(self): + # 1024 is high enough to catch non-direct ASCII mappings + for i in range(1024): + txt = SafeEncode(unichr(i)) + self.failUnlessEqual(txt, SafeEncode(txt)) + + +class TestFormatTime(unittest.TestCase): + """Testing case for FormatTime""" + + def testNone(self): + self.failUnlessEqual(FormatTime(None), "N/A") + + def testInvalid(self): + self.failUnlessEqual(FormatTime(()), "N/A") + + def testNow(self): + # tests that we accept time.time input + FormatTime(time.time()) + # tests that we accept int input + FormatTime(int(time.time())) + + +class RunInSeparateProcess(unittest.TestCase): + def test(self): + for exp in [True, False]: + def _child(): + return exp + + self.assertEqual(exp, utils.RunInSeparateProcess(_child)) + + def testArgs(self): + for arg in [0, 1, 999, "Hello World", (1, 2, 3)]: + def _child(carg1, carg2): + return carg1 == "Foo" and carg2 == arg + + self.assert_(utils.RunInSeparateProcess(_child, "Foo", arg)) + + def testPid(self): + parent_pid = os.getpid() + + def _check(): + return os.getpid() == parent_pid + + self.failIf(utils.RunInSeparateProcess(_check)) + + def testSignal(self): + def _kill(): + os.kill(os.getpid(), signal.SIGTERM) + + self.assertRaises(errors.GenericError, + utils.RunInSeparateProcess, _kill) + + def testException(self): + def _exc(): + raise errors.GenericError("This is a test") + + self.assertRaises(errors.GenericError, + utils.RunInSeparateProcess, _exc) + + +class TestFingerprintFile(unittest.TestCase): + def setUp(self): + self.tmpfile = tempfile.NamedTemporaryFile() + + def test(self): + self.assertEqual(utils._FingerprintFile(self.tmpfile.name), + "da39a3ee5e6b4b0d3255bfef95601890afd80709") + + utils.WriteFile(self.tmpfile.name, data="Hello World\n") + self.assertEqual(utils._FingerprintFile(self.tmpfile.name), + "648a6a6ffffdaa0badb23b8baf90b6168dd16b3a") + + +class TestUnescapeAndSplit(unittest.TestCase): + """Testing case for UnescapeAndSplit""" + + def setUp(self): + # testing more that one separator for regexp safety + self._seps = [",", "+", "."] + + def testSimple(self): + a = ["a", "b", "c", "d"] + for sep in self._seps: + self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), a) + + def testEscape(self): + for sep in self._seps: + a = ["a", "b\\" + sep + "c", "d"] + b = ["a", "b" + sep + "c", "d"] + self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b) + + def testDoubleEscape(self): + for sep in self._seps: + a = ["a", "b\\\\", "c", "d"] + b = ["a", "b\\", "c", "d"] + self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b) + + def testThreeEscape(self): + for sep in self._seps: + a = ["a", "b\\\\\\" + sep + "c", "d"] + b = ["a", "b\\" + sep + "c", "d"] + self.failUnlessEqual(UnescapeAndSplit(sep.join(a), sep=sep), b) + + +class TestGenerateSelfSignedX509Cert(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def _checkRsaPrivateKey(self, key): + lines = key.splitlines() + return ("-----BEGIN RSA PRIVATE KEY-----" in lines and + "-----END RSA PRIVATE KEY-----" in lines) + + def _checkCertificate(self, cert): + lines = cert.splitlines() + return ("-----BEGIN CERTIFICATE-----" in lines and + "-----END CERTIFICATE-----" in lines) + + def test(self): + for common_name in [None, ".", "Ganeti", "node1.example.com"]: + (key_pem, cert_pem) = utils.GenerateSelfSignedX509Cert(common_name, 300) + self._checkRsaPrivateKey(key_pem) + self._checkCertificate(cert_pem) + + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, + key_pem) + self.assert_(key.bits() >= 1024) + self.assertEqual(key.bits(), constants.RSA_KEY_BITS) + self.assertEqual(key.type(), OpenSSL.crypto.TYPE_RSA) + + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert_pem) + self.failIf(x509.has_expired()) + self.assertEqual(x509.get_issuer().CN, common_name) + self.assertEqual(x509.get_subject().CN, common_name) + self.assertEqual(x509.get_pubkey().bits(), constants.RSA_KEY_BITS) + + def testLegacy(self): + cert1_filename = os.path.join(self.tmpdir, "cert1.pem") + + utils.GenerateSelfSignedSslCert(cert1_filename, validity=1) + + cert1 = utils.ReadFile(cert1_filename) + + self.assert_(self._checkRsaPrivateKey(cert1)) + self.assert_(self._checkCertificate(cert1)) + + +class TestPathJoin(unittest.TestCase): + """Testing case for PathJoin""" + + def testBasicItems(self): + mlist = ["/a", "b", "c"] + self.failUnlessEqual(PathJoin(*mlist), "/".join(mlist)) + + def testNonAbsPrefix(self): + self.failUnlessRaises(ValueError, PathJoin, "a", "b") + + def testBackTrack(self): + self.failUnlessRaises(ValueError, PathJoin, "/a", "b/../c") + + def testMultiAbs(self): + self.failUnlessRaises(ValueError, PathJoin, "/a", "/b") + + +class TestHostInfo(unittest.TestCase): + """Testing case for HostInfo""" + + def testUppercase(self): + data = "AbC.example.com" + self.failUnlessEqual(HostInfo.NormalizeName(data), data.lower()) + + def testTooLongName(self): + data = "a.b." + "c" * 255 + self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, data) + + def testTrailingDot(self): + data = "a.b.c" + self.failUnlessEqual(HostInfo.NormalizeName(data + "."), data) + + def testInvalidName(self): + data = [ + "a b", + "a/b", + ".a.b", + "a..b", + ] + for value in data: + self.failUnlessRaises(OpPrereqError, HostInfo.NormalizeName, value) + + def testValidName(self): + data = [ + "a.b", + "a-b", + "a_b", + "a.b.c", + ] + for value in data: + HostInfo.NormalizeName(value) + + +class TestParseAsn1Generalizedtime(unittest.TestCase): + def test(self): + # UTC + self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000Z"), 0) + self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152Z"), + 1266860512) + self.assertEqual(utils._ParseAsn1Generalizedtime("20380119031407Z"), + (2**31) - 1) + + # With offset + self.assertEqual(utils._ParseAsn1Generalizedtime("20100222174152+0000"), + 1266860512) + self.assertEqual(utils._ParseAsn1Generalizedtime("20100223131652+0000"), + 1266931012) + self.assertEqual(utils._ParseAsn1Generalizedtime("20100223051808-0800"), + 1266931088) + self.assertEqual(utils._ParseAsn1Generalizedtime("20100224002135+1100"), + 1266931295) + self.assertEqual(utils._ParseAsn1Generalizedtime("19700101000000-0100"), + 3600) + + # Leap seconds are not supported by datetime.datetime + self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, + "19841231235960+0000") + self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, + "19920630235960+0000") + + # Errors + self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "") + self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, "invalid") + self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, + "20100222174152") + self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, + "Mon Feb 22 17:47:02 UTC 2010") + self.assertRaises(ValueError, utils._ParseAsn1Generalizedtime, + "2010-02-22 17:42:02") + + +class TestGetX509CertValidity(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + pyopenssl_version = distutils.version.LooseVersion(OpenSSL.__version__) + + # Test whether we have pyOpenSSL 0.7 or above + self.pyopenssl0_7 = (pyopenssl_version >= "0.7") + + if not self.pyopenssl0_7: + warnings.warn("This test requires pyOpenSSL 0.7 or above to" + " function correctly") + + def _LoadCert(self, name): + return OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + self._ReadTestData(name)) + + def test(self): + validity = utils.GetX509CertValidity(self._LoadCert("cert1.pem")) + if self.pyopenssl0_7: + self.assertEqual(validity, (1266919967, 1267524767)) + else: + self.assertEqual(validity, (None, None)) + + +class TestSignX509Certificate(unittest.TestCase): + KEY = "My private key!" + KEY_OTHER = "Another key" + + def test(self): + # Generate certificate valid for 5 minutes + (_, cert_pem) = utils.GenerateSelfSignedX509Cert(None, 300) + + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert_pem) + + # No signature at all + self.assertRaises(errors.GenericError, + utils.LoadSignedX509Certificate, cert_pem, self.KEY) + + # Invalid input + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Signature: \n", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Sign: $1234$abcdef\n", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Signature: $1234567890$abcdef\n", self.KEY) + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + "X-Ganeti-Signature: $1234$abc\n\n" + cert_pem, self.KEY) + + # Invalid salt + for salt in list("-_@$,:;/\\ \t\n"): + self.assertRaises(errors.GenericError, utils.SignX509Certificate, + cert_pem, self.KEY, "foo%sbar" % salt) + + for salt in ["HelloWorld", "salt", string.letters, string.digits, + utils.GenerateSecret(numbytes=4), + utils.GenerateSecret(numbytes=16), + "{123:456}".encode("hex")]: + signed_pem = utils.SignX509Certificate(cert, self.KEY, salt) + + self._Check(cert, salt, signed_pem) + + self._Check(cert, salt, "X-Another-Header: with a value\n" + signed_pem) + self._Check(cert, salt, (10 * "Hello World!\n") + signed_pem) + self._Check(cert, salt, (signed_pem + "\n\na few more\n" + "lines----\n------ at\nthe end!")) + + def _Check(self, cert, salt, pem): + (cert2, salt2) = utils.LoadSignedX509Certificate(pem, self.KEY) + self.assertEqual(salt, salt2) + self.assertEqual(cert.digest("sha1"), cert2.digest("sha1")) + + # Other key + self.assertRaises(errors.GenericError, utils.LoadSignedX509Certificate, + pem, self.KEY_OTHER) + + +class TestMakedirs(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testNonExisting(self): + path = utils.PathJoin(self.tmpdir, "foo") + utils.Makedirs(path) + self.assert_(os.path.isdir(path)) + + def testExisting(self): + path = utils.PathJoin(self.tmpdir, "foo") + os.mkdir(path) + utils.Makedirs(path) + self.assert_(os.path.isdir(path)) + + def testRecursiveNonExisting(self): + path = utils.PathJoin(self.tmpdir, "foo/bar/baz") + utils.Makedirs(path) + self.assert_(os.path.isdir(path)) + + def testRecursiveExisting(self): + path = utils.PathJoin(self.tmpdir, "B/moo/xyz") + self.assert_(not os.path.exists(path)) + os.mkdir(utils.PathJoin(self.tmpdir, "B")) + utils.Makedirs(path) + self.assert_(os.path.isdir(path)) + + +class TestRetry(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + self.retries = 0 + + @staticmethod + def _RaiseRetryAgain(): + raise utils.RetryAgain() + + @staticmethod + def _RaiseRetryAgainWithArg(args): + raise utils.RetryAgain(*args) + + def _WrongNestedLoop(self): + return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02) + + def _RetryAndSucceed(self, retries): + if self.retries < retries: + self.retries += 1 + raise utils.RetryAgain() + else: + return True + + def testRaiseTimeout(self): + self.failUnlessRaises(utils.RetryTimeout, utils.Retry, + self._RaiseRetryAgain, 0.01, 0.02) + self.failUnlessRaises(utils.RetryTimeout, utils.Retry, + self._RetryAndSucceed, 0.01, 0, args=[1]) + self.failUnlessEqual(self.retries, 1) + + def testComplete(self): + self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True) + self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]), + True) + self.failUnlessEqual(self.retries, 2) + + def testNestedLoop(self): + try: + self.failUnlessRaises(errors.ProgrammerError, utils.Retry, + self._WrongNestedLoop, 0, 1) + except utils.RetryTimeout: + self.fail("Didn't detect inner loop's exception") + + def testTimeoutArgument(self): + retry_arg="my_important_debugging_message" + try: + utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]]) + except utils.RetryTimeout, err: + self.failUnlessEqual(err.args, (retry_arg, )) + else: + self.fail("Expected timeout didn't happen") + + def testRaiseInnerWithExc(self): + retry_arg="my_important_debugging_message" + try: + try: + utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, + args=[[errors.GenericError(retry_arg, retry_arg)]]) + except utils.RetryTimeout, err: + err.RaiseInner() + else: + self.fail("Expected timeout didn't happen") + except errors.GenericError, err: + self.failUnlessEqual(err.args, (retry_arg, retry_arg)) + else: + self.fail("Expected GenericError didn't happen") + + def testRaiseInnerWithMsg(self): + retry_arg="my_important_debugging_message" + try: + try: + utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, + args=[[retry_arg, retry_arg]]) + except utils.RetryTimeout, err: + err.RaiseInner() + else: + self.fail("Expected timeout didn't happen") + except utils.RetryTimeout, err: + self.failUnlessEqual(err.args, (retry_arg, retry_arg)) + else: + self.fail("Expected RetryTimeout didn't happen") + + +class TestLineSplitter(unittest.TestCase): + def test(self): + lines = [] + ls = utils.LineSplitter(lines.append) + ls.write("Hello World\n") + self.assertEqual(lines, []) + ls.write("Foo\n Bar\r\n ") + ls.write("Baz") + ls.write("Moo") + self.assertEqual(lines, []) + ls.flush() + self.assertEqual(lines, ["Hello World", "Foo", " Bar"]) + ls.close() + self.assertEqual(lines, ["Hello World", "Foo", " Bar", " BazMoo"]) + + def _testExtra(self, line, all_lines, p1, p2): + self.assertEqual(p1, 999) + self.assertEqual(p2, "extra") + all_lines.append(line) + + def testExtraArgsNoFlush(self): + lines = [] + ls = utils.LineSplitter(self._testExtra, lines, 999, "extra") + ls.write("\n\nHello World\n") + ls.write("Foo\n Bar\r\n ") + ls.write("") + ls.write("Baz") + ls.write("Moo\n\nx\n") + self.assertEqual(lines, []) + ls.close() + self.assertEqual(lines, ["", "", "Hello World", "Foo", " Bar", " BazMoo", + "", "x"]) + + +class TestReadLockedPidFile(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testNonExistent(self): + path = utils.PathJoin(self.tmpdir, "nonexist") + self.assert_(utils.ReadLockedPidFile(path) is None) + + def testUnlocked(self): + path = utils.PathJoin(self.tmpdir, "pid") + utils.WriteFile(path, data="123") + self.assert_(utils.ReadLockedPidFile(path) is None) + + def testLocked(self): + path = utils.PathJoin(self.tmpdir, "pid") + utils.WriteFile(path, data="123") + + fl = utils.FileLock.Open(path) + try: + fl.Exclusive(blocking=True) + + self.assertEqual(utils.ReadLockedPidFile(path), 123) + finally: + fl.Close() + + self.assert_(utils.ReadLockedPidFile(path) is None) + + def testError(self): + path = utils.PathJoin(self.tmpdir, "foobar", "pid") + utils.WriteFile(utils.PathJoin(self.tmpdir, "foobar"), data="") + # open(2) should return ENOTDIR + self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path) + + +class TestCertVerification(testutils.GanetiTestCase): + def setUp(self): + testutils.GanetiTestCase.setUp(self) + + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def testVerifyCertificate(self): + cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem")) + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, + cert_pem) + + # Not checking return value as this certificate is expired + utils.VerifyX509Certificate(cert, 30, 7) + + +class TestVerifyCertificateInner(unittest.TestCase): + def test(self): + vci = utils._VerifyCertificateInner + + # Valid + self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7), + (None, None)) + + # Not yet valid + (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7) + self.assertEqual(errcode, utils.CERT_WARNING) + + # Expiring soon + (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1) + self.assertEqual(errcode, utils.CERT_WARNING) + + (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7) + self.assertEqual(errcode, None) + + # Expired + (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + (errcode, msg) = vci(True, None, None, 1266939600, 30, 7) + self.assertEqual(errcode, utils.CERT_ERROR) + + +class TestHmacFunctions(unittest.TestCase): + # Digests can be checked with "openssl sha1 -hmac $key" + def testSha1Hmac(self): + self.assertEqual(utils.Sha1Hmac("", ""), + "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d") + self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"), + "ef4f3bda82212ecb2f7ce868888a19092481f1fd") + self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""), + "f904c2476527c6d3e6609ab683c66fa0652cb1dc") + + longtext = 1500 * "The quick brown fox jumps over the lazy dog\n" + self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext), + "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54") + + def testSha1HmacSalt(self): + self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"), + "4999bf342470eadb11dfcd24ca5680cf9fd7cdce") + self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"), + "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8") + self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"), + "7f264f8114c9066afc9bb7636e1786d996d3cc0d") + + def testVerifySha1Hmac(self): + self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b" + "7d64b71fb76370690e1d"))) + self.assert_(utils.VerifySha1Hmac("TguMTA2K", "", + ("f904c2476527c6d3e660" + "9ab683c66fa0652cb1dc"))) + + digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd" + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest)) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + digest.lower())) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + digest.upper())) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + digest.title())) + + def testVerifySha1HmacSalt(self): + self.assert_(utils.VerifySha1Hmac("TguMTA2K", "", + ("17a4adc34d69c0d367d4" + "ffbef96fd41d4df7a6e8"), + salt="abc9")) + self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", + ("7f264f8114c9066afc9b" + "b7636e1786d996d3cc0d"), + salt="xyz0")) + + +class TestIgnoreSignals(unittest.TestCase): + """Test the IgnoreSignals decorator""" + + @staticmethod + def _Raise(exception): + raise exception + + @staticmethod + def _Return(rval): + return rval + + def testIgnoreSignals(self): + sock_err_intr = socket.error(errno.EINTR, "Message") + sock_err_intr.errno = errno.EINTR + sock_err_inval = socket.error(errno.EINVAL, "Message") + sock_err_inval.errno = errno.EINVAL + + env_err_intr = EnvironmentError(errno.EINTR, "Message") + env_err_inval = EnvironmentError(errno.EINVAL, "Message") + + self.assertRaises(socket.error, self._Raise, sock_err_intr) + self.assertRaises(socket.error, self._Raise, sock_err_inval) + self.assertRaises(EnvironmentError, self._Raise, env_err_intr) + self.assertRaises(EnvironmentError, self._Raise, env_err_inval) + + self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None) + self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None) + self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise, + sock_err_inval) + self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise, + env_err_inval) + + self.assertEquals(utils.IgnoreSignals(self._Return, True), True) + self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33) + + if __name__ == '__main__': - unittest.main() + testutils.GanetiTestProgram()