4 # Copyright (C) 2012 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21 """Script to prepare a node for joining a cluster.
33 from ganeti import cli
34 from ganeti import constants
35 from ganeti import errors
36 from ganeti import pathutils
37 from ganeti import utils
38 from ganeti import serializer
40 from ganeti import ssh
41 from ganeti import ssconf
44 _SSH_KEY_LIST_ITEM = \
45 ht.TAnd(ht.TIsLength(3),
47 ht.TElemOf(constants.SSHK_ALL),
48 ht.Comment("public")(ht.TNonEmptyString),
49 ht.Comment("private")(ht.TNonEmptyString),
52 _SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
54 _DATA_CHECK = ht.TStrictDict(False, True, {
55 constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56 constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57 constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58 constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
62 class JoinError(errors.GenericError):
63 """Local class for reporting errors.
69 """Parses the options passed to the program.
71 @return: Options and arguments
74 program = os.path.basename(sys.argv[0])
76 parser = optparse.OptionParser(usage="%prog [--dry-run]",
78 parser.add_option(cli.DEBUG_OPT)
79 parser.add_option(cli.VERBOSE_OPT)
80 parser.add_option(cli.DRY_RUN_OPT)
82 (opts, args) = parser.parse_args()
84 return VerifyOptions(parser, opts, args)
87 def VerifyOptions(parser, opts, args):
88 """Verifies options and arguments for correctness.
92 parser.error("No arguments are expected")
97 def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
98 """Verifies a certificate against the local node daemon certificate.
101 @param cert: Certificate in PEM format (no key)
105 OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
106 except OpenSSL.crypto.Error, err:
109 raise JoinError("No private key may be given")
112 cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
113 except Exception, err:
114 raise errors.X509CertError("(stdin)",
115 "Unable to load certificate: %s" % err)
118 noded_pem = utils.ReadFile(_noded_cert_file)
119 except EnvironmentError, err:
120 if err.errno != errno.ENOENT:
123 logging.debug("Local node certificate was not found (file %s)",
128 key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
129 except Exception, err:
130 raise errors.X509CertError(_noded_cert_file,
131 "Unable to load private key: %s" % err)
133 check_fn = utils.PrepareX509CertKeyCheck(cert, key)
136 except OpenSSL.SSL.Error:
137 raise JoinError("Given cluster certificate does not match local key")
140 def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
141 """Verifies cluster certificate.
146 cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
151 def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
152 """Verifies cluster name.
157 name = data.get(constants.SSHS_CLUSTER_NAME)
161 raise JoinError("Cluster name must be specified")
164 def _UpdateKeyFiles(keys, dry_run, keyfiles):
165 """Updates SSH key files.
167 @type keys: sequence of tuple; (string, string, string)
168 @param keys: Keys to write, tuples consist of key type
169 (L{constants.SSHK_ALL}), public and private key
170 @type dry_run: boolean
171 @param dry_run: Whether to perform a dry run
172 @type keyfiles: dict; (string as key, tuple with (string, string) as values)
173 @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
174 names; value tuples consist of public key filename and private key filename
177 assert set(keyfiles) == constants.SSHK_ALL
179 for (kind, private_key, public_key) in keys:
180 (private_file, public_file) = keyfiles[kind]
182 logging.debug("Writing %s ...", private_file)
183 utils.WriteFile(private_file, data=private_key, mode=0600,
184 backup=True, dry_run=dry_run)
186 logging.debug("Writing %s ...", public_file)
187 utils.WriteFile(public_file, data=public_key, mode=0644,
188 backup=True, dry_run=dry_run)
191 def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
193 """Updates SSH daemon's keys.
195 Unless C{dry_run} is set, the daemon is restarted at the end.
198 @param data: Input data
199 @type dry_run: boolean
200 @param dry_run: Whether to perform a dry run
203 keys = data.get(constants.SSHS_SSH_HOST_KEY)
207 if _keyfiles is None:
208 _keyfiles = constants.SSH_DAEMON_KEYFILES
210 logging.info("Updating SSH daemon key files")
211 _UpdateKeyFiles(keys, dry_run, _keyfiles)
214 logging.info("This is a dry run, not restarting SSH daemon")
216 result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
219 raise JoinError("Could not reload SSH keys, command '%s'"
220 " had exitcode %s and error %s" %
221 (result.cmd, result.exit_code, result.output))
224 def UpdateSshRoot(data, dry_run, _homedir_fn=None):
225 """Updates root's SSH keys.
227 Root's C{authorized_keys} file is also updated with new public keys.
230 @param data: Input data
231 @type dry_run: boolean
232 @param dry_run: Whether to perform a dry run
235 keys = data.get(constants.SSHS_SSH_ROOT_KEY)
239 (auth_keys_file, keyfiles) = \
240 ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
241 _homedir_fn=_homedir_fn)
243 _UpdateKeyFiles(keys, dry_run, keyfiles)
246 logging.info("This is a dry run, not modifying %s", auth_keys_file)
248 for (_, _, public_key) in keys:
249 utils.AddAuthorizedKey(auth_keys_file, public_key)
253 """Parses and verifies input data.
258 return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
265 opts = ParseOptions()
267 utils.SetupToolLogging(opts.debug, opts.verbose)
270 data = LoadData(sys.stdin.read())
272 # Check if input data is correct
273 VerifyClusterName(data)
274 VerifyCertificate(data)
277 UpdateSshDaemon(data, opts.dry_run)
278 UpdateSshRoot(data, opts.dry_run)
280 logging.info("Setup finished successfully")
281 except Exception, err: # pylint: disable=W0703
282 logging.debug("Caught unhandled exception", exc_info=True)
284 (retcode, message) = cli.FormatError(err)
285 logging.error(message)
289 return constants.EXIT_SUCCESS