#!/usr/bin/python
#
-# Copyright (C) 2010 Google Inc.
+# Copyright (C) 2010, 2011, 2012 Google Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
import os
import sys
import unittest
+import random
+import tempfile
from ganeti import constants
from ganeti import compat
from ganeti import rpc
+from ganeti import rpc_defs
from ganeti import http
from ganeti import errors
from ganeti import serializer
+from ganeti import objects
+from ganeti import backend
import testutils
+import mocks
-class TestTimeouts(unittest.TestCase):
- def test(self):
- names = [name[len("call_"):] for name in dir(rpc.RpcRunner)
- if name.startswith("call_")]
- self.assertEqual(len(names), len(rpc._TIMEOUTS))
- self.assertFalse([name for name in names
- if not (rpc._TIMEOUTS[name] is None or
- rpc._TIMEOUTS[name] > 0)])
-
-
-class FakeHttpPool:
+class _FakeRequestProcessor:
def __init__(self, response_fn):
self._response_fn = response_fn
self.reqcount = 0
- def ProcessRequests(self, reqs):
+ def __call__(self, reqs, lock_monitor_cb=None):
+ assert lock_monitor_cb is None or callable(lock_monitor_cb)
for req in reqs:
self.reqcount += 1
self._response_fn(req)
return FakeSimpleStore
-class TestClient(unittest.TestCase):
+def _RaiseNotImplemented():
+ """Simple wrapper to raise NotImplementedError.
+
+ """
+ raise NotImplementedError
+
+
+class TestRpcProcessor(unittest.TestCase):
def _FakeAddressLookup(self, map):
return lambda node_list: [map.get(node) for node in node_list]
def _GetVersionResponse(self, req):
- self.assertEqual(req.host, "localhost")
+ self.assertEqual(req.host, "127.0.0.1")
self.assertEqual(req.port, 24094)
self.assertEqual(req.path, "/version")
+ self.assertEqual(req.read_timeout, constants.RPC_TMO_URGENT)
req.success = True
req.resp_status_code = http.HTTP_OK
req.resp_body = serializer.DumpJson((True, 123))
def testVersionSuccess(self):
- fn = self._FakeAddressLookup({"localhost": "localhost"})
- client = rpc.Client("version", None, 24094, address_lookup_fn=fn)
- client.ConnectNode("localhost")
- pool = FakeHttpPool(self._GetVersionResponse)
- result = client.GetResults(http_pool=pool)
+ resolver = rpc._StaticResolver(["127.0.0.1"])
+ http_proc = _FakeRequestProcessor(self._GetVersionResponse)
+ proc = rpc._RpcProcessor(resolver, 24094)
+ result = proc(["localhost"], "version", {"localhost": ""}, 60,
+ NotImplemented, _req_process_fn=http_proc)
self.assertEqual(result.keys(), ["localhost"])
lhresp = result["localhost"]
self.assertFalse(lhresp.offline)
self.assertEqual(lhresp.payload, 123)
self.assertEqual(lhresp.call, "version")
lhresp.Raise("should not raise")
- self.assertEqual(pool.reqcount, 1)
+ self.assertEqual(http_proc.reqcount, 1)
+
+ def _ReadTimeoutResponse(self, req):
+ self.assertEqual(req.host, "192.0.2.13")
+ self.assertEqual(req.port, 19176)
+ self.assertEqual(req.path, "/version")
+ self.assertEqual(req.read_timeout, 12356)
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, -1))
+
+ def testReadTimeout(self):
+ resolver = rpc._StaticResolver(["192.0.2.13"])
+ http_proc = _FakeRequestProcessor(self._ReadTimeoutResponse)
+ proc = rpc._RpcProcessor(resolver, 19176)
+ host = "node31856"
+ body = {host: ""}
+ result = proc([host], "version", body, 12356, NotImplemented,
+ _req_process_fn=http_proc)
+ self.assertEqual(result.keys(), [host])
+ lhresp = result[host]
+ self.assertFalse(lhresp.offline)
+ self.assertEqual(lhresp.node, host)
+ self.assertFalse(lhresp.fail_msg)
+ self.assertEqual(lhresp.payload, -1)
+ self.assertEqual(lhresp.call, "version")
+ lhresp.Raise("should not raise")
+ self.assertEqual(http_proc.reqcount, 1)
+
+ def testOfflineNode(self):
+ resolver = rpc._StaticResolver([rpc._OFFLINE])
+ http_proc = _FakeRequestProcessor(NotImplemented)
+ proc = rpc._RpcProcessor(resolver, 30668)
+ host = "n17296"
+ body = {host: ""}
+ result = proc([host], "version", body, 60, NotImplemented,
+ _req_process_fn=http_proc)
+ self.assertEqual(result.keys(), [host])
+ lhresp = result[host]
+ self.assertTrue(lhresp.offline)
+ self.assertEqual(lhresp.node, host)
+ self.assertTrue(lhresp.fail_msg)
+ self.assertFalse(lhresp.payload)
+ self.assertEqual(lhresp.call, "version")
+
+ # With a message
+ self.assertRaises(errors.OpExecError, lhresp.Raise, "should raise")
+
+ # No message
+ self.assertRaises(errors.OpExecError, lhresp.Raise, None)
+
+ self.assertEqual(http_proc.reqcount, 0)
def _GetMultiVersionResponse(self, req):
self.assert_(req.host.startswith("node"))
def testMultiVersionSuccess(self):
nodes = ["node%s" % i for i in range(50)]
- fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
- client = rpc.Client("version", None, 23245, address_lookup_fn=fn)
- client.ConnectList(nodes)
-
- pool = FakeHttpPool(self._GetMultiVersionResponse)
- result = client.GetResults(http_pool=pool)
+ body = dict((n, "") for n in nodes)
+ resolver = rpc._StaticResolver(nodes)
+ http_proc = _FakeRequestProcessor(self._GetMultiVersionResponse)
+ proc = rpc._RpcProcessor(resolver, 23245)
+ result = proc(nodes, "version", body, 60, NotImplemented,
+ _req_process_fn=http_proc)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
self.assertEqual(lhresp.call, "version")
lhresp.Raise("should not raise")
- self.assertEqual(pool.reqcount, len(nodes))
+ self.assertEqual(http_proc.reqcount, len(nodes))
- def _GetVersionResponseFail(self, req):
+ def _GetVersionResponseFail(self, errinfo, req):
self.assertEqual(req.path, "/version")
req.success = True
req.resp_status_code = http.HTTP_OK
- req.resp_body = serializer.DumpJson((False, "Unknown error"))
+ req.resp_body = serializer.DumpJson((False, errinfo))
def testVersionFailure(self):
- lookup_map = {"aef9ur4i.example.com": "aef9ur4i.example.com"}
- fn = self._FakeAddressLookup(lookup_map)
- client = rpc.Client("version", None, 5903, address_lookup_fn=fn)
- client.ConnectNode("aef9ur4i.example.com")
- pool = FakeHttpPool(self._GetVersionResponseFail)
- result = client.GetResults(http_pool=pool)
- self.assertEqual(result.keys(), ["aef9ur4i.example.com"])
- lhresp = result["aef9ur4i.example.com"]
- self.assertFalse(lhresp.offline)
- self.assertEqual(lhresp.node, "aef9ur4i.example.com")
- self.assert_(lhresp.fail_msg)
- self.assertFalse(lhresp.payload)
- self.assertEqual(lhresp.call, "version")
- self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
- self.assertEqual(pool.reqcount, 1)
+ resolver = rpc._StaticResolver(["aef9ur4i.example.com"])
+ proc = rpc._RpcProcessor(resolver, 5903)
+ for errinfo in [None, "Unknown error"]:
+ http_proc = \
+ _FakeRequestProcessor(compat.partial(self._GetVersionResponseFail,
+ errinfo))
+ host = "aef9ur4i.example.com"
+ body = {host: ""}
+ result = proc(body.keys(), "version", body, 60, NotImplemented,
+ _req_process_fn=http_proc)
+ self.assertEqual(result.keys(), [host])
+ lhresp = result[host]
+ self.assertFalse(lhresp.offline)
+ self.assertEqual(lhresp.node, host)
+ self.assert_(lhresp.fail_msg)
+ self.assertFalse(lhresp.payload)
+ self.assertEqual(lhresp.call, "version")
+ self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
+ self.assertEqual(http_proc.reqcount, 1)
def _GetHttpErrorResponse(self, httperrnodes, failnodes, req):
self.assertEqual(req.path, "/vg_list")
def testHttpError(self):
nodes = ["uaf6pbbv%s" % i for i in range(50)]
- fn = self._FakeAddressLookup(dict(zip(nodes, nodes)))
+ body = dict((n, "") for n in nodes)
+ resolver = rpc._StaticResolver(nodes)
httperrnodes = set(nodes[1::7])
self.assertEqual(len(httperrnodes), 7)
self.assertEqual(len(set(nodes) - failnodes - httperrnodes), 29)
- client = rpc.Client("vg_list", None, 15165, address_lookup_fn=fn)
- client.ConnectList(nodes)
-
- pool = FakeHttpPool(compat.partial(self._GetHttpErrorResponse,
- httperrnodes, failnodes))
- result = client.GetResults(http_pool=pool)
+ proc = rpc._RpcProcessor(resolver, 15165)
+ http_proc = \
+ _FakeRequestProcessor(compat.partial(self._GetHttpErrorResponse,
+ httperrnodes, failnodes))
+ result = proc(nodes, "vg_list", body,
+ constants.RPC_TMO_URGENT, NotImplemented,
+ _req_process_fn=http_proc)
self.assertEqual(sorted(result.keys()), sorted(nodes))
for name in nodes:
self.assertEqual(lhresp.payload, hash(name))
lhresp.Raise("should not raise")
- self.assertEqual(pool.reqcount, len(nodes))
+ self.assertEqual(http_proc.reqcount, len(nodes))
def _GetInvalidResponseA(self, req):
self.assertEqual(req.path, "/version")
req.resp_body = serializer.DumpJson("invalid response")
def testInvalidResponse(self):
- lookup_map = {"oqo7lanhly.example.com": "oqo7lanhly.example.com"}
- fn = self._FakeAddressLookup(lookup_map)
- client = rpc.Client("version", None, 19978, address_lookup_fn=fn)
+ resolver = rpc._StaticResolver(["oqo7lanhly.example.com"])
+ proc = rpc._RpcProcessor(resolver, 19978)
+
for fn in [self._GetInvalidResponseA, self._GetInvalidResponseB]:
- client.ConnectNode("oqo7lanhly.example.com")
- pool = FakeHttpPool(fn)
- result = client.GetResults(http_pool=pool)
- self.assertEqual(result.keys(), ["oqo7lanhly.example.com"])
- lhresp = result["oqo7lanhly.example.com"]
+ http_proc = _FakeRequestProcessor(fn)
+ host = "oqo7lanhly.example.com"
+ body = {host: ""}
+ result = proc([host], "version", body, 60, NotImplemented,
+ _req_process_fn=http_proc)
+ self.assertEqual(result.keys(), [host])
+ lhresp = result[host]
self.assertFalse(lhresp.offline)
- self.assertEqual(lhresp.node, "oqo7lanhly.example.com")
+ self.assertEqual(lhresp.node, host)
self.assert_(lhresp.fail_msg)
self.assertFalse(lhresp.payload)
self.assertEqual(lhresp.call, "version")
self.assertRaises(errors.OpExecError, lhresp.Raise, "failed")
- self.assertEqual(pool.reqcount, 1)
+ self.assertEqual(http_proc.reqcount, 1)
+
+ def _GetBodyTestResponse(self, test_data, req):
+ self.assertEqual(req.host, "192.0.2.84")
+ self.assertEqual(req.port, 18700)
+ self.assertEqual(req.path, "/upload_file")
+ self.assertEqual(serializer.LoadJson(req.post_data), test_data)
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, None))
+
+ def testResponseBody(self):
+ test_data = {
+ "Hello": "World",
+ "xyz": range(10),
+ }
+ resolver = rpc._StaticResolver(["192.0.2.84"])
+ http_proc = _FakeRequestProcessor(compat.partial(self._GetBodyTestResponse,
+ test_data))
+ proc = rpc._RpcProcessor(resolver, 18700)
+ host = "node19759"
+ body = {host: serializer.DumpJson(test_data)}
+ result = proc([host], "upload_file", body, 30, NotImplemented,
+ _req_process_fn=http_proc)
+ self.assertEqual(result.keys(), [host])
+ lhresp = result[host]
+ self.assertFalse(lhresp.offline)
+ self.assertEqual(lhresp.node, host)
+ self.assertFalse(lhresp.fail_msg)
+ self.assertEqual(lhresp.payload, None)
+ self.assertEqual(lhresp.call, "upload_file")
+ lhresp.Raise("should not raise")
+ self.assertEqual(http_proc.reqcount, 1)
+
- def testAddressLookupSimpleStore(self):
+class TestSsconfResolver(unittest.TestCase):
+ def testSsconfLookup(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
- node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
+ node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
- result = rpc._AddressLookup(node_list, ssc=ssc)
- self.assertEqual(result, addr_list)
+ result = rpc._SsconfResolver(True, node_list, NotImplemented,
+ ssc=ssc, nslookup_fn=NotImplemented)
+ self.assertEqual(result, zip(node_list, addr_list))
- def testAddressLookupNSLookup(self):
+ def testNsLookup(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
ssc = GetFakeSimpleStoreClass(lambda _: [])
node_addr_map = dict(zip(node_list, addr_list))
nslookup_fn = lambda name, family=None: node_addr_map.get(name)
- result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
- self.assertEqual(result, addr_list)
+ result = rpc._SsconfResolver(True, node_list, NotImplemented,
+ ssc=ssc, nslookup_fn=nslookup_fn)
+ self.assertEqual(result, zip(node_list, addr_list))
+
+ def testDisabledSsconfIp(self):
+ addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
+ node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
+ ssc = GetFakeSimpleStoreClass(_RaiseNotImplemented)
+ node_addr_map = dict(zip(node_list, addr_list))
+ nslookup_fn = lambda name, family=None: node_addr_map.get(name)
+ result = rpc._SsconfResolver(False, node_list, NotImplemented,
+ ssc=ssc, nslookup_fn=nslookup_fn)
+ self.assertEqual(result, zip(node_list, addr_list))
- def testAddressLookupBoth(self):
+ def testBothLookups(self):
addr_list = ["192.0.2.%d" % n for n in range(0, 255, 13)]
node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
n = len(addr_list) / 2
- node_addr_list = [ " ".join(t) for t in zip(node_list[n:], addr_list[n:])]
+ node_addr_list = [" ".join(t) for t in zip(node_list[n:], addr_list[n:])]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
node_addr_map = dict(zip(node_list[:n], addr_list[:n]))
nslookup_fn = lambda name, family=None: node_addr_map.get(name)
- result = rpc._AddressLookup(node_list, ssc=ssc, nslookup_fn=nslookup_fn)
- self.assertEqual(result, addr_list)
+ result = rpc._SsconfResolver(True, node_list, NotImplemented,
+ ssc=ssc, nslookup_fn=nslookup_fn)
+ self.assertEqual(result, zip(node_list, addr_list))
def testAddressLookupIPv6(self):
- addr_list = ["2001:db8::%d" % n for n in range(0, 255, 13)]
- node_list = ["node%d.example.com" % n for n in range(0, 255, 13)]
- node_addr_list = [ " ".join(t) for t in zip(node_list, addr_list)]
+ addr_list = ["2001:db8::%d" % n for n in range(0, 255, 11)]
+ node_list = ["node%d.example.com" % n for n in range(0, 255, 11)]
+ node_addr_list = [" ".join(t) for t in zip(node_list, addr_list)]
ssc = GetFakeSimpleStoreClass(lambda _: node_addr_list)
- result = rpc._AddressLookup(node_list, ssc=ssc)
- self.assertEqual(result, addr_list)
+ result = rpc._SsconfResolver(True, node_list, NotImplemented,
+ ssc=ssc, nslookup_fn=NotImplemented)
+ self.assertEqual(result, zip(node_list, addr_list))
+
+
+class TestStaticResolver(unittest.TestCase):
+ def test(self):
+ addresses = ["192.0.2.%d" % n for n in range(0, 123, 7)]
+ nodes = ["node%s.example.com" % n for n in range(0, 123, 7)]
+ res = rpc._StaticResolver(addresses)
+ self.assertEqual(res(nodes, NotImplemented), zip(nodes, addresses))
+
+ def testWrongLength(self):
+ res = rpc._StaticResolver([])
+ self.assertRaises(AssertionError, res, ["abc"], NotImplemented)
+
+
+class TestNodeConfigResolver(unittest.TestCase):
+ @staticmethod
+ def _GetSingleOnlineNode(name):
+ assert name == "node90.example.com"
+ return objects.Node(name=name, offline=False, primary_ip="192.0.2.90")
+
+ @staticmethod
+ def _GetSingleOfflineNode(name):
+ assert name == "node100.example.com"
+ return objects.Node(name=name, offline=True, primary_ip="192.0.2.100")
+
+ def testSingleOnline(self):
+ self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOnlineNode,
+ NotImplemented,
+ ["node90.example.com"], None),
+ [("node90.example.com", "192.0.2.90")])
+
+ def testSingleOffline(self):
+ self.assertEqual(rpc._NodeConfigResolver(self._GetSingleOfflineNode,
+ NotImplemented,
+ ["node100.example.com"], None),
+ [("node100.example.com", rpc._OFFLINE)])
+
+ def testSingleOfflineWithAcceptOffline(self):
+ fn = self._GetSingleOfflineNode
+ assert fn("node100.example.com").offline
+ self.assertEqual(rpc._NodeConfigResolver(fn, NotImplemented,
+ ["node100.example.com"],
+ rpc_defs.ACCEPT_OFFLINE_NODE),
+ [("node100.example.com", "192.0.2.100")])
+ for i in [False, True, "", "Hello", 0, 1]:
+ self.assertRaises(AssertionError, rpc._NodeConfigResolver,
+ fn, NotImplemented, ["node100.example.com"], i)
+
+ def testUnknownSingleNode(self):
+ self.assertEqual(rpc._NodeConfigResolver(lambda _: None, NotImplemented,
+ ["node110.example.com"], None),
+ [("node110.example.com", "node110.example.com")])
+
+ def testMultiEmpty(self):
+ self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
+ lambda: {},
+ [], None),
+ [])
+
+ def testMultiSomeOffline(self):
+ nodes = dict(("node%s.example.com" % i,
+ objects.Node(name="node%s.example.com" % i,
+ offline=((i % 3) == 0),
+ primary_ip="192.0.2.%s" % i))
+ for i in range(1, 255))
+
+ # Resolve no names
+ self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
+ lambda: nodes,
+ [], None),
+ [])
+
+ # Offline, online and unknown hosts
+ self.assertEqual(rpc._NodeConfigResolver(NotImplemented,
+ lambda: nodes,
+ ["node3.example.com",
+ "node92.example.com",
+ "node54.example.com",
+ "unknown.example.com",],
+ None), [
+ ("node3.example.com", rpc._OFFLINE),
+ ("node92.example.com", "192.0.2.92"),
+ ("node54.example.com", rpc._OFFLINE),
+ ("unknown.example.com", "unknown.example.com"),
+ ])
+
+
+class TestCompress(unittest.TestCase):
+ def test(self):
+ for data in ["", "Hello", "Hello World!\nnew\nlines"]:
+ self.assertEqual(rpc._Compress(data),
+ (constants.RPC_ENCODING_NONE, data))
+
+ for data in [512 * " ", 5242 * "Hello World!\n"]:
+ compressed = rpc._Compress(data)
+ self.assertEqual(len(compressed), 2)
+ self.assertEqual(backend._Decompress(compressed), data)
+
+ def testDecompression(self):
+ self.assertRaises(AssertionError, backend._Decompress, "")
+ self.assertRaises(AssertionError, backend._Decompress, [""])
+ self.assertRaises(AssertionError, backend._Decompress,
+ ("unknown compression", "data"))
+ self.assertRaises(Exception, backend._Decompress,
+ (constants.RPC_ENCODING_ZLIB_BASE64, "invalid zlib data"))
+
+
+class TestRpcClientBase(unittest.TestCase):
+ def testNoHosts(self):
+ cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_SLOW, [],
+ None, None, NotImplemented)
+ http_proc = _FakeRequestProcessor(NotImplemented)
+ client = rpc._RpcClientBase(rpc._StaticResolver([]), NotImplemented,
+ _req_process_fn=http_proc)
+ self.assertEqual(client._Call(cdef, [], []), {})
+
+ # Test wrong number of arguments
+ self.assertRaises(errors.ProgrammerError, client._Call,
+ cdef, [], [0, 1, 2])
+
+ def testTimeout(self):
+ def _CalcTimeout((arg1, arg2)):
+ return arg1 + arg2
+
+ def _VerifyRequest(exp_timeout, req):
+ self.assertEqual(req.read_timeout, exp_timeout)
+
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, hex(req.read_timeout)))
+
+ resolver = rpc._StaticResolver([
+ "192.0.2.1",
+ "192.0.2.2",
+ ])
+
+ nodes = [
+ "node1.example.com",
+ "node2.example.com",
+ ]
+
+ tests = [(100, None, 100), (30, None, 30)]
+ tests.extend((_CalcTimeout, i, i + 300)
+ for i in [0, 5, 16485, 30516])
+
+ for timeout, arg1, exp_timeout in tests:
+ cdef = ("test_call", NotImplemented, None, timeout, [
+ ("arg1", None, NotImplemented),
+ ("arg2", None, NotImplemented),
+ ], None, None, NotImplemented)
+
+ http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest,
+ exp_timeout))
+ client = rpc._RpcClientBase(resolver, NotImplemented,
+ _req_process_fn=http_proc)
+ result = client._Call(cdef, nodes, [arg1, 300])
+ self.assertEqual(len(result), len(nodes))
+ self.assertTrue(compat.all(not res.fail_msg and
+ res.payload == hex(exp_timeout)
+ for res in result.values()))
+
+ def testArgumentEncoder(self):
+ (AT1, AT2) = range(1, 3)
+
+ resolver = rpc._StaticResolver([
+ "192.0.2.5",
+ "192.0.2.6",
+ ])
+
+ nodes = [
+ "node5.example.com",
+ "node6.example.com",
+ ]
+
+ encoders = {
+ AT1: hex,
+ AT2: hash,
+ }
+
+ cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
+ ("arg0", None, NotImplemented),
+ ("arg1", AT1, NotImplemented),
+ ("arg1", AT2, NotImplemented),
+ ], None, None, NotImplemented)
+
+ def _VerifyRequest(req):
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, req.post_data))
+
+ http_proc = _FakeRequestProcessor(_VerifyRequest)
+
+ for num in [0, 3796, 9032119]:
+ client = rpc._RpcClientBase(resolver, encoders.get,
+ _req_process_fn=http_proc)
+ result = client._Call(cdef, nodes, ["foo", num, "Hello%s" % num])
+ self.assertEqual(len(result), len(nodes))
+ for res in result.values():
+ self.assertFalse(res.fail_msg)
+ self.assertEqual(serializer.LoadJson(res.payload),
+ ["foo", hex(num), hash("Hello%s" % num)])
+
+ def testPostProc(self):
+ def _VerifyRequest(nums, req):
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, nums))
+
+ resolver = rpc._StaticResolver([
+ "192.0.2.90",
+ "192.0.2.95",
+ ])
+
+ nodes = [
+ "node90.example.com",
+ "node95.example.com",
+ ]
+
+ def _PostProc(res):
+ self.assertFalse(res.fail_msg)
+ res.payload = sum(res.payload)
+ return res
+
+ cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [],
+ None, _PostProc, NotImplemented)
+
+ # Seeded random generator
+ rnd = random.Random(20299)
+
+ for i in [0, 4, 74, 1391]:
+ nums = [rnd.randint(0, 1000) for _ in range(i)]
+ http_proc = _FakeRequestProcessor(compat.partial(_VerifyRequest, nums))
+ client = rpc._RpcClientBase(resolver, NotImplemented,
+ _req_process_fn=http_proc)
+ result = client._Call(cdef, nodes, [])
+ self.assertEqual(len(result), len(nodes))
+ for res in result.values():
+ self.assertFalse(res.fail_msg)
+ self.assertEqual(res.payload, sum(nums))
+
+ def testPreProc(self):
+ def _VerifyRequest(req):
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, req.post_data))
+
+ resolver = rpc._StaticResolver([
+ "192.0.2.30",
+ "192.0.2.35",
+ ])
+
+ nodes = [
+ "node30.example.com",
+ "node35.example.com",
+ ]
+
+ def _PreProc(node, data):
+ self.assertEqual(len(data), 1)
+ return data[0] + node
+
+ cdef = ("test_call", NotImplemented, None, constants.RPC_TMO_NORMAL, [
+ ("arg0", None, NotImplemented),
+ ], _PreProc, None, NotImplemented)
+
+ http_proc = _FakeRequestProcessor(_VerifyRequest)
+ client = rpc._RpcClientBase(resolver, NotImplemented,
+ _req_process_fn=http_proc)
+
+ for prefix in ["foo", "bar", "baz"]:
+ result = client._Call(cdef, nodes, [prefix])
+ self.assertEqual(len(result), len(nodes))
+ for (idx, (node, res)) in enumerate(result.items()):
+ self.assertFalse(res.fail_msg)
+ self.assertEqual(serializer.LoadJson(res.payload), prefix + node)
+
+ def testResolverOptions(self):
+ def _VerifyRequest(req):
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, req.post_data))
+
+ nodes = [
+ "node30.example.com",
+ "node35.example.com",
+ ]
+
+ def _Resolver(expected, hosts, options):
+ self.assertEqual(hosts, nodes)
+ self.assertEqual(options, expected)
+ return zip(hosts, nodes)
+
+ def _DynamicResolverOptions((arg0, )):
+ return sum(arg0)
+
+ tests = [
+ (None, None, None),
+ (rpc_defs.ACCEPT_OFFLINE_NODE, None, rpc_defs.ACCEPT_OFFLINE_NODE),
+ (False, None, False),
+ (True, None, True),
+ (0, None, 0),
+ (_DynamicResolverOptions, [1, 2, 3], 6),
+ (_DynamicResolverOptions, range(4, 19), 165),
+ ]
+
+ for (resolver_opts, arg0, expected) in tests:
+ cdef = ("test_call", NotImplemented, resolver_opts,
+ constants.RPC_TMO_NORMAL, [
+ ("arg0", None, NotImplemented),
+ ], None, None, NotImplemented)
+
+ http_proc = _FakeRequestProcessor(_VerifyRequest)
+
+ client = rpc._RpcClientBase(compat.partial(_Resolver, expected),
+ NotImplemented, _req_process_fn=http_proc)
+ result = client._Call(cdef, nodes, [arg0])
+ self.assertEqual(len(result), len(nodes))
+ for (idx, (node, res)) in enumerate(result.items()):
+ self.assertFalse(res.fail_msg)
+
+
+class _FakeConfigForRpcRunner:
+ GetAllNodesInfo = NotImplemented
+
+ def __init__(self, cluster=NotImplemented):
+ self._cluster = cluster
+
+ def GetNodeInfo(self, name):
+ return objects.Node(name=name)
+
+ def GetClusterInfo(self):
+ return self._cluster
+
+ def GetInstanceDiskParams(self, _):
+ return constants.DISK_DT_DEFAULTS
+
+
+class TestRpcRunner(unittest.TestCase):
+ def testUploadFile(self):
+ data = 1779 * "Hello World\n"
+
+ tmpfile = tempfile.NamedTemporaryFile()
+ tmpfile.write(data)
+ tmpfile.flush()
+ st = os.stat(tmpfile.name)
+
+ def _VerifyRequest(req):
+ (uldata, ) = serializer.LoadJson(req.post_data)
+ self.assertEqual(len(uldata), 7)
+ self.assertEqual(uldata[0], tmpfile.name)
+ self.assertEqual(list(uldata[1]), list(rpc._Compress(data)))
+ self.assertEqual(uldata[2], st.st_mode)
+ self.assertEqual(uldata[3], "user%s" % os.getuid())
+ self.assertEqual(uldata[4], "group%s" % os.getgid())
+ self.assertTrue(uldata[5] is not None)
+ self.assertEqual(uldata[6], st.st_mtime)
+
+ req.success = True
+ req.resp_status_code = http.HTTP_OK
+ req.resp_body = serializer.DumpJson((True, None))
+
+ http_proc = _FakeRequestProcessor(_VerifyRequest)
+
+ std_runner = rpc.RpcRunner(_FakeConfigForRpcRunner(), None,
+ _req_process_fn=http_proc,
+ _getents=mocks.FakeGetentResolver)
+
+ cfg_runner = rpc.ConfigRunner(None, ["192.0.2.13"],
+ _req_process_fn=http_proc,
+ _getents=mocks.FakeGetentResolver)
+
+ nodes = [
+ "node1.example.com",
+ ]
+
+ for runner in [std_runner, cfg_runner]:
+ result = runner.call_upload_file(nodes, tmpfile.name)
+ self.assertEqual(len(result), len(nodes))
+ for (idx, (node, res)) in enumerate(result.items()):
+ self.assertFalse(res.fail_msg)
+
+ def testEncodeInstance(self):
+ cluster = objects.Cluster(hvparams={
+ constants.HT_KVM: {
+ constants.HV_BLOCKDEV_PREFIX: "foo",
+ },
+ },
+ beparams={
+ constants.PP_DEFAULT: {
+ constants.BE_MAXMEM: 8192,
+ },
+ },
+ os_hvp={},
+ osparams={
+ "linux": {
+ "role": "unknown",
+ },
+ })
+ cluster.UpgradeConfig()
+
+ inst = objects.Instance(name="inst1.example.com",
+ hypervisor=constants.HT_FAKE,
+ os="linux",
+ hvparams={
+ constants.HT_KVM: {
+ constants.HV_BLOCKDEV_PREFIX: "bar",
+ constants.HV_ROOT_PATH: "/tmp",
+ },
+ },
+ beparams={
+ constants.BE_MINMEM: 128,
+ constants.BE_MAXMEM: 256,
+ },
+ nics=[
+ objects.NIC(nicparams={
+ constants.NIC_MODE: "mymode",
+ }),
+ ],
+ disk_template=constants.DT_DISKLESS,
+ disks=[])
+ inst.UpgradeConfig()
+
+ cfg = _FakeConfigForRpcRunner(cluster=cluster)
+ runner = rpc.RpcRunner(cfg, None,
+ _req_process_fn=NotImplemented,
+ _getents=mocks.FakeGetentResolver)
+
+ def _CheckBasics(result):
+ self.assertEqual(result["name"], "inst1.example.com")
+ self.assertEqual(result["os"], "linux")
+ self.assertEqual(result["beparams"][constants.BE_MINMEM], 128)
+ self.assertEqual(len(result["hvparams"]), 1)
+ self.assertEqual(len(result["nics"]), 1)
+ self.assertEqual(result["nics"][0]["nicparams"][constants.NIC_MODE],
+ "mymode")
+
+ # Generic object serialization
+ result = runner._encoder((rpc_defs.ED_OBJECT_DICT, inst))
+ _CheckBasics(result)
+
+ result = runner._encoder((rpc_defs.ED_OBJECT_DICT_LIST, 5 * [inst]))
+ map(_CheckBasics, result)
+
+ # Just an instance
+ result = runner._encoder((rpc_defs.ED_INST_DICT, inst))
+ _CheckBasics(result)
+ self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
+ self.assertEqual(result["hvparams"][constants.HT_KVM], {
+ constants.HV_BLOCKDEV_PREFIX: "bar",
+ constants.HV_ROOT_PATH: "/tmp",
+ })
+ self.assertEqual(result["osparams"], {
+ "role": "unknown",
+ })
+
+ # Instance with OS parameters
+ result = runner._encoder((rpc_defs.ED_INST_DICT_OSP_DP, (inst, {
+ "role": "webserver",
+ "other": "field",
+ })))
+ _CheckBasics(result)
+ self.assertEqual(result["beparams"][constants.BE_MAXMEM], 256)
+ self.assertEqual(result["hvparams"][constants.HT_KVM], {
+ constants.HV_BLOCKDEV_PREFIX: "bar",
+ constants.HV_ROOT_PATH: "/tmp",
+ })
+ self.assertEqual(result["osparams"], {
+ "role": "webserver",
+ "other": "field",
+ })
+
+ # Instance with hypervisor and backend parameters
+ result = runner._encoder((rpc_defs.ED_INST_DICT_HVP_BEP, (inst, {
+ constants.HT_KVM: {
+ constants.HV_BOOT_ORDER: "xyz",
+ },
+ }, {
+ constants.BE_VCPUS: 100,
+ constants.BE_MAXMEM: 4096,
+ })))
+ _CheckBasics(result)
+ self.assertEqual(result["beparams"][constants.BE_MAXMEM], 4096)
+ self.assertEqual(result["beparams"][constants.BE_VCPUS], 100)
+ self.assertEqual(result["hvparams"][constants.HT_KVM], {
+ constants.HV_BOOT_ORDER: "xyz",
+ })
if __name__ == "__main__":