Revision 8a3c9e8a test/ganeti.ssh_unittest.py

b/test/ganeti.ssh_unittest.py
24 24
import os
25 25
import tempfile
26 26
import unittest
27
import shutil
27 28

  
28 29
import testutils
29 30
import mocks
......
31 32
from ganeti import constants
32 33
from ganeti import utils
33 34
from ganeti import ssh
35
from ganeti import errors
34 36

  
35 37

  
36 38
class TestKnownHosts(testutils.GanetiTestCase):
......
54 56
    self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe")
55 57

  
56 58

  
57
if __name__ == '__main__':
59
class TestGetUserFiles(unittest.TestCase):
60
  def setUp(self):
61
    self.tmpdir = tempfile.mkdtemp()
62

  
63
  def tearDown(self):
64
    shutil.rmtree(self.tmpdir)
65

  
66
  @staticmethod
67
  def _GetNoHomedir(_):
68
    return None
69

  
70
  def _GetTempHomedir(self, _):
71
    return self.tmpdir
72

  
73
  def testNonExistantUser(self):
74
    for kind in constants.SSHK_ALL:
75
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
76
                        kind=kind, _homedir_fn=self._GetNoHomedir)
77

  
78
  def testUnknownKind(self):
79
    kind = "something-else"
80
    assert kind not in constants.SSHK_ALL
81
    self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
82
                      kind=kind, _homedir_fn=self._GetTempHomedir)
83

  
84
    self.assertEqual(os.listdir(self.tmpdir), [])
85

  
86
  def testNoSshDirectory(self):
87
    for kind in constants.SSHK_ALL:
88
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
89
                        kind=kind, _homedir_fn=self._GetTempHomedir)
90
      self.assertEqual(os.listdir(self.tmpdir), [])
91

  
92
  def testSshIsFile(self):
93
    utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
94
    for kind in constants.SSHK_ALL:
95
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
96
                        kind=kind, _homedir_fn=self._GetTempHomedir)
97
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
98

  
99
  def testMakeSshDirectory(self):
100
    sshdir = os.path.join(self.tmpdir, ".ssh")
101

  
102
    self.assertEqual(os.listdir(self.tmpdir), [])
103

  
104
    for kind in constants.SSHK_ALL:
105
      ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
106
                       _homedir_fn=self._GetTempHomedir)
107
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
108
      self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
109

  
110
  def testFilenames(self):
111
    sshdir = os.path.join(self.tmpdir, ".ssh")
112

  
113
    os.mkdir(sshdir)
114

  
115
    for kind in constants.SSHK_ALL:
116
      result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
117
                                _homedir_fn=self._GetTempHomedir)
118
      self.assertEqual(result, [
119
        os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
120
        os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
121
        os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
122
        ])
123

  
124
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
125
      self.assertEqual(os.listdir(sshdir), [])
126

  
127

  
128
if __name__ == "__main__":
58 129
  testutils.GanetiTestProgram()

Also available in: Unified diff