Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ f712208d

History | View | Annotate | Download (8.9 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

    
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

    
43

    
44
_SSH_KEY_LIST_ITEM = \
45
  ht.TAnd(ht.TIsLength(3),
46
          ht.TItems([
47
            ht.TElemOf(constants.SSHK_ALL),
48
            ht.Comment("public")(ht.TNonEmptyString),
49
            ht.Comment("private")(ht.TNonEmptyString),
50
          ]))
51

    
52
_SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53

    
54
_DATA_CHECK = ht.TStrictDict(False, True, {
55
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59
  })
60

    
61

    
62
class JoinError(errors.GenericError):
63
  """Local class for reporting errors.
64

65
  """
66

    
67

    
68
def ParseOptions():
69
  """Parses the options passed to the program.
70

71
  @return: Options and arguments
72

73
  """
74
  program = os.path.basename(sys.argv[0])
75

    
76
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
77
                                 prog=program)
78
  parser.add_option(cli.DEBUG_OPT)
79
  parser.add_option(cli.VERBOSE_OPT)
80
  parser.add_option(cli.DRY_RUN_OPT)
81

    
82
  (opts, args) = parser.parse_args()
83

    
84
  return VerifyOptions(parser, opts, args)
85

    
86

    
87
def VerifyOptions(parser, opts, args):
88
  """Verifies options and arguments for correctness.
89

90
  """
91
  if args:
92
    parser.error("No arguments are expected")
93

    
94
  return opts
95

    
96

    
97
def SetupLogging(opts):
98
  """Configures the logging module.
99

100
  """
101
  formatter = logging.Formatter("%(asctime)s: %(message)s")
102

    
103
  stderr_handler = logging.StreamHandler()
104
  stderr_handler.setFormatter(formatter)
105
  if opts.debug:
106
    stderr_handler.setLevel(logging.NOTSET)
107
  elif opts.verbose:
108
    stderr_handler.setLevel(logging.INFO)
109
  else:
110
    stderr_handler.setLevel(logging.WARNING)
111

    
112
  root_logger = logging.getLogger("")
113
  root_logger.setLevel(logging.NOTSET)
114
  root_logger.addHandler(stderr_handler)
115

    
116

    
117
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
118
  """Verifies a certificate against the local node daemon certificate.
119

120
  @type cert: string
121
  @param cert: Certificate in PEM format (no key)
122

123
  """
124
  try:
125
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
126
  except OpenSSL.crypto.Error, err:
127
    pass
128
  else:
129
    raise JoinError("No private key may be given")
130

    
131
  try:
132
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
133
  except Exception, err:
134
    raise errors.X509CertError("(stdin)",
135
                               "Unable to load certificate: %s" % err)
136

    
137
  try:
138
    noded_pem = utils.ReadFile(_noded_cert_file)
139
  except EnvironmentError, err:
140
    if err.errno != errno.ENOENT:
141
      raise
142

    
143
    logging.debug("Local node certificate was not found (file %s)",
144
                  _noded_cert_file)
145
    return
146

    
147
  try:
148
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
149
  except Exception, err:
150
    raise errors.X509CertError(_noded_cert_file,
151
                               "Unable to load private key: %s" % err)
152

    
153
  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
154
  ctx.use_privatekey(key)
155
  ctx.use_certificate(cert)
156
  try:
157
    ctx.check_privatekey()
158
  except OpenSSL.SSL.Error:
159
    raise JoinError("Given cluster certificate does not match local key")
160

    
161

    
162
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
163
  """Verifies cluster certificate.
164

165
  @type data: dict
166

167
  """
168
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
169
  if cert:
170
    _verify_fn(cert)
171

    
172

    
173
def _VerifyClusterName(name, _ss_cluster_name_file=None):
174
  """Verifies cluster name against a local cluster name.
175

176
  @type name: string
177
  @param name: Cluster name
178

179
  """
180
  if _ss_cluster_name_file is None:
181
    _ss_cluster_name_file = \
182
      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
183

    
184
  try:
185
    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
186
  except EnvironmentError, err:
187
    if err.errno != errno.ENOENT:
188
      raise
189

    
190
    logging.debug("Local cluster name was not found (file %s)",
191
                  _ss_cluster_name_file)
192
  else:
193
    if name != local_name:
194
      raise JoinError("Current cluster name is '%s'" % local_name)
195

    
196

    
197
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
198
  """Verifies cluster name.
199

200
  @type data: dict
201

202
  """
203
  name = data.get(constants.SSHS_CLUSTER_NAME)
204
  if name:
205
    _verify_fn(name)
206
  else:
207
    raise JoinError("Cluster name must be specified")
208

    
209

    
210
def _UpdateKeyFiles(keys, dry_run, keyfiles):
211
  """Updates SSH key files.
212

213
  @type keys: sequence of tuple; (string, string, string)
214
  @param keys: Keys to write, tuples consist of key type
215
    (L{constants.SSHK_ALL}), public and private key
216
  @type dry_run: boolean
217
  @param dry_run: Whether to perform a dry run
218
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
219
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
220
    names; value tuples consist of public key filename and private key filename
221

222
  """
223
  assert set(keyfiles) == constants.SSHK_ALL
224

    
225
  for (kind, private_key, public_key) in keys:
226
    (private_file, public_file) = keyfiles[kind]
227

    
228
    logging.debug("Writing %s ...", private_file)
229
    utils.WriteFile(private_file, data=private_key, mode=0600,
230
                    backup=True, dry_run=dry_run)
231

    
232
    logging.debug("Writing %s ...", public_file)
233
    utils.WriteFile(public_file, data=public_key, mode=0644,
234
                    backup=True, dry_run=dry_run)
235

    
236

    
237
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
238
                    _keyfiles=None):
239
  """Updates SSH daemon's keys.
240

241
  Unless C{dry_run} is set, the daemon is restarted at the end.
242

243
  @type data: dict
244
  @param data: Input data
245
  @type dry_run: boolean
246
  @param dry_run: Whether to perform a dry run
247

248
  """
249
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
250
  if not keys:
251
    return
252

    
253
  if _keyfiles is None:
254
    _keyfiles = constants.SSH_DAEMON_KEYFILES
255

    
256
  logging.info("Updating SSH daemon key files")
257
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
258

    
259
  if dry_run:
260
    logging.info("This is a dry run, not restarting SSH daemon")
261
  else:
262
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
263
                        interactive=True)
264
    if result.failed:
265
      raise JoinError("Could not reload SSH keys, command '%s'"
266
                      " had exitcode %s and error %s" %
267
                       (result.cmd, result.exit_code, result.output))
268

    
269

    
270
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
271
  """Updates root's SSH keys.
272

273
  Root's C{authorized_keys} file is also updated with new public keys.
274

275
  @type data: dict
276
  @param data: Input data
277
  @type dry_run: boolean
278
  @param dry_run: Whether to perform a dry run
279

280
  """
281
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
282
  if not keys:
283
    return
284

    
285
  (auth_keys_file, keyfiles) = \
286
    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
287
                        _homedir_fn=_homedir_fn)
288

    
289
  _UpdateKeyFiles(keys, dry_run, keyfiles)
290

    
291
  if dry_run:
292
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
293
  else:
294
    for (_, _, public_key) in keys:
295
      utils.AddAuthorizedKey(auth_keys_file, public_key)
296

    
297

    
298
def LoadData(raw):
299
  """Parses and verifies input data.
300

301
  @rtype: dict
302

303
  """
304
  try:
305
    data = serializer.LoadJson(raw)
306
  except Exception, err:
307
    raise errors.ParseError("Can't parse input data: %s" % err)
308

    
309
  if not _DATA_CHECK(data):
310
    raise errors.ParseError("Input data does not match expected format: %s" %
311
                            _DATA_CHECK)
312

    
313
  return data
314

    
315

    
316
def Main():
317
  """Main routine.
318

319
  """
320
  opts = ParseOptions()
321

    
322
  SetupLogging(opts)
323

    
324
  try:
325
    data = LoadData(sys.stdin.read())
326

    
327
    # Check if input data is correct
328
    VerifyClusterName(data)
329
    VerifyCertificate(data)
330

    
331
    # Update SSH files
332
    UpdateSshDaemon(data, opts.dry_run)
333
    UpdateSshRoot(data, opts.dry_run)
334

    
335
    logging.info("Setup finished successfully")
336
  except Exception, err: # pylint: disable=W0703
337
    logging.debug("Caught unhandled exception", exc_info=True)
338

    
339
    (retcode, message) = cli.FormatError(err)
340
    logging.error(message)
341

    
342
    return retcode
343
  else:
344
    return constants.EXIT_SUCCESS