Statistics
| Branch: | Tag: | Revision:

root / snf-pithos-backend / pithos / backends / lib / sqlalchemy / node.py @ 83a3723e

History | View | Annotate | Download (44.5 kB)

1
# Copyright 2011-2012 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
from time import time
35
from operator import itemgetter
36
from itertools import groupby
37

    
38
from sqlalchemy import Table, Integer, BigInteger, DECIMAL, Column, String, MetaData, ForeignKey
39
from sqlalchemy.types import Text
40
from sqlalchemy.schema import Index, Sequence
41
from sqlalchemy.sql import func, and_, or_, not_, null, select, bindparam, text, exists
42
from sqlalchemy.ext.compiler import compiles
43
from sqlalchemy.engine.reflection import Inspector
44
from sqlalchemy.exc import NoSuchTableError
45

    
46
from dbworker import DBWorker, ESCAPE_CHAR
47

    
48
from pithos.backends.filter import parse_filters
49

    
50

    
51
ROOTNODE = 0
52

    
53
(SERIAL, NODE, HASH, SIZE, TYPE, SOURCE, MTIME, MUSER, UUID, CHECKSUM,
54
 CLUSTER) = range(11)
55

    
56
(MATCH_PREFIX, MATCH_EXACT) = range(2)
57

    
58
inf = float('inf')
59

    
60

    
61
def strnextling(prefix):
62
    """Return the first unicode string
63
       greater than but not starting with given prefix.
64
       strnextling('hello') -> 'hellp'
65
    """
66
    if not prefix:
67
        ## all strings start with the null string,
68
        ## therefore we have to approximate strnextling('')
69
        ## with the last unicode character supported by python
70
        ## 0x10ffff for wide (32-bit unicode) python builds
71
        ## 0x00ffff for narrow (16-bit unicode) python builds
72
        ## We will not autodetect. 0xffff is safe enough.
73
        return unichr(0xffff)
74
    s = prefix[:-1]
75
    c = ord(prefix[-1])
76
    if c >= 0xffff:
77
        raise RuntimeError
78
    s += unichr(c + 1)
79
    return s
80

    
81

    
82
def strprevling(prefix):
83
    """Return an approximation of the last unicode string
84
       less than but not starting with given prefix.
85
       strprevling(u'hello') -> u'helln\\xffff'
86
    """
87
    if not prefix:
88
        ## There is no prevling for the null string
89
        return prefix
90
    s = prefix[:-1]
91
    c = ord(prefix[-1])
92
    if c > 0:
93
        s += unichr(c - 1) + unichr(0xffff)
94
    return s
95

    
96
_propnames = {
97
    'serial': 0,
98
    'node': 1,
99
    'hash': 2,
100
    'size': 3,
101
    'type': 4,
102
    'source': 5,
103
    'mtime': 6,
104
    'muser': 7,
105
    'uuid': 8,
106
    'checksum': 9,
107
    'cluster': 10
108
}
109

    
110

    
111
def create_tables(engine):
112
    metadata = MetaData()
113

    
114
    #create nodes table
115
    columns = []
116
    columns.append(Column('node', Integer, primary_key=True))
117
    columns.append(Column('parent', Integer,
118
                          ForeignKey('nodes.node',
119
                                     ondelete='CASCADE',
120
                                     onupdate='CASCADE'),
121
                          autoincrement=False))
122
    columns.append(Column('latest_version', Integer))
123
    columns.append(Column('path', String(2048), default='', nullable=False))
124
    nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
125
    Index('idx_nodes_path', nodes.c.path, unique=True)
126
    Index('idx_nodes_parent', nodes.c.parent)
127

    
128
    #create policy table
129
    columns = []
130
    columns.append(Column('node', Integer,
131
                          ForeignKey('nodes.node',
132
                                     ondelete='CASCADE',
133
                                     onupdate='CASCADE'),
134
                          primary_key=True))
135
    columns.append(Column('key', String(128), primary_key=True))
136
    columns.append(Column('value', String(256)))
137
    policy = Table('policy', metadata, *columns, mysql_engine='InnoDB')
138

    
139
    #create statistics table
140
    columns = []
141
    columns.append(Column('node', Integer,
142
                          ForeignKey('nodes.node',
143
                                     ondelete='CASCADE',
144
                                     onupdate='CASCADE'),
145
                          primary_key=True))
146
    columns.append(Column('population', Integer, nullable=False, default=0))
147
    columns.append(Column('size', BigInteger, nullable=False, default=0))
148
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
149
    columns.append(Column('cluster', Integer, nullable=False, default=0,
150
                          primary_key=True, autoincrement=False))
151
    statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
152

    
153
    #create versions table
154
    columns = []
155
    columns.append(Column('serial', Integer, primary_key=True))
156
    columns.append(Column('node', Integer,
157
                          ForeignKey('nodes.node',
158
                                     ondelete='CASCADE',
159
                                     onupdate='CASCADE')))
160
    columns.append(Column('hash', String(256)))
161
    columns.append(Column('size', BigInteger, nullable=False, default=0))
162
    columns.append(Column('type', String(256), nullable=False, default=''))
163
    columns.append(Column('source', Integer))
164
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
165
    columns.append(Column('muser', String(256), nullable=False, default=''))
166
    columns.append(Column('uuid', String(64), nullable=False, default=''))
167
    columns.append(Column('checksum', String(256), nullable=False, default=''))
168
    columns.append(Column('cluster', Integer, nullable=False, default=0))
169
    versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
170
    Index('idx_versions_node_mtime', versions.c.node, versions.c.mtime)
171
    Index('idx_versions_node_uuid', versions.c.uuid)
172

    
173
    #create attributes table
174
    columns = []
175
    columns.append(Column('serial', Integer,
176
                          ForeignKey('versions.serial',
177
                                     ondelete='CASCADE',
178
                                     onupdate='CASCADE'),
179
                          primary_key=True))
180
    columns.append(Column('domain', String(256), primary_key=True))
181
    columns.append(Column('key', String(128), primary_key=True))
182
    columns.append(Column('value', String(256)))
183
    attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
184
    Index('idx_attributes_domain', attributes.c.domain)
185

    
186
    metadata.create_all(engine)
187
    return metadata.sorted_tables
188

    
189

    
190
class Node(DBWorker):
191
    """Nodes store path organization and have multiple versions.
192
       Versions store object history and have multiple attributes.
193
       Attributes store metadata.
194
    """
195

    
196
    # TODO: Provide an interface for included and excluded clusters.
197

    
198
    def __init__(self, **params):
199
        DBWorker.__init__(self, **params)
200
        try:
201
            metadata = MetaData(self.engine)
202
            self.nodes = Table('nodes', metadata, autoload=True)
203
            self.policy = Table('policy', metadata, autoload=True)
204
            self.statistics = Table('statistics', metadata, autoload=True)
205
            self.versions = Table('versions', metadata, autoload=True)
206
            self.attributes = Table('attributes', metadata, autoload=True)
207
        except NoSuchTableError:
208
            tables = create_tables(self.engine)
209
            map(lambda t: self.__setattr__(t.name, t), tables)
210

    
211
        s = self.nodes.select().where(and_(self.nodes.c.node == ROOTNODE,
212
                                           self.nodes.c.parent == ROOTNODE))
213
        wrapper = self.wrapper
214
        wrapper.execute()
215
        try:
216
            rp = self.conn.execute(s)
217
            r = rp.fetchone()
218
            rp.close()
219
            if not r:
220
                s = self.nodes.insert(
221
                ).values(node=ROOTNODE, parent=ROOTNODE, path='')
222
                self.conn.execute(s)
223
        finally:
224
            wrapper.commit()
225

    
226
    def node_create(self, parent, path):
227
        """Create a new node from the given properties.
228
           Return the node identifier of the new node.
229
        """
230
        #TODO catch IntegrityError?
231
        s = self.nodes.insert().values(parent=parent, path=path)
232
        r = self.conn.execute(s)
233
        inserted_primary_key = r.inserted_primary_key[0]
234
        r.close()
235
        return inserted_primary_key
236

    
237
    def node_lookup(self, path, for_update=False):
238
        """Lookup the current node of the given path.
239
           Return None if the path is not found.
240
        """
241

    
242
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
243
        s = select([self.nodes.c.node], self.nodes.c.path.like(
244
            self.escape_like(path), escape=ESCAPE_CHAR), for_update=for_update)
245
        r = self.conn.execute(s)
246
        row = r.fetchone()
247
        r.close()
248
        if row:
249
            return row[0]
250
        return None
251

    
252
    def node_lookup_bulk(self, paths):
253
        """Lookup the current nodes for the given paths.
254
           Return () if the path is not found.
255
        """
256

    
257
        if not paths:
258
            return ()
259
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
260
        s = select([self.nodes.c.node], self.nodes.c.path.in_(paths))
261
        r = self.conn.execute(s)
262
        rows = r.fetchall()
263
        r.close()
264
        return [row[0] for row in rows]
265

    
266
    def node_get_properties(self, node):
267
        """Return the node's (parent, path).
268
           Return None if the node is not found.
269
        """
270

    
271
        s = select([self.nodes.c.parent, self.nodes.c.path])
272
        s = s.where(self.nodes.c.node == node)
273
        r = self.conn.execute(s)
274
        l = r.fetchone()
275
        r.close()
276
        return l
277

    
278
    def node_get_versions(self, node, keys=(), propnames=_propnames):
279
        """Return the properties of all versions at node.
280
           If keys is empty, return all properties in the order
281
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
282
        """
283

    
284
        s = select([self.versions.c.serial,
285
                    self.versions.c.node,
286
                    self.versions.c.hash,
287
                    self.versions.c.size,
288
                    self.versions.c.type,
289
                    self.versions.c.source,
290
                    self.versions.c.mtime,
291
                    self.versions.c.muser,
292
                    self.versions.c.uuid,
293
                    self.versions.c.checksum,
294
                    self.versions.c.cluster], self.versions.c.node == node)
295
        s = s.order_by(self.versions.c.serial)
296
        r = self.conn.execute(s)
297
        rows = r.fetchall()
298
        r.close()
299
        if not rows:
300
            return rows
301

    
302
        if not keys:
303
            return rows
304

    
305
        return [[p[propnames[k]] for k in keys if k in propnames] for p in rows]
306

    
307
    def node_count_children(self, node):
308
        """Return node's child count."""
309

    
310
        s = select([func.count(self.nodes.c.node)])
311
        s = s.where(and_(self.nodes.c.parent == node,
312
                         self.nodes.c.node != ROOTNODE))
313
        r = self.conn.execute(s)
314
        row = r.fetchone()
315
        r.close()
316
        return row[0]
317

    
318
    def node_purge_children(self, parent, before=inf, cluster=0,
319
                            update_statistics_ancestors_depth=None):
320
        """Delete all versions with the specified
321
           parent and cluster, and return
322
           the hashes, the total size and the serials of versions deleted.
323
           Clears out nodes with no remaining versions.
324
        """
325
        #update statistics
326
        c1 = select([self.nodes.c.node],
327
                    self.nodes.c.parent == parent)
328
        where_clause = and_(self.versions.c.node.in_(c1),
329
                            self.versions.c.cluster == cluster)
330
        if before != inf:
331
            where_clause = and_(where_clause,
332
                                self.versions.c.mtime <= before)
333
        s = select([func.count(self.versions.c.serial),
334
                    func.sum(self.versions.c.size)])
335
        s = s.where(where_clause)
336
        r = self.conn.execute(s)
337
        row = r.fetchone()
338
        r.close()
339
        if not row:
340
            return (), 0, ()
341
        nr, size = row[0], row[1] if row[1] else 0
342
        mtime = time()
343
        self.statistics_update(parent, -nr, -size, mtime, cluster)
344
        self.statistics_update_ancestors(parent, -nr, -size, mtime, cluster,
345
                                         update_statistics_ancestors_depth)
346

    
347
        s = select([self.versions.c.hash, self.versions.c.serial])
348
        s = s.where(where_clause)
349
        r = self.conn.execute(s)
350
        hashes = []
351
        serials = []
352
        for row in r.fetchall():
353
            hashes += [row[0]]
354
            serials += [row[1]]
355
        r.close()
356

    
357
        #delete versions
358
        s = self.versions.delete().where(where_clause)
359
        r = self.conn.execute(s)
360
        r.close()
361

    
362
        #delete nodes
363
        s = select([self.nodes.c.node],
364
                   and_(self.nodes.c.parent == parent,
365
                        select([func.count(self.versions.c.serial)],
366
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
367
        rp = self.conn.execute(s)
368
        nodes = [r[0] for r in rp.fetchall()]
369
        rp.close()
370
        if nodes:
371
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
372
            self.conn.execute(s).close()
373

    
374
        return hashes, size, serials
375

    
376
    def node_purge(self, node, before=inf, cluster=0,
377
                   update_statistics_ancestors_depth=None):
378
        """Delete all versions with the specified
379
           node and cluster, and return
380
           the hashes and size of versions deleted.
381
           Clears out the node if it has no remaining versions.
382
        """
383

    
384
        #update statistics
385
        s = select([func.count(self.versions.c.serial),
386
                    func.sum(self.versions.c.size)])
387
        where_clause = and_(self.versions.c.node == node,
388
                            self.versions.c.cluster == cluster)
389
        if before != inf:
390
            where_clause = and_(where_clause,
391
                                self.versions.c.mtime <= before)
392
        s = s.where(where_clause)
393
        r = self.conn.execute(s)
394
        row = r.fetchone()
395
        nr, size = row[0], row[1]
396
        r.close()
397
        if not nr:
398
            return (), 0, ()
399
        mtime = time()
400
        self.statistics_update_ancestors(node, -nr, -size, mtime, cluster,
401
                                         update_statistics_ancestors_depth)
402

    
403
        s = select([self.versions.c.hash, self.versions.c.serial])
404
        s = s.where(where_clause)
405
        r = self.conn.execute(s)
406
        hashes = []
407
        serials = []
408
        for row in r.fetchall():
409
            hashes += [row[0]]
410
            serials += [row[1]]
411
        r.close()
412

    
413
        #delete versions
414
        s = self.versions.delete().where(where_clause)
415
        r = self.conn.execute(s)
416
        r.close()
417

    
418
        #delete nodes
419
        s = select([self.nodes.c.node],
420
                   and_(self.nodes.c.node == node,
421
                        select([func.count(self.versions.c.serial)],
422
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
423
        rp= self.conn.execute(s)
424
        nodes = [r[0] for r in rp.fetchall()]
425
        rp.close()
426
        if nodes:
427
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
428
            self.conn.execute(s).close()
429

    
430
        return hashes, size, serials
431

    
432
    def node_remove(self, node, update_statistics_ancestors_depth=None):
433
        """Remove the node specified.
434
           Return false if the node has children or is not found.
435
        """
436

    
437
        if self.node_count_children(node):
438
            return False
439

    
440
        mtime = time()
441
        s = select([func.count(self.versions.c.serial),
442
                    func.sum(self.versions.c.size),
443
                    self.versions.c.cluster])
444
        s = s.where(self.versions.c.node == node)
445
        s = s.group_by(self.versions.c.cluster)
446
        r = self.conn.execute(s)
447
        for population, size, cluster in r.fetchall():
448
            self.statistics_update_ancestors(
449
                node, -population, -size, mtime, cluster,
450
                update_statistics_ancestors_depth)
451
        r.close()
452

    
453
        s = self.nodes.delete().where(self.nodes.c.node == node)
454
        self.conn.execute(s).close()
455
        return True
456

    
457
    def node_accounts(self, accounts=()):
458
        s = select([self.nodes.c.path, self.nodes.c.node])
459
        s = s.where(and_(self.nodes.c.node != 0,
460
                         self.nodes.c.parent == 0))
461
        if accounts:
462
            s = s.where(self.nodes.c.path.in_(accounts))
463
        r = self.conn.execute(s)
464
        rows = r.fetchall()
465
        r.close()
466
        return rows
467

    
468
    def node_account_quotas(self):
469
        s = select([self.nodes.c.path, self.policy.c.value])
470
        s = s.where(and_(self.nodes.c.node != 0,
471
                         self.nodes.c.parent == 0))
472
        s = s.where(self.nodes.c.node == self.policy.c.node)
473
        s = s.where(self.policy.c.key == 'quota')
474
        r = self.conn.execute(s)
475
        rows = r.fetchall()
476
        r.close()
477
        return dict(rows)
478

    
479
    def node_account_usage(self, account_node, cluster):
480
        select_children = select(
481
            [self.nodes.c.node]).where(self.nodes.c.parent == account_node)
482
        select_descendants = select([self.nodes.c.node]).where(
483
            or_(self.nodes.c.parent.in_(select_children),
484
                self.nodes.c.node.in_(select_children)))
485
        s = select([func.sum(self.versions.c.size)])
486
        s = s.where(self.nodes.c.node == self.versions.c.node)
487
        s = s.where(self.nodes.c.node.in_(select_descendants))
488
        s = s.where(self.versions.c.cluster == cluster)
489
        r = self.conn.execute(s)
490
        usage = r.fetchone()[0]
491
        r.close()
492
        return usage
493

    
494
    def policy_get(self, node):
495
        s = select([self.policy.c.key, self.policy.c.value],
496
                   self.policy.c.node == node)
497
        r = self.conn.execute(s)
498
        d = dict(r.fetchall())
499
        r.close()
500
        return d
501

    
502
    def policy_set(self, node, policy):
503
        #insert or replace
504
        for k, v in policy.iteritems():
505
            s = self.policy.update().where(and_(self.policy.c.node == node,
506
                                                self.policy.c.key == k))
507
            s = s.values(value=v)
508
            rp = self.conn.execute(s)
509
            rp.close()
510
            if rp.rowcount == 0:
511
                s = self.policy.insert()
512
                values = {'node': node, 'key': k, 'value': v}
513
                r = self.conn.execute(s, values)
514
                r.close()
515

    
516
    def statistics_get(self, node, cluster=0):
517
        """Return population, total size and last mtime
518
           for all versions under node that belong to the cluster.
519
        """
520

    
521
        s = select([self.statistics.c.population,
522
                    self.statistics.c.size,
523
                    self.statistics.c.mtime])
524
        s = s.where(and_(self.statistics.c.node == node,
525
                         self.statistics.c.cluster == cluster))
526
        r = self.conn.execute(s)
527
        row = r.fetchone()
528
        r.close()
529
        return row
530

    
531
    def statistics_update(self, node, population, size, mtime, cluster=0):
532
        """Update the statistics of the given node.
533
           Statistics keep track the population, total
534
           size of objects and mtime in the node's namespace.
535
           May be zero or positive or negative numbers.
536
        """
537
        s = select([self.statistics.c.population, self.statistics.c.size],
538
                   and_(self.statistics.c.node == node,
539
                        self.statistics.c.cluster == cluster))
540
        rp = self.conn.execute(s)
541
        r = rp.fetchone()
542
        rp.close()
543
        if not r:
544
            prepopulation, presize = (0, 0)
545
        else:
546
            prepopulation, presize = r
547
        population += prepopulation
548
        population = max(population, 0)
549
        size += presize
550

    
551
        #insert or replace
552
        #TODO better upsert
553
        u = self.statistics.update().where(and_(self.statistics.c.node == node,
554
                                           self.statistics.c.cluster == cluster))
555
        u = u.values(population=population, size=size, mtime=mtime)
556
        rp = self.conn.execute(u)
557
        rp.close()
558
        if rp.rowcount == 0:
559
            ins = self.statistics.insert()
560
            ins = ins.values(node=node, population=population, size=size,
561
                             mtime=mtime, cluster=cluster)
562
            self.conn.execute(ins).close()
563

    
564
    def statistics_update_ancestors(self, node, population, size, mtime,
565
                                    cluster=0, recursion_depth=None):
566
        """Update the statistics of the given node's parent.
567
           Then recursively update all parents up to the root
568
           or up to the ``recursion_depth`` (if not None).
569
           Population is not recursive.
570
        """
571

    
572
        i = 0
573
        while True:
574
            if node == ROOTNODE:
575
                break
576
            if recursion_depth and recursion_depth <= i:
577
                break
578
            props = self.node_get_properties(node)
579
            if props is None:
580
                break
581
            parent, path = props
582
            self.statistics_update(parent, population, size, mtime, cluster)
583
            node = parent
584
            population = 0  # Population isn't recursive
585
            i += 1
586

    
587
    def statistics_latest(self, node, before=inf, except_cluster=0):
588
        """Return population, total size and last mtime
589
           for all latest versions under node that
590
           do not belong to the cluster.
591
        """
592

    
593
        # The node.
594
        props = self.node_get_properties(node)
595
        if props is None:
596
            return None
597
        parent, path = props
598

    
599
        # The latest version.
600
        s = select([self.versions.c.serial,
601
                    self.versions.c.node,
602
                    self.versions.c.hash,
603
                    self.versions.c.size,
604
                    self.versions.c.type,
605
                    self.versions.c.source,
606
                    self.versions.c.mtime,
607
                    self.versions.c.muser,
608
                    self.versions.c.uuid,
609
                    self.versions.c.checksum,
610
                    self.versions.c.cluster])
611
        if before != inf:
612
            filtered = select([func.max(self.versions.c.serial)],
613
                              self.versions.c.node == node)
614
            filtered = filtered.where(self.versions.c.mtime < before)
615
        else:
616
            filtered = select([self.nodes.c.latest_version],
617
                              self.nodes.c.node == node)
618
        s = s.where(and_(self.versions.c.cluster != except_cluster,
619
                         self.versions.c.serial == filtered))
620
        r = self.conn.execute(s)
621
        props = r.fetchone()
622
        r.close()
623
        if not props:
624
            return None
625
        mtime = props[MTIME]
626

    
627
        # First level, just under node (get population).
628
        v = self.versions.alias('v')
629
        s = select([func.count(v.c.serial),
630
                    func.sum(v.c.size),
631
                    func.max(v.c.mtime)])
632
        if before != inf:
633
            c1 = select([func.max(self.versions.c.serial)])
634
            c1 = c1.where(self.versions.c.mtime < before)
635
            c1.where(self.versions.c.node == v.c.node)
636
        else:
637
            c1 = select([self.nodes.c.latest_version])
638
            c1 = c1.where(self.nodes.c.node == v.c.node)
639
        c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
640
        s = s.where(and_(v.c.serial == c1,
641
                         v.c.cluster != except_cluster,
642
                         v.c.node.in_(c2)))
643
        rp = self.conn.execute(s)
644
        r = rp.fetchone()
645
        rp.close()
646
        if not r:
647
            return None
648
        count = r[0]
649
        mtime = max(mtime, r[2])
650
        if count == 0:
651
            return (0, 0, mtime)
652

    
653
        # All children (get size and mtime).
654
        # This is why the full path is stored.
655
        s = select([func.count(v.c.serial),
656
                    func.sum(v.c.size),
657
                    func.max(v.c.mtime)])
658
        if before != inf:
659
            c1 = select([func.max(self.versions.c.serial)],
660
                        self.versions.c.node == v.c.node)
661
            c1 = c1.where(self.versions.c.mtime < before)
662
        else:
663
            c1 = select([self.nodes.c.latest_version],
664
                        self.nodes.c.node == v.c.node)
665
        c2 = select([self.nodes.c.node], self.nodes.c.path.like(
666
            self.escape_like(path) + '%', escape=ESCAPE_CHAR))
667
        s = s.where(and_(v.c.serial == c1,
668
                         v.c.cluster != except_cluster,
669
                         v.c.node.in_(c2)))
670
        rp = self.conn.execute(s)
671
        r = rp.fetchone()
672
        rp.close()
673
        if not r:
674
            return None
675
        size = r[1] - props[SIZE]
676
        mtime = max(mtime, r[2])
677
        return (count, size, mtime)
678

    
679
    def nodes_set_latest_version(self, node, serial):
680
        s = self.nodes.update().where(self.nodes.c.node == node)
681
        s = s.values(latest_version=serial)
682
        self.conn.execute(s).close()
683

    
684
    def version_create(self, node, hash, size, type, source, muser, uuid,
685
                       checksum, cluster=0,
686
                       update_statistics_ancestors_depth=None):
687
        """Create a new version from the given properties.
688
           Return the (serial, mtime) of the new version.
689
        """
690

    
691
        mtime = time()
692
        s = self.versions.insert(
693
        ).values(node=node, hash=hash, size=size, type=type, source=source,
694
                 mtime=mtime, muser=muser, uuid=uuid, checksum=checksum, cluster=cluster)
695
        serial = self.conn.execute(s).inserted_primary_key[0]
696
        self.statistics_update_ancestors(node, 1, size, mtime, cluster,
697
                                         update_statistics_ancestors_depth)
698

    
699
        self.nodes_set_latest_version(node, serial)
700

    
701
        return serial, mtime
702

    
703
    def version_lookup(self, node, before=inf, cluster=0, all_props=True):
704
        """Lookup the current version of the given node.
705
           Return a list with its properties:
706
           (serial, node, hash, size, type, source, mtime,
707
            muser, uuid, checksum, cluster)
708
           or None if the current version is not found in the given cluster.
709
        """
710

    
711
        v = self.versions.alias('v')
712
        if not all_props:
713
            s = select([v.c.serial])
714
        else:
715
            s = select([v.c.serial, v.c.node, v.c.hash,
716
                        v.c.size, v.c.type, v.c.source,
717
                        v.c.mtime, v.c.muser, v.c.uuid,
718
                        v.c.checksum, v.c.cluster])
719
        if before != inf:
720
            c = select([func.max(self.versions.c.serial)],
721
                       self.versions.c.node == node)
722
            c = c.where(self.versions.c.mtime < before)
723
        else:
724
            c = select([self.nodes.c.latest_version],
725
                       self.nodes.c.node == node)
726
        s = s.where(and_(v.c.serial == c,
727
                         v.c.cluster == cluster))
728
        r = self.conn.execute(s)
729
        props = r.fetchone()
730
        r.close()
731
        if props:
732
            return props
733
        return None
734

    
735
    def version_lookup_bulk(self, nodes, before=inf, cluster=0, all_props=True):
736
        """Lookup the current versions of the given nodes.
737
           Return a list with their properties:
738
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
739
        """
740
        if not nodes:
741
            return ()
742
        v = self.versions.alias('v')
743
        if not all_props:
744
            s = select([v.c.serial])
745
        else:
746
            s = select([v.c.serial, v.c.node, v.c.hash,
747
                        v.c.size, v.c.type, v.c.source,
748
                        v.c.mtime, v.c.muser, v.c.uuid,
749
                        v.c.checksum, v.c.cluster])
750
        if before != inf:
751
            c = select([func.max(self.versions.c.serial)],
752
                       self.versions.c.node.in_(nodes))
753
            c = c.where(self.versions.c.mtime < before)
754
            c = c.group_by(self.versions.c.node)
755
        else:
756
            c = select([self.nodes.c.latest_version],
757
                       self.nodes.c.node.in_(nodes))
758
        s = s.where(and_(v.c.serial.in_(c),
759
                         v.c.cluster == cluster))
760
        s = s.order_by(v.c.node)
761
        r = self.conn.execute(s)
762
        rproxy = r.fetchall()
763
        r.close()
764
        return (tuple(row.values()) for row in rproxy)
765

    
766
    def version_get_properties(self, serial, keys=(), propnames=_propnames):
767
        """Return a sequence of values for the properties of
768
           the version specified by serial and the keys, in the order given.
769
           If keys is empty, return all properties in the order
770
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
771
        """
772

    
773
        v = self.versions.alias()
774
        s = select([v.c.serial, v.c.node, v.c.hash,
775
                    v.c.size, v.c.type, v.c.source,
776
                    v.c.mtime, v.c.muser, v.c.uuid,
777
                    v.c.checksum, v.c.cluster], v.c.serial == serial)
778
        rp = self.conn.execute(s)
779
        r = rp.fetchone()
780
        rp.close()
781
        if r is None:
782
            return r
783

    
784
        if not keys:
785
            return r
786
        return [r[propnames[k]] for k in keys if k in propnames]
787

    
788
    def version_put_property(self, serial, key, value):
789
        """Set value for the property of version specified by key."""
790

    
791
        if key not in _propnames:
792
            return
793
        s = self.versions.update()
794
        s = s.where(self.versions.c.serial == serial)
795
        s = s.values(**{key: value})
796
        self.conn.execute(s).close()
797

    
798
    def version_recluster(self, serial, cluster,
799
                          update_statistics_ancestors_depth=None):
800
        """Move the version into another cluster."""
801

    
802
        props = self.version_get_properties(serial)
803
        if not props:
804
            return
805
        node = props[NODE]
806
        size = props[SIZE]
807
        oldcluster = props[CLUSTER]
808
        if cluster == oldcluster:
809
            return
810

    
811
        mtime = time()
812
        self.statistics_update_ancestors(node, -1, -size, mtime, oldcluster,
813
                                         update_statistics_ancestors_depth)
814
        self.statistics_update_ancestors(node, 1, size, mtime, cluster,
815
                                         update_statistics_ancestors_depth)
816

    
817
        s = self.versions.update()
818
        s = s.where(self.versions.c.serial == serial)
819
        s = s.values(cluster=cluster)
820
        self.conn.execute(s).close()
821

    
822
    def version_remove(self, serial, update_statistics_ancestors_depth=None):
823
        """Remove the serial specified."""
824

    
825
        props = self.version_get_properties(serial)
826
        if not props:
827
            return
828
        node = props[NODE]
829
        hash = props[HASH]
830
        size = props[SIZE]
831
        cluster = props[CLUSTER]
832

    
833
        mtime = time()
834
        self.statistics_update_ancestors(node, -1, -size, mtime, cluster,
835
                                         update_statistics_ancestors_depth)
836

    
837
        s = self.versions.delete().where(self.versions.c.serial == serial)
838
        self.conn.execute(s).close()
839

    
840
        props = self.version_lookup(node, cluster=cluster, all_props=False)
841
        if props:
842
            self.nodes_set_latest_version(node, serial)
843

    
844
        return hash, size
845

    
846
    def attribute_get(self, serial, domain, keys=()):
847
        """Return a list of (key, value) pairs of the version specified by serial.
848
           If keys is empty, return all attributes.
849
           Othwerise, return only those specified.
850
        """
851

    
852
        if keys:
853
            attrs = self.attributes.alias()
854
            s = select([attrs.c.key, attrs.c.value])
855
            s = s.where(and_(attrs.c.key.in_(keys),
856
                             attrs.c.serial == serial,
857
                             attrs.c.domain == domain))
858
        else:
859
            attrs = self.attributes.alias()
860
            s = select([attrs.c.key, attrs.c.value])
861
            s = s.where(and_(attrs.c.serial == serial,
862
                             attrs.c.domain == domain))
863
        r = self.conn.execute(s)
864
        l = r.fetchall()
865
        r.close()
866
        return l
867

    
868
    def attribute_set(self, serial, domain, items):
869
        """Set the attributes of the version specified by serial.
870
           Receive attributes as an iterable of (key, value) pairs.
871
        """
872
        #insert or replace
873
        #TODO better upsert
874
        for k, v in items:
875
            s = self.attributes.update()
876
            s = s.where(and_(self.attributes.c.serial == serial,
877
                             self.attributes.c.domain == domain,
878
                             self.attributes.c.key == k))
879
            s = s.values(value=v)
880
            rp = self.conn.execute(s)
881
            rp.close()
882
            if rp.rowcount == 0:
883
                s = self.attributes.insert()
884
                s = s.values(serial=serial, domain=domain, key=k, value=v)
885
                self.conn.execute(s).close()
886

    
887
    def attribute_del(self, serial, domain, keys=()):
888
        """Delete attributes of the version specified by serial.
889
           If keys is empty, delete all attributes.
890
           Otherwise delete those specified.
891
        """
892

    
893
        if keys:
894
            #TODO more efficient way to do this?
895
            for key in keys:
896
                s = self.attributes.delete()
897
                s = s.where(and_(self.attributes.c.serial == serial,
898
                                 self.attributes.c.domain == domain,
899
                                 self.attributes.c.key == key))
900
                self.conn.execute(s).close()
901
        else:
902
            s = self.attributes.delete()
903
            s = s.where(and_(self.attributes.c.serial == serial,
904
                             self.attributes.c.domain == domain))
905
            self.conn.execute(s).close()
906

    
907
    def attribute_copy(self, source, dest):
908
        s = select(
909
            [dest, self.attributes.c.domain,
910
                self.attributes.c.key, self.attributes.c.value],
911
            self.attributes.c.serial == source)
912
        rp = self.conn.execute(s)
913
        attributes = rp.fetchall()
914
        rp.close()
915
        for dest, domain, k, v in attributes:
916
            #insert or replace
917
            s = self.attributes.update().where(and_(
918
                self.attributes.c.serial == dest,
919
                self.attributes.c.domain == domain,
920
                self.attributes.c.key == k))
921
            rp = self.conn.execute(s, value=v)
922
            rp.close()
923
            if rp.rowcount == 0:
924
                s = self.attributes.insert()
925
                values = {'serial': dest, 'domain': domain,
926
                          'key': k, 'value': v}
927
                self.conn.execute(s, values).close()
928

    
929
    def latest_attribute_keys(self, parent, domain, before=inf, except_cluster=0, pathq=None):
930
        """Return a list with all keys pairs defined
931
           for all latest versions under parent that
932
           do not belong to the cluster.
933
        """
934

    
935
        pathq = pathq or []
936

    
937
        # TODO: Use another table to store before=inf results.
938
        a = self.attributes.alias('a')
939
        v = self.versions.alias('v')
940
        n = self.nodes.alias('n')
941
        s = select([a.c.key]).distinct()
942
        if before != inf:
943
            filtered = select([func.max(self.versions.c.serial)])
944
            filtered = filtered.where(self.versions.c.mtime < before)
945
            filtered = filtered.where(self.versions.c.node == v.c.node)
946
        else:
947
            filtered = select([self.nodes.c.latest_version])
948
            filtered = filtered.where(self.nodes.c.node == v.c.node)
949
        s = s.where(v.c.serial == filtered)
950
        s = s.where(v.c.cluster != except_cluster)
951
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
952
                                        self.nodes.c.parent == parent)))
953
        s = s.where(a.c.serial == v.c.serial)
954
        s = s.where(a.c.domain == domain)
955
        s = s.where(n.c.node == v.c.node)
956
        conj = []
957
        for path, match in pathq:
958
            if match == MATCH_PREFIX:
959
                conj.append(
960
                    n.c.path.like(
961
                        self.escape_like(path) + '%',
962
                        escape=ESCAPE_CHAR
963
                    )
964
                )
965
            elif match == MATCH_EXACT:
966
                conj.append(n.c.path == path)
967
        if conj:
968
            s = s.where(or_(*conj))
969
        rp = self.conn.execute(s)
970
        rows = rp.fetchall()
971
        rp.close()
972
        return [r[0] for r in rows]
973

    
974
    def latest_version_list(self, parent, prefix='', delimiter=None,
975
                            start='', limit=10000, before=inf,
976
                            except_cluster=0, pathq=[], domain=None,
977
                            filterq=[], sizeq=None, all_props=False):
978
        """Return a (list of (path, serial) tuples, list of common prefixes)
979
           for the current versions of the paths with the given parent,
980
           matching the following criteria.
981

982
           The property tuple for a version is returned if all
983
           of these conditions are true:
984

985
                a. parent matches
986

987
                b. path > start
988

989
                c. path starts with prefix (and paths in pathq)
990

991
                d. version is the max up to before
992

993
                e. version is not in cluster
994

995
                f. the path does not have the delimiter occuring
996
                   after the prefix, or ends with the delimiter
997

998
                g. serial matches the attribute filter query.
999

1000
                   A filter query is a comma-separated list of
1001
                   terms in one of these three forms:
1002

1003
                   key
1004
                       an attribute with this key must exist
1005

1006
                   !key
1007
                       an attribute with this key must not exist
1008

1009
                   key ?op value
1010
                       the attribute with this key satisfies the value
1011
                       where ?op is one of ==, != <=, >=, <, >.
1012

1013
                h. the size is in the range set by sizeq
1014

1015
           The list of common prefixes includes the prefixes
1016
           matching up to the first delimiter after prefix,
1017
           and are reported only once, as "virtual directories".
1018
           The delimiter is included in the prefixes.
1019

1020
           If arguments are None, then the corresponding matching rule
1021
           will always match.
1022

1023
           Limit applies to the first list of tuples returned.
1024

1025
           If all_props is True, return all properties after path, not just serial.
1026
        """
1027

    
1028
        if not start or start < prefix:
1029
            start = strprevling(prefix)
1030
        nextling = strnextling(prefix)
1031

    
1032
        v = self.versions.alias('v')
1033
        n = self.nodes.alias('n')
1034
        if not all_props:
1035
            s = select([n.c.path, v.c.serial]).distinct()
1036
        else:
1037
            s = select([n.c.path,
1038
                        v.c.serial, v.c.node, v.c.hash,
1039
                        v.c.size, v.c.type, v.c.source,
1040
                        v.c.mtime, v.c.muser, v.c.uuid,
1041
                        v.c.checksum, v.c.cluster]).distinct()
1042
        if before != inf:
1043
            filtered = select([func.max(self.versions.c.serial)])
1044
            filtered = filtered.where(self.versions.c.mtime < before)
1045
        else:
1046
            filtered = select([self.nodes.c.latest_version])
1047
        s = s.where(
1048
            v.c.serial == filtered.where(self.nodes.c.node == v.c.node))
1049
        s = s.where(v.c.cluster != except_cluster)
1050
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
1051
                                        self.nodes.c.parent == parent)))
1052

    
1053
        s = s.where(n.c.node == v.c.node)
1054
        s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling))
1055
        conj = []
1056
        for path, match in pathq:
1057
            if match == MATCH_PREFIX:
1058
                conj.append(
1059
                    n.c.path.like(
1060
                        self.escape_like(path) + '%',
1061
                        escape=ESCAPE_CHAR
1062
                    )
1063
                )
1064
            elif match == MATCH_EXACT:
1065
                conj.append(n.c.path == path)
1066
        if conj:
1067
            s = s.where(or_(*conj))
1068

    
1069
        if sizeq and len(sizeq) == 2:
1070
            if sizeq[0]:
1071
                s = s.where(v.c.size >= sizeq[0])
1072
            if sizeq[1]:
1073
                s = s.where(v.c.size < sizeq[1])
1074

    
1075
        if domain and filterq:
1076
            a = self.attributes.alias('a')
1077
            included, excluded, opers = parse_filters(filterq)
1078
            if included:
1079
                subs = select([1])
1080
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1081
                subs = subs.where(a.c.domain == domain)
1082
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in included]))
1083
                s = s.where(exists(subs))
1084
            if excluded:
1085
                subs = select([1])
1086
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1087
                subs = subs.where(a.c.domain == domain)
1088
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in excluded]))
1089
                s = s.where(not_(exists(subs)))
1090
            if opers:
1091
                for k, o, val in opers:
1092
                    subs = select([1])
1093
                    subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1094
                    subs = subs.where(a.c.domain == domain)
1095
                    subs = subs.where(
1096
                        and_(a.c.key.op('=')(k), a.c.value.op(o)(val)))
1097
                    s = s.where(exists(subs))
1098

    
1099
        s = s.order_by(n.c.path)
1100

    
1101
        if not delimiter:
1102
            s = s.limit(limit)
1103
            rp = self.conn.execute(s, start=start)
1104
            r = rp.fetchall()
1105
            rp.close()
1106
            return r, ()
1107

    
1108
        pfz = len(prefix)
1109
        dz = len(delimiter)
1110
        count = 0
1111
        prefixes = []
1112
        pappend = prefixes.append
1113
        matches = []
1114
        mappend = matches.append
1115

    
1116
        rp = self.conn.execute(s, start=start)
1117
        while True:
1118
            props = rp.fetchone()
1119
            if props is None:
1120
                break
1121
            path = props[0]
1122
            serial = props[1]
1123
            idx = path.find(delimiter, pfz)
1124

    
1125
            if idx < 0:
1126
                mappend(props)
1127
                count += 1
1128
                if count >= limit:
1129
                    break
1130
                continue
1131

    
1132
            if idx + dz == len(path):
1133
                mappend(props)
1134
                count += 1
1135
                continue  # Get one more, in case there is a path.
1136
            pf = path[:idx + dz]
1137
            pappend(pf)
1138
            if count >= limit:
1139
                break
1140

    
1141
            rp = self.conn.execute(s, start=strnextling(pf))  # New start.
1142
        rp.close()
1143

    
1144
        return matches, prefixes
1145

    
1146
    def latest_uuid(self, uuid, cluster):
1147
        """Return the latest version of the given uuid and cluster.
1148

1149
        Return a (path, serial) tuple.
1150
        If cluster is None, all clusters are considered.
1151

1152
        """
1153

    
1154
        v = self.versions.alias('v')
1155
        n = self.nodes.alias('n')
1156
        s = select([n.c.path, v.c.serial])
1157
        filtered = select([func.max(self.versions.c.serial)])
1158
        filtered = filtered.where(self.versions.c.uuid == uuid)
1159
        if cluster is not None:
1160
            filtered = filtered.where(self.versions.c.cluster == cluster)
1161
        s = s.where(v.c.serial == filtered)
1162
        s = s.where(n.c.node == v.c.node)
1163

    
1164
        r = self.conn.execute(s)
1165
        l = r.fetchone()
1166
        r.close()
1167
        return l
1168

    
1169
    def domain_object_list(self, domain, cluster=None):
1170
        """Return a list of (path, property list, attribute dictionary)
1171
           for the objects in the specific domain and cluster.
1172
        """
1173

    
1174
        v = self.versions.alias('v')
1175
        n = self.nodes.alias('n')
1176
        a = self.attributes.alias('a')
1177

    
1178
        s = select([n.c.path, v.c.serial, v.c.node, v.c.hash, v.c.size,
1179
                    v.c.type, v.c.source, v.c.mtime, v.c.muser, v.c.uuid,
1180
                    v.c.checksum, v.c.cluster, a.c.key, a.c.value])
1181
        s = s.where(n.c.node == v.c.node)
1182
        s = s.where(n.c.latest_version == v.c.serial)
1183
        if cluster:
1184
            s = s.where(v.c.cluster == cluster)
1185
        s = s.where(v.c.serial == a.c.serial)
1186
        s = s.where(a.c.domain == domain)
1187

    
1188
        r = self.conn.execute(s)
1189
        rows = r.fetchall()
1190
        r.close()
1191

    
1192
        group_by = itemgetter(slice(12))
1193
        rows.sort(key = group_by)
1194
        groups = groupby(rows, group_by)
1195
        return [(k[0], k[1:], dict([i[12:] for i in data])) \
1196
            for (k, data) in groups]