AsyncAwaker: use shutdown on the socketpair
[ganeti-local] / lib / daemon.py
1 #
2 #
3
4 # Copyright (C) 2006, 2007, 2008 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Module with helper classes and functions for daemons"""
23
24
25 import asyncore
26 import asynchat
27 import grp
28 import os
29 import pwd
30 import signal
31 import logging
32 import sched
33 import time
34 import socket
35 import select
36 import sys
37
38 from ganeti import utils
39 from ganeti import constants
40 from ganeti import errors
41
42
43 _DEFAULT_RUN_USER = "root"
44 _DEFAULT_RUN_GROUP = "root"
45
46
47 class SchedulerBreakout(Exception):
48   """Exception used to get out of the scheduler loop
49
50   """
51
52
53 def AsyncoreDelayFunction(timeout):
54   """Asyncore-compatible scheduler delay function.
55
56   This is a delay function for sched that, rather than actually sleeping,
57   executes asyncore events happening in the meantime.
58
59   After an event has occurred, rather than returning, it raises a
60   SchedulerBreakout exception, which will force the current scheduler.run()
61   invocation to terminate, so that we can also check for signals. The main loop
62   will then call the scheduler run again, which will allow it to actually
63   process any due events.
64
65   This is needed because scheduler.run() doesn't support a count=..., as
66   asyncore loop, and the scheduler module documents throwing exceptions from
67   inside the delay function as an allowed usage model.
68
69   """
70   asyncore.loop(timeout=timeout, count=1, use_poll=True)
71   raise SchedulerBreakout()
72
73
74 class AsyncoreScheduler(sched.scheduler):
75   """Event scheduler integrated with asyncore
76
77   """
78   def __init__(self, timefunc):
79     sched.scheduler.__init__(self, timefunc, AsyncoreDelayFunction)
80
81
82 class GanetiBaseAsyncoreDispatcher(asyncore.dispatcher):
83   """Base Ganeti Asyncore Dispacher
84
85   """
86   # this method is overriding an asyncore.dispatcher method
87   def handle_error(self):
88     """Log an error in handling any request, and proceed.
89
90     """
91     logging.exception("Error while handling asyncore request")
92
93   # this method is overriding an asyncore.dispatcher method
94   def writable(self):
95     """Most of the time we don't want to check for writability.
96
97     """
98     return False
99
100
101 def FormatAddress(family, address):
102   """Format a client's address
103
104   @type family: integer
105   @param family: socket family (one of socket.AF_*)
106   @type address: family specific (usually tuple)
107   @param address: address, as reported by this class
108
109   """
110   if family == socket.AF_INET and len(address) == 2:
111     return "%s:%d" % address
112   elif family == socket.AF_UNIX and len(address) == 3:
113     return "pid=%s, uid=%s, gid=%s" % address
114   else:
115     return str(address)
116
117
118 class AsyncStreamServer(GanetiBaseAsyncoreDispatcher):
119   """A stream server to use with asyncore.
120
121   Each request is accepted, and then dispatched to a separate asyncore
122   dispatcher to handle.
123
124   """
125
126   _REQUEST_QUEUE_SIZE = 5
127
128   def __init__(self, family, address):
129     """Constructor for AsyncUnixStreamSocket
130
131     @type family: integer
132     @param family: socket family (one of socket.AF_*)
133     @type address: address family dependent
134     @param address: address to bind the socket to
135
136     """
137     GanetiBaseAsyncoreDispatcher.__init__(self)
138     self.family = family
139     self.create_socket(self.family, socket.SOCK_STREAM)
140     self.set_reuse_addr()
141     self.bind(address)
142     self.listen(self._REQUEST_QUEUE_SIZE)
143
144   # this method is overriding an asyncore.dispatcher method
145   def handle_accept(self):
146     """Accept a new client connection.
147
148     Creates a new instance of the handler class, which will use asyncore to
149     serve the client.
150
151     """
152     accept_result = utils.IgnoreSignals(self.accept)
153     if accept_result is not None:
154       connected_socket, client_address = accept_result
155       if self.family == socket.AF_UNIX:
156         # override the client address, as for unix sockets nothing meaningful
157         # is passed in from accept anyway
158         client_address = utils.GetSocketCredentials(connected_socket)
159       logging.info("Accepted connection from %s",
160                    FormatAddress(self.family, client_address))
161       self.handle_connection(connected_socket, client_address)
162
163   def handle_connection(self, connected_socket, client_address):
164     """Handle an already accepted connection.
165
166     """
167     raise NotImplementedError
168
169
170 class AsyncTerminatedMessageStream(asynchat.async_chat):
171   """A terminator separated message stream asyncore module.
172
173   Handles a stream connection receiving messages terminated by a defined
174   separator. For each complete message handle_message is called.
175
176   """
177   def __init__(self, connected_socket, peer_address, terminator, family):
178     """AsyncTerminatedMessageStream constructor.
179
180     @type connected_socket: socket.socket
181     @param connected_socket: connected stream socket to receive messages from
182     @param peer_address: family-specific peer address
183     @type terminator: string
184     @param terminator: terminator separating messages in the stream
185     @type family: integer
186     @param family: socket family
187
188     """
189     # python 2.4/2.5 uses conn=... while 2.6 has sock=... we have to cheat by
190     # using a positional argument rather than a keyword one.
191     asynchat.async_chat.__init__(self, connected_socket)
192     self.connected_socket = connected_socket
193     # on python 2.4 there is no "family" attribute for the socket class
194     # FIXME: when we move to python 2.5 or above remove the family parameter
195     #self.family = self.connected_socket.family
196     self.family = family
197     self.peer_address = peer_address
198     self.terminator = terminator
199     self.set_terminator(terminator)
200     self.ibuffer = []
201     self.next_incoming_message = 0
202
203   # this method is overriding an asynchat.async_chat method
204   def collect_incoming_data(self, data):
205     self.ibuffer.append(data)
206
207   # this method is overriding an asynchat.async_chat method
208   def found_terminator(self):
209     message = "".join(self.ibuffer)
210     self.ibuffer = []
211     message_id = self.next_incoming_message
212     self.next_incoming_message += 1
213     self.handle_message(message, message_id)
214
215   def handle_message(self, message, message_id):
216     """Handle a terminated message.
217
218     @type message: string
219     @param message: message to handle
220     @type message_id: integer
221     @param message_id: stream's message sequence number
222
223     """
224     pass
225     # TODO: move this method to raise NotImplementedError
226     # raise NotImplementedError
227
228   def close_log(self):
229     logging.info("Closing connection from %s",
230                  FormatAddress(self.family, self.peer_address))
231     self.close()
232
233   # this method is overriding an asyncore.dispatcher method
234   def handle_expt(self):
235     self.close_log()
236
237   # this method is overriding an asyncore.dispatcher method
238   def handle_error(self):
239     """Log an error in handling any request, and proceed.
240
241     """
242     logging.exception("Error while handling asyncore request")
243     self.close_log()
244
245
246 class AsyncUDPSocket(GanetiBaseAsyncoreDispatcher):
247   """An improved asyncore udp socket.
248
249   """
250   def __init__(self):
251     """Constructor for AsyncUDPSocket
252
253     """
254     GanetiBaseAsyncoreDispatcher.__init__(self)
255     self._out_queue = []
256     self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
257
258   # this method is overriding an asyncore.dispatcher method
259   def handle_connect(self):
260     # Python thinks that the first udp message from a source qualifies as a
261     # "connect" and further ones are part of the same connection. We beg to
262     # differ and treat all messages equally.
263     pass
264
265   # this method is overriding an asyncore.dispatcher method
266   def handle_read(self):
267     recv_result = utils.IgnoreSignals(self.recvfrom,
268                                       constants.MAX_UDP_DATA_SIZE)
269     if recv_result is not None:
270       payload, address = recv_result
271       ip, port = address
272       self.handle_datagram(payload, ip, port)
273
274   def handle_datagram(self, payload, ip, port):
275     """Handle an already read udp datagram
276
277     """
278     raise NotImplementedError
279
280   # this method is overriding an asyncore.dispatcher method
281   def writable(self):
282     # We should check whether we can write to the socket only if we have
283     # something scheduled to be written
284     return bool(self._out_queue)
285
286   # this method is overriding an asyncore.dispatcher method
287   def handle_write(self):
288     if not self._out_queue:
289       logging.error("handle_write called with empty output queue")
290       return
291     (ip, port, payload) = self._out_queue[0]
292     utils.IgnoreSignals(self.sendto, payload, 0, (ip, port))
293     self._out_queue.pop(0)
294
295   def enqueue_send(self, ip, port, payload):
296     """Enqueue a datagram to be sent when possible
297
298     """
299     if len(payload) > constants.MAX_UDP_DATA_SIZE:
300       raise errors.UdpDataSizeError('Packet too big: %s > %s' % (len(payload),
301                                     constants.MAX_UDP_DATA_SIZE))
302     self._out_queue.append((ip, port, payload))
303
304   def process_next_packet(self, timeout=0):
305     """Process the next datagram, waiting for it if necessary.
306
307     @type timeout: float
308     @param timeout: how long to wait for data
309     @rtype: boolean
310     @return: True if some data has been handled, False otherwise
311
312     """
313     result = utils.WaitForFdCondition(self, select.POLLIN, timeout)
314     if result is not None and result & select.POLLIN:
315       self.handle_read()
316       return True
317     else:
318       return False
319
320
321 class AsyncAwaker(GanetiBaseAsyncoreDispatcher):
322   """A way to notify the asyncore loop that something is going on.
323
324   If an asyncore daemon is multithreaded when a thread tries to push some data
325   to a socket, the main loop handling asynchronous requests might be sleeping
326   waiting on a select(). To avoid this it can create an instance of the
327   AsyncAwaker, which other threads can use to wake it up.
328
329   """
330   def __init__(self, signal_fn=None):
331     """Constructor for AsyncAwaker
332
333     @type signal_fn: function
334     @param signal_fn: function to call when awaken
335
336     """
337     GanetiBaseAsyncoreDispatcher.__init__(self)
338     assert signal_fn == None or callable(signal_fn)
339     (self.in_socket, self.out_socket) = socket.socketpair(socket.AF_UNIX,
340                                                           socket.SOCK_STREAM)
341     self.in_socket.setblocking(0)
342     self.in_socket.shutdown(socket.SHUT_WR)
343     self.out_socket.shutdown(socket.SHUT_RD)
344     self.set_socket(self.in_socket)
345     self.need_signal = True
346     self.signal_fn = signal_fn
347     self.connected = True
348
349   # this method is overriding an asyncore.dispatcher method
350   def handle_read(self):
351     utils.IgnoreSignals(self.recv, 4096)
352     if self.signal_fn:
353       self.signal_fn()
354     self.need_signal = True
355
356   # this method is overriding an asyncore.dispatcher method
357   def close(self):
358     asyncore.dispatcher.close(self)
359     self.out_socket.close()
360
361   def signal(self):
362     """Signal the asyncore main loop.
363
364     Any data we send here will be ignored, but it will cause the select() call
365     to return.
366
367     """
368     # Yes, there is a race condition here. No, we don't care, at worst we're
369     # sending more than one wakeup token, which doesn't harm at all.
370     if self.need_signal:
371       self.need_signal = False
372       self.out_socket.send("\0")
373
374
375 class Mainloop(object):
376   """Generic mainloop for daemons
377
378   @ivar scheduler: A sched.scheduler object, which can be used to register
379     timed events
380
381   """
382   def __init__(self):
383     """Constructs a new Mainloop instance.
384
385     """
386     self._signal_wait = []
387     self.scheduler = AsyncoreScheduler(time.time)
388
389   @utils.SignalHandled([signal.SIGCHLD])
390   @utils.SignalHandled([signal.SIGTERM])
391   @utils.SignalHandled([signal.SIGINT])
392   def Run(self, signal_handlers=None):
393     """Runs the mainloop.
394
395     @type signal_handlers: dict
396     @param signal_handlers: signal->L{utils.SignalHandler} passed by decorator
397
398     """
399     assert isinstance(signal_handlers, dict) and \
400            len(signal_handlers) > 0, \
401            "Broken SignalHandled decorator"
402     running = True
403     # Start actual main loop
404     while running:
405       if not self.scheduler.empty():
406         try:
407           self.scheduler.run()
408         except SchedulerBreakout:
409           pass
410       else:
411         asyncore.loop(count=1, use_poll=True)
412
413       # Check whether a signal was raised
414       for sig in signal_handlers:
415         handler = signal_handlers[sig]
416         if handler.called:
417           self._CallSignalWaiters(sig)
418           running = sig not in (signal.SIGTERM, signal.SIGINT)
419           handler.Clear()
420
421   def _CallSignalWaiters(self, signum):
422     """Calls all signal waiters for a certain signal.
423
424     @type signum: int
425     @param signum: Signal number
426
427     """
428     for owner in self._signal_wait:
429       owner.OnSignal(signum)
430
431   def RegisterSignal(self, owner):
432     """Registers a receiver for signal notifications
433
434     The receiver must support a "OnSignal(self, signum)" function.
435
436     @type owner: instance
437     @param owner: Receiver
438
439     """
440     self._signal_wait.append(owner)
441
442
443 def GenericMain(daemon_name, optionparser, dirs, check_fn, exec_fn,
444                 multithreaded=False, console_logging=False,
445                 default_ssl_cert=None, default_ssl_key=None,
446                 user=_DEFAULT_RUN_USER, group=_DEFAULT_RUN_GROUP):
447   """Shared main function for daemons.
448
449   @type daemon_name: string
450   @param daemon_name: daemon name
451   @type optionparser: optparse.OptionParser
452   @param optionparser: initialized optionparser with daemon-specific options
453                        (common -f -d options will be handled by this module)
454   @type dirs: list of (string, integer)
455   @param dirs: list of directories that must be created if they don't exist,
456                and the permissions to be used to create them
457   @type check_fn: function which accepts (options, args)
458   @param check_fn: function that checks start conditions and exits if they're
459                    not met
460   @type exec_fn: function which accepts (options, args)
461   @param exec_fn: function that's executed with the daemon's pid file held, and
462                   runs the daemon itself.
463   @type multithreaded: bool
464   @param multithreaded: Whether the daemon uses threads
465   @type console_logging: boolean
466   @param console_logging: if True, the daemon will fall back to the system
467                           console if logging fails
468   @type default_ssl_cert: string
469   @param default_ssl_cert: Default SSL certificate path
470   @type default_ssl_key: string
471   @param default_ssl_key: Default SSL key path
472   @param user: Default user to run as
473   @type user: string
474   @param group: Default group to run as
475   @type group: string
476
477   """
478   optionparser.add_option("-f", "--foreground", dest="fork",
479                           help="Don't detach from the current terminal",
480                           default=True, action="store_false")
481   optionparser.add_option("-d", "--debug", dest="debug",
482                           help="Enable some debug messages",
483                           default=False, action="store_true")
484   optionparser.add_option("--syslog", dest="syslog",
485                           help="Enable logging to syslog (except debug"
486                           " messages); one of 'no', 'yes' or 'only' [%s]" %
487                           constants.SYSLOG_USAGE,
488                           default=constants.SYSLOG_USAGE,
489                           choices=["no", "yes", "only"])
490
491   if daemon_name in constants.DAEMONS_PORTS:
492     default_bind_address = "0.0.0.0"
493     default_port = utils.GetDaemonPort(daemon_name)
494
495     # For networked daemons we allow choosing the port and bind address
496     optionparser.add_option("-p", "--port", dest="port",
497                             help="Network port (default: %s)" % default_port,
498                             default=default_port, type="int")
499     optionparser.add_option("-b", "--bind", dest="bind_address",
500                             help=("Bind address (default: %s)" %
501                                   default_bind_address),
502                             default=default_bind_address, metavar="ADDRESS")
503
504   if default_ssl_key is not None and default_ssl_cert is not None:
505     optionparser.add_option("--no-ssl", dest="ssl",
506                             help="Do not secure HTTP protocol with SSL",
507                             default=True, action="store_false")
508     optionparser.add_option("-K", "--ssl-key", dest="ssl_key",
509                             help=("SSL key path (default: %s)" %
510                                   default_ssl_key),
511                             default=default_ssl_key, type="string",
512                             metavar="SSL_KEY_PATH")
513     optionparser.add_option("-C", "--ssl-cert", dest="ssl_cert",
514                             help=("SSL certificate path (default: %s)" %
515                                   default_ssl_cert),
516                             default=default_ssl_cert, type="string",
517                             metavar="SSL_CERT_PATH")
518
519   # Disable the use of fork(2) if the daemon uses threads
520   utils.no_fork = multithreaded
521
522   options, args = optionparser.parse_args()
523
524   if getattr(options, "ssl", False):
525     ssl_paths = {
526       "certificate": options.ssl_cert,
527       "key": options.ssl_key,
528       }
529
530     for name, path in ssl_paths.iteritems():
531       if not os.path.isfile(path):
532         print >> sys.stderr, "SSL %s file '%s' was not found" % (name, path)
533         sys.exit(constants.EXIT_FAILURE)
534
535     # TODO: By initiating http.HttpSslParams here we would only read the files
536     # once and have a proper validation (isfile returns False on directories)
537     # at the same time.
538
539   if check_fn is not None:
540     check_fn(options, args)
541
542   utils.EnsureDirs(dirs)
543
544   if options.fork:
545     try:
546       uid = pwd.getpwnam(user).pw_uid
547       gid = grp.getgrnam(group).gr_gid
548     except KeyError:
549       raise errors.ConfigurationError("User or group not existing on system:"
550                                       " %s:%s" % (user, group))
551     utils.CloseFDs()
552     utils.Daemonize(constants.DAEMONS_LOGFILES[daemon_name], uid, gid)
553
554   utils.WritePidFile(daemon_name)
555   try:
556     utils.SetupLogging(logfile=constants.DAEMONS_LOGFILES[daemon_name],
557                        debug=options.debug,
558                        stderr_logging=not options.fork,
559                        multithreaded=multithreaded,
560                        program=daemon_name,
561                        syslog=options.syslog,
562                        console_logging=console_logging)
563     logging.info("%s daemon startup", daemon_name)
564     exec_fn(options, args)
565   finally:
566     utils.RemovePidFile(daemon_name)