Statistics
| Branch: | Tag: | Revision:

root / snf-pithos-backend / pithos / backends / lib / sqlalchemy / node.py @ 0f510652

History | View | Annotate | Download (44.6 kB)

1
# Copyright 2011-2012 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
from time import time
35
from operator import itemgetter
36
from itertools import groupby
37

    
38
from sqlalchemy import Table, Integer, BigInteger, DECIMAL, Column, String, MetaData, ForeignKey
39
from sqlalchemy.types import Text
40
from sqlalchemy.schema import Index, Sequence
41
from sqlalchemy.sql import func, and_, or_, not_, null, select, bindparam, text, exists
42
from sqlalchemy.ext.compiler import compiles
43
from sqlalchemy.engine.reflection import Inspector
44
from sqlalchemy.exc import NoSuchTableError
45

    
46
from dbworker import DBWorker, ESCAPE_CHAR
47

    
48
from pithos.backends.filter import parse_filters
49

    
50

    
51
ROOTNODE = 0
52

    
53
(SERIAL, NODE, HASH, SIZE, TYPE, SOURCE, MTIME, MUSER, UUID, CHECKSUM,
54
 CLUSTER) = range(11)
55

    
56
(MATCH_PREFIX, MATCH_EXACT) = range(2)
57

    
58
inf = float('inf')
59

    
60

    
61
def strnextling(prefix):
62
    """Return the first unicode string
63
       greater than but not starting with given prefix.
64
       strnextling('hello') -> 'hellp'
65
    """
66
    if not prefix:
67
        ## all strings start with the null string,
68
        ## therefore we have to approximate strnextling('')
69
        ## with the last unicode character supported by python
70
        ## 0x10ffff for wide (32-bit unicode) python builds
71
        ## 0x00ffff for narrow (16-bit unicode) python builds
72
        ## We will not autodetect. 0xffff is safe enough.
73
        return unichr(0xffff)
74
    s = prefix[:-1]
75
    c = ord(prefix[-1])
76
    if c >= 0xffff:
77
        raise RuntimeError
78
    s += unichr(c + 1)
79
    return s
80

    
81

    
82
def strprevling(prefix):
83
    """Return an approximation of the last unicode string
84
       less than but not starting with given prefix.
85
       strprevling(u'hello') -> u'helln\\xffff'
86
    """
87
    if not prefix:
88
        ## There is no prevling for the null string
89
        return prefix
90
    s = prefix[:-1]
91
    c = ord(prefix[-1])
92
    if c > 0:
93
        s += unichr(c - 1) + unichr(0xffff)
94
    return s
95

    
96
_propnames = {
97
    'serial': 0,
98
    'node': 1,
99
    'hash': 2,
100
    'size': 3,
101
    'type': 4,
102
    'source': 5,
103
    'mtime': 6,
104
    'muser': 7,
105
    'uuid': 8,
106
    'checksum': 9,
107
    'cluster': 10
108
}
109

    
110

    
111
def create_tables(engine):
112
    metadata = MetaData()
113

    
114
    #create nodes table
115
    columns = []
116
    columns.append(Column('node', Integer, primary_key=True))
117
    columns.append(Column('parent', Integer,
118
                          ForeignKey('nodes.node',
119
                                     ondelete='CASCADE',
120
                                     onupdate='CASCADE'),
121
                          autoincrement=False))
122
    columns.append(Column('latest_version', Integer))
123
    columns.append(Column('path', String(2048), default='', nullable=False))
124
    nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
125
    Index('idx_nodes_path', nodes.c.path, unique=True)
126
    Index('idx_nodes_parent', nodes.c.parent)
127

    
128
    #create policy table
129
    columns = []
130
    columns.append(Column('node', Integer,
131
                          ForeignKey('nodes.node',
132
                                     ondelete='CASCADE',
133
                                     onupdate='CASCADE'),
134
                          primary_key=True))
135
    columns.append(Column('key', String(128), primary_key=True))
136
    columns.append(Column('value', String(256)))
137
    policy = Table('policy', metadata, *columns, mysql_engine='InnoDB')
138

    
139
    #create statistics table
140
    columns = []
141
    columns.append(Column('node', Integer,
142
                          ForeignKey('nodes.node',
143
                                     ondelete='CASCADE',
144
                                     onupdate='CASCADE'),
145
                          primary_key=True))
146
    columns.append(Column('population', Integer, nullable=False, default=0))
147
    columns.append(Column('size', BigInteger, nullable=False, default=0))
148
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
149
    columns.append(Column('cluster', Integer, nullable=False, default=0,
150
                          primary_key=True, autoincrement=False))
151
    statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
152

    
153
    #create versions table
154
    columns = []
155
    columns.append(Column('serial', Integer, primary_key=True))
156
    columns.append(Column('node', Integer,
157
                          ForeignKey('nodes.node',
158
                                     ondelete='CASCADE',
159
                                     onupdate='CASCADE')))
160
    columns.append(Column('hash', String(256)))
161
    columns.append(Column('size', BigInteger, nullable=False, default=0))
162
    columns.append(Column('type', String(256), nullable=False, default=''))
163
    columns.append(Column('source', Integer))
164
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
165
    columns.append(Column('muser', String(256), nullable=False, default=''))
166
    columns.append(Column('uuid', String(64), nullable=False, default=''))
167
    columns.append(Column('checksum', String(256), nullable=False, default=''))
168
    columns.append(Column('cluster', Integer, nullable=False, default=0))
169
    versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
170
    Index('idx_versions_node_mtime', versions.c.node, versions.c.mtime)
171
    Index('idx_versions_node_uuid', versions.c.uuid)
172

    
173
    #create attributes table
174
    columns = []
175
    columns.append(Column('serial', Integer,
176
                          ForeignKey('versions.serial',
177
                                     ondelete='CASCADE',
178
                                     onupdate='CASCADE'),
179
                          primary_key=True))
180
    columns.append(Column('domain', String(256), primary_key=True))
181
    columns.append(Column('key', String(128), primary_key=True))
182
    columns.append(Column('value', String(256)))
183
    attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
184
    Index('idx_attributes_domain', attributes.c.domain)
185

    
186
    metadata.create_all(engine)
187
    return metadata.sorted_tables
188

    
189

    
190
class Node(DBWorker):
191
    """Nodes store path organization and have multiple versions.
192
       Versions store object history and have multiple attributes.
193
       Attributes store metadata.
194
    """
195

    
196
    # TODO: Provide an interface for included and excluded clusters.
197

    
198
    def __init__(self, **params):
199
        DBWorker.__init__(self, **params)
200
        try:
201
            metadata = MetaData(self.engine)
202
            self.nodes = Table('nodes', metadata, autoload=True)
203
            self.policy = Table('policy', metadata, autoload=True)
204
            self.statistics = Table('statistics', metadata, autoload=True)
205
            self.versions = Table('versions', metadata, autoload=True)
206
            self.attributes = Table('attributes', metadata, autoload=True)
207
        except NoSuchTableError:
208
            tables = create_tables(self.engine)
209
            map(lambda t: self.__setattr__(t.name, t), tables)
210

    
211
        s = self.nodes.select().where(and_(self.nodes.c.node == ROOTNODE,
212
                                           self.nodes.c.parent == ROOTNODE))
213
        wrapper = self.wrapper
214
        wrapper.execute()
215
        try:
216
            rp = self.conn.execute(s)
217
            r = rp.fetchone()
218
            rp.close()
219
            if not r:
220
                s = self.nodes.insert(
221
                ).values(node=ROOTNODE, parent=ROOTNODE, path='')
222
                self.conn.execute(s)
223
        finally:
224
            wrapper.commit()
225

    
226
    def node_create(self, parent, path):
227
        """Create a new node from the given properties.
228
           Return the node identifier of the new node.
229
        """
230
        #TODO catch IntegrityError?
231
        s = self.nodes.insert().values(parent=parent, path=path)
232
        r = self.conn.execute(s)
233
        inserted_primary_key = r.inserted_primary_key[0]
234
        r.close()
235
        return inserted_primary_key
236

    
237
    def node_lookup(self, path, for_update=False):
238
        """Lookup the current node of the given path.
239
           Return None if the path is not found.
240
        """
241

    
242
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
243
        s = select([self.nodes.c.node], self.nodes.c.path.like(
244
            self.escape_like(path), escape=ESCAPE_CHAR), for_update=for_update)
245
        r = self.conn.execute(s)
246
        row = r.fetchone()
247
        r.close()
248
        if row:
249
            return row[0]
250
        return None
251

    
252
    def node_lookup_bulk(self, paths):
253
        """Lookup the current nodes for the given paths.
254
           Return () if the path is not found.
255
        """
256

    
257
        if not paths:
258
            return ()
259
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
260
        s = select([self.nodes.c.node], self.nodes.c.path.in_(paths))
261
        r = self.conn.execute(s)
262
        rows = r.fetchall()
263
        r.close()
264
        return [row[0] for row in rows]
265

    
266
    def node_get_properties(self, node):
267
        """Return the node's (parent, path).
268
           Return None if the node is not found.
269
        """
270

    
271
        s = select([self.nodes.c.parent, self.nodes.c.path])
272
        s = s.where(self.nodes.c.node == node)
273
        r = self.conn.execute(s)
274
        l = r.fetchone()
275
        r.close()
276
        return l
277

    
278
    def node_get_versions(self, node, keys=(), propnames=_propnames):
279
        """Return the properties of all versions at node.
280
           If keys is empty, return all properties in the order
281
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
282
        """
283

    
284
        s = select([self.versions.c.serial,
285
                    self.versions.c.node,
286
                    self.versions.c.hash,
287
                    self.versions.c.size,
288
                    self.versions.c.type,
289
                    self.versions.c.source,
290
                    self.versions.c.mtime,
291
                    self.versions.c.muser,
292
                    self.versions.c.uuid,
293
                    self.versions.c.checksum,
294
                    self.versions.c.cluster], self.versions.c.node == node)
295
        s = s.order_by(self.versions.c.serial)
296
        r = self.conn.execute(s)
297
        rows = r.fetchall()
298
        r.close()
299
        if not rows:
300
            return rows
301

    
302
        if not keys:
303
            return rows
304

    
305
        return [[p[propnames[k]] for k in keys if k in propnames] for p in rows]
306

    
307
    def node_count_children(self, node):
308
        """Return node's child count."""
309

    
310
        s = select([func.count(self.nodes.c.node)])
311
        s = s.where(and_(self.nodes.c.parent == node,
312
                         self.nodes.c.node != ROOTNODE))
313
        r = self.conn.execute(s)
314
        row = r.fetchone()
315
        r.close()
316
        return row[0]
317

    
318
    def node_purge_children(self, parent, before=inf, cluster=0,
319
                            update_statistics_ancestors_depth=None):
320
        """Delete all versions with the specified
321
           parent and cluster, and return
322
           the hashes, the total size and the serials of versions deleted.
323
           Clears out nodes with no remaining versions.
324
        """
325
        #update statistics
326
        c1 = select([self.nodes.c.node],
327
                    self.nodes.c.parent == parent)
328
        where_clause = and_(self.versions.c.node.in_(c1),
329
                            self.versions.c.cluster == cluster)
330
        if before != inf:
331
            where_clause = and_(where_clause,
332
                                self.versions.c.mtime <= before)
333
        s = select([func.count(self.versions.c.serial),
334
                    func.sum(self.versions.c.size)])
335
        s = s.where(where_clause)
336
        r = self.conn.execute(s)
337
        row = r.fetchone()
338
        r.close()
339
        if not row:
340
            return (), 0, ()
341
        nr, size = row[0], row[1] if row[1] else 0
342
        mtime = time()
343
        self.statistics_update(parent, -nr, -size, mtime, cluster)
344
        self.statistics_update_ancestors(parent, -nr, -size, mtime, cluster,
345
                                         update_statistics_ancestors_depth)
346

    
347
        s = select([self.versions.c.hash, self.versions.c.serial])
348
        s = s.where(where_clause)
349
        r = self.conn.execute(s)
350
        hashes = []
351
        serials = []
352
        for row in r.fetchall():
353
            hashes += [row[0]]
354
            serials += [row[1]]
355
        r.close()
356

    
357
        #delete versions
358
        s = self.versions.delete().where(where_clause)
359
        r = self.conn.execute(s)
360
        r.close()
361

    
362
        #delete nodes
363
        s = select([self.nodes.c.node],
364
                   and_(self.nodes.c.parent == parent,
365
                        select([func.count(self.versions.c.serial)],
366
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
367
        rp = self.conn.execute(s)
368
        nodes = [r[0] for r in rp.fetchall()]
369
        rp.close()
370
        if nodes:
371
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
372
            self.conn.execute(s).close()
373

    
374
        return hashes, size, serials
375

    
376
    def node_purge(self, node, before=inf, cluster=0,
377
                   update_statistics_ancestors_depth=None):
378
        """Delete all versions with the specified
379
           node and cluster, and return
380
           the hashes and size of versions deleted.
381
           Clears out the node if it has no remaining versions.
382
        """
383

    
384
        #update statistics
385
        s = select([func.count(self.versions.c.serial),
386
                    func.sum(self.versions.c.size)])
387
        where_clause = and_(self.versions.c.node == node,
388
                            self.versions.c.cluster == cluster)
389
        if before != inf:
390
            where_clause = and_(where_clause,
391
                                self.versions.c.mtime <= before)
392
        s = s.where(where_clause)
393
        r = self.conn.execute(s)
394
        row = r.fetchone()
395
        nr, size = row[0], row[1]
396
        r.close()
397
        if not nr:
398
            return (), 0, ()
399
        mtime = time()
400
        self.statistics_update_ancestors(node, -nr, -size, mtime, cluster,
401
                                         update_statistics_ancestors_depth)
402

    
403
        s = select([self.versions.c.hash, self.versions.c.serial])
404
        s = s.where(where_clause)
405
        r = self.conn.execute(s)
406
        hashes = []
407
        serials = []
408
        for row in r.fetchall():
409
            hashes += [row[0]]
410
            serials += [row[1]]
411
        r.close()
412

    
413
        #delete versions
414
        s = self.versions.delete().where(where_clause)
415
        r = self.conn.execute(s)
416
        r.close()
417

    
418
        #delete nodes
419
        s = select([self.nodes.c.node],
420
                   and_(self.nodes.c.node == node,
421
                        select([func.count(self.versions.c.serial)],
422
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
423
        rp= self.conn.execute(s)
424
        nodes = [r[0] for r in rp.fetchall()]
425
        rp.close()
426
        if nodes:
427
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
428
            self.conn.execute(s).close()
429

    
430
        return hashes, size, serials
431

    
432
    def node_remove(self, node, update_statistics_ancestors_depth=None):
433
        """Remove the node specified.
434
           Return false if the node has children or is not found.
435
        """
436

    
437
        if self.node_count_children(node):
438
            return False
439

    
440
        mtime = time()
441
        s = select([func.count(self.versions.c.serial),
442
                    func.sum(self.versions.c.size),
443
                    self.versions.c.cluster])
444
        s = s.where(self.versions.c.node == node)
445
        s = s.group_by(self.versions.c.cluster)
446
        r = self.conn.execute(s)
447
        for population, size, cluster in r.fetchall():
448
            self.statistics_update_ancestors(
449
                node, -population, -size, mtime, cluster,
450
                update_statistics_ancestors_depth)
451
        r.close()
452

    
453
        s = self.nodes.delete().where(self.nodes.c.node == node)
454
        self.conn.execute(s).close()
455
        return True
456

    
457
    def node_accounts(self, accounts=()):
458
        s = select([self.nodes.c.path, self.nodes.c.node])
459
        s = s.where(and_(self.nodes.c.node != 0,
460
                         self.nodes.c.parent == 0))
461
        if accounts:
462
            s = s.where(self.nodes.c.path.in_(accounts))
463
        r = self.conn.execute(s)
464
        rows = r.fetchall()
465
        r.close()
466
        return rows
467

    
468
    def node_account_quotas(self):
469
        s = select([self.nodes.c.path, self.policy.c.value])
470
        s = s.where(and_(self.nodes.c.node != 0,
471
                         self.nodes.c.parent == 0))
472
        s = s.where(self.nodes.c.node == self.policy.c.node)
473
        s = s.where(self.policy.c.key == 'quota')
474
        r = self.conn.execute(s)
475
        rows = r.fetchall()
476
        r.close()
477
        return dict(rows)
478

    
479
    def node_account_usage(self, account_node, cluster):
480
        select_children = select(
481
            [self.nodes.c.node]).where(self.nodes.c.parent == account_node)
482
        select_descendants = select([self.nodes.c.node]).where(
483
            or_(self.nodes.c.parent.in_(select_children),
484
                self.nodes.c.node.in_(select_children)))
485
        s = select([func.sum(self.versions.c.size)])
486
        s = s.group_by(self.versions.c.cluster)
487
        s = s.where(self.nodes.c.node == self.versions.c.node)
488
        s = s.where(self.nodes.c.node.in_(select_descendants))
489
        s = s.where(self.versions.c.cluster == cluster)
490
        r = self.conn.execute(s)
491
        usage = r.fetchone()[0]
492
        r.close()
493
        return usage
494

    
495
    def policy_get(self, node):
496
        s = select([self.policy.c.key, self.policy.c.value],
497
                   self.policy.c.node == node)
498
        r = self.conn.execute(s)
499
        d = dict(r.fetchall())
500
        r.close()
501
        return d
502

    
503
    def policy_set(self, node, policy):
504
        #insert or replace
505
        for k, v in policy.iteritems():
506
            s = self.policy.update().where(and_(self.policy.c.node == node,
507
                                                self.policy.c.key == k))
508
            s = s.values(value=v)
509
            rp = self.conn.execute(s)
510
            rp.close()
511
            if rp.rowcount == 0:
512
                s = self.policy.insert()
513
                values = {'node': node, 'key': k, 'value': v}
514
                r = self.conn.execute(s, values)
515
                r.close()
516

    
517
    def statistics_get(self, node, cluster=0):
518
        """Return population, total size and last mtime
519
           for all versions under node that belong to the cluster.
520
        """
521

    
522
        s = select([self.statistics.c.population,
523
                    self.statistics.c.size,
524
                    self.statistics.c.mtime])
525
        s = s.where(and_(self.statistics.c.node == node,
526
                         self.statistics.c.cluster == cluster))
527
        r = self.conn.execute(s)
528
        row = r.fetchone()
529
        r.close()
530
        return row
531

    
532
    def statistics_update(self, node, population, size, mtime, cluster=0):
533
        """Update the statistics of the given node.
534
           Statistics keep track the population, total
535
           size of objects and mtime in the node's namespace.
536
           May be zero or positive or negative numbers.
537
        """
538
        s = select([self.statistics.c.population, self.statistics.c.size],
539
                   and_(self.statistics.c.node == node,
540
                        self.statistics.c.cluster == cluster))
541
        rp = self.conn.execute(s)
542
        r = rp.fetchone()
543
        rp.close()
544
        if not r:
545
            prepopulation, presize = (0, 0)
546
        else:
547
            prepopulation, presize = r
548
        population += prepopulation
549
        population = max(population, 0)
550
        size += presize
551

    
552
        #insert or replace
553
        #TODO better upsert
554
        u = self.statistics.update().where(and_(self.statistics.c.node == node,
555
                                           self.statistics.c.cluster == cluster))
556
        u = u.values(population=population, size=size, mtime=mtime)
557
        rp = self.conn.execute(u)
558
        rp.close()
559
        if rp.rowcount == 0:
560
            ins = self.statistics.insert()
561
            ins = ins.values(node=node, population=population, size=size,
562
                             mtime=mtime, cluster=cluster)
563
            self.conn.execute(ins).close()
564

    
565
    def statistics_update_ancestors(self, node, population, size, mtime,
566
                                    cluster=0, recursion_depth=None):
567
        """Update the statistics of the given node's parent.
568
           Then recursively update all parents up to the root
569
           or up to the ``recursion_depth`` (if not None).
570
           Population is not recursive.
571
        """
572

    
573
        i = 0
574
        while True:
575
            if node == ROOTNODE:
576
                break
577
            if recursion_depth and recursion_depth == i:
578
                break
579
            props = self.node_get_properties(node)
580
            if props is None:
581
                break
582
            parent, path = props
583
            self.statistics_update(parent, population, size, mtime, cluster)
584
            node = parent
585
            population = 0  # Population isn't recursive
586
            i += 1
587

    
588
    def statistics_latest(self, node, before=inf, except_cluster=0):
589
        """Return population, total size and last mtime
590
           for all latest versions under node that
591
           do not belong to the cluster.
592
        """
593

    
594
        # The node.
595
        props = self.node_get_properties(node)
596
        if props is None:
597
            return None
598
        parent, path = props
599

    
600
        # The latest version.
601
        s = select([self.versions.c.serial,
602
                    self.versions.c.node,
603
                    self.versions.c.hash,
604
                    self.versions.c.size,
605
                    self.versions.c.type,
606
                    self.versions.c.source,
607
                    self.versions.c.mtime,
608
                    self.versions.c.muser,
609
                    self.versions.c.uuid,
610
                    self.versions.c.checksum,
611
                    self.versions.c.cluster])
612
        if before != inf:
613
            filtered = select([func.max(self.versions.c.serial)],
614
                              self.versions.c.node == node)
615
            filtered = filtered.where(self.versions.c.mtime < before)
616
        else:
617
            filtered = select([self.nodes.c.latest_version],
618
                              self.versions.c.node == node)
619
        s = s.where(and_(self.versions.c.cluster != except_cluster,
620
                         self.versions.c.serial == filtered))
621
        r = self.conn.execute(s)
622
        props = r.fetchone()
623
        r.close()
624
        if not props:
625
            return None
626
        mtime = props[MTIME]
627

    
628
        # First level, just under node (get population).
629
        v = self.versions.alias('v')
630
        s = select([func.count(v.c.serial),
631
                    func.sum(v.c.size),
632
                    func.max(v.c.mtime)])
633
        if before != inf:
634
            c1 = select([func.max(self.versions.c.serial)])
635
            c1 = c1.where(self.versions.c.mtime < before)
636
            c1.where(self.versions.c.node == v.c.node)
637
        else:
638
            c1 = select([self.nodes.c.latest_version])
639
            c1.where(self.nodes.c.node == v.c.node)
640
        c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
641
        s = s.where(and_(v.c.serial == c1,
642
                         v.c.cluster != except_cluster,
643
                         v.c.node.in_(c2)))
644
        rp = self.conn.execute(s)
645
        r = rp.fetchone()
646
        rp.close()
647
        if not r:
648
            return None
649
        count = r[0]
650
        mtime = max(mtime, r[2])
651
        if count == 0:
652
            return (0, 0, mtime)
653

    
654
        # All children (get size and mtime).
655
        # This is why the full path is stored.
656
        s = select([func.count(v.c.serial),
657
                    func.sum(v.c.size),
658
                    func.max(v.c.mtime)])
659
        if before != inf:
660
            c1 = select([func.max(self.versions.c.serial)],
661
                        self.versions.c.node == v.c.node)
662
            c1 = c1.where(self.versions.c.mtime < before)
663
        else:
664
            c1 = select([self.nodes.c.serial],
665
                        self.nodes.c.node == v.c.node)
666
        c2 = select([self.nodes.c.node], self.nodes.c.path.like(
667
            self.escape_like(path) + '%', escape=ESCAPE_CHAR))
668
        s = s.where(and_(v.c.serial == c1,
669
                         v.c.cluster != except_cluster,
670
                         v.c.node.in_(c2)))
671
        rp = self.conn.execute(s)
672
        r = rp.fetchone()
673
        rp.close()
674
        if not r:
675
            return None
676
        size = r[1] - props[SIZE]
677
        mtime = max(mtime, r[2])
678
        return (count, size, mtime)
679

    
680
    def nodes_set_latest_version(self, node, serial):
681
        s = self.nodes.update().where(self.nodes.c.node == node)
682
        s = s.values(latest_version=serial)
683
        self.conn.execute(s).close()
684

    
685
    def version_create(self, node, hash, size, type, source, muser, uuid,
686
                       checksum, cluster=0,
687
                       update_statistics_ancestors_depth=None):
688
        """Create a new version from the given properties.
689
           Return the (serial, mtime) of the new version.
690
        """
691

    
692
        mtime = time()
693
        s = self.versions.insert(
694
        ).values(node=node, hash=hash, size=size, type=type, source=source,
695
                 mtime=mtime, muser=muser, uuid=uuid, checksum=checksum, cluster=cluster)
696
        serial = self.conn.execute(s).inserted_primary_key[0]
697
        self.statistics_update_ancestors(node, 1, size, mtime, cluster,
698
                                         update_statistics_ancestors_depth)
699

    
700
        self.nodes_set_latest_version(node, serial)
701

    
702
        return serial, mtime
703

    
704
    def version_lookup(self, node, before=inf, cluster=0, all_props=True):
705
        """Lookup the current version of the given node.
706
           Return a list with its properties:
707
           (serial, node, hash, size, type, source, mtime,
708
            muser, uuid, checksum, cluster)
709
           or None if the current version is not found in the given cluster.
710
        """
711

    
712
        v = self.versions.alias('v')
713
        if not all_props:
714
            s = select([v.c.serial])
715
        else:
716
            s = select([v.c.serial, v.c.node, v.c.hash,
717
                        v.c.size, v.c.type, v.c.source,
718
                        v.c.mtime, v.c.muser, v.c.uuid,
719
                        v.c.checksum, v.c.cluster])
720
        if before != inf:
721
            c = select([func.max(self.versions.c.serial)],
722
                       self.versions.c.node == node)
723
            c = c.where(self.versions.c.mtime < before)
724
        else:
725
            c = select([self.nodes.c.latest_version],
726
                       self.nodes.c.node == node)
727
        s = s.where(and_(v.c.serial == c,
728
                         v.c.cluster == cluster))
729
        r = self.conn.execute(s)
730
        props = r.fetchone()
731
        r.close()
732
        if props:
733
            return props
734
        return None
735

    
736
    def version_lookup_bulk(self, nodes, before=inf, cluster=0, all_props=True):
737
        """Lookup the current versions of the given nodes.
738
           Return a list with their properties:
739
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
740
        """
741
        if not nodes:
742
            return ()
743
        v = self.versions.alias('v')
744
        if not all_props:
745
            s = select([v.c.serial])
746
        else:
747
            s = select([v.c.serial, v.c.node, v.c.hash,
748
                        v.c.size, v.c.type, v.c.source,
749
                        v.c.mtime, v.c.muser, v.c.uuid,
750
                        v.c.checksum, v.c.cluster])
751
        if before != inf:
752
            c = select([func.max(self.versions.c.serial)],
753
                       self.versions.c.node.in_(nodes))
754
            c = c.where(self.versions.c.mtime < before)
755
            c = c.group_by(self.versions.c.node)
756
        else:
757
            c = select([self.nodes.c.latest_version],
758
                       self.nodes.c.node.in_(nodes))
759
        s = s.where(and_(v.c.serial.in_(c),
760
                         v.c.cluster == cluster))
761
        s = s.order_by(v.c.node)
762
        r = self.conn.execute(s)
763
        rproxy = r.fetchall()
764
        r.close()
765
        return (tuple(row.values()) for row in rproxy)
766

    
767
    def version_get_properties(self, serial, keys=(), propnames=_propnames):
768
        """Return a sequence of values for the properties of
769
           the version specified by serial and the keys, in the order given.
770
           If keys is empty, return all properties in the order
771
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
772
        """
773

    
774
        v = self.versions.alias()
775
        s = select([v.c.serial, v.c.node, v.c.hash,
776
                    v.c.size, v.c.type, v.c.source,
777
                    v.c.mtime, v.c.muser, v.c.uuid,
778
                    v.c.checksum, v.c.cluster], v.c.serial == serial)
779
        rp = self.conn.execute(s)
780
        r = rp.fetchone()
781
        rp.close()
782
        if r is None:
783
            return r
784

    
785
        if not keys:
786
            return r
787
        return [r[propnames[k]] for k in keys if k in propnames]
788

    
789
    def version_put_property(self, serial, key, value):
790
        """Set value for the property of version specified by key."""
791

    
792
        if key not in _propnames:
793
            return
794
        s = self.versions.update()
795
        s = s.where(self.versions.c.serial == serial)
796
        s = s.values(**{key: value})
797
        self.conn.execute(s).close()
798

    
799
    def version_recluster(self, serial, cluster,
800
                          update_statistics_ancestors_depth=None):
801
        """Move the version into another cluster."""
802

    
803
        props = self.version_get_properties(serial)
804
        if not props:
805
            return
806
        node = props[NODE]
807
        size = props[SIZE]
808
        oldcluster = props[CLUSTER]
809
        if cluster == oldcluster:
810
            return
811

    
812
        mtime = time()
813
        self.statistics_update_ancestors(node, -1, -size, mtime, oldcluster,
814
                                         update_statistics_ancestors_depth)
815
        self.statistics_update_ancestors(node, 1, size, mtime, cluster,
816
                                         update_statistics_ancestors_depth)
817

    
818
        s = self.versions.update()
819
        s = s.where(self.versions.c.serial == serial)
820
        s = s.values(cluster=cluster)
821
        self.conn.execute(s).close()
822

    
823
    def version_remove(self, serial, update_statistics_ancestors_depth=None):
824
        """Remove the serial specified."""
825

    
826
        props = self.version_get_properties(serial)
827
        if not props:
828
            return
829
        node = props[NODE]
830
        hash = props[HASH]
831
        size = props[SIZE]
832
        cluster = props[CLUSTER]
833

    
834
        mtime = time()
835
        self.statistics_update_ancestors(node, -1, -size, mtime, cluster,
836
                                         update_statistics_ancestors_depth)
837

    
838
        s = self.versions.delete().where(self.versions.c.serial == serial)
839
        self.conn.execute(s).close()
840

    
841
        props = self.version_lookup(node, cluster=cluster, all_props=False)
842
        if props:
843
            self.nodes_set_latest_version(node, serial)
844

    
845
        return hash, size
846

    
847
    def attribute_get(self, serial, domain, keys=()):
848
        """Return a list of (key, value) pairs of the version specified by serial.
849
           If keys is empty, return all attributes.
850
           Othwerise, return only those specified.
851
        """
852

    
853
        if keys:
854
            attrs = self.attributes.alias()
855
            s = select([attrs.c.key, attrs.c.value])
856
            s = s.where(and_(attrs.c.key.in_(keys),
857
                             attrs.c.serial == serial,
858
                             attrs.c.domain == domain))
859
        else:
860
            attrs = self.attributes.alias()
861
            s = select([attrs.c.key, attrs.c.value])
862
            s = s.where(and_(attrs.c.serial == serial,
863
                             attrs.c.domain == domain))
864
        r = self.conn.execute(s)
865
        l = r.fetchall()
866
        r.close()
867
        return l
868

    
869
    def attribute_set(self, serial, domain, items):
870
        """Set the attributes of the version specified by serial.
871
           Receive attributes as an iterable of (key, value) pairs.
872
        """
873
        #insert or replace
874
        #TODO better upsert
875
        for k, v in items:
876
            s = self.attributes.update()
877
            s = s.where(and_(self.attributes.c.serial == serial,
878
                             self.attributes.c.domain == domain,
879
                             self.attributes.c.key == k))
880
            s = s.values(value=v)
881
            rp = self.conn.execute(s)
882
            rp.close()
883
            if rp.rowcount == 0:
884
                s = self.attributes.insert()
885
                s = s.values(serial=serial, domain=domain, key=k, value=v)
886
                self.conn.execute(s).close()
887

    
888
    def attribute_del(self, serial, domain, keys=()):
889
        """Delete attributes of the version specified by serial.
890
           If keys is empty, delete all attributes.
891
           Otherwise delete those specified.
892
        """
893

    
894
        if keys:
895
            #TODO more efficient way to do this?
896
            for key in keys:
897
                s = self.attributes.delete()
898
                s = s.where(and_(self.attributes.c.serial == serial,
899
                                 self.attributes.c.domain == domain,
900
                                 self.attributes.c.key == key))
901
                self.conn.execute(s).close()
902
        else:
903
            s = self.attributes.delete()
904
            s = s.where(and_(self.attributes.c.serial == serial,
905
                             self.attributes.c.domain == domain))
906
            self.conn.execute(s).close()
907

    
908
    def attribute_copy(self, source, dest):
909
        s = select(
910
            [dest, self.attributes.c.domain,
911
                self.attributes.c.key, self.attributes.c.value],
912
            self.attributes.c.serial == source)
913
        rp = self.conn.execute(s)
914
        attributes = rp.fetchall()
915
        rp.close()
916
        for dest, domain, k, v in attributes:
917
            #insert or replace
918
            s = self.attributes.update().where(and_(
919
                self.attributes.c.serial == dest,
920
                self.attributes.c.domain == domain,
921
                self.attributes.c.key == k))
922
            rp = self.conn.execute(s, value=v)
923
            rp.close()
924
            if rp.rowcount == 0:
925
                s = self.attributes.insert()
926
                values = {'serial': dest, 'domain': domain,
927
                          'key': k, 'value': v}
928
                self.conn.execute(s, values).close()
929

    
930
    def latest_attribute_keys(self, parent, domain, before=inf, except_cluster=0, pathq=None):
931
        """Return a list with all keys pairs defined
932
           for all latest versions under parent that
933
           do not belong to the cluster.
934
        """
935

    
936
        pathq = pathq or []
937

    
938
        # TODO: Use another table to store before=inf results.
939
        a = self.attributes.alias('a')
940
        v = self.versions.alias('v')
941
        n = self.nodes.alias('n')
942
        s = select([a.c.key]).distinct()
943
        if before != inf:
944
            filtered = select([func.max(self.versions.c.serial)])
945
            filtered = filtered.where(self.versions.c.mtime < before)
946
            filtered = filtered.where(self.versions.c.node == v.c.node)
947
        else:
948
            filtered = select([self.nodes.c.latest_version])
949
            filtered = filtered.where(self.nodes.c.node == v.c.node)
950
        s = s.where(v.c.serial == filtered)
951
        s = s.where(v.c.cluster != except_cluster)
952
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
953
                                        self.nodes.c.parent == parent)))
954
        s = s.where(a.c.serial == v.c.serial)
955
        s = s.where(a.c.domain == domain)
956
        s = s.where(n.c.node == v.c.node)
957
        conj = []
958
        for path, match in pathq:
959
            if match == MATCH_PREFIX:
960
                conj.append(
961
                    n.c.path.like(
962
                        self.escape_like(path) + '%',
963
                        escape=ESCAPE_CHAR
964
                    )
965
                )
966
            elif match == MATCH_EXACT:
967
                conj.append(n.c.path == path)
968
        if conj:
969
            s = s.where(or_(*conj))
970
        rp = self.conn.execute(s)
971
        rows = rp.fetchall()
972
        rp.close()
973
        return [r[0] for r in rows]
974

    
975
    def latest_version_list(self, parent, prefix='', delimiter=None,
976
                            start='', limit=10000, before=inf,
977
                            except_cluster=0, pathq=[], domain=None,
978
                            filterq=[], sizeq=None, all_props=False):
979
        """Return a (list of (path, serial) tuples, list of common prefixes)
980
           for the current versions of the paths with the given parent,
981
           matching the following criteria.
982

983
           The property tuple for a version is returned if all
984
           of these conditions are true:
985

986
                a. parent matches
987

988
                b. path > start
989

990
                c. path starts with prefix (and paths in pathq)
991

992
                d. version is the max up to before
993

994
                e. version is not in cluster
995

996
                f. the path does not have the delimiter occuring
997
                   after the prefix, or ends with the delimiter
998

999
                g. serial matches the attribute filter query.
1000

1001
                   A filter query is a comma-separated list of
1002
                   terms in one of these three forms:
1003

1004
                   key
1005
                       an attribute with this key must exist
1006

1007
                   !key
1008
                       an attribute with this key must not exist
1009

1010
                   key ?op value
1011
                       the attribute with this key satisfies the value
1012
                       where ?op is one of ==, != <=, >=, <, >.
1013

1014
                h. the size is in the range set by sizeq
1015

1016
           The list of common prefixes includes the prefixes
1017
           matching up to the first delimiter after prefix,
1018
           and are reported only once, as "virtual directories".
1019
           The delimiter is included in the prefixes.
1020

1021
           If arguments are None, then the corresponding matching rule
1022
           will always match.
1023

1024
           Limit applies to the first list of tuples returned.
1025

1026
           If all_props is True, return all properties after path, not just serial.
1027
        """
1028

    
1029
        if not start or start < prefix:
1030
            start = strprevling(prefix)
1031
        nextling = strnextling(prefix)
1032

    
1033
        v = self.versions.alias('v')
1034
        n = self.nodes.alias('n')
1035
        if not all_props:
1036
            s = select([n.c.path, v.c.serial]).distinct()
1037
        else:
1038
            s = select([n.c.path,
1039
                        v.c.serial, v.c.node, v.c.hash,
1040
                        v.c.size, v.c.type, v.c.source,
1041
                        v.c.mtime, v.c.muser, v.c.uuid,
1042
                        v.c.checksum, v.c.cluster]).distinct()
1043
        if before != inf:
1044
            filtered = select([func.max(self.versions.c.serial)])
1045
            filtered = filtered.where(self.versions.c.mtime < before)
1046
        else:
1047
            filtered = select([self.nodes.c.latest_version])
1048
        s = s.where(
1049
            v.c.serial == filtered.where(self.nodes.c.node == v.c.node))
1050
        s = s.where(v.c.cluster != except_cluster)
1051
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
1052
                                        self.nodes.c.parent == parent)))
1053

    
1054
        s = s.where(n.c.node == v.c.node)
1055
        s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling))
1056
        conj = []
1057
        for path, match in pathq:
1058
            if match == MATCH_PREFIX:
1059
                conj.append(
1060
                    n.c.path.like(
1061
                        self.escape_like(path) + '%',
1062
                        escape=ESCAPE_CHAR
1063
                    )
1064
                )
1065
            elif match == MATCH_EXACT:
1066
                conj.append(n.c.path == path)
1067
        if conj:
1068
            s = s.where(or_(*conj))
1069

    
1070
        if sizeq and len(sizeq) == 2:
1071
            if sizeq[0]:
1072
                s = s.where(v.c.size >= sizeq[0])
1073
            if sizeq[1]:
1074
                s = s.where(v.c.size < sizeq[1])
1075

    
1076
        if domain and filterq:
1077
            a = self.attributes.alias('a')
1078
            included, excluded, opers = parse_filters(filterq)
1079
            if included:
1080
                subs = select([1])
1081
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1082
                subs = subs.where(a.c.domain == domain)
1083
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in included]))
1084
                s = s.where(exists(subs))
1085
            if excluded:
1086
                subs = select([1])
1087
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1088
                subs = subs.where(a.c.domain == domain)
1089
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in excluded]))
1090
                s = s.where(not_(exists(subs)))
1091
            if opers:
1092
                for k, o, val in opers:
1093
                    subs = select([1])
1094
                    subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1095
                    subs = subs.where(a.c.domain == domain)
1096
                    subs = subs.where(
1097
                        and_(a.c.key.op('=')(k), a.c.value.op(o)(val)))
1098
                    s = s.where(exists(subs))
1099

    
1100
        s = s.order_by(n.c.path)
1101

    
1102
        if not delimiter:
1103
            s = s.limit(limit)
1104
            rp = self.conn.execute(s, start=start)
1105
            r = rp.fetchall()
1106
            rp.close()
1107
            return r, ()
1108

    
1109
        pfz = len(prefix)
1110
        dz = len(delimiter)
1111
        count = 0
1112
        prefixes = []
1113
        pappend = prefixes.append
1114
        matches = []
1115
        mappend = matches.append
1116

    
1117
        rp = self.conn.execute(s, start=start)
1118
        while True:
1119
            props = rp.fetchone()
1120
            if props is None:
1121
                break
1122
            path = props[0]
1123
            serial = props[1]
1124
            idx = path.find(delimiter, pfz)
1125

    
1126
            if idx < 0:
1127
                mappend(props)
1128
                count += 1
1129
                if count >= limit:
1130
                    break
1131
                continue
1132

    
1133
            if idx + dz == len(path):
1134
                mappend(props)
1135
                count += 1
1136
                continue  # Get one more, in case there is a path.
1137
            pf = path[:idx + dz]
1138
            pappend(pf)
1139
            if count >= limit:
1140
                break
1141

    
1142
            rp = self.conn.execute(s, start=strnextling(pf))  # New start.
1143
        rp.close()
1144

    
1145
        return matches, prefixes
1146

    
1147
    def latest_uuid(self, uuid, cluster):
1148
        """Return the latest version of the given uuid and cluster.
1149

1150
        Return a (path, serial) tuple.
1151
        If cluster is None, all clusters are considered.
1152

1153
        """
1154

    
1155
        v = self.versions.alias('v')
1156
        n = self.nodes.alias('n')
1157
        s = select([n.c.path, v.c.serial])
1158
        filtered = select([func.max(self.versions.c.serial)])
1159
        filtered = filtered.where(self.versions.c.uuid == uuid)
1160
        if cluster is not None:
1161
            filtered = filtered.where(self.versions.c.cluster == cluster)
1162
        s = s.where(v.c.serial == filtered)
1163
        s = s.where(n.c.node == v.c.node)
1164

    
1165
        r = self.conn.execute(s)
1166
        l = r.fetchone()
1167
        r.close()
1168
        return l
1169

    
1170
    def domain_object_list(self, domain, cluster=None):
1171
        """Return a list of (path, property list, attribute dictionary)
1172
           for the objects in the specific domain and cluster.
1173
        """
1174

    
1175
        v = self.versions.alias('v')
1176
        n = self.nodes.alias('n')
1177
        a = self.attributes.alias('a')
1178

    
1179
        s = select([n.c.path, v.c.serial, v.c.node, v.c.hash, v.c.size,
1180
                    v.c.type, v.c.source, v.c.mtime, v.c.muser, v.c.uuid,
1181
                    v.c.checksum, v.c.cluster, a.c.key, a.c.value])
1182
        s = s.where(n.c.node == v.c.node)
1183
        s = s.where(n.c.latest_version == v.c.serial)
1184
        if cluster:
1185
            s = s.where(v.c.cluster == cluster)
1186
        s = s.where(v.c.serial == a.c.serial)
1187
        s = s.where(a.c.domain == domain)
1188

    
1189
        r = self.conn.execute(s)
1190
        rows = r.fetchall()
1191
        r.close()
1192

    
1193
        group_by = itemgetter(slice(12))
1194
        rows.sort(key = group_by)
1195
        groups = groupby(rows, group_by)
1196
        return [(k[0], k[1:], dict([i[12:] for i in data])) \
1197
            for (k, data) in groups]