Revision eb202c13
b/lib/rpc.py | ||
---|---|---|
1 | 1 |
# |
2 | 2 |
# |
3 | 3 |
|
4 |
# Copyright (C) 2006, 2007 Google Inc. |
|
4 |
# Copyright (C) 2006, 2007, 2010 Google Inc.
|
|
5 | 5 |
# |
6 | 6 |
# This program is free software; you can redistribute it and/or modify |
7 | 7 |
# it under the terms of the GNU General Public License as published by |
... | ... | |
44 | 44 |
from ganeti import constants |
45 | 45 |
from ganeti import errors |
46 | 46 |
from ganeti import netutils |
47 |
from ganeti import ssconf |
|
47 | 48 |
|
48 | 49 |
# pylint has a bug here, doesn't see this import |
49 | 50 |
import ganeti.http.client # pylint: disable-msg=W0611 |
... | ... | |
256 | 257 |
raise ec(*args) # pylint: disable-msg=W0142 |
257 | 258 |
|
258 | 259 |
|
260 |
def _AddressLookup(node_list, |
|
261 |
ssc=ssconf.SimpleStore, |
|
262 |
nslookup_fn=netutils.HostInfo.LookupHostname): |
|
263 |
"""Return addresses for given node names. |
|
264 |
|
|
265 |
@type node_list: list |
|
266 |
@param node_list: List of node names |
|
267 |
@type ssc: class |
|
268 |
@param ssc: SimpleStore class that is used to obtain node->ip mappings |
|
269 |
@type lookup_fn: callable |
|
270 |
@param lookup_fn: function use to do NS lookup |
|
271 |
@rtype: list of addresses and/or None's |
|
272 |
@returns: List of corresponding addresses, if found |
|
273 |
|
|
274 |
""" |
|
275 |
def _NSLookup(name): |
|
276 |
_, _, addrs = nslookup_fn(name) |
|
277 |
return addrs[0] |
|
278 |
|
|
279 |
addresses = [] |
|
280 |
try: |
|
281 |
iplist = ssc().GetNodePrimaryIPList() |
|
282 |
ipmap = dict(entry.split() for entry in iplist) |
|
283 |
for node in node_list: |
|
284 |
address = ipmap.get(node) |
|
285 |
if address is None: |
|
286 |
address = _NSLookup(node) |
|
287 |
addresses.append(address) |
|
288 |
except errors.ConfigurationError: |
|
289 |
# Address not found in so we do a NS lookup |
|
290 |
addresses = [_NSLookup(node) for node in node_list] |
|
291 |
|
|
292 |
return addresses |
|
293 |
|
|
294 |
|
|
259 | 295 |
class Client: |
260 | 296 |
"""RPC Client class. |
261 | 297 |
|
... | ... | |
268 | 304 |
cause bugs. |
269 | 305 |
|
270 | 306 |
""" |
271 |
def __init__(self, procedure, body, port): |
|
307 |
def __init__(self, procedure, body, port, address_lookup_fn=_AddressLookup):
|
|
272 | 308 |
assert procedure in _TIMEOUTS, ("New RPC call not declared in the" |
273 | 309 |
" timeouts table") |
274 | 310 |
self.procedure = procedure |
275 | 311 |
self.body = body |
276 | 312 |
self.port = port |
277 | 313 |
self._request = {} |
314 |
self._address_lookup_fn = address_lookup_fn |
|
278 | 315 |
|
279 | 316 |
def ConnectList(self, node_list, address_list=None, read_timeout=None): |
280 | 317 |
"""Add a list of nodes to the target nodes. |
... | ... | |
285 | 322 |
@keyword address_list: either None or a list with node addresses, |
286 | 323 |
which must have the same length as the node list |
287 | 324 |
@type read_timeout: int |
288 |
@param read_timeout: overwrites the default read timeout for the |
|
289 |
given operation |
|
325 |
@param read_timeout: overwrites default timeout for operation |
|
290 | 326 |
|
291 | 327 |
""" |
292 | 328 |
if address_list is None: |
293 |
address_list = [None for _ in node_list] |
|
294 |
else: |
|
295 |
assert len(node_list) == len(address_list), \ |
|
296 |
"Name and address lists should have the same length" |
|
329 |
# Always use IP address instead of node name |
|
330 |
address_list = self._address_lookup_fn(node_list) |
|
331 |
|
|
332 |
assert len(node_list) == len(address_list), \ |
|
333 |
"Name and address lists must have the same length" |
|
334 |
|
|
297 | 335 |
for node, address in zip(node_list, address_list): |
298 | 336 |
self.ConnectNode(node, address, read_timeout=read_timeout) |
299 | 337 |
|
... | ... | |
303 | 341 |
@type name: str |
304 | 342 |
@param name: the node name |
305 | 343 |
@type address: str |
306 |
@keyword address: the node address, if known |
|
344 |
@param address: the node address, if known |
|
345 |
@type read_timeout: int |
|
346 |
@param read_timeout: overwrites default timeout for operation |
|
307 | 347 |
|
308 | 348 |
""" |
309 | 349 |
if address is None: |
310 |
address = name |
|
350 |
# Always use IP address instead of node name |
|
351 |
address = self._address_lookup_fn([name])[0] |
|
352 |
|
|
353 |
assert(address is not None) |
|
311 | 354 |
|
312 | 355 |
if read_timeout is None: |
313 | 356 |
read_timeout = _TIMEOUTS[self.procedure] |
b/test/ganeti.rpc_unittest.py | ||
---|---|---|
56 | 56 |
self._response_fn(req) |
57 | 57 |
|
58 | 58 |
|
59 |
def GetFakeSimpleStoreClass(fn): |
|
60 |
class FakeSimpleStore: |
|
61 |
GetNodePrimaryIPList = fn |
|
62 |
|
|
63 |
return FakeSimpleStore |
|
64 |
|
|
65 |
|
|
59 | 66 |
class TestClient(unittest.TestCase): |
67 |
def _FakeAddressLookup(self, map): |
|
68 |
return lambda node_list: [map.get(node) for node in node_list] |
|
69 |
|
|
60 | 70 |
def _GetVersionResponse(self, req): |
61 | 71 |
self.assertEqual(req.host, "localhost") |
62 | 72 |
self.assertEqual(req.port, 24094) |
... | ... | |
66 | 76 |
req.resp_body = serializer.DumpJson((True, 123)) |
67 | 77 |
|
68 | 78 |
def testVersionSuccess(self): |
69 |
client = rpc.Client("version", None, 24094) |
|
79 |
fn = self._FakeAddressLookup({"localhost": "localhost"}) |
|
80 |
client = rpc.Client("version", None, 24094, address_lookup_fn=fn) |
|
70 | 81 |
client.ConnectNode("localhost") |
71 | 82 |
pool = FakeHttpPool(self._GetVersionResponse) |
72 | 83 |
result = client.GetResults(http_pool=pool) |
... | ... | |
90 | 101 |
|
91 | 102 |
def testMultiVersionSuccess(self): |
92 | 103 |
nodes = ["node%s" % i for i in range(50)] |
93 |
client = rpc.Client("version", None, 23245) |
|
104 |
fn = self._FakeAddressLookup(dict(zip(nodes, nodes))) |
|
105 |
client = rpc.Client("version", None, 23245, address_lookup_fn=fn) |
|
94 | 106 |
client.ConnectList(nodes) |
95 | 107 |
|
96 | 108 |
pool = FakeHttpPool(self._GetMultiVersionResponse) |
... | ... | |
115 | 127 |
req.resp_body = serializer.DumpJson((False, "Unknown error")) |
116 | 128 |
|
117 | 129 |
def testVersionFailure(self): |
118 |
client = rpc.Client("version", None, 5903) |
|
130 |
lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"} |
|
131 |
fn = self._FakeAddressLookup(lookup_map) |
|
132 |
client = rpc.Client("version", None, 5903, address_lookup_fn=fn) |
|
119 | 133 |
client.ConnectNode("aef9ur4i.example.com") |
120 | 134 |
pool = FakeHttpPool(self._GetVersionResponseFail) |
121 | 135 |
result = client.GetResults(http_pool=pool) |
... | ... | |
152 | 166 |
|
153 | 167 |
def testHttpError(self): |
154 | 168 |
nodes = ["uaf6pbbv%s" % i for i in range(50)] |
169 |
fn = self._FakeAddressLookup(dict(zip(nodes, nodes))) |
|
155 | 170 |
|
156 | 171 |
httperrnodes = set(nodes[1::7]) |
157 | 172 |
self.assertEqual(len(httperrnodes), 7) |
... | ... | |
161 | 176 |
|
162 | 177 |
self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29) |
163 | 178 |
|
164 |
client = rpc.Client("vg_list", None, 15165) |
|
179 |
client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
|
|
165 | 180 |
client.ConnectList(nodes) |
166 | 181 |
|
167 | 182 |
pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse, |
... | ... | |
203 | 218 |
req.resp_body = serializer.DumpJson("invalid response") |
204 | 219 |
|
205 | 220 |
def testInvalidResponse(self): |
206 |
client = rpc.Client("version", None, 19978) |
|
221 |
lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"} |
|
222 |
fn = self._FakeAddressLookup(lookup_map) |
|
223 |
client = rpc.Client("version", None, 19978, address_lookup_fn=fn) |
|
207 | 224 |
for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]: |
208 | 225 |
client.ConnectNode("oqo7lanhly.example.com") |
209 | 226 |
pool = FakeHttpPool(fn) |
... | ... | |
218 | 235 |
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed") |
219 | 236 |
self.assertEqual(pool.reqcount, 1) |
220 | 237 |
|
238 |
def testAddressLookupSimpleStore(self): |
|
239 |
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] |
|
240 |
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] |
|
241 |
node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)] |
|
242 |
ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list) |
|
243 |
result = rpc._AddressLookup(node_list, ssc=ssc) |
|
244 |
self.assertEqual(result, addr_list) |
|
245 |
|
|
246 |
def testAddressLookupNSLookup(self): |
|
247 |
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] |
|
248 |
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] |
|
249 |
ssc = GetFakeSimpleStoreClass(lambda s: []) |
|
250 |
node_addr_map = dict(zip(node_list, addr_list)) |
|
251 |
nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)]) |
|
252 |
result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn) |
|
253 |
self.assertEqual(result, addr_list) |
|
254 |
|
|
255 |
def testAddressLookupBoth(self): |
|
256 |
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)] |
|
257 |
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)] |
|
258 |
n = len(addr_list) / 2 |
|
259 |
node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])] |
|
260 |
ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list) |
|
261 |
node_addr_map = dict(zip(node_list[:n], addr_list[:n])) |
|
262 |
nslookup_fn = lambda name: (None, None, [node_addr_map.get(name)]) |
|
263 |
result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn) |
|
264 |
self.assertEqual(result, addr_list) |
|
265 |
|
|
221 | 266 |
|
222 | 267 |
if __name__ == "__main__": |
223 | 268 |
testutils.GanetiTestProgram() |
Also available in: Unified diff