Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ 910ef222

History | View | Annotate | Download (9.4 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

    
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

    
43

    
44
_SSH_KEY_LIST_ITEM = \
45
  ht.TAnd(ht.TIsLength(3),
46
          ht.TItems([
47
            ht.TElemOf(constants.SSHK_ALL),
48
            ht.Comment("public")(ht.TNonEmptyString),
49
            ht.Comment("private")(ht.TNonEmptyString),
50
          ]))
51

    
52
_SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53

    
54
_DATA_CHECK = ht.TStrictDict(False, True, {
55
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59
  })
60

    
61
_SSH_DAEMON_KEYFILES = {
62
  constants.SSHK_RSA:
63
    (pathutils.SSH_HOST_RSA_PUB, pathutils.SSH_HOST_RSA_PRIV),
64
  constants.SSHK_DSA:
65
    (pathutils.SSH_HOST_DSA_PUB, pathutils.SSH_HOST_DSA_PRIV),
66
  }
67

    
68

    
69
class JoinError(errors.GenericError):
70
  """Local class for reporting errors.
71

72
  """
73

    
74

    
75
def ParseOptions():
76
  """Parses the options passed to the program.
77

78
  @return: Options and arguments
79

80
  """
81
  program = os.path.basename(sys.argv[0])
82

    
83
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
84
                                 prog=program)
85
  parser.add_option(cli.DEBUG_OPT)
86
  parser.add_option(cli.VERBOSE_OPT)
87
  parser.add_option(cli.DRY_RUN_OPT)
88

    
89
  (opts, args) = parser.parse_args()
90

    
91
  return VerifyOptions(parser, opts, args)
92

    
93

    
94
def VerifyOptions(parser, opts, args):
95
  """Verifies options and arguments for correctness.
96

97
  """
98
  if args:
99
    parser.error("No arguments are expected")
100

    
101
  return opts
102

    
103

    
104
def SetupLogging(opts):
105
  """Configures the logging module.
106

107
  """
108
  formatter = logging.Formatter("%(asctime)s: %(message)s")
109

    
110
  stderr_handler = logging.StreamHandler()
111
  stderr_handler.setFormatter(formatter)
112
  if opts.debug:
113
    stderr_handler.setLevel(logging.NOTSET)
114
  elif opts.verbose:
115
    stderr_handler.setLevel(logging.INFO)
116
  else:
117
    stderr_handler.setLevel(logging.WARNING)
118

    
119
  root_logger = logging.getLogger("")
120
  root_logger.setLevel(logging.NOTSET)
121
  root_logger.addHandler(stderr_handler)
122

    
123

    
124
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
125
  """Verifies a certificate against the local node daemon certificate.
126

127
  @type cert: string
128
  @param cert: Certificate in PEM format (no key)
129

130
  """
131
  try:
132
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
133
  except OpenSSL.crypto.Error, err:
134
    pass
135
  else:
136
    raise JoinError("No private key may be given")
137

    
138
  try:
139
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
140
  except Exception, err:
141
    raise errors.X509CertError("(stdin)",
142
                               "Unable to load certificate: %s" % err)
143

    
144
  try:
145
    noded_pem = utils.ReadFile(_noded_cert_file)
146
  except EnvironmentError, err:
147
    if err.errno != errno.ENOENT:
148
      raise
149

    
150
    logging.debug("Local node certificate was not found (file %s)",
151
                  _noded_cert_file)
152
    return
153

    
154
  try:
155
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
156
  except Exception, err:
157
    raise errors.X509CertError(_noded_cert_file,
158
                               "Unable to load private key: %s" % err)
159

    
160
  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
161
  ctx.use_privatekey(key)
162
  ctx.use_certificate(cert)
163
  try:
164
    ctx.check_privatekey()
165
  except OpenSSL.SSL.Error:
166
    raise JoinError("Given cluster certificate does not match local key")
167

    
168

    
169
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
170
  """Verifies cluster certificate.
171

172
  @type data: dict
173

174
  """
175
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
176
  if cert:
177
    _verify_fn(cert)
178

    
179

    
180
def _VerifyClusterName(name, _ss_cluster_name_file=None):
181
  """Verifies cluster name against a local cluster name.
182

183
  @type name: string
184
  @param name: Cluster name
185

186
  """
187
  if _ss_cluster_name_file is None:
188
    _ss_cluster_name_file = \
189
      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
190

    
191
  try:
192
    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
193
  except EnvironmentError, err:
194
    if err.errno != errno.ENOENT:
195
      raise
196

    
197
    logging.debug("Local cluster name was not found (file %s)",
198
                  _ss_cluster_name_file)
199
  else:
200
    if name != local_name:
201
      raise JoinError("Current cluster name is '%s'" % local_name)
202

    
203

    
204
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
205
  """Verifies cluster name.
206

207
  @type data: dict
208

209
  """
210
  name = data.get(constants.SSHS_CLUSTER_NAME)
211
  if name:
212
    _verify_fn(name)
213
  else:
214
    raise JoinError("Cluster name must be specified")
215

    
216

    
217
def _UpdateKeyFiles(keys, dry_run, keyfiles):
218
  """Updates SSH key files.
219

220
  @type keys: sequence of tuple; (string, string, string)
221
  @param keys: Keys to write, tuples consist of key type
222
    (L{constants.SSHK_ALL}), public and private key
223
  @type dry_run: boolean
224
  @param dry_run: Whether to perform a dry run
225
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
226
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
227
    names; value tuples consist of public key filename and private key filename
228

229
  """
230
  assert set(keyfiles) == constants.SSHK_ALL
231

    
232
  for (kind, public_key, private_key) in keys:
233
    (public_file, private_file) = keyfiles[kind]
234

    
235
    logging.debug("Writing %s ...", public_file)
236
    utils.WriteFile(public_file, data=public_key, mode=0644,
237
                    backup=True, dry_run=dry_run)
238

    
239
    logging.debug("Writing %s ...", private_file)
240
    utils.WriteFile(private_file, data=private_key, mode=0600,
241
                    backup=True, dry_run=dry_run)
242

    
243

    
244
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
245
                    _keyfiles=None):
246
  """Updates SSH daemon's keys.
247

248
  Unless C{dry_run} is set, the daemon is restarted at the end.
249

250
  @type data: dict
251
  @param data: Input data
252
  @type dry_run: boolean
253
  @param dry_run: Whether to perform a dry run
254

255
  """
256
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
257
  if not keys:
258
    return
259

    
260
  if _keyfiles is None:
261
    _keyfiles = _SSH_DAEMON_KEYFILES
262

    
263
  logging.info("Updating SSH daemon key files")
264
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
265

    
266
  if dry_run:
267
    logging.info("This is a dry run, not restarting SSH daemon")
268
  else:
269
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
270
                        interactive=True)
271
    if result.failed:
272
      raise JoinError("Could not reload SSH keys, command '%s'"
273
                      " had exitcode %s and error %s" %
274
                       (result.cmd, result.exit_code, result.output))
275

    
276

    
277
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
278
  """Updates root's SSH keys.
279

280
  Root's C{authorized_keys} file is also updated with new public keys.
281

282
  @type data: dict
283
  @param data: Input data
284
  @type dry_run: boolean
285
  @param dry_run: Whether to perform a dry run
286

287
  """
288
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
289
  if not keys:
290
    return
291

    
292
  (dsa_private_file, dsa_public_file, auth_keys_file) = \
293
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
294
                     kind=constants.SSHK_DSA, _homedir_fn=_homedir_fn)
295
  (rsa_private_file, rsa_public_file, _) = \
296
    ssh.GetUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
297
                     kind=constants.SSHK_RSA, _homedir_fn=_homedir_fn)
298

    
299
  _UpdateKeyFiles(keys, dry_run, {
300
    constants.SSHK_RSA: (rsa_public_file, rsa_private_file),
301
    constants.SSHK_DSA: (dsa_public_file, dsa_private_file),
302
    })
303

    
304
  if dry_run:
305
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
306
  else:
307
    for (_, _, public_key) in keys:
308
      utils.AddAuthorizedKey(auth_keys_file, public_key)
309

    
310

    
311
def LoadData(raw):
312
  """Parses and verifies input data.
313

314
  @rtype: dict
315

316
  """
317
  try:
318
    data = serializer.LoadJson(raw)
319
  except Exception, err:
320
    raise errors.ParseError("Can't parse input data: %s" % err)
321

    
322
  if not _DATA_CHECK(data):
323
    raise errors.ParseError("Input data does not match expected format: %s" %
324
                            _DATA_CHECK)
325

    
326
  return data
327

    
328

    
329
def Main():
330
  """Main routine.
331

332
  """
333
  opts = ParseOptions()
334

    
335
  SetupLogging(opts)
336

    
337
  try:
338
    data = LoadData(sys.stdin.read())
339

    
340
    # Check if input data is correct
341
    VerifyClusterName(data)
342
    VerifyCertificate(data)
343

    
344
    # Update SSH files
345
    UpdateSshDaemon(data, opts.dry_run)
346
    UpdateSshRoot(data, opts.dry_run)
347

    
348
    logging.info("Setup finished successfully")
349
  except Exception, err: # pylint: disable=W0703
350
    logging.debug("Caught unhandled exception", exc_info=True)
351

    
352
    (retcode, message) = cli.FormatError(err)
353
    logging.error(message)
354

    
355
    return retcode
356
  else:
357
    return constants.EXIT_SUCCESS