Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.cli_unittest.py @ 41044e04

History | View | Annotate | Download (53 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2008, 2011, 2012, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the cli module"""
23

    
24
import copy
25
import testutils
26
import time
27
import unittest
28
from cStringIO import StringIO
29

    
30
from ganeti import constants
31
from ganeti import cli
32
from ganeti import errors
33
from ganeti import utils
34
from ganeti import objects
35
from ganeti import qlang
36
from ganeti.errors import OpPrereqError, ParameterError
37

    
38

    
39
class TestParseTimespec(unittest.TestCase):
40
  """Testing case for ParseTimespec"""
41

    
42
  def testValidTimes(self):
43
    """Test valid timespecs"""
44
    test_data = [
45
      ("1s", 1),
46
      ("1", 1),
47
      ("1m", 60),
48
      ("1h", 60 * 60),
49
      ("1d", 60 * 60 * 24),
50
      ("1w", 60 * 60 * 24 * 7),
51
      ("4h", 4 * 60 * 60),
52
      ("61m", 61 * 60),
53
      ]
54
    for value, expected_result in test_data:
55
      self.failUnlessEqual(cli.ParseTimespec(value), expected_result)
56

    
57
  def testInvalidTime(self):
58
    """Test invalid timespecs"""
59
    test_data = [
60
      "1y",
61
      "",
62
      "aaa",
63
      "s",
64
      ]
65
    for value in test_data:
66
      self.failUnlessRaises(OpPrereqError, cli.ParseTimespec, value)
67

    
68

    
69
class TestSplitKeyVal(unittest.TestCase):
70
  """Testing case for cli._SplitKeyVal"""
71
  DATA = "a=b,c,no_d,-e"
72
  RESULT = {"a": "b", "c": True, "d": False, "e": None}
73
  RESULT_NOPREFIX = {"a": "b", "c": {}, "no_d": {}, "-e": {}}
74

    
75
  def testSplitKeyVal(self):
76
    """Test splitting"""
77
    self.failUnlessEqual(cli._SplitKeyVal("option", self.DATA, True),
78
                         self.RESULT)
79

    
80
  def testDuplicateParam(self):
81
    """Test duplicate parameters"""
82
    for data in ("a=1,a=2", "a,no_a"):
83
      self.failUnlessRaises(ParameterError, cli._SplitKeyVal,
84
                            "option", data, True)
85

    
86
  def testEmptyData(self):
87
    """Test how we handle splitting an empty string"""
88
    self.failUnlessEqual(cli._SplitKeyVal("option", "", True), {})
89

    
90

    
91
class TestIdentKeyVal(unittest.TestCase):
92
  """Testing case for cli.check_ident_key_val"""
93

    
94
  def testIdentKeyVal(self):
95
    """Test identkeyval"""
96
    def cikv(value):
97
      return cli.check_ident_key_val("option", "opt", value)
98

    
99
    self.assertEqual(cikv("foo:bar"), ("foo", {"bar": True}))
100
    self.assertEqual(cikv("foo:bar=baz"), ("foo", {"bar": "baz"}))
101
    self.assertEqual(cikv("bar:b=c,c=a"), ("bar", {"b": "c", "c": "a"}))
102
    self.assertEqual(cikv("no_bar"), ("bar", False))
103
    self.assertRaises(ParameterError, cikv, "no_bar:foo")
104
    self.assertRaises(ParameterError, cikv, "no_bar:foo=baz")
105
    self.assertRaises(ParameterError, cikv, "bar:foo=baz,foo=baz")
106
    self.assertEqual(cikv("-foo"), ("foo", None))
107
    self.assertRaises(ParameterError, cikv, "-foo:a=c")
108

    
109
    # Check negative numbers
110
    self.assertEqual(cikv("-1:remove"), ("-1", {
111
      "remove": True,
112
      }))
113
    self.assertEqual(cikv("-29447:add,size=4G"), ("-29447", {
114
      "add": True,
115
      "size": "4G",
116
      }))
117
    for i in ["-:", "-"]:
118
      self.assertEqual(cikv(i), ("", None))
119

    
120
  @staticmethod
121
  def _csikv(value):
122
    return cli._SplitIdentKeyVal("opt", value, False)
123

    
124
  def testIdentKeyValNoPrefix(self):
125
    """Test identkeyval without prefixes"""
126
    test_cases = [
127
      ("foo:bar", None),
128
      ("foo:no_bar", None),
129
      ("foo:bar=baz,bar=baz", None),
130
      ("foo",
131
       ("foo", {})),
132
      ("foo:bar=baz",
133
       ("foo", {"bar": "baz"})),
134
      ("no_foo:-1=baz,no_op=3",
135
       ("no_foo", {"-1": "baz", "no_op": "3"})),
136
      ]
137
    for (arg, res) in test_cases:
138
      if res is None:
139
        self.assertRaises(ParameterError, self._csikv, arg)
140
      else:
141
        self.assertEqual(self._csikv(arg), res)
142

    
143

    
144
class TestListIdentKeyVal(unittest.TestCase):
145
  """Test for cli.check_list_ident_key_val()"""
146

    
147
  @staticmethod
148
  def _clikv(value):
149
    return cli.check_list_ident_key_val("option", "opt", value)
150

    
151
  def testListIdentKeyVal(self):
152
    test_cases = [
153
      ("",
154
       None),
155
      ("foo",
156
       {"foo": {}}),
157
      ("foo:bar=baz",
158
       {"foo": {"bar": "baz"}}),
159
      ("foo:bar=baz/foo:bat=bad",
160
       None),
161
      ("foo:abc=42/bar:def=11",
162
       {"foo": {"abc": "42"},
163
        "bar": {"def": "11"}}),
164
      ("foo:abc=42/bar:def=11,ghi=07",
165
       {"foo": {"abc": "42"},
166
        "bar": {"def": "11", "ghi": "07"}}),
167
      ]
168
    for (arg, res) in test_cases:
169
      if res is None:
170
        self.assertRaises(ParameterError, self._clikv, arg)
171
      else:
172
        self.assertEqual(res, self._clikv(arg))
173

    
174

    
175
class TestToStream(unittest.TestCase):
176
  """Test the ToStream functions"""
177

    
178
  def testBasic(self):
179
    for data in ["foo",
180
                 "foo %s",
181
                 "foo %(test)s",
182
                 "foo %s %s",
183
                 "",
184
                 ]:
185
      buf = StringIO()
186
      cli._ToStream(buf, data)
187
      self.failUnlessEqual(buf.getvalue(), data + "\n")
188

    
189
  def testParams(self):
190
      buf = StringIO()
191
      cli._ToStream(buf, "foo %s", 1)
192
      self.failUnlessEqual(buf.getvalue(), "foo 1\n")
193
      buf = StringIO()
194
      cli._ToStream(buf, "foo %s", (15,16))
195
      self.failUnlessEqual(buf.getvalue(), "foo (15, 16)\n")
196
      buf = StringIO()
197
      cli._ToStream(buf, "foo %s %s", "a", "b")
198
      self.failUnlessEqual(buf.getvalue(), "foo a b\n")
199

    
200

    
201
class TestGenerateTable(unittest.TestCase):
202
  HEADERS = dict([("f%s" % i, "Field%s" % i) for i in range(5)])
203

    
204
  FIELDS1 = ["f1", "f2"]
205
  DATA1 = [
206
    ["abc", 1234],
207
    ["foobar", 56],
208
    ["b", -14],
209
    ]
210

    
211
  def _test(self, headers, fields, separator, data,
212
            numfields, unitfields, units, expected):
213
    table = cli.GenerateTable(headers, fields, separator, data,
214
                              numfields=numfields, unitfields=unitfields,
215
                              units=units)
216
    self.assertEqual(table, expected)
217

    
218
  def testPlain(self):
219
    exp = [
220
      "Field1 Field2",
221
      "abc    1234",
222
      "foobar 56",
223
      "b      -14",
224
      ]
225
    self._test(self.HEADERS, self.FIELDS1, None, self.DATA1,
226
               None, None, "m", exp)
227

    
228
  def testNoFields(self):
229
    self._test(self.HEADERS, [], None, [[], []],
230
               None, None, "m", ["", "", ""])
231
    self._test(None, [], None, [[], []],
232
               None, None, "m", ["", ""])
233

    
234
  def testSeparator(self):
235
    for sep in ["#", ":", ",", "^", "!", "%", "|", "###", "%%", "!!!", "||"]:
236
      exp = [
237
        "Field1%sField2" % sep,
238
        "abc%s1234" % sep,
239
        "foobar%s56" % sep,
240
        "b%s-14" % sep,
241
        ]
242
      self._test(self.HEADERS, self.FIELDS1, sep, self.DATA1,
243
                 None, None, "m", exp)
244

    
245
  def testNoHeader(self):
246
    exp = [
247
      "abc    1234",
248
      "foobar 56",
249
      "b      -14",
250
      ]
251
    self._test(None, self.FIELDS1, None, self.DATA1,
252
               None, None, "m", exp)
253

    
254
  def testUnknownField(self):
255
    headers = {
256
      "f1": "Field1",
257
      }
258
    exp = [
259
      "Field1 UNKNOWN",
260
      "abc    1234",
261
      "foobar 56",
262
      "b      -14",
263
      ]
264
    self._test(headers, ["f1", "UNKNOWN"], None, self.DATA1,
265
               None, None, "m", exp)
266

    
267
  def testNumfields(self):
268
    fields = ["f1", "f2", "f3"]
269
    data = [
270
      ["abc", 1234, 0],
271
      ["foobar", 56, 3],
272
      ["b", -14, "-"],
273
      ]
274
    exp = [
275
      "Field1 Field2 Field3",
276
      "abc      1234      0",
277
      "foobar     56      3",
278
      "b         -14      -",
279
      ]
280
    self._test(self.HEADERS, fields, None, data,
281
               ["f2", "f3"], None, "m", exp)
282

    
283
  def testUnitfields(self):
284
    expnosep = [
285
      "Field1 Field2 Field3",
286
      "abc      1234     0M",
287
      "foobar     56     3M",
288
      "b         -14      -",
289
      ]
290

    
291
    expsep = [
292
      "Field1:Field2:Field3",
293
      "abc:1234:0M",
294
      "foobar:56:3M",
295
      "b:-14:-",
296
      ]
297

    
298
    for sep, expected in [(None, expnosep), (":", expsep)]:
299
      fields = ["f1", "f2", "f3"]
300
      data = [
301
        ["abc", 1234, 0],
302
        ["foobar", 56, 3],
303
        ["b", -14, "-"],
304
        ]
305
      self._test(self.HEADERS, fields, sep, data,
306
                 ["f2", "f3"], ["f3"], "h", expected)
307

    
308
  def testUnusual(self):
309
    data = [
310
      ["%", "xyz"],
311
      ["%%", "abc"],
312
      ]
313
    exp = [
314
      "Field1 Field2",
315
      "%      xyz",
316
      "%%     abc",
317
      ]
318
    self._test(self.HEADERS, ["f1", "f2"], None, data,
319
               None, None, "m", exp)
320

    
321

    
322
class TestFormatQueryResult(unittest.TestCase):
323
  def test(self):
324
    fields = [
325
      objects.QueryFieldDefinition(name="name", title="Name",
326
                                   kind=constants.QFT_TEXT),
327
      objects.QueryFieldDefinition(name="size", title="Size",
328
                                   kind=constants.QFT_NUMBER),
329
      objects.QueryFieldDefinition(name="act", title="Active",
330
                                   kind=constants.QFT_BOOL),
331
      objects.QueryFieldDefinition(name="mem", title="Memory",
332
                                   kind=constants.QFT_UNIT),
333
      objects.QueryFieldDefinition(name="other", title="SomeList",
334
                                   kind=constants.QFT_OTHER),
335
      ]
336

    
337
    response = objects.QueryResponse(fields=fields, data=[
338
      [(constants.RS_NORMAL, "nodeA"), (constants.RS_NORMAL, 128),
339
       (constants.RS_NORMAL, False), (constants.RS_NORMAL, 1468006),
340
       (constants.RS_NORMAL, [])],
341
      [(constants.RS_NORMAL, "other"), (constants.RS_NORMAL, 512),
342
       (constants.RS_NORMAL, True), (constants.RS_NORMAL, 16),
343
       (constants.RS_NORMAL, [1, 2, 3])],
344
      [(constants.RS_NORMAL, "xyz"), (constants.RS_NORMAL, 1024),
345
       (constants.RS_NORMAL, True), (constants.RS_NORMAL, 4096),
346
       (constants.RS_NORMAL, [{}, {}])],
347
      ])
348

    
349
    self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True),
350
      (cli.QR_NORMAL, [
351
      "Name  Size Active Memory SomeList",
352
      "nodeA  128 N        1.4T []",
353
      "other  512 Y         16M [1, 2, 3]",
354
      "xyz   1024 Y        4.0G [{}, {}]",
355
      ]))
356

    
357
  def testTimestampAndUnit(self):
358
    fields = [
359
      objects.QueryFieldDefinition(name="name", title="Name",
360
                                   kind=constants.QFT_TEXT),
361
      objects.QueryFieldDefinition(name="size", title="Size",
362
                                   kind=constants.QFT_UNIT),
363
      objects.QueryFieldDefinition(name="mtime", title="ModTime",
364
                                   kind=constants.QFT_TIMESTAMP),
365
      ]
366

    
367
    response = objects.QueryResponse(fields=fields, data=[
368
      [(constants.RS_NORMAL, "a"), (constants.RS_NORMAL, 1024),
369
       (constants.RS_NORMAL, 0)],
370
      [(constants.RS_NORMAL, "b"), (constants.RS_NORMAL, 144996),
371
       (constants.RS_NORMAL, 1291746295)],
372
      ])
373

    
374
    self.assertEqual(cli.FormatQueryResult(response, unit="m", header=True),
375
      (cli.QR_NORMAL, [
376
      "Name   Size ModTime",
377
      "a      1024 %s" % utils.FormatTime(0),
378
      "b    144996 %s" % utils.FormatTime(1291746295),
379
      ]))
380

    
381
  def testOverride(self):
382
    fields = [
383
      objects.QueryFieldDefinition(name="name", title="Name",
384
                                   kind=constants.QFT_TEXT),
385
      objects.QueryFieldDefinition(name="cust", title="Custom",
386
                                   kind=constants.QFT_OTHER),
387
      objects.QueryFieldDefinition(name="xt", title="XTime",
388
                                   kind=constants.QFT_TIMESTAMP),
389
      ]
390

    
391
    response = objects.QueryResponse(fields=fields, data=[
392
      [(constants.RS_NORMAL, "x"), (constants.RS_NORMAL, ["a", "b", "c"]),
393
       (constants.RS_NORMAL, 1234)],
394
      [(constants.RS_NORMAL, "y"), (constants.RS_NORMAL, range(10)),
395
       (constants.RS_NORMAL, 1291746295)],
396
      ])
397

    
398
    override = {
399
      "cust": (utils.CommaJoin, False),
400
      "xt": (hex, True),
401
      }
402

    
403
    self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True,
404
                                           format_override=override),
405
      (cli.QR_NORMAL, [
406
      "Name Custom                            XTime",
407
      "x    a, b, c                           0x4d2",
408
      "y    0, 1, 2, 3, 4, 5, 6, 7, 8, 9 0x4cfe7bf7",
409
      ]))
410

    
411
  def testSeparator(self):
412
    fields = [
413
      objects.QueryFieldDefinition(name="name", title="Name",
414
                                   kind=constants.QFT_TEXT),
415
      objects.QueryFieldDefinition(name="count", title="Count",
416
                                   kind=constants.QFT_NUMBER),
417
      objects.QueryFieldDefinition(name="desc", title="Description",
418
                                   kind=constants.QFT_TEXT),
419
      ]
420

    
421
    response = objects.QueryResponse(fields=fields, data=[
422
      [(constants.RS_NORMAL, "instance1.example.com"),
423
       (constants.RS_NORMAL, 21125), (constants.RS_NORMAL, "Hello World!")],
424
      [(constants.RS_NORMAL, "mail.other.net"),
425
       (constants.RS_NORMAL, -9000), (constants.RS_NORMAL, "a,b,c")],
426
      ])
427

    
428
    for sep in [":", "|", "#", "|||", "###", "@@@", "@#@"]:
429
      for header in [None, "Name%sCount%sDescription" % (sep, sep)]:
430
        exp = []
431
        if header:
432
          exp.append(header)
433
        exp.extend([
434
          "instance1.example.com%s21125%sHello World!" % (sep, sep),
435
          "mail.other.net%s-9000%sa,b,c" % (sep, sep),
436
          ])
437

    
438
        self.assertEqual(cli.FormatQueryResult(response, separator=sep,
439
                                               header=bool(header)),
440
                         (cli.QR_NORMAL, exp))
441

    
442
  def testStatusWithUnknown(self):
443
    fields = [
444
      objects.QueryFieldDefinition(name="id", title="ID",
445
                                   kind=constants.QFT_NUMBER),
446
      objects.QueryFieldDefinition(name="unk", title="unk",
447
                                   kind=constants.QFT_UNKNOWN),
448
      objects.QueryFieldDefinition(name="unavail", title="Unavail",
449
                                   kind=constants.QFT_BOOL),
450
      objects.QueryFieldDefinition(name="nodata", title="NoData",
451
                                   kind=constants.QFT_TEXT),
452
      objects.QueryFieldDefinition(name="offline", title="OffLine",
453
                                   kind=constants.QFT_TEXT),
454
      ]
455

    
456
    response = objects.QueryResponse(fields=fields, data=[
457
      [(constants.RS_NORMAL, 1), (constants.RS_UNKNOWN, None),
458
       (constants.RS_NORMAL, False), (constants.RS_NORMAL, ""),
459
       (constants.RS_OFFLINE, None)],
460
      [(constants.RS_NORMAL, 2), (constants.RS_UNKNOWN, None),
461
       (constants.RS_NODATA, None), (constants.RS_NORMAL, "x"),
462
       (constants.RS_OFFLINE, None)],
463
      [(constants.RS_NORMAL, 3), (constants.RS_UNKNOWN, None),
464
       (constants.RS_NORMAL, False), (constants.RS_UNAVAIL, None),
465
       (constants.RS_OFFLINE, None)],
466
      ])
467

    
468
    self.assertEqual(cli.FormatQueryResult(response, header=True,
469
                                           separator="|", verbose=True),
470
      (cli.QR_UNKNOWN, [
471
      "ID|unk|Unavail|NoData|OffLine",
472
      "1|(unknown)|N||(offline)",
473
      "2|(unknown)|(nodata)|x|(offline)",
474
      "3|(unknown)|N|(unavail)|(offline)",
475
      ]))
476
    self.assertEqual(cli.FormatQueryResult(response, header=True,
477
                                           separator="|", verbose=False),
478
      (cli.QR_UNKNOWN, [
479
      "ID|unk|Unavail|NoData|OffLine",
480
      "1|??|N||*",
481
      "2|??|?|x|*",
482
      "3|??|N|-|*",
483
      ]))
484

    
485
  def testNoData(self):
486
    fields = [
487
      objects.QueryFieldDefinition(name="id", title="ID",
488
                                   kind=constants.QFT_NUMBER),
489
      objects.QueryFieldDefinition(name="name", title="Name",
490
                                   kind=constants.QFT_TEXT),
491
      ]
492

    
493
    response = objects.QueryResponse(fields=fields, data=[])
494

    
495
    self.assertEqual(cli.FormatQueryResult(response, header=True),
496
                     (cli.QR_NORMAL, ["ID Name"]))
497

    
498
  def testNoDataWithUnknown(self):
499
    fields = [
500
      objects.QueryFieldDefinition(name="id", title="ID",
501
                                   kind=constants.QFT_NUMBER),
502
      objects.QueryFieldDefinition(name="unk", title="unk",
503
                                   kind=constants.QFT_UNKNOWN),
504
      ]
505

    
506
    response = objects.QueryResponse(fields=fields, data=[])
507

    
508
    self.assertEqual(cli.FormatQueryResult(response, header=False),
509
                     (cli.QR_UNKNOWN, []))
510

    
511
  def testStatus(self):
512
    fields = [
513
      objects.QueryFieldDefinition(name="id", title="ID",
514
                                   kind=constants.QFT_NUMBER),
515
      objects.QueryFieldDefinition(name="unavail", title="Unavail",
516
                                   kind=constants.QFT_BOOL),
517
      objects.QueryFieldDefinition(name="nodata", title="NoData",
518
                                   kind=constants.QFT_TEXT),
519
      objects.QueryFieldDefinition(name="offline", title="OffLine",
520
                                   kind=constants.QFT_TEXT),
521
      ]
522

    
523
    response = objects.QueryResponse(fields=fields, data=[
524
      [(constants.RS_NORMAL, 1), (constants.RS_NORMAL, False),
525
       (constants.RS_NORMAL, ""), (constants.RS_OFFLINE, None)],
526
      [(constants.RS_NORMAL, 2), (constants.RS_NODATA, None),
527
       (constants.RS_NORMAL, "x"), (constants.RS_NORMAL, "abc")],
528
      [(constants.RS_NORMAL, 3), (constants.RS_NORMAL, False),
529
       (constants.RS_UNAVAIL, None), (constants.RS_OFFLINE, None)],
530
      ])
531

    
532
    self.assertEqual(cli.FormatQueryResult(response, header=False,
533
                                           separator="|", verbose=True),
534
      (cli.QR_INCOMPLETE, [
535
      "1|N||(offline)",
536
      "2|(nodata)|x|abc",
537
      "3|N|(unavail)|(offline)",
538
      ]))
539
    self.assertEqual(cli.FormatQueryResult(response, header=False,
540
                                           separator="|", verbose=False),
541
      (cli.QR_INCOMPLETE, [
542
      "1|N||*",
543
      "2|?|x|abc",
544
      "3|N|-|*",
545
      ]))
546

    
547
  def testInvalidFieldType(self):
548
    fields = [
549
      objects.QueryFieldDefinition(name="x", title="x",
550
                                   kind="#some#other#type"),
551
      ]
552

    
553
    response = objects.QueryResponse(fields=fields, data=[])
554

    
555
    self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
556

    
557
  def testInvalidFieldStatus(self):
558
    fields = [
559
      objects.QueryFieldDefinition(name="x", title="x",
560
                                   kind=constants.QFT_TEXT),
561
      ]
562

    
563
    response = objects.QueryResponse(fields=fields, data=[[(-1, None)]])
564
    self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
565

    
566
    response = objects.QueryResponse(fields=fields, data=[[(-1, "x")]])
567
    self.assertRaises(AssertionError, cli.FormatQueryResult, response)
568

    
569
  def testEmptyFieldTitle(self):
570
    fields = [
571
      objects.QueryFieldDefinition(name="x", title="",
572
                                   kind=constants.QFT_TEXT),
573
      ]
574

    
575
    response = objects.QueryResponse(fields=fields, data=[])
576
    self.assertRaises(AssertionError, cli.FormatQueryResult, response)
577

    
578

    
579
class _MockJobPollCb(cli.JobPollCbBase, cli.JobPollReportCbBase):
580
  def __init__(self, tc, job_id):
581
    self.tc = tc
582
    self.job_id = job_id
583
    self._wfjcr = []
584
    self._jobstatus = []
585
    self._expect_notchanged = False
586
    self._expect_log = []
587

    
588
  def CheckEmpty(self):
589
    self.tc.assertFalse(self._wfjcr)
590
    self.tc.assertFalse(self._jobstatus)
591
    self.tc.assertFalse(self._expect_notchanged)
592
    self.tc.assertFalse(self._expect_log)
593

    
594
  def AddWfjcResult(self, *args):
595
    self._wfjcr.append(args)
596

    
597
  def AddQueryJobsResult(self, *args):
598
    self._jobstatus.append(args)
599

    
600
  def WaitForJobChangeOnce(self, job_id, fields,
601
                           prev_job_info, prev_log_serial):
602
    self.tc.assertEqual(job_id, self.job_id)
603
    self.tc.assertEqualValues(fields, ["status"])
604
    self.tc.assertFalse(self._expect_notchanged)
605
    self.tc.assertFalse(self._expect_log)
606

    
607
    (exp_prev_job_info, exp_prev_log_serial, result) = self._wfjcr.pop(0)
608
    self.tc.assertEqualValues(prev_job_info, exp_prev_job_info)
609
    self.tc.assertEqual(prev_log_serial, exp_prev_log_serial)
610

    
611
    if result == constants.JOB_NOTCHANGED:
612
      self._expect_notchanged = True
613
    elif result:
614
      (_, logmsgs) = result
615
      if logmsgs:
616
        self._expect_log.extend(logmsgs)
617

    
618
    return result
619

    
620
  def QueryJobs(self, job_ids, fields):
621
    self.tc.assertEqual(job_ids, [self.job_id])
622
    self.tc.assertEqualValues(fields, ["status", "opstatus", "opresult"])
623
    self.tc.assertFalse(self._expect_notchanged)
624
    self.tc.assertFalse(self._expect_log)
625

    
626
    result = self._jobstatus.pop(0)
627
    self.tc.assertEqual(len(fields), len(result))
628
    return [result]
629

    
630
  def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
631
    self.tc.assertEqual(job_id, self.job_id)
632
    self.tc.assertEqualValues((serial, timestamp, log_type, log_msg),
633
                              self._expect_log.pop(0))
634

    
635
  def ReportNotChanged(self, job_id, status):
636
    self.tc.assertEqual(job_id, self.job_id)
637
    self.tc.assert_(self._expect_notchanged)
638
    self._expect_notchanged = False
639

    
640

    
641
class TestGenericPollJob(testutils.GanetiTestCase):
642
  def testSuccessWithLog(self):
643
    job_id = 29609
644
    cbs = _MockJobPollCb(self, job_id)
645

    
646
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
647

    
648
    cbs.AddWfjcResult(None, None,
649
                      ((constants.JOB_STATUS_QUEUED, ), None))
650

    
651
    cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
652
                      constants.JOB_NOTCHANGED)
653

    
654
    cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
655
                      ((constants.JOB_STATUS_RUNNING, ),
656
                       [(1, utils.SplitTime(1273491611.0),
657
                         constants.ELOG_MESSAGE, "Step 1"),
658
                        (2, utils.SplitTime(1273491615.9),
659
                         constants.ELOG_MESSAGE, "Step 2"),
660
                        (3, utils.SplitTime(1273491625.02),
661
                         constants.ELOG_MESSAGE, "Step 3"),
662
                        (4, utils.SplitTime(1273491635.05),
663
                         constants.ELOG_MESSAGE, "Step 4"),
664
                        (37, utils.SplitTime(1273491645.0),
665
                         constants.ELOG_MESSAGE, "Step 5"),
666
                        (203, utils.SplitTime(127349155.0),
667
                         constants.ELOG_MESSAGE, "Step 6")]))
668

    
669
    cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 203,
670
                      ((constants.JOB_STATUS_RUNNING, ),
671
                       [(300, utils.SplitTime(1273491711.01),
672
                         constants.ELOG_MESSAGE, "Step X"),
673
                        (302, utils.SplitTime(1273491815.8),
674
                         constants.ELOG_MESSAGE, "Step Y"),
675
                        (303, utils.SplitTime(1273491925.32),
676
                         constants.ELOG_MESSAGE, "Step Z")]))
677

    
678
    cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 303,
679
                      ((constants.JOB_STATUS_SUCCESS, ), None))
680

    
681
    cbs.AddQueryJobsResult(constants.JOB_STATUS_SUCCESS,
682
                           [constants.OP_STATUS_SUCCESS,
683
                            constants.OP_STATUS_SUCCESS],
684
                           ["Hello World", "Foo man bar"])
685

    
686
    self.assertEqual(["Hello World", "Foo man bar"],
687
                     cli.GenericPollJob(job_id, cbs, cbs))
688
    cbs.CheckEmpty()
689

    
690
  def testJobLost(self):
691
    job_id = 13746
692

    
693
    cbs = _MockJobPollCb(self, job_id)
694
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
695
    cbs.AddWfjcResult(None, None, None)
696
    self.assertRaises(errors.JobLost, cli.GenericPollJob, job_id, cbs, cbs)
697
    cbs.CheckEmpty()
698

    
699
  def testError(self):
700
    job_id = 31088
701

    
702
    cbs = _MockJobPollCb(self, job_id)
703
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
704
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
705
    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
706
                           [constants.OP_STATUS_SUCCESS,
707
                            constants.OP_STATUS_ERROR],
708
                           ["Hello World", "Error code 123"])
709
    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
710
    cbs.CheckEmpty()
711

    
712
  def testError2(self):
713
    job_id = 22235
714

    
715
    cbs = _MockJobPollCb(self, job_id)
716
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
717
    encexc = errors.EncodeException(errors.LockError("problem"))
718
    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
719
                           [constants.OP_STATUS_ERROR], [encexc])
720
    self.assertRaises(errors.LockError, cli.GenericPollJob, job_id, cbs, cbs)
721
    cbs.CheckEmpty()
722

    
723
  def testWeirdError(self):
724
    job_id = 28847
725

    
726
    cbs = _MockJobPollCb(self, job_id)
727
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
728
    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
729
                           [constants.OP_STATUS_RUNNING,
730
                            constants.OP_STATUS_RUNNING],
731
                           [None, None])
732
    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
733
    cbs.CheckEmpty()
734

    
735
  def testCancel(self):
736
    job_id = 4275
737

    
738
    cbs = _MockJobPollCb(self, job_id)
739
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
740
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_CANCELING, ), None))
741
    cbs.AddQueryJobsResult(constants.JOB_STATUS_CANCELING,
742
                           [constants.OP_STATUS_CANCELING,
743
                            constants.OP_STATUS_CANCELING],
744
                           [None, None])
745
    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
746
    cbs.CheckEmpty()
747

    
748

    
749
class TestFormatLogMessage(unittest.TestCase):
750
  def test(self):
751
    self.assertEqual(cli.FormatLogMessage(constants.ELOG_MESSAGE,
752
                                          "Hello World"),
753
                     "Hello World")
754
    self.assertRaises(TypeError, cli.FormatLogMessage,
755
                      constants.ELOG_MESSAGE, [1, 2, 3])
756

    
757
    self.assert_(cli.FormatLogMessage("some other type", (1, 2, 3)))
758

    
759

    
760
class TestParseFields(unittest.TestCase):
761
  def test(self):
762
    self.assertEqual(cli.ParseFields(None, []), [])
763
    self.assertEqual(cli.ParseFields("name,foo,hello", []),
764
                     ["name", "foo", "hello"])
765
    self.assertEqual(cli.ParseFields(None, ["def", "ault", "fields", "here"]),
766
                     ["def", "ault", "fields", "here"])
767
    self.assertEqual(cli.ParseFields("name,foo", ["def", "ault"]),
768
                     ["name", "foo"])
769
    self.assertEqual(cli.ParseFields("+name,foo", ["def", "ault"]),
770
                     ["def", "ault", "name", "foo"])
771

    
772

    
773
class TestConstants(unittest.TestCase):
774
  def testPriority(self):
775
    self.assertEqual(set(cli._PRIONAME_TO_VALUE.values()),
776
                     set(constants.OP_PRIO_SUBMIT_VALID))
777
    self.assertEqual(list(value for _, value in cli._PRIORITY_NAMES),
778
                     sorted(constants.OP_PRIO_SUBMIT_VALID, reverse=True))
779

    
780

    
781
class TestParseNicOption(unittest.TestCase):
782
  def test(self):
783
    self.assertEqual(cli.ParseNicOption([("0", { "link": "eth0", })]),
784
                     [{ "link": "eth0", }])
785
    self.assertEqual(cli.ParseNicOption([("5", { "ip": "192.0.2.7", })]),
786
                     [{}, {}, {}, {}, {}, { "ip": "192.0.2.7", }])
787

    
788
  def testErrors(self):
789
    for i in [None, "", "abc", "zero", "Hello World", "\0", []]:
790
      self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
791
                        [(i, { "link": "eth0", })])
792
      self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
793
                        [("0", i)])
794

    
795
    self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
796
                      [(0, { True: False, })])
797

    
798
    self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
799
                      [(3, { "mode": [], })])
800

    
801

    
802
class TestFormatResultError(unittest.TestCase):
803
  def testNormal(self):
804
    for verbose in [False, True]:
805
      self.assertRaises(AssertionError, cli.FormatResultError,
806
                        constants.RS_NORMAL, verbose)
807

    
808
  def testUnknown(self):
809
    for verbose in [False, True]:
810
      self.assertRaises(NotImplementedError, cli.FormatResultError,
811
                        "#some!other!status#", verbose)
812

    
813
  def test(self):
814
    for status in constants.RS_ALL:
815
      if status == constants.RS_NORMAL:
816
        continue
817

    
818
      self.assertNotEqual(cli.FormatResultError(status, False),
819
                          cli.FormatResultError(status, True))
820

    
821
      result = cli.FormatResultError(status, True)
822
      self.assertTrue(result.startswith("("))
823
      self.assertTrue(result.endswith(")"))
824

    
825

    
826
class TestGetOnlineNodes(unittest.TestCase):
827
  class _FakeClient:
828
    def __init__(self):
829
      self._query = []
830

    
831
    def AddQueryResult(self, *args):
832
      self._query.append(args)
833

    
834
    def CountPending(self):
835
      return len(self._query)
836

    
837
    def Query(self, res, fields, qfilter):
838
      if res != constants.QR_NODE:
839
        raise Exception("Querying wrong resource")
840

    
841
      (exp_fields, check_filter, result) = self._query.pop(0)
842

    
843
      if exp_fields != fields:
844
        raise Exception("Expected fields %s, got %s" % (exp_fields, fields))
845

    
846
      if not (qfilter is None or check_filter(qfilter)):
847
        raise Exception("Filter doesn't match expectations")
848

    
849
      return objects.QueryResponse(fields=None, data=result)
850

    
851
  def testEmpty(self):
852
    cl = self._FakeClient()
853

    
854
    cl.AddQueryResult(["name", "offline", "sip"], None, [])
855
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl), [])
856
    self.assertEqual(cl.CountPending(), 0)
857

    
858
  def testNoSpecialFilter(self):
859
    cl = self._FakeClient()
860

    
861
    cl.AddQueryResult(["name", "offline", "sip"], None, [
862
      [(constants.RS_NORMAL, "master.example.com"),
863
       (constants.RS_NORMAL, False),
864
       (constants.RS_NORMAL, "192.0.2.1")],
865
      [(constants.RS_NORMAL, "node2.example.com"),
866
       (constants.RS_NORMAL, False),
867
       (constants.RS_NORMAL, "192.0.2.2")],
868
      ])
869
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl),
870
                     ["master.example.com", "node2.example.com"])
871
    self.assertEqual(cl.CountPending(), 0)
872

    
873
  def testNoMaster(self):
874
    cl = self._FakeClient()
875

    
876
    def _CheckFilter(qfilter):
877
      self.assertEqual(qfilter, [qlang.OP_NOT, [qlang.OP_TRUE, "master"]])
878
      return True
879

    
880
    cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
881
      [(constants.RS_NORMAL, "node2.example.com"),
882
       (constants.RS_NORMAL, False),
883
       (constants.RS_NORMAL, "192.0.2.2")],
884
      ])
885
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, filter_master=True),
886
                     ["node2.example.com"])
887
    self.assertEqual(cl.CountPending(), 0)
888

    
889
  def testSecondaryIpAddress(self):
890
    cl = self._FakeClient()
891

    
892
    cl.AddQueryResult(["name", "offline", "sip"], None, [
893
      [(constants.RS_NORMAL, "master.example.com"),
894
       (constants.RS_NORMAL, False),
895
       (constants.RS_NORMAL, "192.0.2.1")],
896
      [(constants.RS_NORMAL, "node2.example.com"),
897
       (constants.RS_NORMAL, False),
898
       (constants.RS_NORMAL, "192.0.2.2")],
899
      ])
900
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, secondary_ips=True),
901
                     ["192.0.2.1", "192.0.2.2"])
902
    self.assertEqual(cl.CountPending(), 0)
903

    
904
  def testNoMasterFilterNodeName(self):
905
    cl = self._FakeClient()
906

    
907
    def _CheckFilter(qfilter):
908
      self.assertEqual(qfilter,
909
        [qlang.OP_AND,
910
         [qlang.OP_OR] + [[qlang.OP_EQUAL, "name", name]
911
                          for name in ["node2", "node3"]],
912
         [qlang.OP_NOT, [qlang.OP_TRUE, "master"]]])
913
      return True
914

    
915
    cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
916
      [(constants.RS_NORMAL, "node2.example.com"),
917
       (constants.RS_NORMAL, False),
918
       (constants.RS_NORMAL, "192.0.2.12")],
919
      [(constants.RS_NORMAL, "node3.example.com"),
920
       (constants.RS_NORMAL, False),
921
       (constants.RS_NORMAL, "192.0.2.13")],
922
      ])
923
    self.assertEqual(cli.GetOnlineNodes(["node2", "node3"], cl=cl,
924
                                        secondary_ips=True, filter_master=True),
925
                     ["192.0.2.12", "192.0.2.13"])
926
    self.assertEqual(cl.CountPending(), 0)
927

    
928
  def testOfflineNodes(self):
929
    cl = self._FakeClient()
930

    
931
    cl.AddQueryResult(["name", "offline", "sip"], None, [
932
      [(constants.RS_NORMAL, "master.example.com"),
933
       (constants.RS_NORMAL, False),
934
       (constants.RS_NORMAL, "192.0.2.1")],
935
      [(constants.RS_NORMAL, "node2.example.com"),
936
       (constants.RS_NORMAL, True),
937
       (constants.RS_NORMAL, "192.0.2.2")],
938
      [(constants.RS_NORMAL, "node3.example.com"),
939
       (constants.RS_NORMAL, True),
940
       (constants.RS_NORMAL, "192.0.2.3")],
941
      ])
942
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nowarn=True),
943
                     ["master.example.com"])
944
    self.assertEqual(cl.CountPending(), 0)
945

    
946
  def testNodeGroup(self):
947
    cl = self._FakeClient()
948

    
949
    def _CheckFilter(qfilter):
950
      self.assertEqual(qfilter,
951
        [qlang.OP_OR, [qlang.OP_EQUAL, "group", "foobar"],
952
                      [qlang.OP_EQUAL, "group.uuid", "foobar"]])
953
      return True
954

    
955
    cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
956
      [(constants.RS_NORMAL, "master.example.com"),
957
       (constants.RS_NORMAL, False),
958
       (constants.RS_NORMAL, "192.0.2.1")],
959
      [(constants.RS_NORMAL, "node3.example.com"),
960
       (constants.RS_NORMAL, False),
961
       (constants.RS_NORMAL, "192.0.2.3")],
962
      ])
963
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nodegroup="foobar"),
964
                     ["master.example.com", "node3.example.com"])
965
    self.assertEqual(cl.CountPending(), 0)
966

    
967

    
968
class TestFormatTimestamp(unittest.TestCase):
969
  def testGood(self):
970
    self.assertEqual(cli.FormatTimestamp((0, 1)),
971
                     time.strftime("%F %T", time.localtime(0)) + ".000001")
972
    self.assertEqual(cli.FormatTimestamp((1332944009, 17376)),
973
                     (time.strftime("%F %T", time.localtime(1332944009)) +
974
                      ".017376"))
975

    
976
  def testWrong(self):
977
    for i in [0, [], {}, "", [1]]:
978
      self.assertEqual(cli.FormatTimestamp(i), "?")
979

    
980

    
981
class TestFormatUsage(unittest.TestCase):
982
  def test(self):
983
    binary = "gnt-unittest"
984
    commands = {
985
      "cmdA":
986
        (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
987
         "description of A"),
988
      "bbb":
989
        (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
990
         "Hello World," * 10),
991
      "longname":
992
        (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
993
         "Another description"),
994
      }
995

    
996
    self.assertEqual(list(cli._FormatUsage(binary, commands)), [
997
      "Usage: gnt-unittest {command} [options...] [argument...]",
998
      "gnt-unittest <command> --help to see details, or man gnt-unittest",
999
      "",
1000
      "Commands:",
1001
      (" bbb      - Hello World,Hello World,Hello World,Hello World,Hello"
1002
       " World,Hello"),
1003
      "            World,Hello World,Hello World,Hello World,Hello World,",
1004
      " cmdA     - description of A",
1005
      " longname - Another description",
1006
      "",
1007
      ])
1008

    
1009

    
1010
class TestParseArgs(unittest.TestCase):
1011
  def testNoArguments(self):
1012
    for argv in [[], ["gnt-unittest"]]:
1013
      try:
1014
        cli._ParseArgs("gnt-unittest", argv, {}, {}, set())
1015
      except cli._ShowUsage, err:
1016
        self.assertTrue(err.exit_error)
1017
      else:
1018
        self.fail("Did not raise exception")
1019

    
1020
  def testVersion(self):
1021
    for argv in [["test", "--version"], ["test", "--version", "somethingelse"]]:
1022
      try:
1023
        cli._ParseArgs("test", argv, {}, {}, set())
1024
      except cli._ShowVersion:
1025
        pass
1026
      else:
1027
        self.fail("Did not raise exception")
1028

    
1029
  def testHelp(self):
1030
    for argv in [["test", "--help"], ["test", "--help", "somethingelse"]]:
1031
      try:
1032
        cli._ParseArgs("test", argv, {}, {}, set())
1033
      except cli._ShowUsage, err:
1034
        self.assertFalse(err.exit_error)
1035
      else:
1036
        self.fail("Did not raise exception")
1037

    
1038
  def testUnknownCommandOrAlias(self):
1039
    for argv in [["test", "list"], ["test", "somethingelse", "--help"]]:
1040
      try:
1041
        cli._ParseArgs("test", argv, {}, {}, set())
1042
      except cli._ShowUsage, err:
1043
        self.assertTrue(err.exit_error)
1044
      else:
1045
        self.fail("Did not raise exception")
1046

    
1047
  def testInvalidAliasList(self):
1048
    cmd = {
1049
      "list": NotImplemented,
1050
      "foo": NotImplemented,
1051
      }
1052
    aliases = {
1053
      "list": NotImplemented,
1054
      "foo": NotImplemented,
1055
      }
1056
    assert sorted(cmd.keys()) == sorted(aliases.keys())
1057
    self.assertRaises(AssertionError, cli._ParseArgs, "test",
1058
                      ["test", "list"], cmd, aliases, set())
1059

    
1060
  def testAliasForNonExistantCommand(self):
1061
    cmd = {}
1062
    aliases = {
1063
      "list": NotImplemented,
1064
      }
1065
    self.assertRaises(errors.ProgrammerError, cli._ParseArgs, "test",
1066
                      ["test", "list"], cmd, aliases, set())
1067

    
1068

    
1069
class TestQftNames(unittest.TestCase):
1070
  def testComplete(self):
1071
    self.assertEqual(frozenset(cli._QFT_NAMES), constants.QFT_ALL)
1072

    
1073
  def testUnique(self):
1074
    lcnames = map(lambda s: s.lower(), cli._QFT_NAMES.values())
1075
    self.assertFalse(utils.FindDuplicates(lcnames))
1076

    
1077
  def testUppercase(self):
1078
    for name in cli._QFT_NAMES.values():
1079
      self.assertEqual(name[0], name[0].upper())
1080

    
1081

    
1082
class TestFieldDescValues(unittest.TestCase):
1083
  def testKnownKind(self):
1084
    fdef = objects.QueryFieldDefinition(name="aname",
1085
                                        title="Atitle",
1086
                                        kind=constants.QFT_TEXT,
1087
                                        doc="aaa doc aaa")
1088
    self.assertEqual(cli._FieldDescValues(fdef),
1089
                     ["aname", "Text", "Atitle", "aaa doc aaa"])
1090

    
1091
  def testUnknownKind(self):
1092
    kind = "#foo#"
1093

    
1094
    self.assertFalse(kind in constants.QFT_ALL)
1095
    self.assertFalse(kind in cli._QFT_NAMES)
1096

    
1097
    fdef = objects.QueryFieldDefinition(name="zname", title="Ztitle",
1098
                                        kind=kind, doc="zzz doc zzz")
1099
    self.assertEqual(cli._FieldDescValues(fdef),
1100
                     ["zname", kind, "Ztitle", "zzz doc zzz"])
1101

    
1102

    
1103
class TestSerializeGenericInfo(unittest.TestCase):
1104
  """Test case for cli._SerializeGenericInfo"""
1105
  def _RunTest(self, data, expected):
1106
    buf = StringIO()
1107
    cli._SerializeGenericInfo(buf, data, 0)
1108
    self.assertEqual(buf.getvalue(), expected)
1109

    
1110
  def testSimple(self):
1111
    test_samples = [
1112
      ("abc", "abc\n"),
1113
      ([], "\n"),
1114
      ({}, "\n"),
1115
      (["1", "2", "3"], "- 1\n- 2\n- 3\n"),
1116
      ([("z", "26")], "z: 26\n"),
1117
      ({"z": "26"}, "z: 26\n"),
1118
      ([("z", "26"), ("a", "1")], "z: 26\na: 1\n"),
1119
      ({"z": "26", "a": "1"}, "a: 1\nz: 26\n"),
1120
      ]
1121
    for (data, expected) in test_samples:
1122
      self._RunTest(data, expected)
1123

    
1124
  def testLists(self):
1125
    adict = {
1126
      "aa": "11",
1127
      "bb": "22",
1128
      "cc": "33",
1129
      }
1130
    adict_exp = ("- aa: 11\n"
1131
                 "  bb: 22\n"
1132
                 "  cc: 33\n")
1133
    anobj = [
1134
      ("zz", "11"),
1135
      ("ww", "33"),
1136
      ("xx", "22"),
1137
      ]
1138
    anobj_exp = ("- zz: 11\n"
1139
                 "  ww: 33\n"
1140
                 "  xx: 22\n")
1141
    alist = ["aa", "cc", "bb"]
1142
    alist_exp = ("- - aa\n"
1143
                 "  - cc\n"
1144
                 "  - bb\n")
1145
    test_samples = [
1146
      (adict, adict_exp),
1147
      (anobj, anobj_exp),
1148
      (alist, alist_exp),
1149
      ]
1150
    for (base_data, base_expected) in test_samples:
1151
      for k in range(1, 4):
1152
        data = k * [base_data]
1153
        expected = k * base_expected
1154
        self._RunTest(data, expected)
1155

    
1156
  def testDictionaries(self):
1157
    data = [
1158
      ("aaa", ["x", "y"]),
1159
      ("bbb", {
1160
          "w": "1",
1161
          "z": "2",
1162
          }),
1163
      ("ccc", [
1164
          ("xyz", "123"),
1165
          ("efg", "456"),
1166
          ]),
1167
      ]
1168
    expected = ("aaa: \n"
1169
                "  - x\n"
1170
                "  - y\n"
1171
                "bbb: \n"
1172
                "  w: 1\n"
1173
                "  z: 2\n"
1174
                "ccc: \n"
1175
                "  xyz: 123\n"
1176
                "  efg: 456\n")
1177
    self._RunTest(data, expected)
1178
    self._RunTest(dict(data), expected)
1179

    
1180

    
1181
class TestCreateIPolicyFromOpts(unittest.TestCase):
1182
  """Test case for cli.CreateIPolicyFromOpts."""
1183
  def setUp(self):
1184
    # Policies are big, and we want to see the difference in case of an error
1185
    self.maxDiff = None
1186

    
1187
  def _RecursiveCheckMergedDicts(self, default_pol, diff_pol, merged_pol,
1188
                                 merge_minmax=False):
1189
    self.assertTrue(type(default_pol) is dict)
1190
    self.assertTrue(type(diff_pol) is dict)
1191
    self.assertTrue(type(merged_pol) is dict)
1192
    self.assertEqual(frozenset(default_pol.keys()),
1193
                     frozenset(merged_pol.keys()))
1194
    for (key, val) in merged_pol.items():
1195
      if key in diff_pol:
1196
        if type(val) is dict:
1197
          self._RecursiveCheckMergedDicts(default_pol[key], diff_pol[key], val)
1198
        elif (merge_minmax and key == "minmax" and type(val) is list and
1199
              len(val) == 1):
1200
          self.assertEqual(len(default_pol[key]), 1)
1201
          self.assertEqual(len(diff_pol[key]), 1)
1202
          self._RecursiveCheckMergedDicts(default_pol[key][0],
1203
                                          diff_pol[key][0], val[0])
1204
        else:
1205
          self.assertEqual(val, diff_pol[key])
1206
      else:
1207
        self.assertEqual(val, default_pol[key])
1208

    
1209
  def testClusterPolicy(self):
1210
    pol0 = cli.CreateIPolicyFromOpts(
1211
      ispecs_mem_size={},
1212
      ispecs_cpu_count={},
1213
      ispecs_disk_count={},
1214
      ispecs_disk_size={},
1215
      ispecs_nic_count={},
1216
      ipolicy_disk_templates=None,
1217
      ipolicy_vcpu_ratio=None,
1218
      ipolicy_spindle_ratio=None,
1219
      fill_all=True
1220
      )
1221
    self.assertEqual(pol0, constants.IPOLICY_DEFAULTS)
1222

    
1223
    exp_pol1 = {
1224
      constants.ISPECS_MINMAX: [
1225
        {
1226
          constants.ISPECS_MIN: {
1227
            constants.ISPEC_CPU_COUNT: 2,
1228
            constants.ISPEC_DISK_COUNT: 1,
1229
            },
1230
          constants.ISPECS_MAX: {
1231
            constants.ISPEC_MEM_SIZE: 12*1024,
1232
            constants.ISPEC_DISK_COUNT: 2,
1233
            },
1234
          },
1235
        ],
1236
      constants.ISPECS_STD: {
1237
        constants.ISPEC_CPU_COUNT: 2,
1238
        constants.ISPEC_DISK_COUNT: 2,
1239
        },
1240
      constants.IPOLICY_VCPU_RATIO: 3.1,
1241
      }
1242
    pol1 = cli.CreateIPolicyFromOpts(
1243
      ispecs_mem_size={"max": "12g"},
1244
      ispecs_cpu_count={"min": 2, "std": 2},
1245
      ispecs_disk_count={"min": 1, "max": 2, "std": 2},
1246
      ispecs_disk_size={},
1247
      ispecs_nic_count={},
1248
      ipolicy_disk_templates=None,
1249
      ipolicy_vcpu_ratio=3.1,
1250
      ipolicy_spindle_ratio=None,
1251
      fill_all=True
1252
      )
1253
    self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1254
                                    exp_pol1, pol1, merge_minmax=True)
1255

    
1256
    exp_pol2 = {
1257
      constants.ISPECS_MINMAX: [
1258
        {
1259
          constants.ISPECS_MIN: {
1260
            constants.ISPEC_DISK_SIZE: 512,
1261
            constants.ISPEC_NIC_COUNT: 2,
1262
            },
1263
          constants.ISPECS_MAX: {
1264
            constants.ISPEC_NIC_COUNT: 3,
1265
            },
1266
          },
1267
        ],
1268
      constants.ISPECS_STD: {
1269
        constants.ISPEC_CPU_COUNT: 2,
1270
        constants.ISPEC_NIC_COUNT: 3,
1271
        },
1272
      constants.IPOLICY_SPINDLE_RATIO: 1.3,
1273
      constants.IPOLICY_DTS: ["templates"],
1274
      }
1275
    pol2 = cli.CreateIPolicyFromOpts(
1276
      ispecs_mem_size={},
1277
      ispecs_cpu_count={"std": 2},
1278
      ispecs_disk_count={},
1279
      ispecs_disk_size={"min": "0.5g"},
1280
      ispecs_nic_count={"min": 2, "max": 3, "std": 3},
1281
      ipolicy_disk_templates=["templates"],
1282
      ipolicy_vcpu_ratio=None,
1283
      ipolicy_spindle_ratio=1.3,
1284
      fill_all=True
1285
      )
1286
    self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1287
                                      exp_pol2, pol2, merge_minmax=True)
1288

    
1289
    for fill_all in [False, True]:
1290
      exp_pol3 = {
1291
        constants.ISPECS_STD: {
1292
          constants.ISPEC_CPU_COUNT: 2,
1293
          constants.ISPEC_NIC_COUNT: 3,
1294
          },
1295
        }
1296
      pol3 = cli.CreateIPolicyFromOpts(
1297
        std_ispecs={
1298
          constants.ISPEC_CPU_COUNT: "2",
1299
          constants.ISPEC_NIC_COUNT: "3",
1300
          },
1301
        ipolicy_disk_templates=None,
1302
        ipolicy_vcpu_ratio=None,
1303
        ipolicy_spindle_ratio=None,
1304
        fill_all=fill_all
1305
        )
1306
      if fill_all:
1307
        self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1308
                                        exp_pol3, pol3, merge_minmax=True)
1309
      else:
1310
        self.assertEqual(pol3, exp_pol3)
1311

    
1312
  def testPartialPolicy(self):
1313
    exp_pol0 = objects.MakeEmptyIPolicy()
1314
    pol0 = cli.CreateIPolicyFromOpts(
1315
      minmax_ispecs=None,
1316
      std_ispecs=None,
1317
      ipolicy_disk_templates=None,
1318
      ipolicy_vcpu_ratio=None,
1319
      ipolicy_spindle_ratio=None,
1320
      fill_all=False
1321
      )
1322
    self.assertEqual(pol0, exp_pol0)
1323

    
1324
    exp_pol1 = {
1325
      constants.IPOLICY_VCPU_RATIO: 3.1,
1326
      }
1327
    pol1 = cli.CreateIPolicyFromOpts(
1328
      minmax_ispecs=None,
1329
      std_ispecs=None,
1330
      ipolicy_disk_templates=None,
1331
      ipolicy_vcpu_ratio=3.1,
1332
      ipolicy_spindle_ratio=None,
1333
      fill_all=False
1334
      )
1335
    self.assertEqual(pol1, exp_pol1)
1336

    
1337
    exp_pol2 = {
1338
      constants.IPOLICY_SPINDLE_RATIO: 1.3,
1339
      constants.IPOLICY_DTS: ["templates"],
1340
      }
1341
    pol2 = cli.CreateIPolicyFromOpts(
1342
      minmax_ispecs=None,
1343
      std_ispecs=None,
1344
      ipolicy_disk_templates=["templates"],
1345
      ipolicy_vcpu_ratio=None,
1346
      ipolicy_spindle_ratio=1.3,
1347
      fill_all=False
1348
      )
1349
    self.assertEqual(pol2, exp_pol2)
1350

    
1351
  def _TestInvalidISpecs(self, minmax_ispecs, std_ispecs, fail=True):
1352
    for fill_all in [False, True]:
1353
      if fail:
1354
        self.assertRaises((errors.OpPrereqError,
1355
                           errors.UnitParseError,
1356
                           errors.TypeEnforcementError),
1357
                          cli.CreateIPolicyFromOpts,
1358
                          minmax_ispecs=minmax_ispecs,
1359
                          std_ispecs=std_ispecs,
1360
                          fill_all=fill_all)
1361
      else:
1362
        cli.CreateIPolicyFromOpts(minmax_ispecs=minmax_ispecs,
1363
                                  std_ispecs=std_ispecs,
1364
                                  fill_all=fill_all)
1365

    
1366
  def testInvalidPolicies(self):
1367
    self.assertRaises(AssertionError, cli.CreateIPolicyFromOpts,
1368
                      std_ispecs={constants.ISPEC_MEM_SIZE: 1024},
1369
                      ipolicy_disk_templates=None, ipolicy_vcpu_ratio=None,
1370
                      ipolicy_spindle_ratio=None, group_ipolicy=True)
1371
    self.assertRaises(errors.OpPrereqError, cli.CreateIPolicyFromOpts,
1372
                      ispecs_mem_size={"wrong": "x"}, ispecs_cpu_count={},
1373
                      ispecs_disk_count={}, ispecs_disk_size={},
1374
                      ispecs_nic_count={}, ipolicy_disk_templates=None,
1375
                      ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None,
1376
                      fill_all=True)
1377
    self.assertRaises(errors.TypeEnforcementError, cli.CreateIPolicyFromOpts,
1378
                      ispecs_mem_size={}, ispecs_cpu_count={"min": "default"},
1379
                      ispecs_disk_count={}, ispecs_disk_size={},
1380
                      ispecs_nic_count={}, ipolicy_disk_templates=None,
1381
                      ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None,
1382
                      fill_all=True)
1383

    
1384
    good_mmspecs = constants.ISPECS_MINMAX_DEFAULTS
1385
    self._TestInvalidISpecs(good_mmspecs, None, fail=False)
1386
    broken_mmspecs = copy.deepcopy(good_mmspecs)
1387
    for key in constants.ISPECS_MINMAX_KEYS:
1388
      for par in constants.ISPECS_PARAMETERS:
1389
        old = broken_mmspecs[key][par]
1390
        del broken_mmspecs[key][par]
1391
        self._TestInvalidISpecs(broken_mmspecs, None)
1392
        broken_mmspecs[key][par] = "invalid"
1393
        self._TestInvalidISpecs(broken_mmspecs, None)
1394
        broken_mmspecs[key][par] = old
1395
      broken_mmspecs[key]["invalid_key"] = None
1396
      self._TestInvalidISpecs(broken_mmspecs, None)
1397
      del broken_mmspecs[key]["invalid_key"]
1398
    broken_mmspecs["invalid_key"] = None
1399
    self._TestInvalidISpecs(broken_mmspecs, None)
1400
    del broken_mmspecs["invalid_key"]
1401
    assert broken_mmspecs == good_mmspecs
1402

    
1403
    good_stdspecs = constants.IPOLICY_DEFAULTS[constants.ISPECS_STD]
1404
    self._TestInvalidISpecs(None, good_stdspecs, fail=False)
1405
    broken_stdspecs = copy.deepcopy(good_stdspecs)
1406
    for par in constants.ISPECS_PARAMETERS:
1407
      old = broken_stdspecs[par]
1408
      broken_stdspecs[par] = "invalid"
1409
      self._TestInvalidISpecs(None, broken_stdspecs)
1410
      broken_stdspecs[par] = old
1411
    broken_stdspecs["invalid_key"] = None
1412
    self._TestInvalidISpecs(None, broken_stdspecs)
1413
    del broken_stdspecs["invalid_key"]
1414
    assert broken_stdspecs == good_stdspecs
1415

    
1416
  def testAllowedValues(self):
1417
    allowedv = "blah"
1418
    exp_pol1 = {
1419
      constants.ISPECS_MINMAX: allowedv,
1420
      constants.IPOLICY_DTS: allowedv,
1421
      constants.IPOLICY_VCPU_RATIO: allowedv,
1422
      constants.IPOLICY_SPINDLE_RATIO: allowedv,
1423
      }
1424
    pol1 = cli.CreateIPolicyFromOpts(minmax_ispecs={allowedv: {}},
1425
                                     std_ispecs=None,
1426
                                     ipolicy_disk_templates=allowedv,
1427
                                     ipolicy_vcpu_ratio=allowedv,
1428
                                     ipolicy_spindle_ratio=allowedv,
1429
                                     allowed_values=[allowedv])
1430
    self.assertEqual(pol1, exp_pol1)
1431

    
1432
  @staticmethod
1433
  def _ConvertSpecToStrings(spec):
1434
    ret = {}
1435
    for (par, val) in spec.items():
1436
      ret[par] = str(val)
1437
    return ret
1438

    
1439
  def _CheckNewStyleSpecsCall(self, exp_ipolicy, minmax_ispecs, std_ispecs,
1440
                              group_ipolicy, fill_all):
1441
    ipolicy = cli.CreateIPolicyFromOpts(minmax_ispecs=minmax_ispecs,
1442
                                        std_ispecs=std_ispecs,
1443
                                        group_ipolicy=group_ipolicy,
1444
                                        fill_all=fill_all)
1445
    self.assertEqual(ipolicy, exp_ipolicy)
1446

    
1447
  def _TestFullISpecsInner(self, skel_exp_ipol, exp_minmax, exp_std,
1448
                           group_ipolicy, fill_all):
1449
    exp_ipol = skel_exp_ipol.copy()
1450
    if exp_minmax is not None:
1451
      minmax_ispecs = {}
1452
      for (key, spec) in exp_minmax.items():
1453
        minmax_ispecs[key] = self._ConvertSpecToStrings(spec)
1454
      exp_ipol[constants.ISPECS_MINMAX] = [exp_minmax]
1455
    else:
1456
      minmax_ispecs = None
1457
    if exp_std is not None:
1458
      std_ispecs = self._ConvertSpecToStrings(exp_std)
1459
      exp_ipol[constants.ISPECS_STD] = exp_std
1460
    else:
1461
      std_ispecs = None
1462

    
1463
    self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1464
                                 group_ipolicy, fill_all)
1465
    if minmax_ispecs:
1466
      for (key, spec) in minmax_ispecs.items():
1467
        for par in [constants.ISPEC_MEM_SIZE, constants.ISPEC_DISK_SIZE]:
1468
          if par in spec:
1469
            spec[par] += "m"
1470
            self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1471
                                         group_ipolicy, fill_all)
1472
    if std_ispecs:
1473
      for par in [constants.ISPEC_MEM_SIZE, constants.ISPEC_DISK_SIZE]:
1474
        if par in std_ispecs:
1475
          std_ispecs[par] += "m"
1476
          self._CheckNewStyleSpecsCall(exp_ipol, minmax_ispecs, std_ispecs,
1477
                                       group_ipolicy, fill_all)
1478

    
1479
  def testFullISpecs(self):
1480
    exp_minmax1 = {
1481
      constants.ISPECS_MIN: {
1482
        constants.ISPEC_MEM_SIZE: 512,
1483
        constants.ISPEC_CPU_COUNT: 2,
1484
        constants.ISPEC_DISK_COUNT: 2,
1485
        constants.ISPEC_DISK_SIZE: 512,
1486
        constants.ISPEC_NIC_COUNT: 2,
1487
        constants.ISPEC_SPINDLE_USE: 2,
1488
        },
1489
      constants.ISPECS_MAX: {
1490
        constants.ISPEC_MEM_SIZE: 768*1024,
1491
        constants.ISPEC_CPU_COUNT: 7,
1492
        constants.ISPEC_DISK_COUNT: 6,
1493
        constants.ISPEC_DISK_SIZE: 2048*1024,
1494
        constants.ISPEC_NIC_COUNT: 3,
1495
        constants.ISPEC_SPINDLE_USE: 1,
1496
        },
1497
      }
1498
    exp_std1 = {
1499
      constants.ISPEC_MEM_SIZE: 768*1024,
1500
      constants.ISPEC_CPU_COUNT: 7,
1501
      constants.ISPEC_DISK_COUNT: 6,
1502
      constants.ISPEC_DISK_SIZE: 2048*1024,
1503
      constants.ISPEC_NIC_COUNT: 3,
1504
      constants.ISPEC_SPINDLE_USE: 1,
1505
      }
1506
    for fill_all in [False, True]:
1507
      if fill_all:
1508
        skel_ipolicy = constants.IPOLICY_DEFAULTS
1509
      else:
1510
        skel_ipolicy = {}
1511
      self._TestFullISpecsInner(skel_ipolicy, exp_minmax1, exp_std1,
1512
                                False, fill_all)
1513
      self._TestFullISpecsInner(skel_ipolicy, None, exp_std1,
1514
                                False, fill_all)
1515
      self._TestFullISpecsInner(skel_ipolicy, exp_minmax1, None,
1516
                                False, fill_all)
1517

    
1518

    
1519
class TestPrintIPolicyCommand(unittest.TestCase):
1520
  """Test case for cli.PrintIPolicyCommand"""
1521
  _SPECS1 = {
1522
    "par1": 42,
1523
    "par2": "xyz",
1524
    }
1525
  _SPECS1_STR = "par1=42,par2=xyz"
1526
  _SPECS2 = {
1527
    "param": 10,
1528
    "another_param": 101,
1529
    }
1530
  _SPECS2_STR = "another_param=101,param=10"
1531

    
1532
  def _CheckPrintIPolicyCommand(self, ipolicy, isgroup, expected):
1533
    buf = StringIO()
1534
    cli.PrintIPolicyCommand(buf, ipolicy, isgroup)
1535
    self.assertEqual(buf.getvalue(), expected)
1536

    
1537
  def testIgnoreStdForGroup(self):
1538
    self._CheckPrintIPolicyCommand({"std": self._SPECS1}, True, "")
1539

    
1540
  def testIgnoreEmpty(self):
1541
    policies = [
1542
      {},
1543
      {"std": {}},
1544
      {"minmax": []},
1545
      {"minmax": [{}]},
1546
      {"minmax": [{
1547
        "min": {},
1548
        "max": {},
1549
        }]},
1550
      {"minmax": [{
1551
        "min": self._SPECS1,
1552
        "max": {},
1553
        }]},
1554
      ]
1555
    for pol in policies:
1556
      self._CheckPrintIPolicyCommand(pol, False, "")
1557

    
1558
  def testFullPolicies(self):
1559
    cases = [
1560
      ({"std": self._SPECS1},
1561
       " %s %s" % (cli.IPOLICY_STD_SPECS_STR, self._SPECS1_STR)),
1562
      ({"minmax": [{
1563
        "min": self._SPECS1,
1564
        "max": self._SPECS2,
1565
        }]},
1566
       " %s min:%s/max:%s" % (cli.IPOLICY_BOUNDS_SPECS_STR,
1567
                              self._SPECS1_STR, self._SPECS2_STR)),
1568
      ]
1569
    for (pol, exp) in cases:
1570
      self._CheckPrintIPolicyCommand(pol, False, exp)
1571

    
1572

    
1573
if __name__ == "__main__":
1574
  testutils.GanetiTestProgram()