ConfigWriter: handle the drained node flag
[ganeti-local] / lib / cli.py
index 0534bc7..ce1277d 100644 (file)
@@ -50,10 +50,8 @@ __all__ = ["DEBUG_OPT", "NOHDR_OPT", "SEP_OPT", "GenericMain",
            "ListTags", "AddTags", "RemoveTags", "TAG_SRC_OPT",
            "FormatError", "SplitNodeOption", "SubmitOrSend",
            "JobSubmittedException", "FormatTimestamp", "ParseTimespec",
-           "ValidateBeParams",
-           "ToStderr", "ToStdout",
-           "UsesRPC",
-           "GetOnlineNodes",
+           "ValidateBeParams", "ToStderr", "ToStdout", "UsesRPC",
+           "GetOnlineNodes", "JobExecutor", "SYNC_OPT",
            ]
 
 
@@ -192,6 +190,11 @@ SUBMIT_OPT = make_option("--submit", dest="submit_only",
                          help="Submit the job and return the job ID, but"
                          " don't wait for the job to finish")
 
+SYNC_OPT = make_option("--sync", dest="do_locking",
+                       default=False, action="store_true",
+                       help="Grab locks while doing the queries"
+                       " in order to ensure more consistent results")
+
 
 def ARGS_FIXED(val):
   """Macro-like function denoting a fixed number of arguments"""
@@ -556,7 +559,8 @@ def PollJob(job_id, cl=None, feedback_fn=None):
         if callable(feedback_fn):
           feedback_fn(log_entry[1:])
         else:
-          print "%s %s" % (time.ctime(utils.MergeTime(timestamp)), message)
+          encoded = utils.SafeEncode(message)
+          print "%s %s" % (time.ctime(utils.MergeTime(timestamp)), encoded)
         prev_logmsg_serial = max(prev_logmsg_serial, serial)
 
     # TODO: Handle canceled and archived jobs
@@ -940,9 +944,8 @@ def GetOnlineNodes(nodes, cl=None, nowarn=False):
   if cl is None:
     cl = GetClient()
 
-  op = opcodes.OpQueryNodes(output_fields=["name", "offline"],
-                            names=nodes)
-  result = SubmitOpCode(op, cl=cl)
+  result = cl.QueryNodes(names=nodes, fields=["name", "offline"],
+                         use_locking=False)
   offline = [row[0] for row in result if row[1]]
   if offline and not nowarn:
     ToStderr("Note: skipping offline node(s): %s" % ", ".join(offline))
@@ -989,3 +992,67 @@ def ToStderr(txt, *args):
 
   """
   _ToStream(sys.stderr, txt, *args)
+
+
+class JobExecutor(object):
+  """Class which manages the submission and execution of multiple jobs.
+
+  Note that instances of this class should not be reused between
+  GetResults() calls.
+
+  """
+  def __init__(self, cl=None, verbose=True):
+    self.queue = []
+    if cl is None:
+      cl = GetClient()
+    self.cl = cl
+    self.verbose = verbose
+
+  def QueueJob(self, name, *ops):
+    """Submit a job for execution.
+
+    @type name: string
+    @param name: a description of the job, will be used in WaitJobSet
+    """
+    job_id = SendJob(ops, cl=self.cl)
+    self.queue.append((job_id, name))
+
+  def GetResults(self):
+    """Wait for and return the results of all jobs.
+
+    @rtype: list
+    @return: list of tuples (success, job results), in the same order
+        as the submitted jobs; if a job has failed, instead of the result
+        there will be the error message
+
+    """
+    results = []
+    if self.verbose:
+      ToStdout("Submitted jobs %s", ", ".join(row[0] for row in self.queue))
+    for jid, name in self.queue:
+      if self.verbose:
+        ToStdout("Waiting for job %s for %s...", jid, name)
+      try:
+        job_result = PollJob(jid, cl=self.cl)
+        success = True
+      except (errors.GenericError, luxi.ProtocolError), err:
+        _, job_result = FormatError(err)
+        success = False
+        # the error message will always be shown, verbose or not
+        ToStderr("Job %s for %s has failed: %s", jid, name, job_result)
+
+      results.append((success, job_result))
+    return results
+
+  def WaitOrShow(self, wait):
+    """Wait for job results or only print the job IDs.
+
+    @type wait: boolean
+    @param wait: whether to wait or not
+
+    """
+    if wait:
+      return self.GetResults()
+    else:
+      for jid, name in self.queue:
+        ToStdout("%s: %s", jid, name)