From b43d44ad4a79495f084c466712cab0154299e05e Mon Sep 17 00:00:00 2001 From: Sofia Papagiannaki Date: Wed, 31 Aug 2011 12:13:08 +0300 Subject: [PATCH] backend components in SQLAlchemy: Progress III --- pithos/backends/lib_alchemy/node.py | 448 +++++++++++++++++++---------------- 1 file changed, 241 insertions(+), 207 deletions(-) diff --git a/pithos/backends/lib_alchemy/node.py b/pithos/backends/lib_alchemy/node.py index 359d237..6a3d949 100644 --- a/pithos/backends/lib_alchemy/node.py +++ b/pithos/backends/lib_alchemy/node.py @@ -34,12 +34,10 @@ from time import time from sqlalchemy import Table, Integer, Column, String, MetaData, ForeignKey from sqlalchemy.schema import Index, Sequence -from sqlalchemy.sql import func, and_, null -from sqlalchemy.sql import select +from sqlalchemy.sql import func, and_, or_, null, select, bindparam from dbworker import DBWorker - ROOTNODE = 1 ( SERIAL, NODE, SIZE, SOURCE, MTIME, MUSER, CLUSTER ) = range(7) @@ -241,12 +239,13 @@ class Node(DBWorker): the serials of versions deleted. Clears out nodes with no remaining versions. """ - - scalar = select([self.nodes.c.node], - self.nodes.c.parent == parent).as_scalar() - where_clause = and_(self.versions.c.node.in_(scalar), + #update statistics + #TODO handle before=inf + c1 = select([self.nodes.c.node], + self.nodes.c.parent == parent) + where_clause = and_(self.versions.c.node.in_(c1), self.versions.c.cluster == cluster, - "versions.mtime <= %f" %before) + self.versions.c.mtime <= before) s = select([func.count(self.versions.c.serial), func.sum(self.versions.c.size)]) s = s.where(where_clause) @@ -257,7 +256,6 @@ class Node(DBWorker): return () nr, size = row[0], -row[1] if row[1] else 0 mtime = time() - print '#', parent, -nr, size, mtime, cluster self.statistics_update(parent, -nr, size, mtime, cluster) self.statistics_update_ancestors(parent, -nr, size, mtime, cluster) @@ -267,21 +265,22 @@ class Node(DBWorker): serials = [row[SERIAL] for row in r.fetchall()] r.close() - #delete versiosn + #delete versions s = self.versions.delete().where(where_clause) r = self.conn.execute(s) r.close() #delete nodes - a = self.nodes.alias() - no_versions = select([func.count(self.versions.c.serial)], - self.versions.c.node == a.c.node).as_scalar() == 0 - n = select([self.nodes.c.node], - and_(no_versions, self.nodes.c.parent == parent)) - s = s.where(self.nodes.c.node.in_(n)) - s = self.nodes.delete().where(self.nodes.c.node == s) - print '#', s + s = select([self.nodes.c.node], + and_(self.nodes.c.parent == parent, + select([func.count(self.versions.c.serial)], + self.versions.c.node == self.nodes.c.node).as_scalar() == 0)) + r = self.conn.execute(s) + nodes = r.fetchall() + r.close() + s = self.nodes.delete().where(self.nodes.c.node.in_(nodes)) self.conn.execute(s).close() + return serials def node_purge(self, node, before=inf, cluster=0): @@ -291,37 +290,43 @@ class Node(DBWorker): Clears out the node if it has no remaining versions. """ - execute = self.execute - q = ("select count(serial), sum(size) from versions " - "where node = ? " - "and cluster = ? " - "and mtime <= ?") - args = (node, cluster, before) - execute(q, args) - nr, size = self.fetchone() + #update statistics + s = select([func.count(self.versions.c.serial), + func.sum(self.versions.c.size)]) + where_clause = and_(self.versions.c.node == node, + self.versions.c.cluster == cluster, + self.versions.c.mtime <= before) + s = s.where(where_clause) + r = self.conn.execute(s) + row = r.fetchone() + nr, size = row[0], row[1] + r.close() if not nr: return () mtime = time() self.statistics_update_ancestors(node, -nr, -size, mtime, cluster) - q = ("select serial from versions " - "where node = ? " - "and cluster = ? " - "and mtime <= ?") - execute(q, args) - serials = [r[SERIAL] for r in self.fetchall()] - q = ("delete from versions " - "where node = ? " - "and cluster = ? " - "and mtime <= ?") - execute(q, args) - q = ("delete from nodes " - "where node in (select node from nodes n " - "where (select count(serial) " - "from versions " - "where node = n.node) = 0 " - "and node = ?)") - execute(q, (node,)) + s = select([self.versions.c.serial]) + s = s.where(where_clause) + r = self.conn.execute(s) + serials = [r[SERIAL] for r in r.fetchall()] + + #delete versions + s = self.versions.delete().where(where_clause) + r = self.conn.execute(s) + r.close() + + #delete nodes + s = select([self.nodes.c.node], + and_(self.nodes.c.node == node, + select([func.count(self.versions.c.serial)], + self.versions.c.node == self.nodes.c.node).as_scalar() == 0)) + r = self.conn.execute(s) + nodes = r.fetchall() + r.close() + s = self.nodes.delete().where(self.nodes.c.node.in_(nodes)) + self.conn.execute(s).close() + return serials def node_remove(self, node): @@ -333,16 +338,18 @@ class Node(DBWorker): return False mtime = time() - q = ("select count(serial), sum(size), cluster " - "from versions " - "where node = ? " - "group by cluster") - self.execute(q, (node,)) - for population, size, cluster in self.fetchall(): + s = select([func.count(self.versions.c.serial), + func.sum(self.versions.c.size), + self.versions.c.cluster]) + s = s.where(self.versions.c.node == node) + s = s.group_by(self.versions.c.cluster) + r = self.conn.execute(s) + for population, size, cluster in r.fetchall(): self.statistics_update_ancestors(node, -population, -size, mtime, cluster) + r.close() - q = "delete from nodes where node = ?" - self.execute(q, (node,)) + s = self.nodes.delete().where(self.nodes.c.node == node) + self.conn.execute(s).close() return True def statistics_get(self, node, cluster=0): @@ -350,10 +357,15 @@ class Node(DBWorker): for all versions under node that belong to the cluster. """ - q = ("select population, size, mtime from statistics " - "where node = ? and cluster = ?") - self.execute(q, (node, cluster)) - return self.fetchone() + s = select([self.statistics.c.population, + self.statistics.c.size, + self.statistics.c.mtime]) + s = s.where(and_(self.statistics.c.node == node, + self.statistics.c.cluster == cluster)) + r = self.conn.execute(s) + row = r.fetchone() + r.close() + return row def statistics_update(self, node, population, size, mtime, cluster=0): """Update the statistics of the given node. @@ -365,9 +377,9 @@ class Node(DBWorker): s = select([self.statistics.c.population, self.statistics.c.size], and_(self.statistics.c.node == node, self.statistics.c.cluster == cluster)) - res = self.conn.execute(s) - r = res.fetchone() - res.close() + rp = self.conn.execute(s) + r = rp.fetchone() + rp.close() if not r: prepopulation, presize = (0, 0) else: @@ -375,9 +387,8 @@ class Node(DBWorker): population += prepopulation size += presize - self.statistics.insert().values(node=node, population=population, - size=size, mtime=mtime, cluster=cluster) - self.conn.execute(s).close() + ins = self.statistics.insert().values(node, population, size, mtime, cluster) + self.conn.execute(ins).close() def statistics_update_ancestors(self, node, population, size, mtime, cluster=0): """Update the statistics of the given node's parent. @@ -412,31 +423,40 @@ class Node(DBWorker): parent, path = props # The latest version. - q = ("select serial, node, size, source, mtime, muser, cluster " - "from versions " - "where serial = (select max(serial) " - "from versions " - "where node = ? and mtime < ?) " - "and cluster != ?") - execute(q, (node, before, except_cluster)) - props = fetchone() - if props is None: + s = select([self.versions.c.serial, + self.versions.c.node, + self.versions.c.size, + self.versions.c.mtime, + self.versions.c.muser, + self.versions.c.cluster]) + s = s.where(and_(self.versions.c.cluster != except_cluster, + self.versions.c.serial == select( + [func.max(self.versions.c.serial)], + and_(self.versions.c.node == node, + self.versions.c.mtime < before)))) + r = self.conn.execute(s) + props = r.fetchone() + r.close() + if not props: return None mtime = props[MTIME] # First level, just under node (get population). - q = ("select count(serial), sum(size), max(mtime) " - "from versions v " - "where serial = (select max(serial) " - "from versions " - "where node = v.node and mtime < ?) " - "and cluster != ? " - "and node in (select node " - "from nodes " - "where parent = ?)") - execute(q, (before, except_cluster, node)) - r = fetchone() - if r is None: + v = self.versions.alias('v') + s = select([func.count(v.c.serial), + func.sum(v.c.size), + func.max(v.c.mtime)]) + c1 = select([func.max(self.versions.c.serial)], + and_(self.versions.c.node == v.c.node, + self.versions.c.mtime < before)) + c2 = select([self.nodes.c.node], self.nodes.c.parent == node) + s = s.where(and_(v.c.serial == c1, + v.c.cluster != except_cluster, + v.c.node.in_(c2))) + rp = self.conn.execute(s) + r = rp.fetchone() + rp.close() + if not r: return None count = r[0] mtime = max(mtime, r[2]) @@ -445,18 +465,20 @@ class Node(DBWorker): # All children (get size and mtime). # XXX: This is why the full path is stored. - q = ("select count(serial), sum(size), max(mtime) " - "from versions v " - "where serial = (select max(serial) " - "from versions " - "where node = v.node and mtime < ?) " - "and cluster != ? " - "and node in (select node " - "from nodes " - "where path like ?)") - execute(q, (before, except_cluster, path + '%')) - r = fetchone() - if r is None: + s = select([func.count(v.c.serial), + func.sum(v.c.size), + func.max(v.c.mtime)]) + c1 = select([func.max(self.versions.c.serial)], + and_(self.versions.c.node == v.c.node, + self.versions.c.mtime < before)) + c2 = select([self.nodes.c.node], self.nodes.c.path.like(path + '%')) + s = s.where(and_(v.c.serial == c1, + v.c.cluster != except_cluster, + v.c.node.in_(c2))) + rp = self.conn.execute(s) + r = rp.fetchone() + rp.close() + if not r: return None size = r[1] - props[SIZE] mtime = max(mtime, r[2]) @@ -467,11 +489,12 @@ class Node(DBWorker): Return the (serial, mtime) of the new version. """ - q = ("insert into versions (node, size, source, mtime, muser, cluster) " - "values (?, ?, ?, ?, ?, ?)") mtime = time() props = (node, size, source, mtime, muser, cluster) - serial = self.execute(q, props).lastrowid + props = locals() + props.pop('self') + s = self.versions.insert().values(**props) + serial = self.conn.execute(s).inserted_primary_key[0] self.statistics_update_ancestors(node, 1, size, mtime, cluster) return serial, mtime @@ -482,15 +505,18 @@ class Node(DBWorker): or None if the current version is not found in the given cluster. """ - q = ("select serial, node, size, source, mtime, muser, cluster " - "from versions " - "where serial = (select max(serial) " - "from versions " - "where node = ? and mtime < ?) " - "and cluster = ?") - self.execute(q, (node, before, cluster)) - props = self.fetchone() - if props is not None: + v = self.versions.alias('v') + s = select([v.c.serial, v.c.node, v.c.size, v.c.source, v.c.mtime, + v.c.muser, v.c.cluster]) + c = select([func.max(self.versions.c.serial)], + and_(self.versions.c.node == node, + self.versions.c.mtime < before)) + s = s.where(and_(v.c.serial == c, + v.c.cluster == cluster)) + r = self.conn.execute(s) + props = r.fetchone() + r.close() + if not props: return props return None @@ -501,11 +527,12 @@ class Node(DBWorker): (serial, node, size, source, mtime, muser, cluster). """ - q = ("select serial, node, size, source, mtime, muser, cluster " - "from versions " - "where serial = ?") - self.execute(q, (serial,)) - r = self.fetchone() + v = self.versions.alias() + s = select([v.c.serial, v.c.node, v.c.size, v.c.source, v.c.mtime, + v.c.muser, v.c.cluster], v.c.serial == serial) + rp = self.conn.execute(s) + r = rp.fetchone() + rp.close() if r is None: return r @@ -529,8 +556,10 @@ class Node(DBWorker): self.statistics_update_ancestors(node, -1, -size, mtime, oldcluster) self.statistics_update_ancestors(node, 1, size, mtime, cluster) - q = "update versions set cluster = ? where serial = ?" - self.execute(q, (cluster, serial)) + s = self.versions.update() + s = s.where(self.versions.c.serial == serial) + s = s.values(cluster = cluster) + self.conn.execute(s).close() def version_remove(self, serial): """Remove the serial specified.""" @@ -545,8 +574,8 @@ class Node(DBWorker): mtime = time() self.statistics_update_ancestors(node, -1, -size, mtime, cluster) - q = "delete from versions where serial = ?" - self.execute(q, (serial,)) + s = self.versions.delete().where(self.versions.c.serial == serial) + self.conn.execute(s).close() return True def attribute_get(self, serial, keys=()): @@ -557,23 +586,26 @@ class Node(DBWorker): execute = self.execute if keys: - marks = ','.join('?' for k in keys) - q = ("select key, value from attributes " - "where key in (%s) and serial = ?" % (marks,)) - execute(q, keys + (serial,)) + attrs = self.attributes.alias() + s = select([attrs.c.key, attrs.c.value]) + s = s.where(and_(attrs.c.key.in_(keys), + attrs.c.serial == serial)) else: - q = "select key, value from attributes where serial = ?" - execute(q, (serial,)) - return self.fetchall() + attrs = self.attributes.alias() + s = select([attrs.c.key, attrs.c.value]) + s = s.where(attrs.c.serial == serial) + r = self.conn.execute(s) + l = r.fetchall() + r.close() + return l def attribute_set(self, serial, items): """Set the attributes of the version specified by serial. Receive attributes as an iterable of (key, value) pairs. """ - q = ("insert or replace into attributes (serial, key, value) " - "values (?, ?, ?)") - self.executemany(q, ((serial, k, v) for k, v in items)) + values = [{'serial':serial, 'key':k, 'value':v} for k, v in items] + self.conn.execute(self.attributes.insert(), values).close() def attribute_del(self, serial, keys=()): """Delete attributes of the version specified by serial. @@ -582,39 +614,37 @@ class Node(DBWorker): """ if keys: - q = "delete from attributes where serial = ? and key = ?" - self.executemany(q, ((serial, key) for key in keys)) + #TODO more efficient way to do this? + for key in keys: + s = self.attributes.delete() + s = s.where(and_(self.attributes.c.serial == serial, + self.attributes.c.key == key)) + self.conn.execute(s).close() else: - q = "delete from attributes where serial = ?" - self.execute(q, (serial,)) + s = self.attributes.delete() + s = s.where(self.attributes.c.serial == serial) + self.conn.execute(s).close() def attribute_copy(self, source, dest): - q = ("insert or replace into attributes " - "select ?, key, value from attributes " - "where serial = ?") - self.execute(q, (dest, source)) - - def _construct_filters(self, filterq): - if not filterq: - return None, None - - args = filterq.split(',') - subq = " and a.key in (" - subq += ','.join(('?' for x in args)) - subq += ")" - - return subq, args - - def _construct_paths(self, pathq): - if not pathq: - return None, None - - subq = " and (" - subq += ' or '.join(('n.path like ?' for x in pathq)) - subq += ")" - args = tuple([x + '%' for x in pathq]) - - return subq, args + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.sql.expression import UpdateBase + + class InsertFromSelect(UpdateBase): + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select) + ) + + s = select([dest, self.attributes.c.key, self.attributes.c.value], + self.attributes.c.serial == source) + ins = InsertFromSelect(self.attributes, s) + self.conn.execute(ins).close() def latest_attribute_keys(self, parent, before=inf, except_cluster=0, pathq=[]): """Return a list with all keys pairs defined @@ -623,23 +653,26 @@ class Node(DBWorker): """ # TODO: Use another table to store before=inf results. - q = ("select distinct a.key " - "from attributes a, versions v, nodes n " - "where v.serial = (select max(serial) " - "from versions " - "where node = v.node and mtime < ?) " - "and v.cluster != ? " - "and v.node in (select node " - "from nodes " - "where parent = ?) " - "and a.serial = v.serial " - "and n.node = v.node") - args = (before, except_cluster, parent) - subq, subargs = self._construct_paths(pathq) - if subq is not None: - q += subq - args += subargs - self.execute(q, args) + a = self.attributes.alias('a') + v = self.versions.alias('v') + n = self.nodes.alias('n') + s = select([a.c.key]).distinct() + s = s.where(v.c.serial == select([func.max(self.versions.c.serial)], + and_(self.versions.c.node == v.c.node, + self.versions.c.mtime < before))) + s = s.where(v.c.cluster != except_cluster) + s = s.where(v.c.node.in_(select([self.nodes.c.node], + self.nodes.c.parent == parent))) + s = s.where(a.c.serial == v.c.serial) + s = s.where(n.c.node == v.c.node) + conj = [] + for x in pathq: + conj.append(n.c.path.like(x + '%')) + if conj: + s = s.where(or_(*conj)) + rp = self.conn.execute(s) + r = rp.fetchall() + rp.close() return [r[0] for r in self.fetchall()] def latest_version_list(self, parent, prefix='', delimiter=None, @@ -697,38 +730,40 @@ class Node(DBWorker): start = strprevling(prefix) nextling = strnextling(prefix) - q = ("select distinct n.path, v.serial " - "from attributes a, versions v, nodes n " - "where v.serial = (select max(serial) " - "from versions " - "where node = v.node and mtime < ?) " - "and v.cluster != ? " - "and v.node in (select node " - "from nodes " - "where parent = ?) " - "and a.serial = v.serial " - "and n.node = v.node " - "and n.path > ? and n.path < ?") - args = [before, except_cluster, parent, start, nextling] - - subq, subargs = self._construct_paths(pathq) - if subq is not None: - q += subq - args += subargs - subq, subargs = self._construct_filters(filterq) - if subq is not None: - q += subq - args += subargs - else: - q = q.replace("attributes a, ", "") - q = q.replace("and a.serial = v.serial ", "") - q += " order by n.path" + a = self.attributes.alias('a') + v = self.versions.alias('v') + n = self.nodes.alias('n') + s = select([n.c.path, v.c.serial]).distinct() + s = s.where(v.c.serial == select([func.max(self.versions.c.serial)], + and_(self.versions.c.node == v.c.node, + self.versions.c.mtime < before))) + s = s.where(v.c.cluster != except_cluster) + s = s.where(v.c.node.in_(select([self.nodes.c.node], + self.nodes.c.parent == parent))) + if filterq: + s = s.where(a.c.serial == v.c.serial) + + s = s.where(n.c.node == v.c.node) + s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling)) + conj = [] + for x in pathq: + print '#', x + conj.append(n.c.path.like(x + '%')) + + if conj: + s = s.where(or_(*conj)) + + if filterq: + s = s.where(a.c.key.in_(filterq.split(','))) + + s = s.order_by(n.c.path) if not delimiter: - q += " limit ?" - args.append(limit) - execute(q, args) - return self.fetchall(), () + s = s.limit(limit) + rp = self.conn.execute(s, start=start) + r = rp.fetchall() + rp.close() + return r, () pfz = len(prefix) dz = len(delimiter) @@ -739,9 +774,9 @@ class Node(DBWorker): matches = [] mappend = matches.append - execute(q, args) + rp = self.conn.execute(s, start=start) while True: - props = fetchone() + props = rp.fetchone() if props is None: break path, serial = props @@ -762,7 +797,6 @@ class Node(DBWorker): if count >= limit: break - args[3] = strnextling(pf) # New start. - execute(q, args) + rp = self.conn.execute(s, start=strnextling(pf)) # New start. return matches, prefixes -- 1.7.10.4