ganeti.http: Add another class to contain SSL key and certificate
[ganeti-local] / lib / http.py
1 #
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful, but
9 # WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
11 # General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License
14 # along with this program; if not, write to the Free Software
15 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
16 # 02110-1301, USA.
17
18 """HTTP server module.
19
20 """
21
22 import BaseHTTPServer
23 import cgi
24 import logging
25 import mimetools
26 import OpenSSL
27 import os
28 import select
29 import socket
30 import sys
31 import time
32 import signal
33 import logging
34 import errno
35 import threading
36
37 from cStringIO import StringIO
38
39 from ganeti import constants
40 from ganeti import serializer
41 from ganeti import workerpool
42 from ganeti import utils
43
44
45 HTTP_CLIENT_THREADS = 10
46
47 HTTP_GANETI_VERSION = "Ganeti %s" % constants.RELEASE_VERSION
48
49 WEEKDAYNAME = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
50 MONTHNAME = [None,
51              'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
52              'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
53
54 # Default error message
55 DEFAULT_ERROR_CONTENT_TYPE = "text/html"
56 DEFAULT_ERROR_MESSAGE = """\
57 <head>
58 <title>Error response</title>
59 </head>
60 <body>
61 <h1>Error response</h1>
62 <p>Error code %(code)d.
63 <p>Message: %(message)s.
64 <p>Error code explanation: %(code)s = %(explain)s.
65 </body>
66 """
67
68 HTTP_OK = 200
69 HTTP_NO_CONTENT = 204
70 HTTP_NOT_MODIFIED = 304
71
72 HTTP_0_9 = "HTTP/0.9"
73 HTTP_1_0 = "HTTP/1.0"
74 HTTP_1_1 = "HTTP/1.1"
75
76 HTTP_GET = "GET"
77 HTTP_HEAD = "HEAD"
78 HTTP_POST = "POST"
79 HTTP_PUT = "PUT"
80
81 HTTP_ETAG = "ETag"
82 HTTP_HOST = "Host"
83 HTTP_SERVER = "Server"
84 HTTP_DATE = "Date"
85 HTTP_USER_AGENT = "User-Agent"
86 HTTP_CONTENT_TYPE = "Content-Type"
87 HTTP_CONTENT_LENGTH = "Content-Length"
88 HTTP_CONNECTION = "Connection"
89 HTTP_KEEP_ALIVE = "Keep-Alive"
90
91 _SSL_UNEXPECTED_EOF = "Unexpected EOF"
92
93
94 class SocketClosed(socket.error):
95   pass
96
97
98 class _HttpClientError(Exception):
99   """Internal exception for HTTP client errors.
100
101   This should only be used for internal error reporting.
102
103   """
104   pass
105
106
107 class HTTPException(Exception):
108   code = None
109   message = None
110
111   def __init__(self, message=None):
112     Exception.__init__(self)
113     if message is not None:
114       self.message = message
115
116
117 class HTTPBadRequest(HTTPException):
118   code = 400
119
120
121 class HTTPForbidden(HTTPException):
122   code = 403
123
124
125 class HTTPNotFound(HTTPException):
126   code = 404
127
128
129 class HTTPGone(HTTPException):
130   code = 410
131
132
133 class HTTPLengthRequired(HTTPException):
134   code = 411
135
136
137 class HTTPInternalError(HTTPException):
138   code = 500
139
140
141 class HTTPNotImplemented(HTTPException):
142   code = 501
143
144
145 class HTTPServiceUnavailable(HTTPException):
146   code = 503
147
148
149 class HTTPVersionNotSupported(HTTPException):
150   code = 505
151
152
153 class ApacheLogfile:
154   """Utility class to write HTTP server log files.
155
156   The written format is the "Common Log Format" as defined by Apache:
157   http://httpd.apache.org/docs/2.2/mod/mod_log_config.html#examples
158
159   """
160   def __init__(self, fd):
161     """Constructor for ApacheLogfile class.
162
163     Args:
164     - fd: Open file object
165
166     """
167     self._fd = fd
168
169   def LogRequest(self, request, format, *args):
170     self._fd.write("%s %s %s [%s] %s\n" % (
171       # Remote host address
172       request.address_string(),
173
174       # RFC1413 identity (identd)
175       "-",
176
177       # Remote user
178       "-",
179
180       # Request time
181       self._FormatCurrentTime(),
182
183       # Message
184       format % args,
185       ))
186     self._fd.flush()
187
188   def _FormatCurrentTime(self):
189     """Formats current time in Common Log Format.
190
191     """
192     return self._FormatLogTime(time.time())
193
194   def _FormatLogTime(self, seconds):
195     """Formats time for Common Log Format.
196
197     All timestamps are logged in the UTC timezone.
198
199     Args:
200     - seconds: Time in seconds since the epoch
201
202     """
203     (_, month, _, _, _, _, _, _, _) = tm = time.gmtime(seconds)
204     format = "%d/" + MONTHNAME[month] + "/%Y:%H:%M:%S +0000"
205     return time.strftime(format, tm)
206
207
208 class HTTPJsonConverter:
209   CONTENT_TYPE = "application/json"
210
211   def Encode(self, data):
212     return serializer.DumpJson(data)
213
214   def Decode(self, data):
215     return serializer.LoadJson(data)
216
217
218 class HttpSslParams(object):
219   """Data class for SSL key and certificate.
220
221   """
222   def __init__(self, ssl_key_path, ssl_cert_path):
223     """Initializes this class.
224
225     @type ssl_key_path: string
226     @param ssl_key_path: Path to file containing SSL key in PEM format
227     @type ssl_cert_path: string
228     @param ssl_cert_path: Path to file containing SSL certificate in PEM format
229
230     """
231     ssl_key_pem = utils.ReadFile(ssl_key_path)
232     ssl_cert_pem = utils.ReadFile(ssl_cert_path)
233
234     cr = OpenSSL.crypto
235     self.cert = cr.load_certificate(cr.FILETYPE_PEM, ssl_cert_pem)
236     self.key = cr.load_privatekey(cr.FILETYPE_PEM, ssl_key_pem)
237     del cr
238
239
240 class _HttpSocketBase(object):
241   """Base class for HTTP server and client.
242
243   """
244   def __init__(self):
245     self._using_ssl = None
246     self._ssl_params = None
247
248   def _CreateSocket(self, ssl_params, ssl_verify_peer):
249     """Creates a TCP socket and initializes SSL if needed.
250
251     @type ssl_params: HttpSslParams
252     @param ssl_params: SSL key and certificate
253     @type ssl_verify_peer: bool
254     @param ssl_verify_peer: Whether to require client certificate and compare
255                             it with our certificate
256
257     """
258     self._ssl_params = ssl_params
259
260     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
261
262     # Should we enable SSL?
263     self._using_ssl = ssl_params is not None
264
265     if not self._using_ssl:
266       return sock
267
268     ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
269     ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2)
270
271     ctx.use_privatekey(ssl_params.key)
272     ctx.use_certificate(ssl_params.cert)
273     ctx.check_privatekey()
274
275     if ssl_verify_peer:
276       ctx.set_verify(OpenSSL.SSL.VERIFY_PEER |
277                      OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
278                      self._SSLVerifyCallback)
279
280     return OpenSSL.SSL.Connection(ctx, sock)
281
282   def _SSLVerifyCallback(self, conn, cert, errnum, errdepth, ok):
283     """Verify the certificate provided by the peer
284
285     We only compare fingerprints. The client must use the same certificate as
286     we do on our side.
287
288     """
289     assert self._ssl_params, "SSL not initialized"
290
291     mykey = self._ssl_params.key
292     mycert = self._ssl_params.cert
293
294     return (mycert.digest("sha1") == cert.digest("sha1") and
295             mycert.digest("md5") == cert.digest("md5"))
296
297
298 class _HttpConnectionHandler(object):
299   """Implements server side of HTTP
300
301   This class implements the server side of HTTP. It's based on code of Python's
302   BaseHTTPServer, from both version 2.4 and 3k. It does not support non-ASCII
303   character encodings. Keep-alive connections are not supported.
304
305   """
306   # The default request version.  This only affects responses up until
307   # the point where the request line is parsed, so it mainly decides what
308   # the client gets back when sending a malformed request line.
309   # Most web servers default to HTTP 0.9, i.e. don't send a status line.
310   default_request_version = HTTP_0_9
311
312   # Error message settings
313   error_message_format = DEFAULT_ERROR_MESSAGE
314   error_content_type = DEFAULT_ERROR_CONTENT_TYPE
315
316   responses = BaseHTTPServer.BaseHTTPRequestHandler.responses
317
318   def __init__(self, server, conn, client_addr, fileio_class):
319     """Initializes this class.
320
321     Part of the initialization is reading the request and eventual POST/PUT
322     data sent by the client.
323
324     """
325     self._server = server
326
327     # We default rfile to buffered because otherwise it could be
328     # really slow for large data (a getc() call per byte); we make
329     # wfile unbuffered because (a) often after a write() we want to
330     # read and we need to flush the line; (b) big writes to unbuffered
331     # files are typically optimized by stdio even when big reads
332     # aren't.
333     self.rfile = fileio_class(conn, mode="rb", bufsize=-1)
334     self.wfile = fileio_class(conn, mode="wb", bufsize=0)
335
336     self.client_addr = client_addr
337
338     self.request_headers = None
339     self.request_method = None
340     self.request_path = None
341     self.request_requestline = None
342     self.request_version = self.default_request_version
343
344     self.response_body = None
345     self.response_code = HTTP_OK
346     self.response_content_type = None
347     self.response_headers = {}
348
349     self.should_fork = False
350
351     try:
352       self._ReadRequest()
353       self._ReadPostData()
354     except HTTPException, err:
355       self._SetErrorStatus(err)
356
357   def Close(self):
358     if not self.wfile.closed:
359       self.wfile.flush()
360     self.wfile.close()
361     self.rfile.close()
362
363   def _DateTimeHeader(self):
364     """Return the current date and time formatted for a message header.
365
366     """
367     (year, month, day, hh, mm, ss, wd, _, _) = time.gmtime()
368     return ("%s, %02d %3s %4d %02d:%02d:%02d GMT" %
369             (WEEKDAYNAME[wd], day, MONTHNAME[month], year, hh, mm, ss))
370
371   def _SetErrorStatus(self, err):
372     """Sets the response code and body from a HTTPException.
373
374     @type err: HTTPException
375     @param err: Exception instance
376
377     """
378     try:
379       (shortmsg, longmsg) = self.responses[err.code]
380     except KeyError:
381       shortmsg = longmsg = "Unknown"
382
383     if err.message:
384       message = err.message
385     else:
386       message = shortmsg
387
388     values = {
389       "code": err.code,
390       "message": cgi.escape(message),
391       "explain": longmsg,
392       }
393
394     self.response_code = err.code
395     self.response_content_type = self.error_content_type
396     self.response_body = self.error_message_format % values
397
398   def HandleRequest(self):
399     """Handle the actual request.
400
401     Calls the actual handler function and converts exceptions into HTTP errors.
402
403     """
404     # Don't do anything if there's already been a problem
405     if self.response_code != HTTP_OK:
406       return
407
408     assert self.request_method, "Status code %s requires a method" % HTTP_OK
409
410     # Check whether client is still there
411     self.rfile.read(0)
412
413     try:
414       try:
415         result = self._server.HandleRequest(self)
416
417         # TODO: Content-type
418         encoder = HTTPJsonConverter()
419         body = encoder.Encode(result)
420
421         self.response_content_type = encoder.CONTENT_TYPE
422         self.response_body = body
423       except (HTTPException, KeyboardInterrupt, SystemExit):
424         raise
425       except Exception, err:
426         logging.exception("Caught exception")
427         raise HTTPInternalError(message=str(err))
428       except:
429         logging.exception("Unknown exception")
430         raise HTTPInternalError(message="Unknown error")
431
432     except HTTPException, err:
433       self._SetErrorStatus(err)
434
435   def SendResponse(self):
436     """Sends response to the client.
437
438     """
439     # Check whether client is still there
440     self.rfile.read(0)
441
442     logging.info("%s:%s %s %s", self.client_addr[0], self.client_addr[1],
443                  self.request_requestline, self.response_code)
444
445     if self.response_code in self.responses:
446       response_message = self.responses[self.response_code][0]
447     else:
448       response_message = ""
449
450     if self.request_version != HTTP_0_9:
451       self.wfile.write("%s %d %s\r\n" %
452                        (self.request_version, self.response_code,
453                         response_message))
454       self._SendHeader(HTTP_SERVER, HTTP_GANETI_VERSION)
455       self._SendHeader(HTTP_DATE, self._DateTimeHeader())
456       self._SendHeader(HTTP_CONTENT_TYPE, self.response_content_type)
457       self._SendHeader(HTTP_CONTENT_LENGTH, str(len(self.response_body)))
458       for key, val in self.response_headers.iteritems():
459         self._SendHeader(key, val)
460
461       # We don't support keep-alive at this time
462       self._SendHeader(HTTP_CONNECTION, "close")
463       self.wfile.write("\r\n")
464
465     if (self.request_method != HTTP_HEAD and
466         self.response_code >= HTTP_OK and
467         self.response_code not in (HTTP_NO_CONTENT, HTTP_NOT_MODIFIED)):
468       self.wfile.write(self.response_body)
469
470   def _SendHeader(self, name, value):
471     if self.request_version != HTTP_0_9:
472       self.wfile.write("%s: %s\r\n" % (name, value))
473
474   def _ReadRequest(self):
475     """Reads and parses request line
476
477     """
478     raw_requestline = self.rfile.readline()
479
480     requestline = raw_requestline
481     if requestline[-2:] == '\r\n':
482       requestline = requestline[:-2]
483     elif requestline[-1:] == '\n':
484       requestline = requestline[:-1]
485
486     if not requestline:
487       raise HTTPBadRequest("Empty request line")
488
489     self.request_requestline = requestline
490
491     logging.debug("HTTP request: %s", raw_requestline.rstrip("\r\n"))
492
493     words = requestline.split()
494
495     if len(words) == 3:
496       [method, path, version] = words
497       if version[:5] != 'HTTP/':
498         raise HTTPBadRequest("Bad request version (%r)" % version)
499
500       try:
501         base_version_number = version.split('/', 1)[1]
502         version_number = base_version_number.split(".")
503
504         # RFC 2145 section 3.1 says there can be only one "." and
505         #   - major and minor numbers MUST be treated as
506         #      separate integers;
507         #   - HTTP/2.4 is a lower version than HTTP/2.13, which in
508         #      turn is lower than HTTP/12.3;
509         #   - Leading zeros MUST be ignored by recipients.
510         if len(version_number) != 2:
511           raise HTTPBadRequest("Bad request version (%r)" % version)
512
513         version_number = int(version_number[0]), int(version_number[1])
514       except (ValueError, IndexError):
515         raise HTTPBadRequest("Bad request version (%r)" % version)
516
517       if version_number >= (2, 0):
518         raise HTTPVersionNotSupported("Invalid HTTP Version (%s)" %
519                                       base_version_number)
520
521     elif len(words) == 2:
522       version = HTTP_0_9
523       [method, path] = words
524       if method != HTTP_GET:
525         raise HTTPBadRequest("Bad HTTP/0.9 request type (%r)" % method)
526
527     else:
528       raise HTTPBadRequest("Bad request syntax (%r)" % requestline)
529
530     # Examine the headers and look for a Connection directive
531     headers = mimetools.Message(self.rfile, 0)
532
533     self.request_method = method
534     self.request_path = path
535     self.request_version = version
536     self.request_headers = headers
537
538   def _ReadPostData(self):
539     """Reads POST/PUT data
540
541     Quoting RFC1945, section 7.2 (HTTP/1.0): "The presence of an entity body in
542     a request is signaled by the inclusion of a Content-Length header field in
543     the request message headers. HTTP/1.0 requests containing an entity body
544     must include a valid Content-Length header field."
545
546     """
547     # While not according to specification, we only support an entity body for
548     # POST and PUT.
549     if (not self.request_method or
550         self.request_method.upper() not in (HTTP_POST, HTTP_PUT)):
551       self.request_post_data = None
552       return
553
554     content_length = None
555     try:
556       if HTTP_CONTENT_LENGTH in self.request_headers:
557         content_length = int(self.request_headers[HTTP_CONTENT_LENGTH])
558     except TypeError:
559       pass
560     except ValueError:
561       pass
562
563     # 411 Length Required is specified in RFC2616, section 10.4.12 (HTTP/1.1)
564     if content_length is None:
565       raise HTTPLengthRequired("Missing Content-Length header or"
566                                " invalid format")
567
568     data = self.rfile.read(content_length)
569
570     # TODO: Content-type, error handling
571     if data:
572       self.request_post_data = HTTPJsonConverter().Decode(data)
573     else:
574       self.request_post_data = None
575
576     logging.debug("HTTP POST data: %s", self.request_post_data)
577
578
579 class HttpServer(_HttpSocketBase):
580   """Generic HTTP server class
581
582   Users of this class must subclass it and override the HandleRequest function.
583
584   """
585   MAX_CHILDREN = 20
586
587   def __init__(self, mainloop, local_address, port,
588                ssl_params=None, ssl_verify_peer=False):
589     """Initializes the HTTP server
590
591     @type mainloop: ganeti.daemon.Mainloop
592     @param mainloop: Mainloop used to poll for I/O events
593     @type local_addess: string
594     @param local_address: Local IP address to bind to
595     @type port: int
596     @param port: TCP port to listen on
597     @type ssl_params: HttpSslParams
598     @param ssl_params: SSL key and certificate
599     @type ssl_verify_peer: bool
600     @param ssl_verify_peer: Whether to require client certificate and compare
601                             it with our certificate
602
603     """
604     _HttpSocketBase.__init__(self)
605
606     self.mainloop = mainloop
607     self.local_address = local_address
608     self.port = port
609
610     self.socket = self._CreateSocket(ssl_params, ssl_verify_peer)
611
612     # Allow port to be reused
613     self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
614
615     if self._using_ssl:
616       self._fileio_class = _SSLFileObject
617     else:
618       self._fileio_class = socket._fileobject
619
620     self._children = []
621
622     mainloop.RegisterIO(self, self.socket.fileno(), select.POLLIN)
623     mainloop.RegisterSignal(self)
624
625   def Start(self):
626     self.socket.bind((self.local_address, self.port))
627     self.socket.listen(5)
628
629   def Stop(self):
630     self.socket.close()
631
632   def OnIO(self, fd, condition):
633     if condition & select.POLLIN:
634       self._IncomingConnection()
635
636   def OnSignal(self, signum):
637     if signum == signal.SIGCHLD:
638       self._CollectChildren(True)
639
640   def _CollectChildren(self, quick):
641     """Checks whether any child processes are done
642
643     @type quick: bool
644     @param quick: Whether to only use non-blocking functions
645
646     """
647     if not quick:
648       # Don't wait for other processes if it should be a quick check
649       while len(self._children) > self.MAX_CHILDREN:
650         try:
651           # Waiting without a timeout brings us into a potential DoS situation.
652           # As soon as too many children run, we'll not respond to new
653           # requests. The real solution would be to add a timeout for children
654           # and killing them after some time.
655           pid, status = os.waitpid(0, 0)
656         except os.error:
657           pid = None
658         if pid and pid in self._children:
659           self._children.remove(pid)
660
661     for child in self._children:
662       try:
663         pid, status = os.waitpid(child, os.WNOHANG)
664       except os.error:
665         pid = None
666       if pid and pid in self._children:
667         self._children.remove(pid)
668
669   def _IncomingConnection(self):
670     """Called for each incoming connection
671
672     """
673     (connection, client_addr) = self.socket.accept()
674
675     self._CollectChildren(False)
676
677     pid = os.fork()
678     if pid == 0:
679       # Child process
680       logging.info("Connection from %s:%s", client_addr[0], client_addr[1])
681
682       try:
683         try:
684           try:
685             handler = None
686             try:
687               # Read, parse and handle request
688               handler = _HttpConnectionHandler(self, connection, client_addr,
689                                                self._fileio_class)
690               handler.HandleRequest()
691             finally:
692               # Try to send a response
693               if handler:
694                 handler.SendResponse()
695                 handler.Close()
696           except SocketClosed:
697             pass
698         finally:
699           logging.info("Disconnected %s:%s", client_addr[0], client_addr[1])
700       except:
701         logging.exception("Error while handling request from %s:%s",
702                           client_addr[0], client_addr[1])
703         os._exit(1)
704       os._exit(0)
705     else:
706       self._children.append(pid)
707
708   def HandleRequest(self, req):
709     raise NotImplementedError()
710
711
712 class HttpClientRequest(object):
713   def __init__(self, host, port, method, path, headers=None, post_data=None,
714                ssl_params=None, ssl_verify_peer=False):
715     """Describes an HTTP request.
716
717     @type host: string
718     @param host: Hostname
719     @type port: int
720     @param port: Port
721     @type method: string
722     @param method: Method name
723     @type path: string
724     @param path: Request path
725     @type headers: dict or None
726     @param headers: Additional headers to send
727     @type post_data: string or None
728     @param post_data: Additional data to send
729     @type ssl_params: HttpSslParams
730     @param ssl_params: SSL key and certificate
731     @type ssl_verify_peer: bool
732     @param ssl_verify_peer: Whether to compare our certificate with server's
733                             certificate
734
735     """
736     if post_data is not None:
737       assert method.upper() in (HTTP_POST, HTTP_PUT), \
738         "Only POST and GET requests support sending data"
739
740     assert path.startswith("/"), "Path must start with slash (/)"
741
742     self.host = host
743     self.port = port
744     self.ssl_params = ssl_params
745     self.ssl_verify_peer = ssl_verify_peer
746     self.method = method
747     self.path = path
748     self.headers = headers
749     self.post_data = post_data
750
751     self.success = None
752     self.error = None
753
754     self.resp_status_line = None
755     self.resp_version = None
756     self.resp_status = None
757     self.resp_reason = None
758     self.resp_headers = None
759     self.resp_body = None
760
761
762 class HttpClientRequestExecutor(_HttpSocketBase):
763   # Default headers
764   DEFAULT_HEADERS = {
765     HTTP_USER_AGENT: HTTP_GANETI_VERSION,
766     # TODO: For keep-alive, don't send "Connection: close"
767     HTTP_CONNECTION: "close",
768     }
769
770   # Length limits
771   STATUS_LINE_LENGTH_MAX = 512
772   HEADER_LENGTH_MAX = 4 * 1024
773
774   # Timeouts in seconds for socket layer
775   # TODO: Make read timeout configurable per OpCode
776   CONNECT_TIMEOUT = 5.0
777   WRITE_TIMEOUT = 10
778   READ_TIMEOUT = None
779   CLOSE_TIMEOUT = 1
780
781   # Parser state machine
782   PS_STATUS_LINE = "status-line"
783   PS_HEADERS = "headers"
784   PS_BODY = "body"
785   PS_COMPLETE = "complete"
786
787   # Socket operations
788   (OP_SEND,
789    OP_RECV,
790    OP_CLOSE_CHECK,
791    OP_SHUTDOWN) = range(4)
792
793   def __init__(self, req):
794     """Initializes the HttpClientRequestExecutor class.
795
796     @type req: HttpClientRequest
797     @param req: Request object
798
799     """
800     _HttpSocketBase.__init__(self)
801
802     self.request = req
803
804     self.parser_status = self.PS_STATUS_LINE
805     self.header_buffer = StringIO()
806     self.body_buffer = StringIO()
807     self.content_length = None
808     self.server_will_close = None
809
810     self.poller = select.poll()
811
812     try:
813       # TODO: Implement connection caching/keep-alive
814       self.sock = self._CreateSocket(req.ssl_params,
815                                      req.ssl_verify_peer)
816
817       # Disable Python's timeout
818       self.sock.settimeout(None)
819
820       # Operate in non-blocking mode
821       self.sock.setblocking(0)
822
823       force_close = True
824       self._Connect()
825       try:
826         self._SendRequest()
827         self._ReadResponse()
828
829         # Only wait for server to close if we didn't have any exception.
830         force_close = False
831       finally:
832         self._CloseConnection(force_close)
833
834       self.sock.close()
835       self.sock = None
836
837       req.resp_body = self.body_buffer.getvalue()
838
839       req.success = True
840       req.error = None
841
842     except _HttpClientError, err:
843       req.success = False
844       req.error = str(err)
845
846   def _BuildRequest(self):
847     """Build HTTP request.
848
849     @rtype: string
850     @return: Complete request
851
852     """
853     # Headers
854     send_headers = self.DEFAULT_HEADERS.copy()
855
856     if self.request.headers:
857       send_headers.update(self.request.headers)
858
859     send_headers[HTTP_HOST] = "%s:%s" % (self.request.host, self.request.port)
860
861     if self.request.post_data:
862       send_headers[HTTP_CONTENT_LENGTH] = len(self.request.post_data)
863
864     buf = StringIO()
865
866     # Add request line. We only support HTTP/1.0 (no chunked transfers and no
867     # keep-alive).
868     # TODO: For keep-alive, change to HTTP/1.1
869     buf.write("%s %s %s\r\n" % (self.request.method.upper(),
870                                 self.request.path, HTTP_1_0))
871
872     # Add headers
873     for name, value in send_headers.iteritems():
874       buf.write("%s: %s\r\n" % (name, value))
875
876     buf.write("\r\n")
877
878     if self.request.post_data:
879       buf.write(self.request.post_data)
880
881     return buf.getvalue()
882
883   def _ParseStatusLine(self):
884     """Parses the status line sent by the server.
885
886     """
887     line = self.request.resp_status_line
888
889     if not line:
890       raise _HttpClientError("Empty status line")
891
892     try:
893       [version, status, reason] = line.split(None, 2)
894     except ValueError:
895       try:
896         [version, status] = line.split(None, 1)
897         reason = ""
898       except ValueError:
899         version = HTTP_9_0
900
901     if version:
902       version = version.upper()
903
904     if version not in (HTTP_1_0, HTTP_1_1):
905       # We do not support HTTP/0.9, despite the specification requiring it
906       # (RFC2616, section 19.6)
907       raise _HttpClientError("Only HTTP/1.0 and HTTP/1.1 are supported (%r)" %
908                              line)
909
910     # The status code is a three-digit number
911     try:
912       status = int(status)
913       if status < 100 or status > 999:
914         status = -1
915     except ValueError:
916       status = -1
917
918     if status == -1:
919       raise _HttpClientError("Invalid status code (%r)" % line)
920
921     self.request.resp_version = version
922     self.request.resp_status = status
923     self.request.resp_reason = reason
924
925   def _WillServerCloseConnection(self):
926     """Evaluate whether server will close the connection.
927
928     @rtype: bool
929     @return: Whether server will close the connection
930
931     """
932     hdr_connection = self.request.resp_headers.get(HTTP_CONNECTION, None)
933     if hdr_connection:
934       hdr_connection = hdr_connection.lower()
935
936     # An HTTP/1.1 server is assumed to stay open unless explicitly closed.
937     if self.request.resp_version == HTTP_1_1:
938       return (hdr_connection and "close" in hdr_connection)
939
940     # Some HTTP/1.0 implementations have support for persistent connections,
941     # using rules different than HTTP/1.1.
942
943     # For older HTTP, Keep-Alive indicates persistent connection.
944     if self.request.resp_headers.get(HTTP_KEEP_ALIVE):
945       return False
946
947     # At least Akamai returns a "Connection: Keep-Alive" header, which was
948     # supposed to be sent by the client.
949     if hdr_connection and "keep-alive" in hdr_connection:
950       return False
951
952     return True
953
954   def _ParseHeaders(self):
955     """Parses the headers sent by the server.
956
957     This function also adjusts internal variables based on the header values.
958
959     """
960     req = self.request
961
962     # Parse headers
963     self.header_buffer.seek(0, 0)
964     req.resp_headers = mimetools.Message(self.header_buffer, 0)
965
966     self.server_will_close = self._WillServerCloseConnection()
967
968     # Do we have a Content-Length header?
969     hdr_content_length = req.resp_headers.get(HTTP_CONTENT_LENGTH, None)
970     if hdr_content_length:
971       try:
972         self.content_length = int(hdr_content_length)
973       except ValueError:
974         pass
975       if self.content_length is not None and self.content_length < 0:
976         self.content_length = None
977
978     # does the body have a fixed length? (of zero)
979     if (req.resp_status in (HTTP_NO_CONTENT, HTTP_NOT_MODIFIED) or
980         100 <= req.resp_status < 200 or req.method == HTTP_HEAD):
981       self.content_length = 0
982
983     # if the connection remains open and a content-length was not provided,
984     # then assume that the connection WILL close.
985     if self.content_length is None:
986       self.server_will_close = True
987
988   def _CheckStatusLineLength(self, length):
989     if length > self.STATUS_LINE_LENGTH_MAX:
990       raise _HttpClientError("Status line longer than %d chars" %
991                              self.STATUS_LINE_LENGTH_MAX)
992
993   def _CheckHeaderLength(self, length):
994     if length > self.HEADER_LENGTH_MAX:
995       raise _HttpClientError("Headers longer than %d chars" %
996                              self.HEADER_LENGTH_MAX)
997
998   def _ParseBuffer(self, buf, eof):
999     """Main function for HTTP response state machine.
1000
1001     @type buf: string
1002     @param buf: Receive buffer
1003     @type eof: bool
1004     @param eof: Whether we've reached EOF on the socket
1005     @rtype: string
1006     @return: Updated receive buffer
1007
1008     """
1009     if self.parser_status == self.PS_STATUS_LINE:
1010       # Expect status line
1011       idx = buf.find("\r\n")
1012       if idx >= 0:
1013         self.request.resp_status_line = buf[:idx]
1014
1015         self._CheckStatusLineLength(len(self.request.resp_status_line))
1016
1017         # Remove status line, including CRLF
1018         buf = buf[idx + 2:]
1019
1020         self._ParseStatusLine()
1021
1022         self.parser_status = self.PS_HEADERS
1023       else:
1024         # Check whether incoming data is getting too large, otherwise we just
1025         # fill our read buffer.
1026         self._CheckStatusLineLength(len(buf))
1027
1028     if self.parser_status == self.PS_HEADERS:
1029       # Wait for header end
1030       idx = buf.find("\r\n\r\n")
1031       if idx >= 0:
1032         self.header_buffer.write(buf[:idx + 2])
1033
1034         self._CheckHeaderLength(self.header_buffer.tell())
1035
1036         # Remove headers, including CRLF
1037         buf = buf[idx + 4:]
1038
1039         self._ParseHeaders()
1040
1041         self.parser_status = self.PS_BODY
1042       else:
1043         # Check whether incoming data is getting too large, otherwise we just
1044         # fill our read buffer.
1045         self._CheckHeaderLength(len(buf))
1046
1047     if self.parser_status == self.PS_BODY:
1048       self.body_buffer.write(buf)
1049       buf = ""
1050
1051       # Check whether we've read everything
1052       if (eof or
1053           (self.content_length is not None and
1054            self.body_buffer.tell() >= self.content_length)):
1055         self.parser_status = self.PS_COMPLETE
1056
1057     return buf
1058
1059   def _WaitForCondition(self, event, timeout):
1060     """Waits for a condition to occur on the socket.
1061
1062     @type event: int
1063     @param event: ORed condition (see select module)
1064     @type timeout: float or None
1065     @param timeout: Timeout in seconds
1066     @rtype: int or None
1067     @return: None for timeout, otherwise occured conditions
1068
1069     """
1070     check = (event | select.POLLPRI |
1071              select.POLLNVAL | select.POLLHUP | select.POLLERR)
1072
1073     if timeout is not None:
1074       # Poller object expects milliseconds
1075       timeout *= 1000
1076
1077     self.poller.register(self.sock, event)
1078     try:
1079       while True:
1080         # TODO: If the main thread receives a signal and we have no timeout, we
1081         # could wait forever. This should check a global "quit" flag or
1082         # something every so often.
1083         io_events = self.poller.poll(timeout)
1084         if io_events:
1085           for (evfd, evcond) in io_events:
1086             if evcond & check:
1087               return evcond
1088         else:
1089           # Timeout
1090           return None
1091     finally:
1092       self.poller.unregister(self.sock)
1093
1094   def _SocketOperation(self, op, arg1, error_msg, timeout_msg):
1095     """Wrapper around socket functions.
1096
1097     This function abstracts error handling for socket operations, especially
1098     for the complicated interaction with OpenSSL.
1099
1100     """
1101     if op == self.OP_SEND:
1102       event_poll = select.POLLOUT
1103       event_check = select.POLLOUT
1104       timeout = self.WRITE_TIMEOUT
1105
1106     elif op in (self.OP_RECV, self.OP_CLOSE_CHECK):
1107       event_poll = select.POLLIN
1108       event_check = select.POLLIN | select.POLLPRI
1109       if op == self.OP_CLOSE_CHECK:
1110         timeout = self.CLOSE_TIMEOUT
1111       else:
1112         timeout = self.READ_TIMEOUT
1113
1114     elif op == self.OP_SHUTDOWN:
1115       event_poll = None
1116       event_check = None
1117
1118       # The timeout is only used when OpenSSL requests polling for a condition.
1119       # It is not advisable to have no timeout for shutdown.
1120       timeout = self.WRITE_TIMEOUT
1121
1122     else:
1123       raise AssertionError("Invalid socket operation")
1124
1125     # No override by default
1126     event_override = 0
1127
1128     while True:
1129       # Poll only for certain operations and when asked for by an override
1130       if (event_override or
1131           op in (self.OP_SEND, self.OP_RECV, self.OP_CLOSE_CHECK)):
1132         if event_override:
1133           wait_for_event = event_override
1134         else:
1135           wait_for_event = event_poll
1136
1137         event = self._WaitForCondition(wait_for_event, timeout)
1138         if event is None:
1139           raise _HttpClientTimeout(timeout_msg)
1140
1141         if (op == self.OP_RECV and
1142             event & (select.POLLNVAL | select.POLLHUP | select.POLLERR)):
1143           return ""
1144
1145         if not event & wait_for_event:
1146           continue
1147
1148       # Reset override
1149       event_override = 0
1150
1151       try:
1152         try:
1153           if op == self.OP_SEND:
1154             return self.sock.send(arg1)
1155
1156           elif op in (self.OP_RECV, self.OP_CLOSE_CHECK):
1157             return self.sock.recv(arg1)
1158
1159           elif op == self.OP_SHUTDOWN:
1160             if self._using_ssl:
1161               # PyOpenSSL's shutdown() doesn't take arguments
1162               return self.sock.shutdown()
1163             else:
1164               return self.sock.shutdown(arg1)
1165
1166         except OpenSSL.SSL.WantWriteError:
1167           # OpenSSL wants to write, poll for POLLOUT
1168           event_override = select.POLLOUT
1169           continue
1170
1171         except OpenSSL.SSL.WantReadError:
1172           # OpenSSL wants to read, poll for POLLIN
1173           event_override = select.POLLIN | select.POLLPRI
1174           continue
1175
1176         except OpenSSL.SSL.WantX509LookupError:
1177           continue
1178
1179         except OpenSSL.SSL.SysCallError, err:
1180           if op == self.OP_SEND:
1181             # arg1 is the data when writing
1182             if err.args and err.args[0] == -1 and arg1 == "":
1183               # errors when writing empty strings are expected
1184               # and can be ignored
1185               return 0
1186
1187           elif op == self.OP_RECV:
1188             if err.args == (-1, _SSL_UNEXPECTED_EOF):
1189               return ""
1190
1191           raise socket.error(err.args)
1192
1193         except OpenSSL.SSL.Error, err:
1194           raise socket.error(err.args)
1195
1196       except socket.error, err:
1197         if err.args and err.args[0] == errno.EAGAIN:
1198           # Ignore EAGAIN
1199           continue
1200
1201         raise _HttpClientError("%s: %s" % (error_msg, str(err)))
1202
1203   def _Connect(self):
1204     """Non-blocking connect to host with timeout.
1205
1206     """
1207     connected = False
1208     while True:
1209       try:
1210         connect_error = self.sock.connect_ex((self.request.host,
1211                                               self.request.port))
1212       except socket.gaierror, err:
1213         raise _HttpClientError("Connection failed: %s" % str(err))
1214
1215       if connect_error == errno.EINTR:
1216         # Mask signals
1217         pass
1218
1219       elif connect_error == 0:
1220         # Connection established
1221         connected = True
1222         break
1223
1224       elif connect_error == errno.EINPROGRESS:
1225         # Connection started
1226         break
1227
1228       raise _HttpClientError("Connection failed (%s: %s)" %
1229                              (connect_error, os.strerror(connect_error)))
1230
1231     if not connected:
1232       # Wait for connection
1233       event = self._WaitForCondition(select.POLLOUT, self.CONNECT_TIMEOUT)
1234       if event is None:
1235         raise _HttpClientError("Timeout while connecting to server")
1236
1237       # Get error code
1238       connect_error = self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
1239       if connect_error != 0:
1240         raise _HttpClientError("Connection failed (%s: %s)" %
1241                                (connect_error, os.strerror(connect_error)))
1242
1243     # Enable TCP keep-alive
1244     self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
1245
1246     # If needed, Linux specific options are available to change the TCP
1247     # keep-alive settings, see "man 7 tcp" for TCP_KEEPCNT, TCP_KEEPIDLE and
1248     # TCP_KEEPINTVL.
1249
1250   def _SendRequest(self):
1251     """Sends request to server.
1252
1253     """
1254     buf = self._BuildRequest()
1255
1256     while buf:
1257       # Send only 4 KB at a time
1258       data = buf[:4096]
1259
1260       sent = self._SocketOperation(self.OP_SEND, data,
1261                                    "Error while sending request",
1262                                    "Timeout while sending request")
1263
1264       # Remove sent bytes
1265       buf = buf[sent:]
1266
1267     assert not buf, "Request wasn't sent completely"
1268
1269   def _ReadResponse(self):
1270     """Read response from server.
1271
1272     Calls the parser function after reading a chunk of data.
1273
1274     """
1275     buf = ""
1276     eof = False
1277     while self.parser_status != self.PS_COMPLETE:
1278       data = self._SocketOperation(self.OP_RECV, 4096,
1279                                    "Error while reading response",
1280                                    "Timeout while reading response")
1281
1282       if data:
1283         buf += data
1284       else:
1285         eof = True
1286
1287       # Do some parsing and error checking while more data arrives
1288       buf = self._ParseBuffer(buf, eof)
1289
1290       # Must be done only after the buffer has been evaluated
1291       if (eof and
1292           self.parser_status in (self.PS_STATUS_LINE,
1293                                  self.PS_HEADERS)):
1294         raise _HttpClientError("Connection closed prematurely")
1295
1296     # Parse rest
1297     buf = self._ParseBuffer(buf, True)
1298
1299     assert self.parser_status == self.PS_COMPLETE
1300     assert not buf, "Parser didn't read full response"
1301
1302   def _CloseConnection(self, force):
1303     """Closes the connection.
1304
1305     """
1306     if self.server_will_close and not force:
1307       # Wait for server to close
1308       try:
1309         # Check whether it's actually closed
1310         if not self._SocketOperation(self.OP_CLOSE_CHECK, 1,
1311                                      "Error", "Timeout"):
1312           return
1313       except (socket.error, _HttpClientError):
1314         # Ignore errors at this stage
1315         pass
1316
1317     # Close the connection from our side
1318     self._SocketOperation(self.OP_SHUTDOWN, socket.SHUT_RDWR,
1319                           "Error while shutting down connection",
1320                           "Timeout while shutting down connection")
1321
1322
1323 class _HttpClientPendingRequest(object):
1324   """Data class for pending requests.
1325
1326   """
1327   def __init__(self, request):
1328     self.request = request
1329
1330     # Thread synchronization
1331     self.done = threading.Event()
1332
1333
1334 class HttpClientWorker(workerpool.BaseWorker):
1335   """HTTP client worker class.
1336
1337   """
1338   def RunTask(self, pend_req):
1339     try:
1340       HttpClientRequestExecutor(pend_req.request)
1341     finally:
1342       pend_req.done.set()
1343
1344
1345 class HttpClientWorkerPool(workerpool.WorkerPool):
1346   def __init__(self, manager):
1347     workerpool.WorkerPool.__init__(self, HTTP_CLIENT_THREADS,
1348                                    HttpClientWorker)
1349     self.manager = manager
1350
1351
1352 class HttpClientManager(object):
1353   """Manages HTTP requests.
1354
1355   """
1356   def __init__(self):
1357     self._wpool = HttpClientWorkerPool(self)
1358
1359   def __del__(self):
1360     self.Shutdown()
1361
1362   def ExecRequests(self, requests):
1363     """Execute HTTP requests.
1364
1365     This function can be called from multiple threads at the same time.
1366
1367     @type requests: List of HttpClientRequest instances
1368     @param requests: The requests to execute
1369     @rtype: List of HttpClientRequest instances
1370     @returns: The list of requests passed in
1371
1372     """
1373     # _HttpClientPendingRequest is used for internal thread synchronization
1374     pending = [_HttpClientPendingRequest(req) for req in requests]
1375
1376     try:
1377       # Add requests to queue
1378       for pend_req in pending:
1379         self._wpool.AddTask(pend_req)
1380
1381     finally:
1382       # In case of an exception we should still wait for the rest, otherwise
1383       # another thread from the worker pool could modify the request object
1384       # after we returned.
1385
1386       # And wait for them to finish
1387       for pend_req in pending:
1388         pend_req.done.wait()
1389
1390     # Return original list
1391     return requests
1392
1393   def Shutdown(self):
1394     self._wpool.Quiesce()
1395     self._wpool.TerminateWorkers()
1396
1397
1398 class _SSLFileObject(object):
1399   """Wrapper around socket._fileobject
1400
1401   This wrapper is required to handle OpenSSL exceptions.
1402
1403   """
1404   def _RequireOpenSocket(fn):
1405     def wrapper(self, *args, **kwargs):
1406       if self.closed:
1407         raise SocketClosed("Socket is closed")
1408       return fn(self, *args, **kwargs)
1409     return wrapper
1410
1411   def __init__(self, sock, mode='rb', bufsize=-1):
1412     self._base = socket._fileobject(sock, mode=mode, bufsize=bufsize)
1413
1414   def _ConnectionLost(self):
1415     self._base = None
1416
1417   def _getclosed(self):
1418     return self._base is None or self._base.closed
1419   closed = property(_getclosed, doc="True if the file is closed")
1420
1421   @_RequireOpenSocket
1422   def close(self):
1423     return self._base.close()
1424
1425   @_RequireOpenSocket
1426   def flush(self):
1427     return self._base.flush()
1428
1429   @_RequireOpenSocket
1430   def fileno(self):
1431     return self._base.fileno()
1432
1433   @_RequireOpenSocket
1434   def read(self, size=-1):
1435     return self._ReadWrapper(self._base.read, size=size)
1436
1437   @_RequireOpenSocket
1438   def readline(self, size=-1):
1439     return self._ReadWrapper(self._base.readline, size=size)
1440
1441   def _ReadWrapper(self, fn, *args, **kwargs):
1442     while True:
1443       try:
1444         return fn(*args, **kwargs)
1445
1446       except OpenSSL.SSL.ZeroReturnError, err:
1447         self._ConnectionLost()
1448         return ""
1449
1450       except OpenSSL.SSL.WantReadError:
1451         continue
1452
1453       #except OpenSSL.SSL.WantWriteError:
1454       # TODO
1455
1456       except OpenSSL.SSL.SysCallError, (retval, desc):
1457         if ((retval == -1 and desc == _SSL_UNEXPECTED_EOF)
1458             or retval > 0):
1459           self._ConnectionLost()
1460           return ""
1461
1462         logging.exception("Error in OpenSSL")
1463         self._ConnectionLost()
1464         raise socket.error(err.args)
1465
1466       except OpenSSL.SSL.Error, err:
1467         self._ConnectionLost()
1468         raise socket.error(err.args)
1469
1470   @_RequireOpenSocket
1471   def write(self, data):
1472     return self._WriteWrapper(self._base.write, data)
1473
1474   def _WriteWrapper(self, fn, *args, **kwargs):
1475     while True:
1476       try:
1477         return fn(*args, **kwargs)
1478       except OpenSSL.SSL.ZeroReturnError, err:
1479         self._ConnectionLost()
1480         return 0
1481
1482       except OpenSSL.SSL.WantWriteError:
1483         continue
1484
1485       #except OpenSSL.SSL.WantReadError:
1486       # TODO
1487
1488       except OpenSSL.SSL.SysCallError, err:
1489         if err.args[0] == -1 and data == "":
1490           # errors when writing empty strings are expected
1491           # and can be ignored
1492           return 0
1493
1494         self._ConnectionLost()
1495         raise socket.error(err.args)
1496
1497       except OpenSSL.SSL.Error, err:
1498         self._ConnectionLost()
1499         raise socket.error(err.args)