Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ a8b3b09d

History | View | Annotate | Download (7.5 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

    
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

    
43

    
44
_SSH_KEY_LIST_ITEM = \
45
  ht.TAnd(ht.TIsLength(3),
46
          ht.TItems([
47
            ht.TElemOf(constants.SSHK_ALL),
48
            ht.Comment("public")(ht.TNonEmptyString),
49
            ht.Comment("private")(ht.TNonEmptyString),
50
          ]))
51

    
52
_SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53

    
54
_DATA_CHECK = ht.TStrictDict(False, True, {
55
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59
  })
60

    
61

    
62
class JoinError(errors.GenericError):
63
  """Local class for reporting errors.
64

65
  """
66

    
67

    
68
def ParseOptions():
69
  """Parses the options passed to the program.
70

71
  @return: Options and arguments
72

73
  """
74
  program = os.path.basename(sys.argv[0])
75

    
76
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
77
                                 prog=program)
78
  parser.add_option(cli.DEBUG_OPT)
79
  parser.add_option(cli.VERBOSE_OPT)
80
  parser.add_option(cli.DRY_RUN_OPT)
81

    
82
  (opts, args) = parser.parse_args()
83

    
84
  return VerifyOptions(parser, opts, args)
85

    
86

    
87
def VerifyOptions(parser, opts, args):
88
  """Verifies options and arguments for correctness.
89

90
  """
91
  if args:
92
    parser.error("No arguments are expected")
93

    
94
  return opts
95

    
96

    
97
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
98
  """Verifies a certificate against the local node daemon certificate.
99

100
  @type cert: string
101
  @param cert: Certificate in PEM format (no key)
102

103
  """
104
  try:
105
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
106
  except OpenSSL.crypto.Error, err:
107
    pass
108
  else:
109
    raise JoinError("No private key may be given")
110

    
111
  try:
112
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
113
  except Exception, err:
114
    raise errors.X509CertError("(stdin)",
115
                               "Unable to load certificate: %s" % err)
116

    
117
  try:
118
    noded_pem = utils.ReadFile(_noded_cert_file)
119
  except EnvironmentError, err:
120
    if err.errno != errno.ENOENT:
121
      raise
122

    
123
    logging.debug("Local node certificate was not found (file %s)",
124
                  _noded_cert_file)
125
    return
126

    
127
  try:
128
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
129
  except Exception, err:
130
    raise errors.X509CertError(_noded_cert_file,
131
                               "Unable to load private key: %s" % err)
132

    
133
  check_fn = utils.PrepareX509CertKeyCheck(cert, key)
134
  try:
135
    check_fn()
136
  except OpenSSL.SSL.Error:
137
    raise JoinError("Given cluster certificate does not match local key")
138

    
139

    
140
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
141
  """Verifies cluster certificate.
142

143
  @type data: dict
144

145
  """
146
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
147
  if cert:
148
    _verify_fn(cert)
149

    
150

    
151
def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
152
  """Verifies cluster name.
153

154
  @type data: dict
155

156
  """
157
  name = data.get(constants.SSHS_CLUSTER_NAME)
158
  if name:
159
    _verify_fn(name)
160
  else:
161
    raise JoinError("Cluster name must be specified")
162

    
163

    
164
def _UpdateKeyFiles(keys, dry_run, keyfiles):
165
  """Updates SSH key files.
166

167
  @type keys: sequence of tuple; (string, string, string)
168
  @param keys: Keys to write, tuples consist of key type
169
    (L{constants.SSHK_ALL}), public and private key
170
  @type dry_run: boolean
171
  @param dry_run: Whether to perform a dry run
172
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
173
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
174
    names; value tuples consist of public key filename and private key filename
175

176
  """
177
  assert set(keyfiles) == constants.SSHK_ALL
178

    
179
  for (kind, private_key, public_key) in keys:
180
    (private_file, public_file) = keyfiles[kind]
181

    
182
    logging.debug("Writing %s ...", private_file)
183
    utils.WriteFile(private_file, data=private_key, mode=0600,
184
                    backup=True, dry_run=dry_run)
185

    
186
    logging.debug("Writing %s ...", public_file)
187
    utils.WriteFile(public_file, data=public_key, mode=0644,
188
                    backup=True, dry_run=dry_run)
189

    
190

    
191
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
192
                    _keyfiles=None):
193
  """Updates SSH daemon's keys.
194

195
  Unless C{dry_run} is set, the daemon is restarted at the end.
196

197
  @type data: dict
198
  @param data: Input data
199
  @type dry_run: boolean
200
  @param dry_run: Whether to perform a dry run
201

202
  """
203
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
204
  if not keys:
205
    return
206

    
207
  if _keyfiles is None:
208
    _keyfiles = constants.SSH_DAEMON_KEYFILES
209

    
210
  logging.info("Updating SSH daemon key files")
211
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
212

    
213
  if dry_run:
214
    logging.info("This is a dry run, not restarting SSH daemon")
215
  else:
216
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
217
                        interactive=True)
218
    if result.failed:
219
      raise JoinError("Could not reload SSH keys, command '%s'"
220
                      " had exitcode %s and error %s" %
221
                       (result.cmd, result.exit_code, result.output))
222

    
223

    
224
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
225
  """Updates root's SSH keys.
226

227
  Root's C{authorized_keys} file is also updated with new public keys.
228

229
  @type data: dict
230
  @param data: Input data
231
  @type dry_run: boolean
232
  @param dry_run: Whether to perform a dry run
233

234
  """
235
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
236
  if not keys:
237
    return
238

    
239
  (auth_keys_file, keyfiles) = \
240
    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
241
                        _homedir_fn=_homedir_fn)
242

    
243
  _UpdateKeyFiles(keys, dry_run, keyfiles)
244

    
245
  if dry_run:
246
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
247
  else:
248
    for (_, _, public_key) in keys:
249
      utils.AddAuthorizedKey(auth_keys_file, public_key)
250

    
251

    
252
def LoadData(raw):
253
  """Parses and verifies input data.
254

255
  @rtype: dict
256

257
  """
258
  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
259

    
260

    
261
def Main():
262
  """Main routine.
263

264
  """
265
  opts = ParseOptions()
266

    
267
  utils.SetupToolLogging(opts.debug, opts.verbose)
268

    
269
  try:
270
    data = LoadData(sys.stdin.read())
271

    
272
    # Check if input data is correct
273
    VerifyClusterName(data)
274
    VerifyCertificate(data)
275

    
276
    # Update SSH files
277
    UpdateSshDaemon(data, opts.dry_run)
278
    UpdateSshRoot(data, opts.dry_run)
279

    
280
    logging.info("Setup finished successfully")
281
  except Exception, err: # pylint: disable=W0703
282
    logging.debug("Caught unhandled exception", exc_info=True)
283

    
284
    (retcode, message) = cli.FormatError(err)
285
    logging.error(message)
286

    
287
    return retcode
288
  else:
289
    return constants.EXIT_SUCCESS