Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ c87440f5

History | View | Annotate | Download (9.7 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

    
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

    
43

    
44
_SSH_KEY_LIST_ITEM = \
45
  ht.TAnd(ht.TIsLength(3),
46
          ht.TItems([
47
            ht.TElemOf(constants.SSHK_ALL),
48
            ht.Comment("public")(ht.TNonEmptyString),
49
            ht.Comment("private")(ht.TNonEmptyString),
50
          ]))
51

    
52
_SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53

    
54
_DATA_CHECK = ht.TStrictDict(False, True, {
55
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59
  })
60

    
61
_SSHK_TO_SSHAK = {
62
  constants.SSHK_RSA: constants.SSHAK_RSA,
63
  constants.SSHK_DSA: constants.SSHAK_DSS,
64
  }
65

    
66
_SSH_DAEMON_KEYFILES = {
67
  constants.SSHK_RSA:
68
    (pathutils.SSH_HOST_RSA_PUB, pathutils.SSH_HOST_RSA_PRIV),
69
  constants.SSHK_DSA:
70
    (pathutils.SSH_HOST_DSA_PUB, pathutils.SSH_HOST_DSA_PRIV),
71
  }
72

    
73
assert frozenset(_SSHK_TO_SSHAK.keys()) == constants.SSHK_ALL
74
assert frozenset(_SSHK_TO_SSHAK.values()) == constants.SSHAK_ALL
75

    
76

    
77
class JoinError(errors.GenericError):
78
  """Local class for reporting errors.
79

80
  """
81

    
82

    
83
def ParseOptions():
84
  """Parses the options passed to the program.
85

86
  @return: Options and arguments
87

88
  """
89
  program = os.path.basename(sys.argv[0])
90

    
91
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
92
                                 prog=program)
93
  parser.add_option(cli.DEBUG_OPT)
94
  parser.add_option(cli.VERBOSE_OPT)
95
  parser.add_option(cli.DRY_RUN_OPT)
96

    
97
  (opts, args) = parser.parse_args()
98

    
99
  return VerifyOptions(parser, opts, args)
100

    
101

    
102
def VerifyOptions(parser, opts, args):
103
  """Verifies options and arguments for correctness.
104

105
  """
106
  if args:
107
    parser.error("No arguments are expected")
108

    
109
  return opts
110

    
111

    
112
def SetupLogging(opts):
113
  """Configures the logging module.
114

115
  """
116
  formatter = logging.Formatter("%(asctime)s: %(message)s")
117

    
118
  stderr_handler = logging.StreamHandler()
119
  stderr_handler.setFormatter(formatter)
120
  if opts.debug:
121
    stderr_handler.setLevel(logging.NOTSET)
122
  elif opts.verbose:
123
    stderr_handler.setLevel(logging.INFO)
124
  else:
125
    stderr_handler.setLevel(logging.WARNING)
126

    
127
  root_logger = logging.getLogger("")
128
  root_logger.setLevel(logging.NOTSET)
129
  root_logger.addHandler(stderr_handler)
130

    
131

    
132
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
133
  """Verifies a certificate against the local node daemon certificate.
134

135
  @type cert: string
136
  @param cert: Certificate in PEM format (no key)
137

138
  """
139
  try:
140
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
141
  except OpenSSL.crypto.Error, err:
142
    pass
143
  else:
144
    raise JoinError("No private key may be given")
145

    
146
  try:
147
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
148
  except Exception, err:
149
    raise errors.X509CertError("(stdin)",
150
                               "Unable to load certificate: %s" % err)
151

    
152
  try:
153
    noded_pem = utils.ReadFile(_noded_cert_file)
154
  except EnvironmentError, err:
155
    if err.errno != errno.ENOENT:
156
      raise
157

    
158
    logging.debug("Local node certificate was not found (file %s)",
159
                  _noded_cert_file)
160
    return
161

    
162
  try:
163
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
164
  except Exception, err:
165
    raise errors.X509CertError(_noded_cert_file,
166
                               "Unable to load private key: %s" % err)
167

    
168
  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
169
  ctx.use_privatekey(key)
170
  ctx.use_certificate(cert)
171
  try:
172
    ctx.check_privatekey()
173
  except OpenSSL.SSL.Error:
174
    raise JoinError("Given cluster certificate does not match local key")
175

    
176

    
177
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
178
  """Verifies cluster certificate.
179

180
  @type data: dict
181

182
  """
183
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
184
  if cert:
185
    _verify_fn(cert)
186

    
187

    
188
def _VerifyClusterName(name, _ss_cluster_name_file=None):
189
  """Verifies cluster name against a local cluster name.
190

191
  @type name: string
192
  @param name: Cluster name
193

194
  """
195
  if _ss_cluster_name_file is None:
196
    _ss_cluster_name_file = \
197
      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
198

    
199
  try:
200
    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
201
  except EnvironmentError, err:
202
    if err.errno != errno.ENOENT:
203
      raise
204

    
205
    logging.debug("Local cluster name was not found (file %s)",
206
                  _ss_cluster_name_file)
207
  else:
208
    if name != local_name:
209
      raise JoinError("Current cluster name is '%s'" % local_name)
210

    
211

    
212
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
213
  """Verifies cluster name.
214

215
  @type data: dict
216

217
  """
218
  name = data.get(constants.SSHS_CLUSTER_NAME)
219
  if name:
220
    _verify_fn(name)
221
  else:
222
    raise JoinError("Cluster name must be specified")
223

    
224

    
225
def _UpdateKeyFiles(keys, dry_run, keyfiles):
226
  """Updates SSH key files.
227

228
  @type keys: sequence of tuple; (string, string, string)
229
  @param keys: Keys to write, tuples consist of key type
230
    (L{constants.SSHK_ALL}), public and private key
231
  @type dry_run: boolean
232
  @param dry_run: Whether to perform a dry run
233
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
234
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
235
    names; value tuples consist of public key filename and private key filename
236

237
  """
238
  assert set(keyfiles) == constants.SSHK_ALL
239

    
240
  for (kind, public_key, private_key) in keys:
241
    (public_file, private_file) = keyfiles[kind]
242

    
243
    logging.debug("Writing %s ...", public_file)
244
    utils.WriteFile(public_file, data=public_key, mode=0644,
245
                    backup=True, dry_run=dry_run)
246

    
247
    logging.debug("Writing %s ...", private_file)
248
    utils.WriteFile(private_file, data=private_key, mode=0600,
249
                    backup=True, dry_run=dry_run)
250

    
251

    
252
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
253
                    _keyfiles=None):
254
  """Updates SSH daemon's keys.
255

256
  Unless C{dry_run} is set, the daemon is restarted at the end.
257

258
  @type data: dict
259
  @param data: Input data
260
  @type dry_run: boolean
261
  @param dry_run: Whether to perform a dry run
262

263
  """
264
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
265
  if not keys:
266
    return
267

    
268
  if _keyfiles is None:
269
    _keyfiles = _SSH_DAEMON_KEYFILES
270

    
271
  logging.info("Updating SSH daemon key files")
272
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
273

    
274
  if dry_run:
275
    logging.info("This is a dry run, not restarting SSH daemon")
276
  else:
277
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
278
                        interactive=True)
279
    if result.failed:
280
      raise JoinError("Could not reload SSH keys, command '%s'"
281
                      " had exitcode %s and error %s" %
282
                       (result.cmd, result.exit_code, result.output))
283

    
284

    
285
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
286
  """Updates root's SSH keys.
287

288
  Root's C{authorized_keys} file is also updated with new public keys.
289

290
  @type data: dict
291
  @param data: Input data
292
  @type dry_run: boolean
293
  @param dry_run: Whether to perform a dry run
294

295
  """
296
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
297
  if not keys:
298
    return
299

    
300
  (dsa_private_file, dsa_public_file, auth_keys_file) = \
301
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
302
                     kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
303
  (rsa_private_file, rsa_public_file, _) = \
304
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
305
                     kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
306

    
307
  _UpdateKeyFiles(keys, dry_run, {
308
    constants.SSHK_RSA: (rsa_public_file, rsa_private_file),
309
    constants.SSHK_DSA: (dsa_public_file, dsa_private_file),
310
    })
311

    
312
  if dry_run:
313
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
314
  else:
315
    for (kind, public_key, _) in keys:
316
      line = "%s %s" % (_SSHK_TO_SSHAK[kind], public_key)
317
      utils.AddAuthorizedKey(auth_keys_file, line)
318

    
319

    
320
def LoadData(raw):
321
  """Parses and verifies input data.
322

323
  @rtype: dict
324

325
  """
326
  try:
327
    data = serializer.LoadJson(raw)
328
  except Exception, err:
329
    raise errors.ParseError("Can't parse input data: %s" % err)
330

    
331
  if not _DATA_CHECK(data):
332
    raise errors.ParseError("Input data does not match expected format: %s" %
333
                            _DATA_CHECK)
334

    
335
  return data
336

    
337

    
338
def Main():
339
  """Main routine.
340

341
  """
342
  opts = ParseOptions()
343

    
344
  SetupLogging(opts)
345

    
346
  try:
347
    data = LoadData(sys.stdin.read())
348

    
349
    # Check if input data is correct
350
    VerifyClusterName(data)
351
    VerifyCertificate(data)
352

    
353
    # Update SSH files
354
    UpdateSshDaemon(data, opts.dry_run)
355
    UpdateSshRoot(data, opts.dry_run)
356

    
357
    logging.info("Setup finished successfully")
358
  except Exception, err: # pylint: disable=W0703
359
    logging.debug("Caught unhandled exception", exc_info=True)
360

    
361
    (retcode, message) = cli.FormatError(err)
362
    logging.error(message)
363

    
364
    return retcode
365
  else:
366
    return constants.EXIT_SUCCESS