Statistics
| Branch: | Tag: | Revision:

root / tools / setup-ssh @ d3b18b8e

History | View | Annotate | Download (13.8 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2010 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21
"""Tool to setup the SSH configuration on a remote node.
22

    
23
This is needed before we can join the node into the cluster.
24

    
25
"""
26

    
27
# pylint: disable-msg=C0103
28
# C0103: Invalid name setup-ssh
29

    
30
import getpass
31
import logging
32
import paramiko
33
import os.path
34
import optparse
35
import sys
36

    
37
from ganeti import cli
38
from ganeti import constants
39
from ganeti import errors
40
from ganeti import netutils
41
from ganeti import ssconf
42
from ganeti import ssh
43
from ganeti import utils
44

    
45

    
46
class RemoteCommandError(errors.GenericError):
47
  """Exception if remote command was not successful.
48

    
49
  """
50

    
51

    
52
class JoinCheckError(errors.GenericError):
53
  """Exception raised if join check fails.
54

    
55
  """
56

    
57

    
58
def _CheckJoin(transport):
59
  """Checks if a join is safe or dangerous.
60

    
61
  Note: This function relies on the fact, that all
62
  hosts have the same configuration at compile time of
63
  Ganeti. So that the constants do not mismatch.
64

    
65
  @param transport: The paramiko transport instance
66
  @return: True if the join is safe; False otherwise
67

    
68
  """
69
  sftp = transport.open_sftp_client()
70
  ss = ssconf.SimpleStore()
71
  ss_cluster_name_path = ss.KeyToFilename(constants.SS_CLUSTER_NAME)
72

    
73
  cluster_files = {
74
    ss_cluster_name_path: utils.ReadFile(ss_cluster_name_path),
75
    constants.NODED_CERT_FILE: utils.ReadFile(constants.NODED_CERT_FILE),
76
    }
77

    
78
  try:
79
    remote_noded_file = _ReadSftpFile(sftp, constants.NODED_CERT_FILE)
80
  except IOError:
81
    # We can just assume that the file doesn't exist as such error reporting
82
    # is lacking from paramiko
83
    #
84
    # We don't have the noded certificate. As without the cert, the
85
    # noded is not running, we are on the safe bet to say that this
86
    # node doesn't belong to a cluster
87
    return True
88

    
89
  try:
90
    remote_cluster_name = _ReadSftpFile(sftp, ss_cluster_name_path)
91
  except IOError:
92
    # This can indicate that a previous join was not successful
93
    # So if the noded cert was found and matches we are fine
94
    return cluster_files[constants.NODED_CERT_FILE] == remote_noded_file
95

    
96
  return (cluster_files[constants.NODED_CERT_FILE] == remote_noded_file and
97
          cluster_files[ss_cluster_name_path] == remote_cluster_name)
98

    
99

    
100
def _RunRemoteCommand(transport, command):
101
  """Invokes and wait for the command over SSH.
102

    
103
  @param transport: The paramiko transport instance
104
  @param command: The command to be executed
105

    
106
  """
107
  chan = transport.open_session()
108
  chan.set_combine_stderr(True)
109
  output_handler = chan.makefile("r")
110
  chan.exec_command(command)
111

    
112
  result = chan.recv_exit_status()
113
  msg = output_handler.read()
114

    
115
  out_msg = "'%s' exited with status code %s, output %r" % (command, result,
116
                                                            msg)
117

    
118
  # If result is -1 (no exit status provided) we assume it was not successful
119
  if result:
120
    raise RemoteCommandError(out_msg)
121

    
122
  if msg:
123
    logging.info(out_msg)
124

    
125

    
126
def _InvokeDaemonUtil(transport, command):
127
  """Invokes daemon-util on the remote side.
128

    
129
  @param transport: The paramiko transport instance
130
  @param command: The daemon-util command to be run
131

    
132
  """
133
  _RunRemoteCommand(transport, "%s %s" % (constants.DAEMON_UTIL, command))
134

    
135

    
136
def _ReadSftpFile(sftp, filename):
137
  """Reads a file over sftp.
138

    
139
  @param sftp: An open paramiko SFTP client
140
  @param filename: The filename of the file to read
141
  @return: The content of the file
142

    
143
  """
144
  remote_file = sftp.open(filename, "r")
145
  try:
146
    return remote_file.read()
147
  finally:
148
    remote_file.close()
149

    
150

    
151
def _WriteSftpFile(sftp, name, perm, data):
152
  """SFTPs data to a remote file.
153

    
154
  @param sftp: A open paramiko SFTP client
155
  @param name: The remote file name
156
  @param perm: The remote file permission
157
  @param data: The data to write
158

    
159
  """
160
  remote_file = sftp.open(name, "w")
161
  try:
162
    sftp.chmod(name, perm)
163
    remote_file.write(data)
164
  finally:
165
    remote_file.close()
166

    
167

    
168
def SetupSSH(transport):
169
  """Sets the SSH up on the other side.
170

    
171
  @param transport: The paramiko transport instance
172

    
173
  """
174
  priv_key, pub_key, auth_keys = ssh.GetUserFiles(constants.GANETI_RUNAS)
175
  keyfiles = [
176
    (constants.SSH_HOST_DSA_PRIV, 0600),
177
    (constants.SSH_HOST_DSA_PUB, 0644),
178
    (constants.SSH_HOST_RSA_PRIV, 0600),
179
    (constants.SSH_HOST_RSA_PUB, 0644),
180
    (priv_key, 0600),
181
    (pub_key, 0644),
182
    ]
183

    
184
  sftp = transport.open_sftp_client()
185

    
186
  filemap = dict((name, (utils.ReadFile(name), perm))
187
                 for (name, perm) in keyfiles)
188

    
189
  auth_path = os.path.dirname(auth_keys)
190

    
191
  try:
192
    sftp.mkdir(auth_path, 0700)
193
  except IOError:
194
    # Sadly paramiko doesn't provide errno or similiar
195
    # so we can just assume that the path already exists
196
    logging.info("Path %s seems already to exist on remote node. Ignoring.",
197
                 auth_path)
198

    
199
  for name, (data, perm) in filemap.iteritems():
200
    _WriteSftpFile(sftp, name, perm, data)
201

    
202
  authorized_keys = sftp.open(auth_keys, "a+")
203
  try:
204
    # Due to the way SFTPFile and BufferedFile are implemented,
205
    # opening in a+ mode and then issuing a read(), readline() or
206
    # iterating over the file (which uses read() internally) will see
207
    # an empty file, since the paramiko internal file position and the
208
    # OS-level file-position are desynchronized; therefore, we issue
209
    # an explicit seek to resynchronize these; writes should (note
210
    # should) still go to the right place
211
    authorized_keys.seek(0, 0)
212
    # We don't have to close, as the close happened already in AddAuthorizedKey
213
    utils.AddAuthorizedKey(authorized_keys, filemap[pub_key][0])
214
  finally:
215
    authorized_keys.close()
216

    
217
  _InvokeDaemonUtil(transport, "reload-ssh-keys")
218

    
219

    
220
def ParseOptions():
221
  """Parses options passed to program.
222

    
223
  """
224
  program = os.path.basename(sys.argv[0])
225

    
226
  parser = optparse.OptionParser(usage=("%prog [--debug|--verbose] [--force]"
227
                                        " <node> <node...>"), prog=program)
228
  parser.add_option(cli.DEBUG_OPT)
229
  parser.add_option(cli.VERBOSE_OPT)
230
  parser.add_option(cli.NOSSH_KEYCHECK_OPT)
231
  default_key = ssh.GetUserFiles(constants.GANETI_RUNAS)[0]
232
  parser.add_option(optparse.Option("-f", dest="private_key",
233
                                    default=default_key,
234
                                    help="The private key to (try to) use for"
235
                                    "authentication "))
236
  parser.add_option(optparse.Option("--key-type", dest="key_type",
237
                                    choices=("rsa", "dsa"), default="dsa",
238
                                    help="The private key type (rsa or dsa)"))
239
  parser.add_option(optparse.Option("-j", "--force-join", dest="force_join",
240
                                    action="store_true", default=False,
241
                                    help="Force the join of the host"))
242

    
243
  (options, args) = parser.parse_args()
244

    
245
  return (options, args)
246

    
247

    
248
def SetupLogging(options):
249
  """Sets up the logging.
250

    
251
  @param options: Parsed options
252

    
253
  """
254
  fmt = "%(asctime)s: %(threadName)s "
255
  if options.debug or options.verbose:
256
    fmt += "%(levelname)s "
257
  fmt += "%(message)s"
258

    
259
  formatter = logging.Formatter(fmt)
260

    
261
  file_handler = logging.FileHandler(constants.LOG_SETUP_SSH)
262
  stderr_handler = logging.StreamHandler()
263
  stderr_handler.setFormatter(formatter)
264
  file_handler.setFormatter(formatter)
265
  file_handler.setLevel(logging.INFO)
266

    
267
  if options.debug:
268
    stderr_handler.setLevel(logging.DEBUG)
269
  elif options.verbose:
270
    stderr_handler.setLevel(logging.INFO)
271
  else:
272
    stderr_handler.setLevel(logging.WARNING)
273

    
274
  root_logger = logging.getLogger("")
275
  root_logger.setLevel(logging.NOTSET)
276
  root_logger.addHandler(stderr_handler)
277
  root_logger.addHandler(file_handler)
278

    
279
  # This is the paramiko logger instance
280
  paramiko_logger = logging.getLogger("paramiko")
281
  paramiko_logger.addHandler(file_handler)
282
  # We don't want to debug Paramiko, so filter anything below warning
283
  paramiko_logger.setLevel(logging.WARNING)
284

    
285

    
286
def LoadPrivateKeys(options):
287
  """Load the list of available private keys.
288

    
289
  It loads the standard ssh key from disk and then tries to connect to
290
  the ssh agent too.
291

    
292
  @rtype: list
293
  @return: a list of C{paramiko.PKey}
294

    
295
  """
296
  if options.key_type == "rsa":
297
    pkclass = paramiko.RSAKey
298
  elif options.key_type == "dsa":
299
    pkclass = paramiko.DSSKey
300
  else:
301
    logging.critical("Unknown key type %s selected (choose either rsa or dsa)",
302
                     options.key_type)
303
    sys.exit(1)
304

    
305
  try:
306
    private_key = pkclass.from_private_key_file(options.private_key)
307
  except (paramiko.SSHException, EnvironmentError), err:
308
    logging.critical("Can't load private key %s: %s", options.private_key, err)
309
    sys.exit(1)
310

    
311
  try:
312
    agent = paramiko.Agent()
313
    agent_keys = agent.get_keys()
314
  except paramiko.SSHException, err:
315
    # this will only be seen when the agent is broken/uses invalid
316
    # protocol; for non-existing agent, get_keys() will just return an
317
    # empty tuple
318
    logging.warning("Can't connect to the ssh agent: %s; skipping its use",
319
                    err)
320
    agent_keys = []
321

    
322
  return [private_key] + list(agent_keys)
323

    
324

    
325
def LoginViaKeys(transport, username, keys):
326
  """Try to login on the given transport via a list of keys.
327

    
328
  @param transport: the transport to use
329
  @param username: the username to login as
330
  @type keys: list
331
  @param keys: list of C{paramiko.PKey} to use for authentication
332
  @rtype: boolean
333
  @return: True or False depending on whether the login was
334
      successfull or not
335

    
336
  """
337
  for private_key in keys:
338
    try:
339
      transport.auth_publickey(username, private_key)
340
      fpr = ":".join("%02x" % ord(i) for i in private_key.get_fingerprint())
341
      if isinstance(private_key, paramiko.AgentKey):
342
        logging.debug("Authentication via the ssh-agent key %s", fpr)
343
      else:
344
        logging.debug("Authenticated via public key %s", fpr)
345
      return True
346
    except paramiko.SSHException:
347
      continue
348
  else:
349
    # all keys exhausted
350
    return False
351

    
352

    
353
def LoadKnownHosts():
354
  """Load the known hosts.
355

    
356
  @return: paramiko.util.load_host_keys dict
357

    
358
  """
359
  homedir = utils.GetHomeDir(constants.GANETI_RUNAS)
360
  known_hosts = os.path.join(homedir, ".ssh", "known_hosts")
361

    
362
  try:
363
    return paramiko.util.load_host_keys(known_hosts)
364
  except EnvironmentError:
365
    # We didn't found the path, silently ignore and return an empty dict
366
    return {}
367

    
368

    
369
def main():
370
  """Main routine.
371

    
372
  """
373
  (options, args) = ParseOptions()
374

    
375
  SetupLogging(options)
376

    
377
  all_keys = LoadPrivateKeys(options)
378

    
379
  passwd = None
380
  username = constants.GANETI_RUNAS
381
  ssh_port = netutils.GetDaemonPort("ssh")
382
  host_keys = LoadKnownHosts()
383

    
384
  # Below, we need to join() the transport objects, as otherwise the
385
  # following happens:
386
  # - the main thread finishes
387
  # - the atexit functions run (in the main thread), and cause the
388
  #   logging file to be closed
389
  # - a tiny bit later, the transport thread is finally ending, and
390
  #   wants to log one more message, which fails as the file is closed
391
  #   now
392

    
393
  for host in args:
394
    transport = paramiko.Transport((host, ssh_port))
395
    transport.start_client()
396
    server_key = transport.get_remote_server_key()
397
    keytype = server_key.get_name()
398

    
399
    our_server_key = host_keys.get(host, {}).get(keytype, None)
400
    if options.ssh_key_check:
401
      if not our_server_key:
402
        hexified_key = ssh.FormatParamikoFingerprint(
403
            server_key.get_fingerprint())
404
        msg = ("Unable to verify hostkey of host %s: %s. Do you want to accept"
405
               " it?" % (host, hexified_key))
406

    
407
        if cli.AskUser(msg):
408
          our_server_key = server_key
409

    
410
      if our_server_key != server_key:
411
        logging.error("Unable to verify identity of host. Aborting")
412
        transport.close()
413
        transport.join()
414
        # TODO: Run over all hosts, fetch the keys and let them verify from the
415
        #       user beforehand then proceed with actual work later on
416
        raise paramiko.SSHException("Unable to verify identity of host")
417

    
418
    try:
419
      if LoginViaKeys(transport, username, all_keys):
420
        logging.info("Authenticated to %s via public key", host)
421
      else:
422
        logging.warning("Authentication to %s via public key failed, trying"
423
                        " password", host)
424
        if passwd is None:
425
          passwd = getpass.getpass(prompt="%s password:" % username)
426
        transport.auth_password(username=username, password=passwd)
427
        logging.info("Authenticated to %s via password", host)
428
    except paramiko.SSHException, err:
429
      logging.error("Connection or authentication failed to host %s: %s",
430
                    host, err)
431
      transport.close()
432
      # this is needed for compatibility with older Paramiko or Python
433
      # versions
434
      transport.join()
435
      continue
436
    try:
437
      try:
438
        if not _CheckJoin(transport):
439
          if options.force_join:
440
            logging.warning("Host %s failed join check, forced to continue",
441
                            host)
442
          else:
443
            raise JoinCheckError("Host %s failed join check" % host)
444
        SetupSSH(transport)
445
      except errors.GenericError, err:
446
        logging.error("While doing setup on host %s an error occured: %s",
447
                      host, err)
448
    finally:
449
      transport.close()
450
      # this is needed for compatibility with older Paramiko or Python
451
      # versions
452
      transport.join()
453

    
454

    
455
if __name__ == "__main__":
456
  main()