include milliseconds in mtime - specify decimal precision
[pithos] / pithos / backends / lib / sqlalchemy / node.py
index 585c7c0..255fe95 100644 (file)
@@ -32,7 +32,7 @@
 # or implied, of GRNET S.A.
 
 from time import time
-from sqlalchemy import Table, Integer, BigInteger, Float, Column, String, MetaData, ForeignKey
+from sqlalchemy import Table, Integer, BigInteger, DECIMAL, Column, String, MetaData, ForeignKey
 from sqlalchemy.schema import Index, Sequence
 from sqlalchemy.sql import func, and_, or_, null, select, bindparam
 from sqlalchemy.ext.compiler import compiles
@@ -115,9 +115,9 @@ class Node(DBWorker):
                                          onupdate='CASCADE'),
                               autoincrement=False))
         columns.append(Column('path', String(2048), default='', nullable=False))
-        self.nodes = Table('nodes', metadata, *columns)
+        self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
         # place an index on path
-        Index('idx_nodes_path', self.nodes.c.path, unique=True)
+        Index('idx_nodes_path', self.nodes.c.path)
         
         #create statistics table
         columns=[]
@@ -128,10 +128,10 @@ class Node(DBWorker):
                               primary_key=True))
         columns.append(Column('population', Integer, nullable=False, default=0))
         columns.append(Column('size', BigInteger, nullable=False, default=0))
-        columns.append(Column('mtime', Float))
+        columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
         columns.append(Column('cluster', Integer, nullable=False, default=0,
-                              primary_key=True))
-        self.statistics = Table('statistics', metadata, *columns)
+                              primary_key=True, autoincrement=False))
+        self.statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
         
         #create versions table
         columns=[]
@@ -143,10 +143,10 @@ class Node(DBWorker):
         columns.append(Column('hash', String(255)))
         columns.append(Column('size', BigInteger, nullable=False, default=0))
         columns.append(Column('source', Integer))
-        columns.append(Column('mtime', Float))
+        columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
         columns.append(Column('muser', String(255), nullable=False, default=''))
         columns.append(Column('cluster', Integer, nullable=False, default=0))
-        self.versions = Table('versions', metadata, *columns)
+        self.versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
         Index('idx_versions_node_mtime', self.versions.c.node,
               self.versions.c.mtime)
         
@@ -159,7 +159,7 @@ class Node(DBWorker):
                               primary_key=True))
         columns.append(Column('key', String(255), primary_key=True))
         columns.append(Column('value', String(255)))
-        self.attributes = Table('attributes', metadata, *columns)
+        self.attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
         
         metadata.create_all(self.engine)
         
@@ -255,11 +255,12 @@ class Node(DBWorker):
         c1 = select([self.nodes.c.node],
             self.nodes.c.parent == parent)
         where_clause = and_(self.versions.c.node.in_(c1),
-                            self.versions.c.cluster == cluster,
-                            self.versions.c.mtime <= before)
+                            self.versions.c.cluster == cluster)
         s = select([func.count(self.versions.c.serial),
                     func.sum(self.versions.c.size)])
         s = s.where(where_clause)
+        if before != inf:
+            s = s.where(self.versions.c.mtime <= before)
         r = self.conn.execute(s)
         row = r.fetchone()
         r.close()
@@ -305,9 +306,10 @@ class Node(DBWorker):
         s = select([func.count(self.versions.c.serial),
                     func.sum(self.versions.c.size)])
         where_clause = and_(self.versions.c.node == node,
-                         self.versions.c.cluster == cluster,
-                         self.versions.c.mtime <= before)
+                         self.versions.c.cluster == cluster)
         s = s.where(where_clause)
+        if before != inf:
+            s = s.where(self.versions.c.mtime <= before)
         r = self.conn.execute(s)
         row = r.fetchone()
         nr, size = row[0], row[1]
@@ -450,11 +452,12 @@ class Node(DBWorker):
                     self.versions.c.mtime,
                     self.versions.c.muser,
                     self.versions.c.cluster])
+        filtered = select([func.max(self.versions.c.serial)],
+                            self.versions.c.node == node)
+        if before != inf:
+            filtered = filtered.where(self.versions.c.mtime < before)
         s = s.where(and_(self.versions.c.cluster != except_cluster,
-                         self.versions.c.serial == select(
-                            [func.max(self.versions.c.serial)],
-                            and_(self.versions.c.node == node,
-                            self.versions.c.mtime < before))))
+                         self.versions.c.serial == filtered))
         r = self.conn.execute(s)
         props = r.fetchone()
         r.close()
@@ -467,11 +470,11 @@ class Node(DBWorker):
         s = select([func.count(v.c.serial),
                     func.sum(v.c.size),
                     func.max(v.c.mtime)])
-        c1 = select([func.max(self.versions.c.serial)],
-            and_(self.versions.c.node == v.c.node,
-                 self.versions.c.mtime < before))
+        c1 = select([func.max(self.versions.c.serial)])
+        if before != inf:
+            c1 = c1.where(self.versions.c.mtime < before)
         c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
-        s = s.where(and_(v.c.serial == c1,
+        s = s.where(and_(v.c.serial == c1.where(self.versions.c.node == v.c.node),
                          v.c.cluster != except_cluster,
                          v.c.node.in_(c2)))
         rp = self.conn.execute(s)
@@ -490,8 +493,9 @@ class Node(DBWorker):
                     func.sum(v.c.size),
                     func.max(v.c.mtime)])
         c1 = select([func.max(self.versions.c.serial)],
-            and_(self.versions.c.node == v.c.node,
-                 self.versions.c.mtime < before))
+            self.versions.c.node == v.c.node)
+        if before != inf:
+            c1 = c1.where(self.versions.c.mtime < before)
         c2 = select([self.nodes.c.node], self.nodes.c.path.like(path + '%'))
         s = s.where(and_(v.c.serial == c1,
                          v.c.cluster != except_cluster,
@@ -528,8 +532,9 @@ class Node(DBWorker):
         s = select([v.c.serial, v.c.node, v.c.hash, v.c.size,
                     v.c.source, v.c.mtime, v.c.muser, v.c.cluster])
         c = select([func.max(self.versions.c.serial)],
-            and_(self.versions.c.node == node,
-                 self.versions.c.mtime < before))
+            self.versions.c.node == node)
+        if before != inf:
+            c = c.where(self.versions.c.mtime < before)
         s = s.where(and_(v.c.serial == c,
                          v.c.cluster == cluster))
         r = self.conn.execute(s)
@@ -682,9 +687,10 @@ class Node(DBWorker):
         v = self.versions.alias('v')
         n = self.nodes.alias('n')
         s = select([a.c.key]).distinct()
-        s = s.where(v.c.serial == select([func.max(self.versions.c.serial)],
-                                          and_(self.versions.c.node == v.c.node,
-                                               self.versions.c.mtime < before)))
+        filtered = select([func.max(self.versions.c.serial)])
+        if before != inf:
+            filtered = filtered.where(self.versions.c.mtime < before)
+        s = s.where(v.c.serial == filtered.where(self.versions.c.node == v.c.node))
         s = s.where(v.c.cluster != except_cluster)
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))
@@ -757,9 +763,10 @@ class Node(DBWorker):
         v = self.versions.alias('v')
         n = self.nodes.alias('n')
         s = select([n.c.path, v.c.serial]).distinct()
-        s = s.where(v.c.serial == select([func.max(self.versions.c.serial)],
-            and_(self.versions.c.node == v.c.node,
-                 self.versions.c.mtime < before)))
+        filtered = select([func.max(self.versions.c.serial)])
+        if before != inf:
+            filtered = filtered.where(self.versions.c.mtime < before)
+        s = s.where(v.c.serial == filtered.where(self.versions.c.node == v.c.node))
         s = s.where(v.c.cluster != except_cluster)
         s = s.where(v.c.node.in_(select([self.nodes.c.node],
             self.nodes.c.parent == parent)))