rpc._RpcClientBase: Add check for number of arguments
[ganeti-local] / lib / rpc.py
index c117f0c..f94c701 100644 (file)
@@ -240,7 +240,7 @@ class RpcResult(object):
     raise ec(*args) # pylint: disable=W0142
 
 
-def _SsconfResolver(node_list,
+def _SsconfResolver(node_list, _,
                     ssc=ssconf.SimpleStore,
                     nslookup_fn=netutils.Hostname.GetIP):
   """Return addresses for given node names.
@@ -277,7 +277,7 @@ class _StaticResolver:
     """
     self._addresses = addresses
 
-  def __call__(self, hosts):
+  def __call__(self, hosts, _):
     """Returns static addresses for hosts.
 
     """
@@ -285,7 +285,7 @@ class _StaticResolver:
     return zip(hosts, self._addresses)
 
 
-def _CheckConfigNode(name, node):
+def _CheckConfigNode(name, node, accept_offline_node):
   """Checks if a node is online.
 
   @type name: string
@@ -297,24 +297,29 @@ def _CheckConfigNode(name, node):
   if node is None:
     # Depend on DNS for name resolution
     ip = name
-  elif node.offline:
+  elif node.offline and not accept_offline_node:
     ip = _OFFLINE
   else:
     ip = node.primary_ip
   return (name, ip)
 
 
-def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts):
+def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts, opts):
   """Calculate node addresses using configuration.
 
   """
+  accept_offline_node = (opts is rpc_defs.ACCEPT_OFFLINE_NODE)
+
+  assert accept_offline_node or opts is None, "Unknown option"
+
   # Special case for single-host lookups
   if len(hosts) == 1:
     (name, ) = hosts
-    return [_CheckConfigNode(name, single_node_fn(name))]
+    return [_CheckConfigNode(name, single_node_fn(name), accept_offline_node)]
   else:
     all_nodes = all_nodes_fn()
-    return [_CheckConfigNode(name, all_nodes.get(name, None))
+    return [_CheckConfigNode(name, all_nodes.get(name, None),
+                             accept_offline_node)
             for name in hosts]
 
 
@@ -338,10 +343,19 @@ class _RpcProcessor:
   def _PrepareRequests(hosts, port, procedure, body, read_timeout):
     """Prepares requests by sorting offline hosts into separate list.
 
+    @type body: dict
+    @param body: a dictionary with per-host body data
+
     """
     results = {}
     requests = {}
 
+    assert isinstance(body, dict)
+    assert len(body) == len(hosts)
+    assert compat.all(isinstance(v, str) for v in body.values())
+    assert frozenset(map(compat.fst, hosts)) == frozenset(body.keys()), \
+        "%s != %s" % (hosts, body.keys())
+
     for (name, ip) in hosts:
       if ip is _OFFLINE:
         # Node is marked as offline
@@ -351,7 +365,7 @@ class _RpcProcessor:
           http.client.HttpClientRequest(str(ip), port,
                                         http.HTTP_PUT, str("/%s" % procedure),
                                         headers=_RPC_CLIENT_HEADERS,
-                                        post_data=body,
+                                        post_data=body[name],
                                         read_timeout=read_timeout,
                                         nicename="%s/%s" % (name, procedure),
                                         curl_config_fn=_ConfigRpcCurl)
@@ -382,7 +396,7 @@ class _RpcProcessor:
 
     return results
 
-  def __call__(self, hosts, procedure, body, read_timeout=None,
+  def __call__(self, hosts, procedure, body, read_timeout, resolver_opts,
                _req_process_fn=http.client.ProcessRequests):
     """Makes an RPC request to a number of nodes.
 
@@ -390,8 +404,8 @@ class _RpcProcessor:
     @param hosts: Hostnames
     @type procedure: string
     @param procedure: Request path
-    @type body: string
-    @param body: Request body
+    @type body: dictionary
+    @param body: dictionary with request bodies per host
     @type read_timeout: int or None
     @param read_timeout: Read timeout for request
 
@@ -400,8 +414,8 @@ class _RpcProcessor:
       "Missing RPC read timeout for procedure '%s'" % procedure
 
     (results, requests) = \
-      self._PrepareRequests(self._resolver(hosts), self._port, procedure,
-                            str(body), read_timeout)
+      self._PrepareRequests(self._resolver(hosts, resolver_opts), self._port,
+                            procedure, body, read_timeout)
 
     _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
 
@@ -430,16 +444,47 @@ class _RpcClientBase:
     else:
       return encoder_fn(argkind)(value)
 
-  def _Call(self, node_list, procedure, timeout, argdefs, args):
+  def _Call(self, cdef, node_list, args):
     """Entry point for automatically generated RPC wrappers.
 
     """
-    assert len(args) == len(argdefs), "Wrong number of arguments"
+    (procedure, _, resolver_opts, timeout, argdefs,
+     prep_fn, postproc_fn, _) = cdef
 
-    body = serializer.DumpJson(map(self._encoder, zip(argdefs, args)),
-                               indent=False)
+    if callable(timeout):
+      read_timeout = timeout(args)
+    else:
+      read_timeout = timeout
 
-    return self._proc(node_list, procedure, body, read_timeout=timeout)
+    if callable(resolver_opts):
+      req_resolver_opts = resolver_opts(args)
+    else:
+      req_resolver_opts = resolver_opts
+
+    if len(args) != len(argdefs):
+      raise errors.ProgrammerError("Number of passed arguments doesn't match")
+
+    enc_args = map(self._encoder, zip(map(compat.snd, argdefs), args))
+    if prep_fn is None:
+      # for a no-op prep_fn, we serialise the body once, and then we
+      # reuse it in the dictionary values
+      body = serializer.DumpJson(enc_args)
+      pnbody = dict((n, body) for n in node_list)
+    else:
+      # for a custom prep_fn, we pass the encoded arguments and the
+      # node name to the prep_fn, and we serialise its return value
+      assert callable(prep_fn)
+      pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args)))
+                    for n in node_list)
+
+    result = self._proc(node_list, procedure, pnbody, read_timeout,
+                        req_resolver_opts)
+
+    if postproc_fn:
+      return dict(map(lambda (key, value): (key, postproc_fn(value)),
+                      result.items()))
+    else:
+      return result
 
 
 def _ObjectToDict(value):
@@ -614,79 +659,6 @@ class RpcRunner(_RpcClientBase,
     """
     return self._InstDict(instance, osp=osparams)
 
-  @staticmethod
-  def _MigrationStatusPostProc(result):
-    if not result.fail_msg and result.payload is not None:
-      result.payload = objects.MigrationStatus.FromDict(result.payload)
-    return result
-
-  @staticmethod
-  def _BlockdevFindPostProc(result):
-    if not result.fail_msg and result.payload is not None:
-      result.payload = objects.BlockDevStatus.FromDict(result.payload)
-    return result
-
-  @staticmethod
-  def _BlockdevGetMirrorStatusPostProc(result):
-    if not result.fail_msg:
-      result.payload = [objects.BlockDevStatus.FromDict(i)
-                        for i in result.payload]
-    return result
-
-  @staticmethod
-  def _BlockdevGetMirrorStatusMultiPostProc(result):
-    for nres in result.values():
-      if nres.fail_msg:
-        continue
-
-      for idx, (success, status) in enumerate(nres.payload):
-        if success:
-          nres.payload[idx] = (success, objects.BlockDevStatus.FromDict(status))
-
-    return result
-
-  @staticmethod
-  def _OsGetPostProc(result):
-    if not result.fail_msg and isinstance(result.payload, dict):
-      result.payload = objects.OS.FromDict(result.payload)
-    return result
-
-  @staticmethod
-  def _ImpExpStatusPostProc(result):
-    """Post-processor for import/export status.
-
-    @rtype: Payload containing list of L{objects.ImportExportStatus} instances
-    @return: Returns a list of the state of each named import/export or None if
-             a status couldn't be retrieved
-
-    """
-    if not result.fail_msg:
-      decoded = []
-
-      for i in result.payload:
-        if i is None:
-          decoded.append(None)
-          continue
-        decoded.append(objects.ImportExportStatus.FromDict(i))
-
-      result.payload = decoded
-
-    return result
-
-  #
-  # Begin RPC calls
-  #
-
-  def call_test_delay(self, node_list, duration): # pylint: disable=W0221
-    """Sleep for a fixed time on given node(s).
-
-    This is a multi-node call.
-
-    """
-    # TODO: Use callable timeout calculation
-    return _generated_rpc.RpcClientDefault.call_test_delay(self,
-      node_list, duration, read_timeout=int(duration + 5))
-
 
 class JobQueueRunner(_RpcClientBase, _generated_rpc.RpcClientJobQueue):
   """RPC wrappers for job queue.