bash_completion: Enable extglob while parsing file
[ganeti-local] / test / ganeti.http_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2007, 2008 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the http module"""
23
24
25 import os
26 import unittest
27 import time
28 import tempfile
29 import pycurl
30 import itertools
31 import threading
32 from cStringIO import StringIO
33
34 from ganeti import http
35 from ganeti import compat
36
37 import ganeti.http.server
38 import ganeti.http.client
39 import ganeti.http.auth
40
41 import testutils
42
43
44 class TestStartLines(unittest.TestCase):
45   """Test cases for start line classes"""
46
47   def testClientToServerStartLine(self):
48     """Test client to server start line (HTTP request)"""
49     start_line = http.HttpClientToServerStartLine("GET", "/", "HTTP/1.1")
50     self.assertEqual(str(start_line), "GET / HTTP/1.1")
51
52   def testServerToClientStartLine(self):
53     """Test server to client start line (HTTP response)"""
54     start_line = http.HttpServerToClientStartLine("HTTP/1.1", 200, "OK")
55     self.assertEqual(str(start_line), "HTTP/1.1 200 OK")
56
57
58 class TestMisc(unittest.TestCase):
59   """Miscellaneous tests"""
60
61   def _TestDateTimeHeader(self, gmnow, expected):
62     self.assertEqual(http.server._DateTimeHeader(gmnow=gmnow), expected)
63
64   def testDateTimeHeader(self):
65     """Test ganeti.http._DateTimeHeader"""
66     self._TestDateTimeHeader((2008, 1, 2, 3, 4, 5, 3, 0, 0),
67                              "Thu, 02 Jan 2008 03:04:05 GMT")
68     self._TestDateTimeHeader((2008, 1, 1, 0, 0, 0, 0, 0, 0),
69                              "Mon, 01 Jan 2008 00:00:00 GMT")
70     self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 0, 0, 0),
71                              "Mon, 31 Dec 2008 00:00:00 GMT")
72     self._TestDateTimeHeader((2008, 12, 31, 23, 59, 59, 0, 0, 0),
73                              "Mon, 31 Dec 2008 23:59:59 GMT")
74     self._TestDateTimeHeader((2008, 12, 31, 0, 0, 0, 6, 0, 0),
75                              "Sun, 31 Dec 2008 00:00:00 GMT")
76
77   def testHttpServerRequest(self):
78     """Test ganeti.http.server._HttpServerRequest"""
79     server_request = http.server._HttpServerRequest("GET", "/", None, None)
80
81     # These are expected by users of the HTTP server
82     self.assert_(hasattr(server_request, "request_method"))
83     self.assert_(hasattr(server_request, "request_path"))
84     self.assert_(hasattr(server_request, "request_headers"))
85     self.assert_(hasattr(server_request, "request_body"))
86     self.assert_(isinstance(server_request.resp_headers, dict))
87     self.assert_(hasattr(server_request, "private"))
88
89   def testServerSizeLimits(self):
90     """Test HTTP server size limits"""
91     message_reader_class = http.server._HttpClientToServerMessageReader
92     self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
93     self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
94
95   def testFormatAuthHeader(self):
96     self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
97                      "Basic")
98     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "bar", }),
99                      "Basic foo=bar")
100     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "", }),
101                      "Basic foo=\"\"")
102     self.assertEqual(http.auth._FormatAuthHeader("Basic", { "foo": "x,y", }),
103                      "Basic foo=\"x,y\"")
104     params = {
105       "foo": "x,y",
106       "realm": "secure",
107       }
108     # It's a dict whose order isn't guaranteed, hence checking a list
109     self.assert_(http.auth._FormatAuthHeader("Digest", params) in
110                  ("Digest foo=\"x,y\" realm=secure",
111                   "Digest realm=secure foo=\"x,y\""))
112
113
114 class _FakeRequestAuth(http.auth.HttpServerRequestAuthentication):
115   def __init__(self, realm, authreq, authenticate_fn):
116     http.auth.HttpServerRequestAuthentication.__init__(self)
117
118     self.realm = realm
119     self.authreq = authreq
120     self.authenticate_fn = authenticate_fn
121
122   def AuthenticationRequired(self, req):
123     return self.authreq
124
125   def GetAuthRealm(self, req):
126     return self.realm
127
128   def Authenticate(self, *args):
129     if self.authenticate_fn:
130       return self.authenticate_fn(*args)
131     raise NotImplementedError()
132
133
134 class TestAuth(unittest.TestCase):
135   """Authentication tests"""
136
137   hsra = http.auth.HttpServerRequestAuthentication
138
139   def testConstants(self):
140     for scheme in [self.hsra._CLEARTEXT_SCHEME, self.hsra._HA1_SCHEME]:
141       self.assertEqual(scheme, scheme.upper())
142       self.assert_(scheme.startswith("{"))
143       self.assert_(scheme.endswith("}"))
144
145   def _testVerifyBasicAuthPassword(self, realm, user, password, expected):
146     ra = _FakeRequestAuth(realm, False, None)
147
148     return ra.VerifyBasicAuthPassword(None, user, password, expected)
149
150   def testVerifyBasicAuthPassword(self):
151     tvbap = self._testVerifyBasicAuthPassword
152
153     good_pws = ["pw", "pw{", "pw}", "pw{}", "pw{x}y", "}pw",
154                 "0", "123", "foo...:xyz", "TeST"]
155
156     for pw in good_pws:
157       # Try cleartext passwords
158       self.assert_(tvbap("abc", "user", pw, pw))
159       self.assert_(tvbap("abc", "user", pw, "{cleartext}" + pw))
160       self.assert_(tvbap("abc", "user", pw, "{ClearText}" + pw))
161       self.assert_(tvbap("abc", "user", pw, "{CLEARTEXT}" + pw))
162
163       # Try with invalid password
164       self.failIf(tvbap("abc", "user", pw, "something"))
165
166       # Try with invalid scheme
167       self.failIf(tvbap("abc", "user", pw, "{000}" + pw))
168       self.failIf(tvbap("abc", "user", pw, "{unk}" + pw))
169       self.failIf(tvbap("abc", "user", pw, "{Unk}" + pw))
170       self.failIf(tvbap("abc", "user", pw, "{UNK}" + pw))
171
172     # Try with invalid scheme format
173     self.failIf(tvbap("abc", "user", "pw", "{something"))
174
175     # Hash is MD5("user:This is only a test:pw")
176     self.assert_(tvbap("This is only a test", "user", "pw",
177                        "{ha1}92ea58ae804481498c257b2f65561a17"))
178     self.assert_(tvbap("This is only a test", "user", "pw",
179                        "{HA1}92ea58ae804481498c257b2f65561a17"))
180
181     self.failUnlessRaises(AssertionError, tvbap, None, "user", "pw",
182                           "{HA1}92ea58ae804481498c257b2f65561a17")
183     self.failIf(tvbap("Admin area", "user", "pw",
184                       "{HA1}92ea58ae804481498c257b2f65561a17"))
185     self.failIf(tvbap("This is only a test", "someone", "pw",
186                       "{HA1}92ea58ae804481498c257b2f65561a17"))
187     self.failIf(tvbap("This is only a test", "user", "something",
188                       "{HA1}92ea58ae804481498c257b2f65561a17"))
189
190
191 class _SimpleAuthenticator:
192   def __init__(self, user, password):
193     self.user = user
194     self.password = password
195     self.called = False
196
197   def __call__(self, req, user, password):
198     self.called = True
199     return self.user == user and self.password == password
200
201
202 class TestHttpServerRequestAuthentication(unittest.TestCase):
203   def testNoAuth(self):
204     req = http.server._HttpServerRequest("GET", "/", None, None)
205     _FakeRequestAuth("area1", False, None).PreHandleRequest(req)
206
207   def testNoRealm(self):
208     headers = { http.HTTP_AUTHORIZATION: "", }
209     req = http.server._HttpServerRequest("GET", "/", headers, None)
210     ra = _FakeRequestAuth(None, False, None)
211     self.assertRaises(AssertionError, ra.PreHandleRequest, req)
212
213   def testNoScheme(self):
214     headers = { http.HTTP_AUTHORIZATION: "", }
215     req = http.server._HttpServerRequest("GET", "/", headers, None)
216     ra = _FakeRequestAuth("area1", False, None)
217     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
218
219   def testUnknownScheme(self):
220     headers = { http.HTTP_AUTHORIZATION: "NewStyleAuth abc", }
221     req = http.server._HttpServerRequest("GET", "/", headers, None)
222     ra = _FakeRequestAuth("area1", False, None)
223     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
224
225   def testInvalidBase64(self):
226     headers = { http.HTTP_AUTHORIZATION: "Basic x_=_", }
227     req = http.server._HttpServerRequest("GET", "/", headers, None)
228     ra = _FakeRequestAuth("area1", False, None)
229     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
230
231   def testAuthForPublicResource(self):
232     headers = {
233       http.HTTP_AUTHORIZATION: "Basic %s" % ("foo".encode("base64").strip(), ),
234       }
235     req = http.server._HttpServerRequest("GET", "/", headers, None)
236     ra = _FakeRequestAuth("area1", False, None)
237     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
238
239   def testAuthForPublicResource(self):
240     headers = {
241       http.HTTP_AUTHORIZATION:
242         "Basic %s" % ("foo:bar".encode("base64").strip(), ),
243       }
244     req = http.server._HttpServerRequest("GET", "/", headers, None)
245     ac = _SimpleAuthenticator("foo", "bar")
246     ra = _FakeRequestAuth("area1", False, ac)
247     ra.PreHandleRequest(req)
248
249     req = http.server._HttpServerRequest("GET", "/", headers, None)
250     ac = _SimpleAuthenticator("something", "else")
251     ra = _FakeRequestAuth("area1", False, ac)
252     self.assertRaises(http.HttpUnauthorized, ra.PreHandleRequest, req)
253
254   def testInvalidRequestHeader(self):
255     checks = {
256       http.HttpUnauthorized: ["", "\t", "-", ".", "@", "<", ">", "Digest",
257                               "basic %s" % "foobar".encode("base64").strip()],
258       http.HttpBadRequest: ["Basic"],
259       }
260
261     for exc, headers in checks.items():
262       for i in headers:
263         headers = { http.HTTP_AUTHORIZATION: i, }
264         req = http.server._HttpServerRequest("GET", "/", headers, None)
265         ra = _FakeRequestAuth("area1", False, None)
266         self.assertRaises(exc, ra.PreHandleRequest, req)
267
268   def testBasicAuth(self):
269     for user in ["", "joe", "user name with spaces"]:
270       for pw in ["", "-", ":", "foobar", "Foo Bar Baz", "@@@", "###",
271                  "foo:bar:baz"]:
272         for wrong_pw in [True, False]:
273           basic_auth = "%s:%s" % (user, pw)
274           if wrong_pw:
275             basic_auth += "WRONG"
276           headers = {
277               http.HTTP_AUTHORIZATION:
278                 "Basic %s" % (basic_auth.encode("base64").strip(), ),
279             }
280           req = http.server._HttpServerRequest("GET", "/", headers, None)
281
282           ac = _SimpleAuthenticator(user, pw)
283           self.assertFalse(ac.called)
284           ra = _FakeRequestAuth("area1", True, ac)
285           if wrong_pw:
286             try:
287               ra.PreHandleRequest(req)
288             except http.HttpUnauthorized, err:
289               www_auth = err.headers[http.HTTP_WWW_AUTHENTICATE]
290               self.assert_(www_auth.startswith(http.auth.HTTP_BASIC_AUTH))
291             else:
292               self.fail("Didn't raise HttpUnauthorized")
293           else:
294             ra.PreHandleRequest(req)
295           self.assert_(ac.called)
296
297
298 class TestReadPasswordFile(unittest.TestCase):
299   def testSimple(self):
300     users = http.auth.ParsePasswordFile("user1 password")
301     self.assertEqual(len(users), 1)
302     self.assertEqual(users["user1"].password, "password")
303     self.assertEqual(len(users["user1"].options), 0)
304
305   def testOptions(self):
306     buf = StringIO()
307     buf.write("# Passwords\n")
308     buf.write("user1 password\n")
309     buf.write("\n")
310     buf.write("# Comment\n")
311     buf.write("user2 pw write,read\n")
312     buf.write("   \t# Another comment\n")
313     buf.write("invalidline\n")
314
315     users = http.auth.ParsePasswordFile(buf.getvalue())
316     self.assertEqual(len(users), 2)
317     self.assertEqual(users["user1"].password, "password")
318     self.assertEqual(len(users["user1"].options), 0)
319
320     self.assertEqual(users["user2"].password, "pw")
321     self.assertEqual(users["user2"].options, ["write", "read"])
322
323
324 class TestClientRequest(unittest.TestCase):
325   def testRepr(self):
326     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
327                                        headers=[], post_data="Hello World")
328     self.assert_(repr(cr).startswith("<"))
329
330   def testNoHeaders(self):
331     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
332                                        headers=None)
333     self.assert_(isinstance(cr.headers, list))
334     self.assertEqual(cr.headers, [])
335     self.assertEqual(cr.url, "https://localhost:1234/version")
336
337   def testPlainAddressIPv4(self):
338     cr = http.client.HttpClientRequest("192.0.2.9", 19956, "GET", "/version")
339     self.assertEqual(cr.url, "https://192.0.2.9:19956/version")
340
341   def testPlainAddressIPv6(self):
342     cr = http.client.HttpClientRequest("2001:db8::cafe", 15110, "GET", "/info")
343     self.assertEqual(cr.url, "https://[2001:db8::cafe]:15110/info")
344
345   def testOldStyleHeaders(self):
346     headers = {
347       "Content-type": "text/plain",
348       "Accept": "text/html",
349       }
350     cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
351                                        headers=headers)
352     self.assert_(isinstance(cr.headers, list))
353     self.assertEqual(sorted(cr.headers), [
354       "Accept: text/html",
355       "Content-type: text/plain",
356       ])
357     self.assertEqual(cr.url, "https://localhost:16481/vg_list")
358
359   def testNewStyleHeaders(self):
360     headers = [
361       "Accept: text/html",
362       "Content-type: text/plain; charset=ascii",
363       "Server: httpd 1.0",
364       ]
365     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
366                                        headers=headers)
367     self.assert_(isinstance(cr.headers, list))
368     self.assertEqual(sorted(cr.headers), sorted(headers))
369     self.assertEqual(cr.url, "https://localhost:1234/version")
370
371   def testPostData(self):
372     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
373                                        post_data="Hello World")
374     self.assertEqual(cr.post_data, "Hello World")
375
376   def testNoPostData(self):
377     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
378     self.assertEqual(cr.post_data, "")
379
380   def testCompletionCallback(self):
381     for argname in ["completion_cb", "curl_config_fn"]:
382       kwargs = {
383         argname: NotImplementedError,
384         }
385       cr = http.client.HttpClientRequest("localhost", 14038, "GET", "/version",
386                                          **kwargs)
387       self.assertEqual(getattr(cr, argname), NotImplementedError)
388
389       for fn in [NotImplemented, {}, 1]:
390         kwargs = {
391           argname: fn,
392           }
393         self.assertRaises(AssertionError, http.client.HttpClientRequest,
394                           "localhost", 23150, "GET", "/version", **kwargs)
395
396
397 class _FakeCurl:
398   def __init__(self):
399     self.opts = {}
400     self.info = NotImplemented
401
402   def setopt(self, opt, value):
403     assert opt not in self.opts, "Option set more than once"
404     self.opts[opt] = value
405
406   def getinfo(self, info):
407     return self.info.pop(info)
408
409
410 class TestClientStartRequest(unittest.TestCase):
411   @staticmethod
412   def _TestCurlConfig(curl):
413     curl.setopt(pycurl.SSLKEYTYPE, "PEM")
414
415   def test(self):
416     for method in [http.HTTP_GET, http.HTTP_PUT, "CUSTOM"]:
417       for port in [8761, 29796, 19528]:
418         for curl_config_fn in [None, self._TestCurlConfig]:
419           for read_timeout in [None, 0, 1, 123, 36000]:
420             self._TestInner(method, port, curl_config_fn, read_timeout)
421
422   def _TestInner(self, method, port, curl_config_fn, read_timeout):
423     for response_code in [http.HTTP_OK, http.HttpNotFound.code,
424                           http.HTTP_NOT_MODIFIED]:
425       for response_body in [None, "Hello World",
426                             "Very Long\tContent here\n" * 171]:
427         for errmsg in [None, "error"]:
428           req = http.client.HttpClientRequest("localhost", port, method,
429                                               "/version",
430                                               curl_config_fn=curl_config_fn,
431                                               read_timeout=read_timeout)
432           curl = _FakeCurl()
433           pending = http.client._StartRequest(curl, req)
434           self.assertEqual(pending.GetCurlHandle(), curl)
435           self.assertEqual(pending.GetCurrentRequest(), req)
436
437           # Check options
438           opts = curl.opts
439           self.assertEqual(opts.pop(pycurl.CUSTOMREQUEST), method)
440           self.assertEqual(opts.pop(pycurl.URL),
441                            "https://localhost:%s/version" % port)
442           if read_timeout is None:
443             self.assertEqual(opts.pop(pycurl.TIMEOUT), 0)
444           else:
445             self.assertEqual(opts.pop(pycurl.TIMEOUT), read_timeout)
446           self.assertFalse(opts.pop(pycurl.VERBOSE))
447           self.assertTrue(opts.pop(pycurl.NOSIGNAL))
448           self.assertEqual(opts.pop(pycurl.USERAGENT),
449                            http.HTTP_GANETI_VERSION)
450           self.assertEqual(opts.pop(pycurl.PROXY), "")
451           self.assertFalse(opts.pop(pycurl.POSTFIELDS))
452           self.assertFalse(opts.pop(pycurl.HTTPHEADER))
453           write_fn = opts.pop(pycurl.WRITEFUNCTION)
454           self.assertTrue(callable(write_fn))
455           if hasattr(pycurl, "SSL_SESSIONID_CACHE"):
456             self.assertFalse(opts.pop(pycurl.SSL_SESSIONID_CACHE))
457           if curl_config_fn:
458             self.assertEqual(opts.pop(pycurl.SSLKEYTYPE), "PEM")
459           else:
460             self.assertFalse(pycurl.SSLKEYTYPE in opts)
461           self.assertFalse(opts)
462
463           if response_body is not None:
464             offset = 0
465             while offset < len(response_body):
466               piece = response_body[offset:offset + 10]
467               write_fn(piece)
468               offset += len(piece)
469
470           curl.info = {
471             pycurl.RESPONSE_CODE: response_code,
472             }
473
474           # Finalize request
475           pending.Done(errmsg)
476
477           self.assertFalse(curl.info)
478
479           # Can only finalize once
480           self.assertRaises(AssertionError, pending.Done, True)
481
482           if errmsg:
483             self.assertFalse(req.success)
484           else:
485             self.assertTrue(req.success)
486           self.assertEqual(req.error, errmsg)
487           self.assertEqual(req.resp_status_code, response_code)
488           if response_body is None:
489             self.assertEqual(req.resp_body, "")
490           else:
491             self.assertEqual(req.resp_body, response_body)
492
493           # Check if resetting worked
494           assert not hasattr(curl, "reset")
495           opts = curl.opts
496           self.assertFalse(opts.pop(pycurl.POSTFIELDS))
497           self.assertTrue(callable(opts.pop(pycurl.WRITEFUNCTION)))
498           self.assertFalse(opts)
499
500           self.assertFalse(curl.opts,
501                            msg="Previous checks did not consume all options")
502           assert id(opts) == id(curl.opts)
503
504   def _TestWrongTypes(self, *args, **kwargs):
505     req = http.client.HttpClientRequest(*args, **kwargs)
506     self.assertRaises(AssertionError, http.client._StartRequest,
507                       _FakeCurl(), req)
508
509   def testWrongHostType(self):
510     self._TestWrongTypes(unicode("localhost"), 8080, "GET", "/version")
511
512   def testWrongUrlType(self):
513     self._TestWrongTypes("localhost", 8080, "GET", unicode("/version"))
514
515   def testWrongMethodType(self):
516     self._TestWrongTypes("localhost", 8080, unicode("GET"), "/version")
517
518   def testWrongHeaderType(self):
519     self._TestWrongTypes("localhost", 8080, "GET", "/version",
520                          headers={
521                            unicode("foo"): "bar",
522                            })
523
524   def testWrongPostDataType(self):
525     self._TestWrongTypes("localhost", 8080, "GET", "/version",
526                          post_data=unicode("verylongdata" * 100))
527
528
529 class _EmptyCurlMulti:
530   def perform(self):
531     return (pycurl.E_MULTI_OK, 0)
532
533   def info_read(self):
534     return (0, [], [])
535
536
537 class TestClientProcessRequests(unittest.TestCase):
538   def testEmpty(self):
539     requests = []
540     http.client.ProcessRequests(requests, _curl=NotImplemented,
541                                 _curl_multi=_EmptyCurlMulti)
542     self.assertEqual(requests, [])
543
544
545 class TestProcessCurlRequests(unittest.TestCase):
546   class _FakeCurlMulti:
547     def __init__(self):
548       self.handles = []
549       self.will_fail = []
550       self._expect = ["perform"]
551       self._counter = itertools.count()
552
553     def add_handle(self, curl):
554       assert curl not in self.handles
555       self.handles.append(curl)
556       if self._counter.next() % 3 == 0:
557         self.will_fail.append(curl)
558
559     def remove_handle(self, curl):
560       self.handles.remove(curl)
561
562     def perform(self):
563       assert self._expect.pop(0) == "perform"
564
565       if self._counter.next() % 2 == 0:
566         self._expect.append("perform")
567         return (pycurl.E_CALL_MULTI_PERFORM, None)
568
569       self._expect.append("info_read")
570
571       return (pycurl.E_MULTI_OK, len(self.handles))
572
573     def info_read(self):
574       assert self._expect.pop(0) == "info_read"
575       successful = []
576       failed = []
577       if self.handles:
578         if self._counter.next() % 17 == 0:
579           curl = self.handles[0]
580           if curl in self.will_fail:
581             failed.append((curl, -1, "test error"))
582           else:
583             successful.append(curl)
584         remaining_messages = len(self.handles) % 3
585         if remaining_messages > 0:
586           self._expect.append("info_read")
587         else:
588           self._expect.append("select")
589       else:
590         remaining_messages = 0
591         self._expect.append("select")
592       return (remaining_messages, successful, failed)
593
594     def select(self, timeout):
595       # Never compare floats for equality
596       assert timeout >= 0.95 and timeout <= 1.05
597       assert self._expect.pop(0) == "select"
598       self._expect.append("perform")
599
600   def test(self):
601     requests = [_FakeCurl() for _ in range(10)]
602     multi = self._FakeCurlMulti()
603     for (curl, errmsg) in http.client._ProcessCurlRequests(multi, requests):
604       self.assertTrue(curl not in multi.handles)
605       if curl in multi.will_fail:
606         self.assertTrue("test error" in errmsg)
607       else:
608         self.assertTrue(errmsg is None)
609     self.assertFalse(multi.handles)
610     self.assertEqual(multi._expect, ["select"])
611
612
613 class TestProcessRequests(unittest.TestCase):
614   class _DummyCurlMulti:
615     pass
616
617   def testNoMonitor(self):
618     self._Test(False)
619
620   def testWithMonitor(self):
621     self._Test(True)
622
623   class _MonitorChecker:
624     def __init__(self):
625       self._monitor = None
626
627     def GetMonitor(self):
628       return self._monitor
629
630     def __call__(self, monitor):
631       assert callable(monitor.GetLockInfo)
632       self._monitor = monitor
633
634   def _Test(self, use_monitor):
635     def cfg_fn(port, curl):
636       curl.opts["__port__"] = port
637
638     def _LockCheckReset(monitor, req):
639       self.assertTrue(monitor._lock.is_owned(shared=0),
640                       msg="Lock must be owned in exclusive mode")
641       assert not hasattr(req, "lockcheck__")
642       setattr(req, "lockcheck__", True)
643
644     def _BuildNiceName(port, default=None):
645       if port % 5 == 0:
646         return "nicename%s" % port
647       else:
648         # Use standard name
649         return default
650
651     requests = \
652       [http.client.HttpClientRequest("localhost", i, "POST", "/version%s" % i,
653                                      curl_config_fn=compat.partial(cfg_fn, i),
654                                      completion_cb=NotImplementedError,
655                                      nicename=_BuildNiceName(i))
656        for i in range(15176, 15501)]
657     requests_count = len(requests)
658
659     if use_monitor:
660       lock_monitor_cb = self._MonitorChecker()
661     else:
662       lock_monitor_cb = None
663
664     def _ProcessRequests(multi, handles):
665       self.assertTrue(isinstance(multi, self._DummyCurlMulti))
666       self.assertEqual(len(requests), len(handles))
667       self.assertTrue(compat.all(isinstance(curl, _FakeCurl)
668                                  for curl in handles))
669
670       # Prepare for lock check
671       for req in requests:
672         assert req.completion_cb is NotImplementedError
673         if use_monitor:
674           req.completion_cb = \
675             compat.partial(_LockCheckReset, lock_monitor_cb.GetMonitor())
676
677       for idx, curl in enumerate(handles):
678         try:
679           port = curl.opts["__port__"]
680         except KeyError:
681           self.fail("Per-request config function was not called")
682
683         if use_monitor:
684           # Check if lock information is correct
685           lock_info = lock_monitor_cb.GetMonitor().GetLockInfo(None)
686           expected = \
687             [("rpc/%s" % (_BuildNiceName(handle.opts["__port__"],
688                                          default=("localhost/version%s" %
689                                                   handle.opts["__port__"]))),
690               None,
691               [threading.currentThread().getName()], None)
692              for handle in handles[idx:]]
693           self.assertEqual(sorted(lock_info), sorted(expected))
694
695         if port % 3 == 0:
696           response_code = http.HTTP_OK
697           msg = None
698         else:
699           response_code = http.HttpNotFound.code
700           msg = "test error"
701
702         curl.info = {
703           pycurl.RESPONSE_CODE: response_code,
704           }
705
706         # Prepare for reset
707         self.assertFalse(curl.opts.pop(pycurl.POSTFIELDS))
708         self.assertTrue(callable(curl.opts.pop(pycurl.WRITEFUNCTION)))
709
710         yield (curl, msg)
711
712       if use_monitor:
713         self.assertTrue(compat.all(req.lockcheck__ for req in requests))
714
715     if use_monitor:
716       self.assertEqual(lock_monitor_cb.GetMonitor(), None)
717
718     http.client.ProcessRequests(requests, lock_monitor_cb=lock_monitor_cb,
719                                 _curl=_FakeCurl,
720                                 _curl_multi=self._DummyCurlMulti,
721                                 _curl_process=_ProcessRequests)
722     for req in requests:
723       if req.port % 3 == 0:
724         self.assertTrue(req.success)
725         self.assertEqual(req.error, None)
726       else:
727         self.assertFalse(req.success)
728         self.assertTrue("test error" in req.error)
729
730     # See if monitor was disabled
731     if use_monitor:
732       monitor = lock_monitor_cb.GetMonitor()
733       self.assertEqual(monitor._pending_fn, None)
734       self.assertEqual(monitor.GetLockInfo(None), [])
735     else:
736       self.assertEqual(lock_monitor_cb, None)
737
738     self.assertEqual(len(requests), requests_count)
739
740   def testBadRequest(self):
741     bad_request = http.client.HttpClientRequest("localhost", 27784,
742                                                 "POST", "/version")
743     bad_request.success = False
744
745     self.assertRaises(AssertionError, http.client.ProcessRequests,
746                       [bad_request], _curl=NotImplemented,
747                       _curl_multi=NotImplemented, _curl_process=NotImplemented)
748
749
750 if __name__ == '__main__':
751   testutils.GanetiTestProgram()