472ceb70caa0347639a84967b4476e26d43fdb48
[ganeti-local] / test / ganeti.daemon_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the daemon module"""
23
24 import unittest
25 import signal
26 import os
27 import socket
28 import time
29 import tempfile
30 import shutil
31
32 from ganeti import daemon
33 from ganeti import errors
34 from ganeti import constants
35 from ganeti import utils
36
37 import testutils
38
39
40 class TestMainloop(testutils.GanetiTestCase):
41   """Test daemon.Mainloop"""
42
43   def setUp(self):
44     testutils.GanetiTestCase.setUp(self)
45     self.mainloop = daemon.Mainloop()
46     self.sendsig_events = []
47     self.onsignal_events = []
48
49   def _CancelEvent(self, handle):
50     self.mainloop.scheduler.cancel(handle)
51
52   def _SendSig(self, sig):
53     self.sendsig_events.append(sig)
54     os.kill(os.getpid(), sig)
55
56   def OnSignal(self, signum):
57     self.onsignal_events.append(signum)
58
59   def testRunAndTermBySched(self):
60     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
61     self.mainloop.Run() # terminates by _SendSig being scheduled
62     self.assertEquals(self.sendsig_events, [signal.SIGTERM])
63
64   def testTerminatingSignals(self):
65     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
66     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
67     self.mainloop.Run()
68     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
69     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
70     self.mainloop.Run()
71     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
72                                             signal.SIGTERM])
73
74   def testSchedulerCancel(self):
75     handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
76                                            [signal.SIGTERM])
77     self.mainloop.scheduler.cancel(handle)
78     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
79     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
80     self.mainloop.Run()
81     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
82
83   def testRegisterSignal(self):
84     self.mainloop.RegisterSignal(self)
85     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
86     handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
87                                            [signal.SIGTERM])
88     self.mainloop.scheduler.cancel(handle)
89     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
90     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
91     # ...not delievered because they are scheduled after TERM
92     self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
93     self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
94     self.mainloop.Run()
95     self.assertEquals(self.sendsig_events,
96                       [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
97     self.assertEquals(self.onsignal_events, self.sendsig_events)
98
99   def testDeferredCancel(self):
100     self.mainloop.RegisterSignal(self)
101     now = time.time()
102     self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
103                                      [signal.SIGCHLD])
104     handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
105                                                [signal.SIGCHLD])
106     handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
107                                                [signal.SIGCHLD])
108     self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
109                                      [handle1])
110     self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
111                                      [handle2])
112     self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
113     self.mainloop.Run()
114     self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
115     self.assertEquals(self.onsignal_events, self.sendsig_events)
116
117   def testReRun(self):
118     self.mainloop.RegisterSignal(self)
119     self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
120     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
121     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
122     self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
123     self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
124     self.mainloop.Run()
125     self.assertEquals(self.sendsig_events,
126                       [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
127     self.assertEquals(self.onsignal_events, self.sendsig_events)
128     self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
129     self.mainloop.Run()
130     self.assertEquals(self.sendsig_events,
131                       [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
132                        signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
133     self.assertEquals(self.onsignal_events, self.sendsig_events)
134
135   def testPriority(self):
136     # for events at the same time, the highest priority one executes first
137     now = time.time()
138     self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
139                                      [signal.SIGCHLD])
140     self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
141                                      [signal.SIGTERM])
142     self.mainloop.Run()
143     self.assertEquals(self.sendsig_events, [signal.SIGTERM])
144     self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
145     self.mainloop.Run()
146     self.assertEquals(self.sendsig_events,
147                       [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
148
149
150 class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
151
152   def __init__(self):
153     daemon.AsyncUDPSocket.__init__(self)
154     self.received = []
155     self.error_count = 0
156
157   def handle_datagram(self, payload, ip, port):
158     self.received.append((payload))
159     if payload == "terminate":
160       os.kill(os.getpid(), signal.SIGTERM)
161     elif payload == "error":
162       raise errors.GenericError("error")
163
164   def handle_error(self):
165     self.error_count += 1
166     raise
167
168
169 class TestAsyncUDPSocket(testutils.GanetiTestCase):
170   """Test daemon.AsyncUDPSocket"""
171
172   def setUp(self):
173     testutils.GanetiTestCase.setUp(self)
174     self.mainloop = daemon.Mainloop()
175     self.server = _MyAsyncUDPSocket()
176     self.client = _MyAsyncUDPSocket()
177     self.server.bind(("127.0.0.1", 0))
178     self.port = self.server.getsockname()[1]
179     # Save utils.IgnoreSignals so we can do evil things to it...
180     self.saved_utils_ignoresignals = utils.IgnoreSignals
181
182   def tearDown(self):
183     self.server.close()
184     self.client.close()
185     # ...and restore it as well
186     utils.IgnoreSignals = self.saved_utils_ignoresignals
187     testutils.GanetiTestCase.tearDown(self)
188
189   def testNoDoubleBind(self):
190     self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
191
192   def testAsyncClientServer(self):
193     self.client.enqueue_send("127.0.0.1", self.port, "p1")
194     self.client.enqueue_send("127.0.0.1", self.port, "p2")
195     self.client.enqueue_send("127.0.0.1", self.port, "terminate")
196     self.mainloop.Run()
197     self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
198
199   def testSyncClientServer(self):
200     self.client.handle_write()
201     self.client.enqueue_send("127.0.0.1", self.port, "p1")
202     self.client.enqueue_send("127.0.0.1", self.port, "p2")
203     while self.client.writable():
204       self.client.handle_write()
205     self.server.process_next_packet()
206     self.assertEquals(self.server.received, ["p1"])
207     self.server.process_next_packet()
208     self.assertEquals(self.server.received, ["p1", "p2"])
209     self.client.enqueue_send("127.0.0.1", self.port, "p3")
210     while self.client.writable():
211       self.client.handle_write()
212     self.server.process_next_packet()
213     self.assertEquals(self.server.received, ["p1", "p2", "p3"])
214
215   def testErrorHandling(self):
216     self.client.enqueue_send("127.0.0.1", self.port, "p1")
217     self.client.enqueue_send("127.0.0.1", self.port, "p2")
218     self.client.enqueue_send("127.0.0.1", self.port, "error")
219     self.client.enqueue_send("127.0.0.1", self.port, "p3")
220     self.client.enqueue_send("127.0.0.1", self.port, "error")
221     self.client.enqueue_send("127.0.0.1", self.port, "terminate")
222     self.assertRaises(errors.GenericError, self.mainloop.Run)
223     self.assertEquals(self.server.received,
224                       ["p1", "p2", "error"])
225     self.assertEquals(self.server.error_count, 1)
226     self.assertRaises(errors.GenericError, self.mainloop.Run)
227     self.assertEquals(self.server.received,
228                       ["p1", "p2", "error", "p3", "error"])
229     self.assertEquals(self.server.error_count, 2)
230     self.mainloop.Run()
231     self.assertEquals(self.server.received,
232                       ["p1", "p2", "error", "p3", "error", "terminate"])
233     self.assertEquals(self.server.error_count, 2)
234
235   def testSignaledWhileReceiving(self):
236     utils.IgnoreSignals = lambda fn, *args, **kwargs: None
237     self.client.enqueue_send("127.0.0.1", self.port, "p1")
238     self.client.enqueue_send("127.0.0.1", self.port, "p2")
239     self.server.handle_read()
240     self.assertEquals(self.server.received, [])
241     self.client.enqueue_send("127.0.0.1", self.port, "terminate")
242     utils.IgnoreSignals = self.saved_utils_ignoresignals
243     self.mainloop.Run()
244     self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
245
246   def testOversizedDatagram(self):
247     oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
248     self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
249                       "127.0.0.1", self.port, oversized_data)
250
251
252 class _MyAsyncStreamServer(daemon.AsyncStreamServer):
253
254   def __init__(self, family, address, handle_connection_fn):
255     daemon.AsyncStreamServer.__init__(self, family, address)
256     self.handle_connection_fn = handle_connection_fn
257     self.error_count = 0
258     self.expt_count = 0
259
260   def handle_connection(self, connected_socket, client_address):
261     self.handle_connection_fn(connected_socket, client_address)
262
263   def handle_error(self):
264     self.error_count += 1
265     self.close()
266     raise
267
268   def handle_expt(self):
269     self.expt_count += 1
270     self.close()
271
272
273 class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
274
275   def __init__(self, connected_socket, client_address, terminator, family,
276                message_fn, client_id):
277     daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
278                                                  client_address,
279                                                  terminator, family)
280     self.message_fn = message_fn
281     self.client_id = client_id
282     self.error_count = 0
283
284   def handle_message(self, message, message_id):
285     self.message_fn(self, message, message_id)
286
287   def handle_error(self):
288     self.error_count += 1
289     raise
290
291
292 class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
293   """Test daemon.AsyncStreamServer with a TCP connection"""
294
295   family = socket.AF_INET
296
297   def setUp(self):
298     testutils.GanetiTestCase.setUp(self)
299     self.mainloop = daemon.Mainloop()
300     self.address = self.getAddress()
301     self.server = _MyAsyncStreamServer(self.family, self.address,
302                                        self.handle_connection)
303     self.client_handler = _MyMessageStreamHandler
304     self.terminator = "\3"
305     self.address = self.server.getsockname()
306     self.clients = []
307     self.connections = []
308     self.messages = {}
309     self.connect_terminate_count = 0
310     self.message_terminate_count = 0
311     self.next_client_id = 0
312     # Save utils.IgnoreSignals so we can do evil things to it...
313     self.saved_utils_ignoresignals = utils.IgnoreSignals
314
315   def tearDown(self):
316     for c in self.clients:
317       c.close()
318     for c in self.connections:
319       c.close()
320     self.server.close()
321     # ...and restore it as well
322     utils.IgnoreSignals = self.saved_utils_ignoresignals
323     testutils.GanetiTestCase.tearDown(self)
324
325   def getAddress(self):
326     return ("127.0.0.1", 0)
327
328   def countTerminate(self, name):
329     value = getattr(self, name)
330     if value is not None:
331       value -= 1
332       setattr(self, name, value)
333       if value <= 0:
334         os.kill(os.getpid(), signal.SIGTERM)
335
336   def handle_connection(self, connected_socket, client_address):
337     client_id = self.next_client_id
338     self.next_client_id += 1
339     client_handler = self.client_handler(connected_socket, client_address,
340                                          self.terminator, self.family,
341                                          self.handle_message,
342                                          client_id)
343     self.connections.append(client_handler)
344     self.countTerminate("connect_terminate_count")
345
346   def handle_message(self, handler, message, message_id):
347     self.messages.setdefault(handler.client_id, [])
348     # We should just check that the message_ids are monotonically increasing.
349     # If in the unit tests we never remove messages from the received queue,
350     # though, we can just require that the queue length is the same as the
351     # message id, before pushing the message to it. This forces a more
352     # restrictive check, but we can live with this for now.
353     self.assertEquals(len(self.messages[handler.client_id]), message_id)
354     self.messages[handler.client_id].append(message)
355     if message == "error":
356       raise errors.GenericError("error")
357     self.countTerminate("message_terminate_count")
358
359   def getClient(self):
360     client = socket.socket(self.family, socket.SOCK_STREAM)
361     client.connect(self.address)
362     self.clients.append(client)
363     return client
364
365   def tearDown(self):
366     testutils.GanetiTestCase.tearDown(self)
367     self.server.close()
368
369   def testConnect(self):
370     self.getClient()
371     self.mainloop.Run()
372     self.assertEquals(len(self.connections), 1)
373     self.getClient()
374     self.mainloop.Run()
375     self.assertEquals(len(self.connections), 2)
376     self.connect_terminate_count = 4
377     self.getClient()
378     self.getClient()
379     self.getClient()
380     self.getClient()
381     self.mainloop.Run()
382     self.assertEquals(len(self.connections), 6)
383
384   def testBasicMessage(self):
385     self.connect_terminate_count = None
386     client = self.getClient()
387     client.send("ciao\3")
388     self.mainloop.Run()
389     self.assertEquals(len(self.connections), 1)
390     self.assertEquals(len(self.messages[0]), 1)
391     self.assertEquals(self.messages[0][0], "ciao")
392
393   def testDoubleMessage(self):
394     self.connect_terminate_count = None
395     client = self.getClient()
396     client.send("ciao\3")
397     self.mainloop.Run()
398     client.send("foobar\3")
399     self.mainloop.Run()
400     self.assertEquals(len(self.connections), 1)
401     self.assertEquals(len(self.messages[0]), 2)
402     self.assertEquals(self.messages[0][1], "foobar")
403
404   def testComposedMessage(self):
405     self.connect_terminate_count = None
406     self.message_terminate_count = 3
407     client = self.getClient()
408     client.send("one\3composed\3message\3")
409     self.mainloop.Run()
410     self.assertEquals(len(self.messages[0]), 3)
411     self.assertEquals(self.messages[0], ["one", "composed", "message"])
412
413   def testLongTerminator(self):
414     self.terminator = "\0\1\2"
415     self.connect_terminate_count = None
416     self.message_terminate_count = 3
417     client = self.getClient()
418     client.send("one\0\1\2composed\0\1\2message\0\1\2")
419     self.mainloop.Run()
420     self.assertEquals(len(self.messages[0]), 3)
421     self.assertEquals(self.messages[0], ["one", "composed", "message"])
422
423   def testErrorHandling(self):
424     self.connect_terminate_count = None
425     self.message_terminate_count = None
426     client = self.getClient()
427     client.send("one\3two\3error\3three\3")
428     self.assertRaises(errors.GenericError, self.mainloop.Run)
429     self.assertEquals(self.connections[0].error_count, 1)
430     self.assertEquals(self.messages[0], ["one", "two", "error"])
431     client.send("error\3")
432     self.assertRaises(errors.GenericError, self.mainloop.Run)
433     self.assertEquals(self.connections[0].error_count, 2)
434     self.assertEquals(self.messages[0], ["one", "two", "error", "three",
435                                          "error"])
436
437   def testDoubleClient(self):
438     self.connect_terminate_count = None
439     self.message_terminate_count = 2
440     client1 = self.getClient()
441     client2 = self.getClient()
442     client1.send("c1m1\3")
443     client2.send("c2m1\3")
444     self.mainloop.Run()
445     self.assertEquals(self.messages[0], ["c1m1"])
446     self.assertEquals(self.messages[1], ["c2m1"])
447
448   def testUnterminatedMessage(self):
449     self.connect_terminate_count = None
450     self.message_terminate_count = 3
451     client1 = self.getClient()
452     client2 = self.getClient()
453     client1.send("message\3unterminated")
454     client2.send("c2m1\3c2m2\3")
455     self.mainloop.Run()
456     self.assertEquals(self.messages[0], ["message"])
457     self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
458     client1.send("message\3")
459     self.mainloop.Run()
460     self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
461
462   def testSignaledWhileAccepting(self):
463     utils.IgnoreSignals = lambda fn, *args, **kwargs: None
464     client1 = self.getClient()
465     self.server.handle_accept()
466     # When interrupted while accepting we don't have a connection, but we
467     # didn't crash either.
468     self.assertEquals(len(self.connections), 0)
469     utils.IgnoreSignals = self.saved_utils_ignoresignals
470     self.mainloop.Run()
471     self.assertEquals(len(self.connections), 1)
472
473
474 class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
475   """Test daemon.AsyncStreamServer with a Unix path connection"""
476
477   family = socket.AF_UNIX
478
479   def getAddress(self):
480     self.tmpdir = tempfile.mkdtemp()
481     return os.path.join(self.tmpdir, "server.sock")
482
483   def tearDown(self):
484     shutil.rmtree(self.tmpdir)
485     TestAsyncStreamServerTCP.tearDown(self)
486
487
488 class TestAsyncAwaker(testutils.GanetiTestCase):
489   """Test daemon.AsyncAwaker"""
490
491   family = socket.AF_INET
492
493   def setUp(self):
494     testutils.GanetiTestCase.setUp(self)
495     self.mainloop = daemon.Mainloop()
496     self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
497     self.signal_count = 0
498     self.signal_terminate_count = 1
499
500   def tearDown(self):
501     self.awaker.close()
502
503   def handle_signal(self):
504     self.signal_count += 1
505     self.signal_terminate_count -= 1
506     if self.signal_terminate_count <= 0:
507       os.kill(os.getpid(), signal.SIGTERM)
508
509   def testBasicSignaling(self):
510     self.awaker.signal()
511     self.mainloop.Run()
512     self.assertEquals(self.signal_count, 1)
513
514   def testDoubleSignaling(self):
515     self.awaker.signal()
516     self.awaker.signal()
517     self.mainloop.Run()
518     # The second signal is never delivered
519     self.assertEquals(self.signal_count, 1)
520
521   def testReallyDoubleSignaling(self):
522     self.assert_(self.awaker.readable())
523     self.awaker.signal()
524     # Let's suppose two threads overlap, and both find need_signal True
525     self.awaker.need_signal = True
526     self.awaker.signal()
527     self.mainloop.Run()
528     # We still get only one signaling
529     self.assertEquals(self.signal_count, 1)
530
531   def testNoSignalFnArgument(self):
532     myawaker = daemon.AsyncAwaker()
533     self.assertRaises(socket.error, myawaker.handle_read)
534     myawaker.signal()
535     myawaker.handle_read()
536     self.assertRaises(socket.error, myawaker.handle_read)
537     myawaker.signal()
538     myawaker.signal()
539     myawaker.handle_read()
540     self.assertRaises(socket.error, myawaker.handle_read)
541     myawaker.close()
542
543   def testWrongSignalFnArgument(self):
544     self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
545     self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
546     self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
547     self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
548
549
550 if __name__ == "__main__":
551   testutils.GanetiTestProgram()