Use 401 error when missing token and 403 when forbidden.
[pithos] / pithos / api / util.py
index 0e0aaed..296be84 100644 (file)
@@ -45,11 +45,11 @@ from django.utils.http import http_date, parse_etags
 from django.utils.encoding import smart_str
 
 from pithos.api.compat import parse_http_date_safe, parse_http_date
-from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, ItemNotFound,
-                                Conflict, LengthRequired, PreconditionFailed, RangeNotSatisfiable,
-                                ServiceUnavailable)
+from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, Forbidden, ItemNotFound,
+                                Conflict, LengthRequired, PreconditionFailed, RequestEntityTooLarge,
+                                RangeNotSatisfiable, ServiceUnavailable)
 from pithos.backends import connect_backend
-from pithos.backends.base import NotAllowedError
+from pithos.backends.base import NotAllowedError, QuotaError
 
 import logging
 import re
@@ -119,7 +119,7 @@ def get_account_headers(request):
             groups[n].remove('')
     return meta, groups
 
-def put_account_headers(response, meta, groups):
+def put_account_headers(response, meta, groups, policy):
     if 'count' in meta:
         response['X-Account-Container-Count'] = meta['count']
     if 'bytes' in meta:
@@ -134,7 +134,9 @@ def put_account_headers(response, meta, groups):
         k = format_header_key('X-Account-Group-' + k)
         v = smart_str(','.join(v), strings_only=True)
         response[k] = v
-    
+    for k, v in policy.iteritems():
+        response[smart_str(format_header_key('X-Account-Policy-' + k), strings_only=True)] = smart_str(v, strings_only=True)
+
 def get_container_headers(request):
     meta = get_header_prefix(request, 'X-Container-Meta-')
     policy = dict([(k[19:].lower(), v.replace(' ', '')) for k, v in get_header_prefix(request, 'X-Container-Policy-').iteritems()])
@@ -290,7 +292,6 @@ def copy_or_move_object(request, src_account, src_container, src_name, dest_acco
     """Copy or move an object."""
     
     meta, permissions, public = get_object_headers(request)
-    print '---', meta, permissions, public
     src_version = request.META.get('HTTP_X_SOURCE_VERSION')
     try:
         if move:
@@ -302,18 +303,20 @@ def copy_or_move_object(request, src_account, src_container, src_name, dest_acco
                                                         dest_account, dest_container, dest_name,
                                                         meta, False, permissions, src_version)
     except NotAllowedError:
-        raise Unauthorized('Access denied')
+        raise Forbidden('Not allowed')
     except (NameError, IndexError):
         raise ItemNotFound('Container or object does not exist')
     except ValueError:
         raise BadRequest('Invalid sharing header')
     except AttributeError, e:
         raise Conflict('\n'.join(e.data) + '\n')
+    except QuotaError:
+        raise RequestEntityTooLarge('Quota exceeded')
     if public is not None:
         try:
             request.backend.update_object_public(request.user, dest_account, dest_container, dest_name, public)
         except NotAllowedError:
-            raise Unauthorized('Access denied')
+            raise Forbidden('Not allowed')
         except NameError:
             raise ItemNotFound('Object does not exist')
     return version_id
@@ -497,7 +500,7 @@ def socket_read_iterator(request, length=0, blocksize=4096):
     sock = raw_input_socket(request)
     if length < 0: # Chunked transfers
         # Small version (server does the dechunking).
-        if request.environ.get('mod_wsgi.input_chunked', None):
+        if request.environ.get('mod_wsgi.input_chunked', None) or request.META['SERVER_SOFTWARE'].startswith('gunicorn'):
             while length < MAX_UPLOAD_SIZE:
                 data = sock.read(blocksize)
                 if data == '':
@@ -764,7 +767,7 @@ def request_serialization(request, format_allowed=False):
     
     return 'text'
 
-def api_method(http_method=None, format_allowed=False):
+def api_method(http_method=None, format_allowed=False, user_required=True):
     """Decorator function for views that implement an API method."""
     
     def decorator(func):
@@ -773,6 +776,8 @@ def api_method(http_method=None, format_allowed=False):
             try:
                 if http_method and request.method != http_method:
                     raise BadRequest('Method not allowed.')
+                if user_required and getattr(request, 'user', None) is None:
+                    raise Unauthorized('Access denied')
                 
                 # The args variable may contain up to (account, container, object).
                 if len(args) > 1 and len(args[1]) > 256:
@@ -794,6 +799,7 @@ def api_method(http_method=None, format_allowed=False):
                 fault = ServiceUnavailable('Unexpected error')
                 return render_fault(request, fault)
             finally:
-                request.backend.wrapper.conn.close()
+                if getattr(request, 'backend', None) is not None:
+                    request.backend.wrapper.conn.close()
         return wrapper
     return decorator