RPC: Add a new client type for DNS only
[ganeti-local] / test / ganeti.http_unittest.py
index 7bffe6f..9cbbf13 100755 (executable)
@@ -377,6 +377,22 @@ class TestClientRequest(unittest.TestCase):
     cr = http.client.HttpClientRequest("localhost", 1234, "GET", "/version")
     self.assertEqual(cr.post_data, "")
 
+  def testCompletionCallback(self):
+    for argname in ["completion_cb", "curl_config_fn"]:
+      kwargs = {
+        argname: NotImplementedError,
+        }
+      cr = http.client.HttpClientRequest("localhost", 14038, "GET", "/version",
+                                         **kwargs)
+      self.assertEqual(getattr(cr, argname), NotImplementedError)
+
+      for fn in [NotImplemented, {}, 1]:
+        kwargs = {
+          argname: fn,
+          }
+        self.assertRaises(AssertionError, http.client.HttpClientRequest,
+                          "localhost", 23150, "GET", "/version", **kwargs)
+
 
 class _FakeCurl:
   def __init__(self):
@@ -619,14 +635,24 @@ class TestProcessRequests(unittest.TestCase):
     def cfg_fn(port, curl):
       curl.opts["__port__"] = port
 
-    def _LockCheckReset(monitor, curl):
+    def _LockCheckReset(monitor, req):
       self.assertTrue(monitor._lock.is_owned(shared=0),
                       msg="Lock must be owned in exclusive mode")
-      curl.opts["__lockcheck__"] = True
+      assert not hasattr(req, "lockcheck__")
+      setattr(req, "lockcheck__", True)
+
+    def _BuildNiceName(port, default=None):
+      if port % 5 == 0:
+        return "nicename%s" % port
+      else:
+        # Use standard name
+        return default
 
     requests = \
       [http.client.HttpClientRequest("localhost", i, "POST", "/version%s" % i,
-                                     curl_config_fn=compat.partial(cfg_fn, i))
+                                     curl_config_fn=compat.partial(cfg_fn, i),
+                                     completion_cb=NotImplementedError,
+                                     nicename=_BuildNiceName(i))
        for i in range(15176, 15501)]
     requests_count = len(requests)
 
@@ -641,14 +667,27 @@ class TestProcessRequests(unittest.TestCase):
       self.assertTrue(compat.all(isinstance(curl, _FakeCurl)
                                  for curl in handles))
 
+      # Prepare for lock check
+      for req in requests:
+        assert req.completion_cb is NotImplementedError
+        if use_monitor:
+          req.completion_cb = \
+            compat.partial(_LockCheckReset, lock_monitor_cb.GetMonitor())
+
       for idx, curl in enumerate(handles):
-        port = curl.opts["__port__"]
+        try:
+          port = curl.opts["__port__"]
+        except KeyError:
+          self.fail("Per-request config function was not called")
 
         if use_monitor:
           # Check if lock information is correct
           lock_info = lock_monitor_cb.GetMonitor().GetLockInfo(None)
           expected = \
-            [("rpc/localhost/version%s" % handle.opts["__port__"], None,
+            [("rpc/%s" % (_BuildNiceName(handle.opts["__port__"],
+                                         default=("localhost/version%s" %
+                                                  handle.opts["__port__"]))),
+              None,
               [threading.currentThread().getName()], None)
              for handle in handles[idx:]]
           self.assertEqual(sorted(lock_info), sorted(expected))
@@ -664,21 +703,17 @@ class TestProcessRequests(unittest.TestCase):
           pycurl.RESPONSE_CODE: response_code,
           }
 
-        # Unset options which will be reset
-        assert not hasattr(curl, "reset")
-        if use_monitor:
-          setattr(curl, "reset",
-                  compat.partial(_LockCheckReset, lock_monitor_cb.GetMonitor(),
-                                 curl))
-        else:
-          self.assertFalse(curl.opts.pop(pycurl.POSTFIELDS))
-          self.assertTrue(callable(curl.opts.pop(pycurl.WRITEFUNCTION)))
+        # Prepare for reset
+        self.assertFalse(curl.opts.pop(pycurl.POSTFIELDS))
+        self.assertTrue(callable(curl.opts.pop(pycurl.WRITEFUNCTION)))
 
         yield (curl, msg)
 
       if use_monitor:
-        self.assertTrue(compat.all(curl.opts["__lockcheck__"]
-                                   for curl in handles))
+        self.assertTrue(compat.all(req.lockcheck__ for req in requests))
+
+    if use_monitor:
+      self.assertEqual(lock_monitor_cb.GetMonitor(), None)
 
     http.client.ProcessRequests(requests, lock_monitor_cb=lock_monitor_cb,
                                 _curl=_FakeCurl,