Revision 8a3c9e8a

b/lib/constants.py
2044 2044
FAKE_OP_MASTER_TURNUP = "OP_CLUSTER_IP_TURNUP"
2045 2045
FAKE_OP_MASTER_TURNDOWN = "OP_CLUSTER_IP_TURNDOWN"
2046 2046

  
2047
# SSH key types
2048
SSHK_RSA = "rsa"
2049
SSHK_DSA = "dsa"
2050
SSHK_ALL = frozenset([SSHK_RSA, SSHK_DSA])
2051

  
2047 2052
# Do not re-export imported modules
2048 2053
del re, _vcsversion, _autoconf, socket, pathutils
b/lib/ssh.py
48 48
  return ":".join(re.findall(r"..", fingerprint.lower()))
49 49

  
50 50

  
51
def GetUserFiles(user, mkdir=False):
52
  """Return the paths of a user's ssh files.
53

  
54
  The function will return a triplet (priv_key_path, pub_key_path,
55
  auth_key_path) that are used for ssh authentication. Currently, the
56
  keys used are DSA keys, so this function will return:
57
  (~user/.ssh/id_dsa, ~user/.ssh/id_dsa.pub,
58
  ~user/.ssh/authorized_keys).
59

  
60
  If the optional parameter mkdir is True, the ssh directory will be
61
  created if it doesn't exist.
62

  
63
  Regardless of the mkdir parameters, the script will raise an error
64
  if ~user/.ssh is not a directory.
51
def GetUserFiles(user, mkdir=False, kind=constants.SSHK_DSA,
52
                 _homedir_fn=utils.GetHomeDir):
53
  """Return the paths of a user's SSH files.
54

  
55
  @type user: string
56
  @param user: Username
57
  @type mkdir: bool
58
  @param mkdir: Whether to create ".ssh" directory if it doesn't exist
59
  @type kind: string
60
  @param kind: One of L{constants.SSHK_ALL}
61
  @rtype: tuple; (string, string, string)
62
  @return: Tuple containing three file system paths; the private SSH key file,
63
    the public SSH key file and the user's C{authorized_keys} file
64
  @raise errors.OpExecError: When home directory of the user can not be
65
    determined
66
  @raise errors.OpExecError: Regardless of the C{mkdir} parameters, this
67
    exception is raised if C{~$user/.ssh} is not a directory
65 68

  
66 69
  """
67
  user_dir = utils.GetHomeDir(user)
70
  user_dir = _homedir_fn(user)
68 71
  if not user_dir:
69
    raise errors.OpExecError("Cannot resolve home of user %s" % user)
72
    raise errors.OpExecError("Cannot resolve home of user '%s'" % user)
73

  
74
  if kind == constants.SSHK_DSA:
75
    suffix = "dsa"
76
  elif kind == constants.SSHK_RSA:
77
    suffix = "rsa"
78
  else:
79
    raise errors.ProgrammerError("Unknown SSH key kind '%s'" % kind)
70 80

  
71 81
  ssh_dir = utils.PathJoin(user_dir, ".ssh")
72 82
  if mkdir:
......
75 85
    raise errors.OpExecError("Path %s is not a directory" % ssh_dir)
76 86

  
77 87
  return [utils.PathJoin(ssh_dir, base)
78
          for base in ["id_dsa", "id_dsa.pub", "authorized_keys"]]
88
          for base in ["id_%s" % suffix, "id_%s.pub" % suffix,
89
                       "authorized_keys"]]
79 90

  
80 91

  
81 92
class SshRunner:
b/test/ganeti.ssh_unittest.py
24 24
import os
25 25
import tempfile
26 26
import unittest
27
import shutil
27 28

  
28 29
import testutils
29 30
import mocks
......
31 32
from ganeti import constants
32 33
from ganeti import utils
33 34
from ganeti import ssh
35
from ganeti import errors
34 36

  
35 37

  
36 38
class TestKnownHosts(testutils.GanetiTestCase):
......
54 56
    self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe")
55 57

  
56 58

  
57
if __name__ == '__main__':
59
class TestGetUserFiles(unittest.TestCase):
60
  def setUp(self):
61
    self.tmpdir = tempfile.mkdtemp()
62

  
63
  def tearDown(self):
64
    shutil.rmtree(self.tmpdir)
65

  
66
  @staticmethod
67
  def _GetNoHomedir(_):
68
    return None
69

  
70
  def _GetTempHomedir(self, _):
71
    return self.tmpdir
72

  
73
  def testNonExistantUser(self):
74
    for kind in constants.SSHK_ALL:
75
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
76
                        kind=kind, _homedir_fn=self._GetNoHomedir)
77

  
78
  def testUnknownKind(self):
79
    kind = "something-else"
80
    assert kind not in constants.SSHK_ALL
81
    self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
82
                      kind=kind, _homedir_fn=self._GetTempHomedir)
83

  
84
    self.assertEqual(os.listdir(self.tmpdir), [])
85

  
86
  def testNoSshDirectory(self):
87
    for kind in constants.SSHK_ALL:
88
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
89
                        kind=kind, _homedir_fn=self._GetTempHomedir)
90
      self.assertEqual(os.listdir(self.tmpdir), [])
91

  
92
  def testSshIsFile(self):
93
    utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
94
    for kind in constants.SSHK_ALL:
95
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
96
                        kind=kind, _homedir_fn=self._GetTempHomedir)
97
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
98

  
99
  def testMakeSshDirectory(self):
100
    sshdir = os.path.join(self.tmpdir, ".ssh")
101

  
102
    self.assertEqual(os.listdir(self.tmpdir), [])
103

  
104
    for kind in constants.SSHK_ALL:
105
      ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
106
                       _homedir_fn=self._GetTempHomedir)
107
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
108
      self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
109

  
110
  def testFilenames(self):
111
    sshdir = os.path.join(self.tmpdir, ".ssh")
112

  
113
    os.mkdir(sshdir)
114

  
115
    for kind in constants.SSHK_ALL:
116
      result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
117
                                _homedir_fn=self._GetTempHomedir)
118
      self.assertEqual(result, [
119
        os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
120
        os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
121
        os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
122
        ])
123

  
124
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
125
      self.assertEqual(os.listdir(sshdir), [])
126

  
127

  
128
if __name__ == "__main__":
58 129
  testutils.GanetiTestProgram()

Also available in: Unified diff