Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.http_unittest.py @ 33231500

History | View | Annotate | Download (14.6 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2007, 2008 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the http module"""
23

    
24

    
25
import os
26
import unittest
27
import time
28
import tempfile
29

    
30
from ganeti import http
31

    
32
import ganeti.http.server
33
import ganeti.http.client
34
import ganeti.http.auth
35

    
36
import testutils
37

    
38

    
39
class TestStartLines(unittest.TestCase):
40
  """Test cases for start line classes"""
41

    
42
  def testClientToServerStartLine(self):
43
    """Test client to server start line (HTTP request)"""
44
    start_line = http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
45
    self.assertEqual(str(start_line), "GET / HTTP/1.1")
46

    
47
  def testServerToClientStartLine(self):
48
    """Test server to client start line (HTTP response)"""
49
    start_line = http.HttpServerToClientStartLine("HTTP/1.1", 200, "OK")
50
    self.assertEqual(str(start_line), "HTTP/1.1 200 OK")
51

    
52

    
53
class TestMisc(unittest.TestCase):
54
  """Miscellaneous tests"""
55

    
56
  def _TestDateTimeHeader(self, gmnow, expected):
57
    self.assertEqual(http.server._DateTimeHeader(gmnow=gmnow), expected)
58

    
59
  def testDateTimeHeader(self):
60
    """Test ganeti.http._DateTimeHeader"""
61
    self._TestDateTimeHeader((2008, 1, 2, 3, 4, 5, 3, 0, 0),
62
                             "Thu, 02 Jan 2008 03:04:05 GMT")
63
    self._TestDateTimeHeader((2008, 1, 1, 0, 0, 0, 0, 0, 0),
64
                             "Mon, 01 Jan 2008 00:00:00 GMT")
65
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 0, 0, 0),
66
                             "Mon, 31 Dec 2008 00:00:00 GMT")
67
    self._TestDateTimeHeader((2008, 12, 31, 23, 59, 59, 0, 0, 0),
68
                             "Mon, 31 Dec 2008 23:59:59 GMT")
69
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 6, 0, 0),
70
                             "Sun, 31 Dec 2008 00:00:00 GMT")
71

    
72
  def testHttpServerRequest(self):
73
    """Test ganeti.http.server._HttpServerRequest"""
74
    server_request = http.server._HttpServerRequest("GET", "/", None, None)
75

    
76
    # These are expected by users of the HTTP server
77
    self.assert_(hasattr(server_request, "request_method"))
78
    self.assert_(hasattr(server_request, "request_path"))
79
    self.assert_(hasattr(server_request, "request_headers"))
80
    self.assert_(hasattr(server_request, "request_body"))
81
    self.assert_(isinstance(server_request.resp_headers, dict))
82
    self.assert_(hasattr(server_request, "private"))
83

    
84
  def testServerSizeLimits(self):
85
    """Test HTTP server size limits"""
86
    message_reader_class = http.server._HttpClientToServerMessageReader
87
    self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
88
    self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
89

    
90
  def testFormatAuthHeader(self):
91
    self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
92
                     "Basic")
93
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
94
                     "Basic foo=bar")
95
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
96
                     "Basic foo=\"\"")
97
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
98
                     "Basic foo=\"x,y\"")
99
    params = {
100
      "foo": "x,y",
101
      "realm": "secure",
102
      }
103
    # It's a dict whose order isn't guaranteed, hence checking a list
104
    self.assert_(http.auth._FormatAuthHeader("Digest", params) in
105
                 ("Digest foo=\"x,y\" realm=secure",
106
                  "Digest realm=secure foo=\"x,y\""))
107

    
108

    
109
class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
110
  def __init__(self, realm, authreq, authenticate_fn):
111
    http.auth.HttpServerRequestAuthentication.__init__(self)
112

    
113
    self.realm = realm
114
    self.authreq = authreq
115
    self.authenticate_fn = authenticate_fn
116

    
117
  def AuthenticationRequired(self, req):
118
    return self.authreq
119

    
120
  def GetAuthRealm(self, req):
121
    return self.realm
122

    
123
  def Authenticate(self, *args):
124
    if self.authenticate_fn:
125
      return self.authenticate_fn(*args)
126
    raise NotImplementedError()
127

    
128

    
129
class TestAuth(unittest.TestCase):
130
  """Authentication tests"""
131

    
132
  hsra = http.auth.HttpServerRequestAuthentication
133

    
134
  def testConstants(self):
135
    for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
136
      self.assertEqual(scheme, scheme.upper())
137
      self.assert_(scheme.startswith("{"))
138
      self.assert_(scheme.endswith("}"))
139

    
140
  def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
141
    ra = _FakeRequestAuth(realm, False, None)
142

    
143
    return ra.VerifyBasicAuthPassword(None, user, password, expected)
144

    
145
  def testVerifyBasicAuthPassword(self):
146
    tvbap = self._testVerifyBasicAuthPassword
147

    
148
    good_pws = ["pw", "pw{", "pw}", "pw{}", "pw{x}y", "}pw",
149
                "0", "123", "foo...:xyz", "TeST"]
150

    
151
    for pw in good_pws:
152
      # Try cleartext passwords
153
      self.assert_(tvbap("abc", "user", pw, pw))
154
      self.assert_(tvbap("abc", "user", pw, "{cleartext}" + pw))
155
      self.assert_(tvbap("abc", "user", pw, "{ClearText}" + pw))
156
      self.assert_(tvbap("abc", "user", pw, "{CLEARTEXT}" + pw))
157

    
158
      # Try with invalid password
159
      self.failIf(tvbap("abc", "user", pw, "something"))
160

    
161
      # Try with invalid scheme
162
      self.failIf(tvbap("abc", "user", pw, "{000}" + pw))
163
      self.failIf(tvbap("abc", "user", pw, "{unk}" + pw))
164
      self.failIf(tvbap("abc", "user", pw, "{Unk}" + pw))
165
      self.failIf(tvbap("abc", "user", pw, "{UNK}" + pw))
166

    
167
    # Try with invalid scheme format
168
    self.failIf(tvbap("abc", "user", "pw", "{something"))
169

    
170
    # Hash is MD5("user:This is only a test:pw")
171
    self.assert_(tvbap("This is only a test", "user", "pw",
172
                       "{ha1}92ea58ae804481498c257b2f65561a17"))
173
    self.assert_(tvbap("This is only a test", "user", "pw",
174
                       "{HA1}92ea58ae804481498c257b2f65561a17"))
175

    
176
    self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
177
                          "{HA1}92ea58ae804481498c257b2f65561a17")
178
    self.failIf(tvbap("Admin area", "user", "pw",
179
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
180
    self.failIf(tvbap("This is only a test", "someone", "pw",
181
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
182
    self.failIf(tvbap("This is only a test", "user", "something",
183
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
184

    
185

    
186
class _SimpleAuthenticator:
187
  def __init__(self, user, password):
188
    self.user = user
189
    self.password = password
190
    self.called = False
191

    
192
  def __call__(self, req, user, password):
193
    self.called = True
194
    return self.user == user and self.password == password
195

    
196

    
197
class TestHttpServerRequestAuthentication(unittest.TestCase):
198
  def testNoAuth(self):
199
    req = http.server._HttpServerRequest("GET", "/", None, None)
200
    _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
201

    
202
  def testNoRealm(self):
203
    headers = { http.HTTP_AUTHORIZATION: "", }
204
    req = http.server._HttpServerRequest("GET", "/", headers, None)
205
    ra = _FakeRequestAuth(None, False, None)
206
    self.assertRaises(AssertionError, ra.PreHandleRequest, req)
207

    
208
  def testNoScheme(self):
209
    headers = { http.HTTP_AUTHORIZATION: "", }
210
    req = http.server._HttpServerRequest("GET", "/", headers, None)
211
    ra = _FakeRequestAuth("area1", False, None)
212
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
213

    
214
  def testUnknownScheme(self):
215
    headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
216
    req = http.server._HttpServerRequest("GET", "/", headers, None)
217
    ra = _FakeRequestAuth("area1", False, None)
218
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
219

    
220
  def testInvalidBase64(self):
221
    headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
222
    req = http.server._HttpServerRequest("GET", "/", headers, None)
223
    ra = _FakeRequestAuth("area1", False, None)
224
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
225

    
226
  def testAuthForPublicResource(self):
227
    headers = {
228
      http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
229
      }
230
    req = http.server._HttpServerRequest("GET", "/", headers, None)
231
    ra = _FakeRequestAuth("area1", False, None)
232
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
233

    
234
  def testAuthForPublicResource(self):
235
    headers = {
236
      http.HTTP_AUTHORIZATION:
237
        "Basic %s" % ("foo:bar".encode("base64").strip(), ),
238
      }
239
    req = http.server._HttpServerRequest("GET", "/", headers, None)
240
    ac = _SimpleAuthenticator("foo", "bar")
241
    ra = _FakeRequestAuth("area1", False, ac)
242
    ra.PreHandleRequest(req)
243

    
244
    req = http.server._HttpServerRequest("GET", "/", headers, None)
245
    ac = _SimpleAuthenticator("something", "else")
246
    ra = _FakeRequestAuth("area1", False, ac)
247
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
248

    
249
  def testInvalidRequestHeader(self):
250
    checks = {
251
      http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
252
                              "basic %s" % "foobar".encode("base64").strip()],
253
      http.HttpBadRequest: ["Basic"],
254
      }
255

    
256
    for exc, headers in checks.items():
257
      for i in headers:
258
        headers = { http.HTTP_AUTHORIZATION: i, }
259
        req = http.server._HttpServerRequest("GET", "/", headers, None)
260
        ra = _FakeRequestAuth("area1", False, None)
261
        self.assertRaises(exc, ra.PreHandleRequest, req)
262

    
263
  def testBasicAuth(self):
264
    for user in ["", "joe", "user name with spaces"]:
265
      for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
266
                 "foo:bar:baz"]:
267
        for wrong_pw in [True, False]:
268
          basic_auth = "%s:%s" % (user, pw)
269
          if wrong_pw:
270
            basic_auth += "WRONG"
271
          headers = {
272
              http.HTTP_AUTHORIZATION:
273
                "Basic %s" % (basic_auth.encode("base64").strip(), ),
274
            }
275
          req = http.server._HttpServerRequest("GET", "/", headers, None)
276

    
277
          ac = _SimpleAuthenticator(user, pw)
278
          self.assertFalse(ac.called)
279
          ra = _FakeRequestAuth("area1", True, ac)
280
          if wrong_pw:
281
            try:
282
              ra.PreHandleRequest(req)
283
            except http.HttpUnauthorized, err:
284
              www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
285
              self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
286
            else:
287
              self.fail("Didn't raise HttpUnauthorized")
288
          else:
289
            ra.PreHandleRequest(req)
290
          self.assert_(ac.called)
291

    
292

    
293
class TestReadPasswordFile(testutils.GanetiTestCase):
294
  def setUp(self):
295
    testutils.GanetiTestCase.setUp(self)
296

    
297
    self.tmpfile = tempfile.NamedTemporaryFile()
298

    
299
  def testSimple(self):
300
    self.tmpfile.write("user1 password")
301
    self.tmpfile.flush()
302

    
303
    users = http.auth.ReadPasswordFile(self.tmpfile.name)
304
    self.assertEqual(len(users), 1)
305
    self.assertEqual(users["user1"].password, "password")
306
    self.assertEqual(len(users["user1"].options), 0)
307

    
308
  def testOptions(self):
309
    self.tmpfile.write("# Passwords\n")
310
    self.tmpfile.write("user1 password\n")
311
    self.tmpfile.write("\n")
312
    self.tmpfile.write("# Comment\n")
313
    self.tmpfile.write("user2 pw write,read\n")
314
    self.tmpfile.write("   \t# Another comment\n")
315
    self.tmpfile.write("invalidline\n")
316
    self.tmpfile.flush()
317

    
318
    users = http.auth.ReadPasswordFile(self.tmpfile.name)
319
    self.assertEqual(len(users), 2)
320
    self.assertEqual(users["user1"].password, "password")
321
    self.assertEqual(len(users["user1"].options), 0)
322

    
323
    self.assertEqual(users["user2"].password, "pw")
324
    self.assertEqual(users["user2"].options, ["write", "read"])
325

    
326

    
327
class TestClientRequest(unittest.TestCase):
328
  def testRepr(self):
329
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
330
                                       headers=[], post_data="Hello World")
331
    self.assert_(repr(cr).startswith("<"))
332

    
333
  def testNoHeaders(self):
334
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
335
                                       headers=None)
336
    self.assert_(isinstance(cr.headers, list))
337
    self.assertEqual(cr.headers, [])
338
    self.assertEqual(cr.url, "https://localhost:1234/version")
339

    
340
  def testOldStyleHeaders(self):
341
    headers = {
342
      "Content-type": "text/plain",
343
      "Accept": "text/html",
344
      }
345
    cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
346
                                       headers=headers)
347
    self.assert_(isinstance(cr.headers, list))
348
    self.assertEqual(sorted(cr.headers), [
349
      "Accept: text/html",
350
      "Content-type: text/plain",
351
      ])
352
    self.assertEqual(cr.url, "https://localhost:16481/vg_list")
353

    
354
  def testNewStyleHeaders(self):
355
    headers = [
356
      "Accept: text/html",
357
      "Content-type: text/plain; charset=ascii",
358
      "Server: httpd 1.0",
359
      ]
360
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
361
                                       headers=headers)
362
    self.assert_(isinstance(cr.headers, list))
363
    self.assertEqual(sorted(cr.headers), sorted(headers))
364
    self.assertEqual(cr.url, "https://localhost:1234/version")
365

    
366
  def testPostData(self):
367
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
368
                                       post_data="Hello World")
369
    self.assertEqual(cr.post_data, "Hello World")
370

    
371
  def testNoPostData(self):
372
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
373
    self.assertEqual(cr.post_data, "")
374

    
375
  def testIdentity(self):
376
    # These should all use different connections, hence also have a different
377
    # identity
378
    cr1 = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
379
    cr2 = http.client.HttpClientRequest("localhost", 9999, "GET", "/version")
380
    cr3 = http.client.HttpClientRequest("node1", 1234, "GET", "/version")
381
    cr4 = http.client.HttpClientRequest("node1", 9999, "GET", "/version")
382

    
383
    self.assertEqual(len(set([cr1.identity, cr2.identity,
384
                              cr3.identity, cr4.identity])), 4)
385

    
386
    # But this one should have the same
387
    cr1vglist = http.client.HttpClientRequest("localhost", 1234,
388
                                              "GET", "/vg_list")
389
    self.assertEqual(cr1.identity, cr1vglist.identity)
390

    
391

    
392
class TestClient(unittest.TestCase):
393
  def test(self):
394
    pool = http.client.HttpClientPool(None)
395
    self.assertFalse(pool._pool)
396

    
397

    
398
if __name__ == '__main__':
399
  testutils.GanetiTestProgram()