Statistics
| Branch: | Tag: | Revision:

root / snf-pithos-backend / pithos / backends / lib / sqlalchemy / node.py @ 0a92ff85

History | View | Annotate | Download (40.7 kB)

1
# Copyright 2011-2012 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
from time import time
35
from sqlalchemy import Table, Integer, BigInteger, DECIMAL, Column, String, MetaData, ForeignKey
36
from sqlalchemy.types import Text
37
from sqlalchemy.schema import Index, Sequence
38
from sqlalchemy.sql import func, and_, or_, not_, null, select, bindparam, text, exists
39
from sqlalchemy.ext.compiler import compiles
40
from sqlalchemy.engine.reflection import Inspector
41
from sqlalchemy.exc import NoSuchTableError
42

    
43
from dbworker import DBWorker
44

    
45
from pithos.backends.filter import parse_filters
46

    
47

    
48
ROOTNODE = 0
49

    
50
(SERIAL, NODE, HASH, SIZE, TYPE, SOURCE, MTIME, MUSER, UUID, CHECKSUM,
51
 CLUSTER) = range(11)
52

    
53
(MATCH_PREFIX, MATCH_EXACT) = range(2)
54

    
55
inf = float('inf')
56

    
57

    
58
def strnextling(prefix):
59
    """Return the first unicode string
60
       greater than but not starting with given prefix.
61
       strnextling('hello') -> 'hellp'
62
    """
63
    if not prefix:
64
        ## all strings start with the null string,
65
        ## therefore we have to approximate strnextling('')
66
        ## with the last unicode character supported by python
67
        ## 0x10ffff for wide (32-bit unicode) python builds
68
        ## 0x00ffff for narrow (16-bit unicode) python builds
69
        ## We will not autodetect. 0xffff is safe enough.
70
        return unichr(0xffff)
71
    s = prefix[:-1]
72
    c = ord(prefix[-1])
73
    if c >= 0xffff:
74
        raise RuntimeError
75
    s += unichr(c + 1)
76
    return s
77

    
78

    
79
def strprevling(prefix):
80
    """Return an approximation of the last unicode string
81
       less than but not starting with given prefix.
82
       strprevling(u'hello') -> u'helln\\xffff'
83
    """
84
    if not prefix:
85
        ## There is no prevling for the null string
86
        return prefix
87
    s = prefix[:-1]
88
    c = ord(prefix[-1])
89
    if c > 0:
90
        s += unichr(c - 1) + unichr(0xffff)
91
    return s
92

    
93
_propnames = {
94
    'serial': 0,
95
    'node': 1,
96
    'hash': 2,
97
    'size': 3,
98
    'type': 4,
99
    'source': 5,
100
    'mtime': 6,
101
    'muser': 7,
102
    'uuid': 8,
103
    'checksum': 9,
104
    'cluster': 10
105
}
106

    
107

    
108
def create_tables(engine):
109
    metadata = MetaData()
110

    
111
    #create nodes table
112
    columns = []
113
    columns.append(Column('node', Integer, primary_key=True))
114
    columns.append(Column('parent', Integer,
115
                          ForeignKey('nodes.node',
116
                                     ondelete='CASCADE',
117
                                     onupdate='CASCADE'),
118
                          autoincrement=False))
119
    columns.append(Column('latest_version', Integer))
120
    columns.append(Column('path', String(2048), default='', nullable=False))
121
    nodes = Table('nodes', metadata, *columns, mysql_engine='InnoDB')
122
    Index('idx_nodes_path', nodes.c.path, unique=True)
123
    Index('idx_nodes_parent', nodes.c.parent)
124

    
125
    #create policy table
126
    columns = []
127
    columns.append(Column('node', Integer,
128
                          ForeignKey('nodes.node',
129
                                     ondelete='CASCADE',
130
                                     onupdate='CASCADE'),
131
                          primary_key=True))
132
    columns.append(Column('key', String(128), primary_key=True))
133
    columns.append(Column('value', String(256)))
134
    policy = Table('policy', metadata, *columns, mysql_engine='InnoDB')
135

    
136
    #create statistics table
137
    columns = []
138
    columns.append(Column('node', Integer,
139
                          ForeignKey('nodes.node',
140
                                     ondelete='CASCADE',
141
                                     onupdate='CASCADE'),
142
                          primary_key=True))
143
    columns.append(Column('population', Integer, nullable=False, default=0))
144
    columns.append(Column('size', BigInteger, nullable=False, default=0))
145
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
146
    columns.append(Column('cluster', Integer, nullable=False, default=0,
147
                          primary_key=True, autoincrement=False))
148
    statistics = Table('statistics', metadata, *columns, mysql_engine='InnoDB')
149

    
150
    #create versions table
151
    columns = []
152
    columns.append(Column('serial', Integer, primary_key=True))
153
    columns.append(Column('node', Integer,
154
                          ForeignKey('nodes.node',
155
                                     ondelete='CASCADE',
156
                                     onupdate='CASCADE')))
157
    columns.append(Column('hash', String(256)))
158
    columns.append(Column('size', BigInteger, nullable=False, default=0))
159
    columns.append(Column('type', String(256), nullable=False, default=''))
160
    columns.append(Column('source', Integer))
161
    columns.append(Column('mtime', DECIMAL(precision=16, scale=6)))
162
    columns.append(Column('muser', String(256), nullable=False, default=''))
163
    columns.append(Column('uuid', String(64), nullable=False, default=''))
164
    columns.append(Column('checksum', String(256), nullable=False, default=''))
165
    columns.append(Column('cluster', Integer, nullable=False, default=0))
166
    versions = Table('versions', metadata, *columns, mysql_engine='InnoDB')
167
    Index('idx_versions_node_mtime', versions.c.node, versions.c.mtime)
168
    Index('idx_versions_node_uuid', versions.c.uuid)
169

    
170
    #create attributes table
171
    columns = []
172
    columns.append(Column('serial', Integer,
173
                          ForeignKey('versions.serial',
174
                                     ondelete='CASCADE',
175
                                     onupdate='CASCADE'),
176
                          primary_key=True))
177
    columns.append(Column('domain', String(256), primary_key=True))
178
    columns.append(Column('key', String(128), primary_key=True))
179
    columns.append(Column('value', String(256)))
180
    attributes = Table('attributes', metadata, *columns, mysql_engine='InnoDB')
181

    
182
    metadata.create_all(engine)
183
    return metadata.sorted_tables
184

    
185

    
186
class Node(DBWorker):
187
    """Nodes store path organization and have multiple versions.
188
       Versions store object history and have multiple attributes.
189
       Attributes store metadata.
190
    """
191

    
192
    # TODO: Provide an interface for included and excluded clusters.
193

    
194
    def __init__(self, **params):
195
        DBWorker.__init__(self, **params)
196
        try:
197
            metadata = MetaData(self.engine)
198
            self.nodes = Table('nodes', metadata, autoload=True)
199
            self.policy = Table('policy', metadata, autoload=True)
200
            self.statistics = Table('statistics', metadata, autoload=True)
201
            self.versions = Table('versions', metadata, autoload=True)
202
            self.attributes = Table('attributes', metadata, autoload=True)
203
        except NoSuchTableError:
204
            tables = create_tables(self.engine)
205
            map(lambda t: self.__setattr__(t.name, t), tables)
206

    
207
        s = self.nodes.select().where(and_(self.nodes.c.node == ROOTNODE,
208
                                           self.nodes.c.parent == ROOTNODE))
209
        rp = self.conn.execute(s)
210
        r = rp.fetchone()
211
        rp.close()
212
        if not r:
213
            s = self.nodes.insert(
214
            ).values(node=ROOTNODE, parent=ROOTNODE, path='')
215
            self.conn.execute(s)
216

    
217
    def node_create(self, parent, path):
218
        """Create a new node from the given properties.
219
           Return the node identifier of the new node.
220
        """
221
        #TODO catch IntegrityError?
222
        s = self.nodes.insert().values(parent=parent, path=path)
223
        r = self.conn.execute(s)
224
        inserted_primary_key = r.inserted_primary_key[0]
225
        r.close()
226
        return inserted_primary_key
227

    
228
    def node_lookup(self, path):
229
        """Lookup the current node of the given path.
230
           Return None if the path is not found.
231
        """
232

    
233
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
234
        s = select([self.nodes.c.node], self.nodes.c.path.like(
235
            self.escape_like(path), escape='\\'))
236
        r = self.conn.execute(s)
237
        row = r.fetchone()
238
        r.close()
239
        if row:
240
            return row[0]
241
        return None
242

    
243
    def node_lookup_bulk(self, paths):
244
        """Lookup the current nodes for the given paths.
245
           Return () if the path is not found.
246
        """
247

    
248
        if not paths:
249
            return ()
250
        # Use LIKE for comparison to avoid MySQL problems with trailing spaces.
251
        s = select([self.nodes.c.node], self.nodes.c.path.in_(paths))
252
        r = self.conn.execute(s)
253
        rows = r.fetchall()
254
        r.close()
255
        return [row[0] for row in rows]
256

    
257
    def node_get_properties(self, node):
258
        """Return the node's (parent, path).
259
           Return None if the node is not found.
260
        """
261

    
262
        s = select([self.nodes.c.parent, self.nodes.c.path])
263
        s = s.where(self.nodes.c.node == node)
264
        r = self.conn.execute(s)
265
        l = r.fetchone()
266
        r.close()
267
        return l
268

    
269
    def node_get_versions(self, node, keys=(), propnames=_propnames):
270
        """Return the properties of all versions at node.
271
           If keys is empty, return all properties in the order
272
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
273
        """
274

    
275
        s = select([self.versions.c.serial,
276
                    self.versions.c.node,
277
                    self.versions.c.hash,
278
                    self.versions.c.size,
279
                    self.versions.c.type,
280
                    self.versions.c.source,
281
                    self.versions.c.mtime,
282
                    self.versions.c.muser,
283
                    self.versions.c.uuid,
284
                    self.versions.c.checksum,
285
                    self.versions.c.cluster], self.versions.c.node == node)
286
        s = s.order_by(self.versions.c.serial)
287
        r = self.conn.execute(s)
288
        rows = r.fetchall()
289
        r.close()
290
        if not rows:
291
            return rows
292

    
293
        if not keys:
294
            return rows
295

    
296
        return [[p[propnames[k]] for k in keys if k in propnames] for p in rows]
297

    
298
    def node_count_children(self, node):
299
        """Return node's child count."""
300

    
301
        s = select([func.count(self.nodes.c.node)])
302
        s = s.where(and_(self.nodes.c.parent == node,
303
                         self.nodes.c.node != ROOTNODE))
304
        r = self.conn.execute(s)
305
        row = r.fetchone()
306
        r.close()
307
        return row[0]
308

    
309
    def node_purge_children(self, parent, before=inf, cluster=0):
310
        """Delete all versions with the specified
311
           parent and cluster, and return
312
           the hashes, the total size and the serials of versions deleted.
313
           Clears out nodes with no remaining versions.
314
        """
315
        #update statistics
316
        c1 = select([self.nodes.c.node],
317
                    self.nodes.c.parent == parent)
318
        where_clause = and_(self.versions.c.node.in_(c1),
319
                            self.versions.c.cluster == cluster)
320
        if before != inf:
321
            where_clause = and_(where_clause,
322
                                self.versions.c.mtime <= before)
323
        s = select([func.count(self.versions.c.serial),
324
                    func.sum(self.versions.c.size)])
325
        s = s.where(where_clause)
326
        r = self.conn.execute(s)
327
        row = r.fetchone()
328
        r.close()
329
        if not row:
330
            return (), 0, ()
331
        nr, size = row[0], row[1] if row[1] else 0
332
        mtime = time()
333
        self.statistics_update(parent, -nr, -size, mtime, cluster)
334
        self.statistics_update_ancestors(parent, -nr, -size, mtime, cluster)
335

    
336
        s = select([self.versions.c.hash, self.versions.c.serial])
337
        s = s.where(where_clause)
338
        r = self.conn.execute(s)
339
        hashes = []
340
        serials = []
341
        for row in r.fetchall():
342
            hashes += [row[0]]
343
            serials += [row[1]]
344
        r.close()
345

    
346
        #delete versions
347
        s = self.versions.delete().where(where_clause)
348
        r = self.conn.execute(s)
349
        r.close()
350

    
351
        #delete nodes
352
        s = select([self.nodes.c.node],
353
                   and_(self.nodes.c.parent == parent,
354
                        select([func.count(self.versions.c.serial)],
355
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
356
        rp = self.conn.execute(s)
357
        nodes = [r[0] for r in rp.fetchall()]
358
        rp.close()
359
        if nodes:
360
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
361
            self.conn.execute(s).close()
362

    
363
        return hashes, size, serials
364

    
365
    def node_purge(self, node, before=inf, cluster=0):
366
        """Delete all versions with the specified
367
           node and cluster, and return
368
           the hashes and size of versions deleted.
369
           Clears out the node if it has no remaining versions.
370
        """
371

    
372
        #update statistics
373
        s = select([func.count(self.versions.c.serial),
374
                    func.sum(self.versions.c.size)])
375
        where_clause = and_(self.versions.c.node == node,
376
                            self.versions.c.cluster == cluster)
377
        if before != inf:
378
            where_clause = and_(where_clause,
379
                                self.versions.c.mtime <= before)
380
        s = s.where(where_clause)
381
        r = self.conn.execute(s)
382
        row = r.fetchone()
383
        nr, size = row[0], row[1]
384
        r.close()
385
        if not nr:
386
            return (), 0, ()
387
        mtime = time()
388
        self.statistics_update_ancestors(node, -nr, -size, mtime, cluster)
389

    
390
        s = select([self.versions.c.hash, self.versions.c.serial])
391
        s = s.where(where_clause)
392
        r = self.conn.execute(s)
393
        hashes = []
394
        serials = []
395
        for row in r.fetchall():
396
            hashes += [row[0]]
397
            serials += [row[1]]
398
        r.close()
399

    
400
        #delete versions
401
        s = self.versions.delete().where(where_clause)
402
        r = self.conn.execute(s)
403
        r.close()
404

    
405
        #delete nodes
406
        s = select([self.nodes.c.node],
407
                   and_(self.nodes.c.node == node,
408
                        select([func.count(self.versions.c.serial)],
409
                               self.versions.c.node == self.nodes.c.node).as_scalar() == 0))
410
        rp= self.conn.execute(s)
411
        nodes = [r[0] for r in rp.fetchall()]
412
        rp.close()
413
        if nodes:
414
            s = self.nodes.delete().where(self.nodes.c.node.in_(nodes))
415
            self.conn.execute(s).close()
416

    
417
        return hashes, size, serials
418

    
419
    def node_remove(self, node):
420
        """Remove the node specified.
421
           Return false if the node has children or is not found.
422
        """
423

    
424
        if self.node_count_children(node):
425
            return False
426

    
427
        mtime = time()
428
        s = select([func.count(self.versions.c.serial),
429
                    func.sum(self.versions.c.size),
430
                    self.versions.c.cluster])
431
        s = s.where(self.versions.c.node == node)
432
        s = s.group_by(self.versions.c.cluster)
433
        r = self.conn.execute(s)
434
        for population, size, cluster in r.fetchall():
435
            self.statistics_update_ancestors(
436
                node, -population, -size, mtime, cluster)
437
        r.close()
438

    
439
        s = self.nodes.delete().where(self.nodes.c.node == node)
440
        self.conn.execute(s).close()
441
        return True
442

    
443
    def node_accounts(self):
444
        s = select([self.nodes.c.path])
445
        s = s.where(and_(self.nodes.c.node != 0, self.nodes.c.parent == 0))
446
        account_nodes = self.conn.execute(s).fetchall()
447
        return sorted(i[0] for i in account_nodes)
448

    
449
    def policy_get(self, node):
450
        s = select([self.policy.c.key, self.policy.c.value],
451
                   self.policy.c.node == node)
452
        r = self.conn.execute(s)
453
        d = dict(r.fetchall())
454
        r.close()
455
        return d
456

    
457
    def policy_set(self, node, policy):
458
        #insert or replace
459
        for k, v in policy.iteritems():
460
            s = self.policy.update().where(and_(self.policy.c.node == node,
461
                                                self.policy.c.key == k))
462
            s = s.values(value=v)
463
            rp = self.conn.execute(s)
464
            rp.close()
465
            if rp.rowcount == 0:
466
                s = self.policy.insert()
467
                values = {'node': node, 'key': k, 'value': v}
468
                r = self.conn.execute(s, values)
469
                r.close()
470

    
471
    def statistics_get(self, node, cluster=0):
472
        """Return population, total size and last mtime
473
           for all versions under node that belong to the cluster.
474
        """
475

    
476
        s = select([self.statistics.c.population,
477
                    self.statistics.c.size,
478
                    self.statistics.c.mtime])
479
        s = s.where(and_(self.statistics.c.node == node,
480
                         self.statistics.c.cluster == cluster))
481
        r = self.conn.execute(s)
482
        row = r.fetchone()
483
        r.close()
484
        return row
485

    
486
    def statistics_update(self, node, population, size, mtime, cluster=0):
487
        """Update the statistics of the given node.
488
           Statistics keep track the population, total
489
           size of objects and mtime in the node's namespace.
490
           May be zero or positive or negative numbers.
491
        """
492
        s = select([self.statistics.c.population, self.statistics.c.size],
493
                   and_(self.statistics.c.node == node,
494
                        self.statistics.c.cluster == cluster))
495
        rp = self.conn.execute(s)
496
        r = rp.fetchone()
497
        rp.close()
498
        if not r:
499
            prepopulation, presize = (0, 0)
500
        else:
501
            prepopulation, presize = r
502
        population += prepopulation
503
        population = max(population, 0)
504
        size += presize
505

    
506
        #insert or replace
507
        #TODO better upsert
508
        u = self.statistics.update().where(and_(self.statistics.c.node == node,
509
                                           self.statistics.c.cluster == cluster))
510
        u = u.values(population=population, size=size, mtime=mtime)
511
        rp = self.conn.execute(u)
512
        rp.close()
513
        if rp.rowcount == 0:
514
            ins = self.statistics.insert()
515
            ins = ins.values(node=node, population=population, size=size,
516
                             mtime=mtime, cluster=cluster)
517
            self.conn.execute(ins).close()
518

    
519
    def statistics_update_ancestors(self, node, population, size, mtime, cluster=0):
520
        """Update the statistics of the given node's parent.
521
           Then recursively update all parents up to the root.
522
           Population is not recursive.
523
        """
524

    
525
        while True:
526
            if node == ROOTNODE:
527
                break
528
            props = self.node_get_properties(node)
529
            if props is None:
530
                break
531
            parent, path = props
532
            self.statistics_update(parent, population, size, mtime, cluster)
533
            node = parent
534
            population = 0  # Population isn't recursive
535

    
536
    def statistics_latest(self, node, before=inf, except_cluster=0):
537
        """Return population, total size and last mtime
538
           for all latest versions under node that
539
           do not belong to the cluster.
540
        """
541

    
542
        # The node.
543
        props = self.node_get_properties(node)
544
        if props is None:
545
            return None
546
        parent, path = props
547

    
548
        # The latest version.
549
        s = select([self.versions.c.serial,
550
                    self.versions.c.node,
551
                    self.versions.c.hash,
552
                    self.versions.c.size,
553
                    self.versions.c.type,
554
                    self.versions.c.source,
555
                    self.versions.c.mtime,
556
                    self.versions.c.muser,
557
                    self.versions.c.uuid,
558
                    self.versions.c.checksum,
559
                    self.versions.c.cluster])
560
        if before != inf:
561
            filtered = select([func.max(self.versions.c.serial)],
562
                              self.versions.c.node == node)
563
            filtered = filtered.where(self.versions.c.mtime < before)
564
        else:
565
            filtered = select([self.nodes.c.latest_version],
566
                              self.versions.c.node == node)
567
        s = s.where(and_(self.versions.c.cluster != except_cluster,
568
                         self.versions.c.serial == filtered))
569
        r = self.conn.execute(s)
570
        props = r.fetchone()
571
        r.close()
572
        if not props:
573
            return None
574
        mtime = props[MTIME]
575

    
576
        # First level, just under node (get population).
577
        v = self.versions.alias('v')
578
        s = select([func.count(v.c.serial),
579
                    func.sum(v.c.size),
580
                    func.max(v.c.mtime)])
581
        if before != inf:
582
            c1 = select([func.max(self.versions.c.serial)])
583
            c1 = c1.where(self.versions.c.mtime < before)
584
            c1.where(self.versions.c.node == v.c.node)
585
        else:
586
            c1 = select([self.nodes.c.latest_version])
587
            c1.where(self.nodes.c.node == v.c.node)
588
        c2 = select([self.nodes.c.node], self.nodes.c.parent == node)
589
        s = s.where(and_(v.c.serial == c1,
590
                         v.c.cluster != except_cluster,
591
                         v.c.node.in_(c2)))
592
        rp = self.conn.execute(s)
593
        r = rp.fetchone()
594
        rp.close()
595
        if not r:
596
            return None
597
        count = r[0]
598
        mtime = max(mtime, r[2])
599
        if count == 0:
600
            return (0, 0, mtime)
601

    
602
        # All children (get size and mtime).
603
        # This is why the full path is stored.
604
        s = select([func.count(v.c.serial),
605
                    func.sum(v.c.size),
606
                    func.max(v.c.mtime)])
607
        if before != inf:
608
            c1 = select([func.max(self.versions.c.serial)],
609
                        self.versions.c.node == v.c.node)
610
            c1 = c1.where(self.versions.c.mtime < before)
611
        else:
612
            c1 = select([self.nodes.c.serial],
613
                        self.nodes.c.node == v.c.node)
614
        c2 = select([self.nodes.c.node], self.nodes.c.path.like(
615
            self.escape_like(path) + '%', escape='\\'))
616
        s = s.where(and_(v.c.serial == c1,
617
                         v.c.cluster != except_cluster,
618
                         v.c.node.in_(c2)))
619
        rp = self.conn.execute(s)
620
        r = rp.fetchone()
621
        rp.close()
622
        if not r:
623
            return None
624
        size = r[1] - props[SIZE]
625
        mtime = max(mtime, r[2])
626
        return (count, size, mtime)
627

    
628
    def nodes_set_latest_version(self, node, serial):
629
        s = self.nodes.update().where(self.nodes.c.node == node)
630
        s = s.values(latest_version=serial)
631
        self.conn.execute(s).close()
632

    
633
    def version_create(self, node, hash, size, type, source, muser, uuid, checksum, cluster=0):
634
        """Create a new version from the given properties.
635
           Return the (serial, mtime) of the new version.
636
        """
637

    
638
        mtime = time()
639
        s = self.versions.insert(
640
        ).values(node=node, hash=hash, size=size, type=type, source=source,
641
                 mtime=mtime, muser=muser, uuid=uuid, checksum=checksum, cluster=cluster)
642
        serial = self.conn.execute(s).inserted_primary_key[0]
643
        self.statistics_update_ancestors(node, 1, size, mtime, cluster)
644

    
645
        self.nodes_set_latest_version(node, serial)
646

    
647
        return serial, mtime
648

    
649
    def version_lookup(self, node, before=inf, cluster=0, all_props=True):
650
        """Lookup the current version of the given node.
651
           Return a list with its properties:
652
           (serial, node, hash, size, type, source, mtime,
653
            muser, uuid, checksum, cluster)
654
           or None if the current version is not found in the given cluster.
655
        """
656

    
657
        v = self.versions.alias('v')
658
        if not all_props:
659
            s = select([v.c.serial])
660
        else:
661
            s = select([v.c.serial, v.c.node, v.c.hash,
662
                        v.c.size, v.c.type, v.c.source,
663
                        v.c.mtime, v.c.muser, v.c.uuid,
664
                        v.c.checksum, v.c.cluster])
665
        if before != inf:
666
            c = select([func.max(self.versions.c.serial)],
667
                       self.versions.c.node == node)
668
            c = c.where(self.versions.c.mtime < before)
669
        else:
670
            c = select([self.nodes.c.latest_version],
671
                       self.nodes.c.node == node)
672
        s = s.where(and_(v.c.serial == c,
673
                         v.c.cluster == cluster))
674
        r = self.conn.execute(s)
675
        props = r.fetchone()
676
        r.close()
677
        if props:
678
            return props
679
        return None
680

    
681
    def version_lookup_bulk(self, nodes, before=inf, cluster=0, all_props=True):
682
        """Lookup the current versions of the given nodes.
683
           Return a list with their properties:
684
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
685
        """
686
        if not nodes:
687
            return ()
688
        v = self.versions.alias('v')
689
        if not all_props:
690
            s = select([v.c.serial])
691
        else:
692
            s = select([v.c.serial, v.c.node, v.c.hash,
693
                        v.c.size, v.c.type, v.c.source,
694
                        v.c.mtime, v.c.muser, v.c.uuid,
695
                        v.c.checksum, v.c.cluster])
696
        if before != inf:
697
            c = select([func.max(self.versions.c.serial)],
698
                       self.versions.c.node.in_(nodes))
699
            c = c.where(self.versions.c.mtime < before)
700
            c = c.group_by(self.versions.c.node)
701
        else:
702
            c = select([self.nodes.c.latest_version],
703
                       self.nodes.c.node.in_(nodes))
704
        s = s.where(and_(v.c.serial.in_(c),
705
                         v.c.cluster == cluster))
706
        s = s.order_by(v.c.node)
707
        r = self.conn.execute(s)
708
        rproxy = r.fetchall()
709
        r.close()
710
        return (tuple(row.values()) for row in rproxy)
711

    
712
    def version_get_properties(self, serial, keys=(), propnames=_propnames):
713
        """Return a sequence of values for the properties of
714
           the version specified by serial and the keys, in the order given.
715
           If keys is empty, return all properties in the order
716
           (serial, node, hash, size, type, source, mtime, muser, uuid, checksum, cluster).
717
        """
718

    
719
        v = self.versions.alias()
720
        s = select([v.c.serial, v.c.node, v.c.hash,
721
                    v.c.size, v.c.type, v.c.source,
722
                    v.c.mtime, v.c.muser, v.c.uuid,
723
                    v.c.checksum, v.c.cluster], v.c.serial == serial)
724
        rp = self.conn.execute(s)
725
        r = rp.fetchone()
726
        rp.close()
727
        if r is None:
728
            return r
729

    
730
        if not keys:
731
            return r
732
        return [r[propnames[k]] for k in keys if k in propnames]
733

    
734
    def version_put_property(self, serial, key, value):
735
        """Set value for the property of version specified by key."""
736

    
737
        if key not in _propnames:
738
            return
739
        s = self.versions.update()
740
        s = s.where(self.versions.c.serial == serial)
741
        s = s.values(**{key: value})
742
        self.conn.execute(s).close()
743

    
744
    def version_recluster(self, serial, cluster):
745
        """Move the version into another cluster."""
746

    
747
        props = self.version_get_properties(serial)
748
        if not props:
749
            return
750
        node = props[NODE]
751
        size = props[SIZE]
752
        oldcluster = props[CLUSTER]
753
        if cluster == oldcluster:
754
            return
755

    
756
        mtime = time()
757
        self.statistics_update_ancestors(node, -1, -size, mtime, oldcluster)
758
        self.statistics_update_ancestors(node, 1, size, mtime, cluster)
759

    
760
        s = self.versions.update()
761
        s = s.where(self.versions.c.serial == serial)
762
        s = s.values(cluster=cluster)
763
        self.conn.execute(s).close()
764

    
765
    def version_remove(self, serial):
766
        """Remove the serial specified."""
767

    
768
        props = self.version_get_properties(serial)
769
        if not props:
770
            return
771
        node = props[NODE]
772
        hash = props[HASH]
773
        size = props[SIZE]
774
        cluster = props[CLUSTER]
775

    
776
        mtime = time()
777
        self.statistics_update_ancestors(node, -1, -size, mtime, cluster)
778

    
779
        s = self.versions.delete().where(self.versions.c.serial == serial)
780
        self.conn.execute(s).close()
781

    
782
        props = self.version_lookup(node, cluster=cluster, all_props=False)
783
        if props:
784
            self.nodes_set_latest_version(node, serial)
785

    
786
        return hash, size
787

    
788
    def attribute_get(self, serial, domain, keys=()):
789
        """Return a list of (key, value) pairs of the version specified by serial.
790
           If keys is empty, return all attributes.
791
           Othwerise, return only those specified.
792
        """
793

    
794
        if keys:
795
            attrs = self.attributes.alias()
796
            s = select([attrs.c.key, attrs.c.value])
797
            s = s.where(and_(attrs.c.key.in_(keys),
798
                             attrs.c.serial == serial,
799
                             attrs.c.domain == domain))
800
        else:
801
            attrs = self.attributes.alias()
802
            s = select([attrs.c.key, attrs.c.value])
803
            s = s.where(and_(attrs.c.serial == serial,
804
                             attrs.c.domain == domain))
805
        r = self.conn.execute(s)
806
        l = r.fetchall()
807
        r.close()
808
        return l
809

    
810
    def attribute_set(self, serial, domain, items):
811
        """Set the attributes of the version specified by serial.
812
           Receive attributes as an iterable of (key, value) pairs.
813
        """
814
        #insert or replace
815
        #TODO better upsert
816
        for k, v in items:
817
            s = self.attributes.update()
818
            s = s.where(and_(self.attributes.c.serial == serial,
819
                             self.attributes.c.domain == domain,
820
                             self.attributes.c.key == k))
821
            s = s.values(value=v)
822
            rp = self.conn.execute(s)
823
            rp.close()
824
            if rp.rowcount == 0:
825
                s = self.attributes.insert()
826
                s = s.values(serial=serial, domain=domain, key=k, value=v)
827
                self.conn.execute(s).close()
828

    
829
    def attribute_del(self, serial, domain, keys=()):
830
        """Delete attributes of the version specified by serial.
831
           If keys is empty, delete all attributes.
832
           Otherwise delete those specified.
833
        """
834

    
835
        if keys:
836
            #TODO more efficient way to do this?
837
            for key in keys:
838
                s = self.attributes.delete()
839
                s = s.where(and_(self.attributes.c.serial == serial,
840
                                 self.attributes.c.domain == domain,
841
                                 self.attributes.c.key == key))
842
                self.conn.execute(s).close()
843
        else:
844
            s = self.attributes.delete()
845
            s = s.where(and_(self.attributes.c.serial == serial,
846
                             self.attributes.c.domain == domain))
847
            self.conn.execute(s).close()
848

    
849
    def attribute_copy(self, source, dest):
850
        s = select(
851
            [dest, self.attributes.c.domain,
852
                self.attributes.c.key, self.attributes.c.value],
853
            self.attributes.c.serial == source)
854
        rp = self.conn.execute(s)
855
        attributes = rp.fetchall()
856
        rp.close()
857
        for dest, domain, k, v in attributes:
858
            #insert or replace
859
            s = self.attributes.update().where(and_(
860
                self.attributes.c.serial == dest,
861
                self.attributes.c.domain == domain,
862
                self.attributes.c.key == k))
863
            rp = self.conn.execute(s, value=v)
864
            rp.close()
865
            if rp.rowcount == 0:
866
                s = self.attributes.insert()
867
                values = {'serial': dest, 'domain': domain,
868
                          'key': k, 'value': v}
869
                self.conn.execute(s, values).close()
870

    
871
    def latest_attribute_keys(self, parent, domain, before=inf, except_cluster=0, pathq=[]):
872
        """Return a list with all keys pairs defined
873
           for all latest versions under parent that
874
           do not belong to the cluster.
875
        """
876

    
877
        # TODO: Use another table to store before=inf results.
878
        a = self.attributes.alias('a')
879
        v = self.versions.alias('v')
880
        n = self.nodes.alias('n')
881
        s = select([a.c.key]).distinct()
882
        if before != inf:
883
            filtered = select([func.max(self.versions.c.serial)])
884
            filtered = filtered.where(self.versions.c.mtime < before)
885
            filtered = filtered.where(self.versions.c.node == v.c.node)
886
        else:
887
            filtered = select([self.nodes.c.latest_version])
888
            filtered = filtered.where(self.nodes.c.node == v.c.node)
889
        s = s.where(v.c.serial == filtered)
890
        s = s.where(v.c.cluster != except_cluster)
891
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
892
                                        self.nodes.c.parent == parent)))
893
        s = s.where(a.c.serial == v.c.serial)
894
        s = s.where(a.c.domain == domain)
895
        s = s.where(n.c.node == v.c.node)
896
        conj = []
897
        for path, match in pathq:
898
            if match == MATCH_PREFIX:
899
                conj.append(
900
                    n.c.path.like(self.escape_like(path) + '%', escape='\\'))
901
            elif match == MATCH_EXACT:
902
                conj.append(n.c.path == path)
903
        if conj:
904
            s = s.where(or_(*conj))
905
        rp = self.conn.execute(s)
906
        rows = rp.fetchall()
907
        rp.close()
908
        return [r[0] for r in rows]
909

    
910
    def latest_version_list(self, parent, prefix='', delimiter=None,
911
                            start='', limit=10000, before=inf,
912
                            except_cluster=0, pathq=[], domain=None,
913
                            filterq=[], sizeq=None, all_props=False):
914
        """Return a (list of (path, serial) tuples, list of common prefixes)
915
           for the current versions of the paths with the given parent,
916
           matching the following criteria.
917

918
           The property tuple for a version is returned if all
919
           of these conditions are true:
920

921
                a. parent matches
922

923
                b. path > start
924

925
                c. path starts with prefix (and paths in pathq)
926

927
                d. version is the max up to before
928

929
                e. version is not in cluster
930

931
                f. the path does not have the delimiter occuring
932
                   after the prefix, or ends with the delimiter
933

934
                g. serial matches the attribute filter query.
935

936
                   A filter query is a comma-separated list of
937
                   terms in one of these three forms:
938

939
                   key
940
                       an attribute with this key must exist
941

942
                   !key
943
                       an attribute with this key must not exist
944

945
                   key ?op value
946
                       the attribute with this key satisfies the value
947
                       where ?op is one of ==, != <=, >=, <, >.
948

949
                h. the size is in the range set by sizeq
950

951
           The list of common prefixes includes the prefixes
952
           matching up to the first delimiter after prefix,
953
           and are reported only once, as "virtual directories".
954
           The delimiter is included in the prefixes.
955

956
           If arguments are None, then the corresponding matching rule
957
           will always match.
958

959
           Limit applies to the first list of tuples returned.
960

961
           If all_props is True, return all properties after path, not just serial.
962
        """
963

    
964
        if not start or start < prefix:
965
            start = strprevling(prefix)
966
        nextling = strnextling(prefix)
967

    
968
        v = self.versions.alias('v')
969
        n = self.nodes.alias('n')
970
        if not all_props:
971
            s = select([n.c.path, v.c.serial]).distinct()
972
        else:
973
            s = select([n.c.path,
974
                        v.c.serial, v.c.node, v.c.hash,
975
                        v.c.size, v.c.type, v.c.source,
976
                        v.c.mtime, v.c.muser, v.c.uuid,
977
                        v.c.checksum, v.c.cluster]).distinct()
978
        if before != inf:
979
            filtered = select([func.max(self.versions.c.serial)])
980
            filtered = filtered.where(self.versions.c.mtime < before)
981
        else:
982
            filtered = select([self.nodes.c.latest_version])
983
        s = s.where(
984
            v.c.serial == filtered.where(self.nodes.c.node == v.c.node))
985
        s = s.where(v.c.cluster != except_cluster)
986
        s = s.where(v.c.node.in_(select([self.nodes.c.node],
987
                                        self.nodes.c.parent == parent)))
988

    
989
        s = s.where(n.c.node == v.c.node)
990
        s = s.where(and_(n.c.path > bindparam('start'), n.c.path < nextling))
991
        conj = []
992
        for path, match in pathq:
993
            if match == MATCH_PREFIX:
994
                conj.append(
995
                    n.c.path.like(self.escape_like(path) + '%', escape='\\'))
996
            elif match == MATCH_EXACT:
997
                conj.append(n.c.path == path)
998
        if conj:
999
            s = s.where(or_(*conj))
1000

    
1001
        if sizeq and len(sizeq) == 2:
1002
            if sizeq[0]:
1003
                s = s.where(v.c.size >= sizeq[0])
1004
            if sizeq[1]:
1005
                s = s.where(v.c.size < sizeq[1])
1006

    
1007
        if domain and filterq:
1008
            a = self.attributes.alias('a')
1009
            included, excluded, opers = parse_filters(filterq)
1010
            if included:
1011
                subs = select([1])
1012
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1013
                subs = subs.where(a.c.domain == domain)
1014
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in included]))
1015
                s = s.where(exists(subs))
1016
            if excluded:
1017
                subs = select([1])
1018
                subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1019
                subs = subs.where(a.c.domain == domain)
1020
                subs = subs.where(or_(*[a.c.key.op('=')(x) for x in excluded]))
1021
                s = s.where(not_(exists(subs)))
1022
            if opers:
1023
                for k, o, val in opers:
1024
                    subs = select([1])
1025
                    subs = subs.where(a.c.serial == v.c.serial).correlate(v)
1026
                    subs = subs.where(a.c.domain == domain)
1027
                    subs = subs.where(
1028
                        and_(a.c.key.op('=')(k), a.c.value.op(o)(val)))
1029
                    s = s.where(exists(subs))
1030

    
1031
        s = s.order_by(n.c.path)
1032

    
1033
        if not delimiter:
1034
            s = s.limit(limit)
1035
            rp = self.conn.execute(s, start=start)
1036
            r = rp.fetchall()
1037
            rp.close()
1038
            return r, ()
1039

    
1040
        pfz = len(prefix)
1041
        dz = len(delimiter)
1042
        count = 0
1043
        prefixes = []
1044
        pappend = prefixes.append
1045
        matches = []
1046
        mappend = matches.append
1047

    
1048
        rp = self.conn.execute(s, start=start)
1049
        while True:
1050
            props = rp.fetchone()
1051
            if props is None:
1052
                break
1053
            path = props[0]
1054
            serial = props[1]
1055
            idx = path.find(delimiter, pfz)
1056

    
1057
            if idx < 0:
1058
                mappend(props)
1059
                count += 1
1060
                if count >= limit:
1061
                    break
1062
                continue
1063

    
1064
            if idx + dz == len(path):
1065
                mappend(props)
1066
                count += 1
1067
                continue  # Get one more, in case there is a path.
1068
            pf = path[:idx + dz]
1069
            pappend(pf)
1070
            if count >= limit:
1071
                break
1072

    
1073
            rp = self.conn.execute(s, start=strnextling(pf))  # New start.
1074
        rp.close()
1075

    
1076
        return matches, prefixes
1077

    
1078
    def latest_uuid(self, uuid, cluster):
1079
        """Return the latest version of the given uuid and cluster.
1080

1081
        Return a (path, serial) tuple.
1082
        If cluster is None, all clusters are considered.
1083

1084
        """
1085

    
1086
        v = self.versions.alias('v')
1087
        n = self.nodes.alias('n')
1088
        s = select([n.c.path, v.c.serial])
1089
        filtered = select([func.max(self.versions.c.serial)])
1090
        filtered = filtered.where(self.versions.c.uuid == uuid)
1091
        if cluster is not None:
1092
            filtered = filtered.where(self.versions.c.cluster == cluster)
1093
        s = s.where(v.c.serial == filtered)
1094
        s = s.where(n.c.node == v.c.node)
1095

    
1096
        r = self.conn.execute(s)
1097
        l = r.fetchone()
1098
        r.close()
1099
        return l