Revision 5484cda5

b/lib/ssh.py
34 34
from ganeti import netutils
35 35
from ganeti import pathutils
36 36
from ganeti import vcluster
37
from ganeti import compat
37 38

  
38 39

  
39 40
def FormatParamikoFingerprint(fingerprint):
......
95 96
                       "authorized_keys"]]
96 97

  
97 98

  
99
def GetAllUserFiles(user, mkdir=False, dircheck=True, _homedir_fn=None):
100
  """Wrapper over L{GetUserFiles} to retrieve files for all SSH key types.
101

  
102
  See L{GetUserFiles} for details.
103

  
104
  @rtype: tuple; (string, dict with string as key, tuple of (string, string) as
105
    value)
106

  
107
  """
108
  helper = compat.partial(GetUserFiles, user, mkdir=mkdir, dircheck=dircheck,
109
                          _homedir_fn=_homedir_fn)
110
  result = [(kind, helper(kind=kind)) for kind in constants.SSHK_ALL]
111

  
112
  authorized_keys = [i for (_, (_, _, i)) in result]
113

  
114
  assert len(frozenset(authorized_keys)) == 1, \
115
    "Different paths for authorized_keys were returned"
116

  
117
  return (authorized_keys[0],
118
          dict((kind, (privkey, pubkey))
119
               for (kind, (privkey, pubkey, _)) in result))
120

  
121

  
98 122
class SshRunner:
99 123
  """Wrapper for SSH commands.
100 124

  
b/test/ganeti.ssh_unittest.py
132 132
                       _homedir_fn=self._GetTempHomedir)
133 133
      self.assertEqual(os.listdir(self.tmpdir), [])
134 134

  
135
  def testGetAllUserFiles(self):
136
    result = ssh.GetAllUserFiles("example7475", mkdir=False, dircheck=False,
137
                                 _homedir_fn=self._GetTempHomedir)
138
    self.assertEqual(result,
139
      (os.path.join(self.tmpdir, ".ssh", "authorized_keys"), {
140
        constants.SSHK_RSA:
141
          (os.path.join(self.tmpdir, ".ssh", "id_rsa"),
142
           os.path.join(self.tmpdir, ".ssh", "id_rsa.pub")),
143
        constants.SSHK_DSA:
144
          (os.path.join(self.tmpdir, ".ssh", "id_dsa"),
145
           os.path.join(self.tmpdir, ".ssh", "id_dsa.pub")),
146
      }))
147
    self.assertEqual(os.listdir(self.tmpdir), [])
148

  
149
  def testGetAllUserFilesNoDirectoryNoMkdir(self):
150
    self.assertRaises(errors.OpExecError, ssh.GetAllUserFiles,
151
                      "example17270", mkdir=False, dircheck=True,
152
                      _homedir_fn=self._GetTempHomedir)
153
    self.assertEqual(os.listdir(self.tmpdir), [])
154

  
135 155

  
136 156
if __name__ == "__main__":
137 157
  testutils.GanetiTestProgram()

Also available in: Unified diff