Revision 33231500

b/Makefile.am
389 389
	test/ganeti.rapi.client_unittest.py \
390 390
	test/ganeti.rapi.resources_unittest.py \
391 391
	test/ganeti.rapi.rlib2_unittest.py \
392
	test/ganeti.rpc_unittest.py \
392 393
	test/ganeti.serializer_unittest.py \
393 394
	test/ganeti.ssh_unittest.py \
394 395
	test/ganeti.uidpool_unittest.py \
b/daemons/ganeti-masterd
513 513
    sys.exit(constants.EXIT_FAILURE)
514 514

  
515 515

  
516
def ExecMasterd (options, args): # pylint: disable-msg=W0613
516
def ExecMasterd(options, args): # pylint: disable-msg=W0613
517 517
  """Main master daemon function, executed with the PID file held.
518 518

  
519 519
  """
b/lib/http/client.py
1 1
#
2 2
#
3 3

  
4
# Copyright (C) 2007, 2008 Google Inc.
4
# Copyright (C) 2007, 2008, 2010 Google Inc.
5 5
#
6 6
# This program is free software; you can redistribute it and/or modify
7 7
# it under the terms of the GNU General Public License as published by
......
22 22

  
23 23
"""
24 24

  
25
# pylint: disable-msg=E1103
25
import logging
26
import pycurl
27
from cStringIO import StringIO
26 28

  
27
# # E1103: %s %r has no %r member (but some types could not be
28
# inferred), since _socketobject could be ssl or not and pylint
29
# doesn't parse that
30

  
31

  
32
import os
33
import select
34
import socket
35
import errno
36
import threading
37

  
38
from ganeti import workerpool
39 29
from ganeti import http
40
from ganeti import utils
41

  
42

  
43
HTTP_CLIENT_THREADS = 10
30
from ganeti import compat
44 31

  
45 32

  
46 33
class HttpClientRequest(object):
47 34
  def __init__(self, host, port, method, path, headers=None, post_data=None,
48
               ssl_params=None, ssl_verify_peer=False, read_timeout=None):
35
               read_timeout=None, curl_config_fn=None):
49 36
    """Describes an HTTP request.
50 37

  
51 38
    @type host: string
......
56 43
    @param method: Method name
57 44
    @type path: string
58 45
    @param path: Request path
59
    @type headers: dict or None
60
    @param headers: Additional headers to send
46
    @type headers: list or None
47
    @param headers: Additional headers to send, list of strings
61 48
    @type post_data: string or None
62 49
    @param post_data: Additional data to send
63
    @type ssl_params: HttpSslParams
64
    @param ssl_params: SSL key and certificate
65
    @type ssl_verify_peer: bool
66
    @param ssl_verify_peer: Whether to compare our certificate with
67
        server's certificate
68 50
    @type read_timeout: int
69 51
    @param read_timeout: if passed, it will be used as the read
70 52
        timeout while reading the response from the server
53
    @type curl_config_fn: callable
54
    @param curl_config_fn: Function to configure cURL object before request
55
                           (Note: if the function configures the connection in
56
                           a way where it wouldn't be efficient to reuse them,
57
                           a "identity" property should be defined, see
58
                           L{HttpClientRequest.identity})
71 59

  
72 60
    """
73
    if post_data is not None:
74
      assert method.upper() in (http.HTTP_POST, http.HTTP_PUT), \
75
        "Only POST and GET requests support sending data"
76

  
77 61
    assert path.startswith("/"), "Path must start with slash (/)"
62
    assert curl_config_fn is None or callable(curl_config_fn)
78 63

  
79 64
    # Request attributes
80 65
    self.host = host
81 66
    self.port = port
82
    self.ssl_params = ssl_params
83
    self.ssl_verify_peer = ssl_verify_peer
84 67
    self.method = method
85 68
    self.path = path
86
    self.headers = headers
87
    self.post_data = post_data
88 69
    self.read_timeout = read_timeout
70
    self.curl_config_fn = curl_config_fn
89 71

  
72
    if post_data is None:
73
      self.post_data = ""
74
    else:
75
      self.post_data = post_data
76

  
77
    if headers is None:
78
      self.headers = []
79
    elif isinstance(headers, dict):
80
      # Support for old interface
81
      self.headers = ["%s: %s" % (name, value)
82
                      for name, value in headers.items()]
83
    else:
84
      self.headers = headers
85

  
86
    # Response status
90 87
    self.success = None
91 88
    self.error = None
92 89

  
93
    # Raw response
94
    self.response = None
95

  
96 90
    # Response attributes
97
    self.resp_version = None
98 91
    self.resp_status_code = None
99
    self.resp_reason = None
100
    self.resp_headers = None
101 92
    self.resp_body = None
102 93

  
103 94
  def __repr__(self):
......
108 99

  
109 100
    return "<%s at %#x>" % (" ".join(status), id(self))
110 101

  
102
  @property
103
  def url(self):
104
    """Returns the full URL for this requests.
111 105

  
112
class _HttpClientToServerMessageWriter(http.HttpMessageWriter):
113
  pass
114

  
106
    """
107
    # TODO: Support for non-SSL requests
108
    return "https://%s:%s%s" % (self.host, self.port, self.path)
115 109

  
116
class _HttpServerToClientMessageReader(http.HttpMessageReader):
117
  # Length limits
118
  START_LINE_LENGTH_MAX = 512
119
  HEADER_LENGTH_MAX = 4096
110
  @property
111
  def identity(self):
112
    """Returns identifier for retrieving a pooled connection for this request.
120 113

  
121
  def ParseStartLine(self, start_line):
122
    """Parses the status line sent by the server.
114
    This allows cURL client objects to be re-used and to cache information
115
    (e.g. SSL session IDs or connections).
123 116

  
124 117
    """
125
    # Empty lines are skipped when reading
126
    assert start_line
118
    parts = [self.host, self.port]
127 119

  
128
    try:
129
      [version, status, reason] = start_line.split(None, 2)
130
    except ValueError:
120
    if self.curl_config_fn:
131 121
      try:
132
        [version, status] = start_line.split(None, 1)
133
        reason = ""
134
      except ValueError:
135
        version = http.HTTP_0_9
122
        parts.append(self.curl_config_fn.identity)
123
      except AttributeError:
124
        pass
136 125

  
137
    if version:
138
      version = version.upper()
126
    return "/".join(str(i) for i in parts)
139 127

  
140
    # The status code is a three-digit number
141
    try:
142
      status = int(status)
143
      if status < 100 or status > 999:
144
        status = -1
145
    except (TypeError, ValueError):
146
      status = -1
147 128

  
148
    if status == -1:
149
      raise http.HttpError("Invalid status code (%r)" % start_line)
129
class _HttpClient(object):
130
  def __init__(self, curl_config_fn):
131
    """Initializes this class.
150 132

  
151
    return http.HttpServerToClientStartLine(version, status, reason)
133
    @type curl_config_fn: callable
134
    @param curl_config_fn: Function to configure cURL object after
135
                           initialization
152 136

  
137
    """
138
    self._req = None
153 139

  
154
class HttpClientRequestExecutor(http.HttpBase):
155
  # Default headers
156
  DEFAULT_HEADERS = {
157
    http.HTTP_USER_AGENT: http.HTTP_GANETI_VERSION,
158
    # TODO: For keep-alive, don't send "Connection: close"
159
    http.HTTP_CONNECTION: "close",
160
    }
140
    curl = self._CreateCurlHandle()
141
    curl.setopt(pycurl.VERBOSE, False)
142
    curl.setopt(pycurl.NOSIGNAL, True)
143
    curl.setopt(pycurl.USERAGENT, http.HTTP_GANETI_VERSION)
144
    curl.setopt(pycurl.PROXY, "")
161 145

  
162
  # Timeouts in seconds for socket layer
163
  # TODO: Soft timeout instead of only socket timeout?
164
  # TODO: Make read timeout configurable per OpCode?
165
  CONNECT_TIMEOUT = 5
166
  WRITE_TIMEOUT = 10
167
  READ_TIMEOUT = None
168
  CLOSE_TIMEOUT = 1
146
    # Pass cURL object to external config function
147
    if curl_config_fn:
148
      curl_config_fn(curl)
169 149

  
170
  def __init__(self, req):
171
    """Initializes the HttpClientRequestExecutor class.
150
    self._curl = curl
172 151

  
173
    @type req: HttpClientRequest
174
    @param req: Request object
152
  @staticmethod
153
  def _CreateCurlHandle():
154
    """Returns a new cURL object.
175 155

  
176 156
    """
177
    http.HttpBase.__init__(self)
178
    self.request = req
179

  
180
    try:
181
      # TODO: Implement connection caching/keep-alive
182
      self.sock = self._CreateSocket(req.ssl_params,
183
                                     req.ssl_verify_peer)
157
    return pycurl.Curl()
184 158

  
185
      # Disable Python's timeout
186
      self.sock.settimeout(None)
159
  def GetCurlHandle(self):
160
    """Returns the cURL object.
187 161

  
188
      # Operate in non-blocking mode
189
      self.sock.setblocking(0)
162
    """
163
    return self._curl
190 164

  
191
      response_msg_reader = None
192
      response_msg = None
193
      force_close = True
165
  def GetCurrentRequest(self):
166
    """Returns the current request.
194 167

  
195
      self._Connect()
196
      try:
197
        self._SendRequest()
198
        (response_msg_reader, response_msg) = self._ReadResponse()
168
    @rtype: L{HttpClientRequest} or None
199 169

  
200
        # Only wait for server to close if we didn't have any exception.
201
        force_close = False
202
      finally:
203
        # TODO: Keep-alive is not supported, always close connection
204
        force_close = True
205
        http.ShutdownConnection(self.sock, self.CLOSE_TIMEOUT,
206
                                self.WRITE_TIMEOUT, response_msg_reader,
207
                                force_close)
170
    """
171
    return self._req
208 172

  
209
      self.sock.close()
210
      self.sock = None
173
  def StartRequest(self, req):
174
    """Starts a request on this client.
211 175

  
212
      req.response = response_msg
176
    @type req: L{HttpClientRequest}
177
    @param req: HTTP request
213 178

  
214
      req.resp_version = req.response.start_line.version
215
      req.resp_status_code = req.response.start_line.code
216
      req.resp_reason = req.response.start_line.reason
217
      req.resp_headers = req.response.headers
218
      req.resp_body = req.response.body
179
    """
180
    assert not self._req, "Another request is already started"
181

  
182
    self._req = req
183
    self._resp_buffer = StringIO()
184

  
185
    url = req.url
186
    method = req.method
187
    post_data = req.post_data
188
    headers = req.headers
189

  
190
    # PycURL requires strings to be non-unicode
191
    assert isinstance(method, str)
192
    assert isinstance(url, str)
193
    assert isinstance(post_data, str)
194
    assert compat.all(isinstance(i, str) for i in headers)
195

  
196
    # Configure cURL object for request
197
    curl = self._curl
198
    curl.setopt(pycurl.CUSTOMREQUEST, str(method))
199
    curl.setopt(pycurl.URL, url)
200
    curl.setopt(pycurl.POSTFIELDS, post_data)
201
    curl.setopt(pycurl.WRITEFUNCTION, self._resp_buffer.write)
202
    curl.setopt(pycurl.HTTPHEADER, headers)
203

  
204
    if req.read_timeout is None:
205
      curl.setopt(pycurl.TIMEOUT, 0)
206
    else:
207
      curl.setopt(pycurl.TIMEOUT, int(req.read_timeout))
219 208

  
220
      req.success = True
221
      req.error = None
209
    # Pass cURL object to external config function
210
    if req.curl_config_fn:
211
      req.curl_config_fn(curl)
222 212

  
223
    except http.HttpError, err:
224
      req.success = False
225
      req.error = str(err)
213
  def Done(self, errmsg):
214
    """Finishes a request.
226 215

  
227
  def _Connect(self):
228
    """Non-blocking connect to host with timeout.
216
    @type errmsg: string or None
217
    @param errmsg: Error message if request failed
229 218

  
230 219
    """
231
    connected = False
232
    while True:
233
      try:
234
        connect_error = self.sock.connect_ex((self.request.host,
235
                                              self.request.port))
236
      except socket.gaierror, err:
237
        raise http.HttpError("Connection failed: %s" % str(err))
220
    req = self._req
221
    assert req, "No request"
238 222

  
239
      if connect_error == errno.EINTR:
240
        # Mask signals
241
        pass
223
    logging.debug("Request %s finished, errmsg=%s", req, errmsg)
242 224

  
243
      elif connect_error == 0:
244
        # Connection established
245
        connected = True
246
        break
225
    curl = self._curl
247 226

  
248
      elif connect_error == errno.EINPROGRESS:
249
        # Connection started
250
        break
227
    req.success = not bool(errmsg)
228
    req.error = errmsg
251 229

  
252
      raise http.HttpError("Connection failed (%s: %s)" %
253
                             (connect_error, os.strerror(connect_error)))
230
    # Get HTTP response code
231
    req.resp_status_code = curl.getinfo(pycurl.RESPONSE_CODE)
232
    req.resp_body = self._resp_buffer.getvalue()
254 233

  
255
    if not connected:
256
      # Wait for connection
257
      event = utils.WaitForFdCondition(self.sock, select.POLLOUT,
258
                                       self.CONNECT_TIMEOUT)
259
      if event is None:
260
        raise http.HttpError("Timeout while connecting to server")
234
    # Reset client object
235
    self._req = None
236
    self._resp_buffer = None
261 237

  
262
      # Get error code
263
      connect_error = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
264
      if connect_error != 0:
265
        raise http.HttpError("Connection failed (%s: %s)" %
266
                               (connect_error, os.strerror(connect_error)))
238
    # Ensure no potentially large variables are referenced
239
    curl.setopt(pycurl.POSTFIELDS, "")
240
    curl.setopt(pycurl.WRITEFUNCTION, lambda _: None)
267 241

  
268
    # Enable TCP keep-alive
269
    self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
270 242

  
271
    # If needed, Linux specific options are available to change the TCP
272
    # keep-alive settings, see "man 7 tcp" for TCP_KEEPCNT, TCP_KEEPIDLE and
273
    # TCP_KEEPINTVL.
243
class _PooledHttpClient:
244
  """Data structure for HTTP client pool.
274 245

  
275
    # Do the secret SSL handshake
276
    if self.using_ssl:
277
      self.sock.set_connect_state() # pylint: disable-msg=E1103
278
      try:
279
        http.Handshake(self.sock, self.WRITE_TIMEOUT)
280
      except http.HttpSessionHandshakeUnexpectedEOF:
281
        raise http.HttpError("Server closed connection during SSL handshake")
246
  """
247
  def __init__(self, identity, client):
248
    """Initializes this class.
282 249

  
283
  def _SendRequest(self):
284
    """Sends request to server.
250
    @type identity: string
251
    @param identity: Client identifier for pool
252
    @type client: L{_HttpClient}
253
    @param client: HTTP client
285 254

  
286 255
    """
287
    # Headers
288
    send_headers = self.DEFAULT_HEADERS.copy()
256
    self.identity = identity
257
    self.client = client
258
    self.lastused = 0
289 259

  
290
    if self.request.headers:
291
      send_headers.update(self.request.headers)
260
  def __repr__(self):
261
    status = ["%s.%s" % (self.__class__.__module__, self.__class__.__name__),
262
              "id=%s" % self.identity,
263
              "lastuse=%s" % self.lastused,
264
              repr(self.client)]
292 265

  
293
    send_headers[http.HTTP_HOST] = "%s:%s" % (self.request.host,
294
                                              self.request.port)
266
    return "<%s at %#x>" % (" ".join(status), id(self))
295 267

  
296
    # Response message
297
    msg = http.HttpMessage()
298 268

  
299
    # Combine request line. We only support HTTP/1.0 (no chunked transfers and
300
    # no keep-alive).
301
    # TODO: For keep-alive, change to HTTP/1.1
302
    msg.start_line = \
303
      http.HttpClientToServerStartLine(method=self.request.method.upper(),
304
                                       path=self.request.path,
305
                                       version=http.HTTP_1_0)
306
    msg.headers = send_headers
307
    msg.body = self.request.post_data
269
class HttpClientPool:
270
  """A simple HTTP client pool.
308 271

  
309
    try:
310
      _HttpClientToServerMessageWriter(self.sock, msg, self.WRITE_TIMEOUT)
311
    except http.HttpSocketTimeout:
312
      raise http.HttpError("Timeout while sending request")
313
    except socket.error, err:
314
      raise http.HttpError("Error sending request: %s" % err)
272
  Supports one pooled connection per identity (see
273
  L{HttpClientRequest.identity}).
315 274

  
316
  def _ReadResponse(self):
317
    """Read response from server.
275
  """
276
  #: After how many generations to drop unused clients
277
  _MAX_GENERATIONS_DROP = 25
278

  
279
  def __init__(self, curl_config_fn):
280
    """Initializes this class.
281

  
282
    @type curl_config_fn: callable
283
    @param curl_config_fn: Function to configure cURL object after
284
                           initialization
318 285

  
319 286
    """
320
    response_msg = http.HttpMessage()
287
    self._curl_config_fn = curl_config_fn
288
    self._generation = 0
289
    self._pool = {}
321 290

  
322
    if self.request.read_timeout is None:
323
      timeout = self.READ_TIMEOUT
324
    else:
325
      timeout = self.request.read_timeout
291
  @staticmethod
292
  def _GetHttpClientCreator():
293
    """Returns callable to create HTTP client.
326 294

  
327
    try:
328
      response_msg_reader = \
329
        _HttpServerToClientMessageReader(self.sock, response_msg, timeout)
330
    except http.HttpSocketTimeout:
331
      raise http.HttpError("Timeout while reading response")
332
    except socket.error, err:
333
      raise http.HttpError("Error reading response: %s" % err)
295
    """
296
    return _HttpClient
334 297

  
335
    return (response_msg_reader, response_msg)
298
  def _Get(self, identity):
299
    """Gets an HTTP client from the pool.
336 300

  
301
    @type identity: string
302
    @param identity: Client identifier
337 303

  
338
class _HttpClientPendingRequest(object):
339
  """Data class for pending requests.
304
    """
305
    try:
306
      pclient  = self._pool.pop(identity)
307
    except KeyError:
308
      # Need to create new client
309
      client = self._GetHttpClientCreator()(self._curl_config_fn)
310
      pclient = _PooledHttpClient(identity, client)
311
      logging.debug("Created new client %s", pclient)
312
    else:
313
      logging.debug("Reusing client %s", pclient)
340 314

  
341
  """
342
  def __init__(self, request):
343
    self.request = request
315
    assert pclient.identity == identity
344 316

  
345
    # Thread synchronization
346
    self.done = threading.Event()
317
    return pclient
347 318

  
348
  def __repr__(self):
349
    status = ["%s.%s" % (self.__class__.__module__, self.__class__.__name__),
350
              "req=%r" % self.request]
319
  def _StartRequest(self, req):
320
    """Starts a request.
351 321

  
352
    return "<%s at %#x>" % (" ".join(status), id(self))
322
    @type req: L{HttpClientRequest}
323
    @param req: HTTP request
353 324

  
325
    """
326
    logging.debug("Starting request %r", req)
327
    pclient = self._Get(req.identity)
354 328

  
355
class HttpClientWorker(workerpool.BaseWorker):
356
  """HTTP client worker class.
329
    assert req.identity not in self._pool
357 330

  
358
  """
359
  def RunTask(self, pend_req): # pylint: disable-msg=W0221
360
    try:
361
      HttpClientRequestExecutor(pend_req.request)
362
    finally:
363
      pend_req.done.set()
331
    pclient.client.StartRequest(req)
332
    pclient.lastused = self._generation
333

  
334
    return pclient
364 335

  
336
  def _Return(self, pclients):
337
    """Returns HTTP clients to the pool.
365 338

  
366
class HttpClientWorkerPool(workerpool.WorkerPool):
367
  def __init__(self, manager):
368
    workerpool.WorkerPool.__init__(self, "HttpClient",
369
                                   HTTP_CLIENT_THREADS,
370
                                   HttpClientWorker)
371
    self.manager = manager
339
    """
340
    for pc in pclients:
341
      logging.debug("Returning client %s to pool", pc)
342
      assert pc.identity not in self._pool
343
      assert pc not in self._pool.values()
344
      self._pool[pc.identity] = pc
345

  
346
    # Check for unused clients
347
    for pc in self._pool.values():
348
      if (pc.lastused + self._MAX_GENERATIONS_DROP) < self._generation:
349
        logging.debug("Removing client %s which hasn't been used"
350
                      " for %s generations",
351
                      pc, self._MAX_GENERATIONS_DROP)
352
        self._pool.pop(pc.identity, None)
353

  
354
    assert compat.all(pc.lastused >= (self._generation -
355
                                      self._MAX_GENERATIONS_DROP)
356
                      for pc in self._pool.values())
357

  
358
  @staticmethod
359
  def _CreateCurlMultiHandle():
360
    """Creates new cURL multi handle.
372 361

  
362
    """
363
    return pycurl.CurlMulti()
373 364

  
374
class HttpClientManager(object):
375
  """Manages HTTP requests.
365
  def ProcessRequests(self, requests):
366
    """Processes any number of HTTP client requests using pooled objects.
376 367

  
377
  """
378
  def __init__(self):
379
    self._wpool = HttpClientWorkerPool(self)
368
    @type requests: list of L{HttpClientRequest}
369
    @param requests: List of all requests
380 370

  
381
  def __del__(self):
382
    self.Shutdown()
371
    """
372
    multi = self._CreateCurlMultiHandle()
383 373

  
384
  def ExecRequests(self, requests):
385
    """Execute HTTP requests.
374
    # For client cleanup
375
    self._generation += 1
386 376

  
387
    This function can be called from multiple threads at the same time.
377
    assert compat.all((req.error is None and
378
                       req.success is None and
379
                       req.resp_status_code is None and
380
                       req.resp_body is None)
381
                      for req in requests)
388 382

  
389
    @type requests: List of HttpClientRequest instances
390
    @param requests: The requests to execute
391
    @rtype: List of HttpClientRequest instances
392
    @return: The list of requests passed in
383
    curl_to_pclient = {}
384
    for req in requests:
385
      pclient = self._StartRequest(req)
386
      curl = pclient.client.GetCurlHandle()
387
      curl_to_pclient[curl] = pclient
388
      multi.add_handle(curl)
389
      assert pclient.client.GetCurrentRequest() == req
390
      assert pclient.lastused >= 0
393 391

  
394
    """
395
    # _HttpClientPendingRequest is used for internal thread synchronization
396
    pending = [_HttpClientPendingRequest(req) for req in requests]
392
    assert len(curl_to_pclient) == len(requests)
397 393

  
398
    try:
399
      # Add requests to queue
400
      for pend_req in pending:
401
        self._wpool.AddTask(pend_req)
394
    done_count = 0
395
    while True:
396
      (ret, _) = multi.perform()
397
      assert ret in (pycurl.E_MULTI_OK, pycurl.E_CALL_MULTI_PERFORM)
398

  
399
      if ret == pycurl.E_CALL_MULTI_PERFORM:
400
        # cURL wants to be called again
401
        continue
402

  
403
      while True:
404
        (remaining_messages, successful, failed) = multi.info_read()
405

  
406
        for curl in successful:
407
          multi.remove_handle(curl)
408
          done_count += 1
409
          pclient = curl_to_pclient[curl]
410
          req = pclient.client.GetCurrentRequest()
411
          pclient.client.Done(None)
412
          assert req.success
413
          assert not pclient.client.GetCurrentRequest()
414

  
415
        for curl, errnum, errmsg in failed:
416
          multi.remove_handle(curl)
417
          done_count += 1
418
          pclient = curl_to_pclient[curl]
419
          req = pclient.client.GetCurrentRequest()
420
          pclient.client.Done("Error %s: %s" % (errnum, errmsg))
421
          assert req.error
422
          assert not pclient.client.GetCurrentRequest()
423

  
424
        if remaining_messages == 0:
425
          break
426

  
427
      assert done_count <= len(requests)
428

  
429
      if done_count == len(requests):
430
        break
402 431

  
403
    finally:
404
      # In case of an exception we should still wait for the rest, otherwise
405
      # another thread from the worker pool could modify the request object
406
      # after we returned.
432
      # Wait for I/O. The I/O timeout shouldn't be too long so that HTTP
433
      # timeouts, which are only evaluated in multi.perform, aren't
434
      # unnecessarily delayed.
435
      multi.select(1.0)
407 436

  
408
      # And wait for them to finish
409
      for pend_req in pending:
410
        pend_req.done.wait()
437
    assert compat.all(pclient.client.GetCurrentRequest() is None
438
                      for pclient in curl_to_pclient.values())
411 439

  
412
    # Return original list
413
    return requests
440
    # Return clients to pool
441
    self._Return(curl_to_pclient.values())
414 442

  
415
  def Shutdown(self):
416
    self._wpool.Quiesce()
417
    self._wpool.TerminateWorkers()
443
    assert done_count == len(requests)
444
    assert compat.all(req.error is not None or
445
                      (req.success and
446
                       req.resp_status_code is not None and
447
                       req.resp_body is not None)
448
                      for req in requests)
b/lib/rpc.py
34 34
import logging
35 35
import zlib
36 36
import base64
37
import pycurl
38
import threading
37 39

  
38 40
from ganeti import utils
39 41
from ganeti import objects
......
47 49
import ganeti.http.client  # pylint: disable-msg=W0611
48 50

  
49 51

  
50
# Module level variable
51
_http_manager = None
52
# Timeout for connecting to nodes (seconds)
53
_RPC_CONNECT_TIMEOUT = 5
54

  
55
_RPC_CLIENT_HEADERS = [
56
  "Content-type: %s" % http.HTTP_APP_JSON,
57
  ]
52 58

  
53 59
# Various time constants for the timeout table
54 60
_TMO_URGENT = 60 # one minute
......
72 78
def Init():
73 79
  """Initializes the module-global HTTP client manager.
74 80

  
75
  Must be called before using any RPC function.
81
  Must be called before using any RPC function and while exactly one thread is
82
  running.
76 83

  
77 84
  """
78
  global _http_manager # pylint: disable-msg=W0603
79

  
80
  assert not _http_manager, "RPC module initialized more than once"
85
  # curl_global_init(3) and curl_global_cleanup(3) must be called with only
86
  # one thread running. This check is just a safety measure -- it doesn't
87
  # cover all cases.
88
  assert threading.activeCount() == 1, \
89
         "Found more than one active thread when initializing pycURL"
81 90

  
82
  http.InitSsl()
91
  logging.info("Using PycURL %s", pycurl.version)
83 92

  
84
  _http_manager = http.client.HttpClientManager()
93
  pycurl.global_init(pycurl.GLOBAL_ALL)
85 94

  
86 95

  
87 96
def Shutdown():
88 97
  """Stops the module-global HTTP client manager.
89 98

  
90
  Must be called before quitting the program.
99
  Must be called before quitting the program and while exactly one thread is
100
  running.
91 101

  
92 102
  """
93
  global _http_manager # pylint: disable-msg=W0603
103
  pycurl.global_cleanup()
104

  
105

  
106
def _ConfigRpcCurl(curl):
107
  noded_cert = str(constants.NODED_CERT_FILE)
94 108

  
95
  if _http_manager:
96
    _http_manager.Shutdown()
97
    _http_manager = None
109
  curl.setopt(pycurl.FOLLOWLOCATION, False)
110
  curl.setopt(pycurl.CAINFO, noded_cert)
111
  curl.setopt(pycurl.SSL_VERIFYHOST, 0)
112
  curl.setopt(pycurl.SSL_VERIFYPEER, True)
113
  curl.setopt(pycurl.SSLCERTTYPE, "PEM")
114
  curl.setopt(pycurl.SSLCERT, noded_cert)
115
  curl.setopt(pycurl.SSLKEYTYPE, "PEM")
116
  curl.setopt(pycurl.SSLKEY, noded_cert)
117
  curl.setopt(pycurl.CONNECTTIMEOUT, _RPC_CONNECT_TIMEOUT)
118

  
119

  
120
class _RpcThreadLocal(threading.local):
121
  def GetHttpClientPool(self):
122
    """Returns a per-thread HTTP client pool.
123

  
124
    @rtype: L{http.client.HttpClientPool}
125

  
126
    """
127
    try:
128
      pool = self.hcp
129
    except AttributeError:
130
      pool = http.client.HttpClientPool(_ConfigRpcCurl)
131
      self.hcp = pool
132

  
133
    return pool
134

  
135

  
136
_thread_local = _RpcThreadLocal()
98 137

  
99 138

  
100 139
def _RpcTimeout(secs):
......
218 257
    self.procedure = procedure
219 258
    self.body = body
220 259
    self.port = port
221
    self.nc = {}
222

  
223
    self._ssl_params = \
224
      http.HttpSslParams(ssl_key_path=constants.NODED_CERT_FILE,
225
                         ssl_cert_path=constants.NODED_CERT_FILE)
260
    self._request = {}
226 261

  
227 262
  def ConnectList(self, node_list, address_list=None, read_timeout=None):
228 263
    """Add a list of nodes to the target nodes.
......
260 295
    if read_timeout is None:
261 296
      read_timeout = _TIMEOUTS[self.procedure]
262 297

  
263
    self.nc[name] = \
264
      http.client.HttpClientRequest(address, self.port, http.HTTP_PUT,
265
                                    "/%s" % self.procedure,
266
                                    post_data=self.body,
267
                                    ssl_params=self._ssl_params,
268
                                    ssl_verify_peer=True,
298
    self._request[name] = \
299
      http.client.HttpClientRequest(str(address), self.port,
300
                                    http.HTTP_PUT, str("/%s" % self.procedure),
301
                                    headers=_RPC_CLIENT_HEADERS,
302
                                    post_data=str(self.body),
269 303
                                    read_timeout=read_timeout)
270 304

  
271
  def GetResults(self):
305
  def GetResults(self, http_pool=None):
272 306
    """Call nodes and return results.
273 307

  
274 308
    @rtype: list
275 309
    @return: List of RPC results
276 310

  
277 311
    """
278
    assert _http_manager, "RPC module not initialized"
312
    if not http_pool:
313
      http_pool = _thread_local.GetHttpClientPool()
279 314

  
280
    _http_manager.ExecRequests(self.nc.values())
315
    http_pool.ProcessRequests(self._request.values())
281 316

  
282 317
    results = {}
283 318

  
284
    for name, req in self.nc.iteritems():
319
    for name, req in self._request.iteritems():
285 320
      if req.success and req.resp_status_code == http.HTTP_OK:
286 321
        results[name] = RpcResult(data=serializer.LoadJson(req.resp_body),
287 322
                                  node=name, call=self.procedure)
b/test/ganeti.http_unittest.py
87 87
    self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
88 88
    self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
89 89

  
90
  def testClientSizeLimits(self):
91
    """Test HTTP client size limits"""
92
    message_reader_class = http.client._HttpServerToClientMessageReader
93
    self.assert_(message_reader_class.START_LINE_LENGTH_MAX > 0)
94
    self.assert_(message_reader_class.HEADER_LENGTH_MAX > 0)
95

  
96 90
  def testFormatAuthHeader(self):
97 91
    self.assertEqual(http.auth._FormatAuthHeader("Basic", {}),
98 92
                     "Basic")
......
330 324
    self.assertEqual(users["user2"].options, ["write", "read"])
331 325

  
332 326

  
327
class TestClientRequest(unittest.TestCase):
328
  def testRepr(self):
329
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
330
                                       headers=[], post_data="Hello World")
331
    self.assert_(repr(cr).startswith("<"))
332

  
333
  def testNoHeaders(self):
334
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
335
                                       headers=None)
336
    self.assert_(isinstance(cr.headers, list))
337
    self.assertEqual(cr.headers, [])
338
    self.assertEqual(cr.url, "https://localhost:1234/version")
339

  
340
  def testOldStyleHeaders(self):
341
    headers = {
342
      "Content-type": "text/plain",
343
      "Accept": "text/html",
344
      }
345
    cr = http.client.HttpClientRequest("localhost", 16481, "GET", "/vg_list",
346
                                       headers=headers)
347
    self.assert_(isinstance(cr.headers, list))
348
    self.assertEqual(sorted(cr.headers), [
349
      "Accept: text/html",
350
      "Content-type: text/plain",
351
      ])
352
    self.assertEqual(cr.url, "https://localhost:16481/vg_list")
353

  
354
  def testNewStyleHeaders(self):
355
    headers = [
356
      "Accept: text/html",
357
      "Content-type: text/plain; charset=ascii",
358
      "Server: httpd 1.0",
359
      ]
360
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
361
                                       headers=headers)
362
    self.assert_(isinstance(cr.headers, list))
363
    self.assertEqual(sorted(cr.headers), sorted(headers))
364
    self.assertEqual(cr.url, "https://localhost:1234/version")
365

  
366
  def testPostData(self):
367
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version",
368
                                       post_data="Hello World")
369
    self.assertEqual(cr.post_data, "Hello World")
370

  
371
  def testNoPostData(self):
372
    cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
373
    self.assertEqual(cr.post_data, "")
374

  
375
  def testIdentity(self):
376
    # These should all use different connections, hence also have a different
377
    # identity
378
    cr1 = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
379
    cr2 = http.client.HttpClientRequest("localhost", 9999, "GET", "/version")
380
    cr3 = http.client.HttpClientRequest("node1", 1234, "GET", "/version")
381
    cr4 = http.client.HttpClientRequest("node1", 9999, "GET", "/version")
382

  
383
    self.assertEqual(len(set([cr1.identity, cr2.identity,
384
                              cr3.identity, cr4.identity])), 4)
385

  
386
    # But this one should have the same
387
    cr1vglist = http.client.HttpClientRequest("localhost", 1234,
388
                                              "GET", "/vg_list")
389
    self.assertEqual(cr1.identity, cr1vglist.identity)
390

  
391

  
392
class TestClient(unittest.TestCase):
393
  def test(self):
394
    pool = http.client.HttpClientPool(None)
395
    self.assertFalse(pool._pool)
396

  
397

  
333 398
if __name__ == '__main__':
334 399
  testutils.GanetiTestProgram()
b/test/ganeti.rpc_unittest.py
1
#!/usr/bin/python
2
#
3

  
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

  
21

  
22
"""Script for testing ganeti.rpc"""
23

  
24
import os
25
import sys
26
import unittest
27

  
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34

  
35
import testutils
36

  
37

  
38
class TestTimeouts(unittest.TestCase):
39
  def test(self):
40
    names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
41
             if name.startswith("call_")]
42
    self.assertEqual(len(names), len(rpc._TIMEOUTS))
43
    self.assertFalse([name for name in names
44
                      if not (rpc._TIMEOUTS[name] is None or
45
                              rpc._TIMEOUTS[name] > 0)])
46

  
47

  
48
class FakeHttpPool:
49
  def __init__(self, response_fn):
50
    self._response_fn = response_fn
51
    self.reqcount = 0
52

  
53
  def ProcessRequests(self, reqs):
54
    for req in reqs:
55
      self.reqcount += 1
56
      self._response_fn(req)
57

  
58

  
59
class TestClient(unittest.TestCase):
60
  def _GetVersionResponse(self, req):
61
    self.assertEqual(req.host, "localhost")
62
    self.assertEqual(req.port, 24094)
63
    self.assertEqual(req.path, "/version")
64
    req.success = True
65
    req.resp_status_code = http.HTTP_OK
66
    req.resp_body = serializer.DumpJson((True, 123))
67

  
68
  def testVersionSuccess(self):
69
    client = rpc.Client("version", None, 24094)
70
    client.ConnectNode("localhost")
71
    pool = FakeHttpPool(self._GetVersionResponse)
72
    result = client.GetResults(http_pool=pool)
73
    self.assertEqual(result.keys(), ["localhost"])
74
    lhresp = result["localhost"]
75
    self.assertFalse(lhresp.offline)
76
    self.assertEqual(lhresp.node, "localhost")
77
    self.assertFalse(lhresp.fail_msg)
78
    self.assertEqual(lhresp.payload, 123)
79
    self.assertEqual(lhresp.call, "version")
80
    lhresp.Raise("should not raise")
81
    self.assertEqual(pool.reqcount, 1)
82

  
83
  def _GetMultiVersionResponse(self, req):
84
    self.assert_(req.host.startswith("node"))
85
    self.assertEqual(req.port, 23245)
86
    self.assertEqual(req.path, "/version")
87
    req.success = True
88
    req.resp_status_code = http.HTTP_OK
89
    req.resp_body = serializer.DumpJson((True, 987))
90

  
91
  def testMultiVersionSuccess(self):
92
    nodes = ["node%s" % i for i in range(50)]
93
    client = rpc.Client("version", None, 23245)
94
    client.ConnectList(nodes)
95

  
96
    pool = FakeHttpPool(self._GetMultiVersionResponse)
97
    result = client.GetResults(http_pool=pool)
98
    self.assertEqual(sorted(result.keys()), sorted(nodes))
99

  
100
    for name in nodes:
101
      lhresp = result[name]
102
      self.assertFalse(lhresp.offline)
103
      self.assertEqual(lhresp.node, name)
104
      self.assertFalse(lhresp.fail_msg)
105
      self.assertEqual(lhresp.payload, 987)
106
      self.assertEqual(lhresp.call, "version")
107
      lhresp.Raise("should not raise")
108

  
109
    self.assertEqual(pool.reqcount, len(nodes))
110

  
111
  def _GetVersionResponseFail(self, req):
112
    self.assertEqual(req.path, "/version")
113
    req.success = True
114
    req.resp_status_code = http.HTTP_OK
115
    req.resp_body = serializer.DumpJson((False, "Unknown error"))
116

  
117
  def testVersionFailure(self):
118
    client = rpc.Client("version", None, 5903)
119
    client.ConnectNode("aef9ur4i.example.com")
120
    pool = FakeHttpPool(self._GetVersionResponseFail)
121
    result = client.GetResults(http_pool=pool)
122
    self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
123
    lhresp = result["aef9ur4i.example.com"]
124
    self.assertFalse(lhresp.offline)
125
    self.assertEqual(lhresp.node, "aef9ur4i.example.com")
126
    self.assert_(lhresp.fail_msg)
127
    self.assertFalse(lhresp.payload)
128
    self.assertEqual(lhresp.call, "version")
129
    self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
130
    self.assertEqual(pool.reqcount, 1)
131

  
132
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
133
    self.assertEqual(req.path, "/vg_list")
134
    self.assertEqual(req.port, 15165)
135

  
136
    if req.host in httperrnodes:
137
      req.success = False
138
      req.error = "Node set up for HTTP errors"
139

  
140
    elif req.host in failnodes:
141
      req.success = True
142
      req.resp_status_code = 404
143
      req.resp_body = serializer.DumpJson({
144
        "code": 404,
145
        "message": "Method not found",
146
        "explain": "Explanation goes here",
147
        })
148
    else:
149
      req.success = True
150
      req.resp_status_code = http.HTTP_OK
151
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
152

  
153
  def testHttpError(self):
154
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
155

  
156
    httperrnodes = set(nodes[1::7])
157
    self.assertEqual(len(httperrnodes), 7)
158

  
159
    failnodes = set(nodes[2::3]) - httperrnodes
160
    self.assertEqual(len(failnodes), 14)
161

  
162
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
163

  
164
    client = rpc.Client("vg_list", None, 15165)
165
    client.ConnectList(nodes)
166

  
167
    pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
168
                                       httperrnodes, failnodes))
169
    result = client.GetResults(http_pool=pool)
170
    self.assertEqual(sorted(result.keys()), sorted(nodes))
171

  
172
    for name in nodes:
173
      lhresp = result[name]
174
      self.assertFalse(lhresp.offline)
175
      self.assertEqual(lhresp.node, name)
176
      self.assertEqual(lhresp.call, "vg_list")
177

  
178
      if name in httperrnodes:
179
        self.assert_(lhresp.fail_msg)
180
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
181
      elif name in failnodes:
182
        self.assert_(lhresp.fail_msg)
183
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
184
                          prereq=True, ecode=errors.ECODE_INVAL)
185
      else:
186
        self.assertFalse(lhresp.fail_msg)
187
        self.assertEqual(lhresp.payload, hash(name))
188
        lhresp.Raise("should not raise")
189

  
190
    self.assertEqual(pool.reqcount, len(nodes))
191

  
192
  def _GetInvalidResponseA(self, req):
193
    self.assertEqual(req.path, "/version")
194
    req.success = True
195
    req.resp_status_code = http.HTTP_OK
196
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
197
                                         "response", "!", 1, 2, 3))
198

  
199
  def _GetInvalidResponseB(self, req):
200
    self.assertEqual(req.path, "/version")
201
    req.success = True
202
    req.resp_status_code = http.HTTP_OK
203
    req.resp_body = serializer.DumpJson("invalid response")
204

  
205
  def testInvalidResponse(self):
206
    client = rpc.Client("version", None, 19978)
207
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
208
      client.ConnectNode("oqo7lanhly.example.com")
209
      pool = FakeHttpPool(fn)
210
      result = client.GetResults(http_pool=pool)
211
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
212
      lhresp = result["oqo7lanhly.example.com"]
213
      self.assertFalse(lhresp.offline)
214
      self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
215
      self.assert_(lhresp.fail_msg)
216
      self.assertFalse(lhresp.payload)
217
      self.assertEqual(lhresp.call, "version")
218
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
219
      self.assertEqual(pool.reqcount, 1)
220

  
221

  
222
if __name__ == "__main__":
223
  testutils.GanetiTestProgram()

Also available in: Unified diff