Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ dffa96d6

History | View | Annotate | Download (7.6 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

    
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

    
43

    
44
_SSH_KEY_LIST_ITEM = \
45
  ht.TAnd(ht.TIsLength(3),
46
          ht.TItems([
47
            ht.TElemOf(constants.SSHK_ALL),
48
            ht.Comment("public")(ht.TNonEmptyString),
49
            ht.Comment("private")(ht.TNonEmptyString),
50
          ]))
51

    
52
_SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53

    
54
_DATA_CHECK = ht.TStrictDict(False, True, {
55
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59
  })
60

    
61

    
62
class JoinError(errors.GenericError):
63
  """Local class for reporting errors.
64

65
  """
66

    
67

    
68
def ParseOptions():
69
  """Parses the options passed to the program.
70

71
  @return: Options and arguments
72

73
  """
74
  program = os.path.basename(sys.argv[0])
75

    
76
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
77
                                 prog=program)
78
  parser.add_option(cli.DEBUG_OPT)
79
  parser.add_option(cli.VERBOSE_OPT)
80
  parser.add_option(cli.DRY_RUN_OPT)
81

    
82
  (opts, args) = parser.parse_args()
83

    
84
  return VerifyOptions(parser, opts, args)
85

    
86

    
87
def VerifyOptions(parser, opts, args):
88
  """Verifies options and arguments for correctness.
89

90
  """
91
  if args:
92
    parser.error("No arguments are expected")
93

    
94
  return opts
95

    
96

    
97
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
98
  """Verifies a certificate against the local node daemon certificate.
99

100
  @type cert: string
101
  @param cert: Certificate in PEM format (no key)
102

103
  """
104
  try:
105
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
106
  except OpenSSL.crypto.Error, err:
107
    pass
108
  else:
109
    raise JoinError("No private key may be given")
110

    
111
  try:
112
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
113
  except Exception, err:
114
    raise errors.X509CertError("(stdin)",
115
                               "Unable to load certificate: %s" % err)
116

    
117
  try:
118
    noded_pem = utils.ReadFile(_noded_cert_file)
119
  except EnvironmentError, err:
120
    if err.errno != errno.ENOENT:
121
      raise
122

    
123
    logging.debug("Local node certificate was not found (file %s)",
124
                  _noded_cert_file)
125
    return
126

    
127
  try:
128
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
129
  except Exception, err:
130
    raise errors.X509CertError(_noded_cert_file,
131
                               "Unable to load private key: %s" % err)
132

    
133
  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
134
  ctx.use_privatekey(key)
135
  ctx.use_certificate(cert)
136
  try:
137
    ctx.check_privatekey()
138
  except OpenSSL.SSL.Error:
139
    raise JoinError("Given cluster certificate does not match local key")
140

    
141

    
142
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
143
  """Verifies cluster certificate.
144

145
  @type data: dict
146

147
  """
148
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
149
  if cert:
150
    _verify_fn(cert)
151

    
152

    
153
def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
154
  """Verifies cluster name.
155

156
  @type data: dict
157

158
  """
159
  name = data.get(constants.SSHS_CLUSTER_NAME)
160
  if name:
161
    _verify_fn(name)
162
  else:
163
    raise JoinError("Cluster name must be specified")
164

    
165

    
166
def _UpdateKeyFiles(keys, dry_run, keyfiles):
167
  """Updates SSH key files.
168

169
  @type keys: sequence of tuple; (string, string, string)
170
  @param keys: Keys to write, tuples consist of key type
171
    (L{constants.SSHK_ALL}), public and private key
172
  @type dry_run: boolean
173
  @param dry_run: Whether to perform a dry run
174
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
175
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
176
    names; value tuples consist of public key filename and private key filename
177

178
  """
179
  assert set(keyfiles) == constants.SSHK_ALL
180

    
181
  for (kind, private_key, public_key) in keys:
182
    (private_file, public_file) = keyfiles[kind]
183

    
184
    logging.debug("Writing %s ...", private_file)
185
    utils.WriteFile(private_file, data=private_key, mode=0600,
186
                    backup=True, dry_run=dry_run)
187

    
188
    logging.debug("Writing %s ...", public_file)
189
    utils.WriteFile(public_file, data=public_key, mode=0644,
190
                    backup=True, dry_run=dry_run)
191

    
192

    
193
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
194
                    _keyfiles=None):
195
  """Updates SSH daemon's keys.
196

197
  Unless C{dry_run} is set, the daemon is restarted at the end.
198

199
  @type data: dict
200
  @param data: Input data
201
  @type dry_run: boolean
202
  @param dry_run: Whether to perform a dry run
203

204
  """
205
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
206
  if not keys:
207
    return
208

    
209
  if _keyfiles is None:
210
    _keyfiles = constants.SSH_DAEMON_KEYFILES
211

    
212
  logging.info("Updating SSH daemon key files")
213
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
214

    
215
  if dry_run:
216
    logging.info("This is a dry run, not restarting SSH daemon")
217
  else:
218
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
219
                        interactive=True)
220
    if result.failed:
221
      raise JoinError("Could not reload SSH keys, command '%s'"
222
                      " had exitcode %s and error %s" %
223
                       (result.cmd, result.exit_code, result.output))
224

    
225

    
226
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
227
  """Updates root's SSH keys.
228

229
  Root's C{authorized_keys} file is also updated with new public keys.
230

231
  @type data: dict
232
  @param data: Input data
233
  @type dry_run: boolean
234
  @param dry_run: Whether to perform a dry run
235

236
  """
237
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
238
  if not keys:
239
    return
240

    
241
  (auth_keys_file, keyfiles) = \
242
    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
243
                        _homedir_fn=_homedir_fn)
244

    
245
  _UpdateKeyFiles(keys, dry_run, keyfiles)
246

    
247
  if dry_run:
248
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
249
  else:
250
    for (_, _, public_key) in keys:
251
      utils.AddAuthorizedKey(auth_keys_file, public_key)
252

    
253

    
254
def LoadData(raw):
255
  """Parses and verifies input data.
256

257
  @rtype: dict
258

259
  """
260
  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
261

    
262

    
263
def Main():
264
  """Main routine.
265

266
  """
267
  opts = ParseOptions()
268

    
269
  utils.SetupToolLogging(opts.debug, opts.verbose)
270

    
271
  try:
272
    data = LoadData(sys.stdin.read())
273

    
274
    # Check if input data is correct
275
    VerifyClusterName(data)
276
    VerifyCertificate(data)
277

    
278
    # Update SSH files
279
    UpdateSshDaemon(data, opts.dry_run)
280
    UpdateSshRoot(data, opts.dry_run)
281

    
282
    logging.info("Setup finished successfully")
283
  except Exception, err: # pylint: disable=W0703
284
    logging.debug("Caught unhandled exception", exc_info=True)
285

    
286
    (retcode, message) = cli.FormatError(err)
287
    logging.error(message)
288

    
289
    return retcode
290
  else:
291
    return constants.EXIT_SUCCESS