Use 401 error when missing token and 403 when forbidden.
[pithos] / pithos / api / util.py
index ab49423..296be84 100644 (file)
@@ -35,7 +35,8 @@ from functools import wraps
 from time import time
 from traceback import format_exc
 from wsgiref.handlers import format_date_time
-from binascii import hexlify
+from binascii import hexlify, unhexlify
+from datetime import datetime, tzinfo, timedelta
 
 from django.conf import settings
 from django.http import HttpResponse
@@ -44,22 +45,41 @@ from django.utils.http import http_date, parse_etags
 from django.utils.encoding import smart_str
 
 from pithos.api.compat import parse_http_date_safe, parse_http_date
-from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, ItemNotFound,
-                                Conflict, LengthRequired, PreconditionFailed, RangeNotSatisfiable,
-                                ServiceUnavailable)
-from pithos.backends import backend
-from pithos.backends.base import NotAllowedError
+from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, Forbidden, ItemNotFound,
+                                Conflict, LengthRequired, PreconditionFailed, RequestEntityTooLarge,
+                                RangeNotSatisfiable, ServiceUnavailable)
+from pithos.backends import connect_backend
+from pithos.backends.base import NotAllowedError, QuotaError
 
-import datetime
 import logging
 import re
 import hashlib
 import uuid
-
+import decimal
 
 logger = logging.getLogger(__name__)
 
 
+class UTC(tzinfo):
+   def utcoffset(self, dt):
+       return timedelta(0)
+
+   def tzname(self, dt):
+       return 'UTC'
+
+   def dst(self, dt):
+       return timedelta(0)
+
+def json_encode_decimal(obj):
+    if isinstance(obj, decimal.Decimal):
+        return str(obj)
+    raise TypeError(repr(obj) + " is not JSON serializable")
+
+def isoformat(d):
+   """Return an ISO8601 date string that includes a timezone."""
+
+   return d.replace(tzinfo=UTC()).isoformat()
+
 def rename_meta_key(d, old, new):
     if old not in d:
         return
@@ -73,7 +93,7 @@ def printable_header_dict(d):
     Format 'last_modified' timestamp.
     """
     
-    d['last_modified'] = datetime.datetime.fromtimestamp(int(d['last_modified'])).isoformat()
+    d['last_modified'] = isoformat(datetime.fromtimestamp(d['last_modified']))
     return dict([(k.lower().replace('-', '_'), v) for k, v in d.iteritems()])
 
 def format_header_key(k):
@@ -95,11 +115,11 @@ def get_account_headers(request):
         if '-' in n or '_' in n:
             raise BadRequest('Bad characters in group name')
         groups[n] = v.replace(' ', '').split(',')
-        if '' in groups[n]:
+        while '' in groups[n]:
             groups[n].remove('')
     return meta, groups
 
-def put_account_headers(response, meta, groups):
+def put_account_headers(response, meta, groups, policy):
     if 'count' in meta:
         response['X-Account-Container-Count'] = meta['count']
     if 'bytes' in meta:
@@ -114,13 +134,15 @@ def put_account_headers(response, meta, groups):
         k = format_header_key('X-Account-Group-' + k)
         v = smart_str(','.join(v), strings_only=True)
         response[k] = v
-    
+    for k, v in policy.iteritems():
+        response[smart_str(format_header_key('X-Account-Policy-' + k), strings_only=True)] = smart_str(v, strings_only=True)
+
 def get_container_headers(request):
     meta = get_header_prefix(request, 'X-Container-Meta-')
     policy = dict([(k[19:].lower(), v.replace(' ', '')) for k, v in get_header_prefix(request, 'X-Container-Policy-').iteritems()])
     return meta, policy
 
-def put_container_headers(response, meta, policy):
+def put_container_headers(request, response, meta, policy):
     if 'count' in meta:
         response['X-Container-Object-Count'] = meta['count']
     if 'bytes' in meta:
@@ -130,8 +152,8 @@ def put_container_headers(response, meta, policy):
         response[smart_str(k, strings_only=True)] = smart_str(meta[k], strings_only=True)
     l = [smart_str(x, strings_only=True) for x in meta['object_meta'] if x.startswith('X-Object-Meta-')]
     response['X-Container-Object-Meta'] = ','.join([x[14:] for x in l])
-    response['X-Container-Block-Size'] = backend.block_size
-    response['X-Container-Block-Hash'] = backend.hash_algorithm
+    response['X-Container-Block-Size'] = request.backend.block_size
+    response['X-Container-Block-Hash'] = request.backend.hash_algorithm
     if 'until_timestamp' in meta:
         response['X-Container-Until-Timestamp'] = http_date(int(meta['until_timestamp']))
     for k, v in policy.iteritems():
@@ -155,12 +177,14 @@ def put_object_headers(response, meta, restricted=False):
     response['Content-Type'] = meta.get('Content-Type', 'application/octet-stream')
     response['Last-Modified'] = http_date(int(meta['modified']))
     if not restricted:
-        response['X-Object-Modified-By'] = meta['modified_by']
+        response['X-Object-Modified-By'] = smart_str(meta['modified_by'], strings_only=True)
         response['X-Object-Version'] = meta['version']
         response['X-Object-Version-Timestamp'] = http_date(int(meta['version_timestamp']))
         for k in [x for x in meta.keys() if x.startswith('X-Object-Meta-')]:
             response[smart_str(k, strings_only=True)] = smart_str(meta[k], strings_only=True)
-        for k in ('Content-Encoding', 'Content-Disposition', 'X-Object-Manifest', 'X-Object-Sharing', 'X-Object-Shared-By', 'X-Object-Public'):
+        for k in ('Content-Encoding', 'Content-Disposition', 'X-Object-Manifest',
+                  'X-Object-Sharing', 'X-Object-Shared-By', 'X-Object-Allowed-To',
+                  'X-Object-Public'):
             if k in meta:
                 response[k] = smart_str(meta[k], strings_only=True)
     else:
@@ -176,9 +200,11 @@ def update_manifest_meta(request, v_account, meta):
         bytes = 0
         try:
             src_container, src_name = split_container_object_string('/' + meta['X-Object-Manifest'])
-            objects = backend.list_objects(request.user, v_account, src_container, prefix=src_name, virtual=False)
+            objects = request.backend.list_objects(request.user, v_account,
+                                src_container, prefix=src_name, virtual=False)
             for x in objects:
-                src_meta = backend.get_object_meta(request.user, v_account, src_container, x[0], x[1])
+                src_meta = request.backend.get_object_meta(request.user,
+                                        v_account, src_container, x[0], x[1])
                 hash += src_meta['hash']
                 bytes += src_meta['bytes']
         except:
@@ -189,10 +215,10 @@ def update_manifest_meta(request, v_account, meta):
         md5.update(hash)
         meta['hash'] = md5.hexdigest().lower()
 
-def update_sharing_meta(permissions, v_account, v_container, v_object, meta):
+def update_sharing_meta(request, permissions, v_account, v_container, v_object, meta):
     if permissions is None:
         return
-    perm_path, perms = permissions
+    allowed, perm_path, perms = permissions
     if len(perms) == 0:
         return
     ret = []
@@ -205,6 +231,8 @@ def update_sharing_meta(permissions, v_account, v_container, v_object, meta):
     meta['X-Object-Sharing'] = '; '.join(ret)
     if '/'.join((v_account, v_container, v_object)) != perm_path:
         meta['X-Object-Shared-By'] = perm_path
+    if request.user != v_account:
+        meta['X-Object-Allowed-To'] = allowed
 
 def update_public_meta(public, meta):
     if not public:
@@ -260,31 +288,38 @@ def split_container_object_string(s):
         raise ValueError
     return s[:pos], s[(pos + 1):]
 
-def copy_or_move_object(request, v_account, src_container, src_name, dest_container, dest_name, move=False):
+def copy_or_move_object(request, src_account, src_container, src_name, dest_account, dest_container, dest_name, move=False):
     """Copy or move an object."""
     
     meta, permissions, public = get_object_headers(request)
-    src_version = request.META.get('HTTP_X_SOURCE_VERSION')    
+    src_version = request.META.get('HTTP_X_SOURCE_VERSION')
     try:
         if move:
-            backend.move_object(request.user, v_account, src_container, src_name, dest_container, dest_name, meta, False, permissions)
+            version_id = request.backend.move_object(request.user, src_account, src_container, src_name,
+                                                        dest_account, dest_container, dest_name,
+                                                        meta, False, permissions)
         else:
-            backend.copy_object(request.user, v_account, src_container, src_name, dest_container, dest_name, meta, False, permissions, src_version)
+            version_id = request.backend.copy_object(request.user, src_account, src_container, src_name,
+                                                        dest_account, dest_container, dest_name,
+                                                        meta, False, permissions, src_version)
     except NotAllowedError:
-        raise Unauthorized('Access denied')
+        raise Forbidden('Not allowed')
     except (NameError, IndexError):
         raise ItemNotFound('Container or object does not exist')
     except ValueError:
         raise BadRequest('Invalid sharing header')
     except AttributeError, e:
-        raise Conflict(json.dumps(e.data))
+        raise Conflict('\n'.join(e.data) + '\n')
+    except QuotaError:
+        raise RequestEntityTooLarge('Quota exceeded')
     if public is not None:
         try:
-            backend.update_object_public(request.user, v_account, dest_container, dest_name, public)
+            request.backend.update_object_public(request.user, dest_account, dest_container, dest_name, public)
         except NotAllowedError:
-            raise Unauthorized('Access denied')
+            raise Forbidden('Not allowed')
         except NameError:
             raise ItemNotFound('Object does not exist')
+    return version_id
 
 def get_int_parameter(p):
     if p is not None:
@@ -389,6 +424,9 @@ def get_sharing(request):
     if permissions is None:
         return None
     
+    # TODO: Document or remove '~' replacing.
+    permissions = permissions.replace('~', '')
+    
     ret = {}
     permissions = permissions.replace(' ', '')
     if permissions == '':
@@ -412,6 +450,15 @@ def get_sharing(request):
                 raise BadRequest('Bad X-Object-Sharing header value')
         else:
             raise BadRequest('Bad X-Object-Sharing header value')
+    
+    # Keep duplicates only in write list.
+    dups = [x for x in ret.get('read', []) if x in ret.get('write', []) and x != '*']
+    if dups:
+        for x in dups:
+            ret['read'].remove(x)
+        if len(ret['read']) == 0:
+            del(ret['read'])
+    
     return ret
 
 def get_public(request):
@@ -441,7 +488,7 @@ def raw_input_socket(request):
         return request.environ['wsgi.input']
     raise ServiceUnavailable('Unknown server software')
 
-MAX_UPLOAD_SIZE = 10 * (1024 * 1024) # 10MB
+MAX_UPLOAD_SIZE = 5 * (1024 * 1024 * 1024) # 5GB
 
 def socket_read_iterator(request, length=0, blocksize=4096):
     """Return a maximum of blocksize data read from the socket in each iteration.
@@ -453,7 +500,7 @@ def socket_read_iterator(request, length=0, blocksize=4096):
     sock = raw_input_socket(request)
     if length < 0: # Chunked transfers
         # Small version (server does the dechunking).
-        if request.environ.get('mod_wsgi.input_chunked', None):
+        if request.environ.get('mod_wsgi.input_chunked', None) or request.META['SERVER_SOFTWARE'].startswith('gunicorn'):
             while length < MAX_UPLOAD_SIZE:
                 data = sock.read(blocksize)
                 if data == '':
@@ -502,6 +549,8 @@ def socket_read_iterator(request, length=0, blocksize=4096):
             raise BadRequest('Maximum size is reached')
         while length > 0:
             data = sock.read(min(length, blocksize))
+            if not data:
+                raise BadRequest()
             length -= len(data)
             yield data
 
@@ -511,7 +560,8 @@ class ObjectWrapper(object):
     Read from the object using the offset and length provided in each entry of the range list.
     """
     
-    def __init__(self, ranges, sizes, hashmaps, boundary):
+    def __init__(self, backend, ranges, sizes, hashmaps, boundary):
+        self.backend = backend
         self.ranges = ranges
         self.sizes = sizes
         self.hashmaps = hashmaps
@@ -539,16 +589,16 @@ class ObjectWrapper(object):
                 file_size = self.sizes[self.file_index]
             
             # Get the block for the current position.
-            self.block_index = int(self.offset / backend.block_size)
+            self.block_index = int(self.offset / self.backend.block_size)
             if self.block_hash != self.hashmaps[self.file_index][self.block_index]:
                 self.block_hash = self.hashmaps[self.file_index][self.block_index]
                 try:
-                    self.block = backend.get_block(self.block_hash)
+                    self.block = self.backend.get_block(self.block_hash)
                 except NameError:
                     raise ItemNotFound('Block does not exist')
             
             # Get the data from the block.
-            bo = self.offset % backend.block_size
+            bo = self.offset % self.backend.block_size
             bl = min(self.length, len(self.block) - bo)
             data = self.block[bo:bo + bl]
             self.offset += bl
@@ -622,7 +672,7 @@ def object_data_response(request, sizes, hashmaps, meta, public=False):
         boundary = uuid.uuid4().hex
     else:
         boundary = ''
-    wrapper = ObjectWrapper(ranges, sizes, hashmaps, boundary)
+    wrapper = ObjectWrapper(request.backend, ranges, sizes, hashmaps, boundary)
     response = HttpResponse(wrapper, status=ret)
     put_object_headers(response, meta, public)
     if ret == 206:
@@ -635,37 +685,38 @@ def object_data_response(request, sizes, hashmaps, meta, public=False):
             response['Content-Type'] = 'multipart/byteranges; boundary=%s' % (boundary,)
     return response
 
-def put_object_block(hashmap, data, offset):
+def put_object_block(request, hashmap, data, offset):
     """Put one block of data at the given offset."""
     
-    bi = int(offset / backend.block_size)
-    bo = offset % backend.block_size
-    bl = min(len(data), backend.block_size - bo)
+    bi = int(offset / request.backend.block_size)
+    bo = offset % request.backend.block_size
+    bl = min(len(data), request.backend.block_size - bo)
     if bi < len(hashmap):
-        hashmap[bi] = backend.update_block(hashmap[bi], data[:bl], bo)
+        hashmap[bi] = request.backend.update_block(hashmap[bi], data[:bl], bo)
     else:
-        hashmap.append(backend.put_block(('\x00' * bo) + data[:bl]))
+        hashmap.append(request.backend.put_block(('\x00' * bo) + data[:bl]))
     return bl # Return ammount of data written.
 
-def hashmap_hash(hashmap):
+def hashmap_hash(request, hashmap):
     """Produce the root hash, treating the hashmap as a Merkle-like tree."""
     
     def subhash(d):
-        h = hashlib.new(backend.hash_algorithm)
+        h = hashlib.new(request.backend.hash_algorithm)
         h.update(d)
         return h.digest()
     
     if len(hashmap) == 0:
         return hexlify(subhash(''))
     if len(hashmap) == 1:
-        return hexlify(hashmap[0])
+        return hashmap[0]
+    
     s = 2
     while s < len(hashmap):
         s = s * 2
-    h = hashmap + ([('\x00' * len(hashmap[0]))] * (s - len(hashmap)))
-    h = [subhash(h[x] + (h[x + 1] if x + 1 < len(h) else '')) for x in range(0, len(h), 2)]
+    h = [unhexlify(x) for x in hashmap]
+    h += [('\x00' * len(h[0]))] * (s - len(hashmap))
     while len(h) > 1:
-        h = [subhash(h[x] + (h[x + 1] if x + 1 < len(h) else '')) for x in range(0, len(h), 2)]
+        h = [subhash(h[x] + h[x + 1]) for x in range(0, len(h), 2)]
     return hexlify(h[0])
 
 def update_response_headers(request, response):
@@ -707,16 +758,16 @@ def request_serialization(request, format_allowed=False):
     elif format == 'xml':
         return 'xml'
     
-#     for item in request.META.get('HTTP_ACCEPT', '').split(','):
-#         accept, sep, rest = item.strip().partition(';')
-#         if accept == 'application/json':
-#             return 'json'
-#         elif accept == 'application/xml' or accept == 'text/xml':
-#             return 'xml'
+    for item in request.META.get('HTTP_ACCEPT', '').split(','):
+        accept, sep, rest = item.strip().partition(';')
+        if accept == 'application/json':
+            return 'json'
+        elif accept == 'application/xml' or accept == 'text/xml':
+            return 'xml'
     
     return 'text'
 
-def api_method(http_method=None, format_allowed=False):
+def api_method(http_method=None, format_allowed=False, user_required=True):
     """Decorator function for views that implement an API method."""
     
     def decorator(func):
@@ -725,6 +776,8 @@ def api_method(http_method=None, format_allowed=False):
             try:
                 if http_method and request.method != http_method:
                     raise BadRequest('Method not allowed.')
+                if user_required and getattr(request, 'user', None) is None:
+                    raise Unauthorized('Access denied')
                 
                 # The args variable may contain up to (account, container, object).
                 if len(args) > 1 and len(args[1]) > 256:
@@ -734,7 +787,8 @@ def api_method(http_method=None, format_allowed=False):
                 
                 # Fill in custom request variables.
                 request.serialization = request_serialization(request, format_allowed)
-                
+                request.backend = connect_backend()
+
                 response = func(request, *args, **kwargs)
                 update_response_headers(request, response)
                 return response
@@ -744,5 +798,8 @@ def api_method(http_method=None, format_allowed=False):
                 logger.exception('Unexpected error: %s' % e)
                 fault = ServiceUnavailable('Unexpected error')
                 return render_fault(request, fault)
+            finally:
+                if getattr(request, 'backend', None) is not None:
+                    request.backend.wrapper.conn.close()
         return wrapper
     return decorator