Move _TimeoutExpired to utils
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34
35 import testutils
36
37
38 class TestTimeouts(unittest.TestCase):
39   def test(self):
40     names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
41              if name.startswith("call_")]
42     self.assertEqual(len(names), len(rpc._TIMEOUTS))
43     self.assertFalse([name for name in names
44                       if not (rpc._TIMEOUTS[name] is None or
45                               rpc._TIMEOUTS[name] > 0)])
46
47
48 class FakeHttpPool:
49   def __init__(self, response_fn):
50     self._response_fn = response_fn
51     self.reqcount = 0
52
53   def ProcessRequests(self, reqs):
54     for req in reqs:
55       self.reqcount += 1
56       self._response_fn(req)
57
58
59 def GetFakeSimpleStoreClass(fn):
60   class FakeSimpleStore:
61     GetNodePrimaryIPList = fn
62     GetPrimaryIPFamily = lambda _: None
63
64   return FakeSimpleStore
65
66
67 class TestClient(unittest.TestCase):
68   def _FakeAddressLookup(self, map):
69     return lambda node_list: [map.get(node) for node in node_list]
70
71   def _GetVersionResponse(self, req):
72     self.assertEqual(req.host, "localhost")
73     self.assertEqual(req.port, 24094)
74     self.assertEqual(req.path, "/version")
75     req.success = True
76     req.resp_status_code = http.HTTP_OK
77     req.resp_body = serializer.DumpJson((True, 123))
78
79   def testVersionSuccess(self):
80     fn = self._FakeAddressLookup({"localhost": "localhost"})
81     client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
82     client.ConnectNode("localhost")
83     pool = FakeHttpPool(self._GetVersionResponse)
84     result = client.GetResults(http_pool=pool)
85     self.assertEqual(result.keys(), ["localhost"])
86     lhresp = result["localhost"]
87     self.assertFalse(lhresp.offline)
88     self.assertEqual(lhresp.node, "localhost")
89     self.assertFalse(lhresp.fail_msg)
90     self.assertEqual(lhresp.payload, 123)
91     self.assertEqual(lhresp.call, "version")
92     lhresp.Raise("should not raise")
93     self.assertEqual(pool.reqcount, 1)
94
95   def _GetMultiVersionResponse(self, req):
96     self.assert_(req.host.startswith("node"))
97     self.assertEqual(req.port, 23245)
98     self.assertEqual(req.path, "/version")
99     req.success = True
100     req.resp_status_code = http.HTTP_OK
101     req.resp_body = serializer.DumpJson((True, 987))
102
103   def testMultiVersionSuccess(self):
104     nodes = ["node%s" % i for i in range(50)]
105     fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
106     client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
107     client.ConnectList(nodes)
108
109     pool = FakeHttpPool(self._GetMultiVersionResponse)
110     result = client.GetResults(http_pool=pool)
111     self.assertEqual(sorted(result.keys()), sorted(nodes))
112
113     for name in nodes:
114       lhresp = result[name]
115       self.assertFalse(lhresp.offline)
116       self.assertEqual(lhresp.node, name)
117       self.assertFalse(lhresp.fail_msg)
118       self.assertEqual(lhresp.payload, 987)
119       self.assertEqual(lhresp.call, "version")
120       lhresp.Raise("should not raise")
121
122     self.assertEqual(pool.reqcount, len(nodes))
123
124   def _GetVersionResponseFail(self, req):
125     self.assertEqual(req.path, "/version")
126     req.success = True
127     req.resp_status_code = http.HTTP_OK
128     req.resp_body = serializer.DumpJson((False, "Unknown error"))
129
130   def testVersionFailure(self):
131     lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
132     fn = self._FakeAddressLookup(lookup_map)
133     client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
134     client.ConnectNode("aef9ur4i.example.com")
135     pool = FakeHttpPool(self._GetVersionResponseFail)
136     result = client.GetResults(http_pool=pool)
137     self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
138     lhresp = result["aef9ur4i.example.com"]
139     self.assertFalse(lhresp.offline)
140     self.assertEqual(lhresp.node, "aef9ur4i.example.com")
141     self.assert_(lhresp.fail_msg)
142     self.assertFalse(lhresp.payload)
143     self.assertEqual(lhresp.call, "version")
144     self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
145     self.assertEqual(pool.reqcount, 1)
146
147   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
148     self.assertEqual(req.path, "/vg_list")
149     self.assertEqual(req.port, 15165)
150
151     if req.host in httperrnodes:
152       req.success = False
153       req.error = "Node set up for HTTP errors"
154
155     elif req.host in failnodes:
156       req.success = True
157       req.resp_status_code = 404
158       req.resp_body = serializer.DumpJson({
159         "code": 404,
160         "message": "Method not found",
161         "explain": "Explanation goes here",
162         })
163     else:
164       req.success = True
165       req.resp_status_code = http.HTTP_OK
166       req.resp_body = serializer.DumpJson((True, hash(req.host)))
167
168   def testHttpError(self):
169     nodes = ["uaf6pbbv%s" % i for i in range(50)]
170     fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
171
172     httperrnodes = set(nodes[1::7])
173     self.assertEqual(len(httperrnodes), 7)
174
175     failnodes = set(nodes[2::3]) - httperrnodes
176     self.assertEqual(len(failnodes), 14)
177
178     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
179
180     client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
181     client.ConnectList(nodes)
182
183     pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
184                                        httperrnodes, failnodes))
185     result = client.GetResults(http_pool=pool)
186     self.assertEqual(sorted(result.keys()), sorted(nodes))
187
188     for name in nodes:
189       lhresp = result[name]
190       self.assertFalse(lhresp.offline)
191       self.assertEqual(lhresp.node, name)
192       self.assertEqual(lhresp.call, "vg_list")
193
194       if name in httperrnodes:
195         self.assert_(lhresp.fail_msg)
196         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
197       elif name in failnodes:
198         self.assert_(lhresp.fail_msg)
199         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
200                           prereq=True, ecode=errors.ECODE_INVAL)
201       else:
202         self.assertFalse(lhresp.fail_msg)
203         self.assertEqual(lhresp.payload, hash(name))
204         lhresp.Raise("should not raise")
205
206     self.assertEqual(pool.reqcount, len(nodes))
207
208   def _GetInvalidResponseA(self, req):
209     self.assertEqual(req.path, "/version")
210     req.success = True
211     req.resp_status_code = http.HTTP_OK
212     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
213                                          "response", "!", 1, 2, 3))
214
215   def _GetInvalidResponseB(self, req):
216     self.assertEqual(req.path, "/version")
217     req.success = True
218     req.resp_status_code = http.HTTP_OK
219     req.resp_body = serializer.DumpJson("invalid response")
220
221   def testInvalidResponse(self):
222     lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
223     fn = self._FakeAddressLookup(lookup_map)
224     client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
225     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
226       client.ConnectNode("oqo7lanhly.example.com")
227       pool = FakeHttpPool(fn)
228       result = client.GetResults(http_pool=pool)
229       self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
230       lhresp = result["oqo7lanhly.example.com"]
231       self.assertFalse(lhresp.offline)
232       self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
233       self.assert_(lhresp.fail_msg)
234       self.assertFalse(lhresp.payload)
235       self.assertEqual(lhresp.call, "version")
236       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
237       self.assertEqual(pool.reqcount, 1)
238
239   def testAddressLookupSimpleStore(self):
240     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
241     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
242     node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
243     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
244     result = rpc._AddressLookup(node_list, ssc=ssc)
245     self.assertEqual(result, addr_list)
246
247   def testAddressLookupNSLookup(self):
248     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
249     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
250     ssc = GetFakeSimpleStoreClass(lambda _: [])
251     node_addr_map = dict(zip(node_list, addr_list))
252     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
253     result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
254     self.assertEqual(result, addr_list)
255
256   def testAddressLookupBoth(self):
257     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
258     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
259     n = len(addr_list) / 2
260     node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
261     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
262     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
263     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
264     result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
265     self.assertEqual(result, addr_list)
266
267   def testAddressLookupIPv6(self):
268     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)]
269     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
270     node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
271     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
272     result = rpc._AddressLookup(node_list, ssc=ssc)
273     self.assertEqual(result, addr_list)
274
275
276 if __name__ == "__main__":
277   testutils.GanetiTestProgram()