Statistics
| Branch: | Tag: | Revision:

root / snf-astakos-app / astakos / oa2 / backends / base.py @ 68122bae

History | View | Annotate | Download (23.7 kB)

1
# Copyright 2013 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
import urllib
35
import urlparse
36
import uuid
37
import datetime
38
import json
39

    
40
from base64 import b64encode, b64decode
41
from hashlib import sha512
42

    
43
import logging
44
logger = logging.getLogger(__name__)
45

    
46

    
47
def urlencode(params):
48
    if hasattr(params, 'urlencode') and callable(getattr(params, 'urlencode')):
49
        return params.urlencode()
50
    return urllib.urlencode(params)
51

    
52

    
53
def handles_oa2_requests(func):
54
    def wrapper(self, *args, **kwargs):
55
        if not self._errors_to_http:
56
            return func(self, *args, **kwargs)
57
        try:
58
            return func(self, *args, **kwargs)
59
        except OA2Error, e:
60
            return self.build_response_from_error(e)
61
    return wrapper
62

    
63

    
64
class OA2Error(Exception):
65
    error = None
66

    
67

    
68
class InvalidClientID(OA2Error):
69
    pass
70

    
71

    
72
class NotAuthenticatedError(OA2Error):
73
    pass
74

    
75

    
76
class InvalidClientRedirectUrl(OA2Error):
77
    pass
78

    
79

    
80
class InvalidAuthorizationRequest(OA2Error):
81
    pass
82

    
83

    
84
class Response(object):
85

    
86
    def __init__(self, status, body='', headers=None,
87
                 content_type='plain/text'):
88
        if not body:
89
            body = ''
90
        if not headers:
91
            headers = {}
92

    
93
        self.status = status
94
        self.body = body
95
        self.headers = headers
96
        self.content_type = content_type
97

    
98
    def __repr__(self):
99
        return "%d RESPONSE (BODY: %r, HEADERS: %r)" % (self.status,
100
                                                        self.body,
101
                                                        self.headers)
102

    
103

    
104
class Request(object):
105

    
106
    def __init__(self, method, path, GET=None, POST=None, META=None,
107
                 secure=False, user=None):
108
        self.method = method
109
        self.path = path
110

    
111
        if not GET:
112
            GET = {}
113
        if not POST:
114
            POST = {}
115
        if not META:
116
            META = {}
117

    
118
        self.secure = secure
119
        self.GET = GET
120
        self.POST = POST
121
        self.META = META
122
        self.user = user
123

    
124
    def __repr__(self):
125
        prepend = ""
126
        if self.secure:
127
            prepend = "SECURE "
128
        return "%s%s REQUEST (POST: %r, GET:%r, HEADERS:%r, " % (prepend,
129
                                                                 self.method,
130
                                                                 self.POST,
131
                                                                 self.GET,
132
                                                                 self.META)
133

    
134

    
135
class ORMAbstractBase(type):
136

    
137
    def __new__(cls, name, bases, attrs):
138
        attrs['ENTRIES'] = {}
139
        return super(ORMAbstractBase, cls).__new__(cls, name, bases, attrs)
140

    
141

    
142
class ORMAbstract(object):
143

    
144
    ENTRIES = {}
145

    
146
    __metaclass__ = ORMAbstractBase
147

    
148
    def __init__(self, **kwargs):
149
        for key, value in kwargs.iteritems():
150
            setattr(self, key, value)
151

    
152
    @classmethod
153
    def create(cls, id, **params):
154
        params = cls.clean_params(params)
155
        params['id'] = id
156
        cls.ENTRIES[id] = cls(**params)
157
        return cls.get(id)
158

    
159
    @classmethod
160
    def get(cls, pk):
161
        return cls.ENTRIES.get(pk)
162

    
163
    @classmethod
164
    def clean_params(cls, params):
165
        return params
166

    
167

    
168
class Client(ORMAbstract):
169

    
170
    def get_id(self):
171
        return self.id
172

    
173
    def get_redirect_uris(self):
174
        return self.uris
175

    
176
    def get_default_redirect_uri(self):
177
        return self.uris[0]
178

    
179
    def redirect_uri_is_valid(self, redirect_uri):
180
        split = urlparse.urlsplit(redirect_uri)
181
        if split.scheme not in urlparse.uses_query:
182
            raise OA2Error("Invalid redirect url scheme")
183
        uris = self.get_redirect_uris()
184
        return redirect_uri in uris
185

    
186
    def requires_auth(self):
187
        if self.client_type == 'confidential':
188
            return True
189
        return 'secret' in dir(self)
190

    
191
    def check_credentials(self, username, secret):
192
        return username == self.id and secret == self.secret
193

    
194

    
195
class Token(ORMAbstract):
196

    
197
    def to_dict(self):
198
        params = {
199
            'access_token': self.token,
200
            'token_type': self.token_type,
201
            'expires_in': self.expires,
202
        }
203
        if self.refresh_token:
204
            params['refresh_token'] = self.refresh_token
205
        return params
206

    
207

    
208
class AuthorizationCode(ORMAbstract):
209
    pass
210

    
211

    
212
class User(ORMAbstract):
213
    pass
214

    
215

    
216
class BackendBase(type):
217

    
218
    def __new__(cls, name, bases, attrs):
219
        super_new = super(BackendBase, cls).__new__
220
        #parents = [b for b in bases if isinstance(b, BackendBase)]
221
        #meta = attrs.pop('Meta', None)
222
        return super_new(cls, name, bases, attrs)
223

    
224
    @classmethod
225
    def get_orm_options(cls, attrs):
226
        meta = attrs.pop('ORM', None)
227
        orm = {}
228
        if meta:
229
            for attr in dir(meta):
230
                orm[attr] = getattr(meta, attr)
231
        return orm
232

    
233

    
234
class SimpleBackend(object):
235

    
236
    __metaclass__ = BackendBase
237

    
238
    base_url = ''
239
    endpoints_prefix = 'oauth2/'
240

    
241
    token_endpoint = 'token/'
242
    token_length = 30
243
    token_expires = 20
244

    
245
    authorization_endpoint = 'auth/'
246
    authorization_code_length = 60
247
    authorization_response_types = ['code', 'token']
248

    
249
    grant_types = ['authorization_code']
250

    
251
    response_cls = Response
252
    request_cls = Request
253

    
254
    client_model = Client
255
    token_model = Token
256
    code_model = AuthorizationCode
257
    user_model = User
258

    
259
    def __init__(self, base_url='', endpoints_prefix='oauth2/', id='oauth2',
260
                 token_endpoint='token/', token_length=30,
261
                 token_expires=20, authorization_endpoint='auth/',
262
                 authorization_code_length=60,
263
                 redirect_uri_limit=5000, **kwargs):
264
        self.base_url = base_url
265
        self.endpoints_prefix = endpoints_prefix
266
        self.token_endpoint = token_endpoint
267
        self.token_length = token_length
268
        self.token_expires = token_expires
269
        self.authorization_endpoint = authorization_endpoint
270
        self.authorization_code_length = authorization_code_length
271
        self.id = id
272
        self._errors_to_http = kwargs.get('errors_to_http', True)
273
        self.redirect_uri_limit = redirect_uri_limit
274

    
275
    # Request/response builders
276
    def build_request(self, method, get, post, meta):
277
        return self.request_cls(method=method, GET=get, POST=post, META=meta)
278

    
279
    def build_response(self, status, headers=None, body=''):
280
        return self.response_cls(status=status, headers=headers, body=body)
281

    
282
    # ORM Methods
283
    def create_authorization_code(self, user, client, code, redirect_uri,
284
                                  scope, state, **kwargs):
285
        code_params = {
286
            'code': code,
287
            'redirect_uri': redirect_uri,
288
            'client': client,
289
            'scope': scope,
290
            'state': state,
291
            'user': user
292
        }
293
        code_params.update(kwargs)
294
        code_instance = self.code_model.create(**code_params)
295
        logger.info(u'%r created' % code_instance)
296
        return code_instance
297

    
298
    def _token_params(self, value, token_type, authorization, scope):
299
        created_at = datetime.datetime.now()
300
        expires = self.token_expires
301
        expires_at = created_at + datetime.timedelta(seconds=expires)
302
        token_params = {
303
            'code': value,
304
            'token_type': token_type,
305
            'created_at': created_at,
306
            'expires_at': expires_at,
307
            'user': authorization.user,
308
            'redirect_uri': authorization.redirect_uri,
309
            'client': authorization.client,
310
            'scope': authorization.scope,
311
        }
312
        return token_params
313

    
314
    def create_token(self, value, token_type, authorization, scope,
315
                     refresh=False):
316
        params = self._token_params(value, token_type, authorization, scope)
317
        if refresh:
318
            refresh_token = self.generate_token()
319
            params['refresh_token'] = refresh_token
320
            # TODO: refresh token expires ???
321
        token = self.token_model.create(**params)
322
        logger.info(u'%r created' % token)
323
        return token
324

    
325
#    def delete_authorization_code(self, code):
326
#        del self.code_model.ENTRIES[code]
327

    
328
    def get_client_by_id(self, client_id):
329
        return self.client_model.get(client_id)
330

    
331
    def get_client_by_credentials(self, username, password):
332
        return None
333

    
334
    def get_authorization_code(self, code):
335
        return self.code_model.get(code)
336

    
337
    def get_client_authorization_code(self, client, code):
338
        code_instance = self.get_authorization_code(code)
339
        if not code_instance:
340
            raise OA2Error("Invalid code")
341

    
342
        if client.get_id() != code_instance.client.get_id():
343
            raise OA2Error("Mismatching client with code client")
344
        return code_instance
345

    
346
    def client_id_exists(self, client_id):
347
        return bool(self.get_client_by_id(client_id))
348

    
349
    def build_site_url(self, prefix='', **params):
350
        params = urlencode(params)
351
        return "%s%s%s%s" % (self.base_url, self.endpoints_prefix, prefix,
352
                             params)
353

    
354
    def _get_uri_base(self, uri):
355
        split = urlparse.urlsplit(uri)
356
        return "%s://%s%s" % (split.scheme, split.netloc, split.path)
357

    
358
    def build_client_redirect_uri(self, client, uri, **params):
359
        if not client.redirect_uri_is_valid(uri):
360
            raise OA2Error("Invalid redirect uri")
361
        params = urlencode(params)
362
        uri = self._get_uri_base(uri)
363
        return "%s?%s" % (uri, params)
364

    
365
    def generate_authorization_code(self):
366
        dg64 = b64encode(sha512(str(uuid.uuid4())).hexdigest())
367
        return dg64[:self.authorization_code_length]
368

    
369
    def generate_token(self, *args, **kwargs):
370
        dg64 = b64encode(sha512(str(uuid.uuid4())).hexdigest())
371
        return dg64[:self.token_length]
372

    
373
    def add_authorization_code(self, user, client, redirect_uri, scope, state,
374
                               **kwargs):
375
        code = self.generate_authorization_code()
376
        self.create_authorization_code(user, client, code, redirect_uri, scope,
377
                                       state, **kwargs)
378
        return code
379

    
380
    def add_token_for_client(self, token_type, authorization, refresh=False):
381
        token = self.generate_token()
382
        self.create_token(token, token_type, authorization, refresh)
383
        return token
384

    
385
    #
386
    # Response helpers
387
    #
388

    
389
    def grant_accept_response(self, client, redirect_uri, scope, state):
390
        context = {'client': client.get_id(), 'redirect_uri': redirect_uri,
391
                   'scope': scope, 'state': state,
392
                   #'url': url,
393
                   }
394
        json_content = json.dumps(context)
395
        return self.response_cls(status=200, body=json_content)
396

    
397
    def grant_token_response(self, token, token_type):
398
        context = {'access_token': token, 'token_type': token_type,
399
                   'expires_in': self.token_expires}
400
        json_content = json.dumps(context)
401
        return self.response_cls(status=200, body=json_content)
402

    
403
    def redirect_to_login_response(self, request, params):
404
        parts = list(urlparse.urlsplit(request.path))
405
        parts[3] = urlencode(params)
406
        query = {'next': urlparse.urlunsplit(parts)}
407
        return Response(302,
408
                        headers={'Location': '%s?%s' %
409
                                 (self.get_login_uri(),
410
                                  urlencode(query))})
411

    
412
    def redirect_to_uri(self, redirect_uri, code, state=None):
413
        parts = list(urlparse.urlsplit(redirect_uri))
414
        params = dict(urlparse.parse_qsl(parts[3], keep_blank_values=True))
415
        params['code'] = code
416
        if state is not None:
417
            params['state'] = state
418
        parts[3] = urlencode(params)
419
        return Response(302,
420
                        headers={'Location': '%s' %
421
                                 urlparse.urlunsplit(parts)})
422

    
423
    def build_response_from_error(self, exception):
424
        response = Response(400)
425
        logger.exception(exception)
426
        error = 'generic_error'
427
        if exception.error:
428
            error = exception.error
429
        body = {
430
            'error': error,
431
            'exception': exception.message,
432
        }
433
        response.body = json.dumps(body)
434
        response.content_type = "application/json"
435
        return response
436

    
437
    #
438
    # Processor methods
439
    #
440

    
441
    def process_code_request(self, user, client, uri, scope, state):
442
        code = self.add_authorization_code(user, client, uri, scope, state)
443
        return self.redirect_to_uri(uri, code, state)
444

    
445
    #
446
    # Helpers
447
    #
448

    
449
    def grant_authorization_code(self, client, code_instance, redirect_uri,
450
                                 scope=None, token_type="Bearer"):
451
        if scope and code_instance.scope != scope:
452
            raise OA2Error("Invalid scope")
453
        if redirect_uri != code_instance.redirect_uri:
454
            raise OA2Error("The redirect uri does not match "
455
                           "the one used during authorization")
456
        token = self.add_token_for_client(token_type, code_instance)
457
        self.delete_authorization_code(code_instance)  # use only once
458
        return token, token_type
459

    
460
    def consume_token(self, token):
461
        token_instance = self.get_token(token)
462
        if datetime.datetime.now() > token_instance.expires_at:
463
            self.delete_token(token_instance)  # delete expired token
464
            raise OA2Error("Token has expired")
465
        # TODO: delete token?
466
        return token_instance
467

    
468
    def _get_credentials(self, params, headers):
469
        if 'HTTP_AUTHORIZATION' in headers:
470
            scheme, b64credentials = headers.get(
471
                'HTTP_AUTHORIZATION').split(" ")
472
            if scheme != 'Basic':
473
                # TODO: raise 401 + WWW-Authenticate
474
                raise OA2Error("Unsupported authorization scheme")
475
            credentials = b64decode(b64credentials).split(":")
476
            return scheme, credentials
477
        else:
478
            return None, None
479
        pass
480

    
481
    def _get_authorization(self, params, headers):
482
        scheme, client_credentials = self._get_credentials(params, headers)
483
        no_authorization = scheme is None and client_credentials is None
484
        if no_authorization:
485
            raise OA2Error("Missing authorization header")
486
        return client_credentials
487

    
488
    def get_redirect_uri_from_params(self, client, params, default=True):
489
        """
490
        Accepts a client instance and request parameters.
491
        """
492
        redirect_uri = params.get('redirect_uri', None)
493
        if not redirect_uri and default:
494
            redirect_uri = client.get_default_redirect_uri()
495
        else:
496
            # TODO: sanitize redirect_uri (self.clean_redirect_uri ???)
497
            # clean and validate
498
            if not client.redirect_uri_is_valid(redirect_uri):
499
                raise OA2Error("Invalid client redirect uri")
500
        return redirect_uri
501

    
502
    #
503
    # Request identifiers
504
    #
505

    
506
    def identify_authorize_request(self, params, headers):
507
        return params.get('response_type'), params
508

    
509
    def identify_token_request(self, headers, params):
510
        content_type = headers.get('CONTENT_TYPE')
511
        if content_type != 'application/x-www-form-urlencoded':
512
            raise OA2Error("Invalid Content-Type header")
513
        return params.get('grant_type')
514

    
515
    #
516
    # Parameters validation methods
517
    #
518

    
519
    def validate_client(self, params, meta, requires_auth=True,
520
                        client_id_required=True):
521
        client_id = params.get('client_id')
522
        if client_id is None and client_id_required:
523
            raise OA2Error("Client identification is required")
524

    
525
        client_credentials = None
526
        try:  # check authorization header
527
            client_credentials = self._get_authorization(params, meta)
528
            if client_credentials is not None:
529
                _client_id = client_credentials[0]
530
                if client_id is not None and client_id != _client_id:
531
                    raise OA2Error("Client identification conflicts "
532
                                   "with client authorization")
533
                client_id = _client_id
534
        except:
535
            pass
536

    
537
        if client_id is None:
538
            raise OA2Error("Missing client identification")
539

    
540
        client = self.get_client_by_id(client_id)
541

    
542
        if requires_auth and client.requires_auth():
543
            if client_credentials is None:
544
                raise OA2Error("Client authentication is required")
545

    
546
        if client_credentials is not None:
547
            self.check_credentials(client, *client_credentials)
548
        return client
549

    
550
    def validate_redirect_uri(self, client, params, headers,
551
                              allow_default=True, is_required=False,
552
                              expected_value=None):
553
        redirect_uri = params.get('redirect_uri')
554
        if is_required and redirect_uri is None:
555
            raise OA2Error("Missing redirect uri")
556
        if redirect_uri is not None:
557
            if not bool(urlparse.urlparse(redirect_uri).scheme):
558
                raise OA2Error("Redirect uri should be an absolute URI")
559
            if len(redirect_uri) > self.redirect_uri_limit:
560
                raise OA2Error("Redirect uri length limit exceeded")
561
            if not client.redirect_uri_is_valid(redirect_uri):
562
                raise OA2Error("Mismatching redirect uri")
563
            if expected_value is not None and redirect_uri != expected_value:
564
                raise OA2Error("Invalid redirect uri")
565
        else:
566
            try:
567
                redirect_uri = client.redirecturl_set.values_list('url',
568
                                                                  flat=True)[0]
569
            except IndexError:
570
                raise OA2Error("Unable to fallback to client redirect URI")
571
        return redirect_uri
572

    
573
    def validate_state(self, client, params, headers):
574
        return params.get('state')
575
        raise OA2Error("Invalid state")
576

    
577
    def validate_scope(self, client, params, headers):
578
        scope = params.get('scope')
579
        if scope is not None:
580
            scope = scope.split(' ')[0]  # keep only the first
581
        # TODO: check for invalid characters
582
        return scope
583

    
584
    def validate_code(self, client, params, headers):
585
        code = params.get('code')
586
        if code is None:
587
            raise OA2Error("Missing authorization code")
588
        return self.get_client_authorization_code(client, code)
589

    
590
    #
591
    # Requests validation methods
592
    #
593

    
594
    def validate_code_request(self, params, headers):
595
        client = self.validate_client(params, headers, requires_auth=False)
596
        redirect_uri = self.validate_redirect_uri(client, params, headers)
597
        scope = self.validate_scope(client, params, headers)
598
        scope = scope or redirect_uri  # set default
599
        state = self.validate_state(client, params, headers)
600
        return client, redirect_uri, scope, state
601

    
602
    def validate_token_request(self, params, headers, requires_auth=False):
603
        client = self.validate_client(params, headers)
604
        redirect_uri = self.validate_redirect_uri(client, params, headers)
605
        scope = self.validate_scope(client, params, headers)
606
        scope = scope or redirect_uri  # set default
607
        state = self.validate_state(client, params, headers)
608
        return client, redirect_uri, scope, state
609

    
610
    def validate_code_grant(self, params, headers):
611
        client = self.validate_client(params, headers,
612
                                      client_id_required=False)
613
        code_instance = self.validate_code(client, params, headers)
614
        redirect_uri = self.validate_redirect_uri(
615
            client, params, headers,
616
            expected_value=code_instance.redirect_uri)
617
        return client, redirect_uri, code_instance
618

    
619
    #
620
    # Endpoint methods
621
    #
622

    
623
    @handles_oa2_requests
624
    def authorize(self, request, **extra):
625
        """
626
        Used in the following cases
627
        """
628
        if not request.secure:
629
            raise OA2Error("Secure request required")
630

    
631
        # identify
632
        request_params = request.GET
633
        if request.method == "POST":
634
            request_params = request.POST
635

    
636
        auth_type, params = self.identify_authorize_request(request_params,
637
                                                            request.META)
638

    
639
        if auth_type is None:
640
            raise OA2Error("Missing authorization type")
641
        if auth_type == 'code':
642
            client, uri, scope, state = \
643
                self.validate_code_request(params, request.META)
644
        elif auth_type == 'token':
645
            raise OA2Error("Unsupported authorization type")
646
#            client, uri, scope, state = \
647
#                self.validate_token_request(params, request.META)
648
        else:
649
            #TODO: handle custom type
650
            raise OA2Error("Invalid authorization type")
651

    
652
        user = getattr(request, 'user', None)
653
        if not user:
654
            return self.redirect_to_login_response(request, params)
655

    
656
        if request.method == 'POST':
657
            if auth_type == 'code':
658
                return self.process_code_request(user, client, uri, scope,
659
                                                 state)
660
            elif auth_type == 'token':
661
                raise OA2Error("Unsupported response type")
662
#                return self.process_token_request(user, client, uri, scope,
663
#                                                 state)
664
            else:
665
                #TODO: handle custom type
666
                raise OA2Error("Invalid authorization type")
667
        else:
668
            if client.is_trusted:
669
                return self.process_code_request(user, client, uri, scope,
670
                                                 state)
671
            else:
672
                return self.grant_accept_response(client, uri, scope, state)
673

    
674
    @handles_oa2_requests
675
    def grant_token(self, request, **extra):
676
        """
677
        Used in the following cases
678
        """
679
        if not request.secure:
680
            raise OA2Error("Secure request required")
681

    
682
        grant_type = self.identify_token_request(request.META, request.POST)
683

    
684
        if grant_type is None:
685
            raise OA2Error("Missing grant type")
686
        elif grant_type == 'authorization_code':
687
            client, redirect_uri, code = \
688
                self.validate_code_grant(request.POST, request.META)
689
            token, token_type = \
690
                self.grant_authorization_code(client, code, redirect_uri)
691
            return self.grant_token_response(token, token_type)
692
        elif (grant_type in ['client_credentials', 'token'] or
693
              self.is_uri(grant_type)):
694
            raise OA2Error("Unsupported grant type")
695
        else:
696
            #TODO: handle custom type
697
            raise OA2Error("Invalid grant type")