Reason trail implementation for "shutdown"
[ganeti-local] / test / py / ganeti.cli_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2008, 2011, 2012, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the cli module"""
23
24 import copy
25 import testutils
26 import time
27 import unittest
28 from cStringIO import StringIO
29
30 from ganeti import constants
31 from ganeti import cli
32 from ganeti import errors
33 from ganeti import utils
34 from ganeti import objects
35 from ganeti import qlang
36 from ganeti.errors import OpPrereqError, ParameterError
37
38
39 class TestParseTimespec(unittest.TestCase):
40   """Testing case for ParseTimespec"""
41
42   def testValidTimes(self):
43     """Test valid timespecs"""
44     test_data = [
45       ("1s", 1),
46       ("1", 1),
47       ("1m", 60),
48       ("1h", 60 * 60),
49       ("1d", 60 * 60 * 24),
50       ("1w", 60 * 60 * 24 * 7),
51       ("4h", 4 * 60 * 60),
52       ("61m", 61 * 60),
53       ]
54     for value, expected_result in test_data:
55       self.failUnlessEqual(cli.ParseTimespec(value), expected_result)
56
57   def testInvalidTime(self):
58     """Test invalid timespecs"""
59     test_data = [
60       "1y",
61       "",
62       "aaa",
63       "s",
64       ]
65     for value in test_data:
66       self.failUnlessRaises(OpPrereqError, cli.ParseTimespec, value)
67
68
69 class TestSplitKeyVal(unittest.TestCase):
70   """Testing case for cli._SplitKeyVal"""
71   DATA = "a=b,c,no_d,-e"
72   RESULT = {"a": "b", "c": True, "d": False, "e": None}
73   RESULT_NOPREFIX = {"a": "b", "c": {}, "no_d": {}, "-e": {}}
74
75   def testSplitKeyVal(self):
76     """Test splitting"""
77     self.failUnlessEqual(cli._SplitKeyVal("option", self.DATA, True),
78                          self.RESULT)
79
80   def testDuplicateParam(self):
81     """Test duplicate parameters"""
82     for data in ("a=1,a=2", "a,no_a"):
83       self.failUnlessRaises(ParameterError, cli._SplitKeyVal,
84                             "option", data, True)
85
86   def testEmptyData(self):
87     """Test how we handle splitting an empty string"""
88     self.failUnlessEqual(cli._SplitKeyVal("option", "", True), {})
89
90
91 class TestIdentKeyVal(unittest.TestCase):
92   """Testing case for cli.check_ident_key_val"""
93
94   def testIdentKeyVal(self):
95     """Test identkeyval"""
96     def cikv(value):
97       return cli.check_ident_key_val("option", "opt", value)
98
99     self.assertEqual(cikv("foo:bar"), ("foo", {"bar": True}))
100     self.assertEqual(cikv("foo:bar=baz"), ("foo", {"bar": "baz"}))
101     self.assertEqual(cikv("bar:b=c,c=a"), ("bar", {"b": "c", "c": "a"}))
102     self.assertEqual(cikv("no_bar"), ("bar", False))
103     self.assertRaises(ParameterError, cikv, "no_bar:foo")
104     self.assertRaises(ParameterError, cikv, "no_bar:foo=baz")
105     self.assertRaises(ParameterError, cikv, "bar:foo=baz,foo=baz")
106     self.assertEqual(cikv("-foo"), ("foo", None))
107     self.assertRaises(ParameterError, cikv, "-foo:a=c")
108
109     # Check negative numbers
110     self.assertEqual(cikv("-1:remove"), ("-1", {
111       "remove": True,
112       }))
113     self.assertEqual(cikv("-29447:add,size=4G"), ("-29447", {
114       "add": True,
115       "size": "4G",
116       }))
117     for i in ["-:", "-"]:
118       self.assertEqual(cikv(i), ("", None))
119
120   @staticmethod
121   def _csikv(value):
122     return cli._SplitIdentKeyVal("opt", value, False)
123
124   def testIdentKeyValNoPrefix(self):
125     """Test identkeyval without prefixes"""
126     test_cases = [
127       ("foo:bar", None),
128       ("foo:no_bar", None),
129       ("foo:bar=baz,bar=baz", None),
130       ("foo",
131        ("foo", {})),
132       ("foo:bar=baz",
133        ("foo", {"bar": "baz"})),
134       ("no_foo:-1=baz,no_op=3",
135        ("no_foo", {"-1": "baz", "no_op": "3"})),
136       ]
137     for (arg, res) in test_cases:
138       if res is None:
139         self.assertRaises(ParameterError, self._csikv, arg)
140       else:
141         self.assertEqual(self._csikv(arg), res)
142
143
144 class TestListIdentKeyVal(unittest.TestCase):
145   """Test for cli.check_list_ident_key_val()"""
146
147   @staticmethod
148   def _clikv(value):
149     return cli.check_list_ident_key_val("option", "opt", value)
150
151   def testListIdentKeyVal(self):
152     test_cases = [
153       ("",
154        None),
155       ("foo",
156        {"foo": {}}),
157       ("foo:bar=baz",
158        {"foo": {"bar": "baz"}}),
159       ("foo:bar=baz/foo:bat=bad",
160        None),
161       ("foo:abc=42/bar:def=11",
162        {"foo": {"abc": "42"},
163         "bar": {"def": "11"}}),
164       ("foo:abc=42/bar:def=11,ghi=07",
165        {"foo": {"abc": "42"},
166         "bar": {"def": "11", "ghi": "07"}}),
167       ]
168     for (arg, res) in test_cases:
169       if res is None:
170         self.assertRaises(ParameterError, self._clikv, arg)
171       else:
172         self.assertEqual(res, self._clikv(arg))
173
174
175 class TestToStream(unittest.TestCase):
176   """Test the ToStream functions"""
177
178   def testBasic(self):
179     for data in ["foo",
180                  "foo %s",
181                  "foo %(test)s",
182                  "foo %s %s",
183                  "",
184                  ]:
185       buf = StringIO()
186       cli._ToStream(buf, data)
187       self.failUnlessEqual(buf.getvalue(), data + "\n")
188
189   def testParams(self):
190       buf = StringIO()
191       cli._ToStream(buf, "foo %s", 1)
192       self.failUnlessEqual(buf.getvalue(), "foo 1\n")
193       buf = StringIO()
194       cli._ToStream(buf, "foo %s", (15,16))
195       self.failUnlessEqual(buf.getvalue(), "foo (15, 16)\n")
196       buf = StringIO()
197       cli._ToStream(buf, "foo %s %s", "a", "b")
198       self.failUnlessEqual(buf.getvalue(), "foo a b\n")
199
200
201 class TestGenerateTable(unittest.TestCase):
202   HEADERS = dict([("f%s" % i, "Field%s" % i) for i in range(5)])
203
204   FIELDS1 = ["f1", "f2"]
205   DATA1 = [
206     ["abc", 1234],
207     ["foobar", 56],
208     ["b", -14],
209     ]
210
211   def _test(self, headers, fields, separator, data,
212             numfields, unitfields, units, expected):
213     table = cli.GenerateTable(headers, fields, separator, data,
214                               numfields=numfields, unitfields=unitfields,
215                               units=units)
216     self.assertEqual(table, expected)
217
218   def testPlain(self):
219     exp = [
220       "Field1 Field2",
221       "abc    1234",
222       "foobar 56",
223       "b      -14",
224       ]
225     self._test(self.HEADERS, self.FIELDS1, None, self.DATA1,
226                None, None, "m", exp)
227
228   def testNoFields(self):
229     self._test(self.HEADERS, [], None, [[], []],
230                None, None, "m", ["", "", ""])
231     self._test(None, [], None, [[], []],
232                None, None, "m", ["", ""])
233
234   def testSeparator(self):
235     for sep in ["#", ":", ",", "^", "!", "%", "|", "###", "%%", "!!!", "||"]:
236       exp = [
237         "Field1%sField2" % sep,
238         "abc%s1234" % sep,
239         "foobar%s56" % sep,
240         "b%s-14" % sep,
241         ]
242       self._test(self.HEADERS, self.FIELDS1, sep, self.DATA1,
243                  None, None, "m", exp)
244
245   def testNoHeader(self):
246     exp = [
247       "abc    1234",
248       "foobar 56",
249       "b      -14",
250       ]
251     self._test(None, self.FIELDS1, None, self.DATA1,
252                None, None, "m", exp)
253
254   def testUnknownField(self):
255     headers = {
256       "f1": "Field1",
257       }
258     exp = [
259       "Field1 UNKNOWN",
260       "abc    1234",
261       "foobar 56",
262       "b      -14",
263       ]
264     self._test(headers, ["f1", "UNKNOWN"], None, self.DATA1,
265                None, None, "m", exp)
266
267   def testNumfields(self):
268     fields = ["f1", "f2", "f3"]
269     data = [
270       ["abc", 1234, 0],
271       ["foobar", 56, 3],
272       ["b", -14, "-"],
273       ]
274     exp = [
275       "Field1 Field2 Field3",
276       "abc      1234      0",
277       "foobar     56      3",
278       "b         -14      -",
279       ]
280     self._test(self.HEADERS, fields, None, data,
281                ["f2", "f3"], None, "m", exp)
282
283   def testUnitfields(self):
284     expnosep = [
285       "Field1 Field2 Field3",
286       "abc      1234     0M",
287       "foobar     56     3M",
288       "b         -14      -",
289       ]
290
291     expsep = [
292       "Field1:Field2:Field3",
293       "abc:1234:0M",
294       "foobar:56:3M",
295       "b:-14:-",
296       ]
297
298     for sep, expected in [(None, expnosep), (":", expsep)]:
299       fields = ["f1", "f2", "f3"]
300       data = [
301         ["abc", 1234, 0],
302         ["foobar", 56, 3],
303         ["b", -14, "-"],
304         ]
305       self._test(self.HEADERS, fields, sep, data,
306                  ["f2", "f3"], ["f3"], "h", expected)
307
308   def testUnusual(self):
309     data = [
310       ["%", "xyz"],
311       ["%%", "abc"],
312       ]
313     exp = [
314       "Field1 Field2",
315       "%      xyz",
316       "%%     abc",
317       ]
318     self._test(self.HEADERS, ["f1", "f2"], None, data,
319                None, None, "m", exp)
320
321
322 class TestFormatQueryResult(unittest.TestCase):
323   def test(self):
324     fields = [
325       objects.QueryFieldDefinition(name="name", title="Name",
326                                    kind=constants.QFT_TEXT),
327       objects.QueryFieldDefinition(name="size", title="Size",
328                                    kind=constants.QFT_NUMBER),
329       objects.QueryFieldDefinition(name="act", title="Active",
330                                    kind=constants.QFT_BOOL),
331       objects.QueryFieldDefinition(name="mem", title="Memory",
332                                    kind=constants.QFT_UNIT),
333       objects.QueryFieldDefinition(name="other", title="SomeList",
334                                    kind=constants.QFT_OTHER),
335       ]
336
337     response = objects.QueryResponse(fields=fields, data=[
338       [(constants.RS_NORMAL, "nodeA"), (constants.RS_NORMAL, 128),
339        (constants.RS_NORMAL, False), (constants.RS_NORMAL, 1468006),
340        (constants.RS_NORMAL, [])],
341       [(constants.RS_NORMAL, "other"), (constants.RS_NORMAL, 512),
342        (constants.RS_NORMAL, True), (constants.RS_NORMAL, 16),
343        (constants.RS_NORMAL, [1, 2, 3])],
344       [(constants.RS_NORMAL, "xyz"), (constants.RS_NORMAL, 1024),
345        (constants.RS_NORMAL, True), (constants.RS_NORMAL, 4096),
346        (constants.RS_NORMAL, [{}, {}])],
347       ])
348
349     self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True),
350       (cli.QR_NORMAL, [
351       "Name  Size Active Memory SomeList",
352       "nodeA  128 N        1.4T []",
353       "other  512 Y         16M [1, 2, 3]",
354       "xyz   1024 Y        4.0G [{}, {}]",
355       ]))
356
357   def testTimestampAndUnit(self):
358     fields = [
359       objects.QueryFieldDefinition(name="name", title="Name",
360                                    kind=constants.QFT_TEXT),
361       objects.QueryFieldDefinition(name="size", title="Size",
362                                    kind=constants.QFT_UNIT),
363       objects.QueryFieldDefinition(name="mtime", title="ModTime",
364                                    kind=constants.QFT_TIMESTAMP),
365       ]
366
367     response = objects.QueryResponse(fields=fields, data=[
368       [(constants.RS_NORMAL, "a"), (constants.RS_NORMAL, 1024),
369        (constants.RS_NORMAL, 0)],
370       [(constants.RS_NORMAL, "b"), (constants.RS_NORMAL, 144996),
371        (constants.RS_NORMAL, 1291746295)],
372       ])
373
374     self.assertEqual(cli.FormatQueryResult(response, unit="m", header=True),
375       (cli.QR_NORMAL, [
376       "Name   Size ModTime",
377       "a      1024 %s" % utils.FormatTime(0),
378       "b    144996 %s" % utils.FormatTime(1291746295),
379       ]))
380
381   def testOverride(self):
382     fields = [
383       objects.QueryFieldDefinition(name="name", title="Name",
384                                    kind=constants.QFT_TEXT),
385       objects.QueryFieldDefinition(name="cust", title="Custom",
386                                    kind=constants.QFT_OTHER),
387       objects.QueryFieldDefinition(name="xt", title="XTime",
388                                    kind=constants.QFT_TIMESTAMP),
389       ]
390
391     response = objects.QueryResponse(fields=fields, data=[
392       [(constants.RS_NORMAL, "x"), (constants.RS_NORMAL, ["a", "b", "c"]),
393        (constants.RS_NORMAL, 1234)],
394       [(constants.RS_NORMAL, "y"), (constants.RS_NORMAL, range(10)),
395        (constants.RS_NORMAL, 1291746295)],
396       ])
397
398     override = {
399       "cust": (utils.CommaJoin, False),
400       "xt": (hex, True),
401       }
402
403     self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True,
404                                            format_override=override),
405       (cli.QR_NORMAL, [
406       "Name Custom                            XTime",
407       "x    a, b, c                           0x4d2",
408       "y    0, 1, 2, 3, 4, 5, 6, 7, 8, 9 0x4cfe7bf7",
409       ]))
410
411   def testSeparator(self):
412     fields = [
413       objects.QueryFieldDefinition(name="name", title="Name",
414                                    kind=constants.QFT_TEXT),
415       objects.QueryFieldDefinition(name="count", title="Count",
416                                    kind=constants.QFT_NUMBER),
417       objects.QueryFieldDefinition(name="desc", title="Description",
418                                    kind=constants.QFT_TEXT),
419       ]
420
421     response = objects.QueryResponse(fields=fields, data=[
422       [(constants.RS_NORMAL, "instance1.example.com"),
423        (constants.RS_NORMAL, 21125), (constants.RS_NORMAL, "Hello World!")],
424       [(constants.RS_NORMAL, "mail.other.net"),
425        (constants.RS_NORMAL, -9000), (constants.RS_NORMAL, "a,b,c")],
426       ])
427
428     for sep in [":", "|", "#", "|||", "###", "@@@", "@#@"]:
429       for header in [None, "Name%sCount%sDescription" % (sep, sep)]:
430         exp = []
431         if header:
432           exp.append(header)
433         exp.extend([
434           "instance1.example.com%s21125%sHello World!" % (sep, sep),
435           "mail.other.net%s-9000%sa,b,c" % (sep, sep),
436           ])
437
438         self.assertEqual(cli.FormatQueryResult(response, separator=sep,
439                                                header=bool(header)),
440                          (cli.QR_NORMAL, exp))
441
442   def testStatusWithUnknown(self):
443     fields = [
444       objects.QueryFieldDefinition(name="id", title="ID",
445                                    kind=constants.QFT_NUMBER),
446       objects.QueryFieldDefinition(name="unk", title="unk",
447                                    kind=constants.QFT_UNKNOWN),
448       objects.QueryFieldDefinition(name="unavail", title="Unavail",
449                                    kind=constants.QFT_BOOL),
450       objects.QueryFieldDefinition(name="nodata", title="NoData",
451                                    kind=constants.QFT_TEXT),
452       objects.QueryFieldDefinition(name="offline", title="OffLine",
453                                    kind=constants.QFT_TEXT),
454       ]
455
456     response = objects.QueryResponse(fields=fields, data=[
457       [(constants.RS_NORMAL, 1), (constants.RS_UNKNOWN, None),
458        (constants.RS_NORMAL, False), (constants.RS_NORMAL, ""),
459        (constants.RS_OFFLINE, None)],
460       [(constants.RS_NORMAL, 2), (constants.RS_UNKNOWN, None),
461        (constants.RS_NODATA, None), (constants.RS_NORMAL, "x"),
462        (constants.RS_OFFLINE, None)],
463       [(constants.RS_NORMAL, 3), (constants.RS_UNKNOWN, None),
464        (constants.RS_NORMAL, False), (constants.RS_UNAVAIL, None),
465        (constants.RS_OFFLINE, None)],
466       ])
467
468     self.assertEqual(cli.FormatQueryResult(response, header=True,
469                                            separator="|", verbose=True),
470       (cli.QR_UNKNOWN, [
471       "ID|unk|Unavail|NoData|OffLine",
472       "1|(unknown)|N||(offline)",
473       "2|(unknown)|(nodata)|x|(offline)",
474       "3|(unknown)|N|(unavail)|(offline)",
475       ]))
476     self.assertEqual(cli.FormatQueryResult(response, header=True,
477                                            separator="|", verbose=False),
478       (cli.QR_UNKNOWN, [
479       "ID|unk|Unavail|NoData|OffLine",
480       "1|??|N||*",
481       "2|??|?|x|*",
482       "3|??|N|-|*",
483       ]))
484
485   def testNoData(self):
486     fields = [
487       objects.QueryFieldDefinition(name="id", title="ID",
488                                    kind=constants.QFT_NUMBER),
489       objects.QueryFieldDefinition(name="name", title="Name",
490                                    kind=constants.QFT_TEXT),
491       ]
492
493     response = objects.QueryResponse(fields=fields, data=[])
494
495     self.assertEqual(cli.FormatQueryResult(response, header=True),
496                      (cli.QR_NORMAL, ["ID Name"]))
497
498   def testNoDataWithUnknown(self):
499     fields = [
500       objects.QueryFieldDefinition(name="id", title="ID",
501                                    kind=constants.QFT_NUMBER),
502       objects.QueryFieldDefinition(name="unk", title="unk",
503                                    kind=constants.QFT_UNKNOWN),
504       ]
505
506     response = objects.QueryResponse(fields=fields, data=[])
507
508     self.assertEqual(cli.FormatQueryResult(response, header=False),
509                      (cli.QR_UNKNOWN, []))
510
511   def testStatus(self):
512     fields = [
513       objects.QueryFieldDefinition(name="id", title="ID",
514                                    kind=constants.QFT_NUMBER),
515       objects.QueryFieldDefinition(name="unavail", title="Unavail",
516                                    kind=constants.QFT_BOOL),
517       objects.QueryFieldDefinition(name="nodata", title="NoData",
518                                    kind=constants.QFT_TEXT),
519       objects.QueryFieldDefinition(name="offline", title="OffLine",
520                                    kind=constants.QFT_TEXT),
521       ]
522
523     response = objects.QueryResponse(fields=fields, data=[
524       [(constants.RS_NORMAL, 1), (constants.RS_NORMAL, False),
525        (constants.RS_NORMAL, ""), (constants.RS_OFFLINE, None)],
526       [(constants.RS_NORMAL, 2), (constants.RS_NODATA, None),
527        (constants.RS_NORMAL, "x"), (constants.RS_NORMAL, "abc")],
528       [(constants.RS_NORMAL, 3), (constants.RS_NORMAL, False),
529        (constants.RS_UNAVAIL, None), (constants.RS_OFFLINE, None)],
530       ])
531
532     self.assertEqual(cli.FormatQueryResult(response, header=False,
533                                            separator="|", verbose=True),
534       (cli.QR_INCOMPLETE, [
535       "1|N||(offline)",
536       "2|(nodata)|x|abc",
537       "3|N|(unavail)|(offline)",
538       ]))
539     self.assertEqual(cli.FormatQueryResult(response, header=False,
540                                            separator="|", verbose=False),
541       (cli.QR_INCOMPLETE, [
542       "1|N||*",
543       "2|?|x|abc",
544       "3|N|-|*",
545       ]))
546
547   def testInvalidFieldType(self):
548     fields = [
549       objects.QueryFieldDefinition(name="x", title="x",
550                                    kind="#some#other#type"),
551       ]
552
553     response = objects.QueryResponse(fields=fields, data=[])
554
555     self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
556
557   def testInvalidFieldStatus(self):
558     fields = [
559       objects.QueryFieldDefinition(name="x", title="x",
560                                    kind=constants.QFT_TEXT),
561       ]
562
563     response = objects.QueryResponse(fields=fields, data=[[(-1, None)]])
564     self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
565
566     response = objects.QueryResponse(fields=fields, data=[[(-1, "x")]])
567     self.assertRaises(AssertionError, cli.FormatQueryResult, response)
568
569   def testEmptyFieldTitle(self):
570     fields = [
571       objects.QueryFieldDefinition(name="x", title="",
572                                    kind=constants.QFT_TEXT),
573       ]
574
575     response = objects.QueryResponse(fields=fields, data=[])
576     self.assertRaises(AssertionError, cli.FormatQueryResult, response)
577
578
579 class _MockJobPollCb(cli.JobPollCbBase, cli.JobPollReportCbBase):
580   def __init__(self, tc, job_id):
581     self.tc = tc
582     self.job_id = job_id
583     self._wfjcr = []
584     self._jobstatus = []
585     self._expect_notchanged = False
586     self._expect_log = []
587
588   def CheckEmpty(self):
589     self.tc.assertFalse(self._wfjcr)
590     self.tc.assertFalse(self._jobstatus)
591     self.tc.assertFalse(self._expect_notchanged)
592     self.tc.assertFalse(self._expect_log)
593
594   def AddWfjcResult(self, *args):
595     self._wfjcr.append(args)
596
597   def AddQueryJobsResult(self, *args):
598     self._jobstatus.append(args)
599
600   def WaitForJobChangeOnce(self, job_id, fields,
601                            prev_job_info, prev_log_serial):
602     self.tc.assertEqual(job_id, self.job_id)
603     self.tc.assertEqualValues(fields, ["status"])
604     self.tc.assertFalse(self._expect_notchanged)
605     self.tc.assertFalse(self._expect_log)
606
607     (exp_prev_job_info, exp_prev_log_serial, result) = self._wfjcr.pop(0)
608     self.tc.assertEqualValues(prev_job_info, exp_prev_job_info)
609     self.tc.assertEqual(prev_log_serial, exp_prev_log_serial)
610
611     if result == constants.JOB_NOTCHANGED:
612       self._expect_notchanged = True
613     elif result:
614       (_, logmsgs) = result
615       if logmsgs:
616         self._expect_log.extend(logmsgs)
617
618     return result
619
620   def QueryJobs(self, job_ids, fields):
621     self.tc.assertEqual(job_ids, [self.job_id])
622     self.tc.assertEqualValues(fields, ["status", "opstatus", "opresult"])
623     self.tc.assertFalse(self._expect_notchanged)
624     self.tc.assertFalse(self._expect_log)
625
626     result = self._jobstatus.pop(0)
627     self.tc.assertEqual(len(fields), len(result))
628     return [result]
629
630   def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
631     self.tc.assertEqual(job_id, self.job_id)
632     self.tc.assertEqualValues((serial, timestamp, log_type, log_msg),
633                               self._expect_log.pop(0))
634
635   def ReportNotChanged(self, job_id, status):
636     self.tc.assertEqual(job_id, self.job_id)
637     self.tc.assert_(self._expect_notchanged)
638     self._expect_notchanged = False
639
640
641 class TestGenericPollJob(testutils.GanetiTestCase):
642   def testSuccessWithLog(self):
643     job_id = 29609
644     cbs = _MockJobPollCb(self, job_id)
645
646     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
647
648     cbs.AddWfjcResult(None, None,
649                       ((constants.JOB_STATUS_QUEUED, ), None))
650
651     cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
652                       constants.JOB_NOTCHANGED)
653
654     cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
655                       ((constants.JOB_STATUS_RUNNING, ),
656                        [(1, utils.SplitTime(1273491611.0),
657                          constants.ELOG_MESSAGE, "Step 1"),
658                         (2, utils.SplitTime(1273491615.9),
659                          constants.ELOG_MESSAGE, "Step 2"),
660                         (3, utils.SplitTime(1273491625.02),
661                          constants.ELOG_MESSAGE, "Step 3"),
662                         (4, utils.SplitTime(1273491635.05),
663                          constants.ELOG_MESSAGE, "Step 4"),
664                         (37, utils.SplitTime(1273491645.0),
665                          constants.ELOG_MESSAGE, "Step 5"),
666                         (203, utils.SplitTime(127349155.0),
667                          constants.ELOG_MESSAGE, "Step 6")]))
668
669     cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 203,
670                       ((constants.JOB_STATUS_RUNNING, ),
671                        [(300, utils.SplitTime(1273491711.01),
672                          constants.ELOG_MESSAGE, "Step X"),
673                         (302, utils.SplitTime(1273491815.8),
674                          constants.ELOG_MESSAGE, "Step Y"),
675                         (303, utils.SplitTime(1273491925.32),
676                          constants.ELOG_MESSAGE, "Step Z")]))
677
678     cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 303,
679                       ((constants.JOB_STATUS_SUCCESS, ), None))
680
681     cbs.AddQueryJobsResult(constants.JOB_STATUS_SUCCESS,
682                            [constants.OP_STATUS_SUCCESS,
683                             constants.OP_STATUS_SUCCESS],
684                            ["Hello World", "Foo man bar"])
685
686     self.assertEqual(["Hello World", "Foo man bar"],
687                      cli.GenericPollJob(job_id, cbs, cbs))
688     cbs.CheckEmpty()
689
690   def testJobLost(self):
691     job_id = 13746
692
693     cbs = _MockJobPollCb(self, job_id)
694     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
695     cbs.AddWfjcResult(None, None, None)
696     self.assertRaises(errors.JobLost, cli.GenericPollJob, job_id, cbs, cbs)
697     cbs.CheckEmpty()
698
699   def testError(self):
700     job_id = 31088
701
702     cbs = _MockJobPollCb(self, job_id)
703     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
704     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
705     cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
706                            [constants.OP_STATUS_SUCCESS,
707                             constants.OP_STATUS_ERROR],
708                            ["Hello World", "Error code 123"])
709     self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
710     cbs.CheckEmpty()
711
712   def testError2(self):
713     job_id = 22235
714
715     cbs = _MockJobPollCb(self, job_id)
716     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
717     encexc = errors.EncodeException(errors.LockError("problem"))
718     cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
719                            [constants.OP_STATUS_ERROR], [encexc])
720     self.assertRaises(errors.LockError, cli.GenericPollJob, job_id, cbs, cbs)
721     cbs.CheckEmpty()
722
723   def testWeirdError(self):
724     job_id = 28847
725
726     cbs = _MockJobPollCb(self, job_id)
727     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
728     cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
729                            [constants.OP_STATUS_RUNNING,
730                             constants.OP_STATUS_RUNNING],
731                            [None, None])
732     self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
733     cbs.CheckEmpty()
734
735   def testCancel(self):
736     job_id = 4275
737
738     cbs = _MockJobPollCb(self, job_id)
739     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
740     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_CANCELING, ), None))
741     cbs.AddQueryJobsResult(constants.JOB_STATUS_CANCELING,
742                            [constants.OP_STATUS_CANCELING,
743                             constants.OP_STATUS_CANCELING],
744                            [None, None])
745     self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
746     cbs.CheckEmpty()
747
748
749 class TestFormatLogMessage(unittest.TestCase):
750   def test(self):
751     self.assertEqual(cli.FormatLogMessage(constants.ELOG_MESSAGE,
752                                           "Hello World"),
753                      "Hello World")
754     self.assertRaises(TypeError, cli.FormatLogMessage,
755                       constants.ELOG_MESSAGE, [1, 2, 3])
756
757     self.assert_(cli.FormatLogMessage("some other type", (1, 2, 3)))
758
759
760 class TestParseFields(unittest.TestCase):
761   def test(self):
762     self.assertEqual(cli.ParseFields(None, []), [])
763     self.assertEqual(cli.ParseFields("name,foo,hello", []),
764                      ["name", "foo", "hello"])
765     self.assertEqual(cli.ParseFields(None, ["def", "ault", "fields", "here"]),
766                      ["def", "ault", "fields", "here"])
767     self.assertEqual(cli.ParseFields("name,foo", ["def", "ault"]),
768                      ["name", "foo"])
769     self.assertEqual(cli.ParseFields("+name,foo", ["def", "ault"]),
770                      ["def", "ault", "name", "foo"])
771
772
773 class TestConstants(unittest.TestCase):
774   def testPriority(self):
775     self.assertEqual(set(cli._PRIONAME_TO_VALUE.values()),
776                      set(constants.OP_PRIO_SUBMIT_VALID))
777     self.assertEqual(list(value for _, value in cli._PRIORITY_NAMES),
778                      sorted(constants.OP_PRIO_SUBMIT_VALID, reverse=True))
779
780
781 class TestParseNicOption(unittest.TestCase):
782   def test(self):
783     self.assertEqual(cli.ParseNicOption([("0", { "link": "eth0", })]),
784                      [{ "link": "eth0", }])
785     self.assertEqual(cli.ParseNicOption([("5", { "ip": "192.0.2.7", })]),
786                      [{}, {}, {}, {}, {}, { "ip": "192.0.2.7", }])
787
788   def testErrors(self):
789     for i in [None, "", "abc", "zero", "Hello World", "\0", []]:
790       self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
791                         [(i, { "link": "eth0", })])
792       self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
793                         [("0", i)])
794
795     self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
796                       [(0, { True: False, })])
797
798     self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
799                       [(3, { "mode": [], })])
800
801
802 class TestFormatResultError(unittest.TestCase):
803   def testNormal(self):
804     for verbose in [False, True]:
805       self.assertRaises(AssertionError, cli.FormatResultError,
806                         constants.RS_NORMAL, verbose)
807
808   def testUnknown(self):
809     for verbose in [False, True]:
810       self.assertRaises(NotImplementedError, cli.FormatResultError,
811                         "#some!other!status#", verbose)
812
813   def test(self):
814     for status in constants.RS_ALL:
815       if status == constants.RS_NORMAL:
816         continue
817
818       self.assertNotEqual(cli.FormatResultError(status, False),
819                           cli.FormatResultError(status, True))
820
821       result = cli.FormatResultError(status, True)
822       self.assertTrue(result.startswith("("))
823       self.assertTrue(result.endswith(")"))
824
825
826 class TestGetOnlineNodes(unittest.TestCase):
827   class _FakeClient:
828     def __init__(self):
829       self._query = []
830
831     def AddQueryResult(self, *args):
832       self._query.append(args)
833
834     def CountPending(self):
835       return len(self._query)
836
837     def Query(self, res, fields, qfilter):
838       if res != constants.QR_NODE:
839         raise Exception("Querying wrong resource")
840
841       (exp_fields, check_filter, result) = self._query.pop(0)
842
843       if exp_fields != fields:
844         raise Exception("Expected fields %s, got %s" % (exp_fields, fields))
845
846       if not (qfilter is None or check_filter(qfilter)):
847         raise Exception("Filter doesn't match expectations")
848
849       return objects.QueryResponse(fields=None, data=result)
850
851   def testEmpty(self):
852     cl = self._FakeClient()
853
854     cl.AddQueryResult(["name", "offline", "sip"], None, [])
855     self.assertEqual(cli.GetOnlineNodes(None, cl=cl), [])
856     self.assertEqual(cl.CountPending(), 0)
857
858   def testNoSpecialFilter(self):
859     cl = self._FakeClient()
860
861     cl.AddQueryResult(["name", "offline", "sip"], None, [
862       [(constants.RS_NORMAL, "master.example.com"),
863        (constants.RS_NORMAL, False),
864        (constants.RS_NORMAL, "192.0.2.1")],
865       [(constants.RS_NORMAL, "node2.example.com"),
866        (constants.RS_NORMAL, False),
867        (constants.RS_NORMAL, "192.0.2.2")],
868       ])
869     self.assertEqual(cli.GetOnlineNodes(None, cl=cl),
870                      ["master.example.com", "node2.example.com"])
871     self.assertEqual(cl.CountPending(), 0)
872
873   def testNoMaster(self):
874     cl = self._FakeClient()
875
876     def _CheckFilter(qfilter):
877       self.assertEqual(qfilter, [qlang.OP_NOT, [qlang.OP_TRUE, "master"]])
878       return True
879
880     cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
881       [(constants.RS_NORMAL, "node2.example.com"),
882        (constants.RS_NORMAL, False),
883        (constants.RS_NORMAL, "192.0.2.2")],
884       ])
885     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, filter_master=True),
886                      ["node2.example.com"])
887     self.assertEqual(cl.CountPending(), 0)
888
889   def testSecondaryIpAddress(self):
890     cl = self._FakeClient()
891
892     cl.AddQueryResult(["name", "offline", "sip"], None, [
893       [(constants.RS_NORMAL, "master.example.com"),
894        (constants.RS_NORMAL, False),
895        (constants.RS_NORMAL, "192.0.2.1")],
896       [(constants.RS_NORMAL, "node2.example.com"),
897        (constants.RS_NORMAL, False),
898        (constants.RS_NORMAL, "192.0.2.2")],
899       ])
900     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, secondary_ips=True),
901                      ["192.0.2.1", "192.0.2.2"])
902     self.assertEqual(cl.CountPending(), 0)
903
904   def testNoMasterFilterNodeName(self):
905     cl = self._FakeClient()
906
907     def _CheckFilter(qfilter):
908       self.assertEqual(qfilter,
909         [qlang.OP_AND,
910          [qlang.OP_OR] + [[qlang.OP_EQUAL, "name", name]
911                           for name in ["node2", "node3"]],
912          [qlang.OP_NOT, [qlang.OP_TRUE, "master"]]])
913       return True
914
915     cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
916       [(constants.RS_NORMAL, "node2.example.com"),
917        (constants.RS_NORMAL, False),
918        (constants.RS_NORMAL, "192.0.2.12")],
919       [(constants.RS_NORMAL, "node3.example.com"),
920        (constants.RS_NORMAL, False),
921        (constants.RS_NORMAL, "192.0.2.13")],
922       ])
923     self.assertEqual(cli.GetOnlineNodes(["node2", "node3"], cl=cl,
924                                         secondary_ips=True, filter_master=True),
925                      ["192.0.2.12", "192.0.2.13"])
926     self.assertEqual(cl.CountPending(), 0)
927
928   def testOfflineNodes(self):
929     cl = self._FakeClient()
930
931     cl.AddQueryResult(["name", "offline", "sip"], None, [
932       [(constants.RS_NORMAL, "master.example.com"),
933        (constants.RS_NORMAL, False),
934        (constants.RS_NORMAL, "192.0.2.1")],
935       [(constants.RS_NORMAL, "node2.example.com"),
936        (constants.RS_NORMAL, True),
937        (constants.RS_NORMAL, "192.0.2.2")],
938       [(constants.RS_NORMAL, "node3.example.com"),
939        (constants.RS_NORMAL, True),
940        (constants.RS_NORMAL, "192.0.2.3")],
941       ])
942     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nowarn=True),
943                      ["master.example.com"])
944     self.assertEqual(cl.CountPending(), 0)
945
946   def testNodeGroup(self):
947     cl = self._FakeClient()
948
949     def _CheckFilter(qfilter):
950       self.assertEqual(qfilter,
951         [qlang.OP_OR, [qlang.OP_EQUAL, "group", "foobar"],
952                       [qlang.OP_EQUAL, "group.uuid", "foobar"]])
953       return True
954
955     cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
956       [(constants.RS_NORMAL, "master.example.com"),
957        (constants.RS_NORMAL, False),
958        (constants.RS_NORMAL, "192.0.2.1")],
959       [(constants.RS_NORMAL, "node3.example.com"),
960        (constants.RS_NORMAL, False),
961        (constants.RS_NORMAL, "192.0.2.3")],
962       ])
963     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nodegroup="foobar"),
964                      ["master.example.com", "node3.example.com"])
965     self.assertEqual(cl.CountPending(), 0)
966
967
968 class TestFormatTimestamp(unittest.TestCase):
969   def testGood(self):
970     self.assertEqual(cli.FormatTimestamp((0, 1)),
971                      time.strftime("%F %T", time.localtime(0)) + ".000001")
972     self.assertEqual(cli.FormatTimestamp((1332944009, 17376)),
973                      (time.strftime("%F %T", time.localtime(1332944009)) +
974                       ".017376"))
975
976   def testWrong(self):
977     for i in [0, [], {}, "", [1]]:
978       self.assertEqual(cli.FormatTimestamp(i), "?")
979
980
981 class TestFormatUsage(unittest.TestCase):
982   def test(self):
983     binary = "gnt-unittest"
984     commands = {
985       "cmdA":
986         (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
987          "description of A"),
988       "bbb":
989         (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
990          "Hello World," * 10),
991       "longname":
992         (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
993          "Another description"),
994       }
995
996     self.assertEqual(list(cli._FormatUsage(binary, commands)), [
997       "Usage: gnt-unittest {command} [options...] [argument...]",
998       "gnt-unittest <command> --help to see details, or man gnt-unittest",
999       "",
1000       "Commands:",
1001       (" bbb      - Hello World,Hello World,Hello World,Hello World,Hello"
1002        " World,Hello"),
1003       "            World,Hello World,Hello World,Hello World,Hello World,",
1004       " cmdA     - description of A",
1005       " longname - Another description",
1006       "",
1007       ])
1008
1009
1010 class TestParseArgs(unittest.TestCase):
1011   def testNoArguments(self):
1012     for argv in [[], ["gnt-unittest"]]:
1013       try:
1014         cli._ParseArgs("gnt-unittest", argv, {}, {}, set())
1015       except cli._ShowUsage, err:
1016         self.assertTrue(err.exit_error)
1017       else:
1018         self.fail("Did not raise exception")
1019
1020   def testVersion(self):
1021     for argv in [["test", "--version"], ["test", "--version", "somethingelse"]]:
1022       try:
1023         cli._ParseArgs("test", argv, {}, {}, set())
1024       except cli._ShowVersion:
1025         pass
1026       else:
1027         self.fail("Did not raise exception")
1028
1029   def testHelp(self):
1030     for argv in [["test", "--help"], ["test", "--help", "somethingelse"]]:
1031       try:
1032         cli._ParseArgs("test", argv, {}, {}, set())
1033       except cli._ShowUsage, err:
1034         self.assertFalse(err.exit_error)
1035       else:
1036         self.fail("Did not raise exception")
1037
1038   def testUnknownCommandOrAlias(self):
1039     for argv in [["test", "list"], ["test", "somethingelse", "--help"]]:
1040       try:
1041         cli._ParseArgs("test", argv, {}, {}, set())
1042       except cli._ShowUsage, err:
1043         self.assertTrue(err.exit_error)
1044       else:
1045         self.fail("Did not raise exception")
1046
1047   def testInvalidAliasList(self):
1048     cmd = {
1049       "list": NotImplemented,
1050       "foo": NotImplemented,
1051       }
1052     aliases = {
1053       "list": NotImplemented,
1054       "foo": NotImplemented,
1055       }
1056     assert sorted(cmd.keys()) == sorted(aliases.keys())
1057     self.assertRaises(AssertionError, cli._ParseArgs, "test",
1058                       ["test", "list"], cmd, aliases, set())
1059
1060   def testAliasForNonExistantCommand(self):
1061     cmd = {}
1062     aliases = {
1063       "list": NotImplemented,
1064       }
1065     self.assertRaises(errors.ProgrammerError, cli._ParseArgs, "test",
1066                       ["test", "list"], cmd, aliases, set())
1067
1068
1069 class TestQftNames(unittest.TestCase):
1070   def testComplete(self):
1071     self.assertEqual(frozenset(cli._QFT_NAMES), constants.QFT_ALL)
1072
1073   def testUnique(self):
1074     lcnames = map(lambda s: s.lower(), cli._QFT_NAMES.values())
1075     self.assertFalse(utils.FindDuplicates(lcnames))
1076
1077   def testUppercase(self):
1078     for name in cli._QFT_NAMES.values():
1079       self.assertEqual(name[0], name[0].upper())
1080
1081
1082 class TestFieldDescValues(unittest.TestCase):
1083   def testKnownKind(self):
1084     fdef = objects.QueryFieldDefinition(name="aname",
1085                                         title="Atitle",
1086                                         kind=constants.QFT_TEXT,
1087                                         doc="aaa doc aaa")
1088     self.assertEqual(cli._FieldDescValues(fdef),
1089                      ["aname", "Text", "Atitle", "aaa doc aaa"])
1090
1091   def testUnknownKind(self):
1092     kind = "#foo#"
1093
1094     self.assertFalse(kind in constants.QFT_ALL)
1095     self.assertFalse(kind in cli._QFT_NAMES)
1096
1097     fdef = objects.QueryFieldDefinition(name="zname", title="Ztitle",
1098                                         kind=kind, doc="zzz doc zzz")
1099     self.assertEqual(cli._FieldDescValues(fdef),
1100                      ["zname", kind, "Ztitle", "zzz doc zzz"])
1101
1102
1103 class TestSerializeGenericInfo(unittest.TestCase):
1104   """Test case for cli._SerializeGenericInfo"""
1105   def _RunTest(self, data, expected):
1106     buf = StringIO()
1107     cli._SerializeGenericInfo(buf, data, 0)
1108     self.assertEqual(buf.getvalue(), expected)
1109
1110   def testSimple(self):
1111     test_samples = [
1112       ("abc", "abc\n"),
1113       ([], "\n"),
1114       ({}, "\n"),
1115       (["1", "2", "3"], "- 1\n- 2\n- 3\n"),
1116       ([("z", "26")], "z: 26\n"),
1117       ({"z": "26"}, "z: 26\n"),
1118       ([("z", "26"), ("a", "1")], "z: 26\na: 1\n"),
1119       ({"z": "26", "a": "1"}, "a: 1\nz: 26\n"),
1120       ]
1121     for (data, expected) in test_samples:
1122       self._RunTest(data, expected)
1123
1124   def testLists(self):
1125     adict = {
1126       "aa": "11",
1127       "bb": "22",
1128       "cc": "33",
1129       }
1130     adict_exp = ("- aa: 11\n"
1131                  "  bb: 22\n"
1132                  "  cc: 33\n")
1133     anobj = [
1134       ("zz", "11"),
1135       ("ww", "33"),
1136       ("xx", "22"),
1137       ]
1138     anobj_exp = ("- zz: 11\n"
1139                  "  ww: 33\n"
1140                  "  xx: 22\n")
1141     alist = ["aa", "cc", "bb"]
1142     alist_exp = ("- - aa\n"
1143                  "  - cc\n"
1144                  "  - bb\n")
1145     test_samples = [
1146       (adict, adict_exp),
1147       (anobj, anobj_exp),
1148       (alist, alist_exp),
1149       ]
1150     for (base_data, base_expected) in test_samples:
1151       for k in range(1, 4):
1152         data = k * [base_data]
1153         expected = k * base_expected
1154         self._RunTest(data, expected)
1155
1156   def testDictionaries(self):
1157     data = [
1158       ("aaa", ["x", "y"]),
1159       ("bbb", {
1160           "w": "1",
1161           "z": "2",
1162           }),
1163       ("ccc", [
1164           ("xyz", "123"),
1165           ("efg", "456"),
1166           ]),
1167       ]
1168     expected = ("aaa: \n"
1169                 "  - x\n"
1170                 "  - y\n"
1171                 "bbb: \n"
1172                 "  w: 1\n"
1173                 "  z: 2\n"
1174                 "ccc: \n"
1175                 "  xyz: 123\n"
1176                 "  efg: 456\n")
1177     self._RunTest(data, expected)
1178     self._RunTest(dict(data), expected)
1179
1180
1181 class TestCreateIPolicyFromOpts(unittest.TestCase):
1182   """Test case for cli.CreateIPolicyFromOpts."""
1183   def setUp(self):
1184     # Policies are big, and we want to see the difference in case of an error
1185     self.maxDiff = None
1186
1187   def _RecursiveCheckMergedDicts(self, default_pol, diff_pol, merged_pol):
1188     self.assertTrue(type(default_pol) is dict)
1189     self.assertTrue(type(diff_pol) is dict)
1190     self.assertTrue(type(merged_pol) is dict)
1191     self.assertEqual(frozenset(default_pol.keys()),
1192                      frozenset(merged_pol.keys()))
1193     for (key, val) in merged_pol.items():
1194       if key in diff_pol:
1195         if type(val) is dict:
1196           self._RecursiveCheckMergedDicts(default_pol[key], diff_pol[key], val)
1197         else:
1198           self.assertEqual(val, diff_pol[key])
1199       else:
1200         self.assertEqual(val, default_pol[key])
1201
1202   def testClusterPolicy(self):
1203     pol0 = cli.CreateIPolicyFromOpts(
1204       ispecs_mem_size={},
1205       ispecs_cpu_count={},
1206       ispecs_disk_count={},
1207       ispecs_disk_size={},
1208       ispecs_nic_count={},
1209       ipolicy_disk_templates=None,
1210       ipolicy_vcpu_ratio=None,
1211       ipolicy_spindle_ratio=None,
1212       fill_all=True
1213       )
1214     self.assertEqual(pol0, constants.IPOLICY_DEFAULTS)
1215
1216     exp_pol1 = {
1217       constants.ISPECS_MINMAX: {
1218         constants.ISPECS_MIN: {
1219           constants.ISPEC_CPU_COUNT: 2,
1220           constants.ISPEC_DISK_COUNT: 1,
1221           },
1222         constants.ISPECS_MAX: {
1223           constants.ISPEC_MEM_SIZE: 12*1024,
1224           constants.ISPEC_DISK_COUNT: 2,
1225           },
1226         },
1227       constants.ISPECS_STD: {
1228         constants.ISPEC_CPU_COUNT: 2,
1229         constants.ISPEC_DISK_COUNT: 2,
1230         },
1231       constants.IPOLICY_VCPU_RATIO: 3.1,
1232       }
1233     pol1 = cli.CreateIPolicyFromOpts(
1234       ispecs_mem_size={"max": "12g"},
1235       ispecs_cpu_count={"min": 2, "std": 2},
1236       ispecs_disk_count={"min": 1, "max": 2, "std": 2},
1237       ispecs_disk_size={},
1238       ispecs_nic_count={},
1239       ipolicy_disk_templates=None,
1240       ipolicy_vcpu_ratio=3.1,
1241       ipolicy_spindle_ratio=None,
1242       fill_all=True
1243       )
1244     self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1245                                     exp_pol1, pol1)
1246
1247     exp_pol2 = {
1248       constants.ISPECS_MINMAX: {
1249         constants.ISPECS_MIN: {
1250           constants.ISPEC_DISK_SIZE: 512,
1251           constants.ISPEC_NIC_COUNT: 2,
1252           },
1253         constants.ISPECS_MAX: {
1254           constants.ISPEC_NIC_COUNT: 3,
1255           },
1256         },
1257       constants.ISPECS_STD: {
1258         constants.ISPEC_CPU_COUNT: 2,
1259         constants.ISPEC_NIC_COUNT: 3,
1260         },
1261       constants.IPOLICY_SPINDLE_RATIO: 1.3,
1262       constants.IPOLICY_DTS: ["templates"],
1263       }
1264     pol2 = cli.CreateIPolicyFromOpts(
1265       ispecs_mem_size={},
1266       ispecs_cpu_count={"std": 2},
1267       ispecs_disk_count={},
1268       ispecs_disk_size={"min": "0.5g"},
1269       ispecs_nic_count={"min": 2, "max": 3, "std": 3},
1270       ipolicy_disk_templates=["templates"],
1271       ipolicy_vcpu_ratio=None,
1272       ipolicy_spindle_ratio=1.3,
1273       fill_all=True
1274       )
1275     self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1276                                       exp_pol2, pol2)
1277
1278     for fill_all in [False, True]:
1279       exp_pol3 = {
1280         constants.ISPECS_STD: {
1281           constants.ISPEC_CPU_COUNT: 2,
1282           constants.ISPEC_NIC_COUNT: 3,
1283           },
1284         }
1285       pol3 = cli.CreateIPolicyFromOpts(
1286         std_ispecs={
1287           constants.ISPEC_CPU_COUNT: "2",
1288           constants.ISPEC_NIC_COUNT: "3",
1289           },
1290         ipolicy_disk_templates=None,
1291         ipolicy_vcpu_ratio=None,
1292         ipolicy_spindle_ratio=None,
1293         fill_all=fill_all
1294         )
1295       if fill_all:
1296         self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1297                                         exp_pol3, pol3)
1298       else:
1299         self.assertEqual(pol3, exp_pol3)
1300
1301   def testPartialPolicy(self):
1302     exp_pol0 = objects.MakeEmptyIPolicy()
1303     pol0 = cli.CreateIPolicyFromOpts(
1304       minmax_ispecs=None,
1305       std_ispecs=None,
1306       ipolicy_disk_templates=None,
1307       ipolicy_vcpu_ratio=None,
1308       ipolicy_spindle_ratio=None,
1309       fill_all=False
1310       )
1311     self.assertEqual(pol0, exp_pol0)
1312
1313     exp_pol1 = {
1314       constants.IPOLICY_VCPU_RATIO: 3.1,
1315       }
1316     pol1 = cli.CreateIPolicyFromOpts(
1317       minmax_ispecs=None,
1318       std_ispecs=None,
1319       ipolicy_disk_templates=None,
1320       ipolicy_vcpu_ratio=3.1,
1321       ipolicy_spindle_ratio=None,
1322       fill_all=False
1323       )
1324     self.assertEqual(pol1, exp_pol1)
1325
1326     exp_pol2 = {
1327       constants.IPOLICY_SPINDLE_RATIO: 1.3,
1328       constants.IPOLICY_DTS: ["templates"],
1329       }
1330     pol2 = cli.CreateIPolicyFromOpts(
1331       minmax_ispecs=None,
1332       std_ispecs=None,
1333       ipolicy_disk_templates=["templates"],
1334       ipolicy_vcpu_ratio=None,
1335       ipolicy_spindle_ratio=1.3,
1336       fill_all=False
1337       )
1338     self.assertEqual(pol2, exp_pol2)
1339
1340   def _TestInvalidISpecs(self, minmax_ispecs, std_ispecs, fail=True):
1341     for fill_all in [False, True]:
1342       if fail:
1343         self.assertRaises((errors.OpPrereqError,
1344                            errors.UnitParseError,
1345                            errors.TypeEnforcementError),
1346                           cli.CreateIPolicyFromOpts,
1347                           minmax_ispecs=minmax_ispecs,
1348                           std_ispecs=std_ispecs,
1349                           fill_all=fill_all)
1350       else:
1351         cli.CreateIPolicyFromOpts(minmax_ispecs=minmax_ispecs,
1352                                   std_ispecs=std_ispecs,
1353                                   fill_all=fill_all)
1354
1355   def testInvalidPolicies(self):
1356     self.assertRaises(AssertionError, cli.CreateIPolicyFromOpts,
1357                       std_ispecs={constants.ISPEC_MEM_SIZE: 1024},
1358                       ipolicy_disk_templates=None, ipolicy_vcpu_ratio=None,
1359                       ipolicy_spindle_ratio=None, group_ipolicy=True)
1360     self.assertRaises(errors.OpPrereqError, cli.CreateIPolicyFromOpts,
1361                       ispecs_mem_size={"wrong": "x"}, ispecs_cpu_count={},
1362                       ispecs_disk_count={}, ispecs_disk_size={},
1363                       ispecs_nic_count={}, ipolicy_disk_templates=None,
1364                       ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None,
1365                       fill_all=True)
1366     self.assertRaises(errors.TypeEnforcementError, cli.CreateIPolicyFromOpts,
1367                       ispecs_mem_size={}, ispecs_cpu_count={"min": "default"},
1368                       ispecs_disk_count={}, ispecs_disk_size={},
1369                       ispecs_nic_count={}, ipolicy_disk_templates=None,
1370                       ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None,
1371                       fill_all=True)
1372
1373     good_mmspecs = constants.ISPECS_MINMAX_DEFAULTS
1374     self._TestInvalidISpecs(good_mmspecs, None, fail=False)
1375     broken_mmspecs = copy.deepcopy(good_mmspecs)
1376     for key in constants.ISPECS_MINMAX_KEYS:
1377       for par in constants.ISPECS_PARAMETERS:
1378         old = broken_mmspecs[key][par]
1379         del broken_mmspecs[key][par]
1380         self._TestInvalidISpecs(broken_mmspecs, None)
1381         broken_mmspecs[key][par] = "invalid"
1382         self._TestInvalidISpecs(broken_mmspecs, None)
1383         broken_mmspecs[key][par] = old
1384       broken_mmspecs[key]["invalid_key"] = None
1385       self._TestInvalidISpecs(broken_mmspecs, None)
1386       del broken_mmspecs[key]["invalid_key"]
1387     assert broken_mmspecs == good_mmspecs
1388
1389     good_stdspecs = constants.IPOLICY_DEFAULTS[constants.ISPECS_STD]
1390     self._TestInvalidISpecs(None, good_stdspecs, fail=False)
1391     broken_stdspecs = copy.deepcopy(good_stdspecs)
1392     for par in constants.ISPECS_PARAMETERS:
1393       old = broken_stdspecs[par]
1394       broken_stdspecs[par] = "invalid"
1395       self._TestInvalidISpecs(None, broken_stdspecs)
1396       broken_stdspecs[par] = old
1397     broken_stdspecs["invalid_key"] = None
1398     self._TestInvalidISpecs(None, broken_stdspecs)
1399     del broken_stdspecs["invalid_key"]
1400     assert broken_stdspecs == good_stdspecs
1401
1402   def testAllowedValues(self):
1403     allowedv = "blah"
1404     exp_pol1 = {
1405       constants.ISPECS_MINMAX: allowedv,
1406       constants.IPOLICY_DTS: allowedv,
1407       constants.IPOLICY_VCPU_RATIO: allowedv,
1408       constants.IPOLICY_SPINDLE_RATIO: allowedv,
1409       }
1410     pol1 = cli.CreateIPolicyFromOpts(minmax_ispecs={allowedv: {}},
1411                                      std_ispecs=None,
1412                                      ipolicy_disk_templates=allowedv,
1413                                      ipolicy_vcpu_ratio=allowedv,
1414                                      ipolicy_spindle_ratio=allowedv,
1415                                      allowed_values=[allowedv])
1416     self.assertEqual(pol1, exp_pol1)
1417
1418   @staticmethod
1419   def _ConvertSpecToStrings(spec):
1420     ret = {}
1421     for (par, val) in spec.items():
1422       ret[par] = str(val)
1423     return ret
1424
1425   def _CheckNewStyleSpecsCall(self, exp_ipolicy, minmax_ispecs, std_ispecs,
1426                               group_ipolicy, fill_all):
1427     ipolicy = cli.CreateIPolicyFromOpts(minmax_ispecs=minmax_ispecs,
1428                                         std_ispecs=std_ispecs,
1429                                         group_ipolicy=group_ipolicy,
1430                                         fill_all=fill_all)
1431     self.assertEqual(ipolicy, exp_ipolicy)
1432
1433   def _TestFullISpecsInner(self, skel_exp_ipol, exp_minmax, exp_std,
1434                            group_ipolicy, fill_all):
1435     exp_ipol = skel_exp_ipol.copy()
1436     if exp_minmax is not None:
1437       minmax_ispecs = {}
1438       for (key, spec) in exp_minmax.items():
1439         minmax_ispecs[key] = self._ConvertSpecToStrings(spec)
1440       exp_ipol[constants.ISPECS_MINMAX] = exp_minmax
1441     else:
1442       minmax_ispecs = None
1443     if exp_std is not None:
1444       std_ispecs = self._ConvertSpecToStrings(exp_std)
1445       exp_ipol[constants.ISPECS_STD] = exp_std
1446     else:
1447       std_ispecs = None
1448
1449     self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1450                                  group_ipolicy, fill_all)
1451     if minmax_ispecs:
1452       for (key, spec) in minmax_ispecs.items():
1453         for par in [constants.ISPEC_MEM_SIZE, constants.ISPEC_DISK_SIZE]:
1454           if par in spec:
1455             spec[par] += "m"
1456             self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1457                                          group_ipolicy, fill_all)
1458     if std_ispecs:
1459       for par in [constants.ISPEC_MEM_SIZE, constants.ISPEC_DISK_SIZE]:
1460         if par in std_ispecs:
1461           std_ispecs[par] += "m"
1462           self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1463                                        group_ipolicy, fill_all)
1464
1465   def testFullISpecs(self):
1466     exp_minmax1 = {
1467       constants.ISPECS_MIN: {
1468         constants.ISPEC_MEM_SIZE: 512,
1469         constants.ISPEC_CPU_COUNT: 2,
1470         constants.ISPEC_DISK_COUNT: 2,
1471         constants.ISPEC_DISK_SIZE: 512,
1472         constants.ISPEC_NIC_COUNT: 2,
1473         constants.ISPEC_SPINDLE_USE: 2,
1474         },
1475       constants.ISPECS_MAX: {
1476         constants.ISPEC_MEM_SIZE: 768*1024,
1477         constants.ISPEC_CPU_COUNT: 7,
1478         constants.ISPEC_DISK_COUNT: 6,
1479         constants.ISPEC_DISK_SIZE: 2048*1024,
1480         constants.ISPEC_NIC_COUNT: 3,
1481         constants.ISPEC_SPINDLE_USE: 1,
1482         },
1483       }
1484     exp_std1 = {
1485       constants.ISPEC_MEM_SIZE: 768*1024,
1486       constants.ISPEC_CPU_COUNT: 7,
1487       constants.ISPEC_DISK_COUNT: 6,
1488       constants.ISPEC_DISK_SIZE: 2048*1024,
1489       constants.ISPEC_NIC_COUNT: 3,
1490       constants.ISPEC_SPINDLE_USE: 1,
1491       }
1492     for fill_all in [False, True]:
1493       if fill_all:
1494         skel_ipolicy = constants.IPOLICY_DEFAULTS
1495       else:
1496         skel_ipolicy = {}
1497       self._TestFullISpecsInner(skel_ipolicy, exp_minmax1, exp_std1,
1498                                 False, fill_all)
1499       self._TestFullISpecsInner(skel_ipolicy, None, exp_std1,
1500                                 False, fill_all)
1501       self._TestFullISpecsInner(skel_ipolicy, exp_minmax1, None,
1502                                 False, fill_all)
1503
1504
1505 class TestPrintIPolicyCommand(unittest.TestCase):
1506   """Test case for cli.PrintIPolicyCommand"""
1507   _SPECS1 = {
1508     "par1": 42,
1509     "par2": "xyz",
1510     }
1511   _SPECS1_STR = "par1=42,par2=xyz"
1512   _SPECS2 = {
1513     "param": 10,
1514     "another_param": 101,
1515     }
1516   _SPECS2_STR = "another_param=101,param=10"
1517
1518   def _CheckPrintIPolicyCommand(self, ipolicy, isgroup, expected):
1519     buf = StringIO()
1520     cli.PrintIPolicyCommand(buf, ipolicy, isgroup)
1521     self.assertEqual(buf.getvalue(), expected)
1522
1523   def testIgnoreStdForGroup(self):
1524     self._CheckPrintIPolicyCommand({"std": self._SPECS1}, True, "")
1525
1526   def testIgnoreEmpty(self):
1527     policies = [
1528       {},
1529       {"std": {}},
1530       {"minmax": {}},
1531       {"minmax": {
1532         "min": {},
1533         "max": {},
1534         }},
1535       {"minmax": {
1536         "min": self._SPECS1,
1537         "max": {},
1538         }},
1539       ]
1540     for pol in policies:
1541       self._CheckPrintIPolicyCommand(pol, False, "")
1542
1543   def testFullPolicies(self):
1544     cases = [
1545       ({"std": self._SPECS1},
1546        " %s %s" % (cli.IPOLICY_STD_SPECS_STR, self._SPECS1_STR)),
1547       ({"minmax": {
1548         "min": self._SPECS1,
1549         "max": self._SPECS2,
1550         }},
1551        " %s min:%s/max:%s" % (cli.IPOLICY_BOUNDS_SPECS_STR,
1552                               self._SPECS1_STR, self._SPECS2_STR)),
1553       ]
1554     for (pol, exp) in cases:
1555       self._CheckPrintIPolicyCommand(pol, False, exp)
1556
1557
1558 if __name__ == "__main__":
1559   testutils.GanetiTestProgram()