Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.rpc_unittest.py @ b43dcc5a

History | View | Annotate | Download (10 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for testing ganeti.rpc"""
23

    
24
import os
25
import sys
26
import unittest
27

    
28
from ganeti import constants
29
from ganeti import compat
30
from ganeti import rpc
31
from ganeti import http
32
from ganeti import errors
33
from ganeti import serializer
34

    
35
import testutils
36

    
37

    
38
class TestTimeouts(unittest.TestCase):
39
  def test(self):
40
    names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
41
             if name.startswith("call_")]
42
    self.assertEqual(len(names), len(rpc._TIMEOUTS))
43
    self.assertFalse([name for name in names
44
                      if not (rpc._TIMEOUTS[name] is None or
45
                              rpc._TIMEOUTS[name] > 0)])
46

    
47

    
48
class FakeHttpPool:
49
  def __init__(self, response_fn):
50
    self._response_fn = response_fn
51
    self.reqcount = 0
52

    
53
  def ProcessRequests(self, reqs):
54
    for req in reqs:
55
      self.reqcount += 1
56
      self._response_fn(req)
57

    
58

    
59
def GetFakeSimpleStoreClass(fn):
60
  class FakeSimpleStore:
61
    GetNodePrimaryIPList = fn
62
    GetPrimaryIPFamily = lambda _: None
63

    
64
  return FakeSimpleStore
65

    
66

    
67
class TestClient(unittest.TestCase):
68
  def _FakeAddressLookup(self, map):
69
    return lambda node_list: [map.get(node) for node in node_list]
70

    
71
  def _GetVersionResponse(self, req):
72
    self.assertEqual(req.host, "localhost")
73
    self.assertEqual(req.port, 24094)
74
    self.assertEqual(req.path, "/version")
75
    req.success = True
76
    req.resp_status_code = http.HTTP_OK
77
    req.resp_body = serializer.DumpJson((True, 123))
78

    
79
  def testVersionSuccess(self):
80
    fn = self._FakeAddressLookup({"localhost": "localhost"})
81
    client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
82
    client.ConnectNode("localhost")
83
    pool = FakeHttpPool(self._GetVersionResponse)
84
    result = client.GetResults(http_pool=pool)
85
    self.assertEqual(result.keys(), ["localhost"])
86
    lhresp = result["localhost"]
87
    self.assertFalse(lhresp.offline)
88
    self.assertEqual(lhresp.node, "localhost")
89
    self.assertFalse(lhresp.fail_msg)
90
    self.assertEqual(lhresp.payload, 123)
91
    self.assertEqual(lhresp.call, "version")
92
    lhresp.Raise("should not raise")
93
    self.assertEqual(pool.reqcount, 1)
94

    
95
  def _GetMultiVersionResponse(self, req):
96
    self.assert_(req.host.startswith("node"))
97
    self.assertEqual(req.port, 23245)
98
    self.assertEqual(req.path, "/version")
99
    req.success = True
100
    req.resp_status_code = http.HTTP_OK
101
    req.resp_body = serializer.DumpJson((True, 987))
102

    
103
  def testMultiVersionSuccess(self):
104
    nodes = ["node%s" % i for i in range(50)]
105
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
106
    client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
107
    client.ConnectList(nodes)
108

    
109
    pool = FakeHttpPool(self._GetMultiVersionResponse)
110
    result = client.GetResults(http_pool=pool)
111
    self.assertEqual(sorted(result.keys()), sorted(nodes))
112

    
113
    for name in nodes:
114
      lhresp = result[name]
115
      self.assertFalse(lhresp.offline)
116
      self.assertEqual(lhresp.node, name)
117
      self.assertFalse(lhresp.fail_msg)
118
      self.assertEqual(lhresp.payload, 987)
119
      self.assertEqual(lhresp.call, "version")
120
      lhresp.Raise("should not raise")
121

    
122
    self.assertEqual(pool.reqcount, len(nodes))
123

    
124
  def _GetVersionResponseFail(self, req):
125
    self.assertEqual(req.path, "/version")
126
    req.success = True
127
    req.resp_status_code = http.HTTP_OK
128
    req.resp_body = serializer.DumpJson((False, "Unknown error"))
129

    
130
  def testVersionFailure(self):
131
    lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
132
    fn = self._FakeAddressLookup(lookup_map)
133
    client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
134
    client.ConnectNode("aef9ur4i.example.com")
135
    pool = FakeHttpPool(self._GetVersionResponseFail)
136
    result = client.GetResults(http_pool=pool)
137
    self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
138
    lhresp = result["aef9ur4i.example.com"]
139
    self.assertFalse(lhresp.offline)
140
    self.assertEqual(lhresp.node, "aef9ur4i.example.com")
141
    self.assert_(lhresp.fail_msg)
142
    self.assertFalse(lhresp.payload)
143
    self.assertEqual(lhresp.call, "version")
144
    self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
145
    self.assertEqual(pool.reqcount, 1)
146

    
147
  def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
148
    self.assertEqual(req.path, "/vg_list")
149
    self.assertEqual(req.port, 15165)
150

    
151
    if req.host in httperrnodes:
152
      req.success = False
153
      req.error = "Node set up for HTTP errors"
154

    
155
    elif req.host in failnodes:
156
      req.success = True
157
      req.resp_status_code = 404
158
      req.resp_body = serializer.DumpJson({
159
        "code": 404,
160
        "message": "Method not found",
161
        "explain": "Explanation goes here",
162
        })
163
    else:
164
      req.success = True
165
      req.resp_status_code = http.HTTP_OK
166
      req.resp_body = serializer.DumpJson((True, hash(req.host)))
167

    
168
  def testHttpError(self):
169
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
170
    fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
171

    
172
    httperrnodes = set(nodes[1::7])
173
    self.assertEqual(len(httperrnodes), 7)
174

    
175
    failnodes = set(nodes[2::3]) - httperrnodes
176
    self.assertEqual(len(failnodes), 14)
177

    
178
    self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
179

    
180
    client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
181
    client.ConnectList(nodes)
182

    
183
    pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
184
                                       httperrnodes, failnodes))
185
    result = client.GetResults(http_pool=pool)
186
    self.assertEqual(sorted(result.keys()), sorted(nodes))
187

    
188
    for name in nodes:
189
      lhresp = result[name]
190
      self.assertFalse(lhresp.offline)
191
      self.assertEqual(lhresp.node, name)
192
      self.assertEqual(lhresp.call, "vg_list")
193

    
194
      if name in httperrnodes:
195
        self.assert_(lhresp.fail_msg)
196
        self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
197
      elif name in failnodes:
198
        self.assert_(lhresp.fail_msg)
199
        self.assertRaises(errors.OpPrereqError, lhresp.Raise, "failed",
200
                          prereq=True, ecode=errors.ECODE_INVAL)
201
      else:
202
        self.assertFalse(lhresp.fail_msg)
203
        self.assertEqual(lhresp.payload, hash(name))
204
        lhresp.Raise("should not raise")
205

    
206
    self.assertEqual(pool.reqcount, len(nodes))
207

    
208
  def _GetInvalidResponseA(self, req):
209
    self.assertEqual(req.path, "/version")
210
    req.success = True
211
    req.resp_status_code = http.HTTP_OK
212
    req.resp_body = serializer.DumpJson(("This", "is", "an", "invalid",
213
                                         "response", "!", 1, 2, 3))
214

    
215
  def _GetInvalidResponseB(self, req):
216
    self.assertEqual(req.path, "/version")
217
    req.success = True
218
    req.resp_status_code = http.HTTP_OK
219
    req.resp_body = serializer.DumpJson("invalid response")
220

    
221
  def testInvalidResponse(self):
222
    lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
223
    fn = self._FakeAddressLookup(lookup_map)
224
    client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
225
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
226
      client.ConnectNode("oqo7lanhly.example.com")
227
      pool = FakeHttpPool(fn)
228
      result = client.GetResults(http_pool=pool)
229
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
230
      lhresp = result["oqo7lanhly.example.com"]
231
      self.assertFalse(lhresp.offline)
232
      self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
233
      self.assert_(lhresp.fail_msg)
234
      self.assertFalse(lhresp.payload)
235
      self.assertEqual(lhresp.call, "version")
236
      self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
237
      self.assertEqual(pool.reqcount, 1)
238

    
239
  def testAddressLookupSimpleStore(self):
240
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
241
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
242
    node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
243
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
244
    result = rpc._AddressLookup(node_list, ssc=ssc)
245
    self.assertEqual(result, addr_list)
246

    
247
  def testAddressLookupNSLookup(self):
248
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
249
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
250
    ssc = GetFakeSimpleStoreClass(lambda _: [])
251
    node_addr_map = dict(zip(node_list, addr_list))
252
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
253
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
254
    self.assertEqual(result, addr_list)
255

    
256
  def testAddressLookupBoth(self):
257
    addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
258
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
259
    n = len(addr_list) / 2
260
    node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
261
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
262
    node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
263
    nslookup_fn = lambda name, family=None: node_addr_map.get(name)
264
    result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
265
    self.assertEqual(result, addr_list)
266

    
267
  def testAddressLookupIPv6(self):
268
    addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)]
269
    node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
270
    node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
271
    ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
272
    result = rpc._AddressLookup(node_list, ssc=ssc)
273
    self.assertEqual(result, addr_list)
274

    
275

    
276
if __name__ == "__main__":
277
  testutils.GanetiTestProgram()