RAPI client: Switch to pycURL
authorMichael Hanselmann <hansmi@google.com>
Thu, 1 Jul 2010 11:38:59 +0000 (13:38 +0200)
committerMichael Hanselmann <hansmi@google.com>
Thu, 1 Jul 2010 12:13:15 +0000 (14:13 +0200)
Currently the RAPI client uses the urllib2 and httplib modules from
Python's standard library. They're used with pyOpenSSL in a very fragile
way, and there are known issues when receiving large responses from a RAPI
server.

By switching to PycURL we leverage the power and stability of the
widely-used curl library (libcurl). This brings us much more flexibility
than before, and timeouts were easily implemented (something that would
have involved a lot of work with the built-in modules).

There's one small drawback: Programs using libcurl have to call
curl_global_init(3) (available as pycurl.global_init) while exactly one
thread is running (e.g. before other threads) and are supposed to call
curl_global_cleanup(3) (available as pycurl.global_cleanup) upon exiting.
See the manpages for details. A decorator is provided to simplify this.

Unittests for the new code are provided, increasing the test coverage of
the RAPI client from 74% to 89%.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Guido Trotter <ultrotter@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>

INSTALL
daemons/ganeti-watcher
lib/rapi/client.py
qa/ganeti-qa.py
qa/qa_rapi.py
test/ganeti.rapi.client_unittest.py
tools/move-instance

diff --git a/INSTALL b/INSTALL
index 74355b7..2348848 100644 (file)
--- a/INSTALL
+++ b/INSTALL
@@ -29,6 +29,7 @@ Before installing, please verify that you have the following programs:
 - `simplejson Python module <http://code.google.com/p/simplejson/>`_
 - `pyparsing Python module <http://pyparsing.wikispaces.com/>`_
 - `pyinotify Python module <http://trac.dbzteam.org/pyinotify/>`_
+- `PycURL Python module <http://pycurl.sourceforge.net/>`_
 - `socat <http://www.dest-unreach.org/socat/>`_
 
 These programs are supplied as part of most Linux distributions, so
@@ -39,7 +40,8 @@ packages, except for DRBD and Xen::
 
   $ apt-get install lvm2 ssh bridge-utils iproute iputils-arping \
                     python python-pyopenssl openssl python-pyparsing \
-                    python-simplejson python-pyinotify socat
+                    python-simplejson python-pyinotify python-pycurl \
+                    socat
 
 If you want to build from source, please see doc/devnotes.rst for more
 dependencies.
index e5ce913..84a9c29 100755 (executable)
@@ -610,10 +610,9 @@ def IsRapiResponding(hostname):
   @return: Whether RAPI is working properly
 
   """
-  ssl_config = rapi.client.CertAuthorityVerify(constants.RAPI_CERT_FILE)
-  rapi_client = \
-    rapi.client.GanetiRapiClient(hostname,
-                                 config_ssl_verification=ssl_config)
+  curl_config = rapi.client.GenericCurlConfig(cafile=constants.RAPI_CERT_FILE)
+  rapi_client = rapi.client.GanetiRapiClient(hostname,
+                                             curl_config_fn=curl_config)
   try:
     master_version = rapi_client.GetVersion()
   except rapi.client.CertificateError, err:
@@ -646,6 +645,7 @@ def ParseOptions():
   return options, args
 
 
+@rapi.client.UsesRapiClient
 def main():
   """Main function.
 
index 3f8fb76..6a8c021 100644 (file)
 # 02110-1301, USA.
 
 
-"""Ganeti RAPI client."""
+"""Ganeti RAPI client.
+
+@attention: To use the RAPI client, the application B{must} call
+            C{pycurl.global_init} during initialization and
+            C{pycurl.global_cleanup} before exiting the process. This is very
+            important in multi-threaded programs. See curl_global_init(3) and
+            curl_global_cleanup(3) for details. The decorator L{UsesRapiClient}
+            can be used.
+
+"""
 
 # No Ganeti-specific modules should be imported. The RAPI client is supposed to
 # be standalone.
 
-import sys
-import httplib
-import urllib2
 import logging
 import simplejson
-import socket
 import urllib
-import OpenSSL
-import distutils.version
+import threading
+import pycurl
+
+try:
+  from cStringIO import StringIO
+except ImportError:
+  from StringIO import StringIO
 
 
 GANETI_RAPI_PORT = 5080
@@ -61,6 +71,19 @@ NODE_ROLE_REGULAR = "regular"
 _REQ_DATA_VERSION_FIELD = "__version__"
 _INST_CREATE_REQV1 = "instance-create-reqv1"
 
+# Older pycURL versions don't have all error constants
+try:
+  _CURLE_SSL_CACERT = pycurl.E_SSL_CACERT
+  _CURLE_SSL_CACERT_BADFILE = pycurl.E_SSL_CACERT_BADFILE
+except AttributeError:
+  _CURLE_SSL_CACERT = 60
+  _CURLE_SSL_CACERT_BADFILE = 77
+
+_CURL_SSL_CERT_ERRORS = frozenset([
+  _CURLE_SSL_CACERT,
+  _CURLE_SSL_CACERT_BADFILE,
+  ])
+
 
 class Error(Exception):
   """Base error class for this module.
@@ -85,239 +108,123 @@ class GanetiApiError(Error):
     self.code = code
 
 
-def FormatX509Name(x509_name):
-  """Formats an X509 name.
-
-  @type x509_name: OpenSSL.crypto.X509Name
+def UsesRapiClient(fn):
+  """Decorator for code using RAPI client to initialize pycURL.
 
   """
-  try:
-    # Only supported in pyOpenSSL 0.7 and above
-    get_components_fn = x509_name.get_components
-  except AttributeError:
-    return repr(x509_name)
-  else:
-    return "".join("/%s=%s" % (name, value)
-                   for name, value in get_components_fn())
-
-
-class CertAuthorityVerify:
-  """Certificate verificator for SSL context.
-
-  Configures SSL context to verify server's certificate.
+  def wrapper(*args, **kwargs):
+    # curl_global_init(3) and curl_global_cleanup(3) must be called with only
+    # one thread running. This check is just a safety measure -- it doesn't
+    # cover all cases.
+    assert threading.activeCount() == 1, \
+           "Found active threads when initializing pycURL"
+
+    pycurl.global_init(pycurl.GLOBAL_ALL)
+    try:
+      return fn(*args, **kwargs)
+    finally:
+      pycurl.global_cleanup()
+
+  return wrapper
+
+
+def GenericCurlConfig(verbose=False, use_signal=False,
+                      use_curl_cabundle=False, cafile=None, capath=None,
+                      proxy=None, verify_hostname=False,
+                      connect_timeout=None, timeout=None,
+                      _pycurl_version_fn=pycurl.version_info):
+  """Curl configuration function generator.
+
+  @type verbose: bool
+  @param verbose: Whether to set cURL to verbose mode
+  @type use_signal: bool
+  @param use_signal: Whether to allow cURL to use signals
+  @type use_curl_cabundle: bool
+  @param use_curl_cabundle: Whether to use cURL's default CA bundle
+  @type cafile: string
+  @param cafile: In which file we can find the certificates
+  @type capath: string
+  @param capath: In which directory we can find the certificates
+  @type proxy: string
+  @param proxy: Proxy to use, None for default behaviour and empty string for
+                disabling proxies (see curl_easy_setopt(3))
+  @type verify_hostname: bool
+  @param verify_hostname: Whether to verify the remote peer certificate's
+                          commonName
+  @type connect_timeout: number
+  @param connect_timeout: Timeout for establishing connection in seconds
+  @type timeout: number
+  @param timeout: Timeout for complete transfer in seconds (see
+                  curl_easy_setopt(3)).
 
   """
-  _CAPATH_MINVERSION = "0.9"
-  _DEFVFYPATHS_MINVERSION = "0.9"
+  if use_curl_cabundle and (cafile or capath):
+    raise Error("Can not use default CA bundle when CA file or path is set")
 
-  _PYOPENSSL_VERSION = OpenSSL.__version__
-  _PARSED_PYOPENSSL_VERSION = distutils.version.LooseVersion(_PYOPENSSL_VERSION)
-
-  _SUPPORT_CAPATH = (_PARSED_PYOPENSSL_VERSION >= _CAPATH_MINVERSION)
-  _SUPPORT_DEFVFYPATHS = (_PARSED_PYOPENSSL_VERSION >= _DEFVFYPATHS_MINVERSION)
-
-  def __init__(self, cafile=None, capath=None, use_default_verify_paths=False):
-    """Initializes this class.
+  def _ConfigCurl(curl, logger):
+    """Configures a cURL object
 
-    @type cafile: string
-    @param cafile: In which file we can find the certificates
-    @type capath: string
-    @param capath: In which directory we can find the certificates
-    @type use_default_verify_paths: bool
-    @param use_default_verify_paths: Whether the platform provided CA
-                                     certificates are to be used for
-                                     verification purposes
+    @type curl: pycurl.Curl
+    @param curl: cURL object
 
     """
-    self._cafile = cafile
-    self._capath = capath
-    self._use_default_verify_paths = use_default_verify_paths
-
-    if self._capath is not None and not self._SUPPORT_CAPATH:
-      raise Error(("PyOpenSSL %s has no support for a CA directory,"
-                   " version %s or above is required") %
-                  (self._PYOPENSSL_VERSION, self._CAPATH_MINVERSION))
-
-    if self._use_default_verify_paths and not self._SUPPORT_DEFVFYPATHS:
-      raise Error(("PyOpenSSL %s has no support for using default verification"
-                   " paths, version %s or above is required") %
-                  (self._PYOPENSSL_VERSION, self._DEFVFYPATHS_MINVERSION))
-
-  @staticmethod
-  def _VerifySslCertCb(logger, _, cert, errnum, errdepth, ok):
-    """Callback for SSL certificate verification.
-
-    @param logger: Logging object
-
-    """
-    if ok:
-      log_fn = logger.debug
+    logger.debug("Using cURL version %s", pycurl.version)
+
+    # pycurl.version_info returns a tuple with information about the used
+    # version of libcurl. Item 5 is the SSL library linked to it.
+    # e.g.: (3, '7.18.0', 463360, 'x86_64-pc-linux-gnu', 1581, 'GnuTLS/2.0.4',
+    # 0, '1.2.3.3', ...)
+    sslver = _pycurl_version_fn()[5]
+    if not sslver:
+      raise Error("No SSL support in cURL")
+
+    lcsslver = sslver.lower()
+    if lcsslver.startswith("openssl/"):
+      pass
+    elif lcsslver.startswith("gnutls/"):
+      if capath:
+        raise Error("cURL linked against GnuTLS has no support for a"
+                    " CA path (%s)" % (pycurl.version, ))
     else:
-      log_fn = logger.error
-
-    log_fn("Verifying SSL certificate at depth %s, subject '%s', issuer '%s'",
-           errdepth, FormatX509Name(cert.get_subject()),
-           FormatX509Name(cert.get_issuer()))
-
-    if not ok:
-      try:
-        # Only supported in pyOpenSSL 0.7 and above
-        # pylint: disable-msg=E1101
-        fn = OpenSSL.crypto.X509_verify_cert_error_string
-      except AttributeError:
-        errmsg = ""
-      else:
-        errmsg = ":%s" % fn(errnum)
-
-      logger.error("verify error:num=%s%s", errnum, errmsg)
-
-    return ok
-
-  def __call__(self, ctx, logger):
-    """Configures an SSL context to verify certificates.
-
-    @type ctx: OpenSSL.SSL.Context
-    @param ctx: SSL context
-
-    """
-    if self._use_default_verify_paths:
-      ctx.set_default_verify_paths()
-
-    if self._cafile or self._capath:
-      if self._SUPPORT_CAPATH:
-        ctx.load_verify_locations(self._cafile, self._capath)
-      else:
-        ctx.load_verify_locations(self._cafile)
-
-    ctx.set_verify(OpenSSL.SSL.VERIFY_PEER,
-                   lambda conn, cert, errnum, errdepth, ok: \
-                     self._VerifySslCertCb(logger, conn, cert,
-                                           errnum, errdepth, ok))
-
-
-class _HTTPSConnectionOpenSSL(httplib.HTTPSConnection):
-  """HTTPS Connection handler that verifies the SSL certificate.
-
-  """
-  # Python before version 2.6 had its own httplib.FakeSocket wrapper for
-  # sockets
-  _SUPPORT_FAKESOCKET = (sys.hexversion < 0x2060000)
-
-  def __init__(self, *args, **kwargs):
-    """Initializes this class.
-
-    """
-    httplib.HTTPSConnection.__init__(self, *args, **kwargs)
-    self._logger = None
-    self._config_ssl_verification = None
-
-  def Setup(self, logger, config_ssl_verification):
-    """Sets the SSL verification config function.
-
-    @param logger: Logging object
-    @type config_ssl_verification: callable
-
-    """
-    assert self._logger is None
-    assert self._config_ssl_verification is None
-
-    self._logger = logger
-    self._config_ssl_verification = config_ssl_verification
-
-  def connect(self):
-    """Connect to the server specified when the object was created.
-
-    This ensures that SSL certificates are verified.
-
-    """
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-
-    ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
-    ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2)
-
-    if self._config_ssl_verification:
-      self._config_ssl_verification(ctx, self._logger)
-
-    ssl = OpenSSL.SSL.Connection(ctx, sock)
-    ssl.connect((self.host, self.port))
-
-    if self._SUPPORT_FAKESOCKET:
-      self.sock = httplib.FakeSocket(sock, ssl)
+      raise NotImplementedError("cURL uses unsupported SSL version '%s'" %
+                                sslver)
+
+    curl.setopt(pycurl.VERBOSE, verbose)
+    curl.setopt(pycurl.NOSIGNAL, not use_signal)
+
+    # Whether to verify remote peer's CN
+    if verify_hostname:
+      # curl_easy_setopt(3): "When CURLOPT_SSL_VERIFYHOST is 2, that
+      # certificate must indicate that the server is the server to which you
+      # meant to connect, or the connection fails. [...] When the value is 1,
+      # the certificate must contain a Common Name field, but it doesn't matter
+      # what name it says. [...]"
+      curl.setopt(pycurl.SSL_VERIFYHOST, 2)
     else:
-      self.sock = _SslSocketWrapper(ssl)
-
-
-class _SslSocketWrapper(object):
-  def __init__(self, sock):
-    """Initializes this class.
-
-    """
-    self._sock = sock
-
-  def __getattr__(self, name):
-    """Forward everything to underlying socket.
-
-    """
-    return getattr(self._sock, name)
-
-  def makefile(self, mode, bufsize):
-    """Fake makefile method.
-
-    makefile() on normal file descriptors uses dup2(2), which doesn't work with
-    SSL sockets and therefore is not implemented by pyOpenSSL. This fake method
-    works with the httplib module, but might not work for other modules.
-
-    """
-    # pylint: disable-msg=W0212
-    return socket._fileobject(self._sock, mode, bufsize)
-
-
-class _HTTPSHandler(urllib2.HTTPSHandler):
-  def __init__(self, logger, config_ssl_verification):
-    """Initializes this class.
-
-    @param logger: Logging object
-    @type config_ssl_verification: callable
-    @param config_ssl_verification: Function to configure SSL context for
-                                    certificate verification
-
-    """
-    urllib2.HTTPSHandler.__init__(self)
-    self._logger = logger
-    self._config_ssl_verification = config_ssl_verification
-
-  def _CreateHttpsConnection(self, *args, **kwargs):
-    """Wrapper around L{_HTTPSConnectionOpenSSL} to add SSL verification.
-
-    This wrapper is necessary provide a compatible API to urllib2.
-
-    """
-    conn = _HTTPSConnectionOpenSSL(*args, **kwargs)
-    conn.Setup(self._logger, self._config_ssl_verification)
-    return conn
-
-  def https_open(self, req):
-    """Creates HTTPS connection.
-
-    Called by urllib2.
-
-    """
-    return self.do_open(self._CreateHttpsConnection, req)
-
-
-class _RapiRequest(urllib2.Request):
-  def __init__(self, method, url, headers, data):
-    """Initializes this class.
+      curl.setopt(pycurl.SSL_VERIFYHOST, 0)
+
+    if cafile or capath or use_curl_cabundle:
+      # Require certificates to be checked
+      curl.setopt(pycurl.SSL_VERIFYPEER, True)
+      if cafile:
+        curl.setopt(pycurl.CAINFO, str(cafile))
+      if capath:
+        curl.setopt(pycurl.CAPATH, str(capath))
+      # Not changing anything for using default CA bundle
+    else:
+      # Disable SSL certificate verification
+      curl.setopt(pycurl.SSL_VERIFYPEER, False)
 
-    """
-    urllib2.Request.__init__(self, url, data=data, headers=headers)
-    self._method = method
+    if proxy is not None:
+      curl.setopt(pycurl.PROXY, str(proxy))
 
-  def get_method(self):
-    """Returns the HTTP request method.
+    # Timeouts
+    if connect_timeout is not None:
+      curl.setopt(pycurl.CONNECTTIMEOUT, connect_timeout)
+    if timeout is not None:
+      curl.setopt(pycurl.TIMEOUT, timeout)
 
-    """
-    return self._method
+  return _ConfigCurl
 
 
 class GanetiRapiClient(object):
@@ -328,10 +235,9 @@ class GanetiRapiClient(object):
   _json_encoder = simplejson.JSONEncoder(sort_keys=True)
 
   def __init__(self, host, port=GANETI_RAPI_PORT,
-               username=None, password=None,
-               config_ssl_verification=None, ignore_proxy=False,
-               logger=logging):
-    """Constructor.
+               username=None, password=None, logger=logging,
+               curl_config_fn=None, curl=None):
+    """Initializes this class.
 
     @type host: string
     @param host: the ganeti cluster master to interact with
@@ -341,11 +247,8 @@ class GanetiRapiClient(object):
     @param username: the username to connect with
     @type password: string
     @param password: the password to connect with
-    @type config_ssl_verification: callable
-    @param config_ssl_verification: Function to configure SSL context for
-                                    certificate verification
-    @type ignore_proxy: bool
-    @param ignore_proxy: Whether to ignore proxy settings
+    @type curl_config_fn: callable
+    @param curl_config_fn: Function to configure C{pycurl.Curl} object
     @param logger: Logging object
 
     """
@@ -355,25 +258,37 @@ class GanetiRapiClient(object):
 
     self._base_url = "https://%s:%s" % (host, port)
 
-    handlers = [_HTTPSHandler(self._logger, config_ssl_verification)]
-
+    # Create pycURL object if not supplied
+    if not curl:
+      curl = pycurl.Curl()
+
+    # Default cURL settings
+    curl.setopt(pycurl.VERBOSE, False)
+    curl.setopt(pycurl.FOLLOWLOCATION, False)
+    curl.setopt(pycurl.MAXREDIRS, 5)
+    curl.setopt(pycurl.NOSIGNAL, True)
+    curl.setopt(pycurl.USERAGENT, self.USER_AGENT)
+    curl.setopt(pycurl.SSL_VERIFYHOST, 0)
+    curl.setopt(pycurl.SSL_VERIFYPEER, False)
+    curl.setopt(pycurl.HTTPHEADER, [
+      "Accept: %s" % HTTP_APP_JSON,
+      "Content-type: %s" % HTTP_APP_JSON,
+      ])
+
+    # Setup authentication
     if username is not None:
-      pwmgr = urllib2.HTTPPasswordMgrWithDefaultRealm()
-      pwmgr.add_password(None, self._base_url, username, password)
-      handlers.append(urllib2.HTTPBasicAuthHandler(pwmgr))
+      if password is None:
+        raise Error("Password not specified")
+      curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
+      curl.setopt(pycurl.USERPWD, str("%s:%s" % (username, password)))
     elif password:
       raise Error("Specified password without username")
 
-    if ignore_proxy:
-      handlers.append(urllib2.ProxyHandler({}))
-
-    self._http = urllib2.build_opener(*handlers) # pylint: disable-msg=W0142
+    # Call external configuration function
+    if curl_config_fn:
+      curl_config_fn(curl, logger)
 
-    self._headers = {
-      "Accept": HTTP_APP_JSON,
-      "Content-type": HTTP_APP_JSON,
-      "User-Agent": self.USER_AGENT,
-      }
+    self._curl = curl
 
   @staticmethod
   def _EncodeQuery(query):
@@ -427,10 +342,12 @@ class GanetiRapiClient(object):
     """
     assert path.startswith("/")
 
+    curl = self._curl
+
     if content:
       encoded_content = self._json_encoder.encode(content)
     else:
-      encoded_content = None
+      encoded_content = ""
 
     # Build URL
     urlparts = [self._base_url, path]
@@ -440,30 +357,43 @@ class GanetiRapiClient(object):
 
     url = "".join(urlparts)
 
-    self._logger.debug("Sending request %s %s to %s:%s"
-                       " (headers=%r, content=%r)",
-                       method, url, self._host, self._port, self._headers,
-                       encoded_content)
+    self._logger.debug("Sending request %s %s to %s:%s (content=%r)",
+                       method, url, self._host, self._port, encoded_content)
+
+    # Buffer for response
+    encoded_resp_body = StringIO()
 
-    req = _RapiRequest(method, url, self._headers, encoded_content)
+    # Configure cURL
+    curl.setopt(pycurl.CUSTOMREQUEST, str(method))
+    curl.setopt(pycurl.URL, str(url))
+    curl.setopt(pycurl.POSTFIELDS, str(encoded_content))
+    curl.setopt(pycurl.WRITEFUNCTION, encoded_resp_body.write)
 
     try:
-      resp = self._http.open(req)
-      encoded_response_content = resp.read()
-    except (OpenSSL.SSL.Error, OpenSSL.crypto.Error), err:
-      raise CertificateError("SSL issue: %s (%r)" % (err, err))
-    except urllib2.HTTPError, err:
-      raise GanetiApiError(str(err), code=err.code)
-    except urllib2.URLError, err:
-      raise GanetiApiError(str(err))
-
-    if encoded_response_content:
-      response_content = simplejson.loads(encoded_response_content)
+      # Send request and wait for response
+      try:
+        curl.perform()
+      except pycurl.error, err:
+        if err.args[0] in _CURL_SSL_CERT_ERRORS:
+          raise CertificateError("SSL certificate error %s" % err)
+
+        raise GanetiApiError(str(err))
+    finally:
+      # Reset settings to not keep references to large objects in memory
+      # between requests
+      curl.setopt(pycurl.POSTFIELDS, "")
+      curl.setopt(pycurl.WRITEFUNCTION, lambda _: None)
+
+    # Get HTTP response code
+    http_code = curl.getinfo(pycurl.RESPONSE_CODE)
+
+    # Was anything written to the response buffer?
+    if encoded_resp_body.tell():
+      response_content = simplejson.loads(encoded_resp_body.getvalue())
     else:
       response_content = None
 
-    # TODO: Are there other status codes that are valid? (redirect?)
-    if resp.code != HTTP_OK:
+    if http_code != HTTP_OK:
       if isinstance(response_content, dict):
         msg = ("%s %s: %s" %
                (response_content["code"],
@@ -472,7 +402,7 @@ class GanetiRapiClient(object):
       else:
         msg = str(response_content)
 
-      raise GanetiApiError(msg, code=resp.code)
+      raise GanetiApiError(msg, code=http_code)
 
     return response_content
 
index 3bc8199..b31f78c 100755 (executable)
@@ -39,6 +39,9 @@ import qa_tags
 import qa_utils
 
 from ganeti import utils
+from ganeti import rapi
+
+import ganeti.rapi.client
 
 
 def RunTest(fn, *args):
@@ -269,6 +272,7 @@ def RunHardwareFailureTests(instance, pnode, snode):
             instance, pnode, snode)
 
 
+@rapi.client.UsesRapiClient
 def main():
   """Main program.
 
index a171121..09dd74a 100644 (file)
@@ -72,13 +72,13 @@ def Setup(username, password):
   _rapi_ca.flush()
 
   port = qa_config.get("rapi-port", default=constants.DEFAULT_RAPI_PORT)
-  cfg_ssl = rapi.client.CertAuthorityVerify(cafile=_rapi_ca.name)
+  cfg_curl = rapi.client.GenericCurlConfig(cafile=_rapi_ca.name,
+                                           proxy="")
 
   _rapi_client = rapi.client.GanetiRapiClient(master["primary"], port=port,
                                               username=username,
                                               password=password,
-                                              config_ssl_verification=cfg_ssl,
-                                              ignore_proxy=True)
+                                              curl_config_fn=cfg_curl)
 
   print "RAPI protocol version: %s" % _rapi_client.GetVersion()
 
index 10c23d0..adfe1be 100755 (executable)
@@ -25,7 +25,9 @@
 import re
 import unittest
 import warnings
+import pycurl
 
+from ganeti import constants
 from ganeti import http
 from ganeti import serializer
 
@@ -50,34 +52,36 @@ def _GetPathFromUri(uri):
     return None
 
 
-class HttpResponseMock:
-  """Dumb mock of httplib.HTTPResponse.
-
-  """
-
-  def __init__(self, code, data):
-    self.code = code
-    self._data = data
+class FakeCurl:
+  def __init__(self, rapi):
+    self._rapi = rapi
+    self._opts = {}
+    self._info = {}
 
-  def read(self):
-    return self._data
+  def setopt(self, opt, value):
+    self._opts[opt] = value
 
+  def getopt(self, opt):
+    return self._opts.get(opt)
 
-class OpenerDirectorMock:
-  """Mock for urllib.OpenerDirector.
+  def unsetopt(self, opt):
+    self._opts.pop(opt, None)
 
-  """
+  def getinfo(self, info):
+    return self._info[info]
 
-  def __init__(self, rapi):
-    self._rapi = rapi
-    self.last_request = None
+  def perform(self):
+    method = self._opts[pycurl.CUSTOMREQUEST]
+    url = self._opts[pycurl.URL]
+    request_body = self._opts[pycurl.POSTFIELDS]
+    writefn = self._opts[pycurl.WRITEFUNCTION]
 
-  def open(self, req):
-    self.last_request = req
+    path = _GetPathFromUri(url)
+    (code, resp_body) = self._rapi.FetchResponse(path, method, request_body)
 
-    path = _GetPathFromUri(req.get_full_url())
-    code, resp_body = self._rapi.FetchResponse(path, req.get_method())
-    return HttpResponseMock(code, resp_body)
+    self._info[pycurl.RESPONSE_CODE] = code
+    if resp_body is not None:
+      writefn(resp_body)
 
 
 class RapiMock(object):
@@ -85,6 +89,7 @@ class RapiMock(object):
     self._mapper = connector.Mapper()
     self._responses = []
     self._last_handler = None
+    self._last_req_data = None
 
   def AddResponse(self, response, code=200):
     self._responses.insert(0, (code, response))
@@ -92,7 +97,12 @@ class RapiMock(object):
   def GetLastHandler(self):
     return self._last_handler
 
-  def FetchResponse(self, path, method):
+  def GetLastRequestData(self):
+    return self._last_req_data
+
+  def FetchResponse(self, path, method, request_body):
+    self._last_req_data = request_body
+
     try:
       HandlerClass, items, args = self._mapper.getController(path)
       self._last_handler = HandlerClass(items, args, None)
@@ -111,30 +121,213 @@ class RapiMock(object):
     return code, response
 
 
+class TestConstants(unittest.TestCase):
+  def test(self):
+    self.assertEqual(client.GANETI_RAPI_PORT, constants.DEFAULT_RAPI_PORT)
+    self.assertEqual(client.GANETI_RAPI_VERSION, constants.RAPI_VERSION)
+    self.assertEqual(client.HTTP_APP_JSON, http.HTTP_APP_JSON)
+    self.assertEqual(client._REQ_DATA_VERSION_FIELD, rlib2._REQ_DATA_VERSION)
+    self.assertEqual(client._INST_CREATE_REQV1, rlib2._INST_CREATE_REQV1)
+
+
 class RapiMockTest(unittest.TestCase):
   def test(self):
     rapi = RapiMock()
     path = "/version"
-    self.assertEqual((404, None), rapi.FetchResponse("/foo", "GET"))
+    self.assertEqual((404, None), rapi.FetchResponse("/foo", "GET", None))
     self.assertEqual((501, "Method not implemented"),
-                     rapi.FetchResponse("/version", "POST"))
+                     rapi.FetchResponse("/version", "POST", None))
     rapi.AddResponse("2")
-    code, response = rapi.FetchResponse("/version", "GET")
+    code, response = rapi.FetchResponse("/version", "GET", None)
     self.assertEqual(200, code)
     self.assertEqual("2", response)
     self.failUnless(isinstance(rapi.GetLastHandler(), rlib2.R_version))
 
 
+def _FakeNoSslPycurlVersion():
+  # Note: incomplete version tuple
+  return (3, "7.16.0", 462848, "mysystem", 1581, None, 0)
+
+
+def _FakeFancySslPycurlVersion():
+  # Note: incomplete version tuple
+  return (3, "7.16.0", 462848, "mysystem", 1581, "FancySSL/1.2.3", 0)
+
+
+def _FakeOpenSslPycurlVersion():
+  # Note: incomplete version tuple
+  return (2, "7.15.5", 462597, "othersystem", 668, "OpenSSL/0.9.8c", 0)
+
+
+def _FakeGnuTlsPycurlVersion():
+  # Note: incomplete version tuple
+  return (3, "7.18.0", 463360, "somesystem", 1581, "GnuTLS/2.0.4", 0)
+
+
+class TestExtendedConfig(unittest.TestCase):
+  def testAuth(self):
+    curl = FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com",
+                                 username="user", password="pw",
+                                 curl=curl)
+
+    self.assertEqual(curl.getopt(pycurl.HTTPAUTH), pycurl.HTTPAUTH_BASIC)
+    self.assertEqual(curl.getopt(pycurl.USERPWD), "user:pw")
+
+  def testInvalidAuth(self):
+    # No username
+    self.assertRaises(client.Error, client.GanetiRapiClient,
+                      "master-a.example.com", password="pw")
+    # No password
+    self.assertRaises(client.Error, client.GanetiRapiClient,
+                      "master-b.example.com", username="user")
+
+  def testCertVerifyInvalidCombinations(self):
+    self.assertRaises(client.Error, client.GenericCurlConfig,
+                      use_curl_cabundle=True, cafile="cert1.pem")
+    self.assertRaises(client.Error, client.GenericCurlConfig,
+                      use_curl_cabundle=True, capath="certs/")
+    self.assertRaises(client.Error, client.GenericCurlConfig,
+                      use_curl_cabundle=True,
+                      cafile="cert1.pem", capath="certs/")
+
+  def testProxySignalVerifyHostname(self):
+    for use_gnutls in [False, True]:
+      if use_gnutls:
+        pcverfn = _FakeGnuTlsPycurlVersion
+      else:
+        pcverfn = _FakeOpenSslPycurlVersion
+
+      for proxy in ["", "http://127.0.0.1:1234"]:
+        for use_signal in [False, True]:
+          for verify_hostname in [False, True]:
+            cfgfn = client.GenericCurlConfig(proxy=proxy, use_signal=use_signal,
+                                             verify_hostname=verify_hostname,
+                                             _pycurl_version_fn=pcverfn)
+
+            curl = FakeCurl(RapiMock())
+            cl = client.GanetiRapiClient("master.example.com",
+                                         curl_config_fn=cfgfn, curl=curl)
+
+            self.assertEqual(curl.getopt(pycurl.PROXY), proxy)
+            self.assertEqual(curl.getopt(pycurl.NOSIGNAL), not use_signal)
+
+            if verify_hostname:
+              self.assertEqual(curl.getopt(pycurl.SSL_VERIFYHOST), 2)
+            else:
+              self.assertEqual(curl.getopt(pycurl.SSL_VERIFYHOST), 0)
+
+  def testNoCertVerify(self):
+    cfgfn = client.GenericCurlConfig()
+
+    curl = FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                 curl=curl)
+
+    self.assertFalse(curl.getopt(pycurl.SSL_VERIFYPEER))
+    self.assertFalse(curl.getopt(pycurl.CAINFO))
+    self.assertFalse(curl.getopt(pycurl.CAPATH))
+
+  def testCertVerifyCurlBundle(self):
+    cfgfn = client.GenericCurlConfig(use_curl_cabundle=True)
+
+    curl = FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                 curl=curl)
+
+    self.assert_(curl.getopt(pycurl.SSL_VERIFYPEER))
+    self.assertFalse(curl.getopt(pycurl.CAINFO))
+    self.assertFalse(curl.getopt(pycurl.CAPATH))
+
+  def testCertVerifyCafile(self):
+    mycert = "/tmp/some/UNUSED/cert/file.pem"
+    cfgfn = client.GenericCurlConfig(cafile=mycert)
+
+    curl = FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                 curl=curl)
+
+    self.assert_(curl.getopt(pycurl.SSL_VERIFYPEER))
+    self.assertEqual(curl.getopt(pycurl.CAINFO), mycert)
+    self.assertFalse(curl.getopt(pycurl.CAPATH))
+
+  def testCertVerifyCapath(self):
+    certdir = "/tmp/some/UNUSED/cert/directory"
+    pcverfn = _FakeOpenSslPycurlVersion
+    cfgfn = client.GenericCurlConfig(capath=certdir,
+                                     _pycurl_version_fn=pcverfn)
+
+    curl = FakeCurl(RapiMock())
+    cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                 curl=curl)
+
+    self.assert_(curl.getopt(pycurl.SSL_VERIFYPEER))
+    self.assertEqual(curl.getopt(pycurl.CAPATH), certdir)
+    self.assertFalse(curl.getopt(pycurl.CAINFO))
+
+  def testCertVerifyCapathGnuTls(self):
+    certdir = "/tmp/some/UNUSED/cert/directory"
+    pcverfn = _FakeGnuTlsPycurlVersion
+    cfgfn = client.GenericCurlConfig(capath=certdir,
+                                     _pycurl_version_fn=pcverfn)
+
+    curl = FakeCurl(RapiMock())
+    self.assertRaises(client.Error, client.GanetiRapiClient,
+                      "master.example.com", curl_config_fn=cfgfn, curl=curl)
+
+  def testCertVerifyNoSsl(self):
+    certdir = "/tmp/some/UNUSED/cert/directory"
+    pcverfn = _FakeNoSslPycurlVersion
+    cfgfn = client.GenericCurlConfig(capath=certdir,
+                                     _pycurl_version_fn=pcverfn)
+
+    curl = FakeCurl(RapiMock())
+    self.assertRaises(client.Error, client.GanetiRapiClient,
+                      "master.example.com", curl_config_fn=cfgfn, curl=curl)
+
+  def testCertVerifyFancySsl(self):
+    certdir = "/tmp/some/UNUSED/cert/directory"
+    pcverfn = _FakeFancySslPycurlVersion
+    cfgfn = client.GenericCurlConfig(capath=certdir,
+                                     _pycurl_version_fn=pcverfn)
+
+    curl = FakeCurl(RapiMock())
+    self.assertRaises(NotImplementedError, client.GanetiRapiClient,
+                      "master.example.com", curl_config_fn=cfgfn, curl=curl)
+
+  def testCertVerifyCapath(self):
+    for connect_timeout in [None, 1, 5, 10, 30, 60, 300]:
+      for timeout in [None, 1, 30, 60, 3600, 24 * 3600]:
+        cfgfn = client.GenericCurlConfig(connect_timeout=connect_timeout,
+                                         timeout=timeout)
+
+        curl = FakeCurl(RapiMock())
+        cl = client.GanetiRapiClient("master.example.com", curl_config_fn=cfgfn,
+                                     curl=curl)
+
+        self.assertEqual(curl.getopt(pycurl.CONNECTTIMEOUT), connect_timeout)
+        self.assertEqual(curl.getopt(pycurl.TIMEOUT), timeout)
+
+
 class GanetiRapiClientTests(testutils.GanetiTestCase):
   def setUp(self):
     testutils.GanetiTestCase.setUp(self)
 
     self.rapi = RapiMock()
-    self.http = OpenerDirectorMock(self.rapi)
-    self.client = client.GanetiRapiClient('master.foo.com')
-    self.client._http = self.http
-    # Hard-code the version for easier testing.
-    self.client._version = 2
+    self.curl = FakeCurl(self.rapi)
+    self.client = client.GanetiRapiClient("master.example.com",
+                                          curl=self.curl)
+
+    # Signals should be disabled by default
+    self.assert_(self.curl.getopt(pycurl.NOSIGNAL))
+
+    # No auth and no proxy
+    self.assertFalse(self.curl.getopt(pycurl.USERPWD))
+    self.assert_(self.curl.getopt(pycurl.PROXY) is None)
+
+    # Content-type is required for requests
+    headers = self.curl.getopt(pycurl.HTTPHEADER)
+    self.assert_("Content-type: application/json" in headers)
 
   def assertHandler(self, handler_cls):
     self.failUnless(isinstance(self.rapi.GetLastHandler(), handler_cls))
@@ -273,7 +466,7 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
     self.assertHandler(rlib2.R_2_instances)
     self.assertDryRun()
 
-    data = serializer.LoadJson(self.http.last_request.data)
+    data = serializer.LoadJson(self.rapi.GetLastRequestData())
 
     for field in ["dry_run", "beparams", "hvparams", "start"]:
       self.assertFalse(field in data)
@@ -293,7 +486,7 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
     self.assertEqual(job_id, 24740)
     self.assertHandler(rlib2.R_2_instances)
 
-    data = serializer.LoadJson(self.http.last_request.data)
+    data = serializer.LoadJson(self.rapi.GetLastRequestData())
     self.assertEqual(data[rlib2._REQ_DATA_VERSION], 1)
     self.assertEqual(data["name"], "inst2.example.com")
     self.assertEqual(data["disk_template"], "drbd8")
@@ -411,7 +604,7 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
     self.assertHandler(rlib2.R_2_instances_name_export)
     self.assertItems(["inst2"])
 
-    data = serializer.LoadJson(self.http.last_request.data)
+    data = serializer.LoadJson(self.rapi.GetLastRequestData())
     self.assertEqual(data["mode"], "local")
     self.assertEqual(data["destination"], "nodeX")
     self.assertEqual(data["shutdown"], True)
@@ -509,7 +702,7 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
     self.assertHandler(rlib2.R_2_nodes_name_role)
     self.assertItems(["node-foo"])
     self.assertQuery("force", ["1"])
-    self.assertEqual("\"master-candidate\"", self.http.last_request.data)
+    self.assertEqual("\"master-candidate\"", self.rapi.GetLastRequestData())
 
   def testGetNodeStorageUnits(self):
     self.rapi.AddResponse("42")
@@ -576,4 +769,4 @@ class GanetiRapiClientTests(testutils.GanetiTestCase):
 
 
 if __name__ == '__main__':
-  testutils.GanetiTestProgram()
+  client.UsesRapiClient(testutils.GanetiTestProgram)()
index 05b3b24..9399302 100755 (executable)
@@ -148,18 +148,18 @@ class RapiClientFactory:
     self.src_cluster_name = src_cluster_name
     self.dest_cluster_name = dest_cluster_name
 
+    # TODO: Implement timeouts for RAPI connections
     # TODO: Support for using system default paths for verifying SSL certificate
-    # (already implemented in CertAuthorityVerify)
     logging.debug("Using '%s' as source CA", options.src_ca_file)
-    src_ssl_config = rapi.client.CertAuthorityVerify(cafile=options.src_ca_file)
+    src_curl_config = rapi.client.GenericCurlConfig(cafile=options.src_ca_file)
 
     if options.dest_ca_file:
       logging.debug("Using '%s' as destination CA", options.dest_ca_file)
-      dest_ssl_config = \
-        rapi.client.CertAuthorityVerify(cafile=options.dest_ca_file)
+      dest_curl_config = \
+        rapi.client.GenericCurlConfig(cafile=options.dest_ca_file)
     else:
       logging.debug("Using source CA for destination")
-      dest_ssl_config = src_ssl_config
+      dest_curl_config = src_curl_config
 
     logging.debug("Source RAPI server is %s:%s",
                   src_cluster_name, options.src_rapi_port)
@@ -182,7 +182,7 @@ class RapiClientFactory:
     self.GetSourceClient = lambda: \
       rapi.client.GanetiRapiClient(src_cluster_name,
                                    port=options.src_rapi_port,
-                                   config_ssl_verification=src_ssl_config,
+                                   curl_config_fn=src_curl_config,
                                    username=src_username,
                                    password=src_password)
 
@@ -212,7 +212,7 @@ class RapiClientFactory:
     self.GetDestClient = lambda: \
       rapi.client.GanetiRapiClient(dest_cluster_name,
                                    port=dest_rapi_port,
-                                   config_ssl_verification=dest_ssl_config,
+                                   curl_config_fn=dest_curl_config,
                                    username=dest_username,
                                    password=dest_password)
 
@@ -771,6 +771,7 @@ def CheckOptions(parser, options, args):
   return (src_cluster_name, dest_cluster_name, instance_names)
 
 
+@rapi.client.UsesRapiClient
 def main():
   """Main routine.