Move SSH functions into a class
[ganeti-local] / lib / ssh.py
1 #
2 #
3
4 # Copyright (C) 2006, 2007 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Module encapsulating ssh functionality.
23
24 """
25
26
27 import os
28
29 from ganeti import logger
30 from ganeti import utils
31 from ganeti import errors
32 from ganeti import constants
33
34
35 KNOWN_HOSTS_OPTS = [
36   "-oGlobalKnownHostsFile=%s" % constants.SSH_KNOWN_HOSTS_FILE,
37   "-oUserKnownHostsFile=/dev/null",
38   ]
39
40 # Note: BATCH_MODE conflicts with ASK_KEY
41 BATCH_MODE_OPTS = [
42   "-oEscapeChar=none",
43   "-oBatchMode=yes",
44   "-oStrictHostKeyChecking=yes",
45   ]
46
47 ASK_KEY_OPTS = [
48   "-oStrictHostKeyChecking=ask",
49   "-oEscapeChar=none",
50   "-oHashKnownHosts=no",
51   ]
52
53
54 def GetUserFiles(user, mkdir=False):
55   """Return the paths of a user's ssh files.
56
57   The function will return a triplet (priv_key_path, pub_key_path,
58   auth_key_path) that are used for ssh authentication. Currently, the
59   keys used are DSA keys, so this function will return:
60   (~user/.ssh/id_dsa, ~user/.ssh/id_dsa.pub,
61   ~user/.ssh/authorized_keys).
62
63   If the optional parameter mkdir is True, the ssh directory will be
64   created if it doesn't exist.
65
66   Regardless of the mkdir parameters, the script will raise an error
67   if ~user/.ssh is not a directory.
68
69   """
70   user_dir = utils.GetHomeDir(user)
71   if not user_dir:
72     raise errors.OpExecError("Cannot resolve home of user %s" % user)
73
74   ssh_dir = os.path.join(user_dir, ".ssh")
75   if not os.path.lexists(ssh_dir):
76     if mkdir:
77       try:
78         os.mkdir(ssh_dir, 0700)
79       except EnvironmentError, err:
80         raise errors.OpExecError("Can't create .ssh dir for user %s: %s" %
81                                  (user, str(err)))
82   elif not os.path.isdir(ssh_dir):
83     raise errors.OpExecError("path ~%s/.ssh is not a directory" % user)
84
85   return [os.path.join(ssh_dir, base)
86           for base in ["id_dsa", "id_dsa.pub", "authorized_keys"]]
87
88
89 class SshRunner:
90   """Wrapper for SSH commands.
91
92   """
93   def BuildCmd(self, hostname, user, command, batch=True, ask_key=False):
94     """Build an ssh command to execute a command on a remote node.
95
96     Args:
97       hostname: the target host, string
98       user: user to auth as
99       command: the command
100       batch: if true, ssh will run in batch mode with no prompting
101       ask_key: if true, ssh will run with StrictHostKeyChecking=ask, so that
102                we can connect to an unknown host (not valid in batch mode)
103
104     Returns:
105       The ssh call to run 'command' on the remote host.
106
107     """
108     argv = ["ssh", "-q"]
109     argv.extend(KNOWN_HOSTS_OPTS)
110     if batch:
111       # if we are in batch mode, we can't ask the key
112       if ask_key:
113         raise errors.ProgrammerError("SSH call requested conflicting options")
114       argv.extend(BATCH_MODE_OPTS)
115     elif ask_key:
116       argv.extend(ASK_KEY_OPTS)
117     argv.extend(["%s@%s" % (user, hostname), command])
118     return argv
119
120   def Run(self, hostname, user, command, batch=True, ask_key=False):
121     """Runs a command on a remote node.
122
123     This method has the same return value as `utils.RunCmd()`, which it
124     uses to launch ssh.
125
126     Args:
127       hostname: the target host, string
128       user: user to auth as
129       command: the command
130       batch: if true, ssh will run in batch mode with no prompting
131       ask_key: if true, ssh will run with StrictHostKeyChecking=ask, so that
132                we can connect to an unknown host (not valid in batch mode)
133
134     Returns:
135       `utils.RunResult` like `utils.RunCmd()`
136
137     """
138     return utils.RunCmd(self.BuildCmd(hostname, user, command, batch=batch,
139                                       ask_key=ask_key))
140
141   def CopyFileToNode(self, node, filename):
142     """Copy a file to another node with scp.
143
144     Args:
145       node: node in the cluster
146       filename: absolute pathname of a local file
147
148     Returns:
149       success: True/False
150
151     """
152     if not os.path.isfile(filename):
153       logger.Error("file %s does not exist" % (filename))
154       return False
155
156     if not os.path.isabs(filename):
157       logger.Error("file %s must be an absolute path" % (filename))
158       return False
159
160     command = ["scp", "-q", "-p"]
161     command.extend(KNOWN_HOSTS_OPTS)
162     command.extend(BATCH_MODE_OPTS)
163     command.append(filename)
164     command.append("%s:%s" % (node, filename))
165
166     result = utils.RunCmd(command)
167
168     if result.failed:
169       logger.Error("copy to node %s failed (%s) error %s,"
170                    " command was %s" %
171                    (node, result.fail_reason, result.output, result.cmd))
172
173     return not result.failed
174
175   def VerifyNodeHostname(self, node):
176     """Verify hostname consistency via SSH.
177
178     This functions connects via ssh to a node and compares the hostname
179     reported by the node to the name with have (the one that we
180     connected to).
181
182     This is used to detect problems in ssh known_hosts files
183     (conflicting known hosts) and incosistencies between dns/hosts
184     entries and local machine names
185
186     Args:
187       node: nodename of a host to check. can be short or full qualified hostname
188
189     Returns:
190       (success, detail)
191       where
192         success: True/False
193         detail: String with details
194
195     """
196     retval = self.Run(node, 'root', 'hostname')
197
198     if retval.failed:
199       msg = "ssh problem"
200       output = retval.output
201       if output:
202         msg += ": %s" % output
203       return False, msg
204
205     remotehostname = retval.stdout.strip()
206
207     if not remotehostname or remotehostname != node:
208       return False, "hostname mismatch, got %s" % remotehostname
209
210     return True, "host matches"
211
212
213 def WriteKnownHostsFile(cfg, sstore, file_name):
214   """Writes the cluster-wide equally known_hosts file.
215
216   """
217   utils.WriteFile(file_name, mode=0700,
218                   data="%s ssh-rsa %s\n" % (sstore.GetClusterName(),
219                                             cfg.GetHostKey()))