Generate import-export unittest certs in parallel
[ganeti-local] / test / ganeti.http_unittest.py
index 6e4f9dc..7d0d477 100755 (executable)
@@ -25,6 +25,7 @@
 import os
 import unittest
 import time
+import tempfile
 
 from ganeti import http
 
@@ -32,6 +33,8 @@ import ganeti.http.server
 import ganeti.http.client
 import ganeti.http.auth
 
+import testutils
+
 
 class TestStartLines(unittest.TestCase):
   """Test cases for start line classes"""
@@ -68,10 +71,7 @@ class TestMisc(unittest.TestCase):
 
   def testHttpServerRequest(self):
     """Test ganeti.http.server._HttpServerRequest"""
-    fake_request = http.HttpMessage()
-    fake_request.start_line = \
-      http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
-    server_request = http.server._HttpServerRequest(fake_request)
+    server_request = http.server._HttpServerRequest("GET", "/", None, None)
 
     # These are expected by users of the HTTP server
     self.assert_(hasattr(server_request, "request_method"))
@@ -93,16 +93,44 @@ class TestMisc(unittest.TestCase):
     self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
     self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
 
+  def testFormatAuthHeader(self):
+    self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
+                     "Basic")
+    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
+                     "Basic foo=bar")
+    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
+                     "Basic foo=\"\"")
+    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
+                     "Basic foo=\"x,y\"")
+    params = {
+      "foo": "x,y",
+      "realm": "secure",
+      }
+    # It's a dict whose order isn't guaranteed, hence checking a list
+    self.assert_(http.auth._FormatAuthHeader("Digest", params) in
+                 ("Digest foo=\"x,y\" realm=secure",
+                  "Digest realm=secure foo=\"x,y\""))
+
 
 class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
-  def __init__(self, realm):
+  def __init__(self, realm, authreq, authenticate_fn):
     http.auth.HttpServerRequestAuthentication.__init__(self)
 
     self.realm = realm
+    self.authreq = authreq
+    self.authenticate_fn = authenticate_fn
+
+  def AuthenticationRequired(self, req):
+    return self.authreq
 
   def GetAuthRealm(self, req):
     return self.realm
 
+  def Authenticate(self, *args):
+    if self.authenticate_fn:
+      return self.authenticate_fn(*args)
+    raise NotImplementedError()
+
 
 class TestAuth(unittest.TestCase):
   """Authentication tests"""
@@ -110,17 +138,16 @@ class TestAuth(unittest.TestCase):
   hsra = http.auth.HttpServerRequestAuthentication
 
   def testConstants(self):
-    self.assertEqual(self.hsra._CLEARTEXT_SCHEME,
-                     self.hsra._CLEARTEXT_SCHEME.upper())
-    self.assertEqual(self.hsra._HA1_SCHEME,
-                     self.hsra._HA1_SCHEME.upper())
+    for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
+      self.assertEqual(scheme, scheme.upper())
+      self.assert_(scheme.startswith("{"))
+      self.assert_(scheme.endswith("}"))
 
   def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
-    ra = _FakeRequestAuth(realm)
+    ra = _FakeRequestAuth(realm, False, None)
 
     return ra.VerifyBasicAuthPassword(None, user, password, expected)
 
-
   def testVerifyBasicAuthPassword(self):
     tvbap = self._testVerifyBasicAuthPassword
 
@@ -152,8 +179,8 @@ class TestAuth(unittest.TestCase):
     self.assert_(tvbap("This is only a test", "user", "pw",
                        "{HA1}92ea58ae804481498c257b2f65561a17"))
 
-    self.failIf(tvbap(None, "user", "pw",
-                      "{HA1}92ea58ae804481498c257b2f65561a17"))
+    self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
+                          "{HA1}92ea58ae804481498c257b2f65561a17")
     self.failIf(tvbap("Admin area", "user", "pw",
                       "{HA1}92ea58ae804481498c257b2f65561a17"))
     self.failIf(tvbap("This is only a test", "someone", "pw",
@@ -162,5 +189,146 @@ class TestAuth(unittest.TestCase):
                       "{HA1}92ea58ae804481498c257b2f65561a17"))
 
 
+class _SimpleAuthenticator:
+  def __init__(self, user, password):
+    self.user = user
+    self.password = password
+    self.called = False
+
+  def __call__(self, req, user, password):
+    self.called = True
+    return self.user == user and self.password == password
+
+
+class TestHttpServerRequestAuthentication(unittest.TestCase):
+  def testNoAuth(self):
+    req = http.server._HttpServerRequest("GET", "/", None, None)
+    _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
+
+  def testNoRealm(self):
+    headers = { http.HTTP_AUTHORIZATION: "", }
+    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    ra = _FakeRequestAuth(None, False, None)
+    self.assertRaises(AssertionError, ra.PreHandleRequest, req)
+
+  def testNoScheme(self):
+    headers = { http.HTTP_AUTHORIZATION: "", }
+    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    ra = _FakeRequestAuth("area1", False, None)
+    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+  def testUnknownScheme(self):
+    headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
+    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    ra = _FakeRequestAuth("area1", False, None)
+    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+  def testInvalidBase64(self):
+    headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
+    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    ra = _FakeRequestAuth("area1", False, None)
+    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+  def testAuthForPublicResource(self):
+    headers = {
+      http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
+      }
+    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    ra = _FakeRequestAuth("area1", False, None)
+    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+  def testAuthForPublicResource(self):
+    headers = {
+      http.HTTP_AUTHORIZATION:
+        "Basic %s" % ("foo:bar".encode("base64").strip(), ),
+      }
+    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    ac = _SimpleAuthenticator("foo", "bar")
+    ra = _FakeRequestAuth("area1", False, ac)
+    ra.PreHandleRequest(req)
+
+    req = http.server._HttpServerRequest("GET", "/", headers, None)
+    ac = _SimpleAuthenticator("something", "else")
+    ra = _FakeRequestAuth("area1", False, ac)
+    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+  def testInvalidRequestHeader(self):
+    checks = {
+      http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
+                              "basic %s" % "foobar".encode("base64").strip()],
+      http.HttpBadRequest: ["Basic"],
+      }
+
+    for exc, headers in checks.items():
+      for i in headers:
+        headers = { http.HTTP_AUTHORIZATION: i, }
+        req = http.server._HttpServerRequest("GET", "/", headers, None)
+        ra = _FakeRequestAuth("area1", False, None)
+        self.assertRaises(exc, ra.PreHandleRequest, req)
+
+  def testBasicAuth(self):
+    for user in ["", "joe", "user name with spaces"]:
+      for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
+                 "foo:bar:baz"]:
+        for wrong_pw in [True, False]:
+          basic_auth = "%s:%s" % (user, pw)
+          if wrong_pw:
+            basic_auth += "WRONG"
+          headers = {
+              http.HTTP_AUTHORIZATION:
+                "Basic %s" % (basic_auth.encode("base64").strip(), ),
+            }
+          req = http.server._HttpServerRequest("GET", "/", headers, None)
+
+          ac = _SimpleAuthenticator(user, pw)
+          self.assertFalse(ac.called)
+          ra = _FakeRequestAuth("area1", True, ac)
+          if wrong_pw:
+            try:
+              ra.PreHandleRequest(req)
+            except http.HttpUnauthorized, err:
+              www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
+              self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
+            else:
+              self.fail("Didn't raise HttpUnauthorized")
+          else:
+            ra.PreHandleRequest(req)
+          self.assert_(ac.called)
+
+
+class TestReadPasswordFile(testutils.GanetiTestCase):
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+
+    self.tmpfile = tempfile.NamedTemporaryFile()
+
+  def testSimple(self):
+    self.tmpfile.write("user1 password")
+    self.tmpfile.flush()
+
+    users = http.auth.ReadPasswordFile(self.tmpfile.name)
+    self.assertEqual(len(users), 1)
+    self.assertEqual(users["user1"].password, "password")
+    self.assertEqual(len(users["user1"].options), 0)
+
+  def testOptions(self):
+    self.tmpfile.write("# Passwords\n")
+    self.tmpfile.write("user1 password\n")
+    self.tmpfile.write("\n")
+    self.tmpfile.write("# Comment\n")
+    self.tmpfile.write("user2 pw write,read\n")
+    self.tmpfile.write("   \t# Another comment\n")
+    self.tmpfile.write("invalidline\n")
+    self.tmpfile.flush()
+
+    users = http.auth.ReadPasswordFile(self.tmpfile.name)
+    self.assertEqual(len(users), 2)
+    self.assertEqual(users["user1"].password, "password")
+    self.assertEqual(len(users["user1"].options), 0)
+
+    self.assertEqual(users["user2"].password, "pw")
+    self.assertEqual(users["user2"].options, ["write", "read"])
+
+
 if __name__ == '__main__':
-  unittest.main()
+  testutils.GanetiTestProgram()