Add AsyncUDPSocket tests
authorGuido Trotter <ultrotter@google.com>
Wed, 19 May 2010 15:39:41 +0000 (16:39 +0100)
committerGuido Trotter <ultrotter@google.com>
Thu, 20 May 2010 14:52:53 +0000 (15:52 +0100)
Signed-off-by: Guido Trotter <ultrotter@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>

test/ganeti.daemon_unittest.py

index 374d1c3..e9f10bd 100755 (executable)
@@ -24,6 +24,7 @@
 import unittest
 import signal
 import os
+import socket
 
 from ganeti import daemon
 
@@ -94,5 +95,83 @@ class TestMainloop(testutils.GanetiTestCase):
     self.assertEquals(self.onsignal_events, self.sendsig_events)
 
 
+class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
+
+  def __init__(self):
+    daemon.AsyncUDPSocket.__init__(self)
+    self.received = []
+    self.error_count = 0
+
+  def handle_datagram(self, payload, ip, port):
+    self.received.append((payload))
+    if payload == "terminate":
+      os.kill(os.getpid(), signal.SIGTERM)
+    elif payload == "error":
+      raise errors.GenericError("error")
+
+  def handle_error(self):
+    self.error_count += 1
+
+
+class TestAsyncUDPSocket(testutils.GanetiTestCase):
+  """Test daemon.AsyncUDPSocket"""
+
+  def setUp(self):
+    testutils.GanetiTestCase.setUp(self)
+    self.mainloop = daemon.Mainloop()
+    self.server = _MyAsyncUDPSocket()
+    self.client = _MyAsyncUDPSocket()
+    self.server.bind(("127.0.0.1", 0))
+    self.port = self.server.getsockname()[1]
+
+  def tearDown(self):
+    self.server.close()
+    self.client.close()
+    testutils.GanetiTestCase.tearDown(self)
+
+  def testNoDoubleBind(self):
+    self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
+
+  def _ThreadedClient(self, payload):
+    self.client.enqueue_send("127.0.0.1", self.port, payload)
+    print "sending %s" % payload
+    while self.client.writable():
+      self.client.handle_write()
+
+  def testAsyncClientServer(self):
+    self.client.enqueue_send("127.0.0.1", self.port, "p1")
+    self.client.enqueue_send("127.0.0.1", self.port, "p2")
+    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
+    self.mainloop.Run()
+    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
+
+  def testSyncClientServer(self):
+    self.client.enqueue_send("127.0.0.1", self.port, "p1")
+    self.client.enqueue_send("127.0.0.1", self.port, "p2")
+    while self.client.writable():
+      self.client.handle_write()
+    self.server.process_next_packet()
+    self.assertEquals(self.server.received, ["p1"])
+    self.server.process_next_packet()
+    self.assertEquals(self.server.received, ["p1", "p2"])
+    self.client.enqueue_send("127.0.0.1", self.port, "p3")
+    while self.client.writable():
+      self.client.handle_write()
+    self.server.process_next_packet()
+    self.assertEquals(self.server.received, ["p1", "p2", "p3"])
+
+  def testErrorHandling(self):
+    self.client.enqueue_send("127.0.0.1", self.port, "p1")
+    self.client.enqueue_send("127.0.0.1", self.port, "p2")
+    self.client.enqueue_send("127.0.0.1", self.port, "error")
+    self.client.enqueue_send("127.0.0.1", self.port, "p3")
+    self.client.enqueue_send("127.0.0.1", self.port, "error")
+    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
+    self.mainloop.Run()
+    self.assertEquals(self.server.received,
+                      ["p1", "p2", "error", "p3", "error", "terminate"])
+    self.assertEquals(self.server.error_count, 2)
+
+
 if __name__ == "__main__":
   testutils.GanetiTestProgram()