Revision d9de612c

b/lib/rpc.py
338 338
  def _PrepareRequests(hosts, port, procedure, body, read_timeout):
339 339
    """Prepares requests by sorting offline hosts into separate list.
340 340

  
341
    @type body: dict
342
    @param body: a dictionary with per-host body data
343

  
341 344
    """
342 345
    results = {}
343 346
    requests = {}
344 347

  
348
    assert isinstance(body, dict)
349
    assert len(body) == len(hosts)
350
    assert compat.all(isinstance(v, str) for v in body.values())
351
    assert frozenset(map(compat.fst, hosts)) == frozenset(body.keys()), \
352
        "%s != %s" % (hosts, body.keys())
353

  
345 354
    for (name, ip) in hosts:
346 355
      if ip is _OFFLINE:
347 356
        # Node is marked as offline
......
351 360
          http.client.HttpClientRequest(str(ip), port,
352 361
                                        http.HTTP_PUT, str("/%s" % procedure),
353 362
                                        headers=_RPC_CLIENT_HEADERS,
354
                                        post_data=body,
363
                                        post_data=body[name],
355 364
                                        read_timeout=read_timeout,
356 365
                                        nicename="%s/%s" % (name, procedure),
357 366
                                        curl_config_fn=_ConfigRpcCurl)
......
390 399
    @param hosts: Hostnames
391 400
    @type procedure: string
392 401
    @param procedure: Request path
393
    @type body: string
394
    @param body: Request body
402
    @type body: dictionary
403
    @param body: dictionary with request bodies per host
395 404
    @type read_timeout: int or None
396 405
    @param read_timeout: Read timeout for request
397 406

  
......
401 410

  
402 411
    (results, requests) = \
403 412
      self._PrepareRequests(self._resolver(hosts), self._port, procedure,
404
                            str(body), read_timeout)
413
                            body, read_timeout)
405 414

  
406 415
    _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
407 416

  
......
434 443
    """Entry point for automatically generated RPC wrappers.
435 444

  
436 445
    """
437
    (procedure, _, timeout, argdefs, _, postproc_fn, _) = cdef
446
    (procedure, _, timeout, argdefs, prep_fn, postproc_fn, _) = cdef
438 447

  
439 448
    if callable(timeout):
440 449
      read_timeout = timeout(args)
441 450
    else:
442 451
      read_timeout = timeout
443 452

  
444
    body = serializer.DumpJson(map(self._encoder,
445
                                   zip(map(compat.snd, argdefs), args)))
446

  
447
    result = self._proc(node_list, procedure, body, read_timeout=read_timeout)
453
    enc_args = map(self._encoder, zip(map(compat.snd, argdefs), args))
454
    if prep_fn is None:
455
      # for a no-op prep_fn, we serialise the body once, and then we
456
      # reuse it in the dictionary values
457
      body = serializer.DumpJson(enc_args)
458
      pnbody = dict((n, body) for n in node_list)
459
    else:
460
      # for a custom prep_fn, we pass the encoded arguments and the
461
      # node name to the prep_fn, and we serialise its return value
462
      assert(callable(prep_fn))
463
      pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args)))
464
                    for n in node_list)
465

  
466
    result = self._proc(node_list, procedure, pnbody,
467
                        read_timeout=read_timeout)
448 468

  
449 469
    if postproc_fn:
450 470
      return dict(map(lambda (key, value): (key, postproc_fn(value)),
b/test/ganeti.rpc_unittest.py
1 1
#!/usr/bin/python
2 2
#
3 3

  
4
# Copyright (C) 2010 Google Inc.
4
# Copyright (C) 2010, 2011 Google Inc.
5 5
#
6 6
# This program is free software; you can redistribute it and/or modify
7 7
# it under the terms of the GNU General Public License as published by
......
73 73
    resolver = rpc._StaticResolver(["127.0.0.1"])
74 74
    http_proc = _FakeRequestProcessor(self._GetVersionResponse)
75 75
    proc = rpc._RpcProcessor(resolver, 24094)
76
    result = proc(["localhost"], "version", None, _req_process_fn=http_proc,
77
                  read_timeout=60)
76
    result = proc(["localhost"], "version", {"localhost": ""},
77
                  _req_process_fn=http_proc, read_timeout=60)
78 78
    self.assertEqual(result.keys(), ["localhost"])
79 79
    lhresp = result["localhost"]
80 80
    self.assertFalse(lhresp.offline)
......
98 98
    resolver = rpc._StaticResolver(["192.0.2.13"])
99 99
    http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
100 100
    proc = rpc._RpcProcessor(resolver, 19176)
101
    result = proc(["node31856"], "version", None, _req_process_fn=http_proc,
101
    host = "node31856"
102
    body = {host: ""}
103
    result = proc([host], "version", body, _req_process_fn=http_proc,
102 104
                  read_timeout=12356)
103
    self.assertEqual(result.keys(), ["node31856"])
104
    lhresp = result["node31856"]
105
    self.assertEqual(result.keys(), [host])
106
    lhresp = result[host]
105 107
    self.assertFalse(lhresp.offline)
106
    self.assertEqual(lhresp.node, "node31856")
108
    self.assertEqual(lhresp.node, host)
107 109
    self.assertFalse(lhresp.fail_msg)
108 110
    self.assertEqual(lhresp.payload, -1)
109 111
    self.assertEqual(lhresp.call, "version")
......
114 116
    resolver = rpc._StaticResolver([rpc._OFFLINE])
115 117
    http_proc = _FakeRequestProcessor(NotImplemented)
116 118
    proc = rpc._RpcProcessor(resolver, 30668)
117
    result = proc(["n17296"], "version", None, _req_process_fn=http_proc,
119
    host = "n17296"
120
    body = {host: ""}
121
    result = proc([host], "version", body, _req_process_fn=http_proc,
118 122
                  read_timeout=60)
119
    self.assertEqual(result.keys(), ["n17296"])
120
    lhresp = result["n17296"]
123
    self.assertEqual(result.keys(), [host])
124
    lhresp = result[host]
121 125
    self.assertTrue(lhresp.offline)
122
    self.assertEqual(lhresp.node, "n17296")
126
    self.assertEqual(lhresp.node, host)
123 127
    self.assertTrue(lhresp.fail_msg)
124 128
    self.assertFalse(lhresp.payload)
125 129
    self.assertEqual(lhresp.call, "version")
......
142 146

  
143 147
  def testMultiVersionSuccess(self):
144 148
    nodes = ["node%s" % i for i in range(50)]
149
    body = dict((n, "") for n in nodes)
145 150
    resolver = rpc._StaticResolver(nodes)
146 151
    http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
147 152
    proc = rpc._RpcProcessor(resolver, 23245)
148
    result = proc(nodes, "version", None, _req_process_fn=http_proc,
153
    result = proc(nodes, "version", body, _req_process_fn=http_proc,
149 154
                  read_timeout=60)
150 155
    self.assertEqual(sorted(result.keys()), sorted(nodes))
151 156

  
......
173 178
      http_proc = \
174 179
        _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
175 180
                                             errinfo))
176
      result = proc(["aef9ur4i.example.com"], "version", None,
181
      host = "aef9ur4i.example.com"
182
      body = {host: ""}
183
      result = proc(body.keys(), "version", body,
177 184
                    _req_process_fn=http_proc, read_timeout=60)
178
      self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
179
      lhresp = result["aef9ur4i.example.com"]
185
      self.assertEqual(result.keys(), [host])
186
      lhresp = result[host]
180 187
      self.assertFalse(lhresp.offline)
181
      self.assertEqual(lhresp.node, "aef9ur4i.example.com")
188
      self.assertEqual(lhresp.node, host)
182 189
      self.assert_(lhresp.fail_msg)
183 190
      self.assertFalse(lhresp.payload)
184 191
      self.assertEqual(lhresp.call, "version")
......
208 215

  
209 216
  def testHttpError(self):
210 217
    nodes = ["uaf6pbbv%s" % i for i in range(50)]
218
    body = dict((n, "") for n in nodes)
211 219
    resolver = rpc._StaticResolver(nodes)
212 220

  
213 221
    httperrnodes = set(nodes[1::7])
......
222 230
    http_proc = \
223 231
      _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
224 232
                                           httperrnodes, failnodes))
225
    result = proc(nodes, "vg_list", None, _req_process_fn=http_proc,
233
    result = proc(nodes, "vg_list", body, _req_process_fn=http_proc,
226 234
                  read_timeout=rpc._TMO_URGENT)
227 235
    self.assertEqual(sorted(result.keys()), sorted(nodes))
228 236

  
......
265 273

  
266 274
    for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
267 275
      http_proc = _FakeRequestProcessor(fn)
268
      result = proc(["oqo7lanhly.example.com"], "version", None,
276
      host = "oqo7lanhly.example.com"
277
      body = {host: ""}
278
      result = proc([host], "version", body,
269 279
                    _req_process_fn=http_proc, read_timeout=60)
270
      self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
271
      lhresp = result["oqo7lanhly.example.com"]
280
      self.assertEqual(result.keys(), [host])
281
      lhresp = result[host]
272 282
      self.assertFalse(lhresp.offline)
273
      self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
283
      self.assertEqual(lhresp.node, host)
274 284
      self.assert_(lhresp.fail_msg)
275 285
      self.assertFalse(lhresp.payload)
276 286
      self.assertEqual(lhresp.call, "version")
......
295 305
    http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
296 306
                                                     test_data))
297 307
    proc = rpc._RpcProcessor(resolver, 18700)
298
    body = serializer.DumpJson(test_data)
299
    result = proc(["node19759"], "upload_file", body, _req_process_fn=http_proc,
308
    host = "node19759"
309
    body = {host: serializer.DumpJson(test_data)}
310
    result = proc([host], "upload_file", body, _req_process_fn=http_proc,
300 311
                  read_timeout=30)
301
    self.assertEqual(result.keys(), ["node19759"])
302
    lhresp = result["node19759"]
312
    self.assertEqual(result.keys(), [host])
313
    lhresp = result[host]
303 314
    self.assertFalse(lhresp.offline)
304
    self.assertEqual(lhresp.node, "node19759")
315
    self.assertEqual(lhresp.node, host)
305 316
    self.assertFalse(lhresp.fail_msg)
306 317
    self.assertEqual(lhresp.payload, None)
307 318
    self.assertEqual(lhresp.call, "upload_file")

Also available in: Unified diff