4 # Copyright (C) 2010 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Script for unittesting the daemon module"""
32 from ganeti import daemon
33 from ganeti import errors
34 from ganeti import constants
35 from ganeti import utils
40 class TestMainloop(testutils.GanetiTestCase):
41 """Test daemon.Mainloop"""
44 testutils.GanetiTestCase.setUp(self)
45 self.mainloop = daemon.Mainloop()
46 self.sendsig_events = []
47 self.onsignal_events = []
49 def _CancelEvent(self, handle):
50 self.mainloop.scheduler.cancel(handle)
52 def _SendSig(self, sig):
53 self.sendsig_events.append(sig)
54 os.kill(os.getpid(), sig)
56 def OnSignal(self, signum):
57 self.onsignal_events.append(signum)
59 def testRunAndTermBySched(self):
60 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
61 self.mainloop.Run() # terminates by _SendSig being scheduled
62 self.assertEquals(self.sendsig_events, [signal.SIGTERM])
64 def testTerminatingSignals(self):
65 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
66 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
68 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
69 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
71 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
74 def testSchedulerCancel(self):
75 handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
77 self.mainloop.scheduler.cancel(handle)
78 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
79 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
81 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
83 def testRegisterSignal(self):
84 self.mainloop.RegisterSignal(self)
85 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
86 handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
88 self.mainloop.scheduler.cancel(handle)
89 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
90 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
91 # ...not delievered because they are scheduled after TERM
92 self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
93 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
95 self.assertEquals(self.sendsig_events,
96 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
97 self.assertEquals(self.onsignal_events, self.sendsig_events)
99 def testDeferredCancel(self):
100 self.mainloop.RegisterSignal(self)
102 self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
104 handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
106 handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
108 self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
110 self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
112 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
114 self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
115 self.assertEquals(self.onsignal_events, self.sendsig_events)
118 self.mainloop.RegisterSignal(self)
119 self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
120 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
121 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
122 self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
123 self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
125 self.assertEquals(self.sendsig_events,
126 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
127 self.assertEquals(self.onsignal_events, self.sendsig_events)
128 self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
130 self.assertEquals(self.sendsig_events,
131 [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
132 signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
133 self.assertEquals(self.onsignal_events, self.sendsig_events)
135 def testPriority(self):
136 # for events at the same time, the highest priority one executes first
138 self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
140 self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
143 self.assertEquals(self.sendsig_events, [signal.SIGTERM])
144 self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
146 self.assertEquals(self.sendsig_events,
147 [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
150 class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
153 daemon.AsyncUDPSocket.__init__(self)
157 def handle_datagram(self, payload, ip, port):
158 self.received.append((payload))
159 if payload == "terminate":
160 os.kill(os.getpid(), signal.SIGTERM)
161 elif payload == "error":
162 raise errors.GenericError("error")
164 def handle_error(self):
165 self.error_count += 1
169 class TestAsyncUDPSocket(testutils.GanetiTestCase):
170 """Test daemon.AsyncUDPSocket"""
173 testutils.GanetiTestCase.setUp(self)
174 self.mainloop = daemon.Mainloop()
175 self.server = _MyAsyncUDPSocket()
176 self.client = _MyAsyncUDPSocket()
177 self.server.bind(("127.0.0.1", 0))
178 self.port = self.server.getsockname()[1]
179 # Save utils.IgnoreSignals so we can do evil things to it...
180 self.saved_utils_ignoresignals = utils.IgnoreSignals
185 # ...and restore it as well
186 utils.IgnoreSignals = self.saved_utils_ignoresignals
187 testutils.GanetiTestCase.tearDown(self)
189 def testNoDoubleBind(self):
190 self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
192 def testAsyncClientServer(self):
193 self.client.enqueue_send("127.0.0.1", self.port, "p1")
194 self.client.enqueue_send("127.0.0.1", self.port, "p2")
195 self.client.enqueue_send("127.0.0.1", self.port, "terminate")
197 self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
199 def testSyncClientServer(self):
200 self.client.handle_write()
201 self.client.enqueue_send("127.0.0.1", self.port, "p1")
202 self.client.enqueue_send("127.0.0.1", self.port, "p2")
203 while self.client.writable():
204 self.client.handle_write()
205 self.server.process_next_packet()
206 self.assertEquals(self.server.received, ["p1"])
207 self.server.process_next_packet()
208 self.assertEquals(self.server.received, ["p1", "p2"])
209 self.client.enqueue_send("127.0.0.1", self.port, "p3")
210 while self.client.writable():
211 self.client.handle_write()
212 self.server.process_next_packet()
213 self.assertEquals(self.server.received, ["p1", "p2", "p3"])
215 def testErrorHandling(self):
216 self.client.enqueue_send("127.0.0.1", self.port, "p1")
217 self.client.enqueue_send("127.0.0.1", self.port, "p2")
218 self.client.enqueue_send("127.0.0.1", self.port, "error")
219 self.client.enqueue_send("127.0.0.1", self.port, "p3")
220 self.client.enqueue_send("127.0.0.1", self.port, "error")
221 self.client.enqueue_send("127.0.0.1", self.port, "terminate")
222 self.assertRaises(errors.GenericError, self.mainloop.Run)
223 self.assertEquals(self.server.received,
224 ["p1", "p2", "error"])
225 self.assertEquals(self.server.error_count, 1)
226 self.assertRaises(errors.GenericError, self.mainloop.Run)
227 self.assertEquals(self.server.received,
228 ["p1", "p2", "error", "p3", "error"])
229 self.assertEquals(self.server.error_count, 2)
231 self.assertEquals(self.server.received,
232 ["p1", "p2", "error", "p3", "error", "terminate"])
233 self.assertEquals(self.server.error_count, 2)
235 def testSignaledWhileReceiving(self):
236 utils.IgnoreSignals = lambda fn, *args, **kwargs: None
237 self.client.enqueue_send("127.0.0.1", self.port, "p1")
238 self.client.enqueue_send("127.0.0.1", self.port, "p2")
239 self.server.handle_read()
240 self.assertEquals(self.server.received, [])
241 self.client.enqueue_send("127.0.0.1", self.port, "terminate")
242 utils.IgnoreSignals = self.saved_utils_ignoresignals
244 self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
246 def testOversizedDatagram(self):
247 oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
248 self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
249 "127.0.0.1", self.port, oversized_data)
252 class _MyAsyncStreamServer(daemon.AsyncStreamServer):
254 def __init__(self, family, address, handle_connection_fn):
255 daemon.AsyncStreamServer.__init__(self, family, address)
256 self.handle_connection_fn = handle_connection_fn
260 def handle_connection(self, connected_socket, client_address):
261 self.handle_connection_fn(connected_socket, client_address)
263 def handle_error(self):
264 self.error_count += 1
268 def handle_expt(self):
273 class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
275 def __init__(self, connected_socket, client_address, terminator, family,
276 message_fn, client_id):
277 daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
280 self.message_fn = message_fn
281 self.client_id = client_id
284 def handle_message(self, message, message_id):
285 self.message_fn(self, message, message_id)
287 def handle_error(self):
288 self.error_count += 1
292 class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
293 """Test daemon.AsyncStreamServer with a TCP connection"""
295 family = socket.AF_INET
298 testutils.GanetiTestCase.setUp(self)
299 self.mainloop = daemon.Mainloop()
300 self.address = self.getAddress()
301 self.server = _MyAsyncStreamServer(self.family, self.address,
302 self.handle_connection)
303 self.client_handler = _MyMessageStreamHandler
304 self.terminator = "\3"
305 self.address = self.server.getsockname()
307 self.connections = []
309 self.connect_terminate_count = 0
310 self.message_terminate_count = 0
311 self.next_client_id = 0
312 # Save utils.IgnoreSignals so we can do evil things to it...
313 self.saved_utils_ignoresignals = utils.IgnoreSignals
316 for c in self.clients:
318 for c in self.connections:
321 # ...and restore it as well
322 utils.IgnoreSignals = self.saved_utils_ignoresignals
323 testutils.GanetiTestCase.tearDown(self)
325 def getAddress(self):
326 return ("127.0.0.1", 0)
328 def countTerminate(self, name):
329 value = getattr(self, name)
330 if value is not None:
332 setattr(self, name, value)
334 os.kill(os.getpid(), signal.SIGTERM)
336 def handle_connection(self, connected_socket, client_address):
337 client_id = self.next_client_id
338 self.next_client_id += 1
339 client_handler = self.client_handler(connected_socket, client_address,
340 self.terminator, self.family,
343 self.connections.append(client_handler)
344 self.countTerminate("connect_terminate_count")
346 def handle_message(self, handler, message, message_id):
347 self.messages.setdefault(handler.client_id, [])
348 # We should just check that the message_ids are monotonically increasing.
349 # If in the unit tests we never remove messages from the received queue,
350 # though, we can just require that the queue length is the same as the
351 # message id, before pushing the message to it. This forces a more
352 # restrictive check, but we can live with this for now.
353 self.assertEquals(len(self.messages[handler.client_id]), message_id)
354 self.messages[handler.client_id].append(message)
355 if message == "error":
356 raise errors.GenericError("error")
357 self.countTerminate("message_terminate_count")
360 client = socket.socket(self.family, socket.SOCK_STREAM)
361 client.connect(self.address)
362 self.clients.append(client)
366 testutils.GanetiTestCase.tearDown(self)
369 def testConnect(self):
372 self.assertEquals(len(self.connections), 1)
375 self.assertEquals(len(self.connections), 2)
376 self.connect_terminate_count = 4
382 self.assertEquals(len(self.connections), 6)
384 def testBasicMessage(self):
385 self.connect_terminate_count = None
386 client = self.getClient()
387 client.send("ciao\3")
389 self.assertEquals(len(self.connections), 1)
390 self.assertEquals(len(self.messages[0]), 1)
391 self.assertEquals(self.messages[0][0], "ciao")
393 def testDoubleMessage(self):
394 self.connect_terminate_count = None
395 client = self.getClient()
396 client.send("ciao\3")
398 client.send("foobar\3")
400 self.assertEquals(len(self.connections), 1)
401 self.assertEquals(len(self.messages[0]), 2)
402 self.assertEquals(self.messages[0][1], "foobar")
404 def testComposedMessage(self):
405 self.connect_terminate_count = None
406 self.message_terminate_count = 3
407 client = self.getClient()
408 client.send("one\3composed\3message\3")
410 self.assertEquals(len(self.messages[0]), 3)
411 self.assertEquals(self.messages[0], ["one", "composed", "message"])
413 def testLongTerminator(self):
414 self.terminator = "\0\1\2"
415 self.connect_terminate_count = None
416 self.message_terminate_count = 3
417 client = self.getClient()
418 client.send("one\0\1\2composed\0\1\2message\0\1\2")
420 self.assertEquals(len(self.messages[0]), 3)
421 self.assertEquals(self.messages[0], ["one", "composed", "message"])
423 def testErrorHandling(self):
424 self.connect_terminate_count = None
425 self.message_terminate_count = None
426 client = self.getClient()
427 client.send("one\3two\3error\3three\3")
428 self.assertRaises(errors.GenericError, self.mainloop.Run)
429 self.assertEquals(self.connections[0].error_count, 1)
430 self.assertEquals(self.messages[0], ["one", "two", "error"])
431 client.send("error\3")
432 self.assertRaises(errors.GenericError, self.mainloop.Run)
433 self.assertEquals(self.connections[0].error_count, 2)
434 self.assertEquals(self.messages[0], ["one", "two", "error", "three",
437 def testDoubleClient(self):
438 self.connect_terminate_count = None
439 self.message_terminate_count = 2
440 client1 = self.getClient()
441 client2 = self.getClient()
442 client1.send("c1m1\3")
443 client2.send("c2m1\3")
445 self.assertEquals(self.messages[0], ["c1m1"])
446 self.assertEquals(self.messages[1], ["c2m1"])
448 def testUnterminatedMessage(self):
449 self.connect_terminate_count = None
450 self.message_terminate_count = 3
451 client1 = self.getClient()
452 client2 = self.getClient()
453 client1.send("message\3unterminated")
454 client2.send("c2m1\3c2m2\3")
456 self.assertEquals(self.messages[0], ["message"])
457 self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
458 client1.send("message\3")
460 self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
462 def testSignaledWhileAccepting(self):
463 utils.IgnoreSignals = lambda fn, *args, **kwargs: None
464 client1 = self.getClient()
465 self.server.handle_accept()
466 # When interrupted while accepting we don't have a connection, but we
467 # didn't crash either.
468 self.assertEquals(len(self.connections), 0)
469 utils.IgnoreSignals = self.saved_utils_ignoresignals
471 self.assertEquals(len(self.connections), 1)
474 class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
475 """Test daemon.AsyncStreamServer with a Unix path connection"""
477 family = socket.AF_UNIX
479 def getAddress(self):
480 self.tmpdir = tempfile.mkdtemp()
481 return os.path.join(self.tmpdir, "server.sock")
484 shutil.rmtree(self.tmpdir)
485 TestAsyncStreamServerTCP.tearDown(self)
488 class TestAsyncAwaker(testutils.GanetiTestCase):
489 """Test daemon.AsyncAwaker"""
491 family = socket.AF_INET
494 testutils.GanetiTestCase.setUp(self)
495 self.mainloop = daemon.Mainloop()
496 self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
497 self.signal_count = 0
498 self.signal_terminate_count = 1
503 def handle_signal(self):
504 self.signal_count += 1
505 self.signal_terminate_count -= 1
506 if self.signal_terminate_count <= 0:
507 os.kill(os.getpid(), signal.SIGTERM)
509 def testBasicSignaling(self):
512 self.assertEquals(self.signal_count, 1)
514 def testDoubleSignaling(self):
518 # The second signal is never delivered
519 self.assertEquals(self.signal_count, 1)
521 def testReallyDoubleSignaling(self):
522 self.assert_(self.awaker.readable())
524 # Let's suppose two threads overlap, and both find need_signal True
525 self.awaker.need_signal = True
528 # We still get only one signaling
529 self.assertEquals(self.signal_count, 1)
531 def testNoSignalFnArgument(self):
532 myawaker = daemon.AsyncAwaker()
533 self.assertRaises(socket.error, myawaker.handle_read)
535 myawaker.handle_read()
536 self.assertRaises(socket.error, myawaker.handle_read)
539 myawaker.handle_read()
540 self.assertRaises(socket.error, myawaker.handle_read)
543 def testWrongSignalFnArgument(self):
544 self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
545 self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
546 self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
547 self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
550 if __name__ == "__main__":
551 testutils.GanetiTestProgram()