+class TestParseFields(unittest.TestCase):
+ def test(self):
+ self.assertEqual(cli.ParseFields(None, []), [])
+ self.assertEqual(cli.ParseFields("name,foo,hello", []),
+ ["name", "foo", "hello"])
+ self.assertEqual(cli.ParseFields(None, ["def", "ault", "fields", "here"]),
+ ["def", "ault", "fields", "here"])
+ self.assertEqual(cli.ParseFields("name,foo", ["def", "ault"]),
+ ["name", "foo"])
+ self.assertEqual(cli.ParseFields("+name,foo", ["def", "ault"]),
+ ["def", "ault", "name", "foo"])
+
+
+class TestConstants(unittest.TestCase):
+ def testPriority(self):
+ self.assertEqual(set(cli._PRIONAME_TO_VALUE.values()),
+ set(constants.OP_PRIO_SUBMIT_VALID))
+ self.assertEqual(list(value for _, value in cli._PRIORITY_NAMES),
+ sorted(constants.OP_PRIO_SUBMIT_VALID, reverse=True))
+
+
+class TestParseNicOption(unittest.TestCase):
+ def test(self):
+ self.assertEqual(cli.ParseNicOption([("0", { "link": "eth0", })]),
+ [{ "link": "eth0", }])
+ self.assertEqual(cli.ParseNicOption([("5", { "ip": "192.0.2.7", })]),
+ [{}, {}, {}, {}, {}, { "ip": "192.0.2.7", }])
+
+ def testErrors(self):
+ for i in [None, "", "abc", "zero", "Hello World", "\0", []]:
+ self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
+ [(i, { "link": "eth0", })])
+ self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
+ [("0", i)])
+
+ self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
+ [(0, { True: False, })])
+
+ self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
+ [(3, { "mode": [], })])
+
+
+class TestFormatResultError(unittest.TestCase):
+ def testNormal(self):
+ for verbose in [False, True]:
+ self.assertRaises(AssertionError, cli.FormatResultError,
+ constants.RS_NORMAL, verbose)
+
+ def testUnknown(self):
+ for verbose in [False, True]:
+ self.assertRaises(NotImplementedError, cli.FormatResultError,
+ "#some!other!status#", verbose)
+
+ def test(self):
+ for status in constants.RS_ALL:
+ if status == constants.RS_NORMAL:
+ continue
+
+ self.assertNotEqual(cli.FormatResultError(status, False),
+ cli.FormatResultError(status, True))
+
+ result = cli.FormatResultError(status, True)
+ self.assertTrue(result.startswith("("))
+ self.assertTrue(result.endswith(")"))
+
+
+class TestGetOnlineNodes(unittest.TestCase):
+ class _FakeClient:
+ def __init__(self):
+ self._query = []
+
+ def AddQueryResult(self, *args):
+ self._query.append(args)
+
+ def CountPending(self):
+ return len(self._query)
+
+ def Query(self, res, fields, qfilter):
+ if res != constants.QR_NODE:
+ raise Exception("Querying wrong resource")
+
+ (exp_fields, check_filter, result) = self._query.pop(0)
+
+ if exp_fields != fields:
+ raise Exception("Expected fields %s, got %s" % (exp_fields, fields))
+
+ if not (qfilter is None or check_filter(qfilter)):
+ raise Exception("Filter doesn't match expectations")
+
+ return objects.QueryResponse(fields=None, data=result)
+
+ def testEmpty(self):
+ cl = self._FakeClient()
+
+ cl.AddQueryResult(["name", "offline", "sip"], None, [])
+ self.assertEqual(cli.GetOnlineNodes(None, cl=cl), [])
+ self.assertEqual(cl.CountPending(), 0)
+
+ def testNoSpecialFilter(self):
+ cl = self._FakeClient()
+
+ cl.AddQueryResult(["name", "offline", "sip"], None, [
+ [(constants.RS_NORMAL, "master.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.1")],
+ [(constants.RS_NORMAL, "node2.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.2")],
+ ])
+ self.assertEqual(cli.GetOnlineNodes(None, cl=cl),
+ ["master.example.com", "node2.example.com"])
+ self.assertEqual(cl.CountPending(), 0)
+
+ def testNoMaster(self):
+ cl = self._FakeClient()
+
+ def _CheckFilter(qfilter):
+ self.assertEqual(qfilter, [qlang.OP_NOT, [qlang.OP_TRUE, "master"]])
+ return True
+
+ cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
+ [(constants.RS_NORMAL, "node2.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.2")],
+ ])
+ self.assertEqual(cli.GetOnlineNodes(None, cl=cl, filter_master=True),
+ ["node2.example.com"])
+ self.assertEqual(cl.CountPending(), 0)
+
+ def testSecondaryIpAddress(self):
+ cl = self._FakeClient()
+
+ cl.AddQueryResult(["name", "offline", "sip"], None, [
+ [(constants.RS_NORMAL, "master.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.1")],
+ [(constants.RS_NORMAL, "node2.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.2")],
+ ])
+ self.assertEqual(cli.GetOnlineNodes(None, cl=cl, secondary_ips=True),
+ ["192.0.2.1", "192.0.2.2"])
+ self.assertEqual(cl.CountPending(), 0)
+
+ def testNoMasterFilterNodeName(self):
+ cl = self._FakeClient()
+
+ def _CheckFilter(qfilter):
+ self.assertEqual(qfilter,
+ [qlang.OP_AND,
+ [qlang.OP_OR] + [[qlang.OP_EQUAL, "name", name]
+ for name in ["node2", "node3"]],
+ [qlang.OP_NOT, [qlang.OP_TRUE, "master"]]])
+ return True
+
+ cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
+ [(constants.RS_NORMAL, "node2.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.12")],
+ [(constants.RS_NORMAL, "node3.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.13")],
+ ])
+ self.assertEqual(cli.GetOnlineNodes(["node2", "node3"], cl=cl,
+ secondary_ips=True, filter_master=True),
+ ["192.0.2.12", "192.0.2.13"])
+ self.assertEqual(cl.CountPending(), 0)
+
+ def testOfflineNodes(self):
+ cl = self._FakeClient()
+
+ cl.AddQueryResult(["name", "offline", "sip"], None, [
+ [(constants.RS_NORMAL, "master.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.1")],
+ [(constants.RS_NORMAL, "node2.example.com"),
+ (constants.RS_NORMAL, True),
+ (constants.RS_NORMAL, "192.0.2.2")],
+ [(constants.RS_NORMAL, "node3.example.com"),
+ (constants.RS_NORMAL, True),
+ (constants.RS_NORMAL, "192.0.2.3")],
+ ])
+ self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nowarn=True),
+ ["master.example.com"])
+ self.assertEqual(cl.CountPending(), 0)
+
+ def testNodeGroup(self):
+ cl = self._FakeClient()
+
+ def _CheckFilter(qfilter):
+ self.assertEqual(qfilter,
+ [qlang.OP_OR, [qlang.OP_EQUAL, "group", "foobar"],
+ [qlang.OP_EQUAL, "group.uuid", "foobar"]])
+ return True
+
+ cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
+ [(constants.RS_NORMAL, "master.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.1")],
+ [(constants.RS_NORMAL, "node3.example.com"),
+ (constants.RS_NORMAL, False),
+ (constants.RS_NORMAL, "192.0.2.3")],
+ ])
+ self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nodegroup="foobar"),
+ ["master.example.com", "node3.example.com"])
+ self.assertEqual(cl.CountPending(), 0)
+
+
+class TestFormatTimestamp(unittest.TestCase):
+ def testGood(self):
+ self.assertEqual(cli.FormatTimestamp((0, 1)),
+ time.strftime("%F %T", time.localtime(0)) + ".000001")
+ self.assertEqual(cli.FormatTimestamp((1332944009, 17376)),
+ (time.strftime("%F %T", time.localtime(1332944009)) +
+ ".017376"))
+
+ def testWrong(self):
+ for i in [0, [], {}, "", [1]]:
+ self.assertEqual(cli.FormatTimestamp(i), "?")
+
+
+class TestFormatUsage(unittest.TestCase):
+ def test(self):
+ binary = "gnt-unittest"
+ commands = {
+ "cmdA":
+ (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
+ "description of A"),
+ "bbb":
+ (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
+ "Hello World," * 10),
+ "longname":
+ (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
+ "Another description"),
+ }
+
+ self.assertEqual(list(cli._FormatUsage(binary, commands)), [
+ "Usage: gnt-unittest {command} [options...] [argument...]",
+ "gnt-unittest <command> --help to see details, or man gnt-unittest",
+ "",
+ "Commands:",
+ (" bbb - Hello World,Hello World,Hello World,Hello World,Hello"
+ " World,Hello"),
+ " World,Hello World,Hello World,Hello World,Hello World,",
+ " cmdA - description of A",
+ " longname - Another description",
+ "",
+ ])
+
+
+class TestParseArgs(unittest.TestCase):
+ def testNoArguments(self):
+ for argv in [[], ["gnt-unittest"]]:
+ try:
+ cli._ParseArgs("gnt-unittest", argv, {}, {}, set())
+ except cli._ShowUsage, err:
+ self.assertTrue(err.exit_error)
+ else:
+ self.fail("Did not raise exception")
+
+ def testVersion(self):
+ for argv in [["test", "--version"], ["test", "--version", "somethingelse"]]:
+ try:
+ cli._ParseArgs("test", argv, {}, {}, set())
+ except cli._ShowVersion:
+ pass
+ else:
+ self.fail("Did not raise exception")
+
+ def testHelp(self):
+ for argv in [["test", "--help"], ["test", "--help", "somethingelse"]]:
+ try:
+ cli._ParseArgs("test", argv, {}, {}, set())
+ except cli._ShowUsage, err:
+ self.assertFalse(err.exit_error)
+ else:
+ self.fail("Did not raise exception")
+
+ def testUnknownCommandOrAlias(self):
+ for argv in [["test", "list"], ["test", "somethingelse", "--help"]]:
+ try:
+ cli._ParseArgs("test", argv, {}, {}, set())
+ except cli._ShowUsage, err:
+ self.assertTrue(err.exit_error)
+ else:
+ self.fail("Did not raise exception")
+
+ def testInvalidAliasList(self):
+ cmd = {
+ "list": NotImplemented,
+ "foo": NotImplemented,
+ }
+ aliases = {
+ "list": NotImplemented,
+ "foo": NotImplemented,
+ }
+ assert sorted(cmd.keys()) == sorted(aliases.keys())
+ self.assertRaises(AssertionError, cli._ParseArgs, "test",
+ ["test", "list"], cmd, aliases, set())
+
+ def testAliasForNonExistantCommand(self):
+ cmd = {}
+ aliases = {
+ "list": NotImplemented,
+ }
+ self.assertRaises(errors.ProgrammerError, cli._ParseArgs, "test",
+ ["test", "list"], cmd, aliases, set())
+
+
+if __name__ == "__main__":