Statistics
| Branch: | Tag: | Revision:

root / snf-pithos-backend / pithos / backends / lib / sqlalchemy / node.py @ cc412b78

History | View | Annotate | Download (42.9 kB)

1
# Copyright 2011-2012 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
from time import time
35
from operator import itemgetter
36
from itertools import groupby
37

    
38
from sqlalchemy import Table, Integer, BigInteger, DECIMAL, Column, String, MetaData, ForeignKey
39
from sqlalchemy.types import Text
40
from sqlalchemy.schema import Index, Sequence
41
from sqlalchemy.sql import func, and_, or_, not_, null, select, bindparam, text, exists
42
from sqlalchemy.ext.compiler import compiles
43
from sqlalchemy.engine.reflection import Inspector
44
from sqlalchemy.exc import NoSuchTableError
45

    
46
from dbworker import DBWorker
47

    
48
from pithos.backends.filter import parse_filters
49

    
50

    
51
ROOTNODE = 0
52

    
53
(SERIAL, NODE, HASH, SIZE, TYPE, SOURCE, MTIME, MUSER, UUID, CHECKSUM,
54
 CLUSTER) = range(11)
55

    
56
(MATCH_PREFIX, MATCH_EXACT) = range(2)
57

    
58
inf = float('inf')
59

    
60

    
61
def strnextling(prefix):
62
    """Return the first unicode string
63
       greater than but not starting with given prefix.
64
       strnextling('hello') -> 'hellp'
65
    """
66
    if not prefix:
67
        ## all strings start with the null string,
68
        ## therefore we have to approximate strnextling('')
69
        ## with the last unicode character supported by python
70
        ## 0x10ffff for wide (32-bit unicode) python builds
71
        ## 0x00ffff for narrow (16-bit unicode) python builds
72
        ## We will not autodetect. 0xffff is safe enough.
73
        return unichr(0xffff)
74
    s = prefix[:-1]
75
    c = ord(prefix[-1])
76
    if c >= 0xffff:
77
        raise RuntimeError
78
    s += unichr(c + 1)
79
    return s
80

    
81

    
82
def strprevling(prefix):
83
    """Return an approximation of the last unicode string
84
       less than but not starting with given prefix.
85
       strprevling(u'hello') -> u'helln\\xffff'
86
    """
87
    if not prefix:
88
        ## There is no prevling for the null string
89
        return prefix
90
    s = prefix[:-1]
91
    c = ord(prefix[-1])
92
    if c > 0:
93
        s += unichr(c - 1) + unichr(0xffff)
94
    return s
95

    
96
_propnames = {
97
    'serial': 0,
98
    'node': 1,
99
    'hash': 2,
100
    'size': 3,
101
    'type': 4,
102
    'source': 5,
103
    'mtime': 6,
104
    'muser': 7,
105
    'uuid': 8,
106
    'checksum': 9,
107
    'cluster': 10
108
}
109

    
110

    
111
def create_tables(engine):
112
    metadata = MetaData()
113

    
114
    #create nodes table
115
    columns = []
116
    columns.append(Column('node', Integer, primary_key=True))
117
    columns.append(Column('parent', Integer,
118
                          ForeignKey('nodes.node',
119
                                     ondelete='CASCADE',
120
                                     onupdate='CASCADE'),
121
                          autoincrement=False))
122
    columns.append(Column('latest_version', Integer))
123
    columns.append(Column('path', String(2048), default='', nullable=False))
124
    nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
125
    Index('idx_nodes_path', nodes.c.path, unique=True)
126
    Index('idx_nodes_parent', nodes.c.parent)
127

    
128
    #create policy table
129
    columns = []
130
    columns.append(Column('node', Integer,
131
                          ForeignKey('nodes.node',
132
                                     ondelete='CASCADE',
133
                                     onupdate='CASCADE'),
134
                          primary_key=True))
135
    columns.append(Column('key', String(128), primary_key=True))
136
    columns.append(Column('value', String(256)))
137
    policy = Table('policy', metadata, *columns, mysql_engine='InnoDB')
138

    
139
    #create statistics table
140
    columns = []
141
    columns.append(Column('node', Integer,
142
                          ForeignKey('nodes.node',
143
                                     ondelete='CASCADE',
144
                                     onupdate='CASCADE'),
145
                          primary_key=True))
146
    columns.append(Column('population', Integer, nullable=False, default=0))
147
    columns.append(Column('size', BigInteger, nullable=False, default=0))
148
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
149
    columns.append(Column('cluster', Integer, nullable=False, default=0,
150
                          primary_key=True, autoincrement=False))
151
    statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
152

    
153
    #create versions table
154
    columns = []
155
    columns.append(Column('serial', Integer, primary_key=True))
156
    columns.append(Column('node', Integer,
157
                          ForeignKey('nodes.node',
158
                                     ondelete='CASCADE',
159
                                     onupdate='CASCADE')))
160
    columns.append(Column('hash', String(256)))
161
    columns.append(Column('size', BigInteger, nullable=False, default=0))
162
    columns.append(Column('type', String(256), nullable=False, default=''))
163
    columns.append(Column('source', Integer))
164
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
165
    columns.append(Column('muser', String(256), nullable=False, default=''))
166
    columns.append(Column('uuid', String(64), nullable=False, default=''))
167
    columns.append(Column('checksum', String(256), nullable=False, default=''))
168
    columns.append(Column('cluster', Integer, nullable=False, default=0))
169
    versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
170
    Index('idx_versions_node_mtime', versions.c.node, versions.c.mtime)
171
    Index('idx_versions_node_uuid', versions.c.uuid)
172

    
173
    #create attributes table
174
    columns = []
175
    columns.append(Column('serial', Integer,
176
                          ForeignKey('versions.serial',
177
                                     ondelete='CASCADE',
178
                                     onupdate='CASCADE'),
179
                          primary_key=True))
180
    columns.append(Column('domain', String(256), primary_key=True))
181
    columns.append(Column('key', String(128), primary_key=True))
182
    columns.append(Column('value', String(256)))
183
    attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
184
    Index('idx_attributes_domain', attributes.c.domain)
185

    
186
    metadata.create_all(engine)
187
    return metadata.sorted_tables
188

    
189

    
190
class Node(DBWorker):
191
    """Nodes store path organization and have multiple versions.
192
       Versions store object history and have multiple attributes.
193
       Attributes store metadata.
194
    """
195

    
196
    # TODO: Provide an interface for included and excluded clusters.
197

    
198
    def __init__(self, **params):
199
        DBWorker.__init__(self, **params)
200
        try:
201
            metadata = MetaData(self.engine)
202
            self.nodes = Table('nodes', metadata, autoload=True)
203
            self.policy = Table('policy', metadata, autoload=True)
204
            self.statistics = Table('statistics', metadata, autoload=True)
205
            self.versions = Table('versions', metadata, autoload=True)
206
            self.attributes = Table('attributes', metadata, autoload=True)
207
        except NoSuchTableError:
208
            tables = create_tables(self.engine)
209
            map(lambda t: self.__setattr__(t.name, t), tables)
210

    
211
        s = self.nodes.select().where(and_(self.nodes.c.node == ROOTNODE,
212
                                           self.nodes.c.parent == ROOTNODE))
213
        wrapper = self.wrapper
214
        wrapper.execute()
215
        try:
216
            rp = self.conn.execute(s)
217
            r = rp.fetchone()
218
            rp.close()
219
            if not r:
220
                s = self.nodes.insert(
221
                ).values(node=ROOTNODE, parent=ROOTNODE, path='')
222
                self.conn.execute(s)
223
        finally:
224
            wrapper.commit()
225

    
226
    def node_create(self, parent, path):
227
        """Create a new node from the given properties.
228
           Return the node identifier of the new node.
229
        """
230
        #TODO catch IntegrityError?
231
        s = self.nodes.insert().values(parent=parent, path=path)
232
        r = self.conn.execute(s)
233
        inserted_primary_key = r.inserted_primary_key[0]
234
        r.close()
235
        return inserted_primary_key
236

    
237
    def node_lookup(self, path):
238
        """Lookup the current node of the given path.
239
           Return None if the path is not found.
240
        """
241

    
242
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
243
        s = select([self.nodes.c.node], self.nodes.c.path.like(
244
            self.escape_like(path), escape='\\'))
245
        r = self.conn.execute(s)
246
        row = r.fetchone()
247
        r.close()
248
        if row:
249
            return row[0]
250
        return None
251

    
252
    def node_lookup_bulk(self, paths):
253
        """Lookup the current nodes for the given paths.
254
           Return () if the path is not found.
255
        """
256

    
257
        if not paths:
258
            return ()
259
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
260
        s = select([self.nodes.c.node], self.nodes.c.path.in_(paths))
261
        r = self.conn.execute(s)
262
        rows = r.fetchall()
263
        r.close()
264
        return [row[0] for row in rows]
265

    
266
    def node_get_properties(self, node):
267
        """Return the node's (parent, path).
268
           Return None if the node is not found.
269
        """
270

    
271
        s = select([self.nodes.c.parent, self.nodes.c.path])
272
        s = s.where(self.nodes.c.node == node)
273
        r = self.conn.execute(s)
274
        l = r.fetchone()
275
        r.close()
276
        return l
277

    
278
    def node_get_versions(self, node, keys=(), propnames=_propnames):
279
        """Return the properties of all versions at node.
280
           If keys is empty, return all properties in the order
281
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
282
        """
283

    
284
        s = select([self.versions.c.serial,
285
                    self.versions.c.node,
286
                    self.versions.c.hash,
287
                    self.versions.c.size,
288
                    self.versions.c.type,
289
                    self.versions.c.source,
290
                    self.versions.c.mtime,
291
                    self.versions.c.muser,
292
                    self.versions.c.uuid,
293
                    self.versions.c.checksum,
294
                    self.versions.c.cluster], self.versions.c.node == node)
295
        s = s.order_by(self.versions.c.serial)
296
        r = self.conn.execute(s)
297
        rows = r.fetchall()
298
        r.close()
299
        if not rows:
300
            return rows
301

    
302
        if not keys:
303
            return rows
304

    
305
        return [[p[propnames[k]] for k in keys if k in propnames] for p in rows]
306

    
307
    def node_count_children(self, node):
308
        """Return node's child count."""
309

    
310
        s = select([func.count(self.nodes.c.node)])
311
        s = s.where(and_(self.nodes.c.parent == node,
312
                         self.nodes.c.node != ROOTNODE))
313
        r = self.conn.execute(s)
314
        row = r.fetchone()
315
        r.close()
316
        return row[0]
317

    
318
    def node_purge_children(self, parent, before=inf, cluster=0):
319
        """Delete all versions with the specified
320
           parent and cluster, and return
321
           the hashes, the total size and the serials of versions deleted.
322
           Clears out nodes with no remaining versions.
323
        """
324
        #update statistics
325
        c1 = select([self.nodes.c.node],
326
                    self.nodes.c.parent == parent)
327
        where_clause = and_(self.versions.c.node.in_(c1),
328
                            self.versions.c.cluster == cluster)
329
        if before != inf:
330
            where_clause = and_(where_clause,
331
                                self.versions.c.mtime <= before)
332
        s = select([func.count(self.versions.c.serial),
333
                    func.sum(self.versions.c.size)])
334
        s = s.where(where_clause)
335
        r = self.conn.execute(s)
336
        row = r.fetchone()
337
        r.close()
338
        if not row:
339
            return (), 0, ()
340
        nr, size = row[0], row[1] if row[1] else 0
341
        mtime = time()
342
        self.statistics_update(parent, -nr, -size, mtime, cluster)
343
        self.statistics_update_ancestors(parent, -nr, -size, mtime, cluster)
344

    
345
        s = select([self.versions.c.hash, self.versions.c.serial])
346
        s = s.where(where_clause)
347
        r = self.conn.execute(s)
348
        hashes = []
349
        serials = []
350
        for row in r.fetchall():
351
            hashes += [row[0]]
352
            serials += [row[1]]
353
        r.close()
354

    
355
        #delete versions
356
        s = self.versions.delete().where(where_clause)
357
        r = self.conn.execute(s)
358
        r.close()
359

    
360
        #delete nodes
361
        s = select([self.nodes.c.node],
362
                   and_(self.nodes.c.parent == parent,
363
                        select([func.count(self.versions.c.serial)],
364
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
365
        rp = self.conn.execute(s)
366
        nodes = [r[0] for r in rp.fetchall()]
367
        rp.close()
368
        if nodes:
369
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
370
            self.conn.execute(s).close()
371

    
372
        return hashes, size, serials
373

    
374
    def node_purge(self, node, before=inf, cluster=0):
375
        """Delete all versions with the specified
376
           node and cluster, and return
377
           the hashes and size of versions deleted.
378
           Clears out the node if it has no remaining versions.
379
        """
380

    
381
        #update statistics
382
        s = select([func.count(self.versions.c.serial),
383
                    func.sum(self.versions.c.size)])
384
        where_clause = and_(self.versions.c.node == node,
385
                            self.versions.c.cluster == cluster)
386
        if before != inf:
387
            where_clause = and_(where_clause,
388
                                self.versions.c.mtime <= before)
389
        s = s.where(where_clause)
390
        r = self.conn.execute(s)
391
        row = r.fetchone()
392
        nr, size = row[0], row[1]
393
        r.close()
394
        if not nr:
395
            return (), 0, ()
396
        mtime = time()
397
        self.statistics_update_ancestors(node, -nr, -size, mtime, cluster)
398

    
399
        s = select([self.versions.c.hash, self.versions.c.serial])
400
        s = s.where(where_clause)
401
        r = self.conn.execute(s)
402
        hashes = []
403
        serials = []
404
        for row in r.fetchall():
405
            hashes += [row[0]]
406
            serials += [row[1]]
407
        r.close()
408

    
409
        #delete versions
410
        s = self.versions.delete().where(where_clause)
411
        r = self.conn.execute(s)
412
        r.close()
413

    
414
        #delete nodes
415
        s = select([self.nodes.c.node],
416
                   and_(self.nodes.c.node == node,
417
                        select([func.count(self.versions.c.serial)],
418
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
419
        rp= self.conn.execute(s)
420
        nodes = [r[0] for r in rp.fetchall()]
421
        rp.close()
422
        if nodes:
423
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
424
            self.conn.execute(s).close()
425

    
426
        return hashes, size, serials
427

    
428
    def node_remove(self, node):
429
        """Remove the node specified.
430
           Return false if the node has children or is not found.
431
        """
432

    
433
        if self.node_count_children(node):
434
            return False
435

    
436
        mtime = time()
437
        s = select([func.count(self.versions.c.serial),
438
                    func.sum(self.versions.c.size),
439
                    self.versions.c.cluster])
440
        s = s.where(self.versions.c.node == node)
441
        s = s.group_by(self.versions.c.cluster)
442
        r = self.conn.execute(s)
443
        for population, size, cluster in r.fetchall():
444
            self.statistics_update_ancestors(
445
                node, -population, -size, mtime, cluster)
446
        r.close()
447

    
448
        s = self.nodes.delete().where(self.nodes.c.node == node)
449
        self.conn.execute(s).close()
450
        return True
451

    
452
    def node_accounts(self, accounts=()):
453
        s = select([self.nodes.c.path, self.nodes.c.node])
454
        s = s.where(and_(self.nodes.c.node != 0,
455
                         self.nodes.c.parent == 0))
456
        if accounts:
457
            s = s.where(self.nodes.c.path.in_(accounts))
458
        r = self.conn.execute(s)
459
        rows = r.fetchall()
460
        r.close()
461
        return rows
462

    
463
    def node_account_usage(self, account_node, cluster):
464
        select_children = select(
465
            [self.nodes.c.node]).where(self.nodes.c.parent == account_node)
466
        select_descendants = select([self.nodes.c.node]).where(
467
            or_(self.nodes.c.parent.in_(select_children),
468
                self.nodes.c.node.in_(select_children)))
469
        s = select([func.sum(self.versions.c.size)])
470
        s = s.group_by(self.versions.c.cluster)
471
        s = s.where(self.nodes.c.node == self.versions.c.node)
472
        s = s.where(self.nodes.c.node.in_(select_descendants))
473
        s = s.where(self.versions.c.cluster == cluster)
474
        r = self.conn.execute(s)
475
        usage = r.fetchone()[0]
476
        r.close()
477
        return usage
478

    
479
    def policy_get(self, node):
480
        s = select([self.policy.c.key, self.policy.c.value],
481
                   self.policy.c.node == node)
482
        r = self.conn.execute(s)
483
        d = dict(r.fetchall())
484
        r.close()
485
        return d
486

    
487
    def policy_set(self, node, policy):
488
        #insert or replace
489
        for k, v in policy.iteritems():
490
            s = self.policy.update().where(and_(self.policy.c.node == node,
491
                                                self.policy.c.key == k))
492
            s = s.values(value=v)
493
            rp = self.conn.execute(s)
494
            rp.close()
495
            if rp.rowcount == 0:
496
                s = self.policy.insert()
497
                values = {'node': node, 'key': k, 'value': v}
498
                r = self.conn.execute(s, values)
499
                r.close()
500

    
501
    def statistics_get(self, node, cluster=0):
502
        """Return population, total size and last mtime
503
           for all versions under node that belong to the cluster.
504
        """
505

    
506
        s = select([self.statistics.c.population,
507
                    self.statistics.c.size,
508
                    self.statistics.c.mtime])
509
        s = s.where(and_(self.statistics.c.node == node,
510
                         self.statistics.c.cluster == cluster))
511
        r = self.conn.execute(s)
512
        row = r.fetchone()
513
        r.close()
514
        return row
515

    
516
    def statistics_update(self, node, population, size, mtime, cluster=0):
517
        """Update the statistics of the given node.
518
           Statistics keep track the population, total
519
           size of objects and mtime in the node's namespace.
520
           May be zero or positive or negative numbers.
521
        """
522
        s = select([self.statistics.c.population, self.statistics.c.size],
523
                   and_(self.statistics.c.node == node,
524
                        self.statistics.c.cluster == cluster))
525
        rp = self.conn.execute(s)
526
        r = rp.fetchone()
527
        rp.close()
528
        if not r:
529
            prepopulation, presize = (0, 0)
530
        else:
531
            prepopulation, presize = r
532
        population += prepopulation
533
        population = max(population, 0)
534
        size += presize
535

    
536
        #insert or replace
537
        #TODO better upsert
538
        u = self.statistics.update().where(and_(self.statistics.c.node == node,
539
                                           self.statistics.c.cluster == cluster))
540
        u = u.values(population=population, size=size, mtime=mtime)
541
        rp = self.conn.execute(u)
542
        rp.close()
543
        if rp.rowcount == 0:
544
            ins = self.statistics.insert()
545
            ins = ins.values(node=node, population=population, size=size,
546
                             mtime=mtime, cluster=cluster)
547
            self.conn.execute(ins).close()
548

    
549
    def statistics_update_ancestors(self, node, population, size, mtime, cluster=0):
550
        """Update the statistics of the given node's parent.
551
           Then recursively update all parents up to the root.
552
           Population is not recursive.
553
        """
554

    
555
        while True:
556
            if node == ROOTNODE:
557
                break
558
            props = self.node_get_properties(node)
559
            if props is None:
560
                break
561
            parent, path = props
562
            self.statistics_update(parent, population, size, mtime, cluster)
563
            node = parent
564
            population = 0  # Population isn't recursive
565

    
566
    def statistics_latest(self, node, before=inf, except_cluster=0):
567
        """Return population, total size and last mtime
568
           for all latest versions under node that
569
           do not belong to the cluster.
570
        """
571

    
572
        # The node.
573
        props = self.node_get_properties(node)
574
        if props is None:
575
            return None
576
        parent, path = props
577

    
578
        # The latest version.
579
        s = select([self.versions.c.serial,
580
                    self.versions.c.node,
581
                    self.versions.c.hash,
582
                    self.versions.c.size,
583
                    self.versions.c.type,
584
                    self.versions.c.source,
585
                    self.versions.c.mtime,
586
                    self.versions.c.muser,
587
                    self.versions.c.uuid,
588
                    self.versions.c.checksum,
589
                    self.versions.c.cluster])
590
        if before != inf:
591
            filtered = select([func.max(self.versions.c.serial)],
592
                              self.versions.c.node == node)
593
            filtered = filtered.where(self.versions.c.mtime < before)
594
        else:
595
            filtered = select([self.nodes.c.latest_version],
596
                              self.versions.c.node == node)
597
        s = s.where(and_(self.versions.c.cluster != except_cluster,
598
                         self.versions.c.serial == filtered))
599
        r = self.conn.execute(s)
600
        props = r.fetchone()
601
        r.close()
602
        if not props:
603
            return None
604
        mtime = props[MTIME]
605

    
606
        # First level, just under node (get population).
607
        v = self.versions.alias('v')
608
        s = select([func.count(v.c.serial),
609
                    func.sum(v.c.size),
610
                    func.max(v.c.mtime)])
611
        if before != inf:
612
            c1 = select([func.max(self.versions.c.serial)])
613
            c1 = c1.where(self.versions.c.mtime < before)
614
            c1.where(self.versions.c.node == v.c.node)
615
        else:
616
            c1 = select([self.nodes.c.latest_version])
617
            c1.where(self.nodes.c.node == v.c.node)
618
        c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
619
        s = s.where(and_(v.c.serial == c1,
620
                         v.c.cluster != except_cluster,
621
                         v.c.node.in_(c2)))
622
        rp = self.conn.execute(s)
623
        r = rp.fetchone()
624
        rp.close()
625
        if not r:
626
            return None
627
        count = r[0]
628
        mtime = max(mtime, r[2])
629
        if count == 0:
630
            return (0, 0, mtime)
631

    
632
        # All children (get size and mtime).
633
        # This is why the full path is stored.
634
        s = select([func.count(v.c.serial),
635
                    func.sum(v.c.size),
636
                    func.max(v.c.mtime)])
637
        if before != inf:
638
            c1 = select([func.max(self.versions.c.serial)],
639
                        self.versions.c.node == v.c.node)
640
            c1 = c1.where(self.versions.c.mtime < before)
641
        else:
642
            c1 = select([self.nodes.c.serial],
643
                        self.nodes.c.node == v.c.node)
644
        c2 = select([self.nodes.c.node], self.nodes.c.path.like(
645
            self.escape_like(path) + '%', escape='\\'))
646
        s = s.where(and_(v.c.serial == c1,
647
                         v.c.cluster != except_cluster,
648
                         v.c.node.in_(c2)))
649
        rp = self.conn.execute(s)
650
        r = rp.fetchone()
651
        rp.close()
652
        if not r:
653
            return None
654
        size = r[1] - props[SIZE]
655
        mtime = max(mtime, r[2])
656
        return (count, size, mtime)
657

    
658
    def nodes_set_latest_version(self, node, serial):
659
        s = self.nodes.update().where(self.nodes.c.node == node)
660
        s = s.values(latest_version=serial)
661
        self.conn.execute(s).close()
662

    
663
    def version_create(self, node, hash, size, type, source, muser, uuid, checksum, cluster=0):
664
        """Create a new version from the given properties.
665
           Return the (serial, mtime) of the new version.
666
        """
667

    
668
        mtime = time()
669
        s = self.versions.insert(
670
        ).values(node=node, hash=hash, size=size, type=type, source=source,
671
                 mtime=mtime, muser=muser, uuid=uuid, checksum=checksum, cluster=cluster)
672
        serial = self.conn.execute(s).inserted_primary_key[0]
673
        self.statistics_update_ancestors(node, 1, size, mtime, cluster)
674

    
675
        self.nodes_set_latest_version(node, serial)
676

    
677
        return serial, mtime
678

    
679
    def version_lookup(self, node, before=inf, cluster=0, all_props=True):
680
        """Lookup the current version of the given node.
681
           Return a list with its properties:
682
           (serial, node, hash, size, type, source, mtime,
683
            muser, uuid, checksum, cluster)
684
           or None if the current version is not found in the given cluster.
685
        """
686

    
687
        v = self.versions.alias('v')
688
        if not all_props:
689
            s = select([v.c.serial])
690
        else:
691
            s = select([v.c.serial, v.c.node, v.c.hash,
692
                        v.c.size, v.c.type, v.c.source,
693
                        v.c.mtime, v.c.muser, v.c.uuid,
694
                        v.c.checksum, v.c.cluster])
695
        if before != inf:
696
            c = select([func.max(self.versions.c.serial)],
697
                       self.versions.c.node == node)
698
            c = c.where(self.versions.c.mtime < before)
699
        else:
700
            c = select([self.nodes.c.latest_version],
701
                       self.nodes.c.node == node)
702
        s = s.where(and_(v.c.serial == c,
703
                         v.c.cluster == cluster))
704
        r = self.conn.execute(s)
705
        props = r.fetchone()
706
        r.close()
707
        if props:
708
            return props
709
        return None
710

    
711
    def version_lookup_bulk(self, nodes, before=inf, cluster=0, all_props=True):
712
        """Lookup the current versions of the given nodes.
713
           Return a list with their properties:
714
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
715
        """
716
        if not nodes:
717
            return ()
718
        v = self.versions.alias('v')
719
        if not all_props:
720
            s = select([v.c.serial])
721
        else:
722
            s = select([v.c.serial, v.c.node, v.c.hash,
723
                        v.c.size, v.c.type, v.c.source,
724
                        v.c.mtime, v.c.muser, v.c.uuid,
725
                        v.c.checksum, v.c.cluster])
726
        if before != inf:
727
            c = select([func.max(self.versions.c.serial)],
728
                       self.versions.c.node.in_(nodes))
729
            c = c.where(self.versions.c.mtime < before)
730
            c = c.group_by(self.versions.c.node)
731
        else:
732
            c = select([self.nodes.c.latest_version],
733
                       self.nodes.c.node.in_(nodes))
734
        s = s.where(and_(v.c.serial.in_(c),
735
                         v.c.cluster == cluster))
736
        s = s.order_by(v.c.node)
737
        r = self.conn.execute(s)
738
        rproxy = r.fetchall()
739
        r.close()
740
        return (tuple(row.values()) for row in rproxy)
741

    
742
    def version_get_properties(self, serial, keys=(), propnames=_propnames):
743
        """Return a sequence of values for the properties of
744
           the version specified by serial and the keys, in the order given.
745
           If keys is empty, return all properties in the order
746
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
747
        """
748

    
749
        v = self.versions.alias()
750
        s = select([v.c.serial, v.c.node, v.c.hash,
751
                    v.c.size, v.c.type, v.c.source,
752
                    v.c.mtime, v.c.muser, v.c.uuid,
753
                    v.c.checksum, v.c.cluster], v.c.serial == serial)
754
        rp = self.conn.execute(s)
755
        r = rp.fetchone()
756
        rp.close()
757
        if r is None:
758
            return r
759

    
760
        if not keys:
761
            return r
762
        return [r[propnames[k]] for k in keys if k in propnames]
763

    
764
    def version_put_property(self, serial, key, value):
765
        """Set value for the property of version specified by key."""
766

    
767
        if key not in _propnames:
768
            return
769
        s = self.versions.update()
770
        s = s.where(self.versions.c.serial == serial)
771
        s = s.values(**{key: value})
772
        self.conn.execute(s).close()
773

    
774
    def version_recluster(self, serial, cluster):
775
        """Move the version into another cluster."""
776

    
777
        props = self.version_get_properties(serial)
778
        if not props:
779
            return
780
        node = props[NODE]
781
        size = props[SIZE]
782
        oldcluster = props[CLUSTER]
783
        if cluster == oldcluster:
784
            return
785

    
786
        mtime = time()
787
        self.statistics_update_ancestors(node, -1, -size, mtime, oldcluster)
788
        self.statistics_update_ancestors(node, 1, size, mtime, cluster)
789

    
790
        s = self.versions.update()
791
        s = s.where(self.versions.c.serial == serial)
792
        s = s.values(cluster=cluster)
793
        self.conn.execute(s).close()
794

    
795
    def version_remove(self, serial):
796
        """Remove the serial specified."""
797

    
798
        props = self.version_get_properties(serial)
799
        if not props:
800
            return
801
        node = props[NODE]
802
        hash = props[HASH]
803
        size = props[SIZE]
804
        cluster = props[CLUSTER]
805

    
806
        mtime = time()
807
        self.statistics_update_ancestors(node, -1, -size, mtime, cluster)
808

    
809
        s = self.versions.delete().where(self.versions.c.serial == serial)
810
        self.conn.execute(s).close()
811

    
812
        props = self.version_lookup(node, cluster=cluster, all_props=False)
813
        if props:
814
            self.nodes_set_latest_version(node, serial)
815

    
816
        return hash, size
817

    
818
    def attribute_get(self, serial, domain, keys=()):
819
        """Return a list of (key, value) pairs of the version specified by serial.
820
           If keys is empty, return all attributes.
821
           Othwerise, return only those specified.
822
        """
823

    
824
        if keys:
825
            attrs = self.attributes.alias()
826
            s = select([attrs.c.key, attrs.c.value])
827
            s = s.where(and_(attrs.c.key.in_(keys),
828
                             attrs.c.serial == serial,
829
                             attrs.c.domain == domain))
830
        else:
831
            attrs = self.attributes.alias()
832
            s = select([attrs.c.key, attrs.c.value])
833
            s = s.where(and_(attrs.c.serial == serial,
834
                             attrs.c.domain == domain))
835
        r = self.conn.execute(s)
836
        l = r.fetchall()
837
        r.close()
838
        return l
839

    
840
    def attribute_set(self, serial, domain, items):
841
        """Set the attributes of the version specified by serial.
842
           Receive attributes as an iterable of (key, value) pairs.
843
        """
844
        #insert or replace
845
        #TODO better upsert
846
        for k, v in items:
847
            s = self.attributes.update()
848
            s = s.where(and_(self.attributes.c.serial == serial,
849
                             self.attributes.c.domain == domain,
850
                             self.attributes.c.key == k))
851
            s = s.values(value=v)
852
            rp = self.conn.execute(s)
853
            rp.close()
854
            if rp.rowcount == 0:
855
                s = self.attributes.insert()
856
                s = s.values(serial=serial, domain=domain, key=k, value=v)
857
                self.conn.execute(s).close()
858

    
859
    def attribute_del(self, serial, domain, keys=()):
860
        """Delete attributes of the version specified by serial.
861
           If keys is empty, delete all attributes.
862
           Otherwise delete those specified.
863
        """
864

    
865
        if keys:
866
            #TODO more efficient way to do this?
867
            for key in keys:
868
                s = self.attributes.delete()
869
                s = s.where(and_(self.attributes.c.serial == serial,
870
                                 self.attributes.c.domain == domain,
871
                                 self.attributes.c.key == key))
872
                self.conn.execute(s).close()
873
        else:
874
            s = self.attributes.delete()
875
            s = s.where(and_(self.attributes.c.serial == serial,
876
                             self.attributes.c.domain == domain))
877
            self.conn.execute(s).close()
878

    
879
    def attribute_copy(self, source, dest):
880
        s = select(
881
            [dest, self.attributes.c.domain,
882
                self.attributes.c.key, self.attributes.c.value],
883
            self.attributes.c.serial == source)
884
        rp = self.conn.execute(s)
885
        attributes = rp.fetchall()
886
        rp.close()
887
        for dest, domain, k, v in attributes:
888
            #insert or replace
889
            s = self.attributes.update().where(and_(
890
                self.attributes.c.serial == dest,
891
                self.attributes.c.domain == domain,
892
                self.attributes.c.key == k))
893
            rp = self.conn.execute(s, value=v)
894
            rp.close()
895
            if rp.rowcount == 0:
896
                s = self.attributes.insert()
897
                values = {'serial': dest, 'domain': domain,
898
                          'key': k, 'value': v}
899
                self.conn.execute(s, values).close()
900

    
901
    def latest_attribute_keys(self, parent, domain, before=inf, except_cluster=0, pathq=None):
902
        """Return a list with all keys pairs defined
903
           for all latest versions under parent that
904
           do not belong to the cluster.
905
        """
906

    
907
        pathq = pathq or []
908

    
909
        # TODO: Use another table to store before=inf results.
910
        a = self.attributes.alias('a')
911
        v = self.versions.alias('v')
912
        n = self.nodes.alias('n')
913
        s = select([a.c.key]).distinct()
914
        if before != inf:
915
            filtered = select([func.max(self.versions.c.serial)])
916
            filtered = filtered.where(self.versions.c.mtime < before)
917
            filtered = filtered.where(self.versions.c.node == v.c.node)
918
        else:
919
            filtered = select([self.nodes.c.latest_version])
920
            filtered = filtered.where(self.nodes.c.node == v.c.node)
921
        s = s.where(v.c.serial == filtered)
922
        s = s.where(v.c.cluster != except_cluster)
923
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
924
                                        self.nodes.c.parent == parent)))
925
        s = s.where(a.c.serial == v.c.serial)
926
        s = s.where(a.c.domain == domain)
927
        s = s.where(n.c.node == v.c.node)
928
        conj = []
929
        for path, match in pathq:
930
            if match == MATCH_PREFIX:
931
                conj.append(
932
                    n.c.path.like(self.escape_like(path) + '%', escape='\\'))
933
            elif match == MATCH_EXACT:
934
                conj.append(n.c.path == path)
935
        if conj:
936
            s = s.where(or_(*conj))
937
        rp = self.conn.execute(s)
938
        rows = rp.fetchall()
939
        rp.close()
940
        return [r[0] for r in rows]
941

    
942
    def latest_version_list(self, parent, prefix='', delimiter=None,
943
                            start='', limit=10000, before=inf,
944
                            except_cluster=0, pathq=[], domain=None,
945
                            filterq=[], sizeq=None, all_props=False):
946
        """Return a (list of (path, serial) tuples, list of common prefixes)
947
           for the current versions of the paths with the given parent,
948
           matching the following criteria.
949

950
           The property tuple for a version is returned if all
951
           of these conditions are true:
952

953
                a. parent matches
954

955
                b. path > start
956

957
                c. path starts with prefix (and paths in pathq)
958

959
                d. version is the max up to before
960

961
                e. version is not in cluster
962

963
                f. the path does not have the delimiter occuring
964
                   after the prefix, or ends with the delimiter
965

966
                g. serial matches the attribute filter query.
967

968
                   A filter query is a comma-separated list of
969
                   terms in one of these three forms:
970

971
                   key
972
                       an attribute with this key must exist
973

974
                   !key
975
                       an attribute with this key must not exist
976

977
                   key ?op value
978
                       the attribute with this key satisfies the value
979
                       where ?op is one of ==, != <=, >=, <, >.
980

981
                h. the size is in the range set by sizeq
982

983
           The list of common prefixes includes the prefixes
984
           matching up to the first delimiter after prefix,
985
           and are reported only once, as "virtual directories".
986
           The delimiter is included in the prefixes.
987

988
           If arguments are None, then the corresponding matching rule
989
           will always match.
990

991
           Limit applies to the first list of tuples returned.
992

993
           If all_props is True, return all properties after path, not just serial.
994
        """
995

    
996
        if not start or start < prefix:
997
            start = strprevling(prefix)
998
        nextling = strnextling(prefix)
999

    
1000
        v = self.versions.alias('v')
1001
        n = self.nodes.alias('n')
1002
        if not all_props:
1003
            s = select([n.c.path, v.c.serial]).distinct()
1004
        else:
1005
            s = select([n.c.path,
1006
                        v.c.serial, v.c.node, v.c.hash,
1007
                        v.c.size, v.c.type, v.c.source,
1008
                        v.c.mtime, v.c.muser, v.c.uuid,
1009
                        v.c.checksum, v.c.cluster]).distinct()
1010
        if before != inf:
1011
            filtered = select([func.max(self.versions.c.serial)])
1012
            filtered = filtered.where(self.versions.c.mtime < before)
1013
        else:
1014
            filtered = select([self.nodes.c.latest_version])
1015
        s = s.where(
1016
            v.c.serial == filtered.where(self.nodes.c.node == v.c.node))
1017
        s = s.where(v.c.cluster != except_cluster)
1018
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
1019
                                        self.nodes.c.parent == parent)))
1020

    
1021
        s = s.where(n.c.node == v.c.node)
1022
        s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling))
1023
        conj = []
1024
        for path, match in pathq:
1025
            if match == MATCH_PREFIX:
1026
                conj.append(
1027
                    n.c.path.like(self.escape_like(path) + '%', escape='\\'))
1028
            elif match == MATCH_EXACT:
1029
                conj.append(n.c.path == path)
1030
        if conj:
1031
            s = s.where(or_(*conj))
1032

    
1033
        if sizeq and len(sizeq) == 2:
1034
            if sizeq[0]:
1035
                s = s.where(v.c.size >= sizeq[0])
1036
            if sizeq[1]:
1037
                s = s.where(v.c.size < sizeq[1])
1038

    
1039
        if domain and filterq:
1040
            a = self.attributes.alias('a')
1041
            included, excluded, opers = parse_filters(filterq)
1042
            if included:
1043
                subs = select([1])
1044
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1045
                subs = subs.where(a.c.domain == domain)
1046
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in included]))
1047
                s = s.where(exists(subs))
1048
            if excluded:
1049
                subs = select([1])
1050
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1051
                subs = subs.where(a.c.domain == domain)
1052
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in excluded]))
1053
                s = s.where(not_(exists(subs)))
1054
            if opers:
1055
                for k, o, val in opers:
1056
                    subs = select([1])
1057
                    subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1058
                    subs = subs.where(a.c.domain == domain)
1059
                    subs = subs.where(
1060
                        and_(a.c.key.op('=')(k), a.c.value.op(o)(val)))
1061
                    s = s.where(exists(subs))
1062

    
1063
        s = s.order_by(n.c.path)
1064

    
1065
        if not delimiter:
1066
            s = s.limit(limit)
1067
            rp = self.conn.execute(s, start=start)
1068
            r = rp.fetchall()
1069
            rp.close()
1070
            return r, ()
1071

    
1072
        pfz = len(prefix)
1073
        dz = len(delimiter)
1074
        count = 0
1075
        prefixes = []
1076
        pappend = prefixes.append
1077
        matches = []
1078
        mappend = matches.append
1079

    
1080
        rp = self.conn.execute(s, start=start)
1081
        while True:
1082
            props = rp.fetchone()
1083
            if props is None:
1084
                break
1085
            path = props[0]
1086
            serial = props[1]
1087
            idx = path.find(delimiter, pfz)
1088

    
1089
            if idx < 0:
1090
                mappend(props)
1091
                count += 1
1092
                if count >= limit:
1093
                    break
1094
                continue
1095

    
1096
            if idx + dz == len(path):
1097
                mappend(props)
1098
                count += 1
1099
                continue  # Get one more, in case there is a path.
1100
            pf = path[:idx + dz]
1101
            pappend(pf)
1102
            if count >= limit:
1103
                break
1104

    
1105
            rp = self.conn.execute(s, start=strnextling(pf))  # New start.
1106
        rp.close()
1107

    
1108
        return matches, prefixes
1109

    
1110
    def latest_uuid(self, uuid, cluster):
1111
        """Return the latest version of the given uuid and cluster.
1112

1113
        Return a (path, serial) tuple.
1114
        If cluster is None, all clusters are considered.
1115

1116
        """
1117

    
1118
        v = self.versions.alias('v')
1119
        n = self.nodes.alias('n')
1120
        s = select([n.c.path, v.c.serial])
1121
        filtered = select([func.max(self.versions.c.serial)])
1122
        filtered = filtered.where(self.versions.c.uuid == uuid)
1123
        if cluster is not None:
1124
            filtered = filtered.where(self.versions.c.cluster == cluster)
1125
        s = s.where(v.c.serial == filtered)
1126
        s = s.where(n.c.node == v.c.node)
1127

    
1128
        r = self.conn.execute(s)
1129
        l = r.fetchone()
1130
        r.close()
1131
        return l
1132

    
1133
    def domain_object_list(self, domain, cluster=None):
1134
        """Return a list of (path, property list, attribute dictionary)
1135
           for the objects in the specific domain and cluster.
1136
        """
1137

    
1138
        v = self.versions.alias('v')
1139
        n = self.nodes.alias('n')
1140
        a = self.attributes.alias('a')
1141

    
1142
        s = select([n.c.path, v.c.serial, v.c.node, v.c.hash, v.c.size,
1143
                    v.c.type, v.c.source, v.c.mtime, v.c.muser, v.c.uuid,
1144
                    v.c.checksum, v.c.cluster, a.c.key, a.c.value])
1145
        s = s.where(n.c.node == v.c.node)
1146
        s = s.where(n.c.latest_version == v.c.serial)
1147
        if cluster:
1148
            s = s.where(v.c.cluster == cluster)
1149
        s = s.where(v.c.serial == a.c.serial)
1150
        s = s.where(a.c.domain == domain)
1151

    
1152
        r = self.conn.execute(s)
1153
        rows = r.fetchall()
1154
        r.close()
1155

    
1156
        group_by = itemgetter(slice(12))
1157
        rows.sort(key = group_by)
1158
        groups = groupby(rows, group_by)
1159
        return [(k[0], k[1:], dict([i[12:] for i in data])) \
1160
            for (k, data) in groups]