http.server: Factorize request handling even more
authorMichael Hanselmann <hansmi@google.com>
Tue, 21 Feb 2012 19:14:46 +0000 (20:14 +0100)
committerMichael Hanselmann <hansmi@google.com>
Wed, 22 Feb 2012 13:13:01 +0000 (14:13 +0100)
This splits even more parts of the request handling code into a separate
class. Doing so allows us to reuse this part of the code for tests (e.g.
mocks). Unlike before now the error handling can also be reused.

The patch became a bit more convoluted than intended, but the end result
is easier to read than the original code.

Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: RenĂ© Nussbaumer <rn@google.com>

lib/http/server.py

index 25e7928..d28bf66 100644 (file)
@@ -34,6 +34,7 @@ import asyncore
 from ganeti import http
 from ganeti import utils
 from ganeti import netutils
+from ganeti import compat
 
 
 WEEKDAYNAME = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
@@ -214,7 +215,7 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
     return http.HttpClientToServerStartLine(method, path, version)
 
 
-def HandleServerRequest(handler, req_msg):
+def _HandleServerRequestInner(handler, req_msg):
   """Calls the handler function for the current request.
 
   """
@@ -250,6 +251,126 @@ def HandleServerRequest(handler, req_msg):
     handler_context.private = None
 
 
+class HttpResponder(object):
+  # The default request version.  This only affects responses up until
+  # the point where the request line is parsed, so it mainly decides what
+  # the client gets back when sending a malformed request line.
+  # Most web servers default to HTTP 0.9, i.e. don't send a status line.
+  default_request_version = http.HTTP_0_9
+
+  responses = BaseHTTPServer.BaseHTTPRequestHandler.responses
+
+  def __init__(self, handler):
+    """Initializes this class.
+
+    """
+    self._handler = handler
+
+  def __call__(self, fn):
+    """Handles a request.
+
+    @type fn: callable
+    @param fn: Callback for retrieving HTTP request, must return a tuple
+      containing request message (L{http.HttpMessage}) and C{None} or the
+      message reader (L{_HttpClientToServerMessageReader})
+
+    """
+    response_msg = http.HttpMessage()
+    response_msg.start_line = \
+      http.HttpServerToClientStartLine(version=self.default_request_version,
+                                       code=None, reason=None)
+
+    force_close = True
+
+    try:
+      (request_msg, req_msg_reader) = fn()
+
+      response_msg.start_line.version = request_msg.start_line.version
+
+      # RFC2616, 14.23: All Internet-based HTTP/1.1 servers MUST respond
+      # with a 400 (Bad Request) status code to any HTTP/1.1 request
+      # message which lacks a Host header field.
+      if (request_msg.start_line.version == http.HTTP_1_1 and
+          not (request_msg.headers and
+               http.HTTP_HOST in request_msg.headers)):
+        raise http.HttpBadRequest(message="Missing Host header")
+
+      (response_msg.start_line.code, response_msg.headers,
+       response_msg.body) = \
+        _HandleServerRequestInner(self._handler, request_msg)
+    except http.HttpException, err:
+      self._SetError(self.responses, self._handler, response_msg, err)
+    else:
+      # Only wait for client to close if we didn't have any exception.
+      force_close = False
+
+    return (request_msg, req_msg_reader, force_close,
+            self._Finalize(self.responses, response_msg))
+
+  @staticmethod
+  def _SetError(responses, handler, response_msg, err):
+    """Sets the response code and body from a HttpException.
+
+    @type err: HttpException
+    @param err: Exception instance
+
+    """
+    try:
+      (shortmsg, longmsg) = responses[err.code]
+    except KeyError:
+      shortmsg = longmsg = "Unknown"
+
+    if err.message:
+      message = err.message
+    else:
+      message = shortmsg
+
+    values = {
+      "code": err.code,
+      "message": cgi.escape(message),
+      "explain": longmsg,
+      }
+
+    (content_type, body) = handler.FormatErrorMessage(values)
+
+    headers = {
+      http.HTTP_CONTENT_TYPE: content_type,
+      }
+
+    if err.headers:
+      headers.update(err.headers)
+
+    response_msg.start_line.code = err.code
+    response_msg.headers = headers
+    response_msg.body = body
+
+  @staticmethod
+  def _Finalize(responses, msg):
+    assert msg.start_line.reason is None
+
+    if not msg.headers:
+      msg.headers = {}
+
+    msg.headers.update({
+      # TODO: Keep-alive is not supported
+      http.HTTP_CONNECTION: "close",
+      http.HTTP_DATE: _DateTimeHeader(),
+      http.HTTP_SERVER: http.HTTP_GANETI_VERSION,
+      })
+
+    # Get response reason based on code
+    try:
+      code_desc = responses[msg.start_line.code]
+    except KeyError:
+      reason = ""
+    else:
+      (reason, _) = code_desc
+
+    msg.start_line.reason = reason
+
+    return msg
+
+
 class HttpServerRequestExecutor(object):
   """Implements server side of HTTP.
 
@@ -259,14 +380,6 @@ class HttpServerRequestExecutor(object):
   not supported.
 
   """
-  # The default request version.  This only affects responses up until
-  # the point where the request line is parsed, so it mainly decides what
-  # the client gets back when sending a malformed request line.
-  # Most web servers default to HTTP 0.9, i.e. don't send a status line.
-  default_request_version = http.HTTP_0_9
-
-  responses = BaseHTTPServer.BaseHTTPRequestHandler.responses
-
   # Timeouts in seconds for socket layer
   WRITE_TIMEOUT = 10
   READ_TIMEOUT = 10
@@ -276,159 +389,75 @@ class HttpServerRequestExecutor(object):
     """Initializes this class.
 
     """
-    self.server = server
-    self.handler = handler
-    self.sock = sock
-    self.client_addr = client_addr
-
-    self.request_msg = http.HttpMessage()
-    self.response_msg = http.HttpMessage()
-
-    self.response_msg.start_line = \
-      http.HttpServerToClientStartLine(version=self.default_request_version,
-                                       code=None, reason=None)
+    responder = HttpResponder(handler)
 
     # Disable Python's timeout
-    self.sock.settimeout(None)
+    sock.settimeout(None)
 
     # Operate in non-blocking mode
-    self.sock.setblocking(0)
+    sock.setblocking(0)
+
+    request_msg_reader = None
+    force_close = True
 
     logging.debug("Connection from %s:%s", client_addr[0], client_addr[1])
     try:
-      request_msg_reader = None
-      force_close = True
+      # Block for closing connection
       try:
         # Do the secret SSL handshake
-        if self.server.using_ssl:
-          self.sock.set_accept_state()
+        if server.using_ssl:
+          sock.set_accept_state()
           try:
-            http.Handshake(self.sock, self.WRITE_TIMEOUT)
+            http.Handshake(sock, self.WRITE_TIMEOUT)
           except http.HttpSessionHandshakeUnexpectedEOF:
             # Ignore rest
             return
 
-        try:
-          try:
-            request_msg_reader = self._ReadRequest()
-
-            # RFC2616, 14.23: All Internet-based HTTP/1.1 servers MUST respond
-            # with a 400 (Bad Request) status code to any HTTP/1.1 request
-            # message which lacks a Host header field.
-            if (self.request_msg.start_line.version == http.HTTP_1_1 and
-                http.HTTP_HOST not in self.request_msg.headers):
-              raise http.HttpBadRequest(message="Missing Host header")
-
-            (self.response_msg.start_line.code, self.response_msg.headers,
-             self.response_msg.body) = \
-              HandleServerRequest(self.handler, self.request_msg)
-
-            # Only wait for client to close if we didn't have any exception.
-            force_close = False
-          except http.HttpException, err:
-            self._SetErrorStatus(err)
-        finally:
-          # Try to send a response
-          self._SendResponse()
+        (request_msg, request_msg_reader, force_close, response_msg) = \
+          responder(compat.partial(self._ReadRequest, sock, self.READ_TIMEOUT))
+        if response_msg:
+          # HttpMessage.start_line can be of different types
+          # pylint: disable=E1103
+          logging.info("%s:%s %s %s", client_addr[0], client_addr[1],
+                       request_msg.start_line, response_msg.start_line.code)
+          self._SendResponse(sock, request_msg, response_msg,
+                             self.WRITE_TIMEOUT)
       finally:
         http.ShutdownConnection(sock, self.CLOSE_TIMEOUT, self.WRITE_TIMEOUT,
                                 request_msg_reader, force_close)
 
-      self.sock.close()
-      self.sock = None
+      sock.close()
     finally:
       logging.debug("Disconnected %s:%s", client_addr[0], client_addr[1])
 
-  def _ReadRequest(self):
+  @staticmethod
+  def _ReadRequest(sock, timeout):
     """Reads a request sent by client.
 
     """
+    msg = http.HttpMessage()
+
     try:
-      request_msg_reader = \
-        _HttpClientToServerMessageReader(self.sock, self.request_msg,
-                                         self.READ_TIMEOUT)
+      reader = _HttpClientToServerMessageReader(sock, msg, timeout)
     except http.HttpSocketTimeout:
       raise http.HttpError("Timeout while reading request")
     except socket.error, err:
       raise http.HttpError("Error reading request: %s" % err)
 
-    self.response_msg.start_line.version = self.request_msg.start_line.version
-
-    return request_msg_reader
+    return (msg, reader)
 
-  def _SendResponse(self):
+  @staticmethod
+  def _SendResponse(sock, req_msg, msg, timeout):
     """Sends the response to the client.
 
     """
-    # HttpMessage.start_line can be of different types, pylint: disable=E1103
-    if self.response_msg.start_line.code is None:
-      return
-
-    if not self.response_msg.headers:
-      self.response_msg.headers = {}
-
-    self.response_msg.headers.update({
-      # TODO: Keep-alive is not supported
-      http.HTTP_CONNECTION: "close",
-      http.HTTP_DATE: _DateTimeHeader(),
-      http.HTTP_SERVER: http.HTTP_GANETI_VERSION,
-      })
-
-    # Get response reason based on code
-    response_code = self.response_msg.start_line.code
-    if response_code in self.responses:
-      response_reason = self.responses[response_code][0]
-    else:
-      response_reason = ""
-    self.response_msg.start_line.reason = response_reason
-
-    logging.info("%s:%s %s %s", self.client_addr[0], self.client_addr[1],
-                 self.request_msg.start_line, response_code)
-
     try:
-      _HttpServerToClientMessageWriter(self.sock, self.request_msg,
-                                       self.response_msg, self.WRITE_TIMEOUT)
+      _HttpServerToClientMessageWriter(sock, req_msg, msg, timeout)
     except http.HttpSocketTimeout:
       raise http.HttpError("Timeout while sending response")
     except socket.error, err:
       raise http.HttpError("Error sending response: %s" % err)
 
-  def _SetErrorStatus(self, err):
-    """Sets the response code and body from a HttpException.
-
-    @type err: HttpException
-    @param err: Exception instance
-
-    """
-    try:
-      (shortmsg, longmsg) = self.responses[err.code]
-    except KeyError:
-      shortmsg = longmsg = "Unknown"
-
-    if err.message:
-      message = err.message
-    else:
-      message = shortmsg
-
-    values = {
-      "code": err.code,
-      "message": cgi.escape(message),
-      "explain": longmsg,
-      }
-
-    (content_type, body) = self.handler.FormatErrorMessage(values)
-
-    headers = {
-      http.HTTP_CONTENT_TYPE: content_type,
-      }
-
-    if err.headers:
-      headers.update(err.headers)
-
-    self.response_msg.start_line.code = err.code
-    self.response_msg.headers = headers
-    self.response_msg.body = body
-
 
 class HttpServer(http.HttpBase, asyncore.dispatcher):
   """Generic HTTP server class