utils.IgnoreSignals
authorGuido Trotter <ultrotter@google.com>
Fri, 14 May 2010 13:44:23 +0000 (14:44 +0100)
committerGuido Trotter <ultrotter@google.com>
Fri, 14 May 2010 15:46:35 +0000 (16:46 +0100)
Remove duplicate code between a couple of asyncore related function by
having a function in charge of handling EINTR errors. Unittests included.

Signed-off-by: Guido Trotter <ultrotter@google.com>
Reviewed-by: Michael Hanselmann <hansmi@google.com>

lib/daemon.py
lib/utils.py
test/ganeti.utils_unittest.py

index 185f2e7..1f210b9 100644 (file)
@@ -94,15 +94,8 @@ class AsyncUDPSocket(asyncore.dispatcher):
 
   # this method is overriding an asyncore.dispatcher method
   def handle_read(self):
-    try:
-      payload, address = self.recvfrom(constants.MAX_UDP_DATA_SIZE)
-    except socket.error, err:
-      if err.errno == errno.EINTR:
-        # we got a signal while trying to read. no need to do anything,
-        # handle_read will be called again if there is data on the socket.
-        return
-      else:
-        raise
+    payload, address = utils.IgnoreSignals(self.recvfrom,
+                                           constants.MAX_UDP_DATA_SIZE)
     ip, port = address
     self.handle_datagram(payload, ip, port)
 
@@ -124,16 +117,7 @@ class AsyncUDPSocket(asyncore.dispatcher):
       logging.error("handle_write called with empty output queue")
       return
     (ip, port, payload) = self._out_queue[0]
-    try:
-      self.sendto(payload, 0, (ip, port))
-    except socket.error, err:
-      if err.errno == errno.EINTR:
-        # we got a signal while trying to write. no need to do anything,
-        # handle_write will be called again because we haven't emptied the
-        # _out_queue, and we'll try again
-        return
-      else:
-        raise
+    utils.IgnoreSignals(self.sendto, payload, 0, (ip, port))
     self._out_queue.pop(0)
 
   # this method is overriding an asyncore.dispatcher method
index c1d5c6f..3cbebe4 100644 (file)
@@ -2507,6 +2507,20 @@ def RunInSeparateProcess(fn, *args):
   return bool(exitcode)
 
 
+def IgnoreSignals(fn, *args, **kwargs):
+  """Tries to call a function ignoring failures due to EINTR.
+
+  """
+  try:
+    return fn(*args, **kwargs)
+  except (EnvironmentError, socket.error), err:
+    if err.errno != errno.EINTR:
+      raise
+  except select.error, err:
+    if not (err.args and err.args[0] == errno.EINTR):
+      raise
+
+
 def LockedMethod(fn):
   """Synchronized object access decorator.
 
index ad08f12..b8c00de 100755 (executable)
@@ -39,6 +39,7 @@ import warnings
 import distutils.version
 import glob
 import md5
+import errno
 
 import ganeti
 import testutils
@@ -1822,5 +1823,41 @@ class TestLineSplitter(unittest.TestCase):
                              "", "x"])
 
 
+class TestIgnoreSignals(unittest.TestCase):
+  """Test the IgnoreSignals decorator"""
+
+  @staticmethod
+  def _Raise(exception):
+    raise exception
+
+  @staticmethod
+  def _Return(rval):
+    return rval
+
+  def testIgnoreSignals(self):
+    sock_err_intr = socket.error(errno.EINTR, "Message")
+    sock_err_intr.errno = errno.EINTR
+    sock_err_inval = socket.error(errno.EINVAL, "Message")
+    sock_err_inval.errno = errno.EINVAL
+
+    env_err_intr = EnvironmentError(errno.EINTR, "Message")
+    env_err_inval = EnvironmentError(errno.EINVAL, "Message")
+
+    self.assertRaises(socket.error, self._Raise, sock_err_intr)
+    self.assertRaises(socket.error, self._Raise, sock_err_inval)
+    self.assertRaises(EnvironmentError, self._Raise, env_err_intr)
+    self.assertRaises(EnvironmentError, self._Raise, env_err_inval)
+
+    self.assertEquals(utils.IgnoreSignals(self._Raise, sock_err_intr), None)
+    self.assertEquals(utils.IgnoreSignals(self._Raise, env_err_intr), None)
+    self.assertRaises(socket.error, utils.IgnoreSignals, self._Raise,
+                      sock_err_inval)
+    self.assertRaises(EnvironmentError, utils.IgnoreSignals, self._Raise,
+                      env_err_inval)
+
+    self.assertEquals(utils.IgnoreSignals(self._Return, True), True)
+    self.assertEquals(utils.IgnoreSignals(self._Return, 33), 33)
+
+
 if __name__ == '__main__':
   testutils.GanetiTestProgram()