import os
import unittest
import time
+import tempfile
from ganeti import http
import ganeti.http.client
import ganeti.http.auth
+import testutils
+
class TestStartLines(unittest.TestCase):
"""Test cases for start line classes"""
def testHttpServerRequest(self):
"""Test ganeti.http.server._HttpServerRequest"""
- fake_request = http.HttpMessage()
- fake_request.start_line = \
- http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
- server_request = http.server._HttpServerRequest(fake_request)
+ server_request = http.server._HttpServerRequest("GET", "/", None, None)
# These are expected by users of the HTTP server
self.assert_(hasattr(server_request, "request_method"))
self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
+ def testFormatAuthHeader(self):
+ self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
+ "Basic")
+ self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
+ "Basic foo=bar")
+ self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
+ "Basic foo=\"\"")
+ self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
+ "Basic foo=\"x,y\"")
+ params = {
+ "foo": "x,y",
+ "realm": "secure",
+ }
+ # It's a dict whose order isn't guaranteed, hence checking a list
+ self.assert_(http.auth._FormatAuthHeader("Digest", params) in
+ ("Digest foo=\"x,y\" realm=secure",
+ "Digest realm=secure foo=\"x,y\""))
+
class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
- def __init__(self, realm):
+ def __init__(self, realm, authreq, authenticate_fn):
http.auth.HttpServerRequestAuthentication.__init__(self)
self.realm = realm
+ self.authreq = authreq
+ self.authenticate_fn = authenticate_fn
+
+ def AuthenticationRequired(self, req):
+ return self.authreq
def GetAuthRealm(self, req):
return self.realm
+ def Authenticate(self, *args):
+ if self.authenticate_fn:
+ return self.authenticate_fn(*args)
+ raise NotImplementedError()
+
class TestAuth(unittest.TestCase):
"""Authentication tests"""
hsra = http.auth.HttpServerRequestAuthentication
def testConstants(self):
- self.assertEqual(self.hsra._CLEARTEXT_SCHEME,
- self.hsra._CLEARTEXT_SCHEME.upper())
- self.assertEqual(self.hsra._HA1_SCHEME,
- self.hsra._HA1_SCHEME.upper())
+ for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
+ self.assertEqual(scheme, scheme.upper())
+ self.assert_(scheme.startswith("{"))
+ self.assert_(scheme.endswith("}"))
def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
- ra = _FakeRequestAuth(realm)
+ ra = _FakeRequestAuth(realm, False, None)
return ra.VerifyBasicAuthPassword(None, user, password, expected)
-
def testVerifyBasicAuthPassword(self):
tvbap = self._testVerifyBasicAuthPassword
self.assert_(tvbap("This is only a test", "user", "pw",
"{HA1}92ea58ae804481498c257b2f65561a17"))
- self.failIf(tvbap(None, "user", "pw",
- "{HA1}92ea58ae804481498c257b2f65561a17"))
+ self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
+ "{HA1}92ea58ae804481498c257b2f65561a17")
self.failIf(tvbap("Admin area", "user", "pw",
"{HA1}92ea58ae804481498c257b2f65561a17"))
self.failIf(tvbap("This is only a test", "someone", "pw",
"{HA1}92ea58ae804481498c257b2f65561a17"))
+class _SimpleAuthenticator:
+ def __init__(self, user, password):
+ self.user = user
+ self.password = password
+ self.called = False
+
+ def __call__(self, req, user, password):
+ self.called = True
+ return self.user == user and self.password == password
+
+
+class TestHttpServerRequestAuthentication(unittest.TestCase):
+ def testNoAuth(self):
+ req = http.server._HttpServerRequest("GET", "/", None, None)
+ _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
+
+ def testNoRealm(self):
+ headers = { http.HTTP_AUTHORIZATION: "", }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ra = _FakeRequestAuth(None, False, None)
+ self.assertRaises(AssertionError, ra.PreHandleRequest, req)
+
+ def testNoScheme(self):
+ headers = { http.HTTP_AUTHORIZATION: "", }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ra = _FakeRequestAuth("area1", False, None)
+ self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+ def testUnknownScheme(self):
+ headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ra = _FakeRequestAuth("area1", False, None)
+ self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+ def testInvalidBase64(self):
+ headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ra = _FakeRequestAuth("area1", False, None)
+ self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+ def testAuthForPublicResource(self):
+ headers = {
+ http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
+ }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ra = _FakeRequestAuth("area1", False, None)
+ self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+ def testAuthForPublicResource(self):
+ headers = {
+ http.HTTP_AUTHORIZATION:
+ "Basic %s" % ("foo:bar".encode("base64").strip(), ),
+ }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ac = _SimpleAuthenticator("foo", "bar")
+ ra = _FakeRequestAuth("area1", False, ac)
+ ra.PreHandleRequest(req)
+
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ac = _SimpleAuthenticator("something", "else")
+ ra = _FakeRequestAuth("area1", False, ac)
+ self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
+
+ def testInvalidRequestHeader(self):
+ checks = {
+ http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
+ "basic %s" % "foobar".encode("base64").strip()],
+ http.HttpBadRequest: ["Basic"],
+ }
+
+ for exc, headers in checks.items():
+ for i in headers:
+ headers = { http.HTTP_AUTHORIZATION: i, }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+ ra = _FakeRequestAuth("area1", False, None)
+ self.assertRaises(exc, ra.PreHandleRequest, req)
+
+ def testBasicAuth(self):
+ for user in ["", "joe", "user name with spaces"]:
+ for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
+ "foo:bar:baz"]:
+ for wrong_pw in [True, False]:
+ basic_auth = "%s:%s" % (user, pw)
+ if wrong_pw:
+ basic_auth += "WRONG"
+ headers = {
+ http.HTTP_AUTHORIZATION:
+ "Basic %s" % (basic_auth.encode("base64").strip(), ),
+ }
+ req = http.server._HttpServerRequest("GET", "/", headers, None)
+
+ ac = _SimpleAuthenticator(user, pw)
+ self.assertFalse(ac.called)
+ ra = _FakeRequestAuth("area1", True, ac)
+ if wrong_pw:
+ try:
+ ra.PreHandleRequest(req)
+ except http.HttpUnauthorized, err:
+ www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
+ self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
+ else:
+ self.fail("Didn't raise HttpUnauthorized")
+ else:
+ ra.PreHandleRequest(req)
+ self.assert_(ac.called)
+
+
+class TestReadPasswordFile(testutils.GanetiTestCase):
+ def setUp(self):
+ testutils.GanetiTestCase.setUp(self)
+
+ self.tmpfile = tempfile.NamedTemporaryFile()
+
+ def testSimple(self):
+ self.tmpfile.write("user1 password")
+ self.tmpfile.flush()
+
+ users = http.auth.ReadPasswordFile(self.tmpfile.name)
+ self.assertEqual(len(users), 1)
+ self.assertEqual(users["user1"].password, "password")
+ self.assertEqual(len(users["user1"].options), 0)
+
+ def testOptions(self):
+ self.tmpfile.write("# Passwords\n")
+ self.tmpfile.write("user1 password\n")
+ self.tmpfile.write("\n")
+ self.tmpfile.write("# Comment\n")
+ self.tmpfile.write("user2 pw write,read\n")
+ self.tmpfile.write(" \t# Another comment\n")
+ self.tmpfile.write("invalidline\n")
+ self.tmpfile.flush()
+
+ users = http.auth.ReadPasswordFile(self.tmpfile.name)
+ self.assertEqual(len(users), 2)
+ self.assertEqual(users["user1"].password, "password")
+ self.assertEqual(len(users["user1"].options), 0)
+
+ self.assertEqual(users["user2"].password, "pw")
+ self.assertEqual(users["user2"].options, ["write", "read"])
+
+
if __name__ == '__main__':
- unittest.main()
+ testutils.GanetiTestProgram()