backend components in SQLAlchemy: Progress IV
[pithos] / pithos / backends / modular_alchemy.py
index 0ca00cc..1079d56 100644 (file)
@@ -42,8 +42,8 @@ from base import NotAllowedError, BaseBackend
 from lib_alchemy.node import Node, ROOTNODE, SERIAL, SIZE, MTIME, MUSER, CLUSTER
 from lib_alchemy.permissions import Permissions, READ, WRITE
 from lib_alchemy.policy import Policy
-from lib_alchemy.hashfiler import Mapper, Blocker
 from sqlalchemy import create_engine
+from lib.hashfiler import Mapper, Blocker
 
 ( CLUSTER_NORMAL, CLUSTER_HISTORY, CLUSTER_DELETED ) = range(3)
 
@@ -61,13 +61,13 @@ def backend_method(func=None, autocommit=1):
     if not autocommit:
         return func
     def fn(self, *args, **kw):
-        self.con.execute('begin deferred')
+        trans = self.con.begin()
         try:
             ret = func(self, *args, **kw)
-            self.con.commit()
+            trans.commit()
             return ret
         except:
-            self.con.rollback()
+            trans.rollback()
             raise
     return fn
 
@@ -78,7 +78,7 @@ class ModularBackend(BaseBackend):
     Uses modules for SQL functions and storage.
     """
     
-    def __init__(self, db):
+    def __init__(self, db, db_options):
         self.hash_algorithm = 'sha256'
         self.block_size = 4 * 1024 * 1024 # 4MB
         
@@ -90,12 +90,9 @@ class ModularBackend(BaseBackend):
         if not os.path.isdir(basepath):
             raise RuntimeError("Cannot open database at '%s'" % (db,))
         
-        dbuser = 'pithos'
-        dbpass = 'archipelagos'
-        dbhost = '62.217.112.56'
-        dbname = 'pithosdb'
-        connection_str = 'mysql://%s:%s@%s/%s' %(dbuser, dbpass, dbhost, dbname)
+        connection_str = 'postgresql://%s:%s@%s/%s' % db_options
         engine = create_engine(connection_str, echo=True)
+        self.con = engine.connect()
         
         params = {'blocksize': self.block_size,
                   'blockpath': basepath + '/blocks',
@@ -106,19 +103,17 @@ class ModularBackend(BaseBackend):
                   'namelen': self.blocker.hashlen}
         self.mapper = Mapper(**params)
         
-        params = {'connection': engine.connect(),
+        params = {'connection': self.con,
                   'engine': engine}
         self.permissions = Permissions(**params)
         self.policy = Policy(**params)
         self.node = Node(**params)
-        
-        self.con.commit()
     
     @backend_method
     def list_accounts(self, user, marker=None, limit=10000):
         """Return a list of accounts the user can access."""
         
-        logger.debug("list_accounts: %s %s", user, marker, limit)
+        logger.debug("list_accounts: %s %s %s", user, marker, limit)
         allowed = self._allowed_accounts(user)
         start, limit = self._list_limits(allowed, marker, limit)
         return allowed[start:start + limit]
@@ -369,6 +364,8 @@ class ModularBackend(BaseBackend):
         else:
             if shared:
                 allowed = self.permissions.access_list_shared('/'.join((account, container)))
+                if not allowed:
+                    return []
         path, node = self._lookup_container(account, container)
         return self._list_objects(node, path, prefix, delimiter, marker, limit, virtual, keys, until, allowed)
     
@@ -561,6 +558,7 @@ class ModularBackend(BaseBackend):
         
         logger.debug("list_versions: %s %s %s", account, container, name)
         self._can_read(user, account, container, name)
+        path, node = self._lookup_object(account, container, name)
         return self.node.node_get_versions(node, ['serial', 'mtime'])
     
     @backend_method(autocommit=0)
@@ -575,7 +573,7 @@ class ModularBackend(BaseBackend):
     
     @backend_method(autocommit=0)
     def put_block(self, data):
-        """Create a block and return the hash."""
+        """Store a block and return the hash."""
         
         logger.debug("put_block: %s", len(data))
         hashes, absent = self.blocker.block_stor((data,))
@@ -591,95 +589,6 @@ class ModularBackend(BaseBackend):
         h, e = self.blocker.block_delta(binascii.unhexlify(hash), ((offset, data),))
         return binascii.hexlify(h)
     
-    def _check_policy(self, policy):
-        for k in policy.keys():
-            if policy[k] == '':
-                policy[k] = self.default_policy.get(k)
-        for k, v in policy.iteritems():
-            if k == 'quota':
-                q = int(v) # May raise ValueError.
-                if q < 0:
-                    raise ValueError
-            elif k == 'versioning':
-                if v not in ['auto', 'manual', 'none']:
-                    raise ValueError
-            else:
-                raise ValueError
-    
-    def _sql_until(self, parent, until=None):
-        """Return the sql to get the latest versions until the timestamp given."""
-        
-        if until is None:
-            until = time.time()
-        sql = ("select v.serial, n.path, v.mtime, v.size "
-               "from versions v, nodes n "
-               "where v.serial = (select max(serial) "
-                                 "from versions "
-                                 "where node = v.node and mtime < %s) "
-               "and v.cluster != %s "
-               "and v.node = n.node "
-               "and v.node in (select node "
-                              "from nodes "
-                              "where parent = %s)")
-        return sql % (until, CLUSTER_DELETED, parent)
-    
-    def _list_limits(self, listing, marker, limit):
-        start = 0
-        if marker:
-            try:
-                start = listing.index(marker) + 1
-            except ValueError:
-                pass
-        if not limit or limit > 10000:
-            limit = 10000
-        return start, limit
-    
-    def _list_objects(self, parent, path, prefix='', delimiter=None, marker=None, limit=10000, virtual=True, keys=[], until=None, allowed=[]):
-        cont_prefix = path + '/'
-        if keys and len(keys) > 0:
-#             sql = '''select distinct o.name, o.version_id from (%s) o, metadata m where o.name like ? and
-#                         m.version_id = o.version_id and m.key in (%s)'''
-#             sql = sql % (self._sql_until(until), ', '.join('?' * len(keys)))
-#             param = (cont_prefix + prefix + '%',) + tuple(keys)
-#             if allowed:
-#                 sql += ' and (' + ' or '.join(('o.name like ?',) * len(allowed)) + ')'
-#                 param += tuple([x + '%' for x in allowed])
-#             sql += ' order by o.name'
-            return []
-        else:
-            sql = 'select path, serial from (%s) where path like ?'
-            sql = sql % self._sql_until(parent, until)
-            param = (cont_prefix + prefix + '%',)
-            if allowed:
-                sql += ' and (' + ' or '.join(('name like ?',) * len(allowed)) + ')'
-                param += tuple([x + '%' for x in allowed])
-            sql += ' order by path'
-        c = self.con.execute(sql, param)
-        objects = [(x[0][len(cont_prefix):], x[1]) for x in c.fetchall()]
-        if delimiter:
-            pseudo_objects = []
-            for x in objects:
-                pseudo_name = x[0]
-                i = pseudo_name.find(delimiter, len(prefix))
-                if not virtual:
-                    # If the delimiter is not found, or the name ends
-                    # with the delimiter's first occurence.
-                    if i == -1 or len(pseudo_name) == i + len(delimiter):
-                        pseudo_objects.append(x)
-                else:
-                    # If the delimiter is found, keep up to (and including) the delimiter.
-                    if i != -1:
-                        pseudo_name = pseudo_name[:i + len(delimiter)]
-                    if pseudo_name not in [y[0] for y in pseudo_objects]:
-                        if pseudo_name == x[0]:
-                            pseudo_objects.append(x)
-                        else:
-                            pseudo_objects.append((pseudo_name, None))
-            objects = pseudo_objects
-        
-        start, limit = self._list_limits([x[0] for x in objects], marker, limit)
-        return objects[start:start + limit]
-    
     # Path functions.
     
     def _put_object_node(self, account, container, name):
@@ -801,6 +710,49 @@ class ModularBackend(BaseBackend):
         if copy_data and src_version_id is not None:
             self._copy_data(src_version_id, dest_version_id)
     
+    def _list_limits(self, listing, marker, limit):
+        start = 0
+        if marker:
+            try:
+                start = listing.index(marker) + 1
+            except ValueError:
+                pass
+        if not limit or limit > 10000:
+            limit = 10000
+        return start, limit
+    
+    def _list_objects(self, parent, path, prefix='', delimiter=None, marker=None, limit=10000, virtual=True, keys=[], until=None, allowed=[]):
+        cont_prefix = path + '/'
+        prefix = cont_prefix + prefix
+        start = cont_prefix + marker if marker else None
+        before = until if until is not None else inf
+        filterq = ','.join(keys) if keys else None
+        
+        objects, prefixes = self.node.latest_version_list(parent, prefix, delimiter, start, limit, before, CLUSTER_DELETED, allowed, filterq)
+        objects.extend([(p, None) for p in prefixes] if virtual else [])
+        objects.sort()
+        objects = [(x[0][len(cont_prefix):], x[1]) for x in objects]
+        
+        start, limit = self._list_limits([x[0] for x in objects], marker, limit)
+        return objects[start:start + limit]
+    
+    # Policy functions.
+    
+    def _check_policy(self, policy):
+        for k in policy.keys():
+            if policy[k] == '':
+                policy[k] = self.default_policy.get(k)
+        for k, v in policy.iteritems():
+            if k == 'quota':
+                q = int(v) # May raise ValueError.
+                if q < 0:
+                    raise ValueError
+            elif k == 'versioning':
+                if v not in ['auto', 'manual', 'none']:
+                    raise ValueError
+            else:
+                raise ValueError
+    
     # Access control functions.
     
     def _check_groups(self, groups):