4 # Copyright (C) 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for unittesting the daemon module"""
32 from ganeti import daemon
33 from ganeti import errors
34 from ganeti import utils
39 class TestMainloop(testutils.GanetiTestCase):
40 """Test daemon.Mainloop"""
43 testutils.GanetiTestCase.setUp(self)
44 self.mainloop = daemon.Mainloop()
45 self.sendsig_events = []
46 self.onsignal_events = []
48 def _CancelEvent(self, handle):
49 self.mainloop.scheduler.cancel(handle)
51 def _SendSig(self, sig):
52 self.sendsig_events.append(sig)
53 os.kill(os.getpid(), sig)
55 def OnSignal(self, signum):
56 self.onsignal_events.append(signum)
58 def testRunAndTermBySched(self):
59 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
60 self.mainloop.Run() # terminates by _SendSig being scheduled
61 self.assertEquals(self.sendsig_events, [signal.SIGTERM])
63 def testTerminatingSignals(self):
64 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
65 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
67 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
68 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
70 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
73 def testSchedulerCancel(self):
74 handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
76 self.mainloop.scheduler.cancel(handle)
77 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
78 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
80 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
82 def testRegisterSignal(self):
83 self.mainloop.RegisterSignal(self)
84 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
85 handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
87 self.mainloop.scheduler.cancel(handle)
88 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
89 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
90 # ...not delievered because they are scheduled after TERM
91 self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
92 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
94 self.assertEquals(self.sendsig_events,
95 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
96 self.assertEquals(self.onsignal_events, self.sendsig_events)
98 def testDeferredCancel(self):
99 self.mainloop.RegisterSignal(self)
101 self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
103 handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
105 handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
107 self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
109 self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
111 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
113 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
114 self.assertEquals(self.onsignal_events, self.sendsig_events)
117 self.mainloop.RegisterSignal(self)
118 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
119 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
120 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
121 self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
122 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
124 self.assertEquals(self.sendsig_events,
125 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
126 self.assertEquals(self.onsignal_events, self.sendsig_events)
127 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
129 self.assertEquals(self.sendsig_events,
130 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
131 signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
132 self.assertEquals(self.onsignal_events, self.sendsig_events)
134 def testPriority(self):
135 # for events at the same time, the highest priority one executes first
137 self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
139 self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
142 self.assertEquals(self.sendsig_events, [signal.SIGTERM])
143 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
145 self.assertEquals(self.sendsig_events,
146 [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
149 class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
152 daemon.AsyncUDPSocket.__init__(self)
156 def handle_datagram(self, payload, ip, port):
157 self.received.append((payload))
158 if payload == "terminate":
159 os.kill(os.getpid(), signal.SIGTERM)
160 elif payload == "error":
161 raise errors.GenericError("error")
163 def handle_error(self):
164 self.error_count += 1
168 class TestAsyncUDPSocket(testutils.GanetiTestCase):
169 """Test daemon.AsyncUDPSocket"""
172 testutils.GanetiTestCase.setUp(self)
173 self.mainloop = daemon.Mainloop()
174 self.server = _MyAsyncUDPSocket()
175 self.client = _MyAsyncUDPSocket()
176 self.server.bind(("127.0.0.1", 0))
177 self.port = self.server.getsockname()[1]
178 # Save utils.IgnoreSignals so we can do evil things to it...
179 self.saved_utils_ignoresignals = utils.IgnoreSignals
184 # ...and restore it as well
185 utils.IgnoreSignals = self.saved_utils_ignoresignals
186 testutils.GanetiTestCase.tearDown(self)
188 def testNoDoubleBind(self):
189 self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
191 def _ThreadedClient(self, payload):
192 self.client.enqueue_send("127.0.0.1", self.port, payload)
193 print "sending %s" % payload
194 while self.client.writable():
195 self.client.handle_write()
197 def testAsyncClientServer(self):
198 self.client.enqueue_send("127.0.0.1", self.port, "p1")
199 self.client.enqueue_send("127.0.0.1", self.port, "p2")
200 self.client.enqueue_send("127.0.0.1", self.port, "terminate")
202 self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
204 def testSyncClientServer(self):
205 self.client.enqueue_send("127.0.0.1", self.port, "p1")
206 self.client.enqueue_send("127.0.0.1", self.port, "p2")
207 while self.client.writable():
208 self.client.handle_write()
209 self.server.process_next_packet()
210 self.assertEquals(self.server.received, ["p1"])
211 self.server.process_next_packet()
212 self.assertEquals(self.server.received, ["p1", "p2"])
213 self.client.enqueue_send("127.0.0.1", self.port, "p3")
214 while self.client.writable():
215 self.client.handle_write()
216 self.server.process_next_packet()
217 self.assertEquals(self.server.received, ["p1", "p2", "p3"])
219 def testErrorHandling(self):
220 self.client.enqueue_send("127.0.0.1", self.port, "p1")
221 self.client.enqueue_send("127.0.0.1", self.port, "p2")
222 self.client.enqueue_send("127.0.0.1", self.port, "error")
223 self.client.enqueue_send("127.0.0.1", self.port, "p3")
224 self.client.enqueue_send("127.0.0.1", self.port, "error")
225 self.client.enqueue_send("127.0.0.1", self.port, "terminate")
226 self.assertRaises(errors.GenericError, self.mainloop.Run)
227 self.assertEquals(self.server.received,
228 ["p1", "p2", "error"])
229 self.assertEquals(self.server.error_count, 1)
230 self.assertRaises(errors.GenericError, self.mainloop.Run)
231 self.assertEquals(self.server.received,
232 ["p1", "p2", "error", "p3", "error"])
233 self.assertEquals(self.server.error_count, 2)
235 self.assertEquals(self.server.received,
236 ["p1", "p2", "error", "p3", "error", "terminate"])
237 self.assertEquals(self.server.error_count, 2)
239 def testSignaledWhileReceiving(self):
240 utils.IgnoreSignals = lambda fn, *args, **kwargs: None
241 self.client.enqueue_send("127.0.0.1", self.port, "p1")
242 self.client.enqueue_send("127.0.0.1", self.port, "p2")
243 self.server.handle_read()
244 self.assertEquals(self.server.received, [])
245 self.client.enqueue_send("127.0.0.1", self.port, "terminate")
246 utils.IgnoreSignals = self.saved_utils_ignoresignals
248 self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
251 class _MyAsyncStreamServer(daemon.AsyncStreamServer):
253 def __init__(self, family, address, handle_connection_fn):
254 daemon.AsyncStreamServer.__init__(self, family, address)
255 self.handle_connection_fn = handle_connection_fn
259 def handle_connection(self, connected_socket, client_address):
260 self.handle_connection_fn(connected_socket, client_address)
262 def handle_error(self):
263 self.error_count += 1
267 def handle_expt(self):
272 class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
274 def __init__(self, connected_socket, client_address, terminator, family,
275 message_fn, client_id):
276 daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
279 self.message_fn = message_fn
280 self.client_id = client_id
283 def handle_message(self, message, message_id):
284 self.message_fn(self, message, message_id)
286 def handle_error(self):
287 self.error_count += 1
291 class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
292 """Test daemon.AsyncStreamServer with a TCP connection"""
294 family = socket.AF_INET
297 testutils.GanetiTestCase.setUp(self)
298 self.mainloop = daemon.Mainloop()
299 self.address = self.getAddress()
300 self.server = _MyAsyncStreamServer(self.family, self.address,
301 self.handle_connection)
302 self.client_handler = _MyMessageStreamHandler
303 self.terminator = "\3"
304 self.address = self.server.getsockname()
306 self.connections = []
308 self.connect_terminate_count = 0
309 self.message_terminate_count = 0
310 self.next_client_id = 0
311 # Save utils.IgnoreSignals so we can do evil things to it...
312 self.saved_utils_ignoresignals = utils.IgnoreSignals
315 for c in self.clients:
317 for c in self.connections:
320 # ...and restore it as well
321 utils.IgnoreSignals = self.saved_utils_ignoresignals
322 testutils.GanetiTestCase.tearDown(self)
324 def getAddress(self):
325 return ("127.0.0.1", 0)
327 def countTerminate(self, name):
328 value = getattr(self, name)
329 if value is not None:
331 setattr(self, name, value)
333 os.kill(os.getpid(), signal.SIGTERM)
335 def handle_connection(self, connected_socket, client_address):
336 client_id = self.next_client_id
337 self.next_client_id += 1
338 client_handler = self.client_handler(connected_socket, client_address,
339 self.terminator, self.family,
342 self.connections.append(client_handler)
343 self.countTerminate("connect_terminate_count")
345 def handle_message(self, handler, message, message_id):
346 self.messages.setdefault(handler.client_id, [])
347 # We should just check that the message_ids are monotonically increasing.
348 # If in the unit tests we never remove messages from the received queue,
349 # though, we can just require that the queue length is the same as the
350 # message id, before pushing the message to it. This forces a more
351 # restrictive check, but we can live with this for now.
352 self.assertEquals(len(self.messages[handler.client_id]), message_id)
353 self.messages[handler.client_id].append(message)
354 if message == "error":
355 raise errors.GenericError("error")
356 self.countTerminate("message_terminate_count")
359 client = socket.socket(self.family, socket.SOCK_STREAM)
360 client.connect(self.address)
361 self.clients.append(client)
365 testutils.GanetiTestCase.tearDown(self)
368 def testConnect(self):
371 self.assertEquals(len(self.connections), 1)
374 self.assertEquals(len(self.connections), 2)
375 self.connect_terminate_count = 4
381 self.assertEquals(len(self.connections), 6)
383 def testBasicMessage(self):
384 self.connect_terminate_count = None
385 client = self.getClient()
386 client.send("ciao\3")
388 self.assertEquals(len(self.connections), 1)
389 self.assertEquals(len(self.messages[0]), 1)
390 self.assertEquals(self.messages[0][0], "ciao")
392 def testDoubleMessage(self):
393 self.connect_terminate_count = None
394 client = self.getClient()
395 client.send("ciao\3")
397 client.send("foobar\3")
399 self.assertEquals(len(self.connections), 1)
400 self.assertEquals(len(self.messages[0]), 2)
401 self.assertEquals(self.messages[0][1], "foobar")
403 def testComposedMessage(self):
404 self.connect_terminate_count = None
405 self.message_terminate_count = 3
406 client = self.getClient()
407 client.send("one\3composed\3message\3")
409 self.assertEquals(len(self.messages[0]), 3)
410 self.assertEquals(self.messages[0], ["one", "composed", "message"])
412 def testLongTerminator(self):
413 self.terminator = "\0\1\2"
414 self.connect_terminate_count = None
415 self.message_terminate_count = 3
416 client = self.getClient()
417 client.send("one\0\1\2composed\0\1\2message\0\1\2")
419 self.assertEquals(len(self.messages[0]), 3)
420 self.assertEquals(self.messages[0], ["one", "composed", "message"])
422 def testErrorHandling(self):
423 self.connect_terminate_count = None
424 self.message_terminate_count = None
425 client = self.getClient()
426 client.send("one\3two\3error\3three\3")
427 self.assertRaises(errors.GenericError, self.mainloop.Run)
428 self.assertEquals(self.connections[0].error_count, 1)
429 self.assertEquals(self.messages[0], ["one", "two", "error"])
430 client.send("error\3")
431 self.assertRaises(errors.GenericError, self.mainloop.Run)
432 self.assertEquals(self.connections[0].error_count, 2)
433 self.assertEquals(self.messages[0], ["one", "two", "error", "three",
436 def testDoubleClient(self):
437 self.connect_terminate_count = None
438 self.message_terminate_count = 2
439 client1 = self.getClient()
440 client2 = self.getClient()
441 client1.send("c1m1\3")
442 client2.send("c2m1\3")
444 self.assertEquals(self.messages[0], ["c1m1"])
445 self.assertEquals(self.messages[1], ["c2m1"])
447 def testUnterminatedMessage(self):
448 self.connect_terminate_count = None
449 self.message_terminate_count = 3
450 client1 = self.getClient()
451 client2 = self.getClient()
452 client1.send("message\3unterminated")
453 client2.send("c2m1\3c2m2\3")
455 self.assertEquals(self.messages[0], ["message"])
456 self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
457 client1.send("message\3")
459 self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
461 def testSignaledWhileAccepting(self):
462 utils.IgnoreSignals = lambda fn, *args, **kwargs: None
463 client1 = self.getClient()
464 self.server.handle_accept()
465 # When interrupted while accepting we don't have a connection, but we
466 # didn't crash either.
467 self.assertEquals(len(self.connections), 0)
468 utils.IgnoreSignals = self.saved_utils_ignoresignals
470 self.assertEquals(len(self.connections), 1)
473 class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
474 """Test daemon.AsyncStreamServer with a Unix path connection"""
476 family = socket.AF_UNIX
478 def getAddress(self):
479 self.tmpdir = tempfile.mkdtemp()
480 return os.path.join(self.tmpdir, "server.sock")
483 shutil.rmtree(self.tmpdir)
484 TestAsyncStreamServerTCP.tearDown(self)
487 class TestAsyncAwaker(testutils.GanetiTestCase):
488 """Test daemon.AsyncAwaker"""
490 family = socket.AF_INET
493 testutils.GanetiTestCase.setUp(self)
494 self.mainloop = daemon.Mainloop()
495 self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
496 self.signal_count = 0
497 self.signal_terminate_count = 1
502 def handle_signal(self):
503 self.signal_count += 1
504 self.signal_terminate_count -= 1
505 if self.signal_terminate_count <= 0:
506 os.kill(os.getpid(), signal.SIGTERM)
508 def testBasicSignaling(self):
511 self.assertEquals(self.signal_count, 1)
513 def testDoubleSignaling(self):
517 # The second signal is never delivered
518 self.assertEquals(self.signal_count, 1)
520 def testReallyDoubleSignaling(self):
521 self.assert_(self.awaker.readable())
523 # Let's suppose two threads overlap, and both find need_signal True
524 self.awaker.need_signal = True
527 # We still get only one signaling
528 self.assertEquals(self.signal_count, 1)
530 def testNoSignalFnArgument(self):
531 myawaker = daemon.AsyncAwaker()
532 self.assertRaises(socket.error, myawaker.handle_read)
534 myawaker.handle_read()
535 self.assertRaises(socket.error, myawaker.handle_read)
538 myawaker.handle_read()
539 self.assertRaises(socket.error, myawaker.handle_read)
542 def testWrongSignalFnArgument(self):
543 self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
544 self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
545 self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
546 self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
549 if __name__ == "__main__":
550 testutils.GanetiTestProgram()