use alembic to initialize the backend database
[pithos] / snf-pithos-backend / pithos / backends / lib / sqlalchemy / public.py
index 25bf0f3..a132980 100644 (file)
@@ -35,22 +35,31 @@ from dbworker import DBWorker
 from sqlalchemy import Table, Column, String, Integer, Boolean, MetaData
 from sqlalchemy.sql import and_, select
 from sqlalchemy.schema import Index
+from sqlalchemy.exc import NoSuchTableError
 
+def create_tables(engine):
+    metadata = MetaData()
+    columns=[]
+    columns.append(Column('public_id', Integer, primary_key=True))
+    columns.append(Column('path', String(2048), nullable=False))
+    columns.append(Column('active', Boolean, nullable=False, default=True))
+    public = Table('public', metadata, *columns, mysql_engine='InnoDB', sqlite_autoincrement=True)
+    # place an index on path
+    Index('idx_public_path', public.c.path, unique=True)
+    metadata.create_all(engine)
+    return metadata.sorted_tables
 
 class Public(DBWorker):
     """Paths can be marked as public."""
     
     def __init__(self, **params):
         DBWorker.__init__(self, **params)
-        metadata = MetaData()
-        columns=[]
-        columns.append(Column('public_id', Integer, primary_key=True))
-        columns.append(Column('path', String(2048), nullable=False))
-        columns.append(Column('active', Boolean, nullable=False, default=True))
-        self.public = Table('public', metadata, *columns, mysql_engine='InnoDB', sqlite_autoincrement=True)
-        # place an index on path
-        Index('idx_public_path', self.public.c.path, unique=True)
-        metadata.create_all(self.engine)
+        try:
+            metadata = MetaData(self.engine)
+            self.public = Table('public', metadata, autoload=True)
+        except NoSuchTableError:
+            tables = create_tables(self.engine)
+            map(lambda t: self.__setattr__(t.name, t), tables)
     
     def public_set(self, path):
         s = select([self.public.c.public_id])
@@ -74,6 +83,13 @@ class Public(DBWorker):
         r = self.conn.execute(s)
         r.close()
     
+    def public_unset_bulk(self, paths):
+        s = self.public.update()
+        s = s.where(self.public.c.path.in_(paths))
+        s = s.values(active=False)
+        r = self.conn.execute(s)
+        r.close()
+    
     def public_get(self, path):
         s = select([self.public.c.public_id])
         s = s.where(and_(self.public.c.path == path,