Modify Disk.GetNodes() to support LD_FILE
[ganeti-local] / qa / qa_utils.py
index 8433060..2005634 100644 (file)
@@ -1,3 +1,6 @@
+#
+#
+
 # Copyright (C) 2007 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
 # Copyright (C) 2007 Google Inc.
 #
 # This program is free software; you can redistribute it and/or modify
@@ -36,12 +39,20 @@ _ERROR_SEQ = None
 _RESET_SEQ = None
 
 
 _RESET_SEQ = None
 
 
+# List of all hooks
+_hooks = []
+
+
 def _SetupColours():
   """Initializes the colour constants.
 
   """
   global _INFO_SEQ, _WARNING_SEQ, _ERROR_SEQ, _RESET_SEQ
 
 def _SetupColours():
   """Initializes the colour constants.
 
   """
   global _INFO_SEQ, _WARNING_SEQ, _ERROR_SEQ, _RESET_SEQ
 
+  # Don't use colours if stdout isn't a terminal
+  if not sys.stdout.isatty():
+    return
+
   try:
     import curses
   except ImportError:
   try:
     import curses
   except ImportError:
@@ -61,17 +72,30 @@ def _SetupColours():
 _SetupColours()
 
 
 _SetupColours()
 
 
-def AssertEqual(first, second, msg=None):
+def AssertEqual(first, second):
   """Raises an error when values aren't equal.
 
   """
   if not first == second:
   """Raises an error when values aren't equal.
 
   """
   if not first == second:
-    raise qa_error.Error(msg or '%r == %r' % (first, second))
+    raise qa_error.Error('%r == %r' % (first, second))
+
+
+def AssertNotEqual(first, second):
+  """Raises an error when values are equal.
+
+  """
+  if not first != second:
+    raise qa_error.Error('%r != %r' % (first, second))
 
 
 def GetSSHCommand(node, cmd, strict=True):
   """Builds SSH command to be executed.
 
 
 
 def GetSSHCommand(node, cmd, strict=True):
   """Builds SSH command to be executed.
 
+  Args:
+  - node: Node the command should run on
+  - cmd: Command to be executed as a list with all parameters
+  - strict: Whether to enable strict host key checking
+
   """
   args = [ 'ssh', '-oEscapeChar=none', '-oBatchMode=yes', '-l', 'root' ]
 
   """
   args = [ 'ssh', '-oEscapeChar=none', '-oBatchMode=yes', '-l', 'root' ]
 
@@ -158,7 +182,7 @@ def ResolveInstanceName(instance):
   """Gets the full name of an instance.
 
   """
   """Gets the full name of an instance.
 
   """
-  return _ResolveName(['gnt-instance', 'info', instance['info']],
+  return _ResolveName(['gnt-instance', 'info', instance['name']],
                       'Instance name')
 
 
                       'Instance name')
 
 
@@ -175,7 +199,6 @@ def GetNodeInstances(node, secondaries=False):
 
   """
   master = qa_config.GetMasterNode()
 
   """
   master = qa_config.GetMasterNode()
-
   node_name = ResolveNodeName(node)
 
   # Get list of all instances
   node_name = ResolveNodeName(node)
 
   # Get list of all instances
@@ -193,29 +216,96 @@ def GetNodeInstances(node, secondaries=False):
   return instances
 
 
   return instances
 
 
-def _PrintWithColor(text, seq):
-  f = sys.stdout
+def _FormatWithColor(text, seq):
+  if not seq:
+    return text
+  return "%s%s%s" % (seq, text, _RESET_SEQ)
 
 
-  if not f.isatty():
-    seq = None
 
 
-  if seq:
-    f.write(seq)
+FormatWarning = lambda text: _FormatWithColor(text, _WARNING_SEQ)
+FormatError = lambda text: _FormatWithColor(text, _ERROR_SEQ)
+FormatInfo = lambda text: _FormatWithColor(text, _INFO_SEQ)
 
 
-  f.write(text)
-  f.write("\n")
 
 
-  if seq:
-    f.write(_RESET_SEQ)
+def LoadHooks():
+  """Load all QA hooks.
+
+  """
+  hooks_dir = qa_config.get('options', {}).get('hooks-dir', None)
+  if not hooks_dir:
+    return
+  if hooks_dir not in sys.path:
+    sys.path.insert(0, hooks_dir)
+  for name in utils.ListVisibleFiles(hooks_dir):
+    if name.endswith('.py'):
+      # Load and instanciate hook
+      print "Loading hook %s" % name
+      _hooks.append(__import__(name[:-3], None, None, ['']).hook())
 
 
 
 
-def PrintWarning(text):
-  return _PrintWithColor(text, _WARNING_SEQ)
+class QaHookContext:
+  """Definition of context passed to hooks.
 
 
+  """
+  name = None
+  phase = None
+  success = None
+  args = None
+  kwargs = None
 
 
-def PrintError(f, text):
-  return _PrintWithColor(text, _ERROR_SEQ)
 
 
+def _CallHooks(ctx):
+  """Calls all hooks with the given context.
 
 
-def PrintInfo(f, text):
-  return _PrintWithColor(text, _INFO_SEQ)
+  """
+  if not _hooks:
+    return
+
+  name = "%s-%s" % (ctx.phase, ctx.name)
+  if ctx.success is not None:
+    msg = "%s (success=%s)" % (name, ctx.success)
+  else:
+    msg = name
+  print FormatInfo("Begin %s" % msg)
+  for hook in _hooks:
+    hook.run(ctx)
+  print FormatInfo("End %s" % name)
+
+
+def DefineHook(name):
+  """Wraps a function with calls to hooks.
+
+  Usage: prefix function with @qa_utils.DefineHook(...)
+
+  This is based on PEP 318, "Decorators for Functions and Methods".
+
+  """
+  def wrapper(fn):
+    def new_f(*args, **kwargs):
+      # Create context
+      ctx = QaHookContext()
+      ctx.name = name
+      ctx.phase = 'pre'
+      ctx.args = args
+      ctx.kwargs = kwargs
+
+      _CallHooks(ctx)
+      try:
+        ctx.phase = 'post'
+        ctx.success = True
+        try:
+          # Call real function
+          return fn(*args, **kwargs)
+        except:
+          ctx.success = False
+          raise
+      finally:
+        _CallHooks(ctx)
+
+    # Override function metadata
+    new_f.func_name = fn.func_name
+    new_f.func_doc = fn.func_doc
+
+    return new_f
+
+  return wrapper