Factorize SSL context setup for certificate check
[ganeti-local] / lib / tools / prepare_node_join.py
1 #
2 #
3
4 # Copyright (C) 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21 """Script to prepare a node for joining a cluster.
22
23 """
24
25 import os
26 import os.path
27 import optparse
28 import sys
29 import logging
30 import errno
31 import OpenSSL
32
33 from ganeti import cli
34 from ganeti import constants
35 from ganeti import errors
36 from ganeti import pathutils
37 from ganeti import utils
38 from ganeti import serializer
39 from ganeti import ht
40 from ganeti import ssh
41 from ganeti import ssconf
42
43
44 _SSH_KEY_LIST_ITEM = \
45   ht.TAnd(ht.TIsLength(3),
46           ht.TItems([
47             ht.TElemOf(constants.SSHK_ALL),
48             ht.Comment("public")(ht.TNonEmptyString),
49             ht.Comment("private")(ht.TNonEmptyString),
50           ]))
51
52 _SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53
54 _DATA_CHECK = ht.TStrictDict(False, True, {
55   constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56   constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57   constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58   constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59   })
60
61
62 class JoinError(errors.GenericError):
63   """Local class for reporting errors.
64
65   """
66
67
68 def ParseOptions():
69   """Parses the options passed to the program.
70
71   @return: Options and arguments
72
73   """
74   program = os.path.basename(sys.argv[0])
75
76   parser = optparse.OptionParser(usage="%prog [--dry-run]",
77                                  prog=program)
78   parser.add_option(cli.DEBUG_OPT)
79   parser.add_option(cli.VERBOSE_OPT)
80   parser.add_option(cli.DRY_RUN_OPT)
81
82   (opts, args) = parser.parse_args()
83
84   return VerifyOptions(parser, opts, args)
85
86
87 def VerifyOptions(parser, opts, args):
88   """Verifies options and arguments for correctness.
89
90   """
91   if args:
92     parser.error("No arguments are expected")
93
94   return opts
95
96
97 def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
98   """Verifies a certificate against the local node daemon certificate.
99
100   @type cert: string
101   @param cert: Certificate in PEM format (no key)
102
103   """
104   try:
105     OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
106   except OpenSSL.crypto.Error, err:
107     pass
108   else:
109     raise JoinError("No private key may be given")
110
111   try:
112     cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
113   except Exception, err:
114     raise errors.X509CertError("(stdin)",
115                                "Unable to load certificate: %s" % err)
116
117   try:
118     noded_pem = utils.ReadFile(_noded_cert_file)
119   except EnvironmentError, err:
120     if err.errno != errno.ENOENT:
121       raise
122
123     logging.debug("Local node certificate was not found (file %s)",
124                   _noded_cert_file)
125     return
126
127   try:
128     key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
129   except Exception, err:
130     raise errors.X509CertError(_noded_cert_file,
131                                "Unable to load private key: %s" % err)
132
133   check_fn = utils.PrepareX509CertKeyCheck(cert, key)
134   try:
135     check_fn()
136   except OpenSSL.SSL.Error:
137     raise JoinError("Given cluster certificate does not match local key")
138
139
140 def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
141   """Verifies cluster certificate.
142
143   @type data: dict
144
145   """
146   cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
147   if cert:
148     _verify_fn(cert)
149
150
151 def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
152   """Verifies cluster name.
153
154   @type data: dict
155
156   """
157   name = data.get(constants.SSHS_CLUSTER_NAME)
158   if name:
159     _verify_fn(name)
160   else:
161     raise JoinError("Cluster name must be specified")
162
163
164 def _UpdateKeyFiles(keys, dry_run, keyfiles):
165   """Updates SSH key files.
166
167   @type keys: sequence of tuple; (string, string, string)
168   @param keys: Keys to write, tuples consist of key type
169     (L{constants.SSHK_ALL}), public and private key
170   @type dry_run: boolean
171   @param dry_run: Whether to perform a dry run
172   @type keyfiles: dict; (string as key, tuple with (string, string) as values)
173   @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
174     names; value tuples consist of public key filename and private key filename
175
176   """
177   assert set(keyfiles) == constants.SSHK_ALL
178
179   for (kind, private_key, public_key) in keys:
180     (private_file, public_file) = keyfiles[kind]
181
182     logging.debug("Writing %s ...", private_file)
183     utils.WriteFile(private_file, data=private_key, mode=0600,
184                     backup=True, dry_run=dry_run)
185
186     logging.debug("Writing %s ...", public_file)
187     utils.WriteFile(public_file, data=public_key, mode=0644,
188                     backup=True, dry_run=dry_run)
189
190
191 def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
192                     _keyfiles=None):
193   """Updates SSH daemon's keys.
194
195   Unless C{dry_run} is set, the daemon is restarted at the end.
196
197   @type data: dict
198   @param data: Input data
199   @type dry_run: boolean
200   @param dry_run: Whether to perform a dry run
201
202   """
203   keys = data.get(constants.SSHS_SSH_HOST_KEY)
204   if not keys:
205     return
206
207   if _keyfiles is None:
208     _keyfiles = constants.SSH_DAEMON_KEYFILES
209
210   logging.info("Updating SSH daemon key files")
211   _UpdateKeyFiles(keys, dry_run, _keyfiles)
212
213   if dry_run:
214     logging.info("This is a dry run, not restarting SSH daemon")
215   else:
216     result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
217                         interactive=True)
218     if result.failed:
219       raise JoinError("Could not reload SSH keys, command '%s'"
220                       " had exitcode %s and error %s" %
221                        (result.cmd, result.exit_code, result.output))
222
223
224 def UpdateSshRoot(data, dry_run, _homedir_fn=None):
225   """Updates root's SSH keys.
226
227   Root's C{authorized_keys} file is also updated with new public keys.
228
229   @type data: dict
230   @param data: Input data
231   @type dry_run: boolean
232   @param dry_run: Whether to perform a dry run
233
234   """
235   keys = data.get(constants.SSHS_SSH_ROOT_KEY)
236   if not keys:
237     return
238
239   (auth_keys_file, keyfiles) = \
240     ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
241                         _homedir_fn=_homedir_fn)
242
243   _UpdateKeyFiles(keys, dry_run, keyfiles)
244
245   if dry_run:
246     logging.info("This is a dry run, not modifying %s", auth_keys_file)
247   else:
248     for (_, _, public_key) in keys:
249       utils.AddAuthorizedKey(auth_keys_file, public_key)
250
251
252 def LoadData(raw):
253   """Parses and verifies input data.
254
255   @rtype: dict
256
257   """
258   return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
259
260
261 def Main():
262   """Main routine.
263
264   """
265   opts = ParseOptions()
266
267   utils.SetupToolLogging(opts.debug, opts.verbose)
268
269   try:
270     data = LoadData(sys.stdin.read())
271
272     # Check if input data is correct
273     VerifyClusterName(data)
274     VerifyCertificate(data)
275
276     # Update SSH files
277     UpdateSshDaemon(data, opts.dry_run)
278     UpdateSshRoot(data, opts.dry_run)
279
280     logging.info("Setup finished successfully")
281   except Exception, err: # pylint: disable=W0703
282     logging.debug("Caught unhandled exception", exc_info=True)
283
284     (retcode, message) = cli.FormatError(err)
285     logging.error(message)
286
287     return retcode
288   else:
289     return constants.EXIT_SUCCESS