Move cmdlib._VerifyCertificate to utils
authorMichael Hanselmann <hansmi@google.com>
Tue, 27 Apr 2010 15:24:21 +0000 (17:24 +0200)
committerMichael Hanselmann <hansmi@google.com>
Thu, 29 Apr 2010 13:30:17 +0000 (15:30 +0200)
This function will also be useful for inter-cluster instance
moves for verifying certificates.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: RenĂ© Nussbaumer <rn@google.com>

lib/cmdlib.py
lib/utils.py
test/ganeti.cmdlib_unittest.py
test/ganeti.utils_unittest.py

index 7eb8c2a..263ec2b 100644 (file)
@@ -923,13 +923,6 @@ def _FindFaultyInstanceDisks(cfg, rpc, instance, node_name, prereq):
   return faulty
 
 
-def _FormatTimestamp(secs):
-  """Formats a Unix timestamp with the local timezone.
-
-  """
-  return time.strftime("%F %T %Z", time.gmtime(secs))
-
-
 class LUPostInitCluster(LogicalUnit):
   """Logical unit for running hooks after cluster initialization.
 
@@ -1021,45 +1014,6 @@ class LUDestroyCluster(LogicalUnit):
     return master
 
 
-def _VerifyCertificateInner(filename, expired, not_before, not_after, now,
-                            warn_days=constants.SSL_CERT_EXPIRATION_WARN,
-                            error_days=constants.SSL_CERT_EXPIRATION_ERROR):
-  """Verifies certificate details for LUVerifyCluster.
-
-  """
-  if expired:
-    msg = "Certificate %s is expired" % filename
-
-    if not_before is not None and not_after is not None:
-      msg += (" (valid from %s to %s)" %
-              (_FormatTimestamp(not_before),
-               _FormatTimestamp(not_after)))
-    elif not_before is not None:
-      msg += " (valid from %s)" % _FormatTimestamp(not_before)
-    elif not_after is not None:
-      msg += " (valid until %s)" % _FormatTimestamp(not_after)
-
-    return (LUVerifyCluster.ETYPE_ERROR, msg)
-
-  elif not_before is not None and not_before > now:
-    return (LUVerifyCluster.ETYPE_WARNING,
-            "Certificate %s not yet valid (valid from %s)" %
-            (filename, _FormatTimestamp(not_before)))
-
-  elif not_after is not None:
-    remaining_days = int((not_after - now) / (24 * 3600))
-
-    msg = ("Certificate %s expires in %d days" % (filename, remaining_days))
-
-    if remaining_days <= error_days:
-      return (LUVerifyCluster.ETYPE_ERROR, msg)
-
-    if remaining_days <= warn_days:
-      return (LUVerifyCluster.ETYPE_WARNING, msg)
-
-  return (None, None)
-
-
 def _VerifyCertificate(filename):
   """Verifies a certificate for LUVerifyCluster.
 
@@ -1074,11 +1028,23 @@ def _VerifyCertificate(filename):
     return (LUVerifyCluster.ETYPE_ERROR,
             "Failed to load X509 certificate %s: %s" % (filename, err))
 
-  # Depending on the pyOpenSSL version, this can just return (None, None)
-  (not_before, not_after) = utils.GetX509CertValidity(cert)
+  (errcode, msg) = \
+    utils.VerifyX509Certificate(cert, constants.SSL_CERT_EXPIRATION_WARN,
+                                constants.SSL_CERT_EXPIRATION_ERROR)
+
+  if msg:
+    fnamemsg = "While verifying %s: %s" % (filename, msg)
+  else:
+    fnamemsg = None
+
+  if errcode is None:
+    return (None, fnamemsg)
+  elif errcode == utils.CERT_WARNING:
+    return (LUVerifyCluster.ETYPE_WARNING, fnamemsg)
+  elif errcode == utils.CERT_ERROR:
+    return (LUVerifyCluster.ETYPE_ERROR, fnamemsg)
 
-  return _VerifyCertificateInner(filename, cert.has_expired(),
-                                 not_before, not_after, time.time())
+  raise errors.ProgrammerError("Unhandled certificate error code %r" % errcode)
 
 
 class LUVerifyCluster(LogicalUnit):
index 59ecaea..379172a 100644 (file)
@@ -92,6 +92,10 @@ X509_SIGNATURE = re.compile(r"^%s:\s*(?P<salt>%s+)/(?P<sign>%s+)$" %
 _STRUCT_UCRED = "iII"
 _STRUCT_UCRED_SIZE = struct.calcsize(_STRUCT_UCRED)
 
+# Certificate verification results
+(CERT_WARNING,
+ CERT_ERROR) = range(1, 3)
+
 
 class RunResult(object):
   """Holds the result of running external programs.
@@ -2442,6 +2446,13 @@ def TailFile(fname, lines=20):
   return rows[-lines:]
 
 
+def FormatTimestampWithTZ(secs):
+  """Formats a Unix timestamp with the local timezone.
+
+  """
+  return time.strftime("%F %T %Z", time.gmtime(secs))
+
+
 def _ParseAsn1Generalizedtime(value):
   """Parses an ASN1 GENERALIZEDTIME timestamp as used by pyOpenSSL.
 
@@ -2505,6 +2516,75 @@ def GetX509CertValidity(cert):
   return (not_before, not_after)
 
 
+def _VerifyCertificateInner(expired, not_before, not_after, now,
+                            warn_days, error_days):
+  """Verifies certificate validity.
+
+  @type expired: bool
+  @param expired: Whether pyOpenSSL considers the certificate as expired
+  @type not_before: number or None
+  @param not_before: Unix timestamp before which certificate is not valid
+  @type not_after: number or None
+  @param not_after: Unix timestamp after which certificate is invalid
+  @type now: number
+  @param now: Current time as Unix timestamp
+  @type warn_days: number or None
+  @param warn_days: How many days before expiration a warning should be reported
+  @type error_days: number or None
+  @param error_days: How many days before expiration an error should be reported
+
+  """
+  if expired:
+    msg = "Certificate is expired"
+
+    if not_before is not None and not_after is not None:
+      msg += (" (valid from %s to %s)" %
+              (FormatTimestampWithTZ(not_before),
+               FormatTimestampWithTZ(not_after)))
+    elif not_before is not None:
+      msg += " (valid from %s)" % FormatTimestampWithTZ(not_before)
+    elif not_after is not None:
+      msg += " (valid until %s)" % FormatTimestampWithTZ(not_after)
+
+    return (CERT_ERROR, msg)
+
+  elif not_before is not None and not_before > now:
+    return (CERT_WARNING,
+            "Certificate not yet valid (valid from %s)" %
+            FormatTimestampWithTZ(not_before))
+
+  elif not_after is not None:
+    remaining_days = int((not_after - now) / (24 * 3600))
+
+    msg = "Certificate expires in about %d days" % remaining_days
+
+    if error_days is not None and remaining_days <= error_days:
+      return (CERT_ERROR, msg)
+
+    if warn_days is not None and remaining_days <= warn_days:
+      return (CERT_WARNING, msg)
+
+  return (None, None)
+
+
+def VerifyX509Certificate(cert, warn_days, error_days):
+  """Verifies a certificate for LUVerifyCluster.
+
+  @type cert: OpenSSL.crypto.X509
+  @param cert: X509 certificate object
+  @type warn_days: number or None
+  @param warn_days: How many days before expiration a warning should be reported
+  @type error_days: number or None
+  @param error_days: How many days before expiration an error should be reported
+
+  """
+  # Depending on the pyOpenSSL version, this can just return (None, None)
+  (not_before, not_after) = GetX509CertValidity(cert)
+
+  return _VerifyCertificateInner(cert.has_expired(), not_before, not_after,
+                                 time.time(), warn_days, error_days)
+
+
 def SignX509Certificate(cert, key, salt):
   """Sign a X509 certificate.
 
index 8af6168..092225b 100755 (executable)
@@ -52,57 +52,10 @@ class TestCertVerification(testutils.GanetiTestCase):
     self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_ERROR)
 
     # Try to load non-certificate file
-    invalid_cert = self._TestDataFilename("bdev-net1.txt")
+    invalid_cert = self._TestDataFilename("bdev-net.txt")
     (errcode, msg) = cmdlib._VerifyCertificate(invalid_cert)
     self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_ERROR)
 
 
-class TestVerifyCertificateInner(unittest.TestCase):
-  FAKEFILE = "/tmp/fake/cert/file.pem"
-
-  def test(self):
-    vci = cmdlib._VerifyCertificateInner
-
-    # Valid
-    self.assertEqual(vci(self.FAKEFILE, False, 1263916313, 1298476313,
-                         1266940313, warn_days=30, error_days=7),
-                     (None, None))
-
-    # Not yet valid
-    (errcode, msg) = vci(self.FAKEFILE, False, 1266507600, 1267544400,
-                         1266075600, warn_days=30, error_days=7)
-    self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_WARNING)
-
-    # Expiring soon
-    (errcode, msg) = vci(self.FAKEFILE, False, 1266507600, 1267544400,
-                         1266939600, warn_days=30, error_days=7)
-    self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_ERROR)
-
-    (errcode, msg) = vci(self.FAKEFILE, False, 1266507600, 1267544400,
-                         1266939600, warn_days=30, error_days=1)
-    self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_WARNING)
-
-    (errcode, msg) = vci(self.FAKEFILE, False, 1266507600, None,
-                         1266939600, warn_days=30, error_days=7)
-    self.assertEqual(errcode, None)
-
-    # Expired
-    (errcode, msg) = vci(self.FAKEFILE, True, 1266507600, 1267544400,
-                         1266939600, warn_days=30, error_days=7)
-    self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_ERROR)
-
-    (errcode, msg) = vci(self.FAKEFILE, True, None, 1267544400,
-                         1266939600, warn_days=30, error_days=7)
-    self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_ERROR)
-
-    (errcode, msg) = vci(self.FAKEFILE, True, 1266507600, None,
-                         1266939600, warn_days=30, error_days=7)
-    self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_ERROR)
-
-    (errcode, msg) = vci(self.FAKEFILE, True, None, None,
-                         1266939600, warn_days=30, error_days=7)
-    self.assertEqual(errcode, cmdlib.LUVerifyCluster.ETYPE_ERROR)
-
-
 if __name__ == "__main__":
   testutils.GanetiTestProgram()
index 9c49cbd..836b200 100755 (executable)
@@ -1945,5 +1945,59 @@ class TestReadLockedPidFile(unittest.TestCase):
     self.assertRaises(EnvironmentError, utils.ReadLockedPidFile, path)
 
 
+class TestCertVerification(testutils.GanetiTestCase):
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+
+    self.tmpdir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    shutil.rmtree(self.tmpdir)
+
+  def testVerifyCertificate(self):
+    cert_pem = utils.ReadFile(self._TestDataFilename("cert1.pem"))
+    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM,
+                                           cert_pem)
+
+    # Not checking return value as this certificate is expired
+    utils.VerifyX509Certificate(cert, 30, 7)
+
+
+class TestVerifyCertificateInner(unittest.TestCase):
+  def test(self):
+    vci = utils._VerifyCertificateInner
+
+    # Valid
+    self.assertEqual(vci(False, 1263916313, 1298476313, 1266940313, 30, 7),
+                     (None, None))
+
+    # Not yet valid
+    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266075600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_WARNING)
+
+    # Expiring soon
+    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(False, 1266507600, 1267544400, 1266939600, 30, 1)
+    self.assertEqual(errcode, utils.CERT_WARNING)
+
+    (errcode, msg) = vci(False, 1266507600, None, 1266939600, 30, 7)
+    self.assertEqual(errcode, None)
+
+    # Expired
+    (errcode, msg) = vci(True, 1266507600, 1267544400, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(True, None, 1267544400, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(True, 1266507600, None, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+    (errcode, msg) = vci(True, None, None, 1266939600, 30, 7)
+    self.assertEqual(errcode, utils.CERT_ERROR)
+
+
 if __name__ == '__main__':
   testutils.GanetiTestProgram()