Statistics
| Branch: | Revision:

root / taggit / managers.py @ 04733cdb

History | View | Annotate | Download (8.2 kB)

1
from django.contrib.contenttypes.generic import GenericRelation
2
from django.contrib.contenttypes.models import ContentType
3
from django.db import models
4
from django.db.models.fields.related import ManyToManyRel, RelatedField, add_lazy_relation
5
from django.db.models.related import RelatedObject
6
from django.utils.text import capfirst
7
from django.utils.translation import ugettext_lazy as _
8

    
9
from taggit.forms import TagField
10
from taggit.models import TaggedItem, GenericTaggedItemBase
11
from taggit.utils import require_instance_manager
12

    
13

    
14
try:
15
    all
16
except NameError:
17
    # 2.4 compat
18
    try:
19
        from django.utils.itercompat import all
20
    except ImportError:
21
        # 1.1.X compat
22
        def all(iterable):
23
            for item in iterable:
24
                if not item:
25
                    return False
26
            return True
27

    
28

    
29
class TaggableRel(ManyToManyRel):
30
    def __init__(self):
31
        self.related_name = None
32
        self.limit_choices_to = {}
33
        self.symmetrical = True
34
        self.multiple = True
35
        self.through = None
36

    
37

    
38
class TaggableManager(RelatedField):
39
    def __init__(self, verbose_name=_("Tags"),
40
        help_text=_("A comma-separated list of tags."), through=None, blank=False):
41
        self.through = through or TaggedItem
42
        self.rel = TaggableRel()
43
        self.verbose_name = verbose_name
44
        self.help_text = help_text
45
        self.blank = blank
46
        self.editable = True
47
        self.unique = False
48
        self.creates_table = False
49
        self.db_column = None
50
        self.choices = None
51
        self.serialize = False
52
        self.null = True
53
        self.creation_counter = models.Field.creation_counter
54
        models.Field.creation_counter += 1
55

    
56
    def __get__(self, instance, model):
57
        if instance is not None and instance.pk is None:
58
            raise ValueError("%s objects need to have a primary key value "
59
                "before you can access their tags." % model.__name__)
60
        manager = _TaggableManager(
61
            through=self.through, model=model, instance=instance
62
        )
63
        return manager
64

    
65
    def contribute_to_class(self, cls, name):
66
        self.name = self.column = name
67
        self.model = cls
68
        cls._meta.add_field(self)
69
        setattr(cls, name, self)
70
        if not cls._meta.abstract:
71
            if isinstance(self.through, basestring):
72
                def resolve_related_class(field, model, cls):
73
                    self.through = model
74
                    self.post_through_setup(cls)
75
                add_lazy_relation(
76
                    cls, self, self.through, resolve_related_class
77
                )
78
            else:
79
                self.post_through_setup(cls)
80

    
81
    def post_through_setup(self, cls):
82
        self.use_gfk = (
83
            self.through is None or issubclass(self.through, GenericTaggedItemBase)
84
        )
85
        self.rel.to = self.through._meta.get_field("tag").rel.to
86
        if self.use_gfk:
87
            tagged_items = GenericRelation(self.through)
88
            tagged_items.contribute_to_class(cls, "tagged_items")
89

    
90
    def save_form_data(self, instance, value):
91
        getattr(instance, self.name).set(*value)
92

    
93
    def formfield(self, form_class=TagField, **kwargs):
94
        defaults = {
95
            "label": capfirst(self.verbose_name),
96
            "help_text": self.help_text,
97
            "required": not self.blank
98
        }
99
        defaults.update(kwargs)
100
        return form_class(**defaults)
101

    
102
    def value_from_object(self, instance):
103
        if instance.pk:
104
            return self.through.objects.filter(**self.through.lookup_kwargs(instance))
105
        return self.through.objects.none()
106

    
107
    def related_query_name(self):
108
        return self.model._meta.module_name
109

    
110
    def m2m_reverse_name(self):
111
        return self.through._meta.get_field_by_name("tag")[0].column
112

    
113
    def m2m_target_field_name(self):
114
        return self.model._meta.pk.name
115

    
116
    def m2m_reverse_target_field_name(self):
117
        return self.rel.to._meta.pk.name
118

    
119
    def m2m_column_name(self):
120
        if self.use_gfk:
121
            return self.through._meta.virtual_fields[0].fk_field
122
        return self.through._meta.get_field('content_object').column
123

    
124
    def db_type(self, connection=None):
125
        return None
126

    
127
    def m2m_db_table(self):
128
        return self.through._meta.db_table
129

    
130
    def extra_filters(self, pieces, pos, negate):
131
        if negate or not self.use_gfk:
132
            return []
133
        prefix = "__".join(["tagged_items"] + pieces[:pos-2])
134
        cts = map(ContentType.objects.get_for_model, _get_subclasses(self.model))
135
        if len(cts) == 1:
136
            return [("%s__content_type" % prefix, cts[0])]
137
        return [("%s__content_type__in" % prefix, cts)]
138

    
139
    def bulk_related_objects(self, new_objs, using):
140
        return []
141

    
142

    
143
class _TaggableManager(models.Manager):
144
    def __init__(self, through, model, instance):
145
        self.through = through
146
        self.model = model
147
        self.instance = instance
148

    
149
    def get_query_set(self):
150
        return self.through.tags_for(self.model, self.instance)
151

    
152
    def _lookup_kwargs(self):
153
        return self.through.lookup_kwargs(self.instance)
154

    
155
    @require_instance_manager
156
    def add(self, *tags):
157
        str_tags = set([
158
            t
159
            for t in tags
160
            if not isinstance(t, self.through.tag_model())
161
        ])
162
        tag_objs = set(tags) - str_tags
163
        # If str_tags has 0 elements Django actually optimizes that to not do a
164
        # query.  Malcolm is very smart.
165
        existing = self.through.tag_model().objects.filter(
166
            name__in=str_tags
167
        )
168
        tag_objs.update(existing)
169

    
170
        for new_tag in str_tags - set(t.name for t in existing):
171
            tag_objs.add(self.through.tag_model().objects.create(name=new_tag))
172

    
173
        for tag in tag_objs:
174
            self.through.objects.get_or_create(tag=tag, **self._lookup_kwargs())
175

    
176
    @require_instance_manager
177
    def set(self, *tags):
178
        self.clear()
179
        self.add(*tags)
180

    
181
    @require_instance_manager
182
    def remove(self, *tags):
183
        self.through.objects.filter(**self._lookup_kwargs()).filter(
184
            tag__name__in=tags).delete()
185

    
186
    @require_instance_manager
187
    def clear(self):
188
        self.through.objects.filter(**self._lookup_kwargs()).delete()
189

    
190
    def most_common(self):
191
        return self.get_query_set().annotate(
192
            num_times=models.Count(self.through.tag_relname())
193
        ).order_by('-num_times')
194

    
195
    @require_instance_manager
196
    def similar_objects(self):
197
        lookup_kwargs = self._lookup_kwargs()
198
        lookup_keys = sorted(lookup_kwargs)
199
        qs = self.through.objects.values(*lookup_kwargs.keys())
200
        qs = qs.annotate(n=models.Count('pk'))
201
        qs = qs.exclude(**lookup_kwargs)
202
        qs = qs.filter(tag__in=self.all())
203
        qs = qs.order_by('-n')
204

    
205
        # TODO: This all feels like a bit of a hack.
206
        items = {}
207
        if len(lookup_keys) == 1:
208
            # Can we do this without a second query by using a select_related()
209
            # somehow?
210
            f = self.through._meta.get_field_by_name(lookup_keys[0])[0]
211
            objs = f.rel.to._default_manager.filter(**{
212
                "%s__in" % f.rel.field_name: [r["content_object"] for r in qs]
213
            })
214
            for obj in objs:
215
                items[(getattr(obj, f.rel.field_name),)] = obj
216
        else:
217
            preload = {}
218
            for result in qs:
219
                preload.setdefault(result['content_type'], set())
220
                preload[result["content_type"]].add(result["object_id"])
221

    
222
            for ct, obj_ids in preload.iteritems():
223
                ct = ContentType.objects.get_for_id(ct)
224
                for obj in ct.model_class()._default_manager.filter(pk__in=obj_ids):
225
                    items[(ct.pk, obj.pk)] = obj
226

    
227
        results = []
228
        for result in qs:
229
            obj = items[
230
                tuple(result[k] for k in lookup_keys)
231
            ]
232
            obj.similar_tags = result["n"]
233
            results.append(obj)
234
        return results
235

    
236

    
237
def _get_subclasses(model):
238
    subclasses = [model]
239
    for f in model._meta.get_all_field_names():
240
        field = model._meta.get_field_by_name(f)[0]
241
        if (isinstance(field, RelatedObject) and
242
            getattr(field.field.rel, "parent_link", None)):
243
            subclasses.extend(_get_subclasses(field.model))
244
    return subclasses