Statistics
| Branch: | Tag: | Revision:

root / test / ganeti.ssh_unittest.py @ 7bd70e6b

History | View | Annotate | Download (4.4 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2006, 2007, 2008 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the ssh module"""
23

    
24
import os
25
import tempfile
26
import unittest
27
import shutil
28

    
29
import testutils
30
import mocks
31

    
32
from ganeti import constants
33
from ganeti import utils
34
from ganeti import ssh
35
from ganeti import errors
36

    
37

    
38
class TestKnownHosts(testutils.GanetiTestCase):
39
  """Test case for function writing the known_hosts file"""
40

    
41
  def setUp(self):
42
    testutils.GanetiTestCase.setUp(self)
43
    self.tmpfile = self._CreateTempFile()
44

    
45
  def test(self):
46
    cfg = mocks.FakeConfig()
47
    ssh.WriteKnownHostsFile(cfg, self.tmpfile)
48
    self.assertFileContent(self.tmpfile,
49
        "%s ssh-rsa %s\n" % (cfg.GetClusterName(),
50
                             mocks.FAKE_CLUSTER_KEY))
51

    
52
  def testFormatParamikoFingerprintCorrect(self):
53
    self.assertEqual(ssh.FormatParamikoFingerprint("C0Ffee"), "c0:ff:ee")
54

    
55
  def testFormatParamikoFingerprintNotDividableByTwo(self):
56
    self.assertRaises(AssertionError, ssh.FormatParamikoFingerprint, "C0Ffe")
57

    
58

    
59
class TestGetUserFiles(unittest.TestCase):
60
  def setUp(self):
61
    self.tmpdir = tempfile.mkdtemp()
62

    
63
  def tearDown(self):
64
    shutil.rmtree(self.tmpdir)
65

    
66
  @staticmethod
67
  def _GetNoHomedir(_):
68
    return None
69

    
70
  def _GetTempHomedir(self, _):
71
    return self.tmpdir
72

    
73
  def testNonExistantUser(self):
74
    for kind in constants.SSHK_ALL:
75
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example",
76
                        kind=kind, _homedir_fn=self._GetNoHomedir)
77

    
78
  def testUnknownKind(self):
79
    kind = "something-else"
80
    assert kind not in constants.SSHK_ALL
81
    self.assertRaises(errors.ProgrammerError, ssh.GetUserFiles, "example4645",
82
                      kind=kind, _homedir_fn=self._GetTempHomedir)
83

    
84
    self.assertEqual(os.listdir(self.tmpdir), [])
85

    
86
  def testNoSshDirectory(self):
87
    for kind in constants.SSHK_ALL:
88
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example29694",
89
                        kind=kind, _homedir_fn=self._GetTempHomedir)
90
      self.assertEqual(os.listdir(self.tmpdir), [])
91

    
92
  def testSshIsFile(self):
93
    utils.WriteFile(os.path.join(self.tmpdir, ".ssh"), data="")
94
    for kind in constants.SSHK_ALL:
95
      self.assertRaises(errors.OpExecError, ssh.GetUserFiles, "example26237",
96
                        kind=kind, _homedir_fn=self._GetTempHomedir)
97
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
98

    
99
  def testMakeSshDirectory(self):
100
    sshdir = os.path.join(self.tmpdir, ".ssh")
101

    
102
    self.assertEqual(os.listdir(self.tmpdir), [])
103

    
104
    for kind in constants.SSHK_ALL:
105
      ssh.GetUserFiles("example20745", mkdir=True, kind=kind,
106
                       _homedir_fn=self._GetTempHomedir)
107
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
108
      self.assertEqual(os.stat(sshdir).st_mode & 0777, 0700)
109

    
110
  def testFilenames(self):
111
    sshdir = os.path.join(self.tmpdir, ".ssh")
112

    
113
    os.mkdir(sshdir)
114

    
115
    for kind in constants.SSHK_ALL:
116
      result = ssh.GetUserFiles("example15103", mkdir=False, kind=kind,
117
                                _homedir_fn=self._GetTempHomedir)
118
      self.assertEqual(result, [
119
        os.path.join(self.tmpdir, ".ssh", "id_%s" % kind),
120
        os.path.join(self.tmpdir, ".ssh", "id_%s.pub" % kind),
121
        os.path.join(self.tmpdir, ".ssh", "authorized_keys"),
122
        ])
123

    
124
      self.assertEqual(os.listdir(self.tmpdir), [".ssh"])
125
      self.assertEqual(os.listdir(sshdir), [])
126

    
127
  def testNoDirCheck(self):
128
    self.assertEqual(os.listdir(self.tmpdir), [])
129

    
130
    for kind in constants.SSHK_ALL:
131
      ssh.GetUserFiles("example14528", mkdir=False, dircheck=False, kind=kind,
132
                       _homedir_fn=self._GetTempHomedir)
133
      self.assertEqual(os.listdir(self.tmpdir), [])
134

    
135

    
136
if __name__ == "__main__":
137
  testutils.GanetiTestProgram()