Use hashmap lib in api.
[pithos] / pithos / api / util.py
index bff9fc3..1ac568f 100644 (file)
@@ -37,19 +37,25 @@ from traceback import format_exc
 from wsgiref.handlers import format_date_time
 from binascii import hexlify, unhexlify
 from datetime import datetime, tzinfo, timedelta
+from urllib import quote, unquote
 
 from django.conf import settings
 from django.http import HttpResponse
 from django.utils import simplejson as json
 from django.utils.http import http_date, parse_etags
-from django.utils.encoding import smart_str
+from django.utils.encoding import smart_unicode, smart_str
+from django.core.files.uploadhandler import FileUploadHandler
+from django.core.files.uploadedfile import UploadedFile
 
-from pithos.api.compat import parse_http_date_safe, parse_http_date
-from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, ItemNotFound,
-                                Conflict, LengthRequired, PreconditionFailed, RangeNotSatisfiable,
-                                ServiceUnavailable)
+from pithos.lib.compat import parse_http_date_safe, parse_http_date
+from pithos.lib.hashmap import HashMap
+
+from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, Forbidden, ItemNotFound,
+                                Conflict, LengthRequired, PreconditionFailed, RequestEntityTooLarge,
+                                RangeNotSatisfiable, ServiceUnavailable)
+from pithos.api.short_url import encode_url
 from pithos.backends import connect_backend
-from pithos.backends.base import NotAllowedError
+from pithos.backends.base import NotAllowedError, QuotaError
 
 import logging
 import re
@@ -57,6 +63,7 @@ import hashlib
 import uuid
 import decimal
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -93,7 +100,8 @@ def printable_header_dict(d):
     Format 'last_modified' timestamp.
     """
     
-    d['last_modified'] = isoformat(datetime.fromtimestamp(d['last_modified']))
+    if 'last_modified' in d:
+        d['last_modified'] = isoformat(datetime.fromtimestamp(d['last_modified']))
     return dict([(k.lower().replace('-', '_'), v) for k, v in d.iteritems()])
 
 def format_header_key(k):
@@ -172,11 +180,12 @@ def get_object_headers(request):
     return meta, get_sharing(request), get_public(request)
 
 def put_object_headers(response, meta, restricted=False):
-    response['ETag'] = meta['hash']
+    response['ETag'] = meta['ETag']
     response['Content-Length'] = meta['bytes']
     response['Content-Type'] = meta.get('Content-Type', 'application/octet-stream')
     response['Last-Modified'] = http_date(int(meta['modified']))
     if not restricted:
+        response['X-Object-Hash'] = meta['hash']
         response['X-Object-Modified-By'] = smart_str(meta['modified_by'], strings_only=True)
         response['X-Object-Version'] = meta['version']
         response['X-Object-Version-Timestamp'] = http_date(int(meta['version_timestamp']))
@@ -190,30 +199,30 @@ def put_object_headers(response, meta, restricted=False):
     else:
         for k in ('Content-Encoding', 'Content-Disposition'):
             if k in meta:
-                response[k] = meta[k]
+                response[k] = smart_str(meta[k], strings_only=True)
 
 def update_manifest_meta(request, v_account, meta):
     """Update metadata if the object has an X-Object-Manifest."""
     
     if 'X-Object-Manifest' in meta:
-        hash = ''
+        etag = ''
         bytes = 0
         try:
             src_container, src_name = split_container_object_string('/' + meta['X-Object-Manifest'])
-            objects = request.backend.list_objects(request.user, v_account,
+            objects = request.backend.list_objects(request.user_uniq, v_account,
                                 src_container, prefix=src_name, virtual=False)
             for x in objects:
-                src_meta = request.backend.get_object_meta(request.user,
+                src_meta = request.backend.get_object_meta(request.user_uniq,
                                         v_account, src_container, x[0], x[1])
-                hash += src_meta['hash']
+                etag += src_meta['ETag']
                 bytes += src_meta['bytes']
         except:
             # Ignore errors.
             return
         meta['bytes'] = bytes
         md5 = hashlib.md5()
-        md5.update(hash)
-        meta['hash'] = md5.hexdigest().lower()
+        md5.update(etag)
+        meta['ETag'] = md5.hexdigest().lower()
 
 def update_sharing_meta(request, permissions, v_account, v_container, v_object, meta):
     if permissions is None:
@@ -231,13 +240,13 @@ def update_sharing_meta(request, permissions, v_account, v_container, v_object,
     meta['X-Object-Sharing'] = '; '.join(ret)
     if '/'.join((v_account, v_container, v_object)) != perm_path:
         meta['X-Object-Shared-By'] = perm_path
-    if request.user != v_account:
+    if request.user_uniq != v_account:
         meta['X-Object-Allowed-To'] = allowed
 
 def update_public_meta(public, meta):
     if not public:
         return
-    meta['X-Object-Public'] = public
+    meta['X-Object-Public'] = '/public/' + encode_url(public)
 
 def validate_modification_preconditions(request, meta):
     """Check that the modified timestamp conforms with the preconditions set."""
@@ -260,20 +269,20 @@ def validate_modification_preconditions(request, meta):
 def validate_matching_preconditions(request, meta):
     """Check that the ETag conforms with the preconditions set."""
     
-    hash = meta.get('hash', None)
+    etag = meta.get('ETag', None)
     
     if_match = request.META.get('HTTP_IF_MATCH')
     if if_match is not None:
-        if hash is None:
+        if etag is None:
             raise PreconditionFailed('Resource does not exist')
-        if if_match != '*' and hash not in [x.lower() for x in parse_etags(if_match)]:
+        if if_match != '*' and etag not in [x.lower() for x in parse_etags(if_match)]:
             raise PreconditionFailed('Resource ETag does not match')
     
     if_none_match = request.META.get('HTTP_IF_NONE_MATCH')
     if if_none_match is not None:
         # TODO: If this passes, must ignore If-Modified-Since header.
-        if hash is not None:
-            if if_none_match == '*' or hash in [x.lower() for x in parse_etags(if_none_match)]:
+        if etag is not None:
+            if if_none_match == '*' or etag in [x.lower() for x in parse_etags(if_none_match)]:
                 # TODO: Continue if an If-Modified-Since header is present.
                 if request.method in ('HEAD', 'GET'):
                     raise NotModified('Resource ETag matches')
@@ -284,7 +293,7 @@ def split_container_object_string(s):
         raise ValueError
     s = s[1:]
     pos = s.find('/')
-    if pos == -1:
+    if pos == -1 or pos == len(s) - 1:
         raise ValueError
     return s[:pos], s[(pos + 1):]
 
@@ -295,26 +304,28 @@ def copy_or_move_object(request, src_account, src_container, src_name, dest_acco
     src_version = request.META.get('HTTP_X_SOURCE_VERSION')
     try:
         if move:
-            version_id = request.backend.move_object(request.user, src_account, src_container, src_name,
+            version_id = request.backend.move_object(request.user_uniq, src_account, src_container, src_name,
                                                         dest_account, dest_container, dest_name,
                                                         meta, False, permissions)
         else:
-            version_id = request.backend.copy_object(request.user, src_account, src_container, src_name,
+            version_id = request.backend.copy_object(request.user_uniq, src_account, src_container, src_name,
                                                         dest_account, dest_container, dest_name,
                                                         meta, False, permissions, src_version)
     except NotAllowedError:
-        raise Unauthorized('Access denied')
+        raise Forbidden('Not allowed')
     except (NameError, IndexError):
         raise ItemNotFound('Container or object does not exist')
     except ValueError:
         raise BadRequest('Invalid sharing header')
     except AttributeError, e:
         raise Conflict('\n'.join(e.data) + '\n')
+    except QuotaError:
+        raise RequestEntityTooLarge('Quota exceeded')
     if public is not None:
         try:
-            request.backend.update_object_public(request.user, dest_account, dest_container, dest_name, public)
+            request.backend.update_object_public(request.user_uniq, dest_account, dest_container, dest_name, public)
         except NotAllowedError:
-            raise Unauthorized('Access denied')
+            raise Forbidden('Not allowed')
         except NameError:
             raise ItemNotFound('Object does not exist')
     return version_id
@@ -498,7 +509,7 @@ def socket_read_iterator(request, length=0, blocksize=4096):
     sock = raw_input_socket(request)
     if length < 0: # Chunked transfers
         # Small version (server does the dechunking).
-        if request.environ.get('mod_wsgi.input_chunked', None):
+        if request.environ.get('mod_wsgi.input_chunked', None) or request.META['SERVER_SOFTWARE'].startswith('gunicorn'):
             while length < MAX_UPLOAD_SIZE:
                 data = sock.read(blocksize)
                 if data == '':
@@ -552,6 +563,40 @@ def socket_read_iterator(request, length=0, blocksize=4096):
             length -= len(data)
             yield data
 
+class SaveToBackendHandler(FileUploadHandler):
+    """Handle a file from an HTML form the django way."""
+    
+    def __init__(self, request=None):
+        super(SaveToBackendHandler, self).__init__(request)
+        self.backend = request.backend
+    
+    def put_data(self, length):
+        if len(self.data) >= length:
+            block = self.data[:length]
+            self.file.hashmap.append(self.backend.put_block(block))
+            self.md5.update(block)
+            self.data = self.data[length:]
+    
+    def new_file(self, field_name, file_name, content_type, content_length, charset=None):
+        self.md5 = hashlib.md5()        
+        self.data = ''
+        self.file = UploadedFile(name=file_name, content_type=content_type, charset=charset)
+        self.file.size = 0
+        self.file.hashmap = []
+    
+    def receive_data_chunk(self, raw_data, start):
+        self.data += raw_data
+        self.file.size += len(raw_data)
+        self.put_data(self.request.backend.block_size)
+        return None
+    
+    def file_complete(self, file_size):
+        l = len(self.data)
+        if l > 0:
+            self.put_data(l)
+        self.file.etag = self.md5.hexdigest().lower()
+        return self.file
+
 class ObjectWrapper(object):
     """Return the object's data block-per-block in each iteration.
     
@@ -662,7 +707,7 @@ def object_data_response(request, sizes, hashmaps, meta, public=False):
                     ranges = [(0, size)]
                     ret = 200
             except ValueError:
-                if if_range != meta['hash']:
+                if if_range != meta['ETag']:
                     ranges = [(0, size)]
                     ret = 200
     
@@ -698,24 +743,29 @@ def put_object_block(request, hashmap, data, offset):
 def hashmap_hash(request, hashmap):
     """Produce the root hash, treating the hashmap as a Merkle-like tree."""
     
-    def subhash(d):
-        h = hashlib.new(request.backend.hash_algorithm)
-        h.update(d)
-        return h.digest()
-    
-    if len(hashmap) == 0:
-        return hexlify(subhash(''))
-    if len(hashmap) == 1:
-        return hashmap[0]
-    
-    s = 2
-    while s < len(hashmap):
-        s = s * 2
-    h = [unhexlify(x) for x in hashmap]
-    h += [('\x00' * len(h[0]))] * (s - len(hashmap))
-    while len(h) > 1:
-        h = [subhash(h[x] + h[x + 1]) for x in range(0, len(h), 2)]
-    return hexlify(h[0])
+    map = HashMap(request.backend.block_size, request.backend.hash_algorithm)
+    map.extend([unhexlify(x) for x in hashmap])
+    return hexlify(map.hash())
+
+def update_request_headers(request):
+    # Handle URL-encoded keys and values.
+    # Handle URL-encoded keys and values.
+    meta = dict([(k, v) for k, v in request.META.iteritems() if k.startswith('HTTP_')])
+    if len(meta) > 90:
+        raise BadRequest('Too many headers.')
+    for k, v in meta.iteritems():
+        if len(k) > 128:
+            raise BadRequest('Header name too large.')
+        if len(v) > 256:
+            raise BadRequest('Header value too large.')
+        try:
+            k.decode('ascii')
+            v.decode('ascii')
+        except UnicodeDecodeError:
+            raise BadRequest('Bad character in headers.')
+        if '%' in k or '%' in v:
+            del(request.META[k])
+            request.META[unquote(k)] = smart_unicode(unquote(v), strings_only=True)
 
 def update_response_headers(request, response):
     if request.serialization == 'xml':
@@ -725,9 +775,19 @@ def update_response_headers(request, response):
     elif not response['Content-Type']:
         response['Content-Type'] = 'text/plain; charset=UTF-8'
     
-    if not response.has_header('Content-Length') and not (response.has_header('Content-Type') and response['Content-Type'].startswith('multipart/byteranges')):
+    if (not response.has_header('Content-Length') and
+        not (response.has_header('Content-Type') and
+             response['Content-Type'].startswith('multipart/byteranges'))):
         response['Content-Length'] = len(response.content)
     
+    # URL-encode unicode in headers.
+    meta = response.items()
+    for k, v in meta:
+        if (k.startswith('X-Account-') or k.startswith('X-Container-') or
+            k.startswith('X-Object-') or k.startswith('Content-')):
+            del(response[k])
+            response[quote(k)] = quote(v, safe='/=,:@; ')
+    
     if settings.TEST:
         response['Date'] = format_date_time(time())
 
@@ -765,7 +825,7 @@ def request_serialization(request, format_allowed=False):
     
     return 'text'
 
-def api_method(http_method=None, format_allowed=False):
+def api_method(http_method=None, format_allowed=False, user_required=True):
     """Decorator function for views that implement an API method."""
     
     def decorator(func):
@@ -774,6 +834,8 @@ def api_method(http_method=None, format_allowed=False):
             try:
                 if http_method and request.method != http_method:
                     raise BadRequest('Method not allowed.')
+                if user_required and getattr(request, 'user', None) is None:
+                    raise Unauthorized('Access denied')
                 
                 # The args variable may contain up to (account, container, object).
                 if len(args) > 1 and len(args[1]) > 256:
@@ -781,10 +843,13 @@ def api_method(http_method=None, format_allowed=False):
                 if len(args) > 2 and len(args[2]) > 1024:
                     raise BadRequest('Object name too large.')
                 
+                # Format and check headers.
+                update_request_headers(request)
+                
                 # Fill in custom request variables.
                 request.serialization = request_serialization(request, format_allowed)
                 request.backend = connect_backend()
-
+                
                 response = func(request, *args, **kwargs)
                 update_response_headers(request, response)
                 return response
@@ -795,6 +860,7 @@ def api_method(http_method=None, format_allowed=False):
                 fault = ServiceUnavailable('Unexpected error')
                 return render_fault(request, fault)
             finally:
-                request.backend.wrapper.conn.close()
+                if getattr(request, 'backend', None) is not None:
+                    request.backend.close()
         return wrapper
     return decorator