Revision 18215385
b/test/ganeti.daemon_unittest.py | ||
---|---|---|
26 | 26 |
import os |
27 | 27 |
import socket |
28 | 28 |
import time |
29 |
import tempfile |
|
30 |
import shutil |
|
29 | 31 |
|
30 | 32 |
from ganeti import daemon |
31 | 33 |
from ganeti import errors |
... | ... | |
246 | 248 |
self.assertEquals(self.server.received, ["p1", "p2", "terminate"]) |
247 | 249 |
|
248 | 250 |
|
251 |
class _MyAsyncStreamServer(daemon.AsyncStreamServer): |
|
252 |
|
|
253 |
def __init__(self, family, address, handle_connection_fn): |
|
254 |
daemon.AsyncStreamServer.__init__(self, family, address) |
|
255 |
self.handle_connection_fn = handle_connection_fn |
|
256 |
self.error_count = 0 |
|
257 |
self.expt_count = 0 |
|
258 |
|
|
259 |
def handle_connection(self, connected_socket, client_address): |
|
260 |
self.handle_connection_fn(connected_socket, client_address) |
|
261 |
|
|
262 |
def handle_error(self): |
|
263 |
self.error_count += 1 |
|
264 |
self.close() |
|
265 |
raise |
|
266 |
|
|
267 |
def handle_expt(self): |
|
268 |
self.expt_count += 1 |
|
269 |
self.close() |
|
270 |
|
|
271 |
|
|
272 |
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream): |
|
273 |
|
|
274 |
def __init__(self, connected_socket, client_address, terminator, family, |
|
275 |
message_fn, client_id): |
|
276 |
daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket, |
|
277 |
client_address, |
|
278 |
terminator, family) |
|
279 |
self.message_fn = message_fn |
|
280 |
self.client_id = client_id |
|
281 |
self.error_count = 0 |
|
282 |
|
|
283 |
def handle_message(self, message, message_id): |
|
284 |
self.message_fn(self, message, message_id) |
|
285 |
|
|
286 |
def handle_error(self): |
|
287 |
self.error_count += 1 |
|
288 |
raise |
|
289 |
|
|
290 |
|
|
291 |
class TestAsyncStreamServerTCP(testutils.GanetiTestCase): |
|
292 |
"""Test daemon.AsyncStreamServer with a TCP connection""" |
|
293 |
|
|
294 |
family = socket.AF_INET |
|
295 |
|
|
296 |
def setUp(self): |
|
297 |
testutils.GanetiTestCase.setUp(self) |
|
298 |
self.mainloop = daemon.Mainloop() |
|
299 |
self.address = self.getAddress() |
|
300 |
self.server = _MyAsyncStreamServer(self.family, self.address, |
|
301 |
self.handle_connection) |
|
302 |
self.client_handler = _MyMessageStreamHandler |
|
303 |
self.terminator = "\3" |
|
304 |
self.address = self.server.getsockname() |
|
305 |
self.clients = [] |
|
306 |
self.connections = [] |
|
307 |
self.messages = {} |
|
308 |
self.connect_terminate_count = 0 |
|
309 |
self.message_terminate_count = 0 |
|
310 |
self.next_client_id = 0 |
|
311 |
# Save utils.IgnoreSignals so we can do evil things to it... |
|
312 |
self.saved_utils_ignoresignals = utils.IgnoreSignals |
|
313 |
|
|
314 |
def tearDown(self): |
|
315 |
for c in self.clients: |
|
316 |
c.close() |
|
317 |
for c in self.connections: |
|
318 |
c.close() |
|
319 |
self.server.close() |
|
320 |
# ...and restore it as well |
|
321 |
utils.IgnoreSignals = self.saved_utils_ignoresignals |
|
322 |
testutils.GanetiTestCase.tearDown(self) |
|
323 |
|
|
324 |
def getAddress(self): |
|
325 |
return ("127.0.0.1", 0) |
|
326 |
|
|
327 |
def countTerminate(self, name): |
|
328 |
value = getattr(self, name) |
|
329 |
if value is not None: |
|
330 |
value -= 1 |
|
331 |
setattr(self, name, value) |
|
332 |
if value <= 0: |
|
333 |
os.kill(os.getpid(), signal.SIGTERM) |
|
334 |
|
|
335 |
def handle_connection(self, connected_socket, client_address): |
|
336 |
client_id = self.next_client_id |
|
337 |
self.next_client_id += 1 |
|
338 |
client_handler = self.client_handler(connected_socket, client_address, |
|
339 |
self.terminator, self.family, |
|
340 |
self.handle_message, |
|
341 |
client_id) |
|
342 |
self.connections.append(client_handler) |
|
343 |
self.countTerminate("connect_terminate_count") |
|
344 |
|
|
345 |
def handle_message(self, handler, message, message_id): |
|
346 |
self.messages.setdefault(handler.client_id, []) |
|
347 |
# We should just check that the message_ids are monotonically increasing. |
|
348 |
# If in the unit tests we never remove messages from the received queue, |
|
349 |
# though, we can just require that the queue length is the same as the |
|
350 |
# message id, before pushing the message to it. This forces a more |
|
351 |
# restrictive check, but we can live with this for now. |
|
352 |
self.assertEquals(len(self.messages[handler.client_id]), message_id) |
|
353 |
self.messages[handler.client_id].append(message) |
|
354 |
if message == "error": |
|
355 |
raise errors.GenericError("error") |
|
356 |
self.countTerminate("message_terminate_count") |
|
357 |
|
|
358 |
def getClient(self): |
|
359 |
client = socket.socket(self.family, socket.SOCK_STREAM) |
|
360 |
client.connect(self.address) |
|
361 |
self.clients.append(client) |
|
362 |
return client |
|
363 |
|
|
364 |
def tearDown(self): |
|
365 |
testutils.GanetiTestCase.tearDown(self) |
|
366 |
self.server.close() |
|
367 |
|
|
368 |
def testConnect(self): |
|
369 |
self.getClient() |
|
370 |
self.mainloop.Run() |
|
371 |
self.assertEquals(len(self.connections), 1) |
|
372 |
self.getClient() |
|
373 |
self.mainloop.Run() |
|
374 |
self.assertEquals(len(self.connections), 2) |
|
375 |
self.connect_terminate_count = 4 |
|
376 |
self.getClient() |
|
377 |
self.getClient() |
|
378 |
self.getClient() |
|
379 |
self.getClient() |
|
380 |
self.mainloop.Run() |
|
381 |
self.assertEquals(len(self.connections), 6) |
|
382 |
|
|
383 |
def testBasicMessage(self): |
|
384 |
self.connect_terminate_count = None |
|
385 |
client = self.getClient() |
|
386 |
client.send("ciao\3") |
|
387 |
self.mainloop.Run() |
|
388 |
self.assertEquals(len(self.connections), 1) |
|
389 |
self.assertEquals(len(self.messages[0]), 1) |
|
390 |
self.assertEquals(self.messages[0][0], "ciao") |
|
391 |
|
|
392 |
def testDoubleMessage(self): |
|
393 |
self.connect_terminate_count = None |
|
394 |
client = self.getClient() |
|
395 |
client.send("ciao\3") |
|
396 |
self.mainloop.Run() |
|
397 |
client.send("foobar\3") |
|
398 |
self.mainloop.Run() |
|
399 |
self.assertEquals(len(self.connections), 1) |
|
400 |
self.assertEquals(len(self.messages[0]), 2) |
|
401 |
self.assertEquals(self.messages[0][1], "foobar") |
|
402 |
|
|
403 |
def testComposedMessage(self): |
|
404 |
self.connect_terminate_count = None |
|
405 |
self.message_terminate_count = 3 |
|
406 |
client = self.getClient() |
|
407 |
client.send("one\3composed\3message\3") |
|
408 |
self.mainloop.Run() |
|
409 |
self.assertEquals(len(self.messages[0]), 3) |
|
410 |
self.assertEquals(self.messages[0], ["one", "composed", "message"]) |
|
411 |
|
|
412 |
def testLongTerminator(self): |
|
413 |
self.terminator = "\0\1\2" |
|
414 |
self.connect_terminate_count = None |
|
415 |
self.message_terminate_count = 3 |
|
416 |
client = self.getClient() |
|
417 |
client.send("one\0\1\2composed\0\1\2message\0\1\2") |
|
418 |
self.mainloop.Run() |
|
419 |
self.assertEquals(len(self.messages[0]), 3) |
|
420 |
self.assertEquals(self.messages[0], ["one", "composed", "message"]) |
|
421 |
|
|
422 |
def testErrorHandling(self): |
|
423 |
self.connect_terminate_count = None |
|
424 |
self.message_terminate_count = None |
|
425 |
client = self.getClient() |
|
426 |
client.send("one\3two\3error\3three\3") |
|
427 |
self.assertRaises(errors.GenericError, self.mainloop.Run) |
|
428 |
self.assertEquals(self.connections[0].error_count, 1) |
|
429 |
self.assertEquals(self.messages[0], ["one", "two", "error"]) |
|
430 |
client.send("error\3") |
|
431 |
self.assertRaises(errors.GenericError, self.mainloop.Run) |
|
432 |
self.assertEquals(self.connections[0].error_count, 2) |
|
433 |
self.assertEquals(self.messages[0], ["one", "two", "error", "three", |
|
434 |
"error"]) |
|
435 |
|
|
436 |
def testDoubleClient(self): |
|
437 |
self.connect_terminate_count = None |
|
438 |
self.message_terminate_count = 2 |
|
439 |
client1 = self.getClient() |
|
440 |
client2 = self.getClient() |
|
441 |
client1.send("c1m1\3") |
|
442 |
client2.send("c2m1\3") |
|
443 |
self.mainloop.Run() |
|
444 |
self.assertEquals(self.messages[0], ["c1m1"]) |
|
445 |
self.assertEquals(self.messages[1], ["c2m1"]) |
|
446 |
|
|
447 |
def testUnterminatedMessage(self): |
|
448 |
self.connect_terminate_count = None |
|
449 |
self.message_terminate_count = 3 |
|
450 |
client1 = self.getClient() |
|
451 |
client2 = self.getClient() |
|
452 |
client1.send("message\3unterminated") |
|
453 |
client2.send("c2m1\3c2m2\3") |
|
454 |
self.mainloop.Run() |
|
455 |
self.assertEquals(self.messages[0], ["message"]) |
|
456 |
self.assertEquals(self.messages[1], ["c2m1", "c2m2"]) |
|
457 |
client1.send("message\3") |
|
458 |
self.mainloop.Run() |
|
459 |
self.assertEquals(self.messages[0], ["message", "unterminatedmessage"]) |
|
460 |
|
|
461 |
def testSignaledWhileAccepting(self): |
|
462 |
utils.IgnoreSignals = lambda fn, *args, **kwargs: None |
|
463 |
client1 = self.getClient() |
|
464 |
self.server.handle_accept() |
|
465 |
# When interrupted while accepting we don't have a connection, but we |
|
466 |
# didn't crash either. |
|
467 |
self.assertEquals(len(self.connections), 0) |
|
468 |
utils.IgnoreSignals = self.saved_utils_ignoresignals |
|
469 |
self.mainloop.Run() |
|
470 |
self.assertEquals(len(self.connections), 1) |
|
471 |
|
|
472 |
|
|
473 |
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP): |
|
474 |
"""Test daemon.AsyncStreamServer with a Unix path connection""" |
|
475 |
|
|
476 |
family = socket.AF_UNIX |
|
477 |
|
|
478 |
def getAddress(self): |
|
479 |
self.tmpdir = tempfile.mkdtemp() |
|
480 |
return os.path.join(self.tmpdir, "server.sock") |
|
481 |
|
|
482 |
def tearDown(self): |
|
483 |
shutil.rmtree(self.tmpdir) |
|
484 |
TestAsyncStreamServerTCP.tearDown(self) |
|
485 |
|
|
486 |
|
|
249 | 487 |
if __name__ == "__main__": |
250 | 488 |
testutils.GanetiTestProgram() |
Also available in: Unified diff