Fix permission errors for split users
[ganeti-local] / lib / tools / prepare_node_join.py
1 #
2 #
3
4 # Copyright (C) 2012 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21 """Script to prepare a node for joining a cluster.
22
23 """
24
25 import os
26 import os.path
27 import optparse
28 import sys
29 import logging
30 import OpenSSL
31
32 from ganeti import cli
33 from ganeti import constants
34 from ganeti import errors
35 from ganeti import pathutils
36 from ganeti import utils
37 from ganeti import serializer
38 from ganeti import ht
39 from ganeti import ssh
40 from ganeti import ssconf
41
42
43 _SSH_KEY_LIST_ITEM = \
44   ht.TAnd(ht.TIsLength(3),
45           ht.TItems([
46             ht.TElemOf(constants.SSHK_ALL),
47             ht.Comment("public")(ht.TNonEmptyString),
48             ht.Comment("private")(ht.TNonEmptyString),
49           ]))
50
51 _SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
52
53 _DATA_CHECK = ht.TStrictDict(False, True, {
54   constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
55   constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
56   constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
57   constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
58   })
59
60
61 class JoinError(errors.GenericError):
62   """Local class for reporting errors.
63
64   """
65
66
67 def ParseOptions():
68   """Parses the options passed to the program.
69
70   @return: Options and arguments
71
72   """
73   program = os.path.basename(sys.argv[0])
74
75   parser = optparse.OptionParser(usage="%prog [--dry-run]",
76                                  prog=program)
77   parser.add_option(cli.DEBUG_OPT)
78   parser.add_option(cli.VERBOSE_OPT)
79   parser.add_option(cli.DRY_RUN_OPT)
80
81   (opts, args) = parser.parse_args()
82
83   return VerifyOptions(parser, opts, args)
84
85
86 def VerifyOptions(parser, opts, args):
87   """Verifies options and arguments for correctness.
88
89   """
90   if args:
91     parser.error("No arguments are expected")
92
93   return opts
94
95
96 def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
97   """Verifies a certificate against the local node daemon certificate.
98
99   @type cert_pem: string
100   @param cert_pem: Certificate in PEM format (no key)
101
102   """
103   try:
104     OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
105   except OpenSSL.crypto.Error, err:
106     pass
107   else:
108     raise JoinError("No private key may be given")
109
110   try:
111     cert = \
112       OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
113   except Exception, err:
114     raise errors.X509CertError("(stdin)",
115                                "Unable to load certificate: %s" % err)
116
117   _check_fn(cert)
118
119
120 def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
121   """Verifies cluster certificate.
122
123   @type data: dict
124
125   """
126   cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
127   if cert:
128     _verify_fn(cert)
129
130
131 def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
132   """Verifies cluster name.
133
134   @type data: dict
135
136   """
137   name = data.get(constants.SSHS_CLUSTER_NAME)
138   if name:
139     _verify_fn(name)
140   else:
141     raise JoinError("Cluster name must be specified")
142
143
144 def _UpdateKeyFiles(keys, dry_run, keyfiles):
145   """Updates SSH key files.
146
147   @type keys: sequence of tuple; (string, string, string)
148   @param keys: Keys to write, tuples consist of key type
149     (L{constants.SSHK_ALL}), public and private key
150   @type dry_run: boolean
151   @param dry_run: Whether to perform a dry run
152   @type keyfiles: dict; (string as key, tuple with (string, string) as values)
153   @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
154     names; value tuples consist of public key filename and private key filename
155
156   """
157   assert set(keyfiles) == constants.SSHK_ALL
158
159   for (kind, private_key, public_key) in keys:
160     (private_file, public_file) = keyfiles[kind]
161
162     logging.debug("Writing %s ...", private_file)
163     utils.WriteFile(private_file, data=private_key, mode=0600,
164                     backup=True, dry_run=dry_run)
165
166     logging.debug("Writing %s ...", public_file)
167     utils.WriteFile(public_file, data=public_key, mode=0644,
168                     backup=True, dry_run=dry_run)
169
170
171 def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
172                     _keyfiles=None):
173   """Updates SSH daemon's keys.
174
175   Unless C{dry_run} is set, the daemon is restarted at the end.
176
177   @type data: dict
178   @param data: Input data
179   @type dry_run: boolean
180   @param dry_run: Whether to perform a dry run
181
182   """
183   keys = data.get(constants.SSHS_SSH_HOST_KEY)
184   if not keys:
185     return
186
187   if _keyfiles is None:
188     _keyfiles = constants.SSH_DAEMON_KEYFILES
189
190   logging.info("Updating SSH daemon key files")
191   _UpdateKeyFiles(keys, dry_run, _keyfiles)
192
193   if dry_run:
194     logging.info("This is a dry run, not restarting SSH daemon")
195   else:
196     result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
197                         interactive=True)
198     if result.failed:
199       raise JoinError("Could not reload SSH keys, command '%s'"
200                       " had exitcode %s and error %s" %
201                        (result.cmd, result.exit_code, result.output))
202
203
204 def UpdateSshRoot(data, dry_run, _homedir_fn=None):
205   """Updates root's SSH keys.
206
207   Root's C{authorized_keys} file is also updated with new public keys.
208
209   @type data: dict
210   @param data: Input data
211   @type dry_run: boolean
212   @param dry_run: Whether to perform a dry run
213
214   """
215   keys = data.get(constants.SSHS_SSH_ROOT_KEY)
216   if not keys:
217     return
218
219   (auth_keys_file, keyfiles) = \
220     ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
221                         _homedir_fn=_homedir_fn)
222
223   _UpdateKeyFiles(keys, dry_run, keyfiles)
224
225   if dry_run:
226     logging.info("This is a dry run, not modifying %s", auth_keys_file)
227   else:
228     for (_, _, public_key) in keys:
229       utils.AddAuthorizedKey(auth_keys_file, public_key)
230
231
232 def LoadData(raw):
233   """Parses and verifies input data.
234
235   @rtype: dict
236
237   """
238   return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
239
240
241 def Main():
242   """Main routine.
243
244   """
245   opts = ParseOptions()
246
247   utils.SetupToolLogging(opts.debug, opts.verbose)
248
249   try:
250     data = LoadData(sys.stdin.read())
251
252     # Check if input data is correct
253     VerifyClusterName(data)
254     VerifyCertificate(data)
255
256     # Update SSH files
257     UpdateSshDaemon(data, opts.dry_run)
258     UpdateSshRoot(data, opts.dry_run)
259
260     logging.info("Setup finished successfully")
261   except Exception, err: # pylint: disable=W0703
262     logging.debug("Caught unhandled exception", exc_info=True)
263
264     (retcode, message) = cli.FormatError(err)
265     logging.error(message)
266
267     return retcode
268   else:
269     return constants.EXIT_SUCCESS