Support for resolving hostnames to IPv6 addresses
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34
35 import testutils
36
37
38 class TestTimeouts(unittest.TestCase):
39   def test(self):
40     names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
41              if name.startswith("call_")]
42     self.assertEqual(len(names), len(rpc._TIMEOUTS))
43     self.assertFalse([name for name in names
44                       if not (rpc._TIMEOUTS[name] is None or
45                               rpc._TIMEOUTS[name] > 0)])
46
47
48 class FakeHttpPool:
49   def __init__(self, response_fn):
50     self._response_fn = response_fn
51     self.reqcount = 0
52
53   def ProcessRequests(self, reqs):
54     for req in reqs:
55       self.reqcount += 1
56       self._response_fn(req)
57
58
59 def GetFakeSimpleStoreClass(fn):
60   class FakeSimpleStore:
61     GetNodePrimaryIPList = fn
62
63   return FakeSimpleStore
64
65
66 class TestClient(unittest.TestCase):
67   def _FakeAddressLookup(self, map):
68     return lambda node_list: [map.get(node) for node in node_list]
69
70   def _GetVersionResponse(self, req):
71     self.assertEqual(req.host, "localhost")
72     self.assertEqual(req.port, 24094)
73     self.assertEqual(req.path, "/version")
74     req.success = True
75     req.resp_status_code = http.HTTP_OK
76     req.resp_body = serializer.DumpJson((True, 123))
77
78   def testVersionSuccess(self):
79     fn = self._FakeAddressLookup({"localhost": "localhost"})
80     client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
81     client.ConnectNode("localhost")
82     pool = FakeHttpPool(self._GetVersionResponse)
83     result = client.GetResults(http_pool=pool)
84     self.assertEqual(result.keys(), ["localhost"])
85     lhresp = result["localhost"]
86     self.assertFalse(lhresp.offline)
87     self.assertEqual(lhresp.node, "localhost")
88     self.assertFalse(lhresp.fail_msg)
89     self.assertEqual(lhresp.payload, 123)
90     self.assertEqual(lhresp.call, "version")
91     lhresp.Raise("should not raise")
92     self.assertEqual(pool.reqcount, 1)
93
94   def _GetMultiVersionResponse(self, req):
95     self.assert_(req.host.startswith("node"))
96     self.assertEqual(req.port, 23245)
97     self.assertEqual(req.path, "/version")
98     req.success = True
99     req.resp_status_code = http.HTTP_OK
100     req.resp_body = serializer.DumpJson((True, 987))
101
102   def testMultiVersionSuccess(self):
103     nodes = ["node%s" % i for i in range(50)]
104     fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
105     client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
106     client.ConnectList(nodes)
107
108     pool = FakeHttpPool(self._GetMultiVersionResponse)
109     result = client.GetResults(http_pool=pool)
110     self.assertEqual(sorted(result.keys()), sorted(nodes))
111
112     for name in nodes:
113       lhresp = result[name]
114       self.assertFalse(lhresp.offline)
115       self.assertEqual(lhresp.node, name)
116       self.assertFalse(lhresp.fail_msg)
117       self.assertEqual(lhresp.payload, 987)
118       self.assertEqual(lhresp.call, "version")
119       lhresp.Raise("should not raise")
120
121     self.assertEqual(pool.reqcount, len(nodes))
122
123   def _GetVersionResponseFail(self, req):
124     self.assertEqual(req.path, "/version")
125     req.success = True
126     req.resp_status_code = http.HTTP_OK
127     req.resp_body = serializer.DumpJson((False, "Unknown error"))
128
129   def testVersionFailure(self):
130     lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
131     fn = self._FakeAddressLookup(lookup_map)
132     client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
133     client.ConnectNode("aef9ur4i.example.com")
134     pool = FakeHttpPool(self._GetVersionResponseFail)
135     result = client.GetResults(http_pool=pool)
136     self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
137     lhresp = result["aef9ur4i.example.com"]
138     self.assertFalse(lhresp.offline)
139     self.assertEqual(lhresp.node, "aef9ur4i.example.com")
140     self.assert_(lhresp.fail_msg)
141     self.assertFalse(lhresp.payload)
142     self.assertEqual(lhresp.call, "version")
143     self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
144     self.assertEqual(pool.reqcount, 1)
145
146   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
147     self.assertEqual(req.path, "/vg_list")
148     self.assertEqual(req.port, 15165)
149
150     if req.host in httperrnodes:
151       req.success = False
152       req.error = "Node set up for HTTP errors"
153
154     elif req.host in failnodes:
155       req.success = True
156       req.resp_status_code = 404
157       req.resp_body = serializer.DumpJson({
158         "code": 404,
159         "message": "Method not found",
160         "explain": "Explanation goes here",
161         })
162     else:
163       req.success = True
164       req.resp_status_code = http.HTTP_OK
165       req.resp_body = serializer.DumpJson((True, hash(req.host)))
166
167   def testHttpError(self):
168     nodes = ["uaf6pbbv%s" % i for i in range(50)]
169     fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
170
171     httperrnodes = set(nodes[1::7])
172     self.assertEqual(len(httperrnodes), 7)
173
174     failnodes = set(nodes[2::3]) - httperrnodes
175     self.assertEqual(len(failnodes), 14)
176
177     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
178
179     client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
180     client.ConnectList(nodes)
181
182     pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
183                                        httperrnodes, failnodes))
184     result = client.GetResults(http_pool=pool)
185     self.assertEqual(sorted(result.keys()), sorted(nodes))
186
187     for name in nodes:
188       lhresp = result[name]
189       self.assertFalse(lhresp.offline)
190       self.assertEqual(lhresp.node, name)
191       self.assertEqual(lhresp.call, "vg_list")
192
193       if name in httperrnodes:
194         self.assert_(lhresp.fail_msg)
195         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
196       elif name in failnodes:
197         self.assert_(lhresp.fail_msg)
198         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
199                           prereq=True, ecode=errors.ECODE_INVAL)
200       else:
201         self.assertFalse(lhresp.fail_msg)
202         self.assertEqual(lhresp.payload, hash(name))
203         lhresp.Raise("should not raise")
204
205     self.assertEqual(pool.reqcount, len(nodes))
206
207   def _GetInvalidResponseA(self, req):
208     self.assertEqual(req.path, "/version")
209     req.success = True
210     req.resp_status_code = http.HTTP_OK
211     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
212                                          "response", "!", 1, 2, 3))
213
214   def _GetInvalidResponseB(self, req):
215     self.assertEqual(req.path, "/version")
216     req.success = True
217     req.resp_status_code = http.HTTP_OK
218     req.resp_body = serializer.DumpJson("invalid response")
219
220   def testInvalidResponse(self):
221     lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
222     fn = self._FakeAddressLookup(lookup_map)
223     client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
224     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
225       client.ConnectNode("oqo7lanhly.example.com")
226       pool = FakeHttpPool(fn)
227       result = client.GetResults(http_pool=pool)
228       self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
229       lhresp = result["oqo7lanhly.example.com"]
230       self.assertFalse(lhresp.offline)
231       self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
232       self.assert_(lhresp.fail_msg)
233       self.assertFalse(lhresp.payload)
234       self.assertEqual(lhresp.call, "version")
235       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
236       self.assertEqual(pool.reqcount, 1)
237
238   def testAddressLookupSimpleStore(self):
239     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
240     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
241     node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
242     ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list)
243     result = rpc._AddressLookup(node_list, ssc=ssc)
244     self.assertEqual(result, addr_list)
245
246   def testAddressLookupNSLookup(self):
247     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
248     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
249     ssc = GetFakeSimpleStoreClass(lambda s: [])
250     node_addr_map = dict(zip(node_list, addr_list))
251     nslookup_fn = lambda name: node_addr_map.get(name)
252     result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
253     self.assertEqual(result, addr_list)
254
255   def testAddressLookupBoth(self):
256     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
257     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
258     n = len(addr_list) / 2
259     node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
260     ssc = GetFakeSimpleStoreClass(lambda s: node_addr_list)
261     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
262     nslookup_fn = lambda name: node_addr_map.get(name)
263     result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
264     self.assertEqual(result, addr_list)
265
266
267 if __name__ == "__main__":
268   testutils.GanetiTestProgram()