Increase maximum HTTP message size
[ganeti-local] / lib / http / server.py
index f63bc50..b6f0504 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
-# Copyright (C) 2007, 2008 Google Inc.
+# Copyright (C) 2007, 2008, 2010, 2012 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -26,21 +26,22 @@ import BaseHTTPServer
 import cgi
 import logging
 import os
-import select
 import socket
 import time
 import signal
+import asyncore
 
-from ganeti import constants
-from ganeti import serializer
-from ganeti import utils
 from ganeti import http
+from ganeti import utils
+from ganeti import netutils
+from ganeti import compat
+from ganeti import errors
 
 
-WEEKDAYNAME = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
+WEEKDAYNAME = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
 MONTHNAME = [None,
-             'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
-             'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
+             "Jan", "Feb", "Mar", "Apr", "May", "Jun",
+             "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"]
 
 # Default error message
 DEFAULT_ERROR_CONTENT_TYPE = "text/html"
@@ -76,12 +77,12 @@ class _HttpServerRequest(object):
   """Data structure for HTTP request on server side.
 
   """
-  def __init__(self, request_msg):
+  def __init__(self, method, path, headers, body):
     # Request attributes
-    self.request_method = request_msg.start_line.method
-    self.request_path = request_msg.start_line.path
-    self.request_headers = request_msg.headers
-    self.request_body = request_msg.decoded_body
+    self.request_method = method
+    self.request_path = path
+    self.request_headers = headers
+    self.request_body = body
 
     # Response attributes
     self.resp_headers = {}
@@ -90,6 +91,14 @@ class _HttpServerRequest(object):
     # authentication)
     self.private = None
 
+  def __repr__(self):
+    status = ["%s.%s" % (self.__class__.__module__, self.__class__.__name__),
+              self.request_method, self.request_path,
+              "headers=%r" % str(self.request_headers),
+              "body=%r" % (self.request_body, )]
+
+    return "<%s at %#x>" % (" ".join(status), id(self))
+
 
 class _HttpServerToClientMessageWriter(http.HttpMessageWriter):
   """Writes an HTTP response to client.
@@ -102,7 +111,7 @@ class _HttpServerToClientMessageWriter(http.HttpMessageWriter):
     @param sock: Target socket
     @type request_msg: http.HttpMessage
     @param request_msg: Request message, required to determine whether
-                        response may have a message body
+        response may have a message body
     @type response_msg: http.HttpMessage
     @param response_msg: Response message
     @type write_timeout: float
@@ -150,7 +159,7 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
 
   """
   # Length limits
-  START_LINE_LENGTH_MAX = 4096
+  START_LINE_LENGTH_MAX = 8192
   HEADER_LENGTH_MAX = 4096
 
   def ParseStartLine(self, start_line):
@@ -171,7 +180,7 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
 
     if len(words) == 3:
       [method, path, version] = words
-      if version[:5] != 'HTTP/':
+      if version[:5] != "HTTP/":
         raise http.HttpBadRequest("Bad request version (%r)" % version)
 
       try:
@@ -193,7 +202,7 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
 
       if version_number >= (2, 0):
         raise http.HttpVersionNotSupported("Invalid HTTP Version (%s)" %
-                                      base_version_number)
+                                           base_version_number)
 
     elif len(words) == 2:
       version = http.HTTP_0_9
@@ -207,147 +216,144 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
     return http.HttpClientToServerStartLine(method, path, version)
 
 
-class _HttpServerRequestExecutor(object):
-  """Implements server side of HTTP.
-
-  This class implements the server side of HTTP. It's based on code of Python's
-  BaseHTTPServer, from both version 2.4 and 3k. It does not support non-ASCII
-  character encodings. Keep-alive connections are not supported.
+def _HandleServerRequestInner(handler, req_msg):
+  """Calls the handler function for the current request.
 
   """
+  handler_context = _HttpServerRequest(req_msg.start_line.method,
+                                       req_msg.start_line.path,
+                                       req_msg.headers,
+                                       req_msg.body)
+
+  logging.debug("Handling request %r", handler_context)
+
+  try:
+    try:
+      # Authentication, etc.
+      handler.PreHandleRequest(handler_context)
+
+      # Call actual request handler
+      result = handler.HandleRequest(handler_context)
+    except (http.HttpException, errors.RapiTestResult,
+            KeyboardInterrupt, SystemExit):
+      raise
+    except Exception, err:
+      logging.exception("Caught exception")
+      raise http.HttpInternalServerError(message=str(err))
+    except:
+      logging.exception("Unknown exception")
+      raise http.HttpInternalServerError(message="Unknown error")
+
+    if not isinstance(result, basestring):
+      raise http.HttpError("Handler function didn't return string type")
+
+    return (http.HTTP_OK, handler_context.resp_headers, result)
+  finally:
+    # No reason to keep this any longer, even for exceptions
+    handler_context.private = None
+
+
+class HttpResponder(object):
   # The default request version.  This only affects responses up until
   # the point where the request line is parsed, so it mainly decides what
   # the client gets back when sending a malformed request line.
   # Most web servers default to HTTP 0.9, i.e. don't send a status line.
   default_request_version = http.HTTP_0_9
 
-  # Error message settings
-  error_message_format = DEFAULT_ERROR_MESSAGE
-  error_content_type = DEFAULT_ERROR_CONTENT_TYPE
-
   responses = BaseHTTPServer.BaseHTTPRequestHandler.responses
 
-  # Timeouts in seconds for socket layer
-  WRITE_TIMEOUT = 10
-  READ_TIMEOUT = 10
-  CLOSE_TIMEOUT = 1
-
-  def __init__(self, server, sock, client_addr):
+  def __init__(self, handler):
     """Initializes this class.
 
     """
-    self.server = server
-    self.sock = sock
-    self.client_addr = client_addr
+    self._handler = handler
+
+  def __call__(self, fn):
+    """Handles a request.
 
-    self.request_msg = http.HttpMessage()
-    self.response_msg = http.HttpMessage()
+    @type fn: callable
+    @param fn: Callback for retrieving HTTP request, must return a tuple
+      containing request message (L{http.HttpMessage}) and C{None} or the
+      message reader (L{_HttpClientToServerMessageReader})
 
-    self.response_msg.start_line = \
+    """
+    response_msg = http.HttpMessage()
+    response_msg.start_line = \
       http.HttpServerToClientStartLine(version=self.default_request_version,
                                        code=None, reason=None)
 
-    # Disable Python's timeout
-    self.sock.settimeout(None)
-
-    # Operate in non-blocking mode
-    self.sock.setblocking(0)
+    force_close = True
 
-    logging.info("Connection from %s:%s", client_addr[0], client_addr[1])
     try:
-      request_msg_reader = None
-      force_close = True
-      try:
-        # Do the secret SSL handshake
-        if self.server.using_ssl:
-          self.sock.set_accept_state()
-          try:
-            http.Handshake(self.sock, self.WRITE_TIMEOUT)
-          except http.HttpSessionHandshakeUnexpectedEOF:
-            # Ignore rest
-            return
+      (request_msg, req_msg_reader) = fn()
+
+      response_msg.start_line.version = request_msg.start_line.version
+
+      # RFC2616, 14.23: All Internet-based HTTP/1.1 servers MUST respond
+      # with a 400 (Bad Request) status code to any HTTP/1.1 request
+      # message which lacks a Host header field.
+      if (request_msg.start_line.version == http.HTTP_1_1 and
+          not (request_msg.headers and
+               http.HTTP_HOST in request_msg.headers)):
+        raise http.HttpBadRequest(message="Missing Host header")
+
+      (response_msg.start_line.code, response_msg.headers,
+       response_msg.body) = \
+        _HandleServerRequestInner(self._handler, request_msg)
+    except http.HttpException, err:
+      self._SetError(self.responses, self._handler, response_msg, err)
+    else:
+      # Only wait for client to close if we didn't have any exception.
+      force_close = False
 
-        try:
-          try:
-            request_msg_reader = self._ReadRequest()
-            self._HandleRequest()
-
-            # Only wait for client to close if we didn't have any exception.
-            force_close = False
-          except http.HttpException, err:
-            self._SetErrorStatus(err)
-        finally:
-          # Try to send a response
-          self._SendResponse()
-      finally:
-        http.ShutdownConnection(sock, self.CLOSE_TIMEOUT, self.WRITE_TIMEOUT,
-                                request_msg_reader, force_close)
+    return (request_msg, req_msg_reader, force_close,
+            self._Finalize(self.responses, response_msg))
 
-      self.sock.close()
-      self.sock = None
-    finally:
-      logging.info("Disconnected %s:%s", client_addr[0], client_addr[1])
+  @staticmethod
+  def _SetError(responses, handler, response_msg, err):
+    """Sets the response code and body from a HttpException.
 
-  def _ReadRequest(self):
-    """Reads a request sent by client.
+    @type err: HttpException
+    @param err: Exception instance
 
     """
     try:
-      request_msg_reader = \
-        _HttpClientToServerMessageReader(self.sock, self.request_msg,
-                                         self.READ_TIMEOUT)
-    except http.HttpSocketTimeout:
-      raise http.HttpError("Timeout while reading request")
-    except socket.error, err:
-      raise http.HttpError("Error reading request: %s" % err)
+      (shortmsg, longmsg) = responses[err.code]
+    except KeyError:
+      shortmsg = longmsg = "Unknown"
 
-    self.response_msg.start_line.version = self.request_msg.start_line.version
+    if err.message:
+      message = err.message
+    else:
+      message = shortmsg
 
-    return request_msg_reader
+    values = {
+      "code": err.code,
+      "message": cgi.escape(message),
+      "explain": longmsg,
+      }
 
-  def _HandleRequest(self):
-    """Calls the handler function for the current request.
+    (content_type, body) = handler.FormatErrorMessage(values)
 
-    """
-    handler_context = _HttpServerRequest(self.request_msg)
+    headers = {
+      http.HTTP_CONTENT_TYPE: content_type,
+      }
 
-    try:
-      try:
-        # Authentication, etc.
-        self.server.PreHandleRequest(handler_context)
-
-        # Call actual request handler
-        result = self.server.HandleRequest(handler_context)
-      except (http.HttpException, KeyboardInterrupt, SystemExit):
-        raise
-      except Exception, err:
-        logging.exception("Caught exception")
-        raise http.HttpInternalServerError(message=str(err))
-      except:
-        logging.exception("Unknown exception")
-        raise http.HttpInternalServerError(message="Unknown error")
-
-      # TODO: Content-type
-      encoder = http.HttpJsonConverter()
-      self.response_msg.start_line.code = http.HTTP_OK
-      self.response_msg.body = encoder.Encode(result)
-      self.response_msg.headers = handler_context.resp_headers
-      self.response_msg.headers[http.HTTP_CONTENT_TYPE] = encoder.CONTENT_TYPE
-    finally:
-      # No reason to keep this any longer, even for exceptions
-      handler_context.private = None
+    if err.headers:
+      headers.update(err.headers)
 
-  def _SendResponse(self):
-    """Sends the response to the client.
+    response_msg.start_line.code = err.code
+    response_msg.headers = headers
+    response_msg.body = body
 
-    """
-    if self.response_msg.start_line.code is None:
-      return
+  @staticmethod
+  def _Finalize(responses, msg):
+    assert msg.start_line.reason is None
 
-    if not self.response_msg.headers:
-      self.response_msg.headers = {}
+    if not msg.headers:
+      msg.headers = {}
 
-    self.response_msg.headers.update({
+    msg.headers.update({
       # TODO: Keep-alive is not supported
       http.HTTP_CONNECTION: "close",
       http.HTTP_DATE: _DateTimeHeader(),
@@ -355,68 +361,116 @@ class _HttpServerRequestExecutor(object):
       })
 
     # Get response reason based on code
-    response_code = self.response_msg.start_line.code
-    if response_code in self.responses:
-      response_reason = self.responses[response_code][0]
+    try:
+      code_desc = responses[msg.start_line.code]
+    except KeyError:
+      reason = ""
     else:
-      response_reason = ""
-    self.response_msg.start_line.reason = response_reason
+      (reason, _) = code_desc
 
-    logging.info("%s:%s %s %s", self.client_addr[0], self.client_addr[1],
-                 self.request_msg.start_line, response_code)
+    msg.start_line.reason = reason
 
-    try:
-      _HttpServerToClientMessageWriter(self.sock, self.request_msg,
-                                       self.response_msg, self.WRITE_TIMEOUT)
-    except http.HttpSocketTimeout:
-      raise http.HttpError("Timeout while sending response")
-    except socket.error, err:
-      raise http.HttpError("Error sending response: %s" % err)
+    return msg
 
-  def _SetErrorStatus(self, err):
-    """Sets the response code and body from a HttpException.
 
-    @type err: HttpException
-    @param err: Exception instance
+class HttpServerRequestExecutor(object):
+  """Implements server side of HTTP.
+
+  This class implements the server side of HTTP. It's based on code of
+  Python's BaseHTTPServer, from both version 2.4 and 3k. It does not
+  support non-ASCII character encodings. Keep-alive connections are
+  not supported.
+
+  """
+  # Timeouts in seconds for socket layer
+  WRITE_TIMEOUT = 10
+  READ_TIMEOUT = 10
+  CLOSE_TIMEOUT = 1
+
+  def __init__(self, server, handler, sock, client_addr):
+    """Initializes this class.
 
     """
+    responder = HttpResponder(handler)
+
+    # Disable Python's timeout
+    sock.settimeout(None)
+
+    # Operate in non-blocking mode
+    sock.setblocking(0)
+
+    request_msg_reader = None
+    force_close = True
+
+    logging.debug("Connection from %s:%s", client_addr[0], client_addr[1])
     try:
-      (shortmsg, longmsg) = self.responses[err.code]
-    except KeyError:
-      shortmsg = longmsg = "Unknown"
+      # Block for closing connection
+      try:
+        # Do the secret SSL handshake
+        if server.using_ssl:
+          sock.set_accept_state()
+          try:
+            http.Handshake(sock, self.WRITE_TIMEOUT)
+          except http.HttpSessionHandshakeUnexpectedEOF:
+            # Ignore rest
+            return
 
-    if err.message:
-      message = err.message
-    else:
-      message = shortmsg
+        (request_msg, request_msg_reader, force_close, response_msg) = \
+          responder(compat.partial(self._ReadRequest, sock, self.READ_TIMEOUT))
+        if response_msg:
+          # HttpMessage.start_line can be of different types
+          # Instance of 'HttpClientToServerStartLine' has no 'code' member
+          # pylint: disable=E1103,E1101
+          logging.info("%s:%s %s %s", client_addr[0], client_addr[1],
+                       request_msg.start_line, response_msg.start_line.code)
+          self._SendResponse(sock, request_msg, response_msg,
+                             self.WRITE_TIMEOUT)
+      finally:
+        http.ShutdownConnection(sock, self.CLOSE_TIMEOUT, self.WRITE_TIMEOUT,
+                                request_msg_reader, force_close)
 
-    values = {
-      "code": err.code,
-      "message": cgi.escape(message),
-      "explain": longmsg,
-      }
+      sock.close()
+    finally:
+      logging.debug("Disconnected %s:%s", client_addr[0], client_addr[1])
 
-    self.response_msg.start_line.code = err.code
+  @staticmethod
+  def _ReadRequest(sock, timeout):
+    """Reads a request sent by client.
 
-    headers = {}
-    if err.headers:
-      headers.update(err.headers)
-    headers[http.HTTP_CONTENT_TYPE] = self.error_content_type
-    self.response_msg.headers = headers
+    """
+    msg = http.HttpMessage()
+
+    try:
+      reader = _HttpClientToServerMessageReader(sock, msg, timeout)
+    except http.HttpSocketTimeout:
+      raise http.HttpError("Timeout while reading request")
+    except socket.error, err:
+      raise http.HttpError("Error reading request: %s" % err)
 
-    self.response_msg.body = self.error_message_format % values
+    return (msg, reader)
 
+  @staticmethod
+  def _SendResponse(sock, req_msg, msg, timeout):
+    """Sends the response to the client.
+
+    """
+    try:
+      _HttpServerToClientMessageWriter(sock, req_msg, msg, timeout)
+    except http.HttpSocketTimeout:
+      raise http.HttpError("Timeout while sending response")
+    except socket.error, err:
+      raise http.HttpError("Error sending response: %s" % err)
 
-class HttpServer(http.HttpBase):
-  """Generic HTTP server class
 
-  Users of this class must subclass it and override the HandleRequest function.
+class HttpServer(http.HttpBase, asyncore.dispatcher):
+  """Generic HTTP server class
 
   """
   MAX_CHILDREN = 20
 
-  def __init__(self, mainloop, local_address, port,
-               ssl_params=None, ssl_verify_peer=False):
+  def __init__(self, mainloop, local_address, port, handler,
+               ssl_params=None, ssl_verify_peer=False,
+               request_executor_class=None):
     """Initializes the HTTP server
 
     @type mainloop: ganeti.daemon.Mainloop
@@ -428,24 +482,34 @@ class HttpServer(http.HttpBase):
     @type ssl_params: HttpSslParams
     @param ssl_params: SSL key and certificate
     @type ssl_verify_peer: bool
-    @param ssl_verify_peer: Whether to require client certificate and compare
-                            it with our certificate
+    @param ssl_verify_peer: Whether to require client certificate
+        and compare it with our certificate
+    @type request_executor_class: class
+    @param request_executor_class: an class derived from the
+        HttpServerRequestExecutor class
 
     """
     http.HttpBase.__init__(self)
+    asyncore.dispatcher.__init__(self)
+
+    if request_executor_class is None:
+      self.request_executor = HttpServerRequestExecutor
+    else:
+      self.request_executor = request_executor_class
 
     self.mainloop = mainloop
     self.local_address = local_address
     self.port = port
-
-    self.socket = self._CreateSocket(ssl_params, ssl_verify_peer)
+    self.handler = handler
+    family = netutils.IPAddress.GetAddressFamily(local_address)
+    self.socket = self._CreateSocket(ssl_params, ssl_verify_peer, family)
 
     # Allow port to be reused
     self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 
     self._children = []
-
-    mainloop.RegisterIO(self, self.socket.fileno(), select.POLLIN)
+    self.set_socket(self.socket)
+    self.accepting = True
     mainloop.RegisterSignal(self)
 
   def Start(self):
@@ -455,9 +519,8 @@ class HttpServer(http.HttpBase):
   def Stop(self):
     self.socket.close()
 
-  def OnIO(self, fd, condition):
-    if condition & select.POLLIN:
-      self._IncomingConnection()
+  def handle_accept(self):
+    self._IncomingConnection()
 
   def OnSignal(self, signum):
     if signum == signal.SIGCHLD:
@@ -478,7 +541,7 @@ class HttpServer(http.HttpBase):
           # As soon as too many children run, we'll not respond to new
           # requests. The real solution would be to add a timeout for children
           # and killing them after some time.
-          pid, status = os.waitpid(0, 0)
+          pid, _ = os.waitpid(0, 0)
         except os.error:
           pid = None
         if pid and pid in self._children:
@@ -486,7 +549,7 @@ class HttpServer(http.HttpBase):
 
     for child in self._children:
       try:
-        pid, status = os.waitpid(child, os.WNOHANG)
+        pid, _ = os.waitpid(child, os.WNOHANG)
       except os.error:
         pid = None
       if pid and pid in self._children:
@@ -496,6 +559,7 @@ class HttpServer(http.HttpBase):
     """Called for each incoming connection
 
     """
+    # pylint: disable=W0212
     (connection, client_addr) = self.socket.accept()
 
     self._CollectChildren(False)
@@ -504,8 +568,21 @@ class HttpServer(http.HttpBase):
     if pid == 0:
       # Child process
       try:
-        _HttpServerRequestExecutor(self, connection, client_addr)
-      except Exception:
+        # The client shouldn't keep the listening socket open. If the parent
+        # process is restarted, it would fail when there's already something
+        # listening (in this case its own child from a previous run) on the
+        # same port.
+        try:
+          self.socket.close()
+        except socket.error:
+          pass
+        self.socket = None
+
+        # In case the handler code uses temporary files
+        utils.ResetTempfileModule()
+
+        self.request_executor(self, self.handler, connection, client_addr)
+      except Exception: # pylint: disable=W0703
         logging.exception("Error while handling request from %s:%s",
                           client_addr[0], client_addr[1])
         os._exit(1)
@@ -513,17 +590,37 @@ class HttpServer(http.HttpBase):
     else:
       self._children.append(pid)
 
+
+class HttpServerHandler(object):
+  """Base class for handling HTTP server requests.
+
+  Users of this class must subclass it and override the L{HandleRequest}
+  function.
+
+  """
   def PreHandleRequest(self, req):
     """Called before handling a request.
 
-    Can be overriden by a subclass.
+    Can be overridden by a subclass.
 
     """
 
   def HandleRequest(self, req):
     """Handles a request.
 
-    Must be overriden by subclass.
+    Must be overridden by subclass.
 
     """
     raise NotImplementedError()
+
+  @staticmethod
+  def FormatErrorMessage(values):
+    """Formats the body of an error message.
+
+    @type values: dict
+    @param values: dictionary with keys C{code}, C{message} and C{explain}.
+    @rtype: tuple; (string, string)
+    @return: Content-type and response body
+
+    """
+    return (DEFAULT_ERROR_CONTENT_TYPE, DEFAULT_ERROR_MESSAGE % values)