Merge ~cgrabowski/maas:rpc_connection_pool_burst into maas:master

Proposed by Christian Grabowski
Status: Merged
Approved by: Christian Grabowski
Approved revision: 0c5564670e61e93261925cf91a46850824536a77
Merge reported by: MAAS Lander
Merged at revision: not available
Proposed branch: ~cgrabowski/maas:rpc_connection_pool_burst
Merge into: maas:master
Diff against target: 1518 lines (+741/-175)
15 files modified
src/provisioningserver/config.py (+17/-0)
src/provisioningserver/dhcp/tests/test_config.py (+2/-2)
src/provisioningserver/plugin.py (+7/-1)
src/provisioningserver/rackdservices/external.py (+3/-2)
src/provisioningserver/rackdservices/http.py (+3/-2)
src/provisioningserver/rackdservices/tests/test_external.py (+8/-5)
src/provisioningserver/rackdservices/tests/test_http.py (+5/-4)
src/provisioningserver/rpc/clusterservice.py (+56/-50)
src/provisioningserver/rpc/common.py (+22/-5)
src/provisioningserver/rpc/connectionpool.py (+163/-0)
src/provisioningserver/rpc/exceptions.py (+8/-0)
src/provisioningserver/rpc/testing/__init__.py (+2/-1)
src/provisioningserver/rpc/testing/doubles.py (+18/-0)
src/provisioningserver/rpc/tests/test_clusterservice.py (+147/-103)
src/provisioningserver/rpc/tests/test_connectionpool.py (+280/-0)
Reviewer Review Type Date Requested Status
Alexsander de Souza Approve
MAAS Lander Approve
Review via email: mp+428335@code.launchpad.net

Commit message

move connection lifecycle logic into ConnectionPool

allocate additional connections when busy

always connect max idle connections times

To post a comment you must log in.
Revision history for this message
MAAS Lander (maas-lander) wrote :

UNIT TESTS
-b rpc_connection_pool_burst lp:~cgrabowski/maas/+git/maas into -b master lp:~maas-committers/maas

STATUS: FAILED
LOG: http://maas-ci.internal:8080/job/maas-tester/318/consoleText
COMMIT: c19b8616c6acc5cde30fa33916ad40da3d3b04e5

review: Needs Fixing
Revision history for this message
Adam Collard (adam-collard) wrote :

jenkins: !test

Revision history for this message
MAAS Lander (maas-lander) wrote :

UNIT TESTS
-b rpc_connection_pool_burst lp:~cgrabowski/maas/+git/maas into -b master lp:~maas-committers/maas

STATUS: FAILED
LOG: http://maas-ci.internal:8080/job/maas-tester/321/consoleText
COMMIT: c19b8616c6acc5cde30fa33916ad40da3d3b04e5

review: Needs Fixing
Revision history for this message
MAAS Lander (maas-lander) wrote :

UNIT TESTS
-b rpc_connection_pool_burst lp:~cgrabowski/maas/+git/maas into -b master lp:~maas-committers/maas

STATUS: SUCCESS
COMMIT: 0c5564670e61e93261925cf91a46850824536a77

review: Approve
Revision history for this message
Alexsander de Souza (alexsander-souza) wrote :

LGTM

review: Approve

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
diff --git a/src/provisioningserver/config.py b/src/provisioningserver/config.py
index 97b6e68..9511ad6 100644
--- a/src/provisioningserver/config.py
+++ b/src/provisioningserver/config.py
@@ -762,6 +762,23 @@ class ClusterConfiguration(Configuration, metaclass=ClusterConfigurationMeta):
762 ),762 ),
763 )763 )
764764
765 # RPC Connection Pool options
766 max_idle_rpc_connections = ConfigurationOption(
767 "max_idle_rpc_connections",
768 "The nominal number of connections to have per endpoint",
769 Number(min=1, max=1024, if_missing=1),
770 )
771 max_rpc_connections = ConfigurationOption(
772 "max_rpc_connections",
773 "The maximum number of connections to scale to when under load",
774 Number(min=1, max=1024, if_missing=4),
775 )
776 rpc_keepalive = ConfigurationOption(
777 "rpc_keepalive",
778 "The duration in miliseconds to keep added connections alive",
779 Number(min=1, max=5000, if_missing=1000),
780 )
781
765 # TFTP options.782 # TFTP options.
766 tftp_port = ConfigurationOption(783 tftp_port = ConfigurationOption(
767 "tftp_port",784 "tftp_port",
diff --git a/src/provisioningserver/dhcp/tests/test_config.py b/src/provisioningserver/dhcp/tests/test_config.py
index c53e906..8c3f2fe 100644
--- a/src/provisioningserver/dhcp/tests/test_config.py
+++ b/src/provisioningserver/dhcp/tests/test_config.py
@@ -176,7 +176,7 @@ def validate_dhcpd_configuration(test, configuration, ipv6):
176 ),176 ),
177 ),177 ),
178 )178 )
179 cmd = (179 cmd = [
180 "dhcpd",180 "dhcpd",
181 ("-6" if ipv6 else "-4"),181 ("-6" if ipv6 else "-4"),
182 "-t",182 "-t",
@@ -184,7 +184,7 @@ def validate_dhcpd_configuration(test, configuration, ipv6):
184 conffile.name,184 conffile.name,
185 "-lf",185 "-lf",
186 leasesfile.name,186 leasesfile.name,
187 )187 ]
188 if not running_in_docker():188 if not running_in_docker():
189 # Call `dhcpd` without AppArmor confinement, so that it can read189 # Call `dhcpd` without AppArmor confinement, so that it can read
190 # configurations file from /tmp. This is not needed when running190 # configurations file from /tmp. This is not needed when running
diff --git a/src/provisioningserver/plugin.py b/src/provisioningserver/plugin.py
index e987c73..00ff898 100644
--- a/src/provisioningserver/plugin.py
+++ b/src/provisioningserver/plugin.py
@@ -139,7 +139,13 @@ class ProvisioningServiceMaker:
139 def _makeRPCService(self):139 def _makeRPCService(self):
140 from provisioningserver.rpc.clusterservice import ClusterClientService140 from provisioningserver.rpc.clusterservice import ClusterClientService
141141
142 rpc_service = ClusterClientService(reactor)142 with ClusterConfiguration.open() as config:
143 rpc_service = ClusterClientService(
144 reactor,
145 config.max_idle_rpc_connections,
146 config.max_rpc_connections,
147 config.rpc_keepalive,
148 )
143 rpc_service.setName("rpc")149 rpc_service.setName("rpc")
144 return rpc_service150 return rpc_service
145151
diff --git a/src/provisioningserver/rackdservices/external.py b/src/provisioningserver/rackdservices/external.py
index ccabb74..5b8afe5 100644
--- a/src/provisioningserver/rackdservices/external.py
+++ b/src/provisioningserver/rackdservices/external.py
@@ -68,8 +68,9 @@ class RackOnlyExternalService(metaclass=ABCMeta):
6868
69 # Filter the connects by region.69 # Filter the connects by region.
70 conn_per_region = defaultdict(set)70 conn_per_region = defaultdict(set)
71 for eventloop, connection in connections.items():71 for eventloop, connection_set in connections.items():
72 conn_per_region[eventloop.split(":")[0]].add(connection)72 for connection in connection_set:
73 conn_per_region[eventloop.split(":")[0]].add(connection)
73 for eventloop, connections in conn_per_region.items():74 for eventloop, connections in conn_per_region.items():
74 # Sort the connections so the same IP is always picked per75 # Sort the connections so the same IP is always picked per
75 # region controller. This ensures that the HTTP configuration76 # region controller. This ensures that the HTTP configuration
diff --git a/src/provisioningserver/rackdservices/http.py b/src/provisioningserver/rackdservices/http.py
index 421e35f..bda9d23 100644
--- a/src/provisioningserver/rackdservices/http.py
+++ b/src/provisioningserver/rackdservices/http.py
@@ -101,8 +101,9 @@ class RackHTTPService(TimerService):
101 controller is connected to."""101 controller is connected to."""
102 # Filter the connects by region.102 # Filter the connects by region.
103 conn_per_region = defaultdict(set)103 conn_per_region = defaultdict(set)
104 for eventloop, connection in self._rpc_service.connections.items():104 for eventloop, connection_set in self._rpc_service.connections.items():
105 conn_per_region[eventloop.split(":")[0]].add(connection)105 for connection in connection_set:
106 conn_per_region[eventloop.split(":")[0]].add(connection)
106 for _, connections in conn_per_region.items():107 for _, connections in conn_per_region.items():
107 # Sort the connections so the same IP is always picked per108 # Sort the connections so the same IP is always picked per
108 # region controller. This ensures that the HTTP configuration109 # region controller. This ensures that the HTTP configuration
diff --git a/src/provisioningserver/rackdservices/tests/test_external.py b/src/provisioningserver/rackdservices/tests/test_external.py
index ad214a1..0cb8601 100644
--- a/src/provisioningserver/rackdservices/tests/test_external.py
+++ b/src/provisioningserver/rackdservices/tests/test_external.py
@@ -430,7 +430,8 @@ class TestRackDNS(MAASTestCase):
430 return frozenset(430 return frozenset(
431 {431 {
432 client.address[0]432 client.address[0]
433 for _, client in rpc_service.connections.items()433 for _, clients in rpc_service.connections.items()
434 for client in clients
434 }435 }
435 )436 )
436437
@@ -609,7 +610,7 @@ class TestRackDNS(MAASTestCase):
609 ip = factory.make_ip_address()610 ip = factory.make_ip_address()
610 mock_conn = Mock()611 mock_conn = Mock()
611 mock_conn.address = (ip, random.randint(5240, 5250))612 mock_conn.address = (ip, random.randint(5240, 5250))
612 mock_rpc.connections[eventloop] = mock_conn613 mock_rpc.connections[eventloop] = {mock_conn}
613614
614 dns = external.RackDNS()615 dns = external.RackDNS()
615 region_ips = list(dns._genRegionIps(mock_rpc.connections))616 region_ips = list(dns._genRegionIps(mock_rpc.connections))
@@ -626,7 +627,7 @@ class TestRackDNS(MAASTestCase):
626 ip = factory.make_ip_address()627 ip = factory.make_ip_address()
627 mock_conn = Mock()628 mock_conn = Mock()
628 mock_conn.address = (ip, random.randint(5240, 5250))629 mock_conn.address = (ip, random.randint(5240, 5250))
629 mock_rpc.connections[eventloop] = mock_conn630 mock_rpc.connections[eventloop] = {mock_conn}
630631
631 dns = external.RackDNS()632 dns = external.RackDNS()
632 region_ips = frozenset(dns._genRegionIps(mock_rpc.connections))633 region_ips = frozenset(dns._genRegionIps(mock_rpc.connections))
@@ -659,7 +660,8 @@ class TestRackProxy(MAASTestCase):
659 return frozenset(660 return frozenset(
660 {661 {
661 client.address[0]662 client.address[0]
662 for _, client in rpc_service.connections.items()663 for _, clients in rpc_service.connections.items()
664 for client in clients
663 }665 }
664 )666 )
665667
@@ -824,7 +826,8 @@ class TestRackSyslog(MAASTestCase):
824 return frozenset(826 return frozenset(
825 {827 {
826 (eventloop, client.address[0])828 (eventloop, client.address[0])
827 for eventloop, client in rpc_service.connections.items()829 for eventloop, clients in rpc_service.connections.items()
830 for client in clients
828 }831 }
829 )832 )
830833
diff --git a/src/provisioningserver/rackdservices/tests/test_http.py b/src/provisioningserver/rackdservices/tests/test_http.py
index bc43c66..43cb495 100644
--- a/src/provisioningserver/rackdservices/tests/test_http.py
+++ b/src/provisioningserver/rackdservices/tests/test_http.py
@@ -92,7 +92,8 @@ class TestRackHTTPService(MAASTestCase):
92 return frozenset(92 return frozenset(
93 {93 {
94 client.address[0]94 client.address[0]
95 for _, client in rpc_service.connections.items()95 for _, clients in rpc_service.connections.items()
96 for client in clients
96 }97 }
97 )98 )
9899
@@ -208,7 +209,7 @@ class TestRackHTTPService(MAASTestCase):
208 ip = factory.make_ip_address()209 ip = factory.make_ip_address()
209 mock_conn = Mock()210 mock_conn = Mock()
210 mock_conn.address = (ip, random.randint(5240, 5250))211 mock_conn.address = (ip, random.randint(5240, 5250))
211 mock_rpc.connections[eventloop] = mock_conn212 mock_rpc.connections[eventloop] = {mock_conn}
212213
213 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)214 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)
214 region_ips = list(service._genRegionIps())215 region_ips = list(service._genRegionIps())
@@ -225,7 +226,7 @@ class TestRackHTTPService(MAASTestCase):
225 ip = factory.make_ip_address()226 ip = factory.make_ip_address()
226 mock_conn = Mock()227 mock_conn = Mock()
227 mock_conn.address = (ip, random.randint(5240, 5250))228 mock_conn.address = (ip, random.randint(5240, 5250))
228 mock_rpc.connections[eventloop] = mock_conn229 mock_rpc.connections[eventloop] = {mock_conn}
229230
230 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)231 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)
231 region_ips = frozenset(service._genRegionIps())232 region_ips = frozenset(service._genRegionIps())
@@ -244,7 +245,7 @@ class TestRackHTTPService(MAASTestCase):
244 ip_addresses.add("[%s]" % ip)245 ip_addresses.add("[%s]" % ip)
245 mock_conn = Mock()246 mock_conn = Mock()
246 mock_conn.address = (ip, random.randint(5240, 5250))247 mock_conn.address = (ip, random.randint(5240, 5250))
247 mock_rpc.connections[eventloop] = mock_conn248 mock_rpc.connections[eventloop] = {mock_conn}
248249
249 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)250 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)
250 region_ips = set(service._genRegionIps())251 region_ips = set(service._genRegionIps())
diff --git a/src/provisioningserver/rpc/clusterservice.py b/src/provisioningserver/rpc/clusterservice.py
index c92d48a..a7205db 100644
--- a/src/provisioningserver/rpc/clusterservice.py
+++ b/src/provisioningserver/rpc/clusterservice.py
@@ -9,7 +9,6 @@ import json
9from operator import itemgetter9from operator import itemgetter
10import os10import os
11from os import urandom11from os import urandom
12import random
13from socket import AF_INET, AF_INET6, gethostname12from socket import AF_INET, AF_INET6, gethostname
14import sys13import sys
15from urllib.parse import urlparse14from urllib.parse import urlparse
@@ -24,7 +23,6 @@ from twisted.internet.defer import (
24 maybeDeferred,23 maybeDeferred,
25 returnValue,24 returnValue,
26)25)
27from twisted.internet.endpoints import connectProtocol, TCP6ClientEndpoint
28from twisted.internet.error import ConnectError, ConnectionClosed, ProcessDone26from twisted.internet.error import ConnectError, ConnectionClosed, ProcessDone
29from twisted.internet.threads import deferToThread27from twisted.internet.threads import deferToThread
30from twisted.protocols import amp28from twisted.protocols import amp
@@ -67,6 +65,7 @@ from provisioningserver.rpc.boot_images import (
67 list_boot_images,65 list_boot_images,
68)66)
69from provisioningserver.rpc.common import Ping, RPCProtocol67from provisioningserver.rpc.common import Ping, RPCProtocol
68from provisioningserver.rpc.connectionpool import ConnectionPool
70from provisioningserver.rpc.exceptions import CannotConfigureDHCP69from provisioningserver.rpc.exceptions import CannotConfigureDHCP
71from provisioningserver.rpc.interfaces import IConnectionToRegion70from provisioningserver.rpc.interfaces import IConnectionToRegion
72from provisioningserver.rpc.osystems import (71from provisioningserver.rpc.osystems import (
@@ -999,6 +998,7 @@ class ClusterClient(Cluster):
999 # Events for this protocol's life-cycle.998 # Events for this protocol's life-cycle.
1000 self.authenticated = DeferredValue()999 self.authenticated = DeferredValue()
1001 self.ready = DeferredValue()1000 self.ready = DeferredValue()
1001 self.in_use = False
1002 self.localIdent = None1002 self.localIdent = None
10031003
1004 @property1004 @property
@@ -1201,13 +1201,15 @@ class ClusterClientService(TimerService):
12011201
1202 time_started = None1202 time_started = None
12031203
1204 def __init__(self, reactor):1204 def __init__(self, reactor, max_idle_conns=1, max_conns=1, keepalive=1000):
1205 super().__init__(self._calculate_interval(None, None), self._tryUpdate)1205 super().__init__(self._calculate_interval(None, None), self._tryUpdate)
1206 self.connections = {}
1207 self.try_connections = {}
1208 self._previous_work = (None, None)1206 self._previous_work = (None, None)
1209 self.clock = reactor1207 self.clock = reactor
12101208
1209 self.connections = ConnectionPool(
1210 reactor, self, max_idle_conns, max_conns, keepalive
1211 )
1212
1211 # Stored the URL used to connect to the region controller. This will be1213 # Stored the URL used to connect to the region controller. This will be
1212 # the URL that was used to get the eventloops.1214 # the URL that was used to get the eventloops.
1213 self.maas_url = None1215 self.maas_url = None
@@ -1236,11 +1238,19 @@ class ClusterClientService(TimerService):
1236 :raises: :py:class:`~.exceptions.NoConnectionsAvailable` when1238 :raises: :py:class:`~.exceptions.NoConnectionsAvailable` when
1237 there are no open connections to a region controller.1239 there are no open connections to a region controller.
1238 """1240 """
1239 conns = list(self.connections.values())1241 if len(self.connections) == 0:
1240 if len(conns) == 0:
1241 raise exceptions.NoConnectionsAvailable()1242 raise exceptions.NoConnectionsAvailable()
1242 else:1243 else:
1243 return common.Client(random.choice(conns))1244 try:
1245 return common.Client(
1246 self.connections.get_random_free_connection()
1247 )
1248 except exceptions.AllConnectionsBusy as e:
1249 for endpoint_conns in self.connections.values():
1250 if len(endpoint_conns) < self.connections._max_connections:
1251 raise e
1252 # return a busy connection, assume it will free up or timeout
1253 return common.Client(self.connections.get_random_connection())
12441254
1245 @deferred1255 @deferred
1246 def getClientNow(self):1256 def getClientNow(self):
@@ -1259,10 +1269,17 @@ class ClusterClientService(TimerService):
1259 return self.getClient()1269 return self.getClient()
1260 except exceptions.NoConnectionsAvailable:1270 except exceptions.NoConnectionsAvailable:
1261 return self._tryUpdate().addCallback(call, self.getClient)1271 return self._tryUpdate().addCallback(call, self.getClient)
1272 except exceptions.AllConnectionsBusy:
1273 return self.connections.scale_up_connections().addCallback(
1274 call, self.getClient
1275 )
12621276
1263 def getAllClients(self):1277 def getAllClients(self):
1264 """Return a list of all connected :class:`common.Client`s."""1278 """Return a list of all connected :class:`common.Client`s."""
1265 return [common.Client(conn) for conn in self.connections.values()]1279 return [
1280 common.Client(conn)
1281 for conn in self.connections.get_all_connections()
1282 ]
12661283
1267 def _tryUpdate(self):1284 def _tryUpdate(self):
1268 """Attempt to refresh outgoing connections.1285 """Attempt to refresh outgoing connections.
@@ -1391,7 +1408,9 @@ class ClusterClientService(TimerService):
1391 """Update the saved RPC info state."""1408 """Update the saved RPC info state."""
1392 # Build a list of addresses based on the current connections.1409 # Build a list of addresses based on the current connections.
1393 connected_addr = {1410 connected_addr = {
1394 conn.address[0] for _, conn in self.connections.items()1411 conn.address[0]
1412 for _, conns in self.connections.items()
1413 for conn in conns
1395 }1414 }
1396 if (1415 if (
1397 self._rpc_info_state is None1416 self._rpc_info_state is None
@@ -1467,8 +1486,8 @@ class ClusterClientService(TimerService):
1467 # Gather the list of successful responses.1486 # Gather the list of successful responses.
1468 successful = []1487 successful = []
1469 errors = []1488 errors = []
1470 for sucess, result in results:1489 for success, result in results:
1471 if sucess:1490 if success:
1472 body, orig_url = result1491 body, orig_url = result
1473 eventloops = body.get("eventloops")1492 eventloops = body.get("eventloops")
1474 if eventloops is not None:1493 if eventloops is not None:
@@ -1656,12 +1675,15 @@ class ClusterClientService(TimerService):
1656 "Dropping connections to event-loops: %s"1675 "Dropping connections to event-loops: %s"
1657 % (", ".join(drop.keys()))1676 % (", ".join(drop.keys()))
1658 )1677 )
1678 drop_defers = []
1679 for eventloop, connections in drop.items():
1680 for connection in connections:
1681 drop_defers.append(
1682 maybeDeferred(self.connections.disconnect, connection)
1683 )
1684 self.connections.remove_connection(eventloop, connection)
1659 yield DeferredList(1685 yield DeferredList(
1660 [1686 drop_defers,
1661 maybeDeferred(self._drop_connection, connection)
1662 for eventloop, connections in drop.items()
1663 for connection in connections
1664 ],
1665 consumeErrors=True,1687 consumeErrors=True,
1666 )1688 )
16671689
@@ -1692,11 +1714,12 @@ class ClusterClientService(TimerService):
1692 # between consenting adults.1714 # between consenting adults.
1693 for eventloop, addresses in eventloops.items():1715 for eventloop, addresses in eventloops.items():
1694 if eventloop in self.connections:1716 if eventloop in self.connections:
1695 connection = self.connections[eventloop]1717 connection_list = self.connections[eventloop]
1696 if connection.address not in addresses:1718 for connection in connection_list:
1697 drop[eventloop] = [connection]1719 if connection.address not in addresses:
1698 if eventloop in self.try_connections:1720 drop[eventloop] = [connection]
1699 connection = self.try_connections[eventloop]1721 if self.connections.is_staged(eventloop):
1722 connection = self.connections.get_staged_connection(eventloop)
1700 if connection.address not in addresses:1723 if connection.address not in addresses:
1701 drop[eventloop] = [connection]1724 drop[eventloop] = [connection]
17021725
@@ -1705,7 +1728,7 @@ class ClusterClientService(TimerService):
1705 for eventloop, addresses in eventloops.items():1728 for eventloop, addresses in eventloops.items():
1706 if (1729 if (
1707 eventloop not in self.connections1730 eventloop not in self.connections
1708 and eventloop not in self.try_connections1731 and not self.connections.is_staged(eventloop)
1709 ) or eventloop in drop:1732 ) or eventloop in drop:
1710 connect[eventloop] = addresses1733 connect[eventloop] = addresses
17111734
@@ -1714,13 +1737,13 @@ class ClusterClientService(TimerService):
1714 # the process in which the event-loop is no longer running, but1737 # the process in which the event-loop is no longer running, but
1715 # it could be an indicator of a heavily loaded machine, or a1738 # it could be an indicator of a heavily loaded machine, or a
1716 # fault. In any case, it seems to make sense to disconnect.1739 # fault. In any case, it seems to make sense to disconnect.
1717 for eventloop in self.connections:1740 for eventloop in self.connections.keys():
1718 if eventloop not in eventloops:1741 if eventloop not in eventloops:
1719 connection = self.connections[eventloop]1742 connection_list = self.connections[eventloop]
1720 drop[eventloop] = [connection]1743 drop[eventloop] = connection_list
1721 for eventloop in self.try_connections:1744 for eventloop in self.connections.get_staged_connections():
1722 if eventloop not in eventloops:1745 if eventloop not in eventloops:
1723 connection = self.try_connections[eventloop]1746 connection = self.connections.get_staged_connection(eventloop)
1724 drop[eventloop] = [connection]1747 drop[eventloop] = [connection]
17251748
1726 return drop, connect1749 return drop, connect
@@ -1730,7 +1753,7 @@ class ClusterClientService(TimerService):
1730 """Connect to `eventloop` using all `addresses`."""1753 """Connect to `eventloop` using all `addresses`."""
1731 for address in addresses:1754 for address in addresses:
1732 try:1755 try:
1733 connection = yield self._make_connection(eventloop, address)1756 connection = yield self.connections.connect(eventloop, address)
1734 except ConnectError as error:1757 except ConnectError as error:
1735 host, port = address1758 host, port = address
1736 log.msg(1759 log.msg(
@@ -1747,29 +1770,17 @@ class ClusterClientService(TimerService):
1747 ),1770 ),
1748 )1771 )
1749 else:1772 else:
1750 self.try_connections[eventloop] = connection1773 self.connections.stage_connection(eventloop, connection)
1751 break1774 break
17521775
1753 def _make_connection(self, eventloop, address):1776 @inlineCallbacks
1754 """Connect to `eventloop` at `address`."""
1755 # Force everything to use AF_INET6 sockets.
1756 endpoint = TCP6ClientEndpoint(self.clock, *address)
1757 protocol = ClusterClient(address, eventloop, self)
1758 return connectProtocol(endpoint, protocol)
1759
1760 def _drop_connection(self, connection):
1761 """Drop the given `connection`."""
1762 return connection.transport.loseConnection()
1763
1764 def add_connection(self, eventloop, connection):1777 def add_connection(self, eventloop, connection):
1765 """Add the connection to the tracked connections.1778 """Add the connection to the tracked connections.
17661779
1767 Update the saved RPC info state information based on the new1780 Update the saved RPC info state information based on the new
1768 connection.1781 connection.
1769 """1782 """
1770 if eventloop in self.try_connections:1783 yield self.connections.add_connection(eventloop, connection)
1771 del self.try_connections[eventloop]
1772 self.connections[eventloop] = connection
1773 self._update_saved_rpc_info_state()1784 self._update_saved_rpc_info_state()
17741785
1775 def remove_connection(self, eventloop, connection):1786 def remove_connection(self, eventloop, connection):
@@ -1778,12 +1789,7 @@ class ClusterClientService(TimerService):
1778 If this is the last connection that was keeping rackd connected to1789 If this is the last connection that was keeping rackd connected to
1779 a regiond then dhcpd and dhcpd6 services will be turned off.1790 a regiond then dhcpd and dhcpd6 services will be turned off.
1780 """1791 """
1781 if eventloop in self.try_connections:1792 self.connections.remove_connection(eventloop, connection)
1782 if self.try_connections[eventloop] is connection:
1783 del self.try_connections[eventloop]
1784 if eventloop in self.connections:
1785 if self.connections[eventloop] is connection:
1786 del self.connections[eventloop]
1787 # Disable DHCP when no connections to a region controller.1793 # Disable DHCP when no connections to a region controller.
1788 if len(self.connections) == 0:1794 if len(self.connections) == 0:
1789 stopping_services = []1795 stopping_services = []
diff --git a/src/provisioningserver/rpc/common.py b/src/provisioningserver/rpc/common.py
index 5d67bba..40e091f 100644
--- a/src/provisioningserver/rpc/common.py
+++ b/src/provisioningserver/rpc/common.py
@@ -14,7 +14,11 @@ from twisted.python.failure import Failure
14from provisioningserver.logger import LegacyLogger14from provisioningserver.logger import LegacyLogger
15from provisioningserver.prometheus.metrics import PROMETHEUS_METRICS15from provisioningserver.prometheus.metrics import PROMETHEUS_METRICS
16from provisioningserver.rpc.interfaces import IConnection, IConnectionToRegion16from provisioningserver.rpc.interfaces import IConnection, IConnectionToRegion
17from provisioningserver.utils.twisted import asynchronous, deferWithTimeout17from provisioningserver.utils.twisted import (
18 asynchronous,
19 callOut,
20 deferWithTimeout,
21)
1822
19log = LegacyLogger()23log = LegacyLogger()
2024
@@ -156,6 +160,11 @@ class Client:
156 :return: A deferred result. Call its `wait` method (with a timeout160 :return: A deferred result. Call its `wait` method (with a timeout
157 in seconds) to block on the call's completion.161 in seconds) to block on the call's completion.
158 """162 """
163 self._conn.in_use = True
164
165 def _free_conn():
166 self._conn.in_use = False
167
159 if len(args) != 0:168 if len(args) != 0:
160 receiver_name = "{}.{}".format(169 receiver_name = "{}.{}".format(
161 self.__module__,170 self.__module__,
@@ -171,11 +180,19 @@ class Client:
171 if timeout is undefined:180 if timeout is undefined:
172 timeout = 120 # 2 minutes181 timeout = 120 # 2 minutes
173 if timeout is None or timeout <= 0:182 if timeout is None or timeout <= 0:
174 return self._conn.callRemote(cmd, **kwargs)183 d = self._conn.callRemote(cmd, **kwargs)
184 if isinstance(d, Deferred):
185 d.addBoth(lambda x: callOut(x, _free_conn))
186 else:
187 _free_conn()
188 return d
175 else:189 else:
176 return deferWithTimeout(190 d = deferWithTimeout(timeout, self._conn.callRemote, cmd, **kwargs)
177 timeout, self._conn.callRemote, cmd, **kwargs191 if isinstance(d, Deferred):
178 )192 d.addBoth(lambda x: callOut(x, _free_conn))
193 else:
194 _free_conn()
195 return d
179196
180 @asynchronous197 @asynchronous
181 def getHostCertificate(self):198 def getHostCertificate(self):
diff --git a/src/provisioningserver/rpc/connectionpool.py b/src/provisioningserver/rpc/connectionpool.py
182new file mode 100644199new file mode 100644
index 0000000..8023f80
--- /dev/null
+++ b/src/provisioningserver/rpc/connectionpool.py
@@ -0,0 +1,163 @@
1# Copyright 2022 Canonical Ltd. This software is licensed under the
2# GNU Affero General Public License version 3 (see the file LICENSE).
3
4""" RPC Connection Pooling and Lifecycle """
5
6import random
7
8from twisted.internet.defer import inlineCallbacks
9from twisted.internet.endpoints import connectProtocol, TCP6ClientEndpoint
10
11from provisioningserver.rpc import exceptions
12
13
14class ConnectionPool:
15 def __init__(
16 self, reactor, service, max_idle_conns=1, max_conns=1, keepalive=1000
17 ):
18 # The maximum number of connections to allways allocate per eventloop
19 self._max_idle_connections = max_idle_conns
20 # The maximum number of connections to allocate while under load per eventloop
21 self._max_connections = max_conns
22 # The duration in milliseconds to keep scaled up connections alive
23 self._keepalive = keepalive
24
25 self.connections = {}
26 self.try_connections = {}
27 self.clock = reactor
28 self._service = service
29
30 def __setitem__(self, key, item):
31 self.connections[key] = item
32
33 def __getitem__(self, key):
34 return self.connections.get(key)
35
36 def __repr__(self):
37 return repr(self.connections)
38
39 def __len__(self):
40 return len(self.get_all_connections())
41
42 def __delitem__(self, key):
43 del self.connections[key]
44
45 def __contains__(self, item):
46 return item in self.connections
47
48 def __cmp__(self, value):
49 return self.connections.__cmp__(value)
50
51 def __eq__(self, value):
52 return self.connections.__eq__(value)
53
54 def keys(self):
55 return self.connections.keys()
56
57 def values(self):
58 return self.connections.values()
59
60 def items(self):
61 return self.connections.items()
62
63 def _reap_extra_connection(self, eventloop, conn):
64 if not conn.in_use:
65 self.disconnect(conn)
66 return self.remove_connection(eventloop, conn)
67 return self.clock.callLater(
68 self._keepalive, self._reap_extra_connection, eventloop, conn
69 )
70
71 def is_staged(self, eventloop):
72 return eventloop in self.try_connections
73
74 def get_staged_connection(self, eventloop):
75 return self.try_connections.get(eventloop)
76
77 def get_staged_connections(self):
78 return self.try_connections
79
80 def stage_connection(self, eventloop, connection):
81 self.try_connections[eventloop] = connection
82
83 @inlineCallbacks
84 def scale_up_connections(self):
85 for ev, ev_conns in self.connections.items():
86 # pick first group with room for additional conns
87 if len(ev_conns) < self._max_connections:
88 # spawn an extra connection
89 conn_to_clone = random.choice(list(ev_conns))
90 conn = yield self.connect(ev, conn_to_clone.address)
91 self.connections[ev].append(conn)
92 self.clock.callLater(
93 self._keepalive, self._reap_extra_connection, ev, conn
94 )
95 return
96 raise exceptions.MaxConnectionsOpen()
97
98 def get_connection(self, eventloop):
99 return random.choice(self.connections[eventloop])
100
101 def get_random_connection(self):
102 return random.choice(self.get_all_connections())
103
104 def get_random_free_connection(self):
105 free_conns = self.get_all_free_connections()
106 if len(free_conns) == 0:
107 # caller should create a new connection
108 raise exceptions.AllConnectionsBusy()
109 return random.choice(free_conns)
110
111 def get_all_connections(self):
112 return [
113 conn
114 for conn_list in self.connections.values()
115 for conn in conn_list
116 ]
117
118 def get_all_free_connections(self):
119 return [
120 conn
121 for conn_list in self.connections.values()
122 for conn in conn_list
123 if not conn.in_use
124 ]
125
126 @inlineCallbacks
127 def connect(self, eventloop, address):
128 from provisioningserver.rpc.clusterservice import ClusterClient
129
130 # Force everything to use AF_INET6 sockets.
131 endpoint = TCP6ClientEndpoint(self.clock, *address)
132 protocol = ClusterClient(address, eventloop, self._service)
133 conn = yield connectProtocol(endpoint, protocol)
134 return conn
135
136 def disconnect(self, connection):
137 return connection.transport.loseConnection()
138
139 @inlineCallbacks
140 def add_connection(self, eventloop, connection):
141 if self.is_staged(eventloop):
142 del self.try_connections[eventloop]
143 if eventloop not in self.connections:
144 self.connections[eventloop] = []
145
146 self.connections[eventloop].append(connection)
147
148 # clone connection to equal num idle connections
149 if self._max_idle_connections - 1 > 0:
150 for _ in range(self._max_idle_connections - 1):
151 extra_conn = yield self.connect(
152 connection.eventloop, connection.address
153 )
154 self.connections[eventloop].append(extra_conn)
155
156 def remove_connection(self, eventloop, connection):
157 if self.is_staged(eventloop):
158 if self.try_connections[eventloop] is connection:
159 del self.try_connections[eventloop]
160 if connection in self.connections.get(eventloop, []):
161 self.connections[eventloop].remove(connection)
162 if len(self.connections[eventloop]) == 0:
163 del self.connections[eventloop]
diff --git a/src/provisioningserver/rpc/exceptions.py b/src/provisioningserver/rpc/exceptions.py
index 7ee4f3f..136e471 100644
--- a/src/provisioningserver/rpc/exceptions.py
+++ b/src/provisioningserver/rpc/exceptions.py
@@ -12,6 +12,14 @@ class NoConnectionsAvailable(Exception):
12 self.uuid = uuid12 self.uuid = uuid
1313
1414
15class AllConnectionsBusy(Exception):
16 """The current connection pool is busy"""
17
18
19class MaxConnectionsOpen(Exception):
20 """The maxmimum number of connections are currently open"""
21
22
15class NoSuchEventType(Exception):23class NoSuchEventType(Exception):
16 """The specified event type was not found."""24 """The specified event type was not found."""
1725
diff --git a/src/provisioningserver/rpc/testing/__init__.py b/src/provisioningserver/rpc/testing/__init__.py
index ee4a9e2..1b2f94f 100644
--- a/src/provisioningserver/rpc/testing/__init__.py
+++ b/src/provisioningserver/rpc/testing/__init__.py
@@ -262,7 +262,8 @@ class MockClusterToRegionRPCFixtureBase(fixtures.Fixture, metaclass=ABCMeta):
262 {262 {
263 "eventloops": {263 "eventloops": {
264 eventloop: [client.address]264 eventloop: [client.address]
265 for eventloop, client in connections265 for eventloop, clients in connections
266 for client in clients
266 }267 }
267 },268 },
268 orig_url,269 orig_url,
diff --git a/src/provisioningserver/rpc/testing/doubles.py b/src/provisioningserver/rpc/testing/doubles.py
index cb9f27f..0785859 100644
--- a/src/provisioningserver/rpc/testing/doubles.py
+++ b/src/provisioningserver/rpc/testing/doubles.py
@@ -30,6 +30,7 @@ class FakeConnection:
30 ident = attr.ib(default=sentinel.ident)30 ident = attr.ib(default=sentinel.ident)
31 hostCertificate = attr.ib(default=sentinel.hostCertificate)31 hostCertificate = attr.ib(default=sentinel.hostCertificate)
32 peerCertificate = attr.ib(default=sentinel.peerCertificate)32 peerCertificate = attr.ib(default=sentinel.peerCertificate)
33 in_use = attr.ib(default=False)
3334
34 def callRemote(self, cmd, **arguments):35 def callRemote(self, cmd, **arguments):
35 return succeed(sentinel.response)36 return succeed(sentinel.response)
@@ -48,6 +49,7 @@ class FakeConnectionToRegion:
48 address = attr.ib(default=(sentinel.host, sentinel.port))49 address = attr.ib(default=(sentinel.host, sentinel.port))
49 hostCertificate = attr.ib(default=sentinel.hostCertificate)50 hostCertificate = attr.ib(default=sentinel.hostCertificate)
50 peerCertificate = attr.ib(default=sentinel.peerCertificate)51 peerCertificate = attr.ib(default=sentinel.peerCertificate)
52 in_use = attr.ib(default=False)
5153
52 def callRemote(self, cmd, **arguments):54 def callRemote(self, cmd, **arguments):
53 return succeed(sentinel.response)55 return succeed(sentinel.response)
@@ -56,6 +58,22 @@ class FakeConnectionToRegion:
56verifyObject(IConnectionToRegion, FakeConnectionToRegion())58verifyObject(IConnectionToRegion, FakeConnectionToRegion())
5759
5860
61@attr.s(eq=False, order=False)
62@implementer(IConnectionToRegion)
63class FakeBusyConnectionToRegion:
64 "A fake `IConnectionToRegion` that appears busy." ""
65
66 ident = attr.ib(default=sentinel.ident)
67 localIdent = attr.ib(default=sentinel.localIdent)
68 address = attr.ib(default=(sentinel.host, sentinel.port))
69 hostCertificate = attr.ib(default=sentinel.hostCertificate)
70 peerCertificate = attr.ib(default=sentinel.peerCertificate)
71 in_use = attr.ib(default=True)
72
73 def callRemote(self, cmd, **arguments):
74 return succeed(sentinel.response)
75
76
59class StubOS(OperatingSystem):77class StubOS(OperatingSystem):
60 """An :py:class:`OperatingSystem` subclass that has canned answers.78 """An :py:class:`OperatingSystem` subclass that has canned answers.
6179
diff --git a/src/provisioningserver/rpc/tests/test_clusterservice.py b/src/provisioningserver/rpc/tests/test_clusterservice.py
index b50311d..6f3e4f9 100644
--- a/src/provisioningserver/rpc/tests/test_clusterservice.py
+++ b/src/provisioningserver/rpc/tests/test_clusterservice.py
@@ -23,7 +23,6 @@ from testtools.matchers import (
23 Is,23 Is,
24 IsInstance,24 IsInstance,
25 KeysEqual,25 KeysEqual,
26 MatchesAll,
27 MatchesDict,26 MatchesDict,
28 MatchesListwise,27 MatchesListwise,
29 MatchesStructure,28 MatchesStructure,
@@ -32,7 +31,6 @@ from twisted import web
32from twisted.application.internet import TimerService31from twisted.application.internet import TimerService
33from twisted.internet import error, reactor32from twisted.internet import error, reactor
34from twisted.internet.defer import Deferred, fail, inlineCallbacks, succeed33from twisted.internet.defer import Deferred, fail, inlineCallbacks, succeed
35from twisted.internet.endpoints import TCP6ClientEndpoint
36from twisted.internet.error import ConnectionClosed34from twisted.internet.error import ConnectionClosed
37from twisted.internet.task import Clock35from twisted.internet.task import Clock
38from twisted.internet.testing import StringTransportWithDisconnection36from twisted.internet.testing import StringTransportWithDisconnection
@@ -117,7 +115,11 @@ from provisioningserver.rpc.testing import (
117 call_responder,115 call_responder,
118 MockLiveClusterToRegionRPCFixture,116 MockLiveClusterToRegionRPCFixture,
119)117)
120from provisioningserver.rpc.testing.doubles import DummyConnection, StubOS118from provisioningserver.rpc.testing.doubles import (
119 FakeBusyConnectionToRegion,
120 FakeConnection,
121 StubOS,
122)
121from provisioningserver.security import set_shared_secret_on_filesystem123from provisioningserver.security import set_shared_secret_on_filesystem
122from provisioningserver.service_monitor import service_monitor124from provisioningserver.service_monitor import service_monitor
123from provisioningserver.testing.config import ClusterConfigurationFixture125from provisioningserver.testing.config import ClusterConfigurationFixture
@@ -444,8 +446,10 @@ class TestClusterProtocol_DescribePowerTypes(MAASTestCase):
444 )446 )
445447
446448
447def make_inert_client_service():449def make_inert_client_service(max_idle_conns=1, max_conns=1, keepalive=1):
448 service = ClusterClientService(Clock())450 service = ClusterClientService(
451 Clock(), max_idle_conns, max_conns, keepalive
452 )
449 # ClusterClientService's superclass, TimerService, creates a453 # ClusterClientService's superclass, TimerService, creates a
450 # LoopingCall with now=True. We neuter it here to allow454 # LoopingCall with now=True. We neuter it here to allow
451 # observation of the behaviour of _update_interval() for455 # observation of the behaviour of _update_interval() for
@@ -498,11 +502,11 @@ class TestClusterClientService(MAASTestCase):
498 )502 )
499503
500 # Fake some connections.504 # Fake some connections.
501 service.connections = {505 service.connections.connections = {
502 ipv4client.eventloop: ipv4client,506 ipv4client.eventloop: [ipv4client],
503 ipv6client.eventloop: ipv6client,507 ipv6client.eventloop: [ipv6client],
504 ipv6mapped.eventloop: ipv6mapped,508 ipv6mapped.eventloop: [ipv6mapped],
505 hostclient.eventloop: hostclient,509 hostclient.eventloop: [hostclient],
506 }510 }
507511
508 # Update the RPC state to the filesystem and info cache.512 # Update the RPC state to the filesystem and info cache.
@@ -515,7 +519,8 @@ class TestClusterClientService(MAASTestCase):
515 Equals(519 Equals(
516 {520 {
517 client.address[0]521 client.address[0]
518 for _, client in service.connections.items()522 for _, clients in service.connections.items()
523 for client in clients
519 }524 }
520 ),525 ),
521 )526 )
@@ -999,9 +1004,9 @@ class TestClusterClientService(MAASTestCase):
999 def test_update_connections_initially(self):1004 def test_update_connections_initially(self):
1000 service = ClusterClientService(Clock())1005 service = ClusterClientService(Clock())
1001 mock_client = Mock()1006 mock_client = Mock()
1002 _make_connection = self.patch(service, "_make_connection")1007 _make_connection = self.patch(service.connections, "connect")
1003 _make_connection.side_effect = lambda *args: succeed(mock_client)1008 _make_connection.side_effect = lambda *args: succeed(mock_client)
1004 _drop_connection = self.patch(service, "_drop_connection")1009 _drop_connection = self.patch(service.connections, "disconnect")
10051010
1006 info = json.loads(self.example_rpc_info_view_response.decode("ascii"))1011 info = json.loads(self.example_rpc_info_view_response.decode("ascii"))
1007 yield service._update_connections(info["eventloops"])1012 yield service._update_connections(info["eventloops"])
@@ -1020,7 +1025,7 @@ class TestClusterClientService(MAASTestCase):
1020 "host1:pid=2002": mock_client,1025 "host1:pid=2002": mock_client,
1021 "host2:pid=3003": mock_client,1026 "host2:pid=3003": mock_client,
1022 },1027 },
1023 service.try_connections,1028 service.connections.try_connections,
1024 )1029 )
10251030
1026 self.assertEqual([], _drop_connection.mock_calls)1031 self.assertEqual([], _drop_connection.mock_calls)
@@ -1038,7 +1043,7 @@ class TestClusterClientService(MAASTestCase):
1038 for address in addresses:1043 for address in addresses:
1039 client = Mock()1044 client = Mock()
1040 client.address = address1045 client.address = address
1041 service.connections[eventloop] = client1046 service.connections.connections[eventloop] = [client]
10421047
1043 logger = self.useFixture(TwistedLoggerFixture())1048 logger = self.useFixture(TwistedLoggerFixture())
10441049
@@ -1055,7 +1060,7 @@ class TestClusterClientService(MAASTestCase):
1055 @inlineCallbacks1060 @inlineCallbacks
1056 def test_update_connections_connect_error_is_logged_tersely(self):1061 def test_update_connections_connect_error_is_logged_tersely(self):
1057 service = ClusterClientService(Clock())1062 service = ClusterClientService(Clock())
1058 _make_connection = self.patch(service, "_make_connection")1063 _make_connection = self.patch(service.connections, "connect")
1059 _make_connection.side_effect = error.ConnectionRefusedError()1064 _make_connection.side_effect = error.ConnectionRefusedError()
10601065
1061 logger = self.useFixture(TwistedLoggerFixture())1066 logger = self.useFixture(TwistedLoggerFixture())
@@ -1079,7 +1084,7 @@ class TestClusterClientService(MAASTestCase):
1079 @inlineCallbacks1084 @inlineCallbacks
1080 def test_update_connections_unknown_error_is_logged_with_stack(self):1085 def test_update_connections_unknown_error_is_logged_with_stack(self):
1081 service = ClusterClientService(Clock())1086 service = ClusterClientService(Clock())
1082 _make_connection = self.patch(service, "_make_connection")1087 _make_connection = self.patch(service.connections, "connect")
1083 _make_connection.side_effect = RuntimeError("Something went wrong.")1088 _make_connection.side_effect = RuntimeError("Something went wrong.")
10841089
1085 logger = self.useFixture(TwistedLoggerFixture())1090 logger = self.useFixture(TwistedLoggerFixture())
@@ -1106,8 +1111,8 @@ class TestClusterClientService(MAASTestCase):
11061111
1107 def test_update_connections_when_there_are_existing_connections(self):1112 def test_update_connections_when_there_are_existing_connections(self):
1108 service = ClusterClientService(Clock())1113 service = ClusterClientService(Clock())
1109 _make_connection = self.patch(service, "_make_connection")1114 _connect = self.patch(service.connections, "connect")
1110 _drop_connection = self.patch(service, "_drop_connection")1115 _disconnect = self.patch(service.connections, "disconnect")
11111116
1112 host1client = ClusterClient(1117 host1client = ClusterClient(
1113 ("::ffff:1.1.1.1", 1111), "host1:pid=1", service1118 ("::ffff:1.1.1.1", 1111), "host1:pid=1", service
@@ -1120,9 +1125,9 @@ class TestClusterClientService(MAASTestCase):
1120 )1125 )
11211126
1122 # Fake some connections.1127 # Fake some connections.
1123 service.connections = {1128 service.connections.connections = {
1124 host1client.eventloop: host1client,1129 host1client.eventloop: [host1client],
1125 host2client.eventloop: host2client,1130 host2client.eventloop: [host2client],
1126 }1131 }
11271132
1128 # Request a new set of connections that overlaps with the1133 # Request a new set of connections that overlaps with the
@@ -1137,10 +1142,10 @@ class TestClusterClientService(MAASTestCase):
1137 # A connection is made for host3's event-loop, and the1142 # A connection is made for host3's event-loop, and the
1138 # connection to host2's event-loop is dropped.1143 # connection to host2's event-loop is dropped.
1139 self.assertThat(1144 self.assertThat(
1140 _make_connection,1145 _connect,
1141 MockCalledOnceWith(host3client.eventloop, host3client.address),1146 MockCalledOnceWith(host3client.eventloop, host3client.address),
1142 )1147 )
1143 self.assertThat(_drop_connection, MockCalledWith(host2client))1148 self.assertThat(_disconnect, MockCalledWith(host2client))
11441149
1145 @inlineCallbacks1150 @inlineCallbacks
1146 def test_update_only_updates_interval_when_eventloops_are_unknown(self):1151 def test_update_only_updates_interval_when_eventloops_are_unknown(self):
@@ -1175,57 +1180,15 @@ class TestClusterClientService(MAASTestCase):
1175 logger.dump(),1180 logger.dump(),
1176 )1181 )
11771182
1178 def test_make_connection(self):
1179 service = ClusterClientService(Clock())
1180 connectProtocol = self.patch(clusterservice, "connectProtocol")
1181 service._make_connection("an-event-loop", ("a.example.com", 1111))
1182 self.assertThat(connectProtocol.call_args_list, HasLength(1))
1183 self.assertThat(
1184 connectProtocol.call_args_list[0][0],
1185 MatchesListwise(
1186 (
1187 # First argument is an IPv4 TCP client endpoint
1188 # specification.
1189 MatchesAll(
1190 IsInstance(TCP6ClientEndpoint),
1191 MatchesStructure.byEquality(
1192 _reactor=service.clock,
1193 _host="a.example.com",
1194 _port=1111,
1195 ),
1196 ),
1197 # Second argument is a ClusterClient instance, the
1198 # protocol to use for the connection.
1199 MatchesAll(
1200 IsInstance(clusterservice.ClusterClient),
1201 MatchesStructure.byEquality(
1202 address=("a.example.com", 1111),
1203 eventloop="an-event-loop",
1204 service=service,
1205 ),
1206 ),
1207 )
1208 ),
1209 )
1210
1211 def test_drop_connection(self):
1212 connection = Mock()
1213 service = make_inert_client_service()
1214 service.startService()
1215 service._drop_connection(connection)
1216 self.assertThat(
1217 connection.transport.loseConnection, MockCalledOnceWith()
1218 )
1219
1220 def test_add_connection_removes_from_try_connections(self):1183 def test_add_connection_removes_from_try_connections(self):
1221 service = make_inert_client_service()1184 service = make_inert_client_service()
1222 service.startService()1185 service.startService()
1223 endpoint = Mock()1186 endpoint = Mock()
1224 connection = Mock()1187 connection = Mock()
1225 connection.address = (":::ffff", 2222)1188 connection.address = (":::ffff", 2222)
1226 service.try_connections[endpoint] = connection1189 service.connections.try_connections[endpoint] = connection
1227 service.add_connection(endpoint, connection)1190 service.add_connection(endpoint, connection)
1228 self.assertThat(service.try_connections, Equals({}))1191 self.assertThat(service.connections.try_connections, Equals({}))
12291192
1230 def test_add_connection_adds_to_connections(self):1193 def test_add_connection_adds_to_connections(self):
1231 service = make_inert_client_service()1194 service = make_inert_client_service()
@@ -1234,7 +1197,7 @@ class TestClusterClientService(MAASTestCase):
1234 connection = Mock()1197 connection = Mock()
1235 connection.address = (":::ffff", 2222)1198 connection.address = (":::ffff", 2222)
1236 service.add_connection(endpoint, connection)1199 service.add_connection(endpoint, connection)
1237 self.assertThat(service.connections, Equals({endpoint: connection}))1200 self.assertEqual(service.connections, {endpoint: [connection]})
12381201
1239 def test_add_connection_calls__update_saved_rpc_info_state(self):1202 def test_add_connection_calls__update_saved_rpc_info_state(self):
1240 service = make_inert_client_service()1203 service = make_inert_client_service()
@@ -1248,21 +1211,45 @@ class TestClusterClientService(MAASTestCase):
1248 service._update_saved_rpc_info_state, MockCalledOnceWith()1211 service._update_saved_rpc_info_state, MockCalledOnceWith()
1249 )1212 )
12501213
1214 def test_add_connection_creates_max_idle_connections(self):
1215 service = make_inert_client_service(max_idle_conns=2)
1216 service.startService()
1217 endpoint = Mock()
1218 connection = Mock()
1219 connection.address = (":::ffff", 2222)
1220 connection2 = Mock()
1221 connection.address = (":::ffff", 2222)
1222 self.patch(service.connections, "connect").return_value = succeed(
1223 connection2
1224 )
1225 self.patch_autospec(service, "_update_saved_rpc_info_state")
1226 service.add_connection(endpoint, connection)
1227 self.assertEqual(
1228 len(
1229 [
1230 conn
1231 for conns in service.connections.values()
1232 for conn in conns
1233 ]
1234 ),
1235 service.connections._max_idle_connections,
1236 )
1237
1251 def test_remove_connection_removes_from_try_connections(self):1238 def test_remove_connection_removes_from_try_connections(self):
1252 service = make_inert_client_service()1239 service = make_inert_client_service()
1253 service.startService()1240 service.startService()
1254 endpoint = Mock()1241 endpoint = Mock()
1255 connection = Mock()1242 connection = Mock()
1256 service.try_connections[endpoint] = connection1243 service.connections.try_connections[endpoint] = connection
1257 service.remove_connection(endpoint, connection)1244 service.remove_connection(endpoint, connection)
1258 self.assertThat(service.try_connections, Equals({}))1245 self.assertEqual(service.connections.try_connections, {})
12591246
1260 def test_remove_connection_removes_from_connections(self):1247 def test_remove_connection_removes_from_connections(self):
1261 service = make_inert_client_service()1248 service = make_inert_client_service()
1262 service.startService()1249 service.startService()
1263 endpoint = Mock()1250 endpoint = Mock()
1264 connection = Mock()1251 connection = Mock()
1265 service.connections[endpoint] = connection1252 service.connections[endpoint] = {connection}
1266 service.remove_connection(endpoint, connection)1253 service.remove_connection(endpoint, connection)
1267 self.assertThat(service.connections, Equals({}))1254 self.assertThat(service.connections, Equals({}))
12681255
@@ -1271,7 +1258,7 @@ class TestClusterClientService(MAASTestCase):
1271 service.startService()1258 service.startService()
1272 endpoint = Mock()1259 endpoint = Mock()
1273 connection = Mock()1260 connection = Mock()
1274 service.connections[endpoint] = connection1261 service.connections[endpoint] = {connection}
1275 service.remove_connection(endpoint, connection)1262 service.remove_connection(endpoint, connection)
1276 self.assertEqual(service.step, service.INTERVAL_LOW)1263 self.assertEqual(service.step, service.INTERVAL_LOW)
12771264
@@ -1280,7 +1267,7 @@ class TestClusterClientService(MAASTestCase):
1280 service.startService()1267 service.startService()
1281 endpoint = Mock()1268 endpoint = Mock()
1282 connection = Mock()1269 connection = Mock()
1283 service.connections[endpoint] = connection1270 service.connections[endpoint] = {connection}
12841271
1285 # Enable both dhcpd and dhcpd6.1272 # Enable both dhcpd and dhcpd6.
1286 service_monitor.getServiceByName("dhcpd").on()1273 service_monitor.getServiceByName("dhcpd").on()
@@ -1294,45 +1281,96 @@ class TestClusterClientService(MAASTestCase):
12941281
1295 def test_getClient(self):1282 def test_getClient(self):
1296 service = ClusterClientService(Clock())1283 service = ClusterClientService(Clock())
1297 service.connections = {1284 service.connections.connections = {
1298 sentinel.eventloop01: DummyConnection(),1285 sentinel.eventloop01: [FakeConnection()],
1299 sentinel.eventloop02: DummyConnection(),1286 sentinel.eventloop02: [FakeConnection()],
1300 sentinel.eventloop03: DummyConnection(),1287 sentinel.eventloop03: [FakeConnection()],
1301 }1288 }
1302 self.assertIn(1289 self.assertIn(
1303 service.getClient(),1290 service.getClient(),
1304 {common.Client(conn) for conn in service.connections.values()},1291 {
1292 common.Client(conn)
1293 for conns in service.connections.values()
1294 for conn in conns
1295 },
1305 )1296 )
13061297
1307 def test_getClient_when_there_are_no_connections(self):1298 def test_getClient_when_there_are_no_connections(self):
1308 service = ClusterClientService(Clock())1299 service = ClusterClientService(Clock())
1309 service.connections = {}1300 service.connections.connections = {}
1310 self.assertRaises(exceptions.NoConnectionsAvailable, service.getClient)1301 self.assertRaises(exceptions.NoConnectionsAvailable, service.getClient)
13111302
1312 @inlineCallbacks1303 @inlineCallbacks
1304 def test_getClientNow_scales_connections_when_busy(self):
1305 service = ClusterClientService(Clock(), max_conns=2)
1306 service.connections.connections = {
1307 sentinel.eventloop01: [FakeBusyConnectionToRegion()],
1308 sentinel.eventloop02: [FakeBusyConnectionToRegion()],
1309 sentinel.eventloop03: [FakeBusyConnectionToRegion()],
1310 }
1311 self.patch(service.connections, "connect").return_value = succeed(
1312 FakeConnection()
1313 )
1314 original_conns = [
1315 conn for conns in service.connections.values() for conn in conns
1316 ]
1317 new_client = yield service.getClientNow()
1318 new_conn = new_client._conn
1319 self.assertIsNotNone(new_conn)
1320 self.assertNotIn(new_conn, original_conns)
1321 self.assertIn(
1322 new_conn,
1323 [conn for conns in service.connections.values() for conn in conns],
1324 )
1325
1326 @inlineCallbacks
1327 def test_getClientNow_returns_an_existing_connection_when_max_are_open(
1328 self,
1329 ):
1330 service = ClusterClientService(Clock(), max_conns=1)
1331 service.connections.connections = {
1332 sentinel.eventloop01: [FakeBusyConnectionToRegion()],
1333 sentinel.eventloop02: [FakeBusyConnectionToRegion()],
1334 sentinel.eventloop03: [FakeBusyConnectionToRegion()],
1335 }
1336 self.patch(service, "_make_connection").return_value = succeed(
1337 FakeConnection()
1338 )
1339 original_conns = [
1340 conn for conns in service.connections.values() for conn in conns
1341 ]
1342 new_client = yield service.getClientNow()
1343 new_conn = new_client._conn
1344 self.assertIsNotNone(new_conn)
1345 self.assertIn(new_conn, original_conns)
1346
1347 @inlineCallbacks
1313 def test_getClientNow_returns_current_connection(self):1348 def test_getClientNow_returns_current_connection(self):
1314 service = ClusterClientService(Clock())1349 service = ClusterClientService(Clock())
1315 service.connections = {1350 service.connections.connections = {
1316 sentinel.eventloop01: DummyConnection(),1351 sentinel.eventloop01: [FakeConnection()],
1317 sentinel.eventloop02: DummyConnection(),1352 sentinel.eventloop02: [FakeConnection()],
1318 sentinel.eventloop03: DummyConnection(),1353 sentinel.eventloop03: [FakeConnection()],
1319 }1354 }
1320 client = yield service.getClientNow()1355 client = yield service.getClientNow()
1321 self.assertIn(1356 self.assertIn(
1322 client,1357 client,
1323 {common.Client(conn) for conn in service.connections.values()},1358 [
1359 common.Client(conn)
1360 for conns in service.connections.values()
1361 for conn in conns
1362 ],
1324 )1363 )
13251364
1326 @inlineCallbacks1365 @inlineCallbacks
1327 def test_getClientNow_calls__tryUpdate_when_there_are_no_connections(self):1366 def test_getClientNow_calls__tryUpdate_when_there_are_no_connections(self):
1328 service = ClusterClientService(Clock())1367 service = ClusterClientService(Clock())
1329 service.connections = {}
13301368
1331 def addConnections():1369 def addConnections():
1332 service.connections = {1370 service.connections.connections = {
1333 sentinel.eventloop01: DummyConnection(),1371 sentinel.eventloop01: [FakeConnection()],
1334 sentinel.eventloop02: DummyConnection(),1372 sentinel.eventloop02: [FakeConnection()],
1335 sentinel.eventloop03: DummyConnection(),1373 sentinel.eventloop03: [FakeConnection()],
1336 }1374 }
1337 return succeed(None)1375 return succeed(None)
13381376
@@ -1340,12 +1378,15 @@ class TestClusterClientService(MAASTestCase):
1340 client = yield service.getClientNow()1378 client = yield service.getClientNow()
1341 self.assertIn(1379 self.assertIn(
1342 client,1380 client,
1343 {common.Client(conn) for conn in service.connections.values()},1381 {
1382 common.Client(conn)
1383 for conns in service.connections.values()
1384 for conn in conns
1385 },
1344 )1386 )
13451387
1346 def test_getClientNow_raises_exception_when_no_clients(self):1388 def test_getClientNow_raises_exception_when_no_clients(self):
1347 service = ClusterClientService(Clock())1389 service = ClusterClientService(Clock())
1348 service.connections = {}
13491390
1350 self.patch(service, "_tryUpdate").return_value = succeed(None)1391 self.patch(service, "_tryUpdate").return_value = succeed(None)
1351 d = service.getClientNow()1392 d = service.getClientNow()
@@ -1383,17 +1424,16 @@ class TestClusterClientService(MAASTestCase):
1383 def test_getAllClients(self):1424 def test_getAllClients(self):
1384 service = ClusterClientService(Clock())1425 service = ClusterClientService(Clock())
1385 uuid1 = factory.make_UUID()1426 uuid1 = factory.make_UUID()
1386 c1 = DummyConnection()1427 c1 = FakeConnection()
1387 service.connections[uuid1] = c11428 service.connections[uuid1] = {c1}
1388 uuid2 = factory.make_UUID()1429 uuid2 = factory.make_UUID()
1389 c2 = DummyConnection()1430 c2 = FakeConnection()
1390 service.connections[uuid2] = c21431 service.connections[uuid2] = {c2}
1391 clients = service.getAllClients()1432 clients = service.getAllClients()
1392 self.assertEqual(clients, [common.Client(c1), common.Client(c2)])1433 self.assertEqual(clients, [common.Client(c1), common.Client(c2)])
13931434
1394 def test_getAllClients_when_there_are_no_connections(self):1435 def test_getAllClients_when_there_are_no_connections(self):
1395 service = ClusterClientService(Clock())1436 service = ClusterClientService(Clock())
1396 service.connections = {}
1397 self.assertThat(service.getAllClients(), Equals([]))1437 self.assertThat(service.getAllClients(), Equals([]))
13981438
13991439
@@ -1546,7 +1586,7 @@ class TestClusterClient(MAASTestCase):
15461586
1547 def test_connecting(self):1587 def test_connecting(self):
1548 client = self.make_running_client()1588 client = self.make_running_client()
1549 client.service.try_connections[client.eventloop] = client1589 client.service.connections.try_connections[client.eventloop] = client
1550 self.patch_authenticate_for_success(client)1590 self.patch_authenticate_for_success(client)
1551 self.patch_register_for_success(client)1591 self.patch_register_for_success(client)
1552 self.assertEqual(client.service.connections, {})1592 self.assertEqual(client.service.connections, {})
@@ -1560,16 +1600,19 @@ class TestClusterClient(MAASTestCase):
1560 self.assertTrue(extract_result(wait_for_authenticated))1600 self.assertTrue(extract_result(wait_for_authenticated))
1561 # ready has been set with the name of the event-loop.1601 # ready has been set with the name of the event-loop.
1562 self.assertEqual(client.eventloop, extract_result(wait_for_ready))1602 self.assertEqual(client.eventloop, extract_result(wait_for_ready))
1563 self.assertEqual(client.service.try_connections, {})1603 self.assertEqual(len(client.service.connections.try_connections), 0)
1564 self.assertEqual(1604 self.assertEqual(
1565 client.service.connections, {client.eventloop: client}1605 client.service.connections.connections,
1606 {client.eventloop: [client]},
1566 )1607 )
15671608
1568 def test_disconnects_when_there_is_an_existing_connection(self):1609 def test_disconnects_when_there_is_an_existing_connection(self):
1569 client = self.make_running_client()1610 client = self.make_running_client()
15701611
1571 # Pretend that a connection already exists for this address.1612 # Pretend that a connection already exists for this address.
1572 client.service.connections[client.eventloop] = sentinel.connection1613 client.service.connections.connections[client.eventloop] = [
1614 sentinel.connection
1615 ]
15731616
1574 # Connect via an in-memory transport.1617 # Connect via an in-memory transport.
1575 transport = StringTransportWithDisconnection()1618 transport = StringTransportWithDisconnection()
@@ -1586,7 +1629,8 @@ class TestClusterClient(MAASTestCase):
1586 # The connections list is unchanged because the new connection1629 # The connections list is unchanged because the new connection
1587 # immediately disconnects.1630 # immediately disconnects.
1588 self.assertEqual(1631 self.assertEqual(
1589 client.service.connections, {client.eventloop: sentinel.connection}1632 client.service.connections,
1633 {client.eventloop: [sentinel.connection]},
1590 )1634 )
1591 self.assertFalse(client.connected)1635 self.assertFalse(client.connected)
1592 self.assertIsNone(client.transport)1636 self.assertIsNone(client.transport)
@@ -1631,7 +1675,7 @@ class TestClusterClient(MAASTestCase):
16311675
1632 # The connections list is unchanged because the new connection1676 # The connections list is unchanged because the new connection
1633 # immediately disconnects.1677 # immediately disconnects.
1634 self.assertEqual(client.service.connections, {})1678 self.assertEqual(client.service.connections.connections, {})
1635 self.assertFalse(client.connected)1679 self.assertFalse(client.connected)
16361680
1637 def test_disconnects_when_authentication_errors(self):1681 def test_disconnects_when_authentication_errors(self):
diff --git a/src/provisioningserver/rpc/tests/test_connectionpool.py b/src/provisioningserver/rpc/tests/test_connectionpool.py
1638new file mode 1006441682new file mode 100644
index 0000000..692d5e6
--- /dev/null
+++ b/src/provisioningserver/rpc/tests/test_connectionpool.py
@@ -0,0 +1,280 @@
1# Copyright 2022 Canonical Ltd. This software is licensed under the
2# GNU Affero General Public License version 3 (see the file LICENSE).
3
4from unittest.mock import Mock
5
6from twisted.internet.defer import inlineCallbacks, succeed
7from twisted.internet.endpoints import TCP6ClientEndpoint
8from twisted.internet.task import Clock
9
10from maastesting import get_testing_timeout
11from maastesting.testcase import MAASTestCase, MAASTwistedRunTest
12from maastesting.twisted import extract_result
13from provisioningserver.rpc import connectionpool as connectionpoolModule
14from provisioningserver.rpc import exceptions
15from provisioningserver.rpc.clusterservice import ClusterClient
16from provisioningserver.rpc.connectionpool import ConnectionPool
17
18TIMEOUT = get_testing_timeout()
19
20
21class TestConnectionPool(MAASTestCase):
22
23 run_tests_with = MAASTwistedRunTest.make_factory(timeout=TIMEOUT)
24
25 def test_setitem_sets_item_in_connections(self):
26 cp = ConnectionPool(Clock(), Mock())
27 key = Mock()
28 val = Mock()
29 cp[key] = val
30 self.assertEqual(cp.connections, {key: val})
31
32 def test_getitem_gets_item_in_connections(self):
33 cp = ConnectionPool(Clock(), Mock())
34 key = Mock()
35 val = Mock()
36 cp[key] = val
37 self.assertEqual(cp.connections[key], cp[key])
38
39 def test_len_gets_length_of_connections(self):
40 cp = ConnectionPool(Clock(), Mock())
41 key = Mock()
42 val = Mock()
43 cp[key] = [val]
44 self.assertEqual(len(cp), len(cp.get_all_connections()))
45
46 def test_delitem_removes_item_from_connections(self):
47 cp = ConnectionPool(Clock(), Mock())
48 key = Mock()
49 val = Mock()
50 cp[key] = val
51 self.assertEqual(cp.connections[key], val)
52 del cp[key]
53 self.assertIsNone(cp.connections.get(key))
54
55 def test_contains_returns_if_key_in_connections(self):
56 cp = ConnectionPool(Clock(), Mock())
57 key = Mock()
58 val = Mock()
59 cp[key] = val
60 self.assertEqual(key in cp, key in cp.connections)
61
62 def test_compare_ConnectionPool_equal_to_compare_connections(self):
63 cp = ConnectionPool(Clock(), Mock())
64 self.assertEqual(cp, cp.connections)
65 self.assertEqual(cp, {})
66
67 def test__reap_extra_connection_reaps_a_non_busy_connection(self):
68 cp = ConnectionPool(Clock(), Mock())
69 eventloop = Mock()
70 connection = Mock()
71 connection.in_use = False
72 cp[eventloop] = [connection]
73 disconnect = self.patch(cp, "disconnect")
74 cp._reap_extra_connection(eventloop, connection)
75 self.assertEqual(len(cp), 0)
76 disconnect.assert_called_once_with(connection)
77
78 def test__reap_extra_connection_waits_for_a_busy_connection(self):
79 clock = Clock()
80 cp = ConnectionPool(clock, Mock())
81 eventloop = Mock()
82 connection = Mock()
83 connection.in_use = True
84 cp[eventloop] = [connection]
85 self.patch(cp, "disconnect")
86 cp._reap_extra_connection(eventloop, connection)
87 self.assertIn(eventloop, clock.calls[0].args)
88 self.assertIn(connection, clock.calls[0].args)
89 self.assertEqual(
90 "_reap_extra_connection", clock.calls[0].func.__name__
91 )
92 self.assertEqual(cp._keepalive, clock.calls[0].time)
93
94 def test_is_staged(self):
95 cp = ConnectionPool(Clock(), Mock())
96 eventloop1 = Mock()
97 eventloop2 = Mock()
98 cp.try_connections[eventloop1] = Mock()
99 self.assertTrue(cp.is_staged(eventloop1))
100 self.assertFalse(cp.is_staged(eventloop2))
101
102 def test_get_staged_connection(self):
103 cp = ConnectionPool(Clock(), Mock())
104 eventloop = Mock()
105 connection = Mock()
106 cp.try_connections[eventloop] = connection
107 self.assertEqual(cp.get_staged_connection(eventloop), connection)
108
109 def test_get_staged_connections(self):
110 cp = ConnectionPool(Clock(), Mock())
111 eventloop = Mock()
112 connection = Mock()
113 cp.try_connections[eventloop] = connection
114 self.assertEqual(cp.get_staged_connections(), {eventloop: connection})
115
116 def test_scale_up_connections_adds_a_connection(self):
117 cp = ConnectionPool(Clock(), Mock(), max_conns=2)
118 eventloop = Mock()
119 connection1 = Mock()
120 connection2 = Mock()
121 connect = self.patch(cp, "connect")
122 connect.return_value = succeed(connection2)
123 cp[eventloop] = [connection1]
124 cp.scale_up_connections()
125 self.assertCountEqual(cp[eventloop], [connection1, connection2])
126
127 def test_scale_up_connections_raises_MaxConnectionsOpen_when_cannot_create_another(
128 self,
129 ):
130 cp = ConnectionPool(Clock(), Mock())
131 eventloop = Mock()
132 connection1 = Mock()
133 connection2 = Mock()
134 connect = self.patch(cp, "connect")
135 connect.return_value = succeed(connection2)
136 cp[eventloop] = [connection1]
137 self.assertRaises(
138 exceptions.MaxConnectionsOpen,
139 extract_result,
140 cp.scale_up_connections(),
141 )
142
143 def test_get_connection(self):
144 cp = ConnectionPool(Clock(), Mock(), max_idle_conns=2, max_conns=2)
145 eventloops = [Mock() for _ in range(3)]
146 cp.connections = {
147 eventloop: [Mock() for _ in range(2)] for eventloop in eventloops
148 }
149 self.assertIn(cp.get_connection(eventloops[0]), cp[eventloops[0]])
150
151 def test_get_random_connection(self):
152 cp = ConnectionPool(Clock(), Mock(), max_idle_conns=2, max_conns=2)
153 eventloops = [Mock() for _ in range(3)]
154 cp.connections = {
155 eventloop: [Mock() for _ in range(2)] for eventloop in eventloops
156 }
157 self.assertIn(
158 cp.get_connection(eventloops[0]),
159 [conn for conn_list in cp.values() for conn in conn_list],
160 )
161
162 def test_get_random_free_connection_returns_a_free_connection(self):
163 cp = ConnectionPool(Clock(), Mock())
164 eventloops = [Mock() for _ in range(3)]
165
166 def _create_conn(in_use):
167 conn = Mock()
168 conn.in_use = in_use
169 return conn
170
171 cp.connections = {
172 eventloops[0]: [_create_conn(True)],
173 eventloops[1]: [_create_conn(False)],
174 eventloops[2]: [_create_conn(True)],
175 }
176 conn = cp.get_random_free_connection()
177 self.assertIn(conn, cp[eventloops[1]])
178
179 def test_get_random_free_connection_raises_AllConnectionsBusy_when_there_are_no_free_connections(
180 self,
181 ):
182 cp = ConnectionPool(Clock(), Mock())
183 eventloops = [Mock() for _ in range(3)]
184
185 def _create_conn(in_use):
186 conn = Mock()
187 conn.in_use = in_use
188 return conn
189
190 cp.connections = {
191 eventloops[0]: [_create_conn(True)],
192 eventloops[1]: [_create_conn(True)],
193 eventloops[2]: [_create_conn(True)],
194 }
195
196 self.assertRaises(
197 exceptions.AllConnectionsBusy, cp.get_random_free_connection
198 )
199
200 def test_get_all_connections(self):
201 cp = ConnectionPool(Clock(), Mock())
202 eventloops = [Mock() for _ in range(3)]
203 cp.connections = {
204 eventloops[0]: [Mock()],
205 eventloops[1]: [Mock()],
206 eventloops[2]: [Mock()],
207 }
208
209 self.assertCountEqual(
210 cp.get_all_connections(),
211 [conn for conn_list in cp.values() for conn in conn_list],
212 )
213
214 def test_get_all_free_connections(self):
215 cp = ConnectionPool(Clock(), Mock(), max_conns=2)
216 eventloops = [Mock() for _ in range(3)]
217
218 def _create_conn(in_use):
219 conn = Mock()
220 conn.in_use = in_use
221 return conn
222
223 cp.connections = {
224 eventloops[0]: [_create_conn(True), _create_conn(False)],
225 eventloops[1]: [_create_conn(True)],
226 eventloops[2]: [_create_conn(False)],
227 }
228
229 self.assertCountEqual(
230 cp.get_all_free_connections(),
231 [
232 conn
233 for conn_list in cp.values()
234 for conn in conn_list
235 if not conn.in_use
236 ],
237 )
238
239 @inlineCallbacks
240 def test_connect(self):
241 clock = Clock()
242 connection = Mock()
243 service = Mock()
244 cp = ConnectionPool(clock, service)
245 connectProtocol = self.patch(connectionpoolModule, "connectProtocol")
246 connectProtocol.return_value = connection
247 result = yield cp.connect("an-event-loop", ("a.example.com", 1111))
248 self.assertEqual(len(connectProtocol.call_args_list), 1)
249 connectProtocol.called_once_with(
250 TCP6ClientEndpoint(reactor=clock, host="a.example.com", port=1111),
251 ClusterClient(
252 address=("a.example.com", 1111),
253 eventloop="an-event-loop",
254 service=service,
255 ),
256 )
257 self.assertEqual(result, connection)
258
259 def test_drop_connection(self):
260 connection = Mock()
261 cp = ConnectionPool(Clock(), Mock())
262 cp.disconnect(connection)
263 connection.transport.loseConnection.assert_called_once_with()
264
265 @inlineCallbacks
266 def test_add_connection_adds_the_staged_connection(self):
267 eventloop = Mock()
268 connection = Mock()
269 cp = ConnectionPool(Clock(), Mock())
270 cp.try_connections = {eventloop: connection}
271 yield cp.add_connection(eventloop, connection)
272 self.assertIn(connection, cp.get_all_connections())
273
274 def test_remove_connection_removes_connection_from_pool(self):
275 eventloop = Mock()
276 connection = Mock()
277 cp = ConnectionPool(Clock(), Mock())
278 cp.connections[eventloop] = [connection]
279 cp.remove_connection(eventloop, connection)
280 self.assertEqual(cp.connections, {})

Subscribers

People subscribed via source and target branches