Statistics
| Branch: | Tag: | Revision:

root / test / py / ganeti.cli_unittest.py @ 726ae450

History | View | Annotate | Download (45 kB)

1
#!/usr/bin/python
2
#
3

    
4
# Copyright (C) 2008, 2011, 2012, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Script for unittesting the cli module"""
23

    
24
import unittest
25
import time
26
from cStringIO import StringIO
27

    
28
import ganeti
29
import testutils
30

    
31
from ganeti import constants
32
from ganeti import cli
33
from ganeti import errors
34
from ganeti import utils
35
from ganeti import objects
36
from ganeti import qlang
37
from ganeti.errors import OpPrereqError, ParameterError
38

    
39

    
40
class TestParseTimespec(unittest.TestCase):
41
  """Testing case for ParseTimespec"""
42

    
43
  def testValidTimes(self):
44
    """Test valid timespecs"""
45
    test_data = [
46
      ("1s", 1),
47
      ("1", 1),
48
      ("1m", 60),
49
      ("1h", 60 * 60),
50
      ("1d", 60 * 60 * 24),
51
      ("1w", 60 * 60 * 24 * 7),
52
      ("4h", 4 * 60 * 60),
53
      ("61m", 61 * 60),
54
      ]
55
    for value, expected_result in test_data:
56
      self.failUnlessEqual(cli.ParseTimespec(value), expected_result)
57

    
58
  def testInvalidTime(self):
59
    """Test invalid timespecs"""
60
    test_data = [
61
      "1y",
62
      "",
63
      "aaa",
64
      "s",
65
      ]
66
    for value in test_data:
67
      self.failUnlessRaises(OpPrereqError, cli.ParseTimespec, value)
68

    
69

    
70
class TestSplitKeyVal(unittest.TestCase):
71
  """Testing case for cli._SplitKeyVal"""
72
  DATA = "a=b,c,no_d,-e"
73
  RESULT = {"a": "b", "c": True, "d": False, "e": None}
74
  RESULT_NOPREFIX = {"a": "b", "c": {}, "no_d": {}, "-e": {}}
75

    
76
  def testSplitKeyVal(self):
77
    """Test splitting"""
78
    self.failUnlessEqual(cli._SplitKeyVal("option", self.DATA, True),
79
                         self.RESULT)
80

    
81
  def testDuplicateParam(self):
82
    """Test duplicate parameters"""
83
    for data in ("a=1,a=2", "a,no_a"):
84
      self.failUnlessRaises(ParameterError, cli._SplitKeyVal,
85
                            "option", data, True)
86

    
87
  def testEmptyData(self):
88
    """Test how we handle splitting an empty string"""
89
    self.failUnlessEqual(cli._SplitKeyVal("option", "", True), {})
90

    
91

    
92
class TestIdentKeyVal(unittest.TestCase):
93
  """Testing case for cli.check_ident_key_val"""
94

    
95
  def testIdentKeyVal(self):
96
    """Test identkeyval"""
97
    def cikv(value):
98
      return cli.check_ident_key_val("option", "opt", value)
99

    
100
    self.assertEqual(cikv("foo:bar"), ("foo", {"bar": True}))
101
    self.assertEqual(cikv("foo:bar=baz"), ("foo", {"bar": "baz"}))
102
    self.assertEqual(cikv("bar:b=c,c=a"), ("bar", {"b": "c", "c": "a"}))
103
    self.assertEqual(cikv("no_bar"), ("bar", False))
104
    self.assertRaises(ParameterError, cikv, "no_bar:foo")
105
    self.assertRaises(ParameterError, cikv, "no_bar:foo=baz")
106
    self.assertRaises(ParameterError, cikv, "bar:foo=baz,foo=baz")
107
    self.assertEqual(cikv("-foo"), ("foo", None))
108
    self.assertRaises(ParameterError, cikv, "-foo:a=c")
109

    
110
    # Check negative numbers
111
    self.assertEqual(cikv("-1:remove"), ("-1", {
112
      "remove": True,
113
      }))
114
    self.assertEqual(cikv("-29447:add,size=4G"), ("-29447", {
115
      "add": True,
116
      "size": "4G",
117
      }))
118
    for i in ["-:", "-"]:
119
      self.assertEqual(cikv(i), ("", None))
120

    
121
  @staticmethod
122
  def _csikv(value):
123
    return cli._SplitIdentKeyVal("opt", value, False)
124

    
125
  def testIdentKeyValNoPrefix(self):
126
    """Test identkeyval without prefixes"""
127
    test_cases = [
128
      ("foo:bar", None),
129
      ("foo:no_bar", None),
130
      ("foo:bar=baz,bar=baz", None),
131
      ("foo",
132
       ("foo", {})),
133
      ("foo:bar=baz",
134
       ("foo", {"bar": "baz"})),
135
      ("no_foo:-1=baz,no_op=3",
136
       ("no_foo", {"-1": "baz", "no_op": "3"})),
137
      ]
138
    for (arg, res) in test_cases:
139
      if res is None:
140
        self.assertRaises(ParameterError, self._csikv, arg)
141
      else:
142
        self.assertEqual(self._csikv(arg), res)
143

    
144

    
145
class TestListIdentKeyVal(unittest.TestCase):
146
  """Test for cli.check_list_ident_key_val()"""
147

    
148
  @staticmethod
149
  def _clikv(value):
150
    return cli.check_list_ident_key_val("option", "opt", value)
151

    
152
  def testListIdentKeyVal(self):
153
    test_cases = [
154
      ("",
155
       None),
156
      ("foo",
157
       {"foo": {}}),
158
      ("foo:bar=baz",
159
       {"foo": {"bar": "baz"}}),
160
      ("foo:bar=baz/foo:bat=bad",
161
       None),
162
      ("foo:abc=42/bar:def=11",
163
       {"foo": {"abc": "42"},
164
        "bar": {"def": "11"}}),
165
      ("foo:abc=42/bar:def=11,ghi=07",
166
       {"foo": {"abc": "42"},
167
        "bar": {"def": "11", "ghi": "07"}}),
168
      ]
169
    for (arg, res) in test_cases:
170
      if res is None:
171
        self.assertRaises(ParameterError, self._clikv, arg)
172
      else:
173
        self.assertEqual(res, self._clikv(arg))
174

    
175

    
176
class TestToStream(unittest.TestCase):
177
  """Test the ToStream functions"""
178

    
179
  def testBasic(self):
180
    for data in ["foo",
181
                 "foo %s",
182
                 "foo %(test)s",
183
                 "foo %s %s",
184
                 "",
185
                 ]:
186
      buf = StringIO()
187
      cli._ToStream(buf, data)
188
      self.failUnlessEqual(buf.getvalue(), data + "\n")
189

    
190
  def testParams(self):
191
      buf = StringIO()
192
      cli._ToStream(buf, "foo %s", 1)
193
      self.failUnlessEqual(buf.getvalue(), "foo 1\n")
194
      buf = StringIO()
195
      cli._ToStream(buf, "foo %s", (15,16))
196
      self.failUnlessEqual(buf.getvalue(), "foo (15, 16)\n")
197
      buf = StringIO()
198
      cli._ToStream(buf, "foo %s %s", "a", "b")
199
      self.failUnlessEqual(buf.getvalue(), "foo a b\n")
200

    
201

    
202
class TestGenerateTable(unittest.TestCase):
203
  HEADERS = dict([("f%s" % i, "Field%s" % i) for i in range(5)])
204

    
205
  FIELDS1 = ["f1", "f2"]
206
  DATA1 = [
207
    ["abc", 1234],
208
    ["foobar", 56],
209
    ["b", -14],
210
    ]
211

    
212
  def _test(self, headers, fields, separator, data,
213
            numfields, unitfields, units, expected):
214
    table = cli.GenerateTable(headers, fields, separator, data,
215
                              numfields=numfields, unitfields=unitfields,
216
                              units=units)
217
    self.assertEqual(table, expected)
218

    
219
  def testPlain(self):
220
    exp = [
221
      "Field1 Field2",
222
      "abc    1234",
223
      "foobar 56",
224
      "b      -14",
225
      ]
226
    self._test(self.HEADERS, self.FIELDS1, None, self.DATA1,
227
               None, None, "m", exp)
228

    
229
  def testNoFields(self):
230
    self._test(self.HEADERS, [], None, [[], []],
231
               None, None, "m", ["", "", ""])
232
    self._test(None, [], None, [[], []],
233
               None, None, "m", ["", ""])
234

    
235
  def testSeparator(self):
236
    for sep in ["#", ":", ",", "^", "!", "%", "|", "###", "%%", "!!!", "||"]:
237
      exp = [
238
        "Field1%sField2" % sep,
239
        "abc%s1234" % sep,
240
        "foobar%s56" % sep,
241
        "b%s-14" % sep,
242
        ]
243
      self._test(self.HEADERS, self.FIELDS1, sep, self.DATA1,
244
                 None, None, "m", exp)
245

    
246
  def testNoHeader(self):
247
    exp = [
248
      "abc    1234",
249
      "foobar 56",
250
      "b      -14",
251
      ]
252
    self._test(None, self.FIELDS1, None, self.DATA1,
253
               None, None, "m", exp)
254

    
255
  def testUnknownField(self):
256
    headers = {
257
      "f1": "Field1",
258
      }
259
    exp = [
260
      "Field1 UNKNOWN",
261
      "abc    1234",
262
      "foobar 56",
263
      "b      -14",
264
      ]
265
    self._test(headers, ["f1", "UNKNOWN"], None, self.DATA1,
266
               None, None, "m", exp)
267

    
268
  def testNumfields(self):
269
    fields = ["f1", "f2", "f3"]
270
    data = [
271
      ["abc", 1234, 0],
272
      ["foobar", 56, 3],
273
      ["b", -14, "-"],
274
      ]
275
    exp = [
276
      "Field1 Field2 Field3",
277
      "abc      1234      0",
278
      "foobar     56      3",
279
      "b         -14      -",
280
      ]
281
    self._test(self.HEADERS, fields, None, data,
282
               ["f2", "f3"], None, "m", exp)
283

    
284
  def testUnitfields(self):
285
    expnosep = [
286
      "Field1 Field2 Field3",
287
      "abc      1234     0M",
288
      "foobar     56     3M",
289
      "b         -14      -",
290
      ]
291

    
292
    expsep = [
293
      "Field1:Field2:Field3",
294
      "abc:1234:0M",
295
      "foobar:56:3M",
296
      "b:-14:-",
297
      ]
298

    
299
    for sep, expected in [(None, expnosep), (":", expsep)]:
300
      fields = ["f1", "f2", "f3"]
301
      data = [
302
        ["abc", 1234, 0],
303
        ["foobar", 56, 3],
304
        ["b", -14, "-"],
305
        ]
306
      self._test(self.HEADERS, fields, sep, data,
307
                 ["f2", "f3"], ["f3"], "h", expected)
308

    
309
  def testUnusual(self):
310
    data = [
311
      ["%", "xyz"],
312
      ["%%", "abc"],
313
      ]
314
    exp = [
315
      "Field1 Field2",
316
      "%      xyz",
317
      "%%     abc",
318
      ]
319
    self._test(self.HEADERS, ["f1", "f2"], None, data,
320
               None, None, "m", exp)
321

    
322

    
323
class TestFormatQueryResult(unittest.TestCase):
324
  def test(self):
325
    fields = [
326
      objects.QueryFieldDefinition(name="name", title="Name",
327
                                   kind=constants.QFT_TEXT),
328
      objects.QueryFieldDefinition(name="size", title="Size",
329
                                   kind=constants.QFT_NUMBER),
330
      objects.QueryFieldDefinition(name="act", title="Active",
331
                                   kind=constants.QFT_BOOL),
332
      objects.QueryFieldDefinition(name="mem", title="Memory",
333
                                   kind=constants.QFT_UNIT),
334
      objects.QueryFieldDefinition(name="other", title="SomeList",
335
                                   kind=constants.QFT_OTHER),
336
      ]
337

    
338
    response = objects.QueryResponse(fields=fields, data=[
339
      [(constants.RS_NORMAL, "nodeA"), (constants.RS_NORMAL, 128),
340
       (constants.RS_NORMAL, False), (constants.RS_NORMAL, 1468006),
341
       (constants.RS_NORMAL, [])],
342
      [(constants.RS_NORMAL, "other"), (constants.RS_NORMAL, 512),
343
       (constants.RS_NORMAL, True), (constants.RS_NORMAL, 16),
344
       (constants.RS_NORMAL, [1, 2, 3])],
345
      [(constants.RS_NORMAL, "xyz"), (constants.RS_NORMAL, 1024),
346
       (constants.RS_NORMAL, True), (constants.RS_NORMAL, 4096),
347
       (constants.RS_NORMAL, [{}, {}])],
348
      ])
349

    
350
    self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True),
351
      (cli.QR_NORMAL, [
352
      "Name  Size Active Memory SomeList",
353
      "nodeA  128 N        1.4T []",
354
      "other  512 Y         16M [1, 2, 3]",
355
      "xyz   1024 Y        4.0G [{}, {}]",
356
      ]))
357

    
358
  def testTimestampAndUnit(self):
359
    fields = [
360
      objects.QueryFieldDefinition(name="name", title="Name",
361
                                   kind=constants.QFT_TEXT),
362
      objects.QueryFieldDefinition(name="size", title="Size",
363
                                   kind=constants.QFT_UNIT),
364
      objects.QueryFieldDefinition(name="mtime", title="ModTime",
365
                                   kind=constants.QFT_TIMESTAMP),
366
      ]
367

    
368
    response = objects.QueryResponse(fields=fields, data=[
369
      [(constants.RS_NORMAL, "a"), (constants.RS_NORMAL, 1024),
370
       (constants.RS_NORMAL, 0)],
371
      [(constants.RS_NORMAL, "b"), (constants.RS_NORMAL, 144996),
372
       (constants.RS_NORMAL, 1291746295)],
373
      ])
374

    
375
    self.assertEqual(cli.FormatQueryResult(response, unit="m", header=True),
376
      (cli.QR_NORMAL, [
377
      "Name   Size ModTime",
378
      "a      1024 %s" % utils.FormatTime(0),
379
      "b    144996 %s" % utils.FormatTime(1291746295),
380
      ]))
381

    
382
  def testOverride(self):
383
    fields = [
384
      objects.QueryFieldDefinition(name="name", title="Name",
385
                                   kind=constants.QFT_TEXT),
386
      objects.QueryFieldDefinition(name="cust", title="Custom",
387
                                   kind=constants.QFT_OTHER),
388
      objects.QueryFieldDefinition(name="xt", title="XTime",
389
                                   kind=constants.QFT_TIMESTAMP),
390
      ]
391

    
392
    response = objects.QueryResponse(fields=fields, data=[
393
      [(constants.RS_NORMAL, "x"), (constants.RS_NORMAL, ["a", "b", "c"]),
394
       (constants.RS_NORMAL, 1234)],
395
      [(constants.RS_NORMAL, "y"), (constants.RS_NORMAL, range(10)),
396
       (constants.RS_NORMAL, 1291746295)],
397
      ])
398

    
399
    override = {
400
      "cust": (utils.CommaJoin, False),
401
      "xt": (hex, True),
402
      }
403

    
404
    self.assertEqual(cli.FormatQueryResult(response, unit="h", header=True,
405
                                           format_override=override),
406
      (cli.QR_NORMAL, [
407
      "Name Custom                            XTime",
408
      "x    a, b, c                           0x4d2",
409
      "y    0, 1, 2, 3, 4, 5, 6, 7, 8, 9 0x4cfe7bf7",
410
      ]))
411

    
412
  def testSeparator(self):
413
    fields = [
414
      objects.QueryFieldDefinition(name="name", title="Name",
415
                                   kind=constants.QFT_TEXT),
416
      objects.QueryFieldDefinition(name="count", title="Count",
417
                                   kind=constants.QFT_NUMBER),
418
      objects.QueryFieldDefinition(name="desc", title="Description",
419
                                   kind=constants.QFT_TEXT),
420
      ]
421

    
422
    response = objects.QueryResponse(fields=fields, data=[
423
      [(constants.RS_NORMAL, "instance1.example.com"),
424
       (constants.RS_NORMAL, 21125), (constants.RS_NORMAL, "Hello World!")],
425
      [(constants.RS_NORMAL, "mail.other.net"),
426
       (constants.RS_NORMAL, -9000), (constants.RS_NORMAL, "a,b,c")],
427
      ])
428

    
429
    for sep in [":", "|", "#", "|||", "###", "@@@", "@#@"]:
430
      for header in [None, "Name%sCount%sDescription" % (sep, sep)]:
431
        exp = []
432
        if header:
433
          exp.append(header)
434
        exp.extend([
435
          "instance1.example.com%s21125%sHello World!" % (sep, sep),
436
          "mail.other.net%s-9000%sa,b,c" % (sep, sep),
437
          ])
438

    
439
        self.assertEqual(cli.FormatQueryResult(response, separator=sep,
440
                                               header=bool(header)),
441
                         (cli.QR_NORMAL, exp))
442

    
443
  def testStatusWithUnknown(self):
444
    fields = [
445
      objects.QueryFieldDefinition(name="id", title="ID",
446
                                   kind=constants.QFT_NUMBER),
447
      objects.QueryFieldDefinition(name="unk", title="unk",
448
                                   kind=constants.QFT_UNKNOWN),
449
      objects.QueryFieldDefinition(name="unavail", title="Unavail",
450
                                   kind=constants.QFT_BOOL),
451
      objects.QueryFieldDefinition(name="nodata", title="NoData",
452
                                   kind=constants.QFT_TEXT),
453
      objects.QueryFieldDefinition(name="offline", title="OffLine",
454
                                   kind=constants.QFT_TEXT),
455
      ]
456

    
457
    response = objects.QueryResponse(fields=fields, data=[
458
      [(constants.RS_NORMAL, 1), (constants.RS_UNKNOWN, None),
459
       (constants.RS_NORMAL, False), (constants.RS_NORMAL, ""),
460
       (constants.RS_OFFLINE, None)],
461
      [(constants.RS_NORMAL, 2), (constants.RS_UNKNOWN, None),
462
       (constants.RS_NODATA, None), (constants.RS_NORMAL, "x"),
463
       (constants.RS_OFFLINE, None)],
464
      [(constants.RS_NORMAL, 3), (constants.RS_UNKNOWN, None),
465
       (constants.RS_NORMAL, False), (constants.RS_UNAVAIL, None),
466
       (constants.RS_OFFLINE, None)],
467
      ])
468

    
469
    self.assertEqual(cli.FormatQueryResult(response, header=True,
470
                                           separator="|", verbose=True),
471
      (cli.QR_UNKNOWN, [
472
      "ID|unk|Unavail|NoData|OffLine",
473
      "1|(unknown)|N||(offline)",
474
      "2|(unknown)|(nodata)|x|(offline)",
475
      "3|(unknown)|N|(unavail)|(offline)",
476
      ]))
477
    self.assertEqual(cli.FormatQueryResult(response, header=True,
478
                                           separator="|", verbose=False),
479
      (cli.QR_UNKNOWN, [
480
      "ID|unk|Unavail|NoData|OffLine",
481
      "1|??|N||*",
482
      "2|??|?|x|*",
483
      "3|??|N|-|*",
484
      ]))
485

    
486
  def testNoData(self):
487
    fields = [
488
      objects.QueryFieldDefinition(name="id", title="ID",
489
                                   kind=constants.QFT_NUMBER),
490
      objects.QueryFieldDefinition(name="name", title="Name",
491
                                   kind=constants.QFT_TEXT),
492
      ]
493

    
494
    response = objects.QueryResponse(fields=fields, data=[])
495

    
496
    self.assertEqual(cli.FormatQueryResult(response, header=True),
497
                     (cli.QR_NORMAL, ["ID Name"]))
498

    
499
  def testNoDataWithUnknown(self):
500
    fields = [
501
      objects.QueryFieldDefinition(name="id", title="ID",
502
                                   kind=constants.QFT_NUMBER),
503
      objects.QueryFieldDefinition(name="unk", title="unk",
504
                                   kind=constants.QFT_UNKNOWN),
505
      ]
506

    
507
    response = objects.QueryResponse(fields=fields, data=[])
508

    
509
    self.assertEqual(cli.FormatQueryResult(response, header=False),
510
                     (cli.QR_UNKNOWN, []))
511

    
512
  def testStatus(self):
513
    fields = [
514
      objects.QueryFieldDefinition(name="id", title="ID",
515
                                   kind=constants.QFT_NUMBER),
516
      objects.QueryFieldDefinition(name="unavail", title="Unavail",
517
                                   kind=constants.QFT_BOOL),
518
      objects.QueryFieldDefinition(name="nodata", title="NoData",
519
                                   kind=constants.QFT_TEXT),
520
      objects.QueryFieldDefinition(name="offline", title="OffLine",
521
                                   kind=constants.QFT_TEXT),
522
      ]
523

    
524
    response = objects.QueryResponse(fields=fields, data=[
525
      [(constants.RS_NORMAL, 1), (constants.RS_NORMAL, False),
526
       (constants.RS_NORMAL, ""), (constants.RS_OFFLINE, None)],
527
      [(constants.RS_NORMAL, 2), (constants.RS_NODATA, None),
528
       (constants.RS_NORMAL, "x"), (constants.RS_NORMAL, "abc")],
529
      [(constants.RS_NORMAL, 3), (constants.RS_NORMAL, False),
530
       (constants.RS_UNAVAIL, None), (constants.RS_OFFLINE, None)],
531
      ])
532

    
533
    self.assertEqual(cli.FormatQueryResult(response, header=False,
534
                                           separator="|", verbose=True),
535
      (cli.QR_INCOMPLETE, [
536
      "1|N||(offline)",
537
      "2|(nodata)|x|abc",
538
      "3|N|(unavail)|(offline)",
539
      ]))
540
    self.assertEqual(cli.FormatQueryResult(response, header=False,
541
                                           separator="|", verbose=False),
542
      (cli.QR_INCOMPLETE, [
543
      "1|N||*",
544
      "2|?|x|abc",
545
      "3|N|-|*",
546
      ]))
547

    
548
  def testInvalidFieldType(self):
549
    fields = [
550
      objects.QueryFieldDefinition(name="x", title="x",
551
                                   kind="#some#other#type"),
552
      ]
553

    
554
    response = objects.QueryResponse(fields=fields, data=[])
555

    
556
    self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
557

    
558
  def testInvalidFieldStatus(self):
559
    fields = [
560
      objects.QueryFieldDefinition(name="x", title="x",
561
                                   kind=constants.QFT_TEXT),
562
      ]
563

    
564
    response = objects.QueryResponse(fields=fields, data=[[(-1, None)]])
565
    self.assertRaises(NotImplementedError, cli.FormatQueryResult, response)
566

    
567
    response = objects.QueryResponse(fields=fields, data=[[(-1, "x")]])
568
    self.assertRaises(AssertionError, cli.FormatQueryResult, response)
569

    
570
  def testEmptyFieldTitle(self):
571
    fields = [
572
      objects.QueryFieldDefinition(name="x", title="",
573
                                   kind=constants.QFT_TEXT),
574
      ]
575

    
576
    response = objects.QueryResponse(fields=fields, data=[])
577
    self.assertRaises(AssertionError, cli.FormatQueryResult, response)
578

    
579

    
580
class _MockJobPollCb(cli.JobPollCbBase, cli.JobPollReportCbBase):
581
  def __init__(self, tc, job_id):
582
    self.tc = tc
583
    self.job_id = job_id
584
    self._wfjcr = []
585
    self._jobstatus = []
586
    self._expect_notchanged = False
587
    self._expect_log = []
588

    
589
  def CheckEmpty(self):
590
    self.tc.assertFalse(self._wfjcr)
591
    self.tc.assertFalse(self._jobstatus)
592
    self.tc.assertFalse(self._expect_notchanged)
593
    self.tc.assertFalse(self._expect_log)
594

    
595
  def AddWfjcResult(self, *args):
596
    self._wfjcr.append(args)
597

    
598
  def AddQueryJobsResult(self, *args):
599
    self._jobstatus.append(args)
600

    
601
  def WaitForJobChangeOnce(self, job_id, fields,
602
                           prev_job_info, prev_log_serial):
603
    self.tc.assertEqual(job_id, self.job_id)
604
    self.tc.assertEqualValues(fields, ["status"])
605
    self.tc.assertFalse(self._expect_notchanged)
606
    self.tc.assertFalse(self._expect_log)
607

    
608
    (exp_prev_job_info, exp_prev_log_serial, result) = self._wfjcr.pop(0)
609
    self.tc.assertEqualValues(prev_job_info, exp_prev_job_info)
610
    self.tc.assertEqual(prev_log_serial, exp_prev_log_serial)
611

    
612
    if result == constants.JOB_NOTCHANGED:
613
      self._expect_notchanged = True
614
    elif result:
615
      (_, logmsgs) = result
616
      if logmsgs:
617
        self._expect_log.extend(logmsgs)
618

    
619
    return result
620

    
621
  def QueryJobs(self, job_ids, fields):
622
    self.tc.assertEqual(job_ids, [self.job_id])
623
    self.tc.assertEqualValues(fields, ["status", "opstatus", "opresult"])
624
    self.tc.assertFalse(self._expect_notchanged)
625
    self.tc.assertFalse(self._expect_log)
626

    
627
    result = self._jobstatus.pop(0)
628
    self.tc.assertEqual(len(fields), len(result))
629
    return [result]
630

    
631
  def ReportLogMessage(self, job_id, serial, timestamp, log_type, log_msg):
632
    self.tc.assertEqual(job_id, self.job_id)
633
    self.tc.assertEqualValues((serial, timestamp, log_type, log_msg),
634
                              self._expect_log.pop(0))
635

    
636
  def ReportNotChanged(self, job_id, status):
637
    self.tc.assertEqual(job_id, self.job_id)
638
    self.tc.assert_(self._expect_notchanged)
639
    self._expect_notchanged = False
640

    
641

    
642
class TestGenericPollJob(testutils.GanetiTestCase):
643
  def testSuccessWithLog(self):
644
    job_id = 29609
645
    cbs = _MockJobPollCb(self, job_id)
646

    
647
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
648

    
649
    cbs.AddWfjcResult(None, None,
650
                      ((constants.JOB_STATUS_QUEUED, ), None))
651

    
652
    cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
653
                      constants.JOB_NOTCHANGED)
654

    
655
    cbs.AddWfjcResult((constants.JOB_STATUS_QUEUED, ), None,
656
                      ((constants.JOB_STATUS_RUNNING, ),
657
                       [(1, utils.SplitTime(1273491611.0),
658
                         constants.ELOG_MESSAGE, "Step 1"),
659
                        (2, utils.SplitTime(1273491615.9),
660
                         constants.ELOG_MESSAGE, "Step 2"),
661
                        (3, utils.SplitTime(1273491625.02),
662
                         constants.ELOG_MESSAGE, "Step 3"),
663
                        (4, utils.SplitTime(1273491635.05),
664
                         constants.ELOG_MESSAGE, "Step 4"),
665
                        (37, utils.SplitTime(1273491645.0),
666
                         constants.ELOG_MESSAGE, "Step 5"),
667
                        (203, utils.SplitTime(127349155.0),
668
                         constants.ELOG_MESSAGE, "Step 6")]))
669

    
670
    cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 203,
671
                      ((constants.JOB_STATUS_RUNNING, ),
672
                       [(300, utils.SplitTime(1273491711.01),
673
                         constants.ELOG_MESSAGE, "Step X"),
674
                        (302, utils.SplitTime(1273491815.8),
675
                         constants.ELOG_MESSAGE, "Step Y"),
676
                        (303, utils.SplitTime(1273491925.32),
677
                         constants.ELOG_MESSAGE, "Step Z")]))
678

    
679
    cbs.AddWfjcResult((constants.JOB_STATUS_RUNNING, ), 303,
680
                      ((constants.JOB_STATUS_SUCCESS, ), None))
681

    
682
    cbs.AddQueryJobsResult(constants.JOB_STATUS_SUCCESS,
683
                           [constants.OP_STATUS_SUCCESS,
684
                            constants.OP_STATUS_SUCCESS],
685
                           ["Hello World", "Foo man bar"])
686

    
687
    self.assertEqual(["Hello World", "Foo man bar"],
688
                     cli.GenericPollJob(job_id, cbs, cbs))
689
    cbs.CheckEmpty()
690

    
691
  def testJobLost(self):
692
    job_id = 13746
693

    
694
    cbs = _MockJobPollCb(self, job_id)
695
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
696
    cbs.AddWfjcResult(None, None, None)
697
    self.assertRaises(errors.JobLost, cli.GenericPollJob, job_id, cbs, cbs)
698
    cbs.CheckEmpty()
699

    
700
  def testError(self):
701
    job_id = 31088
702

    
703
    cbs = _MockJobPollCb(self, job_id)
704
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
705
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
706
    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
707
                           [constants.OP_STATUS_SUCCESS,
708
                            constants.OP_STATUS_ERROR],
709
                           ["Hello World", "Error code 123"])
710
    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
711
    cbs.CheckEmpty()
712

    
713
  def testError2(self):
714
    job_id = 22235
715

    
716
    cbs = _MockJobPollCb(self, job_id)
717
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
718
    encexc = errors.EncodeException(errors.LockError("problem"))
719
    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
720
                           [constants.OP_STATUS_ERROR], [encexc])
721
    self.assertRaises(errors.LockError, cli.GenericPollJob, job_id, cbs, cbs)
722
    cbs.CheckEmpty()
723

    
724
  def testWeirdError(self):
725
    job_id = 28847
726

    
727
    cbs = _MockJobPollCb(self, job_id)
728
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_ERROR, ), None))
729
    cbs.AddQueryJobsResult(constants.JOB_STATUS_ERROR,
730
                           [constants.OP_STATUS_RUNNING,
731
                            constants.OP_STATUS_RUNNING],
732
                           [None, None])
733
    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
734
    cbs.CheckEmpty()
735

    
736
  def testCancel(self):
737
    job_id = 4275
738

    
739
    cbs = _MockJobPollCb(self, job_id)
740
    cbs.AddWfjcResult(None, None, constants.JOB_NOTCHANGED)
741
    cbs.AddWfjcResult(None, None, ((constants.JOB_STATUS_CANCELING, ), None))
742
    cbs.AddQueryJobsResult(constants.JOB_STATUS_CANCELING,
743
                           [constants.OP_STATUS_CANCELING,
744
                            constants.OP_STATUS_CANCELING],
745
                           [None, None])
746
    self.assertRaises(errors.OpExecError, cli.GenericPollJob, job_id, cbs, cbs)
747
    cbs.CheckEmpty()
748

    
749

    
750
class TestFormatLogMessage(unittest.TestCase):
751
  def test(self):
752
    self.assertEqual(cli.FormatLogMessage(constants.ELOG_MESSAGE,
753
                                          "Hello World"),
754
                     "Hello World")
755
    self.assertRaises(TypeError, cli.FormatLogMessage,
756
                      constants.ELOG_MESSAGE, [1, 2, 3])
757

    
758
    self.assert_(cli.FormatLogMessage("some other type", (1, 2, 3)))
759

    
760

    
761
class TestParseFields(unittest.TestCase):
762
  def test(self):
763
    self.assertEqual(cli.ParseFields(None, []), [])
764
    self.assertEqual(cli.ParseFields("name,foo,hello", []),
765
                     ["name", "foo", "hello"])
766
    self.assertEqual(cli.ParseFields(None, ["def", "ault", "fields", "here"]),
767
                     ["def", "ault", "fields", "here"])
768
    self.assertEqual(cli.ParseFields("name,foo", ["def", "ault"]),
769
                     ["name", "foo"])
770
    self.assertEqual(cli.ParseFields("+name,foo", ["def", "ault"]),
771
                     ["def", "ault", "name", "foo"])
772

    
773

    
774
class TestConstants(unittest.TestCase):
775
  def testPriority(self):
776
    self.assertEqual(set(cli._PRIONAME_TO_VALUE.values()),
777
                     set(constants.OP_PRIO_SUBMIT_VALID))
778
    self.assertEqual(list(value for _, value in cli._PRIORITY_NAMES),
779
                     sorted(constants.OP_PRIO_SUBMIT_VALID, reverse=True))
780

    
781

    
782
class TestParseNicOption(unittest.TestCase):
783
  def test(self):
784
    self.assertEqual(cli.ParseNicOption([("0", { "link": "eth0", })]),
785
                     [{ "link": "eth0", }])
786
    self.assertEqual(cli.ParseNicOption([("5", { "ip": "192.0.2.7", })]),
787
                     [{}, {}, {}, {}, {}, { "ip": "192.0.2.7", }])
788

    
789
  def testErrors(self):
790
    for i in [None, "", "abc", "zero", "Hello World", "\0", []]:
791
      self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
792
                        [(i, { "link": "eth0", })])
793
      self.assertRaises(errors.OpPrereqError, cli.ParseNicOption,
794
                        [("0", i)])
795

    
796
    self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
797
                      [(0, { True: False, })])
798

    
799
    self.assertRaises(errors.TypeEnforcementError, cli.ParseNicOption,
800
                      [(3, { "mode": [], })])
801

    
802

    
803
class TestFormatResultError(unittest.TestCase):
804
  def testNormal(self):
805
    for verbose in [False, True]:
806
      self.assertRaises(AssertionError, cli.FormatResultError,
807
                        constants.RS_NORMAL, verbose)
808

    
809
  def testUnknown(self):
810
    for verbose in [False, True]:
811
      self.assertRaises(NotImplementedError, cli.FormatResultError,
812
                        "#some!other!status#", verbose)
813

    
814
  def test(self):
815
    for status in constants.RS_ALL:
816
      if status == constants.RS_NORMAL:
817
        continue
818

    
819
      self.assertNotEqual(cli.FormatResultError(status, False),
820
                          cli.FormatResultError(status, True))
821

    
822
      result = cli.FormatResultError(status, True)
823
      self.assertTrue(result.startswith("("))
824
      self.assertTrue(result.endswith(")"))
825

    
826

    
827
class TestGetOnlineNodes(unittest.TestCase):
828
  class _FakeClient:
829
    def __init__(self):
830
      self._query = []
831

    
832
    def AddQueryResult(self, *args):
833
      self._query.append(args)
834

    
835
    def CountPending(self):
836
      return len(self._query)
837

    
838
    def Query(self, res, fields, qfilter):
839
      if res != constants.QR_NODE:
840
        raise Exception("Querying wrong resource")
841

    
842
      (exp_fields, check_filter, result) = self._query.pop(0)
843

    
844
      if exp_fields != fields:
845
        raise Exception("Expected fields %s, got %s" % (exp_fields, fields))
846

    
847
      if not (qfilter is None or check_filter(qfilter)):
848
        raise Exception("Filter doesn't match expectations")
849

    
850
      return objects.QueryResponse(fields=None, data=result)
851

    
852
  def testEmpty(self):
853
    cl = self._FakeClient()
854

    
855
    cl.AddQueryResult(["name", "offline", "sip"], None, [])
856
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl), [])
857
    self.assertEqual(cl.CountPending(), 0)
858

    
859
  def testNoSpecialFilter(self):
860
    cl = self._FakeClient()
861

    
862
    cl.AddQueryResult(["name", "offline", "sip"], None, [
863
      [(constants.RS_NORMAL, "master.example.com"),
864
       (constants.RS_NORMAL, False),
865
       (constants.RS_NORMAL, "192.0.2.1")],
866
      [(constants.RS_NORMAL, "node2.example.com"),
867
       (constants.RS_NORMAL, False),
868
       (constants.RS_NORMAL, "192.0.2.2")],
869
      ])
870
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl),
871
                     ["master.example.com", "node2.example.com"])
872
    self.assertEqual(cl.CountPending(), 0)
873

    
874
  def testNoMaster(self):
875
    cl = self._FakeClient()
876

    
877
    def _CheckFilter(qfilter):
878
      self.assertEqual(qfilter, [qlang.OP_NOT, [qlang.OP_TRUE, "master"]])
879
      return True
880

    
881
    cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
882
      [(constants.RS_NORMAL, "node2.example.com"),
883
       (constants.RS_NORMAL, False),
884
       (constants.RS_NORMAL, "192.0.2.2")],
885
      ])
886
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, filter_master=True),
887
                     ["node2.example.com"])
888
    self.assertEqual(cl.CountPending(), 0)
889

    
890
  def testSecondaryIpAddress(self):
891
    cl = self._FakeClient()
892

    
893
    cl.AddQueryResult(["name", "offline", "sip"], None, [
894
      [(constants.RS_NORMAL, "master.example.com"),
895
       (constants.RS_NORMAL, False),
896
       (constants.RS_NORMAL, "192.0.2.1")],
897
      [(constants.RS_NORMAL, "node2.example.com"),
898
       (constants.RS_NORMAL, False),
899
       (constants.RS_NORMAL, "192.0.2.2")],
900
      ])
901
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, secondary_ips=True),
902
                     ["192.0.2.1", "192.0.2.2"])
903
    self.assertEqual(cl.CountPending(), 0)
904

    
905
  def testNoMasterFilterNodeName(self):
906
    cl = self._FakeClient()
907

    
908
    def _CheckFilter(qfilter):
909
      self.assertEqual(qfilter,
910
        [qlang.OP_AND,
911
         [qlang.OP_OR] + [[qlang.OP_EQUAL, "name", name]
912
                          for name in ["node2", "node3"]],
913
         [qlang.OP_NOT, [qlang.OP_TRUE, "master"]]])
914
      return True
915

    
916
    cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
917
      [(constants.RS_NORMAL, "node2.example.com"),
918
       (constants.RS_NORMAL, False),
919
       (constants.RS_NORMAL, "192.0.2.12")],
920
      [(constants.RS_NORMAL, "node3.example.com"),
921
       (constants.RS_NORMAL, False),
922
       (constants.RS_NORMAL, "192.0.2.13")],
923
      ])
924
    self.assertEqual(cli.GetOnlineNodes(["node2", "node3"], cl=cl,
925
                                        secondary_ips=True, filter_master=True),
926
                     ["192.0.2.12", "192.0.2.13"])
927
    self.assertEqual(cl.CountPending(), 0)
928

    
929
  def testOfflineNodes(self):
930
    cl = self._FakeClient()
931

    
932
    cl.AddQueryResult(["name", "offline", "sip"], None, [
933
      [(constants.RS_NORMAL, "master.example.com"),
934
       (constants.RS_NORMAL, False),
935
       (constants.RS_NORMAL, "192.0.2.1")],
936
      [(constants.RS_NORMAL, "node2.example.com"),
937
       (constants.RS_NORMAL, True),
938
       (constants.RS_NORMAL, "192.0.2.2")],
939
      [(constants.RS_NORMAL, "node3.example.com"),
940
       (constants.RS_NORMAL, True),
941
       (constants.RS_NORMAL, "192.0.2.3")],
942
      ])
943
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nowarn=True),
944
                     ["master.example.com"])
945
    self.assertEqual(cl.CountPending(), 0)
946

    
947
  def testNodeGroup(self):
948
    cl = self._FakeClient()
949

    
950
    def _CheckFilter(qfilter):
951
      self.assertEqual(qfilter,
952
        [qlang.OP_OR, [qlang.OP_EQUAL, "group", "foobar"],
953
                      [qlang.OP_EQUAL, "group.uuid", "foobar"]])
954
      return True
955

    
956
    cl.AddQueryResult(["name", "offline", "sip"], _CheckFilter, [
957
      [(constants.RS_NORMAL, "master.example.com"),
958
       (constants.RS_NORMAL, False),
959
       (constants.RS_NORMAL, "192.0.2.1")],
960
      [(constants.RS_NORMAL, "node3.example.com"),
961
       (constants.RS_NORMAL, False),
962
       (constants.RS_NORMAL, "192.0.2.3")],
963
      ])
964
    self.assertEqual(cli.GetOnlineNodes(None, cl=cl, nodegroup="foobar"),
965
                     ["master.example.com", "node3.example.com"])
966
    self.assertEqual(cl.CountPending(), 0)
967

    
968

    
969
class TestFormatTimestamp(unittest.TestCase):
970
  def testGood(self):
971
    self.assertEqual(cli.FormatTimestamp((0, 1)),
972
                     time.strftime("%F %T", time.localtime(0)) + ".000001")
973
    self.assertEqual(cli.FormatTimestamp((1332944009, 17376)),
974
                     (time.strftime("%F %T", time.localtime(1332944009)) +
975
                      ".017376"))
976

    
977
  def testWrong(self):
978
    for i in [0, [], {}, "", [1]]:
979
      self.assertEqual(cli.FormatTimestamp(i), "?")
980

    
981

    
982
class TestFormatUsage(unittest.TestCase):
983
  def test(self):
984
    binary = "gnt-unittest"
985
    commands = {
986
      "cmdA":
987
        (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
988
         "description of A"),
989
      "bbb":
990
        (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
991
         "Hello World," * 10),
992
      "longname":
993
        (NotImplemented, NotImplemented, NotImplemented, NotImplemented,
994
         "Another description"),
995
      }
996

    
997
    self.assertEqual(list(cli._FormatUsage(binary, commands)), [
998
      "Usage: gnt-unittest {command} [options...] [argument...]",
999
      "gnt-unittest <command> --help to see details, or man gnt-unittest",
1000
      "",
1001
      "Commands:",
1002
      (" bbb      - Hello World,Hello World,Hello World,Hello World,Hello"
1003
       " World,Hello"),
1004
      "            World,Hello World,Hello World,Hello World,Hello World,",
1005
      " cmdA     - description of A",
1006
      " longname - Another description",
1007
      "",
1008
      ])
1009

    
1010

    
1011
class TestParseArgs(unittest.TestCase):
1012
  def testNoArguments(self):
1013
    for argv in [[], ["gnt-unittest"]]:
1014
      try:
1015
        cli._ParseArgs("gnt-unittest", argv, {}, {}, set())
1016
      except cli._ShowUsage, err:
1017
        self.assertTrue(err.exit_error)
1018
      else:
1019
        self.fail("Did not raise exception")
1020

    
1021
  def testVersion(self):
1022
    for argv in [["test", "--version"], ["test", "--version", "somethingelse"]]:
1023
      try:
1024
        cli._ParseArgs("test", argv, {}, {}, set())
1025
      except cli._ShowVersion:
1026
        pass
1027
      else:
1028
        self.fail("Did not raise exception")
1029

    
1030
  def testHelp(self):
1031
    for argv in [["test", "--help"], ["test", "--help", "somethingelse"]]:
1032
      try:
1033
        cli._ParseArgs("test", argv, {}, {}, set())
1034
      except cli._ShowUsage, err:
1035
        self.assertFalse(err.exit_error)
1036
      else:
1037
        self.fail("Did not raise exception")
1038

    
1039
  def testUnknownCommandOrAlias(self):
1040
    for argv in [["test", "list"], ["test", "somethingelse", "--help"]]:
1041
      try:
1042
        cli._ParseArgs("test", argv, {}, {}, set())
1043
      except cli._ShowUsage, err:
1044
        self.assertTrue(err.exit_error)
1045
      else:
1046
        self.fail("Did not raise exception")
1047

    
1048
  def testInvalidAliasList(self):
1049
    cmd = {
1050
      "list": NotImplemented,
1051
      "foo": NotImplemented,
1052
      }
1053
    aliases = {
1054
      "list": NotImplemented,
1055
      "foo": NotImplemented,
1056
      }
1057
    assert sorted(cmd.keys()) == sorted(aliases.keys())
1058
    self.assertRaises(AssertionError, cli._ParseArgs, "test",
1059
                      ["test", "list"], cmd, aliases, set())
1060

    
1061
  def testAliasForNonExistantCommand(self):
1062
    cmd = {}
1063
    aliases = {
1064
      "list": NotImplemented,
1065
      }
1066
    self.assertRaises(errors.ProgrammerError, cli._ParseArgs, "test",
1067
                      ["test", "list"], cmd, aliases, set())
1068

    
1069

    
1070
class TestQftNames(unittest.TestCase):
1071
  def testComplete(self):
1072
    self.assertEqual(frozenset(cli._QFT_NAMES), constants.QFT_ALL)
1073

    
1074
  def testUnique(self):
1075
    lcnames = map(lambda s: s.lower(), cli._QFT_NAMES.values())
1076
    self.assertFalse(utils.FindDuplicates(lcnames))
1077

    
1078
  def testUppercase(self):
1079
    for name in cli._QFT_NAMES.values():
1080
      self.assertEqual(name[0], name[0].upper())
1081

    
1082

    
1083
class TestFieldDescValues(unittest.TestCase):
1084
  def testKnownKind(self):
1085
    fdef = objects.QueryFieldDefinition(name="aname",
1086
                                        title="Atitle",
1087
                                        kind=constants.QFT_TEXT,
1088
                                        doc="aaa doc aaa")
1089
    self.assertEqual(cli._FieldDescValues(fdef),
1090
                     ["aname", "Text", "Atitle", "aaa doc aaa"])
1091

    
1092
  def testUnknownKind(self):
1093
    kind = "#foo#"
1094

    
1095
    self.assertFalse(kind in constants.QFT_ALL)
1096
    self.assertFalse(kind in cli._QFT_NAMES)
1097

    
1098
    fdef = objects.QueryFieldDefinition(name="zname", title="Ztitle",
1099
                                        kind=kind, doc="zzz doc zzz")
1100
    self.assertEqual(cli._FieldDescValues(fdef),
1101
                     ["zname", kind, "Ztitle", "zzz doc zzz"])
1102

    
1103

    
1104
class TestSerializeGenericInfo(unittest.TestCase):
1105
  """Test case for cli._SerializeGenericInfo"""
1106
  def _RunTest(self, data, expected):
1107
    buf = StringIO()
1108
    cli._SerializeGenericInfo(buf, data, 0)
1109
    self.assertEqual(buf.getvalue(), expected)
1110

    
1111
  def testSimple(self):
1112
    test_samples = [
1113
      ("abc", "abc\n"),
1114
      ([], "\n"),
1115
      ({}, "\n"),
1116
      (["1", "2", "3"], "- 1\n- 2\n- 3\n"),
1117
      ([("z", "26")], "z: 26\n"),
1118
      ({"z": "26"}, "z: 26\n"),
1119
      ([("z", "26"), ("a", "1")], "z: 26\na: 1\n"),
1120
      ({"z": "26", "a": "1"}, "a: 1\nz: 26\n"),
1121
      ]
1122
    for (data, expected) in test_samples:
1123
      self._RunTest(data, expected)
1124

    
1125
  def testLists(self):
1126
    adict = {
1127
      "aa": "11",
1128
      "bb": "22",
1129
      "cc": "33",
1130
      }
1131
    adict_exp = ("- aa: 11\n"
1132
                 "  bb: 22\n"
1133
                 "  cc: 33\n")
1134
    anobj = [
1135
      ("zz", "11"),
1136
      ("ww", "33"),
1137
      ("xx", "22"),
1138
      ]
1139
    anobj_exp = ("- zz: 11\n"
1140
                 "  ww: 33\n"
1141
                 "  xx: 22\n")
1142
    alist = ["aa", "cc", "bb"]
1143
    alist_exp = ("- - aa\n"
1144
                 "  - cc\n"
1145
                 "  - bb\n")
1146
    test_samples = [
1147
      (adict, adict_exp),
1148
      (anobj, anobj_exp),
1149
      (alist, alist_exp),
1150
      ]
1151
    for (base_data, base_expected) in test_samples:
1152
      for k in range(1, 4):
1153
        data = k * [base_data]
1154
        expected = k * base_expected
1155
        self._RunTest(data, expected)
1156

    
1157
  def testDictionaries(self):
1158
    data = [
1159
      ("aaa", ["x", "y"]),
1160
      ("bbb", {
1161
          "w": "1",
1162
          "z": "2",
1163
          }),
1164
      ("ccc", [
1165
          ("xyz", "123"),
1166
          ("efg", "456"),
1167
          ]),
1168
      ]
1169
    expected = ("aaa: \n"
1170
                "  - x\n"
1171
                "  - y\n"
1172
                "bbb: \n"
1173
                "  w: 1\n"
1174
                "  z: 2\n"
1175
                "ccc: \n"
1176
                "  xyz: 123\n"
1177
                "  efg: 456\n")
1178
    self._RunTest(data, expected)
1179
    self._RunTest(dict(data), expected)
1180

    
1181

    
1182
class TestCreateIPolicyFromOpts(unittest.TestCase):
1183
  """Test case for cli.CreateIPolicyFromOpts."""
1184
  def _RecursiveCheckMergedDicts(self, default_pol, diff_pol, merged_pol):
1185
    self.assertTrue(type(default_pol) is dict)
1186
    self.assertTrue(type(diff_pol) is dict)
1187
    self.assertTrue(type(merged_pol) is dict)
1188
    self.assertEqual(frozenset(default_pol.keys()),
1189
                     frozenset(merged_pol.keys()))
1190
    for (key, val) in merged_pol.items():
1191
      if key in diff_pol:
1192
        if type(val) is dict:
1193
          self._RecursiveCheckMergedDicts(default_pol[key], diff_pol[key], val)
1194
        else:
1195
          self.assertEqual(val, diff_pol[key])
1196
      else:
1197
        self.assertEqual(val, default_pol[key])
1198

    
1199
  def testClusterPolicy(self):
1200
    exp_pol0 = {
1201
      constants.ISPECS_MINMAX: {
1202
        constants.ISPECS_MIN: {},
1203
        constants.ISPECS_MAX: {},
1204
        },
1205
      constants.ISPECS_STD: {},
1206
      }
1207
    exp_pol1 = {
1208
      constants.ISPECS_MINMAX: {
1209
        constants.ISPECS_MIN: {
1210
          constants.ISPEC_CPU_COUNT: 2,
1211
          constants.ISPEC_DISK_COUNT: 1,
1212
          },
1213
        constants.ISPECS_MAX: {
1214
          constants.ISPEC_MEM_SIZE: 12*1024,
1215
          constants.ISPEC_DISK_COUNT: 2,
1216
          },
1217
        },
1218
      constants.ISPECS_STD: {
1219
        constants.ISPEC_CPU_COUNT: 2,
1220
        constants.ISPEC_DISK_COUNT: 2,
1221
        },
1222
      constants.IPOLICY_VCPU_RATIO: 3.1,
1223
      }
1224
    exp_pol2 = {
1225
      constants.ISPECS_MINMAX: {
1226
        constants.ISPECS_MIN: {
1227
          constants.ISPEC_DISK_SIZE: 512,
1228
          constants.ISPEC_NIC_COUNT: 2,
1229
          },
1230
        constants.ISPECS_MAX: {
1231
          constants.ISPEC_NIC_COUNT: 3,
1232
          },
1233
        },
1234
      constants.ISPECS_STD: {
1235
        constants.ISPEC_CPU_COUNT: 2,
1236
        constants.ISPEC_NIC_COUNT: 3,
1237
        },
1238
      constants.IPOLICY_SPINDLE_RATIO: 1.3,
1239
      constants.IPOLICY_DTS: ["templates"],
1240
      }
1241
    for fillall in [False, True]:
1242
      pol0 = cli.CreateIPolicyFromOpts(
1243
        ispecs_mem_size={},
1244
        ispecs_cpu_count={},
1245
        ispecs_disk_count={},
1246
        ispecs_disk_size={},
1247
        ispecs_nic_count={},
1248
        ipolicy_disk_templates=None,
1249
        ipolicy_vcpu_ratio=None,
1250
        ipolicy_spindle_ratio=None,
1251
        fill_all=fillall
1252
        )
1253
      if fillall:
1254
        self.assertEqual(pol0, constants.IPOLICY_DEFAULTS)
1255
      else:
1256
        self.assertEqual(pol0, exp_pol0)
1257
      pol1 = cli.CreateIPolicyFromOpts(
1258
        ispecs_mem_size={"max": "12g"},
1259
        ispecs_cpu_count={"min": 2, "std": 2},
1260
        ispecs_disk_count={"min": 1, "max": 2, "std": 2},
1261
        ispecs_disk_size={},
1262
        ispecs_nic_count={},
1263
        ipolicy_disk_templates=None,
1264
        ipolicy_vcpu_ratio=3.1,
1265
        ipolicy_spindle_ratio=None,
1266
        fill_all=fillall
1267
        )
1268
      if fillall:
1269
        self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1270
                                        exp_pol1, pol1)
1271
      else:
1272
        self.assertEqual(pol1, exp_pol1)
1273
      pol2 = cli.CreateIPolicyFromOpts(
1274
        ispecs_mem_size={},
1275
        ispecs_cpu_count={"std": 2},
1276
        ispecs_disk_count={},
1277
        ispecs_disk_size={"min": "0.5g"},
1278
        ispecs_nic_count={"min": 2, "max": 3, "std": 3},
1279
        ipolicy_disk_templates=["templates"],
1280
        ipolicy_vcpu_ratio=None,
1281
        ipolicy_spindle_ratio=1.3,
1282
        fill_all=fillall
1283
        )
1284
      if fillall:
1285
        self._RecursiveCheckMergedDicts(constants.IPOLICY_DEFAULTS,
1286
                                        exp_pol2, pol2)
1287
      else:
1288
        self.assertEqual(pol2, exp_pol2)
1289

    
1290
  def testInvalidPolicies(self):
1291
    self.assertRaises(errors.TypeEnforcementError, cli.CreateIPolicyFromOpts,
1292
                      ispecs_mem_size={}, ispecs_cpu_count={},
1293
                      ispecs_disk_count={}, ispecs_disk_size={"std": 1},
1294
                      ispecs_nic_count={}, ipolicy_disk_templates=None,
1295
                      ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None,
1296
                      group_ipolicy=True)
1297
    self.assertRaises(errors.OpPrereqError, cli.CreateIPolicyFromOpts,
1298
                      ispecs_mem_size={"wrong": "x"}, ispecs_cpu_count={},
1299
                      ispecs_disk_count={}, ispecs_disk_size={},
1300
                      ispecs_nic_count={}, ipolicy_disk_templates=None,
1301
                      ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None)
1302
    self.assertRaises(errors.TypeEnforcementError, cli.CreateIPolicyFromOpts,
1303
                      ispecs_mem_size={}, ispecs_cpu_count={"min": "default"},
1304
                      ispecs_disk_count={}, ispecs_disk_size={},
1305
                      ispecs_nic_count={}, ipolicy_disk_templates=None,
1306
                      ipolicy_vcpu_ratio=None, ipolicy_spindle_ratio=None)
1307

    
1308
  def testAllowedValues(self):
1309
    allowedv = "blah"
1310
    exp_pol1 = {
1311
      constants.ISPECS_MINMAX: {
1312
        constants.ISPECS_MIN: {
1313
          constants.ISPEC_CPU_COUNT: allowedv,
1314
          },
1315
        constants.ISPECS_MAX: {
1316
          },
1317
        },
1318
      constants.ISPECS_STD: {
1319
        },
1320
      constants.IPOLICY_DTS: allowedv,
1321
      constants.IPOLICY_VCPU_RATIO: allowedv,
1322
      constants.IPOLICY_SPINDLE_RATIO: allowedv,
1323
      }
1324
    pol1 = cli.CreateIPolicyFromOpts(ispecs_mem_size={},
1325
                                     ispecs_cpu_count={"min": allowedv},
1326
                                     ispecs_disk_count={},
1327
                                     ispecs_disk_size={},
1328
                                     ispecs_nic_count={},
1329
                                     ipolicy_disk_templates=allowedv,
1330
                                     ipolicy_vcpu_ratio=allowedv,
1331
                                     ipolicy_spindle_ratio=allowedv,
1332
                                     allowed_values=[allowedv])
1333
    self.assertEqual(pol1, exp_pol1)
1334

    
1335

    
1336
if __name__ == "__main__":
1337
  testutils.GanetiTestProgram()