Statistics
| Branch: | Tag: | Revision:

root / lib / rpc / client.py @ a28216b0

History | View | Annotate | Download (6.2 kB)

1
#
2
#
3

    
4
# Copyright (C) 2013 Google Inc.
5
#
6
# This program is free software; you can redistribute it and/or modify
7
# it under the terms of the GNU General Public License as published by
8
# the Free Software Foundation; either version 2 of the License, or
9
# (at your option) any later version.
10
#
11
# This program is distributed in the hope that it will be useful, but
12
# WITHOUT ANY WARRANTY; without even the implied warranty of
13
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14
# General Public License for more details.
15
#
16
# You should have received a copy of the GNU General Public License
17
# along with this program; if not, write to the Free Software
18
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19
# 02110-1301, USA.
20

    
21

    
22
"""Module for generic RPC clients.
23

24
"""
25

    
26
import logging
27

    
28
from ganeti import pathutils
29
import ganeti.rpc.transport as t
30

    
31
from ganeti import constants
32
from ganeti import errors
33
from ganeti.rpc.errors import (ProtocolError, RequestError, LuxiError)
34
from ganeti import serializer
35

    
36
KEY_METHOD = constants.LUXI_KEY_METHOD
37
KEY_ARGS = constants.LUXI_KEY_ARGS
38
KEY_SUCCESS = constants.LUXI_KEY_SUCCESS
39
KEY_RESULT = constants.LUXI_KEY_RESULT
40
KEY_VERSION = constants.LUXI_KEY_VERSION
41

    
42

    
43
def ParseRequest(msg):
44
  """Parses a request message.
45

46
  """
47
  try:
48
    request = serializer.LoadJson(msg)
49
  except ValueError, err:
50
    raise ProtocolError("Invalid RPC request (parsing error): %s" % err)
51

    
52
  logging.debug("RPC request: %s", request)
53

    
54
  if not isinstance(request, dict):
55
    logging.error("RPC request not a dict: %r", msg)
56
    raise ProtocolError("Invalid RPC request (not a dict)")
57

    
58
  method = request.get(KEY_METHOD, None) # pylint: disable=E1103
59
  args = request.get(KEY_ARGS, None) # pylint: disable=E1103
60
  version = request.get(KEY_VERSION, None) # pylint: disable=E1103
61

    
62
  if method is None or args is None:
63
    logging.error("RPC request missing method or arguments: %r", msg)
64
    raise ProtocolError(("Invalid RPC request (no method or arguments"
65
                         " in request): %r") % msg)
66

    
67
  return (method, args, version)
68

    
69

    
70
def ParseResponse(msg):
71
  """Parses a response message.
72

73
  """
74
  # Parse the result
75
  try:
76
    data = serializer.LoadJson(msg)
77
  except KeyboardInterrupt:
78
    raise
79
  except Exception, err:
80
    raise ProtocolError("Error while deserializing response: %s" % str(err))
81

    
82
  # Validate response
83
  if not (isinstance(data, dict) and
84
          KEY_SUCCESS in data and
85
          KEY_RESULT in data):
86
    raise ProtocolError("Invalid response from server: %r" % data)
87

    
88
  return (data[KEY_SUCCESS], data[KEY_RESULT],
89
          data.get(KEY_VERSION, None)) # pylint: disable=E1103
90

    
91

    
92
def FormatResponse(success, result, version=None):
93
  """Formats a response message.
94

95
  """
96
  response = {
97
    KEY_SUCCESS: success,
98
    KEY_RESULT: result,
99
    }
100

    
101
  if version is not None:
102
    response[KEY_VERSION] = version
103

    
104
  logging.debug("RPC response: %s", response)
105

    
106
  return serializer.DumpJson(response)
107

    
108

    
109
def FormatRequest(method, args, version=None):
110
  """Formats a request message.
111

112
  """
113
  # Build request
114
  request = {
115
    KEY_METHOD: method,
116
    KEY_ARGS: args,
117
    }
118

    
119
  if version is not None:
120
    request[KEY_VERSION] = version
121

    
122
  # Serialize the request
123
  return serializer.DumpJson(request,
124
                             private_encoder=serializer.EncodeWithPrivateFields)
125

    
126

    
127
def CallRPCMethod(transport_cb, method, args, version=None):
128
  """Send a RPC request via a transport and return the response.
129

130
  """
131
  assert callable(transport_cb)
132

    
133
  request_msg = FormatRequest(method, args, version=version)
134

    
135
  # Send request and wait for response
136
  response_msg = transport_cb(request_msg)
137

    
138
  (success, result, resp_version) = ParseResponse(response_msg)
139

    
140
  # Verify version if there was one in the response
141
  if resp_version is not None and resp_version != version:
142
    raise LuxiError("RPC version mismatch, client %s, response %s" %
143
                    (version, resp_version))
144

    
145
  if success:
146
    return result
147

    
148
  errors.MaybeRaise(result)
149
  raise RequestError(result)
150

    
151

    
152
class AbstractClient(object):
153
  """High-level client abstraction.
154

155
  This uses a backing Transport-like class on top of which it
156
  implements data serialization/deserialization.
157

158
  """
159

    
160
  def __init__(self, address=None, timeouts=None,
161
               transport=t.Transport):
162
    """Constructor for the Client class.
163

164
    Arguments:
165
      - address: a valid address the the used transport class
166
      - timeout: a list of timeouts, to be used on connect and read/write
167
      - transport: a Transport-like class
168

169

170
    If timeout is not passed, the default timeouts of the transport
171
    class are used.
172

173
    """
174
    if address is None:
175
      address = pathutils.QUERY_SOCKET
176
    self.address = address
177
    self.timeouts = timeouts
178
    self.transport_class = transport
179
    self.transport = None
180
    self._InitTransport()
181
    # The version used in RPC communication, by default unused:
182
    self.version = None
183

    
184
  def _InitTransport(self):
185
    """(Re)initialize the transport if needed.
186

187
    """
188
    if self.transport is None:
189
      self.transport = self.transport_class(self.address,
190
                                            timeouts=self.timeouts)
191

    
192
  def _CloseTransport(self):
193
    """Close the transport, ignoring errors.
194

195
    """
196
    if self.transport is None:
197
      return
198
    try:
199
      old_transp = self.transport
200
      self.transport = None
201
      old_transp.Close()
202
    except Exception: # pylint: disable=W0703
203
      pass
204

    
205
  def _SendMethodCall(self, data):
206
    # Send request and wait for response
207
    try:
208
      self._InitTransport()
209
      return self.transport.Call(data)
210
    except Exception:
211
      self._CloseTransport()
212
      raise
213

    
214
  def Close(self):
215
    """Close the underlying connection.
216

217
    """
218
    self._CloseTransport()
219

    
220
  def close(self):
221
    """Same as L{Close}, to be used with contextlib.closing(...).
222

223
    """
224
    self.Close()
225

    
226
  def CallMethod(self, method, args):
227
    """Send a generic request and return the response.
228

229
    """
230
    if not isinstance(args, (list, tuple)):
231
      raise errors.ProgrammerError("Invalid parameter passed to CallMethod:"
232
                                   " expected list, got %s" % type(args))
233
    return CallRPCMethod(self._SendMethodCall, method, args,
234
                         version=self.version)