Split handling HTTP requests into separate class
[ganeti-local] / lib / http / server.py
index d7e374c..8e9a638 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2007, 2008 Google Inc.
+# Copyright (C) 2007, 2008, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -26,19 +26,20 @@ import BaseHTTPServer
 import cgi
 import logging
 import os
-import select
 import socket
 import time
 import signal
 import asyncore
 
 from ganeti import http
+from ganeti import utils
+from ganeti import netutils
 
 
-WEEKDAYNAME = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
+WEEKDAYNAME = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
 MONTHNAME = [None,
-             'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
-             'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
+             "Jan", "Feb", "Mar", "Apr", "May", "Jun",
+             "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
 
 # Default error message
 DEFAULT_ERROR_CONTENT_TYPE = "text/html"
@@ -74,12 +75,12 @@ class _HttpServerRequest(object):
   """Data structure for HTTP request on server side.
 
   """
-  def __init__(self, request_msg):
+  def __init__(self, method, path, headers, body):
     # Request attributes
-    self.request_method = request_msg.start_line.method
-    self.request_path = request_msg.start_line.path
-    self.request_headers = request_msg.headers
-    self.request_body = request_msg.decoded_body
+    self.request_method = method
+    self.request_path = path
+    self.request_headers = headers
+    self.request_body = body
 
     # Response attributes
     self.resp_headers = {}
@@ -88,6 +89,14 @@ class _HttpServerRequest(object):
     # authentication)
     self.private = None
 
+  def __repr__(self):
+    status = ["%s.%s" % (self.__class__.__module__, self.__class__.__name__),
+              self.request_method, self.request_path,
+              "headers=%r" % str(self.request_headers),
+              "body=%r" % (self.request_body, )]
+
+    return "<%s at %#x>" % (" ".join(status), id(self))
+
 
 class _HttpServerToClientMessageWriter(http.HttpMessageWriter):
   """Writes an HTTP response to client.
@@ -169,7 +178,7 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
 
     if len(words) == 3:
       [method, path, version] = words
-      if version[:5] != 'HTTP/':
+      if version[:5] != "HTTP/":
         raise http.HttpBadRequest("Bad request version (%r)" % version)
 
       try:
@@ -205,6 +214,42 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
     return http.HttpClientToServerStartLine(method, path, version)
 
 
+def HandleServerRequest(handler, req_msg):
+  """Calls the handler function for the current request.
+
+  """
+  handler_context = _HttpServerRequest(req_msg.start_line.method,
+                                       req_msg.start_line.path,
+                                       req_msg.headers,
+                                       req_msg.body)
+
+  logging.debug("Handling request %r", handler_context)
+
+  try:
+    try:
+      # Authentication, etc.
+      handler.PreHandleRequest(handler_context)
+
+      # Call actual request handler
+      result = handler.HandleRequest(handler_context)
+    except (http.HttpException, KeyboardInterrupt, SystemExit):
+      raise
+    except Exception, err:
+      logging.exception("Caught exception")
+      raise http.HttpInternalServerError(message=str(err))
+    except:
+      logging.exception("Unknown exception")
+      raise http.HttpInternalServerError(message="Unknown error")
+
+    if not isinstance(result, basestring):
+      raise http.HttpError("Handler function didn't return string type")
+
+    return (http.HTTP_OK, handler_context.resp_headers, result)
+  finally:
+    # No reason to keep this any longer, even for exceptions
+    handler_context.private = None
+
+
 class HttpServerRequestExecutor(object):
   """Implements server side of HTTP.
 
@@ -231,11 +276,12 @@ class HttpServerRequestExecutor(object):
   READ_TIMEOUT = 10
   CLOSE_TIMEOUT = 1
 
-  def __init__(self, server, sock, client_addr):
+  def __init__(self, server, handler, sock, client_addr):
     """Initializes this class.
 
     """
     self.server = server
+    self.handler = handler
     self.sock = sock
     self.client_addr = client_addr
 
@@ -269,7 +315,17 @@ class HttpServerRequestExecutor(object):
         try:
           try:
             request_msg_reader = self._ReadRequest()
-            self._HandleRequest()
+
+            # RFC2616, 14.23: All Internet-based HTTP/1.1 servers MUST respond
+            # with a 400 (Bad Request) status code to any HTTP/1.1 request
+            # message which lacks a Host header field.
+            if (self.request_msg.start_line.version == http.HTTP_1_1 and
+                http.HTTP_HOST not in self.request_msg.headers):
+              raise http.HttpBadRequest(message="Missing Host header")
+
+            (self.response_msg.start_line.code, self.response_msg.headers,
+             self.response_msg.body) = \
+              HandleServerRequest(self.handler, self.request_msg)
 
             # Only wait for client to close if we didn't have any exception.
             force_close = False
@@ -304,42 +360,11 @@ class HttpServerRequestExecutor(object):
 
     return request_msg_reader
 
-  def _HandleRequest(self):
-    """Calls the handler function for the current request.
-
-    """
-    handler_context = _HttpServerRequest(self.request_msg)
-
-    try:
-      try:
-        # Authentication, etc.
-        self.server.PreHandleRequest(handler_context)
-
-        # Call actual request handler
-        result = self.server.HandleRequest(handler_context)
-      except (http.HttpException, KeyboardInterrupt, SystemExit):
-        raise
-      except Exception, err:
-        logging.exception("Caught exception")
-        raise http.HttpInternalServerError(message=str(err))
-      except:
-        logging.exception("Unknown exception")
-        raise http.HttpInternalServerError(message="Unknown error")
-
-      # TODO: Content-type
-      encoder = http.HttpJsonConverter()
-      self.response_msg.start_line.code = http.HTTP_OK
-      self.response_msg.body = encoder.Encode(result)
-      self.response_msg.headers = handler_context.resp_headers
-      self.response_msg.headers[http.HTTP_CONTENT_TYPE] = encoder.CONTENT_TYPE
-    finally:
-      # No reason to keep this any longer, even for exceptions
-      handler_context.private = None
-
   def _SendResponse(self):
     """Sends the response to the client.
 
     """
+    # HttpMessage.start_line can be of different types, pylint: disable=E1103
     if self.response_msg.start_line.code is None:
       return
 
@@ -416,15 +441,14 @@ class HttpServerRequestExecutor(object):
     """
     return self.error_message_format % values
 
+
 class HttpServer(http.HttpBase, asyncore.dispatcher):
   """Generic HTTP server class
 
-  Users of this class must subclass it and override the HandleRequest function.
-
   """
   MAX_CHILDREN = 20
 
-  def __init__(self, mainloop, local_address, port,
+  def __init__(self, mainloop, local_address, port, handler,
                ssl_params=None, ssl_verify_peer=False,
                request_executor_class=None):
     """Initializes the HTTP server
@@ -456,8 +480,9 @@ class HttpServer(http.HttpBase, asyncore.dispatcher):
     self.mainloop = mainloop
     self.local_address = local_address
     self.port = port
-
-    self.socket = self._CreateSocket(ssl_params, ssl_verify_peer)
+    self.handler = handler
+    family = netutils.IPAddress.GetAddressFamily(local_address)
+    self.socket = self._CreateSocket(ssl_params, ssl_verify_peer, family)
 
     # Allow port to be reused
     self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -504,7 +529,7 @@ class HttpServer(http.HttpBase, asyncore.dispatcher):
 
     for child in self._children:
       try:
-        pid, status = os.waitpid(child, os.WNOHANG)
+        pid, _ = os.waitpid(child, os.WNOHANG)
       except os.error:
         pid = None
       if pid and pid in self._children:
@@ -514,6 +539,7 @@ class HttpServer(http.HttpBase, asyncore.dispatcher):
     """Called for each incoming connection
 
     """
+    # pylint: disable=W0212
     (connection, client_addr) = self.socket.accept()
 
     self._CollectChildren(False)
@@ -532,8 +558,11 @@ class HttpServer(http.HttpBase, asyncore.dispatcher):
           pass
         self.socket = None
 
-        self.request_executor(self, connection, client_addr)
-      except Exception:
+        # In case the handler code uses temporary files
+        utils.ResetTempfileModule()
+
+        self.request_executor(self, self.handler, connection, client_addr)
+      except Exception: # pylint: disable=W0703
         logging.exception("Error while handling request from %s:%s",
                           client_addr[0], client_addr[1])
         os._exit(1)
@@ -541,6 +570,14 @@ class HttpServer(http.HttpBase, asyncore.dispatcher):
     else:
       self._children.append(pid)
 
+
+class HttpServerHandler(object):
+  """Base class for handling HTTP server requests.
+
+  Users of this class must subclass it and override the L{HandleRequest}
+  function.
+
+  """
   def PreHandleRequest(self, req):
     """Called before handling a request.