Statistics
| Branch: | Tag: | Revision:

root / snf-astakos-app / astakos / oa2 / backends / base.py @ b806a15a

History | View | Annotate | Download (23.6 kB)

1
# Copyright 2013 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
import urllib
35
import urlparse
36
import uuid
37
import datetime
38
import json
39

    
40
from base64 import b64encode, b64decode
41
from hashlib import sha512
42

    
43
import logging
44
logger = logging.getLogger(__name__)
45

    
46

    
47
def handles_oa2_requests(func):
48
    def wrapper(self, *args, **kwargs):
49
        if not self._errors_to_http:
50
            return func(self, *args, **kwargs)
51
        try:
52
            return func(self, *args, **kwargs)
53
        except OA2Error, e:
54
            return self.build_response_from_error(e)
55
    return wrapper
56

    
57

    
58
class OA2Error(Exception):
59
    error = None
60

    
61

    
62
class InvalidClientID(OA2Error):
63
    pass
64

    
65

    
66
class NotAuthenticatedError(OA2Error):
67
    pass
68

    
69

    
70
class InvalidClientRedirectUrl(OA2Error):
71
    pass
72

    
73

    
74
class InvalidAuthorizationRequest(OA2Error):
75
    pass
76

    
77

    
78
class Response(object):
79

    
80
    def __init__(self, status, body='', headers=None,
81
                 content_type='plain/text'):
82
        if not body:
83
            body = ''
84
        if not headers:
85
            headers = {}
86

    
87
        self.status = status
88
        self.body = body
89
        self.headers = headers
90
        self.content_type = content_type
91

    
92
    def __repr__(self):
93
        return "%d RESPONSE (BODY: %r, HEADERS: %r)" % (self.status,
94
                                                        self.body,
95
                                                        self.headers)
96

    
97

    
98
class Request(object):
99

    
100
    def __init__(self, method, path, GET=None, POST=None, META=None,
101
                 secure=False, user=None):
102
        self.method = method
103
        self.path = path
104

    
105
        if not GET:
106
            GET = {}
107
        if not POST:
108
            POST = {}
109
        if not META:
110
            META = {}
111

    
112
        self.secure = secure
113
        self.GET = GET
114
        self.POST = POST
115
        self.META = META
116
        self.user = user
117

    
118
    def __repr__(self):
119
        prepend = ""
120
        if self.secure:
121
            prepend = "SECURE "
122
        return "%s%s REQUEST (POST: %r, GET:%r, HEADERS:%r, " % (prepend,
123
                                                                 self.method,
124
                                                                 self.POST,
125
                                                                 self.GET,
126
                                                                 self.META)
127

    
128

    
129
class ORMAbstractBase(type):
130

    
131
    def __new__(cls, name, bases, attrs):
132
        attrs['ENTRIES'] = {}
133
        return super(ORMAbstractBase, cls).__new__(cls, name, bases, attrs)
134

    
135

    
136
class ORMAbstract(object):
137

    
138
    ENTRIES = {}
139

    
140
    __metaclass__ = ORMAbstractBase
141

    
142
    def __init__(self, **kwargs):
143
        for key, value in kwargs.iteritems():
144
            setattr(self, key, value)
145

    
146
    @classmethod
147
    def create(cls, id, **params):
148
        params = cls.clean_params(params)
149
        params['id'] = id
150
        cls.ENTRIES[id] = cls(**params)
151
        return cls.get(id)
152

    
153
    @classmethod
154
    def get(cls, pk):
155
        return cls.ENTRIES.get(pk)
156

    
157
    @classmethod
158
    def clean_params(cls, params):
159
        return params
160

    
161

    
162
class Client(ORMAbstract):
163

    
164
    def get_id(self):
165
        return self.id
166

    
167
    def get_redirect_uris(self):
168
        return self.uris
169

    
170
    def get_default_redirect_uri(self):
171
        return self.uris[0]
172

    
173
    def redirect_uri_is_valid(self, redirect_uri):
174
        split = urlparse.urlsplit(redirect_uri)
175
        if split.scheme not in urlparse.uses_query:
176
            raise OA2Error("Invalid redirect url scheme")
177
        uris = self.get_redirect_uris()
178
        return redirect_uri in uris
179

    
180
    def requires_auth(self):
181
        if self.client_type == 'confidential':
182
            return True
183
        return 'secret' in dir(self)
184

    
185
    def check_credentials(self, username, secret):
186
        return username == self.id and secret == self.secret
187

    
188

    
189
class Token(ORMAbstract):
190

    
191
    def to_dict(self):
192
        params = {
193
            'access_token': self.token,
194
            'token_type': self.token_type,
195
            'expires_in': self.expires,
196
        }
197
        if self.refresh_token:
198
            params['refresh_token'] = self.refresh_token
199
        return params
200

    
201

    
202
class AuthorizationCode(ORMAbstract):
203
    pass
204

    
205

    
206
class User(ORMAbstract):
207
    pass
208

    
209

    
210
class BackendBase(type):
211

    
212
    def __new__(cls, name, bases, attrs):
213
        super_new = super(BackendBase, cls).__new__
214
        #parents = [b for b in bases if isinstance(b, BackendBase)]
215
        #meta = attrs.pop('Meta', None)
216
        return super_new(cls, name, bases, attrs)
217

    
218
    @classmethod
219
    def get_orm_options(cls, attrs):
220
        meta = attrs.pop('ORM', None)
221
        orm = {}
222
        if meta:
223
            for attr in dir(meta):
224
                orm[attr] = getattr(meta, attr)
225
        return orm
226

    
227

    
228
class SimpleBackend(object):
229

    
230
    __metaclass__ = BackendBase
231

    
232
    base_url = ''
233
    endpoints_prefix = 'oauth2/'
234

    
235
    token_endpoint = 'token/'
236
    token_length = 30
237
    token_expires = 20
238

    
239
    authorization_endpoint = 'auth/'
240
    authorization_code_length = 60
241
    authorization_response_types = ['code', 'token']
242

    
243
    grant_types = ['authorization_code']
244

    
245
    response_cls = Response
246
    request_cls = Request
247

    
248
    client_model = Client
249
    token_model = Token
250
    code_model = AuthorizationCode
251
    user_model = User
252

    
253
    def __init__(self, base_url='', endpoints_prefix='oauth2/', id='oauth2',
254
                 token_endpoint='token/', token_length=30,
255
                 token_expires=20, authorization_endpoint='auth/',
256
                 authorization_code_length=60,
257
                 redirect_uri_limit=5000, **kwargs):
258
        self.base_url = base_url
259
        self.endpoints_prefix = endpoints_prefix
260
        self.token_endpoint = token_endpoint
261
        self.token_length = token_length
262
        self.token_expires = token_expires
263
        self.authorization_endpoint = authorization_endpoint
264
        self.authorization_code_length = authorization_code_length
265
        self.id = id
266
        self._errors_to_http = kwargs.get('errors_to_http', True)
267
        self.redirect_uri_limit = redirect_uri_limit
268

    
269
    # Request/response builders
270
    def build_request(self, method, get, post, meta):
271
        return self.request_cls(method=method, GET=get, POST=post, META=meta)
272

    
273
    def build_response(self, status, headers=None, body=''):
274
        return self.response_cls(status=status, headers=headers, body=body)
275

    
276
    # ORM Methods
277
    def create_authorization_code(self, user, client, code, redirect_uri,
278
                                  scope, state, **kwargs):
279
        code_params = {
280
            'code': code,
281
            'redirect_uri': redirect_uri,
282
            'client': client,
283
            'scope': scope,
284
            'state': state,
285
            'user': user
286
        }
287
        code_params.update(kwargs)
288
        code_instance = self.code_model.create(**code_params)
289
        logger.info(u'%r created' % code_instance)
290
        return code_instance
291

    
292
    def _token_params(self, value, token_type, authorization, scope):
293
        created_at = datetime.datetime.now()
294
        expires = self.token_expires
295
        expires_at = created_at + datetime.timedelta(seconds=expires)
296
        token_params = {
297
            'code': value,
298
            'token_type': token_type,
299
            'created_at': created_at,
300
            'expires_at': expires_at,
301
            'user': authorization.user,
302
            'redirect_uri': authorization.redirect_uri,
303
            'client': authorization.client,
304
            'scope': authorization.scope,
305
        }
306
        return token_params
307

    
308
    def create_token(self, value, token_type, authorization, scope,
309
                     refresh=False):
310
        params = self._token_params(value, token_type, authorization, scope)
311
        if refresh:
312
            refresh_token = self.generate_token()
313
            params['refresh_token'] = refresh_token
314
            # TODO: refresh token expires ???
315
        token = self.token_model.create(**params)
316
        logger.info(u'%r created' % token)
317
        return token
318

    
319
#    def delete_authorization_code(self, code):
320
#        del self.code_model.ENTRIES[code]
321

    
322
    def get_client_by_id(self, client_id):
323
        return self.client_model.get(client_id)
324

    
325
    def get_client_by_credentials(self, username, password):
326
        return None
327

    
328
    def get_authorization_code(self, code):
329
        return self.code_model.get(code)
330

    
331
    def get_client_authorization_code(self, client, code):
332
        code_instance = self.get_authorization_code(code)
333
        if not code_instance:
334
            raise OA2Error("Invalid code")
335

    
336
        if client.get_id() != code_instance.client.get_id():
337
            raise OA2Error("Mismatching client with code client")
338
        return code_instance
339

    
340
    def client_id_exists(self, client_id):
341
        return bool(self.get_client_by_id(client_id))
342

    
343
    def build_site_url(self, prefix='', **params):
344
        params = urllib.urlencode(params)
345
        return "%s%s%s%s" % (self.base_url, self.endpoints_prefix, prefix,
346
                             params)
347

    
348
    def _get_uri_base(self, uri):
349
        split = urlparse.urlsplit(uri)
350
        return "%s://%s%s" % (split.scheme, split.netloc, split.path)
351

    
352
    def build_client_redirect_uri(self, client, uri, **params):
353
        if not client.redirect_uri_is_valid(uri):
354
            raise OA2Error("Invalid redirect uri")
355
        params = urllib.urlencode(params)
356
        uri = self._get_uri_base(uri)
357
        return "%s?%s" % (uri, params)
358

    
359
    def generate_authorization_code(self):
360
        dg64 = b64encode(sha512(str(uuid.uuid4())).hexdigest())
361
        return dg64[:self.authorization_code_length]
362

    
363
    def generate_token(self, *args, **kwargs):
364
        dg64 = b64encode(sha512(str(uuid.uuid4())).hexdigest())
365
        return dg64[:self.token_length]
366

    
367
    def add_authorization_code(self, user, client, redirect_uri, scope, state,
368
                               **kwargs):
369
        code = self.generate_authorization_code()
370
        self.create_authorization_code(user, client, code, redirect_uri, scope,
371
                                       state, **kwargs)
372
        return code
373

    
374
    def add_token_for_client(self, token_type, authorization, refresh=False):
375
        token = self.generate_token()
376
        self.create_token(token, token_type, authorization, refresh)
377
        return token
378

    
379
    #
380
    # Response helpers
381
    #
382

    
383
    def grant_accept_response(self, client, redirect_uri, scope, state):
384
        context = {'client': client.get_id(), 'redirect_uri': redirect_uri,
385
                   'scope': scope, 'state': state,
386
                   #'url': url,
387
                   }
388
        json_content = json.dumps(context)
389
        return self.response_cls(status=200, body=json_content)
390

    
391
    def grant_token_response(self, token, token_type):
392
        context = {'access_token': token, 'token_type': token_type,
393
                   'expires_in': self.token_expires}
394
        json_content = json.dumps(context)
395
        return self.response_cls(status=200, body=json_content)
396

    
397
    def redirect_to_login_response(self, request, params):
398
        parts = list(urlparse.urlsplit(request.path))
399
        parts[3] = urllib.urlencode(params)
400
        query = {'next': urlparse.urlunsplit(parts)}
401
        return Response(302,
402
                        headers={'Location': '%s?%s' %
403
                                 (self.get_login_uri(),
404
                                  urllib.urlencode(query))})
405

    
406
    def redirect_to_uri(self, redirect_uri, code, state=None):
407
        parts = list(urlparse.urlsplit(redirect_uri))
408
        params = dict(urlparse.parse_qsl(parts[3], keep_blank_values=True))
409
        params['code'] = code
410
        if state is not None:
411
            params['state'] = state
412
        parts[3] = urllib.urlencode(params)
413
        return Response(302,
414
                        headers={'Location': '%s' %
415
                                 urlparse.urlunsplit(parts)})
416

    
417
    def build_response_from_error(self, exception):
418
        response = Response(400)
419
        logger.exception(exception)
420
        error = 'generic_error'
421
        if exception.error:
422
            error = exception.error
423
        body = {
424
            'error': error,
425
            'exception': exception.message,
426
        }
427
        response.body = json.dumps(body)
428
        response.content_type = "application/json"
429
        return response
430

    
431
    #
432
    # Processor methods
433
    #
434

    
435
    def process_code_request(self, user, client, uri, scope, state):
436
        code = self.add_authorization_code(user, client, uri, scope, state)
437
        return self.redirect_to_uri(uri, code, state)
438

    
439
    #
440
    # Helpers
441
    #
442

    
443
    def grant_authorization_code(self, client, code_instance, redirect_uri,
444
                                 scope=None, token_type="Bearer"):
445
        if scope and code_instance.scope != scope:
446
            raise OA2Error("Invalid scope")
447
        if redirect_uri != code_instance.redirect_uri:
448
            raise OA2Error("The redirect uri does not match "
449
                           "the one used during authorization")
450
        token = self.add_token_for_client(token_type, code_instance)
451
        self.delete_authorization_code(code_instance)  # use only once
452
        return token, token_type
453

    
454
    def consume_token(self, token):
455
        token_instance = self.get_token(token)
456
        if datetime.datetime.now() > token_instance.expires_at:
457
            self.delete_token(token_instance)  # delete expired token
458
            raise OA2Error("Token has expired")
459
        # TODO: delete token?
460
        return token_instance
461

    
462
    def _get_credentials(self, params, headers):
463
        if 'HTTP_AUTHORIZATION' in headers:
464
            scheme, b64credentials = headers.get(
465
                'HTTP_AUTHORIZATION').split(" ")
466
            if scheme != 'Basic':
467
                # TODO: raise 401 + WWW-Authenticate
468
                raise OA2Error("Unsupported authorization scheme")
469
            credentials = b64decode(b64credentials).split(":")
470
            return scheme, credentials
471
        else:
472
            return None, None
473
        pass
474

    
475
    def _get_authorization(self, params, headers):
476
        scheme, client_credentials = self._get_credentials(params, headers)
477
        no_authorization = scheme is None and client_credentials is None
478
        if no_authorization:
479
            raise OA2Error("Missing authorization header")
480
        return client_credentials
481

    
482
    def get_redirect_uri_from_params(self, client, params, default=True):
483
        """
484
        Accepts a client instance and request parameters.
485
        """
486
        redirect_uri = params.get('redirect_uri', None)
487
        if not redirect_uri and default:
488
            redirect_uri = client.get_default_redirect_uri()
489
        else:
490
            # TODO: sanitize redirect_uri (self.clean_redirect_uri ???)
491
            # clean and validate
492
            if not client.redirect_uri_is_valid(redirect_uri):
493
                raise OA2Error("Invalid client redirect uri")
494
        return redirect_uri
495

    
496
    #
497
    # Request identifiers
498
    #
499

    
500
    def identify_authorize_request(self, params, headers):
501
        return params.get('response_type'), params
502

    
503
    def identify_token_request(self, headers, params):
504
        content_type = headers.get('CONTENT_TYPE')
505
        if content_type != 'application/x-www-form-urlencoded':
506
            raise OA2Error("Invalid Content-Type header")
507
        return params.get('grant_type')
508

    
509
    #
510
    # Parameters validation methods
511
    #
512

    
513
    def validate_client(self, params, meta, requires_auth=True,
514
                        client_id_required=True):
515
        client_id = params.get('client_id')
516
        if client_id is None and client_id_required:
517
            raise OA2Error("Client identification is required")
518

    
519
        client_credentials = None
520
        try:  # check authorization header
521
            client_credentials = self._get_authorization(params, meta)
522
            if client_credentials is not None:
523
                _client_id = client_credentials[0]
524
                if client_id is not None and client_id != _client_id:
525
                    raise OA2Error("Client identification conflicts "
526
                                   "with client authorization")
527
                client_id = _client_id
528
        except:
529
            pass
530

    
531
        if client_id is None:
532
            raise OA2Error("Missing client identification")
533

    
534
        client = self.get_client_by_id(client_id)
535

    
536
        if requires_auth and client.requires_auth():
537
            if client_credentials is None:
538
                raise OA2Error("Client authentication is required")
539

    
540
        if client_credentials is not None:
541
            self.check_credentials(client, *client_credentials)
542
        return client
543

    
544
    def validate_redirect_uri(self, client, params, headers,
545
                              allow_default=True, is_required=False,
546
                              expected_value=None):
547
        redirect_uri = params.get('redirect_uri')
548
        if is_required and redirect_uri is None:
549
            raise OA2Error("Missing redirect uri")
550
        if redirect_uri is not None:
551
            if not bool(urlparse.urlparse(redirect_uri).scheme):
552
                raise OA2Error("Redirect uri should be an absolute URI")
553
            if len(redirect_uri) > self.redirect_uri_limit:
554
                raise OA2Error("Redirect uri length limit exceeded")
555
            if not client.redirect_uri_is_valid(redirect_uri):
556
                raise OA2Error("Mismatching redirect uri")
557
            if expected_value is not None and redirect_uri != expected_value:
558
                raise OA2Error("Invalid redirect uri")
559
        else:
560
            try:
561
                redirect_uri = client.redirecturl_set.values_list('url',
562
                                                                  flat=True)[0]
563
            except IndexError:
564
                raise OA2Error("Unable to fallback to client redirect URI")
565
        return redirect_uri
566

    
567
    def validate_state(self, client, params, headers):
568
        return params.get('state')
569
        raise OA2Error("Invalid state")
570

    
571
    def validate_scope(self, client, params, headers):
572
        scope = params.get('scope')
573
        if scope is not None:
574
            scope = scope.split(' ')[0]  # keep only the first
575
        # TODO: check for invalid characters
576
        return scope
577

    
578
    def validate_code(self, client, params, headers):
579
        code = params.get('code')
580
        if code is None:
581
            raise OA2Error("Missing authorization code")
582
        return self.get_client_authorization_code(client, code)
583

    
584
    #
585
    # Requests validation methods
586
    #
587

    
588
    def validate_code_request(self, params, headers):
589
        client = self.validate_client(params, headers, requires_auth=False)
590
        redirect_uri = self.validate_redirect_uri(client, params, headers)
591
        scope = self.validate_scope(client, params, headers)
592
        scope = scope or redirect_uri  # set default
593
        state = self.validate_state(client, params, headers)
594
        return client, redirect_uri, scope, state
595

    
596
    def validate_token_request(self, params, headers, requires_auth=False):
597
        client = self.validate_client(params, headers)
598
        redirect_uri = self.validate_redirect_uri(client, params, headers)
599
        scope = self.validate_scope(client, params, headers)
600
        scope = scope or redirect_uri  # set default
601
        state = self.validate_state(client, params, headers)
602
        return client, redirect_uri, scope, state
603

    
604
    def validate_code_grant(self, params, headers):
605
        client = self.validate_client(params, headers,
606
                                      client_id_required=False)
607
        code_instance = self.validate_code(client, params, headers)
608
        redirect_uri = self.validate_redirect_uri(
609
            client, params, headers,
610
            expected_value=code_instance.redirect_uri)
611
        return client, redirect_uri, code_instance
612

    
613
    #
614
    # Endpoint methods
615
    #
616

    
617
    @handles_oa2_requests
618
    def authorize(self, request, **extra):
619
        """
620
        Used in the following cases
621
        """
622
        if not request.secure:
623
            raise OA2Error("Secure request required")
624

    
625
        # identify
626
        request_params = request.GET
627
        if request.method == "POST":
628
            request_params = request.POST
629

    
630
        auth_type, params = self.identify_authorize_request(request_params,
631
                                                            request.META)
632

    
633
        if auth_type is None:
634
            raise OA2Error("Missing authorization type")
635
        if auth_type == 'code':
636
            client, uri, scope, state = \
637
                self.validate_code_request(params, request.META)
638
        elif auth_type == 'token':
639
            raise OA2Error("Unsupported authorization type")
640
#            client, uri, scope, state = \
641
#                self.validate_token_request(params, request.META)
642
        else:
643
            #TODO: handle custom type
644
            raise OA2Error("Invalid authorization type")
645

    
646
        user = getattr(request, 'user', None)
647
        if not user:
648
            return self.redirect_to_login_response(request, params)
649

    
650
        if request.method == 'POST':
651
            if auth_type == 'code':
652
                return self.process_code_request(user, client, uri, scope,
653
                                                 state)
654
            elif auth_type == 'token':
655
                raise OA2Error("Unsupported response type")
656
#                return self.process_token_request(user, client, uri, scope,
657
#                                                 state)
658
            else:
659
                #TODO: handle custom type
660
                raise OA2Error("Invalid authorization type")
661
        else:
662
            if client.is_trusted:
663
                return self.process_code_request(user, client, uri, scope,
664
                                                 state)
665
            else:
666
                return self.grant_accept_response(client, uri, scope, state)
667

    
668
    @handles_oa2_requests
669
    def grant_token(self, request, **extra):
670
        """
671
        Used in the following cases
672
        """
673
        if not request.secure:
674
            raise OA2Error("Secure request required")
675

    
676
        grant_type = self.identify_token_request(request.META, request.POST)
677

    
678
        if grant_type is None:
679
            raise OA2Error("Missing grant type")
680
        elif grant_type == 'authorization_code':
681
            client, redirect_uri, code = \
682
                self.validate_code_grant(request.POST, request.META)
683
            token, token_type = \
684
                self.grant_authorization_code(client, code, redirect_uri)
685
            return self.grant_token_response(token, token_type)
686
        elif (grant_type in ['client_credentials', 'token'] or
687
              self.is_uri(grant_type)):
688
            raise OA2Error("Unsupported grant type")
689
        else:
690
            #TODO: handle custom type
691
            raise OA2Error("Invalid grant type")