Statistics
| Branch: | Tag: | Revision:

root / snf-astakos-app / astakos / oa2 / backends / base.py @ 96b58530

History | View | Annotate | Download (23.9 kB)

1
# Copyright 2013 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
import urllib
35
import urlparse
36
import uuid
37
import datetime
38
import json
39

    
40
from base64 import b64encode, b64decode
41
from hashlib import sha512
42

    
43
import logging
44
logger = logging.getLogger(__name__)
45

    
46

    
47
def urlencode(params):
48
    if hasattr(params, 'urlencode') and callable(getattr(params, 'urlencode')):
49
        return params.urlencode()
50
    return urllib.urlencode(params)
51

    
52

    
53
def handles_oa2_requests(func):
54
    def wrapper(self, *args, **kwargs):
55
        if not self._errors_to_http:
56
            return func(self, *args, **kwargs)
57
        try:
58
            return func(self, *args, **kwargs)
59
        except OA2Error, e:
60
            return self.build_response_from_error(e)
61
    return wrapper
62

    
63

    
64
class OA2Error(Exception):
65
    error = None
66

    
67

    
68
class InvalidClientID(OA2Error):
69
    pass
70

    
71

    
72
class NotAuthenticatedError(OA2Error):
73
    pass
74

    
75

    
76
class InvalidClientRedirectUrl(OA2Error):
77
    pass
78

    
79

    
80
class InvalidAuthorizationRequest(OA2Error):
81
    pass
82

    
83

    
84
class Response(object):
85

    
86
    def __init__(self, status, body='', headers=None,
87
                 content_type='plain/text'):
88
        if not body:
89
            body = ''
90
        if not headers:
91
            headers = {}
92

    
93
        self.status = status
94
        self.body = body
95
        self.headers = headers
96
        self.content_type = content_type
97

    
98
    def __repr__(self):
99
        return "%d RESPONSE (BODY: %r, HEADERS: %r)" % (self.status,
100
                                                        self.body,
101
                                                        self.headers)
102

    
103

    
104
class Request(object):
105

    
106
    def __init__(self, method, path, GET=None, POST=None, META=None,
107
                 secure=False, user=None):
108
        self.method = method
109
        self.path = path
110

    
111
        if not GET:
112
            GET = {}
113
        if not POST:
114
            POST = {}
115
        if not META:
116
            META = {}
117

    
118
        self.secure = secure
119
        self.GET = GET
120
        self.POST = POST
121
        self.META = META
122
        self.user = user
123

    
124
    def __repr__(self):
125
        prepend = ""
126
        if self.secure:
127
            prepend = "SECURE "
128
        return "%s%s REQUEST (POST: %r, GET:%r, HEADERS:%r, " % (prepend,
129
                                                                 self.method,
130
                                                                 self.POST,
131
                                                                 self.GET,
132
                                                                 self.META)
133

    
134

    
135
class ORMAbstractBase(type):
136

    
137
    def __new__(cls, name, bases, attrs):
138
        attrs['ENTRIES'] = {}
139
        return super(ORMAbstractBase, cls).__new__(cls, name, bases, attrs)
140

    
141

    
142
class ORMAbstract(object):
143

    
144
    ENTRIES = {}
145

    
146
    __metaclass__ = ORMAbstractBase
147

    
148
    def __init__(self, **kwargs):
149
        for key, value in kwargs.iteritems():
150
            setattr(self, key, value)
151

    
152
    @classmethod
153
    def create(cls, id, **params):
154
        params = cls.clean_params(params)
155
        params['id'] = id
156
        cls.ENTRIES[id] = cls(**params)
157
        return cls.get(id)
158

    
159
    @classmethod
160
    def get(cls, pk):
161
        return cls.ENTRIES.get(pk)
162

    
163
    @classmethod
164
    def clean_params(cls, params):
165
        return params
166

    
167

    
168
class Client(ORMAbstract):
169

    
170
    def get_id(self):
171
        return self.id
172

    
173
    def get_redirect_uris(self):
174
        return self.uris
175

    
176
    def get_default_redirect_uri(self):
177
        return self.uris[0]
178

    
179
    def redirect_uri_is_valid(self, redirect_uri):
180
        split = urlparse.urlsplit(redirect_uri)
181
        if split.scheme not in urlparse.uses_query:
182
            raise OA2Error("Invalid redirect url scheme")
183
        uris = self.get_redirect_uris()
184
        return redirect_uri in uris
185

    
186
    def requires_auth(self):
187
        if self.client_type == 'confidential':
188
            return True
189
        return 'secret' in dir(self)
190

    
191
    def check_credentials(self, username, secret):
192
        return username == self.id and secret == self.secret
193

    
194

    
195
class Token(ORMAbstract):
196

    
197
    def to_dict(self):
198
        params = {
199
            'access_token': self.token,
200
            'token_type': self.token_type,
201
            'expires_in': self.expires,
202
        }
203
        if self.refresh_token:
204
            params['refresh_token'] = self.refresh_token
205
        return params
206

    
207

    
208
class AuthorizationCode(ORMAbstract):
209
    pass
210

    
211

    
212
class User(ORMAbstract):
213
    pass
214

    
215

    
216
class BackendBase(type):
217

    
218
    def __new__(cls, name, bases, attrs):
219
        super_new = super(BackendBase, cls).__new__
220
        #parents = [b for b in bases if isinstance(b, BackendBase)]
221
        #meta = attrs.pop('Meta', None)
222
        return super_new(cls, name, bases, attrs)
223

    
224
    @classmethod
225
    def get_orm_options(cls, attrs):
226
        meta = attrs.pop('ORM', None)
227
        orm = {}
228
        if meta:
229
            for attr in dir(meta):
230
                orm[attr] = getattr(meta, attr)
231
        return orm
232

    
233

    
234
class SimpleBackend(object):
235

    
236
    __metaclass__ = BackendBase
237

    
238
    base_url = ''
239
    endpoints_prefix = 'oauth2/'
240

    
241
    token_endpoint = 'token/'
242
    token_length = 30
243
    token_expires = 20
244

    
245
    authorization_endpoint = 'auth/'
246
    authorization_code_length = 60
247
    authorization_response_types = ['code', 'token']
248

    
249
    grant_types = ['authorization_code']
250

    
251
    response_cls = Response
252
    request_cls = Request
253

    
254
    client_model = Client
255
    token_model = Token
256
    code_model = AuthorizationCode
257
    user_model = User
258

    
259
    def __init__(self, base_url='', endpoints_prefix='oauth2/', id='oauth2',
260
                 token_endpoint='token/', token_length=30,
261
                 token_expires=20, authorization_endpoint='auth/',
262
                 authorization_code_length=60,
263
                 redirect_uri_limit=5000, **kwargs):
264
        self.base_url = base_url
265
        self.endpoints_prefix = endpoints_prefix
266
        self.token_endpoint = token_endpoint
267
        self.token_length = token_length
268
        self.token_expires = token_expires
269
        self.authorization_endpoint = authorization_endpoint
270
        self.authorization_code_length = authorization_code_length
271
        self.id = id
272
        self._errors_to_http = kwargs.get('errors_to_http', True)
273
        self.redirect_uri_limit = redirect_uri_limit
274

    
275
    # Request/response builders
276
    def build_request(self, method, get, post, meta):
277
        return self.request_cls(method=method, GET=get, POST=post, META=meta)
278

    
279
    def build_response(self, status, headers=None, body=''):
280
        return self.response_cls(status=status, headers=headers, body=body)
281

    
282
    # ORM Methods
283
    def create_authorization_code(self, user, client, code, redirect_uri,
284
                                  scope, state, **kwargs):
285
        code_params = {
286
            'code': code,
287
            'redirect_uri': redirect_uri,
288
            'client': client,
289
            'scope': scope,
290
            'state': state,
291
            'user': user
292
        }
293
        code_params.update(kwargs)
294
        code_instance = self.code_model.create(**code_params)
295
        logger.info(u'%r created' % code_instance)
296
        return code_instance
297

    
298
    def _token_params(self, value, token_type, authorization, scope):
299
        created_at = datetime.datetime.now()
300
        expires = self.token_expires
301
        expires_at = created_at + datetime.timedelta(seconds=expires)
302
        token_params = {
303
            'code': value,
304
            'token_type': token_type,
305
            'created_at': created_at,
306
            'expires_at': expires_at,
307
            'user': authorization.user,
308
            'redirect_uri': authorization.redirect_uri,
309
            'client': authorization.client,
310
            'scope': authorization.scope,
311
        }
312
        return token_params
313

    
314
    def create_token(self, value, token_type, authorization, scope,
315
                     refresh=False):
316
        params = self._token_params(value, token_type, authorization, scope)
317
        if refresh:
318
            refresh_token = self.generate_token()
319
            params['refresh_token'] = refresh_token
320
            # TODO: refresh token expires ???
321
        token = self.token_model.create(**params)
322
        logger.info(u'%r created' % token)
323
        return token
324

    
325
#    def delete_authorization_code(self, code):
326
#        del self.code_model.ENTRIES[code]
327

    
328
    def get_client_by_id(self, client_id):
329
        return self.client_model.get(client_id)
330

    
331
    def get_client_by_credentials(self, username, password):
332
        return None
333

    
334
    def get_authorization_code(self, code):
335
        return self.code_model.get(code)
336

    
337
    def get_client_authorization_code(self, client, code):
338
        code_instance = self.get_authorization_code(code)
339
        if not code_instance:
340
            raise OA2Error("Invalid code")
341

    
342
        if client.get_id() != code_instance.client.get_id():
343
            raise OA2Error("Mismatching client with code client")
344
        return code_instance
345

    
346
    def client_id_exists(self, client_id):
347
        return bool(self.get_client_by_id(client_id))
348

    
349
    def build_site_url(self, prefix='', **params):
350
        params = urlencode(params)
351
        return "%s%s%s%s" % (self.base_url, self.endpoints_prefix, prefix,
352
                             params)
353

    
354
    def _get_uri_base(self, uri):
355
        split = urlparse.urlsplit(uri)
356
        return "%s://%s%s" % (split.scheme, split.netloc, split.path)
357

    
358
    def build_client_redirect_uri(self, client, uri, **params):
359
        if not client.redirect_uri_is_valid(uri):
360
            raise OA2Error("Invalid redirect uri")
361
        params = urlencode(params)
362
        uri = self._get_uri_base(uri)
363
        return "%s?%s" % (uri, params)
364

    
365
    def generate_authorization_code(self):
366
        dg64 = b64encode(sha512(str(uuid.uuid4())).hexdigest())
367
        return dg64[:self.authorization_code_length]
368

    
369
    def generate_token(self, *args, **kwargs):
370
        dg64 = b64encode(sha512(str(uuid.uuid4())).hexdigest())
371
        return dg64[:self.token_length]
372

    
373
    def add_authorization_code(self, user, client, redirect_uri, scope, state,
374
                               **kwargs):
375
        code = self.generate_authorization_code()
376
        self.create_authorization_code(user, client, code, redirect_uri, scope,
377
                                       state, **kwargs)
378
        return code
379

    
380
    def add_token_for_client(self, token_type, authorization, refresh=False):
381
        token = self.generate_token()
382
        self.create_token(token, token_type, authorization, refresh)
383
        return token
384

    
385
    #
386
    # Response helpers
387
    #
388

    
389
    def grant_accept_response(self, client, redirect_uri, scope, state):
390
        context = {'client': client.get_id(), 'redirect_uri': redirect_uri,
391
                   'scope': scope, 'state': state,
392
                   #'url': url,
393
                   }
394
        json_content = json.dumps(context)
395
        return self.response_cls(status=200, body=json_content)
396

    
397
    def grant_token_response(self, token, token_type):
398
        context = {'access_token': token, 'token_type': token_type,
399
                   'expires_in': self.token_expires}
400
        json_content = json.dumps(context)
401
        return self.response_cls(status=200, body=json_content)
402

    
403
    def redirect_to_login_response(self, request, params):
404
        parts = list(urlparse.urlsplit(request.path))
405
        parts[3] = urlencode(params)
406
        query = {'next': urlparse.urlunsplit(parts)}
407
        return Response(302,
408
                        headers={'Location': '%s?%s' %
409
                                 (self.get_login_uri(),
410
                                  urlencode(query))})
411

    
412
    def redirect_to_uri(self, redirect_uri, code, state=None):
413
        parts = list(urlparse.urlsplit(redirect_uri))
414
        params = dict(urlparse.parse_qsl(parts[3], keep_blank_values=True))
415
        params['code'] = code
416
        if state is not None:
417
            params['state'] = state
418
        parts[3] = urlencode(params)
419
        return Response(302,
420
                        headers={'Location': '%s' %
421
                                 urlparse.urlunsplit(parts)})
422

    
423
    def build_response_from_error(self, exception):
424
        response = Response(400)
425
        logger.exception(exception)
426
        error = 'generic_error'
427
        if exception.error:
428
            error = exception.error
429
        body = {
430
            'error': error,
431
            'exception': exception.message,
432
        }
433
        response.body = json.dumps(body)
434
        response.content_type = "application/json"
435
        return response
436

    
437
    #
438
    # Processor methods
439
    #
440

    
441
    def process_code_request(self, user, client, uri, scope, state):
442
        code = self.add_authorization_code(user, client, uri, scope, state)
443
        return self.redirect_to_uri(uri, code, state)
444

    
445
    #
446
    # Helpers
447
    #
448

    
449
    def grant_authorization_code(self, client, code_instance, redirect_uri,
450
                                 scope=None, token_type="Bearer"):
451
        if scope and code_instance.scope != scope:
452
            raise OA2Error("Invalid scope")
453
        if redirect_uri != code_instance.redirect_uri:
454
            raise OA2Error("The redirect uri does not match "
455
                           "the one used during authorization")
456
        token = self.add_token_for_client(token_type, code_instance)
457
        self.delete_authorization_code(code_instance)  # use only once
458
        return token, token_type
459

    
460
    def consume_token(self, token):
461
        token_instance = self.get_token(token)
462
        if datetime.datetime.now() > token_instance.expires_at:
463
            self.delete_token(token_instance)  # delete expired token
464
            raise OA2Error("Token has expired")
465
        # TODO: delete token?
466
        return token_instance
467

    
468
    def _get_credentials(self, params, headers):
469
        if 'HTTP_AUTHORIZATION' in headers:
470
            scheme, b64credentials = headers.get(
471
                'HTTP_AUTHORIZATION').split(" ")
472
            if scheme != 'Basic':
473
                # TODO: raise 401 + WWW-Authenticate
474
                raise OA2Error("Unsupported authorization scheme")
475
            credentials = b64decode(b64credentials).split(":")
476
            return scheme, credentials
477
        else:
478
            return None, None
479
        pass
480

    
481
    def _get_authorization(self, params, headers, authorization_required=True):
482
        scheme, client_credentials = self._get_credentials(params, headers)
483
        no_authorization = scheme is None and client_credentials is None
484
        if authorization_required and no_authorization:
485
            raise OA2Error("Missing authorization header")
486
        return client_credentials
487

    
488
    def get_redirect_uri_from_params(self, client, params, default=True):
489
        """
490
        Accepts a client instance and request parameters.
491
        """
492
        redirect_uri = params.get('redirect_uri', None)
493
        if not redirect_uri and default:
494
            redirect_uri = client.get_default_redirect_uri()
495
        else:
496
            # TODO: sanitize redirect_uri (self.clean_redirect_uri ???)
497
            # clean and validate
498
            if not client.redirect_uri_is_valid(redirect_uri):
499
                raise OA2Error("Invalid client redirect uri")
500
        return redirect_uri
501

    
502
    #
503
    # Request identifiers
504
    #
505

    
506
    def identify_authorize_request(self, params, headers):
507
        return params.get('response_type'), params
508

    
509
    def identify_token_request(self, headers, params):
510
        content_type = headers.get('CONTENT_TYPE')
511
        if content_type != 'application/x-www-form-urlencoded':
512
            raise OA2Error("Invalid Content-Type header")
513
        return params.get('grant_type')
514

    
515
    #
516
    # Parameters validation methods
517
    #
518

    
519
    def validate_client(self, params, meta, requires_auth=True,
520
                        client_id_required=True):
521
        client_id = params.get('client_id')
522
        if client_id is None and client_id_required:
523
            raise OA2Error("Client identification is required")
524

    
525
        client_credentials = None
526
        try:  # check authorization header
527
            client_credentials = self._get_authorization(params, meta,
528
                                                         authorization_required=False)
529
        except:
530
            pass
531
        else:
532
            if client_credentials is not None:
533
                _client_id = client_credentials[0]
534
                if client_id is not None and client_id != _client_id:
535
                    raise OA2Error("Client identification conflicts "
536
                                   "with client authorization")
537
                client_id = _client_id
538

    
539
        if client_id is None:
540
            raise OA2Error("Missing client identification")
541

    
542
        client = self.get_client_by_id(client_id)
543

    
544
        if requires_auth and client.requires_auth():
545
            if client_credentials is None:
546
                raise OA2Error("Client authentication is required")
547

    
548
        if client_credentials is not None:
549
            self.check_credentials(client, *client_credentials)
550
        return client
551

    
552
    def validate_redirect_uri(self, client, params, headers,
553
                              allow_default=True, is_required=False,
554
                              expected_value=None):
555
        redirect_uri = params.get('redirect_uri')
556
        if is_required and redirect_uri is None:
557
            raise OA2Error("Missing redirect uri")
558
        if redirect_uri is not None:
559
            if not bool(urlparse.urlparse(redirect_uri).scheme):
560
                raise OA2Error("Redirect uri should be an absolute URI")
561
            if len(redirect_uri) > self.redirect_uri_limit:
562
                raise OA2Error("Redirect uri length limit exceeded")
563
            if not client.redirect_uri_is_valid(redirect_uri):
564
                raise OA2Error("Mismatching redirect uri")
565
            if expected_value is not None and redirect_uri != expected_value:
566
                raise OA2Error("Invalid redirect uri")
567
        else:
568
            try:
569
                redirect_uri = client.redirecturl_set.values_list('url',
570
                                                                  flat=True)[0]
571
            except IndexError:
572
                raise OA2Error("Unable to fallback to client redirect URI")
573
        return redirect_uri
574

    
575
    def validate_state(self, client, params, headers):
576
        return params.get('state')
577
        raise OA2Error("Invalid state")
578

    
579
    def validate_scope(self, client, params, headers):
580
        scope = params.get('scope')
581
        if scope is not None:
582
            scope = scope.split(' ')[0]  # keep only the first
583
        # TODO: check for invalid characters
584
        return scope
585

    
586
    def validate_code(self, client, params, headers):
587
        code = params.get('code')
588
        if code is None:
589
            raise OA2Error("Missing authorization code")
590
        return self.get_client_authorization_code(client, code)
591

    
592
    #
593
    # Requests validation methods
594
    #
595

    
596
    def validate_code_request(self, params, headers):
597
        client = self.validate_client(params, headers, requires_auth=False)
598
        redirect_uri = self.validate_redirect_uri(client, params, headers)
599
        scope = self.validate_scope(client, params, headers)
600
        scope = scope or redirect_uri  # set default
601
        state = self.validate_state(client, params, headers)
602
        return client, redirect_uri, scope, state
603

    
604
    def validate_token_request(self, params, headers, requires_auth=False):
605
        client = self.validate_client(params, headers)
606
        redirect_uri = self.validate_redirect_uri(client, params, headers)
607
        scope = self.validate_scope(client, params, headers)
608
        scope = scope or redirect_uri  # set default
609
        state = self.validate_state(client, params, headers)
610
        return client, redirect_uri, scope, state
611

    
612
    def validate_code_grant(self, params, headers):
613
        client = self.validate_client(params, headers,
614
                                      client_id_required=False)
615
        code_instance = self.validate_code(client, params, headers)
616
        redirect_uri = self.validate_redirect_uri(
617
            client, params, headers,
618
            expected_value=code_instance.redirect_uri)
619
        return client, redirect_uri, code_instance
620

    
621
    #
622
    # Endpoint methods
623
    #
624

    
625
    @handles_oa2_requests
626
    def authorize(self, request, **extra):
627
        """
628
        Used in the following cases
629
        """
630
        if not request.secure:
631
            raise OA2Error("Secure request required")
632

    
633
        # identify
634
        request_params = request.GET
635
        if request.method == "POST":
636
            request_params = request.POST
637

    
638
        auth_type, params = self.identify_authorize_request(request_params,
639
                                                            request.META)
640

    
641
        if auth_type is None:
642
            raise OA2Error("Missing authorization type")
643
        if auth_type == 'code':
644
            client, uri, scope, state = \
645
                self.validate_code_request(params, request.META)
646
        elif auth_type == 'token':
647
            raise OA2Error("Unsupported authorization type")
648
#            client, uri, scope, state = \
649
#                self.validate_token_request(params, request.META)
650
        else:
651
            #TODO: handle custom type
652
            raise OA2Error("Invalid authorization type")
653

    
654
        user = getattr(request, 'user', None)
655
        if not user:
656
            return self.redirect_to_login_response(request, params)
657

    
658
        if request.method == 'POST':
659
            if auth_type == 'code':
660
                return self.process_code_request(user, client, uri, scope,
661
                                                 state)
662
            elif auth_type == 'token':
663
                raise OA2Error("Unsupported response type")
664
#                return self.process_token_request(user, client, uri, scope,
665
#                                                 state)
666
            else:
667
                #TODO: handle custom type
668
                raise OA2Error("Invalid authorization type")
669
        else:
670
            if client.is_trusted:
671
                return self.process_code_request(user, client, uri, scope,
672
                                                 state)
673
            else:
674
                return self.grant_accept_response(client, uri, scope, state)
675

    
676
    @handles_oa2_requests
677
    def grant_token(self, request, **extra):
678
        """
679
        Used in the following cases
680
        """
681
        if not request.secure:
682
            raise OA2Error("Secure request required")
683

    
684
        grant_type = self.identify_token_request(request.META, request.POST)
685

    
686
        if grant_type is None:
687
            raise OA2Error("Missing grant type")
688
        elif grant_type == 'authorization_code':
689
            client, redirect_uri, code = \
690
                self.validate_code_grant(request.POST, request.META)
691
            token, token_type = \
692
                self.grant_authorization_code(client, code, redirect_uri)
693
            return self.grant_token_response(token, token_type)
694
        elif (grant_type in ['client_credentials', 'token'] or
695
              self.is_uri(grant_type)):
696
            raise OA2Error("Unsupported grant type")
697
        else:
698
            #TODO: handle custom type
699
            raise OA2Error("Invalid grant type")