Revision 14402edc snf-cyclades-app/synnefo/api/ports.py

b/snf-cyclades-app/synnefo/api/ports.py
107 107
    req = api.utils.get_request_dict(request)
108 108
    log.info('create_port %s', req)
109 109

  
110
    port_dict = api.utils.get_attribute(req, "port")
111
    net_id = api.utils.get_attribute(port_dict, "network_id")
110
    port_dict = api.utils.get_attribute(req, "port", attr_type=dict)
111
    net_id = api.utils.get_attribute(port_dict, "network_id",
112
                                     attr_type=(basestring, int))
112 113

  
113
    device_id = api.utils.get_attribute(port_dict, "device_id", required=False)
114
    device_id = api.utils.get_attribute(port_dict, "device_id", required=False,
115
                                        attr_type=(basestring, int))
114 116
    vm = None
115 117
    if device_id is not None:
116 118
        vm = util.get_vm(device_id, user_id, for_update=True, non_deleted=True,
117 119
                         non_suspended=True)
118 120

  
119 121
    # Check if the request contains a valid IPv4 address
120
    fixed_ips = api.utils.get_attribute(port_dict, "fixed_ips", required=False)
122
    fixed_ips = api.utils.get_attribute(port_dict, "fixed_ips", required=False,
123
                                        attr_type=list)
121 124
    if fixed_ips is not None and len(fixed_ips) > 0:
122 125
        if len(fixed_ips) > 1:
123 126
            msg = "'fixed_ips' attribute must contain only one fixed IP."
124 127
            raise faults.BadRequest(msg)
125
        fixed_ip_address = fixed_ips[0].get("ip_address")
128
        fixed_ip = fixed_ips[0]
129
        if not isinstance(fixed_ip, dict):
130
            raise faults.BadRequest("Invalid 'fixed_ips' field.")
131
        fixed_ip_address = fixed_ip.get("ip_address")
126 132
        if fixed_ip_address is not None:
127 133
            try:
128 134
                ip = ipaddr.IPAddress(fixed_ip_address)
......
153 159
        ipaddress = ips.allocate_ip(network, user_id,
154 160
                                    address=fixed_ip_address)
155 161

  
156
    name = api.utils.get_attribute(port_dict, "name", required=False)
162
    name = api.utils.get_attribute(port_dict, "name", required=False,
163
                                   attr_type=basestring)
157 164
    if name is None:
158 165
        name = ""
159 166

  
160 167
    security_groups = api.utils.get_attribute(port_dict,
161 168
                                              "security_groups",
162
                                              required=False)
169
                                              required=False,
170
                                              attr_type=list)
163 171
    #validate security groups
164 172
    # like get security group from db
165 173
    sg_list = []
166 174
    if security_groups:
167 175
        for gid in security_groups:
168
            sg = util.get_security_group(int(gid))
176
            try:
177
                sg = util.get_security_group(int(gid))
178
            except (KeyError, ValueError):
179
                raise faults.BadRequest("Invalid 'security_groups' field.")
169 180
            sg_list.append(sg)
170 181

  
171 182
    new_port = servers.create_port(user_id, network, use_ipaddress=ipaddress,
......
191 202
    port = util.get_port(port_id, request.user_uniq, for_update=True)
192 203
    req = api.utils.get_request_dict(request)
193 204

  
194
    port_info = api.utils.get_attribute(req, "port", required=True)
195
    name = api.utils.get_attribute(port_info, "name", required=False)
205
    port_info = api.utils.get_attribute(req, "port", required=True,
206
                                        attr_type=dict)
207
    name = api.utils.get_attribute(port_info, "name", required=False,
208
                                   attr_type=basestring)
196 209

  
197 210
    if name:
198 211
        port.name = name
199 212

  
200 213
    security_groups = api.utils.get_attribute(port_info, "security_groups",
201
                                              required=False)
214
                                              required=False, attr_type=list)
215

  
202 216
    if security_groups:
203 217
        sg_list = []
204 218
        #validate security groups
205 219
        for gid in security_groups:
206
            sg = util.get_security_group(int(gid))
220
            try:
221
                sg = util.get_security_group(int(gid))
222
            except (KeyError, ValueError):
223
                raise faults.BadRequest("Invalid 'security_groups' field.")
207 224
            sg_list.append(sg)
208 225

  
209 226
        #clear the old security groups

Also available in: Unified diff