Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.http_unittest.py @ 2287b920

History | View | Annotate | Download (14.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2007, 2008 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the http module"""
23

    
24

    
25
import os
26
import unittest
27
import time
28
import tempfile
29
from cStringIO import StringIO
30

    
31
from ganeti import http
32

    
33
import ganeti.http.server
34
import ganeti.http.client
35
import ganeti.http.auth
36

    
37
import testutils
38

    
39

    
40
class TestStartLines(unittest.TestCase):
41
  """Test cases for start line classes"""
42

    
43
  def testClientToServerStartLine(self):
44
    """Test client to server start line (HTTP request)"""
45
    start_line = http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
46
    self.assertEqual(str(start_line), "GET / HTTP/1.1")
47

    
48
  def testServerToClientStartLine(self):
49
    """Test server to client start line (HTTP response)"""
50
    start_line = http.HttpServerToClientStartLine("HTTP/1.1", 200, "OK")
51
    self.assertEqual(str(start_line), "HTTP/1.1 200 OK")
52

    
53

    
54
class TestMisc(unittest.TestCase):
55
  """Miscellaneous tests"""
56

    
57
  def _TestDateTimeHeader(self, gmnow, expected):
58
    self.assertEqual(http.server._DateTimeHeader(gmnow=gmnow), expected)
59

    
60
  def testDateTimeHeader(self):
61
    """Test ganeti.http._DateTimeHeader"""
62
    self._TestDateTimeHeader((2008, 1, 2, 3, 4, 5, 3, 0, 0),
63
                             "Thu, 02 Jan 2008 03:04:05 GMT")
64
    self._TestDateTimeHeader((2008, 1, 1, 0, 0, 0, 0, 0, 0),
65
                             "Mon, 01 Jan 2008 00:00:00 GMT")
66
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 0, 0, 0),
67
                             "Mon, 31 Dec 2008 00:00:00 GMT")
68
    self._TestDateTimeHeader((2008, 12, 31, 23, 59, 59, 0, 0, 0),
69
                             "Mon, 31 Dec 2008 23:59:59 GMT")
70
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 6, 0, 0),
71
                             "Sun, 31 Dec 2008 00:00:00 GMT")
72

    
73
  def testHttpServerRequest(self):
74
    """Test ganeti.http.server._HttpServerRequest"""
75
    server_request = http.server._HttpServerRequest("GET", "/", None, None)
76

    
77
    # These are expected by users of the HTTP server
78
    self.assert_(hasattr(server_request, "request_method"))
79
    self.assert_(hasattr(server_request, "request_path"))
80
    self.assert_(hasattr(server_request, "request_headers"))
81
    self.assert_(hasattr(server_request, "request_body"))
82
    self.assert_(isinstance(server_request.resp_headers, dict))
83
    self.assert_(hasattr(server_request, "private"))
84

    
85
  def testServerSizeLimits(self):
86
    """Test HTTP server size limits"""
87
    message_reader_class = http.server._HttpClientToServerMessageReader
88
    self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
89
    self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
90

    
91
  def testFormatAuthHeader(self):
92
    self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
93
                     "Basic")
94
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
95
                     "Basic foo=bar")
96
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
97
                     "Basic foo=\"\"")
98
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
99
                     "Basic foo=\"x,y\"")
100
    params = {
101
      "foo": "x,y",
102
      "realm": "secure",
103
      }
104
    # It's a dict whose order isn't guaranteed, hence checking a list
105
    self.assert_(http.auth._FormatAuthHeader("Digest", params) in
106
                 ("Digest foo=\"x,y\" realm=secure",
107
                  "Digest realm=secure foo=\"x,y\""))
108

    
109

    
110
class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
111
  def __init__(self, realm, authreq, authenticate_fn):
112
    http.auth.HttpServerRequestAuthentication.__init__(self)
113

    
114
    self.realm = realm
115
    self.authreq = authreq
116
    self.authenticate_fn = authenticate_fn
117

    
118
  def AuthenticationRequired(self, req):
119
    return self.authreq
120

    
121
  def GetAuthRealm(self, req):
122
    return self.realm
123

    
124
  def Authenticate(self, *args):
125
    if self.authenticate_fn:
126
      return self.authenticate_fn(*args)
127
    raise NotImplementedError()
128

    
129

    
130
class TestAuth(unittest.TestCase):
131
  """Authentication tests"""
132

    
133
  hsra = http.auth.HttpServerRequestAuthentication
134

    
135
  def testConstants(self):
136
    for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
137
      self.assertEqual(scheme, scheme.upper())
138
      self.assert_(scheme.startswith("{"))
139
      self.assert_(scheme.endswith("}"))
140

    
141
  def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
142
    ra = _FakeRequestAuth(realm, False, None)
143

    
144
    return ra.VerifyBasicAuthPassword(None, user, password, expected)
145

    
146
  def testVerifyBasicAuthPassword(self):
147
    tvbap = self._testVerifyBasicAuthPassword
148

    
149
    good_pws = ["pw", "pw{", "pw}", "pw{}", "pw{x}y", "}pw",
150
                "0", "123", "foo...:xyz", "TeST"]
151

    
152
    for pw in good_pws:
153
      # Try cleartext passwords
154
      self.assert_(tvbap("abc", "user", pw, pw))
155
      self.assert_(tvbap("abc", "user", pw, "{cleartext}" + pw))
156
      self.assert_(tvbap("abc", "user", pw, "{ClearText}" + pw))
157
      self.assert_(tvbap("abc", "user", pw, "{CLEARTEXT}" + pw))
158

    
159
      # Try with invalid password
160
      self.failIf(tvbap("abc", "user", pw, "something"))
161

    
162
      # Try with invalid scheme
163
      self.failIf(tvbap("abc", "user", pw, "{000}" + pw))
164
      self.failIf(tvbap("abc", "user", pw, "{unk}" + pw))
165
      self.failIf(tvbap("abc", "user", pw, "{Unk}" + pw))
166
      self.failIf(tvbap("abc", "user", pw, "{UNK}" + pw))
167

    
168
    # Try with invalid scheme format
169
    self.failIf(tvbap("abc", "user", "pw", "{something"))
170

    
171
    # Hash is MD5("user:This is only a test:pw")
172
    self.assert_(tvbap("This is only a test", "user", "pw",
173
                       "{ha1}92ea58ae804481498c257b2f65561a17"))
174
    self.assert_(tvbap("This is only a test", "user", "pw",
175
                       "{HA1}92ea58ae804481498c257b2f65561a17"))
176

    
177
    self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
178
                          "{HA1}92ea58ae804481498c257b2f65561a17")
179
    self.failIf(tvbap("Admin area", "user", "pw",
180
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
181
    self.failIf(tvbap("This is only a test", "someone", "pw",
182
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
183
    self.failIf(tvbap("This is only a test", "user", "something",
184
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
185

    
186

    
187
class _SimpleAuthenticator:
188
  def __init__(self, user, password):
189
    self.user = user
190
    self.password = password
191
    self.called = False
192

    
193
  def __call__(self, req, user, password):
194
    self.called = True
195
    return self.user == user and self.password == password
196

    
197

    
198
class TestHttpServerRequestAuthentication(unittest.TestCase):
199
  def testNoAuth(self):
200
    req = http.server._HttpServerRequest("GET", "/", None, None)
201
    _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
202

    
203
  def testNoRealm(self):
204
    headers = { http.HTTP_AUTHORIZATION: "", }
205
    req = http.server._HttpServerRequest("GET", "/", headers, None)
206
    ra = _FakeRequestAuth(None, False, None)
207
    self.assertRaises(AssertionError, ra.PreHandleRequest, req)
208

    
209
  def testNoScheme(self):
210
    headers = { http.HTTP_AUTHORIZATION: "", }
211
    req = http.server._HttpServerRequest("GET", "/", headers, None)
212
    ra = _FakeRequestAuth("area1", False, None)
213
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
214

    
215
  def testUnknownScheme(self):
216
    headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
217
    req = http.server._HttpServerRequest("GET", "/", headers, None)
218
    ra = _FakeRequestAuth("area1", False, None)
219
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
220

    
221
  def testInvalidBase64(self):
222
    headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
223
    req = http.server._HttpServerRequest("GET", "/", headers, None)
224
    ra = _FakeRequestAuth("area1", False, None)
225
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
226

    
227
  def testAuthForPublicResource(self):
228
    headers = {
229
      http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
230
      }
231
    req = http.server._HttpServerRequest("GET", "/", headers, None)
232
    ra = _FakeRequestAuth("area1", False, None)
233
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
234

    
235
  def testAuthForPublicResource(self):
236
    headers = {
237
      http.HTTP_AUTHORIZATION:
238
        "Basic %s" % ("foo:bar".encode("base64").strip(), ),
239
      }
240
    req = http.server._HttpServerRequest("GET", "/", headers, None)
241
    ac = _SimpleAuthenticator("foo", "bar")
242
    ra = _FakeRequestAuth("area1", False, ac)
243
    ra.PreHandleRequest(req)
244

    
245
    req = http.server._HttpServerRequest("GET", "/", headers, None)
246
    ac = _SimpleAuthenticator("something", "else")
247
    ra = _FakeRequestAuth("area1", False, ac)
248
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
249

    
250
  def testInvalidRequestHeader(self):
251
    checks = {
252
      http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
253
                              "basic %s" % "foobar".encode("base64").strip()],
254
      http.HttpBadRequest: ["Basic"],
255
      }
256

    
257
    for exc, headers in checks.items():
258
      for i in headers:
259
        headers = { http.HTTP_AUTHORIZATION: i, }
260
        req = http.server._HttpServerRequest("GET", "/", headers, None)
261
        ra = _FakeRequestAuth("area1", False, None)
262
        self.assertRaises(exc, ra.PreHandleRequest, req)
263

    
264
  def testBasicAuth(self):
265
    for user in ["", "joe", "user name with spaces"]:
266
      for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
267
                 "foo:bar:baz"]:
268
        for wrong_pw in [True, False]:
269
          basic_auth = "%s:%s" % (user, pw)
270
          if wrong_pw:
271
            basic_auth += "WRONG"
272
          headers = {
273
              http.HTTP_AUTHORIZATION:
274
                "Basic %s" % (basic_auth.encode("base64").strip(), ),
275
            }
276
          req = http.server._HttpServerRequest("GET", "/", headers, None)
277

    
278
          ac = _SimpleAuthenticator(user, pw)
279
          self.assertFalse(ac.called)
280
          ra = _FakeRequestAuth("area1", True, ac)
281
          if wrong_pw:
282
            try:
283
              ra.PreHandleRequest(req)
284
            except http.HttpUnauthorized, err:
285
              www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
286
              self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
287
            else:
288
              self.fail("Didn't raise HttpUnauthorized")
289
          else:
290
            ra.PreHandleRequest(req)
291
          self.assert_(ac.called)
292

    
293

    
294
class TestReadPasswordFile(unittest.TestCase):
295
  def testSimple(self):
296
    users = http.auth.ParsePasswordFile("user1 password")
297
    self.assertEqual(len(users), 1)
298
    self.assertEqual(users["user1"].password, "password")
299
    self.assertEqual(len(users["user1"].options), 0)
300

    
301
  def testOptions(self):
302
    buf = StringIO()
303
    buf.write("# Passwords\n")
304
    buf.write("user1 password\n")
305
    buf.write("\n")
306
    buf.write("# Comment\n")
307
    buf.write("user2 pw write,read\n")
308
    buf.write("   \t# Another comment\n")
309
    buf.write("invalidline\n")
310

    
311
    users = http.auth.ParsePasswordFile(buf.getvalue())
312
    self.assertEqual(len(users), 2)
313
    self.assertEqual(users["user1"].password, "password")
314
    self.assertEqual(len(users["user1"].options), 0)
315

    
316
    self.assertEqual(users["user2"].password, "pw")
317
    self.assertEqual(users["user2"].options, ["write", "read"])
318

    
319

    
320
class TestClientRequest(unittest.TestCase):
321
  def testRepr(self):
322
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
323
                                       headers=[], post_data="Hello World")
324
    self.assert_(repr(cr).startswith("<"))
325

    
326
  def testNoHeaders(self):
327
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
328
                                       headers=None)
329
    self.assert_(isinstance(cr.headers, list))
330
    self.assertEqual(cr.headers, [])
331
    self.assertEqual(cr.url, "https://localhost:1234/version")
332

    
333
  def testOldStyleHeaders(self):
334
    headers = {
335
      "Content-type": "text/plain",
336
      "Accept": "text/html",
337
      }
338
    cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
339
                                       headers=headers)
340
    self.assert_(isinstance(cr.headers, list))
341
    self.assertEqual(sorted(cr.headers), [
342
      "Accept: text/html",
343
      "Content-type: text/plain",
344
      ])
345
    self.assertEqual(cr.url, "https://localhost:16481/vg_list")
346

    
347
  def testNewStyleHeaders(self):
348
    headers = [
349
      "Accept: text/html",
350
      "Content-type: text/plain; charset=ascii",
351
      "Server: httpd 1.0",
352
      ]
353
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
354
                                       headers=headers)
355
    self.assert_(isinstance(cr.headers, list))
356
    self.assertEqual(sorted(cr.headers), sorted(headers))
357
    self.assertEqual(cr.url, "https://localhost:1234/version")
358

    
359
  def testPostData(self):
360
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
361
                                       post_data="Hello World")
362
    self.assertEqual(cr.post_data, "Hello World")
363

    
364
  def testNoPostData(self):
365
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
366
    self.assertEqual(cr.post_data, "")
367

    
368
  def testIdentity(self):
369
    # These should all use different connections, hence also have a different
370
    # identity
371
    cr1 = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
372
    cr2 = http.client.HttpClientRequest("localhost", 9999, "GET", "/version")
373
    cr3 = http.client.HttpClientRequest("node1", 1234, "GET", "/version")
374
    cr4 = http.client.HttpClientRequest("node1", 9999, "GET", "/version")
375

    
376
    self.assertEqual(len(set([cr1.identity, cr2.identity,
377
                              cr3.identity, cr4.identity])), 4)
378

    
379
    # But this one should have the same
380
    cr1vglist = http.client.HttpClientRequest("localhost", 1234,
381
                                              "GET", "/vg_list")
382
    self.assertEqual(cr1.identity, cr1vglist.identity)
383

    
384

    
385
class TestClient(unittest.TestCase):
386
  def test(self):
387
    pool = http.client.HttpClientPool(None)
388
    self.assertFalse(pool._pool)
389

    
390

    
391
if __name__ == '__main__':
392
  testutils.GanetiTestProgram()