4 # Copyright (C) 2012 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
21 """Script to prepare a node for joining a cluster.
33 from ganeti import cli
34 from ganeti import constants
35 from ganeti import errors
36 from ganeti import pathutils
37 from ganeti import utils
38 from ganeti import serializer
40 from ganeti import ssh
41 from ganeti import ssconf
44 _SSH_KEY_LIST_ITEM = \
45 ht.TAnd(ht.TIsLength(3),
47 ht.TElemOf(constants.SSHK_ALL),
48 ht.Comment("public")(ht.TNonEmptyString),
49 ht.Comment("private")(ht.TNonEmptyString),
52 _SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
54 _DATA_CHECK = ht.TStrictDict(False, True, {
55 constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56 constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57 constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58 constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
62 class JoinError(errors.GenericError):
63 """Local class for reporting errors.
69 """Parses the options passed to the program.
71 @return: Options and arguments
74 program = os.path.basename(sys.argv[0])
76 parser = optparse.OptionParser(usage="%prog [--dry-run]",
78 parser.add_option(cli.DEBUG_OPT)
79 parser.add_option(cli.VERBOSE_OPT)
80 parser.add_option(cli.DRY_RUN_OPT)
82 (opts, args) = parser.parse_args()
84 return VerifyOptions(parser, opts, args)
87 def VerifyOptions(parser, opts, args):
88 """Verifies options and arguments for correctness.
92 parser.error("No arguments are expected")
97 def SetupLogging(opts):
98 """Configures the logging module.
101 formatter = logging.Formatter("%(asctime)s: %(message)s")
103 stderr_handler = logging.StreamHandler()
104 stderr_handler.setFormatter(formatter)
106 stderr_handler.setLevel(logging.NOTSET)
108 stderr_handler.setLevel(logging.INFO)
110 stderr_handler.setLevel(logging.WARNING)
112 root_logger = logging.getLogger("")
113 root_logger.setLevel(logging.NOTSET)
114 root_logger.addHandler(stderr_handler)
117 def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
118 """Verifies a certificate against the local node daemon certificate.
121 @param cert: Certificate in PEM format (no key)
125 OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
126 except OpenSSL.crypto.Error, err:
129 raise JoinError("No private key may be given")
132 cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
133 except Exception, err:
134 raise errors.X509CertError("(stdin)",
135 "Unable to load certificate: %s" % err)
138 noded_pem = utils.ReadFile(_noded_cert_file)
139 except EnvironmentError, err:
140 if err.errno != errno.ENOENT:
143 logging.debug("Local node certificate was not found (file %s)",
148 key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
149 except Exception, err:
150 raise errors.X509CertError(_noded_cert_file,
151 "Unable to load private key: %s" % err)
153 ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
154 ctx.use_privatekey(key)
155 ctx.use_certificate(cert)
157 ctx.check_privatekey()
158 except OpenSSL.SSL.Error:
159 raise JoinError("Given cluster certificate does not match local key")
162 def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
163 """Verifies cluster certificate.
168 cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
173 def _VerifyClusterName(name, _ss_cluster_name_file=None):
174 """Verifies cluster name against a local cluster name.
177 @param name: Cluster name
180 if _ss_cluster_name_file is None:
181 _ss_cluster_name_file = \
182 ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
185 local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
186 except EnvironmentError, err:
187 if err.errno != errno.ENOENT:
190 logging.debug("Local cluster name was not found (file %s)",
191 _ss_cluster_name_file)
193 if name != local_name:
194 raise JoinError("Current cluster name is '%s'" % local_name)
197 def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
198 """Verifies cluster name.
203 name = data.get(constants.SSHS_CLUSTER_NAME)
207 raise JoinError("Cluster name must be specified")
210 def _UpdateKeyFiles(keys, dry_run, keyfiles):
211 """Updates SSH key files.
213 @type keys: sequence of tuple; (string, string, string)
214 @param keys: Keys to write, tuples consist of key type
215 (L{constants.SSHK_ALL}), public and private key
216 @type dry_run: boolean
217 @param dry_run: Whether to perform a dry run
218 @type keyfiles: dict; (string as key, tuple with (string, string) as values)
219 @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
220 names; value tuples consist of public key filename and private key filename
223 assert set(keyfiles) == constants.SSHK_ALL
225 for (kind, private_key, public_key) in keys:
226 (private_file, public_file) = keyfiles[kind]
228 logging.debug("Writing %s ...", private_file)
229 utils.WriteFile(private_file, data=private_key, mode=0600,
230 backup=True, dry_run=dry_run)
232 logging.debug("Writing %s ...", public_file)
233 utils.WriteFile(public_file, data=public_key, mode=0644,
234 backup=True, dry_run=dry_run)
237 def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
239 """Updates SSH daemon's keys.
241 Unless C{dry_run} is set, the daemon is restarted at the end.
244 @param data: Input data
245 @type dry_run: boolean
246 @param dry_run: Whether to perform a dry run
249 keys = data.get(constants.SSHS_SSH_HOST_KEY)
253 if _keyfiles is None:
254 _keyfiles = constants.SSH_DAEMON_KEYFILES
256 logging.info("Updating SSH daemon key files")
257 _UpdateKeyFiles(keys, dry_run, _keyfiles)
260 logging.info("This is a dry run, not restarting SSH daemon")
262 result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
265 raise JoinError("Could not reload SSH keys, command '%s'"
266 " had exitcode %s and error %s" %
267 (result.cmd, result.exit_code, result.output))
270 def UpdateSshRoot(data, dry_run, _homedir_fn=None):
271 """Updates root's SSH keys.
273 Root's C{authorized_keys} file is also updated with new public keys.
276 @param data: Input data
277 @type dry_run: boolean
278 @param dry_run: Whether to perform a dry run
281 keys = data.get(constants.SSHS_SSH_ROOT_KEY)
285 (dsa_private_file, dsa_public_file, auth_keys_file) = \
286 ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
287 kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
288 (rsa_private_file, rsa_public_file, _) = \
289 ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
290 kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
292 _UpdateKeyFiles(keys, dry_run, {
293 constants.SSHK_RSA: (rsa_private_file, rsa_public_file),
294 constants.SSHK_DSA: (dsa_private_file, dsa_public_file),
298 logging.info("This is a dry run, not modifying %s", auth_keys_file)
300 for (_, _, public_key) in keys:
301 utils.AddAuthorizedKey(auth_keys_file, public_key)
305 """Parses and verifies input data.
311 data = serializer.LoadJson(raw)
312 except Exception, err:
313 raise errors.ParseError("Can't parse input data: %s" % err)
315 if not _DATA_CHECK(data):
316 raise errors.ParseError("Input data does not match expected format: %s" %
326 opts = ParseOptions()
331 data = LoadData(sys.stdin.read())
333 # Check if input data is correct
334 VerifyClusterName(data)
335 VerifyCertificate(data)
338 UpdateSshDaemon(data, opts.dry_run)
339 UpdateSshRoot(data, opts.dry_run)
341 logging.info("Setup finished successfully")
342 except Exception, err: # pylint: disable=W0703
343 logging.debug("Caught unhandled exception", exc_info=True)
345 (retcode, message) = cli.FormatError(err)
346 logging.error(message)
350 return constants.EXIT_SUCCESS