rpc: Adapt the callbacks to the new encoder type
[ganeti-local] / lib / impexpd / __init__.py
1 #
2 #
3
4 # Copyright (C) 2010 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Classes and functions for import/export daemon.
23
24 """
25
26 import os
27 import re
28 import socket
29 import logging
30 import signal
31 import errno
32 import time
33 from cStringIO import StringIO
34
35 from ganeti import constants
36 from ganeti import errors
37 from ganeti import utils
38 from ganeti import netutils
39
40
41 #: Used to recognize point at which socat(1) starts to listen on its socket.
42 #: The local address is required for the remote peer to connect (in particular
43 #: the port number).
44 LISTENING_RE = re.compile(r"^listening on\s+"
45                           r"AF=(?P<family>\d+)\s+"
46                           r"(?P<address>.+):(?P<port>\d+)$", re.I)
47
48 #: Used to recognize point at which socat(1) is sending data over the wire
49 TRANSFER_LOOP_RE = re.compile(r"^starting data transfer loop with FDs\s+.*$",
50                               re.I)
51
52 SOCAT_LOG_DEBUG = "D"
53 SOCAT_LOG_INFO = "I"
54 SOCAT_LOG_NOTICE = "N"
55 SOCAT_LOG_WARNING = "W"
56 SOCAT_LOG_ERROR = "E"
57 SOCAT_LOG_FATAL = "F"
58
59 SOCAT_LOG_IGNORE = frozenset([
60   SOCAT_LOG_DEBUG,
61   SOCAT_LOG_INFO,
62   SOCAT_LOG_NOTICE,
63   ])
64
65 #: Used to parse GNU dd(1) statistics
66 DD_INFO_RE = re.compile(r"^(?P<bytes>\d+)\s*byte(?:|s)\s.*\scopied,\s*"
67                         r"(?P<seconds>[\d.]+)\s*s(?:|econds),.*$", re.I)
68
69 #: Used to ignore "N+N records in/out" on dd(1)'s stderr
70 DD_STDERR_IGNORE = re.compile(r"^\d+\+\d+\s*records\s+(?:in|out)$", re.I)
71
72 #: Signal upon which dd(1) will print statistics (on some platforms, SIGINFO is
73 #: unavailable and SIGUSR1 is used instead)
74 DD_INFO_SIGNAL = getattr(signal, "SIGINFO", signal.SIGUSR1)
75
76 #: Buffer size: at most this many bytes are transferred at once
77 BUFSIZE = 1024 * 1024
78
79 # Common options for socat
80 SOCAT_TCP_OPTS = ["keepalive", "keepidle=60", "keepintvl=10", "keepcnt=5"]
81 SOCAT_OPENSSL_OPTS = ["verify=1", "method=TLSv1",
82                       "cipher=%s" % constants.OPENSSL_CIPHERS]
83
84 if constants.SOCAT_USE_COMPRESS:
85   # Disables all compression in by OpenSSL. Only supported in patched versions
86   # of socat (as of November 2010). See INSTALL for more information.
87   SOCAT_OPENSSL_OPTS.append("compress=none")
88
89 SOCAT_OPTION_MAXLEN = 400
90
91 (PROG_OTHER,
92  PROG_SOCAT,
93  PROG_DD,
94  PROG_DD_PID,
95  PROG_EXP_SIZE) = range(1, 6)
96 PROG_ALL = frozenset([
97   PROG_OTHER,
98   PROG_SOCAT,
99   PROG_DD,
100   PROG_DD_PID,
101   PROG_EXP_SIZE,
102   ])
103
104
105 class CommandBuilder(object):
106   def __init__(self, mode, opts, socat_stderr_fd, dd_stderr_fd, dd_pid_fd):
107     """Initializes this class.
108
109     @param mode: Daemon mode (import or export)
110     @param opts: Options object
111     @type socat_stderr_fd: int
112     @param socat_stderr_fd: File descriptor socat should write its stderr to
113     @type dd_stderr_fd: int
114     @param dd_stderr_fd: File descriptor dd should write its stderr to
115     @type dd_pid_fd: int
116     @param dd_pid_fd: File descriptor the child should write dd's PID to
117
118     """
119     self._opts = opts
120     self._mode = mode
121     self._socat_stderr_fd = socat_stderr_fd
122     self._dd_stderr_fd = dd_stderr_fd
123     self._dd_pid_fd = dd_pid_fd
124
125     assert (self._opts.magic is None or
126             constants.IE_MAGIC_RE.match(self._opts.magic))
127
128   @staticmethod
129   def GetBashCommand(cmd):
130     """Prepares a command to be run in Bash.
131
132     """
133     return ["bash", "-o", "errexit", "-o", "pipefail", "-c", cmd]
134
135   def _GetSocatCommand(self):
136     """Returns the socat command.
137
138     """
139     common_addr_opts = SOCAT_TCP_OPTS + SOCAT_OPENSSL_OPTS + [
140       "key=%s" % self._opts.key,
141       "cert=%s" % self._opts.cert,
142       "cafile=%s" % self._opts.ca,
143       ]
144
145     if self._opts.bind is not None:
146       common_addr_opts.append("bind=%s" % self._opts.bind)
147
148     assert not (self._opts.ipv4 and self._opts.ipv6)
149
150     if self._opts.ipv4:
151       common_addr_opts.append("pf=ipv4")
152     elif self._opts.ipv6:
153       common_addr_opts.append("pf=ipv6")
154
155     if self._mode == constants.IEM_IMPORT:
156       if self._opts.port is None:
157         port = 0
158       else:
159         port = self._opts.port
160
161       addr1 = [
162         "OPENSSL-LISTEN:%s" % port,
163         "reuseaddr",
164
165         # Retry to listen if connection wasn't established successfully, up to
166         # 100 times a second. Note that this still leaves room for DoS attacks.
167         "forever",
168         "intervall=0.01",
169         ] + common_addr_opts
170       addr2 = ["stdout"]
171
172     elif self._mode == constants.IEM_EXPORT:
173       if self._opts.host and netutils.IP6Address.IsValid(self._opts.host):
174         host = "[%s]" % self._opts.host
175       else:
176         host = self._opts.host
177
178       addr1 = ["stdin"]
179       addr2 = [
180         "OPENSSL:%s:%s" % (host, self._opts.port),
181
182         # How long to wait per connection attempt
183         "connect-timeout=%s" % self._opts.connect_timeout,
184
185         # Retry a few times before giving up to connect (once per second)
186         "retry=%s" % self._opts.connect_retries,
187         "intervall=1",
188         ] + common_addr_opts
189
190     else:
191       raise errors.GenericError("Invalid mode '%s'" % self._mode)
192
193     for i in [addr1, addr2]:
194       for value in i:
195         if len(value) > SOCAT_OPTION_MAXLEN:
196           raise errors.GenericError("Socat option longer than %s"
197                                     " characters: %r" %
198                                     (SOCAT_OPTION_MAXLEN, value))
199         if "," in value:
200           raise errors.GenericError("Comma not allowed in socat option"
201                                     " value: %r" % value)
202
203     return [
204       constants.SOCAT_PATH,
205
206       # Log to stderr
207       "-ls",
208
209       # Log level
210       "-d", "-d",
211
212       # Buffer size
213       "-b%s" % BUFSIZE,
214
215       # Unidirectional mode, the first address is only used for reading, and the
216       # second address is only used for writing
217       "-u",
218
219       ",".join(addr1), ",".join(addr2)
220       ]
221
222   def _GetMagicCommand(self):
223     """Returns the command to read/write the magic value.
224
225     """
226     if not self._opts.magic:
227       return None
228
229     # Prefix to ensure magic isn't interpreted as option to "echo"
230     magic = "M=%s" % self._opts.magic
231
232     cmd = StringIO()
233
234     if self._mode == constants.IEM_IMPORT:
235       cmd.write("{ ")
236       cmd.write(utils.ShellQuoteArgs(["read", "-n", str(len(magic)), "magic"]))
237       cmd.write(" && ")
238       cmd.write("if test \"$magic\" != %s; then" % utils.ShellQuote(magic))
239       cmd.write(" echo %s >&2;" % utils.ShellQuote("Magic value mismatch"))
240       cmd.write(" exit 1;")
241       cmd.write("fi;")
242       cmd.write(" }")
243
244     elif self._mode == constants.IEM_EXPORT:
245       cmd.write(utils.ShellQuoteArgs(["echo", "-E", "-n", magic]))
246
247     else:
248       raise errors.GenericError("Invalid mode '%s'" % self._mode)
249
250     return cmd.getvalue()
251
252   def _GetDdCommand(self):
253     """Returns the command for measuring throughput.
254
255     """
256     dd_cmd = StringIO()
257
258     magic_cmd = self._GetMagicCommand()
259     if magic_cmd:
260       dd_cmd.write("{ ")
261       dd_cmd.write(magic_cmd)
262       dd_cmd.write(" && ")
263
264     dd_cmd.write("{ ")
265     # Setting LC_ALL since we want to parse the output and explicitely
266     # redirecting stdin, as the background process (dd) would have /dev/null as
267     # stdin otherwise
268     dd_cmd.write("LC_ALL=C dd bs=%s <&0 2>&%d & pid=${!};" %
269                  (BUFSIZE, self._dd_stderr_fd))
270     # Send PID to daemon
271     dd_cmd.write(" echo $pid >&%d;" % self._dd_pid_fd)
272     # And wait for dd
273     dd_cmd.write(" wait $pid;")
274     dd_cmd.write(" }")
275
276     if magic_cmd:
277       dd_cmd.write(" }")
278
279     return dd_cmd.getvalue()
280
281   def _GetTransportCommand(self):
282     """Returns the command for the transport part of the daemon.
283
284     """
285     socat_cmd = ("%s 2>&%d" %
286                  (utils.ShellQuoteArgs(self._GetSocatCommand()),
287                   self._socat_stderr_fd))
288     dd_cmd = self._GetDdCommand()
289
290     compr = self._opts.compress
291
292     assert compr in constants.IEC_ALL
293
294     parts = []
295
296     if self._mode == constants.IEM_IMPORT:
297       parts.append(socat_cmd)
298
299       if compr == constants.IEC_GZIP:
300         parts.append("gunzip -c")
301
302       parts.append(dd_cmd)
303
304     elif self._mode == constants.IEM_EXPORT:
305       parts.append(dd_cmd)
306
307       if compr == constants.IEC_GZIP:
308         parts.append("gzip -c")
309
310       parts.append(socat_cmd)
311
312     else:
313       raise errors.GenericError("Invalid mode '%s'" % self._mode)
314
315     # TODO: Run transport as separate user
316     # The transport uses its own shell to simplify running it as a separate user
317     # in the future.
318     return self.GetBashCommand(" | ".join(parts))
319
320   def GetCommand(self):
321     """Returns the complete child process command.
322
323     """
324     transport_cmd = self._GetTransportCommand()
325
326     buf = StringIO()
327
328     if self._opts.cmd_prefix:
329       buf.write(self._opts.cmd_prefix)
330       buf.write(" ")
331
332     buf.write(utils.ShellQuoteArgs(transport_cmd))
333
334     if self._opts.cmd_suffix:
335       buf.write(" ")
336       buf.write(self._opts.cmd_suffix)
337
338     return self.GetBashCommand(buf.getvalue())
339
340
341 def _VerifyListening(family, address, port):
342   """Verify address given as listening address by socat.
343
344   """
345   if family not in (socket.AF_INET, socket.AF_INET6):
346     raise errors.GenericError("Address family %r not supported" % family)
347
348   if (family == socket.AF_INET6 and address.startswith("[") and
349       address.endswith("]")):
350     address = address.lstrip("[").rstrip("]")
351
352   try:
353     packed_address = socket.inet_pton(family, address)
354   except socket.error:
355     raise errors.GenericError("Invalid address %r for family %s" %
356                               (address, family))
357
358   return (socket.inet_ntop(family, packed_address), port)
359
360
361 class ChildIOProcessor(object):
362   def __init__(self, debug, status_file, logger, throughput_samples, exp_size):
363     """Initializes this class.
364
365     """
366     self._debug = debug
367     self._status_file = status_file
368     self._logger = logger
369
370     self._splitter = dict([(prog, utils.LineSplitter(self._ProcessOutput, prog))
371                            for prog in PROG_ALL])
372
373     self._dd_pid = None
374     self._dd_ready = False
375     self._dd_tp_samples = throughput_samples
376     self._dd_progress = []
377
378     # Expected size of transferred data
379     self._exp_size = exp_size
380
381   def GetLineSplitter(self, prog):
382     """Returns the line splitter for a program.
383
384     """
385     return self._splitter[prog]
386
387   def FlushAll(self):
388     """Flushes all line splitters.
389
390     """
391     for ls in self._splitter.itervalues():
392       ls.flush()
393
394   def CloseAll(self):
395     """Closes all line splitters.
396
397     """
398     for ls in self._splitter.itervalues():
399       ls.close()
400     self._splitter.clear()
401
402   def NotifyDd(self):
403     """Tells dd(1) to write statistics.
404
405     """
406     if self._dd_pid is None:
407       # Can't notify
408       return False
409
410     if not self._dd_ready:
411       # There's a race condition between starting the program and sending
412       # signals.  The signal handler is only registered after some time, so we
413       # have to check whether the program is ready. If it isn't, sending a
414       # signal will invoke the default handler (and usually abort the program).
415       if not utils.IsProcessHandlingSignal(self._dd_pid, DD_INFO_SIGNAL):
416         logging.debug("dd is not yet ready for signal %s", DD_INFO_SIGNAL)
417         return False
418
419       logging.debug("dd is now handling signal %s", DD_INFO_SIGNAL)
420       self._dd_ready = True
421
422     logging.debug("Sending signal %s to PID %s", DD_INFO_SIGNAL, self._dd_pid)
423     try:
424       os.kill(self._dd_pid, DD_INFO_SIGNAL)
425     except EnvironmentError, err:
426       if err.errno != errno.ESRCH:
427         raise
428
429       # Process no longer exists
430       logging.debug("dd exited")
431       self._dd_pid = None
432
433     return True
434
435   def _ProcessOutput(self, line, prog):
436     """Takes care of child process output.
437
438     @type line: string
439     @param line: Child output line
440     @type prog: number
441     @param prog: Program from which the line originates
442
443     """
444     force_update = False
445     forward_line = line
446
447     if prog == PROG_SOCAT:
448       level = None
449       parts = line.split(None, 4)
450
451       if len(parts) == 5:
452         (_, _, _, level, msg) = parts
453
454         force_update = self._ProcessSocatOutput(self._status_file, level, msg)
455
456         if self._debug or (level and level not in SOCAT_LOG_IGNORE):
457           forward_line = "socat: %s %s" % (level, msg)
458         else:
459           forward_line = None
460       else:
461         forward_line = "socat: %s" % line
462
463     elif prog == PROG_DD:
464       (should_forward, force_update) = self._ProcessDdOutput(line)
465
466       if should_forward or self._debug:
467         forward_line = "dd: %s" % line
468       else:
469         forward_line = None
470
471     elif prog == PROG_DD_PID:
472       if self._dd_pid:
473         raise RuntimeError("dd PID reported more than once")
474       logging.debug("Received dd PID %r", line)
475       self._dd_pid = int(line)
476       forward_line = None
477
478     elif prog == PROG_EXP_SIZE:
479       logging.debug("Received predicted size %r", line)
480       forward_line = None
481
482       if line:
483         try:
484           exp_size = utils.BytesToMebibyte(int(line))
485         except (ValueError, TypeError), err:
486           logging.error("Failed to convert predicted size %r to number: %s",
487                         line, err)
488           exp_size = None
489       else:
490         exp_size = None
491
492       self._exp_size = exp_size
493
494     if forward_line:
495       self._logger.info(forward_line)
496       self._status_file.AddRecentOutput(forward_line)
497
498     self._status_file.Update(force_update)
499
500   @staticmethod
501   def _ProcessSocatOutput(status_file, level, msg):
502     """Interprets socat log output.
503
504     """
505     if level == SOCAT_LOG_NOTICE:
506       if status_file.GetListenPort() is None:
507         # TODO: Maybe implement timeout to not listen forever
508         m = LISTENING_RE.match(msg)
509         if m:
510           (_, port) = _VerifyListening(int(m.group("family")),
511                                        m.group("address"),
512                                        int(m.group("port")))
513
514           status_file.SetListenPort(port)
515           return True
516
517       if not status_file.GetConnected():
518         m = TRANSFER_LOOP_RE.match(msg)
519         if m:
520           logging.debug("Connection established")
521           status_file.SetConnected()
522           return True
523
524     return False
525
526   def _ProcessDdOutput(self, line):
527     """Interprets a line of dd(1)'s output.
528
529     """
530     m = DD_INFO_RE.match(line)
531     if m:
532       seconds = float(m.group("seconds"))
533       mbytes = utils.BytesToMebibyte(int(m.group("bytes")))
534       self._UpdateDdProgress(seconds, mbytes)
535       return (False, True)
536
537     m = DD_STDERR_IGNORE.match(line)
538     if m:
539       # Ignore
540       return (False, False)
541
542     # Forward line
543     return (True, False)
544
545   def _UpdateDdProgress(self, seconds, mbytes):
546     """Updates the internal status variables for dd(1) progress.
547
548     @type seconds: float
549     @param seconds: Timestamp of this update
550     @type mbytes: float
551     @param mbytes: Total number of MiB transferred so far
552
553     """
554     # Add latest sample
555     self._dd_progress.append((seconds, mbytes))
556
557     # Remove old samples
558     del self._dd_progress[:-self._dd_tp_samples]
559
560     # Calculate throughput
561     throughput = _CalcThroughput(self._dd_progress)
562
563     # Calculate percent and ETA
564     percent = None
565     eta = None
566
567     if self._exp_size is not None:
568       if self._exp_size != 0:
569         percent = max(0, min(100, (100.0 * mbytes) / self._exp_size))
570
571       if throughput:
572         eta = max(0, float(self._exp_size - mbytes) / throughput)
573
574     self._status_file.SetProgress(mbytes, throughput, percent, eta)
575
576
577 def _CalcThroughput(samples):
578   """Calculates the throughput in MiB/second.
579
580   @type samples: sequence
581   @param samples: List of samples, each consisting of a (timestamp, mbytes)
582                   tuple
583   @rtype: float or None
584   @return: Throughput in MiB/second
585
586   """
587   if len(samples) < 2:
588     # Can't calculate throughput
589     return None
590
591   (start_time, start_mbytes) = samples[0]
592   (end_time, end_mbytes) = samples[-1]
593
594   return (float(end_mbytes) - start_mbytes) / (float(end_time) - start_time)