Code and docstring style fixes
[ganeti-local] / lib / luxi.py
index 541e059..2a3adaa 100644 (file)
 
 """Module for the unix socket protocol
 
 
 """Module for the unix socket protocol
 
-This module implements the local unix socket protocl. You only need
+This module implements the local unix socket protocol. You only need
 this module and the opcodes module in the client program in order to
 communicate with the master.
 
 this module and the opcodes module in the client program in order to
 communicate with the master.
 
-The module is also be used by the master daemon.
+The module is also used by the master daemon.
 
 """
 
 
 """
 
@@ -36,6 +36,7 @@ import errno
 
 from ganeti import serializer
 from ganeti import constants
 
 from ganeti import serializer
 from ganeti import constants
+from ganeti import errors
 
 
 KEY_METHOD = 'method'
 
 
 KEY_METHOD = 'method'
@@ -44,6 +45,7 @@ KEY_SUCCESS = "success"
 KEY_RESULT = "result"
 
 REQ_SUBMIT_JOB = "SubmitJob"
 KEY_RESULT = "result"
 
 REQ_SUBMIT_JOB = "SubmitJob"
+REQ_SUBMIT_MANY_JOBS = "SubmitManyJobs"
 REQ_WAIT_FOR_JOB_CHANGE = "WaitForJobChange"
 REQ_CANCEL_JOB = "CancelJob"
 REQ_ARCHIVE_JOB = "ArchiveJob"
 REQ_WAIT_FOR_JOB_CHANGE = "WaitForJobChange"
 REQ_CANCEL_JOB = "CancelJob"
 REQ_ARCHIVE_JOB = "ArchiveJob"
@@ -53,6 +55,9 @@ REQ_QUERY_INSTANCES = "QueryInstances"
 REQ_QUERY_NODES = "QueryNodes"
 REQ_QUERY_EXPORTS = "QueryExports"
 REQ_QUERY_CONFIG_VALUES = "QueryConfigValues"
 REQ_QUERY_NODES = "QueryNodes"
 REQ_QUERY_EXPORTS = "QueryExports"
 REQ_QUERY_CONFIG_VALUES = "QueryConfigValues"
+REQ_QUERY_CLUSTER_INFO = "QueryClusterInfo"
+REQ_QUEUE_SET_DRAIN_FLAG = "SetDrainFlag"
+REQ_SET_WATCHER_PAUSE = "SetWatcherPause"
 
 DEF_CTMO = 10
 DEF_RWTO = 60
 
 DEF_CTMO = 10
 DEF_RWTO = 60
@@ -183,12 +188,13 @@ class Transport:
       raise EncodingError("Message terminator found in payload")
     self._CheckSocket()
     try:
       raise EncodingError("Message terminator found in payload")
     self._CheckSocket()
     try:
+      # TODO: sendall is not guaranteed to send everything
       self.socket.sendall(msg + self.eom)
     except socket.timeout, err:
       raise TimeoutError("Sending timeout: %s" % str(err))
 
   def Recv(self):
       self.socket.sendall(msg + self.eom)
     except socket.timeout, err:
       raise TimeoutError("Sending timeout: %s" % str(err))
 
   def Recv(self):
-    """Try to receive a messae from the socket.
+    """Try to receive a message from the socket.
 
     In case we already have messages queued, we just return from the
     queue. Otherwise, we try to read data with a _rwtimeout network
 
     In case we already have messages queued, we just return from the
     queue. Otherwise, we try to read data with a _rwtimeout network
@@ -201,10 +207,16 @@ class Transport:
     while not self._msgs:
       if time.time() > etime:
         raise TimeoutError("Extended receive timeout")
     while not self._msgs:
       if time.time() > etime:
         raise TimeoutError("Extended receive timeout")
-      try:
-        data = self.socket.recv(4096)
-      except socket.timeout, err:
-        raise TimeoutError("Receive timeout: %s" % str(err))
+      while True:
+        try:
+          data = self.socket.recv(4096)
+        except socket.error, err:
+          if err.args and err.args[0] == errno.EAGAIN:
+            continue
+          raise
+        except socket.timeout, err:
+          raise TimeoutError("Receive timeout: %s" % str(err))
+        break
       if not data:
         raise ConnectionClosedError("Connection closed while reading")
       new_msgs = (self._buffer + data).split(self.eom)
       if not data:
         raise ConnectionClosedError("Connection closed while reading")
       new_msgs = (self._buffer + data).split(self.eom)
@@ -250,7 +262,32 @@ class Client(object):
     """
     if address is None:
       address = constants.MASTER_SOCKET
     """
     if address is None:
       address = constants.MASTER_SOCKET
-    self.transport = transport(address, timeouts=timeouts)
+    self.address = address
+    self.timeouts = timeouts
+    self.transport_class = transport
+    self.transport = None
+    self._InitTransport()
+
+  def _InitTransport(self):
+    """(Re)initialize the transport if needed.
+
+    """
+    if self.transport is None:
+      self.transport = self.transport_class(self.address,
+                                            timeouts=self.timeouts)
+
+  def _CloseTransport(self):
+    """Close the transport, ignoring errors.
+
+    """
+    if self.transport is None:
+      return
+    try:
+      old_transp = self.transport
+      self.transport = None
+      old_transp.Close()
+    except Exception:
+      pass
 
   def CallMethod(self, method, args):
     """Send a generic request and return the response.
 
   def CallMethod(self, method, args):
     """Send a generic request and return the response.
@@ -262,8 +299,18 @@ class Client(object):
       KEY_ARGS: args,
       }
 
       KEY_ARGS: args,
       }
 
+    # Serialize the request
+    send_data = serializer.DumpJson(request, indent=False)
+
     # Send request and wait for response
     # Send request and wait for response
-    result = self.transport.Call(serializer.DumpJson(request, indent=False))
+    try:
+      self._InitTransport()
+      result = self.transport.Call(send_data)
+    except Exception:
+      self._CloseTransport()
+      raise
+
+    # Parse the result
     try:
       data = serializer.LoadJson(result)
     except Exception, err:
     try:
       data = serializer.LoadJson(result)
     except Exception, err:
@@ -275,16 +322,30 @@ class Client(object):
         KEY_RESULT not in data):
       raise DecodingError("Invalid response from server: %s" % str(data))
 
         KEY_RESULT not in data):
       raise DecodingError("Invalid response from server: %s" % str(data))
 
+    result = data[KEY_RESULT]
+
     if not data[KEY_SUCCESS]:
     if not data[KEY_SUCCESS]:
-      # TODO: decide on a standard exception
-      raise RequestError(data[KEY_RESULT])
+      errors.MaybeRaise(result)
+      raise RequestError(result)
+
+    return result
 
 
-    return data[KEY_RESULT]
+  def SetQueueDrainFlag(self, drain_flag):
+    return self.CallMethod(REQ_QUEUE_SET_DRAIN_FLAG, drain_flag)
+
+  def SetWatcherPause(self, until):
+    return self.CallMethod(REQ_SET_WATCHER_PAUSE, [until])
 
   def SubmitJob(self, ops):
     ops_state = map(lambda op: op.__getstate__(), ops)
     return self.CallMethod(REQ_SUBMIT_JOB, ops_state)
 
 
   def SubmitJob(self, ops):
     ops_state = map(lambda op: op.__getstate__(), ops)
     return self.CallMethod(REQ_SUBMIT_JOB, ops_state)
 
+  def SubmitManyJobs(self, jobs):
+    jobs_state = []
+    for ops in jobs:
+      jobs_state.append([op.__getstate__() for op in ops])
+    return self.CallMethod(REQ_SUBMIT_MANY_JOBS, jobs_state)
+
   def CancelJob(self, job_id):
     return self.CallMethod(REQ_CANCEL_JOB, job_id)
 
   def CancelJob(self, job_id):
     return self.CallMethod(REQ_CANCEL_JOB, job_id)
 
@@ -292,7 +353,8 @@ class Client(object):
     return self.CallMethod(REQ_ARCHIVE_JOB, job_id)
 
   def AutoArchiveJobs(self, age):
     return self.CallMethod(REQ_ARCHIVE_JOB, job_id)
 
   def AutoArchiveJobs(self, age):
-    return self.CallMethod(REQ_AUTOARCHIVE_JOBS, age)
+    timeout = (DEF_RWTO - 1) / 2
+    return self.CallMethod(REQ_AUTOARCHIVE_JOBS, (age, timeout))
 
   def WaitForJobChange(self, job_id, fields, prev_job_info, prev_log_serial):
     timeout = (DEF_RWTO - 1) / 2
 
   def WaitForJobChange(self, job_id, fields, prev_job_info, prev_log_serial):
     timeout = (DEF_RWTO - 1) / 2
@@ -307,16 +369,20 @@ class Client(object):
   def QueryJobs(self, job_ids, fields):
     return self.CallMethod(REQ_QUERY_JOBS, (job_ids, fields))
 
   def QueryJobs(self, job_ids, fields):
     return self.CallMethod(REQ_QUERY_JOBS, (job_ids, fields))
 
-  def QueryInstances(self, names, fields):
-    return self.CallMethod(REQ_QUERY_INSTANCES, (names, fields))
+  def QueryInstances(self, names, fields, use_locking):
+    return self.CallMethod(REQ_QUERY_INSTANCES, (names, fields, use_locking))
 
 
-  def QueryNodes(self, names, fields):
-    return self.CallMethod(REQ_QUERY_NODES, (names, fields))
+  def QueryNodes(self, names, fields, use_locking):
+    return self.CallMethod(REQ_QUERY_NODES, (names, fields, use_locking))
 
 
-  def QueryExports(self, nodes):
-    return self.CallMethod(REQ_QUERY_EXPORTS, nodes)
+  def QueryExports(self, nodes, use_locking):
+    return self.CallMethod(REQ_QUERY_EXPORTS, (nodes, use_locking))
+
+  def QueryClusterInfo(self):
+    return self.CallMethod(REQ_QUERY_CLUSTER_INFO, ())
 
   def QueryConfigValues(self, fields):
     return self.CallMethod(REQ_QUERY_CONFIG_VALUES, fields)
 
 
   def QueryConfigValues(self, fields):
     return self.CallMethod(REQ_QUERY_CONFIG_VALUES, fields)
 
+
 # TODO: class Server(object)
 # TODO: class Server(object)