AsyncTerminatedMessageStream: limit message count
authorGuido Trotter <ultrotter@google.com>
Fri, 25 Jun 2010 15:44:35 +0000 (17:44 +0200)
committerGuido Trotter <ultrotter@google.com>
Tue, 29 Jun 2010 11:30:30 +0000 (12:30 +0100)
Currently the message stream can process any number of messages in
parallel (if they get dispatched to different threads or processes).
In order to limit their number we only handle messages and read from
the socket if we're under a certain limit of unanswered ones.

Signed-off-by: Guido Trotter <ultrotter@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>

lib/daemon.py
test/ganeti.daemon_unittest.py

index 98b9fce..2027a4b 100644 (file)
@@ -175,7 +175,8 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
   separator. For each complete message handle_message is called.
 
   """
-  def __init__(self, connected_socket, peer_address, terminator, family):
+  def __init__(self, connected_socket, peer_address, terminator, family,
+               unhandled_limit):
     """AsyncTerminatedMessageStream constructor.
 
     @type connected_socket: socket.socket
@@ -185,6 +186,8 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
     @param terminator: terminator separating messages in the stream
     @type family: integer
     @param family: socket family
+    @type unhandled_limit: integer or None
+    @param unhandled_limit: maximum unanswered messages
 
     """
     # python 2.4/2.5 uses conn=... while 2.6 has sock=... we have to cheat by
@@ -197,22 +200,36 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
     self.family = family
     self.peer_address = peer_address
     self.terminator = terminator
+    self.unhandled_limit = unhandled_limit
     self.set_terminator(terminator)
     self.ibuffer = []
-    self.next_incoming_message = 0
+    self.receive_count = 0
+    self.send_count = 0
     self.oqueue = collections.deque()
+    self.iqueue = collections.deque()
 
   # this method is overriding an asynchat.async_chat method
   def collect_incoming_data(self, data):
     self.ibuffer.append(data)
 
+  def _can_handle_message(self):
+    return (self.unhandled_limit is None or
+            (self.receive_count < self.send_count + self.unhandled_limit) and
+             not self.iqueue)
+
   # this method is overriding an asynchat.async_chat method
   def found_terminator(self):
     message = "".join(self.ibuffer)
     self.ibuffer = []
-    message_id = self.next_incoming_message
-    self.next_incoming_message += 1
-    self.handle_message(message, message_id)
+    message_id = self.receive_count
+    # We need to increase the receive_count after checking if the message can
+    # be handled, but before calling handle_message
+    can_handle = self._can_handle_message()
+    self.receive_count += 1
+    if can_handle:
+      self.handle_message(message, message_id)
+    else:
+      self.iqueue.append((message, message_id))
 
   def handle_message(self, message, message_id):
     """Handle a terminated message.
@@ -240,10 +257,17 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
     """
     # If we just append the message we received to the output queue, this
     # function can be safely called by multiple threads at the same time, and
-    # we don't need locking, since deques are thread safe.
+    # we don't need locking, since deques are thread safe. handle_write in the
+    # asyncore thread will handle the next input message if there are any
+    # enqueued.
     self.oqueue.append(message)
 
   # this method is overriding an asyncore.dispatcher method
+  def readable(self):
+    # read from the socket if we can handle the next requests
+    return self._can_handle_message() and asynchat.async_chat.readable(self)
+
+  # this method is overriding an asyncore.dispatcher method
   def writable(self):
     # the output queue may become full just after we called writable. This only
     # works if we know we'll have something else waking us up from the select,
@@ -253,8 +277,14 @@ class AsyncTerminatedMessageStream(asynchat.async_chat):
   # this method is overriding an asyncore.dispatcher method
   def handle_write(self):
     if self.oqueue:
+      # if we have data in the output queue, then send_message was called.
+      # this means we can process one more message from the input queue, if
+      # there are any.
       data = self.oqueue.popleft()
       self.push(data + self.terminator)
+      self.send_count += 1
+      if self.iqueue:
+        self.handle_message(*self.iqueue.popleft())
     self.initiate_send()
 
   def close_log(self):
index 1c9160e..1343130 100755 (executable)
@@ -273,10 +273,11 @@ class _MyAsyncStreamServer(daemon.AsyncStreamServer):
 class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
 
   def __init__(self, connected_socket, client_address, terminator, family,
-               message_fn, client_id):
+               message_fn, client_id, unhandled_limit):
     daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
                                                  client_address,
-                                                 terminator, family)
+                                                 terminator, family,
+                                                 unhandled_limit)
     self.message_fn = message_fn
     self.client_id = client_id
     self.error_count = 0
@@ -301,6 +302,7 @@ class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
     self.server = _MyAsyncStreamServer(self.family, self.address,
                                        self.handle_connection)
     self.client_handler = _MyMessageStreamHandler
+    self.unhandled_limit = None
     self.terminator = "\3"
     self.address = self.server.getsockname()
     self.clients = []
@@ -339,7 +341,7 @@ class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
     client_handler = self.client_handler(connected_socket, client_address,
                                          self.terminator, self.family,
                                          self.handle_message,
-                                         client_id)
+                                         client_id, self.unhandled_limit)
     self.connections.append(client_handler)
     self.countTerminate("connect_terminate_count")
 
@@ -494,6 +496,61 @@ class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
     self.assertEquals(client1.recv(4096), "r0\3r1\3r2\3")
     self.assertRaises(socket.error, client2.recv, 4096)
 
+  def testLimitedUnhandledMessages(self):
+    self.connect_terminate_count = None
+    self.message_terminate_count = 3
+    self.unhandled_limit = 2
+    client1 = self.getClient()
+    client2 = self.getClient()
+    client1.send("one\3composed\3long\3message\3")
+    client2.send("c2one\3")
+    self.mainloop.Run()
+    self.assertEquals(self.messages[0], ["one", "composed"])
+    self.assertEquals(self.messages[1], ["c2one"])
+    self.assertFalse(self.connections[0].readable())
+    self.assert_(self.connections[1].readable())
+    self.connections[0].send_message("r0")
+    self.message_terminate_count = None
+    client1.send("another\3")
+    # when we write replies messages queued also get handled, but not the ones
+    # in the socket.
+    while self.connections[0].writable():
+      self.connections[0].handle_write()
+    self.assertFalse(self.connections[0].readable())
+    self.assertEquals(self.messages[0], ["one", "composed", "long"])
+    self.connections[0].send_message("r1")
+    self.connections[0].send_message("r2")
+    while self.connections[0].writable():
+      self.connections[0].handle_write()
+    self.assertEquals(self.messages[0], ["one", "composed", "long", "message"])
+    self.assert_(self.connections[0].readable())
+
+  def testLimitedUnhandledMessagesOne(self):
+    self.connect_terminate_count = None
+    self.message_terminate_count = 2
+    self.unhandled_limit = 1
+    client1 = self.getClient()
+    client2 = self.getClient()
+    client1.send("one\3composed\3message\3")
+    client2.send("c2one\3")
+    self.mainloop.Run()
+    self.assertEquals(self.messages[0], ["one"])
+    self.assertEquals(self.messages[1], ["c2one"])
+    self.assertFalse(self.connections[0].readable())
+    self.assertFalse(self.connections[1].readable())
+    self.connections[0].send_message("r0")
+    self.message_terminate_count = None
+    while self.connections[0].writable():
+      self.connections[0].handle_write()
+    self.assertFalse(self.connections[0].readable())
+    self.assertEquals(self.messages[0], ["one", "composed"])
+    self.connections[0].send_message("r2")
+    self.connections[0].send_message("r3")
+    while self.connections[0].writable():
+      self.connections[0].handle_write()
+    self.assertEquals(self.messages[0], ["one", "composed", "message"])
+    self.assert_(self.connections[0].readable())
+
 
 class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
   """Test daemon.AsyncStreamServer with a Unix path connection"""