eliminate nested sql aggregations
authorSofia Papagiannaki <papagian@gmail.com>
Tue, 17 Jul 2012 01:22:49 +0000 (04:22 +0300)
committerSofia Papagiannaki <papagian@gmail.com>
Tue, 17 Jul 2012 01:22:49 +0000 (04:22 +0300)
Refs: #2675

snf-pithos-backend/pithos/backends/lib/sqlalchemy/node.py
snf-pithos-backend/pithos/backends/lib/sqlite/node.py

index 7a1f8b3..1ee4cca 100644 (file)
@@ -122,9 +122,11 @@ class Node(DBWorker):
                                          ondelete='CASCADE',
                                          onupdate='CASCADE'),
                               autoincrement=False))
+        columns.append(Column('latest_version', Integer))
         columns.append(Column('path', String(2048), default='', nullable=False))
         self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
         Index('idx_nodes_path', self.nodes.c.path, unique=True)
+        Index('idx_nodes_parent', self.nodes.c.parent)
         
         #create policy table
         columns=[]
@@ -170,6 +172,7 @@ class Node(DBWorker):
         self.versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
         Index('idx_versions_node_mtime', self.versions.c.node, self.versions.c.mtime)
         Index('idx_versions_node_uuid', self.versions.c.uuid)
+        Index('idx_versions_serial_cluster', self.versions.c.serial, self.versions.c.cluster)
         
         #create attributes table
         columns = []
@@ -514,10 +517,13 @@ class Node(DBWorker):
                     self.versions.c.uuid,
                     self.versions.c.checksum,
                     self.versions.c.cluster])
-        filtered = select([func.max(self.versions.c.serial)],
-                            self.versions.c.node == node)
         if before != inf:
+            filtered = select([func.max(self.versions.c.serial)],
+                            self.versions.c.node == node)
             filtered = filtered.where(self.versions.c.mtime < before)
+        else:
+            filtered = select([self.nodes.c.latest_version],
+                            self.versions.c.node == node)
         s = s.where(and_(self.versions.c.cluster != except_cluster,
                          self.versions.c.serial == filtered))
         r = self.conn.execute(s)
@@ -532,11 +538,15 @@ class Node(DBWorker):
         s = select([func.count(v.c.serial),
                     func.sum(v.c.size),
                     func.max(v.c.mtime)])
-        c1 = select([func.max(self.versions.c.serial)])
         if before != inf:
+            c1 = select([func.max(self.versions.c.serial)])
             c1 = c1.where(self.versions.c.mtime < before)
+            c1.where(self.versions.c.node == v.c.node)
+        else:
+            c1 = select([self.nodes.c.latest_version])
+            c1.where(self.nodes.c.node == v.c.node)
         c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
-        s = s.where(and_(v.c.serial == c1.where(self.versions.c.node == v.c.node),
+        s = s.where(and_(v.c.serial == c1,
                          v.c.cluster != except_cluster,
                          v.c.node.in_(c2)))
         rp = self.conn.execute(s)
@@ -554,10 +564,13 @@ class Node(DBWorker):
         s = select([func.count(v.c.serial),
                     func.sum(v.c.size),
                     func.max(v.c.mtime)])
-        c1 = select([func.max(self.versions.c.serial)],
-            self.versions.c.node == v.c.node)
         if before != inf:
+            c1 = select([func.max(self.versions.c.serial)],
+                    self.versions.c.node == v.c.node)
             c1 = c1.where(self.versions.c.mtime < before)
+        else:
+            c1 = select([self.nodes.c.serial],
+                    self.nodes.c.node == v.c.node)
         c2 = select([self.nodes.c.node], self.nodes.c.path.like(self.escape_like(path) + '%', escape='\\'))
         s = s.where(and_(v.c.serial == c1,
                          v.c.cluster != except_cluster,
@@ -571,6 +584,11 @@ class Node(DBWorker):
         mtime = max(mtime, r[2])
         return (count, size, mtime)
     
+    def nodes_set_latest_version(self, node, serial):
+        s = self.nodes.update().where(self.nodes.c.node == node)
+        s = s.values(latest_version = serial)
+        self.conn.execute(s).close()
+    
     def version_create(self, node, hash, size, type, source, muser, uuid, checksum, cluster=0):
         """Create a new version from the given properties.
            Return the (serial, mtime) of the new version.
@@ -581,6 +599,9 @@ class Node(DBWorker):
                                           mtime=mtime, muser=muser, uuid=uuid, checksum=checksum, cluster=cluster)
         serial = self.conn.execute(s).inserted_primary_key[0]
         self.statistics_update_ancestors(node, 1, size, mtime, cluster)
+        
+        self.nodes_set_latest_version(node, serial)
+        
         return serial, mtime
     
     def version_lookup(self, node, before=inf, cluster=0, all_props=True):
@@ -598,10 +619,13 @@ class Node(DBWorker):
                         v.c.size, v.c.type, v.c.source,
                         v.c.mtime, v.c.muser, v.c.uuid,
                         v.c.checksum, v.c.cluster])
-        c = select([func.max(self.versions.c.serial)],
-            self.versions.c.node == node)
         if before != inf:
+            c = select([func.max(self.versions.c.serial)],
+                self.versions.c.node == node)
             c = c.where(self.versions.c.mtime < before)
+        else:
+            c = select([self.nodes.c.latest_version],
+                self.nodes.c.node == node)
         s = s.where(and_(v.c.serial == c,
                          v.c.cluster == cluster))
         r = self.conn.execute(s)
@@ -616,7 +640,8 @@ class Node(DBWorker):
            Return a list with their properties:
            (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
         """
-        
+        if not nodes:
+            return ()
         v = self.versions.alias('v')
         if not all_props:
             s = select([v.c.serial])
@@ -625,10 +650,14 @@ class Node(DBWorker):
                         v.c.size, v.c.type, v.c.source,
                         v.c.mtime, v.c.muser, v.c.uuid,
                         v.c.checksum, v.c.cluster])
-        c = select([func.max(self.versions.c.serial)],
-            self.versions.c.node.in_(nodes)).group_by(self.versions.c.node)
         if before != inf:
+            c = select([func.max(self.versions.c.serial)],
+                self.versions.c.node.in_(nodes))
             c = c.where(self.versions.c.mtime < before)
+            c = c.group_by(self.versions.c.node)
+        else:
+            c = select([self.nodes.c.latest_version],
+                self.nodes.c.node.in_(nodes))
         s = s.where(and_(v.c.serial.in_(c),
                          v.c.cluster == cluster))
         s = s.order_by(v.c.node)
@@ -706,6 +735,11 @@ class Node(DBWorker):
         
         s = self.versions.delete().where(self.versions.c.serial == serial)
         self.conn.execute(s).close()
+        
+        props = self.version_lookup(node, cluster=cluster, all_props=False)
+        if props:
+            self.nodes_set_latest_version(v.node, serial)
+        
         return hash, size
     
     def attribute_get(self, serial, domain, keys=()):
@@ -799,10 +833,14 @@ class Node(DBWorker):
         v = self.versions.alias('v')
         n = self.nodes.alias('n')
         s = select([a.c.key]).distinct()
-        filtered = select([func.max(self.versions.c.serial)])
         if before != inf:
+            filtered = select([func.max(self.versions.c.serial)])
             filtered = filtered.where(self.versions.c.mtime < before)
-        s = s.where(v.c.serial == filtered.where(self.versions.c.node == v.c.node))
+            filtered = filtered.where(self.versions.c.node == v.c.node)
+        else:
+            filtered = select([self.nodes.c.latest_version])
+            filtered = filtered.where(self.nodes.c.node == v.c.node)
+        s = s.where(v.c.serial == filtered)
         s = s.where(v.c.cluster != except_cluster)
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))
@@ -890,10 +928,12 @@ class Node(DBWorker):
                         v.c.size, v.c.type, v.c.source,
                         v.c.mtime, v.c.muser, v.c.uuid,
                         v.c.checksum, v.c.cluster]).distinct()
-        filtered = select([func.max(self.versions.c.serial)])
         if before != inf:
+            filtered = select([func.max(self.versions.c.serial)])
             filtered = filtered.where(self.versions.c.mtime < before)
-        s = s.where(v.c.serial == filtered.where(self.versions.c.node == v.c.node))
+        else:
+            filtered = select([self.nodes.c.latest_version])
+        s = s.where(v.c.serial == filtered.where(self.nodes.c.node == v.c.node))
         s = s.where(v.c.cluster != except_cluster)
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))
index d3649b1..36dfb9a 100644 (file)
@@ -115,12 +115,15 @@ class Node(DBWorker):
                           ( node       integer primary key,
                             parent     integer default 0,
                             path       text    not null default '',
+                            latest_version     integer,
                             foreign key (parent)
                             references nodes(node)
                             on update cascade
                             on delete cascade ) """)
         execute(""" create unique index if not exists idx_nodes_path
                     on nodes(path) """)
+        execute(""" create index if not exists idx_nodes_parent
+                    on nodes(parent) """)
         
         execute(""" create table if not exists policy
                           ( node   integer,
@@ -164,6 +167,8 @@ class Node(DBWorker):
                     on versions(node, mtime) """)
         execute(""" create index if not exists idx_versions_node_uuid
                     on versions(uuid) """)
+        execute(""" create index if not exists idx_versions_serial_cluster
+                    on versions(serial, cluster) """)
         
         execute(""" create table if not exists attributes
                           ( serial integer,
@@ -433,12 +438,11 @@ class Node(DBWorker):
         
         # The latest version.
         q = ("select serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster "
-             "from versions "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = ? and mtime < ?) "
+             "from versions v "
+             "where serial = %s "
              "and cluster != ?")
-        execute(q, (node, before, except_cluster))
+        subq, args = self._construct_latest_version_subquery(node=node, before=before)
+        execute(q % subq, args + [except_cluster])
         props = fetchone()
         if props is None:
             return None
@@ -447,14 +451,13 @@ class Node(DBWorker):
         # First level, just under node (get population).
         q = ("select count(serial), sum(size), max(mtime) "
              "from versions v "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = v.node and mtime < ?) "
+             "where serial = %s "
              "and cluster != ? "
              "and node in (select node "
                           "from nodes "
                           "where parent = ?)")
-        execute(q, (before, except_cluster, node))
+        subq, args = self._construct_latest_version_subquery(node=None, before=before)
+        execute(q % subq, args + [except_cluster, node])
         r = fetchone()
         if r is None:
             return None
@@ -467,14 +470,13 @@ class Node(DBWorker):
         # This is why the full path is stored.
         q = ("select count(serial), sum(size), max(mtime) "
              "from versions v "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = v.node and mtime < ?) "
+             "where serial = %s "
              "and cluster != ? "
              "and node in (select node "
                           "from nodes "
                           "where path like ? escape '\\')")
-        execute(q, (before, except_cluster, self.escape_like(path) + '%'))
+        subq, args = self._construct_latest_version_subquery(node=None, before=before)
+        execute(q % subq, args + [except_cluster, self.escape_like(path) + '%'])
         r = fetchone()
         if r is None:
             return None
@@ -482,6 +484,11 @@ class Node(DBWorker):
         mtime = max(mtime, r[2])
         return (count, size, mtime)
     
+    def nodes_set_latest_version(self, node, serial):
+       q = ("update nodes set latest_version = ? where node = ?")
+        props = (serial, node)
+        self.execute(q, props)
+    
     def version_create(self, node, hash, size, type, source, muser, uuid, checksum, cluster=0):
         """Create a new version from the given properties.
            Return the (serial, mtime) of the new version.
@@ -493,6 +500,9 @@ class Node(DBWorker):
         props = (node, hash, size, type, source, mtime, muser, uuid, checksum, cluster)
         serial = self.execute(q, props).lastrowid
         self.statistics_update_ancestors(node, 1, size, mtime, cluster)
+        
+        self.nodes_set_latest_version(node, serial)
+        
         return serial, mtime
     
     def version_lookup(self, node, before=inf, cluster=0, all_props=True):
@@ -503,17 +513,16 @@ class Node(DBWorker):
         """
         
         q = ("select %s "
-             "from versions "
-             "where serial = (select max(serial) "
-                             "from versions "
-                             "where node = ? and mtime < ?) "
+             "from versions v "
+             "where serial = %s "
              "and cluster = ?")
+        subq, args = self._construct_latest_version_subquery(node=node, before=before)
         if not all_props:
-            q = q % "serial"
+            q = q % ("serial", subq)
         else:
-            q = q % "serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster"
+            q = q % ("serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster", subq)
         
-        self.execute(q, (node, before, cluster))
+        self.execute(q, args + [cluster])
         props = self.fetchone()
         if props is not None:
             return props
@@ -525,20 +534,19 @@ class Node(DBWorker):
            (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
         """
         
+        if not nodes:
+               return ()
         q = ("select %s "
              "from versions "
-             "where serial in (select max(serial) "
-                             "from versions "
-                             "where node in (%s) and mtime < ? group by node) "
+             "where serial in %s "
              "and cluster = ? %s")
-        placeholders = ','.join('?' for node in nodes)
+        subq, args = self._construct_latest_versions_subquery(nodes=nodes, before = before)
         if not all_props:
-            q = q % ("serial",  placeholders, '')
+            q = q % ("serial", subq, '')
         else:
-            q = q % ("serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster",  placeholders, 'order by node')
+            q = q % ("serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster",  subq, 'order by node')
         
-        args = nodes
-        args.extend((before, cluster))
+        args += [cluster]
         self.execute(q, args)
         return self.fetchall()
     
@@ -604,6 +612,10 @@ class Node(DBWorker):
         
         q = "delete from versions where serial = ?"
         self.execute(q, (serial,))
+        
+        props = self.version_lookup(node, cluster=cluster, all_props=False)
+        if props:
+               self.nodes_set_latest_version(node, props[0])
         return hash, size
     
     def attribute_get(self, serial, domain, keys=()):
@@ -725,6 +737,53 @@ class Node(DBWorker):
         
         return subq, args
     
+    def _construct_versions_nodes_latest_version_subquery(self, before=inf):
+        if before == inf:
+            q = ("n.latest_version ")
+            args = []
+        else:
+            q = ("(select max(serial) "
+                                  "from versions "
+                                  "where node = v.node and mtime < ?) ")
+            args = [before]
+        return q, args
+    
+    def _construct_latest_version_subquery(self, node=None, before=inf):
+        where_cond = "node = v.node"
+        args = []
+        if node:
+            where_cond = "node = ? "
+            args = [node]
+        
+        if before == inf:
+            q = ("(select latest_version "
+                   "from nodes "
+                   "where %s) ")
+        else:
+            q = ("(select max(serial) "
+                   "from versions "
+                   "where %s and mtime < ?) ")
+            args += [before]
+        return q % where_cond, args
+    
+    def _construct_latest_versions_subquery(self, nodes=(), before=inf):
+        where_cond = ""
+        args = []
+        if nodes:
+            where_cond = "node in (%s) " % ','.join('?' for node in nodes)
+            args = nodes
+        
+        if before == inf:
+            q = ("(select latest_version "
+                   "from nodes "
+                   "where %s ) ")
+        else:
+            q = ("(select max(serial) "
+                                "from versions "
+                                "where %s and mtime < ? group by node) ")
+            args += [before]
+        return q % where_cond, args
+    
     def latest_attribute_keys(self, parent, domain, before=inf, except_cluster=0, pathq=[]):
         """Return a list with all keys pairs defined
            for all latest versions under parent that
@@ -734,9 +793,7 @@ class Node(DBWorker):
         # TODO: Use another table to store before=inf results.
         q = ("select distinct a.key "
              "from attributes a, versions v, nodes n "
-             "where v.serial = (select max(serial) "
-                               "from versions "
-                               "where node = v.node and mtime < ?) "
+             "where v.serial = %s "
              "and v.cluster != ? "
              "and v.node in (select node "
                             "from nodes "
@@ -744,7 +801,9 @@ class Node(DBWorker):
              "and a.serial = v.serial "
              "and a.domain = ? "
              "and n.node = v.node")
-        args = (before, except_cluster, parent, domain)
+        subq, subargs = self._construct_latest_version_subquery(node=None, before=before)
+        args = subargs + [except_cluster, parent, domain]
+        q = q % subq
         subq, subargs = self._construct_paths(pathq)
         if subq is not None:
             q += subq
@@ -814,20 +873,20 @@ class Node(DBWorker):
         
         q = ("select distinct n.path, %s "
              "from versions v, nodes n "
-             "where v.serial = (select max(serial) "
-                               "from versions "
-                               "where node = v.node and mtime < ?) "
+             "where v.serial = %s "
              "and v.cluster != ? "
              "and v.node in (select node "
                             "from nodes "
                             "where parent = ?) "
              "and n.node = v.node "
              "and n.path > ? and n.path < ?")
+        subq, args = self._construct_versions_nodes_latest_version_subquery(before)
         if not all_props:
-            q = q % "v.serial"
+            q = q % ("v.serial", subq)
         else:
-            q = q % "v.serial, v.node, v.hash, v.size, v.type, v.source, v.mtime, v.muser, v.uuid, v.checksum, v.cluster"
-        args = [before, except_cluster, parent, start, nextling]
+            q = q % ("v.serial, v.node, v.hash, v.size, v.type, v.source, v.mtime, v.muser, v.uuid, v.checksum, v.cluster", subq)
+        args += [except_cluster, parent, start, nextling]
+        start_index = len(args) - 2
         
         subq, subargs = self._construct_paths(pathq)
         if subq is not None:
@@ -886,7 +945,7 @@ class Node(DBWorker):
             if count >= limit: 
                 break
             
-            args[3] = strnextling(pf) # New start.
+            args[start_index] = strnextling(pf) # New start.
             execute(q, args)
         
         return matches, prefixes