Escape special characters for LIKE in node lookup.
[pithos] / pithos / backends / lib / sqlalchemy / node.py
index ba42f1e..4a0fc55 100644 (file)
@@ -37,7 +37,6 @@ from sqlalchemy.types import Text
 from sqlalchemy.schema import Index, Sequence
 from sqlalchemy.sql import func, and_, or_, null, select, bindparam, text
 from sqlalchemy.ext.compiler import compiles
-#from sqlalchemy.dialects.mysql import VARBINARY
 from sqlalchemy.engine.reflection import Inspector
 
 from dbworker import DBWorker
@@ -80,8 +79,7 @@ def strprevling(prefix):
     s = prefix[:-1]
     c = ord(prefix[-1])
     if c > 0:
-        #s += unichr(c-1) + unichr(0xffff)
-        s += unichr(c-1)
+        s += unichr(c-1) + unichr(0xffff)
     return s
 
 
@@ -118,11 +116,8 @@ class Node(DBWorker):
                                          onupdate='CASCADE'),
                               autoincrement=False))
         path_length = 2048
-        path_length_in_bytes = path_length * 4
-        columns.append(Column('path', Text(path_length_in_bytes), default='', nullable=False))
+        columns.append(Column('path', String(path_length), default='', nullable=False))
         self.nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
-        # place an index on path
-        #Index('idx_nodes_path', self.nodes.c.path)
         
         #create policy table
         columns=[]
@@ -185,7 +180,8 @@ class Node(DBWorker):
         insp = Inspector.from_engine(self.engine)
         indexes = [elem['name'] for elem in insp.get_indexes('nodes')]
         if 'idx_nodes_path' not in indexes:
-            s = text('CREATE INDEX idx_nodes_path ON nodes (path(%s))' %path_length_in_bytes)
+            explicit_length = '(%s)' %path_length if self.engine.name == 'mysql' else ''
+            s = text('CREATE INDEX idx_nodes_path ON nodes (path%s)' %explicit_length)
             self.conn.execute(s).close()
         
         s = self.nodes.select().where(and_(self.nodes.c.node == ROOTNODE,
@@ -213,6 +209,9 @@ class Node(DBWorker):
            Return None if the path is not found.
         """
         
+        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
+        path = path.replace('%', '\%')
+        path = path.replace('_', '\_')
         s = select([self.nodes.c.node], self.nodes.c.path.like(path))
         r = self.conn.execute(s)
         row = r.fetchone()
@@ -273,7 +272,7 @@ class Node(DBWorker):
     def node_purge_children(self, parent, before=inf, cluster=0):
         """Delete all versions with the specified
            parent and cluster, and return
-           the serials of versions deleted.
+           the hashes of versions deleted.
            Clears out nodes with no remaining versions.
         """
         #update statistics
@@ -296,10 +295,10 @@ class Node(DBWorker):
         self.statistics_update(parent, -nr, size, mtime, cluster)
         self.statistics_update_ancestors(parent, -nr, size, mtime, cluster)
         
-        s = select([self.versions.c.serial])
+        s = select([self.versions.c.hash])
         s = s.where(where_clause)
         r = self.conn.execute(s)
-        serials = [row[SERIAL] for row in r.fetchall()]
+        hashes = [row[0] for row in r.fetchall()]
         r.close()
         
         #delete versions
@@ -318,12 +317,12 @@ class Node(DBWorker):
         s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
         self.conn.execute(s).close()
         
-        return serials
+        return hashes
     
     def node_purge(self, node, before=inf, cluster=0):
         """Delete all versions with the specified
            node and cluster, and return
-           the serials of versions deleted.
+           the hashes of versions deleted.
            Clears out the node if it has no remaining versions.
         """
         
@@ -344,10 +343,10 @@ class Node(DBWorker):
         mtime = time()
         self.statistics_update_ancestors(node, -nr, -size, mtime, cluster)
         
-        s = select([self.versions.c.serial])
+        s = select([self.versions.c.hash])
         s = s.where(where_clause)
         r = self.conn.execute(s)
-        serials = [r[SERIAL] for r in r.fetchall()]
+        hashes = [r[0] for r in r.fetchall()]
         r.close()
         
         #delete versions
@@ -366,7 +365,7 @@ class Node(DBWorker):
         s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
         self.conn.execute(s).close()
         
-        return serials
+        return hashes
     
     def node_remove(self, node):
         """Remove the node specified.
@@ -434,7 +433,6 @@ class Node(DBWorker):
            size of objects and mtime in the node's namespace.
            May be zero or positive or negative numbers.
         """
-        
         s = select([self.statistics.c.population, self.statistics.c.size],
             and_(self.statistics.c.node == node,
                  self.statistics.c.cluster == cluster))
@@ -635,10 +633,11 @@ class Node(DBWorker):
     def version_remove(self, serial):
         """Remove the serial specified."""
         
-        props = self.node_get_properties(serial)
+        props = self.version_get_properties(serial)
         if not props:
             return
         node = props[NODE]
+        hash = props[HASH]
         size = props[SIZE]
         cluster = props[CLUSTER]
         
@@ -647,7 +646,7 @@ class Node(DBWorker):
         
         s = self.versions.delete().where(self.versions.c.serial == serial)
         self.conn.execute(s).close()
-        return True
+        return hash
     
     def attribute_get(self, serial, keys=()):
         """Return a list of (key, value) pairs of the version specified by serial.