Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.daemon_unittest.py @ 5d831182

History | View | Annotate | Download (19.8 kB)

1 1118ec44 Guido Trotter
#!/usr/bin/python
2 1118ec44 Guido Trotter
#
3 1118ec44 Guido Trotter
4 1118ec44 Guido Trotter
# Copyright (C) 2010 Google Inc.
5 1118ec44 Guido Trotter
#
6 1118ec44 Guido Trotter
# This program is free software; you can redistribute it and/or modify
7 1118ec44 Guido Trotter
# it under the terms of the GNU General Public License as published by
8 1118ec44 Guido Trotter
# the Free Software Foundation; either version 2 of the License, or
9 1118ec44 Guido Trotter
# (at your option) any later version.
10 1118ec44 Guido Trotter
#
11 1118ec44 Guido Trotter
# This program is distributed in the hope that it will be useful, but
12 1118ec44 Guido Trotter
# WITHOUT ANY WARRANTY; without even the implied warranty of
13 1118ec44 Guido Trotter
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 1118ec44 Guido Trotter
# General Public License for more details.
15 1118ec44 Guido Trotter
#
16 1118ec44 Guido Trotter
# You should have received a copy of the GNU General Public License
17 1118ec44 Guido Trotter
# along with this program; if not, write to the Free Software
18 1118ec44 Guido Trotter
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 1118ec44 Guido Trotter
# 02110-1301, USA.
20 1118ec44 Guido Trotter
21 1118ec44 Guido Trotter
22 1118ec44 Guido Trotter
"""Script for unittesting the daemon module"""
23 1118ec44 Guido Trotter
24 1118ec44 Guido Trotter
import unittest
25 1118ec44 Guido Trotter
import signal
26 1118ec44 Guido Trotter
import os
27 4db33137 Guido Trotter
import socket
28 19ad29d2 Guido Trotter
import time
29 18215385 Guido Trotter
import tempfile
30 18215385 Guido Trotter
import shutil
31 1118ec44 Guido Trotter
32 1118ec44 Guido Trotter
from ganeti import daemon
33 e3cc4c69 Guido Trotter
from ganeti import errors
34 e9de7da4 Guido Trotter
from ganeti import constants
35 6e7e58b4 Guido Trotter
from ganeti import utils
36 1118ec44 Guido Trotter
37 1118ec44 Guido Trotter
import testutils
38 1118ec44 Guido Trotter
39 1118ec44 Guido Trotter
40 1118ec44 Guido Trotter
class TestMainloop(testutils.GanetiTestCase):
41 1118ec44 Guido Trotter
  """Test daemon.Mainloop"""
42 1118ec44 Guido Trotter
43 1118ec44 Guido Trotter
  def setUp(self):
44 1118ec44 Guido Trotter
    testutils.GanetiTestCase.setUp(self)
45 1118ec44 Guido Trotter
    self.mainloop = daemon.Mainloop()
46 1118ec44 Guido Trotter
    self.sendsig_events = []
47 1118ec44 Guido Trotter
    self.onsignal_events = []
48 1118ec44 Guido Trotter
49 1118ec44 Guido Trotter
  def _CancelEvent(self, handle):
50 1118ec44 Guido Trotter
    self.mainloop.scheduler.cancel(handle)
51 1118ec44 Guido Trotter
52 1118ec44 Guido Trotter
  def _SendSig(self, sig):
53 1118ec44 Guido Trotter
    self.sendsig_events.append(sig)
54 1118ec44 Guido Trotter
    os.kill(os.getpid(), sig)
55 1118ec44 Guido Trotter
56 1118ec44 Guido Trotter
  def OnSignal(self, signum):
57 1118ec44 Guido Trotter
    self.onsignal_events.append(signum)
58 1118ec44 Guido Trotter
59 1118ec44 Guido Trotter
  def testRunAndTermBySched(self):
60 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
61 1118ec44 Guido Trotter
    self.mainloop.Run() # terminates by _SendSig being scheduled
62 1118ec44 Guido Trotter
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
63 1118ec44 Guido Trotter
64 f59dce3e Guido Trotter
  def testTerminatingSignals(self):
65 f59dce3e Guido Trotter
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
66 f59dce3e Guido Trotter
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGINT])
67 f59dce3e Guido Trotter
    self.mainloop.Run()
68 f59dce3e Guido Trotter
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT])
69 f59dce3e Guido Trotter
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGTERM])
70 f59dce3e Guido Trotter
    self.mainloop.Run()
71 f59dce3e Guido Trotter
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGINT,
72 f59dce3e Guido Trotter
                                            signal.SIGTERM])
73 f59dce3e Guido Trotter
74 1118ec44 Guido Trotter
  def testSchedulerCancel(self):
75 1118ec44 Guido Trotter
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
76 1118ec44 Guido Trotter
                                           [signal.SIGTERM])
77 1118ec44 Guido Trotter
    self.mainloop.scheduler.cancel(handle)
78 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
79 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
80 1118ec44 Guido Trotter
    self.mainloop.Run()
81 1118ec44 Guido Trotter
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
82 1118ec44 Guido Trotter
83 1118ec44 Guido Trotter
  def testRegisterSignal(self):
84 1118ec44 Guido Trotter
    self.mainloop.RegisterSignal(self)
85 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
86 1118ec44 Guido Trotter
    handle = self.mainloop.scheduler.enter(0.1, 1, self._SendSig,
87 1118ec44 Guido Trotter
                                           [signal.SIGTERM])
88 1118ec44 Guido Trotter
    self.mainloop.scheduler.cancel(handle)
89 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
90 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
91 1118ec44 Guido Trotter
    # ...not delievered because they are scheduled after TERM
92 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
93 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
94 1118ec44 Guido Trotter
    self.mainloop.Run()
95 1118ec44 Guido Trotter
    self.assertEquals(self.sendsig_events,
96 1118ec44 Guido Trotter
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
97 1118ec44 Guido Trotter
    self.assertEquals(self.onsignal_events, self.sendsig_events)
98 1118ec44 Guido Trotter
99 1118ec44 Guido Trotter
  def testDeferredCancel(self):
100 1118ec44 Guido Trotter
    self.mainloop.RegisterSignal(self)
101 19ad29d2 Guido Trotter
    now = time.time()
102 19ad29d2 Guido Trotter
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
103 19ad29d2 Guido Trotter
                                     [signal.SIGCHLD])
104 19ad29d2 Guido Trotter
    handle1 = self.mainloop.scheduler.enterabs(now + 0.3, 2, self._SendSig,
105 19ad29d2 Guido Trotter
                                               [signal.SIGCHLD])
106 19ad29d2 Guido Trotter
    handle2 = self.mainloop.scheduler.enterabs(now + 0.4, 2, self._SendSig,
107 19ad29d2 Guido Trotter
                                               [signal.SIGCHLD])
108 19ad29d2 Guido Trotter
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
109 19ad29d2 Guido Trotter
                                     [handle1])
110 19ad29d2 Guido Trotter
    self.mainloop.scheduler.enterabs(now + 0.2, 1, self._CancelEvent,
111 19ad29d2 Guido Trotter
                                     [handle2])
112 1118ec44 Guido Trotter
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGTERM])
113 1118ec44 Guido Trotter
    self.mainloop.Run()
114 1118ec44 Guido Trotter
    self.assertEquals(self.sendsig_events, [signal.SIGCHLD, signal.SIGTERM])
115 1118ec44 Guido Trotter
    self.assertEquals(self.onsignal_events, self.sendsig_events)
116 1118ec44 Guido Trotter
117 c6987b16 Guido Trotter
  def testReRun(self):
118 c6987b16 Guido Trotter
    self.mainloop.RegisterSignal(self)
119 c6987b16 Guido Trotter
    self.mainloop.scheduler.enter(0.1, 1, self._SendSig, [signal.SIGCHLD])
120 c6987b16 Guido Trotter
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGCHLD])
121 c6987b16 Guido Trotter
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
122 c6987b16 Guido Trotter
    self.mainloop.scheduler.enter(0.4, 1, self._SendSig, [signal.SIGCHLD])
123 c6987b16 Guido Trotter
    self.mainloop.scheduler.enter(0.5, 1, self._SendSig, [signal.SIGCHLD])
124 c6987b16 Guido Trotter
    self.mainloop.Run()
125 c6987b16 Guido Trotter
    self.assertEquals(self.sendsig_events,
126 c6987b16 Guido Trotter
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
127 c6987b16 Guido Trotter
    self.assertEquals(self.onsignal_events, self.sendsig_events)
128 c6987b16 Guido Trotter
    self.mainloop.scheduler.enter(0.3, 1, self._SendSig, [signal.SIGTERM])
129 c6987b16 Guido Trotter
    self.mainloop.Run()
130 c6987b16 Guido Trotter
    self.assertEquals(self.sendsig_events,
131 c6987b16 Guido Trotter
                      [signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM,
132 c6987b16 Guido Trotter
                       signal.SIGCHLD, signal.SIGCHLD, signal.SIGTERM])
133 c6987b16 Guido Trotter
    self.assertEquals(self.onsignal_events, self.sendsig_events)
134 c6987b16 Guido Trotter
135 85dbfd78 Guido Trotter
  def testPriority(self):
136 85dbfd78 Guido Trotter
    # for events at the same time, the highest priority one executes first
137 85dbfd78 Guido Trotter
    now = time.time()
138 85dbfd78 Guido Trotter
    self.mainloop.scheduler.enterabs(now + 0.1, 2, self._SendSig,
139 85dbfd78 Guido Trotter
                                     [signal.SIGCHLD])
140 85dbfd78 Guido Trotter
    self.mainloop.scheduler.enterabs(now + 0.1, 1, self._SendSig,
141 85dbfd78 Guido Trotter
                                     [signal.SIGTERM])
142 85dbfd78 Guido Trotter
    self.mainloop.Run()
143 85dbfd78 Guido Trotter
    self.assertEquals(self.sendsig_events, [signal.SIGTERM])
144 85dbfd78 Guido Trotter
    self.mainloop.scheduler.enter(0.2, 1, self._SendSig, [signal.SIGTERM])
145 85dbfd78 Guido Trotter
    self.mainloop.Run()
146 85dbfd78 Guido Trotter
    self.assertEquals(self.sendsig_events,
147 85dbfd78 Guido Trotter
                      [signal.SIGTERM, signal.SIGCHLD, signal.SIGTERM])
148 85dbfd78 Guido Trotter
149 1118ec44 Guido Trotter
150 4db33137 Guido Trotter
class _MyAsyncUDPSocket(daemon.AsyncUDPSocket):
151 4db33137 Guido Trotter
152 4db33137 Guido Trotter
  def __init__(self):
153 4db33137 Guido Trotter
    daemon.AsyncUDPSocket.__init__(self)
154 4db33137 Guido Trotter
    self.received = []
155 4db33137 Guido Trotter
    self.error_count = 0
156 4db33137 Guido Trotter
157 4db33137 Guido Trotter
  def handle_datagram(self, payload, ip, port):
158 4db33137 Guido Trotter
    self.received.append((payload))
159 4db33137 Guido Trotter
    if payload == "terminate":
160 4db33137 Guido Trotter
      os.kill(os.getpid(), signal.SIGTERM)
161 4db33137 Guido Trotter
    elif payload == "error":
162 4db33137 Guido Trotter
      raise errors.GenericError("error")
163 4db33137 Guido Trotter
164 4db33137 Guido Trotter
  def handle_error(self):
165 4db33137 Guido Trotter
    self.error_count += 1
166 e3cc4c69 Guido Trotter
    raise
167 4db33137 Guido Trotter
168 4db33137 Guido Trotter
169 4db33137 Guido Trotter
class TestAsyncUDPSocket(testutils.GanetiTestCase):
170 4db33137 Guido Trotter
  """Test daemon.AsyncUDPSocket"""
171 4db33137 Guido Trotter
172 4db33137 Guido Trotter
  def setUp(self):
173 4db33137 Guido Trotter
    testutils.GanetiTestCase.setUp(self)
174 4db33137 Guido Trotter
    self.mainloop = daemon.Mainloop()
175 4db33137 Guido Trotter
    self.server = _MyAsyncUDPSocket()
176 4db33137 Guido Trotter
    self.client = _MyAsyncUDPSocket()
177 4db33137 Guido Trotter
    self.server.bind(("127.0.0.1", 0))
178 4db33137 Guido Trotter
    self.port = self.server.getsockname()[1]
179 6e7e58b4 Guido Trotter
    # Save utils.IgnoreSignals so we can do evil things to it...
180 6e7e58b4 Guido Trotter
    self.saved_utils_ignoresignals = utils.IgnoreSignals
181 4db33137 Guido Trotter
182 4db33137 Guido Trotter
  def tearDown(self):
183 4db33137 Guido Trotter
    self.server.close()
184 4db33137 Guido Trotter
    self.client.close()
185 6e7e58b4 Guido Trotter
    # ...and restore it as well
186 6e7e58b4 Guido Trotter
    utils.IgnoreSignals = self.saved_utils_ignoresignals
187 4db33137 Guido Trotter
    testutils.GanetiTestCase.tearDown(self)
188 4db33137 Guido Trotter
189 4db33137 Guido Trotter
  def testNoDoubleBind(self):
190 4db33137 Guido Trotter
    self.assertRaises(socket.error, self.client.bind, ("127.0.0.1", self.port))
191 4db33137 Guido Trotter
192 4db33137 Guido Trotter
  def testAsyncClientServer(self):
193 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
194 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
195 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
196 4db33137 Guido Trotter
    self.mainloop.Run()
197 4db33137 Guido Trotter
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
198 4db33137 Guido Trotter
199 4db33137 Guido Trotter
  def testSyncClientServer(self):
200 95ab227e Guido Trotter
    self.client.handle_write()
201 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
202 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
203 4db33137 Guido Trotter
    while self.client.writable():
204 4db33137 Guido Trotter
      self.client.handle_write()
205 4db33137 Guido Trotter
    self.server.process_next_packet()
206 4db33137 Guido Trotter
    self.assertEquals(self.server.received, ["p1"])
207 4db33137 Guido Trotter
    self.server.process_next_packet()
208 4db33137 Guido Trotter
    self.assertEquals(self.server.received, ["p1", "p2"])
209 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p3")
210 4db33137 Guido Trotter
    while self.client.writable():
211 4db33137 Guido Trotter
      self.client.handle_write()
212 4db33137 Guido Trotter
    self.server.process_next_packet()
213 4db33137 Guido Trotter
    self.assertEquals(self.server.received, ["p1", "p2", "p3"])
214 4db33137 Guido Trotter
215 4db33137 Guido Trotter
  def testErrorHandling(self):
216 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
217 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
218 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "error")
219 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p3")
220 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "error")
221 4db33137 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
222 e3cc4c69 Guido Trotter
    self.assertRaises(errors.GenericError, self.mainloop.Run)
223 e3cc4c69 Guido Trotter
    self.assertEquals(self.server.received,
224 e3cc4c69 Guido Trotter
                      ["p1", "p2", "error"])
225 e3cc4c69 Guido Trotter
    self.assertEquals(self.server.error_count, 1)
226 e3cc4c69 Guido Trotter
    self.assertRaises(errors.GenericError, self.mainloop.Run)
227 e3cc4c69 Guido Trotter
    self.assertEquals(self.server.received,
228 e3cc4c69 Guido Trotter
                      ["p1", "p2", "error", "p3", "error"])
229 e3cc4c69 Guido Trotter
    self.assertEquals(self.server.error_count, 2)
230 4db33137 Guido Trotter
    self.mainloop.Run()
231 4db33137 Guido Trotter
    self.assertEquals(self.server.received,
232 4db33137 Guido Trotter
                      ["p1", "p2", "error", "p3", "error", "terminate"])
233 4db33137 Guido Trotter
    self.assertEquals(self.server.error_count, 2)
234 4db33137 Guido Trotter
235 6e7e58b4 Guido Trotter
  def testSignaledWhileReceiving(self):
236 6e7e58b4 Guido Trotter
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
237 6e7e58b4 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p1")
238 6e7e58b4 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "p2")
239 6e7e58b4 Guido Trotter
    self.server.handle_read()
240 6e7e58b4 Guido Trotter
    self.assertEquals(self.server.received, [])
241 6e7e58b4 Guido Trotter
    self.client.enqueue_send("127.0.0.1", self.port, "terminate")
242 6e7e58b4 Guido Trotter
    utils.IgnoreSignals = self.saved_utils_ignoresignals
243 6e7e58b4 Guido Trotter
    self.mainloop.Run()
244 6e7e58b4 Guido Trotter
    self.assertEquals(self.server.received, ["p1", "p2", "terminate"])
245 6e7e58b4 Guido Trotter
246 e9de7da4 Guido Trotter
  def testOversizedDatagram(self):
247 e9de7da4 Guido Trotter
    oversized_data = (constants.MAX_UDP_DATA_SIZE + 1) * "a"
248 e9de7da4 Guido Trotter
    self.assertRaises(errors.UdpDataSizeError, self.client.enqueue_send,
249 e9de7da4 Guido Trotter
                      "127.0.0.1", self.port, oversized_data)
250 e9de7da4 Guido Trotter
251 4db33137 Guido Trotter
252 18215385 Guido Trotter
class _MyAsyncStreamServer(daemon.AsyncStreamServer):
253 18215385 Guido Trotter
254 18215385 Guido Trotter
  def __init__(self, family, address, handle_connection_fn):
255 18215385 Guido Trotter
    daemon.AsyncStreamServer.__init__(self, family, address)
256 18215385 Guido Trotter
    self.handle_connection_fn = handle_connection_fn
257 18215385 Guido Trotter
    self.error_count = 0
258 18215385 Guido Trotter
    self.expt_count = 0
259 18215385 Guido Trotter
260 18215385 Guido Trotter
  def handle_connection(self, connected_socket, client_address):
261 18215385 Guido Trotter
    self.handle_connection_fn(connected_socket, client_address)
262 18215385 Guido Trotter
263 18215385 Guido Trotter
  def handle_error(self):
264 18215385 Guido Trotter
    self.error_count += 1
265 18215385 Guido Trotter
    self.close()
266 18215385 Guido Trotter
    raise
267 18215385 Guido Trotter
268 18215385 Guido Trotter
  def handle_expt(self):
269 18215385 Guido Trotter
    self.expt_count += 1
270 18215385 Guido Trotter
    self.close()
271 18215385 Guido Trotter
272 18215385 Guido Trotter
273 18215385 Guido Trotter
class _MyMessageStreamHandler(daemon.AsyncTerminatedMessageStream):
274 18215385 Guido Trotter
275 18215385 Guido Trotter
  def __init__(self, connected_socket, client_address, terminator, family,
276 18215385 Guido Trotter
               message_fn, client_id):
277 18215385 Guido Trotter
    daemon.AsyncTerminatedMessageStream.__init__(self, connected_socket,
278 18215385 Guido Trotter
                                                 client_address,
279 18215385 Guido Trotter
                                                 terminator, family)
280 18215385 Guido Trotter
    self.message_fn = message_fn
281 18215385 Guido Trotter
    self.client_id = client_id
282 18215385 Guido Trotter
    self.error_count = 0
283 18215385 Guido Trotter
284 18215385 Guido Trotter
  def handle_message(self, message, message_id):
285 18215385 Guido Trotter
    self.message_fn(self, message, message_id)
286 18215385 Guido Trotter
287 18215385 Guido Trotter
  def handle_error(self):
288 18215385 Guido Trotter
    self.error_count += 1
289 18215385 Guido Trotter
    raise
290 18215385 Guido Trotter
291 18215385 Guido Trotter
292 18215385 Guido Trotter
class TestAsyncStreamServerTCP(testutils.GanetiTestCase):
293 18215385 Guido Trotter
  """Test daemon.AsyncStreamServer with a TCP connection"""
294 18215385 Guido Trotter
295 18215385 Guido Trotter
  family = socket.AF_INET
296 18215385 Guido Trotter
297 18215385 Guido Trotter
  def setUp(self):
298 18215385 Guido Trotter
    testutils.GanetiTestCase.setUp(self)
299 18215385 Guido Trotter
    self.mainloop = daemon.Mainloop()
300 18215385 Guido Trotter
    self.address = self.getAddress()
301 18215385 Guido Trotter
    self.server = _MyAsyncStreamServer(self.family, self.address,
302 18215385 Guido Trotter
                                       self.handle_connection)
303 18215385 Guido Trotter
    self.client_handler = _MyMessageStreamHandler
304 18215385 Guido Trotter
    self.terminator = "\3"
305 18215385 Guido Trotter
    self.address = self.server.getsockname()
306 18215385 Guido Trotter
    self.clients = []
307 18215385 Guido Trotter
    self.connections = []
308 18215385 Guido Trotter
    self.messages = {}
309 18215385 Guido Trotter
    self.connect_terminate_count = 0
310 18215385 Guido Trotter
    self.message_terminate_count = 0
311 18215385 Guido Trotter
    self.next_client_id = 0
312 18215385 Guido Trotter
    # Save utils.IgnoreSignals so we can do evil things to it...
313 18215385 Guido Trotter
    self.saved_utils_ignoresignals = utils.IgnoreSignals
314 18215385 Guido Trotter
315 18215385 Guido Trotter
  def tearDown(self):
316 18215385 Guido Trotter
    for c in self.clients:
317 18215385 Guido Trotter
      c.close()
318 18215385 Guido Trotter
    for c in self.connections:
319 18215385 Guido Trotter
      c.close()
320 18215385 Guido Trotter
    self.server.close()
321 18215385 Guido Trotter
    # ...and restore it as well
322 18215385 Guido Trotter
    utils.IgnoreSignals = self.saved_utils_ignoresignals
323 18215385 Guido Trotter
    testutils.GanetiTestCase.tearDown(self)
324 18215385 Guido Trotter
325 18215385 Guido Trotter
  def getAddress(self):
326 18215385 Guido Trotter
    return ("127.0.0.1", 0)
327 18215385 Guido Trotter
328 18215385 Guido Trotter
  def countTerminate(self, name):
329 18215385 Guido Trotter
    value = getattr(self, name)
330 18215385 Guido Trotter
    if value is not None:
331 18215385 Guido Trotter
      value -= 1
332 18215385 Guido Trotter
      setattr(self, name, value)
333 18215385 Guido Trotter
      if value <= 0:
334 18215385 Guido Trotter
        os.kill(os.getpid(), signal.SIGTERM)
335 18215385 Guido Trotter
336 18215385 Guido Trotter
  def handle_connection(self, connected_socket, client_address):
337 18215385 Guido Trotter
    client_id = self.next_client_id
338 18215385 Guido Trotter
    self.next_client_id += 1
339 18215385 Guido Trotter
    client_handler = self.client_handler(connected_socket, client_address,
340 18215385 Guido Trotter
                                         self.terminator, self.family,
341 18215385 Guido Trotter
                                         self.handle_message,
342 18215385 Guido Trotter
                                         client_id)
343 18215385 Guido Trotter
    self.connections.append(client_handler)
344 18215385 Guido Trotter
    self.countTerminate("connect_terminate_count")
345 18215385 Guido Trotter
346 18215385 Guido Trotter
  def handle_message(self, handler, message, message_id):
347 18215385 Guido Trotter
    self.messages.setdefault(handler.client_id, [])
348 18215385 Guido Trotter
    # We should just check that the message_ids are monotonically increasing.
349 18215385 Guido Trotter
    # If in the unit tests we never remove messages from the received queue,
350 18215385 Guido Trotter
    # though, we can just require that the queue length is the same as the
351 18215385 Guido Trotter
    # message id, before pushing the message to it. This forces a more
352 18215385 Guido Trotter
    # restrictive check, but we can live with this for now.
353 18215385 Guido Trotter
    self.assertEquals(len(self.messages[handler.client_id]), message_id)
354 18215385 Guido Trotter
    self.messages[handler.client_id].append(message)
355 18215385 Guido Trotter
    if message == "error":
356 18215385 Guido Trotter
      raise errors.GenericError("error")
357 18215385 Guido Trotter
    self.countTerminate("message_terminate_count")
358 18215385 Guido Trotter
359 18215385 Guido Trotter
  def getClient(self):
360 18215385 Guido Trotter
    client = socket.socket(self.family, socket.SOCK_STREAM)
361 18215385 Guido Trotter
    client.connect(self.address)
362 18215385 Guido Trotter
    self.clients.append(client)
363 18215385 Guido Trotter
    return client
364 18215385 Guido Trotter
365 18215385 Guido Trotter
  def tearDown(self):
366 18215385 Guido Trotter
    testutils.GanetiTestCase.tearDown(self)
367 18215385 Guido Trotter
    self.server.close()
368 18215385 Guido Trotter
369 18215385 Guido Trotter
  def testConnect(self):
370 18215385 Guido Trotter
    self.getClient()
371 18215385 Guido Trotter
    self.mainloop.Run()
372 18215385 Guido Trotter
    self.assertEquals(len(self.connections), 1)
373 18215385 Guido Trotter
    self.getClient()
374 18215385 Guido Trotter
    self.mainloop.Run()
375 18215385 Guido Trotter
    self.assertEquals(len(self.connections), 2)
376 18215385 Guido Trotter
    self.connect_terminate_count = 4
377 18215385 Guido Trotter
    self.getClient()
378 18215385 Guido Trotter
    self.getClient()
379 18215385 Guido Trotter
    self.getClient()
380 18215385 Guido Trotter
    self.getClient()
381 18215385 Guido Trotter
    self.mainloop.Run()
382 18215385 Guido Trotter
    self.assertEquals(len(self.connections), 6)
383 18215385 Guido Trotter
384 18215385 Guido Trotter
  def testBasicMessage(self):
385 18215385 Guido Trotter
    self.connect_terminate_count = None
386 18215385 Guido Trotter
    client = self.getClient()
387 18215385 Guido Trotter
    client.send("ciao\3")
388 18215385 Guido Trotter
    self.mainloop.Run()
389 18215385 Guido Trotter
    self.assertEquals(len(self.connections), 1)
390 18215385 Guido Trotter
    self.assertEquals(len(self.messages[0]), 1)
391 18215385 Guido Trotter
    self.assertEquals(self.messages[0][0], "ciao")
392 18215385 Guido Trotter
393 18215385 Guido Trotter
  def testDoubleMessage(self):
394 18215385 Guido Trotter
    self.connect_terminate_count = None
395 18215385 Guido Trotter
    client = self.getClient()
396 18215385 Guido Trotter
    client.send("ciao\3")
397 18215385 Guido Trotter
    self.mainloop.Run()
398 18215385 Guido Trotter
    client.send("foobar\3")
399 18215385 Guido Trotter
    self.mainloop.Run()
400 18215385 Guido Trotter
    self.assertEquals(len(self.connections), 1)
401 18215385 Guido Trotter
    self.assertEquals(len(self.messages[0]), 2)
402 18215385 Guido Trotter
    self.assertEquals(self.messages[0][1], "foobar")
403 18215385 Guido Trotter
404 18215385 Guido Trotter
  def testComposedMessage(self):
405 18215385 Guido Trotter
    self.connect_terminate_count = None
406 18215385 Guido Trotter
    self.message_terminate_count = 3
407 18215385 Guido Trotter
    client = self.getClient()
408 18215385 Guido Trotter
    client.send("one\3composed\3message\3")
409 18215385 Guido Trotter
    self.mainloop.Run()
410 18215385 Guido Trotter
    self.assertEquals(len(self.messages[0]), 3)
411 18215385 Guido Trotter
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
412 18215385 Guido Trotter
413 18215385 Guido Trotter
  def testLongTerminator(self):
414 18215385 Guido Trotter
    self.terminator = "\0\1\2"
415 18215385 Guido Trotter
    self.connect_terminate_count = None
416 18215385 Guido Trotter
    self.message_terminate_count = 3
417 18215385 Guido Trotter
    client = self.getClient()
418 18215385 Guido Trotter
    client.send("one\0\1\2composed\0\1\2message\0\1\2")
419 18215385 Guido Trotter
    self.mainloop.Run()
420 18215385 Guido Trotter
    self.assertEquals(len(self.messages[0]), 3)
421 18215385 Guido Trotter
    self.assertEquals(self.messages[0], ["one", "composed", "message"])
422 18215385 Guido Trotter
423 18215385 Guido Trotter
  def testErrorHandling(self):
424 18215385 Guido Trotter
    self.connect_terminate_count = None
425 18215385 Guido Trotter
    self.message_terminate_count = None
426 18215385 Guido Trotter
    client = self.getClient()
427 18215385 Guido Trotter
    client.send("one\3two\3error\3three\3")
428 18215385 Guido Trotter
    self.assertRaises(errors.GenericError, self.mainloop.Run)
429 18215385 Guido Trotter
    self.assertEquals(self.connections[0].error_count, 1)
430 18215385 Guido Trotter
    self.assertEquals(self.messages[0], ["one", "two", "error"])
431 18215385 Guido Trotter
    client.send("error\3")
432 18215385 Guido Trotter
    self.assertRaises(errors.GenericError, self.mainloop.Run)
433 18215385 Guido Trotter
    self.assertEquals(self.connections[0].error_count, 2)
434 18215385 Guido Trotter
    self.assertEquals(self.messages[0], ["one", "two", "error", "three",
435 18215385 Guido Trotter
                                         "error"])
436 18215385 Guido Trotter
437 18215385 Guido Trotter
  def testDoubleClient(self):
438 18215385 Guido Trotter
    self.connect_terminate_count = None
439 18215385 Guido Trotter
    self.message_terminate_count = 2
440 18215385 Guido Trotter
    client1 = self.getClient()
441 18215385 Guido Trotter
    client2 = self.getClient()
442 18215385 Guido Trotter
    client1.send("c1m1\3")
443 18215385 Guido Trotter
    client2.send("c2m1\3")
444 18215385 Guido Trotter
    self.mainloop.Run()
445 18215385 Guido Trotter
    self.assertEquals(self.messages[0], ["c1m1"])
446 18215385 Guido Trotter
    self.assertEquals(self.messages[1], ["c2m1"])
447 18215385 Guido Trotter
448 18215385 Guido Trotter
  def testUnterminatedMessage(self):
449 18215385 Guido Trotter
    self.connect_terminate_count = None
450 18215385 Guido Trotter
    self.message_terminate_count = 3
451 18215385 Guido Trotter
    client1 = self.getClient()
452 18215385 Guido Trotter
    client2 = self.getClient()
453 18215385 Guido Trotter
    client1.send("message\3unterminated")
454 18215385 Guido Trotter
    client2.send("c2m1\3c2m2\3")
455 18215385 Guido Trotter
    self.mainloop.Run()
456 18215385 Guido Trotter
    self.assertEquals(self.messages[0], ["message"])
457 18215385 Guido Trotter
    self.assertEquals(self.messages[1], ["c2m1", "c2m2"])
458 18215385 Guido Trotter
    client1.send("message\3")
459 18215385 Guido Trotter
    self.mainloop.Run()
460 18215385 Guido Trotter
    self.assertEquals(self.messages[0], ["message", "unterminatedmessage"])
461 18215385 Guido Trotter
462 18215385 Guido Trotter
  def testSignaledWhileAccepting(self):
463 18215385 Guido Trotter
    utils.IgnoreSignals = lambda fn, *args, **kwargs: None
464 18215385 Guido Trotter
    client1 = self.getClient()
465 18215385 Guido Trotter
    self.server.handle_accept()
466 18215385 Guido Trotter
    # When interrupted while accepting we don't have a connection, but we
467 18215385 Guido Trotter
    # didn't crash either.
468 18215385 Guido Trotter
    self.assertEquals(len(self.connections), 0)
469 18215385 Guido Trotter
    utils.IgnoreSignals = self.saved_utils_ignoresignals
470 18215385 Guido Trotter
    self.mainloop.Run()
471 18215385 Guido Trotter
    self.assertEquals(len(self.connections), 1)
472 18215385 Guido Trotter
473 18215385 Guido Trotter
474 18215385 Guido Trotter
class TestAsyncStreamServerUnixPath(TestAsyncStreamServerTCP):
475 18215385 Guido Trotter
  """Test daemon.AsyncStreamServer with a Unix path connection"""
476 18215385 Guido Trotter
477 18215385 Guido Trotter
  family = socket.AF_UNIX
478 18215385 Guido Trotter
479 18215385 Guido Trotter
  def getAddress(self):
480 18215385 Guido Trotter
    self.tmpdir = tempfile.mkdtemp()
481 18215385 Guido Trotter
    return os.path.join(self.tmpdir, "server.sock")
482 18215385 Guido Trotter
483 18215385 Guido Trotter
  def tearDown(self):
484 18215385 Guido Trotter
    shutil.rmtree(self.tmpdir)
485 18215385 Guido Trotter
    TestAsyncStreamServerTCP.tearDown(self)
486 18215385 Guido Trotter
487 18215385 Guido Trotter
488 495ba852 Guido Trotter
class TestAsyncAwaker(testutils.GanetiTestCase):
489 495ba852 Guido Trotter
  """Test daemon.AsyncAwaker"""
490 495ba852 Guido Trotter
491 495ba852 Guido Trotter
  family = socket.AF_INET
492 495ba852 Guido Trotter
493 495ba852 Guido Trotter
  def setUp(self):
494 495ba852 Guido Trotter
    testutils.GanetiTestCase.setUp(self)
495 495ba852 Guido Trotter
    self.mainloop = daemon.Mainloop()
496 495ba852 Guido Trotter
    self.awaker = daemon.AsyncAwaker(signal_fn=self.handle_signal)
497 495ba852 Guido Trotter
    self.signal_count = 0
498 495ba852 Guido Trotter
    self.signal_terminate_count = 1
499 495ba852 Guido Trotter
500 495ba852 Guido Trotter
  def tearDown(self):
501 495ba852 Guido Trotter
    self.awaker.close()
502 495ba852 Guido Trotter
503 495ba852 Guido Trotter
  def handle_signal(self):
504 495ba852 Guido Trotter
    self.signal_count += 1
505 495ba852 Guido Trotter
    self.signal_terminate_count -= 1
506 495ba852 Guido Trotter
    if self.signal_terminate_count <= 0:
507 495ba852 Guido Trotter
      os.kill(os.getpid(), signal.SIGTERM)
508 495ba852 Guido Trotter
509 495ba852 Guido Trotter
  def testBasicSignaling(self):
510 495ba852 Guido Trotter
    self.awaker.signal()
511 495ba852 Guido Trotter
    self.mainloop.Run()
512 495ba852 Guido Trotter
    self.assertEquals(self.signal_count, 1)
513 495ba852 Guido Trotter
514 495ba852 Guido Trotter
  def testDoubleSignaling(self):
515 495ba852 Guido Trotter
    self.awaker.signal()
516 495ba852 Guido Trotter
    self.awaker.signal()
517 495ba852 Guido Trotter
    self.mainloop.Run()
518 495ba852 Guido Trotter
    # The second signal is never delivered
519 495ba852 Guido Trotter
    self.assertEquals(self.signal_count, 1)
520 495ba852 Guido Trotter
521 495ba852 Guido Trotter
  def testReallyDoubleSignaling(self):
522 495ba852 Guido Trotter
    self.assert_(self.awaker.readable())
523 495ba852 Guido Trotter
    self.awaker.signal()
524 495ba852 Guido Trotter
    # Let's suppose two threads overlap, and both find need_signal True
525 495ba852 Guido Trotter
    self.awaker.need_signal = True
526 495ba852 Guido Trotter
    self.awaker.signal()
527 495ba852 Guido Trotter
    self.mainloop.Run()
528 495ba852 Guido Trotter
    # We still get only one signaling
529 495ba852 Guido Trotter
    self.assertEquals(self.signal_count, 1)
530 495ba852 Guido Trotter
531 495ba852 Guido Trotter
  def testNoSignalFnArgument(self):
532 495ba852 Guido Trotter
    myawaker = daemon.AsyncAwaker()
533 495ba852 Guido Trotter
    self.assertRaises(socket.error, myawaker.handle_read)
534 495ba852 Guido Trotter
    myawaker.signal()
535 495ba852 Guido Trotter
    myawaker.handle_read()
536 495ba852 Guido Trotter
    self.assertRaises(socket.error, myawaker.handle_read)
537 495ba852 Guido Trotter
    myawaker.signal()
538 495ba852 Guido Trotter
    myawaker.signal()
539 495ba852 Guido Trotter
    myawaker.handle_read()
540 495ba852 Guido Trotter
    self.assertRaises(socket.error, myawaker.handle_read)
541 495ba852 Guido Trotter
    myawaker.close()
542 495ba852 Guido Trotter
543 495ba852 Guido Trotter
  def testWrongSignalFnArgument(self):
544 495ba852 Guido Trotter
    self.assertRaises(AssertionError, daemon.AsyncAwaker, 1)
545 495ba852 Guido Trotter
    self.assertRaises(AssertionError, daemon.AsyncAwaker, "string")
546 495ba852 Guido Trotter
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn=1)
547 495ba852 Guido Trotter
    self.assertRaises(AssertionError, daemon.AsyncAwaker, signal_fn="string")
548 495ba852 Guido Trotter
549 495ba852 Guido Trotter
550 1118ec44 Guido Trotter
if __name__ == "__main__":
551 1118ec44 Guido Trotter
  testutils.GanetiTestProgram()