DRBD IPv6 support
[ganeti-local] / test / ganeti.http_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2007, 2008 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the http module"""
23
24
25 import os
26 import unittest
27 import time
28 import tempfile
29
30 from ganeti import http
31
32 import ganeti.http.server
33 import ganeti.http.client
34 import ganeti.http.auth
35
36 import testutils
37
38
39 class TestStartLines(unittest.TestCase):
40   """Test cases for start line classes"""
41
42   def testClientToServerStartLine(self):
43     """Test client to server start line (HTTP request)"""
44     start_line = http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
45     self.assertEqual(str(start_line), "GET / HTTP/1.1")
46
47   def testServerToClientStartLine(self):
48     """Test server to client start line (HTTP response)"""
49     start_line = http.HttpServerToClientStartLine("HTTP/1.1", 200, "OK")
50     self.assertEqual(str(start_line), "HTTP/1.1 200 OK")
51
52
53 class TestMisc(unittest.TestCase):
54   """Miscellaneous tests"""
55
56   def _TestDateTimeHeader(self, gmnow, expected):
57     self.assertEqual(http.server._DateTimeHeader(gmnow=gmnow), expected)
58
59   def testDateTimeHeader(self):
60     """Test ganeti.http._DateTimeHeader"""
61     self._TestDateTimeHeader((2008, 1, 2, 3, 4, 5, 3, 0, 0),
62                              "Thu, 02 Jan 2008 03:04:05 GMT")
63     self._TestDateTimeHeader((2008, 1, 1, 0, 0, 0, 0, 0, 0),
64                              "Mon, 01 Jan 2008 00:00:00 GMT")
65     self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 0, 0, 0),
66                              "Mon, 31 Dec 2008 00:00:00 GMT")
67     self._TestDateTimeHeader((2008, 12, 31, 23, 59, 59, 0, 0, 0),
68                              "Mon, 31 Dec 2008 23:59:59 GMT")
69     self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 6, 0, 0),
70                              "Sun, 31 Dec 2008 00:00:00 GMT")
71
72   def testHttpServerRequest(self):
73     """Test ganeti.http.server._HttpServerRequest"""
74     server_request = http.server._HttpServerRequest("GET", "/", None, None)
75
76     # These are expected by users of the HTTP server
77     self.assert_(hasattr(server_request, "request_method"))
78     self.assert_(hasattr(server_request, "request_path"))
79     self.assert_(hasattr(server_request, "request_headers"))
80     self.assert_(hasattr(server_request, "request_body"))
81     self.assert_(isinstance(server_request.resp_headers, dict))
82     self.assert_(hasattr(server_request, "private"))
83
84   def testServerSizeLimits(self):
85     """Test HTTP server size limits"""
86     message_reader_class = http.server._HttpClientToServerMessageReader
87     self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
88     self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
89
90   def testClientSizeLimits(self):
91     """Test HTTP client size limits"""
92     message_reader_class = http.client._HttpServerToClientMessageReader
93     self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
94     self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
95
96   def testFormatAuthHeader(self):
97     self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
98                      "Basic")
99     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
100                      "Basic foo=bar")
101     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
102                      "Basic foo=\"\"")
103     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
104                      "Basic foo=\"x,y\"")
105     params = {
106       "foo": "x,y",
107       "realm": "secure",
108       }
109     # It's a dict whose order isn't guaranteed, hence checking a list
110     self.assert_(http.auth._FormatAuthHeader("Digest", params) in
111                  ("Digest foo=\"x,y\" realm=secure",
112                   "Digest realm=secure foo=\"x,y\""))
113
114
115 class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
116   def __init__(self, realm, authreq, authenticate_fn):
117     http.auth.HttpServerRequestAuthentication.__init__(self)
118
119     self.realm = realm
120     self.authreq = authreq
121     self.authenticate_fn = authenticate_fn
122
123   def AuthenticationRequired(self, req):
124     return self.authreq
125
126   def GetAuthRealm(self, req):
127     return self.realm
128
129   def Authenticate(self, *args):
130     if self.authenticate_fn:
131       return self.authenticate_fn(*args)
132     raise NotImplementedError()
133
134
135 class TestAuth(unittest.TestCase):
136   """Authentication tests"""
137
138   hsra = http.auth.HttpServerRequestAuthentication
139
140   def testConstants(self):
141     for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
142       self.assertEqual(scheme, scheme.upper())
143       self.assert_(scheme.startswith("{"))
144       self.assert_(scheme.endswith("}"))
145
146   def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
147     ra = _FakeRequestAuth(realm, False, None)
148
149     return ra.VerifyBasicAuthPassword(None, user, password, expected)
150
151   def testVerifyBasicAuthPassword(self):
152     tvbap = self._testVerifyBasicAuthPassword
153
154     good_pws = ["pw", "pw{", "pw}", "pw{}", "pw{x}y", "}pw",
155                 "0", "123", "foo...:xyz", "TeST"]
156
157     for pw in good_pws:
158       # Try cleartext passwords
159       self.assert_(tvbap("abc", "user", pw, pw))
160       self.assert_(tvbap("abc", "user", pw, "{cleartext}" + pw))
161       self.assert_(tvbap("abc", "user", pw, "{ClearText}" + pw))
162       self.assert_(tvbap("abc", "user", pw, "{CLEARTEXT}" + pw))
163
164       # Try with invalid password
165       self.failIf(tvbap("abc", "user", pw, "something"))
166
167       # Try with invalid scheme
168       self.failIf(tvbap("abc", "user", pw, "{000}" + pw))
169       self.failIf(tvbap("abc", "user", pw, "{unk}" + pw))
170       self.failIf(tvbap("abc", "user", pw, "{Unk}" + pw))
171       self.failIf(tvbap("abc", "user", pw, "{UNK}" + pw))
172
173     # Try with invalid scheme format
174     self.failIf(tvbap("abc", "user", "pw", "{something"))
175
176     # Hash is MD5("user:This is only a test:pw")
177     self.assert_(tvbap("This is only a test", "user", "pw",
178                        "{ha1}92ea58ae804481498c257b2f65561a17"))
179     self.assert_(tvbap("This is only a test", "user", "pw",
180                        "{HA1}92ea58ae804481498c257b2f65561a17"))
181
182     self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
183                           "{HA1}92ea58ae804481498c257b2f65561a17")
184     self.failIf(tvbap("Admin area", "user", "pw",
185                       "{HA1}92ea58ae804481498c257b2f65561a17"))
186     self.failIf(tvbap("This is only a test", "someone", "pw",
187                       "{HA1}92ea58ae804481498c257b2f65561a17"))
188     self.failIf(tvbap("This is only a test", "user", "something",
189                       "{HA1}92ea58ae804481498c257b2f65561a17"))
190
191
192 class _SimpleAuthenticator:
193   def __init__(self, user, password):
194     self.user = user
195     self.password = password
196     self.called = False
197
198   def __call__(self, req, user, password):
199     self.called = True
200     return self.user == user and self.password == password
201
202
203 class TestHttpServerRequestAuthentication(unittest.TestCase):
204   def testNoAuth(self):
205     req = http.server._HttpServerRequest("GET", "/", None, None)
206     _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
207
208   def testNoRealm(self):
209     headers = { http.HTTP_AUTHORIZATION: "", }
210     req = http.server._HttpServerRequest("GET", "/", headers, None)
211     ra = _FakeRequestAuth(None, False, None)
212     self.assertRaises(AssertionError, ra.PreHandleRequest, req)
213
214   def testNoScheme(self):
215     headers = { http.HTTP_AUTHORIZATION: "", }
216     req = http.server._HttpServerRequest("GET", "/", headers, None)
217     ra = _FakeRequestAuth("area1", False, None)
218     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
219
220   def testUnknownScheme(self):
221     headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
222     req = http.server._HttpServerRequest("GET", "/", headers, None)
223     ra = _FakeRequestAuth("area1", False, None)
224     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
225
226   def testInvalidBase64(self):
227     headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
228     req = http.server._HttpServerRequest("GET", "/", headers, None)
229     ra = _FakeRequestAuth("area1", False, None)
230     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
231
232   def testAuthForPublicResource(self):
233     headers = {
234       http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
235       }
236     req = http.server._HttpServerRequest("GET", "/", headers, None)
237     ra = _FakeRequestAuth("area1", False, None)
238     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
239
240   def testAuthForPublicResource(self):
241     headers = {
242       http.HTTP_AUTHORIZATION:
243         "Basic %s" % ("foo:bar".encode("base64").strip(), ),
244       }
245     req = http.server._HttpServerRequest("GET", "/", headers, None)
246     ac = _SimpleAuthenticator("foo", "bar")
247     ra = _FakeRequestAuth("area1", False, ac)
248     ra.PreHandleRequest(req)
249
250     req = http.server._HttpServerRequest("GET", "/", headers, None)
251     ac = _SimpleAuthenticator("something", "else")
252     ra = _FakeRequestAuth("area1", False, ac)
253     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
254
255   def testInvalidRequestHeader(self):
256     checks = {
257       http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
258                               "basic %s" % "foobar".encode("base64").strip()],
259       http.HttpBadRequest: ["Basic"],
260       }
261
262     for exc, headers in checks.items():
263       for i in headers:
264         headers = { http.HTTP_AUTHORIZATION: i, }
265         req = http.server._HttpServerRequest("GET", "/", headers, None)
266         ra = _FakeRequestAuth("area1", False, None)
267         self.assertRaises(exc, ra.PreHandleRequest, req)
268
269   def testBasicAuth(self):
270     for user in ["", "joe", "user name with spaces"]:
271       for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
272                  "foo:bar:baz"]:
273         for wrong_pw in [True, False]:
274           basic_auth = "%s:%s" % (user, pw)
275           if wrong_pw:
276             basic_auth += "WRONG"
277           headers = {
278               http.HTTP_AUTHORIZATION:
279                 "Basic %s" % (basic_auth.encode("base64").strip(), ),
280             }
281           req = http.server._HttpServerRequest("GET", "/", headers, None)
282
283           ac = _SimpleAuthenticator(user, pw)
284           self.assertFalse(ac.called)
285           ra = _FakeRequestAuth("area1", True, ac)
286           if wrong_pw:
287             try:
288               ra.PreHandleRequest(req)
289             except http.HttpUnauthorized, err:
290               www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
291               self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
292             else:
293               self.fail("Didn't raise HttpUnauthorized")
294           else:
295             ra.PreHandleRequest(req)
296           self.assert_(ac.called)
297
298
299 class TestReadPasswordFile(testutils.GanetiTestCase):
300   def setUp(self):
301     testutils.GanetiTestCase.setUp(self)
302
303     self.tmpfile = tempfile.NamedTemporaryFile()
304
305   def testSimple(self):
306     self.tmpfile.write("user1 password")
307     self.tmpfile.flush()
308
309     users = http.auth.ReadPasswordFile(self.tmpfile.name)
310     self.assertEqual(len(users), 1)
311     self.assertEqual(users["user1"].password, "password")
312     self.assertEqual(len(users["user1"].options), 0)
313
314   def testOptions(self):
315     self.tmpfile.write("# Passwords\n")
316     self.tmpfile.write("user1 password\n")
317     self.tmpfile.write("\n")
318     self.tmpfile.write("# Comment\n")
319     self.tmpfile.write("user2 pw write,read\n")
320     self.tmpfile.write("   \t# Another comment\n")
321     self.tmpfile.write("invalidline\n")
322     self.tmpfile.flush()
323
324     users = http.auth.ReadPasswordFile(self.tmpfile.name)
325     self.assertEqual(len(users), 2)
326     self.assertEqual(users["user1"].password, "password")
327     self.assertEqual(len(users["user1"].options), 0)
328
329     self.assertEqual(users["user2"].password, "pw")
330     self.assertEqual(users["user2"].options, ["write", "read"])
331
332
333 if __name__ == '__main__':
334   testutils.GanetiTestProgram()