daemon.AsyncAwaker
[ganeti-local] / test / ganeti.daemon_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the daemon module"""
23
24 import unittest
25 import signal
26 import os
27 import socket
28 import time
29 import tempfile
30 import shutil
31
32 from ganeti import daemon
33 from ganeti import errors
34 from ganeti import utils
35
36 import testutils
37
38
39 class TestMainloop(testutils.GanetiTestCase):
40   """Test daemon.Mainloop"""
41
42   def setUp(self):
43     testutils.GanetiTestCase.setUp(self)
44     self.mainloop = daemon.Mainloop()
45     self.sendsig_events = []
46     self.onsignal_events = []
47
48   def _CancelEvent(self, handle):
49     self.mainloop.scheduler.cancel(handle)
50
51   def _SendSig(self, sig):
52     self.sendsig_events.append(sig)
53     os.kill(os.getpid(), sig)
54
55   def OnSignal(self, signum):
56     self.onsignal_events.append(signum)
57
58   def testRunAndTermBySched(self):
59     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
60     self.mainloop.Run() # terminates by _SendSig being scheduled
61     self.assertEquals(self.sendsig_events, [signal.SIGTERM])
62
63   def testTerminatingSignals(self):
64     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
65     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
66     self.mainloop.Run()
67     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
68     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
69     self.mainloop.Run()
70     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
71                                             signal.SIGTERM])
72
73   def testSchedulerCancel(self):
74     handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
75                                            [signal.SIGTERM])
76     self.mainloop.scheduler.cancel(handle)
77     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
78     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
79     self.mainloop.Run()
80     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
81
82   def testRegisterSignal(self):
83     self.mainloop.RegisterSignal(self)
84     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
85     handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
86                                            [signal.SIGTERM])
87     self.mainloop.scheduler.cancel(handle)
88     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
89     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
90     # ...not delievered because they are scheduled after TERM
91     self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
92     self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
93     self.mainloop.Run()
94     self.assertEquals(self.sendsig_events,
95                       [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
96     self.assertEquals(self.onsignal_events, self.sendsig_events)
97
98   def testDeferredCancel(self):
99     self.mainloop.RegisterSignal(self)
100     now = time.time()
101     self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
102                                      [signal.SIGCHLD])
103     handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
104                                                [signal.SIGCHLD])
105     handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
106                                                [signal.SIGCHLD])
107     self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
108                                      [handle1])
109     self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
110                                      [handle2])
111     self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
112     self.mainloop.Run()
113     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
114     self.assertEquals(self.onsignal_events, self.sendsig_events)
115
116   def testReRun(self):
117     self.mainloop.RegisterSignal(self)
118     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
119     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
120     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
121     self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
122     self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
123     self.mainloop.Run()
124     self.assertEquals(self.sendsig_events,
125                       [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
126     self.assertEquals(self.onsignal_events, self.sendsig_events)
127     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
128     self.mainloop.Run()
129     self.assertEquals(self.sendsig_events,
130                       [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
131                        signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
132     self.assertEquals(self.onsignal_events, self.sendsig_events)
133
134   def testPriority(self):
135     # for events at the same time, the highest priority one executes first
136     now = time.time()
137     self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
138                                      [signal.SIGCHLD])
139     self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
140                                      [signal.SIGTERM])
141     self.mainloop.Run()
142     self.assertEquals(self.sendsig_events, [signal.SIGTERM])
143     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
144     self.mainloop.Run()
145     self.assertEquals(self.sendsig_events,
146                       [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
147
148
149 class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
150
151   def __init__(self):
152     daemon.AsyncUDPSocket.__init__(self)
153     self.received = []
154     self.error_count = 0
155
156   def handle_datagram(self, payload, ip, port):
157     self.received.append((payload))
158     if payload == "terminate":
159       os.kill(os.getpid(), signal.SIGTERM)
160     elif payload == "error":
161       raise errors.GenericError("error")
162
163   def handle_error(self):
164     self.error_count += 1
165     raise
166
167
168 class TestAsyncUDPSocket(testutils.GanetiTestCase):
169   """Test daemon.AsyncUDPSocket"""
170
171   def setUp(self):
172     testutils.GanetiTestCase.setUp(self)
173     self.mainloop = daemon.Mainloop()
174     self.server = _MyAsyncUDPSocket()
175     self.client = _MyAsyncUDPSocket()
176     self.server.bind(("127.0.0.1", 0))
177     self.port = self.server.getsockname()[1]
178     # Save utils.IgnoreSignals so we can do evil things to it...
179     self.saved_utils_ignoresignals = utils.IgnoreSignals
180
181   def tearDown(self):
182     self.server.close()
183     self.client.close()
184     # ...and restore it as well
185     utils.IgnoreSignals = self.saved_utils_ignoresignals
186     testutils.GanetiTestCase.tearDown(self)
187
188   def testNoDoubleBind(self):
189     self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
190
191   def _ThreadedClient(self, payload):
192     self.client.enqueue_send("127.0.0.1", self.port, payload)
193     print "sending %s" % payload
194     while self.client.writable():
195       self.client.handle_write()
196
197   def testAsyncClientServer(self):
198     self.client.enqueue_send("127.0.0.1", self.port, "p1")
199     self.client.enqueue_send("127.0.0.1", self.port, "p2")
200     self.client.enqueue_send("127.0.0.1", self.port, "terminate")
201     self.mainloop.Run()
202     self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
203
204   def testSyncClientServer(self):
205     self.client.enqueue_send("127.0.0.1", self.port, "p1")
206     self.client.enqueue_send("127.0.0.1", self.port, "p2")
207     while self.client.writable():
208       self.client.handle_write()
209     self.server.process_next_packet()
210     self.assertEquals(self.server.received, ["p1"])
211     self.server.process_next_packet()
212     self.assertEquals(self.server.received, ["p1", "p2"])
213     self.client.enqueue_send("127.0.0.1", self.port, "p3")
214     while self.client.writable():
215       self.client.handle_write()
216     self.server.process_next_packet()
217     self.assertEquals(self.server.received, ["p1", "p2", "p3"])
218
219   def testErrorHandling(self):
220     self.client.enqueue_send("127.0.0.1", self.port, "p1")
221     self.client.enqueue_send("127.0.0.1", self.port, "p2")
222     self.client.enqueue_send("127.0.0.1", self.port, "error")
223     self.client.enqueue_send("127.0.0.1", self.port, "p3")
224     self.client.enqueue_send("127.0.0.1", self.port, "error")
225     self.client.enqueue_send("127.0.0.1", self.port, "terminate")
226     self.assertRaises(errors.GenericError, self.mainloop.Run)
227     self.assertEquals(self.server.received,
228                       ["p1", "p2", "error"])
229     self.assertEquals(self.server.error_count, 1)
230     self.assertRaises(errors.GenericError, self.mainloop.Run)
231     self.assertEquals(self.server.received,
232                       ["p1", "p2", "error", "p3", "error"])
233     self.assertEquals(self.server.error_count, 2)
234     self.mainloop.Run()
235     self.assertEquals(self.server.received,
236                       ["p1", "p2", "error", "p3", "error", "terminate"])
237     self.assertEquals(self.server.error_count, 2)
238
239   def testSignaledWhileReceiving(self):
240     utils.IgnoreSignals = lambda fn, *args, **kwargs: None
241     self.client.enqueue_send("127.0.0.1", self.port, "p1")
242     self.client.enqueue_send("127.0.0.1", self.port, "p2")
243     self.server.handle_read()
244     self.assertEquals(self.server.received, [])
245     self.client.enqueue_send("127.0.0.1", self.port, "terminate")
246     utils.IgnoreSignals = self.saved_utils_ignoresignals
247     self.mainloop.Run()
248     self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
249
250
251 class _MyAsyncStreamServer(daemon.AsyncStreamServer):
252
253   def __init__(self, family, address, handle_connection_fn):
254     daemon.AsyncStreamServer.__init__(self, family, address)
255     self.handle_connection_fn = handle_connection_fn
256     self.error_count = 0
257     self.expt_count = 0
258
259   def handle_connection(self, connected_socket, client_address):
260     self.handle_connection_fn(connected_socket, client_address)
261
262   def handle_error(self):
263     self.error_count += 1
264     self.close()
265     raise
266
267   def handle_expt(self):
268     self.expt_count += 1
269     self.close()
270
271
272 class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
273
274   def __init__(self, connected_socket, client_address, terminator, family,
275                message_fn, client_id):
276     daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
277                                                  client_address,
278                                                  terminator, family)
279     self.message_fn = message_fn
280     self.client_id = client_id
281     self.error_count = 0
282
283   def handle_message(self, message, message_id):
284     self.message_fn(self, message, message_id)
285
286   def handle_error(self):
287     self.error_count += 1
288     raise
289
290
291 class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
292   """Test daemon.AsyncStreamServer with a TCP connection"""
293
294   family = socket.AF_INET
295
296   def setUp(self):
297     testutils.GanetiTestCase.setUp(self)
298     self.mainloop = daemon.Mainloop()
299     self.address = self.getAddress()
300     self.server = _MyAsyncStreamServer(self.family, self.address,
301                                        self.handle_connection)
302     self.client_handler = _MyMessageStreamHandler
303     self.terminator = "\3"
304     self.address = self.server.getsockname()
305     self.clients = []
306     self.connections = []
307     self.messages = {}
308     self.connect_terminate_count = 0
309     self.message_terminate_count = 0
310     self.next_client_id = 0
311     # Save utils.IgnoreSignals so we can do evil things to it...
312     self.saved_utils_ignoresignals = utils.IgnoreSignals
313
314   def tearDown(self):
315     for c in self.clients:
316       c.close()
317     for c in self.connections:
318       c.close()
319     self.server.close()
320     # ...and restore it as well
321     utils.IgnoreSignals = self.saved_utils_ignoresignals
322     testutils.GanetiTestCase.tearDown(self)
323
324   def getAddress(self):
325     return ("127.0.0.1", 0)
326
327   def countTerminate(self, name):
328     value = getattr(self, name)
329     if value is not None:
330       value -= 1
331       setattr(self, name, value)
332       if value <= 0:
333         os.kill(os.getpid(), signal.SIGTERM)
334
335   def handle_connection(self, connected_socket, client_address):
336     client_id = self.next_client_id
337     self.next_client_id += 1
338     client_handler = self.client_handler(connected_socket, client_address,
339                                          self.terminator, self.family,
340                                          self.handle_message,
341                                          client_id)
342     self.connections.append(client_handler)
343     self.countTerminate("connect_terminate_count")
344
345   def handle_message(self, handler, message, message_id):
346     self.messages.setdefault(handler.client_id, [])
347     # We should just check that the message_ids are monotonically increasing.
348     # If in the unit tests we never remove messages from the received queue,
349     # though, we can just require that the queue length is the same as the
350     # message id, before pushing the message to it. This forces a more
351     # restrictive check, but we can live with this for now.
352     self.assertEquals(len(self.messages[handler.client_id]), message_id)
353     self.messages[handler.client_id].append(message)
354     if message == "error":
355       raise errors.GenericError("error")
356     self.countTerminate("message_terminate_count")
357
358   def getClient(self):
359     client = socket.socket(self.family, socket.SOCK_STREAM)
360     client.connect(self.address)
361     self.clients.append(client)
362     return client
363
364   def tearDown(self):
365     testutils.GanetiTestCase.tearDown(self)
366     self.server.close()
367
368   def testConnect(self):
369     self.getClient()
370     self.mainloop.Run()
371     self.assertEquals(len(self.connections), 1)
372     self.getClient()
373     self.mainloop.Run()
374     self.assertEquals(len(self.connections), 2)
375     self.connect_terminate_count = 4
376     self.getClient()
377     self.getClient()
378     self.getClient()
379     self.getClient()
380     self.mainloop.Run()
381     self.assertEquals(len(self.connections), 6)
382
383   def testBasicMessage(self):
384     self.connect_terminate_count = None
385     client = self.getClient()
386     client.send("ciao\3")
387     self.mainloop.Run()
388     self.assertEquals(len(self.connections), 1)
389     self.assertEquals(len(self.messages[0]), 1)
390     self.assertEquals(self.messages[0][0], "ciao")
391
392   def testDoubleMessage(self):
393     self.connect_terminate_count = None
394     client = self.getClient()
395     client.send("ciao\3")
396     self.mainloop.Run()
397     client.send("foobar\3")
398     self.mainloop.Run()
399     self.assertEquals(len(self.connections), 1)
400     self.assertEquals(len(self.messages[0]), 2)
401     self.assertEquals(self.messages[0][1], "foobar")
402
403   def testComposedMessage(self):
404     self.connect_terminate_count = None
405     self.message_terminate_count = 3
406     client = self.getClient()
407     client.send("one\3composed\3message\3")
408     self.mainloop.Run()
409     self.assertEquals(len(self.messages[0]), 3)
410     self.assertEquals(self.messages[0], ["one", "composed", "message"])
411
412   def testLongTerminator(self):
413     self.terminator = "\0\1\2"
414     self.connect_terminate_count = None
415     self.message_terminate_count = 3
416     client = self.getClient()
417     client.send("one\0\1\2composed\0\1\2message\0\1\2")
418     self.mainloop.Run()
419     self.assertEquals(len(self.messages[0]), 3)
420     self.assertEquals(self.messages[0], ["one", "composed", "message"])
421
422   def testErrorHandling(self):
423     self.connect_terminate_count = None
424     self.message_terminate_count = None
425     client = self.getClient()
426     client.send("one\3two\3error\3three\3")
427     self.assertRaises(errors.GenericError, self.mainloop.Run)
428     self.assertEquals(self.connections[0].error_count, 1)
429     self.assertEquals(self.messages[0], ["one", "two", "error"])
430     client.send("error\3")
431     self.assertRaises(errors.GenericError, self.mainloop.Run)
432     self.assertEquals(self.connections[0].error_count, 2)
433     self.assertEquals(self.messages[0], ["one", "two", "error", "three",
434                                          "error"])
435
436   def testDoubleClient(self):
437     self.connect_terminate_count = None
438     self.message_terminate_count = 2
439     client1 = self.getClient()
440     client2 = self.getClient()
441     client1.send("c1m1\3")
442     client2.send("c2m1\3")
443     self.mainloop.Run()
444     self.assertEquals(self.messages[0], ["c1m1"])
445     self.assertEquals(self.messages[1], ["c2m1"])
446
447   def testUnterminatedMessage(self):
448     self.connect_terminate_count = None
449     self.message_terminate_count = 3
450     client1 = self.getClient()
451     client2 = self.getClient()
452     client1.send("message\3unterminated")
453     client2.send("c2m1\3c2m2\3")
454     self.mainloop.Run()
455     self.assertEquals(self.messages[0], ["message"])
456     self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
457     client1.send("message\3")
458     self.mainloop.Run()
459     self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
460
461   def testSignaledWhileAccepting(self):
462     utils.IgnoreSignals = lambda fn, *args, **kwargs: None
463     client1 = self.getClient()
464     self.server.handle_accept()
465     # When interrupted while accepting we don't have a connection, but we
466     # didn't crash either.
467     self.assertEquals(len(self.connections), 0)
468     utils.IgnoreSignals = self.saved_utils_ignoresignals
469     self.mainloop.Run()
470     self.assertEquals(len(self.connections), 1)
471
472
473 class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
474   """Test daemon.AsyncStreamServer with a Unix path connection"""
475
476   family = socket.AF_UNIX
477
478   def getAddress(self):
479     self.tmpdir = tempfile.mkdtemp()
480     return os.path.join(self.tmpdir, "server.sock")
481
482   def tearDown(self):
483     shutil.rmtree(self.tmpdir)
484     TestAsyncStreamServerTCP.tearDown(self)
485
486
487 class TestAsyncAwaker(testutils.GanetiTestCase):
488   """Test daemon.AsyncAwaker"""
489
490   family = socket.AF_INET
491
492   def setUp(self):
493     testutils.GanetiTestCase.setUp(self)
494     self.mainloop = daemon.Mainloop()
495     self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
496     self.signal_count = 0
497     self.signal_terminate_count = 1
498
499   def tearDown(self):
500     self.awaker.close()
501
502   def handle_signal(self):
503     self.signal_count += 1
504     self.signal_terminate_count -= 1
505     if self.signal_terminate_count <= 0:
506       os.kill(os.getpid(), signal.SIGTERM)
507
508   def testBasicSignaling(self):
509     self.awaker.signal()
510     self.mainloop.Run()
511     self.assertEquals(self.signal_count, 1)
512
513   def testDoubleSignaling(self):
514     self.awaker.signal()
515     self.awaker.signal()
516     self.mainloop.Run()
517     # The second signal is never delivered
518     self.assertEquals(self.signal_count, 1)
519
520   def testReallyDoubleSignaling(self):
521     self.assert_(self.awaker.readable())
522     self.awaker.signal()
523     # Let's suppose two threads overlap, and both find need_signal True
524     self.awaker.need_signal = True
525     self.awaker.signal()
526     self.mainloop.Run()
527     # We still get only one signaling
528     self.assertEquals(self.signal_count, 1)
529
530   def testNoSignalFnArgument(self):
531     myawaker = daemon.AsyncAwaker()
532     self.assertRaises(socket.error, myawaker.handle_read)
533     myawaker.signal()
534     myawaker.handle_read()
535     self.assertRaises(socket.error, myawaker.handle_read)
536     myawaker.signal()
537     myawaker.signal()
538     myawaker.handle_read()
539     self.assertRaises(socket.error, myawaker.handle_read)
540     myawaker.close()
541
542   def testWrongSignalFnArgument(self):
543     self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
544     self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
545     self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
546     self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
547
548
549 if __name__ == "__main__":
550   testutils.GanetiTestProgram()