4 # Copyright (C) 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for unittesting the daemon module"""
32 from ganeti import daemon
33 from ganeti import errors
34 from ganeti import constants
35 from ganeti import utils
40 class TestMainloop(testutils.GanetiTestCase):
41 """Test daemon.Mainloop"""
44 testutils.GanetiTestCase.setUp(self)
45 self.mainloop = daemon.Mainloop()
46 self.sendsig_events = []
47 self.onsignal_events = []
49 def _CancelEvent(self, handle):
50 self.mainloop.scheduler.cancel(handle)
52 def _SendSig(self, sig):
53 self.sendsig_events.append(sig)
54 os.kill(os.getpid(), sig)
56 def OnSignal(self, signum):
57 self.onsignal_events.append(signum)
59 def testRunAndTermBySched(self):
60 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
61 self.mainloop.Run() # terminates by _SendSig being scheduled
62 self.assertEquals(self.sendsig_events, [signal.SIGTERM])
64 def testTerminatingSignals(self):
65 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
66 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
68 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
69 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
71 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
74 def testSchedulerCancel(self):
75 handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
77 self.mainloop.scheduler.cancel(handle)
78 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
79 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
81 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
83 def testRegisterSignal(self):
84 self.mainloop.RegisterSignal(self)
85 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
86 handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
88 self.mainloop.scheduler.cancel(handle)
89 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
90 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
91 # ...not delievered because they are scheduled after TERM
92 self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
93 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
95 self.assertEquals(self.sendsig_events,
96 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
97 self.assertEquals(self.onsignal_events, self.sendsig_events)
99 def testDeferredCancel(self):
100 self.mainloop.RegisterSignal(self)
102 self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
104 handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
106 handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
108 self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
110 self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
112 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
114 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
115 self.assertEquals(self.onsignal_events, self.sendsig_events)
118 self.mainloop.RegisterSignal(self)
119 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
120 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
121 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
122 self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
123 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
125 self.assertEquals(self.sendsig_events,
126 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
127 self.assertEquals(self.onsignal_events, self.sendsig_events)
128 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
130 self.assertEquals(self.sendsig_events,
131 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
132 signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
133 self.assertEquals(self.onsignal_events, self.sendsig_events)
135 def testPriority(self):
136 # for events at the same time, the highest priority one executes first
138 self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
140 self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
143 self.assertEquals(self.sendsig_events, [signal.SIGTERM])
144 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
146 self.assertEquals(self.sendsig_events,
147 [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
150 class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
152 def __init__(self, family):
153 daemon.AsyncUDPSocket.__init__(self, family)
157 def handle_datagram(self, payload, ip, port):
158 self.received.append((payload))
159 if payload == "terminate":
160 os.kill(os.getpid(), signal.SIGTERM)
161 elif payload == "error":
162 raise errors.GenericError("error")
164 def handle_error(self):
165 self.error_count += 1
169 class _BaseAsyncUDPSocketTest:
170 """Base class for AsyncUDPSocket tests"""
176 self.mainloop = daemon.Mainloop()
177 self.server = _MyAsyncUDPSocket(self.family)
178 self.client = _MyAsyncUDPSocket(self.family)
179 self.server.bind((self.address, 0))
180 self.port = self.server.getsockname()[1]
181 # Save utils.IgnoreSignals so we can do evil things to it...
182 self.saved_utils_ignoresignals = utils.IgnoreSignals
187 # ...and restore it as well
188 utils.IgnoreSignals = self.saved_utils_ignoresignals
189 testutils.GanetiTestCase.tearDown(self)
191 def testNoDoubleBind(self):
192 self.assertRaises(socket.error, self.client.bind, (self.address, self.port))
194 def testAsyncClientServer(self):
195 self.client.enqueue_send(self.address, self.port, "p1")
196 self.client.enqueue_send(self.address, self.port, "p2")
197 self.client.enqueue_send(self.address, self.port, "terminate")
199 self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
201 def testSyncClientServer(self):
202 self.client.handle_write()
203 self.client.enqueue_send(self.address, self.port, "p1")
204 self.client.enqueue_send(self.address, self.port, "p2")
205 while self.client.writable():
206 self.client.handle_write()
207 self.server.process_next_packet()
208 self.assertEquals(self.server.received, ["p1"])
209 self.server.process_next_packet()
210 self.assertEquals(self.server.received, ["p1", "p2"])
211 self.client.enqueue_send(self.address, self.port, "p3")
212 while self.client.writable():
213 self.client.handle_write()
214 self.server.process_next_packet()
215 self.assertEquals(self.server.received, ["p1", "p2", "p3"])
217 def testErrorHandling(self):
218 self.client.enqueue_send(self.address, self.port, "p1")
219 self.client.enqueue_send(self.address, self.port, "p2")
220 self.client.enqueue_send(self.address, self.port, "error")
221 self.client.enqueue_send(self.address, self.port, "p3")
222 self.client.enqueue_send(self.address, self.port, "error")
223 self.client.enqueue_send(self.address, self.port, "terminate")
224 self.assertRaises(errors.GenericError, self.mainloop.Run)
225 self.assertEquals(self.server.received,
226 ["p1", "p2", "error"])
227 self.assertEquals(self.server.error_count, 1)
228 self.assertRaises(errors.GenericError, self.mainloop.Run)
229 self.assertEquals(self.server.received,
230 ["p1", "p2", "error", "p3", "error"])
231 self.assertEquals(self.server.error_count, 2)
233 self.assertEquals(self.server.received,
234 ["p1", "p2", "error", "p3", "error", "terminate"])
235 self.assertEquals(self.server.error_count, 2)
237 def testSignaledWhileReceiving(self):
238 utils.IgnoreSignals = lambda fn, *args, **kwargs: None
239 self.client.enqueue_send(self.address, self.port, "p1")
240 self.client.enqueue_send(self.address, self.port, "p2")
241 self.server.handle_read()
242 self.assertEquals(self.server.received, [])
243 self.client.enqueue_send(self.address, self.port, "terminate")
244 utils.IgnoreSignals = self.saved_utils_ignoresignals
246 self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
248 def testOversizedDatagram(self):
249 oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
250 self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
251 self.address, self.port, oversized_data)
254 class TestAsyncIP4UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
255 """Test IP4 daemon.AsyncUDPSocket"""
257 family = socket.AF_INET
258 address = "127.0.0.1"
261 testutils.GanetiTestCase.setUp(self)
262 _BaseAsyncUDPSocketTest.setUp(self)
265 testutils.GanetiTestCase.tearDown(self)
266 _BaseAsyncUDPSocketTest.tearDown(self)
269 class TestAsyncIP6UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
270 """Test IP6 daemon.AsyncUDPSocket"""
272 family = socket.AF_INET6
276 testutils.GanetiTestCase.setUp(self)
277 _BaseAsyncUDPSocketTest.setUp(self)
280 testutils.GanetiTestCase.tearDown(self)
281 _BaseAsyncUDPSocketTest.tearDown(self)
284 class _MyAsyncStreamServer(daemon.AsyncStreamServer):
286 def __init__(self, family, address, handle_connection_fn):
287 daemon.AsyncStreamServer.__init__(self, family, address)
288 self.handle_connection_fn = handle_connection_fn
292 def handle_connection(self, connected_socket, client_address):
293 self.handle_connection_fn(connected_socket, client_address)
295 def handle_error(self):
296 self.error_count += 1
300 def handle_expt(self):
305 class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
307 def __init__(self, connected_socket, client_address, terminator, family,
308 message_fn, client_id, unhandled_limit):
309 daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
313 self.message_fn = message_fn
314 self.client_id = client_id
317 def handle_message(self, message, message_id):
318 self.message_fn(self, message, message_id)
320 def handle_error(self):
321 self.error_count += 1
325 class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
326 """Test daemon.AsyncStreamServer with a TCP connection"""
328 family = socket.AF_INET
331 testutils.GanetiTestCase.setUp(self)
332 self.mainloop = daemon.Mainloop()
333 self.address = self.getAddress()
334 self.server = _MyAsyncStreamServer(self.family, self.address,
335 self.handle_connection)
336 self.client_handler = _MyMessageStreamHandler
337 self.unhandled_limit = None
338 self.terminator = "\3"
339 self.address = self.server.getsockname()
341 self.connections = []
343 self.connect_terminate_count = 0
344 self.message_terminate_count = 0
345 self.next_client_id = 0
346 # Save utils.IgnoreSignals so we can do evil things to it...
347 self.saved_utils_ignoresignals = utils.IgnoreSignals
350 for c in self.clients:
352 for c in self.connections:
355 # ...and restore it as well
356 utils.IgnoreSignals = self.saved_utils_ignoresignals
357 testutils.GanetiTestCase.tearDown(self)
359 def getAddress(self):
360 return ("127.0.0.1", 0)
362 def countTerminate(self, name):
363 value = getattr(self, name)
364 if value is not None:
366 setattr(self, name, value)
368 os.kill(os.getpid(), signal.SIGTERM)
370 def handle_connection(self, connected_socket, client_address):
371 client_id = self.next_client_id
372 self.next_client_id += 1
373 client_handler = self.client_handler(connected_socket, client_address,
374 self.terminator, self.family,
376 client_id, self.unhandled_limit)
377 self.connections.append(client_handler)
378 self.countTerminate("connect_terminate_count")
380 def handle_message(self, handler, message, message_id):
381 self.messages.setdefault(handler.client_id, [])
382 # We should just check that the message_ids are monotonically increasing.
383 # If in the unit tests we never remove messages from the received queue,
384 # though, we can just require that the queue length is the same as the
385 # message id, before pushing the message to it. This forces a more
386 # restrictive check, but we can live with this for now.
387 self.assertEquals(len(self.messages[handler.client_id]), message_id)
388 self.messages[handler.client_id].append(message)
389 if message == "error":
390 raise errors.GenericError("error")
391 self.countTerminate("message_terminate_count")
394 client = socket.socket(self.family, socket.SOCK_STREAM)
395 client.connect(self.address)
396 self.clients.append(client)
400 testutils.GanetiTestCase.tearDown(self)
403 def testConnect(self):
406 self.assertEquals(len(self.connections), 1)
409 self.assertEquals(len(self.connections), 2)
410 self.connect_terminate_count = 4
416 self.assertEquals(len(self.connections), 6)
418 def testBasicMessage(self):
419 self.connect_terminate_count = None
420 client = self.getClient()
421 client.send("ciao\3")
423 self.assertEquals(len(self.connections), 1)
424 self.assertEquals(len(self.messages[0]), 1)
425 self.assertEquals(self.messages[0][0], "ciao")
427 def testDoubleMessage(self):
428 self.connect_terminate_count = None
429 client = self.getClient()
430 client.send("ciao\3")
432 client.send("foobar\3")
434 self.assertEquals(len(self.connections), 1)
435 self.assertEquals(len(self.messages[0]), 2)
436 self.assertEquals(self.messages[0][1], "foobar")
438 def testComposedMessage(self):
439 self.connect_terminate_count = None
440 self.message_terminate_count = 3
441 client = self.getClient()
442 client.send("one\3composed\3message\3")
444 self.assertEquals(len(self.messages[0]), 3)
445 self.assertEquals(self.messages[0], ["one", "composed", "message"])
447 def testLongTerminator(self):
448 self.terminator = "\0\1\2"
449 self.connect_terminate_count = None
450 self.message_terminate_count = 3
451 client = self.getClient()
452 client.send("one\0\1\2composed\0\1\2message\0\1\2")
454 self.assertEquals(len(self.messages[0]), 3)
455 self.assertEquals(self.messages[0], ["one", "composed", "message"])
457 def testErrorHandling(self):
458 self.connect_terminate_count = None
459 self.message_terminate_count = None
460 client = self.getClient()
461 client.send("one\3two\3error\3three\3")
462 self.assertRaises(errors.GenericError, self.mainloop.Run)
463 self.assertEquals(self.connections[0].error_count, 1)
464 self.assertEquals(self.messages[0], ["one", "two", "error"])
465 client.send("error\3")
466 self.assertRaises(errors.GenericError, self.mainloop.Run)
467 self.assertEquals(self.connections[0].error_count, 2)
468 self.assertEquals(self.messages[0], ["one", "two", "error", "three",
471 def testDoubleClient(self):
472 self.connect_terminate_count = None
473 self.message_terminate_count = 2
474 client1 = self.getClient()
475 client2 = self.getClient()
476 client1.send("c1m1\3")
477 client2.send("c2m1\3")
479 self.assertEquals(self.messages[0], ["c1m1"])
480 self.assertEquals(self.messages[1], ["c2m1"])
482 def testUnterminatedMessage(self):
483 self.connect_terminate_count = None
484 self.message_terminate_count = 3
485 client1 = self.getClient()
486 client2 = self.getClient()
487 client1.send("message\3unterminated")
488 client2.send("c2m1\3c2m2\3")
490 self.assertEquals(self.messages[0], ["message"])
491 self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
492 client1.send("message\3")
494 self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
496 def testSignaledWhileAccepting(self):
497 utils.IgnoreSignals = lambda fn, *args, **kwargs: None
498 client1 = self.getClient()
499 self.server.handle_accept()
500 # When interrupted while accepting we don't have a connection, but we
501 # didn't crash either.
502 self.assertEquals(len(self.connections), 0)
503 utils.IgnoreSignals = self.saved_utils_ignoresignals
505 self.assertEquals(len(self.connections), 1)
507 def testSendMessage(self):
508 self.connect_terminate_count = None
509 self.message_terminate_count = 3
510 client1 = self.getClient()
511 client2 = self.getClient()
512 client1.send("one\3composed\3message\3")
514 self.assertEquals(self.messages[0], ["one", "composed", "message"])
515 self.assertFalse(self.connections[0].writable())
516 self.assertFalse(self.connections[1].writable())
517 self.connections[0].send_message("r0")
518 self.assert_(self.connections[0].writable())
519 self.assertFalse(self.connections[1].writable())
520 self.connections[0].send_message("r1")
521 self.connections[0].send_message("r2")
522 # We currently have no way to terminate the mainloop on write events, but
523 # let's assume handle_write will be called if writable() is True.
524 while self.connections[0].writable():
525 self.connections[0].handle_write()
526 client1.setblocking(0)
527 client2.setblocking(0)
528 self.assertEquals(client1.recv(4096), "r0\3r1\3r2\3")
529 self.assertRaises(socket.error, client2.recv, 4096)
531 def testLimitedUnhandledMessages(self):
532 self.connect_terminate_count = None
533 self.message_terminate_count = 3
534 self.unhandled_limit = 2
535 client1 = self.getClient()
536 client2 = self.getClient()
537 client1.send("one\3composed\3long\3message\3")
538 client2.send("c2one\3")
540 self.assertEquals(self.messages[0], ["one", "composed"])
541 self.assertEquals(self.messages[1], ["c2one"])
542 self.assertFalse(self.connections[0].readable())
543 self.assert_(self.connections[1].readable())
544 self.connections[0].send_message("r0")
545 self.message_terminate_count = None
546 client1.send("another\3")
547 # when we write replies messages queued also get handled, but not the ones
549 while self.connections[0].writable():
550 self.connections[0].handle_write()
551 self.assertFalse(self.connections[0].readable())
552 self.assertEquals(self.messages[0], ["one", "composed", "long"])
553 self.connections[0].send_message("r1")
554 self.connections[0].send_message("r2")
555 while self.connections[0].writable():
556 self.connections[0].handle_write()
557 self.assertEquals(self.messages[0], ["one", "composed", "long", "message"])
558 self.assert_(self.connections[0].readable())
560 def testLimitedUnhandledMessagesOne(self):
561 self.connect_terminate_count = None
562 self.message_terminate_count = 2
563 self.unhandled_limit = 1
564 client1 = self.getClient()
565 client2 = self.getClient()
566 client1.send("one\3composed\3message\3")
567 client2.send("c2one\3")
569 self.assertEquals(self.messages[0], ["one"])
570 self.assertEquals(self.messages[1], ["c2one"])
571 self.assertFalse(self.connections[0].readable())
572 self.assertFalse(self.connections[1].readable())
573 self.connections[0].send_message("r0")
574 self.message_terminate_count = None
575 while self.connections[0].writable():
576 self.connections[0].handle_write()
577 self.assertFalse(self.connections[0].readable())
578 self.assertEquals(self.messages[0], ["one", "composed"])
579 self.connections[0].send_message("r2")
580 self.connections[0].send_message("r3")
581 while self.connections[0].writable():
582 self.connections[0].handle_write()
583 self.assertEquals(self.messages[0], ["one", "composed", "message"])
584 self.assert_(self.connections[0].readable())
587 class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
588 """Test daemon.AsyncStreamServer with a Unix path connection"""
590 family = socket.AF_UNIX
592 def getAddress(self):
593 self.tmpdir = tempfile.mkdtemp()
594 return os.path.join(self.tmpdir, "server.sock")
597 shutil.rmtree(self.tmpdir)
598 TestAsyncStreamServerTCP.tearDown(self)
601 class TestAsyncAwaker(testutils.GanetiTestCase):
602 """Test daemon.AsyncAwaker"""
604 family = socket.AF_INET
607 testutils.GanetiTestCase.setUp(self)
608 self.mainloop = daemon.Mainloop()
609 self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
610 self.signal_count = 0
611 self.signal_terminate_count = 1
616 def handle_signal(self):
617 self.signal_count += 1
618 self.signal_terminate_count -= 1
619 if self.signal_terminate_count <= 0:
620 os.kill(os.getpid(), signal.SIGTERM)
622 def testBasicSignaling(self):
625 self.assertEquals(self.signal_count, 1)
627 def testDoubleSignaling(self):
631 # The second signal is never delivered
632 self.assertEquals(self.signal_count, 1)
634 def testReallyDoubleSignaling(self):
635 self.assert_(self.awaker.readable())
637 # Let's suppose two threads overlap, and both find need_signal True
638 self.awaker.need_signal = True
641 # We still get only one signaling
642 self.assertEquals(self.signal_count, 1)
644 def testNoSignalFnArgument(self):
645 myawaker = daemon.AsyncAwaker()
646 self.assertRaises(socket.error, myawaker.handle_read)
648 myawaker.handle_read()
649 self.assertRaises(socket.error, myawaker.handle_read)
652 myawaker.handle_read()
653 self.assertRaises(socket.error, myawaker.handle_read)
656 def testWrongSignalFnArgument(self):
657 self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
658 self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
659 self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
660 self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
663 if __name__ == "__main__":
664 testutils.GanetiTestProgram()