backend components in SQLAlchemy: Progress III
authorSofia Papagiannaki <papagian@gmail.com>
Wed, 31 Aug 2011 09:13:08 +0000 (12:13 +0300)
committerSofia Papagiannaki <papagian@gmail.com>
Wed, 31 Aug 2011 09:13:08 +0000 (12:13 +0300)
pithos/backends/lib_alchemy/node.py

index 359d237..6a3d949 100644 (file)
 from time import time
 from sqlalchemy import Table, Integer, Column, String, MetaData, ForeignKey
 from sqlalchemy.schema import Index, Sequence
-from sqlalchemy.sql import func, and_, null
-from sqlalchemy.sql import select
+from sqlalchemy.sql import func, and_, or_, null, select, bindparam
 
 from dbworker import DBWorker
 
-
 ROOTNODE  = 1
 
 ( SERIAL, NODE, SIZE, SOURCE, MTIME, MUSER, CLUSTER ) = range(7)
@@ -241,12 +239,13 @@ class Node(DBWorker):
            the serials of versions deleted.
            Clears out nodes with no remaining versions.
         """
-        
-        scalar = select([self.nodes.c.node],
-            self.nodes.c.parent == parent).as_scalar()
-        where_clause = and_(self.versions.c.node.in_(scalar),
+        #update statistics
+        #TODO handle before=inf
+        c1 = select([self.nodes.c.node],
+            self.nodes.c.parent == parent)
+        where_clause = and_(self.versions.c.node.in_(c1),
                             self.versions.c.cluster == cluster,
-                            "versions.mtime <= %f" %before)
+                            self.versions.c.mtime <= before)
         s = select([func.count(self.versions.c.serial),
                     func.sum(self.versions.c.size)])
         s = s.where(where_clause)
@@ -257,7 +256,6 @@ class Node(DBWorker):
             return ()
         nr, size = row[0], -row[1] if row[1] else 0
         mtime = time()
-        print '#', parent, -nr, size, mtime, cluster
         self.statistics_update(parent, -nr, size, mtime, cluster)
         self.statistics_update_ancestors(parent, -nr, size, mtime, cluster)
         
@@ -267,21 +265,22 @@ class Node(DBWorker):
         serials = [row[SERIAL] for row in r.fetchall()]
         r.close()
         
-        #delete versiosn
+        #delete versions
         s = self.versions.delete().where(where_clause)
         r = self.conn.execute(s)
         r.close()
         
         #delete nodes
-        a = self.nodes.alias()
-        no_versions = select([func.count(self.versions.c.serial)],
-            self.versions.c.node == a.c.node).as_scalar() == 0
-        n = select([self.nodes.c.node],
-            and_(no_versions, self.nodes.c.parent == parent))
-        s = s.where(self.nodes.c.node.in_(n))
-        s = self.nodes.delete().where(self.nodes.c.node == s)
-        print '#', s
+        s = select([self.nodes.c.node],
+            and_(self.nodes.c.parent == parent,
+                 select([func.count(self.versions.c.serial)],
+                    self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
+        r = self.conn.execute(s)
+        nodes = r.fetchall()
+        r.close()
+        s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
         self.conn.execute(s).close()
+        
         return serials
     
     def node_purge(self, node, before=inf, cluster=0):
@@ -291,37 +290,43 @@ class Node(DBWorker):
            Clears out the node if it has no remaining versions.
         """
         
-        execute = self.execute
-        q = ("select count(serial), sum(size) from versions "
-             "where node = ? "
-             "and cluster = ? "
-             "and mtime <= ?")
-        args = (node, cluster, before)
-        execute(q, args)
-        nr, size = self.fetchone()
+        #update statistics
+        s = select([func.count(self.versions.c.serial),
+                    func.sum(self.versions.c.size)])
+        where_clause = and_(self.versions.c.node == node,
+                         self.versions.c.cluster == cluster,
+                         self.versions.c.mtime <= before)
+        s = s.where(where_clause)
+        r = self.conn.execute(s)
+        row = r.fetchone()
+        nr, size = row[0], row[1]
+        r.close()
         if not nr:
             return ()
         mtime = time()
         self.statistics_update_ancestors(node, -nr, -size, mtime, cluster)
         
-        q = ("select serial from versions "
-             "where node = ? "
-             "and cluster = ? "
-             "and mtime <= ?")
-        execute(q, args)
-        serials = [r[SERIAL] for r in self.fetchall()]
-        q = ("delete from versions "
-             "where node = ? "
-             "and cluster = ? "
-             "and mtime <= ?")
-        execute(q, args)
-        q = ("delete from nodes "
-             "where node in (select node from nodes n "
-                            "where (select count(serial) "
-                                   "from versions "
-                                   "where node = n.node) = 0 "
-                            "and node = ?)")
-        execute(q, (node,))
+        s = select([self.versions.c.serial])
+        s = s.where(where_clause)
+        r = self.conn.execute(s)
+        serials = [r[SERIAL] for r in r.fetchall()]
+        
+        #delete versions
+        s = self.versions.delete().where(where_clause)
+        r = self.conn.execute(s)
+        r.close()
+        
+        #delete nodes
+        s = select([self.nodes.c.node],
+            and_(self.nodes.c.node == node,
+                 select([func.count(self.versions.c.serial)],
+                    self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
+        r = self.conn.execute(s)
+        nodes = r.fetchall()
+        r.close()
+        s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
+        self.conn.execute(s).close()
+        
         return serials
     
     def node_remove(self, node):
@@ -333,16 +338,18 @@ class Node(DBWorker):
             return False
         
         mtime = time()
-        q = ("select count(serial), sum(size), cluster "
-             "from versions "
-             "where node = ? "
-             "group by cluster")
-        self.execute(q, (node,))
-        for population, size, cluster in self.fetchall():
+        s = select([func.count(self.versions.c.serial),
+                    func.sum(self.versions.c.size),
+                    self.versions.c.cluster])
+        s = s.where(self.versions.c.node == node)
+        s = s.group_by(self.versions.c.cluster)
+        r = self.conn.execute(s)
+        for population, size, cluster in r.fetchall():
             self.statistics_update_ancestors(node, -population, -size, mtime, cluster)
+        r.close()
         
-        q = "delete from nodes where node = ?"
-        self.execute(q, (node,))
+        s = self.nodes.delete().where(self.nodes.c.node == node)
+        self.conn.execute(s).close()
         return True
     
     def statistics_get(self, node, cluster=0):
@@ -350,10 +357,15 @@ class Node(DBWorker):
            for all versions under node that belong to the cluster.
         """
         
-        q = ("select population, size, mtime from statistics "
-             "where node = ? and cluster = ?")
-        self.execute(q, (node, cluster))
-        return self.fetchone()
+        s = select([self.statistics.c.population,
+                    self.statistics.c.size,
+                    self.statistics.c.mtime])
+        s = s.where(and_(self.statistics.c.node == node,
+                         self.statistics.c.cluster == cluster))
+        r = self.conn.execute(s)
+        row = r.fetchone()
+        r.close()
+        return row
     
     def statistics_update(self, node, population, size, mtime, cluster=0):
         """Update the statistics of the given node.
@@ -365,9 +377,9 @@ class Node(DBWorker):
         s = select([self.statistics.c.population, self.statistics.c.size],
             and_(self.statistics.c.node == node,
                  self.statistics.c.cluster == cluster))
-        res = self.conn.execute(s)
-        r = res.fetchone()
-        res.close()
+        rp = self.conn.execute(s)
+        r = rp.fetchone()
+        rp.close()
         if not r:
             prepopulation, presize = (0, 0)
         else:
@@ -375,9 +387,8 @@ class Node(DBWorker):
         population += prepopulation
         size += presize
         
-        self.statistics.insert().values(node=node, population=population,
-                                        size=size, mtime=mtime, cluster=cluster)
-        self.conn.execute(s).close()
+        ins = self.statistics.insert().values(node, population, size, mtime, cluster)
+        self.conn.execute(ins).close()
     
     def statistics_update_ancestors(self, node, population, size, mtime, cluster=0):
         """Update the statistics of the given node's parent.
@@ -412,31 +423,40 @@ class Node(DBWorker):
         parent, path = props
         
         # The latest version.
-        q = ("select serial, node, size, source, mtime, muser, cluster "
-             "from versions "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = ? and mtime < ?) "
-             "and cluster != ?")
-        execute(q, (node, before, except_cluster))
-        props = fetchone()
-        if props is None:
+        s = select([self.versions.c.serial,
+                    self.versions.c.node,
+                    self.versions.c.size,
+                    self.versions.c.mtime,
+                    self.versions.c.muser,
+                    self.versions.c.cluster])
+        s = s.where(and_(self.versions.c.cluster != except_cluster,
+                         self.versions.c.serial == select(
+                            [func.max(self.versions.c.serial)],
+                            and_(self.versions.c.node == node,
+                            self.versions.c.mtime < before))))
+        r = self.conn.execute(s)
+        props = r.fetchone()
+        r.close()
+        if not props:
             return None
         mtime = props[MTIME]
         
         # First level, just under node (get population).
-        q = ("select count(serial), sum(size), max(mtime) "
-             "from versions v "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = v.node and mtime < ?) "
-             "and cluster != ? "
-             "and node in (select node "
-                          "from nodes "
-                          "where parent = ?)")
-        execute(q, (before, except_cluster, node))
-        r = fetchone()
-        if r is None:
+        v = self.versions.alias('v')
+        s = select([func.count(v.c.serial),
+                    func.sum(v.c.size),
+                    func.max(v.c.mtime)])
+        c1 = select([func.max(self.versions.c.serial)],
+            and_(self.versions.c.node == v.c.node,
+                 self.versions.c.mtime < before))
+        c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
+        s = s.where(and_(v.c.serial == c1,
+                         v.c.cluster != except_cluster,
+                         v.c.node.in_(c2)))
+        rp = self.conn.execute(s)
+        r = rp.fetchone()
+        rp.close()
+        if not r:
             return None
         count = r[0]
         mtime = max(mtime, r[2])
@@ -445,18 +465,20 @@ class Node(DBWorker):
         
         # All children (get size and mtime).
         # XXX: This is why the full path is stored.
-        q = ("select count(serial), sum(size), max(mtime) "
-             "from versions v "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = v.node and mtime < ?) "
-             "and cluster != ? "
-             "and node in (select node "
-                          "from nodes "
-                          "where path like ?)")
-        execute(q, (before, except_cluster, path + '%'))
-        r = fetchone()
-        if r is None:
+        s = select([func.count(v.c.serial),
+                    func.sum(v.c.size),
+                    func.max(v.c.mtime)])
+        c1 = select([func.max(self.versions.c.serial)],
+            and_(self.versions.c.node == v.c.node,
+                 self.versions.c.mtime < before))
+        c2 = select([self.nodes.c.node], self.nodes.c.path.like(path + '%'))
+        s = s.where(and_(v.c.serial == c1,
+                         v.c.cluster != except_cluster,
+                         v.c.node.in_(c2)))
+        rp = self.conn.execute(s)
+        r = rp.fetchone()
+        rp.close()
+        if not r:
             return None
         size = r[1] - props[SIZE]
         mtime = max(mtime, r[2])
@@ -467,11 +489,12 @@ class Node(DBWorker):
            Return the (serial, mtime) of the new version.
         """
         
-        q = ("insert into versions (node, size, source, mtime, muser, cluster) "
-             "values (?, ?, ?, ?, ?, ?)")
         mtime = time()
         props = (node, size, source, mtime, muser, cluster)
-        serial = self.execute(q, props).lastrowid
+        props = locals()
+        props.pop('self')
+        s = self.versions.insert().values(**props)
+        serial = self.conn.execute(s).inserted_primary_key[0]
         self.statistics_update_ancestors(node, 1, size, mtime, cluster)
         return serial, mtime
     
@@ -482,15 +505,18 @@ class Node(DBWorker):
            or None if the current version is not found in the given cluster.
         """
         
-        q = ("select serial, node, size, source, mtime, muser, cluster "
-             "from versions "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = ? and mtime < ?) "
-             "and cluster = ?")
-        self.execute(q, (node, before, cluster))
-        props = self.fetchone()
-        if props is not None:
+        v = self.versions.alias('v')
+        s = select([v.c.serial, v.c.node, v.c.size, v.c.source, v.c.mtime,
+                    v.c.muser, v.c.cluster])
+        c = select([func.max(self.versions.c.serial)],
+            and_(self.versions.c.node == node,
+                 self.versions.c.mtime < before))
+        s = s.where(and_(v.c.serial == c,
+                         v.c.cluster == cluster))
+        r = self.conn.execute(s)
+        props = r.fetchone()
+        r.close()
+        if not props:
             return props
         return None
     
@@ -501,11 +527,12 @@ class Node(DBWorker):
            (serial, node, size, source, mtime, muser, cluster).
         """
         
-        q = ("select serial, node, size, source, mtime, muser, cluster "
-             "from versions "
-             "where serial = ?")
-        self.execute(q, (serial,))
-        r = self.fetchone()
+        v = self.versions.alias()
+        s = select([v.c.serial, v.c.node, v.c.size, v.c.source, v.c.mtime,
+                   v.c.muser, v.c.cluster], v.c.serial == serial)
+        rp = self.conn.execute(s)
+        r = rp.fetchone()
+        rp.close()
         if r is None:
             return r
         
@@ -529,8 +556,10 @@ class Node(DBWorker):
         self.statistics_update_ancestors(node, -1, -size, mtime, oldcluster)
         self.statistics_update_ancestors(node, 1, size, mtime, cluster)
         
-        q = "update versions set cluster = ? where serial = ?"
-        self.execute(q, (cluster, serial))
+        s = self.versions.update()
+        s = s.where(self.versions.c.serial == serial)
+        s = s.values(cluster = cluster)
+        self.conn.execute(s).close()
     
     def version_remove(self, serial):
         """Remove the serial specified."""
@@ -545,8 +574,8 @@ class Node(DBWorker):
         mtime = time()
         self.statistics_update_ancestors(node, -1, -size, mtime, cluster)
         
-        q = "delete from versions where serial = ?"
-        self.execute(q, (serial,))
+        s = self.versions.delete().where(self.versions.c.serial == serial)
+        self.conn.execute(s).close()
         return True
     
     def attribute_get(self, serial, keys=()):
@@ -557,23 +586,26 @@ class Node(DBWorker):
         
         execute = self.execute
         if keys:
-            marks = ','.join('?' for k in keys)
-            q = ("select key, value from attributes "
-                 "where key in (%s) and serial = ?" % (marks,))
-            execute(q, keys + (serial,))
+            attrs = self.attributes.alias()
+            s = select([attrs.c.key, attrs.c.value])
+            s = s.where(and_(attrs.c.key.in_(keys),
+                             attrs.c.serial == serial))
         else:
-            q = "select key, value from attributes where serial = ?"
-            execute(q, (serial,))
-        return self.fetchall()
+            attrs = self.attributes.alias()
+            s = select([attrs.c.key, attrs.c.value])
+            s = s.where(attrs.c.serial == serial)
+        r = self.conn.execute(s)
+        l = r.fetchall()
+        r.close()
+        return l
     
     def attribute_set(self, serial, items):
         """Set the attributes of the version specified by serial.
            Receive attributes as an iterable of (key, value) pairs.
         """
         
-        q = ("insert or replace into attributes (serial, key, value) "
-             "values (?, ?, ?)")
-        self.executemany(q, ((serial, k, v) for k, v in items))
+        values = [{'serial':serial, 'key':k, 'value':v} for k, v in items]
+        self.conn.execute(self.attributes.insert(), values).close()
     
     def attribute_del(self, serial, keys=()):
         """Delete attributes of the version specified by serial.
@@ -582,39 +614,37 @@ class Node(DBWorker):
         """
         
         if keys:
-            q = "delete from attributes where serial = ? and key = ?"
-            self.executemany(q, ((serial, key) for key in keys))
+            #TODO more efficient way to do this?
+            for key in keys:
+                s = self.attributes.delete()
+                s = s.where(and_(self.attributes.c.serial == serial,
+                                 self.attributes.c.key == key))
+                self.conn.execute(s).close()
         else:
-            q = "delete from attributes where serial = ?"
-            self.execute(q, (serial,))
+            s = self.attributes.delete()
+            s = s.where(self.attributes.c.serial == serial)
+            self.conn.execute(s).close()
     
     def attribute_copy(self, source, dest):
-        q = ("insert or replace into attributes "
-             "select ?, key, value from attributes "
-             "where serial = ?")
-        self.execute(q, (dest, source))
-    
-    def _construct_filters(self, filterq):
-        if not filterq:
-            return None, None
-        
-        args = filterq.split(',')
-        subq = " and a.key in ("
-        subq += ','.join(('?' for x in args))
-        subq += ")"
-        
-        return subq, args
-    
-    def _construct_paths(self, pathq):
-        if not pathq:
-            return None, None
-        
-        subq = " and ("
-        subq += ' or '.join(('n.path like ?' for x in pathq))
-        subq += ")"
-        args = tuple([x + '%' for x in pathq])
-        
-        return subq, args
+        from sqlalchemy.ext.compiler import compiles
+        from sqlalchemy.sql.expression import UpdateBase
+                
+        class InsertFromSelect(UpdateBase):
+            def __init__(self, table, select):
+                self.table = table
+                self.select = select
+        
+        @compiles(InsertFromSelect)
+        def visit_insert_from_select(element, compiler, **kw):
+            return "INSERT INTO %s (%s)" % (
+                compiler.process(element.table, asfrom=True),
+                compiler.process(element.select)
+            )
+        
+        s = select([dest, self.attributes.c.key, self.attributes.c.value],
+            self.attributes.c.serial == source)
+        ins = InsertFromSelect(self.attributes, s)
+        self.conn.execute(ins).close()
     
     def latest_attribute_keys(self, parent, before=inf, except_cluster=0, pathq=[]):
         """Return a list with all keys pairs defined
@@ -623,23 +653,26 @@ class Node(DBWorker):
         """
         
         # TODO: Use another table to store before=inf results.
-        q = ("select distinct a.key "
-             "from attributes a, versions v, nodes n "
-             "where v.serial = (select max(serial) "
-                              "from versions "
-                              "where node = v.node and mtime < ?) "
-             "and v.cluster != ? "
-             "and v.node in (select node "
-                           "from nodes "
-                           "where parent = ?) "
-             "and a.serial = v.serial "
-             "and n.node = v.node")
-        args = (before, except_cluster, parent)
-        subq, subargs = self._construct_paths(pathq)
-        if subq is not None:
-            q += subq
-            args += subargs
-        self.execute(q, args)
+        a = self.attributes.alias('a')
+        v = self.versions.alias('v')
+        n = self.nodes.alias('n')
+        s = select([a.c.key]).distinct()
+        s = s.where(v.c.serial == select([func.max(self.versions.c.serial)],
+                                          and_(self.versions.c.node == v.c.node,
+                                               self.versions.c.mtime < before)))
+        s = s.where(v.c.cluster != except_cluster)
+        s = s.where(v.c.node.in_(select([self.nodes.c.node],
+            self.nodes.c.parent == parent)))
+        s = s.where(a.c.serial == v.c.serial)
+        s = s.where(n.c.node == v.c.node)
+        conj = []
+        for x in pathq:
+            conj.append(n.c.path.like(x + '%'))
+        if conj:
+            s = s.where(or_(*conj))
+        rp = self.conn.execute(s)
+        r = rp.fetchall()
+        rp.close()
         return [r[0] for r in self.fetchall()]
     
     def latest_version_list(self, parent, prefix='', delimiter=None,
@@ -697,38 +730,40 @@ class Node(DBWorker):
             start = strprevling(prefix)
         nextling = strnextling(prefix)
         
-        q = ("select distinct n.path, v.serial "
-             "from attributes a, versions v, nodes n "
-             "where v.serial = (select max(serial) "
-                              "from versions "
-                              "where node = v.node and mtime < ?) "
-             "and v.cluster != ? "
-             "and v.node in (select node "
-                           "from nodes "
-                           "where parent = ?) "
-             "and a.serial = v.serial "
-             "and n.node = v.node "
-             "and n.path > ? and n.path < ?")
-        args = [before, except_cluster, parent, start, nextling]
-        
-        subq, subargs = self._construct_paths(pathq)
-        if subq is not None:
-            q += subq
-            args += subargs
-        subq, subargs = self._construct_filters(filterq)
-        if subq is not None:
-            q += subq
-            args += subargs
-        else:
-            q = q.replace("attributes a, ", "")
-            q = q.replace("and a.serial = v.serial ", "")
-        q += " order by n.path"
+        a = self.attributes.alias('a')
+        v = self.versions.alias('v')
+        n = self.nodes.alias('n')
+        s = select([n.c.path, v.c.serial]).distinct()
+        s = s.where(v.c.serial == select([func.max(self.versions.c.serial)],
+            and_(self.versions.c.node == v.c.node,
+                 self.versions.c.mtime < before)))
+        s = s.where(v.c.cluster != except_cluster)
+        s = s.where(v.c.node.in_(select([self.nodes.c.node],
+            self.nodes.c.parent == parent)))
+        if filterq:
+            s = s.where(a.c.serial == v.c.serial)
+        
+        s = s.where(n.c.node == v.c.node)
+        s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling))
+        conj = []
+        for x in pathq:
+            print '#', x
+            conj.append(n.c.path.like(x + '%'))
+        
+        if conj:
+            s = s.where(or_(*conj))
+        
+        if filterq:
+            s = s.where(a.c.key.in_(filterq.split(',')))
+        
+        s = s.order_by(n.c.path)
         
         if not delimiter:
-            q += " limit ?"
-            args.append(limit)
-            execute(q, args)
-            return self.fetchall(), ()
+            s = s.limit(limit)
+            rp = self.conn.execute(s, start=start)
+            r = rp.fetchall()
+            rp.close()
+            return r, ()
         
         pfz = len(prefix)
         dz = len(delimiter)
@@ -739,9 +774,9 @@ class Node(DBWorker):
         matches = []
         mappend = matches.append
         
-        execute(q, args)
+        rp = self.conn.execute(s, start=start)
         while True:
-            props = fetchone()
+            props = rp.fetchone()
             if props is None:
                 break
             path, serial = props
@@ -762,7 +797,6 @@ class Node(DBWorker):
             if count >= limit: 
                 break
             
-            args[3] = strnextling(pf) # New start.
-            execute(q, args)
+            rp = self.conn.execute(s, start=strnextling(pf)) # New start.
         
         return matches, prefixes