Change internal RPC client body values
[ganeti-local] / lib / rpc.py
index 2f65ad5..bf05cfb 100644 (file)
@@ -338,10 +338,19 @@ class _RpcProcessor:
   def _PrepareRequests(hosts, port, procedure, body, read_timeout):
     """Prepares requests by sorting offline hosts into separate list.
 
+    @type body: dict
+    @param body: a dictionary with per-host body data
+
     """
     results = {}
     requests = {}
 
+    assert isinstance(body, dict)
+    assert len(body) == len(hosts)
+    assert compat.all(isinstance(v, str) for v in body.values())
+    assert frozenset(map(compat.fst, hosts)) == frozenset(body.keys()), \
+        "%s != %s" % (hosts, body.keys())
+
     for (name, ip) in hosts:
       if ip is _OFFLINE:
         # Node is marked as offline
@@ -351,7 +360,7 @@ class _RpcProcessor:
           http.client.HttpClientRequest(str(ip), port,
                                         http.HTTP_PUT, str("/%s" % procedure),
                                         headers=_RPC_CLIENT_HEADERS,
-                                        post_data=body,
+                                        post_data=body[name],
                                         read_timeout=read_timeout,
                                         nicename="%s/%s" % (name, procedure),
                                         curl_config_fn=_ConfigRpcCurl)
@@ -390,8 +399,8 @@ class _RpcProcessor:
     @param hosts: Hostnames
     @type procedure: string
     @param procedure: Request path
-    @type body: string
-    @param body: Request body
+    @type body: dictionary
+    @param body: dictionary with request bodies per host
     @type read_timeout: int or None
     @param read_timeout: Read timeout for request
 
@@ -401,7 +410,7 @@ class _RpcProcessor:
 
     (results, requests) = \
       self._PrepareRequests(self._resolver(hosts), self._port, procedure,
-                            str(body), read_timeout)
+                            body, read_timeout)
 
     _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
 
@@ -434,17 +443,28 @@ class _RpcClientBase:
     """Entry point for automatically generated RPC wrappers.
 
     """
-    (procedure, _, timeout, argdefs, _, postproc_fn, _) = cdef
+    (procedure, _, timeout, argdefs, prep_fn, postproc_fn, _) = cdef
 
     if callable(timeout):
       read_timeout = timeout(args)
     else:
       read_timeout = timeout
 
-    body = serializer.DumpJson(map(self._encoder,
-                                   zip(map(compat.snd, argdefs), args)))
-
-    result = self._proc(node_list, procedure, body, read_timeout=read_timeout)
+    enc_args = map(self._encoder, zip(map(compat.snd, argdefs), args))
+    if prep_fn is None:
+      # for a no-op prep_fn, we serialise the body once, and then we
+      # reuse it in the dictionary values
+      body = serializer.DumpJson(enc_args)
+      pnbody = dict((n, body) for n in node_list)
+    else:
+      # for a custom prep_fn, we pass the encoded arguments and the
+      # node name to the prep_fn, and we serialise its return value
+      assert(callable(prep_fn))
+      pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args)))
+                    for n in node_list)
+
+    result = self._proc(node_list, procedure, pnbody,
+                        read_timeout=read_timeout)
 
     if postproc_fn:
       return dict(map(lambda (key, value): (key, postproc_fn(value)),