Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ d12b9f66

History | View | Annotate | Download (9.8 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

    
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

    
43

    
44
_SSH_KEY_LIST = \
45
  ht.TListOf(ht.TAnd(ht.TIsLength(3),
46
                     ht.TItems([
47
                       ht.TElemOf(constants.SSHK_ALL),
48
                       ht.Comment("public")(ht.TNonEmptyString),
49
                       ht.Comment("private")(ht.TNonEmptyString),
50
                       ])))
51

    
52
_DATA_CHECK = ht.TStrictDict(False, True, {
53
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
54
  constants.SSHS_FORCE: ht.TBool,
55
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
56
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
57
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
58
  })
59

    
60
_SSHK_TO_SSHAK = {
61
  constants.SSHK_RSA: constants.SSHAK_RSA,
62
  constants.SSHK_DSA: constants.SSHAK_DSS,
63
  }
64

    
65
_SSH_DAEMON_KEYFILES = {
66
  constants.SSHK_RSA:
67
    (pathutils.SSH_HOST_RSA_PUB, pathutils.SSH_HOST_RSA_PRIV),
68
  constants.SSHK_DSA:
69
    (pathutils.SSH_HOST_DSA_PUB, pathutils.SSH_HOST_DSA_PRIV),
70
    }
71

    
72
assert frozenset(_SSHK_TO_SSHAK.keys()) == constants.SSHK_ALL
73
assert frozenset(_SSHK_TO_SSHAK.values()) == constants.SSHAK_ALL
74

    
75

    
76
class JoinError(errors.GenericError):
77
  """Local class for reporting errors.
78

79
  """
80

    
81

    
82
def ParseOptions():
83
  """Parses the options passed to the program.
84

85
  @return: Options and arguments
86

87
  """
88
  program = os.path.basename(sys.argv[0])
89

    
90
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
91
                                 prog=program)
92
  parser.add_option(cli.DEBUG_OPT)
93
  parser.add_option(cli.VERBOSE_OPT)
94
  parser.add_option(cli.DRY_RUN_OPT)
95

    
96
  (opts, args) = parser.parse_args()
97

    
98
  return VerifyOptions(parser, opts, args)
99

    
100

    
101
def VerifyOptions(parser, opts, args):
102
  """Verifies options and arguments for correctness.
103

104
  """
105
  if args:
106
    parser.error("No arguments are expected")
107

    
108
  return opts
109

    
110

    
111
def SetupLogging(opts):
112
  """Configures the logging module.
113

114
  """
115
  formatter = logging.Formatter("%(asctime)s: %(message)s")
116

    
117
  stderr_handler = logging.StreamHandler()
118
  stderr_handler.setFormatter(formatter)
119
  if opts.debug:
120
    stderr_handler.setLevel(logging.NOTSET)
121
  elif opts.verbose:
122
    stderr_handler.setLevel(logging.INFO)
123
  else:
124
    stderr_handler.setLevel(logging.WARNING)
125

    
126
  root_logger = logging.getLogger("")
127
  root_logger.setLevel(logging.NOTSET)
128
  root_logger.addHandler(stderr_handler)
129

    
130

    
131
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
132
  """Verifies a certificate against the local node daemon certificate.
133

134
  @type cert: string
135
  @param cert: Certificate in PEM format (no key)
136

137
  """
138
  try:
139
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
140
  except OpenSSL.crypto.Error, err:
141
    pass
142
  else:
143
    raise JoinError("No private key may be given")
144

    
145
  try:
146
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
147
  except Exception, err:
148
    raise errors.X509CertError("(stdin)",
149
                               "Unable to load certificate: %s" % err)
150

    
151
  try:
152
    noded_pem = utils.ReadFile(_noded_cert_file)
153
  except EnvironmentError, err:
154
    if err.errno != errno.ENOENT:
155
      raise
156

    
157
    logging.debug("Local node certificate was not found (file %s)",
158
                  _noded_cert_file)
159
    return
160

    
161
  try:
162
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
163
  except Exception, err:
164
    raise errors.X509CertError(_noded_cert_file,
165
                               "Unable to load private key: %s" % err)
166

    
167
  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
168
  ctx.use_privatekey(key)
169
  ctx.use_certificate(cert)
170
  try:
171
    ctx.check_privatekey()
172
  except OpenSSL.SSL.Error:
173
    raise JoinError("Given cluster certificate does not match local key")
174

    
175

    
176
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
177
  """Verifies cluster certificate.
178

179
  @type data: dict
180

181
  """
182
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
183
  if cert:
184
    _verify_fn(cert)
185

    
186

    
187
def _VerifyClusterName(name, _ss_cluster_name_file=None):
188
  """Verifies cluster name against a local cluster name.
189

190
  @type name: string
191
  @param name: Cluster name
192

193
  """
194
  if _ss_cluster_name_file is None:
195
    _ss_cluster_name_file = \
196
      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
197

    
198
  try:
199
    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
200
  except EnvironmentError, err:
201
    if err.errno != errno.ENOENT:
202
      raise
203

    
204
    logging.debug("Local cluster name was not found (file %s)",
205
                  _ss_cluster_name_file)
206
  else:
207
    if name != local_name:
208
      raise JoinError("Current cluster name is '%s'" % local_name)
209

    
210

    
211
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
212
  """Verifies cluster name.
213

214
  @type data: dict
215

216
  """
217
  name = data.get(constants.SSHS_CLUSTER_NAME)
218
  if name:
219
    _verify_fn(name)
220
  else:
221
    raise JoinError("Cluster name must be specified")
222

    
223

    
224
def _UpdateKeyFiles(keys, dry_run, keyfiles):
225
  """Updates SSH key files.
226

227
  @type keys: sequence of tuple; (string, string, string)
228
  @param keys: Keys to write, tuples consist of key type
229
    (L{constants.SSHK_ALL}), public and private key
230
  @type dry_run: boolean
231
  @param dry_run: Whether to perform a dry run
232
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
233
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
234
    names; value tuples consist of public key filename and private key filename
235

236
  """
237
  assert set(keyfiles) == constants.SSHK_ALL
238

    
239
  for (kind, public_key, private_key) in keys:
240
    (public_file, private_file) = keyfiles[kind]
241

    
242
    logging.debug("Writing %s ...", public_file)
243
    utils.WriteFile(public_file, data=public_key, mode=0644,
244
                    backup=True, dry_run=dry_run)
245

    
246
    logging.debug("Writing %s ...", private_file)
247
    utils.WriteFile(private_file, data=private_key, mode=0600,
248
                    backup=True, dry_run=dry_run)
249

    
250

    
251
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
252
                    _keyfiles=None):
253
  """Updates SSH daemon's keys.
254

255
  Unless C{dry_run} is set, the daemon is restarted at the end.
256

257
  @type data: dict
258
  @param data: Input data
259
  @type dry_run: boolean
260
  @param dry_run: Whether to perform a dry run
261

262
  """
263
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
264
  if not keys:
265
    return
266

    
267
  if _keyfiles is None:
268
    _keyfiles = _SSH_DAEMON_KEYFILES
269

    
270
  logging.info("Updating SSH daemon key files")
271
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
272

    
273
  if dry_run:
274
    logging.info("This is a dry run, not restarting SSH daemon")
275
  else:
276
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
277
                        interactive=True)
278
    if result.failed:
279
      raise JoinError("Could not reload SSH keys, command '%s'"
280
                      " had exitcode %s and error %s" %
281
                       (result.cmd, result.exit_code, result.output))
282

    
283

    
284
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
285
  """Updates root's SSH keys.
286

287
  Root's C{authorized_keys} file is also updated with new public keys.
288

289
  @type data: dict
290
  @param data: Input data
291
  @type dry_run: boolean
292
  @param dry_run: Whether to perform a dry run
293

294
  """
295
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
296
  if not keys:
297
    return
298

    
299
  (dsa_private_file, dsa_public_file, auth_keys_file) = \
300
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
301
                     kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
302
  (rsa_private_file, rsa_public_file, _) = \
303
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
304
                     kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
305

    
306
  _UpdateKeyFiles(keys, dry_run, {
307
    constants.SSHK_RSA: (rsa_public_file, rsa_private_file),
308
    constants.SSHK_DSA: (dsa_public_file, dsa_private_file),
309
    })
310

    
311
  if dry_run:
312
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
313
  else:
314
    for (kind, public_key, _) in keys:
315
      line = "%s %s" % (_SSHK_TO_SSHAK[kind], public_key)
316
      utils.AddAuthorizedKey(auth_keys_file, line)
317

    
318

    
319
def LoadData(raw):
320
  """Parses and verifies input data.
321

322
  @rtype: dict
323

324
  """
325
  try:
326
    data = serializer.LoadJson(raw)
327
  except Exception, err:
328
    raise errors.ParseError("Can't parse input data: %s" % err)
329

    
330
  if not _DATA_CHECK(data):
331
    raise errors.ParseError("Input data does not match expected format: %s" %
332
                            _DATA_CHECK)
333

    
334
  return data
335

    
336

    
337
def Main():
338
  """Main routine.
339

340
  """
341
  opts = ParseOptions()
342

    
343
  SetupLogging(opts)
344

    
345
  try:
346
    data = LoadData(sys.stdin.read())
347

    
348
    # Check if input data is correct
349
    VerifyClusterName(data)
350
    VerifyCertificate(data)
351

    
352
    # Update SSH files
353
    UpdateSshDaemon(data, opts.dry_run)
354
    UpdateSshRoot(data, opts.dry_run)
355

    
356
    logging.info("Setup finished successfully")
357
  except Exception, err: # pylint: disable=W0703
358
    logging.debug("Caught unhandled exception", exc_info=True)
359

    
360
    (retcode, message) = cli.FormatError(err)
361
    logging.error(message)
362

    
363
    return retcode
364
  else:
365
    return constants.EXIT_SUCCESS