prepare_node_join: Move daemon SSH files to constants
[ganeti-local] / lib / tools / prepare_node_join.py
1 #
2 #
3
4 # Copyright (C) 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21 """Script to prepare a node for joining a cluster.
22
23 """
24
25 import os
26 import os.path
27 import optparse
28 import sys
29 import logging
30 import errno
31 import OpenSSL
32
33 from ganeti import cli
34 from ganeti import constants
35 from ganeti import errors
36 from ganeti import pathutils
37 from ganeti import utils
38 from ganeti import serializer
39 from ganeti import ht
40 from ganeti import ssh
41 from ganeti import ssconf
42
43
44 _SSH_KEY_LIST_ITEM = \
45   ht.TAnd(ht.TIsLength(3),
46           ht.TItems([
47             ht.TElemOf(constants.SSHK_ALL),
48             ht.Comment("public")(ht.TNonEmptyString),
49             ht.Comment("private")(ht.TNonEmptyString),
50           ]))
51
52 _SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53
54 _DATA_CHECK = ht.TStrictDict(False, True, {
55   constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56   constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57   constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58   constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59   })
60
61
62 class JoinError(errors.GenericError):
63   """Local class for reporting errors.
64
65   """
66
67
68 def ParseOptions():
69   """Parses the options passed to the program.
70
71   @return: Options and arguments
72
73   """
74   program = os.path.basename(sys.argv[0])
75
76   parser = optparse.OptionParser(usage="%prog [--dry-run]",
77                                  prog=program)
78   parser.add_option(cli.DEBUG_OPT)
79   parser.add_option(cli.VERBOSE_OPT)
80   parser.add_option(cli.DRY_RUN_OPT)
81
82   (opts, args) = parser.parse_args()
83
84   return VerifyOptions(parser, opts, args)
85
86
87 def VerifyOptions(parser, opts, args):
88   """Verifies options and arguments for correctness.
89
90   """
91   if args:
92     parser.error("No arguments are expected")
93
94   return opts
95
96
97 def SetupLogging(opts):
98   """Configures the logging module.
99
100   """
101   formatter = logging.Formatter("%(asctime)s: %(message)s")
102
103   stderr_handler = logging.StreamHandler()
104   stderr_handler.setFormatter(formatter)
105   if opts.debug:
106     stderr_handler.setLevel(logging.NOTSET)
107   elif opts.verbose:
108     stderr_handler.setLevel(logging.INFO)
109   else:
110     stderr_handler.setLevel(logging.WARNING)
111
112   root_logger = logging.getLogger("")
113   root_logger.setLevel(logging.NOTSET)
114   root_logger.addHandler(stderr_handler)
115
116
117 def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
118   """Verifies a certificate against the local node daemon certificate.
119
120   @type cert: string
121   @param cert: Certificate in PEM format (no key)
122
123   """
124   try:
125     OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
126   except OpenSSL.crypto.Error, err:
127     pass
128   else:
129     raise JoinError("No private key may be given")
130
131   try:
132     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
133   except Exception, err:
134     raise errors.X509CertError("(stdin)",
135                                "Unable to load certificate: %s" % err)
136
137   try:
138     noded_pem = utils.ReadFile(_noded_cert_file)
139   except EnvironmentError, err:
140     if err.errno != errno.ENOENT:
141       raise
142
143     logging.debug("Local node certificate was not found (file %s)",
144                   _noded_cert_file)
145     return
146
147   try:
148     key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
149   except Exception, err:
150     raise errors.X509CertError(_noded_cert_file,
151                                "Unable to load private key: %s" % err)
152
153   ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
154   ctx.use_privatekey(key)
155   ctx.use_certificate(cert)
156   try:
157     ctx.check_privatekey()
158   except OpenSSL.SSL.Error:
159     raise JoinError("Given cluster certificate does not match local key")
160
161
162 def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
163   """Verifies cluster certificate.
164
165   @type data: dict
166
167   """
168   cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
169   if cert:
170     _verify_fn(cert)
171
172
173 def _VerifyClusterName(name, _ss_cluster_name_file=None):
174   """Verifies cluster name against a local cluster name.
175
176   @type name: string
177   @param name: Cluster name
178
179   """
180   if _ss_cluster_name_file is None:
181     _ss_cluster_name_file = \
182       ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
183
184   try:
185     local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
186   except EnvironmentError, err:
187     if err.errno != errno.ENOENT:
188       raise
189
190     logging.debug("Local cluster name was not found (file %s)",
191                   _ss_cluster_name_file)
192   else:
193     if name != local_name:
194       raise JoinError("Current cluster name is '%s'" % local_name)
195
196
197 def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
198   """Verifies cluster name.
199
200   @type data: dict
201
202   """
203   name = data.get(constants.SSHS_CLUSTER_NAME)
204   if name:
205     _verify_fn(name)
206   else:
207     raise JoinError("Cluster name must be specified")
208
209
210 def _UpdateKeyFiles(keys, dry_run, keyfiles):
211   """Updates SSH key files.
212
213   @type keys: sequence of tuple; (string, string, string)
214   @param keys: Keys to write, tuples consist of key type
215     (L{constants.SSHK_ALL}), public and private key
216   @type dry_run: boolean
217   @param dry_run: Whether to perform a dry run
218   @type keyfiles: dict; (string as key, tuple with (string, string) as values)
219   @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
220     names; value tuples consist of public key filename and private key filename
221
222   """
223   assert set(keyfiles) == constants.SSHK_ALL
224
225   for (kind, private_key, public_key) in keys:
226     (private_file, public_file) = keyfiles[kind]
227
228     logging.debug("Writing %s ...", private_file)
229     utils.WriteFile(private_file, data=private_key, mode=0600,
230                     backup=True, dry_run=dry_run)
231
232     logging.debug("Writing %s ...", public_file)
233     utils.WriteFile(public_file, data=public_key, mode=0644,
234                     backup=True, dry_run=dry_run)
235
236
237 def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
238                     _keyfiles=None):
239   """Updates SSH daemon's keys.
240
241   Unless C{dry_run} is set, the daemon is restarted at the end.
242
243   @type data: dict
244   @param data: Input data
245   @type dry_run: boolean
246   @param dry_run: Whether to perform a dry run
247
248   """
249   keys = data.get(constants.SSHS_SSH_HOST_KEY)
250   if not keys:
251     return
252
253   if _keyfiles is None:
254     _keyfiles = constants.SSH_DAEMON_KEYFILES
255
256   logging.info("Updating SSH daemon key files")
257   _UpdateKeyFiles(keys, dry_run, _keyfiles)
258
259   if dry_run:
260     logging.info("This is a dry run, not restarting SSH daemon")
261   else:
262     result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
263                         interactive=True)
264     if result.failed:
265       raise JoinError("Could not reload SSH keys, command '%s'"
266                       " had exitcode %s and error %s" %
267                        (result.cmd, result.exit_code, result.output))
268
269
270 def UpdateSshRoot(data, dry_run, _homedir_fn=None):
271   """Updates root's SSH keys.
272
273   Root's C{authorized_keys} file is also updated with new public keys.
274
275   @type data: dict
276   @param data: Input data
277   @type dry_run: boolean
278   @param dry_run: Whether to perform a dry run
279
280   """
281   keys = data.get(constants.SSHS_SSH_ROOT_KEY)
282   if not keys:
283     return
284
285   (dsa_private_file, dsa_public_file, auth_keys_file) = \
286     ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
287                      kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
288   (rsa_private_file, rsa_public_file, _) = \
289     ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
290                      kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
291
292   _UpdateKeyFiles(keys, dry_run, {
293     constants.SSHK_RSA: (rsa_private_file, rsa_public_file),
294     constants.SSHK_DSA: (dsa_private_file, dsa_public_file),
295     })
296
297   if dry_run:
298     logging.info("This is a dry run, not modifying %s", auth_keys_file)
299   else:
300     for (_, _, public_key) in keys:
301       utils.AddAuthorizedKey(auth_keys_file, public_key)
302
303
304 def LoadData(raw):
305   """Parses and verifies input data.
306
307   @rtype: dict
308
309   """
310   try:
311     data = serializer.LoadJson(raw)
312   except Exception, err:
313     raise errors.ParseError("Can't parse input data: %s" % err)
314
315   if not _DATA_CHECK(data):
316     raise errors.ParseError("Input data does not match expected format: %s" %
317                             _DATA_CHECK)
318
319   return data
320
321
322 def Main():
323   """Main routine.
324
325   """
326   opts = ParseOptions()
327
328   SetupLogging(opts)
329
330   try:
331     data = LoadData(sys.stdin.read())
332
333     # Check if input data is correct
334     VerifyClusterName(data)
335     VerifyCertificate(data)
336
337     # Update SSH files
338     UpdateSshDaemon(data, opts.dry_run)
339     UpdateSshRoot(data, opts.dry_run)
340
341     logging.info("Setup finished successfully")
342   except Exception, err: # pylint: disable=W0703
343     logging.debug("Caught unhandled exception", exc_info=True)
344
345     (retcode, message) = cli.FormatError(err)
346     logging.error(message)
347
348     return retcode
349   else:
350     return constants.EXIT_SUCCESS