unit tests: Add tests for file mode handling in utils.WriteFile
[ganeti-local] / test / ganeti.rpc_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010, 2011 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for testing ganeti.rpc"""
23
24 import os
25 import sys
26 import unittest
27
28 from ganeti import constants
29 from ganeti import compat
30 from ganeti import rpc
31 from ganeti import http
32 from ganeti import errors
33 from ganeti import serializer
34 from ganeti import objects
35
36 import testutils
37
38
39 class _FakeRequestProcessor:
40   def __init__(self, response_fn):
41     self._response_fn = response_fn
42     self.reqcount = 0
43
44   def __call__(self, reqs, lock_monitor_cb=None):
45     assert lock_monitor_cb is None or callable(lock_monitor_cb)
46     for req in reqs:
47       self.reqcount += 1
48       self._response_fn(req)
49
50
51 def GetFakeSimpleStoreClass(fn):
52   class FakeSimpleStore:
53     GetNodePrimaryIPList = fn
54     GetPrimaryIPFamily = lambda _: None
55
56   return FakeSimpleStore
57
58
59 class TestRpcProcessor(unittest.TestCase):
60   def _FakeAddressLookup(self, map):
61     return lambda node_list: [map.get(node) for node in node_list]
62
63   def _GetVersionResponse(self, req):
64     self.assertEqual(req.host, "127.0.0.1")
65     self.assertEqual(req.port, 24094)
66     self.assertEqual(req.path, "/version")
67     self.assertEqual(req.read_timeout, rpc._TMO_URGENT)
68     req.success = True
69     req.resp_status_code = http.HTTP_OK
70     req.resp_body = serializer.DumpJson((True, 123))
71
72   def testVersionSuccess(self):
73     resolver = rpc._StaticResolver(["127.0.0.1"])
74     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75     proc = rpc._RpcProcessor(resolver, 24094)
76     result = proc(["localhost"], "version", {"localhost": ""},
77                   _req_process_fn=http_proc, read_timeout=60)
78     self.assertEqual(result.keys(), ["localhost"])
79     lhresp = result["localhost"]
80     self.assertFalse(lhresp.offline)
81     self.assertEqual(lhresp.node, "localhost")
82     self.assertFalse(lhresp.fail_msg)
83     self.assertEqual(lhresp.payload, 123)
84     self.assertEqual(lhresp.call, "version")
85     lhresp.Raise("should not raise")
86     self.assertEqual(http_proc.reqcount, 1)
87
88   def _ReadTimeoutResponse(self, req):
89     self.assertEqual(req.host, "192.0.2.13")
90     self.assertEqual(req.port, 19176)
91     self.assertEqual(req.path, "/version")
92     self.assertEqual(req.read_timeout, 12356)
93     req.success = True
94     req.resp_status_code = http.HTTP_OK
95     req.resp_body = serializer.DumpJson((True, -1))
96
97   def testReadTimeout(self):
98     resolver = rpc._StaticResolver(["192.0.2.13"])
99     http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100     proc = rpc._RpcProcessor(resolver, 19176)
101     host = "node31856"
102     body = {host: ""}
103     result = proc([host], "version", body, _req_process_fn=http_proc,
104                   read_timeout=12356)
105     self.assertEqual(result.keys(), [host])
106     lhresp = result[host]
107     self.assertFalse(lhresp.offline)
108     self.assertEqual(lhresp.node, host)
109     self.assertFalse(lhresp.fail_msg)
110     self.assertEqual(lhresp.payload, -1)
111     self.assertEqual(lhresp.call, "version")
112     lhresp.Raise("should not raise")
113     self.assertEqual(http_proc.reqcount, 1)
114
115   def testOfflineNode(self):
116     resolver = rpc._StaticResolver([rpc._OFFLINE])
117     http_proc = _FakeRequestProcessor(NotImplemented)
118     proc = rpc._RpcProcessor(resolver, 30668)
119     host = "n17296"
120     body = {host: ""}
121     result = proc([host], "version", body, _req_process_fn=http_proc,
122                   read_timeout=60)
123     self.assertEqual(result.keys(), [host])
124     lhresp = result[host]
125     self.assertTrue(lhresp.offline)
126     self.assertEqual(lhresp.node, host)
127     self.assertTrue(lhresp.fail_msg)
128     self.assertFalse(lhresp.payload)
129     self.assertEqual(lhresp.call, "version")
130
131     # With a message
132     self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
133
134     # No message
135     self.assertRaises(errors.OpExecError, lhresp.Raise, None)
136
137     self.assertEqual(http_proc.reqcount, 0)
138
139   def _GetMultiVersionResponse(self, req):
140     self.assert_(req.host.startswith("node"))
141     self.assertEqual(req.port, 23245)
142     self.assertEqual(req.path, "/version")
143     req.success = True
144     req.resp_status_code = http.HTTP_OK
145     req.resp_body = serializer.DumpJson((True, 987))
146
147   def testMultiVersionSuccess(self):
148     nodes = ["node%s" % i for i in range(50)]
149     body = dict((n, "") for n in nodes)
150     resolver = rpc._StaticResolver(nodes)
151     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
152     proc = rpc._RpcProcessor(resolver, 23245)
153     result = proc(nodes, "version", body, _req_process_fn=http_proc,
154                   read_timeout=60)
155     self.assertEqual(sorted(result.keys()), sorted(nodes))
156
157     for name in nodes:
158       lhresp = result[name]
159       self.assertFalse(lhresp.offline)
160       self.assertEqual(lhresp.node, name)
161       self.assertFalse(lhresp.fail_msg)
162       self.assertEqual(lhresp.payload, 987)
163       self.assertEqual(lhresp.call, "version")
164       lhresp.Raise("should not raise")
165
166     self.assertEqual(http_proc.reqcount, len(nodes))
167
168   def _GetVersionResponseFail(self, errinfo, req):
169     self.assertEqual(req.path, "/version")
170     req.success = True
171     req.resp_status_code = http.HTTP_OK
172     req.resp_body = serializer.DumpJson((False, errinfo))
173
174   def testVersionFailure(self):
175     resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
176     proc = rpc._RpcProcessor(resolver, 5903)
177     for errinfo in [None, "Unknown error"]:
178       http_proc = \
179         _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
180                                              errinfo))
181       host = "aef9ur4i.example.com"
182       body = {host: ""}
183       result = proc(body.keys(), "version", body,
184                     _req_process_fn=http_proc, read_timeout=60)
185       self.assertEqual(result.keys(), [host])
186       lhresp = result[host]
187       self.assertFalse(lhresp.offline)
188       self.assertEqual(lhresp.node, host)
189       self.assert_(lhresp.fail_msg)
190       self.assertFalse(lhresp.payload)
191       self.assertEqual(lhresp.call, "version")
192       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
193       self.assertEqual(http_proc.reqcount, 1)
194
195   def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
196     self.assertEqual(req.path, "/vg_list")
197     self.assertEqual(req.port, 15165)
198
199     if req.host in httperrnodes:
200       req.success = False
201       req.error = "Node set up for HTTP errors"
202
203     elif req.host in failnodes:
204       req.success = True
205       req.resp_status_code = 404
206       req.resp_body = serializer.DumpJson({
207         "code": 404,
208         "message": "Method not found",
209         "explain": "Explanation goes here",
210         })
211     else:
212       req.success = True
213       req.resp_status_code = http.HTTP_OK
214       req.resp_body = serializer.DumpJson((True, hash(req.host)))
215
216   def testHttpError(self):
217     nodes = ["uaf6pbbv%s" % i for i in range(50)]
218     body = dict((n, "") for n in nodes)
219     resolver = rpc._StaticResolver(nodes)
220
221     httperrnodes = set(nodes[1::7])
222     self.assertEqual(len(httperrnodes), 7)
223
224     failnodes = set(nodes[2::3]) - httperrnodes
225     self.assertEqual(len(failnodes), 14)
226
227     self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
228
229     proc = rpc._RpcProcessor(resolver, 15165)
230     http_proc = \
231       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
232                                            httperrnodes, failnodes))
233     result = proc(nodes, "vg_list", body, _req_process_fn=http_proc,
234                   read_timeout=rpc._TMO_URGENT)
235     self.assertEqual(sorted(result.keys()), sorted(nodes))
236
237     for name in nodes:
238       lhresp = result[name]
239       self.assertFalse(lhresp.offline)
240       self.assertEqual(lhresp.node, name)
241       self.assertEqual(lhresp.call, "vg_list")
242
243       if name in httperrnodes:
244         self.assert_(lhresp.fail_msg)
245         self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
246       elif name in failnodes:
247         self.assert_(lhresp.fail_msg)
248         self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
249                           prereq=True, ecode=errors.ECODE_INVAL)
250       else:
251         self.assertFalse(lhresp.fail_msg)
252         self.assertEqual(lhresp.payload, hash(name))
253         lhresp.Raise("should not raise")
254
255     self.assertEqual(http_proc.reqcount, len(nodes))
256
257   def _GetInvalidResponseA(self, req):
258     self.assertEqual(req.path, "/version")
259     req.success = True
260     req.resp_status_code = http.HTTP_OK
261     req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
262                                          "response", "!", 1, 2, 3))
263
264   def _GetInvalidResponseB(self, req):
265     self.assertEqual(req.path, "/version")
266     req.success = True
267     req.resp_status_code = http.HTTP_OK
268     req.resp_body = serializer.DumpJson("invalid response")
269
270   def testInvalidResponse(self):
271     resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
272     proc = rpc._RpcProcessor(resolver, 19978)
273
274     for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
275       http_proc = _FakeRequestProcessor(fn)
276       host = "oqo7lanhly.example.com"
277       body = {host: ""}
278       result = proc([host], "version", body,
279                     _req_process_fn=http_proc, read_timeout=60)
280       self.assertEqual(result.keys(), [host])
281       lhresp = result[host]
282       self.assertFalse(lhresp.offline)
283       self.assertEqual(lhresp.node, host)
284       self.assert_(lhresp.fail_msg)
285       self.assertFalse(lhresp.payload)
286       self.assertEqual(lhresp.call, "version")
287       self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
288       self.assertEqual(http_proc.reqcount, 1)
289
290   def _GetBodyTestResponse(self, test_data, req):
291     self.assertEqual(req.host, "192.0.2.84")
292     self.assertEqual(req.port, 18700)
293     self.assertEqual(req.path, "/upload_file")
294     self.assertEqual(serializer.LoadJson(req.post_data), test_data)
295     req.success = True
296     req.resp_status_code = http.HTTP_OK
297     req.resp_body = serializer.DumpJson((True, None))
298
299   def testResponseBody(self):
300     test_data = {
301       "Hello": "World",
302       "xyz": range(10),
303       }
304     resolver = rpc._StaticResolver(["192.0.2.84"])
305     http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
306                                                      test_data))
307     proc = rpc._RpcProcessor(resolver, 18700)
308     host = "node19759"
309     body = {host: serializer.DumpJson(test_data)}
310     result = proc([host], "upload_file", body, _req_process_fn=http_proc,
311                   read_timeout=30)
312     self.assertEqual(result.keys(), [host])
313     lhresp = result[host]
314     self.assertFalse(lhresp.offline)
315     self.assertEqual(lhresp.node, host)
316     self.assertFalse(lhresp.fail_msg)
317     self.assertEqual(lhresp.payload, None)
318     self.assertEqual(lhresp.call, "upload_file")
319     lhresp.Raise("should not raise")
320     self.assertEqual(http_proc.reqcount, 1)
321
322
323 class TestSsconfResolver(unittest.TestCase):
324   def testSsconfLookup(self):
325     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
326     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
327     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
328     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
329     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
330     self.assertEqual(result, zip(node_list, addr_list))
331
332   def testNsLookup(self):
333     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
334     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
335     ssc = GetFakeSimpleStoreClass(lambda _: [])
336     node_addr_map = dict(zip(node_list, addr_list))
337     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
338     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
339     self.assertEqual(result, zip(node_list, addr_list))
340
341   def testBothLookups(self):
342     addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
343     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
344     n = len(addr_list) / 2
345     node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
346     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
347     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
348     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
349     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
350     self.assertEqual(result, zip(node_list, addr_list))
351
352   def testAddressLookupIPv6(self):
353     addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
354     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
355     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
356     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
357     result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
358     self.assertEqual(result, zip(node_list, addr_list))
359
360
361 class TestStaticResolver(unittest.TestCase):
362   def test(self):
363     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
364     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
365     res = rpc._StaticResolver(addresses)
366     self.assertEqual(res(nodes), zip(nodes, addresses))
367
368   def testWrongLength(self):
369     res = rpc._StaticResolver([])
370     self.assertRaises(AssertionError, res, ["abc"])
371
372
373 class TestNodeConfigResolver(unittest.TestCase):
374   @staticmethod
375   def _GetSingleOnlineNode(name):
376     assert name == "node90.example.com"
377     return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
378
379   @staticmethod
380   def _GetSingleOfflineNode(name):
381     assert name == "node100.example.com"
382     return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
383
384   def testSingleOnline(self):
385     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
386                                              NotImplemented,
387                                              ["node90.example.com"]),
388                      [("node90.example.com", "192.0.2.90")])
389
390   def testSingleOffline(self):
391     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
392                                              NotImplemented,
393                                              ["node100.example.com"]),
394                      [("node100.example.com", rpc._OFFLINE)])
395
396   def testUnknownSingleNode(self):
397     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
398                                              ["node110.example.com"]),
399                      [("node110.example.com", "node110.example.com")])
400
401   def testMultiEmpty(self):
402     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
403                                              lambda: {},
404                                              []),
405                      [])
406
407   def testMultiSomeOffline(self):
408     nodes = dict(("node%s.example.com" % i,
409                   objects.Node(name="node%s.example.com" % i,
410                                offline=((i % 3) == 0),
411                                primary_ip="192.0.2.%s" % i))
412                   for i in range(1, 255))
413
414     # Resolve no names
415     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
416                                              lambda: nodes,
417                                              []),
418                      [])
419
420     # Offline, online and unknown hosts
421     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
422                                              lambda: nodes,
423                                              ["node3.example.com",
424                                               "node92.example.com",
425                                               "node54.example.com",
426                                               "unknown.example.com",]), [
427       ("node3.example.com", rpc._OFFLINE),
428       ("node92.example.com", "192.0.2.92"),
429       ("node54.example.com", rpc._OFFLINE),
430       ("unknown.example.com", "unknown.example.com"),
431       ])
432
433
434 if __name__ == "__main__":
435   testutils.GanetiTestProgram()