Add size queries in backend object lists.
[pithos] / pithos / backends / lib / sqlalchemy / node.py
index cca025e..094e8d4 100644 (file)
 
 from time import time
 from sqlalchemy import Table, Integer, BigInteger, DECIMAL, Column, String, MetaData, ForeignKey
+from sqlalchemy.types import Text
 from sqlalchemy.schema import Index, Sequence
-from sqlalchemy.sql import func, and_, or_, null, select, bindparam
+from sqlalchemy.sql import func, and_, or_, not_, null, select, bindparam, text, exists
 from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.engine.reflection import Inspector
 
 from dbworker import DBWorker
 
+from pithos.lib.filter import parse_filters
+
+
 ROOTNODE  = 0
 
-( SERIAL, NODE, HASH, SIZE, SOURCE, MTIME, MUSER, CLUSTER ) = range(8)
+( SERIAL, NODE, HASH, SIZE, SOURCE, MTIME, MUSER, UUID, CLUSTER ) = range(9)
 
 inf = float('inf')
 
@@ -77,11 +82,9 @@ def strprevling(prefix):
     s = prefix[:-1]
     c = ord(prefix[-1])
     if c > 0:
-        #s += unichr(c-1) + unichr(0xffff)
-        s += unichr(c-1)
+        s += unichr(c-1) + unichr(0xffff)
     return s
 
-
 _propnames = {
     'serial'    : 0,
     'node'      : 1,
@@ -90,7 +93,8 @@ _propnames = {
     'source'    : 4,
     'mtime'     : 5,
     'muser'     : 6,
-    'cluster'   : 7,
+    'uuid'      : 7,
+    'cluster'   : 8
 }
 
 
@@ -114,10 +118,20 @@ class Node(DBWorker):
                                          ondelete='CASCADE',
                                          onupdate='CASCADE'),
                               autoincrement=False))
-        columns.append(Column('path', String(2048), default='', nullable=False))
-        self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB', mysql_charset='utf8')
-        # place an index on path
-        Index('idx_nodes_path', self.nodes.c.path)
+        path_length = 2048
+        columns.append(Column('path', String(path_length), default='', nullable=False))
+        self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
+        
+        #create policy table
+        columns=[]
+        columns.append(Column('node', Integer,
+                              ForeignKey('nodes.node',
+                                         ondelete='CASCADE',
+                                         onupdate='CASCADE'),
+                              primary_key=True))
+        columns.append(Column('key', String(255), primary_key=True))
+        columns.append(Column('value', String(255)))
+        self.policies = Table('policy', metadata, *columns, mysql_engine='InnoDB')
         
         #create statistics table
         columns=[]
@@ -128,10 +142,10 @@ class Node(DBWorker):
                               primary_key=True))
         columns.append(Column('population', Integer, nullable=False, default=0))
         columns.append(Column('size', BigInteger, nullable=False, default=0))
-        columns.append(Column('mtime', DECIMAL))
+        columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
         columns.append(Column('cluster', Integer, nullable=False, default=0,
                               primary_key=True, autoincrement=False))
-        self.statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB', mysql_charset='utf8')
+        self.statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
         
         #create versions table
         columns=[]
@@ -143,12 +157,13 @@ class Node(DBWorker):
         columns.append(Column('hash', String(255)))
         columns.append(Column('size', BigInteger, nullable=False, default=0))
         columns.append(Column('source', Integer))
-        columns.append(Column('mtime', DECIMAL))
+        columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
         columns.append(Column('muser', String(255), nullable=False, default=''))
+        columns.append(Column('uuid', String(64), nullable=False, default=''))
         columns.append(Column('cluster', Integer, nullable=False, default=0))
-        self.versions = Table('versions', metadata, *columns, mysql_engine='InnoDB', mysql_charset='utf8')
-        Index('idx_versions_node_mtime', self.versions.c.node,
-              self.versions.c.mtime)
+        self.versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
+        Index('idx_versions_node_mtime', self.versions.c.node, self.versions.c.mtime)
+        Index('idx_versions_node_uuid', self.versions.c.uuid)
         
         #create attributes table
         columns = []
@@ -157,12 +172,23 @@ class Node(DBWorker):
                                          ondelete='CASCADE',
                                          onupdate='CASCADE'),
                               primary_key=True))
+        columns.append(Column('domain', String(255), primary_key=True))
         columns.append(Column('key', String(255), primary_key=True))
         columns.append(Column('value', String(255)))
-        self.attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB', mysql_charset='utf8')
+        self.attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
         
         metadata.create_all(self.engine)
         
+        # the following code creates an index of specific length
+        # this can be accompliced in sqlalchemy >= 0.7.3
+        # providing mysql_length option during index creation
+        insp = Inspector.from_engine(self.engine)
+        indexes = [elem['name'] for elem in insp.get_indexes('nodes')]
+        if 'idx_nodes_path' not in indexes:
+            explicit_length = '(%s)' %path_length if self.engine.name == 'mysql' else ''
+            s = text('CREATE UNIQUE INDEX idx_nodes_path ON nodes (path%s)' %explicit_length)
+            self.conn.execute(s).close()
+        
         s = self.nodes.select().where(and_(self.nodes.c.node == ROOTNODE,
                                            self.nodes.c.parent == ROOTNODE))
         rp = self.conn.execute(s)
@@ -188,7 +214,8 @@ class Node(DBWorker):
            Return None if the path is not found.
         """
         
-        s = select([self.nodes.c.node], self.nodes.c.path == path)
+        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
+        s = select([self.nodes.c.node], self.nodes.c.path.like(self.escape_like(path), escape='\\'))
         r = self.conn.execute(s)
         row = r.fetchone()
         r.close()
@@ -211,7 +238,7 @@ class Node(DBWorker):
     def node_get_versions(self, node, keys=(), propnames=_propnames):
         """Return the properties of all versions at node.
            If keys is empty, return all properties in the order
-           (serial, node, size, source, mtime, muser, cluster).
+           (serial, node, hash, size, source, mtime, muser, uuid, cluster).
         """
         
         s = select([self.versions.c.serial,
@@ -221,6 +248,7 @@ class Node(DBWorker):
                     self.versions.c.source,
                     self.versions.c.mtime,
                     self.versions.c.muser,
+                    self.versions.c.uuid,
                     self.versions.c.cluster], self.versions.c.node == node)
         s = s.order_by(self.versions.c.serial)
         r = self.conn.execute(s)
@@ -248,7 +276,7 @@ class Node(DBWorker):
     def node_purge_children(self, parent, before=inf, cluster=0):
         """Delete all versions with the specified
            parent and cluster, and return
-           the serials of versions deleted.
+           the hashes of versions deleted.
            Clears out nodes with no remaining versions.
         """
         #update statistics
@@ -271,10 +299,10 @@ class Node(DBWorker):
         self.statistics_update(parent, -nr, size, mtime, cluster)
         self.statistics_update_ancestors(parent, -nr, size, mtime, cluster)
         
-        s = select([self.versions.c.serial])
+        s = select([self.versions.c.hash])
         s = s.where(where_clause)
         r = self.conn.execute(s)
-        serials = [row[SERIAL] for row in r.fetchall()]
+        hashes = [row[0] for row in r.fetchall()]
         r.close()
         
         #delete versions
@@ -293,12 +321,12 @@ class Node(DBWorker):
         s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
         self.conn.execute(s).close()
         
-        return serials
+        return hashes
     
     def node_purge(self, node, before=inf, cluster=0):
         """Delete all versions with the specified
            node and cluster, and return
-           the serials of versions deleted.
+           the hashes of versions deleted.
            Clears out the node if it has no remaining versions.
         """
         
@@ -319,10 +347,10 @@ class Node(DBWorker):
         mtime = time()
         self.statistics_update_ancestors(node, -nr, -size, mtime, cluster)
         
-        s = select([self.versions.c.serial])
+        s = select([self.versions.c.hash])
         s = s.where(where_clause)
         r = self.conn.execute(s)
-        serials = [r[SERIAL] for r in r.fetchall()]
+        hashes = [r[0] for r in r.fetchall()]
         r.close()
         
         #delete versions
@@ -341,7 +369,7 @@ class Node(DBWorker):
         s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
         self.conn.execute(s).close()
         
-        return serials
+        return hashes
     
     def node_remove(self, node):
         """Remove the node specified.
@@ -366,6 +394,28 @@ class Node(DBWorker):
         self.conn.execute(s).close()
         return True
     
+    def policy_get(self, node):
+        s = select([self.policies.c.key, self.policies.c.value],
+            self.policies.c.node==node)
+        r = self.conn.execute(s)
+        d = dict(r.fetchall())
+        r.close()
+        return d
+    
+    def policy_set(self, node, policy):
+        #insert or replace
+        for k, v in policy.iteritems():
+            s = self.policies.update().where(and_(self.policies.c.node == node,
+                                                  self.policies.c.key == k))
+            s = s.values(value = v)
+            rp = self.conn.execute(s)
+            rp.close()
+            if rp.rowcount == 0:
+                s = self.policies.insert()
+                values = {'node':node, 'key':k, 'value':v}
+                r = self.conn.execute(s, values)
+                r.close()
+    
     def statistics_get(self, node, cluster=0):
         """Return population, total size and last mtime
            for all versions under node that belong to the cluster.
@@ -387,7 +437,6 @@ class Node(DBWorker):
            size of objects and mtime in the node's namespace.
            May be zero or positive or negative numbers.
         """
-        
         s = select([self.statistics.c.population, self.statistics.c.size],
             and_(self.statistics.c.node == node,
                  self.statistics.c.cluster == cluster))
@@ -451,6 +500,7 @@ class Node(DBWorker):
                     self.versions.c.source,
                     self.versions.c.mtime,
                     self.versions.c.muser,
+                    self.versions.c.uuid,
                     self.versions.c.cluster])
         filtered = select([func.max(self.versions.c.serial)],
                             self.versions.c.node == node)
@@ -496,7 +546,7 @@ class Node(DBWorker):
             self.versions.c.node == v.c.node)
         if before != inf:
             c1 = c1.where(self.versions.c.mtime < before)
-        c2 = select([self.nodes.c.node], self.nodes.c.path.like(path + '%'))
+        c2 = select([self.nodes.c.node], self.nodes.c.path.like(self.escape_like(path) + '%', escape='\\'))
         s = s.where(and_(v.c.serial == c1,
                          v.c.cluster != except_cluster,
                          v.c.node.in_(c2)))
@@ -509,14 +559,14 @@ class Node(DBWorker):
         mtime = max(mtime, r[2])
         return (count, size, mtime)
     
-    def version_create(self, node, hash, size, source, muser, cluster=0):
+    def version_create(self, node, hash, size, source, muser, uuid, cluster=0):
         """Create a new version from the given properties.
            Return the (serial, mtime) of the new version.
         """
         
         mtime = time()
         s = self.versions.insert().values(node=node, hash=hash, size=size, source=source,
-                                          mtime=mtime, muser=muser, cluster=cluster)
+                                          mtime=mtime, muser=muser, uuid=uuid, cluster=cluster)
         serial = self.conn.execute(s).inserted_primary_key[0]
         self.statistics_update_ancestors(node, 1, size, mtime, cluster)
         return serial, mtime
@@ -524,13 +574,14 @@ class Node(DBWorker):
     def version_lookup(self, node, before=inf, cluster=0):
         """Lookup the current version of the given node.
            Return a list with its properties:
-           (serial, node, hash, size, source, mtime, muser, cluster)
+           (serial, node, hash, size, source, mtime, muser, uuid, cluster)
            or None if the current version is not found in the given cluster.
         """
         
         v = self.versions.alias('v')
-        s = select([v.c.serial, v.c.node, v.c.hash, v.c.size,
-                    v.c.source, v.c.mtime, v.c.muser, v.c.cluster])
+        s = select([v.c.serial, v.c.node, v.c.hash,
+                    v.c.size, v.c.source, v.c.mtime,
+                    v.c.muser, v.c.uuid, v.c.cluster])
         c = select([func.max(self.versions.c.serial)],
             self.versions.c.node == node)
         if before != inf:
@@ -548,12 +599,13 @@ class Node(DBWorker):
         """Return a sequence of values for the properties of
            the version specified by serial and the keys, in the order given.
            If keys is empty, return all properties in the order
-           (serial, node, hash, size, source, mtime, muser, cluster).
+           (serial, node, hash, size, source, mtime, muser, uuid, cluster).
         """
         
         v = self.versions.alias()
-        s = select([v.c.serial, v.c.node, v.c.hash, v.c.size,
-                    v.c.source, v.c.mtime, v.c.muser, v.c.cluster], v.c.serial == serial)
+        s = select([v.c.serial, v.c.node, v.c.hash,
+                    v.c.size, v.c.source, v.c.mtime,
+                    v.c.muser, v.c.uuid, v.c.cluster], v.c.serial == serial)
         rp = self.conn.execute(s)
         r = rp.fetchone()
         rp.close()
@@ -588,10 +640,11 @@ class Node(DBWorker):
     def version_remove(self, serial):
         """Remove the serial specified."""
         
-        props = self.node_get_properties(serial)
+        props = self.version_get_properties(serial)
         if not props:
             return
         node = props[NODE]
+        hash = props[HASH]
         size = props[SIZE]
         cluster = props[CLUSTER]
         
@@ -600,9 +653,9 @@ class Node(DBWorker):
         
         s = self.versions.delete().where(self.versions.c.serial == serial)
         self.conn.execute(s).close()
-        return True
+        return hash
     
-    def attribute_get(self, serial, keys=()):
+    def attribute_get(self, serial, domain, keys=()):
         """Return a list of (key, value) pairs of the version specified by serial.
            If keys is empty, return all attributes.
            Othwerise, return only those specified.
@@ -612,17 +665,19 @@ class Node(DBWorker):
             attrs = self.attributes.alias()
             s = select([attrs.c.key, attrs.c.value])
             s = s.where(and_(attrs.c.key.in_(keys),
-                             attrs.c.serial == serial))
+                             attrs.c.serial == serial,
+                             attrs.c.domain == domain))
         else:
             attrs = self.attributes.alias()
             s = select([attrs.c.key, attrs.c.value])
-            s = s.where(attrs.c.serial == serial)
+            s = s.where(and_(attrs.c.serial == serial,
+                             attrs.c.domain == domain))
         r = self.conn.execute(s)
         l = r.fetchall()
         r.close()
         return l
     
-    def attribute_set(self, serial, items):
+    def attribute_set(self, serial, domain, items):
         """Set the attributes of the version specified by serial.
            Receive attributes as an iterable of (key, value) pairs.
         """
@@ -631,16 +686,17 @@ class Node(DBWorker):
         for k, v in items:
             s = self.attributes.update()
             s = s.where(and_(self.attributes.c.serial == serial,
+                             self.attributes.c.domain == domain,
                              self.attributes.c.key == k))
             s = s.values(value = v)
             rp = self.conn.execute(s)
             rp.close()
             if rp.rowcount == 0:
                 s = self.attributes.insert()
-                s = s.values(serial=serial, key=k, value=v)
+                s = s.values(serial=serial, domain=domain, key=k, value=v)
                 self.conn.execute(s).close()
     
-    def attribute_del(self, serial, keys=()):
+    def attribute_del(self, serial, domain, keys=()):
         """Delete attributes of the version specified by serial.
            If keys is empty, delete all attributes.
            Otherwise delete those specified.
@@ -651,32 +707,35 @@ class Node(DBWorker):
             for key in keys:
                 s = self.attributes.delete()
                 s = s.where(and_(self.attributes.c.serial == serial,
+                                 self.attributes.c.domain == domain,
                                  self.attributes.c.key == key))
                 self.conn.execute(s).close()
         else:
             s = self.attributes.delete()
-            s = s.where(self.attributes.c.serial == serial)
+            s = s.where(and_(self.attributes.c.serial == serial,
+                             self.attributes.c.domain == domain))
             self.conn.execute(s).close()
     
     def attribute_copy(self, source, dest):
-        s = select([dest, self.attributes.c.key, self.attributes.c.value],
+        s = select([dest, self.attributes.c.domain, self.attributes.c.key, self.attributes.c.value],
             self.attributes.c.serial == source)
         rp = self.conn.execute(s)
         attributes = rp.fetchall()
         rp.close()
-        for dest, k, v in attributes:
+        for dest, domain, k, v in attributes:
             #insert or replace
             s = self.attributes.update().where(and_(
                 self.attributes.c.serial == dest,
+                self.attributes.c.domain == domain,
                 self.attributes.c.key == k))
             rp = self.conn.execute(s, value=v)
             rp.close()
             if rp.rowcount == 0:
                 s = self.attributes.insert()
-                values = {'serial':dest, 'key':k, 'value':v}
+                values = {'serial':dest, 'domain':domain, 'key':k, 'value':v}
                 self.conn.execute(s, values).close()
     
-    def latest_attribute_keys(self, parent, before=inf, except_cluster=0, pathq=[]):
+    def latest_attribute_keys(self, parent, domain, before=inf, except_cluster=0, pathq=[]):
         """Return a list with all keys pairs defined
            for all latest versions under parent that
            do not belong to the cluster.
@@ -695,10 +754,11 @@ class Node(DBWorker):
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))
         s = s.where(a.c.serial == v.c.serial)
+        s = s.where(a.c.domain == domain)
         s = s.where(n.c.node == v.c.node)
         conj = []
         for x in pathq:
-            conj.append(n.c.path.like(x + '%'))
+            conj.append(n.c.path.like(self.escape_like(x) + '%', escape='\\'))
         if conj:
             s = s.where(or_(*conj))
         rp = self.conn.execute(s)
@@ -708,7 +768,7 @@ class Node(DBWorker):
     
     def latest_version_list(self, parent, prefix='', delimiter=None,
                             start='', limit=10000, before=inf,
-                            except_cluster=0, pathq=[], filterq=None):
+                            except_cluster=0, pathq=[], domain=None, filterq=[], sizeq=None):
         """Return a (list of (path, serial) tuples, list of common prefixes)
            for the current versions of the paths with the given parent,
            matching the following criteria.
@@ -743,6 +803,8 @@ class Node(DBWorker):
                    key ?op value
                        the attribute with this key satisfies the value
                        where ?op is one of ==, != <=, >=, <, >.
+                
+                h. the size is in the range set by sizeq
            
            The list of common prefixes includes the prefixes
            matching up to the first delimiter after prefix,
@@ -759,7 +821,6 @@ class Node(DBWorker):
             start = strprevling(prefix)
         nextling = strnextling(prefix)
         
-        a = self.attributes.alias('a')
         v = self.versions.alias('v')
         n = self.nodes.alias('n')
         s = select([n.c.path, v.c.serial]).distinct()
@@ -770,20 +831,43 @@ class Node(DBWorker):
         s = s.where(v.c.cluster != except_cluster)
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))
-        if filterq:
-            s = s.where(a.c.serial == v.c.serial)
         
         s = s.where(n.c.node == v.c.node)
         s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling))
         conj = []
         for x in pathq:
-            conj.append(n.c.path.like(x + '%'))
-        
+            conj.append(n.c.path.like(self.escape_like(x) + '%', escape='\\'))
         if conj:
             s = s.where(or_(*conj))
         
-        if filterq:
-            s = s.where(a.c.key.in_(filterq.split(',')))
+        if sizeq and len(sizeq) == 2:
+            if sizeq[0]:
+                s = s.where(v.c.size >= sizeq[0])
+            if sizeq[1]:
+                s = s.where(v.c.size < sizeq[1])
+        
+        if domain and filterq:
+            a = self.attributes.alias('a')
+            included, excluded, opers = parse_filters(filterq)
+            if included:
+                subs = select([1])
+                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
+                subs = subs.where(a.c.domain == domain)
+                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in included]))
+                s = s.where(exists(subs))
+            if excluded:
+                subs = select([1])
+                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
+                subs = subs.where(a.c.domain == domain)
+                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in excluded]))
+                s = s.where(not_(exists(subs)))
+            if opers:
+                for k, o, val in opers:
+                    subs = select([1])
+                    subs = subs.where(a.c.serial == v.c.serial).correlate(v)
+                    subs = subs.where(a.c.domain == domain)
+                    subs = subs.where(and_(a.c.key.op('=')(k), a.c.value.op(o)(val)))
+                    s = s.where(exists(subs))
         
         s = s.order_by(n.c.path)
         
@@ -830,3 +914,18 @@ class Node(DBWorker):
         rp.close()
         
         return matches, prefixes
+    
+    def latest_uuid(self, uuid):
+        """Return a (path, serial) tuple, for the latest version of the given uuid."""
+        
+        v = self.versions.alias('v')
+        n = self.nodes.alias('n')
+        s = select([n.c.path, v.c.serial])
+        filtered = select([func.max(self.versions.c.serial)])
+        s = s.where(v.c.serial == filtered.where(self.versions.c.uuid == uuid))
+        s = s.where(n.c.node == v.c.node)
+        
+        r = self.conn.execute(s)
+        l = r.fetchone()
+        r.close()
+        return l