backend components in SQLAlchemy: Progress IΙ
authorSofia Papagiannaki <papagian@gmail.com>
Fri, 12 Aug 2011 14:25:38 +0000 (17:25 +0300)
committerSofia Papagiannaki <papagian@gmail.com>
Fri, 12 Aug 2011 14:25:38 +0000 (17:25 +0300)
pithos/backends/lib_alchemy/node.py
pithos/backends/lib_alchemy/xfeatures.py

index 3deac74..359d237 100644 (file)
 
 from time import time
 from sqlalchemy import Table, Integer, Column, String, MetaData, ForeignKey
-from sqlalchemy.schema import Index
+from sqlalchemy.schema import Index, Sequence
+from sqlalchemy.sql import func, and_, null
+from sqlalchemy.sql import select
 
 from dbworker import DBWorker
 
 
-ROOTNODE  = 0
+ROOTNODE  = 1
 
 ( SERIAL, NODE, SIZE, SOURCE, MTIME, MUSER, CLUSTER ) = range(7)
 
@@ -110,7 +112,7 @@ class Node(DBWorker):
                               ForeignKey('nodes.node',
                                          ondelete='CASCADE',
                                          onupdate='CASCADE'),
-                              autoincrement=False, default=0))
+                              autoincrement=False))
         #columns.append(Column('path', String(2048), default='', nullable=False))
         columns.append(Column('path', String(255), default='', nullable=False))
         self.nodes = Table('nodes', metadata, *columns)
@@ -124,36 +126,30 @@ class Node(DBWorker):
                                          ondelete='CASCADE',
                                          onupdate='CASCADE'),
                               primary_key=True))
-        columns.append(Column('population', Integer, nullable=False,
-                              autoincrement=False, default=0))
-        columns.append(Column('size', Integer, nullable=False,
-                              autoincrement=False, default=0))
-        columns.append(Column('mtime', Integer, autoincrement=False))
-        columns.append(Column('cluster', Integer, nullable=False,
-                              autoincrement=False, default=0, primary_key=True))
+        columns.append(Column('population', Integer, nullable=False, default=0))
+        columns.append(Column('size', Integer, nullable=False, default=0))
+        columns.append(Column('mtime', Integer))
+        columns.append(Column('cluster', Integer, nullable=False, default=0,
+                              primary_key=True))
         self.statistics = Table('statistics', metadata, *columns)
         
         #create versions table
         columns=[]
-        columns.append(Column('serial', Integer, autoincrement=False,
-                              primary_key=True))
+        columns.append(Column('serial', Integer, primary_key=True))
         columns.append(Column('node', Integer,
                               ForeignKey('nodes.node',
                                          ondelete='CASCADE',
-                                         onupdate='CASCADE'),
-                              autoincrement=False))
-        columns.append(Column('size', Integer, nullable=False,
-                              autoincrement=False, default=0))
-        columns.append(Column('source', Integer, autoincrement=False))
-        columns.append(Column('mtime', Integer, autoincrement=False))
+                                         onupdate='CASCADE')))
+        columns.append(Column('size', Integer, nullable=False, default=0))
+        columns.append(Column('source', Integer))
+        columns.append(Column('mtime', Integer))
         columns.append(Column('muser', String(255), nullable=False, default=''))
-        columns.append(Column('cluster', Integer, nullable=False,
-                              autoincrement=False, default=0))
+        columns.append(Column('cluster', Integer, nullable=False, default=0))
         self.versions = Table('versions', metadata, *columns)
         # place an index on node
-        Index('idx_versions_node', self.versions.c.mtime)
+        Index('idx_versions_node', self.versions.c.node)
         # TODO: Sort out if more indexes are needed.
-        #Index('idx_versions_node', self.versions.c.node)
+        #Index('idx_versions_node', self.versions.c.mtime)
         
         #create attributes table
         columns = []
@@ -161,7 +157,6 @@ class Node(DBWorker):
                               ForeignKey('versions.serial',
                                          ondelete='CASCADE',
                                          onupdate='CASCADE'),
-                              autoincrement=False,
                               primary_key=True))
         columns.append(Column('key', String(255), primary_key=True))
         columns.append(Column('value', String(255)))
@@ -169,29 +164,35 @@ class Node(DBWorker):
         
         metadata.create_all(self.engine)
         
-        s = self.nodes.insert(node=ROOTNODE, parent=ROOTNODE)
-        self.conn.execute(s)
+        s = self.nodes.select().where(and_(self.nodes.c.node == 1,
+                                           self.nodes.c.parent == 1))
+        r = self.conn.execute(s).fetchone()
+        if not r:
+            s = self.nodes.insert().values(node=ROOTNODE, parent=ROOTNODE)
+            self.conn.execute(s)
     
     def node_create(self, parent, path):
         """Create a new node from the given properties.
            Return the node identifier of the new node.
         """
-        
-        q = ("insert into nodes (parent, path) "
-             "values (?, ?)")
-        props = (parent, path)
-        return self.execute(q, props).lastrowid
+        #TODO catch IntegrityError?
+        s = self.nodes.insert().values(parent=parent, path=path)
+        r = self.conn.execute(s)
+        inserted_primary_key = r.inserted_primary_key[0]
+        r.close()
+        return inserted_primary_key
     
     def node_lookup(self, path):
         """Lookup the current node of the given path.
            Return None if the path is not found.
         """
         
-        q = "select node from nodes where path = ?"
-        self.execute(q, (path,))
-        r = self.fetchone()
-        if r is not None:
-            return r[0]
+        s = select([self.nodes.c.node], self.nodes.c.path == path)
+        r = self.conn.execute(s)
+        row = r.fetchone()
+        r.close()
+        if row:
+            return row[0]
         return None
     
     def node_get_properties(self, node):
@@ -199,9 +200,12 @@ class Node(DBWorker):
            Return None if the node is not found.
         """
         
-        q = "select parent, path from nodes where node = ?"
-        self.execute(q, (node,))
-        return self.fetchone()
+        s = select([self.nodes.c.parent, self.nodes.c.path])
+        s = s.where(self.nodes.c.node == node)
+        r = self.conn.execute(s)
+        l = r.fetchone()
+        r.close()
+        return l
     
     def node_get_versions(self, node, keys=(), propnames=_propnames):
         """Return the properties of all versions at node.
@@ -209,27 +213,27 @@ class Node(DBWorker):
            (serial, node, size, source, mtime, muser, cluster).
         """
         
-        q = ("select serial, node, size, source, mtime, muser, cluster "
-             "from versions "
-             "where node = ?")
-        self.execute(q, (node,))
-        r = self.fetchall()
-        if r is None:
-            return r
+        s = select(['*'], self.versions.c.node == node)
+        r = self.conn.execute(s)
+        rows = r.fetchall()
+        if not rows:
+            return rows
         
         if not keys:
-            return r
-        return [[p[propnames[k]] for k in keys if k in propnames] for p in r]
+            return rows
+        
+        return [[p[propnames[k]] for k in keys if k in propnames] for p in rows]
     
     def node_count_children(self, node):
         """Return node's child count."""
         
-        q = "select count(node) from nodes where parent = ? and node != 0"
-        self.execute(q, (node,))
-        r = self.fetchone()
-        if r is None:
-            return 0
-        return r[0]
+        s = select([func.count(self.nodes.c.node)])
+        s = s.where(and_(self.nodes.c.parent == node,
+                         self.nodes.c.node != ROOTNODE))
+        r = self.conn.execute(s)
+        row = r.fetchone()
+        r.close()
+        return row[0]
     
     def node_purge_children(self, parent, before=inf, cluster=0):
         """Delete all versions with the specified
@@ -238,44 +242,46 @@ class Node(DBWorker):
            Clears out nodes with no remaining versions.
         """
         
-        execute = self.execute
-        q = ("select count(serial), sum(size) from versions "
-             "where node in (select node "
-                            "from nodes "
-                            "where parent = ?) "
-             "and cluster = ? "
-             "and mtime <= ?")
-        args = (parent, cluster, before)
-        execute(q, args)
-        nr, size = self.fetchone()
-        if not nr:
+        scalar = select([self.nodes.c.node],
+            self.nodes.c.parent == parent).as_scalar()
+        where_clause = and_(self.versions.c.node.in_(scalar),
+                            self.versions.c.cluster == cluster,
+                            "versions.mtime <= %f" %before)
+        s = select([func.count(self.versions.c.serial),
+                    func.sum(self.versions.c.size)])
+        s = s.where(where_clause)
+        r = self.conn.execute(s)
+        row = r.fetchone()
+        r.close()
+        if not row:
             return ()
+        nr, size = row[0], -row[1] if row[1] else 0
         mtime = time()
-        self.statistics_update(parent, -nr, -size, mtime, cluster)
-        self.statistics_update_ancestors(parent, -nr, -size, mtime, cluster)
-        
-        q = ("select serial from versions "
-             "where node in (select node "
-                            "from nodes "
-                            "where parent = ?) "
-             "and cluster = ? "
-             "and mtime <= ?")
-        execute(q, args)
-        serials = [r[SERIAL] for r in self.fetchall()]
-        q = ("delete from versions "
-             "where node in (select node "
-                            "from nodes "
-                            "where parent = ?) "
-             "and cluster = ? "
-             "and mtime <= ?")
-        execute(q, args)
-        q = ("delete from nodes "
-             "where node in (select node from nodes n "
-                            "where (select count(serial) "
-                                   "from versions "
-                                   "where node = n.node) = 0 "
-                            "and parent = ?)")
-        execute(q, (parent,))
+        print '#', parent, -nr, size, mtime, cluster
+        self.statistics_update(parent, -nr, size, mtime, cluster)
+        self.statistics_update_ancestors(parent, -nr, size, mtime, cluster)
+        
+        s = select([self.versions.c.serial])
+        s = s.where(where_clause)
+        r = self.conn.execute(s)
+        serials = [row[SERIAL] for row in r.fetchall()]
+        r.close()
+        
+        #delete versiosn
+        s = self.versions.delete().where(where_clause)
+        r = self.conn.execute(s)
+        r.close()
+        
+        #delete nodes
+        a = self.nodes.alias()
+        no_versions = select([func.count(self.versions.c.serial)],
+            self.versions.c.node == a.c.node).as_scalar() == 0
+        n = select([self.nodes.c.node],
+            and_(no_versions, self.nodes.c.parent == parent))
+        s = s.where(self.nodes.c.node.in_(n))
+        s = self.nodes.delete().where(self.nodes.c.node == s)
+        print '#', s
+        self.conn.execute(s).close()
         return serials
     
     def node_purge(self, node, before=inf, cluster=0):
@@ -356,19 +362,22 @@ class Node(DBWorker):
            May be zero or positive or negative numbers.
         """
         
-        qs = ("select population, size from statistics "
-              "where node = ? and cluster = ?")
-        qu = ("insert or replace into statistics (node, population, size, mtime, cluster) "
-              "values (?, ?, ?, ?, ?)")
-        self.execute(qs, (node, cluster))
-        r = self.fetchone()
-        if r is None:
+        s = select([self.statistics.c.population, self.statistics.c.size],
+            and_(self.statistics.c.node == node,
+                 self.statistics.c.cluster == cluster))
+        res = self.conn.execute(s)
+        r = res.fetchone()
+        res.close()
+        if not r:
             prepopulation, presize = (0, 0)
         else:
             prepopulation, presize = r
         population += prepopulation
         size += presize
-        self.execute(qu, (node, population, size, mtime, cluster))
+        
+        self.statistics.insert().values(node=node, population=population,
+                                        size=size, mtime=mtime, cluster=cluster)
+        self.conn.execute(s).close()
     
     def statistics_update_ancestors(self, node, population, size, mtime, cluster=0):
         """Update the statistics of the given node's parent.
@@ -377,7 +386,7 @@ class Node(DBWorker):
         """
         
         while True:
-            if node == 0:
+            if node == ROOTNODE:
                 break
             props = self.node_get_properties(node)
             if props is None:
index 8ff0d5c..70901cb 100644 (file)
@@ -60,8 +60,7 @@ class XFeatures(DBWorker):
                               ForeignKey('xfeatures.feature_id',
                                          ondelete='CASCADE'),
                               primary_key=True))
-        columns.append(Column('key', Integer, autoincrement=False,
-                              primary_key=True))
+        columns.append(Column('key', Integer, primary_key=True))
         columns.append(Column('value', String(255), primary_key=True))
         self.xfeaturevals = Table('xfeaturevals', metadata, *columns)