Add simple unittest for utils.CommaJoin
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34
35 import testutils
36
37
38 class TestTimeouts(unittest.TestCase):
39   def test(self):
40     names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
41              if name.startswith("call_")]
42     self.assertEqual(len(names), len(rpc._TIMEOUTS))
43     self.assertFalse([name for name in names
44                       if not (rpc._TIMEOUTS[name] is None or
45                               rpc._TIMEOUTS[name] > 0)])
46
47
48 class FakeHttpPool:
49   def __init__(self, response_fn):
50     self._response_fn = response_fn
51     self.reqcount = 0
52
53   def ProcessRequests(self, reqs):
54     for req in reqs:
55       self.reqcount += 1
56       self._response_fn(req)
57
58
59 class TestClient(unittest.TestCase):
60   def _GetVersionResponse(self, req):
61     self.assertEqual(req.host, "localhost")
62     self.assertEqual(req.port, 24094)
63     self.assertEqual(req.path, "/version")
64     req.success = True
65     req.resp_status_code = http.HTTP_OK
66     req.resp_body = serializer.DumpJson((True, 123))
67
68   def testVersionSuccess(self):
69     client = rpc.Client("version", None, 24094)
70     client.ConnectNode("localhost")
71     pool = FakeHttpPool(self._GetVersionResponse)
72     result = client.GetResults(http_pool=pool)
73     self.assertEqual(result.keys(), ["localhost"])
74     lhresp = result["localhost"]
75     self.assertFalse(lhresp.offline)
76     self.assertEqual(lhresp.node, "localhost")
77     self.assertFalse(lhresp.fail_msg)
78     self.assertEqual(lhresp.payload, 123)
79     self.assertEqual(lhresp.call, "version")
80     lhresp.Raise("should not raise")
81     self.assertEqual(pool.reqcount, 1)
82
83   def _GetMultiVersionResponse(self, req):
84     self.assert_(req.host.startswith("node"))
85     self.assertEqual(req.port, 23245)
86     self.assertEqual(req.path, "/version")
87     req.success = True
88     req.resp_status_code = http.HTTP_OK
89     req.resp_body = serializer.DumpJson((True, 987))
90
91   def testMultiVersionSuccess(self):
92     nodes = ["node%s" % i for i in range(50)]
93     client = rpc.Client("version", None, 23245)
94     client.ConnectList(nodes)
95
96     pool = FakeHttpPool(self._GetMultiVersionResponse)
97     result = client.GetResults(http_pool=pool)
98     self.assertEqual(sorted(result.keys()), sorted(nodes))
99
100     for name in nodes:
101       lhresp = result[name]
102       self.assertFalse(lhresp.offline)
103       self.assertEqual(lhresp.node, name)
104       self.assertFalse(lhresp.fail_msg)
105       self.assertEqual(lhresp.payload, 987)
106       self.assertEqual(lhresp.call, "version")
107       lhresp.Raise("should not raise")
108
109     self.assertEqual(pool.reqcount, len(nodes))
110
111   def _GetVersionResponseFail(self, req):
112     self.assertEqual(req.path, "/version")
113     req.success = True
114     req.resp_status_code = http.HTTP_OK
115     req.resp_body = serializer.DumpJson((False, "Unknown error"))
116
117   def testVersionFailure(self):
118     client = rpc.Client("version", None, 5903)
119     client.ConnectNode("aef9ur4i.example.com")
120     pool = FakeHttpPool(self._GetVersionResponseFail)
121     result = client.GetResults(http_pool=pool)
122     self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
123     lhresp = result["aef9ur4i.example.com"]
124     self.assertFalse(lhresp.offline)
125     self.assertEqual(lhresp.node, "aef9ur4i.example.com")
126     self.assert_(lhresp.fail_msg)
127     self.assertFalse(lhresp.payload)
128     self.assertEqual(lhresp.call, "version")
129     self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
130     self.assertEqual(pool.reqcount, 1)
131
132   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
133     self.assertEqual(req.path, "/vg_list")
134     self.assertEqual(req.port, 15165)
135
136     if req.host in httperrnodes:
137       req.success = False
138       req.error = "Node set up for HTTP errors"
139
140     elif req.host in failnodes:
141       req.success = True
142       req.resp_status_code = 404
143       req.resp_body = serializer.DumpJson({
144         "code": 404,
145         "message": "Method not found",
146         "explain": "Explanation goes here",
147         })
148     else:
149       req.success = True
150       req.resp_status_code = http.HTTP_OK
151       req.resp_body = serializer.DumpJson((True, hash(req.host)))
152
153   def testHttpError(self):
154     nodes = ["uaf6pbbv%s" % i for i in range(50)]
155
156     httperrnodes = set(nodes[1::7])
157     self.assertEqual(len(httperrnodes), 7)
158
159     failnodes = set(nodes[2::3]) - httperrnodes
160     self.assertEqual(len(failnodes), 14)
161
162     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
163
164     client = rpc.Client("vg_list", None, 15165)
165     client.ConnectList(nodes)
166
167     pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
168                                        httperrnodes, failnodes))
169     result = client.GetResults(http_pool=pool)
170     self.assertEqual(sorted(result.keys()), sorted(nodes))
171
172     for name in nodes:
173       lhresp = result[name]
174       self.assertFalse(lhresp.offline)
175       self.assertEqual(lhresp.node, name)
176       self.assertEqual(lhresp.call, "vg_list")
177
178       if name in httperrnodes:
179         self.assert_(lhresp.fail_msg)
180         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
181       elif name in failnodes:
182         self.assert_(lhresp.fail_msg)
183         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
184                           prereq=True, ecode=errors.ECODE_INVAL)
185       else:
186         self.assertFalse(lhresp.fail_msg)
187         self.assertEqual(lhresp.payload, hash(name))
188         lhresp.Raise("should not raise")
189
190     self.assertEqual(pool.reqcount, len(nodes))
191
192   def _GetInvalidResponseA(self, req):
193     self.assertEqual(req.path, "/version")
194     req.success = True
195     req.resp_status_code = http.HTTP_OK
196     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
197                                          "response", "!", 1, 2, 3))
198
199   def _GetInvalidResponseB(self, req):
200     self.assertEqual(req.path, "/version")
201     req.success = True
202     req.resp_status_code = http.HTTP_OK
203     req.resp_body = serializer.DumpJson("invalid response")
204
205   def testInvalidResponse(self):
206     client = rpc.Client("version", None, 19978)
207     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
208       client.ConnectNode("oqo7lanhly.example.com")
209       pool = FakeHttpPool(fn)
210       result = client.GetResults(http_pool=pool)
211       self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
212       lhresp = result["oqo7lanhly.example.com"]
213       self.assertFalse(lhresp.offline)
214       self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
215       self.assert_(lhresp.fail_msg)
216       self.assertFalse(lhresp.payload)
217       self.assertEqual(lhresp.call, "version")
218       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
219       self.assertEqual(pool.reqcount, 1)
220
221
222 if __name__ == "__main__":
223   testutils.GanetiTestProgram()