Statistics
| Branch: | Tag: | Revision:

root / lib / rpc / client.py @ c4071978

History | View | Annotate | Download (6.8 kB)

1
#
2
#
3

    
4
# Copyright (C) 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Module for generic RPC clients.
23

24
"""
25

    
26
import logging
27

    
28
import ganeti.rpc.transport as t
29

    
30
from ganeti import constants
31
from ganeti import errors
32
from ganeti.rpc.errors import (ProtocolError, RequestError, LuxiError)
33
from ganeti import serializer
34

    
35
KEY_METHOD = constants.LUXI_KEY_METHOD
36
KEY_ARGS = constants.LUXI_KEY_ARGS
37
KEY_SUCCESS = constants.LUXI_KEY_SUCCESS
38
KEY_RESULT = constants.LUXI_KEY_RESULT
39
KEY_VERSION = constants.LUXI_KEY_VERSION
40

    
41

    
42
def ParseRequest(msg):
43
  """Parses a request message.
44

45
  """
46
  try:
47
    request = serializer.LoadJson(msg)
48
  except ValueError, err:
49
    raise ProtocolError("Invalid RPC request (parsing error): %s" % err)
50

    
51
  logging.debug("RPC request: %s", request)
52

    
53
  if not isinstance(request, dict):
54
    logging.error("RPC request not a dict: %r", msg)
55
    raise ProtocolError("Invalid RPC request (not a dict)")
56

    
57
  method = request.get(KEY_METHOD, None) # pylint: disable=E1103
58
  args = request.get(KEY_ARGS, None) # pylint: disable=E1103
59
  version = request.get(KEY_VERSION, None) # pylint: disable=E1103
60

    
61
  if method is None or args is None:
62
    logging.error("RPC request missing method or arguments: %r", msg)
63
    raise ProtocolError(("Invalid RPC request (no method or arguments"
64
                         " in request): %r") % msg)
65

    
66
  return (method, args, version)
67

    
68

    
69
def ParseResponse(msg):
70
  """Parses a response message.
71

72
  """
73
  # Parse the result
74
  try:
75
    data = serializer.LoadJson(msg)
76
  except KeyboardInterrupt:
77
    raise
78
  except Exception, err:
79
    raise ProtocolError("Error while deserializing response: %s" % str(err))
80

    
81
  # Validate response
82
  if not (isinstance(data, dict) and
83
          KEY_SUCCESS in data and
84
          KEY_RESULT in data):
85
    raise ProtocolError("Invalid response from server: %r" % data)
86

    
87
  return (data[KEY_SUCCESS], data[KEY_RESULT],
88
          data.get(KEY_VERSION, None)) # pylint: disable=E1103
89

    
90

    
91
def FormatResponse(success, result, version=None):
92
  """Formats a response message.
93

94
  """
95
  response = {
96
    KEY_SUCCESS: success,
97
    KEY_RESULT: result,
98
    }
99

    
100
  if version is not None:
101
    response[KEY_VERSION] = version
102

    
103
  logging.debug("RPC response: %s", response)
104

    
105
  return serializer.DumpJson(response)
106

    
107

    
108
def FormatRequest(method, args, version=None):
109
  """Formats a request message.
110

111
  """
112
  # Build request
113
  request = {
114
    KEY_METHOD: method,
115
    KEY_ARGS: args,
116
    }
117

    
118
  if version is not None:
119
    request[KEY_VERSION] = version
120

    
121
  # Serialize the request
122
  return serializer.DumpJson(request,
123
                             private_encoder=serializer.EncodeWithPrivateFields)
124

    
125

    
126
def CallRPCMethod(transport_cb, method, args, version=None):
127
  """Send a RPC request via a transport and return the response.
128

129
  """
130
  assert callable(transport_cb)
131

    
132
  request_msg = FormatRequest(method, args, version=version)
133

    
134
  # Send request and wait for response
135
  response_msg = transport_cb(request_msg)
136

    
137
  (success, result, resp_version) = ParseResponse(response_msg)
138

    
139
  # Verify version if there was one in the response
140
  if resp_version is not None and resp_version != version:
141
    raise LuxiError("RPC version mismatch, client %s, response %s" %
142
                    (version, resp_version))
143

    
144
  if success:
145
    return result
146

    
147
  errors.MaybeRaise(result)
148
  raise RequestError(result)
149

    
150

    
151
class AbstractClient(object):
152
  """High-level client abstraction.
153

154
  This uses a backing Transport-like class on top of which it
155
  implements data serialization/deserialization.
156

157
  """
158

    
159
  def __init__(self, timeouts=None, transport=t.Transport):
160
    """Constructor for the Client class.
161

162
    Arguments:
163
      - address: a valid address the the used transport class
164
      - timeout: a list of timeouts, to be used on connect and read/write
165
      - transport: a Transport-like class
166

167

168
    If timeout is not passed, the default timeouts of the transport
169
    class are used.
170

171
    """
172
    self.timeouts = timeouts
173
    self.transport_class = transport
174
    self.transport = None
175
    # The version used in RPC communication, by default unused:
176
    self.version = None
177

    
178
  def _GetAddress(self):
179
    """Returns the socket address
180

181
    """
182
    raise NotImplementedError
183

    
184
  def _InitTransport(self):
185
    """(Re)initialize the transport if needed.
186

187
    """
188
    if self.transport is None:
189
      self.transport = self.transport_class(self._GetAddress(),
190
                                            timeouts=self.timeouts)
191

    
192
  def _CloseTransport(self):
193
    """Close the transport, ignoring errors.
194

195
    """
196
    if self.transport is None:
197
      return
198
    try:
199
      old_transp = self.transport
200
      self.transport = None
201
      old_transp.Close()
202
    except Exception: # pylint: disable=W0703
203
      pass
204

    
205
  def _SendMethodCall(self, data):
206
    # Send request and wait for response
207
    try:
208
      self._InitTransport()
209
      return self.transport.Call(data)
210
    except Exception:
211
      self._CloseTransport()
212
      raise
213

    
214
  def Close(self):
215
    """Close the underlying connection.
216

217
    """
218
    self._CloseTransport()
219

    
220
  def close(self):
221
    """Same as L{Close}, to be used with contextlib.closing(...).
222

223
    """
224
    self.Close()
225

    
226
  def CallMethod(self, method, args):
227
    """Send a generic request and return the response.
228

229
    """
230
    if not isinstance(args, (list, tuple)):
231
      raise errors.ProgrammerError("Invalid parameter passed to CallMethod:"
232
                                   " expected list, got %s" % type(args))
233
    return CallRPCMethod(self._SendMethodCall, method, args,
234
                         version=self.version)
235

    
236

    
237
class AbstractStubClient(AbstractClient):
238
  """An abstract Client that connects a generated stub client to a L{Transport}.
239

240
  Subclasses should inherit from this class (first) as well and a designated
241
  stub (second).
242
  """
243

    
244
  def __init__(self, timeouts=None, transport=t.Transport):
245
    """Constructor for the class.
246

247
    Arguments are the same as for L{AbstractClient}. Checks that SOCKET_PATH
248
    attribute is defined (in the stub class).
249
    """
250

    
251
    super(AbstractStubClient, self).__init__(timeouts, transport)
252

    
253
  def _GenericInvoke(self, method, *args):
254
    return self.CallMethod(method, args)
255

    
256
  def _GetAddress(self):
257
    return self._GetSocketPath() # pylint: disable=E1101