Statistics
| Branch: | Tag: | Revision:

root / snf-cyclades-app / synnefo / api / util.py @ bd40abfa

History | View | Annotate | Download (19.4 kB)

1
# Copyright 2011-2012 GRNET S.A. All rights reserved.
2
#
3
# Redistribution and use in source and binary forms, with or
4
# without modification, are permitted provided that the following
5
# conditions are met:
6
#
7
#   1. Redistributions of source code must retain the above
8
#      copyright notice, this list of conditions and the following
9
#      disclaimer.
10
#
11
#   2. Redistributions in binary form must reproduce the above
12
#      copyright notice, this list of conditions and the following
13
#      disclaimer in the documentation and/or other materials
14
#      provided with the distribution.
15
#
16
# THIS SOFTWARE IS PROVIDED BY GRNET S.A. ``AS IS'' AND ANY EXPRESS
17
# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL GRNET S.A OR
20
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
23
# USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
24
# AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
26
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27
# POSSIBILITY OF SUCH DAMAGE.
28
#
29
# The views and conclusions contained in the software and
30
# documentation are those of the authors and should not be
31
# interpreted as representing official policies, either expressed
32
# or implied, of GRNET S.A.
33

    
34
import datetime
35
import ipaddr
36

    
37
from base64 import b64encode, b64decode
38
from datetime import timedelta, tzinfo
39
from functools import wraps
40
from hashlib import sha256
41
from logging import getLogger
42
from random import choice
43
from string import digits, lowercase, uppercase
44
from time import time
45
from traceback import format_exc
46
from wsgiref.handlers import format_date_time
47

    
48
import dateutil.parser
49

    
50
from Crypto.Cipher import AES
51

    
52
from django.conf import settings
53
from django.http import HttpResponse
54
from django.template.loader import render_to_string
55
from django.utils import simplejson as json
56
from django.utils.cache import add_never_cache_headers
57
from django.db.models import Q
58

    
59
from snf_django.lib.api import faults
60
from synnefo.db.models import (Flavor, VirtualMachine, VirtualMachineMetadata,
61
                               Network, BackendNetwork, NetworkInterface,
62
                               BridgePoolTable, MacPrefixPoolTable)
63
from synnefo.db.pools import EmptyPool
64

    
65
from synnefo.lib.astakos import get_user
66
from synnefo.plankton.backend import ImageBackend, NotAllowedError
67
from synnefo.settings import MAX_CIDR_BLOCK
68

    
69

    
70
log = getLogger('synnefo.api')
71

    
72

    
73
class UTC(tzinfo):
74
    def utcoffset(self, dt):
75
        return timedelta(0)
76

    
77
    def tzname(self, dt):
78
        return 'UTC'
79

    
80
    def dst(self, dt):
81
        return timedelta(0)
82

    
83

    
84
def isoformat(d):
85
    """Return an ISO8601 date string that includes a timezone."""
86

    
87
    return d.replace(tzinfo=UTC()).isoformat()
88

    
89

    
90
def isoparse(s):
91
    """Parse an ISO8601 date string into a datetime object."""
92

    
93
    if not s:
94
        return None
95

    
96
    try:
97
        since = dateutil.parser.parse(s)
98
        utc_since = since.astimezone(UTC()).replace(tzinfo=None)
99
    except ValueError:
100
        raise faults.BadRequest('Invalid changes-since parameter.')
101

    
102
    now = datetime.datetime.now()
103
    if utc_since > now:
104
        raise faults.BadRequest('changes-since value set in the future.')
105

    
106
    if now - utc_since > timedelta(seconds=settings.POLL_LIMIT):
107
        raise faults.BadRequest('Too old changes-since value.')
108

    
109
    return utc_since
110

    
111

    
112
def random_password():
113
    """Generates a random password
114

115
    We generate a windows compliant password: it must contain at least
116
    one charachter from each of the groups: upper case, lower case, digits.
117
    """
118

    
119
    pool = lowercase + uppercase + digits
120
    lowerset = set(lowercase)
121
    upperset = set(uppercase)
122
    digitset = set(digits)
123
    length = 10
124

    
125
    password = ''.join(choice(pool) for i in range(length - 2))
126

    
127
    # Make sure the password is compliant
128
    chars = set(password)
129
    if not chars & lowerset:
130
        password += choice(lowercase)
131
    if not chars & upperset:
132
        password += choice(uppercase)
133
    if not chars & digitset:
134
        password += choice(digits)
135

    
136
    # Pad if necessary to reach required length
137
    password += ''.join(choice(pool) for i in range(length - len(password)))
138

    
139
    return password
140

    
141

    
142
def zeropad(s):
143
    """Add zeros at the end of a string in order to make its length
144
       a multiple of 16."""
145

    
146
    npad = 16 - len(s) % 16
147
    return s + '\x00' * npad
148

    
149

    
150
def encrypt(plaintext):
151
    # Make sure key is 32 bytes long
152
    key = sha256(settings.SECRET_KEY).digest()
153

    
154
    aes = AES.new(key)
155
    enc = aes.encrypt(zeropad(plaintext))
156
    return b64encode(enc)
157

    
158

    
159
def get_vm(server_id, user_id, for_update=False, non_deleted=False,
160
           non_suspended=False):
161
    """Find a VirtualMachine instance based on ID and owner."""
162

    
163
    try:
164
        server_id = int(server_id)
165
        servers = VirtualMachine.objects
166
        if for_update:
167
            servers = servers.select_for_update()
168
        vm = servers.get(id=server_id, userid=user_id)
169
        if non_deleted and vm.deleted:
170
            raise VirtualMachine.DeletedError
171
        if non_suspended and vm.suspended:
172
            raise faults.Forbidden("Administratively Suspended VM")
173
        return vm
174
    except ValueError:
175
        raise faults.BadRequest('Invalid server ID.')
176
    except VirtualMachine.DoesNotExist:
177
        raise faults.ItemNotFound('Server not found.')
178

    
179

    
180
def get_vm_meta(vm, key):
181
    """Return a VirtualMachineMetadata instance or raise ItemNotFound."""
182

    
183
    try:
184
        return VirtualMachineMetadata.objects.get(meta_key=key, vm=vm)
185
    except VirtualMachineMetadata.DoesNotExist:
186
        raise faults.ItemNotFound('Metadata key not found.')
187

    
188

    
189
def get_image(image_id, user_id):
190
    """Return an Image instance or raise ItemNotFound."""
191

    
192
    backend = ImageBackend(user_id)
193
    try:
194
        image = backend.get_image(image_id)
195
        if not image:
196
            raise faults.ItemNotFound('Image not found.')
197
        return image
198
    finally:
199
        backend.close()
200

    
201

    
202
def get_image_dict(image_id, user_id):
203
    image = {}
204
    img = get_image(image_id, user_id)
205
    properties = img.get('properties', {})
206
    image['backend_id'] = img['location']
207
    image['format'] = img['disk_format']
208
    image['metadata'] = dict((key.upper(), val)
209
                             for key, val in properties.items())
210
    image['checksum'] = img['checksum']
211

    
212
    return image
213

    
214

    
215
def get_flavor(flavor_id, include_deleted=False):
216
    """Return a Flavor instance or raise ItemNotFound."""
217

    
218
    try:
219
        flavor_id = int(flavor_id)
220
        if include_deleted:
221
            return Flavor.objects.get(id=flavor_id)
222
        else:
223
            return Flavor.objects.get(id=flavor_id, deleted=include_deleted)
224
    except (ValueError, Flavor.DoesNotExist):
225
        raise faults.ItemNotFound('Flavor not found.')
226

    
227

    
228
def get_network(network_id, user_id, for_update=False):
229
    """Return a Network instance or raise ItemNotFound."""
230

    
231
    try:
232
        network_id = int(network_id)
233
        objects = Network.objects
234
        if for_update:
235
            objects = objects.select_for_update()
236
        return objects.get(Q(userid=user_id) | Q(public=True), id=network_id)
237
    except (ValueError, Network.DoesNotExist):
238
        raise faults.ItemNotFound('Network not found.')
239

    
240

    
241
def validate_network_params(subnet, gateway=None, subnet6=None, gateway6=None):
242
    try:
243
        # Use strict option to not all subnets with host bits set
244
        network = ipaddr.IPv4Network(subnet, strict=True)
245
    except ValueError:
246
        raise faults.BadRequest("Invalid network IPv4 subnet")
247

    
248
    # Check that network size is allowed!
249
    if not validate_network_size(network.prefixlen):
250
        raise faults.OverLimit(message="Unsupported network size",
251
                        details="Network mask must be in range (%s, 29]" %
252
                                MAX_CIDR_BLOCK)
253

    
254
    # Check that gateway belongs to network
255
    if gateway:
256
        try:
257
            gateway = ipaddr.IPv4Address(gateway)
258
        except ValueError:
259
            raise faults.BadRequest("Invalid network IPv4 gateway")
260
        if not gateway in network:
261
            raise faults.BadRequest("Invalid network IPv4 gateway")
262

    
263
    if subnet6:
264
        try:
265
            # Use strict option to not all subnets with host bits set
266
            network6 = ipaddr.IPv6Network(subnet6, strict=True)
267
        except ValueError:
268
            raise faults.BadRequest("Invalid network IPv6 subnet")
269
        if gateway6:
270
            try:
271
                gateway6 = ipaddr.IPv6Address(gateway6)
272
            except ValueError:
273
                raise faults.BadRequest("Invalid network IPv6 gateway")
274
            if not gateway6 in network6:
275
                raise faults.BadRequest("Invalid network IPv6 gateway")
276

    
277

    
278
def validate_network_size(cidr_block):
279
    """Return True if network size is allowed."""
280
    return cidr_block <= 29 and cidr_block > MAX_CIDR_BLOCK
281

    
282

    
283
def allocate_public_address(backend):
284
    """Allocate a public IP for a vm."""
285
    for network in backend_public_networks(backend):
286
        try:
287
            address = get_network_free_address(network)
288
            return (network, address)
289
        except EmptyPool:
290
            pass
291
    return (None, None)
292

    
293

    
294
def get_public_ip(backend):
295
    """Reserve an IP from a public network.
296

297
    This method should run inside a transaction.
298

299
    """
300
    address = None
301
    if settings.PUBLIC_USE_POOL:
302
        (network, address) = allocate_public_address(backend)
303
    else:
304
        for net in list(backend_public_networks(backend)):
305
            pool = net.get_pool()
306
            if not pool.empty():
307
                address = 'pool'
308
                network = net
309
                break
310
    if address is None:
311
        log.error("Public networks of backend %s are full", backend)
312
        raise faults.OverLimit("Can not allocate IP for new machine."
313
                        " Public networks are full.")
314
    return (network, address)
315

    
316

    
317
def backend_public_networks(backend):
318
    """Return available public networks of the backend.
319

320
    Iterator for non-deleted public networks that are available
321
    to the specified backend.
322

323
    """
324
    for network in Network.objects.filter(public=True, deleted=False):
325
        if BackendNetwork.objects.filter(network=network,
326
                                         backend=backend).exists():
327
            yield network
328

    
329

    
330
def get_network_free_address(network):
331
    """Reserve an IP address from the IP Pool of the network.
332

333
    Raises EmptyPool
334

335
    """
336

    
337
    pool = network.get_pool()
338
    address = pool.get()
339
    pool.save()
340
    return address
341

    
342

    
343
def get_nic(machine, network):
344
    try:
345
        return NetworkInterface.objects.get(machine=machine, network=network)
346
    except NetworkInterface.DoesNotExist:
347
        raise faults.ItemNotFound('Server not connected to this network.')
348

    
349

    
350
def get_nic_from_index(vm, nic_index):
351
    """Returns the nic_index-th nic of a vm
352
       Error Response Codes: itemNotFound (404), badMediaType (415)
353
    """
354
    matching_nics = vm.nics.filter(index=nic_index)
355
    matching_nics_len = len(matching_nics)
356
    if matching_nics_len < 1:
357
        raise faults.ItemNotFound('NIC not found on VM')
358
    elif matching_nics_len > 1:
359
        raise faults.BadMediaType('NIC index conflict on VM')
360
    nic = matching_nics[0]
361
    return nic
362

    
363

    
364
def get_request_dict(request):
365
    """Returns data sent by the client as a python dict."""
366

    
367
    data = request.raw_post_data
368
    if request.META.get('CONTENT_TYPE').startswith('application/json'):
369
        try:
370
            return json.loads(data)
371
        except ValueError:
372
            raise faults.BadRequest('Invalid JSON data.')
373
    else:
374
        raise faults.BadRequest('Unsupported Content-Type.')
375

    
376

    
377
def update_response_headers(request, response):
378
    if request.serialization == 'xml':
379
        response['Content-Type'] = 'application/xml'
380
    elif request.serialization == 'atom':
381
        response['Content-Type'] = 'application/atom+xml'
382
    else:
383
        response['Content-Type'] = 'application/json'
384

    
385
    if settings.TEST:
386
        response['Date'] = format_date_time(time())
387

    
388
    add_never_cache_headers(response)
389

    
390

    
391
def render_metadata(request, metadata, use_values=False, status=200):
392
    if request.serialization == 'xml':
393
        data = render_to_string('metadata.xml', {'metadata': metadata})
394
    else:
395
        if use_values:
396
            d = {'metadata': {'values': metadata}}
397
        else:
398
            d = {'metadata': metadata}
399
        data = json.dumps(d)
400
    return HttpResponse(data, status=status)
401

    
402

    
403
def render_meta(request, meta, status=200):
404
    if request.serialization == 'xml':
405
        data = render_to_string('meta.xml', dict(key=key, val=val))
406
    else:
407
        data = json.dumps(dict(meta=meta))
408
    return HttpResponse(data, status=status)
409

    
410

    
411
def render_fault(request, fault):
412
    if settings.DEBUG or settings.TEST:
413
        fault.details = format_exc(fault)
414

    
415
    if request.serialization == 'xml':
416
        data = render_to_string('fault.xml', {'fault': fault})
417
    else:
418
        d = {fault.name: {'code': fault.code,
419
                          'message': fault.message,
420
                          'details': fault.details}}
421
        data = json.dumps(d)
422

    
423
    resp = HttpResponse(data, status=fault.code)
424
    update_response_headers(request, resp)
425
    return resp
426

    
427

    
428
def request_serialization(request, atom_allowed=False):
429
    """Return the serialization format requested.
430

431
    Valid formats are 'json', 'xml' and 'atom' if `atom_allowed` is True.
432
    """
433

    
434
    path = request.path
435

    
436
    if path.endswith('.json'):
437
        return 'json'
438
    elif path.endswith('.xml'):
439
        return 'xml'
440
    elif atom_allowed and path.endswith('.atom'):
441
        return 'atom'
442

    
443
    for item in request.META.get('HTTP_ACCEPT', '').split(','):
444
        accept, sep, rest = item.strip().partition(';')
445
        if accept == 'application/json':
446
            return 'json'
447
        elif accept == 'application/xml':
448
            return 'xml'
449
        elif atom_allowed and accept == 'application/atom+xml':
450
            return 'atom'
451

    
452
    return 'json'
453

    
454

    
455
def api_method(http_method=None, atom_allowed=False):
456
    """Decorator function for views that implement an API method."""
457

    
458
    def decorator(func):
459
        @wraps(func)
460
        def wrapper(request, *args, **kwargs):
461
            try:
462
                request.serialization = request_serialization(request,
463
                                                              atom_allowed)
464
                get_user(request, settings.ASTAKOS_URL)
465
                if not request.user_uniq:
466
                    raise faults.Unauthorized('No user found.')
467
                if http_method and request.method != http_method:
468
                    raise faults.BadRequest('Method not allowed.')
469

    
470
                resp = func(request, *args, **kwargs)
471
                update_response_headers(request, resp)
472
                return resp
473
            except VirtualMachine.DeletedError:
474
                fault = faults.BadRequest('Server has been deleted.')
475
                return render_fault(request, fault)
476
            except Network.DeletedError:
477
                fault = faults.BadRequest('Network has been deleted.')
478
                return render_fault(request, fault)
479
            except VirtualMachine.BuildingError:
480
                fault = faults.BuildInProgress('Server is being built.')
481
                return render_fault(request, fault)
482
            except NotAllowedError:
483
                # Image Backend Unathorized
484
                fault = faults.Forbidden('Request not allowed.')
485
                return render_fault(request, fault)
486
            except faults.Fault, fault:
487
                if fault.code >= 500:
488
                    log.exception('API fault')
489
                return render_fault(request, fault)
490
            except BaseException:
491
                log.exception('Unexpected error')
492
                fault = faults.ServiceUnavailable('Unexpected error.')
493
                return render_fault(request, fault)
494
        return wrapper
495
    return decorator
496

    
497

    
498
def construct_nic_id(nic):
499
    return "-".join(["nic", unicode(nic.machine.id), unicode(nic.index)])
500

    
501

    
502
def verify_personality(personality):
503
    """Verify that a a list of personalities is well formed"""
504
    if len(personality) > settings.MAX_PERSONALITY:
505
        raise faults.OverLimit("Maximum number of personalities"
506
                        " exceeded")
507
    for p in personality:
508
        # Verify that personalities are well-formed
509
        try:
510
            assert isinstance(p, dict)
511
            keys = set(p.keys())
512
            allowed = set(['contents', 'group', 'mode', 'owner', 'path'])
513
            assert keys.issubset(allowed)
514
            contents = p['contents']
515
            if len(contents) > settings.MAX_PERSONALITY_SIZE:
516
                # No need to decode if contents already exceed limit
517
                raise faults.OverLimit("Maximum size of personality exceeded")
518
            if len(b64decode(contents)) > settings.MAX_PERSONALITY_SIZE:
519
                raise faults.OverLimit("Maximum size of personality exceeded")
520
        except AssertionError:
521
            raise faults.BadRequest("Malformed personality in request")
522

    
523

    
524
def get_flavor_provider(flavor):
525
    """Extract provider from disk template.
526

527
    Provider for `ext` disk_template is encoded in the disk template
528
    name, which is formed `ext_<provider_name>`. Provider is None
529
    for all other disk templates.
530

531
    """
532
    disk_template = flavor.disk_template
533
    provider = None
534
    if disk_template.startswith("ext"):
535
        disk_template, provider = disk_template.split("_", 1)
536
    return disk_template, provider
537

    
538

    
539
def values_from_flavor(flavor):
540
    """Get Ganeti connectivity info from flavor type.
541

542
    If link or mac_prefix equals to "pool", then the resources
543
    are allocated from the corresponding Pools.
544

545
    """
546
    try:
547
        flavor = Network.FLAVORS[flavor]
548
    except KeyError:
549
        raise faults.BadRequest("Unknown network flavor")
550

    
551
    mode = flavor.get("mode")
552

    
553
    link = flavor.get("link")
554
    if link == "pool":
555
        link = allocate_resource("bridge")
556

    
557
    mac_prefix = flavor.get("mac_prefix")
558
    if mac_prefix == "pool":
559
        mac_prefix = allocate_resource("mac_prefix")
560

    
561
    tags = flavor.get("tags")
562

    
563
    return mode, link, mac_prefix, tags
564

    
565

    
566
def allocate_resource(res_type):
567
    table = get_pool_table(res_type)
568
    pool = table.get_pool()
569
    value = pool.get()
570
    pool.save()
571
    return value
572

    
573

    
574
def release_resource(res_type, value):
575
    table = get_pool_table(res_type)
576
    pool = table.get_pool()
577
    pool.put(value)
578
    pool.save()
579

    
580

    
581
def get_pool_table(res_type):
582
    if res_type == "bridge":
583
        return BridgePoolTable
584
    elif res_type == "mac_prefix":
585
        return MacPrefixPoolTable
586
    else:
587
        raise Exception("Unknown resource type")
588

    
589

    
590
def get_existing_users():
591
    """
592
    Retrieve user ids stored in cyclades user agnostic models.
593
    """
594
    # also check PublicKeys a user with no servers/networks exist
595
    from synnefo.ui.userdata.models import PublicKeyPair
596
    from synnefo.db.models import VirtualMachine, Network
597

    
598
    keypairusernames = PublicKeyPair.objects.filter().values_list('user',
599
                                                                  flat=True)
600
    serverusernames = VirtualMachine.objects.filter().values_list('userid',
601
                                                                  flat=True)
602
    networkusernames = Network.objects.filter().values_list('userid',
603
                                                            flat=True)
604

    
605
    return set(list(keypairusernames) + list(serverusernames) +
606
               list(networkusernames))