Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.daemon_unittest.py @ 495ba852

History | View | Annotate | Download (19.7 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the daemon module"""
23

    
24
import unittest
25
import signal
26
import os
27
import socket
28
import time
29
import tempfile
30
import shutil
31

    
32
from ganeti import daemon
33
from ganeti import errors
34
from ganeti import utils
35

    
36
import testutils
37

    
38

    
39
class TestMainloop(testutils.GanetiTestCase):
40
  """Test daemon.Mainloop"""
41

    
42
  def setUp(self):
43
    testutils.GanetiTestCase.setUp(self)
44
    self.mainloop = daemon.Mainloop()
45
    self.sendsig_events = []
46
    self.onsignal_events = []
47

    
48
  def _CancelEvent(self, handle):
49
    self.mainloop.scheduler.cancel(handle)
50

    
51
  def _SendSig(self, sig):
52
    self.sendsig_events.append(sig)
53
    os.kill(os.getpid(), sig)
54

    
55
  def OnSignal(self, signum):
56
    self.onsignal_events.append(signum)
57

    
58
  def testRunAndTermBySched(self):
59
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
60
    self.mainloop.Run() # terminates by _SendSig being scheduled
61
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
62

    
63
  def testTerminatingSignals(self):
64
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
65
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
66
    self.mainloop.Run()
67
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
68
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
69
    self.mainloop.Run()
70
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
71
                                            signal.SIGTERM])
72

    
73
  def testSchedulerCancel(self):
74
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
75
                                           [signal.SIGTERM])
76
    self.mainloop.scheduler.cancel(handle)
77
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
78
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
79
    self.mainloop.Run()
80
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
81

    
82
  def testRegisterSignal(self):
83
    self.mainloop.RegisterSignal(self)
84
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
85
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
86
                                           [signal.SIGTERM])
87
    self.mainloop.scheduler.cancel(handle)
88
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
89
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
90
    # ...not delievered because they are scheduled after TERM
91
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
92
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
93
    self.mainloop.Run()
94
    self.assertEquals(self.sendsig_events,
95
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
96
    self.assertEquals(self.onsignal_events, self.sendsig_events)
97

    
98
  def testDeferredCancel(self):
99
    self.mainloop.RegisterSignal(self)
100
    now = time.time()
101
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
102
                                     [signal.SIGCHLD])
103
    handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
104
                                               [signal.SIGCHLD])
105
    handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
106
                                               [signal.SIGCHLD])
107
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
108
                                     [handle1])
109
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
110
                                     [handle2])
111
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
112
    self.mainloop.Run()
113
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
114
    self.assertEquals(self.onsignal_events, self.sendsig_events)
115

    
116
  def testReRun(self):
117
    self.mainloop.RegisterSignal(self)
118
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
119
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
120
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
121
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
122
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
123
    self.mainloop.Run()
124
    self.assertEquals(self.sendsig_events,
125
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
126
    self.assertEquals(self.onsignal_events, self.sendsig_events)
127
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
128
    self.mainloop.Run()
129
    self.assertEquals(self.sendsig_events,
130
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
131
                       signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
132
    self.assertEquals(self.onsignal_events, self.sendsig_events)
133

    
134
  def testPriority(self):
135
    # for events at the same time, the highest priority one executes first
136
    now = time.time()
137
    self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
138
                                     [signal.SIGCHLD])
139
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
140
                                     [signal.SIGTERM])
141
    self.mainloop.Run()
142
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
143
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
144
    self.mainloop.Run()
145
    self.assertEquals(self.sendsig_events,
146
                      [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
147

    
148

    
149
class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
150

    
151
  def __init__(self):
152
    daemon.AsyncUDPSocket.__init__(self)
153
    self.received = []
154
    self.error_count = 0
155

    
156
  def handle_datagram(self, payload, ip, port):
157
    self.received.append((payload))
158
    if payload == "terminate":
159
      os.kill(os.getpid(), signal.SIGTERM)
160
    elif payload == "error":
161
      raise errors.GenericError("error")
162

    
163
  def handle_error(self):
164
    self.error_count += 1
165
    raise
166

    
167

    
168
class TestAsyncUDPSocket(testutils.GanetiTestCase):
169
  """Test daemon.AsyncUDPSocket"""
170

    
171
  def setUp(self):
172
    testutils.GanetiTestCase.setUp(self)
173
    self.mainloop = daemon.Mainloop()
174
    self.server = _MyAsyncUDPSocket()
175
    self.client = _MyAsyncUDPSocket()
176
    self.server.bind(("127.0.0.1", 0))
177
    self.port = self.server.getsockname()[1]
178
    # Save utils.IgnoreSignals so we can do evil things to it...
179
    self.saved_utils_ignoresignals = utils.IgnoreSignals
180

    
181
  def tearDown(self):
182
    self.server.close()
183
    self.client.close()
184
    # ...and restore it as well
185
    utils.IgnoreSignals = self.saved_utils_ignoresignals
186
    testutils.GanetiTestCase.tearDown(self)
187

    
188
  def testNoDoubleBind(self):
189
    self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
190

    
191
  def _ThreadedClient(self, payload):
192
    self.client.enqueue_send("127.0.0.1", self.port, payload)
193
    print "sending %s" % payload
194
    while self.client.writable():
195
      self.client.handle_write()
196

    
197
  def testAsyncClientServer(self):
198
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
199
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
200
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
201
    self.mainloop.Run()
202
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
203

    
204
  def testSyncClientServer(self):
205
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
206
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
207
    while self.client.writable():
208
      self.client.handle_write()
209
    self.server.process_next_packet()
210
    self.assertEquals(self.server.received, ["p1"])
211
    self.server.process_next_packet()
212
    self.assertEquals(self.server.received, ["p1", "p2"])
213
    self.client.enqueue_send("127.0.0.1", self.port, "p3")
214
    while self.client.writable():
215
      self.client.handle_write()
216
    self.server.process_next_packet()
217
    self.assertEquals(self.server.received, ["p1", "p2", "p3"])
218

    
219
  def testErrorHandling(self):
220
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
221
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
222
    self.client.enqueue_send("127.0.0.1", self.port, "error")
223
    self.client.enqueue_send("127.0.0.1", self.port, "p3")
224
    self.client.enqueue_send("127.0.0.1", self.port, "error")
225
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
226
    self.assertRaises(errors.GenericError, self.mainloop.Run)
227
    self.assertEquals(self.server.received,
228
                      ["p1", "p2", "error"])
229
    self.assertEquals(self.server.error_count, 1)
230
    self.assertRaises(errors.GenericError, self.mainloop.Run)
231
    self.assertEquals(self.server.received,
232
                      ["p1", "p2", "error", "p3", "error"])
233
    self.assertEquals(self.server.error_count, 2)
234
    self.mainloop.Run()
235
    self.assertEquals(self.server.received,
236
                      ["p1", "p2", "error", "p3", "error", "terminate"])
237
    self.assertEquals(self.server.error_count, 2)
238

    
239
  def testSignaledWhileReceiving(self):
240
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
241
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
242
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
243
    self.server.handle_read()
244
    self.assertEquals(self.server.received, [])
245
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
246
    utils.IgnoreSignals = self.saved_utils_ignoresignals
247
    self.mainloop.Run()
248
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
249

    
250

    
251
class _MyAsyncStreamServer(daemon.AsyncStreamServer):
252

    
253
  def __init__(self, family, address, handle_connection_fn):
254
    daemon.AsyncStreamServer.__init__(self, family, address)
255
    self.handle_connection_fn = handle_connection_fn
256
    self.error_count = 0
257
    self.expt_count = 0
258

    
259
  def handle_connection(self, connected_socket, client_address):
260
    self.handle_connection_fn(connected_socket, client_address)
261

    
262
  def handle_error(self):
263
    self.error_count += 1
264
    self.close()
265
    raise
266

    
267
  def handle_expt(self):
268
    self.expt_count += 1
269
    self.close()
270

    
271

    
272
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
273

    
274
  def __init__(self, connected_socket, client_address, terminator, family,
275
               message_fn, client_id):
276
    daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
277
                                                 client_address,
278
                                                 terminator, family)
279
    self.message_fn = message_fn
280
    self.client_id = client_id
281
    self.error_count = 0
282

    
283
  def handle_message(self, message, message_id):
284
    self.message_fn(self, message, message_id)
285

    
286
  def handle_error(self):
287
    self.error_count += 1
288
    raise
289

    
290

    
291
class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
292
  """Test daemon.AsyncStreamServer with a TCP connection"""
293

    
294
  family = socket.AF_INET
295

    
296
  def setUp(self):
297
    testutils.GanetiTestCase.setUp(self)
298
    self.mainloop = daemon.Mainloop()
299
    self.address = self.getAddress()
300
    self.server = _MyAsyncStreamServer(self.family, self.address,
301
                                       self.handle_connection)
302
    self.client_handler = _MyMessageStreamHandler
303
    self.terminator = "\3"
304
    self.address = self.server.getsockname()
305
    self.clients = []
306
    self.connections = []
307
    self.messages = {}
308
    self.connect_terminate_count = 0
309
    self.message_terminate_count = 0
310
    self.next_client_id = 0
311
    # Save utils.IgnoreSignals so we can do evil things to it...
312
    self.saved_utils_ignoresignals = utils.IgnoreSignals
313

    
314
  def tearDown(self):
315
    for c in self.clients:
316
      c.close()
317
    for c in self.connections:
318
      c.close()
319
    self.server.close()
320
    # ...and restore it as well
321
    utils.IgnoreSignals = self.saved_utils_ignoresignals
322
    testutils.GanetiTestCase.tearDown(self)
323

    
324
  def getAddress(self):
325
    return ("127.0.0.1", 0)
326

    
327
  def countTerminate(self, name):
328
    value = getattr(self, name)
329
    if value is not None:
330
      value -= 1
331
      setattr(self, name, value)
332
      if value <= 0:
333
        os.kill(os.getpid(), signal.SIGTERM)
334

    
335
  def handle_connection(self, connected_socket, client_address):
336
    client_id = self.next_client_id
337
    self.next_client_id += 1
338
    client_handler = self.client_handler(connected_socket, client_address,
339
                                         self.terminator, self.family,
340
                                         self.handle_message,
341
                                         client_id)
342
    self.connections.append(client_handler)
343
    self.countTerminate("connect_terminate_count")
344

    
345
  def handle_message(self, handler, message, message_id):
346
    self.messages.setdefault(handler.client_id, [])
347
    # We should just check that the message_ids are monotonically increasing.
348
    # If in the unit tests we never remove messages from the received queue,
349
    # though, we can just require that the queue length is the same as the
350
    # message id, before pushing the message to it. This forces a more
351
    # restrictive check, but we can live with this for now.
352
    self.assertEquals(len(self.messages[handler.client_id]), message_id)
353
    self.messages[handler.client_id].append(message)
354
    if message == "error":
355
      raise errors.GenericError("error")
356
    self.countTerminate("message_terminate_count")
357

    
358
  def getClient(self):
359
    client = socket.socket(self.family, socket.SOCK_STREAM)
360
    client.connect(self.address)
361
    self.clients.append(client)
362
    return client
363

    
364
  def tearDown(self):
365
    testutils.GanetiTestCase.tearDown(self)
366
    self.server.close()
367

    
368
  def testConnect(self):
369
    self.getClient()
370
    self.mainloop.Run()
371
    self.assertEquals(len(self.connections), 1)
372
    self.getClient()
373
    self.mainloop.Run()
374
    self.assertEquals(len(self.connections), 2)
375
    self.connect_terminate_count = 4
376
    self.getClient()
377
    self.getClient()
378
    self.getClient()
379
    self.getClient()
380
    self.mainloop.Run()
381
    self.assertEquals(len(self.connections), 6)
382

    
383
  def testBasicMessage(self):
384
    self.connect_terminate_count = None
385
    client = self.getClient()
386
    client.send("ciao\3")
387
    self.mainloop.Run()
388
    self.assertEquals(len(self.connections), 1)
389
    self.assertEquals(len(self.messages[0]), 1)
390
    self.assertEquals(self.messages[0][0], "ciao")
391

    
392
  def testDoubleMessage(self):
393
    self.connect_terminate_count = None
394
    client = self.getClient()
395
    client.send("ciao\3")
396
    self.mainloop.Run()
397
    client.send("foobar\3")
398
    self.mainloop.Run()
399
    self.assertEquals(len(self.connections), 1)
400
    self.assertEquals(len(self.messages[0]), 2)
401
    self.assertEquals(self.messages[0][1], "foobar")
402

    
403
  def testComposedMessage(self):
404
    self.connect_terminate_count = None
405
    self.message_terminate_count = 3
406
    client = self.getClient()
407
    client.send("one\3composed\3message\3")
408
    self.mainloop.Run()
409
    self.assertEquals(len(self.messages[0]), 3)
410
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
411

    
412
  def testLongTerminator(self):
413
    self.terminator = "\0\1\2"
414
    self.connect_terminate_count = None
415
    self.message_terminate_count = 3
416
    client = self.getClient()
417
    client.send("one\0\1\2composed\0\1\2message\0\1\2")
418
    self.mainloop.Run()
419
    self.assertEquals(len(self.messages[0]), 3)
420
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
421

    
422
  def testErrorHandling(self):
423
    self.connect_terminate_count = None
424
    self.message_terminate_count = None
425
    client = self.getClient()
426
    client.send("one\3two\3error\3three\3")
427
    self.assertRaises(errors.GenericError, self.mainloop.Run)
428
    self.assertEquals(self.connections[0].error_count, 1)
429
    self.assertEquals(self.messages[0], ["one", "two", "error"])
430
    client.send("error\3")
431
    self.assertRaises(errors.GenericError, self.mainloop.Run)
432
    self.assertEquals(self.connections[0].error_count, 2)
433
    self.assertEquals(self.messages[0], ["one", "two", "error", "three",
434
                                         "error"])
435

    
436
  def testDoubleClient(self):
437
    self.connect_terminate_count = None
438
    self.message_terminate_count = 2
439
    client1 = self.getClient()
440
    client2 = self.getClient()
441
    client1.send("c1m1\3")
442
    client2.send("c2m1\3")
443
    self.mainloop.Run()
444
    self.assertEquals(self.messages[0], ["c1m1"])
445
    self.assertEquals(self.messages[1], ["c2m1"])
446

    
447
  def testUnterminatedMessage(self):
448
    self.connect_terminate_count = None
449
    self.message_terminate_count = 3
450
    client1 = self.getClient()
451
    client2 = self.getClient()
452
    client1.send("message\3unterminated")
453
    client2.send("c2m1\3c2m2\3")
454
    self.mainloop.Run()
455
    self.assertEquals(self.messages[0], ["message"])
456
    self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
457
    client1.send("message\3")
458
    self.mainloop.Run()
459
    self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
460

    
461
  def testSignaledWhileAccepting(self):
462
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
463
    client1 = self.getClient()
464
    self.server.handle_accept()
465
    # When interrupted while accepting we don't have a connection, but we
466
    # didn't crash either.
467
    self.assertEquals(len(self.connections), 0)
468
    utils.IgnoreSignals = self.saved_utils_ignoresignals
469
    self.mainloop.Run()
470
    self.assertEquals(len(self.connections), 1)
471

    
472

    
473
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
474
  """Test daemon.AsyncStreamServer with a Unix path connection"""
475

    
476
  family = socket.AF_UNIX
477

    
478
  def getAddress(self):
479
    self.tmpdir = tempfile.mkdtemp()
480
    return os.path.join(self.tmpdir, "server.sock")
481

    
482
  def tearDown(self):
483
    shutil.rmtree(self.tmpdir)
484
    TestAsyncStreamServerTCP.tearDown(self)
485

    
486

    
487
class TestAsyncAwaker(testutils.GanetiTestCase):
488
  """Test daemon.AsyncAwaker"""
489

    
490
  family = socket.AF_INET
491

    
492
  def setUp(self):
493
    testutils.GanetiTestCase.setUp(self)
494
    self.mainloop = daemon.Mainloop()
495
    self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
496
    self.signal_count = 0
497
    self.signal_terminate_count = 1
498

    
499
  def tearDown(self):
500
    self.awaker.close()
501

    
502
  def handle_signal(self):
503
    self.signal_count += 1
504
    self.signal_terminate_count -= 1
505
    if self.signal_terminate_count <= 0:
506
      os.kill(os.getpid(), signal.SIGTERM)
507

    
508
  def testBasicSignaling(self):
509
    self.awaker.signal()
510
    self.mainloop.Run()
511
    self.assertEquals(self.signal_count, 1)
512

    
513
  def testDoubleSignaling(self):
514
    self.awaker.signal()
515
    self.awaker.signal()
516
    self.mainloop.Run()
517
    # The second signal is never delivered
518
    self.assertEquals(self.signal_count, 1)
519

    
520
  def testReallyDoubleSignaling(self):
521
    self.assert_(self.awaker.readable())
522
    self.awaker.signal()
523
    # Let's suppose two threads overlap, and both find need_signal True
524
    self.awaker.need_signal = True
525
    self.awaker.signal()
526
    self.mainloop.Run()
527
    # We still get only one signaling
528
    self.assertEquals(self.signal_count, 1)
529

    
530
  def testNoSignalFnArgument(self):
531
    myawaker = daemon.AsyncAwaker()
532
    self.assertRaises(socket.error, myawaker.handle_read)
533
    myawaker.signal()
534
    myawaker.handle_read()
535
    self.assertRaises(socket.error, myawaker.handle_read)
536
    myawaker.signal()
537
    myawaker.signal()
538
    myawaker.handle_read()
539
    self.assertRaises(socket.error, myawaker.handle_read)
540
    myawaker.close()
541

    
542
  def testWrongSignalFnArgument(self):
543
    self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
544
    self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
545
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
546
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
547

    
548

    
549
if __name__ == "__main__":
550
  testutils.GanetiTestProgram()