4 # Copyright (C) 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for testing ganeti.rpc"""
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
38 class TestTimeouts(unittest.TestCase):
40 names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
41 if name.startswith("call_")]
42 self.assertEqual(len(names), len(rpc._TIMEOUTS))
43 self.assertFalse([name for name in names
44 if not (rpc._TIMEOUTS[name] is None or
45 rpc._TIMEOUTS[name] > 0)])
49 def __init__(self, response_fn):
50 self._response_fn = response_fn
53 def ProcessRequests(self, reqs):
56 self._response_fn(req)
59 def GetFakeSimpleStoreClass(fn):
60 class FakeSimpleStore:
61 GetNodePrimaryIPList = fn
62 GetPrimaryIPFamily = lambda _: None
64 return FakeSimpleStore
67 class TestClient(unittest.TestCase):
68 def _FakeAddressLookup(self, map):
69 return lambda node_list: [map.get(node) for node in node_list]
71 def _GetVersionResponse(self, req):
72 self.assertEqual(req.host, "localhost")
73 self.assertEqual(req.port, 24094)
74 self.assertEqual(req.path, "/version")
76 req.resp_status_code = http.HTTP_OK
77 req.resp_body = serializer.DumpJson((True, 123))
79 def testVersionSuccess(self):
80 fn = self._FakeAddressLookup({"localhost": "localhost"})
81 client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
82 client.ConnectNode("localhost")
83 pool = FakeHttpPool(self._GetVersionResponse)
84 result = client.GetResults(http_pool=pool)
85 self.assertEqual(result.keys(), ["localhost"])
86 lhresp = result["localhost"]
87 self.assertFalse(lhresp.offline)
88 self.assertEqual(lhresp.node, "localhost")
89 self.assertFalse(lhresp.fail_msg)
90 self.assertEqual(lhresp.payload, 123)
91 self.assertEqual(lhresp.call, "version")
92 lhresp.Raise("should not raise")
93 self.assertEqual(pool.reqcount, 1)
95 def _GetMultiVersionResponse(self, req):
96 self.assert_(req.host.startswith("node"))
97 self.assertEqual(req.port, 23245)
98 self.assertEqual(req.path, "/version")
100 req.resp_status_code = http.HTTP_OK
101 req.resp_body = serializer.DumpJson((True, 987))
103 def testMultiVersionSuccess(self):
104 nodes = ["node%s" % i for i in range(50)]
105 fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
106 client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
107 client.ConnectList(nodes)
109 pool = FakeHttpPool(self._GetMultiVersionResponse)
110 result = client.GetResults(http_pool=pool)
111 self.assertEqual(sorted(result.keys()), sorted(nodes))
114 lhresp = result[name]
115 self.assertFalse(lhresp.offline)
116 self.assertEqual(lhresp.node, name)
117 self.assertFalse(lhresp.fail_msg)
118 self.assertEqual(lhresp.payload, 987)
119 self.assertEqual(lhresp.call, "version")
120 lhresp.Raise("should not raise")
122 self.assertEqual(pool.reqcount, len(nodes))
124 def _GetVersionResponseFail(self, req):
125 self.assertEqual(req.path, "/version")
127 req.resp_status_code = http.HTTP_OK
128 req.resp_body = serializer.DumpJson((False, "Unknown error"))
130 def testVersionFailure(self):
131 lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
132 fn = self._FakeAddressLookup(lookup_map)
133 client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
134 client.ConnectNode("aef9ur4i.example.com")
135 pool = FakeHttpPool(self._GetVersionResponseFail)
136 result = client.GetResults(http_pool=pool)
137 self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
138 lhresp = result["aef9ur4i.example.com"]
139 self.assertFalse(lhresp.offline)
140 self.assertEqual(lhresp.node, "aef9ur4i.example.com")
141 self.assert_(lhresp.fail_msg)
142 self.assertFalse(lhresp.payload)
143 self.assertEqual(lhresp.call, "version")
144 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
145 self.assertEqual(pool.reqcount, 1)
147 def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
148 self.assertEqual(req.path, "/vg_list")
149 self.assertEqual(req.port, 15165)
151 if req.host in httperrnodes:
153 req.error = "Node set up for HTTP errors"
155 elif req.host in failnodes:
157 req.resp_status_code = 404
158 req.resp_body = serializer.DumpJson({
160 "message": "Method not found",
161 "explain": "Explanation goes here",
165 req.resp_status_code = http.HTTP_OK
166 req.resp_body = serializer.DumpJson((True, hash(req.host)))
168 def testHttpError(self):
169 nodes = ["uaf6pbbv%s" % i for i in range(50)]
170 fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
172 httperrnodes = set(nodes[1::7])
173 self.assertEqual(len(httperrnodes), 7)
175 failnodes = set(nodes[2::3]) - httperrnodes
176 self.assertEqual(len(failnodes), 14)
178 self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
180 client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
181 client.ConnectList(nodes)
183 pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
184 httperrnodes, failnodes))
185 result = client.GetResults(http_pool=pool)
186 self.assertEqual(sorted(result.keys()), sorted(nodes))
189 lhresp = result[name]
190 self.assertFalse(lhresp.offline)
191 self.assertEqual(lhresp.node, name)
192 self.assertEqual(lhresp.call, "vg_list")
194 if name in httperrnodes:
195 self.assert_(lhresp.fail_msg)
196 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
197 elif name in failnodes:
198 self.assert_(lhresp.fail_msg)
199 self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
200 prereq=True, ecode=errors.ECODE_INVAL)
202 self.assertFalse(lhresp.fail_msg)
203 self.assertEqual(lhresp.payload, hash(name))
204 lhresp.Raise("should not raise")
206 self.assertEqual(pool.reqcount, len(nodes))
208 def _GetInvalidResponseA(self, req):
209 self.assertEqual(req.path, "/version")
211 req.resp_status_code = http.HTTP_OK
212 req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
213 "response", "!", 1, 2, 3))
215 def _GetInvalidResponseB(self, req):
216 self.assertEqual(req.path, "/version")
218 req.resp_status_code = http.HTTP_OK
219 req.resp_body = serializer.DumpJson("invalid response")
221 def testInvalidResponse(self):
222 lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
223 fn = self._FakeAddressLookup(lookup_map)
224 client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
225 for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
226 client.ConnectNode("oqo7lanhly.example.com")
227 pool = FakeHttpPool(fn)
228 result = client.GetResults(http_pool=pool)
229 self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
230 lhresp = result["oqo7lanhly.example.com"]
231 self.assertFalse(lhresp.offline)
232 self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
233 self.assert_(lhresp.fail_msg)
234 self.assertFalse(lhresp.payload)
235 self.assertEqual(lhresp.call, "version")
236 self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
237 self.assertEqual(pool.reqcount, 1)
239 def testAddressLookupSimpleStore(self):
240 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
241 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
242 node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
243 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
244 result = rpc._AddressLookup(node_list, ssc=ssc)
245 self.assertEqual(result, addr_list)
247 def testAddressLookupNSLookup(self):
248 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
249 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
250 ssc = GetFakeSimpleStoreClass(lambda _: [])
251 node_addr_map = dict(zip(node_list, addr_list))
252 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
253 result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
254 self.assertEqual(result, addr_list)
256 def testAddressLookupBoth(self):
257 addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
258 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
259 n = len(addr_list) / 2
260 node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
261 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
262 node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
263 nslookup_fn = lambda name, family=None: node_addr_map.get(name)
264 result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
265 self.assertEqual(result, addr_list)
267 def testAddressLookupIPv6(self):
268 addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)]
269 node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
270 node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
271 ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
272 result = rpc._AddressLookup(node_list, ssc=ssc)
273 self.assertEqual(result, addr_list)
276 if __name__ == "__main__":
277 testutils.GanetiTestProgram()