4 # Copyright (C) 2006, 2007, 2008 Google Inc.
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 # General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
22 """Utilities for unit testing"""
31 from ganeti import utils
35 return os.environ.get("TOP_SRCDIR", ".")
38 def TestDataFilename(name):
39 """Returns the filename of a given test data file.
42 @param name: the 'base' of the file name, as present in
43 the test/data directory
45 @return: the full path to the filename, such that it can
46 be used in 'make distcheck' rules
49 return "%s/test/data/%s" % (GetSourceDir(), name)
52 def ReadTestData(name):
53 """Returns the content of a test data file.
55 This is just a very simple wrapper over utils.ReadFile with the
56 proper test file name.
59 return utils.ReadFile(TestDataFilename(name))
62 def _SetupLogging(verbose):
63 """Setupup logging infrastructure.
66 fmt = logging.Formatter("%(asctime)s: %(threadName)s"
67 " %(levelname)s %(message)s")
70 handler = logging.StreamHandler()
72 handler = logging.FileHandler(os.devnull, "a")
74 handler.setLevel(logging.NOTSET)
75 handler.setFormatter(fmt)
77 root_logger = logging.getLogger("")
78 root_logger.setLevel(logging.NOTSET)
79 root_logger.addHandler(handler)
82 class GanetiTestProgram(unittest.TestProgram):
87 _SetupLogging("LOGTOSTDERR" in os.environ)
89 sys.stderr.write("Running %s\n" % self.progName)
92 # Ensure assertions will be evaluated
94 raise Exception("Not running in debug mode, assertions would not be"
97 # Check again, this time with a real assertion
100 except AssertionError:
103 raise Exception("Assertion not evaluated")
105 return unittest.TestProgram.runTests(self)
108 class GanetiTestCase(unittest.TestCase):
109 """Helper class for unittesting.
111 This class defines a few utility functions that help in building
112 unittests. Child classes must call the parent setup and cleanup.
116 self._temp_files = []
119 while self._temp_files:
121 utils.RemoveFile(self._temp_files.pop())
122 except EnvironmentError:
125 def assertFileContent(self, file_name, expected_content):
126 """Checks that the content of a file is what we expect.
129 @param file_name: the file whose contents we should check
130 @type expected_content: str
131 @param expected_content: the content we expect
134 actual_content = utils.ReadFile(file_name)
135 self.assertEqual(actual_content, expected_content)
137 def assertFileMode(self, file_name, expected_mode):
138 """Checks that the mode of a file is what we expect.
141 @param file_name: the file whose contents we should check
142 @type expected_mode: int
143 @param expected_mode: the mode we expect
146 st = os.stat(file_name)
147 actual_mode = stat.S_IMODE(st.st_mode)
148 self.assertEqual(actual_mode, expected_mode)
150 def assertFileUid(self, file_name, expected_uid):
151 """Checks that the user id of a file is what we expect.
154 @param file_name: the file whose contents we should check
155 @type expected_uid: int
156 @param expected_uid: the user id we expect
159 st = os.stat(file_name)
160 actual_uid = st.st_uid
161 self.assertEqual(actual_uid, expected_uid)
163 def assertFileGid(self, file_name, expected_gid):
164 """Checks that the group id of a file is what we expect.
167 @param file_name: the file whose contents we should check
168 @type expected_gid: int
169 @param expected_gid: the group id we expect
172 st = os.stat(file_name)
173 actual_gid = st.st_gid
174 self.assertEqual(actual_gid, expected_gid)
176 def assertEqualValues(self, first, second, msg=None):
177 """Compares two values whether they're equal.
179 Tuples are automatically converted to lists before comparing.
182 return self.assertEqual(UnifyValueType(first),
183 UnifyValueType(second),
186 def _CreateTempFile(self):
187 """Creates a temporary file and adds it to the internal cleanup list.
189 This method simplifies the creation and cleanup of temporary files
193 fh, fname = tempfile.mkstemp(prefix="ganeti-test", suffix=".tmp")
195 self._temp_files.append(fname)
199 def patch_object(*args, **kwargs):
200 """Unified patch_object for various versions of Python Mock.
202 Different Python Mock versions provide incompatible versions of patching an
203 object. More recent versions use _patch_object, older ones used patch_object.
204 This function unifies the different variations.
209 # pylint: disable=W0212
210 return mock._patch_object(*args, **kwargs)
211 except AttributeError:
212 # pylint: disable=E1101
213 return mock.patch_object(*args, **kwargs)
216 def UnifyValueType(data):
217 """Converts all tuples into lists.
219 This is useful for unittests where an external library doesn't keep types.
222 if isinstance(data, (tuple, list)):
223 return [UnifyValueType(i) for i in data]
225 elif isinstance(data, dict):
226 return dict([(UnifyValueType(key), UnifyValueType(value))
227 for (key, value) in data.iteritems()])
232 class CallCounter(object):
233 """Utility class to count number of calls to a function/method.
236 def __init__(self, fn):
237 """Initializes this class.
245 def __call__(self, *args, **kwargs):
246 """Calls wrapped function with given parameters.
250 return self._fn(*args, **kwargs)
253 """Returns number of calls.