Statistics
| Branch: | Tag: | Revision:

root / snf-django-lib / snf_django / lib / api / __init__.py @ bda47e03

History | View | Annotate | Download (11.5 kB)

1
# Copyright 2012, 2013 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
from functools import wraps
35
from traceback import format_exc
36
from time import time
37
from logging import getLogger
38
from wsgiref.handlers import format_date_time
39

    
40
from django.http import HttpResponse
41
from django.utils import cache
42
from django.utils import simplejson as json
43
from django.template.loader import render_to_string
44
from django.views.decorators import csrf
45

    
46
from astakosclient import AstakosClient
47
from astakosclient.errors import AstakosClientException
48
from django.conf import settings
49
from snf_django.lib.api import faults
50

    
51
import itertools
52

    
53
log = getLogger(__name__)
54

    
55

    
56
def get_token(request):
57
    """Get the Authentication Token of a request."""
58
    token = request.GET.get("X-Auth-Token", None)
59
    if not token:
60
        token = request.META.get("HTTP_X_AUTH_TOKEN", None)
61
    return token
62

    
63

    
64
def api_method(http_method=None, token_required=True, user_required=True,
65
               logger=None, format_allowed=True, astakos_auth_url=None,
66
               serializations=None, strict_serlization=False):
67
    """Decorator function for views that implement an API method."""
68
    if not logger:
69
        logger = log
70

    
71
    serializations = serializations or ['json', 'xml']
72

    
73
    def decorator(func):
74
        @wraps(func)
75
        def wrapper(request, *args, **kwargs):
76
            try:
77
                # Get the requested serialization format
78
                serialization = get_serialization(
79
                    request, format_allowed, serializations[0])
80

    
81
                # If guessed serialization is not supported, fallback to
82
                # the default serialization or return an API error in case
83
                # strict serialization flag is set.
84
                if not serialization in serializations:
85
                    if strict_serlization:
86
                        raise faults.BadRequest(("%s serialization not "
87
                                                "supported") % serialization)
88
                    serialization = serializations[0]
89
                request.serialization = serialization
90

    
91
                # Check HTTP method
92
                if http_method and request.method != http_method:
93
                    raise faults.BadRequest("Method not allowed")
94

    
95
                # Get authentication token
96
                request.x_auth_token = None
97
                if token_required or user_required:
98
                    token = get_token(request)
99
                    if not token:
100
                        msg = "Access denied. No authentication token"
101
                        raise faults.Unauthorized(msg)
102
                    request.x_auth_token = token
103

    
104
                # Authenticate
105
                if user_required:
106
                    assert(token_required), "Can not get user without token"
107
                    astakos = astakos_auth_url or settings.ASTAKOS_AUTH_URL
108
                    astakos = AstakosClient(token, astakos,
109
                                            use_pool=True,
110
                                            retry=2,
111
                                            logger=logger)
112
                    user_info = astakos.authenticate()
113
                    request.user_uniq = user_info["access"]["user"]["id"]
114
                    request.user = user_info
115

    
116
                # Get the response object
117
                response = func(request, *args, **kwargs)
118

    
119
                # Fill in response variables
120
                update_response_headers(request, response)
121
                return response
122
            except faults.Fault, fault:
123
                if fault.code >= 500:
124
                    logger.exception("API ERROR")
125
                return render_fault(request, fault)
126
            except AstakosClientException as err:
127
                fault = faults.Fault(message=err.message,
128
                                     details=err.details,
129
                                     code=err.status)
130
                if fault.code >= 500:
131
                    logger.exception("Astakos ERROR")
132
                return render_fault(request, fault)
133
            except:
134
                logger.exception("Unexpected ERROR")
135
                fault = faults.InternalServerError("Unexpected error")
136
                return render_fault(request, fault)
137
        return csrf.csrf_exempt(wrapper)
138
    return decorator
139

    
140

    
141
def get_serialization(request, format_allowed=True,
142
                      default_serialization="json"):
143
    """Return the serialization format requested.
144

145
    Valid formats are 'json' and 'xml' and 'text'
146
    """
147

    
148
    if not format_allowed:
149
        return "text"
150

    
151
    # Try to get serialization from 'format' parameter
152
    _format = request.GET.get("format")
153
    if _format:
154
        if _format == "json":
155
            return "json"
156
        elif _format == "xml":
157
            return "xml"
158

    
159
    # Try to get serialization from path
160
    path = request.path
161
    if path.endswith(".json"):
162
        return "json"
163
    elif path.endswith(".xml"):
164
        return "xml"
165

    
166
    for item in request.META.get("HTTP_ACCEPT", "").split(","):
167
        accept, sep, rest = item.strip().partition(";")
168
        if accept == "application/json":
169
            return "json"
170
        elif accept == "application/xml":
171
            return "xml"
172

    
173
    return default_serialization
174

    
175

    
176
def update_response_headers(request, response):
177
    if not getattr(response, "override_serialization", False):
178
        serialization = request.serialization
179
        if serialization == "xml":
180
            response["Content-Type"] = "application/xml; charset=UTF-8"
181
        elif serialization == "json":
182
            response["Content-Type"] = "application/json; charset=UTF-8"
183
        elif serialization == "text":
184
            response["Content-Type"] = "text/plain; charset=UTF-8"
185
        else:
186
            raise ValueError("Unknown serialization format '%s'" %
187
                             serialization)
188

    
189
    if settings.DEBUG or getattr(settings, "TEST", False):
190
        response["Date"] = format_date_time(time())
191

    
192
    if not response.has_header("Content-Length"):
193
        _base_content_is_iter = getattr(response, '_base_content_is_iter',
194
                                        None)
195
        if (_base_content_is_iter is not None and not _base_content_is_iter):
196
            response["Content-Length"] = len(response.content)
197
        else:
198
            if not (response.has_header('Content-Type') and
199
                    response['Content-Type'].startswith(
200
                        'multipart/byteranges')):
201
                # save response content from been consumed if it is an iterator
202
                response._container, data = itertools.tee(response._container)
203
                response["Content-Length"] = len(str(data))
204

    
205
    cache.add_never_cache_headers(response)
206
    # Fix Vary and Cache-Control Headers. Issue: #3448
207
    cache.patch_vary_headers(response, ('X-Auth-Token',))
208
    cache.patch_cache_control(response, no_cache=True, no_store=True,
209
                              must_revalidate=True)
210

    
211

    
212
def render_fault(request, fault):
213
    """Render an API fault to an HTTP response."""
214
    # If running in debug mode add exception information to fault details
215
    if settings.DEBUG or getattr(settings, "TEST", False):
216
        fault.details = format_exc()
217

    
218
    try:
219
        serialization = request.serialization
220
    except AttributeError:
221
        request.serialization = "json"
222
        serialization = "json"
223

    
224
    # Serialize the fault data to xml or json
225
    if serialization == "xml":
226
        data = render_to_string("fault.xml", {"fault": fault})
227
    else:
228
        d = {fault.name: {"code": fault.code,
229
                          "message": fault.message,
230
                          "details": fault.details}}
231
        data = json.dumps(d)
232

    
233
    response = HttpResponse(data, status=fault.code)
234
    update_response_headers(request, response)
235
    return response
236

    
237

    
238
@api_method(token_required=False, user_required=False)
239
def api_endpoint_not_found(request):
240
    raise faults.BadRequest("API endpoint not found")
241

    
242

    
243
@api_method(token_required=False, user_required=False)
244
def api_method_not_allowed(request):
245
    raise faults.BadRequest('Method not allowed')
246

    
247

    
248
def allow_jsonp(key='callback'):
249
    """
250
    Wrapper to enable jsonp responses.
251
    """
252
    def wrapper(func):
253
        @wraps(func)
254
        def view_wrapper(request, *args, **kwargs):
255
            response = func(request, *args, **kwargs)
256
            if 'content-type' in response._headers and \
257
               response._headers['content-type'][1] == 'application/json':
258
                callback_name = request.GET.get(key, None)
259
                if callback_name:
260
                    response.content = "%s(%s)" % (callback_name,
261
                                                   response.content)
262
                    response._headers['content-type'] = ('Content-Type',
263
                                                         'text/javascript')
264
            return response
265
        return view_wrapper
266
    return wrapper
267

    
268

    
269
def user_in_groups(permitted_groups, logger=None):
270
    """Check that the request user belongs to one of permitted groups.
271

272
    Django view wrapper to check that the already identified request user
273
    belongs to one of the allowed groups.
274

275
    """
276
    if not logger:
277
        logger = log
278

    
279
    def decorator(func):
280
        @wraps(func)
281
        def wrapper(request, *args, **kwargs):
282
            if hasattr(request, "user") and request.user is not None:
283
                groups = request.user["access"]["user"]["roles"]
284
                groups = [g["name"] for g in groups]
285
            else:
286
                raise faults.Forbidden
287

    
288
            common_groups = set(groups) & set(permitted_groups)
289

    
290
            if not common_groups:
291
                msg = ("Not allowing access to '%s' by user '%s'. User does"
292
                       " not belong to a valid group. User groups: %s,"
293
                       " Required groups %s"
294
                       % (request.path, request.user, groups,
295
                          permitted_groups))
296
                logger.error(msg)
297
                raise faults.Forbidden
298

    
299
            logger.info("User '%s' in groups '%s' accessed view '%s'",
300
                        request.user_uniq, groups, request.path)
301

    
302
            return func(request, *args, **kwargs)
303
        return wrapper
304
    return decorator