Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.http_unittest.py @ abbf2cd9

History | View | Annotate | Download (25 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2007, 2008 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the http module"""
23

    
24

    
25
import os
26
import unittest
27
import time
28
import tempfile
29
import pycurl
30
import itertools
31
import threading
32
from cStringIO import StringIO
33

    
34
from ganeti import http
35
from ganeti import compat
36

    
37
import ganeti.http.server
38
import ganeti.http.client
39
import ganeti.http.auth
40

    
41
import testutils
42

    
43

    
44
class TestStartLines(unittest.TestCase):
45
  """Test cases for start line classes"""
46

    
47
  def testClientToServerStartLine(self):
48
    """Test client to server start line (HTTP request)"""
49
    start_line = http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
50
    self.assertEqual(str(start_line), "GET / HTTP/1.1")
51

    
52
  def testServerToClientStartLine(self):
53
    """Test server to client start line (HTTP response)"""
54
    start_line = http.HttpServerToClientStartLine("HTTP/1.1", 200, "OK")
55
    self.assertEqual(str(start_line), "HTTP/1.1 200 OK")
56

    
57

    
58
class TestMisc(unittest.TestCase):
59
  """Miscellaneous tests"""
60

    
61
  def _TestDateTimeHeader(self, gmnow, expected):
62
    self.assertEqual(http.server._DateTimeHeader(gmnow=gmnow), expected)
63

    
64
  def testDateTimeHeader(self):
65
    """Test ganeti.http._DateTimeHeader"""
66
    self._TestDateTimeHeader((2008, 1, 2, 3, 4, 5, 3, 0, 0),
67
                             "Thu, 02 Jan 2008 03:04:05 GMT")
68
    self._TestDateTimeHeader((2008, 1, 1, 0, 0, 0, 0, 0, 0),
69
                             "Mon, 01 Jan 2008 00:00:00 GMT")
70
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 0, 0, 0),
71
                             "Mon, 31 Dec 2008 00:00:00 GMT")
72
    self._TestDateTimeHeader((2008, 12, 31, 23, 59, 59, 0, 0, 0),
73
                             "Mon, 31 Dec 2008 23:59:59 GMT")
74
    self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 6, 0, 0),
75
                             "Sun, 31 Dec 2008 00:00:00 GMT")
76

    
77
  def testHttpServerRequest(self):
78
    """Test ganeti.http.server._HttpServerRequest"""
79
    server_request = http.server._HttpServerRequest("GET", "/", None, None)
80

    
81
    # These are expected by users of the HTTP server
82
    self.assert_(hasattr(server_request, "request_method"))
83
    self.assert_(hasattr(server_request, "request_path"))
84
    self.assert_(hasattr(server_request, "request_headers"))
85
    self.assert_(hasattr(server_request, "request_body"))
86
    self.assert_(isinstance(server_request.resp_headers, dict))
87
    self.assert_(hasattr(server_request, "private"))
88

    
89
  def testServerSizeLimits(self):
90
    """Test HTTP server size limits"""
91
    message_reader_class = http.server._HttpClientToServerMessageReader
92
    self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
93
    self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
94

    
95
  def testFormatAuthHeader(self):
96
    self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
97
                     "Basic")
98
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
99
                     "Basic foo=bar")
100
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
101
                     "Basic foo=\"\"")
102
    self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
103
                     "Basic foo=\"x,y\"")
104
    params = {
105
      "foo": "x,y",
106
      "realm": "secure",
107
      }
108
    # It's a dict whose order isn't guaranteed, hence checking a list
109
    self.assert_(http.auth._FormatAuthHeader("Digest", params) in
110
                 ("Digest foo=\"x,y\" realm=secure",
111
                  "Digest realm=secure foo=\"x,y\""))
112

    
113

    
114
class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
115
  def __init__(self, realm, authreq, authenticate_fn):
116
    http.auth.HttpServerRequestAuthentication.__init__(self)
117

    
118
    self.realm = realm
119
    self.authreq = authreq
120
    self.authenticate_fn = authenticate_fn
121

    
122
  def AuthenticationRequired(self, req):
123
    return self.authreq
124

    
125
  def GetAuthRealm(self, req):
126
    return self.realm
127

    
128
  def Authenticate(self, *args):
129
    if self.authenticate_fn:
130
      return self.authenticate_fn(*args)
131
    raise NotImplementedError()
132

    
133

    
134
class TestAuth(unittest.TestCase):
135
  """Authentication tests"""
136

    
137
  hsra = http.auth.HttpServerRequestAuthentication
138

    
139
  def testConstants(self):
140
    for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
141
      self.assertEqual(scheme, scheme.upper())
142
      self.assert_(scheme.startswith("{"))
143
      self.assert_(scheme.endswith("}"))
144

    
145
  def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
146
    ra = _FakeRequestAuth(realm, False, None)
147

    
148
    return ra.VerifyBasicAuthPassword(None, user, password, expected)
149

    
150
  def testVerifyBasicAuthPassword(self):
151
    tvbap = self._testVerifyBasicAuthPassword
152

    
153
    good_pws = ["pw", "pw{", "pw}", "pw{}", "pw{x}y", "}pw",
154
                "0", "123", "foo...:xyz", "TeST"]
155

    
156
    for pw in good_pws:
157
      # Try cleartext passwords
158
      self.assert_(tvbap("abc", "user", pw, pw))
159
      self.assert_(tvbap("abc", "user", pw, "{cleartext}" + pw))
160
      self.assert_(tvbap("abc", "user", pw, "{ClearText}" + pw))
161
      self.assert_(tvbap("abc", "user", pw, "{CLEARTEXT}" + pw))
162

    
163
      # Try with invalid password
164
      self.failIf(tvbap("abc", "user", pw, "something"))
165

    
166
      # Try with invalid scheme
167
      self.failIf(tvbap("abc", "user", pw, "{000}" + pw))
168
      self.failIf(tvbap("abc", "user", pw, "{unk}" + pw))
169
      self.failIf(tvbap("abc", "user", pw, "{Unk}" + pw))
170
      self.failIf(tvbap("abc", "user", pw, "{UNK}" + pw))
171

    
172
    # Try with invalid scheme format
173
    self.failIf(tvbap("abc", "user", "pw", "{something"))
174

    
175
    # Hash is MD5("user:This is only a test:pw")
176
    self.assert_(tvbap("This is only a test", "user", "pw",
177
                       "{ha1}92ea58ae804481498c257b2f65561a17"))
178
    self.assert_(tvbap("This is only a test", "user", "pw",
179
                       "{HA1}92ea58ae804481498c257b2f65561a17"))
180

    
181
    self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
182
                          "{HA1}92ea58ae804481498c257b2f65561a17")
183
    self.failIf(tvbap("Admin area", "user", "pw",
184
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
185
    self.failIf(tvbap("This is only a test", "someone", "pw",
186
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
187
    self.failIf(tvbap("This is only a test", "user", "something",
188
                      "{HA1}92ea58ae804481498c257b2f65561a17"))
189

    
190

    
191
class _SimpleAuthenticator:
192
  def __init__(self, user, password):
193
    self.user = user
194
    self.password = password
195
    self.called = False
196

    
197
  def __call__(self, req, user, password):
198
    self.called = True
199
    return self.user == user and self.password == password
200

    
201

    
202
class TestHttpServerRequestAuthentication(unittest.TestCase):
203
  def testNoAuth(self):
204
    req = http.server._HttpServerRequest("GET", "/", None, None)
205
    _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
206

    
207
  def testNoRealm(self):
208
    headers = { http.HTTP_AUTHORIZATION: "", }
209
    req = http.server._HttpServerRequest("GET", "/", headers, None)
210
    ra = _FakeRequestAuth(None, False, None)
211
    self.assertRaises(AssertionError, ra.PreHandleRequest, req)
212

    
213
  def testNoScheme(self):
214
    headers = { http.HTTP_AUTHORIZATION: "", }
215
    req = http.server._HttpServerRequest("GET", "/", headers, None)
216
    ra = _FakeRequestAuth("area1", False, None)
217
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
218

    
219
  def testUnknownScheme(self):
220
    headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
221
    req = http.server._HttpServerRequest("GET", "/", headers, None)
222
    ra = _FakeRequestAuth("area1", False, None)
223
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
224

    
225
  def testInvalidBase64(self):
226
    headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
227
    req = http.server._HttpServerRequest("GET", "/", headers, None)
228
    ra = _FakeRequestAuth("area1", False, None)
229
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
230

    
231
  def testAuthForPublicResource(self):
232
    headers = {
233
      http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
234
      }
235
    req = http.server._HttpServerRequest("GET", "/", headers, None)
236
    ra = _FakeRequestAuth("area1", False, None)
237
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
238

    
239
  def testAuthForPublicResource(self):
240
    headers = {
241
      http.HTTP_AUTHORIZATION:
242
        "Basic %s" % ("foo:bar".encode("base64").strip(), ),
243
      }
244
    req = http.server._HttpServerRequest("GET", "/", headers, None)
245
    ac = _SimpleAuthenticator("foo", "bar")
246
    ra = _FakeRequestAuth("area1", False, ac)
247
    ra.PreHandleRequest(req)
248

    
249
    req = http.server._HttpServerRequest("GET", "/", headers, None)
250
    ac = _SimpleAuthenticator("something", "else")
251
    ra = _FakeRequestAuth("area1", False, ac)
252
    self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
253

    
254
  def testInvalidRequestHeader(self):
255
    checks = {
256
      http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
257
                              "basic %s" % "foobar".encode("base64").strip()],
258
      http.HttpBadRequest: ["Basic"],
259
      }
260

    
261
    for exc, headers in checks.items():
262
      for i in headers:
263
        headers = { http.HTTP_AUTHORIZATION: i, }
264
        req = http.server._HttpServerRequest("GET", "/", headers, None)
265
        ra = _FakeRequestAuth("area1", False, None)
266
        self.assertRaises(exc, ra.PreHandleRequest, req)
267

    
268
  def testBasicAuth(self):
269
    for user in ["", "joe", "user name with spaces"]:
270
      for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
271
                 "foo:bar:baz"]:
272
        for wrong_pw in [True, False]:
273
          basic_auth = "%s:%s" % (user, pw)
274
          if wrong_pw:
275
            basic_auth += "WRONG"
276
          headers = {
277
              http.HTTP_AUTHORIZATION:
278
                "Basic %s" % (basic_auth.encode("base64").strip(), ),
279
            }
280
          req = http.server._HttpServerRequest("GET", "/", headers, None)
281

    
282
          ac = _SimpleAuthenticator(user, pw)
283
          self.assertFalse(ac.called)
284
          ra = _FakeRequestAuth("area1", True, ac)
285
          if wrong_pw:
286
            try:
287
              ra.PreHandleRequest(req)
288
            except http.HttpUnauthorized, err:
289
              www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
290
              self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
291
            else:
292
              self.fail("Didn't raise HttpUnauthorized")
293
          else:
294
            ra.PreHandleRequest(req)
295
          self.assert_(ac.called)
296

    
297

    
298
class TestReadPasswordFile(unittest.TestCase):
299
  def testSimple(self):
300
    users = http.auth.ParsePasswordFile("user1 password")
301
    self.assertEqual(len(users), 1)
302
    self.assertEqual(users["user1"].password, "password")
303
    self.assertEqual(len(users["user1"].options), 0)
304

    
305
  def testOptions(self):
306
    buf = StringIO()
307
    buf.write("# Passwords\n")
308
    buf.write("user1 password\n")
309
    buf.write("\n")
310
    buf.write("# Comment\n")
311
    buf.write("user2 pw write,read\n")
312
    buf.write("   \t# Another comment\n")
313
    buf.write("invalidline\n")
314

    
315
    users = http.auth.ParsePasswordFile(buf.getvalue())
316
    self.assertEqual(len(users), 2)
317
    self.assertEqual(users["user1"].password, "password")
318
    self.assertEqual(len(users["user1"].options), 0)
319

    
320
    self.assertEqual(users["user2"].password, "pw")
321
    self.assertEqual(users["user2"].options, ["write", "read"])
322

    
323

    
324
class TestClientRequest(unittest.TestCase):
325
  def testRepr(self):
326
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
327
                                       headers=[], post_data="Hello World")
328
    self.assert_(repr(cr).startswith("<"))
329

    
330
  def testNoHeaders(self):
331
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
332
                                       headers=None)
333
    self.assert_(isinstance(cr.headers, list))
334
    self.assertEqual(cr.headers, [])
335
    self.assertEqual(cr.url, "https://localhost:1234/version")
336

    
337
  def testPlainAddressIPv4(self):
338
    cr = http.client.HttpClientRequest("192.0.2.9", 19956, "GET", "/version")
339
    self.assertEqual(cr.url, "https://192.0.2.9:19956/version")
340

    
341
  def testPlainAddressIPv6(self):
342
    cr = http.client.HttpClientRequest("2001:db8::cafe", 15110, "GET", "/info")
343
    self.assertEqual(cr.url, "https://[2001:db8::cafe]:15110/info")
344

    
345
  def testOldStyleHeaders(self):
346
    headers = {
347
      "Content-type": "text/plain",
348
      "Accept": "text/html",
349
      }
350
    cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
351
                                       headers=headers)
352
    self.assert_(isinstance(cr.headers, list))
353
    self.assertEqual(sorted(cr.headers), [
354
      "Accept: text/html",
355
      "Content-type: text/plain",
356
      ])
357
    self.assertEqual(cr.url, "https://localhost:16481/vg_list")
358

    
359
  def testNewStyleHeaders(self):
360
    headers = [
361
      "Accept: text/html",
362
      "Content-type: text/plain; charset=ascii",
363
      "Server: httpd 1.0",
364
      ]
365
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
366
                                       headers=headers)
367
    self.assert_(isinstance(cr.headers, list))
368
    self.assertEqual(sorted(cr.headers), sorted(headers))
369
    self.assertEqual(cr.url, "https://localhost:1234/version")
370

    
371
  def testPostData(self):
372
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
373
                                       post_data="Hello World")
374
    self.assertEqual(cr.post_data, "Hello World")
375

    
376
  def testNoPostData(self):
377
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
378
    self.assertEqual(cr.post_data, "")
379

    
380

    
381
class _FakeCurl:
382
  def __init__(self):
383
    self.opts = {}
384
    self.info = NotImplemented
385

    
386
  def setopt(self, opt, value):
387
    assert opt not in self.opts, "Option set more than once"
388
    self.opts[opt] = value
389

    
390
  def getinfo(self, info):
391
    return self.info.pop(info)
392

    
393

    
394
class TestClientStartRequest(unittest.TestCase):
395
  @staticmethod
396
  def _TestCurlConfig(curl):
397
    curl.setopt(pycurl.SSLKEYTYPE, "PEM")
398

    
399
  def test(self):
400
    for method in [http.HTTP_GET, http.HTTP_PUT, "CUSTOM"]:
401
      for port in [8761, 29796, 19528]:
402
        for curl_config_fn in [None, self._TestCurlConfig]:
403
          for read_timeout in [None, 0, 1, 123, 36000]:
404
            self._TestInner(method, port, curl_config_fn, read_timeout)
405

    
406
  def _TestInner(self, method, port, curl_config_fn, read_timeout):
407
    for response_code in [http.HTTP_OK, http.HttpNotFound.code,
408
                          http.HTTP_NOT_MODIFIED]:
409
      for response_body in [None, "Hello World",
410
                            "Very Long\tContent here\n" * 171]:
411
        for errmsg in [None, "error"]:
412
          req = http.client.HttpClientRequest("localhost", port, method,
413
                                              "/version",
414
                                              curl_config_fn=curl_config_fn,
415
                                              read_timeout=read_timeout)
416
          curl = _FakeCurl()
417
          pending = http.client._StartRequest(curl, req)
418
          self.assertEqual(pending.GetCurlHandle(), curl)
419
          self.assertEqual(pending.GetCurrentRequest(), req)
420

    
421
          # Check options
422
          opts = curl.opts
423
          self.assertEqual(opts.pop(pycurl.CUSTOMREQUEST), method)
424
          self.assertEqual(opts.pop(pycurl.URL),
425
                           "https://localhost:%s/version" % port)
426
          if read_timeout is None:
427
            self.assertEqual(opts.pop(pycurl.TIMEOUT), 0)
428
          else:
429
            self.assertEqual(opts.pop(pycurl.TIMEOUT), read_timeout)
430
          self.assertFalse(opts.pop(pycurl.VERBOSE))
431
          self.assertTrue(opts.pop(pycurl.NOSIGNAL))
432
          self.assertEqual(opts.pop(pycurl.USERAGENT),
433
                           http.HTTP_GANETI_VERSION)
434
          self.assertEqual(opts.pop(pycurl.PROXY), "")
435
          self.assertFalse(opts.pop(pycurl.POSTFIELDS))
436
          self.assertFalse(opts.pop(pycurl.HTTPHEADER))
437
          write_fn = opts.pop(pycurl.WRITEFUNCTION)
438
          self.assertTrue(callable(write_fn))
439
          if hasattr(pycurl, "SSL_SESSIONID_CACHE"):
440
            self.assertFalse(opts.pop(pycurl.SSL_SESSIONID_CACHE))
441
          if curl_config_fn:
442
            self.assertEqual(opts.pop(pycurl.SSLKEYTYPE), "PEM")
443
          else:
444
            self.assertFalse(pycurl.SSLKEYTYPE in opts)
445
          self.assertFalse(opts)
446

    
447
          if response_body is not None:
448
            offset = 0
449
            while offset < len(response_body):
450
              piece = response_body[offset:offset + 10]
451
              write_fn(piece)
452
              offset += len(piece)
453

    
454
          curl.info = {
455
            pycurl.RESPONSE_CODE: response_code,
456
            }
457

    
458
          # Finalize request
459
          pending.Done(errmsg)
460

    
461
          self.assertFalse(curl.info)
462

    
463
          # Can only finalize once
464
          self.assertRaises(AssertionError, pending.Done, True)
465

    
466
          if errmsg:
467
            self.assertFalse(req.success)
468
          else:
469
            self.assertTrue(req.success)
470
          self.assertEqual(req.error, errmsg)
471
          self.assertEqual(req.resp_status_code, response_code)
472
          if response_body is None:
473
            self.assertEqual(req.resp_body, "")
474
          else:
475
            self.assertEqual(req.resp_body, response_body)
476

    
477
          # Check if resetting worked
478
          assert not hasattr(curl, "reset")
479
          opts = curl.opts
480
          self.assertFalse(opts.pop(pycurl.POSTFIELDS))
481
          self.assertTrue(callable(opts.pop(pycurl.WRITEFUNCTION)))
482
          self.assertFalse(opts)
483

    
484
          self.assertFalse(curl.opts,
485
                           msg="Previous checks did not consume all options")
486
          assert id(opts) == id(curl.opts)
487

    
488
  def _TestWrongTypes(self, *args, **kwargs):
489
    req = http.client.HttpClientRequest(*args, **kwargs)
490
    self.assertRaises(AssertionError, http.client._StartRequest,
491
                      _FakeCurl(), req)
492

    
493
  def testWrongHostType(self):
494
    self._TestWrongTypes(unicode("localhost"), 8080, "GET", "/version")
495

    
496
  def testWrongUrlType(self):
497
    self._TestWrongTypes("localhost", 8080, "GET", unicode("/version"))
498

    
499
  def testWrongMethodType(self):
500
    self._TestWrongTypes("localhost", 8080, unicode("GET"), "/version")
501

    
502
  def testWrongHeaderType(self):
503
    self._TestWrongTypes("localhost", 8080, "GET", "/version",
504
                         headers={
505
                           unicode("foo"): "bar",
506
                           })
507

    
508
  def testWrongPostDataType(self):
509
    self._TestWrongTypes("localhost", 8080, "GET", "/version",
510
                         post_data=unicode("verylongdata" * 100))
511

    
512

    
513
class _EmptyCurlMulti:
514
  def perform(self):
515
    return (pycurl.E_MULTI_OK, 0)
516

    
517
  def info_read(self):
518
    return (0, [], [])
519

    
520

    
521
class TestClientProcessRequests(unittest.TestCase):
522
  def testEmpty(self):
523
    requests = []
524
    http.client.ProcessRequests(requests, _curl=NotImplemented,
525
                                _curl_multi=_EmptyCurlMulti)
526
    self.assertEqual(requests, [])
527

    
528

    
529
class TestProcessCurlRequests(unittest.TestCase):
530
  class _FakeCurlMulti:
531
    def __init__(self):
532
      self.handles = []
533
      self.will_fail = []
534
      self._expect = ["perform"]
535
      self._counter = itertools.count()
536

    
537
    def add_handle(self, curl):
538
      assert curl not in self.handles
539
      self.handles.append(curl)
540
      if self._counter.next() % 3 == 0:
541
        self.will_fail.append(curl)
542

    
543
    def remove_handle(self, curl):
544
      self.handles.remove(curl)
545

    
546
    def perform(self):
547
      assert self._expect.pop(0) == "perform"
548

    
549
      if self._counter.next() % 2 == 0:
550
        self._expect.append("perform")
551
        return (pycurl.E_CALL_MULTI_PERFORM, None)
552

    
553
      self._expect.append("info_read")
554

    
555
      return (pycurl.E_MULTI_OK, len(self.handles))
556

    
557
    def info_read(self):
558
      assert self._expect.pop(0) == "info_read"
559
      successful = []
560
      failed = []
561
      if self.handles:
562
        if self._counter.next() % 17 == 0:
563
          curl = self.handles[0]
564
          if curl in self.will_fail:
565
            failed.append((curl, -1, "test error"))
566
          else:
567
            successful.append(curl)
568
        remaining_messages = len(self.handles) % 3
569
        if remaining_messages > 0:
570
          self._expect.append("info_read")
571
        else:
572
          self._expect.append("select")
573
      else:
574
        remaining_messages = 0
575
        self._expect.append("select")
576
      return (remaining_messages, successful, failed)
577

    
578
    def select(self, timeout):
579
      # Never compare floats for equality
580
      assert timeout >= 0.95 and timeout <= 1.05
581
      assert self._expect.pop(0) == "select"
582
      self._expect.append("perform")
583

    
584
  def test(self):
585
    requests = [_FakeCurl() for _ in range(10)]
586
    multi = self._FakeCurlMulti()
587
    for (curl, errmsg) in http.client._ProcessCurlRequests(multi, requests):
588
      self.assertTrue(curl not in multi.handles)
589
      if curl in multi.will_fail:
590
        self.assertTrue("test error" in errmsg)
591
      else:
592
        self.assertTrue(errmsg is None)
593
    self.assertFalse(multi.handles)
594
    self.assertEqual(multi._expect, ["select"])
595

    
596

    
597
class TestProcessRequests(unittest.TestCase):
598
  class _DummyCurlMulti:
599
    pass
600

    
601
  def testNoMonitor(self):
602
    self._Test(False)
603

    
604
  def testWithMonitor(self):
605
    self._Test(True)
606

    
607
  class _MonitorChecker:
608
    def __init__(self):
609
      self._monitor = None
610

    
611
    def GetMonitor(self):
612
      return self._monitor
613

    
614
    def __call__(self, monitor):
615
      assert callable(monitor.GetLockInfo)
616
      self._monitor = monitor
617

    
618
  def _Test(self, use_monitor):
619
    def cfg_fn(port, curl):
620
      curl.opts["__port__"] = port
621

    
622
    def _LockCheckReset(monitor, curl):
623
      self.assertTrue(monitor._lock.is_owned(shared=0),
624
                      msg="Lock must be owned in exclusive mode")
625
      curl.opts["__lockcheck__"] = True
626

    
627
    requests = \
628
      [http.client.HttpClientRequest("localhost", i, "POST", "/version%s" % i,
629
                                     curl_config_fn=compat.partial(cfg_fn, i))
630
       for i in range(15176, 15501)]
631
    requests_count = len(requests)
632

    
633
    if use_monitor:
634
      lock_monitor_cb = self._MonitorChecker()
635
    else:
636
      lock_monitor_cb = None
637

    
638
    def _ProcessRequests(multi, handles):
639
      self.assertTrue(isinstance(multi, self._DummyCurlMulti))
640
      self.assertEqual(len(requests), len(handles))
641
      self.assertTrue(compat.all(isinstance(curl, _FakeCurl)
642
                                 for curl in handles))
643

    
644
      for idx, curl in enumerate(handles):
645
        port = curl.opts["__port__"]
646

    
647
        if use_monitor:
648
          # Check if lock information is correct
649
          lock_info = lock_monitor_cb.GetMonitor().GetLockInfo(None)
650
          expected = \
651
            [("rpc/localhost/version%s" % handle.opts["__port__"], None,
652
              [threading.currentThread().getName()], None)
653
             for handle in handles[idx:]]
654
          self.assertEqual(sorted(lock_info), sorted(expected))
655

    
656
        if port % 3 == 0:
657
          response_code = http.HTTP_OK
658
          msg = None
659
        else:
660
          response_code = http.HttpNotFound.code
661
          msg = "test error"
662

    
663
        curl.info = {
664
          pycurl.RESPONSE_CODE: response_code,
665
          }
666

    
667
        # Unset options which will be reset
668
        assert not hasattr(curl, "reset")
669
        if use_monitor:
670
          setattr(curl, "reset",
671
                  compat.partial(_LockCheckReset, lock_monitor_cb.GetMonitor(),
672
                                 curl))
673
        else:
674
          self.assertFalse(curl.opts.pop(pycurl.POSTFIELDS))
675
          self.assertTrue(callable(curl.opts.pop(pycurl.WRITEFUNCTION)))
676

    
677
        yield (curl, msg)
678

    
679
      if use_monitor:
680
        self.assertTrue(compat.all(curl.opts["__lockcheck__"]
681
                                   for curl in handles))
682

    
683
    http.client.ProcessRequests(requests, lock_monitor_cb=lock_monitor_cb,
684
                                _curl=_FakeCurl,
685
                                _curl_multi=self._DummyCurlMulti,
686
                                _curl_process=_ProcessRequests)
687
    for req in requests:
688
      if req.port % 3 == 0:
689
        self.assertTrue(req.success)
690
        self.assertEqual(req.error, None)
691
      else:
692
        self.assertFalse(req.success)
693
        self.assertTrue("test error" in req.error)
694

    
695
    # See if monitor was disabled
696
    if use_monitor:
697
      monitor = lock_monitor_cb.GetMonitor()
698
      self.assertEqual(monitor._pending_fn, None)
699
      self.assertEqual(monitor.GetLockInfo(None), [])
700
    else:
701
      self.assertEqual(lock_monitor_cb, None)
702

    
703
    self.assertEqual(len(requests), requests_count)
704

    
705
  def testBadRequest(self):
706
    bad_request = http.client.HttpClientRequest("localhost", 27784,
707
                                                "POST", "/version")
708
    bad_request.success = False
709

    
710
    self.assertRaises(AssertionError, http.client.ProcessRequests,
711
                      [bad_request], _curl=NotImplemented,
712
                      _curl_multi=NotImplemented, _curl_process=NotImplemented)
713

    
714

    
715
if __name__ == '__main__':
716
  testutils.GanetiTestProgram()