cli: Fix wrong argument kind for groups
[ganeti-local] / lib / http / server.py
index 128d205..7a46af6 100644 (file)
@@ -1,7 +1,7 @@
 #
 #
 
 #
 #
 
-# Copyright (C) 2007, 2008 Google Inc.
+# Copyright (C) 2007, 2008, 2010 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 #
 # This program is free software; you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -26,15 +26,14 @@ import BaseHTTPServer
 import cgi
 import logging
 import os
 import cgi
 import logging
 import os
-import select
 import socket
 import time
 import signal
 import socket
 import time
 import signal
+import asyncore
 
 
-from ganeti import constants
-from ganeti import serializer
-from ganeti import utils
 from ganeti import http
 from ganeti import http
+from ganeti import utils
+from ganeti import netutils
 
 
 WEEKDAYNAME = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
 
 
 WEEKDAYNAME = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
@@ -59,11 +58,15 @@ DEFAULT_ERROR_MESSAGE = """\
 """
 
 
 """
 
 
-def _DateTimeHeader():
+def _DateTimeHeader(gmnow=None):
   """Return the current date and time formatted for a message header.
 
   """Return the current date and time formatted for a message header.
 
+  The time MUST be in the GMT timezone.
+
   """
   """
-  (year, month, day, hh, mm, ss, wd, _, _) = time.gmtime()
+  if gmnow is None:
+    gmnow = time.gmtime()
+  (year, month, day, hh, mm, ss, wd, _, _) = gmnow
   return ("%s, %02d %3s %4d %02d:%02d:%02d GMT" %
           (WEEKDAYNAME[wd], day, MONTHNAME[month], year, hh, mm, ss))
 
   return ("%s, %02d %3s %4d %02d:%02d:%02d GMT" %
           (WEEKDAYNAME[wd], day, MONTHNAME[month], year, hh, mm, ss))
 
@@ -72,16 +75,28 @@ class _HttpServerRequest(object):
   """Data structure for HTTP request on server side.
 
   """
   """Data structure for HTTP request on server side.
 
   """
-  def __init__(self, request_msg):
+  def __init__(self, method, path, headers, body):
     # Request attributes
     # Request attributes
-    self.request_method = request_msg.start_line.method
-    self.request_path = request_msg.start_line.path
-    self.request_headers = request_msg.headers
-    self.request_body = request_msg.decoded_body
+    self.request_method = method
+    self.request_path = path
+    self.request_headers = headers
+    self.request_body = body
 
     # Response attributes
     self.resp_headers = {}
 
 
     # Response attributes
     self.resp_headers = {}
 
+    # Private data for request handler (useful in combination with
+    # authentication)
+    self.private = None
+
+  def __repr__(self):
+    status = ["%s.%s" % (self.__class__.__module__, self.__class__.__name__),
+              self.request_method, self.request_path,
+              "headers=%r" % str(self.request_headers),
+              "body=%r" % (self.request_body, )]
+
+    return "<%s at %#x>" % (" ".join(status), id(self))
+
 
 class _HttpServerToClientMessageWriter(http.HttpMessageWriter):
   """Writes an HTTP response to client.
 
 class _HttpServerToClientMessageWriter(http.HttpMessageWriter):
   """Writes an HTTP response to client.
@@ -94,7 +109,7 @@ class _HttpServerToClientMessageWriter(http.HttpMessageWriter):
     @param sock: Target socket
     @type request_msg: http.HttpMessage
     @param request_msg: Request message, required to determine whether
     @param sock: Target socket
     @type request_msg: http.HttpMessage
     @param request_msg: Request message, required to determine whether
-                        response may have a message body
+        response may have a message body
     @type response_msg: http.HttpMessage
     @param response_msg: Response message
     @type write_timeout: float
     @type response_msg: http.HttpMessage
     @param response_msg: Response message
     @type write_timeout: float
@@ -130,9 +145,11 @@ class _HttpServerToClientMessageWriter(http.HttpMessageWriter):
     # message-body, [...]"
 
     return (http.HttpMessageWriter.HasMessageBody(self) and
     # message-body, [...]"
 
     return (http.HttpMessageWriter.HasMessageBody(self) and
-            (request_method is not None and request_method != http.HTTP_HEAD) and
+            (request_method is not None and
+             request_method != http.HTTP_HEAD) and
             response_code >= http.HTTP_OK and
             response_code >= http.HTTP_OK and
-            response_code not in (http.HTTP_NO_CONTENT, http.HTTP_NOT_MODIFIED))
+            response_code not in (http.HTTP_NO_CONTENT,
+                                  http.HTTP_NOT_MODIFIED))
 
 
 class _HttpClientToServerMessageReader(http.HttpMessageReader):
 
 
 class _HttpClientToServerMessageReader(http.HttpMessageReader):
@@ -197,12 +214,13 @@ class _HttpClientToServerMessageReader(http.HttpMessageReader):
     return http.HttpClientToServerStartLine(method, path, version)
 
 
     return http.HttpClientToServerStartLine(method, path, version)
 
 
-class _HttpServerRequestExecutor(object):
+class HttpServerRequestExecutor(object):
   """Implements server side of HTTP.
 
   """Implements server side of HTTP.
 
-  This class implements the server side of HTTP. It's based on code of Python's
-  BaseHTTPServer, from both version 2.4 and 3k. It does not support non-ASCII
-  character encodings. Keep-alive connections are not supported.
+  This class implements the server side of HTTP. It's based on code of
+  Python's BaseHTTPServer, from both version 2.4 and 3k. It does not
+  support non-ASCII character encodings. Keep-alive connections are
+  not supported.
 
   """
   # The default request version.  This only affects responses up until
 
   """
   # The default request version.  This only affects responses up until
@@ -230,8 +248,6 @@ class _HttpServerRequestExecutor(object):
     self.sock = sock
     self.client_addr = client_addr
 
     self.sock = sock
     self.client_addr = client_addr
 
-    self.poller = select.poll()
-
     self.request_msg = http.HttpMessage()
     self.response_msg = http.HttpMessage()
 
     self.request_msg = http.HttpMessage()
     self.response_msg = http.HttpMessage()
 
@@ -245,14 +261,31 @@ class _HttpServerRequestExecutor(object):
     # Operate in non-blocking mode
     self.sock.setblocking(0)
 
     # Operate in non-blocking mode
     self.sock.setblocking(0)
 
-    logging.info("Connection from %s:%s", client_addr[0], client_addr[1])
+    logging.debug("Connection from %s:%s", client_addr[0], client_addr[1])
     try:
       request_msg_reader = None
       force_close = True
       try:
     try:
       request_msg_reader = None
       force_close = True
       try:
+        # Do the secret SSL handshake
+        if self.server.using_ssl:
+          self.sock.set_accept_state()
+          try:
+            http.Handshake(self.sock, self.WRITE_TIMEOUT)
+          except http.HttpSessionHandshakeUnexpectedEOF:
+            # Ignore rest
+            return
+
         try:
           try:
             request_msg_reader = self._ReadRequest()
         try:
           try:
             request_msg_reader = self._ReadRequest()
+
+            # RFC2616, 14.23: All Internet-based HTTP/1.1 servers MUST respond
+            # with a 400 (Bad Request) status code to any HTTP/1.1 request
+            # message which lacks a Host header field.
+            if (self.request_msg.start_line.version == http.HTTP_1_1 and
+                http.HTTP_HOST not in self.request_msg.headers):
+              raise http.HttpBadRequest(message="Missing Host header")
+
             self._HandleRequest()
 
             # Only wait for client to close if we didn't have any exception.
             self._HandleRequest()
 
             # Only wait for client to close if we didn't have any exception.
@@ -263,14 +296,13 @@ class _HttpServerRequestExecutor(object):
           # Try to send a response
           self._SendResponse()
       finally:
           # Try to send a response
           self._SendResponse()
       finally:
-        http.ShutdownConnection(self.poller, sock,
-                                self.CLOSE_TIMEOUT, self.WRITE_TIMEOUT,
+        http.ShutdownConnection(sock, self.CLOSE_TIMEOUT, self.WRITE_TIMEOUT,
                                 request_msg_reader, force_close)
 
       self.sock.close()
       self.sock = None
     finally:
                                 request_msg_reader, force_close)
 
       self.sock.close()
       self.sock = None
     finally:
-      logging.info("Disconnected %s:%s", client_addr[0], client_addr[1])
+      logging.debug("Disconnected %s:%s", client_addr[0], client_addr[1])
 
   def _ReadRequest(self):
     """Reads a request sent by client.
 
   def _ReadRequest(self):
     """Reads a request sent by client.
@@ -293,25 +325,38 @@ class _HttpServerRequestExecutor(object):
     """Calls the handler function for the current request.
 
     """
     """Calls the handler function for the current request.
 
     """
-    handler_context = _HttpServerRequest(self.request_msg)
+    handler_context = _HttpServerRequest(self.request_msg.start_line.method,
+                                         self.request_msg.start_line.path,
+                                         self.request_msg.headers,
+                                         self.request_msg.body)
+
+    logging.debug("Handling request %r", handler_context)
 
     try:
 
     try:
-      result = self.server.HandleRequest(handler_context)
-    except (http.HttpException, KeyboardInterrupt, SystemExit):
-      raise
-    except Exception, err:
-      logging.exception("Caught exception")
-      raise http.HttpInternalError(message=str(err))
-    except:
-      logging.exception("Unknown exception")
-      raise http.HttpInternalError(message="Unknown error")
-
-    # TODO: Content-type
-    encoder = http.HttpJsonConverter()
-    self.response_msg.start_line.code = http.HTTP_OK
-    self.response_msg.body = encoder.Encode(result)
-    self.response_msg.headers = handler_context.resp_headers
-    self.response_msg.headers[http.HTTP_CONTENT_TYPE] = encoder.CONTENT_TYPE
+      try:
+        # Authentication, etc.
+        self.server.PreHandleRequest(handler_context)
+
+        # Call actual request handler
+        result = self.server.HandleRequest(handler_context)
+      except (http.HttpException, KeyboardInterrupt, SystemExit):
+        raise
+      except Exception, err:
+        logging.exception("Caught exception")
+        raise http.HttpInternalServerError(message=str(err))
+      except:
+        logging.exception("Unknown exception")
+        raise http.HttpInternalServerError(message="Unknown error")
+
+      if not isinstance(result, basestring):
+        raise http.HttpError("Handler function didn't return string type")
+
+      self.response_msg.start_line.code = http.HTTP_OK
+      self.response_msg.headers = handler_context.resp_headers
+      self.response_msg.body = result
+    finally:
+      # No reason to keep this any longer, even for exceptions
+      handler_context.private = None
 
   def _SendResponse(self):
     """Sends the response to the client.
 
   def _SendResponse(self):
     """Sends the response to the client.
@@ -373,13 +418,28 @@ class _HttpServerRequestExecutor(object):
       }
 
     self.response_msg.start_line.code = err.code
       }
 
     self.response_msg.start_line.code = err.code
-    self.response_msg.headers = {
-      http.HTTP_CONTENT_TYPE: self.error_content_type,
-      }
-    self.response_msg.body = self.error_message_format % values
 
 
+    headers = {}
+    if err.headers:
+      headers.update(err.headers)
+    headers[http.HTTP_CONTENT_TYPE] = self.error_content_type
+    self.response_msg.headers = headers
+
+    self.response_msg.body = self._FormatErrorMessage(values)
 
 
-class HttpServer(http.HttpBase):
+  def _FormatErrorMessage(self, values):
+    """Formats the body of an error message.
+
+    @type values: dict
+    @param values: dictionary with keys code, message and explain.
+    @rtype: string
+    @return: the body of the message
+
+    """
+    return self.error_message_format % values
+
+
+class HttpServer(http.HttpBase, asyncore.dispatcher):
   """Generic HTTP server class
 
   Users of this class must subclass it and override the HandleRequest function.
   """Generic HTTP server class
 
   Users of this class must subclass it and override the HandleRequest function.
@@ -388,7 +448,8 @@ class HttpServer(http.HttpBase):
   MAX_CHILDREN = 20
 
   def __init__(self, mainloop, local_address, port,
   MAX_CHILDREN = 20
 
   def __init__(self, mainloop, local_address, port,
-               ssl_params=None, ssl_verify_peer=False):
+               ssl_params=None, ssl_verify_peer=False,
+               request_executor_class=None):
     """Initializes the HTTP server
 
     @type mainloop: ganeti.daemon.Mainloop
     """Initializes the HTTP server
 
     @type mainloop: ganeti.daemon.Mainloop
@@ -400,24 +461,33 @@ class HttpServer(http.HttpBase):
     @type ssl_params: HttpSslParams
     @param ssl_params: SSL key and certificate
     @type ssl_verify_peer: bool
     @type ssl_params: HttpSslParams
     @param ssl_params: SSL key and certificate
     @type ssl_verify_peer: bool
-    @param ssl_verify_peer: Whether to require client certificate and compare
-                            it with our certificate
+    @param ssl_verify_peer: Whether to require client certificate
+        and compare it with our certificate
+    @type request_executor_class: class
+    @param request_executor_class: an class derived from the
+        HttpServerRequestExecutor class
 
     """
     http.HttpBase.__init__(self)
 
     """
     http.HttpBase.__init__(self)
+    asyncore.dispatcher.__init__(self)
+
+    if request_executor_class is None:
+      self.request_executor = HttpServerRequestExecutor
+    else:
+      self.request_executor = request_executor_class
 
     self.mainloop = mainloop
     self.local_address = local_address
     self.port = port
 
     self.mainloop = mainloop
     self.local_address = local_address
     self.port = port
-
-    self.socket = self._CreateSocket(ssl_params, ssl_verify_peer)
+    family = netutils.IPAddress.GetAddressFamily(local_address)
+    self.socket = self._CreateSocket(ssl_params, ssl_verify_peer, family)
 
     # Allow port to be reused
     self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 
     self._children = []
 
     # Allow port to be reused
     self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 
     self._children = []
-
-    mainloop.RegisterIO(self, self.socket.fileno(), select.POLLIN)
+    self.set_socket(self.socket)
+    self.accepting = True
     mainloop.RegisterSignal(self)
 
   def Start(self):
     mainloop.RegisterSignal(self)
 
   def Start(self):
@@ -427,9 +497,8 @@ class HttpServer(http.HttpBase):
   def Stop(self):
     self.socket.close()
 
   def Stop(self):
     self.socket.close()
 
-  def OnIO(self, fd, condition):
-    if condition & select.POLLIN:
-      self._IncomingConnection()
+  def handle_accept(self):
+    self._IncomingConnection()
 
   def OnSignal(self, signum):
     if signum == signal.SIGCHLD:
 
   def OnSignal(self, signum):
     if signum == signal.SIGCHLD:
@@ -450,7 +519,7 @@ class HttpServer(http.HttpBase):
           # As soon as too many children run, we'll not respond to new
           # requests. The real solution would be to add a timeout for children
           # and killing them after some time.
           # As soon as too many children run, we'll not respond to new
           # requests. The real solution would be to add a timeout for children
           # and killing them after some time.
-          pid, status = os.waitpid(0, 0)
+          pid, _ = os.waitpid(0, 0)
         except os.error:
           pid = None
         if pid and pid in self._children:
         except os.error:
           pid = None
         if pid and pid in self._children:
@@ -458,7 +527,7 @@ class HttpServer(http.HttpBase):
 
     for child in self._children:
       try:
 
     for child in self._children:
       try:
-        pid, status = os.waitpid(child, os.WNOHANG)
+        pid, _ = os.waitpid(child, os.WNOHANG)
       except os.error:
         pid = None
       if pid and pid in self._children:
       except os.error:
         pid = None
       if pid and pid in self._children:
@@ -468,6 +537,7 @@ class HttpServer(http.HttpBase):
     """Called for each incoming connection
 
     """
     """Called for each incoming connection
 
     """
+    # pylint: disable-msg=W0212
     (connection, client_addr) = self.socket.accept()
 
     self._CollectChildren(False)
     (connection, client_addr) = self.socket.accept()
 
     self._CollectChildren(False)
@@ -476,8 +546,21 @@ class HttpServer(http.HttpBase):
     if pid == 0:
       # Child process
       try:
     if pid == 0:
       # Child process
       try:
-        _HttpServerRequestExecutor(self, connection, client_addr)
-      except Exception:
+        # The client shouldn't keep the listening socket open. If the parent
+        # process is restarted, it would fail when there's already something
+        # listening (in this case its own child from a previous run) on the
+        # same port.
+        try:
+          self.socket.close()
+        except socket.error:
+          pass
+        self.socket = None
+
+        # In case the handler code uses temporary files
+        utils.ResetTempfileModule()
+
+        self.request_executor(self, connection, client_addr)
+      except Exception: # pylint: disable-msg=W0703
         logging.exception("Error while handling request from %s:%s",
                           client_addr[0], client_addr[1])
         os._exit(1)
         logging.exception("Error while handling request from %s:%s",
                           client_addr[0], client_addr[1])
         os._exit(1)
@@ -485,10 +568,17 @@ class HttpServer(http.HttpBase):
     else:
       self._children.append(pid)
 
     else:
       self._children.append(pid)
 
+  def PreHandleRequest(self, req):
+    """Called before handling a request.
+
+    Can be overridden by a subclass.
+
+    """
+
   def HandleRequest(self, req):
     """Handles a request.
 
   def HandleRequest(self, req):
     """Handles a request.
 
-    Must be overriden by subclass.
+    Must be overridden by subclass.
 
     """
     raise NotImplementedError()
 
     """
     raise NotImplementedError()