Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.http_unittest.py @ 7352d33b

History | View | Annotate | Download (26.2 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2007, 2008 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the http module"""
23

    
24

    
25
import os
26
import unittest
27
import time
28
import tempfile
29
import pycurl
30
import itertools
31
import threading
32
from cStringIO import StringIO
33

    
34
from ganeti import http
35
from ganeti import compat
36

    
37
import ganeti.http.server
38
import ganeti.http.client
39
import ganeti.http.auth
40

    
41
import testutils
42

    
43

    
44
class TestStartLines(unittest.TestCase):
45
  """Test cases for start line classes"""
46

    
47
  def testClientToServerStartLine(self):
48
    """Test client to server start line (HTTP request)"""
49
    start_line = http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
50
    self.assertEqual(str(start_line), "GET / HTTP/1.1")
51

    
52
  def testServerToClientStartLine(self):
53
    """Test server to client start line (HTTP response)"""
54
    start_line = http.HttpServerToClientStartLine("HTTP/1.1", 200, "OK")
55
    self.assertEqual(str(start_line), "HTTP/1.1 200 OK")
56

    
57

    
58
class TestMisc(unittest.TestCase):
59
  """Miscellaneous tests"""
60

    
61
  def _TestDateTimeHeader(self, gmnow, expected):
62
    self.assertEqual(http.server._DateTimeHeader(gmnow=gmnow), expected)
63

    
64
  def testDateTimeHeader(self):
65
    """Test ganeti.http._DateTimeHeader"""
66
    self._TestDateTimeHeader((2008, 1, 2, 3, 4, 5, 3, 0, 0),
67
                             "Thu, 02 Jan 2008 03:04:05 GMT")
68
    self._TestDateTimeHeader((2008, 1, 1, 0, 0, 0, 0, 0, 0),
69
                             "Mon, 01 Jan 2008 00:00:00 GMT")
70
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 0, 0, 0),
71
                             "Mon, 31 Dec 2008 00:00:00 GMT")
72
    self._TestDateTimeHeader((2008, 12, 31, 23, 59, 59, 0, 0, 0),
73
                             "Mon, 31 Dec 2008 23:59:59 GMT")
74
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 6, 0, 0),
75
                             "Sun, 31 Dec 2008 00:00:00 GMT")
76

    
77
  def testHttpServerRequest(self):
78
    """Test ganeti.http.server._HttpServerRequest"""
79
    server_request = http.server._HttpServerRequest("GET", "/", None, None)
80

    
81
    # These are expected by users of the HTTP server
82
    self.assert_(hasattr(server_request, "request_method"))
83
    self.assert_(hasattr(server_request, "request_path"))
84
    self.assert_(hasattr(server_request, "request_headers"))
85
    self.assert_(hasattr(server_request, "request_body"))
86
    self.assert_(isinstance(server_request.resp_headers, dict))
87
    self.assert_(hasattr(server_request, "private"))
88

    
89
  def testServerSizeLimits(self):
90
    """Test HTTP server size limits"""
91
    message_reader_class = http.server._HttpClientToServerMessageReader
92
    self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
93
    self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
94

    
95
  def testFormatAuthHeader(self):
96
    self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
97
                     "Basic")
98
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
99
                     "Basic foo=bar")
100
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
101
                     "Basic foo=\"\"")
102
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
103
                     "Basic foo=\"x,y\"")
104
    params = {
105
      "foo": "x,y",
106
      "realm": "secure",
107
      }
108
    # It's a dict whose order isn't guaranteed, hence checking a list
109
    self.assert_(http.auth._FormatAuthHeader("Digest", params) in
110
                 ("Digest foo=\"x,y\" realm=secure",
111
                  "Digest realm=secure foo=\"x,y\""))
112

    
113

    
114
class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
115
  def __init__(self, realm, authreq, authenticate_fn):
116
    http.auth.HttpServerRequestAuthentication.__init__(self)
117

    
118
    self.realm = realm
119
    self.authreq = authreq
120
    self.authenticate_fn = authenticate_fn
121

    
122
  def AuthenticationRequired(self, req):
123
    return self.authreq
124

    
125
  def GetAuthRealm(self, req):
126
    return self.realm
127

    
128
  def Authenticate(self, *args):
129
    if self.authenticate_fn:
130
      return self.authenticate_fn(*args)
131
    raise NotImplementedError()
132

    
133

    
134
class TestAuth(unittest.TestCase):
135
  """Authentication tests"""
136

    
137
  hsra = http.auth.HttpServerRequestAuthentication
138

    
139
  def testConstants(self):
140
    for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
141
      self.assertEqual(scheme, scheme.upper())
142
      self.assert_(scheme.startswith("{"))
143
      self.assert_(scheme.endswith("}"))
144

    
145
  def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
146
    ra = _FakeRequestAuth(realm, False, None)
147

    
148
    return ra.VerifyBasicAuthPassword(None, user, password, expected)
149

    
150
  def testVerifyBasicAuthPassword(self):
151
    tvbap = self._testVerifyBasicAuthPassword
152

    
153
    good_pws = ["pw", "pw{", "pw}", "pw{}", "pw{x}y", "}pw",
154
                "0", "123", "foo...:xyz", "TeST"]
155

    
156
    for pw in good_pws:
157
      # Try cleartext passwords
158
      self.assert_(tvbap("abc", "user", pw, pw))
159
      self.assert_(tvbap("abc", "user", pw, "{cleartext}" + pw))
160
      self.assert_(tvbap("abc", "user", pw, "{ClearText}" + pw))
161
      self.assert_(tvbap("abc", "user", pw, "{CLEARTEXT}" + pw))
162

    
163
      # Try with invalid password
164
      self.failIf(tvbap("abc", "user", pw, "something"))
165

    
166
      # Try with invalid scheme
167
      self.failIf(tvbap("abc", "user", pw, "{000}" + pw))
168
      self.failIf(tvbap("abc", "user", pw, "{unk}" + pw))
169
      self.failIf(tvbap("abc", "user", pw, "{Unk}" + pw))
170
      self.failIf(tvbap("abc", "user", pw, "{UNK}" + pw))
171

    
172
    # Try with invalid scheme format
173
    self.failIf(tvbap("abc", "user", "pw", "{something"))
174

    
175
    # Hash is MD5("user:This is only a test:pw")
176
    self.assert_(tvbap("This is only a test", "user", "pw",
177
                       "{ha1}92ea58ae804481498c257b2f65561a17"))
178
    self.assert_(tvbap("This is only a test", "user", "pw",
179
                       "{HA1}92ea58ae804481498c257b2f65561a17"))
180

    
181
    self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
182
                          "{HA1}92ea58ae804481498c257b2f65561a17")
183
    self.failIf(tvbap("Admin area", "user", "pw",
184
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
185
    self.failIf(tvbap("This is only a test", "someone", "pw",
186
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
187
    self.failIf(tvbap("This is only a test", "user", "something",
188
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
189

    
190

    
191
class _SimpleAuthenticator:
192
  def __init__(self, user, password):
193
    self.user = user
194
    self.password = password
195
    self.called = False
196

    
197
  def __call__(self, req, user, password):
198
    self.called = True
199
    return self.user == user and self.password == password
200

    
201

    
202
class TestHttpServerRequestAuthentication(unittest.TestCase):
203
  def testNoAuth(self):
204
    req = http.server._HttpServerRequest("GET", "/", None, None)
205
    _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
206

    
207
  def testNoRealm(self):
208
    headers = { http.HTTP_AUTHORIZATION: "", }
209
    req = http.server._HttpServerRequest("GET", "/", headers, None)
210
    ra = _FakeRequestAuth(None, False, None)
211
    self.assertRaises(AssertionError, ra.PreHandleRequest, req)
212

    
213
  def testNoScheme(self):
214
    headers = { http.HTTP_AUTHORIZATION: "", }
215
    req = http.server._HttpServerRequest("GET", "/", headers, None)
216
    ra = _FakeRequestAuth("area1", False, None)
217
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
218

    
219
  def testUnknownScheme(self):
220
    headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
221
    req = http.server._HttpServerRequest("GET", "/", headers, None)
222
    ra = _FakeRequestAuth("area1", False, None)
223
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
224

    
225
  def testInvalidBase64(self):
226
    headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
227
    req = http.server._HttpServerRequest("GET", "/", headers, None)
228
    ra = _FakeRequestAuth("area1", False, None)
229
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
230

    
231
  def testAuthForPublicResource(self):
232
    headers = {
233
      http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
234
      }
235
    req = http.server._HttpServerRequest("GET", "/", headers, None)
236
    ra = _FakeRequestAuth("area1", False, None)
237
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
238

    
239
  def testAuthForPublicResource(self):
240
    headers = {
241
      http.HTTP_AUTHORIZATION:
242
        "Basic %s" % ("foo:bar".encode("base64").strip(), ),
243
      }
244
    req = http.server._HttpServerRequest("GET", "/", headers, None)
245
    ac = _SimpleAuthenticator("foo", "bar")
246
    ra = _FakeRequestAuth("area1", False, ac)
247
    ra.PreHandleRequest(req)
248

    
249
    req = http.server._HttpServerRequest("GET", "/", headers, None)
250
    ac = _SimpleAuthenticator("something", "else")
251
    ra = _FakeRequestAuth("area1", False, ac)
252
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
253

    
254
  def testInvalidRequestHeader(self):
255
    checks = {
256
      http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
257
                              "basic %s" % "foobar".encode("base64").strip()],
258
      http.HttpBadRequest: ["Basic"],
259
      }
260

    
261
    for exc, headers in checks.items():
262
      for i in headers:
263
        headers = { http.HTTP_AUTHORIZATION: i, }
264
        req = http.server._HttpServerRequest("GET", "/", headers, None)
265
        ra = _FakeRequestAuth("area1", False, None)
266
        self.assertRaises(exc, ra.PreHandleRequest, req)
267

    
268
  def testBasicAuth(self):
269
    for user in ["", "joe", "user name with spaces"]:
270
      for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
271
                 "foo:bar:baz"]:
272
        for wrong_pw in [True, False]:
273
          basic_auth = "%s:%s" % (user, pw)
274
          if wrong_pw:
275
            basic_auth += "WRONG"
276
          headers = {
277
              http.HTTP_AUTHORIZATION:
278
                "Basic %s" % (basic_auth.encode("base64").strip(), ),
279
            }
280
          req = http.server._HttpServerRequest("GET", "/", headers, None)
281

    
282
          ac = _SimpleAuthenticator(user, pw)
283
          self.assertFalse(ac.called)
284
          ra = _FakeRequestAuth("area1", True, ac)
285
          if wrong_pw:
286
            try:
287
              ra.PreHandleRequest(req)
288
            except http.HttpUnauthorized, err:
289
              www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
290
              self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
291
            else:
292
              self.fail("Didn't raise HttpUnauthorized")
293
          else:
294
            ra.PreHandleRequest(req)
295
          self.assert_(ac.called)
296

    
297

    
298
class TestReadPasswordFile(unittest.TestCase):
299
  def testSimple(self):
300
    users = http.auth.ParsePasswordFile("user1 password")
301
    self.assertEqual(len(users), 1)
302
    self.assertEqual(users["user1"].password, "password")
303
    self.assertEqual(len(users["user1"].options), 0)
304

    
305
  def testOptions(self):
306
    buf = StringIO()
307
    buf.write("# Passwords\n")
308
    buf.write("user1 password\n")
309
    buf.write("\n")
310
    buf.write("# Comment\n")
311
    buf.write("user2 pw write,read\n")
312
    buf.write("   \t# Another comment\n")
313
    buf.write("invalidline\n")
314

    
315
    users = http.auth.ParsePasswordFile(buf.getvalue())
316
    self.assertEqual(len(users), 2)
317
    self.assertEqual(users["user1"].password, "password")
318
    self.assertEqual(len(users["user1"].options), 0)
319

    
320
    self.assertEqual(users["user2"].password, "pw")
321
    self.assertEqual(users["user2"].options, ["write", "read"])
322

    
323

    
324
class TestClientRequest(unittest.TestCase):
325
  def testRepr(self):
326
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
327
                                       headers=[], post_data="Hello World")
328
    self.assert_(repr(cr).startswith("<"))
329

    
330
  def testNoHeaders(self):
331
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
332
                                       headers=None)
333
    self.assert_(isinstance(cr.headers, list))
334
    self.assertEqual(cr.headers, [])
335
    self.assertEqual(cr.url, "https://localhost:1234/version")
336

    
337
  def testPlainAddressIPv4(self):
338
    cr = http.client.HttpClientRequest("192.0.2.9", 19956, "GET", "/version")
339
    self.assertEqual(cr.url, "https://192.0.2.9:19956/version")
340

    
341
  def testPlainAddressIPv6(self):
342
    cr = http.client.HttpClientRequest("2001:db8::cafe", 15110, "GET", "/info")
343
    self.assertEqual(cr.url, "https://[2001:db8::cafe]:15110/info")
344

    
345
  def testOldStyleHeaders(self):
346
    headers = {
347
      "Content-type": "text/plain",
348
      "Accept": "text/html",
349
      }
350
    cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
351
                                       headers=headers)
352
    self.assert_(isinstance(cr.headers, list))
353
    self.assertEqual(sorted(cr.headers), [
354
      "Accept: text/html",
355
      "Content-type: text/plain",
356
      ])
357
    self.assertEqual(cr.url, "https://localhost:16481/vg_list")
358

    
359
  def testNewStyleHeaders(self):
360
    headers = [
361
      "Accept: text/html",
362
      "Content-type: text/plain; charset=ascii",
363
      "Server: httpd 1.0",
364
      ]
365
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
366
                                       headers=headers)
367
    self.assert_(isinstance(cr.headers, list))
368
    self.assertEqual(sorted(cr.headers), sorted(headers))
369
    self.assertEqual(cr.url, "https://localhost:1234/version")
370

    
371
  def testPostData(self):
372
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
373
                                       post_data="Hello World")
374
    self.assertEqual(cr.post_data, "Hello World")
375

    
376
  def testNoPostData(self):
377
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
378
    self.assertEqual(cr.post_data, "")
379

    
380
  def testCompletionCallback(self):
381
    for argname in ["completion_cb", "curl_config_fn"]:
382
      kwargs = {
383
        argname: NotImplementedError,
384
        }
385
      cr = http.client.HttpClientRequest("localhost", 14038, "GET", "/version",
386
                                         **kwargs)
387
      self.assertEqual(getattr(cr, argname), NotImplementedError)
388

    
389
      for fn in [NotImplemented, {}, 1]:
390
        kwargs = {
391
          argname: fn,
392
          }
393
        self.assertRaises(AssertionError, http.client.HttpClientRequest,
394
                          "localhost", 23150, "GET", "/version", **kwargs)
395

    
396

    
397
class _FakeCurl:
398
  def __init__(self):
399
    self.opts = {}
400
    self.info = NotImplemented
401

    
402
  def setopt(self, opt, value):
403
    assert opt not in self.opts, "Option set more than once"
404
    self.opts[opt] = value
405

    
406
  def getinfo(self, info):
407
    return self.info.pop(info)
408

    
409

    
410
class TestClientStartRequest(unittest.TestCase):
411
  @staticmethod
412
  def _TestCurlConfig(curl):
413
    curl.setopt(pycurl.SSLKEYTYPE, "PEM")
414

    
415
  def test(self):
416
    for method in [http.HTTP_GET, http.HTTP_PUT, "CUSTOM"]:
417
      for port in [8761, 29796, 19528]:
418
        for curl_config_fn in [None, self._TestCurlConfig]:
419
          for read_timeout in [None, 0, 1, 123, 36000]:
420
            self._TestInner(method, port, curl_config_fn, read_timeout)
421

    
422
  def _TestInner(self, method, port, curl_config_fn, read_timeout):
423
    for response_code in [http.HTTP_OK, http.HttpNotFound.code,
424
                          http.HTTP_NOT_MODIFIED]:
425
      for response_body in [None, "Hello World",
426
                            "Very Long\tContent here\n" * 171]:
427
        for errmsg in [None, "error"]:
428
          req = http.client.HttpClientRequest("localhost", port, method,
429
                                              "/version",
430
                                              curl_config_fn=curl_config_fn,
431
                                              read_timeout=read_timeout)
432
          curl = _FakeCurl()
433
          pending = http.client._StartRequest(curl, req)
434
          self.assertEqual(pending.GetCurlHandle(), curl)
435
          self.assertEqual(pending.GetCurrentRequest(), req)
436

    
437
          # Check options
438
          opts = curl.opts
439
          self.assertEqual(opts.pop(pycurl.CUSTOMREQUEST), method)
440
          self.assertEqual(opts.pop(pycurl.URL),
441
                           "https://localhost:%s/version" % port)
442
          if read_timeout is None:
443
            self.assertEqual(opts.pop(pycurl.TIMEOUT), 0)
444
          else:
445
            self.assertEqual(opts.pop(pycurl.TIMEOUT), read_timeout)
446
          self.assertFalse(opts.pop(pycurl.VERBOSE))
447
          self.assertTrue(opts.pop(pycurl.NOSIGNAL))
448
          self.assertEqual(opts.pop(pycurl.USERAGENT),
449
                           http.HTTP_GANETI_VERSION)
450
          self.assertEqual(opts.pop(pycurl.PROXY), "")
451
          self.assertFalse(opts.pop(pycurl.POSTFIELDS))
452
          self.assertFalse(opts.pop(pycurl.HTTPHEADER))
453
          write_fn = opts.pop(pycurl.WRITEFUNCTION)
454
          self.assertTrue(callable(write_fn))
455
          if hasattr(pycurl, "SSL_SESSIONID_CACHE"):
456
            self.assertFalse(opts.pop(pycurl.SSL_SESSIONID_CACHE))
457
          if curl_config_fn:
458
            self.assertEqual(opts.pop(pycurl.SSLKEYTYPE), "PEM")
459
          else:
460
            self.assertFalse(pycurl.SSLKEYTYPE in opts)
461
          self.assertFalse(opts)
462

    
463
          if response_body is not None:
464
            offset = 0
465
            while offset < len(response_body):
466
              piece = response_body[offset:offset + 10]
467
              write_fn(piece)
468
              offset += len(piece)
469

    
470
          curl.info = {
471
            pycurl.RESPONSE_CODE: response_code,
472
            }
473

    
474
          # Finalize request
475
          pending.Done(errmsg)
476

    
477
          self.assertFalse(curl.info)
478

    
479
          # Can only finalize once
480
          self.assertRaises(AssertionError, pending.Done, True)
481

    
482
          if errmsg:
483
            self.assertFalse(req.success)
484
          else:
485
            self.assertTrue(req.success)
486
          self.assertEqual(req.error, errmsg)
487
          self.assertEqual(req.resp_status_code, response_code)
488
          if response_body is None:
489
            self.assertEqual(req.resp_body, "")
490
          else:
491
            self.assertEqual(req.resp_body, response_body)
492

    
493
          # Check if resetting worked
494
          assert not hasattr(curl, "reset")
495
          opts = curl.opts
496
          self.assertFalse(opts.pop(pycurl.POSTFIELDS))
497
          self.assertTrue(callable(opts.pop(pycurl.WRITEFUNCTION)))
498
          self.assertFalse(opts)
499

    
500
          self.assertFalse(curl.opts,
501
                           msg="Previous checks did not consume all options")
502
          assert id(opts) == id(curl.opts)
503

    
504
  def _TestWrongTypes(self, *args, **kwargs):
505
    req = http.client.HttpClientRequest(*args, **kwargs)
506
    self.assertRaises(AssertionError, http.client._StartRequest,
507
                      _FakeCurl(), req)
508

    
509
  def testWrongHostType(self):
510
    self._TestWrongTypes(unicode("localhost"), 8080, "GET", "/version")
511

    
512
  def testWrongUrlType(self):
513
    self._TestWrongTypes("localhost", 8080, "GET", unicode("/version"))
514

    
515
  def testWrongMethodType(self):
516
    self._TestWrongTypes("localhost", 8080, unicode("GET"), "/version")
517

    
518
  def testWrongHeaderType(self):
519
    self._TestWrongTypes("localhost", 8080, "GET", "/version",
520
                         headers={
521
                           unicode("foo"): "bar",
522
                           })
523

    
524
  def testWrongPostDataType(self):
525
    self._TestWrongTypes("localhost", 8080, "GET", "/version",
526
                         post_data=unicode("verylongdata" * 100))
527

    
528

    
529
class _EmptyCurlMulti:
530
  def perform(self):
531
    return (pycurl.E_MULTI_OK, 0)
532

    
533
  def info_read(self):
534
    return (0, [], [])
535

    
536

    
537
class TestClientProcessRequests(unittest.TestCase):
538
  def testEmpty(self):
539
    requests = []
540
    http.client.ProcessRequests(requests, _curl=NotImplemented,
541
                                _curl_multi=_EmptyCurlMulti)
542
    self.assertEqual(requests, [])
543

    
544

    
545
class TestProcessCurlRequests(unittest.TestCase):
546
  class _FakeCurlMulti:
547
    def __init__(self):
548
      self.handles = []
549
      self.will_fail = []
550
      self._expect = ["perform"]
551
      self._counter = itertools.count()
552

    
553
    def add_handle(self, curl):
554
      assert curl not in self.handles
555
      self.handles.append(curl)
556
      if self._counter.next() % 3 == 0:
557
        self.will_fail.append(curl)
558

    
559
    def remove_handle(self, curl):
560
      self.handles.remove(curl)
561

    
562
    def perform(self):
563
      assert self._expect.pop(0) == "perform"
564

    
565
      if self._counter.next() % 2 == 0:
566
        self._expect.append("perform")
567
        return (pycurl.E_CALL_MULTI_PERFORM, None)
568

    
569
      self._expect.append("info_read")
570

    
571
      return (pycurl.E_MULTI_OK, len(self.handles))
572

    
573
    def info_read(self):
574
      assert self._expect.pop(0) == "info_read"
575
      successful = []
576
      failed = []
577
      if self.handles:
578
        if self._counter.next() % 17 == 0:
579
          curl = self.handles[0]
580
          if curl in self.will_fail:
581
            failed.append((curl, -1, "test error"))
582
          else:
583
            successful.append(curl)
584
        remaining_messages = len(self.handles) % 3
585
        if remaining_messages > 0:
586
          self._expect.append("info_read")
587
        else:
588
          self._expect.append("select")
589
      else:
590
        remaining_messages = 0
591
        self._expect.append("select")
592
      return (remaining_messages, successful, failed)
593

    
594
    def select(self, timeout):
595
      # Never compare floats for equality
596
      assert timeout >= 0.95 and timeout <= 1.05
597
      assert self._expect.pop(0) == "select"
598
      self._expect.append("perform")
599

    
600
  def test(self):
601
    requests = [_FakeCurl() for _ in range(10)]
602
    multi = self._FakeCurlMulti()
603
    for (curl, errmsg) in http.client._ProcessCurlRequests(multi, requests):
604
      self.assertTrue(curl not in multi.handles)
605
      if curl in multi.will_fail:
606
        self.assertTrue("test error" in errmsg)
607
      else:
608
        self.assertTrue(errmsg is None)
609
    self.assertFalse(multi.handles)
610
    self.assertEqual(multi._expect, ["select"])
611

    
612

    
613
class TestProcessRequests(unittest.TestCase):
614
  class _DummyCurlMulti:
615
    pass
616

    
617
  def testNoMonitor(self):
618
    self._Test(False)
619

    
620
  def testWithMonitor(self):
621
    self._Test(True)
622

    
623
  class _MonitorChecker:
624
    def __init__(self):
625
      self._monitor = None
626

    
627
    def GetMonitor(self):
628
      return self._monitor
629

    
630
    def __call__(self, monitor):
631
      assert callable(monitor.GetLockInfo)
632
      self._monitor = monitor
633

    
634
  def _Test(self, use_monitor):
635
    def cfg_fn(port, curl):
636
      curl.opts["__port__"] = port
637

    
638
    def _LockCheckReset(monitor, req):
639
      self.assertTrue(monitor._lock.is_owned(shared=0),
640
                      msg="Lock must be owned in exclusive mode")
641
      assert not hasattr(req, "lockcheck__")
642
      setattr(req, "lockcheck__", True)
643

    
644
    def _BuildNiceName(port, default=None):
645
      if port % 5 == 0:
646
        return "nicename%s" % port
647
      else:
648
        # Use standard name
649
        return default
650

    
651
    requests = \
652
      [http.client.HttpClientRequest("localhost", i, "POST", "/version%s" % i,
653
                                     curl_config_fn=compat.partial(cfg_fn, i),
654
                                     completion_cb=NotImplementedError,
655
                                     nicename=_BuildNiceName(i))
656
       for i in range(15176, 15501)]
657
    requests_count = len(requests)
658

    
659
    if use_monitor:
660
      lock_monitor_cb = self._MonitorChecker()
661
    else:
662
      lock_monitor_cb = None
663

    
664
    def _ProcessRequests(multi, handles):
665
      self.assertTrue(isinstance(multi, self._DummyCurlMulti))
666
      self.assertEqual(len(requests), len(handles))
667
      self.assertTrue(compat.all(isinstance(curl, _FakeCurl)
668
                                 for curl in handles))
669

    
670
      # Prepare for lock check
671
      for req in requests:
672
        assert req.completion_cb is NotImplementedError
673
        if use_monitor:
674
          req.completion_cb = \
675
            compat.partial(_LockCheckReset, lock_monitor_cb.GetMonitor())
676

    
677
      for idx, curl in enumerate(handles):
678
        try:
679
          port = curl.opts["__port__"]
680
        except KeyError:
681
          self.fail("Per-request config function was not called")
682

    
683
        if use_monitor:
684
          # Check if lock information is correct
685
          lock_info = lock_monitor_cb.GetMonitor().GetLockInfo(None)
686
          expected = \
687
            [("rpc/%s" % (_BuildNiceName(handle.opts["__port__"],
688
                                         default=("localhost/version%s" %
689
                                                  handle.opts["__port__"]))),
690
              None,
691
              [threading.currentThread().getName()], None)
692
             for handle in handles[idx:]]
693
          self.assertEqual(sorted(lock_info), sorted(expected))
694

    
695
        if port % 3 == 0:
696
          response_code = http.HTTP_OK
697
          msg = None
698
        else:
699
          response_code = http.HttpNotFound.code
700
          msg = "test error"
701

    
702
        curl.info = {
703
          pycurl.RESPONSE_CODE: response_code,
704
          }
705

    
706
        # Prepare for reset
707
        self.assertFalse(curl.opts.pop(pycurl.POSTFIELDS))
708
        self.assertTrue(callable(curl.opts.pop(pycurl.WRITEFUNCTION)))
709

    
710
        yield (curl, msg)
711

    
712
      if use_monitor:
713
        self.assertTrue(compat.all(req.lockcheck__ for req in requests))
714

    
715
    if use_monitor:
716
      self.assertEqual(lock_monitor_cb.GetMonitor(), None)
717

    
718
    http.client.ProcessRequests(requests, lock_monitor_cb=lock_monitor_cb,
719
                                _curl=_FakeCurl,
720
                                _curl_multi=self._DummyCurlMulti,
721
                                _curl_process=_ProcessRequests)
722
    for req in requests:
723
      if req.port % 3 == 0:
724
        self.assertTrue(req.success)
725
        self.assertEqual(req.error, None)
726
      else:
727
        self.assertFalse(req.success)
728
        self.assertTrue("test error" in req.error)
729

    
730
    # See if monitor was disabled
731
    if use_monitor:
732
      monitor = lock_monitor_cb.GetMonitor()
733
      self.assertEqual(monitor._pending_fn, None)
734
      self.assertEqual(monitor.GetLockInfo(None), [])
735
    else:
736
      self.assertEqual(lock_monitor_cb, None)
737

    
738
    self.assertEqual(len(requests), requests_count)
739

    
740
  def testBadRequest(self):
741
    bad_request = http.client.HttpClientRequest("localhost", 27784,
742
                                                "POST", "/version")
743
    bad_request.success = False
744

    
745
    self.assertRaises(AssertionError, http.client.ProcessRequests,
746
                      [bad_request], _curl=NotImplemented,
747
                      _curl_multi=NotImplemented, _curl_process=NotImplemented)
748

    
749

    
750
if __name__ == "__main__":
751
  testutils.GanetiTestProgram()