Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.daemon_unittest.py @ 14933c17

History | View | Annotate | Download (24.3 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the daemon module"""
23

    
24
import unittest
25
import signal
26
import os
27
import socket
28
import time
29
import tempfile
30
import shutil
31

    
32
from ganeti import daemon
33
from ganeti import errors
34
from ganeti import constants
35
from ganeti import utils
36

    
37
import testutils
38

    
39

    
40
class TestMainloop(testutils.GanetiTestCase):
41
  """Test daemon.Mainloop"""
42

    
43
  def setUp(self):
44
    testutils.GanetiTestCase.setUp(self)
45
    self.mainloop = daemon.Mainloop()
46
    self.sendsig_events = []
47
    self.onsignal_events = []
48

    
49
  def _CancelEvent(self, handle):
50
    self.mainloop.scheduler.cancel(handle)
51

    
52
  def _SendSig(self, sig):
53
    self.sendsig_events.append(sig)
54
    os.kill(os.getpid(), sig)
55

    
56
  def OnSignal(self, signum):
57
    self.onsignal_events.append(signum)
58

    
59
  def testRunAndTermBySched(self):
60
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
61
    self.mainloop.Run() # terminates by _SendSig being scheduled
62
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
63

    
64
  def testTerminatingSignals(self):
65
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
66
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
67
    self.mainloop.Run()
68
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
69
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
70
    self.mainloop.Run()
71
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
72
                                            signal.SIGTERM])
73

    
74
  def testSchedulerCancel(self):
75
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
76
                                           [signal.SIGTERM])
77
    self.mainloop.scheduler.cancel(handle)
78
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
79
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
80
    self.mainloop.Run()
81
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
82

    
83
  def testRegisterSignal(self):
84
    self.mainloop.RegisterSignal(self)
85
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
86
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
87
                                           [signal.SIGTERM])
88
    self.mainloop.scheduler.cancel(handle)
89
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
90
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
91
    # ...not delievered because they are scheduled after TERM
92
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
93
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
94
    self.mainloop.Run()
95
    self.assertEquals(self.sendsig_events,
96
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
97
    self.assertEquals(self.onsignal_events, self.sendsig_events)
98

    
99
  def testDeferredCancel(self):
100
    self.mainloop.RegisterSignal(self)
101
    now = time.time()
102
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
103
                                     [signal.SIGCHLD])
104
    handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
105
                                               [signal.SIGCHLD])
106
    handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
107
                                               [signal.SIGCHLD])
108
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
109
                                     [handle1])
110
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
111
                                     [handle2])
112
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
113
    self.mainloop.Run()
114
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
115
    self.assertEquals(self.onsignal_events, self.sendsig_events)
116

    
117
  def testReRun(self):
118
    self.mainloop.RegisterSignal(self)
119
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
120
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
121
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
122
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
123
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
124
    self.mainloop.Run()
125
    self.assertEquals(self.sendsig_events,
126
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
127
    self.assertEquals(self.onsignal_events, self.sendsig_events)
128
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
129
    self.mainloop.Run()
130
    self.assertEquals(self.sendsig_events,
131
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
132
                       signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
133
    self.assertEquals(self.onsignal_events, self.sendsig_events)
134

    
135
  def testPriority(self):
136
    # for events at the same time, the highest priority one executes first
137
    now = time.time()
138
    self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
139
                                     [signal.SIGCHLD])
140
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
141
                                     [signal.SIGTERM])
142
    self.mainloop.Run()
143
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
144
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
145
    self.mainloop.Run()
146
    self.assertEquals(self.sendsig_events,
147
                      [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
148

    
149

    
150
class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
151

    
152
  def __init__(self, family):
153
    daemon.AsyncUDPSocket.__init__(self, family)
154
    self.received = []
155
    self.error_count = 0
156

    
157
  def handle_datagram(self, payload, ip, port):
158
    self.received.append((payload))
159
    if payload == "terminate":
160
      os.kill(os.getpid(), signal.SIGTERM)
161
    elif payload == "error":
162
      raise errors.GenericError("error")
163

    
164
  def handle_error(self):
165
    self.error_count += 1
166
    raise
167

    
168

    
169
class _BaseAsyncUDPSocketTest:
170
  """Base class for  AsyncUDPSocket tests"""
171

    
172
  family = None
173
  address = None
174

    
175
  def setUp(self):
176
    self.mainloop = daemon.Mainloop()
177
    self.server = _MyAsyncUDPSocket(self.family)
178
    self.client = _MyAsyncUDPSocket(self.family)
179
    self.server.bind((self.address, 0))
180
    self.port = self.server.getsockname()[1]
181
    # Save utils.IgnoreSignals so we can do evil things to it...
182
    self.saved_utils_ignoresignals = utils.IgnoreSignals
183

    
184
  def tearDown(self):
185
    self.server.close()
186
    self.client.close()
187
    # ...and restore it as well
188
    utils.IgnoreSignals = self.saved_utils_ignoresignals
189
    testutils.GanetiTestCase.tearDown(self)
190

    
191
  def testNoDoubleBind(self):
192
    self.assertRaises(socket.error, self.client.bind, (self.address, self.port))
193

    
194
  def testAsyncClientServer(self):
195
    self.client.enqueue_send(self.address, self.port, "p1")
196
    self.client.enqueue_send(self.address, self.port, "p2")
197
    self.client.enqueue_send(self.address, self.port, "terminate")
198
    self.mainloop.Run()
199
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
200

    
201
  def testSyncClientServer(self):
202
    self.client.handle_write()
203
    self.client.enqueue_send(self.address, self.port, "p1")
204
    self.client.enqueue_send(self.address, self.port, "p2")
205
    while self.client.writable():
206
      self.client.handle_write()
207
    self.server.process_next_packet()
208
    self.assertEquals(self.server.received, ["p1"])
209
    self.server.process_next_packet()
210
    self.assertEquals(self.server.received, ["p1", "p2"])
211
    self.client.enqueue_send(self.address, self.port, "p3")
212
    while self.client.writable():
213
      self.client.handle_write()
214
    self.server.process_next_packet()
215
    self.assertEquals(self.server.received, ["p1", "p2", "p3"])
216

    
217
  def testErrorHandling(self):
218
    self.client.enqueue_send(self.address, self.port, "p1")
219
    self.client.enqueue_send(self.address, self.port, "p2")
220
    self.client.enqueue_send(self.address, self.port, "error")
221
    self.client.enqueue_send(self.address, self.port, "p3")
222
    self.client.enqueue_send(self.address, self.port, "error")
223
    self.client.enqueue_send(self.address, self.port, "terminate")
224
    self.assertRaises(errors.GenericError, self.mainloop.Run)
225
    self.assertEquals(self.server.received,
226
                      ["p1", "p2", "error"])
227
    self.assertEquals(self.server.error_count, 1)
228
    self.assertRaises(errors.GenericError, self.mainloop.Run)
229
    self.assertEquals(self.server.received,
230
                      ["p1", "p2", "error", "p3", "error"])
231
    self.assertEquals(self.server.error_count, 2)
232
    self.mainloop.Run()
233
    self.assertEquals(self.server.received,
234
                      ["p1", "p2", "error", "p3", "error", "terminate"])
235
    self.assertEquals(self.server.error_count, 2)
236

    
237
  def testSignaledWhileReceiving(self):
238
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
239
    self.client.enqueue_send(self.address, self.port, "p1")
240
    self.client.enqueue_send(self.address, self.port, "p2")
241
    self.server.handle_read()
242
    self.assertEquals(self.server.received, [])
243
    self.client.enqueue_send(self.address, self.port, "terminate")
244
    utils.IgnoreSignals = self.saved_utils_ignoresignals
245
    self.mainloop.Run()
246
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
247

    
248
  def testOversizedDatagram(self):
249
    oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
250
    self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
251
                      self.address, self.port, oversized_data)
252

    
253

    
254
class TestAsyncIP4UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
255
  """Test IP4 daemon.AsyncUDPSocket"""
256

    
257
  family = socket.AF_INET
258
  address = "127.0.0.1"
259

    
260
  def setUp(self):
261
    testutils.GanetiTestCase.setUp(self)
262
    _BaseAsyncUDPSocketTest.setUp(self)
263

    
264
  def tearDown(self):
265
    testutils.GanetiTestCase.tearDown(self)
266
    _BaseAsyncUDPSocketTest.tearDown(self)
267

    
268

    
269
class TestAsyncIP6UDPSocket(testutils.GanetiTestCase, _BaseAsyncUDPSocketTest):
270
  """Test IP6 daemon.AsyncUDPSocket"""
271

    
272
  family = socket.AF_INET6
273
  address = "::1"
274

    
275
  def setUp(self):
276
    testutils.GanetiTestCase.setUp(self)
277
    _BaseAsyncUDPSocketTest.setUp(self)
278

    
279
  def tearDown(self):
280
    testutils.GanetiTestCase.tearDown(self)
281
    _BaseAsyncUDPSocketTest.tearDown(self)
282

    
283

    
284
class _MyAsyncStreamServer(daemon.AsyncStreamServer):
285

    
286
  def __init__(self, family, address, handle_connection_fn):
287
    daemon.AsyncStreamServer.__init__(self, family, address)
288
    self.handle_connection_fn = handle_connection_fn
289
    self.error_count = 0
290
    self.expt_count = 0
291

    
292
  def handle_connection(self, connected_socket, client_address):
293
    self.handle_connection_fn(connected_socket, client_address)
294

    
295
  def handle_error(self):
296
    self.error_count += 1
297
    self.close()
298
    raise
299

    
300
  def handle_expt(self):
301
    self.expt_count += 1
302
    self.close()
303

    
304

    
305
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
306

    
307
  def __init__(self, connected_socket, client_address, terminator, family,
308
               message_fn, client_id, unhandled_limit):
309
    daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
310
                                                 client_address,
311
                                                 terminator, family,
312
                                                 unhandled_limit)
313
    self.message_fn = message_fn
314
    self.client_id = client_id
315
    self.error_count = 0
316

    
317
  def handle_message(self, message, message_id):
318
    self.message_fn(self, message, message_id)
319

    
320
  def handle_error(self):
321
    self.error_count += 1
322
    raise
323

    
324

    
325
class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
326
  """Test daemon.AsyncStreamServer with a TCP connection"""
327

    
328
  family = socket.AF_INET
329

    
330
  def setUp(self):
331
    testutils.GanetiTestCase.setUp(self)
332
    self.mainloop = daemon.Mainloop()
333
    self.address = self.getAddress()
334
    self.server = _MyAsyncStreamServer(self.family, self.address,
335
                                       self.handle_connection)
336
    self.client_handler = _MyMessageStreamHandler
337
    self.unhandled_limit = None
338
    self.terminator = "\3"
339
    self.address = self.server.getsockname()
340
    self.clients = []
341
    self.connections = []
342
    self.messages = {}
343
    self.connect_terminate_count = 0
344
    self.message_terminate_count = 0
345
    self.next_client_id = 0
346
    # Save utils.IgnoreSignals so we can do evil things to it...
347
    self.saved_utils_ignoresignals = utils.IgnoreSignals
348

    
349
  def tearDown(self):
350
    for c in self.clients:
351
      c.close()
352
    for c in self.connections:
353
      c.close()
354
    self.server.close()
355
    # ...and restore it as well
356
    utils.IgnoreSignals = self.saved_utils_ignoresignals
357
    testutils.GanetiTestCase.tearDown(self)
358

    
359
  def getAddress(self):
360
    return ("127.0.0.1", 0)
361

    
362
  def countTerminate(self, name):
363
    value = getattr(self, name)
364
    if value is not None:
365
      value -= 1
366
      setattr(self, name, value)
367
      if value <= 0:
368
        os.kill(os.getpid(), signal.SIGTERM)
369

    
370
  def handle_connection(self, connected_socket, client_address):
371
    client_id = self.next_client_id
372
    self.next_client_id += 1
373
    client_handler = self.client_handler(connected_socket, client_address,
374
                                         self.terminator, self.family,
375
                                         self.handle_message,
376
                                         client_id, self.unhandled_limit)
377
    self.connections.append(client_handler)
378
    self.countTerminate("connect_terminate_count")
379

    
380
  def handle_message(self, handler, message, message_id):
381
    self.messages.setdefault(handler.client_id, [])
382
    # We should just check that the message_ids are monotonically increasing.
383
    # If in the unit tests we never remove messages from the received queue,
384
    # though, we can just require that the queue length is the same as the
385
    # message id, before pushing the message to it. This forces a more
386
    # restrictive check, but we can live with this for now.
387
    self.assertEquals(len(self.messages[handler.client_id]), message_id)
388
    self.messages[handler.client_id].append(message)
389
    if message == "error":
390
      raise errors.GenericError("error")
391
    self.countTerminate("message_terminate_count")
392

    
393
  def getClient(self):
394
    client = socket.socket(self.family, socket.SOCK_STREAM)
395
    client.connect(self.address)
396
    self.clients.append(client)
397
    return client
398

    
399
  def tearDown(self):
400
    testutils.GanetiTestCase.tearDown(self)
401
    self.server.close()
402

    
403
  def testConnect(self):
404
    self.getClient()
405
    self.mainloop.Run()
406
    self.assertEquals(len(self.connections), 1)
407
    self.getClient()
408
    self.mainloop.Run()
409
    self.assertEquals(len(self.connections), 2)
410
    self.connect_terminate_count = 4
411
    self.getClient()
412
    self.getClient()
413
    self.getClient()
414
    self.getClient()
415
    self.mainloop.Run()
416
    self.assertEquals(len(self.connections), 6)
417

    
418
  def testBasicMessage(self):
419
    self.connect_terminate_count = None
420
    client = self.getClient()
421
    client.send("ciao\3")
422
    self.mainloop.Run()
423
    self.assertEquals(len(self.connections), 1)
424
    self.assertEquals(len(self.messages[0]), 1)
425
    self.assertEquals(self.messages[0][0], "ciao")
426

    
427
  def testDoubleMessage(self):
428
    self.connect_terminate_count = None
429
    client = self.getClient()
430
    client.send("ciao\3")
431
    self.mainloop.Run()
432
    client.send("foobar\3")
433
    self.mainloop.Run()
434
    self.assertEquals(len(self.connections), 1)
435
    self.assertEquals(len(self.messages[0]), 2)
436
    self.assertEquals(self.messages[0][1], "foobar")
437

    
438
  def testComposedMessage(self):
439
    self.connect_terminate_count = None
440
    self.message_terminate_count = 3
441
    client = self.getClient()
442
    client.send("one\3composed\3message\3")
443
    self.mainloop.Run()
444
    self.assertEquals(len(self.messages[0]), 3)
445
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
446

    
447
  def testLongTerminator(self):
448
    self.terminator = "\0\1\2"
449
    self.connect_terminate_count = None
450
    self.message_terminate_count = 3
451
    client = self.getClient()
452
    client.send("one\0\1\2composed\0\1\2message\0\1\2")
453
    self.mainloop.Run()
454
    self.assertEquals(len(self.messages[0]), 3)
455
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
456

    
457
  def testErrorHandling(self):
458
    self.connect_terminate_count = None
459
    self.message_terminate_count = None
460
    client = self.getClient()
461
    client.send("one\3two\3error\3three\3")
462
    self.assertRaises(errors.GenericError, self.mainloop.Run)
463
    self.assertEquals(self.connections[0].error_count, 1)
464
    self.assertEquals(self.messages[0], ["one", "two", "error"])
465
    client.send("error\3")
466
    self.assertRaises(errors.GenericError, self.mainloop.Run)
467
    self.assertEquals(self.connections[0].error_count, 2)
468
    self.assertEquals(self.messages[0], ["one", "two", "error", "three",
469
                                         "error"])
470

    
471
  def testDoubleClient(self):
472
    self.connect_terminate_count = None
473
    self.message_terminate_count = 2
474
    client1 = self.getClient()
475
    client2 = self.getClient()
476
    client1.send("c1m1\3")
477
    client2.send("c2m1\3")
478
    self.mainloop.Run()
479
    self.assertEquals(self.messages[0], ["c1m1"])
480
    self.assertEquals(self.messages[1], ["c2m1"])
481

    
482
  def testUnterminatedMessage(self):
483
    self.connect_terminate_count = None
484
    self.message_terminate_count = 3
485
    client1 = self.getClient()
486
    client2 = self.getClient()
487
    client1.send("message\3unterminated")
488
    client2.send("c2m1\3c2m2\3")
489
    self.mainloop.Run()
490
    self.assertEquals(self.messages[0], ["message"])
491
    self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
492
    client1.send("message\3")
493
    self.mainloop.Run()
494
    self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
495

    
496
  def testSignaledWhileAccepting(self):
497
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
498
    client1 = self.getClient()
499
    self.server.handle_accept()
500
    # When interrupted while accepting we don't have a connection, but we
501
    # didn't crash either.
502
    self.assertEquals(len(self.connections), 0)
503
    utils.IgnoreSignals = self.saved_utils_ignoresignals
504
    self.mainloop.Run()
505
    self.assertEquals(len(self.connections), 1)
506

    
507
  def testSendMessage(self):
508
    self.connect_terminate_count = None
509
    self.message_terminate_count = 3
510
    client1 = self.getClient()
511
    client2 = self.getClient()
512
    client1.send("one\3composed\3message\3")
513
    self.mainloop.Run()
514
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
515
    self.assertFalse(self.connections[0].writable())
516
    self.assertFalse(self.connections[1].writable())
517
    self.connections[0].send_message("r0")
518
    self.assert_(self.connections[0].writable())
519
    self.assertFalse(self.connections[1].writable())
520
    self.connections[0].send_message("r1")
521
    self.connections[0].send_message("r2")
522
    # We currently have no way to terminate the mainloop on write events, but
523
    # let's assume handle_write will be called if writable() is True.
524
    while self.connections[0].writable():
525
      self.connections[0].handle_write()
526
    client1.setblocking(0)
527
    client2.setblocking(0)
528
    self.assertEquals(client1.recv(4096), "r0\3r1\3r2\3")
529
    self.assertRaises(socket.error, client2.recv, 4096)
530

    
531
  def testLimitedUnhandledMessages(self):
532
    self.connect_terminate_count = None
533
    self.message_terminate_count = 3
534
    self.unhandled_limit = 2
535
    client1 = self.getClient()
536
    client2 = self.getClient()
537
    client1.send("one\3composed\3long\3message\3")
538
    client2.send("c2one\3")
539
    self.mainloop.Run()
540
    self.assertEquals(self.messages[0], ["one", "composed"])
541
    self.assertEquals(self.messages[1], ["c2one"])
542
    self.assertFalse(self.connections[0].readable())
543
    self.assert_(self.connections[1].readable())
544
    self.connections[0].send_message("r0")
545
    self.message_terminate_count = None
546
    client1.send("another\3")
547
    # when we write replies messages queued also get handled, but not the ones
548
    # in the socket.
549
    while self.connections[0].writable():
550
      self.connections[0].handle_write()
551
    self.assertFalse(self.connections[0].readable())
552
    self.assertEquals(self.messages[0], ["one", "composed", "long"])
553
    self.connections[0].send_message("r1")
554
    self.connections[0].send_message("r2")
555
    while self.connections[0].writable():
556
      self.connections[0].handle_write()
557
    self.assertEquals(self.messages[0], ["one", "composed", "long", "message"])
558
    self.assert_(self.connections[0].readable())
559

    
560
  def testLimitedUnhandledMessagesOne(self):
561
    self.connect_terminate_count = None
562
    self.message_terminate_count = 2
563
    self.unhandled_limit = 1
564
    client1 = self.getClient()
565
    client2 = self.getClient()
566
    client1.send("one\3composed\3message\3")
567
    client2.send("c2one\3")
568
    self.mainloop.Run()
569
    self.assertEquals(self.messages[0], ["one"])
570
    self.assertEquals(self.messages[1], ["c2one"])
571
    self.assertFalse(self.connections[0].readable())
572
    self.assertFalse(self.connections[1].readable())
573
    self.connections[0].send_message("r0")
574
    self.message_terminate_count = None
575
    while self.connections[0].writable():
576
      self.connections[0].handle_write()
577
    self.assertFalse(self.connections[0].readable())
578
    self.assertEquals(self.messages[0], ["one", "composed"])
579
    self.connections[0].send_message("r2")
580
    self.connections[0].send_message("r3")
581
    while self.connections[0].writable():
582
      self.connections[0].handle_write()
583
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
584
    self.assert_(self.connections[0].readable())
585

    
586

    
587
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
588
  """Test daemon.AsyncStreamServer with a Unix path connection"""
589

    
590
  family = socket.AF_UNIX
591

    
592
  def getAddress(self):
593
    self.tmpdir = tempfile.mkdtemp()
594
    return os.path.join(self.tmpdir, "server.sock")
595

    
596
  def tearDown(self):
597
    shutil.rmtree(self.tmpdir)
598
    TestAsyncStreamServerTCP.tearDown(self)
599

    
600

    
601
class TestAsyncStreamServerUnixAbstract(TestAsyncStreamServerTCP):
602
  """Test daemon.AsyncStreamServer with a Unix abstract connection"""
603

    
604
  family = socket.AF_UNIX
605

    
606
  def getAddress(self):
607
    return "\0myabstractsocketaddress"
608

    
609

    
610
class TestAsyncAwaker(testutils.GanetiTestCase):
611
  """Test daemon.AsyncAwaker"""
612

    
613
  family = socket.AF_INET
614

    
615
  def setUp(self):
616
    testutils.GanetiTestCase.setUp(self)
617
    self.mainloop = daemon.Mainloop()
618
    self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
619
    self.signal_count = 0
620
    self.signal_terminate_count = 1
621

    
622
  def tearDown(self):
623
    self.awaker.close()
624

    
625
  def handle_signal(self):
626
    self.signal_count += 1
627
    self.signal_terminate_count -= 1
628
    if self.signal_terminate_count <= 0:
629
      os.kill(os.getpid(), signal.SIGTERM)
630

    
631
  def testBasicSignaling(self):
632
    self.awaker.signal()
633
    self.mainloop.Run()
634
    self.assertEquals(self.signal_count, 1)
635

    
636
  def testDoubleSignaling(self):
637
    self.awaker.signal()
638
    self.awaker.signal()
639
    self.mainloop.Run()
640
    # The second signal is never delivered
641
    self.assertEquals(self.signal_count, 1)
642

    
643
  def testReallyDoubleSignaling(self):
644
    self.assert_(self.awaker.readable())
645
    self.awaker.signal()
646
    # Let's suppose two threads overlap, and both find need_signal True
647
    self.awaker.need_signal = True
648
    self.awaker.signal()
649
    self.mainloop.Run()
650
    # We still get only one signaling
651
    self.assertEquals(self.signal_count, 1)
652

    
653
  def testNoSignalFnArgument(self):
654
    myawaker = daemon.AsyncAwaker()
655
    self.assertRaises(socket.error, myawaker.handle_read)
656
    myawaker.signal()
657
    myawaker.handle_read()
658
    self.assertRaises(socket.error, myawaker.handle_read)
659
    myawaker.signal()
660
    myawaker.signal()
661
    myawaker.handle_read()
662
    self.assertRaises(socket.error, myawaker.handle_read)
663
    myawaker.close()
664

    
665
  def testWrongSignalFnArgument(self):
666
    self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
667
    self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
668
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
669
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
670

    
671

    
672
if __name__ == "__main__":
673
  testutils.GanetiTestProgram()