Use 401 error when missing token and 403 when forbidden.
[pithos] / pithos / api / util.py
index d60803d..296be84 100644 (file)
@@ -45,17 +45,17 @@ from django.utils.http import http_date, parse_etags
 from django.utils.encoding import smart_str
 
 from pithos.api.compat import parse_http_date_safe, parse_http_date
-from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, ItemNotFound,
-                                Conflict, LengthRequired, PreconditionFailed, RangeNotSatisfiable,
-                                ServiceUnavailable)
+from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, Forbidden, ItemNotFound,
+                                Conflict, LengthRequired, PreconditionFailed, RequestEntityTooLarge,
+                                RangeNotSatisfiable, ServiceUnavailable)
 from pithos.backends import connect_backend
-from pithos.backends.base import NotAllowedError
+from pithos.backends.base import NotAllowedError, QuotaError
 
 import logging
 import re
 import hashlib
 import uuid
-
+import decimal
 
 logger = logging.getLogger(__name__)
 
@@ -70,6 +70,11 @@ class UTC(tzinfo):
    def dst(self, dt):
        return timedelta(0)
 
+def json_encode_decimal(obj):
+    if isinstance(obj, decimal.Decimal):
+        return str(obj)
+    raise TypeError(repr(obj) + " is not JSON serializable")
+
 def isoformat(d):
    """Return an ISO8601 date string that includes a timezone."""
 
@@ -110,11 +115,11 @@ def get_account_headers(request):
         if '-' in n or '_' in n:
             raise BadRequest('Bad characters in group name')
         groups[n] = v.replace(' ', '').split(',')
-        if '' in groups[n]:
+        while '' in groups[n]:
             groups[n].remove('')
     return meta, groups
 
-def put_account_headers(response, meta, groups):
+def put_account_headers(response, meta, groups, policy):
     if 'count' in meta:
         response['X-Account-Container-Count'] = meta['count']
     if 'bytes' in meta:
@@ -129,7 +134,9 @@ def put_account_headers(response, meta, groups):
         k = format_header_key('X-Account-Group-' + k)
         v = smart_str(','.join(v), strings_only=True)
         response[k] = v
-    
+    for k, v in policy.iteritems():
+        response[smart_str(format_header_key('X-Account-Policy-' + k), strings_only=True)] = smart_str(v, strings_only=True)
+
 def get_container_headers(request):
     meta = get_header_prefix(request, 'X-Container-Meta-')
     policy = dict([(k[19:].lower(), v.replace(' ', '')) for k, v in get_header_prefix(request, 'X-Container-Policy-').iteritems()])
@@ -285,7 +292,6 @@ def copy_or_move_object(request, src_account, src_container, src_name, dest_acco
     """Copy or move an object."""
     
     meta, permissions, public = get_object_headers(request)
-    print '---', meta, permissions, public
     src_version = request.META.get('HTTP_X_SOURCE_VERSION')
     try:
         if move:
@@ -297,18 +303,20 @@ def copy_or_move_object(request, src_account, src_container, src_name, dest_acco
                                                         dest_account, dest_container, dest_name,
                                                         meta, False, permissions, src_version)
     except NotAllowedError:
-        raise Unauthorized('Access denied')
+        raise Forbidden('Not allowed')
     except (NameError, IndexError):
         raise ItemNotFound('Container or object does not exist')
     except ValueError:
         raise BadRequest('Invalid sharing header')
     except AttributeError, e:
         raise Conflict('\n'.join(e.data) + '\n')
+    except QuotaError:
+        raise RequestEntityTooLarge('Quota exceeded')
     if public is not None:
         try:
             request.backend.update_object_public(request.user, dest_account, dest_container, dest_name, public)
         except NotAllowedError:
-            raise Unauthorized('Access denied')
+            raise Forbidden('Not allowed')
         except NameError:
             raise ItemNotFound('Object does not exist')
     return version_id
@@ -492,7 +500,7 @@ def socket_read_iterator(request, length=0, blocksize=4096):
     sock = raw_input_socket(request)
     if length < 0: # Chunked transfers
         # Small version (server does the dechunking).
-        if request.environ.get('mod_wsgi.input_chunked', None):
+        if request.environ.get('mod_wsgi.input_chunked', None) or request.META['SERVER_SOFTWARE'].startswith('gunicorn'):
             while length < MAX_UPLOAD_SIZE:
                 data = sock.read(blocksize)
                 if data == '':
@@ -759,7 +767,7 @@ def request_serialization(request, format_allowed=False):
     
     return 'text'
 
-def api_method(http_method=None, format_allowed=False):
+def api_method(http_method=None, format_allowed=False, user_required=True):
     """Decorator function for views that implement an API method."""
     
     def decorator(func):
@@ -768,6 +776,8 @@ def api_method(http_method=None, format_allowed=False):
             try:
                 if http_method and request.method != http_method:
                     raise BadRequest('Method not allowed.')
+                if user_required and getattr(request, 'user', None) is None:
+                    raise Unauthorized('Access denied')
                 
                 # The args variable may contain up to (account, container, object).
                 if len(args) > 1 and len(args[1]) > 256:
@@ -789,6 +799,7 @@ def api_method(http_method=None, format_allowed=False):
                 fault = ServiceUnavailable('Unexpected error')
                 return render_fault(request, fault)
             finally:
-                request.backend.wrapper.conn.close()
+                if getattr(request, 'backend', None) is not None:
+                    request.backend.wrapper.conn.close()
         return wrapper
     return decorator