use alembic to initialize the backend database
authorSofia Papagiannaki <papagian@gmail.com>
Wed, 18 Jul 2012 14:41:57 +0000 (17:41 +0300)
committerSofia Papagiannaki <papagian@gmail.com>
Wed, 18 Jul 2012 14:41:57 +0000 (17:41 +0300)
snf-pithos-backend/pithos/backends/lib/sqlalchemy/groups.py
snf-pithos-backend/pithos/backends/lib/sqlalchemy/node.py
snf-pithos-backend/pithos/backends/lib/sqlalchemy/public.py
snf-pithos-backend/pithos/backends/lib/sqlalchemy/xfeatures.py
snf-pithos-backend/pithos/backends/migrate.py

index d931c31..f75a446 100644 (file)
@@ -35,25 +35,36 @@ from collections import defaultdict
 from sqlalchemy import Table, Column, String, MetaData
 from sqlalchemy.sql import select, and_
 from sqlalchemy.schema import Index
+from sqlalchemy.exc import NoSuchTableError
+
 from dbworker import DBWorker
 
+def create_tables(engine):
+    metadata = MetaData()
+    columns=[]
+    columns.append(Column('owner', String(256), primary_key=True))
+    columns.append(Column('name', String(256), primary_key=True))
+    columns.append(Column('member', String(256), primary_key=True))
+    groups = Table('groups', metadata, *columns, mysql_engine='InnoDB')
+    
+    # place an index on member
+    Index('idx_groups_member', groups.c.member)
+        
+    metadata.create_all(engine)
+    return metadata.sorted_tables
+    
 class Groups(DBWorker):
     """Groups are named collections of members, belonging to an owner."""
     
     def __init__(self, **params):
         DBWorker.__init__(self, **params)
-        metadata = MetaData()
-        columns=[]
-        columns.append(Column('owner', String(256), primary_key=True))
-        columns.append(Column('name', String(256), primary_key=True))
-        columns.append(Column('member', String(256), primary_key=True))
-        self.groups = Table('groups', metadata, *columns, mysql_engine='InnoDB')
-        
-        # place an index on member
-        Index('idx_groups_member', self.groups.c.member)
+        try:
+            metadata = MetaData(self.engine)
+            self.groups = Table('groups', metadata, autoload=True)
+        except NoSuchTableError:
+            tables = create_tables(self.engine)
+            map(lambda t: self.__setattr__(t.name, t), tables)
         
-        metadata.create_all(self.engine)
-    
     def group_names(self, owner):
         """List all group names belonging to owner."""
         
index 8ef828c..a72f9e5 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy.schema import Index, Sequence
 from sqlalchemy.sql import func, and_, or_, not_, null, select, bindparam, text, exists
 from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.engine.reflection import Inspector
+from sqlalchemy.exc import NoSuchTableError
 
 from dbworker import DBWorker
 
@@ -101,6 +102,82 @@ _propnames = {
     'cluster'   : 10
 }
 
+def create_tables(engine):
+    metadata = MetaData()
+        
+    #create nodes table
+    columns=[]
+    columns.append(Column('node', Integer, primary_key=True))
+    columns.append(Column('parent', Integer,
+                          ForeignKey('nodes.node',
+                                     ondelete='CASCADE',
+                                     onupdate='CASCADE'),
+                          autoincrement=False))
+    columns.append(Column('latest_version', Integer))
+    columns.append(Column('path', String(2048), default='', nullable=False))
+    nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
+    Index('idx_nodes_path', nodes.c.path, unique=True)
+    Index('idx_nodes_parent', nodes.c.parent)
+    
+    #create policy table
+    columns=[]
+    columns.append(Column('node', Integer,
+                          ForeignKey('nodes.node',
+                                     ondelete='CASCADE',
+                                     onupdate='CASCADE'),
+                          primary_key=True))
+    columns.append(Column('key', String(128), primary_key=True))
+    columns.append(Column('value', String(256)))
+    policy = Table('policy', metadata, *columns, mysql_engine='InnoDB')
+    
+    #create statistics table
+    columns=[]
+    columns.append(Column('node', Integer,
+                          ForeignKey('nodes.node',
+                                     ondelete='CASCADE',
+                                     onupdate='CASCADE'),
+                          primary_key=True))
+    columns.append(Column('population', Integer, nullable=False, default=0))
+    columns.append(Column('size', BigInteger, nullable=False, default=0))
+    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
+    columns.append(Column('cluster', Integer, nullable=False, default=0,
+                          primary_key=True, autoincrement=False))
+    statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
+    
+    #create versions table
+    columns=[]
+    columns.append(Column('serial', Integer, primary_key=True))
+    columns.append(Column('node', Integer,
+                          ForeignKey('nodes.node',
+                                     ondelete='CASCADE',
+                                     onupdate='CASCADE')))
+    columns.append(Column('hash', String(256)))
+    columns.append(Column('size', BigInteger, nullable=False, default=0))
+    columns.append(Column('type', String(256), nullable=False, default=''))
+    columns.append(Column('source', Integer))
+    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
+    columns.append(Column('muser', String(256), nullable=False, default=''))
+    columns.append(Column('uuid', String(64), nullable=False, default=''))
+    columns.append(Column('checksum', String(256), nullable=False, default=''))
+    columns.append(Column('cluster', Integer, nullable=False, default=0))
+    versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
+    Index('idx_versions_node_mtime', versions.c.node, versions.c.mtime)
+    Index('idx_versions_node_uuid', versions.c.uuid)
+    
+    #create attributes table
+    columns = []
+    columns.append(Column('serial', Integer,
+                          ForeignKey('versions.serial',
+                                     ondelete='CASCADE',
+                                     onupdate='CASCADE'),
+                          primary_key=True))
+    columns.append(Column('domain', String(256), primary_key=True))
+    columns.append(Column('key', String(128), primary_key=True))
+    columns.append(Column('value', String(256)))
+    attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
+    
+    metadata.create_all(engine)
+    return metadata.sorted_tables
 
 class Node(DBWorker):
     """Nodes store path organization and have multiple versions.
@@ -112,80 +189,16 @@ class Node(DBWorker):
     
     def __init__(self, **params):
         DBWorker.__init__(self, **params)
-        metadata = MetaData()
-        
-        #create nodes table
-        columns=[]
-        columns.append(Column('node', Integer, primary_key=True))
-        columns.append(Column('parent', Integer,
-                              ForeignKey('nodes.node',
-                                         ondelete='CASCADE',
-                                         onupdate='CASCADE'),
-                              autoincrement=False))
-        columns.append(Column('latest_version', Integer))
-        columns.append(Column('path', String(2048), default='', nullable=False))
-        self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
-        Index('idx_nodes_path', self.nodes.c.path, unique=True)
-        Index('idx_nodes_parent', self.nodes.c.parent)
-        
-        #create policy table
-        columns=[]
-        columns.append(Column('node', Integer,
-                              ForeignKey('nodes.node',
-                                         ondelete='CASCADE',
-                                         onupdate='CASCADE'),
-                              primary_key=True))
-        columns.append(Column('key', String(128), primary_key=True))
-        columns.append(Column('value', String(256)))
-        self.policies = Table('policy', metadata, *columns, mysql_engine='InnoDB')
-        
-        #create statistics table
-        columns=[]
-        columns.append(Column('node', Integer,
-                              ForeignKey('nodes.node',
-                                         ondelete='CASCADE',
-                                         onupdate='CASCADE'),
-                              primary_key=True))
-        columns.append(Column('population', Integer, nullable=False, default=0))
-        columns.append(Column('size', BigInteger, nullable=False, default=0))
-        columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
-        columns.append(Column('cluster', Integer, nullable=False, default=0,
-                              primary_key=True, autoincrement=False))
-        self.statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
-        
-        #create versions table
-        columns=[]
-        columns.append(Column('serial', Integer, primary_key=True))
-        columns.append(Column('node', Integer,
-                              ForeignKey('nodes.node',
-                                         ondelete='CASCADE',
-                                         onupdate='CASCADE')))
-        columns.append(Column('hash', String(256)))
-        columns.append(Column('size', BigInteger, nullable=False, default=0))
-        columns.append(Column('type', String(256), nullable=False, default=''))
-        columns.append(Column('source', Integer))
-        columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
-        columns.append(Column('muser', String(256), nullable=False, default=''))
-        columns.append(Column('uuid', String(64), nullable=False, default=''))
-        columns.append(Column('checksum', String(256), nullable=False, default=''))
-        columns.append(Column('cluster', Integer, nullable=False, default=0))
-        self.versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
-        Index('idx_versions_node_mtime', self.versions.c.node, self.versions.c.mtime)
-        Index('idx_versions_node_uuid', self.versions.c.uuid)
-        
-        #create attributes table
-        columns = []
-        columns.append(Column('serial', Integer,
-                              ForeignKey('versions.serial',
-                                         ondelete='CASCADE',
-                                         onupdate='CASCADE'),
-                              primary_key=True))
-        columns.append(Column('domain', String(256), primary_key=True))
-        columns.append(Column('key', String(128), primary_key=True))
-        columns.append(Column('value', String(256)))
-        self.attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
-        
-        metadata.create_all(self.engine)
+        try:
+            metadata = MetaData(self.engine)
+            self.nodes = Table('nodes', metadata, autoload=True)
+            self.policy = Table('policy', metadata, autoload=True)
+            self.statistics = Table('statistics', metadata, autoload=True)
+            self.versions = Table('versions', metadata, autoload=True)
+            self.attributes = Table('attributes', metadata, autoload=True)
+        except NoSuchTableError:
+            tables = create_tables(self.engine)
+            map(lambda t: self.__setattr__(t.name, t), tables)
         
         s = self.nodes.select().where(and_(self.nodes.c.node == ROOTNODE,
                                            self.nodes.c.parent == ROOTNODE))
@@ -193,7 +206,7 @@ class Node(DBWorker):
         r = rp.fetchone()
         rp.close()
         if not r:
-            s = self.nodes.insert().values(node=ROOTNODE, parent=ROOTNODE)
+            s = self.nodes.insert().values(node=ROOTNODE, parent=ROOTNODE, path='')
             self.conn.execute(s)
     
     def node_create(self, parent, path):
@@ -407,8 +420,8 @@ class Node(DBWorker):
         return True
     
     def policy_get(self, node):
-        s = select([self.policies.c.key, self.policies.c.value],
-            self.policies.c.node==node)
+        s = select([self.policy.c.key, self.policy.c.value],
+            self.policy.c.node==node)
         r = self.conn.execute(s)
         d = dict(r.fetchall())
         r.close()
@@ -417,13 +430,13 @@ class Node(DBWorker):
     def policy_set(self, node, policy):
         #insert or replace
         for k, v in policy.iteritems():
-            s = self.policies.update().where(and_(self.policies.c.node == node,
-                                                  self.policies.c.key == k))
+            s = self.policy.update().where(and_(self.policy.c.node == node,
+                                                  self.policy.c.key == k))
             s = s.values(value = v)
             rp = self.conn.execute(s)
             rp.close()
             if rp.rowcount == 0:
-                s = self.policies.insert()
+                s = self.policy.insert()
                 values = {'node':node, 'key':k, 'value':v}
                 r = self.conn.execute(s, values)
                 r.close()
index bb06282..a132980 100644 (file)
@@ -35,22 +35,31 @@ from dbworker import DBWorker
 from sqlalchemy import Table, Column, String, Integer, Boolean, MetaData
 from sqlalchemy.sql import and_, select
 from sqlalchemy.schema import Index
+from sqlalchemy.exc import NoSuchTableError
 
+def create_tables(engine):
+    metadata = MetaData()
+    columns=[]
+    columns.append(Column('public_id', Integer, primary_key=True))
+    columns.append(Column('path', String(2048), nullable=False))
+    columns.append(Column('active', Boolean, nullable=False, default=True))
+    public = Table('public', metadata, *columns, mysql_engine='InnoDB', sqlite_autoincrement=True)
+    # place an index on path
+    Index('idx_public_path', public.c.path, unique=True)
+    metadata.create_all(engine)
+    return metadata.sorted_tables
 
 class Public(DBWorker):
     """Paths can be marked as public."""
     
     def __init__(self, **params):
         DBWorker.__init__(self, **params)
-        metadata = MetaData()
-        columns=[]
-        columns.append(Column('public_id', Integer, primary_key=True))
-        columns.append(Column('path', String(2048), nullable=False))
-        columns.append(Column('active', Boolean, nullable=False, default=True))
-        self.public = Table('public', metadata, *columns, mysql_engine='InnoDB', sqlite_autoincrement=True)
-        # place an index on path
-        Index('idx_public_path', self.public.c.path, unique=True)
-        metadata.create_all(self.engine)
+        try:
+            metadata = MetaData(self.engine)
+            self.public = Table('public', metadata, autoload=True)
+        except NoSuchTableError:
+            tables = create_tables(self.engine)
+            map(lambda t: self.__setattr__(t.name, t), tables)
     
     def public_set(self, path):
         s = select([self.public.c.public_id])
index d4c45d8..9604416 100644 (file)
@@ -36,9 +36,31 @@ from sqlalchemy import Table, Column, String, Integer, MetaData, ForeignKey
 from sqlalchemy.sql import select, and_
 from sqlalchemy.schema import Index
 from sqlalchemy.sql.expression import desc
+from sqlalchemy.exc import NoSuchTableError
 
 from dbworker import DBWorker
 
+def create_tables(engine):
+    metadata = MetaData()
+    columns=[]
+    columns.append(Column('feature_id', Integer, primary_key=True))
+    columns.append(Column('path', String(2048)))
+    xfeatures = Table('xfeatures', metadata, *columns, mysql_engine='InnoDB')
+    # place an index on path
+    Index('idx_features_path', xfeatures.c.path, unique=True)
+    
+    columns=[]
+    columns.append(Column('feature_id', Integer,
+                          ForeignKey('xfeatures.feature_id',
+                                     ondelete='CASCADE'),
+                          primary_key=True))
+    columns.append(Column('key', Integer, primary_key=True,
+                          autoincrement=False))
+    columns.append(Column('value', String(256), primary_key=True))
+    xfeaturevals = Table('xfeaturevals', metadata, *columns, mysql_engine='InnoDB')
+    
+    metadata.create_all(engine)
+    return metadata.sorted_tables
 
 class XFeatures(DBWorker):
     """XFeatures are path properties that allow non-nested
@@ -47,25 +69,13 @@ class XFeatures(DBWorker):
     
     def __init__(self, **params):
         DBWorker.__init__(self, **params)
-        metadata = MetaData()
-        columns=[]
-        columns.append(Column('feature_id', Integer, primary_key=True))
-        columns.append(Column('path', String(2048)))
-        self.xfeatures = Table('xfeatures', metadata, *columns, mysql_engine='InnoDB')
-        # place an index on path
-        Index('idx_features_path', self.xfeatures.c.path, unique=True)
-        
-        columns=[]
-        columns.append(Column('feature_id', Integer,
-                              ForeignKey('xfeatures.feature_id',
-                                         ondelete='CASCADE'),
-                              primary_key=True))
-        columns.append(Column('key', Integer, primary_key=True,
-                              autoincrement=False))
-        columns.append(Column('value', String(256), primary_key=True))
-        self.xfeaturevals = Table('xfeaturevals', metadata, *columns, mysql_engine='InnoDB')
-        
-        metadata.create_all(self.engine)
+        try:
+            metadata = MetaData(self.engine)
+            self.xfeatures = Table('xfeatures', metadata, autoload=True)
+            self.xfeaturevals = Table('xfeaturevals', metadata, autoload=True)
+        except NoSuchTableError:
+            tables = create_tables(self.engine)
+            map(lambda t: self.__setattr__(t.name, t), tables)
     
 #     def xfeature_inherit(self, path):
 #         """Return the (path, feature) inherited by the path, or None."""
index 392834f..29263c7 100644 (file)
@@ -46,19 +46,43 @@ e.g::
 import sys
 import os
 
-from alembic.config import main as alembic_main
+from alembic.config import main as alembic_main, Config
+from alembic import context, command
+
 from pithos.backends.lib import sqlalchemy as sqlalchemy_backend
+from pithos.backends.lib.sqlalchemy import node, groups, public, xfeatures
+
+import sqlalchemy as sa
 
 DEFAULT_ALEMBIC_INI_PATH = os.path.join(
         os.path.abspath(os.path.dirname(sqlalchemy_backend.__file__)),
         'alembic.ini')
 
+def initialize_db():
+    alembic_cfg = Config(DEFAULT_ALEMBIC_INI_PATH)
+    engine = sa.engine_from_config(
+                alembic_cfg.get_section(alembic_cfg.config_ini_section), prefix='sqlalchemy.')
+    node.create_tables(engine)
+    groups.create_tables(engine)
+    public.create_tables(engine)
+    xfeatures.create_tables(engine)
+    
+    # then, load the Alembic configuration and generate the
+    # version table, "stamping" it with the most recent rev:
+    command.stamp(alembic_cfg, "head")
+
+
+
 def main(argv=None, **kwargs):
     if not argv:
         argv = sys.argv
-
+    
     # clean up args
     argv.pop(0)
+    
+    if argv[0] == 'initdb':
+        initialize_db()
+        return
 
     # default config arg, if not already set
     if not '-c' in argv: