import unittest
import signal
import os
+import socket
from ganeti import daemon
self.assertEquals(self.onsignal_events, self.sendsig_events)
+class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
+
+ def __init__(self):
+ daemon.AsyncUDPSocket.__init__(self)
+ self.received = []
+ self.error_count = 0
+
+ def handle_datagram(self, payload, ip, port):
+ self.received.append((payload))
+ if payload == "terminate":
+ os.kill(os.getpid(), signal.SIGTERM)
+ elif payload == "error":
+ raise errors.GenericError("error")
+
+ def handle_error(self):
+ self.error_count += 1
+
+
+class TestAsyncUDPSocket(testutils.GanetiTestCase):
+ """Test daemon.AsyncUDPSocket"""
+
+ def setUp(self):
+ testutils.GanetiTestCase.setUp(self)
+ self.mainloop = daemon.Mainloop()
+ self.server = _MyAsyncUDPSocket()
+ self.client = _MyAsyncUDPSocket()
+ self.server.bind(("127.0.0.1", 0))
+ self.port = self.server.getsockname()[1]
+
+ def tearDown(self):
+ self.server.close()
+ self.client.close()
+ testutils.GanetiTestCase.tearDown(self)
+
+ def testNoDoubleBind(self):
+ self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
+
+ def _ThreadedClient(self, payload):
+ self.client.enqueue_send("127.0.0.1", self.port, payload)
+ print "sending %s" % payload
+ while self.client.writable():
+ self.client.handle_write()
+
+ def testAsyncClientServer(self):
+ self.client.enqueue_send("127.0.0.1", self.port, "p1")
+ self.client.enqueue_send("127.0.0.1", self.port, "p2")
+ self.client.enqueue_send("127.0.0.1", self.port, "terminate")
+ self.mainloop.Run()
+ self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
+
+ def testSyncClientServer(self):
+ self.client.enqueue_send("127.0.0.1", self.port, "p1")
+ self.client.enqueue_send("127.0.0.1", self.port, "p2")
+ while self.client.writable():
+ self.client.handle_write()
+ self.server.process_next_packet()
+ self.assertEquals(self.server.received, ["p1"])
+ self.server.process_next_packet()
+ self.assertEquals(self.server.received, ["p1", "p2"])
+ self.client.enqueue_send("127.0.0.1", self.port, "p3")
+ while self.client.writable():
+ self.client.handle_write()
+ self.server.process_next_packet()
+ self.assertEquals(self.server.received, ["p1", "p2", "p3"])
+
+ def testErrorHandling(self):
+ self.client.enqueue_send("127.0.0.1", self.port, "p1")
+ self.client.enqueue_send("127.0.0.1", self.port, "p2")
+ self.client.enqueue_send("127.0.0.1", self.port, "error")
+ self.client.enqueue_send("127.0.0.1", self.port, "p3")
+ self.client.enqueue_send("127.0.0.1", self.port, "error")
+ self.client.enqueue_send("127.0.0.1", self.port, "terminate")
+ self.mainloop.Run()
+ self.assertEquals(self.server.received,
+ ["p1", "p2", "error", "p3", "error", "terminate"])
+ self.assertEquals(self.server.error_count, 2)
+
+
if __name__ == "__main__":
testutils.GanetiTestProgram()