Convert SnapshotBlockDevice's docstring to epydoc
[ganeti-local] / lib / http.py
index a9febda..e32a236 100644 (file)
 
 """
 
-import socket
 import BaseHTTPServer
+import cgi
+import logging
+import mimetools
 import OpenSSL
+import os
+import select
+import socket
+import sys
 import time
+import signal
 import logging
 
-from ganeti import errors
-from ganeti import logger
+from ganeti import constants
 from ganeti import serializer
 
 
+WEEKDAYNAME = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
+MONTHNAME = [None,
+             'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
+             'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
+
+# Default error message
+DEFAULT_ERROR_CONTENT_TYPE = "text/html"
+DEFAULT_ERROR_MESSAGE = """\
+<head>
+<title>Error response</title>
+</head>
+<body>
+<h1>Error response</h1>
+<p>Error code %(code)d.
+<p>Message: %(message)s.
+<p>Error code explanation: %(code)s = %(explain)s.
+</body>
+"""
+
+HTTP_OK = 200
+HTTP_NO_CONTENT = 204
+HTTP_NOT_MODIFIED = 304
+
+HTTP_0_9 = "HTTP/0.9"
+HTTP_1_0 = "HTTP/1.0"
+HTTP_1_1 = "HTTP/1.1"
+
+HTTP_GET = "GET"
+HTTP_HEAD = "HEAD"
+HTTP_ETAG = "ETag"
+
+
+class SocketClosed(socket.error):
+  pass
+
+
 class HTTPException(Exception):
   code = None
   message = None
 
   def __init__(self, message=None):
+    Exception.__init__(self)
     if message is not None:
       self.message = message
 
@@ -71,6 +114,10 @@ class HTTPServiceUnavailable(HTTPException):
   code = 503
 
 
+class HTTPVersionNotSupported(HTTPException):
+  code = 505
+
+
 class ApacheLogfile:
   """Utility class to write HTTP server log files.
 
@@ -78,10 +125,6 @@ class ApacheLogfile:
   http://httpd.apache.org/docs/2.2/mod/mod_log_config.html#examples
 
   """
-  MONTHNAME = [None,
-               'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
-               'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
-
   def __init__(self, fd):
     """Constructor for ApacheLogfile class.
 
@@ -108,6 +151,7 @@ class ApacheLogfile:
       # Message
       format % args,
       ))
+    self._fd.flush()
 
   def _FormatCurrentTime(self):
     """Formats current time in Common Log Format.
@@ -125,153 +169,515 @@ class ApacheLogfile:
 
     """
     (_, month, _, _, _, _, _, _, _) = tm = time.gmtime(seconds)
-    format = "%d/" + self.MONTHNAME[month] + "/%Y:%H:%M:%S +0000"
+    format = "%d/" + MONTHNAME[month] + "/%Y:%H:%M:%S +0000"
     return time.strftime(format, tm)
 
 
-class HTTPServer(BaseHTTPServer.HTTPServer, object):
-  """Class to provide an HTTP/HTTPS server.
+class HTTPJsonConverter:
+  CONTENT_TYPE = "application/json"
+
+  def Encode(self, data):
+    return serializer.DumpJson(data)
+
+  def Decode(self, data):
+    return serializer.LoadJson(data)
+
+
+class _HttpConnectionHandler(object):
+  """Implements server side of HTTP
+
+  This class implements the server side of HTTP. It's based on code of Python's
+  BaseHTTPServer, from both version 2.4 and 3k. It does not support non-ASCII
+  character encodings. Keep-alive connections are not supported.
 
   """
-  allow_reuse_address = True
+  # String for "Server" header
+  server_version = "Ganeti %s" % constants.RELEASE_VERSION
 
-  def __init__(self, server_address, HandlerClass, httplog=None,
-               enable_ssl=False, ssl_key=None, ssl_cert=None):
-    """Server constructor.
+  # The default request version.  This only affects responses up until
+  # the point where the request line is parsed, so it mainly decides what
+  # the client gets back when sending a malformed request line.
+  # Most web servers default to HTTP 0.9, i.e. don't send a status line.
+  default_request_version = HTTP_0_9
 
-    Args:
-      server_address: a touple containing:
-        ip: a string with IP address, localhost if empty string
-        port: port number, integer
-      HandlerClass: HTTPRequestHandler object
-      httplog: Access log object
-      enable_ssl: Whether to enable SSL
-      ssl_key: SSL key file
-      ssl_cert: SSL certificate key
+  # Error message settings
+  error_message_format = DEFAULT_ERROR_MESSAGE
+  error_content_type = DEFAULT_ERROR_CONTENT_TYPE
+
+  responses = BaseHTTPServer.BaseHTTPRequestHandler.responses
+
+  def __init__(self, server, conn, client_addr, fileio_class):
+    """Initializes this class.
+
+    Part of the initialization is reading the request and eventual POST/PUT
+    data sent by the client.
 
     """
-    BaseHTTPServer.HTTPServer.__init__(self, server_address, HandlerClass)
-
-    self.httplog = httplog
-
-    if enable_ssl:
-      # Set up SSL
-      context = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
-      context.use_privatekey_file(ssl_key)
-      context.use_certificate_file(ssl_cert)
-      self.socket = OpenSSL.SSL.Connection(context,
-                                           socket.socket(self.address_family,
-                                           self.socket_type))
-    else:
-      self.socket = socket.socket(self.address_family, self.socket_type)
+    self._server = server
 
-    self.server_bind()
-    self.server_activate()
+    # We default rfile to buffered because otherwise it could be
+    # really slow for large data (a getc() call per byte); we make
+    # wfile unbuffered because (a) often after a write() we want to
+    # read and we need to flush the line; (b) big writes to unbuffered
+    # files are typically optimized by stdio even when big reads
+    # aren't.
+    self.rfile = fileio_class(conn, mode="rb", bufsize=-1)
+    self.wfile = fileio_class(conn, mode="wb", bufsize=0)
 
+    self.client_addr = client_addr
 
-class HTTPJsonConverter:
-  CONTENT_TYPE = "application/json"
+    self.request_headers = None
+    self.request_method = None
+    self.request_path = None
+    self.request_requestline = None
+    self.request_version = self.default_request_version
 
-  def Encode(self, data):
-    return serializer.DumpJson(data)
+    self.response_body = None
+    self.response_code = HTTP_OK
+    self.response_content_type = None
+    self.response_headers = {}
 
-  def Decode(self, data):
-    return serializer.LoadJson(data)
+    self.should_fork = False
 
+    try:
+      self._ReadRequest()
+      self._ReadPostData()
 
-class HTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler, object):
-  """Request handler class.
+      self.should_fork = self._server.ForkForRequest(self)
+    except HTTPException, err:
+      self._SetErrorStatus(err)
 
-  """
-  def setup(self):
-    """Setup secure read and write file objects.
+  def Close(self):
+    if not self.wfile.closed:
+      self.wfile.flush()
+    self.wfile.close()
+    self.rfile.close()
+
+  def _DateTimeHeader(self):
+    """Return the current date and time formatted for a message header.
 
     """
-    self.connection = self.request
-    self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
-    self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
+    (year, month, day, hh, mm, ss, wd, _, _) = time.gmtime()
+    return ("%s, %02d %3s %4d %02d:%02d:%02d GMT" %
+            (WEEKDAYNAME[wd], day, MONTHNAME[month], year, hh, mm, ss))
+
+  def _SetErrorStatus(self, err):
+    """Sets the response code and body from a HTTPException.
 
-  def handle_one_request(self):
-    """Parses a request and calls the handler function.
+    @type err: HTTPException
+    @param err: Exception instance
 
     """
-    self.raw_requestline = None
     try:
-      self.raw_requestline = self.rfile.readline()
-    except OpenSSL.SSL.Error, ex:
-      logger.Error("Error in SSL: %s" % str(ex))
-    if not self.raw_requestline:
-      self.close_connection = 1
-      return
-    if not self.parse_request(): # An error code has been sent, just exit
+      (shortmsg, longmsg) = self.responses[err.code]
+    except KeyError:
+      shortmsg = longmsg = "Unknown"
+
+    if err.message:
+      message = err.message
+    else:
+      message = shortmsg
+
+    values = {
+      "code": err.code,
+      "message": cgi.escape(message),
+      "explain": longmsg,
+      }
+
+    self.response_code = err.code
+    self.response_content_type = self.error_content_type
+    self.response_body = self.error_message_format % values
+
+  def HandleRequest(self):
+    """Handle the actual request.
+
+    Calls the actual handler function and converts exceptions into HTTP errors.
+
+    """
+    # Don't do anything if there's already been a problem
+    if self.response_code != HTTP_OK:
       return
-    logging.debug("HTTP request: %s", self.raw_requestline.rstrip("\r\n"))
+
+    assert self.request_method, "Status code %s requires a method" % HTTP_OK
+
+    # Check whether client is still there
+    self.rfile.read(0)
 
     try:
-      self._ReadPostData()
+      try:
+        result = self._server.HandleRequest(self)
+
+        # TODO: Content-type
+        encoder = HTTPJsonConverter()
+        body = encoder.Encode(result)
+
+        self.response_content_type = encoder.CONTENT_TYPE
+        self.response_body = body
+      except (HTTPException, KeyboardInterrupt, SystemExit):
+        raise
+      except Exception, err:
+        logging.exception("Caught exception")
+        raise HTTPInternalError(message=str(err))
+      except:
+        logging.exception("Unknown exception")
+        raise HTTPInternalError(message="Unknown error")
 
-      result = self.HandleRequest()
+    except HTTPException, err:
+      self._SetErrorStatus(err)
 
-      # TODO: Content-type
-      encoder = HTTPJsonConverter()
-      encoded_result = encoder.Encode(result)
+  def SendResponse(self):
+    """Sends response to the client.
 
-      self.send_response(200)
-      self.send_header("Content-Type", encoder.CONTENT_TYPE)
-      self.send_header("Content-Length", str(len(encoded_result)))
-      self.end_headers()
+    """
+    # Check whether client is still there
+    self.rfile.read(0)
 
-      self.wfile.write(encoded_result)
+    logging.info("%s:%s %s %s", self.client_addr[0], self.client_addr[1],
+                 self.request_requestline, self.response_code)
 
-    except HTTPException, err:
-      self.send_error(err.code, message=err.message)
+    if self.response_code in self.responses:
+      response_message = self.responses[self.response_code][0]
+    else:
+      response_message = ""
+
+    if self.request_version != HTTP_0_9:
+      self.wfile.write("%s %d %s\r\n" %
+                       (self.request_version, self.response_code,
+                        response_message))
+      self._SendHeader("Server", self.server_version)
+      self._SendHeader("Date", self._DateTimeHeader())
+      self._SendHeader("Content-Type", self.response_content_type)
+      self._SendHeader("Content-Length", str(len(self.response_body)))
+      for key, val in self.response_headers.iteritems():
+        self._SendHeader(key, val)
+
+      # We don't support keep-alive at this time
+      self._SendHeader("Connection", "close")
+      self.wfile.write("\r\n")
+
+    if (self.request_method != HTTP_HEAD and
+        self.response_code >= HTTP_OK and
+        self.response_code not in (HTTP_NO_CONTENT, HTTP_NOT_MODIFIED)):
+      self.wfile.write(self.response_body)
+
+  def _SendHeader(self, name, value):
+    if self.request_version != HTTP_0_9:
+      self.wfile.write("%s: %s\r\n" % (name, value))
+
+  def _ReadRequest(self):
+    """Reads and parses request line
+
+    """
+    raw_requestline = self.rfile.readline()
+
+    requestline = raw_requestline
+    if requestline[-2:] == '\r\n':
+      requestline = requestline[:-2]
+    elif requestline[-1:] == '\n':
+      requestline = requestline[:-1]
+
+    if not requestline:
+      raise HTTPBadRequest("Empty request line")
+
+    self.request_requestline = requestline
+
+    logging.debug("HTTP request: %s", raw_requestline.rstrip("\r\n"))
+
+    words = requestline.split()
+
+    if len(words) == 3:
+      [method, path, version] = words
+      if version[:5] != 'HTTP/':
+        raise HTTPBadRequest("Bad request version (%r)" % version)
+
+      try:
+        base_version_number = version.split('/', 1)[1]
+        version_number = base_version_number.split(".")
+
+        # RFC 2145 section 3.1 says there can be only one "." and
+        #   - major and minor numbers MUST be treated as
+        #      separate integers;
+        #   - HTTP/2.4 is a lower version than HTTP/2.13, which in
+        #      turn is lower than HTTP/12.3;
+        #   - Leading zeros MUST be ignored by recipients.
+        if len(version_number) != 2:
+          raise HTTPBadRequest("Bad request version (%r)" % version)
 
-    except Exception, err:
-      self.send_error(HTTPInternalError.code, message=str(err))
+        version_number = int(version_number[0]), int(version_number[1])
+      except (ValueError, IndexError):
+        raise HTTPBadRequest("Bad request version (%r)" % version)
 
-    except:
-      self.send_error(HTTPInternalError.code, message="Unknown error")
+      if version_number >= (2, 0):
+        raise HTTPVersionNotSupported("Invalid HTTP Version (%s)" %
+                                      base_version_number)
+
+    elif len(words) == 2:
+      version = HTTP_0_9
+      [method, path] = words
+      if method != HTTP_GET:
+        raise HTTPBadRequest("Bad HTTP/0.9 request type (%r)" % method)
+
+    else:
+      raise HTTPBadRequest("Bad request syntax (%r)" % requestline)
+
+    # Examine the headers and look for a Connection directive
+    headers = mimetools.Message(self.rfile, 0)
+
+    self.request_method = method
+    self.request_path = path
+    self.request_version = version
+    self.request_headers = headers
 
   def _ReadPostData(self):
-    if self.command.upper() not in ("POST", "PUT"):
-      self.post_data = None
+    """Reads POST/PUT data
+
+    """
+    if not self.request_method or self.request_method.upper() not in ("POST", "PUT"):
+      self.request_post_data = None
       return
 
     # TODO: Decide what to do when Content-Length header was not sent
     try:
-      content_length = int(self.headers.get('Content-Length', 0))
+      content_length = int(self.request_headers.get('Content-Length', 0))
     except ValueError:
       raise HTTPBadRequest("No Content-Length header or invalid format")
 
-    try:
-      data = self.rfile.read(content_length)
-    except socket.error, err:
-      logger.Error("Socket error while reading: %s" % str(err))
-      return
+    data = self.rfile.read(content_length)
 
     # TODO: Content-type, error handling
-    self.post_data = HTTPJsonConverter().Decode(data)
+    self.request_post_data = HTTPJsonConverter().Decode(data)
 
-    logging.debug("HTTP POST data: %s", self.post_data)
+    logging.debug("HTTP POST data: %s", self.request_post_data)
 
-  def HandleRequest(self):
-    """Handles a request.
+
+class HttpServer(object):
+  """Generic HTTP server class
+
+  Users of this class must subclass it and override the HandleRequest function.
+  Optionally, the ForkForRequest function can be overriden.
+
+  """
+  MAX_CHILDREN = 20
+
+  def __init__(self, mainloop, server_address):
+    self.mainloop = mainloop
+    self.server_address = server_address
+
+    # TODO: SSL support
+    self.ssl_cert = None
+    self.ssl_key = self.ssl_cert
+
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+    if self.ssl_cert and self.ssl_key:
+      ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
+      ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2)
+
+      ctx.use_certificate_file(self.ssl_cert)
+      ctx.use_privatekey_file(self.ssl_key)
+
+      self.socket = OpenSSL.SSL.Connection(ctx, sock)
+      self._fileio_class = _SSLFileObject
+    else:
+      self.socket = sock
+      self._fileio_class = socket._fileobject
+
+    # Allow port to be reused
+    self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+
+    self._children = []
+
+    mainloop.RegisterIO(self, self.socket.fileno(), select.POLLIN)
+    mainloop.RegisterSignal(self)
+
+  def Start(self):
+    self.socket.bind(self.server_address)
+    self.socket.listen(5)
+
+  def Stop(self):
+    self.socket.close()
+
+  def OnIO(self, fd, condition):
+    if condition & select.POLLIN:
+      self._IncomingConnection()
+
+  def OnSignal(self, signum):
+    if signum == signal.SIGCHLD:
+      self._CollectChildren(True)
+
+  def _CollectChildren(self, quick):
+    """Checks whether any child processes are done
+
+    @type quick: bool
+    @param quick: Whether to only use non-blocking functions
 
     """
+    if not quick:
+      # Don't wait for other processes if it should be a quick check
+      while len(self._children) > self.MAX_CHILDREN:
+        try:
+          pid, status = os.waitpid(0, 0)
+        except os.error:
+          pid = None
+        if pid and pid in self._children:
+          self._children.remove(pid)
+
+    for child in self._children:
+      try:
+        pid, status = os.waitpid(child, os.WNOHANG)
+      except os.error:
+        pid = None
+      if pid and pid in self._children:
+        self._children.remove(pid)
+
+  def _IncomingConnection(self):
+    connection, client_addr = self.socket.accept()
+    logging.info("Connection from %s:%s", client_addr[0], client_addr[1])
+    try:
+      handler = _HttpConnectionHandler(self, connection, client_addr, self._fileio_class)
+    except (socket.error, SocketClosed):
+      return
+
+    def FinishRequest():
+      try:
+        try:
+          try:
+            handler.HandleRequest()
+          finally:
+            # Try to send a response
+            handler.SendResponse()
+            handler.Close()
+        except SocketClosed:
+          pass
+      finally:
+        logging.info("Disconnected %s:%s", client_addr[0], client_addr[1])
+
+    # Check whether we should fork or not
+    if not handler.should_fork:
+      FinishRequest()
+      return
+
+    self._CollectChildren(False)
+
+    pid = os.fork()
+    if pid == 0:
+      # Child process
+      try:
+        FinishRequest()
+      except:
+        logging.exception("Error while handling request from %s:%s",
+                          client_addr[0], client_addr[1])
+        os._exit(1)
+      os._exit(0)
+    else:
+      self._children.append(pid)
+
+  def HandleRequest(self, req):
     raise NotImplementedError()
 
-  def log_message(self, format, *args):
-    """Log an arbitrary message.
+  def ForkForRequest(self, req):
+    return True
 
-    This is used by all other logging functions.
 
-    The first argument, FORMAT, is a format string for the
-    message to be logged.  If the format string contains
-    any % escapes requiring parameters, they should be
-    specified as subsequent arguments (it's just like
-    printf!).
+class _SSLFileObject(object):
+  """Wrapper around socket._fileobject
 
-    """
-    logging.debug("Handled request: %s", format % args)
-    if self.server.httplog:
-      self.server.httplog.LogRequest(self, format, *args)
+  This wrapper is required to handle OpenSSL exceptions.
+
+  """
+  def _RequireOpenSocket(fn):
+    def wrapper(self, *args, **kwargs):
+      if self.closed:
+        raise SocketClosed("Socket is closed")
+      return fn(self, *args, **kwargs)
+    return wrapper
+
+  def __init__(self, sock, mode='rb', bufsize=-1):
+    self._base = socket._fileobject(sock, mode=mode, bufsize=bufsize)
+
+  def _ConnectionLost(self):
+    self._base = None
+
+  def _getclosed(self):
+    return self._base is None or self._base.closed
+  closed = property(_getclosed, doc="True if the file is closed")
+
+  @_RequireOpenSocket
+  def close(self):
+    return self._base.close()
+
+  @_RequireOpenSocket
+  def flush(self):
+    return self._base.flush()
+
+  @_RequireOpenSocket
+  def fileno(self):
+    return self._base.fileno()
+
+  @_RequireOpenSocket
+  def read(self, size=-1):
+    return self._ReadWrapper(self._base.read, size=size)
+
+  @_RequireOpenSocket
+  def readline(self, size=-1):
+    return self._ReadWrapper(self._base.readline, size=size)
+
+  def _ReadWrapper(self, fn, *args, **kwargs):
+    while True:
+      try:
+        return fn(*args, **kwargs)
+
+      except OpenSSL.SSL.ZeroReturnError, err:
+        self._ConnectionLost()
+        return ""
+
+      except OpenSSL.SSL.WantReadError:
+        continue
+
+      #except OpenSSL.SSL.WantWriteError:
+      # TODO
+
+      except OpenSSL.SSL.SysCallError, (retval, desc):
+        if ((retval == -1 and desc == "Unexpected EOF")
+            or retval > 0):
+          self._ConnectionLost()
+          return ""
+
+        logging.exception("Error in OpenSSL")
+        self._ConnectionLost()
+        raise socket.error(err.args)
+
+      except OpenSSL.SSL.Error, err:
+        self._ConnectionLost()
+        raise socket.error(err.args)
+
+  @_RequireOpenSocket
+  def write(self, data):
+    return self._WriteWrapper(self._base.write, data)
+
+  def _WriteWrapper(self, fn, *args, **kwargs):
+    while True:
+      try:
+        return fn(*args, **kwargs)
+      except OpenSSL.SSL.ZeroReturnError, err:
+        self._ConnectionLost()
+        return 0
+
+      except OpenSSL.SSL.WantWriteError:
+        continue
+
+      #except OpenSSL.SSL.WantReadError:
+      # TODO
+
+      except OpenSSL.SSL.SysCallError, err:
+        if err.args[0] == -1 and data == "":
+          # errors when writing empty strings are expected
+          # and can be ignored
+          return 0
+
+        self._ConnectionLost()
+        raise socket.error(err.args)
+
+      except OpenSSL.SSL.Error, err:
+        self._ConnectionLost()
+        raise socket.error(err.args)