Revision a8950eb7

b/lib/http/server.py
74 74
  """Data structure for HTTP request on server side.
75 75

  
76 76
  """
77
  def __init__(self, request_msg):
77
  def __init__(self, method, path, headers, body):
78 78
    # Request attributes
79
    self.request_method = request_msg.start_line.method
80
    self.request_path = request_msg.start_line.path
81
    self.request_headers = request_msg.headers
82
    self.request_body = request_msg.decoded_body
79
    self.request_method = method
80
    self.request_path = path
81
    self.request_headers = headers
82
    self.request_body = body
83 83

  
84 84
    # Response attributes
85 85
    self.resp_headers = {}
......
308 308
    """Calls the handler function for the current request.
309 309

  
310 310
    """
311
    handler_context = _HttpServerRequest(self.request_msg)
311
    handler_context = _HttpServerRequest(self.request_msg.start_line.method,
312
                                         self.request_msg.start_line.path,
313
                                         self.request_msg.headers,
314
                                         self.request_msg.decoded_body)
312 315

  
313 316
    try:
314 317
      try:
b/test/ganeti.http_unittest.py
25 25
import os
26 26
import unittest
27 27
import time
28
import tempfile
28 29

  
29 30
from ganeti import http
30 31

  
......
70 71

  
71 72
  def testHttpServerRequest(self):
72 73
    """Test ganeti.http.server._HttpServerRequest"""
73
    fake_request = http.HttpMessage()
74
    fake_request.start_line = \
75
      http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
76
    server_request = http.server._HttpServerRequest(fake_request)
74
    server_request = http.server._HttpServerRequest("GET", "/", None, None)
77 75

  
78 76
    # These are expected by users of the HTTP server
79 77
    self.assert_(hasattr(server_request, "request_method"))
......
95 93
    self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
96 94
    self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
97 95

  
96
  def testFormatAuthHeader(self):
97
    self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
98
                     "Basic")
99
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
100
                     "Basic foo=bar")
101
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
102
                     "Basic foo=\"\"")
103
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
104
                     "Basic foo=\"x,y\"")
105
    params = {
106
      "foo": "x,y",
107
      "realm": "secure",
108
      }
109
    # It's a dict whose order isn't guaranteed, hence checking a list
110
    self.assert_(http.auth._FormatAuthHeader("Digest", params) in
111
                 ("Digest foo=\"x,y\" realm=secure",
112
                  "Digest realm=secure foo=\"x,y\""))
113

  
98 114

  
99 115
class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
100
  def __init__(self, realm):
116
  def __init__(self, realm, authreq, authenticate_fn):
101 117
    http.auth.HttpServerRequestAuthentication.__init__(self)
102 118

  
103 119
    self.realm = realm
120
    self.authreq = authreq
121
    self.authenticate_fn = authenticate_fn
122

  
123
  def AuthenticationRequired(self, req):
124
    return self.authreq
104 125

  
105 126
  def GetAuthRealm(self, req):
106 127
    return self.realm
107 128

  
129
  def Authenticate(self, *args):
130
    if self.authenticate_fn:
131
      return self.authenticate_fn(*args)
132
    raise NotImplementedError()
133

  
108 134

  
109 135
class TestAuth(unittest.TestCase):
110 136
  """Authentication tests"""
......
112 138
  hsra = http.auth.HttpServerRequestAuthentication
113 139

  
114 140
  def testConstants(self):
115
    self.assertEqual(self.hsra._CLEARTEXT_SCHEME,
116
                     self.hsra._CLEARTEXT_SCHEME.upper())
117
    self.assertEqual(self.hsra._HA1_SCHEME,
118
                     self.hsra._HA1_SCHEME.upper())
141
    for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
142
      self.assertEqual(scheme, scheme.upper())
143
      self.assert_(scheme.startswith("{"))
144
      self.assert_(scheme.endswith("}"))
119 145

  
120 146
  def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
121
    ra = _FakeRequestAuth(realm)
147
    ra = _FakeRequestAuth(realm, False, None)
122 148

  
123 149
    return ra.VerifyBasicAuthPassword(None, user, password, expected)
124 150

  
125

  
126 151
  def testVerifyBasicAuthPassword(self):
127 152
    tvbap = self._testVerifyBasicAuthPassword
128 153

  
......
164 189
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
165 190

  
166 191

  
192
class _SimpleAuthenticator:
193
  def __init__(self, user, password):
194
    self.user = user
195
    self.password = password
196
    self.called = False
197

  
198
  def __call__(self, req, user, password):
199
    self.called = True
200
    return self.user == user and self.password == password
201

  
202

  
203
class TestHttpServerRequestAuthentication(unittest.TestCase):
204
  def testNoAuth(self):
205
    req = http.server._HttpServerRequest("GET", "/", None, None)
206
    _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
207

  
208
  def testNoRealm(self):
209
    headers = { http.HTTP_AUTHORIZATION: "", }
210
    req = http.server._HttpServerRequest("GET", "/", headers, None)
211
    ra = _FakeRequestAuth(None, False, None)
212
    self.assertRaises(AssertionError, ra.PreHandleRequest, req)
213

  
214
  def testNoScheme(self):
215
    headers = { http.HTTP_AUTHORIZATION: "", }
216
    req = http.server._HttpServerRequest("GET", "/", headers, None)
217
    ra = _FakeRequestAuth("area1", False, None)
218
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
219

  
220
  def testUnknownScheme(self):
221
    headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
222
    req = http.server._HttpServerRequest("GET", "/", headers, None)
223
    ra = _FakeRequestAuth("area1", False, None)
224
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
225

  
226
  def testInvalidBase64(self):
227
    headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
228
    req = http.server._HttpServerRequest("GET", "/", headers, None)
229
    ra = _FakeRequestAuth("area1", False, None)
230
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
231

  
232
  def testAuthForPublicResource(self):
233
    headers = {
234
      http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
235
      }
236
    req = http.server._HttpServerRequest("GET", "/", headers, None)
237
    ra = _FakeRequestAuth("area1", False, None)
238
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
239

  
240
  def testAuthForPublicResource(self):
241
    headers = {
242
      http.HTTP_AUTHORIZATION:
243
        "Basic %s" % ("foo:bar".encode("base64").strip(), ),
244
      }
245
    req = http.server._HttpServerRequest("GET", "/", headers, None)
246
    ac = _SimpleAuthenticator("foo", "bar")
247
    ra = _FakeRequestAuth("area1", False, ac)
248
    ra.PreHandleRequest(req)
249

  
250
    req = http.server._HttpServerRequest("GET", "/", headers, None)
251
    ac = _SimpleAuthenticator("something", "else")
252
    ra = _FakeRequestAuth("area1", False, ac)
253
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
254

  
255
  def testInvalidRequestHeader(self):
256
    checks = {
257
      http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
258
                              "basic %s" % "foobar".encode("base64").strip()],
259
      http.HttpBadRequest: ["Basic"],
260
      }
261

  
262
    for exc, headers in checks.items():
263
      for i in headers:
264
        headers = { http.HTTP_AUTHORIZATION: i, }
265
        req = http.server._HttpServerRequest("GET", "/", headers, None)
266
        ra = _FakeRequestAuth("area1", False, None)
267
        self.assertRaises(exc, ra.PreHandleRequest, req)
268

  
269
  def testBasicAuth(self):
270
    for user in ["", "joe", "user name with spaces"]:
271
      for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
272
                 "foo:bar:baz"]:
273
        for wrong_pw in [True, False]:
274
          basic_auth = "%s:%s" % (user, pw)
275
          if wrong_pw:
276
            basic_auth += "WRONG"
277
          headers = {
278
              http.HTTP_AUTHORIZATION:
279
                "Basic %s" % (basic_auth.encode("base64").strip(), ),
280
            }
281
          req = http.server._HttpServerRequest("GET", "/", headers, None)
282

  
283
          ac = _SimpleAuthenticator(user, pw)
284
          self.assertFalse(ac.called)
285
          ra = _FakeRequestAuth("area1", True, ac)
286
          if wrong_pw:
287
            try:
288
              ra.PreHandleRequest(req)
289
            except http.HttpUnauthorized, err:
290
              www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
291
              self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
292
            else:
293
              self.fail("Didn't raise HttpUnauthorized")
294
          else:
295
            ra.PreHandleRequest(req)
296
          self.assert_(ac.called)
297

  
298

  
299
class TestReadPasswordFile(testutils.GanetiTestCase):
300
  def setUp(self):
301
    testutils.GanetiTestCase.setUp(self)
302

  
303
    self.tmpfile = tempfile.NamedTemporaryFile()
304

  
305
  def testSimple(self):
306
    self.tmpfile.write("user1 password")
307
    self.tmpfile.flush()
308

  
309
    users = http.auth.ReadPasswordFile(self.tmpfile.name)
310
    self.assertEqual(len(users), 1)
311
    self.assertEqual(users["user1"].password, "password")
312
    self.assertEqual(len(users["user1"].options), 0)
313

  
314
  def testOptions(self):
315
    self.tmpfile.write("# Passwords\n")
316
    self.tmpfile.write("user1 password\n")
317
    self.tmpfile.write("\n")
318
    self.tmpfile.write("# Comment\n")
319
    self.tmpfile.write("user2 pw write,read\n")
320
    self.tmpfile.write("   \t# Another comment\n")
321
    self.tmpfile.write("invalidline\n")
322
    self.tmpfile.flush()
323

  
324
    users = http.auth.ReadPasswordFile(self.tmpfile.name)
325
    self.assertEqual(len(users), 2)
326
    self.assertEqual(users["user1"].password, "password")
327
    self.assertEqual(len(users["user1"].options), 0)
328

  
329
    self.assertEqual(users["user2"].password, "pw")
330
    self.assertEqual(users["user2"].options, ["write", "read"])
331

  
332

  
167 333
if __name__ == '__main__':
168 334
  testutils.GanetiTestProgram()

Also available in: Unified diff