cleanup pithos backend pools, new pool api support
[pithos] / snf-pithos-app / pithos / api / util.py
index c9fed29..f8233f9 100644 (file)
@@ -41,20 +41,30 @@ from urllib import quote, unquote
 
 from django.conf import settings
 from django.http import HttpResponse
+from django.template.loader import render_to_string
 from django.utils import simplejson as json
 from django.utils.http import http_date, parse_etags
 from django.utils.encoding import smart_unicode, smart_str
 from django.core.files.uploadhandler import FileUploadHandler
 from django.core.files.uploadedfile import UploadedFile
 
-from pithos.lib.compat import parse_http_date_safe, parse_http_date
+from synnefo.lib.parsedate import parse_http_date_safe, parse_http_date
+from synnefo.lib.astakos import get_user
 
 from pithos.api.faults import (Fault, NotModified, BadRequest, Unauthorized, Forbidden, ItemNotFound,
                                 Conflict, LengthRequired, PreconditionFailed, RequestEntityTooLarge,
                                 RangeNotSatisfiable, InternalServerError, NotImplemented)
 from pithos.api.short_url import encode_url
+from pithos.api.settings import (BACKEND_DB_MODULE, BACKEND_DB_CONNECTION,
+                                    BACKEND_BLOCK_MODULE, BACKEND_BLOCK_PATH,
+                                    BACKEND_BLOCK_UMASK,
+                                    BACKEND_QUEUE_MODULE, BACKEND_QUEUE_CONNECTION,
+                                    BACKEND_QUOTA, BACKEND_VERSIONING,
+                                    AUTHENTICATION_URL, AUTHENTICATION_USERS,
+                                    SERVICE_TOKEN, COOKIE_NAME)
+
 from pithos.backends import connect_backend
-from pithos.backends.base import NotAllowedError, QuotaError
+from pithos.backends.base import NotAllowedError, QuotaError, ItemNotExists, VersionNotExists
 
 import logging
 import re
@@ -99,7 +109,7 @@ def printable_header_dict(d):
     Format 'last_modified' timestamp.
     """
     
-    if 'last_modified' in d:
+    if 'last_modified' in d and d['last_modified']:
         d['last_modified'] = isoformat(datetime.fromtimestamp(d['last_modified']))
     return dict([(k.lower().replace('-', '_'), v) for k, v in d.iteritems()])
 
@@ -114,8 +124,18 @@ def get_header_prefix(request, prefix):
     # TODO: Document or remove '~' replacing.
     return dict([(format_header_key(k[5:]), v.replace('~', '')) for k, v in request.META.iteritems() if k.startswith(prefix) and len(k) > len(prefix)])
 
+def check_meta_headers(meta):
+    if len(meta) > 90:
+        raise BadRequest('Too many headers.')
+    for k, v in meta.iteritems():
+        if len(k) > 128:
+            raise BadRequest('Header name too large.')
+        if len(v) > 256:
+            raise BadRequest('Header value too large.')
+
 def get_account_headers(request):
     meta = get_header_prefix(request, 'X-Account-Meta-')
+    check_meta_headers(meta)
     groups = {}
     for k, v in get_header_prefix(request, 'X-Account-Group-').iteritems():
         n = k[16:].lower()
@@ -146,6 +166,7 @@ def put_account_headers(response, meta, groups, policy):
 
 def get_container_headers(request):
     meta = get_header_prefix(request, 'X-Container-Meta-')
+    check_meta_headers(meta)
     policy = dict([(k[19:].lower(), v.replace(' ', '')) for k, v in get_header_prefix(request, 'X-Container-Policy-').iteritems()])
     return meta, policy
 
@@ -167,22 +188,21 @@ def put_container_headers(request, response, meta, policy):
         response[smart_str(format_header_key('X-Container-Policy-' + k), strings_only=True)] = smart_str(v, strings_only=True)
 
 def get_object_headers(request):
+    content_type = request.META.get('CONTENT_TYPE', None)
     meta = get_header_prefix(request, 'X-Object-Meta-')
-    if request.META.get('CONTENT_TYPE'):
-        meta['Content-Type'] = request.META['CONTENT_TYPE']
+    check_meta_headers(meta)
     if request.META.get('HTTP_CONTENT_ENCODING'):
         meta['Content-Encoding'] = request.META['HTTP_CONTENT_ENCODING']
     if request.META.get('HTTP_CONTENT_DISPOSITION'):
         meta['Content-Disposition'] = request.META['HTTP_CONTENT_DISPOSITION']
     if request.META.get('HTTP_X_OBJECT_MANIFEST'):
         meta['X-Object-Manifest'] = request.META['HTTP_X_OBJECT_MANIFEST']
-    return meta, get_sharing(request), get_public(request)
+    return content_type, meta, get_sharing(request), get_public(request)
 
 def put_object_headers(response, meta, restricted=False):
-    if 'ETag' in meta:
-        response['ETag'] = meta['ETag']
+    response['ETag'] = meta['checksum']
     response['Content-Length'] = meta['bytes']
-    response['Content-Type'] = meta.get('Content-Type', 'application/octet-stream')
+    response['Content-Type'] = meta.get('type', 'application/octet-stream')
     response['Last-Modified'] = http_date(int(meta['modified']))
     if not restricted:
         response['X-Object-Hash'] = meta['hash']
@@ -215,8 +235,7 @@ def update_manifest_meta(request, v_account, meta):
             for x in objects:
                 src_meta = request.backend.get_object_meta(request.user_uniq,
                                         v_account, src_container, x[0], 'pithos', x[1])
-                if 'ETag' in src_meta:
-                    etag += src_meta['ETag']
+                etag += src_meta['checksum']
                 bytes += src_meta['bytes']
         except:
             # Ignore errors.
@@ -224,7 +243,7 @@ def update_manifest_meta(request, v_account, meta):
         meta['bytes'] = bytes
         md5 = hashlib.md5()
         md5.update(etag)
-        meta['ETag'] = md5.hexdigest().lower()
+        meta['checksum'] = md5.hexdigest().lower()
 
 def update_sharing_meta(request, permissions, v_account, v_container, v_object, meta):
     if permissions is None:
@@ -271,7 +290,9 @@ def validate_modification_preconditions(request, meta):
 def validate_matching_preconditions(request, meta):
     """Check that the ETag conforms with the preconditions set."""
     
-    etag = meta.get('ETag', None)
+    etag = meta['checksum']
+    if not etag:
+        etag = None
     
     if_match = request.META.get('HTTP_IF_MATCH')
     if if_match is not None:
@@ -299,28 +320,28 @@ def split_container_object_string(s):
         raise ValueError
     return s[:pos], s[(pos + 1):]
 
-def copy_or_move_object(request, src_account, src_container, src_name, dest_account, dest_container, dest_name, move=False):
+def copy_or_move_object(request, src_account, src_container, src_name, dest_account, dest_container, dest_name, move=False, delimiter=None):
     """Copy or move an object."""
     
-    meta, permissions, public = get_object_headers(request)
+    if 'ignore_content_type' in request.GET and 'CONTENT_TYPE' in request.META:
+        del(request.META['CONTENT_TYPE'])
+    content_type, meta, permissions, public = get_object_headers(request)
     src_version = request.META.get('HTTP_X_SOURCE_VERSION')
     try:
         if move:
             version_id = request.backend.move_object(request.user_uniq, src_account, src_container, src_name,
                                                         dest_account, dest_container, dest_name,
-                                                        'pithos', meta, False, permissions)
+                                                        content_type, 'pithos', meta, False, permissions, delimiter)
         else:
             version_id = request.backend.copy_object(request.user_uniq, src_account, src_container, src_name,
                                                         dest_account, dest_container, dest_name,
-                                                        'pithos', meta, False, permissions, src_version)
+                                                        content_type, 'pithos', meta, False, permissions, src_version, delimiter)
     except NotAllowedError:
         raise Forbidden('Not allowed')
-    except (NameError, IndexError):
+    except (ItemNotExists, VersionNotExists):
         raise ItemNotFound('Container or object does not exist')
     except ValueError:
         raise BadRequest('Invalid sharing header')
-    except AttributeError, e:
-        raise Conflict('\n'.join(e.data) + '\n')
     except QuotaError:
         raise RequestEntityTooLarge('Quota exceeded')
     if public is not None:
@@ -328,7 +349,7 @@ def copy_or_move_object(request, src_account, src_container, src_name, dest_acco
             request.backend.update_object_public(request.user_uniq, dest_account, dest_container, dest_name, public)
         except NotAllowedError:
             raise Forbidden('Not allowed')
-        except NameError:
+        except ItemNotExists:
             raise ItemNotFound('Object does not exist')
     return version_id
 
@@ -639,12 +660,16 @@ class ObjectWrapper(object):
                 self.block_hash = self.hashmaps[self.file_index][self.block_index]
                 try:
                     self.block = self.backend.get_block(self.block_hash)
-                except NameError:
+                except ItemNotExists:
                     raise ItemNotFound('Block does not exist')
             
             # Get the data from the block.
             bo = self.offset % self.backend.block_size
-            bl = min(self.length, len(self.block) - bo)
+            bs = self.backend.block_size
+            if (self.block_index == len(self.hashmaps[self.file_index]) - 1 and
+                self.sizes[self.file_index] % self.backend.block_size):
+                bs = self.sizes[self.file_index] % self.backend.block_size
+            bl = min(self.length, bs - bo)
             data = self.block[bo:bo + bl]
             self.offset += bl
             self.length -= bl
@@ -709,7 +734,7 @@ def object_data_response(request, sizes, hashmaps, meta, public=False):
                     ranges = [(0, size)]
                     ret = 200
             except ValueError:
-                if if_range != meta['ETag']:
+                if if_range != meta['checksum']:
                     ranges = [(0, size)]
                     ret = 200
     
@@ -742,40 +767,88 @@ def put_object_block(request, hashmap, data, offset):
         hashmap.append(request.backend.put_block(('\x00' * bo) + data[:bl]))
     return bl # Return ammount of data written.
 
-def hashmap_md5(request, hashmap, size):
+def hashmap_md5(backend, hashmap, size):
     """Produce the MD5 sum from the data in the hashmap."""
     
     # TODO: Search backend for the MD5 of another object with the same hashmap and size...
     md5 = hashlib.md5()
-    bs = request.backend.block_size
+    bs = backend.block_size
     for bi, hash in enumerate(hashmap):
-        data = request.backend.get_block(hash)
+        data = backend.get_block(hash) # Blocks come in padded.
         if bi == len(hashmap) - 1:
-            bs = size % bs
-        pad = bs - min(len(data), bs)
-        md5.update(data + ('\x00' * pad))
+            data = data[:size % bs]
+        md5.update(data)
     return md5.hexdigest().lower()
 
-def get_backend():
-    backend = connect_backend(db_module=settings.BACKEND_DB_MODULE,
-                              db_connection=settings.BACKEND_DB_CONNECTION,
-                              block_module=settings.BACKEND_BLOCK_MODULE,
-                              block_path=settings.BACKEND_BLOCK_PATH)
-    backend.default_policy['quota'] = settings.BACKEND_QUOTA
-    backend.default_policy['versioning'] = settings.BACKEND_VERSIONING
+def simple_list_response(request, l):
+    if request.serialization == 'text':
+        return '\n'.join(l) + '\n'
+    if request.serialization == 'xml':
+        return render_to_string('items.xml', {'items': l})
+    if request.serialization == 'json':
+        return json.dumps(l)
+
+
+def _get_backend():
+    backend = connect_backend(db_module=BACKEND_DB_MODULE,
+                              db_connection=BACKEND_DB_CONNECTION,
+                              block_module=BACKEND_BLOCK_MODULE,
+                              block_path=BACKEND_BLOCK_PATH,
+                              block_umask=BACKEND_BLOCK_UMASK,
+                              queue_module=BACKEND_QUEUE_MODULE,
+                              queue_connection=BACKEND_QUEUE_CONNECTION)
+    backend.default_policy['quota'] = BACKEND_QUOTA
+    backend.default_policy['versioning'] = BACKEND_VERSIONING
     return backend
 
+
+def _pooled_backend_close(backend):
+    backend._pool.pool_put(backend)
+
+
+from synnefo.lib.pool import ObjectPool
+from new import instancemethod
+
+USAGE_LIMIT = 500
+POOL_SIZE = 5
+
+class PithosBackendPool(ObjectPool):
+    def _pool_create(self):
+        backend = _get_backend()
+        backend._real_close = backend.close
+        backend.close = instancemethod(_pooled_backend_close, backend,
+                                       type(backend))
+        backend._pool = self
+        backend._use_count = USAGE_LIMIT
+        return backend
+
+    def _pool_verify(self, backend):
+        return 1
+
+    def _pool_cleanup(self, backend):
+        c = backend._use_count - 1
+        if c < 0:
+            backend._real_close()
+            return True
+
+        backend._use_count = c
+        if backend.trans is not None:
+            backend.wrapper.rollback()
+        if backend.messages:
+            backend.messages = []
+        return False
+
+_pithos_backend_pool = PithosBackendPool(size=POOL_SIZE)
+
+
+def get_backend():
+    return _pithos_backend_pool.pool_get()
+
+
 def update_request_headers(request):
     # Handle URL-encoded keys and values.
-    # Handle URL-encoded keys and values.
     meta = dict([(k, v) for k, v in request.META.iteritems() if k.startswith('HTTP_')])
-    if len(meta) > 90:
-        raise BadRequest('Too many headers.')
     for k, v in meta.iteritems():
-        if len(k) > 128:
-            raise BadRequest('Header name too large.')
-        if len(v) > 256:
-            raise BadRequest('Header value too large.')
         try:
             k.decode('ascii')
             v.decode('ascii')
@@ -805,12 +878,9 @@ def update_response_headers(request, response):
             k.startswith('X-Object-') or k.startswith('Content-')):
             del(response[k])
             response[quote(k)] = quote(v, safe='/=,:@; ')
-    
-    if settings.TEST:
-        response['Date'] = format_date_time(time())
 
 def render_fault(request, fault):
-    if isinstance(fault, InternalServerError) and (settings.DEBUG or settings.TEST):
+    if isinstance(fault, InternalServerError) and settings.DEBUG:
         fault.details = format_exc(fault)
     
     request.serialization = 'text'
@@ -845,6 +915,7 @@ def request_serialization(request, format_allowed=False):
     
     return 'text'
 
+
 def api_method(http_method=None, format_allowed=False, user_required=True):
     """Decorator function for views that implement an API method."""
     
@@ -854,8 +925,16 @@ def api_method(http_method=None, format_allowed=False, user_required=True):
             try:
                 if http_method and request.method != http_method:
                     raise BadRequest('Method not allowed.')
-                if user_required and getattr(request, 'user', None) is None:
-                    raise Unauthorized('Access denied')
+                
+                if user_required:
+                    token = None
+                    if request.method in ('HEAD', 'GET') and COOKIE_NAME in request.COOKIES:
+                        cookie_value = unquote(request.COOKIES.get(COOKIE_NAME, ''))
+                        if cookie_value and '|' in cookie_value:
+                            token = cookie_value.split('|', 1)[1]
+                    get_user(request, AUTHENTICATION_URL, AUTHENTICATION_USERS, token)
+                    if  getattr(request, 'user', None) is None:
+                        raise Unauthorized('Access denied')
                 
                 # The args variable may contain up to (account, container, object).
                 if len(args) > 1 and len(args[1]) > 256:
@@ -877,7 +956,7 @@ def api_method(http_method=None, format_allowed=False, user_required=True):
                 return render_fault(request, fault)
             except BaseException, e:
                 logger.exception('Unexpected error: %s' % e)
-                fault = InternalServerError('Unexpected error')
+                fault = InternalServerError('Unexpected error: %s' % e)
                 return render_fault(request, fault)
             finally:
                 if getattr(request, 'backend', None) is not None: