rpc: Pass resolver options to actual resolver
authorMichael Hanselmann <hansmi@google.com>
Wed, 4 Jan 2012 19:33:55 +0000 (20:33 +0100)
committerMichael Hanselmann <hansmi@google.com>
Thu, 5 Jan 2012 15:30:34 +0000 (16:30 +0100)
Signed-off-by: Michael Hanselmann <hansmi@google.com>
Reviewed-by: Iustin Pop <iustin@google.com>

lib/rpc.py
test/ganeti.rpc_unittest.py

index d1a21e2..cb0ca78 100644 (file)
@@ -240,7 +240,7 @@ class RpcResult(object):
     raise ec(*args) # pylint: disable=W0142
 
 
-def _SsconfResolver(node_list,
+def _SsconfResolver(node_list, _,
                     ssc=ssconf.SimpleStore,
                     nslookup_fn=netutils.Hostname.GetIP):
   """Return addresses for given node names.
@@ -277,7 +277,7 @@ class _StaticResolver:
     """
     self._addresses = addresses
 
-  def __call__(self, hosts):
+  def __call__(self, hosts, _):
     """Returns static addresses for hosts.
 
     """
@@ -304,7 +304,7 @@ def _CheckConfigNode(name, node):
   return (name, ip)
 
 
-def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts):
+def _NodeConfigResolver(single_node_fn, all_nodes_fn, hosts, _):
   """Calculate node addresses using configuration.
 
   """
@@ -391,7 +391,7 @@ class _RpcProcessor:
 
     return results
 
-  def __call__(self, hosts, procedure, body, read_timeout,
+  def __call__(self, hosts, procedure, body, read_timeout, resolver_opts,
                _req_process_fn=http.client.ProcessRequests):
     """Makes an RPC request to a number of nodes.
 
@@ -409,8 +409,8 @@ class _RpcProcessor:
       "Missing RPC read timeout for procedure '%s'" % procedure
 
     (results, requests) = \
-      self._PrepareRequests(self._resolver(hosts), self._port, procedure,
-                            body, read_timeout)
+      self._PrepareRequests(self._resolver(hosts, resolver_opts), self._port,
+                            procedure, body, read_timeout)
 
     _req_process_fn(requests.values(), lock_monitor_cb=self._lock_monitor_cb)
 
@@ -469,7 +469,8 @@ class _RpcClientBase:
       pnbody = dict((n, serializer.DumpJson(prep_fn(n, enc_args)))
                     for n in node_list)
 
-    result = self._proc(node_list, procedure, pnbody, read_timeout)
+    result = self._proc(node_list, procedure, pnbody, read_timeout,
+                        req_resolver_opts)
 
     if postproc_fn:
       return dict(map(lambda (key, value): (key, postproc_fn(value)),
index 7eb529b..633d89a 100755 (executable)
@@ -74,7 +74,7 @@ class TestRpcProcessor(unittest.TestCase):
     http_proc = _FakeRequestProcessor(self._GetVersionResponse)
     proc = rpc._RpcProcessor(resolver, 24094)
     result = proc(["localhost"], "version", {"localhost": ""}, 60,
-                  _req_process_fn=http_proc)
+                  NotImplemented, _req_process_fn=http_proc)
     self.assertEqual(result.keys(), ["localhost"])
     lhresp = result["localhost"]
     self.assertFalse(lhresp.offline)
@@ -100,7 +100,8 @@ class TestRpcProcessor(unittest.TestCase):
     proc = rpc._RpcProcessor(resolver, 19176)
     host = "node31856"
     body = {host: ""}
-    result = proc([host], "version", body, 12356, _req_process_fn=http_proc)
+    result = proc([host], "version", body, 12356, NotImplemented,
+                  _req_process_fn=http_proc)
     self.assertEqual(result.keys(), [host])
     lhresp = result[host]
     self.assertFalse(lhresp.offline)
@@ -117,7 +118,8 @@ class TestRpcProcessor(unittest.TestCase):
     proc = rpc._RpcProcessor(resolver, 30668)
     host = "n17296"
     body = {host: ""}
-    result = proc([host], "version", body, 60, _req_process_fn=http_proc)
+    result = proc([host], "version", body, 60, NotImplemented,
+                  _req_process_fn=http_proc)
     self.assertEqual(result.keys(), [host])
     lhresp = result[host]
     self.assertTrue(lhresp.offline)
@@ -148,7 +150,8 @@ class TestRpcProcessor(unittest.TestCase):
     resolver = rpc._StaticResolver(nodes)
     http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
     proc = rpc._RpcProcessor(resolver, 23245)
-    result = proc(nodes, "version", body, 60, _req_process_fn=http_proc,)
+    result = proc(nodes, "version", body, 60, NotImplemented,
+                  _req_process_fn=http_proc)
     self.assertEqual(sorted(result.keys()), sorted(nodes))
 
     for name in nodes:
@@ -177,7 +180,7 @@ class TestRpcProcessor(unittest.TestCase):
                                              errinfo))
       host = "aef9ur4i.example.com"
       body = {host: ""}
-      result = proc(body.keys(), "version", body, 60,
+      result = proc(body.keys(), "version", body, 60, NotImplemented,
                     _req_process_fn=http_proc)
       self.assertEqual(result.keys(), [host])
       lhresp = result[host]
@@ -227,7 +230,7 @@ class TestRpcProcessor(unittest.TestCase):
     http_proc = \
       _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
                                            httperrnodes, failnodes))
-    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT,
+    result = proc(nodes, "vg_list", body, rpc._TMO_URGENT, NotImplemented,
                   _req_process_fn=http_proc)
     self.assertEqual(sorted(result.keys()), sorted(nodes))
 
@@ -272,7 +275,7 @@ class TestRpcProcessor(unittest.TestCase):
       http_proc = _FakeRequestProcessor(fn)
       host = "oqo7lanhly.example.com"
       body = {host: ""}
-      result = proc([host], "version", body, 60,
+      result = proc([host], "version", body, 60, NotImplemented,
                     _req_process_fn=http_proc)
       self.assertEqual(result.keys(), [host])
       lhresp = result[host]
@@ -304,7 +307,8 @@ class TestRpcProcessor(unittest.TestCase):
     proc = rpc._RpcProcessor(resolver, 18700)
     host = "node19759"
     body = {host: serializer.DumpJson(test_data)}
-    result = proc([host], "upload_file", body, 30, _req_process_fn=http_proc)
+    result = proc([host], "upload_file", body, 30, NotImplemented,
+                  _req_process_fn=http_proc)
     self.assertEqual(result.keys(), [host])
     lhresp = result[host]
     self.assertFalse(lhresp.offline)
@@ -322,7 +326,8 @@ class TestSsconfResolver(unittest.TestCase):
     node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
-    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
+    result = rpc._SsconfResolver(node_list, NotImplemented,
+                                 ssc=ssc, nslookup_fn=NotImplemented)
     self.assertEqual(result, zip(node_list, addr_list))
 
   def testNsLookup(self):
@@ -331,7 +336,8 @@ class TestSsconfResolver(unittest.TestCase):
     ssc = GetFakeSimpleStoreClass(lambda _: [])
     node_addr_map = dict(zip(node_list, addr_list))
     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
-    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
+    result = rpc._SsconfResolver(node_list, NotImplemented,
+                                 ssc=ssc, nslookup_fn=nslookup_fn)
     self.assertEqual(result, zip(node_list, addr_list))
 
   def testBothLookups(self):
@@ -342,7 +348,8 @@ class TestSsconfResolver(unittest.TestCase):
     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
     node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
     nslookup_fn = lambda name, family=None: node_addr_map.get(name)
-    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
+    result = rpc._SsconfResolver(node_list, NotImplemented,
+                                 ssc=ssc, nslookup_fn=nslookup_fn)
     self.assertEqual(result, zip(node_list, addr_list))
 
   def testAddressLookupIPv6(self):
@@ -350,7 +357,8 @@ class TestSsconfResolver(unittest.TestCase):
     node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
     node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
     ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
-    result = rpc._SsconfResolver(node_list, ssc=ssc, nslookup_fn=NotImplemented)
+    result = rpc._SsconfResolver(node_list, NotImplemented,
+                                 ssc=ssc, nslookup_fn=NotImplemented)
     self.assertEqual(result, zip(node_list, addr_list))
 
 
@@ -359,11 +367,11 @@ class TestStaticResolver(unittest.TestCase):
     addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
     nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
     res = rpc._StaticResolver(addresses)
-    self.assertEqual(res(nodes), zip(nodes, addresses))
+    self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
 
   def testWrongLength(self):
     res = rpc._StaticResolver([])
-    self.assertRaises(AssertionError, res, ["abc"])
+    self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
 
 
 class TestNodeConfigResolver(unittest.TestCase):
@@ -380,24 +388,24 @@ class TestNodeConfigResolver(unittest.TestCase):
   def testSingleOnline(self):
     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
                                              NotImplemented,
-                                             ["node90.example.com"]),
+                                             ["node90.example.com"], None),
                      [("node90.example.com", "192.0.2.90")])
 
   def testSingleOffline(self):
     self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
                                              NotImplemented,
-                                             ["node100.example.com"]),
+                                             ["node100.example.com"], None),
                      [("node100.example.com", rpc._OFFLINE)])
 
   def testUnknownSingleNode(self):
     self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
-                                             ["node110.example.com"]),
+                                             ["node110.example.com"], None),
                      [("node110.example.com", "node110.example.com")])
 
   def testMultiEmpty(self):
     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
                                              lambda: {},
-                                             []),
+                                             [], None),
                      [])
 
   def testMultiSomeOffline(self):
@@ -410,7 +418,7 @@ class TestNodeConfigResolver(unittest.TestCase):
     # Resolve no names
     self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
                                              lambda: nodes,
-                                             []),
+                                             [], None),
                      [])
 
     # Offline, online and unknown hosts
@@ -419,7 +427,8 @@ class TestNodeConfigResolver(unittest.TestCase):
                                              ["node3.example.com",
                                               "node92.example.com",
                                               "node54.example.com",
-                                              "unknown.example.com",]), [
+                                              "unknown.example.com",],
+                                             None), [
       ("node3.example.com", rpc._OFFLINE),
       ("node92.example.com", "192.0.2.92"),
       ("node54.example.com", rpc._OFFLINE),