Fix a couple of epydoc warnings
[ganeti-local] / lib / luxi.py
index 993907d..11ea61d 100644 (file)
 
 """Module for the unix socket protocol
 
 
 """Module for the unix socket protocol
 
-This module implements the local unix socket protocl. You only need
+This module implements the local unix socket protocol. You only need
 this module and the opcodes module in the client program in order to
 communicate with the master.
 
 this module and the opcodes module in the client program in order to
 communicate with the master.
 
-The module is also be used by the master daemon.
+The module is also used by the master daemon.
 
 """
 
 import socket
 import collections
 
 """
 
 import socket
 import collections
-import simplejson
 import time
 import errno
 
 import time
 import errno
 
-from ganeti import opcodes
+from ganeti import serializer
 from ganeti import constants
 from ganeti import constants
-
-
-KEY_REQUEST = 'request'
-KEY_DATA = 'data'
-REQ_SUBMIT = 'submit'
-REQ_ABORT = 'abort'
-REQ_QUERY = 'query'
+from ganeti import errors
+
+
+KEY_METHOD = 'method'
+KEY_ARGS = 'args'
+KEY_SUCCESS = "success"
+KEY_RESULT = "result"
+
+REQ_SUBMIT_JOB = "SubmitJob"
+REQ_SUBMIT_MANY_JOBS = "SubmitManyJobs"
+REQ_WAIT_FOR_JOB_CHANGE = "WaitForJobChange"
+REQ_CANCEL_JOB = "CancelJob"
+REQ_ARCHIVE_JOB = "ArchiveJob"
+REQ_AUTOARCHIVE_JOBS = "AutoArchiveJobs"
+REQ_QUERY_JOBS = "QueryJobs"
+REQ_QUERY_INSTANCES = "QueryInstances"
+REQ_QUERY_NODES = "QueryNodes"
+REQ_QUERY_EXPORTS = "QueryExports"
+REQ_QUERY_CONFIG_VALUES = "QueryConfigValues"
+REQ_QUERY_CLUSTER_INFO = "QueryClusterInfo"
+REQ_QUEUE_SET_DRAIN_FLAG = "SetDrainFlag"
 
 DEF_CTMO = 10
 DEF_RWTO = 60
 
 DEF_CTMO = 10
 DEF_RWTO = 60
@@ -81,6 +94,7 @@ class RequestError(ProtocolError):
 
   """
 
 
   """
 
+
 class NoMasterError(ProtocolError):
   """The master cannot be reached
 
 class NoMasterError(ProtocolError):
   """The master cannot be reached
 
@@ -90,24 +104,6 @@ class NoMasterError(ProtocolError):
   """
 
 
   """
 
 
-def SerializeJob(job):
-  """Convert a job description to a string format.
-
-  """
-  return simplejson.dumps(job.__getstate__())
-
-
-def UnserializeJob(data):
-  """Load a job from a string format"""
-  try:
-    new_data = simplejson.loads(data)
-  except Exception, err:
-    raise DecodingError("Error while unserializing: %s" % str(err))
-  job = opcodes.Job()
-  job.__setstate__(new_data)
-  return job
-
-
 class Transport:
   """Low-level transport class.
 
 class Transport:
   """Low-level transport class.
 
@@ -164,7 +160,7 @@ class Transport:
       except socket.timeout, err:
         raise TimeoutError("Connect timed out: %s" % str(err))
       except socket.error, err:
       except socket.timeout, err:
         raise TimeoutError("Connect timed out: %s" % str(err))
       except socket.error, err:
-        if err.args[0] == errno.ENOENT:
+        if err.args[0] in (errno.ENOENT, errno.ECONNREFUSED):
           raise NoMasterError((address,))
         raise
       self.socket.settimeout(self._rwtimeout)
           raise NoMasterError((address,))
         raise
       self.socket.settimeout(self._rwtimeout)
@@ -191,12 +187,13 @@ class Transport:
       raise EncodingError("Message terminator found in payload")
     self._CheckSocket()
     try:
       raise EncodingError("Message terminator found in payload")
     self._CheckSocket()
     try:
+      # TODO: sendall is not guaranteed to send everything
       self.socket.sendall(msg + self.eom)
     except socket.timeout, err:
       raise TimeoutError("Sending timeout: %s" % str(err))
 
   def Recv(self):
       self.socket.sendall(msg + self.eom)
     except socket.timeout, err:
       raise TimeoutError("Sending timeout: %s" % str(err))
 
   def Recv(self):
-    """Try to receive a messae from the socket.
+    """Try to receive a message from the socket.
 
     In case we already have messages queued, we just return from the
     queue. Otherwise, we try to read data with a _rwtimeout network
 
     In case we already have messages queued, we just return from the
     queue. Otherwise, we try to read data with a _rwtimeout network
@@ -209,10 +206,16 @@ class Transport:
     while not self._msgs:
       if time.time() > etime:
         raise TimeoutError("Extended receive timeout")
     while not self._msgs:
       if time.time() > etime:
         raise TimeoutError("Extended receive timeout")
-      try:
-        data = self.socket.recv(4096)
-      except socket.timeout, err:
-        raise TimeoutError("Receive timeout: %s" % str(err))
+      while True:
+        try:
+          data = self.socket.recv(4096)
+        except socket.error, err:
+          if err.args and err.args[0] == errno.EAGAIN:
+            continue
+          raise
+        except socket.timeout, err:
+          raise TimeoutError("Receive timeout: %s" % str(err))
+        break
       if not data:
         raise ConnectionClosedError("Connection closed while reading")
       new_msgs = (self._buffer + data).split(self.eom)
       if not data:
         raise ConnectionClosedError("Connection closed while reading")
       new_msgs = (self._buffer + data).split(self.eom)
@@ -258,41 +261,131 @@ class Client(object):
     """
     if address is None:
       address = constants.MASTER_SOCKET
     """
     if address is None:
       address = constants.MASTER_SOCKET
-    self.transport = transport(address, timeouts=timeouts)
+    self.address = address
+    self.timeouts = timeouts
+    self.transport_class = transport
+    self.transport = None
+    self._InitTransport()
+
+  def _InitTransport(self):
+    """(Re)initialize the transport if needed.
+
+    """
+    if self.transport is None:
+      self.transport = self.transport_class(self.address,
+                                            timeouts=self.timeouts)
 
 
-  def SendRequest(self, request, data):
+  def _CloseTransport(self):
+    """Close the transport, ignoring errors.
+
+    """
+    if self.transport is None:
+      return
+    try:
+      old_transp = self.transport
+      self.transport = None
+      old_transp.Close()
+    except Exception:
+      pass
+
+  def CallMethod(self, method, args):
     """Send a generic request and return the response.
 
     """
     """Send a generic request and return the response.
 
     """
-    msg = {KEY_REQUEST: request, KEY_DATA: data}
-    result = self.transport.Call(simplejson.dumps(msg))
+    # Build request
+    request = {
+      KEY_METHOD: method,
+      KEY_ARGS: args,
+      }
+
+    # Serialize the request
+    send_data = serializer.DumpJson(request, indent=False)
+
+    # Send request and wait for response
     try:
     try:
-      data = simplejson.loads(result)
+      self._InitTransport()
+      result = self.transport.Call(send_data)
+    except Exception:
+      self._CloseTransport()
+      raise
+
+    # Parse the result
+    try:
+      data = serializer.LoadJson(result)
     except Exception, err:
       raise ProtocolError("Error while deserializing response: %s" % str(err))
     except Exception, err:
       raise ProtocolError("Error while deserializing response: %s" % str(err))
+
+    # Validate response
     if (not isinstance(data, dict) or
     if (not isinstance(data, dict) or
-        'success' not in data or
-        'result' not in data):
+        KEY_SUCCESS not in data or
+        KEY_RESULT not in data):
       raise DecodingError("Invalid response from server: %s" % str(data))
       raise DecodingError("Invalid response from server: %s" % str(data))
-    return data
-
-  def SubmitJob(self, job):
-    """Submit a job"""
-    result = self.SendRequest(REQ_SUBMIT, SerializeJob(job))
-    if not result['success']:
-      raise RequestError(result['result'])
-    return result['result']
-
-  def Query(self, data):
-    """Make a query"""
-    result = self.SendRequest(REQ_QUERY, data)
-    if not result['success']:
-      raise RequestError(result[result])
-    result = result['result']
-    if data["object"] == "jobs":
-      # custom job processing of query values
-      for row in result:
-        for idx, field in enumerate(data["fields"]):
-          if field == "op_list":
-            row[idx] = [opcodes.OpCode.LoadOpCode(i) for i in row[idx]]
+
+    result = data[KEY_RESULT]
+
+    if not data[KEY_SUCCESS]:
+      # TODO: decide on a standard exception
+      if (isinstance(result, (tuple, list)) and len(result) == 2 and
+          isinstance(result[1], (tuple, list))):
+        # custom ganeti errors
+        err_class = errors.GetErrorClass(result[0])
+        if err_class is not None:
+          raise err_class, tuple(result[1])
+
+      raise RequestError(result)
+
     return result
     return result
+
+  def SetQueueDrainFlag(self, drain_flag):
+    return self.CallMethod(REQ_QUEUE_SET_DRAIN_FLAG, drain_flag)
+
+  def SubmitJob(self, ops):
+    ops_state = map(lambda op: op.__getstate__(), ops)
+    return self.CallMethod(REQ_SUBMIT_JOB, ops_state)
+
+  def SubmitManyJobs(self, jobs):
+    jobs_state = []
+    for ops in jobs:
+      jobs_state.append([op.__getstate__() for op in ops])
+    return self.CallMethod(REQ_SUBMIT_MANY_JOBS, jobs_state)
+
+  def CancelJob(self, job_id):
+    return self.CallMethod(REQ_CANCEL_JOB, job_id)
+
+  def ArchiveJob(self, job_id):
+    return self.CallMethod(REQ_ARCHIVE_JOB, job_id)
+
+  def AutoArchiveJobs(self, age):
+    timeout = (DEF_RWTO - 1) / 2
+    return self.CallMethod(REQ_AUTOARCHIVE_JOBS, (age, timeout))
+
+  def WaitForJobChange(self, job_id, fields, prev_job_info, prev_log_serial):
+    timeout = (DEF_RWTO - 1) / 2
+    while True:
+      result = self.CallMethod(REQ_WAIT_FOR_JOB_CHANGE,
+                               (job_id, fields, prev_job_info,
+                                prev_log_serial, timeout))
+      if result != constants.JOB_NOTCHANGED:
+        break
+    return result
+
+  def QueryJobs(self, job_ids, fields):
+    return self.CallMethod(REQ_QUERY_JOBS, (job_ids, fields))
+
+  def QueryInstances(self, names, fields, use_locking):
+    return self.CallMethod(REQ_QUERY_INSTANCES, (names, fields, use_locking))
+
+  def QueryNodes(self, names, fields, use_locking):
+    return self.CallMethod(REQ_QUERY_NODES, (names, fields, use_locking))
+
+  def QueryExports(self, nodes, use_locking):
+    return self.CallMethod(REQ_QUERY_EXPORTS, (nodes, use_locking))
+
+  def QueryClusterInfo(self):
+    return self.CallMethod(REQ_QUERY_CLUSTER_INFO, ())
+
+  def QueryConfigValues(self, fields):
+    return self.CallMethod(REQ_QUERY_CONFIG_VALUES, fields)
+
+
+# TODO: class Server(object)