Statistics
| Branch: | Tag: | Revision:

root / qa / qa_utils.py @ 93029a5b

History | View | Annotate | Download (25.6 kB)

1
#
2
#
3

    
4
# Copyright (C) 2007, 2011, 2012, 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Utilities for QA tests.
23

24
"""
25

    
26
import copy
27
import operator
28
import os
29
import random
30
import re
31
import socket
32
import subprocess
33
import sys
34
import tempfile
35
import yaml
36

    
37
try:
38
  import functools
39
except ImportError, err:
40
  raise ImportError("Python 2.5 or higher is required: %s" % err)
41

    
42
from ganeti import utils
43
from ganeti import compat
44
from ganeti import constants
45
from ganeti import ht
46
from ganeti import pathutils
47
from ganeti import vcluster
48

    
49
import colors
50
import qa_config
51
import qa_error
52

    
53

    
54
_INFO_SEQ = None
55
_WARNING_SEQ = None
56
_ERROR_SEQ = None
57
_RESET_SEQ = None
58

    
59
_MULTIPLEXERS = {}
60

    
61
#: Unique ID per QA run
62
_RUN_UUID = utils.NewUUID()
63

    
64
#: Path to the QA query output log file
65
_QA_OUTPUT = pathutils.GetLogFilename("qa-output")
66

    
67

    
68
(INST_DOWN,
69
 INST_UP) = range(500, 502)
70

    
71
(FIRST_ARG,
72
 RETURN_VALUE) = range(1000, 1002)
73

    
74

    
75
def _SetupColours():
76
  """Initializes the colour constants.
77

78
  """
79
  # pylint: disable=W0603
80
  # due to global usage
81
  global _INFO_SEQ, _WARNING_SEQ, _ERROR_SEQ, _RESET_SEQ
82

    
83
  # Don't use colours if stdout isn't a terminal
84
  if not sys.stdout.isatty():
85
    return
86

    
87
  try:
88
    import curses
89
  except ImportError:
90
    # Don't use colours if curses module can't be imported
91
    return
92

    
93
  try:
94
    curses.setupterm()
95
  except curses.error:
96
    # Probably a non-standard terminal, don't use colours then
97
    return
98

    
99
  _RESET_SEQ = curses.tigetstr("op")
100

    
101
  setaf = curses.tigetstr("setaf")
102
  _INFO_SEQ = curses.tparm(setaf, curses.COLOR_GREEN)
103
  _WARNING_SEQ = curses.tparm(setaf, curses.COLOR_YELLOW)
104
  _ERROR_SEQ = curses.tparm(setaf, curses.COLOR_RED)
105

    
106

    
107
_SetupColours()
108

    
109

    
110
def AssertIn(item, sequence, msg=""):
111
  """Raises an error when item is not in sequence.
112

113
  """
114
  if item not in sequence:
115
    if msg:
116
      raise qa_error.Error("%s: %r not in %r" % (msg, item, sequence))
117
    else:
118
      raise qa_error.Error("%r not in %r" % (item, sequence))
119

    
120

    
121
def AssertNotIn(item, sequence):
122
  """Raises an error when item is in sequence.
123

124
  """
125
  if item in sequence:
126
    raise qa_error.Error("%r in %r" % (item, sequence))
127

    
128

    
129
def AssertEqual(first, second, msg=""):
130
  """Raises an error when values aren't equal.
131

132
  """
133
  if not first == second:
134
    if msg:
135
      raise qa_error.Error("%s: %r == %r" % (msg, first, second))
136
    else:
137
      raise qa_error.Error("%r == %r" % (first, second))
138

    
139

    
140
def AssertMatch(string, pattern):
141
  """Raises an error when string doesn't match regexp pattern.
142

143
  """
144
  if not re.match(pattern, string):
145
    raise qa_error.Error("%r doesn't match /%r/" % (string, pattern))
146

    
147

    
148
def _GetName(entity, fn):
149
  """Tries to get name of an entity.
150

151
  @type entity: string or dict
152
  @param fn: Function retrieving name from entity
153

154
  """
155
  if isinstance(entity, basestring):
156
    result = entity
157
  else:
158
    result = fn(entity)
159

    
160
  if not ht.TNonEmptyString(result):
161
    raise Exception("Invalid name '%s'" % result)
162

    
163
  return result
164

    
165

    
166
def _AssertRetCode(rcode, fail, cmdstr, nodename):
167
  """Check the return value from a command and possibly raise an exception.
168

169
  """
170
  if fail and rcode == 0:
171
    raise qa_error.Error("Command '%s' on node %s was expected to fail but"
172
                         " didn't" % (cmdstr, nodename))
173
  elif not fail and rcode != 0:
174
    raise qa_error.Error("Command '%s' on node %s failed, exit code %s" %
175
                         (cmdstr, nodename, rcode))
176

    
177

    
178
def AssertCommand(cmd, fail=False, node=None, log_cmd=True):
179
  """Checks that a remote command succeeds.
180

181
  @param cmd: either a string (the command to execute) or a list (to
182
      be converted using L{utils.ShellQuoteArgs} into a string)
183
  @type fail: boolean
184
  @param fail: if the command is expected to fail instead of succeeding
185
  @param node: if passed, it should be the node on which the command
186
      should be executed, instead of the master node (can be either a
187
      dict or a string)
188
  @param log_cmd: if False, the command won't be logged (simply passed to
189
      StartSSH)
190
  @return: the return code of the command
191
  @raise qa_error.Error: if the command fails when it shouldn't or vice versa
192

193
  """
194
  if node is None:
195
    node = qa_config.GetMasterNode()
196

    
197
  nodename = _GetName(node, operator.attrgetter("primary"))
198

    
199
  if isinstance(cmd, basestring):
200
    cmdstr = cmd
201
  else:
202
    cmdstr = utils.ShellQuoteArgs(cmd)
203

    
204
  rcode = StartSSH(nodename, cmdstr, log_cmd=log_cmd).wait()
205
  _AssertRetCode(rcode, fail, cmdstr, nodename)
206

    
207
  return rcode
208

    
209

    
210
def AssertRedirectedCommand(cmd, fail=False, node=None, log_cmd=True):
211
  """Executes a command with redirected output.
212

213
  The log will go to the qa-output log file in the ganeti log
214
  directory on the node where the command is executed. The fail and
215
  node parameters are passed unchanged to AssertCommand.
216

217
  @param cmd: the command to be executed, as a list; a string is not
218
      supported
219

220
  """
221
  if not isinstance(cmd, list):
222
    raise qa_error.Error("Non-list passed to AssertRedirectedCommand")
223
  ofile = utils.ShellQuote(_QA_OUTPUT)
224
  cmdstr = utils.ShellQuoteArgs(cmd)
225
  AssertCommand("echo ---- $(date) %s ---- >> %s" % (cmdstr, ofile),
226
                fail=False, node=node, log_cmd=False)
227
  return AssertCommand(cmdstr + " >> %s" % ofile,
228
                       fail=fail, node=node, log_cmd=log_cmd)
229

    
230

    
231
def GetSSHCommand(node, cmd, strict=True, opts=None, tty=None):
232
  """Builds SSH command to be executed.
233

234
  @type node: string
235
  @param node: node the command should run on
236
  @type cmd: string
237
  @param cmd: command to be executed in the node; if None or empty
238
      string, no command will be executed
239
  @type strict: boolean
240
  @param strict: whether to enable strict host key checking
241
  @type opts: list
242
  @param opts: list of additional options
243
  @type tty: boolean or None
244
  @param tty: if we should use tty; if None, will be auto-detected
245

246
  """
247
  args = ["ssh", "-oEscapeChar=none", "-oBatchMode=yes", "-lroot"]
248

    
249
  if tty is None:
250
    tty = sys.stdout.isatty()
251

    
252
  if tty:
253
    args.append("-t")
254

    
255
  if strict:
256
    tmp = "yes"
257
  else:
258
    tmp = "no"
259
  args.append("-oStrictHostKeyChecking=%s" % tmp)
260
  args.append("-oClearAllForwardings=yes")
261
  args.append("-oForwardAgent=yes")
262
  if opts:
263
    args.extend(opts)
264
  if node in _MULTIPLEXERS:
265
    spath = _MULTIPLEXERS[node][0]
266
    args.append("-oControlPath=%s" % spath)
267
    args.append("-oControlMaster=no")
268

    
269
  (vcluster_master, vcluster_basedir) = \
270
    qa_config.GetVclusterSettings()
271

    
272
  if vcluster_master:
273
    args.append(vcluster_master)
274
    args.append("%s/%s/cmd" % (vcluster_basedir, node))
275

    
276
    if cmd:
277
      # For virtual clusters the whole command must be wrapped using the "cmd"
278
      # script, as that script sets a number of environment variables. If the
279
      # command contains shell meta characters the whole command needs to be
280
      # quoted.
281
      args.append(utils.ShellQuote(cmd))
282
  else:
283
    args.append(node)
284

    
285
    if cmd:
286
      args.append(cmd)
287

    
288
  return args
289

    
290

    
291
def StartLocalCommand(cmd, _nolog_opts=False, log_cmd=True, **kwargs):
292
  """Starts a local command.
293

294
  """
295
  if log_cmd:
296
    if _nolog_opts:
297
      pcmd = [i for i in cmd if not i.startswith("-")]
298
    else:
299
      pcmd = cmd
300
    print "%s %s" % (colors.colorize("Command:", colors.CYAN),
301
                     utils.ShellQuoteArgs(pcmd))
302
  return subprocess.Popen(cmd, shell=False, **kwargs)
303

    
304

    
305
def StartSSH(node, cmd, strict=True, log_cmd=True):
306
  """Starts SSH.
307

308
  """
309
  return StartLocalCommand(GetSSHCommand(node, cmd, strict=strict),
310
                           _nolog_opts=True, log_cmd=log_cmd)
311

    
312

    
313
def StartMultiplexer(node):
314
  """Starts a multiplexer command.
315

316
  @param node: the node for which to open the multiplexer
317

318
  """
319
  if node in _MULTIPLEXERS:
320
    return
321

    
322
  # Note: yes, we only need mktemp, since we'll remove the file anyway
323
  sname = tempfile.mktemp(prefix="ganeti-qa-multiplexer.")
324
  utils.RemoveFile(sname)
325
  opts = ["-N", "-oControlPath=%s" % sname, "-oControlMaster=yes"]
326
  print "Created socket at %s" % sname
327
  child = StartLocalCommand(GetSSHCommand(node, None, opts=opts))
328
  _MULTIPLEXERS[node] = (sname, child)
329

    
330

    
331
def CloseMultiplexers():
332
  """Closes all current multiplexers and cleans up.
333

334
  """
335
  for node in _MULTIPLEXERS.keys():
336
    (sname, child) = _MULTIPLEXERS.pop(node)
337
    utils.KillProcess(child.pid, timeout=10, waitpid=True)
338
    utils.RemoveFile(sname)
339

    
340

    
341
def GetCommandOutput(node, cmd, tty=None, fail=False):
342
  """Returns the output of a command executed on the given node.
343

344
  @type node: string
345
  @param node: node the command should run on
346
  @type cmd: string
347
  @param cmd: command to be executed in the node (cannot be empty or None)
348
  @type tty: bool or None
349
  @param tty: if we should use tty; if None, it will be auto-detected
350
  @type fail: bool
351
  @param fail: whether the command is expected to fail
352
  """
353
  assert cmd
354
  p = StartLocalCommand(GetSSHCommand(node, cmd, tty=tty),
355
                        stdout=subprocess.PIPE)
356
  rcode = p.wait()
357
  _AssertRetCode(rcode, fail, cmd, node)
358
  return p.stdout.read()
359

    
360

    
361
def GetObjectInfo(infocmd):
362
  """Get and parse information about a Ganeti object.
363

364
  @type infocmd: list of strings
365
  @param infocmd: command to be executed, e.g. ["gnt-cluster", "info"]
366
  @return: the information parsed, appropriately stored in dictionaries,
367
      lists...
368

369
  """
370
  master = qa_config.GetMasterNode()
371
  cmdline = utils.ShellQuoteArgs(infocmd)
372
  info_out = GetCommandOutput(master.primary, cmdline)
373
  return yaml.load(info_out)
374

    
375

    
376
def UploadFile(node, src):
377
  """Uploads a file to a node and returns the filename.
378

379
  Caller needs to remove the returned file on the node when it's not needed
380
  anymore.
381

382
  """
383
  # Make sure nobody else has access to it while preserving local permissions
384
  mode = os.stat(src).st_mode & 0700
385

    
386
  cmd = ('tmp=$(mktemp --tmpdir gnt.XXXXXX) && '
387
         'chmod %o "${tmp}" && '
388
         '[[ -f "${tmp}" ]] && '
389
         'cat > "${tmp}" && '
390
         'echo "${tmp}"') % mode
391

    
392
  f = open(src, "r")
393
  try:
394
    p = subprocess.Popen(GetSSHCommand(node, cmd), shell=False, stdin=f,
395
                         stdout=subprocess.PIPE)
396
    AssertEqual(p.wait(), 0)
397

    
398
    # Return temporary filename
399
    return p.stdout.read().strip()
400
  finally:
401
    f.close()
402

    
403

    
404
def UploadData(node, data, mode=0600, filename=None):
405
  """Uploads data to a node and returns the filename.
406

407
  Caller needs to remove the returned file on the node when it's not needed
408
  anymore.
409

410
  """
411
  if filename:
412
    tmp = "tmp=%s" % utils.ShellQuote(filename)
413
  else:
414
    tmp = ('tmp=$(mktemp --tmpdir gnt.XXXXXX) && '
415
           'chmod %o "${tmp}"') % mode
416
  cmd = ("%s && "
417
         "[[ -f \"${tmp}\" ]] && "
418
         "cat > \"${tmp}\" && "
419
         "echo \"${tmp}\"") % tmp
420

    
421
  p = subprocess.Popen(GetSSHCommand(node, cmd), shell=False,
422
                       stdin=subprocess.PIPE, stdout=subprocess.PIPE)
423
  p.stdin.write(data)
424
  p.stdin.close()
425
  AssertEqual(p.wait(), 0)
426

    
427
  # Return temporary filename
428
  return p.stdout.read().strip()
429

    
430

    
431
def BackupFile(node, path):
432
  """Creates a backup of a file on the node and returns the filename.
433

434
  Caller needs to remove the returned file on the node when it's not needed
435
  anymore.
436

437
  """
438
  vpath = MakeNodePath(node, path)
439

    
440
  cmd = ("tmp=$(mktemp .gnt.XXXXXX --tmpdir=$(dirname %s)) && "
441
         "[[ -f \"$tmp\" ]] && "
442
         "cp %s $tmp && "
443
         "echo $tmp") % (utils.ShellQuote(vpath), utils.ShellQuote(vpath))
444

    
445
  # Return temporary filename
446
  result = GetCommandOutput(node, cmd).strip()
447

    
448
  print "Backup filename: %s" % result
449

    
450
  return result
451

    
452

    
453
def ResolveInstanceName(instance):
454
  """Gets the full name of an instance.
455

456
  @type instance: string
457
  @param instance: Instance name
458

459
  """
460
  info = GetObjectInfo(["gnt-instance", "info", instance])
461
  return info[0]["Instance name"]
462

    
463

    
464
def ResolveNodeName(node):
465
  """Gets the full name of a node.
466

467
  """
468
  info = GetObjectInfo(["gnt-node", "info", node.primary])
469
  return info[0]["Node name"]
470

    
471

    
472
def GetNodeInstances(node, secondaries=False):
473
  """Gets a list of instances on a node.
474

475
  """
476
  master = qa_config.GetMasterNode()
477
  node_name = ResolveNodeName(node)
478

    
479
  # Get list of all instances
480
  cmd = ["gnt-instance", "list", "--separator=:", "--no-headers",
481
         "--output=name,pnode,snodes"]
482
  output = GetCommandOutput(master.primary, utils.ShellQuoteArgs(cmd))
483

    
484
  instances = []
485
  for line in output.splitlines():
486
    (name, pnode, snodes) = line.split(":", 2)
487
    if ((not secondaries and pnode == node_name) or
488
        (secondaries and node_name in snodes.split(","))):
489
      instances.append(name)
490

    
491
  return instances
492

    
493

    
494
def _SelectQueryFields(rnd, fields):
495
  """Generates a list of fields for query tests.
496

497
  """
498
  # Create copy for shuffling
499
  fields = list(fields)
500
  rnd.shuffle(fields)
501

    
502
  # Check all fields
503
  yield fields
504
  yield sorted(fields)
505

    
506
  # Duplicate fields
507
  yield fields + fields
508

    
509
  # Check small groups of fields
510
  while fields:
511
    yield [fields.pop() for _ in range(rnd.randint(2, 10)) if fields]
512

    
513

    
514
def _List(listcmd, fields, names):
515
  """Runs a list command.
516

517
  """
518
  master = qa_config.GetMasterNode()
519

    
520
  cmd = [listcmd, "list", "--separator=|", "--no-headers",
521
         "--output", ",".join(fields)]
522

    
523
  if names:
524
    cmd.extend(names)
525

    
526
  return GetCommandOutput(master.primary,
527
                          utils.ShellQuoteArgs(cmd)).splitlines()
528

    
529

    
530
def GenericQueryTest(cmd, fields, namefield="name", test_unknown=True):
531
  """Runs a number of tests on query commands.
532

533
  @param cmd: Command name
534
  @param fields: List of field names
535

536
  """
537
  rnd = random.Random(hash(cmd))
538

    
539
  fields = list(fields)
540
  rnd.shuffle(fields)
541

    
542
  # Test a number of field combinations
543
  for testfields in _SelectQueryFields(rnd, fields):
544
    AssertRedirectedCommand([cmd, "list", "--output", ",".join(testfields)])
545

    
546
  if namefield is not None:
547
    namelist_fn = compat.partial(_List, cmd, [namefield])
548

    
549
    # When no names were requested, the list must be sorted
550
    names = namelist_fn(None)
551
    AssertEqual(names, utils.NiceSort(names))
552

    
553
    # When requesting specific names, the order must be kept
554
    revnames = list(reversed(names))
555
    AssertEqual(namelist_fn(revnames), revnames)
556

    
557
    randnames = list(names)
558
    rnd.shuffle(randnames)
559
    AssertEqual(namelist_fn(randnames), randnames)
560

    
561
  if test_unknown:
562
    # Listing unknown items must fail
563
    AssertCommand([cmd, "list", "this.name.certainly.does.not.exist"],
564
                  fail=True)
565

    
566
  # Check exit code for listing unknown field
567
  AssertEqual(AssertRedirectedCommand([cmd, "list",
568
                                       "--output=field/does/not/exist"],
569
                                      fail=True),
570
              constants.EXIT_UNKNOWN_FIELD)
571

    
572

    
573
def GenericQueryFieldsTest(cmd, fields):
574
  master = qa_config.GetMasterNode()
575

    
576
  # Listing fields
577
  AssertRedirectedCommand([cmd, "list-fields"])
578
  AssertRedirectedCommand([cmd, "list-fields"] + fields)
579

    
580
  # Check listed fields (all, must be sorted)
581
  realcmd = [cmd, "list-fields", "--separator=|", "--no-headers"]
582
  output = GetCommandOutput(master.primary,
583
                            utils.ShellQuoteArgs(realcmd)).splitlines()
584
  AssertEqual([line.split("|", 1)[0] for line in output],
585
              utils.NiceSort(fields))
586

    
587
  # Check exit code for listing unknown field
588
  AssertEqual(AssertCommand([cmd, "list-fields", "field/does/not/exist"],
589
                            fail=True),
590
              constants.EXIT_UNKNOWN_FIELD)
591

    
592

    
593
def _FormatWithColor(text, seq):
594
  if not seq:
595
    return text
596
  return "%s%s%s" % (seq, text, _RESET_SEQ)
597

    
598

    
599
FormatWarning = lambda text: _FormatWithColor(text, _WARNING_SEQ)
600
FormatError = lambda text: _FormatWithColor(text, _ERROR_SEQ)
601
FormatInfo = lambda text: _FormatWithColor(text, _INFO_SEQ)
602

    
603

    
604
def AddToEtcHosts(hostnames):
605
  """Adds hostnames to /etc/hosts.
606

607
  @param hostnames: List of hostnames first used A records, all other CNAMEs
608

609
  """
610
  master = qa_config.GetMasterNode()
611
  tmp_hosts = UploadData(master.primary, "", mode=0644)
612

    
613
  data = []
614
  for localhost in ("::1", "127.0.0.1"):
615
    data.append("%s %s" % (localhost, " ".join(hostnames)))
616

    
617
  try:
618
    AssertCommand("{ cat %s && echo -e '%s'; } > %s && mv %s %s" %
619
                  (utils.ShellQuote(pathutils.ETC_HOSTS),
620
                   "\\n".join(data),
621
                   utils.ShellQuote(tmp_hosts),
622
                   utils.ShellQuote(tmp_hosts),
623
                   utils.ShellQuote(pathutils.ETC_HOSTS)))
624
  except Exception:
625
    AssertCommand(["rm", "-f", tmp_hosts])
626
    raise
627

    
628

    
629
def RemoveFromEtcHosts(hostnames):
630
  """Remove hostnames from /etc/hosts.
631

632
  @param hostnames: List of hostnames first used A records, all other CNAMEs
633

634
  """
635
  master = qa_config.GetMasterNode()
636
  tmp_hosts = UploadData(master.primary, "", mode=0644)
637
  quoted_tmp_hosts = utils.ShellQuote(tmp_hosts)
638

    
639
  sed_data = " ".join(hostnames)
640
  try:
641
    AssertCommand((r"sed -e '/^\(::1\|127\.0\.0\.1\)\s\+%s/d' %s > %s"
642
                   r" && mv %s %s") %
643
                   (sed_data, utils.ShellQuote(pathutils.ETC_HOSTS),
644
                    quoted_tmp_hosts, quoted_tmp_hosts,
645
                    utils.ShellQuote(pathutils.ETC_HOSTS)))
646
  except Exception:
647
    AssertCommand(["rm", "-f", tmp_hosts])
648
    raise
649

    
650

    
651
def RunInstanceCheck(instance, running):
652
  """Check if instance is running or not.
653

654
  """
655
  instance_name = _GetName(instance, operator.attrgetter("name"))
656

    
657
  script = qa_config.GetInstanceCheckScript()
658
  if not script:
659
    return
660

    
661
  master_node = qa_config.GetMasterNode()
662

    
663
  # Build command to connect to master node
664
  master_ssh = GetSSHCommand(master_node.primary, "--")
665

    
666
  if running:
667
    running_shellval = "1"
668
    running_text = ""
669
  else:
670
    running_shellval = ""
671
    running_text = "not "
672

    
673
  print FormatInfo("Checking if instance '%s' is %srunning" %
674
                   (instance_name, running_text))
675

    
676
  args = [script, instance_name]
677
  env = {
678
    "PATH": constants.HOOKS_PATH,
679
    "RUN_UUID": _RUN_UUID,
680
    "MASTER_SSH": utils.ShellQuoteArgs(master_ssh),
681
    "INSTANCE_NAME": instance_name,
682
    "INSTANCE_RUNNING": running_shellval,
683
    }
684

    
685
  result = os.spawnve(os.P_WAIT, script, args, env)
686
  if result != 0:
687
    raise qa_error.Error("Instance check failed with result %s" % result)
688

    
689

    
690
def _InstanceCheckInner(expected, instarg, args, result):
691
  """Helper function used by L{InstanceCheck}.
692

693
  """
694
  if instarg == FIRST_ARG:
695
    instance = args[0]
696
  elif instarg == RETURN_VALUE:
697
    instance = result
698
  else:
699
    raise Exception("Invalid value '%s' for instance argument" % instarg)
700

    
701
  if expected in (INST_DOWN, INST_UP):
702
    RunInstanceCheck(instance, (expected == INST_UP))
703
  elif expected is not None:
704
    raise Exception("Invalid value '%s'" % expected)
705

    
706

    
707
def InstanceCheck(before, after, instarg):
708
  """Decorator to check instance status before and after test.
709

710
  @param before: L{INST_DOWN} if instance must be stopped before test,
711
    L{INST_UP} if instance must be running before test, L{None} to not check.
712
  @param after: L{INST_DOWN} if instance must be stopped after test,
713
    L{INST_UP} if instance must be running after test, L{None} to not check.
714
  @param instarg: L{FIRST_ARG} to use first argument to test as instance (a
715
    dictionary), L{RETURN_VALUE} to use return value (disallows pre-checks)
716

717
  """
718
  def decorator(fn):
719
    @functools.wraps(fn)
720
    def wrapper(*args, **kwargs):
721
      _InstanceCheckInner(before, instarg, args, NotImplemented)
722

    
723
      result = fn(*args, **kwargs)
724

    
725
      _InstanceCheckInner(after, instarg, args, result)
726

    
727
      return result
728
    return wrapper
729
  return decorator
730

    
731

    
732
def GetNonexistentGroups(count):
733
  """Gets group names which shouldn't exist on the cluster.
734

735
  @param count: Number of groups to get
736
  @rtype: integer
737

738
  """
739
  return GetNonexistentEntityNames(count, "groups", "group")
740

    
741

    
742
def GetNonexistentEntityNames(count, name_config, name_prefix):
743
  """Gets entity names which shouldn't exist on the cluster.
744

745
  The actualy names can refer to arbitrary entities (for example
746
  groups, networks).
747

748
  @param count: Number of names to get
749
  @rtype: integer
750
  @param name_config: name of the leaf in the config containing
751
    this entity's configuration, including a 'inexistent-'
752
    element
753
  @rtype: string
754
  @param name_prefix: prefix of the entity's names, used to compose
755
    the default values; for example for groups, the prefix is
756
    'group' and the generated names are then group1, group2, ...
757
  @rtype: string
758

759
  """
760
  entities = qa_config.get(name_config, {})
761

    
762
  default = [name_prefix + str(i) for i in range(count)]
763
  assert count <= len(default)
764

    
765
  name_config_inexistent = "inexistent-" + name_config
766
  candidates = entities.get(name_config_inexistent, default)[:count]
767

    
768
  if len(candidates) < count:
769
    raise Exception("At least %s non-existent %s are needed" %
770
                    (count, name_config))
771

    
772
  return candidates
773

    
774

    
775
def MakeNodePath(node, path):
776
  """Builds an absolute path for a virtual node.
777

778
  @type node: string or L{qa_config._QaNode}
779
  @param node: Node
780
  @type path: string
781
  @param path: Path without node-specific prefix
782

783
  """
784
  (_, basedir) = qa_config.GetVclusterSettings()
785

    
786
  if isinstance(node, basestring):
787
    name = node
788
  else:
789
    name = node.primary
790

    
791
  if basedir:
792
    assert path.startswith("/")
793
    return "%s%s" % (vcluster.MakeNodeRoot(basedir, name), path)
794
  else:
795
    return path
796

    
797

    
798
def _GetParameterOptions(specs):
799
  """Helper to build policy options."""
800
  values = ["%s=%s" % (par, val)
801
            for (par, val) in specs.items()]
802
  return ",".join(values)
803

    
804

    
805
def TestSetISpecs(new_specs=None, diff_specs=None, get_policy_fn=None,
806
                  build_cmd_fn=None, fail=False, old_values=None):
807
  """Change instance specs for an object.
808

809
  At most one of new_specs or diff_specs can be specified.
810

811
  @type new_specs: dict
812
  @param new_specs: new complete specs, in the same format returned by
813
      L{ParseIPolicy}.
814
  @type diff_specs: dict
815
  @param diff_specs: partial specs, it can be an incomplete specifications, but
816
      if min/max specs are specified, their number must match the number of the
817
      existing specs
818
  @type get_policy_fn: function
819
  @param get_policy_fn: function that returns the current policy as in
820
      L{ParseIPolicy}
821
  @type build_cmd_fn: function
822
  @param build_cmd_fn: function that return the full command line from the
823
      options alone
824
  @type fail: bool
825
  @param fail: if the change is expected to fail
826
  @type old_values: tuple
827
  @param old_values: (old_policy, old_specs), as returned by
828
     L{ParseIPolicy}
829
  @return: same as L{ParseIPolicy}
830

831
  """
832
  assert get_policy_fn is not None
833
  assert build_cmd_fn is not None
834
  assert new_specs is None or diff_specs is None
835

    
836
  if old_values:
837
    (old_policy, old_specs) = old_values
838
  else:
839
    (old_policy, old_specs) = get_policy_fn()
840

    
841
  if diff_specs:
842
    new_specs = copy.deepcopy(old_specs)
843
    if constants.ISPECS_MINMAX in diff_specs:
844
      AssertEqual(len(new_specs[constants.ISPECS_MINMAX]),
845
                  len(diff_specs[constants.ISPECS_MINMAX]))
846
      for (new_minmax, diff_minmax) in zip(new_specs[constants.ISPECS_MINMAX],
847
                                           diff_specs[constants.ISPECS_MINMAX]):
848
        for (key, parvals) in diff_minmax.items():
849
          for (par, val) in parvals.items():
850
            new_minmax[key][par] = val
851
    for (par, val) in diff_specs.get(constants.ISPECS_STD, {}).items():
852
      new_specs[constants.ISPECS_STD][par] = val
853

    
854
  if new_specs:
855
    cmd = []
856
    if (diff_specs is None or constants.ISPECS_MINMAX in diff_specs):
857
      minmax_opt_items = []
858
      for minmax in new_specs[constants.ISPECS_MINMAX]:
859
        minmax_opts = []
860
        for key in ["min", "max"]:
861
          keyopt = _GetParameterOptions(minmax[key])
862
          minmax_opts.append("%s:%s" % (key, keyopt))
863
        minmax_opt_items.append("/".join(minmax_opts))
864
      cmd.extend([
865
        "--ipolicy-bounds-specs",
866
        "//".join(minmax_opt_items)
867
        ])
868
    if diff_specs is None:
869
      std_source = new_specs
870
    else:
871
      std_source = diff_specs
872
    std_opt = _GetParameterOptions(std_source.get("std", {}))
873
    if std_opt:
874
      cmd.extend(["--ipolicy-std-specs", std_opt])
875
    AssertCommand(build_cmd_fn(cmd), fail=fail)
876

    
877
    # Check the new state
878
    (eff_policy, eff_specs) = get_policy_fn()
879
    AssertEqual(eff_policy, old_policy)
880
    if fail:
881
      AssertEqual(eff_specs, old_specs)
882
    else:
883
      AssertEqual(eff_specs, new_specs)
884

    
885
  else:
886
    (eff_policy, eff_specs) = (old_policy, old_specs)
887

    
888
  return (eff_policy, eff_specs)
889

    
890

    
891
def ParseIPolicy(policy):
892
  """Parse and split instance an instance policy.
893

894
  @type policy: dict
895
  @param policy: policy, as returned by L{GetObjectInfo}
896
  @rtype: tuple
897
  @return: (policy, specs), where:
898
      - policy is a dictionary of the policy values, instance specs excluded
899
      - specs is a dictionary containing only the specs, using the internal
900
        format (see L{constants.IPOLICY_DEFAULTS} for an example)
901

902
  """
903
  ret_specs = {}
904
  ret_policy = {}
905
  for (key, val) in policy.items():
906
    if key == "bounds specs":
907
      ret_specs[constants.ISPECS_MINMAX] = []
908
      for minmax in val:
909
        ret_minmax = {}
910
        for key in minmax:
911
          keyparts = key.split("/", 1)
912
          assert len(keyparts) > 1
913
          ret_minmax[keyparts[0]] = minmax[key]
914
        ret_specs[constants.ISPECS_MINMAX].append(ret_minmax)
915
    elif key == constants.ISPECS_STD:
916
      ret_specs[key] = val
917
    else:
918
      ret_policy[key] = val
919
  return (ret_policy, ret_specs)
920

    
921

    
922
def UsesIPv6Connection(host, port):
923
  """Returns True if the connection to a given host/port could go through IPv6.
924

925
  """
926
  return any(t[0] == socket.AF_INET6 for t in socket.getaddrinfo(host, port))