Escape catch-all characters in LIKE queries.
authorAntony Chazapis <chazapis@gmail.com>
Wed, 14 Dec 2011 12:04:54 +0000 (14:04 +0200)
committerAntony Chazapis <chazapis@gmail.com>
Wed, 14 Dec 2011 12:04:54 +0000 (14:04 +0200)
Refs #1768

pithos/backends/lib/sqlalchemy/dbworker.py
pithos/backends/lib/sqlalchemy/node.py
pithos/backends/lib/sqlalchemy/permissions.py
pithos/backends/lib/sqlalchemy/xfeatures.py
pithos/backends/lib/sqlite/dbworker.py
pithos/backends/lib/sqlite/dbwrapper.py
pithos/backends/lib/sqlite/node.py
pithos/backends/lib/sqlite/permissions.py
pithos/backends/lib/sqlite/xfeatures.py

index beb2a8a..2537e57 100644 (file)
@@ -39,3 +39,6 @@ class DBWorker(object):
         self.params = params
         self.conn = params['wrapper'].conn
         self.engine = params['wrapper'].engine
+    
+    def escape_like(self, s):
+        return s.replace('\\', '\\\\').replace('%', '\%').replace('_', '\_')
index 8aa066a..ddf9308 100644 (file)
@@ -210,9 +210,7 @@ class Node(DBWorker):
         """
         
         # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
-        path = path.replace('%', '\%')
-        path = path.replace('_', '\_')
-        s = select([self.nodes.c.node], self.nodes.c.path.like(path, escape='\\'))
+        s = select([self.nodes.c.node], self.nodes.c.path.like(self.escape_like(path), escape='\\'))
         r = self.conn.execute(s)
         row = r.fetchone()
         r.close()
@@ -541,7 +539,7 @@ class Node(DBWorker):
             self.versions.c.node == v.c.node)
         if before != inf:
             c1 = c1.where(self.versions.c.mtime < before)
-        c2 = select([self.nodes.c.node], self.nodes.c.path.like(path + '%'))
+        c2 = select([self.nodes.c.node], self.nodes.c.path.like(self.escape_like(path) + '%', escape='\\'))
         s = s.where(and_(v.c.serial == c1,
                          v.c.cluster != except_cluster,
                          v.c.node.in_(c2)))
@@ -744,7 +742,7 @@ class Node(DBWorker):
         s = s.where(n.c.node == v.c.node)
         conj = []
         for x in pathq:
-            conj.append(n.c.path.like(x + '%'))
+            conj.append(n.c.path.like(self.escape_like(x) + '%', escape='\\'))
         if conj:
             s = s.where(or_(*conj))
         rp = self.conn.execute(s)
@@ -823,7 +821,7 @@ class Node(DBWorker):
         s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling))
         conj = []
         for x in pathq:
-            conj.append(n.c.path.like(x + '%'))
+            conj.append(n.c.path.like(self.escape_like(x) + '%', escape='\\'))
         
         if conj:
             s = s.where(or_(*conj))
index e346fbf..0497353 100644 (file)
@@ -132,7 +132,7 @@ class Permissions(XFeatures, Groups, Public):
                     self.xfeaturevals.c.value == u.c.value)
         s = select([self.xfeatures.c.path], from_obj=[inner_join]).distinct()
         if prefix:
-            s = s.where(self.xfeatures.c.path.like(prefix + '%'))
+            s = s.where(self.xfeatures.c.path.like(self.escape_like(prefix) + '%', escape='\\'))
         r = self.conn.execute(s)
         l = [row[0] for row in r.fetchall()]
         r.close()
@@ -142,7 +142,7 @@ class Permissions(XFeatures, Groups, Public):
         """Return the list of shared paths."""
         
         s = select([self.xfeatures.c.path],
-            self.xfeatures.c.path.like(prefix + '%')).order_by(self.xfeatures.c.path.asc())
+            self.xfeatures.c.path.like(self.escape_like(prefix) + '%', escape='\\')).order_by(self.xfeatures.c.path.asc())
         r = self.conn.execute(s)
         l = [row[0] for row in r.fetchall()]
         r.close()
index 7cdcf1e..7bfacb9 100644 (file)
@@ -92,7 +92,7 @@ class XFeatures(DBWorker):
             return [inherited]
         
         s = select([self.xfeatures.c.path, self.xfeatures.c.feature_id])
-        s = s.where(and_(self.xfeatures.c.path.like(path + '%'),
+        s = s.where(and_(self.xfeatures.c.path.like(self.escape_like(path) + '%', escape='\\'),
                      self.xfeatures.c.path != path))
         s = s.order_by(self.xfeatures.c.path)
         r = self.conn.execute(s)
index cce4c60..3c7efb2 100644 (file)
@@ -45,3 +45,6 @@ class DBWorker(object):
         self.fetchall = cur.fetchall
         self.cur = cur
         self.conn = conn
+    
+    def escape_like(self, s):
+        return s.replace('\\', '\\\\').replace('%', '\%').replace('_', '\_')
index 3b5ed4c..17e089d 100644 (file)
@@ -38,6 +38,7 @@ class DBWrapper(object):
     
     def __init__(self, db):
         self.conn = sqlite3.connect(db, check_same_thread=False)
+        self.conn.execute(""" pragma case_sensitive_like = on """)
     
     def close(self):
         self.conn.close()
index a367f75..e0485b5 100644 (file)
@@ -447,8 +447,8 @@ class Node(DBWorker):
              "and cluster != ? "
              "and node in (select node "
                           "from nodes "
-                          "where path like ?)")
-        execute(q, (before, except_cluster, path + '%'))
+                          "where path like ? escape '\\')")
+        execute(q, (before, except_cluster, self.escape_like(path) + '%'))
         r = fetchone()
         if r is None:
             return None
@@ -605,9 +605,9 @@ class Node(DBWorker):
             return None, None
         
         subq = " and ("
-        subq += ' or '.join(('n.path like ?' for x in pathq))
+        subq += ' or '.join(("n.path like ? escape '\\'" for x in pathq))
         subq += ")"
-        args = tuple([x + '%' for x in pathq])
+        args = tuple([self.escape_like(x) + '%' for x in pathq])
         
         return subq, args
     
index a6b2314..966612d 100644 (file)
@@ -123,14 +123,14 @@ class Permissions(XFeatures, Groups, Public):
              "using (feature_id)")
         p = (member, member)
         if prefix:
-            q += " where path like ?"
-            p += (prefix + '%',)
+            q += " where path like ? escape '\\'"
+            p += (self.escape_like(prefix) + '%',)
         self.execute(q, p)
         return [r[0] for r in self.fetchall()]
     
     def access_list_shared(self, prefix=''):
         """Return the list of shared paths."""
         
-        q = "select path from xfeatures where path like ?"
-        self.execute(q, (prefix + '%',))
+        q = "select path from xfeatures where path like ? escape '\\'"
+        self.execute(q, (self.escape_like(prefix) + '%',))
         return [r[0] for r in self.fetchall()]
index d1e5df6..0f013da 100644 (file)
@@ -84,8 +84,8 @@ class XFeatures(DBWorker):
             return [inherited]
         
         q = ("select path, feature_id from xfeatures "
-             "where path like ? and path != ? order by path")
-        self.execute(q, (path + '%', path,))
+             "where path like ? escape '\\' and path != ? order by path")
+        self.execute(q, (self.escape_like(path) + '%', path,))
         return self.fetchall()
     
     def xfeature_create(self, path):