eliminate nested sql aggregations
[pithos] / snf-pithos-backend / pithos / backends / lib / sqlalchemy / node.py
index 7a1f8b3..1ee4cca 100644 (file)
@@ -122,9 +122,11 @@ class Node(DBWorker):
                                          ondelete='CASCADE',
                                          onupdate='CASCADE'),
                               autoincrement=False))
+        columns.append(Column('latest_version', Integer))
         columns.append(Column('path', String(2048), default='', nullable=False))
         self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
         Index('idx_nodes_path', self.nodes.c.path, unique=True)
+        Index('idx_nodes_parent', self.nodes.c.parent)
         
         #create policy table
         columns=[]
@@ -170,6 +172,7 @@ class Node(DBWorker):
         self.versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
         Index('idx_versions_node_mtime', self.versions.c.node, self.versions.c.mtime)
         Index('idx_versions_node_uuid', self.versions.c.uuid)
+        Index('idx_versions_serial_cluster', self.versions.c.serial, self.versions.c.cluster)
         
         #create attributes table
         columns = []
@@ -514,10 +517,13 @@ class Node(DBWorker):
                     self.versions.c.uuid,
                     self.versions.c.checksum,
                     self.versions.c.cluster])
-        filtered = select([func.max(self.versions.c.serial)],
-                            self.versions.c.node == node)
         if before != inf:
+            filtered = select([func.max(self.versions.c.serial)],
+                            self.versions.c.node == node)
             filtered = filtered.where(self.versions.c.mtime < before)
+        else:
+            filtered = select([self.nodes.c.latest_version],
+                            self.versions.c.node == node)
         s = s.where(and_(self.versions.c.cluster != except_cluster,
                          self.versions.c.serial == filtered))
         r = self.conn.execute(s)
@@ -532,11 +538,15 @@ class Node(DBWorker):
         s = select([func.count(v.c.serial),
                     func.sum(v.c.size),
                     func.max(v.c.mtime)])
-        c1 = select([func.max(self.versions.c.serial)])
         if before != inf:
+            c1 = select([func.max(self.versions.c.serial)])
             c1 = c1.where(self.versions.c.mtime < before)
+            c1.where(self.versions.c.node == v.c.node)
+        else:
+            c1 = select([self.nodes.c.latest_version])
+            c1.where(self.nodes.c.node == v.c.node)
         c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
-        s = s.where(and_(v.c.serial == c1.where(self.versions.c.node == v.c.node),
+        s = s.where(and_(v.c.serial == c1,
                          v.c.cluster != except_cluster,
                          v.c.node.in_(c2)))
         rp = self.conn.execute(s)
@@ -554,10 +564,13 @@ class Node(DBWorker):
         s = select([func.count(v.c.serial),
                     func.sum(v.c.size),
                     func.max(v.c.mtime)])
-        c1 = select([func.max(self.versions.c.serial)],
-            self.versions.c.node == v.c.node)
         if before != inf:
+            c1 = select([func.max(self.versions.c.serial)],
+                    self.versions.c.node == v.c.node)
             c1 = c1.where(self.versions.c.mtime < before)
+        else:
+            c1 = select([self.nodes.c.serial],
+                    self.nodes.c.node == v.c.node)
         c2 = select([self.nodes.c.node], self.nodes.c.path.like(self.escape_like(path) + '%', escape='\\'))
         s = s.where(and_(v.c.serial == c1,
                          v.c.cluster != except_cluster,
@@ -571,6 +584,11 @@ class Node(DBWorker):
         mtime = max(mtime, r[2])
         return (count, size, mtime)
     
+    def nodes_set_latest_version(self, node, serial):
+        s = self.nodes.update().where(self.nodes.c.node == node)
+        s = s.values(latest_version = serial)
+        self.conn.execute(s).close()
+    
     def version_create(self, node, hash, size, type, source, muser, uuid, checksum, cluster=0):
         """Create a new version from the given properties.
            Return the (serial, mtime) of the new version.
@@ -581,6 +599,9 @@ class Node(DBWorker):
                                           mtime=mtime, muser=muser, uuid=uuid, checksum=checksum, cluster=cluster)
         serial = self.conn.execute(s).inserted_primary_key[0]
         self.statistics_update_ancestors(node, 1, size, mtime, cluster)
+        
+        self.nodes_set_latest_version(node, serial)
+        
         return serial, mtime
     
     def version_lookup(self, node, before=inf, cluster=0, all_props=True):
@@ -598,10 +619,13 @@ class Node(DBWorker):
                         v.c.size, v.c.type, v.c.source,
                         v.c.mtime, v.c.muser, v.c.uuid,
                         v.c.checksum, v.c.cluster])
-        c = select([func.max(self.versions.c.serial)],
-            self.versions.c.node == node)
         if before != inf:
+            c = select([func.max(self.versions.c.serial)],
+                self.versions.c.node == node)
             c = c.where(self.versions.c.mtime < before)
+        else:
+            c = select([self.nodes.c.latest_version],
+                self.nodes.c.node == node)
         s = s.where(and_(v.c.serial == c,
                          v.c.cluster == cluster))
         r = self.conn.execute(s)
@@ -616,7 +640,8 @@ class Node(DBWorker):
            Return a list with their properties:
            (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
         """
-        
+        if not nodes:
+            return ()
         v = self.versions.alias('v')
         if not all_props:
             s = select([v.c.serial])
@@ -625,10 +650,14 @@ class Node(DBWorker):
                         v.c.size, v.c.type, v.c.source,
                         v.c.mtime, v.c.muser, v.c.uuid,
                         v.c.checksum, v.c.cluster])
-        c = select([func.max(self.versions.c.serial)],
-            self.versions.c.node.in_(nodes)).group_by(self.versions.c.node)
         if before != inf:
+            c = select([func.max(self.versions.c.serial)],
+                self.versions.c.node.in_(nodes))
             c = c.where(self.versions.c.mtime < before)
+            c = c.group_by(self.versions.c.node)
+        else:
+            c = select([self.nodes.c.latest_version],
+                self.nodes.c.node.in_(nodes))
         s = s.where(and_(v.c.serial.in_(c),
                          v.c.cluster == cluster))
         s = s.order_by(v.c.node)
@@ -706,6 +735,11 @@ class Node(DBWorker):
         
         s = self.versions.delete().where(self.versions.c.serial == serial)
         self.conn.execute(s).close()
+        
+        props = self.version_lookup(node, cluster=cluster, all_props=False)
+        if props:
+            self.nodes_set_latest_version(v.node, serial)
+        
         return hash, size
     
     def attribute_get(self, serial, domain, keys=()):
@@ -799,10 +833,14 @@ class Node(DBWorker):
         v = self.versions.alias('v')
         n = self.nodes.alias('n')
         s = select([a.c.key]).distinct()
-        filtered = select([func.max(self.versions.c.serial)])
         if before != inf:
+            filtered = select([func.max(self.versions.c.serial)])
             filtered = filtered.where(self.versions.c.mtime < before)
-        s = s.where(v.c.serial == filtered.where(self.versions.c.node == v.c.node))
+            filtered = filtered.where(self.versions.c.node == v.c.node)
+        else:
+            filtered = select([self.nodes.c.latest_version])
+            filtered = filtered.where(self.nodes.c.node == v.c.node)
+        s = s.where(v.c.serial == filtered)
         s = s.where(v.c.cluster != except_cluster)
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))
@@ -890,10 +928,12 @@ class Node(DBWorker):
                         v.c.size, v.c.type, v.c.source,
                         v.c.mtime, v.c.muser, v.c.uuid,
                         v.c.checksum, v.c.cluster]).distinct()
-        filtered = select([func.max(self.versions.c.serial)])
         if before != inf:
+            filtered = select([func.max(self.versions.c.serial)])
             filtered = filtered.where(self.versions.c.mtime < before)
-        s = s.where(v.c.serial == filtered.where(self.versions.c.node == v.c.node))
+        else:
+            filtered = select([self.nodes.c.latest_version])
+        s = s.where(v.c.serial == filtered.where(self.nodes.c.node == v.c.node))
         s = s.where(v.c.cluster != except_cluster)
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))