Merge policy into node. Needs database reset, or the following commands:
[pithos] / pithos / backends / lib / sqlalchemy / node.py
index b865382..3ca5eef 100644 (file)
 # or implied, of GRNET S.A.
 
 from time import time
-from sqlalchemy import Table, Integer, BigInteger, Float, Column, String, MetaData, ForeignKey
+from sqlalchemy import Table, Integer, BigInteger, DECIMAL, Column, String, MetaData, ForeignKey
 from sqlalchemy.schema import Index, Sequence
 from sqlalchemy.sql import func, and_, or_, null, select, bindparam
 from sqlalchemy.ext.compiler import compiles
 
 from dbworker import DBWorker
 
-ROOTNODE  = 1
+ROOTNODE  = 0
 
 ( SERIAL, NODE, HASH, SIZE, SOURCE, MTIME, MUSER, CLUSTER ) = range(8)
 
@@ -115,10 +115,21 @@ class Node(DBWorker):
                                          onupdate='CASCADE'),
                               autoincrement=False))
         columns.append(Column('path', String(2048), default='', nullable=False))
-        self.nodes = Table('nodes', metadata, *columns)
+        self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
         # place an index on path
         Index('idx_nodes_path', self.nodes.c.path)
         
+        #create policy table
+        columns=[]
+        columns.append(Column('node', Integer,
+                              ForeignKey('nodes.node',
+                                         ondelete='CASCADE',
+                                         onupdate='CASCADE'),
+                              primary_key=True))
+        columns.append(Column('key', String(255), primary_key=True))
+        columns.append(Column('value', String(255)))
+        self.policies = Table('policy', metadata, *columns, mysql_engine='InnoDB')
+        
         #create statistics table
         columns=[]
         columns.append(Column('node', Integer,
@@ -128,10 +139,10 @@ class Node(DBWorker):
                               primary_key=True))
         columns.append(Column('population', Integer, nullable=False, default=0))
         columns.append(Column('size', BigInteger, nullable=False, default=0))
-        columns.append(Column('mtime', Float))
+        columns.append(Column('mtime', DECIMAL))
         columns.append(Column('cluster', Integer, nullable=False, default=0,
-                              primary_key=True))
-        self.statistics = Table('statistics', metadata, *columns)
+                              primary_key=True, autoincrement=False))
+        self.statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
         
         #create versions table
         columns=[]
@@ -143,10 +154,10 @@ class Node(DBWorker):
         columns.append(Column('hash', String(255)))
         columns.append(Column('size', BigInteger, nullable=False, default=0))
         columns.append(Column('source', Integer))
-        columns.append(Column('mtime', Float))
+        columns.append(Column('mtime', DECIMAL))
         columns.append(Column('muser', String(255), nullable=False, default=''))
         columns.append(Column('cluster', Integer, nullable=False, default=0))
-        self.versions = Table('versions', metadata, *columns)
+        self.versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
         Index('idx_versions_node_mtime', self.versions.c.node,
               self.versions.c.mtime)
         
@@ -159,7 +170,7 @@ class Node(DBWorker):
                               primary_key=True))
         columns.append(Column('key', String(255), primary_key=True))
         columns.append(Column('value', String(255)))
-        self.attributes = Table('attributes', metadata, *columns)
+        self.attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
         
         metadata.create_all(self.engine)
         
@@ -366,6 +377,28 @@ class Node(DBWorker):
         self.conn.execute(s).close()
         return True
     
+    def policy_get(self, node):
+        s = select([self.policies.c.key, self.policies.c.value],
+            self.policies.c.node==node)
+        r = self.conn.execute(s)
+        d = dict(r.fetchall())
+        r.close()
+        return d
+    
+    def policy_set(self, node, policy):
+        #insert or replace
+        for k, v in policy.iteritems():
+            s = self.policies.update().where(and_(self.policies.c.node == node,
+                                                  self.policies.c.key == k))
+            s = s.values(value = v)
+            rp = self.conn.execute(s)
+            rp.close()
+            if rp.rowcount == 0:
+                s = self.policies.insert()
+                values = {'node':node, 'key':k, 'value':v}
+                r = self.conn.execute(s, values)
+                r.close()
+    
     def statistics_get(self, node, cluster=0):
         """Return population, total size and last mtime
            for all versions under node that belong to the cluster.