Merge branch 'devel-2.1'
[ganeti-local] / test / ganeti.utils_unittest.py
index 7fc48a9..511e48a 100755 (executable)
@@ -39,19 +39,22 @@ import OpenSSL
 import warnings
 import distutils.version
 import glob
 import warnings
 import distutils.version
 import glob
+import md5
+import errno
 
 import ganeti
 import testutils
 from ganeti import constants
 from ganeti import utils
 from ganeti import errors
 
 import ganeti
 import testutils
 from ganeti import constants
 from ganeti import utils
 from ganeti import errors
+from ganeti import serializer
 from ganeti.utils import IsProcessAlive, RunCmd, \
      RemoveFile, MatchNameComponent, FormatUnit, \
      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
 from ganeti.utils import IsProcessAlive, RunCmd, \
      RemoveFile, MatchNameComponent, FormatUnit, \
      ParseUnit, AddAuthorizedKey, RemoveAuthorizedKey, \
      ShellQuote, ShellQuoteArgs, TcpPing, ListVisibleFiles, \
      SetEtcHostsEntry, RemoveEtcHostsEntry, FirstFree, OwnIpAddress, \
      TailFile, ForceDictType, SafeEncode, IsNormAbsPath, FormatTime, \
-     UnescapeAndSplit, RunParts, PathJoin, HostInfo
+     UnescapeAndSplit, RunParts, PathJoin, HostInfo, ReadOneLineFile
 
 from ganeti.errors import LockError, UnitParseError, GenericError, \
      ProgrammerError, OpPrereqError
 
 from ganeti.errors import LockError, UnitParseError, GenericError, \
      ProgrammerError, OpPrereqError
@@ -661,6 +664,107 @@ class TestMatchNameComponent(unittest.TestCase):
                      None)
 
 
                      None)
 
 
+class TestReadFile(testutils.GanetiTestCase):
+
+  def testReadAll(self):
+    data = utils.ReadFile(self._TestDataFilename("cert1.pem"))
+    self.assertEqual(len(data), 814)
+
+    h = md5.new()
+    h.update(data)
+    self.assertEqual(h.hexdigest(), "a491efb3efe56a0535f924d5f8680fd4")
+
+  def testReadSize(self):
+    data = utils.ReadFile(self._TestDataFilename("cert1.pem"),
+                          size=100)
+    self.assertEqual(len(data), 100)
+
+    h = md5.new()
+    h.update(data)
+    self.assertEqual(h.hexdigest(), "893772354e4e690b9efd073eed433ce7")
+
+  def testError(self):
+    self.assertRaises(EnvironmentError, utils.ReadFile,
+                      "/dev/null/does-not-exist")
+
+
+class TestReadOneLineFile(testutils.GanetiTestCase):
+
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+
+  def testDefault(self):
+    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"))
+    self.assertEqual(len(data), 27)
+    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
+
+  def testNotStrict(self):
+    data = ReadOneLineFile(self._TestDataFilename("cert1.pem"), strict=False)
+    self.assertEqual(len(data), 27)
+    self.assertEqual(data, "-----BEGIN CERTIFICATE-----")
+
+  def testStrictFailure(self):
+    self.assertRaises(errors.GenericError, ReadOneLineFile,
+                      self._TestDataFilename("cert1.pem"), strict=True)
+
+  def testLongLine(self):
+    dummydata = (1024 * "Hello World! ")
+    myfile = self._CreateTempFile()
+    utils.WriteFile(myfile, data=dummydata)
+    datastrict = ReadOneLineFile(myfile, strict=True)
+    datalax = ReadOneLineFile(myfile, strict=False)
+    self.assertEqual(dummydata, datastrict)
+    self.assertEqual(dummydata, datalax)
+
+  def testNewline(self):
+    myfile = self._CreateTempFile()
+    myline = "myline"
+    for nl in ["", "\n", "\r\n"]:
+      dummydata = "%s%s" % (myline, nl)
+      utils.WriteFile(myfile, data=dummydata)
+      datalax = ReadOneLineFile(myfile, strict=False)
+      self.assertEqual(myline, datalax)
+      datastrict = ReadOneLineFile(myfile, strict=True)
+      self.assertEqual(myline, datastrict)
+
+  def testWhitespaceAndMultipleLines(self):
+    myfile = self._CreateTempFile()
+    for nl in ["", "\n", "\r\n"]:
+      for ws in [" ", "\t", "\t\t  \t", "\t "]:
+        dummydata = (1024 * ("Foo bar baz %s%s" % (ws, nl)))
+        utils.WriteFile(myfile, data=dummydata)
+        datalax = ReadOneLineFile(myfile, strict=False)
+        if nl:
+          self.assert_(set("\r\n") & set(dummydata))
+          self.assertRaises(errors.GenericError, ReadOneLineFile,
+                            myfile, strict=True)
+          explen = len("Foo bar baz ") + len(ws)
+          self.assertEqual(len(datalax), explen)
+          self.assertEqual(datalax, dummydata[:explen])
+          self.assertFalse(set("\r\n") & set(datalax))
+        else:
+          datastrict = ReadOneLineFile(myfile, strict=True)
+          self.assertEqual(dummydata, datastrict)
+          self.assertEqual(dummydata, datalax)
+
+  def testEmptylines(self):
+    myfile = self._CreateTempFile()
+    myline = "myline"
+    for nl in ["\n", "\r\n"]:
+      for ol in ["", "otherline"]:
+        dummydata = "%s%s%s%s%s%s" % (nl, nl, myline, nl, ol, nl)
+        utils.WriteFile(myfile, data=dummydata)
+        self.assert_(set("\r\n") & set(dummydata))
+        datalax = ReadOneLineFile(myfile, strict=False)
+        self.assertEqual(myline, datalax)
+        if ol:
+          self.assertRaises(errors.GenericError, ReadOneLineFile,
+                            myfile, strict=True)
+        else:
+          datastrict = ReadOneLineFile(myfile, strict=True)
+          self.assertEqual(myline, datastrict)
+
+
 class TestTimestampForFilename(unittest.TestCase):
   def test(self):
     self.assert_("." not in utils.TimestampForFilename())
 class TestTimestampForFilename(unittest.TestCase):
   def test(self):
     self.assert_("." not in utils.TimestampForFilename())
@@ -1049,13 +1153,80 @@ class TestOwnIpAddress(unittest.TestCase):
   def testNowOwnAddress(self):
     """check that I don't own an address"""
 
   def testNowOwnAddress(self):
     """check that I don't own an address"""
 
-    # network 192.0.2.0/24 is reserved for test/documentation as per
-    # rfc 3330, so we *should* not have an address of this range... if
+    # Network 192.0.2.0/24 is reserved for test/documentation as per
+    # RFC 5735, so we *should* not have an address of this range... if
     # this fails, we should extend the test to multiple addresses
     DST_IP = "192.0.2.1"
     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
 
 
     # this fails, we should extend the test to multiple addresses
     DST_IP = "192.0.2.1"
     self.failIf(OwnIpAddress(DST_IP), "Should not own IP address %s" % DST_IP)
 
 
+def _GetSocketCredentials(path):
+  """Connect to a Unix socket and return remote credentials.
+
+  """
+  sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+  try:
+    sock.settimeout(10)
+    sock.connect(path)
+    return utils.GetSocketCredentials(sock)
+  finally:
+    sock.close()
+
+
+class TestGetSocketCredentials(unittest.TestCase):
+  def setUp(self):
+    self.tmpdir = tempfile.mkdtemp()
+    self.sockpath = utils.PathJoin(self.tmpdir, "sock")
+
+    self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+    self.listener.settimeout(10)
+    self.listener.bind(self.sockpath)
+    self.listener.listen(1)
+
+  def tearDown(self):
+    self.listener.shutdown(socket.SHUT_RDWR)
+    self.listener.close()
+    shutil.rmtree(self.tmpdir)
+
+  def test(self):
+    (c2pr, c2pw) = os.pipe()
+
+    # Start child process
+    child = os.fork()
+    if child == 0:
+      try:
+        data = serializer.DumpJson(_GetSocketCredentials(self.sockpath))
+
+        os.write(c2pw, data)
+        os.close(c2pw)
+
+        os._exit(0)
+      finally:
+        os._exit(1)
+
+    os.close(c2pw)
+
+    # Wait for one connection
+    (conn, _) = self.listener.accept()
+    conn.recv(1)
+    conn.close()
+
+    # Wait for result
+    result = os.read(c2pr, 4096)
+    os.close(c2pr)
+
+    # Check child's exit code
+    (_, status) = os.waitpid(child, 0)
+    self.assertFalse(os.WIFSIGNALED(status))
+    self.assertEqual(os.WEXITSTATUS(status), 0)
+
+    # Check result
+    (pid, uid, gid) = serializer.LoadJson(result)
+    self.assertEqual(pid, os.getpid())
+    self.assertEqual(uid, os.getuid())
+    self.assertEqual(gid, os.getgid())
+
+
 class TestListVisibleFiles(unittest.TestCase):
   """Test case for ListVisibleFiles"""
 
 class TestListVisibleFiles(unittest.TestCase):
   """Test case for ListVisibleFiles"""
 
@@ -1398,7 +1569,7 @@ class TestForceDictType(unittest.TestCase):
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
 
 
     self.assertRaises(errors.TypeEnforcementError, self._fdt, {'d': '4 L'})
 
 
-class TestIsAbsNormPath(unittest.TestCase):
+class TestIsNormAbsPath(unittest.TestCase):
   """Testing case for IsNormAbsPath"""
 
   def _pathTestHelper(self, path, result):
   """Testing case for IsNormAbsPath"""
 
   def _pathTestHelper(self, path, result):
@@ -1784,19 +1955,40 @@ class TestMakedirs(unittest.TestCase):
 
 
 class TestRetry(testutils.GanetiTestCase):
 
 
 class TestRetry(testutils.GanetiTestCase):
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    self.retries = 0
+
   @staticmethod
   def _RaiseRetryAgain():
     raise utils.RetryAgain()
 
   @staticmethod
   def _RaiseRetryAgain():
     raise utils.RetryAgain()
 
+  @staticmethod
+  def _RaiseRetryAgainWithArg(args):
+    raise utils.RetryAgain(*args)
+
   def _WrongNestedLoop(self):
     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
 
   def _WrongNestedLoop(self):
     return utils.Retry(self._RaiseRetryAgain, 0.01, 0.02)
 
+  def _RetryAndSucceed(self, retries):
+    if self.retries < retries:
+      self.retries += 1
+      raise utils.RetryAgain()
+    else:
+      return True
+
   def testRaiseTimeout(self):
     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
                           self._RaiseRetryAgain, 0.01, 0.02)
   def testRaiseTimeout(self):
     self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
                           self._RaiseRetryAgain, 0.01, 0.02)
+    self.failUnlessRaises(utils.RetryTimeout, utils.Retry,
+                          self._RetryAndSucceed, 0.01, 0, args=[1])
+    self.failUnlessEqual(self.retries, 1)
 
   def testComplete(self):
     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
 
   def testComplete(self):
     self.failUnlessEqual(utils.Retry(lambda: True, 0, 1), True)
+    self.failUnlessEqual(utils.Retry(self._RetryAndSucceed, 0, 1, args=[2]),
+                         True)
+    self.failUnlessEqual(self.retries, 2)
 
   def testNestedLoop(self):
     try:
 
   def testNestedLoop(self):
     try:
@@ -1805,6 +1997,45 @@ class TestRetry(testutils.GanetiTestCase):
     except utils.RetryTimeout:
       self.fail("Didn't detect inner loop's exception")
 
     except utils.RetryTimeout:
       self.fail("Didn't detect inner loop's exception")
 
+  def testTimeoutArgument(self):
+    retry_arg="my_important_debugging_message"
+    try:
+      utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02, args=[[retry_arg]])
+    except utils.RetryTimeout, err:
+      self.failUnlessEqual(err.args, (retry_arg, ))
+    else:
+      self.fail("Expected timeout didn't happen")
+
+  def testRaiseInnerWithExc(self):
+    retry_arg="my_important_debugging_message"
+    try:
+      try:
+        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
+                    args=[[errors.GenericError(retry_arg, retry_arg)]])
+      except utils.RetryTimeout, err:
+        err.RaiseInner()
+      else:
+        self.fail("Expected timeout didn't happen")
+    except errors.GenericError, err:
+      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
+    else:
+      self.fail("Expected GenericError didn't happen")
+
+  def testRaiseInnerWithMsg(self):
+    retry_arg="my_important_debugging_message"
+    try:
+      try:
+        utils.Retry(self._RaiseRetryAgainWithArg, 0.01, 0.02,
+                    args=[[retry_arg, retry_arg]])
+      except utils.RetryTimeout, err:
+        err.RaiseInner()
+      else:
+        self.fail("Expected timeout didn't happen")
+    except utils.RetryTimeout, err:
+      self.failUnlessEqual(err.args, (retry_arg, retry_arg))
+    else:
+      self.fail("Expected RetryTimeout didn't happen")
+
 
 class TestLineSplitter(unittest.TestCase):
   def test(self):
 
 class TestLineSplitter(unittest.TestCase):
   def test(self):
@@ -1877,5 +2108,144 @@ class TestReadLockedPidFile(unittest.TestCase):
     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
 
 
     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
 
 
+class TestCertVerification(testutils.GanetiTestCase):
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  def testVerifyCertificate(self):
+    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
+    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
+                                           cert_pem)
+
+    # Not checking return value as this certificate is expired
+    utils.VerifyX509Certificate(cert, 30, 7)
+
+
+class TestVerifyCertificateInner(unittest.TestCase):
+  def test(self):
+    vci = utils._VerifyCertificateInner
+
+    # Valid
+    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
+                     (None, None))
+
+    # Not yet valid
+    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_WARNING)
+
+    # Expiring soon
+    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
+    self.assertEqual(errcode, utils.CERT_WARNING)
+
+    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
+    self.assertEqual(errcode, None)
+
+    # Expired
+    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+
+class TestHmacFunctions(unittest.TestCase):
+  # Digests can be checked with "openssl sha1 -hmac $key"
+  def testSha1Hmac(self):
+    self.assertEqual(utils.Sha1Hmac("", ""),
+                     "fbdb1d1b18aa6c08324b7d64b71fb76370690e1d")
+    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World"),
+                     "ef4f3bda82212ecb2f7ce868888a19092481f1fd")
+    self.assertEqual(utils.Sha1Hmac("TguMTA2K", ""),
+                     "f904c2476527c6d3e6609ab683c66fa0652cb1dc")
+
+    longtext = 1500 * "The quick brown fox jumps over the lazy dog\n"
+    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", longtext),
+                     "35901b9a3001a7cdcf8e0e9d7c2e79df2223af54")
+
+  def testSha1HmacSalt(self):
+    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc0"),
+                     "4999bf342470eadb11dfcd24ca5680cf9fd7cdce")
+    self.assertEqual(utils.Sha1Hmac("TguMTA2K", "", salt="abc9"),
+                     "17a4adc34d69c0d367d4ffbef96fd41d4df7a6e8")
+    self.assertEqual(utils.Sha1Hmac("3YzMxZWE", "Hello World", salt="xyz0"),
+                     "7f264f8114c9066afc9bb7636e1786d996d3cc0d")
+
+  def testVerifySha1Hmac(self):
+    self.assert_(utils.VerifySha1Hmac("", "", ("fbdb1d1b18aa6c08324b"
+                                               "7d64b71fb76370690e1d")))
+    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
+                                      ("f904c2476527c6d3e660"
+                                       "9ab683c66fa0652cb1dc")))
+
+    digest = "ef4f3bda82212ecb2f7ce868888a19092481f1fd"
+    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World", digest))
+    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
+                                      digest.lower()))
+    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
+                                      digest.upper()))
+    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
+                                      digest.title()))
+
+  def testVerifySha1HmacSalt(self):
+    self.assert_(utils.VerifySha1Hmac("TguMTA2K", "",
+                                      ("17a4adc34d69c0d367d4"
+                                       "ffbef96fd41d4df7a6e8"),
+                                      salt="abc9"))
+    self.assert_(utils.VerifySha1Hmac("3YzMxZWE", "Hello World",
+                                      ("7f264f8114c9066afc9b"
+                                       "b7636e1786d996d3cc0d"),
+                                      salt="xyz0"))
+
+
+class TestIgnoreSignals(unittest.TestCase):
+  """Test the IgnoreSignals decorator"""
+
+  @staticmethod
+  def _Raise(exception):
+    raise exception
+
+  @staticmethod
+  def _Return(rval):
+    return rval
+
+  def testIgnoreSignals(self):
+    sock_err_intr = socket.error(errno.EINTR, "Message")
+    sock_err_intr.errno = errno.EINTR
+    sock_err_inval = socket.error(errno.EINVAL, "Message")
+    sock_err_inval.errno = errno.EINVAL
+
+    env_err_intr = EnvironmentError(errno.EINTR, "Message")
+    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
+
+    self.assertRaises(socket.error, self._Raise, sock_err_intr)
+    self.assertRaises(socket.error, self._Raise, sock_err_inval)
+    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
+    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
+
+    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
+    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
+    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
+                      sock_err_inval)
+    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
+                      env_err_inval)
+
+    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
+    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
+
+
 if __name__ == '__main__':
   testutils.GanetiTestProgram()
 if __name__ == '__main__':
   testutils.GanetiTestProgram()