Use 401 error when missing token and 403 when forbidden.
[pithos] / pithos / api / util.py
index 369c76c..296be84 100644 (file)
@@ -35,45 +35,69 @@ from functools import wraps
 from time import time
 from traceback import format_exc
 from wsgiref.handlers import format_date_time
-from binascii import hexlify
+from binascii import hexlify, unhexlify
+from datetime import datetime, tzinfo, timedelta
 
 from django.conf import settings
 from django.http import HttpResponse
 from django.utils import simplejson as json
 from django.utils.http import http_date, parse_etags
+from django.utils.encoding import smart_str
 
 from pithos.api.compat import parse_http_date_safe, parse_http_date
-from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, ItemNotFound,
-                                Conflict, LengthRequired, PreconditionFailed, RangeNotSatisfiable,
-                                ServiceUnavailable)
-from pithos.backends import backend
-from pithos.backends.base import NotAllowedError
+from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, Forbidden, ItemNotFound,
+                                Conflict, LengthRequired, PreconditionFailed, RequestEntityTooLarge,
+                                RangeNotSatisfiable, ServiceUnavailable)
+from pithos.backends import connect_backend
+from pithos.backends.base import NotAllowedError, QuotaError
 
-import datetime
 import logging
 import re
 import hashlib
 import uuid
-
+import decimal
 
 logger = logging.getLogger(__name__)
 
 
+class UTC(tzinfo):
+   def utcoffset(self, dt):
+       return timedelta(0)
+
+   def tzname(self, dt):
+       return 'UTC'
+
+   def dst(self, dt):
+       return timedelta(0)
+
+def json_encode_decimal(obj):
+    if isinstance(obj, decimal.Decimal):
+        return str(obj)
+    raise TypeError(repr(obj) + " is not JSON serializable")
+
+def isoformat(d):
+   """Return an ISO8601 date string that includes a timezone."""
+
+   return d.replace(tzinfo=UTC()).isoformat()
+
+def rename_meta_key(d, old, new):
+    if old not in d:
+        return
+    d[new] = d[old]
+    del(d[old])
+
 def printable_header_dict(d):
     """Format a meta dictionary for printing out json/xml.
     
-    Convert all keys to lower case and replace dashes to underscores.
-    Change 'modified' key from backend to 'last_modified' and format date.
+    Convert all keys to lower case and replace dashes with underscores.
+    Format 'last_modified' timestamp.
     """
     
-    if 'modified' in d:
-        d['last_modified'] = datetime.datetime.fromtimestamp(int(d['modified'])).isoformat()
-        del(d['modified'])
+    d['last_modified'] = isoformat(datetime.fromtimestamp(d['last_modified']))
     return dict([(k.lower().replace('-', '_'), v) for k, v in d.iteritems()])
 
 def format_header_key(k):
     """Convert underscores to dashes and capitalize intra-dash strings."""
-    
     return '-'.join([x.capitalize() for x in k.replace('_', '-').split('-')])
 
 def get_header_prefix(request, prefix):
@@ -91,40 +115,49 @@ def get_account_headers(request):
         if '-' in n or '_' in n:
             raise BadRequest('Bad characters in group name')
         groups[n] = v.replace(' ', '').split(',')
-        if '' in groups[n]:
+        while '' in groups[n]:
             groups[n].remove('')
     return meta, groups
 
-def put_account_headers(response, meta, groups):
-    response['X-Account-Container-Count'] = meta['count']
-    response['X-Account-Bytes-Used'] = meta['bytes']
-    if 'modified' in meta:
-        response['Last-Modified'] = http_date(int(meta['modified']))
+def put_account_headers(response, meta, groups, policy):
+    if 'count' in meta:
+        response['X-Account-Container-Count'] = meta['count']
+    if 'bytes' in meta:
+        response['X-Account-Bytes-Used'] = meta['bytes']
+    response['Last-Modified'] = http_date(int(meta['modified']))
     for k in [x for x in meta.keys() if x.startswith('X-Account-Meta-')]:
-        response[k.encode('utf-8')] = meta[k].encode('utf-8')
+        response[smart_str(k, strings_only=True)] = smart_str(meta[k], strings_only=True)
     if 'until_timestamp' in meta:
         response['X-Account-Until-Timestamp'] = http_date(int(meta['until_timestamp']))
     for k, v in groups.iteritems():
-        response[format_header_key('X-Account-Group-' + k).encode('utf-8')] = (','.join(v)).encode('utf-8')
+        k = smart_str(k, strings_only=True)
+        k = format_header_key('X-Account-Group-' + k)
+        v = smart_str(','.join(v), strings_only=True)
+        response[k] = v
+    for k, v in policy.iteritems():
+        response[smart_str(format_header_key('X-Account-Policy-' + k), strings_only=True)] = smart_str(v, strings_only=True)
 
 def get_container_headers(request):
     meta = get_header_prefix(request, 'X-Container-Meta-')
     policy = dict([(k[19:].lower(), v.replace(' ', '')) for k, v in get_header_prefix(request, 'X-Container-Policy-').iteritems()])
     return meta, policy
 
-def put_container_headers(response, meta, policy):
-    response['X-Container-Object-Count'] = meta['count']
-    response['X-Container-Bytes-Used'] = meta['bytes']
+def put_container_headers(request, response, meta, policy):
+    if 'count' in meta:
+        response['X-Container-Object-Count'] = meta['count']
+    if 'bytes' in meta:
+        response['X-Container-Bytes-Used'] = meta['bytes']
     response['Last-Modified'] = http_date(int(meta['modified']))
     for k in [x for x in meta.keys() if x.startswith('X-Container-Meta-')]:
-        response[k.encode('utf-8')] = meta[k].encode('utf-8')
-    response['X-Container-Object-Meta'] = [x[14:] for x in meta['object_meta'] if x.startswith('X-Object-Meta-')]
-    response['X-Container-Block-Size'] = backend.block_size
-    response['X-Container-Block-Hash'] = backend.hash_algorithm
+        response[smart_str(k, strings_only=True)] = smart_str(meta[k], strings_only=True)
+    l = [smart_str(x, strings_only=True) for x in meta['object_meta'] if x.startswith('X-Object-Meta-')]
+    response['X-Container-Object-Meta'] = ','.join([x[14:] for x in l])
+    response['X-Container-Block-Size'] = request.backend.block_size
+    response['X-Container-Block-Hash'] = request.backend.hash_algorithm
     if 'until_timestamp' in meta:
         response['X-Container-Until-Timestamp'] = http_date(int(meta['until_timestamp']))
     for k, v in policy.iteritems():
-        response[format_header_key('X-Container-Policy-' + k).encode('utf-8')] = v.encode('utf-8')
+        response[smart_str(format_header_key('X-Container-Policy-' + k), strings_only=True)] = smart_str(v, strings_only=True)
 
 def get_object_headers(request):
     meta = get_header_prefix(request, 'X-Object-Meta-')
@@ -144,14 +177,16 @@ def put_object_headers(response, meta, restricted=False):
     response['Content-Type'] = meta.get('Content-Type', 'application/octet-stream')
     response['Last-Modified'] = http_date(int(meta['modified']))
     if not restricted:
-        response['X-Object-Modified-By'] = meta['modified_by']
+        response['X-Object-Modified-By'] = smart_str(meta['modified_by'], strings_only=True)
         response['X-Object-Version'] = meta['version']
         response['X-Object-Version-Timestamp'] = http_date(int(meta['version_timestamp']))
         for k in [x for x in meta.keys() if x.startswith('X-Object-Meta-')]:
-            response[k.encode('utf-8')] = meta[k].encode('utf-8')
-        for k in ('Content-Encoding', 'Content-Disposition', 'X-Object-Manifest', 'X-Object-Sharing', 'X-Object-Shared-By', 'X-Object-Public'):
+            response[smart_str(k, strings_only=True)] = smart_str(meta[k], strings_only=True)
+        for k in ('Content-Encoding', 'Content-Disposition', 'X-Object-Manifest',
+                  'X-Object-Sharing', 'X-Object-Shared-By', 'X-Object-Allowed-To',
+                  'X-Object-Public'):
             if k in meta:
-                response[k] = meta[k]
+                response[k] = smart_str(meta[k], strings_only=True)
     else:
         for k in ('Content-Encoding', 'Content-Disposition'):
             if k in meta:
@@ -165,9 +200,11 @@ def update_manifest_meta(request, v_account, meta):
         bytes = 0
         try:
             src_container, src_name = split_container_object_string('/' + meta['X-Object-Manifest'])
-            objects = backend.list_objects(request.user, v_account, src_container, prefix=src_name, virtual=False)
+            objects = request.backend.list_objects(request.user, v_account,
+                                src_container, prefix=src_name, virtual=False)
             for x in objects:
-                src_meta = backend.get_object_meta(request.user, v_account, src_container, x[0], x[1])
+                src_meta = request.backend.get_object_meta(request.user,
+                                        v_account, src_container, x[0], x[1])
                 hash += src_meta['hash']
                 bytes += src_meta['bytes']
         except:
@@ -178,10 +215,10 @@ def update_manifest_meta(request, v_account, meta):
         md5.update(hash)
         meta['hash'] = md5.hexdigest().lower()
 
-def update_sharing_meta(permissions, v_account, v_container, v_object, meta):
+def update_sharing_meta(request, permissions, v_account, v_container, v_object, meta):
     if permissions is None:
         return
-    perm_path, perms = permissions
+    allowed, perm_path, perms = permissions
     if len(perms) == 0:
         return
     ret = []
@@ -194,6 +231,8 @@ def update_sharing_meta(permissions, v_account, v_container, v_object, meta):
     meta['X-Object-Sharing'] = '; '.join(ret)
     if '/'.join((v_account, v_container, v_object)) != perm_path:
         meta['X-Object-Shared-By'] = perm_path
+    if request.user != v_account:
+        meta['X-Object-Allowed-To'] = allowed
 
 def update_public_meta(public, meta):
     if not public:
@@ -221,18 +260,24 @@ def validate_modification_preconditions(request, meta):
 def validate_matching_preconditions(request, meta):
     """Check that the ETag conforms with the preconditions set."""
     
-    if 'hash' not in meta:
-        return # TODO: Always return?
+    hash = meta.get('hash', None)
     
     if_match = request.META.get('HTTP_IF_MATCH')
-    if if_match is not None and if_match != '*':
-        if meta['hash'] not in [x.lower() for x in parse_etags(if_match)]:
-            raise PreconditionFailed('Resource Etag does not match')
+    if if_match is not None:
+        if hash is None:
+            raise PreconditionFailed('Resource does not exist')
+        if if_match != '*' and hash not in [x.lower() for x in parse_etags(if_match)]:
+            raise PreconditionFailed('Resource ETag does not match')
     
     if_none_match = request.META.get('HTTP_IF_NONE_MATCH')
     if if_none_match is not None:
-        if if_none_match == '*' or meta['hash'] in [x.lower() for x in parse_etags(if_none_match)]:
-            raise NotModified('Resource Etag matches')
+        # TODO: If this passes, must ignore If-Modified-Since header.
+        if hash is not None:
+            if if_none_match == '*' or hash in [x.lower() for x in parse_etags(if_none_match)]:
+                # TODO: Continue if an If-Modified-Since header is present.
+                if request.method in ('HEAD', 'GET'):
+                    raise NotModified('Resource ETag matches')
+                raise PreconditionFailed('Resource exists or ETag matches')
 
 def split_container_object_string(s):
     if not len(s) > 0 or s[0] != '/':
@@ -243,31 +288,38 @@ def split_container_object_string(s):
         raise ValueError
     return s[:pos], s[(pos + 1):]
 
-def copy_or_move_object(request, v_account, src_container, src_name, dest_container, dest_name, move=False):
+def copy_or_move_object(request, src_account, src_container, src_name, dest_account, dest_container, dest_name, move=False):
     """Copy or move an object."""
     
     meta, permissions, public = get_object_headers(request)
-    src_version = request.META.get('HTTP_X_SOURCE_VERSION')    
+    src_version = request.META.get('HTTP_X_SOURCE_VERSION')
     try:
         if move:
-            backend.move_object(request.user, v_account, src_container, src_name, dest_container, dest_name, meta, False, permissions)
+            version_id = request.backend.move_object(request.user, src_account, src_container, src_name,
+                                                        dest_account, dest_container, dest_name,
+                                                        meta, False, permissions)
         else:
-            backend.copy_object(request.user, v_account, src_container, src_name, dest_container, dest_name, meta, False, permissions, src_version)
+            version_id = request.backend.copy_object(request.user, src_account, src_container, src_name,
+                                                        dest_account, dest_container, dest_name,
+                                                        meta, False, permissions, src_version)
     except NotAllowedError:
-        raise Unauthorized('Access denied')
-    except NameError, IndexError:
+        raise Forbidden('Not allowed')
+    except (NameError, IndexError):
         raise ItemNotFound('Container or object does not exist')
     except ValueError:
         raise BadRequest('Invalid sharing header')
     except AttributeError, e:
-        raise Conflict(json.dumps(e.data))
+        raise Conflict('\n'.join(e.data) + '\n')
+    except QuotaError:
+        raise RequestEntityTooLarge('Quota exceeded')
     if public is not None:
         try:
-            backend.update_object_public(request.user, v_account, dest_container, dest_name, public)
+            request.backend.update_object_public(request.user, dest_account, dest_container, dest_name, public)
         except NotAllowedError:
-            raise Unauthorized('Access denied')
+            raise Forbidden('Not allowed')
         except NameError:
             raise ItemNotFound('Object does not exist')
+    return version_id
 
 def get_int_parameter(p):
     if p is not None:
@@ -372,13 +424,16 @@ def get_sharing(request):
     if permissions is None:
         return None
     
+    # TODO: Document or remove '~' replacing.
+    permissions = permissions.replace('~', '')
+    
     ret = {}
     permissions = permissions.replace(' ', '')
     if permissions == '':
         return ret
     for perm in (x for x in permissions.split(';')):
         if perm.startswith('read='):
-            ret['read'] = [v.replace(' ','').lower() for v in perm[5:].split(',')]
+            ret['read'] = list(set([v.replace(' ','').lower() for v in perm[5:].split(',')]))
             if '' in ret['read']:
                 ret['read'].remove('')
             if '*' in ret['read']:
@@ -386,7 +441,7 @@ def get_sharing(request):
             if len(ret['read']) == 0:
                 raise BadRequest('Bad X-Object-Sharing header value')
         elif perm.startswith('write='):
-            ret['write'] = [v.replace(' ','').lower() for v in perm[6:].split(',')]
+            ret['write'] = list(set([v.replace(' ','').lower() for v in perm[6:].split(',')]))
             if '' in ret['write']:
                 ret['write'].remove('')
             if '*' in ret['write']:
@@ -395,6 +450,15 @@ def get_sharing(request):
                 raise BadRequest('Bad X-Object-Sharing header value')
         else:
             raise BadRequest('Bad X-Object-Sharing header value')
+    
+    # Keep duplicates only in write list.
+    dups = [x for x in ret.get('read', []) if x in ret.get('write', []) and x != '*']
+    if dups:
+        for x in dups:
+            ret['read'].remove(x)
+        if len(ret['read']) == 0:
+            del(ret['read'])
+    
     return ret
 
 def get_public(request):
@@ -418,26 +482,33 @@ def raw_input_socket(request):
     """Return the socket for reading the rest of the request."""
     
     server_software = request.META.get('SERVER_SOFTWARE')
-    if not server_software:
-        if 'wsgi.input' in request.environ:
-            return request.environ['wsgi.input']
-        raise ServiceUnavailable('Unknown server software')
-    if server_software.startswith('WSGIServer'):
-        return request.environ['wsgi.input']
-    elif server_software.startswith('mod_python'):
+    if server_software and server_software.startswith('mod_python'):
         return request._req
+    if 'wsgi.input' in request.environ:
+        return request.environ['wsgi.input']
     raise ServiceUnavailable('Unknown server software')
 
-MAX_UPLOAD_SIZE = 10 * (1024 * 1024) # 10MB
+MAX_UPLOAD_SIZE = 5 * (1024 * 1024 * 1024) # 5GB
 
-def socket_read_iterator(sock, length=0, blocksize=4096):
+def socket_read_iterator(request, length=0, blocksize=4096):
     """Return a maximum of blocksize data read from the socket in each iteration.
     
     Read up to 'length'. If 'length' is negative, will attempt a chunked read.
     The maximum ammount of data read is controlled by MAX_UPLOAD_SIZE.
     """
     
+    sock = raw_input_socket(request)
     if length < 0: # Chunked transfers
+        # Small version (server does the dechunking).
+        if request.environ.get('mod_wsgi.input_chunked', None) or request.META['SERVER_SOFTWARE'].startswith('gunicorn'):
+            while length < MAX_UPLOAD_SIZE:
+                data = sock.read(blocksize)
+                if data == '':
+                    return
+                yield data
+            raise BadRequest('Maximum size is reached')
+        
+        # Long version (do the dechunking).
         data = ''
         while length < MAX_UPLOAD_SIZE:
             # Get chunk size.
@@ -478,6 +549,8 @@ def socket_read_iterator(sock, length=0, blocksize=4096):
             raise BadRequest('Maximum size is reached')
         while length > 0:
             data = sock.read(min(length, blocksize))
+            if not data:
+                raise BadRequest()
             length -= len(data)
             yield data
 
@@ -487,7 +560,8 @@ class ObjectWrapper(object):
     Read from the object using the offset and length provided in each entry of the range list.
     """
     
-    def __init__(self, ranges, sizes, hashmaps, boundary):
+    def __init__(self, backend, ranges, sizes, hashmaps, boundary):
+        self.backend = backend
         self.ranges = ranges
         self.sizes = sizes
         self.hashmaps = hashmaps
@@ -515,16 +589,16 @@ class ObjectWrapper(object):
                 file_size = self.sizes[self.file_index]
             
             # Get the block for the current position.
-            self.block_index = int(self.offset / backend.block_size)
+            self.block_index = int(self.offset / self.backend.block_size)
             if self.block_hash != self.hashmaps[self.file_index][self.block_index]:
                 self.block_hash = self.hashmaps[self.file_index][self.block_index]
                 try:
-                    self.block = backend.get_block(self.block_hash)
+                    self.block = self.backend.get_block(self.block_hash)
                 except NameError:
                     raise ItemNotFound('Block does not exist')
             
             # Get the data from the block.
-            bo = self.offset % backend.block_size
+            bo = self.offset % self.backend.block_size
             bl = min(self.length, len(self.block) - bo)
             data = self.block[bo:bo + bl]
             self.offset += bl
@@ -598,7 +672,7 @@ def object_data_response(request, sizes, hashmaps, meta, public=False):
         boundary = uuid.uuid4().hex
     else:
         boundary = ''
-    wrapper = ObjectWrapper(ranges, sizes, hashmaps, boundary)
+    wrapper = ObjectWrapper(request.backend, ranges, sizes, hashmaps, boundary)
     response = HttpResponse(wrapper, status=ret)
     put_object_headers(response, meta, public)
     if ret == 206:
@@ -611,37 +685,38 @@ def object_data_response(request, sizes, hashmaps, meta, public=False):
             response['Content-Type'] = 'multipart/byteranges; boundary=%s' % (boundary,)
     return response
 
-def put_object_block(hashmap, data, offset):
+def put_object_block(request, hashmap, data, offset):
     """Put one block of data at the given offset."""
     
-    bi = int(offset / backend.block_size)
-    bo = offset % backend.block_size
-    bl = min(len(data), backend.block_size - bo)
+    bi = int(offset / request.backend.block_size)
+    bo = offset % request.backend.block_size
+    bl = min(len(data), request.backend.block_size - bo)
     if bi < len(hashmap):
-        hashmap[bi] = backend.update_block(hashmap[bi], data[:bl], bo)
+        hashmap[bi] = request.backend.update_block(hashmap[bi], data[:bl], bo)
     else:
-        hashmap.append(backend.put_block(('\x00' * bo) + data[:bl]))
+        hashmap.append(request.backend.put_block(('\x00' * bo) + data[:bl]))
     return bl # Return ammount of data written.
 
-def hashmap_hash(hashmap):
+def hashmap_hash(request, hashmap):
     """Produce the root hash, treating the hashmap as a Merkle-like tree."""
     
     def subhash(d):
-        h = hashlib.new(backend.hash_algorithm)
+        h = hashlib.new(request.backend.hash_algorithm)
         h.update(d)
         return h.digest()
     
     if len(hashmap) == 0:
         return hexlify(subhash(''))
     if len(hashmap) == 1:
-        return hexlify(subhash(hashmap[0]))
+        return hashmap[0]
+    
     s = 2
     while s < len(hashmap):
         s = s * 2
-    h = hashmap + ([('\x00' * len(hashmap[0]))] * (s - len(hashmap)))
-    h = [subhash(h[x] + (h[x + 1] if x + 1 < len(h) else '')) for x in range(0, len(h), 2)]
+    h = [unhexlify(x) for x in hashmap]
+    h += [('\x00' * len(h[0]))] * (s - len(hashmap))
     while len(h) > 1:
-        h = [subhash(h[x] + (h[x + 1] if x + 1 < len(h) else '')) for x in range(0, len(h), 2)]
+        h = [subhash(h[x] + h[x + 1]) for x in range(0, len(h), 2)]
     return hexlify(h[0])
 
 def update_response_headers(request, response):
@@ -683,16 +758,16 @@ def request_serialization(request, format_allowed=False):
     elif format == 'xml':
         return 'xml'
     
-#     for item in request.META.get('HTTP_ACCEPT', '').split(','):
-#         accept, sep, rest = item.strip().partition(';')
-#         if accept == 'application/json':
-#             return 'json'
-#         elif accept == 'application/xml' or accept == 'text/xml':
-#             return 'xml'
+    for item in request.META.get('HTTP_ACCEPT', '').split(','):
+        accept, sep, rest = item.strip().partition(';')
+        if accept == 'application/json':
+            return 'json'
+        elif accept == 'application/xml' or accept == 'text/xml':
+            return 'xml'
     
     return 'text'
 
-def api_method(http_method=None, format_allowed=False):
+def api_method(http_method=None, format_allowed=False, user_required=True):
     """Decorator function for views that implement an API method."""
     
     def decorator(func):
@@ -701,6 +776,8 @@ def api_method(http_method=None, format_allowed=False):
             try:
                 if http_method and request.method != http_method:
                     raise BadRequest('Method not allowed.')
+                if user_required and getattr(request, 'user', None) is None:
+                    raise Unauthorized('Access denied')
                 
                 # The args variable may contain up to (account, container, object).
                 if len(args) > 1 and len(args[1]) > 256:
@@ -710,7 +787,8 @@ def api_method(http_method=None, format_allowed=False):
                 
                 # Fill in custom request variables.
                 request.serialization = request_serialization(request, format_allowed)
-                
+                request.backend = connect_backend()
+
                 response = func(request, *args, **kwargs)
                 update_response_headers(request, response)
                 return response
@@ -720,5 +798,8 @@ def api_method(http_method=None, format_allowed=False):
                 logger.exception('Unexpected error: %s' % e)
                 fault = ServiceUnavailable('Unexpected error')
                 return render_fault(request, fault)
+            finally:
+                if getattr(request, 'backend', None) is not None:
+                    request.backend.wrapper.conn.close()
         return wrapper
     return decorator