Revision abbf2cd9

b/lib/http/client.py
55 55
        timeout while reading the response from the server
56 56
    @type curl_config_fn: callable
57 57
    @param curl_config_fn: Function to configure cURL object before request
58
                           (Note: if the function configures the connection in
59
                           a way where it wouldn't be efficient to reuse them,
60
                           an "identity" property should be defined, see
61
                           L{HttpClientRequest.identity})
62 58
    @type nicename: string
63 59
    @param nicename: Name, presentable to a user, to describe this request (no
64 60
                     whitespace)
......
118 114
    # TODO: Support for non-SSL requests
119 115
    return "https://%s%s" % (address, self.path)
120 116

  
121
  @property
122
  def identity(self):
123
    """Returns identifier for retrieving a pooled connection for this request.
124 117

  
125
    This allows cURL client objects to be re-used and to cache information
126
    (e.g. SSL session IDs or connections).
118
def _StartRequest(curl, req):
119
  """Starts a request on a cURL object.
127 120

  
128
    """
129
    parts = [self.host, self.port]
121
  @type curl: pycurl.Curl
122
  @param curl: cURL object
123
  @type req: L{HttpClientRequest}
124
  @param req: HTTP request
130 125

  
131
    if self.curl_config_fn:
132
      try:
133
        parts.append(self.curl_config_fn.identity)
134
      except AttributeError:
135
        pass
126
  """
127
  logging.debug("Starting request %r", req)
136 128

  
137
    return "/".join(str(i) for i in parts)
129
  url = req.url
130
  method = req.method
131
  post_data = req.post_data
132
  headers = req.headers
138 133

  
134
  # PycURL requires strings to be non-unicode
135
  assert isinstance(method, str)
136
  assert isinstance(url, str)
137
  assert isinstance(post_data, str)
138
  assert compat.all(isinstance(i, str) for i in headers)
139 139

  
140
class _HttpClient(object):
141
  def __init__(self, curl_config_fn):
142
    """Initializes this class.
140
  # Buffer for response
141
  resp_buffer = StringIO()
143 142

  
144
    @type curl_config_fn: callable
145
    @param curl_config_fn: Function to configure cURL object after
146
                           initialization
143
  # Configure client for request
144
  curl.setopt(pycurl.VERBOSE, False)
145
  curl.setopt(pycurl.NOSIGNAL, True)
146
  curl.setopt(pycurl.USERAGENT, http.HTTP_GANETI_VERSION)
147
  curl.setopt(pycurl.PROXY, "")
148
  curl.setopt(pycurl.CUSTOMREQUEST, str(method))
149
  curl.setopt(pycurl.URL, url)
150
  curl.setopt(pycurl.POSTFIELDS, post_data)
151
  curl.setopt(pycurl.HTTPHEADER, headers)
147 152

  
148
    """
149
    self._req = None
153
  if req.read_timeout is None:
154
    curl.setopt(pycurl.TIMEOUT, 0)
155
  else:
156
    curl.setopt(pycurl.TIMEOUT, int(req.read_timeout))
150 157

  
151
    curl = self._CreateCurlHandle()
152
    curl.setopt(pycurl.VERBOSE, False)
153
    curl.setopt(pycurl.NOSIGNAL, True)
154
    curl.setopt(pycurl.USERAGENT, http.HTTP_GANETI_VERSION)
155
    curl.setopt(pycurl.PROXY, "")
158
  # Disable SSL session ID caching (pycurl >= 7.16.0)
159
  if hasattr(pycurl, "SSL_SESSIONID_CACHE"):
160
    curl.setopt(pycurl.SSL_SESSIONID_CACHE, False)
156 161

  
157
    # Disable SSL session ID caching (pycurl >= 7.16.0)
158
    if hasattr(pycurl, "SSL_SESSIONID_CACHE"):
159
      curl.setopt(pycurl.SSL_SESSIONID_CACHE, False)
162
  curl.setopt(pycurl.WRITEFUNCTION, resp_buffer.write)
160 163

  
161
    # Pass cURL object to external config function
162
    if curl_config_fn:
163
      curl_config_fn(curl)
164
  # Pass cURL object to external config function
165
  if req.curl_config_fn:
166
    req.curl_config_fn(curl)
164 167

  
165
    self._curl = curl
168
  return _PendingRequest(curl, req, resp_buffer.getvalue)
166 169

  
167
  @staticmethod
168
  def _CreateCurlHandle():
169
    """Returns a new cURL object.
170

  
171
class _PendingRequest:
172
  def __init__(self, curl, req, resp_buffer_read):
173
    """Initializes this class.
174

  
175
    @type curl: pycurl.Curl
176
    @param curl: cURL object
177
    @type req: L{HttpClientRequest}
178
    @param req: HTTP request
179
    @type resp_buffer_read: callable
180
    @param resp_buffer_read: Function to read response body
170 181

  
171 182
    """
172
    return pycurl.Curl()
183
    assert req.success is None
184

  
185
    self._curl = curl
186
    self._req = req
187
    self._resp_buffer_read = resp_buffer_read
173 188

  
174 189
  def GetCurlHandle(self):
175 190
    """Returns the cURL object.
......
180 195
  def GetCurrentRequest(self):
181 196
    """Returns the current request.
182 197

  
183
    @rtype: L{HttpClientRequest} or None
184

  
185 198
    """
186 199
    return self._req
187 200

  
188
  def StartRequest(self, req):
189
    """Starts a request on this client.
190

  
191
    @type req: L{HttpClientRequest}
192
    @param req: HTTP request
193

  
194
    """
195
    assert not self._req, "Another request is already started"
196

  
197
    logging.debug("Starting request %r", req)
198

  
199
    self._req = req
200
    self._resp_buffer = StringIO()
201

  
202
    url = req.url
203
    method = req.method
204
    post_data = req.post_data
205
    headers = req.headers
206

  
207
    # PycURL requires strings to be non-unicode
208
    assert isinstance(method, str)
209
    assert isinstance(url, str)
210
    assert isinstance(post_data, str)
211
    assert compat.all(isinstance(i, str) for i in headers)
212

  
213
    # Configure cURL object for request
214
    curl = self._curl
215
    curl.setopt(pycurl.CUSTOMREQUEST, str(method))
216
    curl.setopt(pycurl.URL, url)
217
    curl.setopt(pycurl.POSTFIELDS, post_data)
218
    curl.setopt(pycurl.WRITEFUNCTION, self._resp_buffer.write)
219
    curl.setopt(pycurl.HTTPHEADER, headers)
220

  
221
    if req.read_timeout is None:
222
      curl.setopt(pycurl.TIMEOUT, 0)
223
    else:
224
      curl.setopt(pycurl.TIMEOUT, int(req.read_timeout))
225

  
226
    # Pass cURL object to external config function
227
    if req.curl_config_fn:
228
      req.curl_config_fn(curl)
229

  
230 201
  def Done(self, errmsg):
231 202
    """Finishes a request.
232 203

  
......
234 205
    @param errmsg: Error message if request failed
235 206

  
236 207
    """
208
    curl = self._curl
237 209
    req = self._req
238
    assert req, "No request"
239 210

  
240
    logging.debug("Request %s finished, errmsg=%s", req, errmsg)
211
    assert req.success is None, "Request has already been finalized"
241 212

  
242
    curl = self._curl
213
    logging.debug("Request %s finished, errmsg=%s", req, errmsg)
243 214

  
244 215
    req.success = not bool(errmsg)
245 216
    req.error = errmsg
246 217

  
247 218
    # Get HTTP response code
248 219
    req.resp_status_code = curl.getinfo(pycurl.RESPONSE_CODE)
249
    req.resp_body = self._resp_buffer.getvalue()
250

  
251
    # Reset client object
252
    self._req = None
253
    self._resp_buffer = None
220
    req.resp_body = self._resp_buffer_read()
254 221

  
255 222
    # Ensure no potentially large variables are referenced
256
    curl.setopt(pycurl.POSTFIELDS, "")
257
    curl.setopt(pycurl.WRITEFUNCTION, lambda _: None)
258

  
259

  
260
class _PooledHttpClient:
261
  """Data structure for HTTP client pool.
262

  
263
  """
264
  def __init__(self, identity, client):
265
    """Initializes this class.
266

  
267
    @type identity: string
268
    @param identity: Client identifier for pool
269
    @type client: L{_HttpClient}
270
    @param client: HTTP client
271

  
272
    """
273
    self.identity = identity
274
    self.client = client
275
    self.lastused = 0
276

  
277
  def __repr__(self):
278
    status = ["%s.%s" % (self.__class__.__module__, self.__class__.__name__),
279
              "id=%s" % self.identity,
280
              "lastuse=%s" % self.lastused,
281
              repr(self.client)]
282

  
283
    return "<%s at %#x>" % (" ".join(status), id(self))
284

  
285

  
286
class HttpClientPool:
287
  """A simple HTTP client pool.
288

  
289
  Supports one pooled connection per identity (see
290
  L{HttpClientRequest.identity}).
291

  
292
  """
293
  #: After how many generations to drop unused clients
294
  _MAX_GENERATIONS_DROP = 25
295

  
296
  def __init__(self, curl_config_fn):
297
    """Initializes this class.
298

  
299
    @type curl_config_fn: callable
300
    @param curl_config_fn: Function to configure cURL object after
301
                           initialization
302

  
303
    """
304
    self._curl_config_fn = curl_config_fn
305
    self._generation = 0
306
    self._pool = {}
307

  
308
    # Create custom logger for HTTP client pool. Change logging level to
309
    # C{logging.NOTSET} to get more details.
310
    self._logger = logging.getLogger(self.__class__.__name__)
311
    self._logger.setLevel(logging.INFO)
312

  
313
  @staticmethod
314
  def _GetHttpClientCreator():
315
    """Returns callable to create HTTP client.
316

  
317
    """
318
    return _HttpClient
319

  
320
  def _Get(self, identity):
321
    """Gets an HTTP client from the pool.
322

  
323
    @type identity: string
324
    @param identity: Client identifier
325

  
326
    """
327 223
    try:
328
      pclient = self._pool.pop(identity)
329
    except KeyError:
330
      # Need to create new client
331
      client = self._GetHttpClientCreator()(self._curl_config_fn)
332
      pclient = _PooledHttpClient(identity, client)
333
      self._logger.debug("Created new client %s", pclient)
224
      # Only available in PycURL 7.19.0 and above
225
      reset_fn = curl.reset
226
    except AttributeError:
227
      curl.setopt(pycurl.POSTFIELDS, "")
228
      curl.setopt(pycurl.WRITEFUNCTION, lambda _: None)
334 229
    else:
335
      self._logger.debug("Reusing client %s", pclient)
336

  
337
    assert pclient.identity == identity
338

  
339
    return pclient
340

  
341
  def _StartRequest(self, req):
342
    """Starts a request.
343

  
344
    @type req: L{HttpClientRequest}
345
    @param req: HTTP request
346

  
347
    """
348
    pclient = self._Get(req.identity)
349

  
350
    assert req.identity not in self._pool
351

  
352
    pclient.client.StartRequest(req)
353
    pclient.lastused = self._generation
354

  
355
    return pclient
356

  
357
  def _Return(self, pclients):
358
    """Returns HTTP clients to the pool.
359

  
360
    """
361
    assert not frozenset(pclients) & frozenset(self._pool.values())
362

  
363
    for pc in pclients:
364
      self._logger.debug("Returning client %s to pool", pc)
365
      assert pc.identity not in self._pool
366
      self._pool[pc.identity] = pc
367

  
368
    # Check for unused clients
369
    for pc in self._pool.values():
370
      if (pc.lastused + self._MAX_GENERATIONS_DROP) < self._generation:
371
        self._logger.debug("Removing client %s which hasn't been used"
372
                           " for %s generations",
373
                           pc, self._MAX_GENERATIONS_DROP)
374
        self._pool.pop(pc.identity, None)
375

  
376
    assert compat.all(pc.lastused >= (self._generation -
377
                                      self._MAX_GENERATIONS_DROP)
378
                      for pc in self._pool.values())
379

  
380
  @staticmethod
381
  def _CreateCurlMultiHandle():
382
    """Creates new cURL multi handle.
383

  
384
    """
385
    return pycurl.CurlMulti()
386

  
387
  def ProcessRequests(self, requests, lock_monitor_cb=None):
388
    """Processes any number of HTTP client requests using pooled objects.
389

  
390
    @type requests: list of L{HttpClientRequest}
391
    @param requests: List of all requests
392
    @param lock_monitor_cb: Callable for registering with lock monitor
393

  
394
    """
395
    # For client cleanup
396
    self._generation += 1
397

  
398
    assert compat.all((req.error is None and
399
                       req.success is None and
400
                       req.resp_status_code is None and
401
                       req.resp_body is None)
402
                      for req in requests)
403

  
404
    curl_to_pclient = {}
405
    for req in requests:
406
      pclient = self._StartRequest(req)
407
      curl_to_pclient[pclient.client.GetCurlHandle()] = pclient
408
      assert pclient.client.GetCurrentRequest() == req
409
      assert pclient.lastused >= 0
410

  
411
    assert len(curl_to_pclient) == len(requests)
412

  
413
    if lock_monitor_cb:
414
      monitor = _PendingRequestMonitor(threading.currentThread(),
415
                                       curl_to_pclient.values)
416
      lock_monitor_cb(monitor)
417
    else:
418
      monitor = _NoOpRequestMonitor
419

  
420
    # Process all requests and act based on the returned values
421
    for (curl, msg) in _ProcessCurlRequests(self._CreateCurlMultiHandle(),
422
                                            curl_to_pclient.keys()):
423
      pclient = curl_to_pclient[curl]
424
      req = pclient.client.GetCurrentRequest()
425

  
426
      monitor.acquire(shared=0)
427
      try:
428
        pclient.client.Done(msg)
429
      finally:
430
        monitor.release()
431

  
432
      assert ((msg is None and req.success and req.error is None) ^
433
              (msg is not None and not req.success and req.error == msg))
434

  
435
    assert compat.all(pclient.client.GetCurrentRequest() is None
436
                      for pclient in curl_to_pclient.values())
437

  
438
    monitor.acquire(shared=0)
439
    try:
440
      # Don't try to read information from returned clients
441
      monitor.Disable()
442

  
443
      # Return clients to pool
444
      self._Return(curl_to_pclient.values())
445
    finally:
446
      monitor.release()
447

  
448
    assert compat.all(req.error is not None or
449
                      (req.success and
450
                       req.resp_status_code is not None and
451
                       req.resp_body is not None)
452
                      for req in requests)
230
      reset_fn()
453 231

  
454 232

  
455 233
class _NoOpRequestMonitor: # pylint: disable=W0232
......
479 257
    self.acquire = self._lock.acquire
480 258
    self.release = self._lock.release
481 259

  
260
  @locking.ssynchronized(_LOCK)
482 261
  def Disable(self):
483 262
    """Disable monitor.
484 263

  
......
501 280
    if self._pending_fn:
502 281
      owner_name = self._owner.getName()
503 282

  
504
      for pclient in self._pending_fn():
505
        req = pclient.client.GetCurrentRequest()
283
      for client in self._pending_fn():
284
        req = client.GetCurrentRequest()
506 285
        if req:
507 286
          if req.nicename is None:
508 287
            name = "%s%s" % (req.host, req.path)
......
559 338
    # timeouts, which are only evaluated in multi.perform, aren't
560 339
    # unnecessarily delayed.
561 340
    multi.select(1.0)
341

  
342

  
343
def ProcessRequests(requests, lock_monitor_cb=None, _curl=pycurl.Curl,
344
                    _curl_multi=pycurl.CurlMulti,
345
                    _curl_process=_ProcessCurlRequests):
346
  """Processes any number of HTTP client requests.
347

  
348
  @type requests: list of L{HttpClientRequest}
349
  @param requests: List of all requests
350
  @param lock_monitor_cb: Callable for registering with lock monitor
351

  
352
  """
353
  assert compat.all((req.error is None and
354
                     req.success is None and
355
                     req.resp_status_code is None and
356
                     req.resp_body is None)
357
                    for req in requests)
358

  
359
  # Prepare all requests
360
  curl_to_client = \
361
    dict((client.GetCurlHandle(), client)
362
         for client in map(lambda req: _StartRequest(_curl(), req), requests))
363

  
364
  assert len(curl_to_client) == len(requests)
365

  
366
  if lock_monitor_cb:
367
    monitor = _PendingRequestMonitor(threading.currentThread(),
368
                                     curl_to_client.values)
369
    lock_monitor_cb(monitor)
370
  else:
371
    monitor = _NoOpRequestMonitor
372

  
373
  # Process all requests and act based on the returned values
374
  for (curl, msg) in _curl_process(_curl_multi(), curl_to_client.keys()):
375
    monitor.acquire(shared=0)
376
    try:
377
      curl_to_client.pop(curl).Done(msg)
378
    finally:
379
      monitor.release()
380

  
381
  assert not curl_to_client, "Not all requests were processed"
382

  
383
  # Don't try to read information anymore as all requests have been processed
384
  monitor.Disable()
385

  
386
  assert compat.all(req.error is not None or
387
                    (req.success and
388
                     req.resp_status_code is not None and
389
                     req.resp_body is not None)
390
                    for req in requests)
b/lib/rpc.py
374 374
                                        headers=_RPC_CLIENT_HEADERS,
375 375
                                        post_data=body,
376 376
                                        read_timeout=read_timeout,
377
                                        nicename="%s/%s" % (name, procedure))
377
                                        nicename="%s/%s" % (name, procedure),
378
                                        curl_config_fn=_ConfigRpcCurl)
378 379

  
379 380
    return (results, requests)
380 381

  
......
402 403

  
403 404
    return results
404 405

  
405
  def __call__(self, hosts, procedure, body, read_timeout=None, http_pool=None):
406
  def __call__(self, hosts, procedure, body, read_timeout=None,
407
               _req_process_fn=http.client.ProcessRequests):
406 408
    """Makes an RPC request to a number of nodes.
407 409

  
408 410
    @type hosts: sequence
......
417 419
    """
418 420
    assert procedure in _TIMEOUTS, "RPC call not declared in the timeouts table"
419 421

  
420
    if not http_pool:
421
      http_pool = http.client.HttpClientPool(_ConfigRpcCurl)
422

  
423 422
    if read_timeout is None:
424 423
      read_timeout = _TIMEOUTS[procedure]
425 424

  
......
427 426
      self._PrepareRequests(self._resolver(hosts), self._port, procedure,
428 427
                            str(body), read_timeout)
429 428

  
430
    http_pool.ProcessRequests(requests.values(),
431
                              lock_monitor_cb=self._lock_monitor_cb)
429
    _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
432 430

  
433 431
    assert not frozenset(results).intersection(requests)
434 432

  
b/test/ganeti.http_unittest.py
26 26
import unittest
27 27
import time
28 28
import tempfile
29
import pycurl
30
import itertools
31
import threading
29 32
from cStringIO import StringIO
30 33

  
31 34
from ganeti import http
35
from ganeti import compat
32 36

  
33 37
import ganeti.http.server
34 38
import ganeti.http.client
......
330 334
    self.assertEqual(cr.headers, [])
331 335
    self.assertEqual(cr.url, "https://localhost:1234/version")
332 336

  
337
  def testPlainAddressIPv4(self):
338
    cr = http.client.HttpClientRequest("192.0.2.9", 19956, "GET", "/version")
339
    self.assertEqual(cr.url, "https://192.0.2.9:19956/version")
340

  
341
  def testPlainAddressIPv6(self):
342
    cr = http.client.HttpClientRequest("2001:db8::cafe", 15110, "GET", "/info")
343
    self.assertEqual(cr.url, "https://[2001:db8::cafe]:15110/info")
344

  
333 345
  def testOldStyleHeaders(self):
334 346
    headers = {
335 347
      "Content-type": "text/plain",
......
365 377
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
366 378
    self.assertEqual(cr.post_data, "")
367 379

  
368
  def testIdentity(self):
369
    # These should all use different connections, hence also have a different
370
    # identity
371
    cr1 = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
372
    cr2 = http.client.HttpClientRequest("localhost", 9999, "GET", "/version")
373
    cr3 = http.client.HttpClientRequest("node1", 1234, "GET", "/version")
374
    cr4 = http.client.HttpClientRequest("node1", 9999, "GET", "/version")
375 380

  
376
    self.assertEqual(len(set([cr1.identity, cr2.identity,
377
                              cr3.identity, cr4.identity])), 4)
381
class _FakeCurl:
382
  def __init__(self):
383
    self.opts = {}
384
    self.info = NotImplemented
385

  
386
  def setopt(self, opt, value):
387
    assert opt not in self.opts, "Option set more than once"
388
    self.opts[opt] = value
389

  
390
  def getinfo(self, info):
391
    return self.info.pop(info)
392

  
393

  
394
class TestClientStartRequest(unittest.TestCase):
395
  @staticmethod
396
  def _TestCurlConfig(curl):
397
    curl.setopt(pycurl.SSLKEYTYPE, "PEM")
398

  
399
  def test(self):
400
    for method in [http.HTTP_GET, http.HTTP_PUT, "CUSTOM"]:
401
      for port in [8761, 29796, 19528]:
402
        for curl_config_fn in [None, self._TestCurlConfig]:
403
          for read_timeout in [None, 0, 1, 123, 36000]:
404
            self._TestInner(method, port, curl_config_fn, read_timeout)
405

  
406
  def _TestInner(self, method, port, curl_config_fn, read_timeout):
407
    for response_code in [http.HTTP_OK, http.HttpNotFound.code,
408
                          http.HTTP_NOT_MODIFIED]:
409
      for response_body in [None, "Hello World",
410
                            "Very Long\tContent here\n" * 171]:
411
        for errmsg in [None, "error"]:
412
          req = http.client.HttpClientRequest("localhost", port, method,
413
                                              "/version",
414
                                              curl_config_fn=curl_config_fn,
415
                                              read_timeout=read_timeout)
416
          curl = _FakeCurl()
417
          pending = http.client._StartRequest(curl, req)
418
          self.assertEqual(pending.GetCurlHandle(), curl)
419
          self.assertEqual(pending.GetCurrentRequest(), req)
420

  
421
          # Check options
422
          opts = curl.opts
423
          self.assertEqual(opts.pop(pycurl.CUSTOMREQUEST), method)
424
          self.assertEqual(opts.pop(pycurl.URL),
425
                           "https://localhost:%s/version" % port)
426
          if read_timeout is None:
427
            self.assertEqual(opts.pop(pycurl.TIMEOUT), 0)
428
          else:
429
            self.assertEqual(opts.pop(pycurl.TIMEOUT), read_timeout)
430
          self.assertFalse(opts.pop(pycurl.VERBOSE))
431
          self.assertTrue(opts.pop(pycurl.NOSIGNAL))
432
          self.assertEqual(opts.pop(pycurl.USERAGENT),
433
                           http.HTTP_GANETI_VERSION)
434
          self.assertEqual(opts.pop(pycurl.PROXY), "")
435
          self.assertFalse(opts.pop(pycurl.POSTFIELDS))
436
          self.assertFalse(opts.pop(pycurl.HTTPHEADER))
437
          write_fn = opts.pop(pycurl.WRITEFUNCTION)
438
          self.assertTrue(callable(write_fn))
439
          if hasattr(pycurl, "SSL_SESSIONID_CACHE"):
440
            self.assertFalse(opts.pop(pycurl.SSL_SESSIONID_CACHE))
441
          if curl_config_fn:
442
            self.assertEqual(opts.pop(pycurl.SSLKEYTYPE), "PEM")
443
          else:
444
            self.assertFalse(pycurl.SSLKEYTYPE in opts)
445
          self.assertFalse(opts)
446

  
447
          if response_body is not None:
448
            offset = 0
449
            while offset < len(response_body):
450
              piece = response_body[offset:offset + 10]
451
              write_fn(piece)
452
              offset += len(piece)
453

  
454
          curl.info = {
455
            pycurl.RESPONSE_CODE: response_code,
456
            }
457

  
458
          # Finalize request
459
          pending.Done(errmsg)
460

  
461
          self.assertFalse(curl.info)
462

  
463
          # Can only finalize once
464
          self.assertRaises(AssertionError, pending.Done, True)
465

  
466
          if errmsg:
467
            self.assertFalse(req.success)
468
          else:
469
            self.assertTrue(req.success)
470
          self.assertEqual(req.error, errmsg)
471
          self.assertEqual(req.resp_status_code, response_code)
472
          if response_body is None:
473
            self.assertEqual(req.resp_body, "")
474
          else:
475
            self.assertEqual(req.resp_body, response_body)
476

  
477
          # Check if resetting worked
478
          assert not hasattr(curl, "reset")
479
          opts = curl.opts
480
          self.assertFalse(opts.pop(pycurl.POSTFIELDS))
481
          self.assertTrue(callable(opts.pop(pycurl.WRITEFUNCTION)))
482
          self.assertFalse(opts)
483

  
484
          self.assertFalse(curl.opts,
485
                           msg="Previous checks did not consume all options")
486
          assert id(opts) == id(curl.opts)
378 487

  
379
    # But this one should have the same
380
    cr1vglist = http.client.HttpClientRequest("localhost", 1234,
381
                                              "GET", "/vg_list")
382
    self.assertEqual(cr1.identity, cr1vglist.identity)
488
  def _TestWrongTypes(self, *args, **kwargs):
489
    req = http.client.HttpClientRequest(*args, **kwargs)
490
    self.assertRaises(AssertionError, http.client._StartRequest,
491
                      _FakeCurl(), req)
383 492

  
493
  def testWrongHostType(self):
494
    self._TestWrongTypes(unicode("localhost"), 8080, "GET", "/version")
495

  
496
  def testWrongUrlType(self):
497
    self._TestWrongTypes("localhost", 8080, "GET", unicode("/version"))
498

  
499
  def testWrongMethodType(self):
500
    self._TestWrongTypes("localhost", 8080, unicode("GET"), "/version")
501

  
502
  def testWrongHeaderType(self):
503
    self._TestWrongTypes("localhost", 8080, "GET", "/version",
504
                         headers={
505
                           unicode("foo"): "bar",
506
                           })
507

  
508
  def testWrongPostDataType(self):
509
    self._TestWrongTypes("localhost", 8080, "GET", "/version",
510
                         post_data=unicode("verylongdata" * 100))
511

  
512

  
513
class _EmptyCurlMulti:
514
  def perform(self):
515
    return (pycurl.E_MULTI_OK, 0)
516

  
517
  def info_read(self):
518
    return (0, [], [])
519

  
520

  
521
class TestClientProcessRequests(unittest.TestCase):
522
  def testEmpty(self):
523
    requests = []
524
    http.client.ProcessRequests(requests, _curl=NotImplemented,
525
                                _curl_multi=_EmptyCurlMulti)
526
    self.assertEqual(requests, [])
527

  
528

  
529
class TestProcessCurlRequests(unittest.TestCase):
530
  class _FakeCurlMulti:
531
    def __init__(self):
532
      self.handles = []
533
      self.will_fail = []
534
      self._expect = ["perform"]
535
      self._counter = itertools.count()
536

  
537
    def add_handle(self, curl):
538
      assert curl not in self.handles
539
      self.handles.append(curl)
540
      if self._counter.next() % 3 == 0:
541
        self.will_fail.append(curl)
542

  
543
    def remove_handle(self, curl):
544
      self.handles.remove(curl)
545

  
546
    def perform(self):
547
      assert self._expect.pop(0) == "perform"
548

  
549
      if self._counter.next() % 2 == 0:
550
        self._expect.append("perform")
551
        return (pycurl.E_CALL_MULTI_PERFORM, None)
552

  
553
      self._expect.append("info_read")
554

  
555
      return (pycurl.E_MULTI_OK, len(self.handles))
556

  
557
    def info_read(self):
558
      assert self._expect.pop(0) == "info_read"
559
      successful = []
560
      failed = []
561
      if self.handles:
562
        if self._counter.next() % 17 == 0:
563
          curl = self.handles[0]
564
          if curl in self.will_fail:
565
            failed.append((curl, -1, "test error"))
566
          else:
567
            successful.append(curl)
568
        remaining_messages = len(self.handles) % 3
569
        if remaining_messages > 0:
570
          self._expect.append("info_read")
571
        else:
572
          self._expect.append("select")
573
      else:
574
        remaining_messages = 0
575
        self._expect.append("select")
576
      return (remaining_messages, successful, failed)
577

  
578
    def select(self, timeout):
579
      # Never compare floats for equality
580
      assert timeout >= 0.95 and timeout <= 1.05
581
      assert self._expect.pop(0) == "select"
582
      self._expect.append("perform")
384 583

  
385
class TestClient(unittest.TestCase):
386 584
  def test(self):
387
    pool = http.client.HttpClientPool(None)
388
    self.assertFalse(pool._pool)
585
    requests = [_FakeCurl() for _ in range(10)]
586
    multi = self._FakeCurlMulti()
587
    for (curl, errmsg) in http.client._ProcessCurlRequests(multi, requests):
588
      self.assertTrue(curl not in multi.handles)
589
      if curl in multi.will_fail:
590
        self.assertTrue("test error" in errmsg)
591
      else:
592
        self.assertTrue(errmsg is None)
593
    self.assertFalse(multi.handles)
594
    self.assertEqual(multi._expect, ["select"])
595

  
596

  
597
class TestProcessRequests(unittest.TestCase):
598
  class _DummyCurlMulti:
599
    pass
600

  
601
  def testNoMonitor(self):
602
    self._Test(False)
603

  
604
  def testWithMonitor(self):
605
    self._Test(True)
606

  
607
  class _MonitorChecker:
608
    def __init__(self):
609
      self._monitor = None
610

  
611
    def GetMonitor(self):
612
      return self._monitor
613

  
614
    def __call__(self, monitor):
615
      assert callable(monitor.GetLockInfo)
616
      self._monitor = monitor
617

  
618
  def _Test(self, use_monitor):
619
    def cfg_fn(port, curl):
620
      curl.opts["__port__"] = port
621

  
622
    def _LockCheckReset(monitor, curl):
623
      self.assertTrue(monitor._lock.is_owned(shared=0),
624
                      msg="Lock must be owned in exclusive mode")
625
      curl.opts["__lockcheck__"] = True
626

  
627
    requests = \
628
      [http.client.HttpClientRequest("localhost", i, "POST", "/version%s" % i,
629
                                     curl_config_fn=compat.partial(cfg_fn, i))
630
       for i in range(15176, 15501)]
631
    requests_count = len(requests)
632

  
633
    if use_monitor:
634
      lock_monitor_cb = self._MonitorChecker()
635
    else:
636
      lock_monitor_cb = None
637

  
638
    def _ProcessRequests(multi, handles):
639
      self.assertTrue(isinstance(multi, self._DummyCurlMulti))
640
      self.assertEqual(len(requests), len(handles))
641
      self.assertTrue(compat.all(isinstance(curl, _FakeCurl)
642
                                 for curl in handles))
643

  
644
      for idx, curl in enumerate(handles):
645
        port = curl.opts["__port__"]
646

  
647
        if use_monitor:
648
          # Check if lock information is correct
649
          lock_info = lock_monitor_cb.GetMonitor().GetLockInfo(None)
650
          expected = \
651
            [("rpc/localhost/version%s" % handle.opts["__port__"], None,
652
              [threading.currentThread().getName()], None)
653
             for handle in handles[idx:]]
654
          self.assertEqual(sorted(lock_info), sorted(expected))
655

  
656
        if port % 3 == 0:
657
          response_code = http.HTTP_OK
658
          msg = None
659
        else:
660
          response_code = http.HttpNotFound.code
661
          msg = "test error"
662

  
663
        curl.info = {
664
          pycurl.RESPONSE_CODE: response_code,
665
          }
666

  
667
        # Unset options which will be reset
668
        assert not hasattr(curl, "reset")
669
        if use_monitor:
670
          setattr(curl, "reset",
671
                  compat.partial(_LockCheckReset, lock_monitor_cb.GetMonitor(),
672
                                 curl))
673
        else:
674
          self.assertFalse(curl.opts.pop(pycurl.POSTFIELDS))
675
          self.assertTrue(callable(curl.opts.pop(pycurl.WRITEFUNCTION)))
676

  
677
        yield (curl, msg)
678

  
679
      if use_monitor:
680
        self.assertTrue(compat.all(curl.opts["__lockcheck__"]
681
                                   for curl in handles))
682

  
683
    http.client.ProcessRequests(requests, lock_monitor_cb=lock_monitor_cb,
684
                                _curl=_FakeCurl,
685
                                _curl_multi=self._DummyCurlMulti,
686
                                _curl_process=_ProcessRequests)
687
    for req in requests:
688
      if req.port % 3 == 0:
689
        self.assertTrue(req.success)
690
        self.assertEqual(req.error, None)
691
      else:
692
        self.assertFalse(req.success)
693
        self.assertTrue("test error" in req.error)
694

  
695
    # See if monitor was disabled
696
    if use_monitor:
697
      monitor = lock_monitor_cb.GetMonitor()
698
      self.assertEqual(monitor._pending_fn, None)
699
      self.assertEqual(monitor.GetLockInfo(None), [])
700
    else:
701
      self.assertEqual(lock_monitor_cb, None)
702

  
703
    self.assertEqual(len(requests), requests_count)
704

  
705
  def testBadRequest(self):
706
    bad_request = http.client.HttpClientRequest("localhost", 27784,
707
                                                "POST", "/version")
708
    bad_request.success = False
709

  
710
    self.assertRaises(AssertionError, http.client.ProcessRequests,
711
                      [bad_request], _curl=NotImplemented,
712
                      _curl_multi=NotImplemented, _curl_process=NotImplemented)
389 713

  
390 714

  
391 715
if __name__ == '__main__':
b/test/ganeti.rpc_unittest.py
46 46
                              rpc._TIMEOUTS[name] > 0)])
47 47

  
48 48

  
49
class FakeHttpPool:
49
class _FakeRequestProcessor:
50 50
  def __init__(self, response_fn):
51 51
    self._response_fn = response_fn
52 52
    self.reqcount = 0
53 53

  
54
  def ProcessRequests(self, reqs, lock_monitor_cb=None):
54
  def __call__(self, reqs, lock_monitor_cb=None):
55
    assert lock_monitor_cb is None or callable(lock_monitor_cb)
55 56
    for req in reqs:
56 57
      self.reqcount += 1
57 58
      self._response_fn(req)
......
80 81

  
81 82
  def testVersionSuccess(self):
82 83
    resolver = rpc._StaticResolver(["127.0.0.1"])
83
    pool = FakeHttpPool(self._GetVersionResponse)
84
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
84 85
    proc = rpc._RpcProcessor(resolver, 24094)
85
    result = proc(["localhost"], "version", None, http_pool=pool)
86
    result = proc(["localhost"], "version", None, _req_process_fn=http_proc)
86 87
    self.assertEqual(result.keys(), ["localhost"])
87 88
    lhresp = result["localhost"]
88 89
    self.assertFalse(lhresp.offline)
......
91 92
    self.assertEqual(lhresp.payload, 123)
92 93
    self.assertEqual(lhresp.call, "version")
93 94
    lhresp.Raise("should not raise")
94
    self.assertEqual(pool.reqcount, 1)
95
    self.assertEqual(http_proc.reqcount, 1)
95 96

  
96 97
  def _ReadTimeoutResponse(self, req):
97 98
    self.assertEqual(req.host, "192.0.2.13")
......
104 105

  
105 106
  def testReadTimeout(self):
106 107
    resolver = rpc._StaticResolver(["192.0.2.13"])
107
    pool = FakeHttpPool(self._ReadTimeoutResponse)
108
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
108 109
    proc = rpc._RpcProcessor(resolver, 19176)
109
    result = proc(["node31856"], "version", None, http_pool=pool,
110
    result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
110 111
                  read_timeout=12356)
111 112
    self.assertEqual(result.keys(), ["node31856"])
112 113
    lhresp = result["node31856"]
......
116 117
    self.assertEqual(lhresp.payload, -1)
117 118
    self.assertEqual(lhresp.call, "version")
118 119
    lhresp.Raise("should not raise")
119
    self.assertEqual(pool.reqcount, 1)
120
    self.assertEqual(http_proc.reqcount, 1)
120 121

  
121 122
  def testOfflineNode(self):
122 123
    resolver = rpc._StaticResolver([rpc._OFFLINE])
123
    pool = FakeHttpPool(NotImplemented)
124
    http_proc = _FakeRequestProcessor(NotImplemented)
124 125
    proc = rpc._RpcProcessor(resolver, 30668)
125
    result = proc(["n17296"], "version", None, http_pool=pool)
126
    result = proc(["n17296"], "version", None, _req_process_fn=http_proc)
126 127
    self.assertEqual(result.keys(), ["n17296"])
127 128
    lhresp = result["n17296"]
128 129
    self.assertTrue(lhresp.offline)
......
137 138
    # No message
138 139
    self.assertRaises(errors.OpExecError, lhresp.Raise, None)
139 140

  
140
    self.assertEqual(pool.reqcount, 0)
141
    self.assertEqual(http_proc.reqcount, 0)
141 142

  
142 143
  def _GetMultiVersionResponse(self, req):
143 144
    self.assert_(req.host.startswith("node"))
......
150 151
  def testMultiVersionSuccess(self):
151 152
    nodes = ["node%s" % i for i in range(50)]
152 153
    resolver = rpc._StaticResolver(nodes)
153
    pool = FakeHttpPool(self._GetMultiVersionResponse)
154
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
154 155
    proc = rpc._RpcProcessor(resolver, 23245)
155
    result = proc(nodes, "version", None, http_pool=pool)
156
    result = proc(nodes, "version", None, _req_process_fn=http_proc)
156 157
    self.assertEqual(sorted(result.keys()), sorted(nodes))
157 158

  
158 159
    for name in nodes:
......
164 165
      self.assertEqual(lhresp.call, "version")
165 166
      lhresp.Raise("should not raise")
166 167

  
167
    self.assertEqual(pool.reqcount, len(nodes))
168
    self.assertEqual(http_proc.reqcount, len(nodes))
168 169

  
169 170
  def _GetVersionResponseFail(self, errinfo, req):
170 171
    self.assertEqual(req.path, "/version")
......
176 177
    resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
177 178
    proc = rpc._RpcProcessor(resolver, 5903)
178 179
    for errinfo in [None, "Unknown error"]:
179
      pool = FakeHttpPool(compat.partial(self._GetVersionResponseFail, errinfo))
180
      result = proc(["aef9ur4i.example.com"], "version", None, http_pool=pool)
180
      http_proc = \
181
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
182
                                             errinfo))
183
      result = proc(["aef9ur4i.example.com"], "version", None,
184
                    _req_process_fn=http_proc)
181 185
      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
182 186
      lhresp = result["aef9ur4i.example.com"]
183 187
      self.assertFalse(lhresp.offline)
......
186 190
      self.assertFalse(lhresp.payload)
187 191
      self.assertEqual(lhresp.call, "version")
188 192
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
189
      self.assertEqual(pool.reqcount, 1)
193
      self.assertEqual(http_proc.reqcount, 1)
190 194

  
191 195
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
192 196
    self.assertEqual(req.path, "/vg_list")
......
222 226
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
223 227

  
224 228
    proc = rpc._RpcProcessor(resolver, 15165)
225
    pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
226
                                       httperrnodes, failnodes))
227
    result = proc(nodes, "vg_list", None, http_pool=pool)
229
    http_proc = \
230
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
231
                                           httperrnodes, failnodes))
232
    result = proc(nodes, "vg_list", None, _req_process_fn=http_proc)
228 233
    self.assertEqual(sorted(result.keys()), sorted(nodes))
229 234

  
230 235
    for name in nodes:
......
245 250
        self.assertEqual(lhresp.payload, hash(name))
246 251
        lhresp.Raise("should not raise")
247 252

  
248
    self.assertEqual(pool.reqcount, len(nodes))
253
    self.assertEqual(http_proc.reqcount, len(nodes))
249 254

  
250 255
  def _GetInvalidResponseA(self, req):
251 256
    self.assertEqual(req.path, "/version")
......
265 270
    proc = rpc._RpcProcessor(resolver, 19978)
266 271

  
267 272
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
268
      pool = FakeHttpPool(fn)
269
      result = proc(["oqo7lanhly.example.com"], "version", None, http_pool=pool)
273
      http_proc = _FakeRequestProcessor(fn)
274
      result = proc(["oqo7lanhly.example.com"], "version", None,
275
                    _req_process_fn=http_proc)
270 276
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271 277
      lhresp = result["oqo7lanhly.example.com"]
272 278
      self.assertFalse(lhresp.offline)
......
275 281
      self.assertFalse(lhresp.payload)
276 282
      self.assertEqual(lhresp.call, "version")
277 283
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
278
      self.assertEqual(pool.reqcount, 1)
284
      self.assertEqual(http_proc.reqcount, 1)
279 285

  
280 286
  def _GetBodyTestResponse(self, test_data, req):
281 287
    self.assertEqual(req.host, "192.0.2.84")
......
292 298
      "xyz": range(10),
293 299
      }
294 300
    resolver = rpc._StaticResolver(["192.0.2.84"])
295
    pool = FakeHttpPool(compat.partial(self._GetBodyTestResponse, test_data))
301
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
302
                                                     test_data))
296 303
    proc = rpc._RpcProcessor(resolver, 18700)
297 304
    body = serializer.DumpJson(test_data)
298
    result = proc(["node19759"], "upload_file", body, http_pool=pool)
305
    result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc)
299 306
    self.assertEqual(result.keys(), ["node19759"])
300 307
    lhresp = result["node19759"]
301 308
    self.assertFalse(lhresp.offline)
......
304 311
    self.assertEqual(lhresp.payload, None)
305 312
    self.assertEqual(lhresp.call, "upload_file")
306 313
    lhresp.Raise("should not raise")
307
    self.assertEqual(pool.reqcount, 1)
314
    self.assertEqual(http_proc.reqcount, 1)
308 315

  
309 316

  
310 317
class TestSsconfResolver(unittest.TestCase):

Also available in: Unified diff