Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ 796b5152

History | View | Annotate | Download (8.4 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import errno
31
import OpenSSL
32

    
33
from ganeti import cli
34
from ganeti import constants
35
from ganeti import errors
36
from ganeti import pathutils
37
from ganeti import utils
38
from ganeti import serializer
39
from ganeti import ht
40
from ganeti import ssh
41
from ganeti import ssconf
42

    
43

    
44
_SSH_KEY_LIST_ITEM = \
45
  ht.TAnd(ht.TIsLength(3),
46
          ht.TItems([
47
            ht.TElemOf(constants.SSHK_ALL),
48
            ht.Comment("public")(ht.TNonEmptyString),
49
            ht.Comment("private")(ht.TNonEmptyString),
50
          ]))
51

    
52
_SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
53

    
54
_DATA_CHECK = ht.TStrictDict(False, True, {
55
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
56
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
57
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
58
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
59
  })
60

    
61

    
62
class JoinError(errors.GenericError):
63
  """Local class for reporting errors.
64

65
  """
66

    
67

    
68
def ParseOptions():
69
  """Parses the options passed to the program.
70

71
  @return: Options and arguments
72

73
  """
74
  program = os.path.basename(sys.argv[0])
75

    
76
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
77
                                 prog=program)
78
  parser.add_option(cli.DEBUG_OPT)
79
  parser.add_option(cli.VERBOSE_OPT)
80
  parser.add_option(cli.DRY_RUN_OPT)
81

    
82
  (opts, args) = parser.parse_args()
83

    
84
  return VerifyOptions(parser, opts, args)
85

    
86

    
87
def VerifyOptions(parser, opts, args):
88
  """Verifies options and arguments for correctness.
89

90
  """
91
  if args:
92
    parser.error("No arguments are expected")
93

    
94
  return opts
95

    
96

    
97
def _VerifyCertificate(cert, _noded_cert_file=pathutils.NODED_CERT_FILE):
98
  """Verifies a certificate against the local node daemon certificate.
99

100
  @type cert: string
101
  @param cert: Certificate in PEM format (no key)
102

103
  """
104
  try:
105
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert)
106
  except OpenSSL.crypto.Error, err:
107
    pass
108
  else:
109
    raise JoinError("No private key may be given")
110

    
111
  try:
112
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert)
113
  except Exception, err:
114
    raise errors.X509CertError("(stdin)",
115
                               "Unable to load certificate: %s" % err)
116

    
117
  try:
118
    noded_pem = utils.ReadFile(_noded_cert_file)
119
  except EnvironmentError, err:
120
    if err.errno != errno.ENOENT:
121
      raise
122

    
123
    logging.debug("Local node certificate was not found (file %s)",
124
                  _noded_cert_file)
125
    return
126

    
127
  try:
128
    key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, noded_pem)
129
  except Exception, err:
130
    raise errors.X509CertError(_noded_cert_file,
131
                               "Unable to load private key: %s" % err)
132

    
133
  ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
134
  ctx.use_privatekey(key)
135
  ctx.use_certificate(cert)
136
  try:
137
    ctx.check_privatekey()
138
  except OpenSSL.SSL.Error:
139
    raise JoinError("Given cluster certificate does not match local key")
140

    
141

    
142
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
143
  """Verifies cluster certificate.
144

145
  @type data: dict
146

147
  """
148
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
149
  if cert:
150
    _verify_fn(cert)
151

    
152

    
153
def _VerifyClusterName(name, _ss_cluster_name_file=None):
154
  """Verifies cluster name against a local cluster name.
155

156
  @type name: string
157
  @param name: Cluster name
158

159
  """
160
  if _ss_cluster_name_file is None:
161
    _ss_cluster_name_file = \
162
      ssconf.SimpleStore().KeyToFilename(constants.SS_CLUSTER_NAME)
163

    
164
  try:
165
    local_name = utils.ReadOneLineFile(_ss_cluster_name_file)
166
  except EnvironmentError, err:
167
    if err.errno != errno.ENOENT:
168
      raise
169

    
170
    logging.debug("Local cluster name was not found (file %s)",
171
                  _ss_cluster_name_file)
172
  else:
173
    if name != local_name:
174
      raise JoinError("Current cluster name is '%s'" % local_name)
175

    
176

    
177
def VerifyClusterName(data, _verify_fn=_VerifyClusterName):
178
  """Verifies cluster name.
179

180
  @type data: dict
181

182
  """
183
  name = data.get(constants.SSHS_CLUSTER_NAME)
184
  if name:
185
    _verify_fn(name)
186
  else:
187
    raise JoinError("Cluster name must be specified")
188

    
189

    
190
def _UpdateKeyFiles(keys, dry_run, keyfiles):
191
  """Updates SSH key files.
192

193
  @type keys: sequence of tuple; (string, string, string)
194
  @param keys: Keys to write, tuples consist of key type
195
    (L{constants.SSHK_ALL}), public and private key
196
  @type dry_run: boolean
197
  @param dry_run: Whether to perform a dry run
198
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
199
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
200
    names; value tuples consist of public key filename and private key filename
201

202
  """
203
  assert set(keyfiles) == constants.SSHK_ALL
204

    
205
  for (kind, private_key, public_key) in keys:
206
    (private_file, public_file) = keyfiles[kind]
207

    
208
    logging.debug("Writing %s ...", private_file)
209
    utils.WriteFile(private_file, data=private_key, mode=0600,
210
                    backup=True, dry_run=dry_run)
211

    
212
    logging.debug("Writing %s ...", public_file)
213
    utils.WriteFile(public_file, data=public_key, mode=0644,
214
                    backup=True, dry_run=dry_run)
215

    
216

    
217
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
218
                    _keyfiles=None):
219
  """Updates SSH daemon's keys.
220

221
  Unless C{dry_run} is set, the daemon is restarted at the end.
222

223
  @type data: dict
224
  @param data: Input data
225
  @type dry_run: boolean
226
  @param dry_run: Whether to perform a dry run
227

228
  """
229
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
230
  if not keys:
231
    return
232

    
233
  if _keyfiles is None:
234
    _keyfiles = constants.SSH_DAEMON_KEYFILES
235

    
236
  logging.info("Updating SSH daemon key files")
237
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
238

    
239
  if dry_run:
240
    logging.info("This is a dry run, not restarting SSH daemon")
241
  else:
242
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
243
                        interactive=True)
244
    if result.failed:
245
      raise JoinError("Could not reload SSH keys, command '%s'"
246
                      " had exitcode %s and error %s" %
247
                       (result.cmd, result.exit_code, result.output))
248

    
249

    
250
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
251
  """Updates root's SSH keys.
252

253
  Root's C{authorized_keys} file is also updated with new public keys.
254

255
  @type data: dict
256
  @param data: Input data
257
  @type dry_run: boolean
258
  @param dry_run: Whether to perform a dry run
259

260
  """
261
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
262
  if not keys:
263
    return
264

    
265
  (auth_keys_file, keyfiles) = \
266
    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
267
                        _homedir_fn=_homedir_fn)
268

    
269
  _UpdateKeyFiles(keys, dry_run, keyfiles)
270

    
271
  if dry_run:
272
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
273
  else:
274
    for (_, _, public_key) in keys:
275
      utils.AddAuthorizedKey(auth_keys_file, public_key)
276

    
277

    
278
def LoadData(raw):
279
  """Parses and verifies input data.
280

281
  @rtype: dict
282

283
  """
284
  try:
285
    data = serializer.LoadJson(raw)
286
  except Exception, err:
287
    raise errors.ParseError("Can't parse input data: %s" % err)
288

    
289
  if not _DATA_CHECK(data):
290
    raise errors.ParseError("Input data does not match expected format: %s" %
291
                            _DATA_CHECK)
292

    
293
  return data
294

    
295

    
296
def Main():
297
  """Main routine.
298

299
  """
300
  opts = ParseOptions()
301

    
302
  utils.SetupToolLogging(opts.debug, opts.verbose)
303

    
304
  try:
305
    data = LoadData(sys.stdin.read())
306

    
307
    # Check if input data is correct
308
    VerifyClusterName(data)
309
    VerifyCertificate(data)
310

    
311
    # Update SSH files
312
    UpdateSshDaemon(data, opts.dry_run)
313
    UpdateSshRoot(data, opts.dry_run)
314

    
315
    logging.info("Setup finished successfully")
316
  except Exception, err: # pylint: disable=W0703
317
    logging.debug("Caught unhandled exception", exc_info=True)
318

    
319
    (retcode, message) = cli.FormatError(err)
320
    logging.error(message)
321

    
322
    return retcode
323
  else:
324
    return constants.EXIT_SUCCESS