Statistics
| Branch: | Tag: | Revision:

root / lib / tools / prepare_node_join.py @ fb62843c

History | View | Annotate | Download (6.9 kB)

1
#
2
#
3

    
4
# Copyright (C) 2012 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Script to prepare a node for joining a cluster.
22

23
"""
24

    
25
import os
26
import os.path
27
import optparse
28
import sys
29
import logging
30
import OpenSSL
31

    
32
from ganeti import cli
33
from ganeti import constants
34
from ganeti import errors
35
from ganeti import pathutils
36
from ganeti import utils
37
from ganeti import serializer
38
from ganeti import ht
39
from ganeti import ssh
40
from ganeti import ssconf
41

    
42

    
43
_SSH_KEY_LIST_ITEM = \
44
  ht.TAnd(ht.TIsLength(3),
45
          ht.TItems([
46
            ht.TElemOf(constants.SSHK_ALL),
47
            ht.Comment("public")(ht.TNonEmptyString),
48
            ht.Comment("private")(ht.TNonEmptyString),
49
          ]))
50

    
51
_SSH_KEY_LIST = ht.TListOf(_SSH_KEY_LIST_ITEM)
52

    
53
_DATA_CHECK = ht.TStrictDict(False, True, {
54
  constants.SSHS_CLUSTER_NAME: ht.TNonEmptyString,
55
  constants.SSHS_NODE_DAEMON_CERTIFICATE: ht.TNonEmptyString,
56
  constants.SSHS_SSH_HOST_KEY: _SSH_KEY_LIST,
57
  constants.SSHS_SSH_ROOT_KEY: _SSH_KEY_LIST,
58
  })
59

    
60

    
61
class JoinError(errors.GenericError):
62
  """Local class for reporting errors.
63

64
  """
65

    
66

    
67
def ParseOptions():
68
  """Parses the options passed to the program.
69

70
  @return: Options and arguments
71

72
  """
73
  program = os.path.basename(sys.argv[0])
74

    
75
  parser = optparse.OptionParser(usage="%prog [--dry-run]",
76
                                 prog=program)
77
  parser.add_option(cli.DEBUG_OPT)
78
  parser.add_option(cli.VERBOSE_OPT)
79
  parser.add_option(cli.DRY_RUN_OPT)
80

    
81
  (opts, args) = parser.parse_args()
82

    
83
  return VerifyOptions(parser, opts, args)
84

    
85

    
86
def VerifyOptions(parser, opts, args):
87
  """Verifies options and arguments for correctness.
88

89
  """
90
  if args:
91
    parser.error("No arguments are expected")
92

    
93
  return opts
94

    
95

    
96
def _VerifyCertificate(cert_pem, _check_fn=utils.CheckNodeCertificate):
97
  """Verifies a certificate against the local node daemon certificate.
98

99
  @type cert_pem: string
100
  @param cert_pem: Certificate in PEM format (no key)
101

102
  """
103
  try:
104
    OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
105
  except OpenSSL.crypto.Error, err:
106
    pass
107
  else:
108
    raise JoinError("No private key may be given")
109

    
110
  try:
111
    cert = \
112
      OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_pem)
113
  except Exception, err:
114
    raise errors.X509CertError("(stdin)",
115
                               "Unable to load certificate: %s" % err)
116

    
117
  _check_fn(cert)
118

    
119

    
120
def VerifyCertificate(data, _verify_fn=_VerifyCertificate):
121
  """Verifies cluster certificate.
122

123
  @type data: dict
124

125
  """
126
  cert = data.get(constants.SSHS_NODE_DAEMON_CERTIFICATE)
127
  if cert:
128
    _verify_fn(cert)
129

    
130

    
131
def VerifyClusterName(data, _verify_fn=ssconf.VerifyClusterName):
132
  """Verifies cluster name.
133

134
  @type data: dict
135

136
  """
137
  name = data.get(constants.SSHS_CLUSTER_NAME)
138
  if name:
139
    _verify_fn(name)
140
  else:
141
    raise JoinError("Cluster name must be specified")
142

    
143

    
144
def _UpdateKeyFiles(keys, dry_run, keyfiles):
145
  """Updates SSH key files.
146

147
  @type keys: sequence of tuple; (string, string, string)
148
  @param keys: Keys to write, tuples consist of key type
149
    (L{constants.SSHK_ALL}), public and private key
150
  @type dry_run: boolean
151
  @param dry_run: Whether to perform a dry run
152
  @type keyfiles: dict; (string as key, tuple with (string, string) as values)
153
  @param keyfiles: Mapping from key types (L{constants.SSHK_ALL}) to file
154
    names; value tuples consist of public key filename and private key filename
155

156
  """
157
  assert set(keyfiles) == constants.SSHK_ALL
158

    
159
  for (kind, private_key, public_key) in keys:
160
    (private_file, public_file) = keyfiles[kind]
161

    
162
    logging.debug("Writing %s ...", private_file)
163
    utils.WriteFile(private_file, data=private_key, mode=0600,
164
                    backup=True, dry_run=dry_run)
165

    
166
    logging.debug("Writing %s ...", public_file)
167
    utils.WriteFile(public_file, data=public_key, mode=0644,
168
                    backup=True, dry_run=dry_run)
169

    
170

    
171
def UpdateSshDaemon(data, dry_run, _runcmd_fn=utils.RunCmd,
172
                    _keyfiles=None):
173
  """Updates SSH daemon's keys.
174

175
  Unless C{dry_run} is set, the daemon is restarted at the end.
176

177
  @type data: dict
178
  @param data: Input data
179
  @type dry_run: boolean
180
  @param dry_run: Whether to perform a dry run
181

182
  """
183
  keys = data.get(constants.SSHS_SSH_HOST_KEY)
184
  if not keys:
185
    return
186

    
187
  if _keyfiles is None:
188
    _keyfiles = constants.SSH_DAEMON_KEYFILES
189

    
190
  logging.info("Updating SSH daemon key files")
191
  _UpdateKeyFiles(keys, dry_run, _keyfiles)
192

    
193
  if dry_run:
194
    logging.info("This is a dry run, not restarting SSH daemon")
195
  else:
196
    result = _runcmd_fn([pathutils.DAEMON_UTIL, "reload-ssh-keys"],
197
                        interactive=True)
198
    if result.failed:
199
      raise JoinError("Could not reload SSH keys, command '%s'"
200
                      " had exitcode %s and error %s" %
201
                       (result.cmd, result.exit_code, result.output))
202

    
203

    
204
def UpdateSshRoot(data, dry_run, _homedir_fn=None):
205
  """Updates root's SSH keys.
206

207
  Root's C{authorized_keys} file is also updated with new public keys.
208

209
  @type data: dict
210
  @param data: Input data
211
  @type dry_run: boolean
212
  @param dry_run: Whether to perform a dry run
213

214
  """
215
  keys = data.get(constants.SSHS_SSH_ROOT_KEY)
216
  if not keys:
217
    return
218

    
219
  (auth_keys_file, keyfiles) = \
220
    ssh.GetAllUserFiles(constants.SSH_LOGIN_USER, mkdir=True,
221
                        _homedir_fn=_homedir_fn)
222

    
223
  _UpdateKeyFiles(keys, dry_run, keyfiles)
224

    
225
  if dry_run:
226
    logging.info("This is a dry run, not modifying %s", auth_keys_file)
227
  else:
228
    for (_, _, public_key) in keys:
229
      utils.AddAuthorizedKey(auth_keys_file, public_key)
230

    
231

    
232
def LoadData(raw):
233
  """Parses and verifies input data.
234

235
  @rtype: dict
236

237
  """
238
  return serializer.LoadAndVerifyJson(raw, _DATA_CHECK)
239

    
240

    
241
def Main():
242
  """Main routine.
243

244
  """
245
  opts = ParseOptions()
246

    
247
  utils.SetupToolLogging(opts.debug, opts.verbose)
248

    
249
  try:
250
    data = LoadData(sys.stdin.read())
251

    
252
    # Check if input data is correct
253
    VerifyClusterName(data)
254
    VerifyCertificate(data)
255

    
256
    # Update SSH files
257
    UpdateSshDaemon(data, opts.dry_run)
258
    UpdateSshRoot(data, opts.dry_run)
259

    
260
    logging.info("Setup finished successfully")
261
  except Exception, err: # pylint: disable=W0703
262
    logging.debug("Caught unhandled exception", exc_info=True)
263

    
264
    (retcode, message) = cli.FormatError(err)
265
    logging.error(message)
266

    
267
    return retcode
268
  else:
269
    return constants.EXIT_SUCCESS