Modify gnt-node add to call external script
[ganeti-local] / test / ganeti.http_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2007, 2008 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the http module"""
23
24
25 import os
26 import unittest
27 import time
28 import tempfile
29
30 from ganeti import http
31
32 import ganeti.http.server
33 import ganeti.http.client
34 import ganeti.http.auth
35
36 import testutils
37
38
39 class TestStartLines(unittest.TestCase):
40   """Test cases for start line classes"""
41
42   def testClientToServerStartLine(self):
43     """Test client to server start line (HTTP request)"""
44     start_line = http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
45     self.assertEqual(str(start_line), "GET / HTTP/1.1")
46
47   def testServerToClientStartLine(self):
48     """Test server to client start line (HTTP response)"""
49     start_line = http.HttpServerToClientStartLine("HTTP/1.1", 200, "OK")
50     self.assertEqual(str(start_line), "HTTP/1.1 200 OK")
51
52
53 class TestMisc(unittest.TestCase):
54   """Miscellaneous tests"""
55
56   def _TestDateTimeHeader(self, gmnow, expected):
57     self.assertEqual(http.server._DateTimeHeader(gmnow=gmnow), expected)
58
59   def testDateTimeHeader(self):
60     """Test ganeti.http._DateTimeHeader"""
61     self._TestDateTimeHeader((2008, 1, 2, 3, 4, 5, 3, 0, 0),
62                              "Thu, 02 Jan 2008 03:04:05 GMT")
63     self._TestDateTimeHeader((2008, 1, 1, 0, 0, 0, 0, 0, 0),
64                              "Mon, 01 Jan 2008 00:00:00 GMT")
65     self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 0, 0, 0),
66                              "Mon, 31 Dec 2008 00:00:00 GMT")
67     self._TestDateTimeHeader((2008, 12, 31, 23, 59, 59, 0, 0, 0),
68                              "Mon, 31 Dec 2008 23:59:59 GMT")
69     self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 6, 0, 0),
70                              "Sun, 31 Dec 2008 00:00:00 GMT")
71
72   def testHttpServerRequest(self):
73     """Test ganeti.http.server._HttpServerRequest"""
74     server_request = http.server._HttpServerRequest("GET", "/", None, None)
75
76     # These are expected by users of the HTTP server
77     self.assert_(hasattr(server_request, "request_method"))
78     self.assert_(hasattr(server_request, "request_path"))
79     self.assert_(hasattr(server_request, "request_headers"))
80     self.assert_(hasattr(server_request, "request_body"))
81     self.assert_(isinstance(server_request.resp_headers, dict))
82     self.assert_(hasattr(server_request, "private"))
83
84   def testServerSizeLimits(self):
85     """Test HTTP server size limits"""
86     message_reader_class = http.server._HttpClientToServerMessageReader
87     self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
88     self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
89
90   def testFormatAuthHeader(self):
91     self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
92                      "Basic")
93     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
94                      "Basic foo=bar")
95     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
96                      "Basic foo=\"\"")
97     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
98                      "Basic foo=\"x,y\"")
99     params = {
100       "foo": "x,y",
101       "realm": "secure",
102       }
103     # It's a dict whose order isn't guaranteed, hence checking a list
104     self.assert_(http.auth._FormatAuthHeader("Digest", params) in
105                  ("Digest foo=\"x,y\" realm=secure",
106                   "Digest realm=secure foo=\"x,y\""))
107
108
109 class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
110   def __init__(self, realm, authreq, authenticate_fn):
111     http.auth.HttpServerRequestAuthentication.__init__(self)
112
113     self.realm = realm
114     self.authreq = authreq
115     self.authenticate_fn = authenticate_fn
116
117   def AuthenticationRequired(self, req):
118     return self.authreq
119
120   def GetAuthRealm(self, req):
121     return self.realm
122
123   def Authenticate(self, *args):
124     if self.authenticate_fn:
125       return self.authenticate_fn(*args)
126     raise NotImplementedError()
127
128
129 class TestAuth(unittest.TestCase):
130   """Authentication tests"""
131
132   hsra = http.auth.HttpServerRequestAuthentication
133
134   def testConstants(self):
135     for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
136       self.assertEqual(scheme, scheme.upper())
137       self.assert_(scheme.startswith("{"))
138       self.assert_(scheme.endswith("}"))
139
140   def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
141     ra = _FakeRequestAuth(realm, False, None)
142
143     return ra.VerifyBasicAuthPassword(None, user, password, expected)
144
145   def testVerifyBasicAuthPassword(self):
146     tvbap = self._testVerifyBasicAuthPassword
147
148     good_pws = ["pw", "pw{", "pw}", "pw{}", "pw{x}y", "}pw",
149                 "0", "123", "foo...:xyz", "TeST"]
150
151     for pw in good_pws:
152       # Try cleartext passwords
153       self.assert_(tvbap("abc", "user", pw, pw))
154       self.assert_(tvbap("abc", "user", pw, "{cleartext}" + pw))
155       self.assert_(tvbap("abc", "user", pw, "{ClearText}" + pw))
156       self.assert_(tvbap("abc", "user", pw, "{CLEARTEXT}" + pw))
157
158       # Try with invalid password
159       self.failIf(tvbap("abc", "user", pw, "something"))
160
161       # Try with invalid scheme
162       self.failIf(tvbap("abc", "user", pw, "{000}" + pw))
163       self.failIf(tvbap("abc", "user", pw, "{unk}" + pw))
164       self.failIf(tvbap("abc", "user", pw, "{Unk}" + pw))
165       self.failIf(tvbap("abc", "user", pw, "{UNK}" + pw))
166
167     # Try with invalid scheme format
168     self.failIf(tvbap("abc", "user", "pw", "{something"))
169
170     # Hash is MD5("user:This is only a test:pw")
171     self.assert_(tvbap("This is only a test", "user", "pw",
172                        "{ha1}92ea58ae804481498c257b2f65561a17"))
173     self.assert_(tvbap("This is only a test", "user", "pw",
174                        "{HA1}92ea58ae804481498c257b2f65561a17"))
175
176     self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
177                           "{HA1}92ea58ae804481498c257b2f65561a17")
178     self.failIf(tvbap("Admin area", "user", "pw",
179                       "{HA1}92ea58ae804481498c257b2f65561a17"))
180     self.failIf(tvbap("This is only a test", "someone", "pw",
181                       "{HA1}92ea58ae804481498c257b2f65561a17"))
182     self.failIf(tvbap("This is only a test", "user", "something",
183                       "{HA1}92ea58ae804481498c257b2f65561a17"))
184
185
186 class _SimpleAuthenticator:
187   def __init__(self, user, password):
188     self.user = user
189     self.password = password
190     self.called = False
191
192   def __call__(self, req, user, password):
193     self.called = True
194     return self.user == user and self.password == password
195
196
197 class TestHttpServerRequestAuthentication(unittest.TestCase):
198   def testNoAuth(self):
199     req = http.server._HttpServerRequest("GET", "/", None, None)
200     _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
201
202   def testNoRealm(self):
203     headers = { http.HTTP_AUTHORIZATION: "", }
204     req = http.server._HttpServerRequest("GET", "/", headers, None)
205     ra = _FakeRequestAuth(None, False, None)
206     self.assertRaises(AssertionError, ra.PreHandleRequest, req)
207
208   def testNoScheme(self):
209     headers = { http.HTTP_AUTHORIZATION: "", }
210     req = http.server._HttpServerRequest("GET", "/", headers, None)
211     ra = _FakeRequestAuth("area1", False, None)
212     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
213
214   def testUnknownScheme(self):
215     headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
216     req = http.server._HttpServerRequest("GET", "/", headers, None)
217     ra = _FakeRequestAuth("area1", False, None)
218     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
219
220   def testInvalidBase64(self):
221     headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
222     req = http.server._HttpServerRequest("GET", "/", headers, None)
223     ra = _FakeRequestAuth("area1", False, None)
224     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
225
226   def testAuthForPublicResource(self):
227     headers = {
228       http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
229       }
230     req = http.server._HttpServerRequest("GET", "/", headers, None)
231     ra = _FakeRequestAuth("area1", False, None)
232     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
233
234   def testAuthForPublicResource(self):
235     headers = {
236       http.HTTP_AUTHORIZATION:
237         "Basic %s" % ("foo:bar".encode("base64").strip(), ),
238       }
239     req = http.server._HttpServerRequest("GET", "/", headers, None)
240     ac = _SimpleAuthenticator("foo", "bar")
241     ra = _FakeRequestAuth("area1", False, ac)
242     ra.PreHandleRequest(req)
243
244     req = http.server._HttpServerRequest("GET", "/", headers, None)
245     ac = _SimpleAuthenticator("something", "else")
246     ra = _FakeRequestAuth("area1", False, ac)
247     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
248
249   def testInvalidRequestHeader(self):
250     checks = {
251       http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
252                               "basic %s" % "foobar".encode("base64").strip()],
253       http.HttpBadRequest: ["Basic"],
254       }
255
256     for exc, headers in checks.items():
257       for i in headers:
258         headers = { http.HTTP_AUTHORIZATION: i, }
259         req = http.server._HttpServerRequest("GET", "/", headers, None)
260         ra = _FakeRequestAuth("area1", False, None)
261         self.assertRaises(exc, ra.PreHandleRequest, req)
262
263   def testBasicAuth(self):
264     for user in ["", "joe", "user name with spaces"]:
265       for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
266                  "foo:bar:baz"]:
267         for wrong_pw in [True, False]:
268           basic_auth = "%s:%s" % (user, pw)
269           if wrong_pw:
270             basic_auth += "WRONG"
271           headers = {
272               http.HTTP_AUTHORIZATION:
273                 "Basic %s" % (basic_auth.encode("base64").strip(), ),
274             }
275           req = http.server._HttpServerRequest("GET", "/", headers, None)
276
277           ac = _SimpleAuthenticator(user, pw)
278           self.assertFalse(ac.called)
279           ra = _FakeRequestAuth("area1", True, ac)
280           if wrong_pw:
281             try:
282               ra.PreHandleRequest(req)
283             except http.HttpUnauthorized, err:
284               www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
285               self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
286             else:
287               self.fail("Didn't raise HttpUnauthorized")
288           else:
289             ra.PreHandleRequest(req)
290           self.assert_(ac.called)
291
292
293 class TestReadPasswordFile(testutils.GanetiTestCase):
294   def setUp(self):
295     testutils.GanetiTestCase.setUp(self)
296
297     self.tmpfile = tempfile.NamedTemporaryFile()
298
299   def testSimple(self):
300     self.tmpfile.write("user1 password")
301     self.tmpfile.flush()
302
303     users = http.auth.ReadPasswordFile(self.tmpfile.name)
304     self.assertEqual(len(users), 1)
305     self.assertEqual(users["user1"].password, "password")
306     self.assertEqual(len(users["user1"].options), 0)
307
308   def testOptions(self):
309     self.tmpfile.write("# Passwords\n")
310     self.tmpfile.write("user1 password\n")
311     self.tmpfile.write("\n")
312     self.tmpfile.write("# Comment\n")
313     self.tmpfile.write("user2 pw write,read\n")
314     self.tmpfile.write("   \t# Another comment\n")
315     self.tmpfile.write("invalidline\n")
316     self.tmpfile.flush()
317
318     users = http.auth.ReadPasswordFile(self.tmpfile.name)
319     self.assertEqual(len(users), 2)
320     self.assertEqual(users["user1"].password, "password")
321     self.assertEqual(len(users["user1"].options), 0)
322
323     self.assertEqual(users["user2"].password, "pw")
324     self.assertEqual(users["user2"].options, ["write", "read"])
325
326
327 class TestClientRequest(unittest.TestCase):
328   def testRepr(self):
329     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
330                                        headers=[], post_data="Hello World")
331     self.assert_(repr(cr).startswith("<"))
332
333   def testNoHeaders(self):
334     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
335                                        headers=None)
336     self.assert_(isinstance(cr.headers, list))
337     self.assertEqual(cr.headers, [])
338     self.assertEqual(cr.url, "https://localhost:1234/version")
339
340   def testOldStyleHeaders(self):
341     headers = {
342       "Content-type": "text/plain",
343       "Accept": "text/html",
344       }
345     cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
346                                        headers=headers)
347     self.assert_(isinstance(cr.headers, list))
348     self.assertEqual(sorted(cr.headers), [
349       "Accept: text/html",
350       "Content-type: text/plain",
351       ])
352     self.assertEqual(cr.url, "https://localhost:16481/vg_list")
353
354   def testNewStyleHeaders(self):
355     headers = [
356       "Accept: text/html",
357       "Content-type: text/plain; charset=ascii",
358       "Server: httpd 1.0",
359       ]
360     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
361                                        headers=headers)
362     self.assert_(isinstance(cr.headers, list))
363     self.assertEqual(sorted(cr.headers), sorted(headers))
364     self.assertEqual(cr.url, "https://localhost:1234/version")
365
366   def testPostData(self):
367     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
368                                        post_data="Hello World")
369     self.assertEqual(cr.post_data, "Hello World")
370
371   def testNoPostData(self):
372     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
373     self.assertEqual(cr.post_data, "")
374
375   def testIdentity(self):
376     # These should all use different connections, hence also have a different
377     # identity
378     cr1 = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
379     cr2 = http.client.HttpClientRequest("localhost", 9999, "GET", "/version")
380     cr3 = http.client.HttpClientRequest("node1", 1234, "GET", "/version")
381     cr4 = http.client.HttpClientRequest("node1", 9999, "GET", "/version")
382
383     self.assertEqual(len(set([cr1.identity, cr2.identity,
384                               cr3.identity, cr4.identity])), 4)
385
386     # But this one should have the same
387     cr1vglist = http.client.HttpClientRequest("localhost", 1234,
388                                               "GET", "/vg_list")
389     self.assertEqual(cr1.identity, cr1vglist.identity)
390
391
392 class TestClient(unittest.TestCase):
393   def test(self):
394     pool = http.client.HttpClientPool(None)
395     self.assertFalse(pool._pool)
396
397
398 if __name__ == '__main__':
399   testutils.GanetiTestProgram()