Escape special characters for LIKE in node lookup.
[pithos] / pithos / backends / lib / sqlalchemy / node.py
index 76341c8..4a0fc55 100644 (file)
@@ -79,8 +79,7 @@ def strprevling(prefix):
     s = prefix[:-1]
     c = ord(prefix[-1])
     if c > 0:
-        #s += unichr(c-1) + unichr(0xffff)
-        s += unichr(c-1)
+        s += unichr(c-1) + unichr(0xffff)
     return s
 
 
@@ -210,6 +209,9 @@ class Node(DBWorker):
            Return None if the path is not found.
         """
         
+        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
+        path = path.replace('%', '\%')
+        path = path.replace('_', '\_')
         s = select([self.nodes.c.node], self.nodes.c.path.like(path))
         r = self.conn.execute(s)
         row = r.fetchone()
@@ -270,7 +272,7 @@ class Node(DBWorker):
     def node_purge_children(self, parent, before=inf, cluster=0):
         """Delete all versions with the specified
            parent and cluster, and return
-           the serials of versions deleted.
+           the hashes of versions deleted.
            Clears out nodes with no remaining versions.
         """
         #update statistics
@@ -293,10 +295,10 @@ class Node(DBWorker):
         self.statistics_update(parent, -nr, size, mtime, cluster)
         self.statistics_update_ancestors(parent, -nr, size, mtime, cluster)
         
-        s = select([self.versions.c.serial])
+        s = select([self.versions.c.hash])
         s = s.where(where_clause)
         r = self.conn.execute(s)
-        serials = [row[SERIAL] for row in r.fetchall()]
+        hashes = [row[0] for row in r.fetchall()]
         r.close()
         
         #delete versions
@@ -315,12 +317,12 @@ class Node(DBWorker):
         s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
         self.conn.execute(s).close()
         
-        return serials
+        return hashes
     
     def node_purge(self, node, before=inf, cluster=0):
         """Delete all versions with the specified
            node and cluster, and return
-           the serials of versions deleted.
+           the hashes of versions deleted.
            Clears out the node if it has no remaining versions.
         """
         
@@ -341,10 +343,10 @@ class Node(DBWorker):
         mtime = time()
         self.statistics_update_ancestors(node, -nr, -size, mtime, cluster)
         
-        s = select([self.versions.c.serial])
+        s = select([self.versions.c.hash])
         s = s.where(where_clause)
         r = self.conn.execute(s)
-        serials = [r[SERIAL] for r in r.fetchall()]
+        hashes = [r[0] for r in r.fetchall()]
         r.close()
         
         #delete versions
@@ -363,7 +365,7 @@ class Node(DBWorker):
         s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
         self.conn.execute(s).close()
         
-        return serials
+        return hashes
     
     def node_remove(self, node):
         """Remove the node specified.
@@ -631,10 +633,11 @@ class Node(DBWorker):
     def version_remove(self, serial):
         """Remove the serial specified."""
         
-        props = self.node_get_properties(serial)
+        props = self.version_get_properties(serial)
         if not props:
             return
         node = props[NODE]
+        hash = props[HASH]
         size = props[SIZE]
         cluster = props[CLUSTER]
         
@@ -643,7 +646,7 @@ class Node(DBWorker):
         
         s = self.versions.delete().where(self.versions.c.serial == serial)
         self.conn.execute(s).close()
-        return True
+        return hash
     
     def attribute_get(self, serial, keys=()):
         """Return a list of (key, value) pairs of the version specified by serial.