Revision 18215385

b/test/ganeti.daemon_unittest.py
26 26
import os
27 27
import socket
28 28
import time
29
import tempfile
30
import shutil
29 31

  
30 32
from ganeti import daemon
31 33
from ganeti import errors
......
246 248
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
247 249

  
248 250

  
251
class _MyAsyncStreamServer(daemon.AsyncStreamServer):
252

  
253
  def __init__(self, family, address, handle_connection_fn):
254
    daemon.AsyncStreamServer.__init__(self, family, address)
255
    self.handle_connection_fn = handle_connection_fn
256
    self.error_count = 0
257
    self.expt_count = 0
258

  
259
  def handle_connection(self, connected_socket, client_address):
260
    self.handle_connection_fn(connected_socket, client_address)
261

  
262
  def handle_error(self):
263
    self.error_count += 1
264
    self.close()
265
    raise
266

  
267
  def handle_expt(self):
268
    self.expt_count += 1
269
    self.close()
270

  
271

  
272
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
273

  
274
  def __init__(self, connected_socket, client_address, terminator, family,
275
               message_fn, client_id):
276
    daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
277
                                                 client_address,
278
                                                 terminator, family)
279
    self.message_fn = message_fn
280
    self.client_id = client_id
281
    self.error_count = 0
282

  
283
  def handle_message(self, message, message_id):
284
    self.message_fn(self, message, message_id)
285

  
286
  def handle_error(self):
287
    self.error_count += 1
288
    raise
289

  
290

  
291
class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
292
  """Test daemon.AsyncStreamServer with a TCP connection"""
293

  
294
  family = socket.AF_INET
295

  
296
  def setUp(self):
297
    testutils.GanetiTestCase.setUp(self)
298
    self.mainloop = daemon.Mainloop()
299
    self.address = self.getAddress()
300
    self.server = _MyAsyncStreamServer(self.family, self.address,
301
                                       self.handle_connection)
302
    self.client_handler = _MyMessageStreamHandler
303
    self.terminator = "\3"
304
    self.address = self.server.getsockname()
305
    self.clients = []
306
    self.connections = []
307
    self.messages = {}
308
    self.connect_terminate_count = 0
309
    self.message_terminate_count = 0
310
    self.next_client_id = 0
311
    # Save utils.IgnoreSignals so we can do evil things to it...
312
    self.saved_utils_ignoresignals = utils.IgnoreSignals
313

  
314
  def tearDown(self):
315
    for c in self.clients:
316
      c.close()
317
    for c in self.connections:
318
      c.close()
319
    self.server.close()
320
    # ...and restore it as well
321
    utils.IgnoreSignals = self.saved_utils_ignoresignals
322
    testutils.GanetiTestCase.tearDown(self)
323

  
324
  def getAddress(self):
325
    return ("127.0.0.1", 0)
326

  
327
  def countTerminate(self, name):
328
    value = getattr(self, name)
329
    if value is not None:
330
      value -= 1
331
      setattr(self, name, value)
332
      if value <= 0:
333
        os.kill(os.getpid(), signal.SIGTERM)
334

  
335
  def handle_connection(self, connected_socket, client_address):
336
    client_id = self.next_client_id
337
    self.next_client_id += 1
338
    client_handler = self.client_handler(connected_socket, client_address,
339
                                         self.terminator, self.family,
340
                                         self.handle_message,
341
                                         client_id)
342
    self.connections.append(client_handler)
343
    self.countTerminate("connect_terminate_count")
344

  
345
  def handle_message(self, handler, message, message_id):
346
    self.messages.setdefault(handler.client_id, [])
347
    # We should just check that the message_ids are monotonically increasing.
348
    # If in the unit tests we never remove messages from the received queue,
349
    # though, we can just require that the queue length is the same as the
350
    # message id, before pushing the message to it. This forces a more
351
    # restrictive check, but we can live with this for now.
352
    self.assertEquals(len(self.messages[handler.client_id]), message_id)
353
    self.messages[handler.client_id].append(message)
354
    if message == "error":
355
      raise errors.GenericError("error")
356
    self.countTerminate("message_terminate_count")
357

  
358
  def getClient(self):
359
    client = socket.socket(self.family, socket.SOCK_STREAM)
360
    client.connect(self.address)
361
    self.clients.append(client)
362
    return client
363

  
364
  def tearDown(self):
365
    testutils.GanetiTestCase.tearDown(self)
366
    self.server.close()
367

  
368
  def testConnect(self):
369
    self.getClient()
370
    self.mainloop.Run()
371
    self.assertEquals(len(self.connections), 1)
372
    self.getClient()
373
    self.mainloop.Run()
374
    self.assertEquals(len(self.connections), 2)
375
    self.connect_terminate_count = 4
376
    self.getClient()
377
    self.getClient()
378
    self.getClient()
379
    self.getClient()
380
    self.mainloop.Run()
381
    self.assertEquals(len(self.connections), 6)
382

  
383
  def testBasicMessage(self):
384
    self.connect_terminate_count = None
385
    client = self.getClient()
386
    client.send("ciao\3")
387
    self.mainloop.Run()
388
    self.assertEquals(len(self.connections), 1)
389
    self.assertEquals(len(self.messages[0]), 1)
390
    self.assertEquals(self.messages[0][0], "ciao")
391

  
392
  def testDoubleMessage(self):
393
    self.connect_terminate_count = None
394
    client = self.getClient()
395
    client.send("ciao\3")
396
    self.mainloop.Run()
397
    client.send("foobar\3")
398
    self.mainloop.Run()
399
    self.assertEquals(len(self.connections), 1)
400
    self.assertEquals(len(self.messages[0]), 2)
401
    self.assertEquals(self.messages[0][1], "foobar")
402

  
403
  def testComposedMessage(self):
404
    self.connect_terminate_count = None
405
    self.message_terminate_count = 3
406
    client = self.getClient()
407
    client.send("one\3composed\3message\3")
408
    self.mainloop.Run()
409
    self.assertEquals(len(self.messages[0]), 3)
410
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
411

  
412
  def testLongTerminator(self):
413
    self.terminator = "\0\1\2"
414
    self.connect_terminate_count = None
415
    self.message_terminate_count = 3
416
    client = self.getClient()
417
    client.send("one\0\1\2composed\0\1\2message\0\1\2")
418
    self.mainloop.Run()
419
    self.assertEquals(len(self.messages[0]), 3)
420
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
421

  
422
  def testErrorHandling(self):
423
    self.connect_terminate_count = None
424
    self.message_terminate_count = None
425
    client = self.getClient()
426
    client.send("one\3two\3error\3three\3")
427
    self.assertRaises(errors.GenericError, self.mainloop.Run)
428
    self.assertEquals(self.connections[0].error_count, 1)
429
    self.assertEquals(self.messages[0], ["one", "two", "error"])
430
    client.send("error\3")
431
    self.assertRaises(errors.GenericError, self.mainloop.Run)
432
    self.assertEquals(self.connections[0].error_count, 2)
433
    self.assertEquals(self.messages[0], ["one", "two", "error", "three",
434
                                         "error"])
435

  
436
  def testDoubleClient(self):
437
    self.connect_terminate_count = None
438
    self.message_terminate_count = 2
439
    client1 = self.getClient()
440
    client2 = self.getClient()
441
    client1.send("c1m1\3")
442
    client2.send("c2m1\3")
443
    self.mainloop.Run()
444
    self.assertEquals(self.messages[0], ["c1m1"])
445
    self.assertEquals(self.messages[1], ["c2m1"])
446

  
447
  def testUnterminatedMessage(self):
448
    self.connect_terminate_count = None
449
    self.message_terminate_count = 3
450
    client1 = self.getClient()
451
    client2 = self.getClient()
452
    client1.send("message\3unterminated")
453
    client2.send("c2m1\3c2m2\3")
454
    self.mainloop.Run()
455
    self.assertEquals(self.messages[0], ["message"])
456
    self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
457
    client1.send("message\3")
458
    self.mainloop.Run()
459
    self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
460

  
461
  def testSignaledWhileAccepting(self):
462
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
463
    client1 = self.getClient()
464
    self.server.handle_accept()
465
    # When interrupted while accepting we don't have a connection, but we
466
    # didn't crash either.
467
    self.assertEquals(len(self.connections), 0)
468
    utils.IgnoreSignals = self.saved_utils_ignoresignals
469
    self.mainloop.Run()
470
    self.assertEquals(len(self.connections), 1)
471

  
472

  
473
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
474
  """Test daemon.AsyncStreamServer with a Unix path connection"""
475

  
476
  family = socket.AF_UNIX
477

  
478
  def getAddress(self):
479
    self.tmpdir = tempfile.mkdtemp()
480
    return os.path.join(self.tmpdir, "server.sock")
481

  
482
  def tearDown(self):
483
    shutil.rmtree(self.tmpdir)
484
    TestAsyncStreamServerTCP.tearDown(self)
485

  
486

  
249 487
if __name__ == "__main__":
250 488
  testutils.GanetiTestProgram()

Also available in: Unified diff