Keep trash history.
[pithos] / pithos / backends / simple.py
index 8a57d7f..0aecb68 100644 (file)
@@ -58,36 +58,58 @@ class SimpleBackend(BaseBackend):
         self.hash_algorithm = 'sha1'
         self.block_size = 128 * 1024 # 128KB
         
+        self.default_policy = {'quota': 0, 'versioning': 'auto'}
+        
         basepath = os.path.split(db)[0]
         if basepath and not os.path.exists(basepath):
             os.makedirs(basepath)
         
         self.con = sqlite3.connect(db, check_same_thread=False)
+        
+        sql = '''pragma foreign_keys = on'''
+        self.con.execute(sql)
+        
         sql = '''create table if not exists versions (
                     version_id integer primary key,
                     name text,
                     user text,
-                    tstamp datetime default current_timestamp,
+                    tstamp integer not null,
                     size integer default 0,
-                    hide integer default 0)'''
+                    trash integer default 0,
+                    until integer default null)'''
         self.con.execute(sql)
         sql = '''create table if not exists metadata (
-                    version_id integer, key text, value text, primary key (version_id, key))'''
+                    version_id integer,
+                    key text,
+                    value text,
+                    primary key (version_id, key)
+                    foreign key (version_id) references versions(version_id)
+                    on delete cascade)'''
+        self.con.execute(sql)
+        sql = '''create table if not exists hashmaps (
+                    version_id integer,
+                    pos integer,
+                    block_id text,
+                    primary key (version_id, pos)
+                    foreign key (version_id) references versions(version_id)
+                    on delete cascade)'''
         self.con.execute(sql)
         sql = '''create table if not exists blocks (
                     block_id text, data blob, primary key (block_id))'''
         self.con.execute(sql)
-        sql = '''create table if not exists hashmaps (
-                    version_id integer, pos integer, block_id text, primary key (version_id, pos))'''
+        
+        sql = '''create table if not exists policy (
+                    name text, key text, value text, primary key (name, key))'''
         self.con.execute(sql)
+        
         sql = '''create table if not exists groups (
                     account text, name text, users text, primary key (account, name))'''
         self.con.execute(sql)
         sql = '''create table if not exists permissions (
                     name text, read text, write text, primary key (name))'''
         self.con.execute(sql)
-        sql = '''create table if not exists policy (
-                    name text, key text, value text, primary key (name, key))'''
+        sql = '''create table if not exists public (
+                    name text, primary key (name))'''
         self.con.execute(sql)
         self.con.commit()
     
@@ -165,16 +187,34 @@ class SimpleBackend(BaseBackend):
                 self.con.execute(sql, (account, k, ','.join(v)))
         self.con.commit()
     
+    def put_account(self, user, account):
+        """Create a new account with the given name."""
+        
+        logger.debug("put_account: %s", account)
+        if user != account:
+            raise NotAllowedError
+        try:
+            version_id, mtime = self._get_accountinfo(account)
+        except NameError:
+            pass
+        else:
+            raise NameError('Account already exists')
+        version_id = self._put_version(account, user)
+        self.con.commit()
+    
     def delete_account(self, user, account):
         """Delete the account with the given name."""
         
         logger.debug("delete_account: %s", account)
         if user != account:
             raise NotAllowedError
-        count, bytes, tstamp = self._get_pathstats(account)
-        if count > 0:
+        if self._get_pathcount(account) > 0:
             raise IndexError('Account is not empty')
-        self._del_path(account) # Point of no return.
+        sql = 'delete from versions where name = ?'
+        self.con.execute(sql, (path,))
+        sql = 'delete from groups where name = ?'
+        self.con.execute(sql, (account,))
+        self.con.commit()
     
     def list_containers(self, user, account, marker=None, limit=10000, until=None):
         """Return a list of containers existing under an account."""
@@ -182,7 +222,7 @@ class SimpleBackend(BaseBackend):
         logger.debug("list_containers: %s %s %s %s", account, marker, limit, until)
         if user != account:
             raise NotAllowedError
-        return self._list_objects(account, '', '/', marker, limit, False, [], until)
+        return self._list_objects(account, '', '/', marker, limit, False, [], False, until)
     
     def get_container_meta(self, user, account, container, until=None):
         """Return a dictionary with the container metadata."""
@@ -190,6 +230,8 @@ class SimpleBackend(BaseBackend):
         logger.debug("get_container_meta: %s %s %s", account, container, until)
         if user != account:
             raise NotAllowedError
+        
+        # TODO: Container meta for trash.
         path, version_id, mtime = self._get_containerinfo(account, container, until)
         count, bytes, tstamp = self._get_pathstats(path, until)
         if mtime > tstamp:
@@ -220,13 +262,27 @@ class SimpleBackend(BaseBackend):
         """Return a dictionary with the container policy."""
         
         logger.debug("get_container_policy: %s %s", account, container)
-        return {}
+        if user != account:
+            raise NotAllowedError
+        path = self._get_containerinfo(account, container)[0]
+        return self._get_policy(path)
     
     def update_container_policy(self, user, account, container, policy, replace=False):
         """Update the policy associated with the account."""
         
         logger.debug("update_container_policy: %s %s %s %s", account, container, policy, replace)
-        return
+        if user != account:
+            raise NotAllowedError
+        path = self._get_containerinfo(account, container)[0]
+        self._check_policy(policy)
+        if replace:
+            for k, v in self.default_policy.iteritems():
+                if k not in policy:
+                    policy[k] = v
+        for k, v in policy.iteritems():
+            sql = 'insert or replace into policy (name, key, value) values (?, ?, ?)'
+            self.con.execute(sql, (path, k, v))
+        self.con.commit()
     
     def put_container(self, user, account, container, policy=None):
         """Create a new container with the given name."""
@@ -237,34 +293,59 @@ class SimpleBackend(BaseBackend):
         try:
             path, version_id, mtime = self._get_containerinfo(account, container)
         except NameError:
-            path = os.path.join(account, container)
-            version_id = self._put_version(path, user)
+            pass
         else:
             raise NameError('Container already exists')
+        if policy:
+            self._check_policy(policy)
+        path = os.path.join(account, container)
+        version_id = self._put_version(path, user)
+        for k, v in self.default_policy.iteritems():
+            if k not in policy:
+                policy[k] = v
+        for k, v in policy.iteritems():
+            sql = 'insert or replace into policy (name, key, value) values (?, ?, ?)'
+            self.con.execute(sql, (path, k, v))
+        self.con.commit()
     
-    def delete_container(self, user, account, container):
+    def delete_container(self, user, account, container, until=None):
         """Delete the container with the given name."""
         
-        logger.debug("delete_container: %s %s", account, container)
+        logger.debug("delete_container: %s %s %s", account, container, until)
         if user != account:
             raise NotAllowedError
         path, version_id, mtime = self._get_containerinfo(account, container)
-        count, bytes, tstamp = self._get_pathstats(path)
-        if count > 0:
+        
+        if until is not None:
+            sql = '''select version_id from versions where name like ? and tstamp <= ?'''
+            c = self.con.execute(sql, (path + '/%', until))
+            versions = [x[0] for x in c.fetchall()]
+            for v in versions:
+                sql = 'delete from hashmaps where version_id = ?'
+                self.con.execute(sql, (v,))
+                sql = 'delete from versions where version_id = ?'
+                self.con.execute(sql, (v,))
+            self.con.commit()
+            return
+        
+        if self._get_pathcount(path) > 0:
             raise IndexError('Container is not empty')
-        self._del_path(path) # Point of no return.
-        self._copy_version(user, account, account, True, True) # New account version.
+        sql = 'delete from versions where name like ?' # May contain hidden trash items.
+        self.con.execute(sql, (path + '/%',))
+        sql = 'delete from policy where name = ?'
+        self.con.execute(sql, (path,))
+        self._copy_version(user, account, account, True, True) # New account version (for timestamp update).
     
-    def list_objects(self, user, account, container, prefix='', delimiter=None, marker=None, limit=10000, virtual=True, keys=[], until=None):
+    def list_objects(self, user, account, container, prefix='', delimiter=None, marker=None, limit=10000, virtual=True, keys=[], trash=False, until=None):
         """Return a list of objects existing under a container."""
         
-        logger.debug("list_objects: %s %s %s %s %s %s %s", account, container, prefix, delimiter, marker, limit, until)
+        logger.debug("list_objects: %s %s %s %s %s %s %s %s %s %s", account, container, prefix, delimiter, marker, limit, virtual, keys, trash, until)
         if user != account:
             raise NotAllowedError
         path, version_id, mtime = self._get_containerinfo(account, container, until)
-        return self._list_objects(path, prefix, delimiter, marker, limit, virtual, keys, until)
+        return self._list_objects(path, prefix, delimiter, marker, limit, virtual, keys, trash, until)
     
-    def list_object_meta(self, user, account, container, until=None):
+    def list_object_meta(self, user, account, container, trash=False, until=None):
         """Return a list with all the container's object meta keys."""
         
         logger.debug("list_object_meta: %s %s %s", account, container, until)
@@ -273,7 +354,7 @@ class SimpleBackend(BaseBackend):
         path, version_id, mtime = self._get_containerinfo(account, container, until)
         sql = '''select distinct m.key from (%s) o, metadata m
                     where m.version_id = o.version_id and o.name like ?'''
-        sql = sql % self._sql_until(until)
+        sql = sql % self._sql_until(until, trash)
         c = self.con.execute(sql, (path + '/%',))
         return [x[0] for x in c.fetchall()]
     
@@ -325,13 +406,19 @@ class SimpleBackend(BaseBackend):
         """Return the public URL of the object if applicable."""
         
         logger.debug("get_object_public: %s %s %s", account, container, name)
+        self._can_read(user, account, container, name)
+        path = self._get_objectinfo(account, container, name)[0]
+        if self._get_public(path):
+            return '/public/' + path
         return None
     
     def update_object_public(self, user, account, container, name, public):
         """Update the public status of the object."""
         
         logger.debug("update_object_public: %s %s %s %s", account, container, name, public)
-        return
+        self._can_write(user, account, container, name)
+        path = self._get_objectinfo(account, container, name)[0]
+        self._put_public(path, public)
     
     def get_object_hashmap(self, user, account, container, name, version=None):
         """Return the object's size and a list with partial hashes."""
@@ -413,16 +500,62 @@ class SimpleBackend(BaseBackend):
         self.copy_object(user, account, src_container, src_name, dest_container, dest_name, dest_meta, replace_meta, permissions, None)
         self.delete_object(user, account, src_container, src_name)
     
-    def delete_object(self, user, account, container, name):
+    def delete_object(self, user, account, container, name, until=None):
         """Delete an object."""
         
-        logger.debug("delete_object: %s %s %s", account, container, name)
+        logger.debug("delete_object: %s %s %s %s", account, container, name, until)
         if user != account:
             raise NotAllowedError
-        path = self._get_objectinfo(account, container, name)[0]
-        self._put_version(path, user, 0, 1)
-        sql = 'delete from permissions where name = ?'
-        self.con.execute(sql, (path,))
+        if until is None:
+            path = self._get_objectinfo(account, container, name)[0]
+            sql = 'select version_id from versions where name = ?'
+            c = self.con.execute(sql, (path,))
+        else:
+            path = os.path.join(account, container, name)
+            sql = '''select version_id from versions where name = ? and tstamp <= ?'''
+            c = self.con.execute(sql, (path, until))
+        versions = [x[0] for x in c.fetchall()]
+        for v in versions:
+            sql = 'delete from hashmaps where version_id = ?'
+            self.con.execute(sql, (v,))
+            sql = 'delete from versions where version_id = ?'
+            self.con.execute(sql, (v,))
+        
+        # If no more normal versions exist, delete permissions/public.
+        sql = 'select version_id from versions where name = ? and trash = 0'
+        row = self.con.execute(sql, (path,)).fetchone()
+        if row is None:
+            self._del_sharing(path)
+        self.con.commit()
+    
+    def trash_object(self, user, account, container, name):
+        """Trash an object."""
+        
+        logger.debug("trash_object: %s %s %s", account, container, name)
+        if user != account:
+            raise NotAllowedError
+        path, version_id, muser, mtime, size = self._get_objectinfo(account, container, name)
+        src_version_id, dest_version_id = self._copy_version(user, path, path, True, True, version_id)
+        sql = 'update versions set trash = 1 where version_id = ?'
+        self.con.execute(sql, (dest_version_id,))
+        self._del_sharing(path)
+        self.con.commit()
+    
+    def untrash_object(self, user, account, container, name, version):
+        """Untrash an object."""
+        
+        logger.debug("untrash_object: %s %s %s %s", account, container, name, version)
+        if user != account:
+            raise NotAllowedError
+        
+        path = os.path.join(account, container, name)
+        sql = '''select version_id from versions where name = ? and version_id = ? and trash = 1'''
+        c = self.con.execute(sql, (path, version))
+        row = c.fetchone()
+        if not row or not int(row[1]):
+            raise NameError('Object not in trash')
+        sql = 'update versions set until = ? where version_id = ?'
+        self.con.execute(sql, (int(time.time()), version))
         self.con.commit()
     
     def list_versions(self, user, account, container, name):
@@ -430,9 +563,8 @@ class SimpleBackend(BaseBackend):
         
         logger.debug("list_versions: %s %s %s", account, container, name)
         self._can_read(user, account, container, name)
-        # This will even show deleted versions.
         path = os.path.join(account, container, name)
-        sql = '''select distinct version_id, strftime('%s', tstamp) from versions where name = ? and hide = 0'''
+        sql = '''select distinct version_id, tstamp from versions where name = ? and trash = 0'''
         c = self.con.execute(sql, (path,))
         return [(int(x[0]), int(x[1])) for x in c.fetchall()]
     
@@ -472,25 +604,42 @@ class SimpleBackend(BaseBackend):
         dest_data = src_data[:offset] + data + src_data[offset + len(data):]
         return self.put_block(dest_data)
     
-    def _sql_until(self, until=None):
+    def _sql_until(self, until=None, trash=False):
         """Return the sql to get the latest versions until the timestamp given."""
+        
         if until is None:
             until = int(time.time())
-        sql = '''select version_id, name, strftime('%s', tstamp) as tstamp, size from versions v
-                    where version_id = (select max(version_id) from versions
-                                        where v.name = name and tstamp <= datetime(%s, 'unixepoch'))
-                    and hide = 0'''
-        return sql % ('%s', until)
+        if not trash:
+            sql = '''select version_id, name, tstamp, size from versions v
+                        where version_id = (select max(version_id) from versions
+                                            where v.name = name and tstamp <= ?)
+                        and trash = 0'''
+            return sql % (until,)
+        else:
+            sql = '''select version_id, name, tstamp, size from versions v
+                        where trash = 1 and tstamp <= ? and (until is null or until > ?)'''
+            return sql % (until, until)
     
     def _get_pathstats(self, path, until=None):
-        """Return count and sum of size of everything under path and latest timestamp."""
+        """Return count, sum of size and latest timestamp of everything under path (latest versions/no trash)."""
         
-        sql = 'select count(version_id), total(size), max(tstamp) from (%s) where name like ?'
+        sql = 'select count(version_id), total(size) from (%s) where name like ?'
         sql = sql % self._sql_until(until)
         c = self.con.execute(sql, (path + '/%',))
+        total_count, total_size = c.fetchone()
+        sql = 'select max(tstamp) from versions where name like ? and tstamp <= ?' # Include trash actions.
+        c = self.con.execute(sql, (path + '/%', until))
+        row = c.fetchone()
+        tstamp = row[0] if row[0] is not None else 0
+        return int(total_count), int(total_size), int(tstamp)
+    
+    def _get_pathcount(self, path):
+        """Return count of everything under path (including versions/trash)."""
+        
+        sql = 'select count(version_id) from versions where name like ? and until is null'
+        c = self.con.execute(sql, (path + '/%',))
         row = c.fetchone()
-        tstamp = row[2] if row[2] is not None else 0
-        return int(row[0]), int(row[1]), int(tstamp)
+        return int(row[0])
     
     def _get_version(self, path, version=None):
         if version is None:
@@ -510,9 +659,10 @@ class SimpleBackend(BaseBackend):
                 raise IndexError('Version does not exist')
         return str(row[0]), str(row[1]), int(row[2]), int(row[3])
     
-    def _put_version(self, path, user, size=0, hide=0):
-        sql = 'insert into versions (name, user, size, hide) values (?, ?, ?, ?)'
-        id = self.con.execute(sql, (path, user, size, hide)).lastrowid
+    def _put_version(self, path, user, size=0):
+        tstamp = int(time.time())
+        sql = 'insert into versions (name, user, tstamp, size) values (?, ?, ?, ?)'
+        id = self.con.execute(sql, (path, user, tstamp, size)).lastrowid
         self.con.commit()
         return str(id)
     
@@ -599,22 +749,43 @@ class SimpleBackend(BaseBackend):
         c = self.con.execute(sql, (account,))
         return dict([(x[0], x[1].split(',')) for x in c.fetchall()])
     
+    def _check_policy(self, policy):
+        for k in policy.keys():
+            if policy[k] == '':
+                policy[k] = self.default_policy.get(k)
+        for k, v in policy.iteritems():
+            if k == 'quota':
+                q = int(v) # May raise ValueError.
+                if q < 0:
+                    raise ValueError
+            elif k == 'versioning':
+                if v not in ['auto', 'manual', 'none']:
+                    raise ValueError
+            else:
+                raise ValueError
+    
+    def _get_policy(self, path):
+        sql = 'select key, value from policy where name = ?'
+        c = self.con.execute(sql, (path,))
+        return dict(c.fetchall())
+    
     def _is_allowed(self, user, account, container, name, op='read'):
         if user == account:
             return True
         path = os.path.join(account, container, name)
+        if op == 'read' and self._get_public(path):
+            return True
         perm_path, perms = self._get_permissions(path)
         
         # Expand groups.
         for x in ('read', 'write'):
             g_perms = []
             for y in perms.get(x, []):
-                if ':' in y:
-                    g_account, g_name = y.split(':', 1)
-                    groups = self._get_groups(g_account)
-                    if g_name in groups:
-                        g_perms += groups[g_name]
-                else:
+                groups = self._get_groups(account)
+                if y in groups: #it's a group
+                    for g_name in groups[y]:
+                        g_perms.append(g_name)
+                else: #it's a user
                     g_perms.append(y)
             perms[x] = g_perms
         
@@ -637,9 +808,11 @@ class SimpleBackend(BaseBackend):
         sql = '''select name from permissions
                     where name != ? and (name like ? or ? like name || ?)'''
         c = self.con.execute(sql, (path, path + '%', path, '%'))
-        rows = c.fetchall()
-        if rows:
-            raise AttributeError('Permissions already set')
+        row = c.fetchone()
+        if row:
+            ae = AttributeError()
+            ae.data = row[0]
+            raise ae
         
         # Format given permissions.
         if len(permissions) == 0:
@@ -677,16 +850,40 @@ class SimpleBackend(BaseBackend):
             self.con.execute(sql, (path, r, w))
         self.con.commit()
     
-    def _list_objects(self, path, prefix='', delimiter=None, marker=None, limit=10000, virtual=True, keys=[], until=None):
+    def _get_public(self, path):
+        sql = 'select name from public where name = ?'
+        c = self.con.execute(sql, (path,))
+        row = c.fetchone()
+        if not row:
+            return False
+        return True
+    
+    def _put_public(self, path, public):
+        if not public:
+            sql = 'delete from public where name = ?'
+        else:
+            sql = 'insert or replace into public (name) values (?)'
+        self.con.execute(sql, (path,))
+        self.con.commit()
+    
+    def _del_sharing(self, path):
+        sql = 'delete from permissions where name = ?'
+        self.con.execute(sql, (path,))
+        sql = 'delete from public where name = ?'
+        self.con.execute(sql, (path,))
+        self.con.commit()
+    
+    def _list_objects(self, path, prefix='', delimiter=None, marker=None, limit=10000, virtual=True, keys=[], trash=False, until=None):
         cont_prefix = path + '/'
+        
         if keys and len(keys) > 0:
             sql = '''select distinct o.name, o.version_id from (%s) o, metadata m where o.name like ? and
                         m.version_id = o.version_id and m.key in (%s) order by o.name'''
-            sql = sql % (self._sql_until(until), ', '.join('?' * len(keys)))
+            sql = sql % (self._sql_until(until, trash), ', '.join('?' * len(keys)))
             param = (cont_prefix + prefix + '%',) + tuple(keys)
         else:
             sql = 'select name, version_id from (%s) where name like ? order by name'
-            sql = sql % self._sql_until(until)
+            sql = sql % (self._sql_until(until, trash),)
             param = (cont_prefix + prefix + '%',)
         c = self.con.execute(sql, param)
         objects = [(x[0][len(cont_prefix):], x[1]) for x in c.fetchall()]
@@ -720,16 +917,3 @@ class SimpleBackend(BaseBackend):
         if not limit or limit > 10000:
             limit = 10000
         return objects[start:start + limit]
-    
-    def _del_path(self, path):
-        sql = '''delete from hashmaps where version_id in
-                    (select version_id from versions where name = ?)'''
-        self.con.execute(sql, (path,))
-        sql = '''delete from metadata where version_id in
-                    (select version_id from versions where name = ?)'''
-        self.con.execute(sql, (path,))
-        sql = '''delete from versions where name = ?'''
-        self.con.execute(sql, (path,))
-        sql = '''delete from permissions where name like ?'''
-        self.con.execute(sql, (path + '%',)) # Redundant.
-        self.con.commit()