cfg auto update: match ipolicy with enabled disk templates
[ganeti-local] / test / py / ganeti.cli_unittest.py
1 #!/usr/bin/python
2 #
3
4 # Copyright (C) 2008, 2011, 2012, 2013 Google Inc.
5 #
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 2 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program; if not, write to the Free Software
18 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19 # 02110-1301, USA.
20
21
22 """Script for unittesting the cli module"""
23
24 import copy
25 import testutils
26 import time
27 import unittest
28 import yaml
29 from cStringIO import StringIO
30
31 from ganeti import constants
32 from ganeti import cli
33 from ganeti import errors
34 from ganeti import utils
35 from ganeti import objects
36 from ganeti import qlang
37 from ganeti.errors import OpPrereqError, ParameterError
38
39
40 class TestParseTimespec(unittest.TestCase):
41   """Testing case for ParseTimespec"""
42
43   def testValidTimes(self):
44     """Test valid timespecs"""
45     test_data = [
46       ("1s", 1),
47       ("1", 1),
48       ("1m", 60),
49       ("1h", 60 * 60),
50       ("1d", 60 * 60 * 24),
51       ("1w", 60 * 60 * 24 * 7),
52       ("4h", 4 * 60 * 60),
53       ("61m", 61 * 60),
54       ]
55     for value, expected_result in test_data:
56       self.failUnlessEqual(cli.ParseTimespec(value), expected_result)
57
58   def testInvalidTime(self):
59     """Test invalid timespecs"""
60     test_data = [
61       "1y",
62       "",
63       "aaa",
64       "s",
65       ]
66     for value in test_data:
67       self.failUnlessRaises(OpPrereqError, cli.ParseTimespec, value)
68
69
70 class TestSplitKeyVal(unittest.TestCase):
71   """Testing case for cli._SplitKeyVal"""
72   DATA = "a=b,c,no_d,-e"
73   RESULT = {"a": "b", "c": True, "d": False, "e": None}
74   RESULT_NOPREFIX = {"a": "b", "c": {}, "no_d": {}, "-e": {}}
75
76   def testSplitKeyVal(self):
77     """Test splitting"""
78     self.failUnlessEqual(cli._SplitKeyVal("option", self.DATA, True),
79                          self.RESULT)
80
81   def testDuplicateParam(self):
82     """Test duplicate parameters"""
83     for data in ("a=1,a=2", "a,no_a"):
84       self.failUnlessRaises(ParameterError, cli._SplitKeyVal,
85                             "option", data, True)
86
87   def testEmptyData(self):
88     """Test how we handle splitting an empty string"""
89     self.failUnlessEqual(cli._SplitKeyVal("option", "", True), {})
90
91
92 class TestIdentKeyVal(unittest.TestCase):
93   """Testing case for cli.check_ident_key_val"""
94
95   def testIdentKeyVal(self):
96     """Test identkeyval"""
97     def cikv(value):
98       return cli.check_ident_key_val("option", "opt", value)
99
100     self.assertEqual(cikv("foo:bar"), ("foo", {"bar": True}))
101     self.assertEqual(cikv("foo:bar=baz"), ("foo", {"bar": "baz"}))
102     self.assertEqual(cikv("bar:b=c,c=a"), ("bar", {"b": "c", "c": "a"}))
103     self.assertEqual(cikv("no_bar"), ("bar", False))
104     self.assertRaises(ParameterError, cikv, "no_bar:foo")
105     self.assertRaises(ParameterError, cikv, "no_bar:foo=baz")
106     self.assertRaises(ParameterError, cikv, "bar:foo=baz,foo=baz")
107     self.assertEqual(cikv("-foo"), ("foo", None))
108     self.assertRaises(ParameterError, cikv, "-foo:a=c")
109
110     # Check negative numbers
111     self.assertEqual(cikv("-1:remove"), ("-1", {
112       "remove": True,
113       }))
114     self.assertEqual(cikv("-29447:add,size=4G"), ("-29447", {
115       "add": True,
116       "size": "4G",
117       }))
118     for i in ["-:", "-"]:
119       self.assertEqual(cikv(i), ("", None))
120
121   @staticmethod
122   def _csikv(value):
123     return cli._SplitIdentKeyVal("opt", value, False)
124
125   def testIdentKeyValNoPrefix(self):
126     """Test identkeyval without prefixes"""
127     test_cases = [
128       ("foo:bar", None),
129       ("foo:no_bar", None),
130       ("foo:bar=baz,bar=baz", None),
131       ("foo",
132        ("foo", {})),
133       ("foo:bar=baz",
134        ("foo", {"bar": "baz"})),
135       ("no_foo:-1=baz,no_op=3",
136        ("no_foo", {"-1": "baz", "no_op": "3"})),
137       ]
138     for (arg, res) in test_cases:
139       if res is None:
140         self.assertRaises(ParameterError, self._csikv, arg)
141       else:
142         self.assertEqual(self._csikv(arg), res)
143
144
145 class TestMultilistIdentKeyVal(unittest.TestCase):
146   """Test for cli.check_multilist_ident_key_val()"""
147
148   @staticmethod
149   def _cmikv(value):
150     return cli.check_multilist_ident_key_val("option", "opt", value)
151
152   def testListIdentKeyVal(self):
153     test_cases = [
154       ("",
155        None),
156       ("foo", [
157         {"foo": {}}
158         ]),
159       ("foo:bar=baz", [
160         {"foo": {"bar": "baz"}}
161         ]),
162       ("foo:bar=baz/foo:bat=bad",
163        None),
164       ("foo:abc=42/bar:def=11", [
165         {"foo": {"abc": "42"},
166          "bar": {"def": "11"}}
167         ]),
168       ("foo:abc=42/bar:def=11,ghi=07", [
169         {"foo": {"abc": "42"},
170          "bar": {"def": "11", "ghi": "07"}}
171         ]),
172       ("foo:abc=42/bar:def=11//",
173        None),
174       ("foo:abc=42/bar:def=11,ghi=07//foobar", [
175         {"foo": {"abc": "42"},
176          "bar": {"def": "11", "ghi": "07"}},
177         {"foobar": {}}
178         ]),
179       ("foo:abc=42/bar:def=11,ghi=07//foobar:xyz=88", [
180         {"foo": {"abc": "42"},
181          "bar": {"def": "11", "ghi": "07"}},
182         {"foobar": {"xyz": "88"}}
183         ]),
184       ("foo:abc=42/bar:def=11,ghi=07//foobar:xyz=88/foo:uvw=314", [
185         {"foo": {"abc": "42"},
186          "bar": {"def": "11", "ghi": "07"}},
187         {"foobar": {"xyz": "88"},
188          "foo": {"uvw": "314"}}
189         ]),
190       ]
191     for (arg, res) in test_cases:
192       if res is None:
193         self.assertRaises(ParameterError, self._cmikv, arg)
194       else:
195         self.assertEqual(res, self._cmikv(arg))
196
197
198 class TestToStream(unittest.TestCase):
199   """Test the ToStream functions"""
200
201   def testBasic(self):
202     for data in ["foo",
203                  "foo %s",
204                  "foo %(test)s",
205                  "foo %s %s",
206                  "",
207                  ]:
208       buf = StringIO()
209       cli._ToStream(buf, data)
210       self.failUnlessEqual(buf.getvalue(), data + "\n")
211
212   def testParams(self):
213       buf = StringIO()
214       cli._ToStream(buf, "foo %s", 1)
215       self.failUnlessEqual(buf.getvalue(), "foo 1\n")
216       buf = StringIO()
217       cli._ToStream(buf, "foo %s", (15,16))
218       self.failUnlessEqual(buf.getvalue(), "foo (15, 16)\n")
219       buf = StringIO()
220       cli._ToStream(buf, "foo %s %s", "a", "b")
221       self.failUnlessEqual(buf.getvalue(), "foo a b\n")
222
223
224 class TestGenerateTable(unittest.TestCase):
225   HEADERS = dict([("f%s" % i, "Field%s" % i) for i in range(5)])
226
227   FIELDS1 = ["f1", "f2"]
228   DATA1 = [
229     ["abc", 1234],
230     ["foobar", 56],
231     ["b", -14],
232     ]
233
234   def _test(self, headers, fields, separator, data,
235             numfields, unitfields, units, expected):
236     table = cli.GenerateTable(headers, fields, separator, data,
237                               numfields=numfields, unitfields=unitfields,
238                               units=units)
239     self.assertEqual(table, expected)
240
241   def testPlain(self):
242     exp = [
243       "Field1 Field2",
244       "abc    1234",
245       "foobar 56",
246       "b      -14",
247       ]
248     self._test(self.HEADERS, self.FIELDS1, None, self.DATA1,
249                None, None, "m", exp)
250
251   def testNoFields(self):
252     self._test(self.HEADERS, [], None, [[], []],
253                None, None, "m", ["", "", ""])
254     self._test(None, [], None, [[], []],
255                None, None, "m", ["", ""])
256
257   def testSeparator(self):
258     for sep in ["#", ":", ",", "^", "!", "%", "|", "###", "%%", "!!!", "||"]:
259       exp = [
260         "Field1%sField2" % sep,
261         "abc%s1234" % sep,
262         "foobar%s56" % sep,
263         "b%s-14" % sep,
264         ]
265       self._test(self.HEADERS, self.FIELDS1, sep, self.DATA1,
266                  None, None, "m", exp)
267
268   def testNoHeader(self):
269     exp = [
270       "abc    1234",
271       "foobar 56",
272       "b      -14",
273       ]
274     self._test(None, self.FIELDS1, None, self.DATA1,
275                None, None, "m", exp)
276
277   def testUnknownField(self):
278     headers = {
279       "f1": "Field1",
280       }
281     exp = [
282       "Field1 UNKNOWN",
283       "abc    1234",
284       "foobar 56",
285       "b      -14",
286       ]
287     self._test(headers, ["f1", "UNKNOWN"], None, self.DATA1,
288                None, None, "m", exp)
289
290   def testNumfields(self):
291     fields = ["f1", "f2", "f3"]
292     data = [
293       ["abc", 1234, 0],
294       ["foobar", 56, 3],
295       ["b", -14, "-"],
296       ]
297     exp = [
298       "Field1 Field2 Field3",
299       "abc      1234      0",
300       "foobar     56      3",
301       "b         -14      -",
302       ]
303     self._test(self.HEADERS, fields, None, data,
304                ["f2", "f3"], None, "m", exp)
305
306   def testUnitfields(self):
307     expnosep = [
308       "Field1 Field2 Field3",
309       "abc      1234     0M",
310       "foobar     56     3M",
311       "b         -14      -",
312       ]
313
314     expsep = [
315       "Field1:Field2:Field3",
316       "abc:1234:0M",
317       "foobar:56:3M",
318       "b:-14:-",
319       ]
320
321     for sep, expected in [(None, expnosep), (":", expsep)]:
322       fields = ["f1", "f2", "f3"]
323       data = [
324         ["abc", 1234, 0],
325         ["foobar", 56, 3],
326         ["b", -14, "-"],
327         ]
328       self._test(self.HEADERS, fields, sep, data,
329                  ["f2", "f3"], ["f3"], "h", expected)
330
331   def testUnusual(self):
332     data = [
333       ["%", "xyz"],
334       ["%%", "abc"],
335       ]
336     exp = [
337       "Field1 Field2",
338       "%      xyz",
339       "%%     abc",
340       ]
341     self._test(self.HEADERS, ["f1", "f2"], None, data,
342                None, None, "m", exp)
343
344
345 class TestFormatQueryResult(unittest.TestCase):
346   def test(self):
347     fields = [
348       objects.QueryFieldDefinition(name="name", title="Name",
349                                    kind=constants.QFT_TEXT),
350       objects.QueryFieldDefinition(name="size", title="Size",
351                                    kind=constants.QFT_NUMBER),
352       objects.QueryFieldDefinition(name="act", title="Active",
353                                    kind=constants.QFT_BOOL),
354       objects.QueryFieldDefinition(name="mem", title="Memory",
355                                    kind=constants.QFT_UNIT),
356       objects.QueryFieldDefinition(name="other", title="SomeList",
357                                    kind=constants.QFT_OTHER),
358       ]
359
360     response = objects.QueryResponse(fields=fields, data=[
361       [(constants.RS_NORMAL, "nodeA"), (constants.RS_NORMAL, 128),
362        (constants.RS_NORMAL, False), (constants.RS_NORMAL, 1468006),
363        (constants.RS_NORMAL, [])],
364       [(constants.RS_NORMAL, "other"), (constants.RS_NORMAL, 512),
365        (constants.RS_NORMAL, True), (constants.RS_NORMAL, 16),
366        (constants.RS_NORMAL, [1, 2, 3])],
367       [(constants.RS_NORMAL, "xyz"), (constants.RS_NORMAL, 1024),
368        (constants.RS_NORMAL, True), (constants.RS_NORMAL, 4096),
369        (constants.RS_NORMAL, [{}, {}])],
370       ])
371
372     self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True),
373       (cli.QR_NORMAL, [
374       "Name  Size Active Memory SomeList",
375       "nodeA  128 N        1.4T []",
376       "other  512 Y         16M [1, 2, 3]",
377       "xyz   1024 Y        4.0G [{}, {}]",
378       ]))
379
380   def testTimestampAndUnit(self):
381     fields = [
382       objects.QueryFieldDefinition(name="name", title="Name",
383                                    kind=constants.QFT_TEXT),
384       objects.QueryFieldDefinition(name="size", title="Size",
385                                    kind=constants.QFT_UNIT),
386       objects.QueryFieldDefinition(name="mtime", title="ModTime",
387                                    kind=constants.QFT_TIMESTAMP),
388       ]
389
390     response = objects.QueryResponse(fields=fields, data=[
391       [(constants.RS_NORMAL, "a"), (constants.RS_NORMAL, 1024),
392        (constants.RS_NORMAL, 0)],
393       [(constants.RS_NORMAL, "b"), (constants.RS_NORMAL, 144996),
394        (constants.RS_NORMAL, 1291746295)],
395       ])
396
397     self.assertEqual(cli.FormatQueryResult(response, unit="m", header=True),
398       (cli.QR_NORMAL, [
399       "Name   Size ModTime",
400       "a      1024 %s" % utils.FormatTime(0),
401       "b    144996 %s" % utils.FormatTime(1291746295),
402       ]))
403
404   def testOverride(self):
405     fields = [
406       objects.QueryFieldDefinition(name="name", title="Name",
407                                    kind=constants.QFT_TEXT),
408       objects.QueryFieldDefinition(name="cust", title="Custom",
409                                    kind=constants.QFT_OTHER),
410       objects.QueryFieldDefinition(name="xt", title="XTime",
411                                    kind=constants.QFT_TIMESTAMP),
412       ]
413
414     response = objects.QueryResponse(fields=fields, data=[
415       [(constants.RS_NORMAL, "x"), (constants.RS_NORMAL, ["a", "b", "c"]),
416        (constants.RS_NORMAL, 1234)],
417       [(constants.RS_NORMAL, "y"), (constants.RS_NORMAL, range(10)),
418        (constants.RS_NORMAL, 1291746295)],
419       ])
420
421     override = {
422       "cust": (utils.CommaJoin, False),
423       "xt": (hex, True),
424       }
425
426     self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True,
427                                            format_override=override),
428       (cli.QR_NORMAL, [
429       "Name Custom                            XTime",
430       "x    a, b, c                           0x4d2",
431       "y    0, 1, 2, 3, 4, 5, 6, 7, 8, 9 0x4cfe7bf7",
432       ]))
433
434   def testSeparator(self):
435     fields = [
436       objects.QueryFieldDefinition(name="name", title="Name",
437                                    kind=constants.QFT_TEXT),
438       objects.QueryFieldDefinition(name="count", title="Count",
439                                    kind=constants.QFT_NUMBER),
440       objects.QueryFieldDefinition(name="desc", title="Description",
441                                    kind=constants.QFT_TEXT),
442       ]
443
444     response = objects.QueryResponse(fields=fields, data=[
445       [(constants.RS_NORMAL, "instance1.example.com"),
446        (constants.RS_NORMAL, 21125), (constants.RS_NORMAL, "Hello World!")],
447       [(constants.RS_NORMAL, "mail.other.net"),
448        (constants.RS_NORMAL, -9000), (constants.RS_NORMAL, "a,b,c")],
449       ])
450
451     for sep in [":", "|", "#", "|||", "###", "@@@", "@#@"]:
452       for header in [None, "Name%sCount%sDescription" % (sep, sep)]:
453         exp = []
454         if header:
455           exp.append(header)
456         exp.extend([
457           "instance1.example.com%s21125%sHello World!" % (sep, sep),
458           "mail.other.net%s-9000%sa,b,c" % (sep, sep),
459           ])
460
461         self.assertEqual(cli.FormatQueryResult(response, separator=sep,
462                                                header=bool(header)),
463                          (cli.QR_NORMAL, exp))
464
465   def testStatusWithUnknown(self):
466     fields = [
467       objects.QueryFieldDefinition(name="id", title="ID",
468                                    kind=constants.QFT_NUMBER),
469       objects.QueryFieldDefinition(name="unk", title="unk",
470                                    kind=constants.QFT_UNKNOWN),
471       objects.QueryFieldDefinition(name="unavail", title="Unavail",
472                                    kind=constants.QFT_BOOL),
473       objects.QueryFieldDefinition(name="nodata", title="NoData",
474                                    kind=constants.QFT_TEXT),
475       objects.QueryFieldDefinition(name="offline", title="OffLine",
476                                    kind=constants.QFT_TEXT),
477       ]
478
479     response = objects.QueryResponse(fields=fields, data=[
480       [(constants.RS_NORMAL, 1), (constants.RS_UNKNOWN, None),
481        (constants.RS_NORMAL, False), (constants.RS_NORMAL, ""),
482        (constants.RS_OFFLINE, None)],
483       [(constants.RS_NORMAL, 2), (constants.RS_UNKNOWN, None),
484        (constants.RS_NODATA, None), (constants.RS_NORMAL, "x"),
485        (constants.RS_OFFLINE, None)],
486       [(constants.RS_NORMAL, 3), (constants.RS_UNKNOWN, None),
487        (constants.RS_NORMAL, False), (constants.RS_UNAVAIL, None),
488        (constants.RS_OFFLINE, None)],
489       ])
490
491     self.assertEqual(cli.FormatQueryResult(response, header=True,
492                                            separator="|", verbose=True),
493       (cli.QR_UNKNOWN, [
494       "ID|unk|Unavail|NoData|OffLine",
495       "1|(unknown)|N||(offline)",
496       "2|(unknown)|(nodata)|x|(offline)",
497       "3|(unknown)|N|(unavail)|(offline)",
498       ]))
499     self.assertEqual(cli.FormatQueryResult(response, header=True,
500                                            separator="|", verbose=False),
501       (cli.QR_UNKNOWN, [
502       "ID|unk|Unavail|NoData|OffLine",
503       "1|??|N||*",
504       "2|??|?|x|*",
505       "3|??|N|-|*",
506       ]))
507
508   def testNoData(self):
509     fields = [
510       objects.QueryFieldDefinition(name="id", title="ID",
511                                    kind=constants.QFT_NUMBER),
512       objects.QueryFieldDefinition(name="name", title="Name",
513                                    kind=constants.QFT_TEXT),
514       ]
515
516     response = objects.QueryResponse(fields=fields, data=[])
517
518     self.assertEqual(cli.FormatQueryResult(response, header=True),
519                      (cli.QR_NORMAL, ["ID Name"]))
520
521   def testNoDataWithUnknown(self):
522     fields = [
523       objects.QueryFieldDefinition(name="id", title="ID",
524                                    kind=constants.QFT_NUMBER),
525       objects.QueryFieldDefinition(name="unk", title="unk",
526                                    kind=constants.QFT_UNKNOWN),
527       ]
528
529     response = objects.QueryResponse(fields=fields, data=[])
530
531     self.assertEqual(cli.FormatQueryResult(response, header=False),
532                      (cli.QR_UNKNOWN, []))
533
534   def testStatus(self):
535     fields = [
536       objects.QueryFieldDefinition(name="id", title="ID",
537                                    kind=constants.QFT_NUMBER),
538       objects.QueryFieldDefinition(name="unavail", title="Unavail",
539                                    kind=constants.QFT_BOOL),
540       objects.QueryFieldDefinition(name="nodata", title="NoData",
541                                    kind=constants.QFT_TEXT),
542       objects.QueryFieldDefinition(name="offline", title="OffLine",
543                                    kind=constants.QFT_TEXT),
544       ]
545
546     response = objects.QueryResponse(fields=fields, data=[
547       [(constants.RS_NORMAL, 1), (constants.RS_NORMAL, False),
548        (constants.RS_NORMAL, ""), (constants.RS_OFFLINE, None)],
549       [(constants.RS_NORMAL, 2), (constants.RS_NODATA, None),
550        (constants.RS_NORMAL, "x"), (constants.RS_NORMAL, "abc")],
551       [(constants.RS_NORMAL, 3), (constants.RS_NORMAL, False),
552        (constants.RS_UNAVAIL, None), (constants.RS_OFFLINE, None)],
553       ])
554
555     self.assertEqual(cli.FormatQueryResult(response, header=False,
556                                            separator="|", verbose=True),
557       (cli.QR_INCOMPLETE, [
558       "1|N||(offline)",
559       "2|(nodata)|x|abc",
560       "3|N|(unavail)|(offline)",
561       ]))
562     self.assertEqual(cli.FormatQueryResult(response, header=False,
563                                            separator="|", verbose=False),
564       (cli.QR_INCOMPLETE, [
565       "1|N||*",
566       "2|?|x|abc",
567       "3|N|-|*",
568       ]))
569
570   def testInvalidFieldType(self):
571     fields = [
572       objects.QueryFieldDefinition(name="x", title="x",
573                                    kind="#some#other#type"),
574       ]
575
576     response = objects.QueryResponse(fields=fields, data=[])
577
578     self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
579
580   def testInvalidFieldStatus(self):
581     fields = [
582       objects.QueryFieldDefinition(name="x", title="x",
583                                    kind=constants.QFT_TEXT),
584       ]
585
586     response = objects.QueryResponse(fields=fields, data=[[(-1, None)]])
587     self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
588
589     response = objects.QueryResponse(fields=fields, data=[[(-1, "x")]])
590     self.assertRaises(AssertionError, cli.FormatQueryResult, response)
591
592   def testEmptyFieldTitle(self):
593     fields = [
594       objects.QueryFieldDefinition(name="x", title="",
595                                    kind=constants.QFT_TEXT),
596       ]
597
598     response = objects.QueryResponse(fields=fields, data=[])
599     self.assertRaises(AssertionError, cli.FormatQueryResult, response)
600
601
602 class _MockJobPollCb(cli.JobPollCbBase, cli.JobPollReportCbBase):
603   def __init__(self, tc, job_id):
604     self.tc = tc
605     self.job_id = job_id
606     self._wfjcr = []
607     self._jobstatus = []
608     self._expect_notchanged = False
609     self._expect_log = []
610
611   def CheckEmpty(self):
612     self.tc.assertFalse(self._wfjcr)
613     self.tc.assertFalse(self._jobstatus)
614     self.tc.assertFalse(self._expect_notchanged)
615     self.tc.assertFalse(self._expect_log)
616
617   def AddWfjcResult(self, *args):
618     self._wfjcr.append(args)
619
620   def AddQueryJobsResult(self, *args):
621     self._jobstatus.append(args)
622
623   def WaitForJobChangeOnce(self, job_id, fields,
624                            prev_job_info, prev_log_serial):
625     self.tc.assertEqual(job_id, self.job_id)
626     self.tc.assertEqualValues(fields, ["status"])
627     self.tc.assertFalse(self._expect_notchanged)
628     self.tc.assertFalse(self._expect_log)
629
630     (exp_prev_job_info, exp_prev_log_serial, result) = self._wfjcr.pop(0)
631     self.tc.assertEqualValues(prev_job_info, exp_prev_job_info)
632     self.tc.assertEqual(prev_log_serial, exp_prev_log_serial)
633
634     if result == constants.JOB_NOTCHANGED:
635       self._expect_notchanged = True
636     elif result:
637       (_, logmsgs) = result
638       if logmsgs:
639         self._expect_log.extend(logmsgs)
640
641     return result
642
643   def QueryJobs(self, job_ids, fields):
644     self.tc.assertEqual(job_ids, [self.job_id])
645     self.tc.assertEqualValues(fields, ["status", "opstatus", "opresult"])
646     self.tc.assertFalse(self._expect_notchanged)
647     self.tc.assertFalse(self._expect_log)
648
649     result = self._jobstatus.pop(0)
650     self.tc.assertEqual(len(fields), len(result))
651     return [result]
652
653   def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
654     self.tc.assertEqual(job_id, self.job_id)
655     self.tc.assertEqualValues((serial, timestamp, log_type, log_msg),
656                               self._expect_log.pop(0))
657
658   def ReportNotChanged(self, job_id, status):
659     self.tc.assertEqual(job_id, self.job_id)
660     self.tc.assert_(self._expect_notchanged)
661     self._expect_notchanged = False
662
663
664 class TestGenericPollJob(testutils.GanetiTestCase):
665   def testSuccessWithLog(self):
666     job_id = 29609
667     cbs = _MockJobPollCb(self, job_id)
668
669     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
670
671     cbs.AddWfjcResult(None, None,
672                       ((constants.JOB_STATUS_QUEUED, ), None))
673
674     cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
675                       constants.JOB_NOTCHANGED)
676
677     cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
678                       ((constants.JOB_STATUS_RUNNING, ),
679                        [(1, utils.SplitTime(1273491611.0),
680                          constants.ELOG_MESSAGE, "Step 1"),
681                         (2, utils.SplitTime(1273491615.9),
682                          constants.ELOG_MESSAGE, "Step 2"),
683                         (3, utils.SplitTime(1273491625.02),
684                          constants.ELOG_MESSAGE, "Step 3"),
685                         (4, utils.SplitTime(1273491635.05),
686                          constants.ELOG_MESSAGE, "Step 4"),
687                         (37, utils.SplitTime(1273491645.0),
688                          constants.ELOG_MESSAGE, "Step 5"),
689                         (203, utils.SplitTime(127349155.0),
690                          constants.ELOG_MESSAGE, "Step 6")]))
691
692     cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 203,
693                       ((constants.JOB_STATUS_RUNNING, ),
694                        [(300, utils.SplitTime(1273491711.01),
695                          constants.ELOG_MESSAGE, "Step X"),
696                         (302, utils.SplitTime(1273491815.8),
697                          constants.ELOG_MESSAGE, "Step Y"),
698                         (303, utils.SplitTime(1273491925.32),
699                          constants.ELOG_MESSAGE, "Step Z")]))
700
701     cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 303,
702                       ((constants.JOB_STATUS_SUCCESS, ), None))
703
704     cbs.AddQueryJobsResult(constants.JOB_STATUS_SUCCESS,
705                            [constants.OP_STATUS_SUCCESS,
706                             constants.OP_STATUS_SUCCESS],
707                            ["Hello World", "Foo man bar"])
708
709     self.assertEqual(["Hello World", "Foo man bar"],
710                      cli.GenericPollJob(job_id, cbs, cbs))
711     cbs.CheckEmpty()
712
713   def testJobLost(self):
714     job_id = 13746
715
716     cbs = _MockJobPollCb(self, job_id)
717     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
718     cbs.AddWfjcResult(None, None, None)
719     self.assertRaises(errors.JobLost, cli.GenericPollJob, job_id, cbs, cbs)
720     cbs.CheckEmpty()
721
722   def testError(self):
723     job_id = 31088
724
725     cbs = _MockJobPollCb(self, job_id)
726     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
727     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
728     cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
729                            [constants.OP_STATUS_SUCCESS,
730                             constants.OP_STATUS_ERROR],
731                            ["Hello World", "Error code 123"])
732     self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
733     cbs.CheckEmpty()
734
735   def testError2(self):
736     job_id = 22235
737
738     cbs = _MockJobPollCb(self, job_id)
739     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
740     encexc = errors.EncodeException(errors.LockError("problem"))
741     cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
742                            [constants.OP_STATUS_ERROR], [encexc])
743     self.assertRaises(errors.LockError, cli.GenericPollJob, job_id, cbs, cbs)
744     cbs.CheckEmpty()
745
746   def testWeirdError(self):
747     job_id = 28847
748
749     cbs = _MockJobPollCb(self, job_id)
750     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
751     cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
752                            [constants.OP_STATUS_RUNNING,
753                             constants.OP_STATUS_RUNNING],
754                            [None, None])
755     self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
756     cbs.CheckEmpty()
757
758   def testCancel(self):
759     job_id = 4275
760
761     cbs = _MockJobPollCb(self, job_id)
762     cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
763     cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_CANCELING, ), None))
764     cbs.AddQueryJobsResult(constants.JOB_STATUS_CANCELING,
765                            [constants.OP_STATUS_CANCELING,
766                             constants.OP_STATUS_CANCELING],
767                            [None, None])
768     self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
769     cbs.CheckEmpty()
770
771
772 class TestFormatLogMessage(unittest.TestCase):
773   def test(self):
774     self.assertEqual(cli.FormatLogMessage(constants.ELOG_MESSAGE,
775                                           "Hello World"),
776                      "Hello World")
777     self.assertRaises(TypeError, cli.FormatLogMessage,
778                       constants.ELOG_MESSAGE, [1, 2, 3])
779
780     self.assert_(cli.FormatLogMessage("some other type", (1, 2, 3)))
781
782
783 class TestParseFields(unittest.TestCase):
784   def test(self):
785     self.assertEqual(cli.ParseFields(None, []), [])
786     self.assertEqual(cli.ParseFields("name,foo,hello", []),
787                      ["name", "foo", "hello"])
788     self.assertEqual(cli.ParseFields(None, ["def", "ault", "fields", "here"]),
789                      ["def", "ault", "fields", "here"])
790     self.assertEqual(cli.ParseFields("name,foo", ["def", "ault"]),
791                      ["name", "foo"])
792     self.assertEqual(cli.ParseFields("+name,foo", ["def", "ault"]),
793                      ["def", "ault", "name", "foo"])
794
795
796 class TestConstants(unittest.TestCase):
797   def testPriority(self):
798     self.assertEqual(set(cli._PRIONAME_TO_VALUE.values()),
799                      set(constants.OP_PRIO_SUBMIT_VALID))
800     self.assertEqual(list(value for _, value in cli._PRIORITY_NAMES),
801                      sorted(constants.OP_PRIO_SUBMIT_VALID, reverse=True))
802
803
804 class TestParseNicOption(unittest.TestCase):
805   def test(self):
806     self.assertEqual(cli.ParseNicOption([("0", { "link": "eth0", })]),
807                      [{ "link": "eth0", }])
808     self.assertEqual(cli.ParseNicOption([("5", { "ip": "192.0.2.7", })]),
809                      [{}, {}, {}, {}, {}, { "ip": "192.0.2.7", }])
810
811   def testErrors(self):
812     for i in [None, "", "abc", "zero", "Hello World", "\0", []]:
813       self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
814                         [(i, { "link": "eth0", })])
815       self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
816                         [("0", i)])
817
818     self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
819                       [(0, { True: False, })])
820
821     self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
822                       [(3, { "mode": [], })])
823
824
825 class TestFormatResultError(unittest.TestCase):
826   def testNormal(self):
827     for verbose in [False, True]:
828       self.assertRaises(AssertionError, cli.FormatResultError,
829                         constants.RS_NORMAL, verbose)
830
831   def testUnknown(self):
832     for verbose in [False, True]:
833       self.assertRaises(NotImplementedError, cli.FormatResultError,
834                         "#some!other!status#", verbose)
835
836   def test(self):
837     for status in constants.RS_ALL:
838       if status == constants.RS_NORMAL:
839         continue
840
841       self.assertNotEqual(cli.FormatResultError(status, False),
842                           cli.FormatResultError(status, True))
843
844       result = cli.FormatResultError(status, True)
845       self.assertTrue(result.startswith("("))
846       self.assertTrue(result.endswith(")"))
847
848
849 class TestGetOnlineNodes(unittest.TestCase):
850   class _FakeClient:
851     def __init__(self):
852       self._query = []
853
854     def AddQueryResult(self, *args):
855       self._query.append(args)
856
857     def CountPending(self):
858       return len(self._query)
859
860     def Query(self, res, fields, qfilter):
861       if res != constants.QR_NODE:
862         raise Exception("Querying wrong resource")
863
864       (exp_fields, check_filter, result) = self._query.pop(0)
865
866       if exp_fields != fields:
867         raise Exception("Expected fields %s, got %s" % (exp_fields, fields))
868
869       if not (qfilter is None or check_filter(qfilter)):
870         raise Exception("Filter doesn't match expectations")
871
872       return objects.QueryResponse(fields=None, data=result)
873
874   def testEmpty(self):
875     cl = self._FakeClient()
876
877     cl.AddQueryResult(["name", "offline", "sip"], None, [])
878     self.assertEqual(cli.GetOnlineNodes(None, cl=cl), [])
879     self.assertEqual(cl.CountPending(), 0)
880
881   def testNoSpecialFilter(self):
882     cl = self._FakeClient()
883
884     cl.AddQueryResult(["name", "offline", "sip"], None, [
885       [(constants.RS_NORMAL, "master.example.com"),
886        (constants.RS_NORMAL, False),
887        (constants.RS_NORMAL, "192.0.2.1")],
888       [(constants.RS_NORMAL, "node2.example.com"),
889        (constants.RS_NORMAL, False),
890        (constants.RS_NORMAL, "192.0.2.2")],
891       ])
892     self.assertEqual(cli.GetOnlineNodes(None, cl=cl),
893                      ["master.example.com", "node2.example.com"])
894     self.assertEqual(cl.CountPending(), 0)
895
896   def testNoMaster(self):
897     cl = self._FakeClient()
898
899     def _CheckFilter(qfilter):
900       self.assertEqual(qfilter, [qlang.OP_NOT, [qlang.OP_TRUE, "master"]])
901       return True
902
903     cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
904       [(constants.RS_NORMAL, "node2.example.com"),
905        (constants.RS_NORMAL, False),
906        (constants.RS_NORMAL, "192.0.2.2")],
907       ])
908     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, filter_master=True),
909                      ["node2.example.com"])
910     self.assertEqual(cl.CountPending(), 0)
911
912   def testSecondaryIpAddress(self):
913     cl = self._FakeClient()
914
915     cl.AddQueryResult(["name", "offline", "sip"], None, [
916       [(constants.RS_NORMAL, "master.example.com"),
917        (constants.RS_NORMAL, False),
918        (constants.RS_NORMAL, "192.0.2.1")],
919       [(constants.RS_NORMAL, "node2.example.com"),
920        (constants.RS_NORMAL, False),
921        (constants.RS_NORMAL, "192.0.2.2")],
922       ])
923     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, secondary_ips=True),
924                      ["192.0.2.1", "192.0.2.2"])
925     self.assertEqual(cl.CountPending(), 0)
926
927   def testNoMasterFilterNodeName(self):
928     cl = self._FakeClient()
929
930     def _CheckFilter(qfilter):
931       self.assertEqual(qfilter,
932         [qlang.OP_AND,
933          [qlang.OP_OR] + [[qlang.OP_EQUAL, "name", name]
934                           for name in ["node2", "node3"]],
935          [qlang.OP_NOT, [qlang.OP_TRUE, "master"]]])
936       return True
937
938     cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
939       [(constants.RS_NORMAL, "node2.example.com"),
940        (constants.RS_NORMAL, False),
941        (constants.RS_NORMAL, "192.0.2.12")],
942       [(constants.RS_NORMAL, "node3.example.com"),
943        (constants.RS_NORMAL, False),
944        (constants.RS_NORMAL, "192.0.2.13")],
945       ])
946     self.assertEqual(cli.GetOnlineNodes(["node2", "node3"], cl=cl,
947                                         secondary_ips=True, filter_master=True),
948                      ["192.0.2.12", "192.0.2.13"])
949     self.assertEqual(cl.CountPending(), 0)
950
951   def testOfflineNodes(self):
952     cl = self._FakeClient()
953
954     cl.AddQueryResult(["name", "offline", "sip"], None, [
955       [(constants.RS_NORMAL, "master.example.com"),
956        (constants.RS_NORMAL, False),
957        (constants.RS_NORMAL, "192.0.2.1")],
958       [(constants.RS_NORMAL, "node2.example.com"),
959        (constants.RS_NORMAL, True),
960        (constants.RS_NORMAL, "192.0.2.2")],
961       [(constants.RS_NORMAL, "node3.example.com"),
962        (constants.RS_NORMAL, True),
963        (constants.RS_NORMAL, "192.0.2.3")],
964       ])
965     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nowarn=True),
966                      ["master.example.com"])
967     self.assertEqual(cl.CountPending(), 0)
968
969   def testNodeGroup(self):
970     cl = self._FakeClient()
971
972     def _CheckFilter(qfilter):
973       self.assertEqual(qfilter,
974         [qlang.OP_OR, [qlang.OP_EQUAL, "group", "foobar"],
975                       [qlang.OP_EQUAL, "group.uuid", "foobar"]])
976       return True
977
978     cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
979       [(constants.RS_NORMAL, "master.example.com"),
980        (constants.RS_NORMAL, False),
981        (constants.RS_NORMAL, "192.0.2.1")],
982       [(constants.RS_NORMAL, "node3.example.com"),
983        (constants.RS_NORMAL, False),
984        (constants.RS_NORMAL, "192.0.2.3")],
985       ])
986     self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nodegroup="foobar"),
987                      ["master.example.com", "node3.example.com"])
988     self.assertEqual(cl.CountPending(), 0)
989
990
991 class TestFormatTimestamp(unittest.TestCase):
992   def testGood(self):
993     self.assertEqual(cli.FormatTimestamp((0, 1)),
994                      time.strftime("%F %T", time.localtime(0)) + ".000001")
995     self.assertEqual(cli.FormatTimestamp((1332944009, 17376)),
996                      (time.strftime("%F %T", time.localtime(1332944009)) +
997                       ".017376"))
998
999   def testWrong(self):
1000     for i in [0, [], {}, "", [1]]:
1001       self.assertEqual(cli.FormatTimestamp(i), "?")
1002
1003
1004 class TestFormatUsage(unittest.TestCase):
1005   def test(self):
1006     binary = "gnt-unittest"
1007     commands = {
1008       "cmdA":
1009         (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
1010          "description of A"),
1011       "bbb":
1012         (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
1013          "Hello World," * 10),
1014       "longname":
1015         (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
1016          "Another description"),
1017       }
1018
1019     self.assertEqual(list(cli._FormatUsage(binary, commands)), [
1020       "Usage: gnt-unittest {command} [options...] [argument...]",
1021       "gnt-unittest <command> --help to see details, or man gnt-unittest",
1022       "",
1023       "Commands:",
1024       (" bbb      - Hello World,Hello World,Hello World,Hello World,Hello"
1025        " World,Hello"),
1026       "            World,Hello World,Hello World,Hello World,Hello World,",
1027       " cmdA     - description of A",
1028       " longname - Another description",
1029       "",
1030       ])
1031
1032
1033 class TestParseArgs(unittest.TestCase):
1034   def testNoArguments(self):
1035     for argv in [[], ["gnt-unittest"]]:
1036       try:
1037         cli._ParseArgs("gnt-unittest", argv, {}, {}, set())
1038       except cli._ShowUsage, err:
1039         self.assertTrue(err.exit_error)
1040       else:
1041         self.fail("Did not raise exception")
1042
1043   def testVersion(self):
1044     for argv in [["test", "--version"], ["test", "--version", "somethingelse"]]:
1045       try:
1046         cli._ParseArgs("test", argv, {}, {}, set())
1047       except cli._ShowVersion:
1048         pass
1049       else:
1050         self.fail("Did not raise exception")
1051
1052   def testHelp(self):
1053     for argv in [["test", "--help"], ["test", "--help", "somethingelse"]]:
1054       try:
1055         cli._ParseArgs("test", argv, {}, {}, set())
1056       except cli._ShowUsage, err:
1057         self.assertFalse(err.exit_error)
1058       else:
1059         self.fail("Did not raise exception")
1060
1061   def testUnknownCommandOrAlias(self):
1062     for argv in [["test", "list"], ["test", "somethingelse", "--help"]]:
1063       try:
1064         cli._ParseArgs("test", argv, {}, {}, set())
1065       except cli._ShowUsage, err:
1066         self.assertTrue(err.exit_error)
1067       else:
1068         self.fail("Did not raise exception")
1069
1070   def testInvalidAliasList(self):
1071     cmd = {
1072       "list": NotImplemented,
1073       "foo": NotImplemented,
1074       }
1075     aliases = {
1076       "list": NotImplemented,
1077       "foo": NotImplemented,
1078       }
1079     assert sorted(cmd.keys()) == sorted(aliases.keys())
1080     self.assertRaises(AssertionError, cli._ParseArgs, "test",
1081                       ["test", "list"], cmd, aliases, set())
1082
1083   def testAliasForNonExistantCommand(self):
1084     cmd = {}
1085     aliases = {
1086       "list": NotImplemented,
1087       }
1088     self.assertRaises(errors.ProgrammerError, cli._ParseArgs, "test",
1089                       ["test", "list"], cmd, aliases, set())
1090
1091
1092 class TestQftNames(unittest.TestCase):
1093   def testComplete(self):
1094     self.assertEqual(frozenset(cli._QFT_NAMES), constants.QFT_ALL)
1095
1096   def testUnique(self):
1097     lcnames = map(lambda s: s.lower(), cli._QFT_NAMES.values())
1098     self.assertFalse(utils.FindDuplicates(lcnames))
1099
1100   def testUppercase(self):
1101     for name in cli._QFT_NAMES.values():
1102       self.assertEqual(name[0], name[0].upper())
1103
1104
1105 class TestFieldDescValues(unittest.TestCase):
1106   def testKnownKind(self):
1107     fdef = objects.QueryFieldDefinition(name="aname",
1108                                         title="Atitle",
1109                                         kind=constants.QFT_TEXT,
1110                                         doc="aaa doc aaa")
1111     self.assertEqual(cli._FieldDescValues(fdef),
1112                      ["aname", "Text", "Atitle", "aaa doc aaa"])
1113
1114   def testUnknownKind(self):
1115     kind = "#foo#"
1116
1117     self.assertFalse(kind in constants.QFT_ALL)
1118     self.assertFalse(kind in cli._QFT_NAMES)
1119
1120     fdef = objects.QueryFieldDefinition(name="zname", title="Ztitle",
1121                                         kind=kind, doc="zzz doc zzz")
1122     self.assertEqual(cli._FieldDescValues(fdef),
1123                      ["zname", kind, "Ztitle", "zzz doc zzz"])
1124
1125
1126 class TestSerializeGenericInfo(unittest.TestCase):
1127   """Test case for cli._SerializeGenericInfo"""
1128   def _RunTest(self, data, expected):
1129     buf = StringIO()
1130     cli._SerializeGenericInfo(buf, data, 0)
1131     self.assertEqual(buf.getvalue(), expected)
1132
1133   def testSimple(self):
1134     test_samples = [
1135       ("abc", "abc\n"),
1136       ([], "\n"),
1137       ({}, "\n"),
1138       (["1", "2", "3"], "- 1\n- 2\n- 3\n"),
1139       ([("z", "26")], "z: 26\n"),
1140       ({"z": "26"}, "z: 26\n"),
1141       ([("z", "26"), ("a", "1")], "z: 26\na: 1\n"),
1142       ({"z": "26", "a": "1"}, "a: 1\nz: 26\n"),
1143       ]
1144     for (data, expected) in test_samples:
1145       self._RunTest(data, expected)
1146
1147   def testLists(self):
1148     adict = {
1149       "aa": "11",
1150       "bb": "22",
1151       "cc": "33",
1152       }
1153     adict_exp = ("- aa: 11\n"
1154                  "  bb: 22\n"
1155                  "  cc: 33\n")
1156     anobj = [
1157       ("zz", "11"),
1158       ("ww", "33"),
1159       ("xx", "22"),
1160       ]
1161     anobj_exp = ("- zz: 11\n"
1162                  "  ww: 33\n"
1163                  "  xx: 22\n")
1164     alist = ["aa", "cc", "bb"]
1165     alist_exp = ("- - aa\n"
1166                  "  - cc\n"
1167                  "  - bb\n")
1168     test_samples = [
1169       (adict, adict_exp),
1170       (anobj, anobj_exp),
1171       (alist, alist_exp),
1172       ]
1173     for (base_data, base_expected) in test_samples:
1174       for k in range(1, 4):
1175         data = k * [base_data]
1176         expected = k * base_expected
1177         self._RunTest(data, expected)
1178
1179   def testDictionaries(self):
1180     data = [
1181       ("aaa", ["x", "y"]),
1182       ("bbb", {
1183           "w": "1",
1184           "z": "2",
1185           }),
1186       ("ccc", [
1187           ("xyz", "123"),
1188           ("efg", "456"),
1189           ]),
1190       ]
1191     expected = ("aaa: \n"
1192                 "  - x\n"
1193                 "  - y\n"
1194                 "bbb: \n"
1195                 "  w: 1\n"
1196                 "  z: 2\n"
1197                 "ccc: \n"
1198                 "  xyz: 123\n"
1199                 "  efg: 456\n")
1200     self._RunTest(data, expected)
1201     self._RunTest(dict(data), expected)
1202
1203
1204 class TestFormatPolicyInfo(unittest.TestCase):
1205   """Test case for cli.FormatPolicyInfo.
1206
1207   These tests rely on cli._SerializeGenericInfo (tested elsewhere).
1208
1209   """
1210   def setUp(self):
1211     # Policies are big, and we want to see the difference in case of an error
1212     self.maxDiff = None
1213
1214   def _RenameDictItem(self, parsed, old, new):
1215     self.assertTrue(old in parsed)
1216     self.assertTrue(new not in parsed)
1217     parsed[new] = parsed[old]
1218     del parsed[old]
1219
1220   def _TranslateParsedNames(self, parsed):
1221     for (pretty, raw) in [
1222       ("bounds specs", constants.ISPECS_MINMAX),
1223       ("allowed disk templates", constants.IPOLICY_DTS)
1224       ]:
1225       self._RenameDictItem(parsed, pretty, raw)
1226     for minmax in parsed[constants.ISPECS_MINMAX]:
1227       for key in minmax:
1228         keyparts = key.split("/", 1)
1229         if len(keyparts) > 1:
1230           self._RenameDictItem(minmax, key, keyparts[0])
1231     self.assertTrue(constants.IPOLICY_DTS in parsed)
1232     parsed[constants.IPOLICY_DTS] = yaml.load("[%s]" %
1233                                               parsed[constants.IPOLICY_DTS])
1234
1235   @staticmethod
1236   def _PrintAndParsePolicy(custom, effective, iscluster):
1237     formatted = cli.FormatPolicyInfo(custom, effective, iscluster)
1238     buf = StringIO()
1239     cli._SerializeGenericInfo(buf, formatted, 0)
1240     return yaml.load(buf.getvalue())
1241
1242   def _PrintAndCheckParsed(self, policy):
1243     parsed = self._PrintAndParsePolicy(policy, NotImplemented, True)
1244     self._TranslateParsedNames(parsed)
1245     self.assertEqual(parsed, policy)
1246
1247   def _CompareClusterGroupItems(self, cluster, group, skip=None):
1248     if isinstance(group, dict):
1249       self.assertTrue(isinstance(cluster, dict))
1250       if skip is None:
1251         skip = frozenset()
1252       self.assertEqual(frozenset(cluster.keys()).difference(skip),
1253                        frozenset(group.keys()))
1254       for key in group:
1255         self._CompareClusterGroupItems(cluster[key], group[key])
1256     elif isinstance(group, list):
1257       self.assertTrue(isinstance(cluster, list))
1258       self.assertEqual(len(cluster), len(group))
1259       for (cval, gval) in zip(cluster, group):
1260         self._CompareClusterGroupItems(cval, gval)
1261     else:
1262       self.assertTrue(isinstance(group, basestring))
1263       self.assertEqual("default (%s)" % cluster, group)
1264
1265   def _TestClusterVsGroup(self, policy):
1266     cluster = self._PrintAndParsePolicy(policy, NotImplemented, True)
1267     group = self._PrintAndParsePolicy({}, policy, False)
1268     self._CompareClusterGroupItems(cluster, group, ["std"])
1269
1270   def testWithDefaults(self):
1271     self._PrintAndCheckParsed(constants.IPOLICY_DEFAULTS)
1272     self._TestClusterVsGroup(constants.IPOLICY_DEFAULTS)
1273
1274
1275 class TestCreateIPolicyFromOpts(unittest.TestCase):
1276   """Test case for cli.CreateIPolicyFromOpts."""
1277   def setUp(self):
1278     # Policies are big, and we want to see the difference in case of an error
1279     self.maxDiff = None
1280
1281   def _RecursiveCheckMergedDicts(self, default_pol, diff_pol, merged_pol,
1282                                  merge_minmax=False):
1283     self.assertTrue(type(default_pol) is dict)
1284     self.assertTrue(type(diff_pol) is dict)
1285     self.assertTrue(type(merged_pol) is dict)
1286     self.assertEqual(frozenset(default_pol.keys()),
1287                      frozenset(merged_pol.keys()))
1288     for (key, val) in merged_pol.items():
1289       if key in diff_pol:
1290         if type(val) is dict:
1291           self._RecursiveCheckMergedDicts(default_pol[key], diff_pol[key], val)
1292         elif (merge_minmax and key == "minmax" and type(val) is list and
1293               len(val) == 1):
1294           self.assertEqual(len(default_pol[key]), 1)
1295           self.assertEqual(len(diff_pol[key]), 1)
1296           self._RecursiveCheckMergedDicts(default_pol[key][0],
1297                                           diff_pol[key][0], val[0])
1298         else:
1299           self.assertEqual(val, diff_pol[key])
1300       else:
1301         self.assertEqual(val, default_pol[key])
1302
1303   def testClusterPolicy(self):
1304     pol0 = cli.CreateIPolicyFromOpts(
1305       ispecs_mem_size={},
1306       ispecs_cpu_count={},
1307       ispecs_disk_count={},
1308       ispecs_disk_size={},
1309       ispecs_nic_count={},
1310       ipolicy_disk_templates=None,
1311       ipolicy_vcpu_ratio=None,
1312       ipolicy_spindle_ratio=None,
1313       fill_all=True
1314       )
1315     self.assertEqual(pol0, constants.IPOLICY_DEFAULTS)
1316
1317     exp_pol1 = {
1318       constants.ISPECS_MINMAX: [
1319         {
1320           constants.ISPECS_MIN: {
1321             constants.ISPEC_CPU_COUNT: 2,
1322             constants.ISPEC_DISK_COUNT: 1,
1323             },
1324           constants.ISPECS_MAX: {
1325             constants.ISPEC_MEM_SIZE: 12*1024,
1326             constants.ISPEC_DISK_COUNT: 2,
1327             },
1328           },
1329         ],
1330       constants.ISPECS_STD: {
1331         constants.ISPEC_CPU_COUNT: 2,
1332         constants.ISPEC_DISK_COUNT: 2,
1333         },
1334       constants.IPOLICY_VCPU_RATIO: 3.1,
1335       }
1336     pol1 = cli.CreateIPolicyFromOpts(
1337       ispecs_mem_size={"max": "12g"},
1338       ispecs_cpu_count={"min": 2, "std": 2},
1339       ispecs_disk_count={"min": 1, "max": 2, "std": 2},
1340       ispecs_disk_size={},
1341       ispecs_nic_count={},
1342       ipolicy_disk_templates=None,
1343       ipolicy_vcpu_ratio=3.1,
1344       ipolicy_spindle_ratio=None,
1345       fill_all=True
1346       )
1347     self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1348                                     exp_pol1, pol1, merge_minmax=True)
1349
1350     exp_pol2 = {
1351       constants.ISPECS_MINMAX: [
1352         {
1353           constants.ISPECS_MIN: {
1354             constants.ISPEC_DISK_SIZE: 512,
1355             constants.ISPEC_NIC_COUNT: 2,
1356             },
1357           constants.ISPECS_MAX: {
1358             constants.ISPEC_NIC_COUNT: 3,
1359             },
1360           },
1361         ],
1362       constants.ISPECS_STD: {
1363         constants.ISPEC_CPU_COUNT: 2,
1364         constants.ISPEC_NIC_COUNT: 3,
1365         },
1366       constants.IPOLICY_SPINDLE_RATIO: 1.3,
1367       constants.IPOLICY_DTS: ["templates"],
1368       }
1369     pol2 = cli.CreateIPolicyFromOpts(
1370       ispecs_mem_size={},
1371       ispecs_cpu_count={"std": 2},
1372       ispecs_disk_count={},
1373       ispecs_disk_size={"min": "0.5g"},
1374       ispecs_nic_count={"min": 2, "max": 3, "std": 3},
1375       ipolicy_disk_templates=["templates"],
1376       ipolicy_vcpu_ratio=None,
1377       ipolicy_spindle_ratio=1.3,
1378       fill_all=True
1379       )
1380     self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1381                                       exp_pol2, pol2, merge_minmax=True)
1382
1383     for fill_all in [False, True]:
1384       exp_pol3 = {
1385         constants.ISPECS_STD: {
1386           constants.ISPEC_CPU_COUNT: 2,
1387           constants.ISPEC_NIC_COUNT: 3,
1388           },
1389         }
1390       pol3 = cli.CreateIPolicyFromOpts(
1391         std_ispecs={
1392           constants.ISPEC_CPU_COUNT: "2",
1393           constants.ISPEC_NIC_COUNT: "3",
1394           },
1395         ipolicy_disk_templates=None,
1396         ipolicy_vcpu_ratio=None,
1397         ipolicy_spindle_ratio=None,
1398         fill_all=fill_all
1399         )
1400       if fill_all:
1401         self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1402                                         exp_pol3, pol3, merge_minmax=True)
1403       else:
1404         self.assertEqual(pol3, exp_pol3)
1405
1406   def testPartialPolicy(self):
1407     exp_pol0 = objects.MakeEmptyIPolicy()
1408     pol0 = cli.CreateIPolicyFromOpts(
1409       minmax_ispecs=None,
1410       std_ispecs=None,
1411       ipolicy_disk_templates=None,
1412       ipolicy_vcpu_ratio=None,
1413       ipolicy_spindle_ratio=None,
1414       fill_all=False
1415       )
1416     self.assertEqual(pol0, exp_pol0)
1417
1418     exp_pol1 = {
1419       constants.IPOLICY_VCPU_RATIO: 3.1,
1420       }
1421     pol1 = cli.CreateIPolicyFromOpts(
1422       minmax_ispecs=None,
1423       std_ispecs=None,
1424       ipolicy_disk_templates=None,
1425       ipolicy_vcpu_ratio=3.1,
1426       ipolicy_spindle_ratio=None,
1427       fill_all=False
1428       )
1429     self.assertEqual(pol1, exp_pol1)
1430
1431     exp_pol2 = {
1432       constants.IPOLICY_SPINDLE_RATIO: 1.3,
1433       constants.IPOLICY_DTS: ["templates"],
1434       }
1435     pol2 = cli.CreateIPolicyFromOpts(
1436       minmax_ispecs=None,
1437       std_ispecs=None,
1438       ipolicy_disk_templates=["templates"],
1439       ipolicy_vcpu_ratio=None,
1440       ipolicy_spindle_ratio=1.3,
1441       fill_all=False
1442       )
1443     self.assertEqual(pol2, exp_pol2)
1444
1445   def _TestInvalidISpecs(self, minmax_ispecs, std_ispecs, fail=True):
1446     for fill_all in [False, True]:
1447       if fail:
1448         self.assertRaises((errors.OpPrereqError,
1449                            errors.UnitParseError,
1450                            errors.TypeEnforcementError),
1451                           cli.CreateIPolicyFromOpts,
1452                           minmax_ispecs=minmax_ispecs,
1453                           std_ispecs=std_ispecs,
1454                           fill_all=fill_all)
1455       else:
1456         cli.CreateIPolicyFromOpts(minmax_ispecs=minmax_ispecs,
1457                                   std_ispecs=std_ispecs,
1458                                   fill_all=fill_all)
1459
1460   def testInvalidPolicies(self):
1461     self.assertRaises(AssertionError, cli.CreateIPolicyFromOpts,
1462                       std_ispecs={constants.ISPEC_MEM_SIZE: 1024},
1463                       ipolicy_disk_templates=None, ipolicy_vcpu_ratio=None,
1464                       ipolicy_spindle_ratio=None, group_ipolicy=True)
1465     self.assertRaises(errors.OpPrereqError, cli.CreateIPolicyFromOpts,
1466                       ispecs_mem_size={"wrong": "x"}, ispecs_cpu_count={},
1467                       ispecs_disk_count={}, ispecs_disk_size={},
1468                       ispecs_nic_count={}, ipolicy_disk_templates=None,
1469                       ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None,
1470                       fill_all=True)
1471     self.assertRaises(errors.TypeEnforcementError, cli.CreateIPolicyFromOpts,
1472                       ispecs_mem_size={}, ispecs_cpu_count={"min": "default"},
1473                       ispecs_disk_count={}, ispecs_disk_size={},
1474                       ispecs_nic_count={}, ipolicy_disk_templates=None,
1475                       ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None,
1476                       fill_all=True)
1477
1478     good_mmspecs = [
1479       constants.ISPECS_MINMAX_DEFAULTS,
1480       constants.ISPECS_MINMAX_DEFAULTS,
1481       ]
1482     self._TestInvalidISpecs(good_mmspecs, None, fail=False)
1483     broken_mmspecs = copy.deepcopy(good_mmspecs)
1484     for minmaxpair in broken_mmspecs:
1485       for key in constants.ISPECS_MINMAX_KEYS:
1486         for par in constants.ISPECS_PARAMETERS:
1487           old = minmaxpair[key][par]
1488           del minmaxpair[key][par]
1489           self._TestInvalidISpecs(broken_mmspecs, None)
1490           minmaxpair[key][par] = "invalid"
1491           self._TestInvalidISpecs(broken_mmspecs, None)
1492           minmaxpair[key][par] = old
1493         minmaxpair[key]["invalid_key"] = None
1494         self._TestInvalidISpecs(broken_mmspecs, None)
1495         del minmaxpair[key]["invalid_key"]
1496       minmaxpair["invalid_key"] = None
1497       self._TestInvalidISpecs(broken_mmspecs, None)
1498       del minmaxpair["invalid_key"]
1499       assert broken_mmspecs == good_mmspecs
1500
1501     good_stdspecs = constants.IPOLICY_DEFAULTS[constants.ISPECS_STD]
1502     self._TestInvalidISpecs(None, good_stdspecs, fail=False)
1503     broken_stdspecs = copy.deepcopy(good_stdspecs)
1504     for par in constants.ISPECS_PARAMETERS:
1505       old = broken_stdspecs[par]
1506       broken_stdspecs[par] = "invalid"
1507       self._TestInvalidISpecs(None, broken_stdspecs)
1508       broken_stdspecs[par] = old
1509     broken_stdspecs["invalid_key"] = None
1510     self._TestInvalidISpecs(None, broken_stdspecs)
1511     del broken_stdspecs["invalid_key"]
1512     assert broken_stdspecs == good_stdspecs
1513
1514   def testAllowedValues(self):
1515     allowedv = "blah"
1516     exp_pol1 = {
1517       constants.ISPECS_MINMAX: allowedv,
1518       constants.IPOLICY_DTS: allowedv,
1519       constants.IPOLICY_VCPU_RATIO: allowedv,
1520       constants.IPOLICY_SPINDLE_RATIO: allowedv,
1521       }
1522     pol1 = cli.CreateIPolicyFromOpts(minmax_ispecs=[{allowedv: {}}],
1523                                      std_ispecs=None,
1524                                      ipolicy_disk_templates=allowedv,
1525                                      ipolicy_vcpu_ratio=allowedv,
1526                                      ipolicy_spindle_ratio=allowedv,
1527                                      allowed_values=[allowedv])
1528     self.assertEqual(pol1, exp_pol1)
1529
1530   @staticmethod
1531   def _ConvertSpecToStrings(spec):
1532     ret = {}
1533     for (par, val) in spec.items():
1534       ret[par] = str(val)
1535     return ret
1536
1537   def _CheckNewStyleSpecsCall(self, exp_ipolicy, minmax_ispecs, std_ispecs,
1538                               group_ipolicy, fill_all):
1539     ipolicy = cli.CreateIPolicyFromOpts(minmax_ispecs=minmax_ispecs,
1540                                         std_ispecs=std_ispecs,
1541                                         group_ipolicy=group_ipolicy,
1542                                         fill_all=fill_all)
1543     self.assertEqual(ipolicy, exp_ipolicy)
1544
1545   def _TestFullISpecsInner(self, skel_exp_ipol, exp_minmax, exp_std,
1546                            group_ipolicy, fill_all):
1547     exp_ipol = skel_exp_ipol.copy()
1548     if exp_minmax is not None:
1549       minmax_ispecs = []
1550       for exp_mm_pair in exp_minmax:
1551         mmpair = {}
1552         for (key, spec) in exp_mm_pair.items():
1553           mmpair[key] = self._ConvertSpecToStrings(spec)
1554         minmax_ispecs.append(mmpair)
1555       exp_ipol[constants.ISPECS_MINMAX] = exp_minmax
1556     else:
1557       minmax_ispecs = None
1558     if exp_std is not None:
1559       std_ispecs = self._ConvertSpecToStrings(exp_std)
1560       exp_ipol[constants.ISPECS_STD] = exp_std
1561     else:
1562       std_ispecs = None
1563
1564     self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1565                                  group_ipolicy, fill_all)
1566     if minmax_ispecs:
1567       for mmpair in minmax_ispecs:
1568         for (key, spec) in mmpair.items():
1569           for par in [constants.ISPEC_MEM_SIZE, constants.ISPEC_DISK_SIZE]:
1570             if par in spec:
1571               spec[par] += "m"
1572               self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1573                                            group_ipolicy, fill_all)
1574     if std_ispecs:
1575       for par in [constants.ISPEC_MEM_SIZE, constants.ISPEC_DISK_SIZE]:
1576         if par in std_ispecs:
1577           std_ispecs[par] += "m"
1578           self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1579                                        group_ipolicy, fill_all)
1580
1581   def testFullISpecs(self):
1582     exp_minmax1 = [
1583       {
1584         constants.ISPECS_MIN: {
1585           constants.ISPEC_MEM_SIZE: 512,
1586           constants.ISPEC_CPU_COUNT: 2,
1587           constants.ISPEC_DISK_COUNT: 2,
1588           constants.ISPEC_DISK_SIZE: 512,
1589           constants.ISPEC_NIC_COUNT: 2,
1590           constants.ISPEC_SPINDLE_USE: 2,
1591           },
1592         constants.ISPECS_MAX: {
1593           constants.ISPEC_MEM_SIZE: 768*1024,
1594           constants.ISPEC_CPU_COUNT: 7,
1595           constants.ISPEC_DISK_COUNT: 6,
1596           constants.ISPEC_DISK_SIZE: 2048*1024,
1597           constants.ISPEC_NIC_COUNT: 3,
1598           constants.ISPEC_SPINDLE_USE: 3,
1599           },
1600         },
1601       ]
1602     exp_minmax2 = [
1603       {
1604         constants.ISPECS_MIN: {
1605           constants.ISPEC_MEM_SIZE: 512,
1606           constants.ISPEC_CPU_COUNT: 2,
1607           constants.ISPEC_DISK_COUNT: 2,
1608           constants.ISPEC_DISK_SIZE: 512,
1609           constants.ISPEC_NIC_COUNT: 2,
1610           constants.ISPEC_SPINDLE_USE: 2,
1611           },
1612         constants.ISPECS_MAX: {
1613           constants.ISPEC_MEM_SIZE: 768*1024,
1614           constants.ISPEC_CPU_COUNT: 7,
1615           constants.ISPEC_DISK_COUNT: 6,
1616           constants.ISPEC_DISK_SIZE: 2048*1024,
1617           constants.ISPEC_NIC_COUNT: 3,
1618           constants.ISPEC_SPINDLE_USE: 3,
1619           },
1620         },
1621       {
1622         constants.ISPECS_MIN: {
1623           constants.ISPEC_MEM_SIZE: 1024*1024,
1624           constants.ISPEC_CPU_COUNT: 3,
1625           constants.ISPEC_DISK_COUNT: 3,
1626           constants.ISPEC_DISK_SIZE: 256,
1627           constants.ISPEC_NIC_COUNT: 4,
1628           constants.ISPEC_SPINDLE_USE: 5,
1629           },
1630         constants.ISPECS_MAX: {
1631           constants.ISPEC_MEM_SIZE: 2048*1024,
1632           constants.ISPEC_CPU_COUNT: 5,
1633           constants.ISPEC_DISK_COUNT: 5,
1634           constants.ISPEC_DISK_SIZE: 1024*1024,
1635           constants.ISPEC_NIC_COUNT: 5,
1636           constants.ISPEC_SPINDLE_USE: 7,
1637           },
1638         },
1639       ]
1640     exp_std1 = {
1641       constants.ISPEC_MEM_SIZE: 768*1024,
1642       constants.ISPEC_CPU_COUNT: 7,
1643       constants.ISPEC_DISK_COUNT: 6,
1644       constants.ISPEC_DISK_SIZE: 2048*1024,
1645       constants.ISPEC_NIC_COUNT: 3,
1646       constants.ISPEC_SPINDLE_USE: 1,
1647       }
1648     for fill_all in [False, True]:
1649       if fill_all:
1650         skel_ipolicy = constants.IPOLICY_DEFAULTS
1651       else:
1652         skel_ipolicy = {}
1653       self._TestFullISpecsInner(skel_ipolicy, None, exp_std1,
1654                                 False, fill_all)
1655       for exp_minmax in [exp_minmax1, exp_minmax2]:
1656         self._TestFullISpecsInner(skel_ipolicy, exp_minmax, exp_std1,
1657                                   False, fill_all)
1658         self._TestFullISpecsInner(skel_ipolicy, exp_minmax, None,
1659                                   False, fill_all)
1660
1661
1662 class TestPrintIPolicyCommand(unittest.TestCase):
1663   """Test case for cli.PrintIPolicyCommand"""
1664   _SPECS1 = {
1665     "par1": 42,
1666     "par2": "xyz",
1667     }
1668   _SPECS1_STR = "par1=42,par2=xyz"
1669   _SPECS2 = {
1670     "param": 10,
1671     "another_param": 101,
1672     }
1673   _SPECS2_STR = "another_param=101,param=10"
1674   _SPECS3 = {
1675     "par1": 1024,
1676     "param": "abc",
1677     }
1678   _SPECS3_STR = "par1=1024,param=abc"
1679
1680   def _CheckPrintIPolicyCommand(self, ipolicy, isgroup, expected):
1681     buf = StringIO()
1682     cli.PrintIPolicyCommand(buf, ipolicy, isgroup)
1683     self.assertEqual(buf.getvalue(), expected)
1684
1685   def testIgnoreStdForGroup(self):
1686     self._CheckPrintIPolicyCommand({"std": self._SPECS1}, True, "")
1687
1688   def testIgnoreEmpty(self):
1689     policies = [
1690       {},
1691       {"std": {}},
1692       {"minmax": []},
1693       {"minmax": [{}]},
1694       {"minmax": [{
1695         "min": {},
1696         "max": {},
1697         }]},
1698       {"minmax": [{
1699         "min": self._SPECS1,
1700         "max": {},
1701         }]},
1702       ]
1703     for pol in policies:
1704       self._CheckPrintIPolicyCommand(pol, False, "")
1705
1706   def testFullPolicies(self):
1707     cases = [
1708       ({"std": self._SPECS1},
1709        " %s %s" % (cli.IPOLICY_STD_SPECS_STR, self._SPECS1_STR)),
1710       ({"minmax": [{
1711         "min": self._SPECS1,
1712         "max": self._SPECS2,
1713         }]},
1714        " %s min:%s/max:%s" % (cli.IPOLICY_BOUNDS_SPECS_STR,
1715                               self._SPECS1_STR, self._SPECS2_STR)),
1716       ({"minmax": [
1717         {
1718           "min": self._SPECS1,
1719           "max": self._SPECS2,
1720           },
1721         {
1722           "min": self._SPECS2,
1723           "max": self._SPECS3,
1724           },
1725         ]},
1726        " %s min:%s/max:%s//min:%s/max:%s" %
1727        (cli.IPOLICY_BOUNDS_SPECS_STR, self._SPECS1_STR, self._SPECS2_STR,
1728         self._SPECS2_STR, self._SPECS3_STR)),
1729       ]
1730     for (pol, exp) in cases:
1731       self._CheckPrintIPolicyCommand(pol, False, exp)
1732
1733
1734 if __name__ == "__main__":
1735   testutils.GanetiTestProgram()