Merge ~cgrabowski/maas:rpc_connection_pool_burst into maas:master

Proposed by Christian Grabowski
Status: Merged
Approved by: Christian Grabowski
Approved revision: 0c5564670e61e93261925cf91a46850824536a77
Merge reported by: MAAS Lander
Merged at revision: not available
Proposed branch: ~cgrabowski/maas:rpc_connection_pool_burst
Merge into: maas:master
Diff against target: 1518 lines (+741/-175)
15 files modified
src/provisioningserver/config.py (+17/-0)
src/provisioningserver/dhcp/tests/test_config.py (+2/-2)
src/provisioningserver/plugin.py (+7/-1)
src/provisioningserver/rackdservices/external.py (+3/-2)
src/provisioningserver/rackdservices/http.py (+3/-2)
src/provisioningserver/rackdservices/tests/test_external.py (+8/-5)
src/provisioningserver/rackdservices/tests/test_http.py (+5/-4)
src/provisioningserver/rpc/clusterservice.py (+56/-50)
src/provisioningserver/rpc/common.py (+22/-5)
src/provisioningserver/rpc/connectionpool.py (+163/-0)
src/provisioningserver/rpc/exceptions.py (+8/-0)
src/provisioningserver/rpc/testing/__init__.py (+2/-1)
src/provisioningserver/rpc/testing/doubles.py (+18/-0)
src/provisioningserver/rpc/tests/test_clusterservice.py (+147/-103)
src/provisioningserver/rpc/tests/test_connectionpool.py (+280/-0)
Reviewer Review Type Date Requested Status
Alexsander de Souza Approve
MAAS Lander Approve
Review via email: mp+428335@code.launchpad.net

Commit message

move connection lifecycle logic into ConnectionPool

allocate additional connections when busy

always connect max idle connections times

To post a comment you must log in.
Revision history for this message
MAAS Lander (maas-lander) wrote :

UNIT TESTS
-b rpc_connection_pool_burst lp:~cgrabowski/maas/+git/maas into -b master lp:~maas-committers/maas

STATUS: FAILED
LOG: http://maas-ci.internal:8080/job/maas-tester/318/consoleText
COMMIT: c19b8616c6acc5cde30fa33916ad40da3d3b04e5

review: Needs Fixing
Revision history for this message
Adam Collard (adam-collard) wrote :

jenkins: !test

Revision history for this message
MAAS Lander (maas-lander) wrote :

UNIT TESTS
-b rpc_connection_pool_burst lp:~cgrabowski/maas/+git/maas into -b master lp:~maas-committers/maas

STATUS: FAILED
LOG: http://maas-ci.internal:8080/job/maas-tester/321/consoleText
COMMIT: c19b8616c6acc5cde30fa33916ad40da3d3b04e5

review: Needs Fixing
Revision history for this message
MAAS Lander (maas-lander) wrote :

UNIT TESTS
-b rpc_connection_pool_burst lp:~cgrabowski/maas/+git/maas into -b master lp:~maas-committers/maas

STATUS: SUCCESS
COMMIT: 0c5564670e61e93261925cf91a46850824536a77

review: Approve
Revision history for this message
Alexsander de Souza (alexsander-souza) wrote :

LGTM

review: Approve

Preview Diff

[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1diff --git a/src/provisioningserver/config.py b/src/provisioningserver/config.py
2index 97b6e68..9511ad6 100644
3--- a/src/provisioningserver/config.py
4+++ b/src/provisioningserver/config.py
5@@ -762,6 +762,23 @@ class ClusterConfiguration(Configuration, metaclass=ClusterConfigurationMeta):
6 ),
7 )
8
9+ # RPC Connection Pool options
10+ max_idle_rpc_connections = ConfigurationOption(
11+ "max_idle_rpc_connections",
12+ "The nominal number of connections to have per endpoint",
13+ Number(min=1, max=1024, if_missing=1),
14+ )
15+ max_rpc_connections = ConfigurationOption(
16+ "max_rpc_connections",
17+ "The maximum number of connections to scale to when under load",
18+ Number(min=1, max=1024, if_missing=4),
19+ )
20+ rpc_keepalive = ConfigurationOption(
21+ "rpc_keepalive",
22+ "The duration in miliseconds to keep added connections alive",
23+ Number(min=1, max=5000, if_missing=1000),
24+ )
25+
26 # TFTP options.
27 tftp_port = ConfigurationOption(
28 "tftp_port",
29diff --git a/src/provisioningserver/dhcp/tests/test_config.py b/src/provisioningserver/dhcp/tests/test_config.py
30index c53e906..8c3f2fe 100644
31--- a/src/provisioningserver/dhcp/tests/test_config.py
32+++ b/src/provisioningserver/dhcp/tests/test_config.py
33@@ -176,7 +176,7 @@ def validate_dhcpd_configuration(test, configuration, ipv6):
34 ),
35 ),
36 )
37- cmd = (
38+ cmd = [
39 "dhcpd",
40 ("-6" if ipv6 else "-4"),
41 "-t",
42@@ -184,7 +184,7 @@ def validate_dhcpd_configuration(test, configuration, ipv6):
43 conffile.name,
44 "-lf",
45 leasesfile.name,
46- )
47+ ]
48 if not running_in_docker():
49 # Call `dhcpd` without AppArmor confinement, so that it can read
50 # configurations file from /tmp. This is not needed when running
51diff --git a/src/provisioningserver/plugin.py b/src/provisioningserver/plugin.py
52index e987c73..00ff898 100644
53--- a/src/provisioningserver/plugin.py
54+++ b/src/provisioningserver/plugin.py
55@@ -139,7 +139,13 @@ class ProvisioningServiceMaker:
56 def _makeRPCService(self):
57 from provisioningserver.rpc.clusterservice import ClusterClientService
58
59- rpc_service = ClusterClientService(reactor)
60+ with ClusterConfiguration.open() as config:
61+ rpc_service = ClusterClientService(
62+ reactor,
63+ config.max_idle_rpc_connections,
64+ config.max_rpc_connections,
65+ config.rpc_keepalive,
66+ )
67 rpc_service.setName("rpc")
68 return rpc_service
69
70diff --git a/src/provisioningserver/rackdservices/external.py b/src/provisioningserver/rackdservices/external.py
71index ccabb74..5b8afe5 100644
72--- a/src/provisioningserver/rackdservices/external.py
73+++ b/src/provisioningserver/rackdservices/external.py
74@@ -68,8 +68,9 @@ class RackOnlyExternalService(metaclass=ABCMeta):
75
76 # Filter the connects by region.
77 conn_per_region = defaultdict(set)
78- for eventloop, connection in connections.items():
79- conn_per_region[eventloop.split(":")[0]].add(connection)
80+ for eventloop, connection_set in connections.items():
81+ for connection in connection_set:
82+ conn_per_region[eventloop.split(":")[0]].add(connection)
83 for eventloop, connections in conn_per_region.items():
84 # Sort the connections so the same IP is always picked per
85 # region controller. This ensures that the HTTP configuration
86diff --git a/src/provisioningserver/rackdservices/http.py b/src/provisioningserver/rackdservices/http.py
87index 421e35f..bda9d23 100644
88--- a/src/provisioningserver/rackdservices/http.py
89+++ b/src/provisioningserver/rackdservices/http.py
90@@ -101,8 +101,9 @@ class RackHTTPService(TimerService):
91 controller is connected to."""
92 # Filter the connects by region.
93 conn_per_region = defaultdict(set)
94- for eventloop, connection in self._rpc_service.connections.items():
95- conn_per_region[eventloop.split(":")[0]].add(connection)
96+ for eventloop, connection_set in self._rpc_service.connections.items():
97+ for connection in connection_set:
98+ conn_per_region[eventloop.split(":")[0]].add(connection)
99 for _, connections in conn_per_region.items():
100 # Sort the connections so the same IP is always picked per
101 # region controller. This ensures that the HTTP configuration
102diff --git a/src/provisioningserver/rackdservices/tests/test_external.py b/src/provisioningserver/rackdservices/tests/test_external.py
103index ad214a1..0cb8601 100644
104--- a/src/provisioningserver/rackdservices/tests/test_external.py
105+++ b/src/provisioningserver/rackdservices/tests/test_external.py
106@@ -430,7 +430,8 @@ class TestRackDNS(MAASTestCase):
107 return frozenset(
108 {
109 client.address[0]
110- for _, client in rpc_service.connections.items()
111+ for _, clients in rpc_service.connections.items()
112+ for client in clients
113 }
114 )
115
116@@ -609,7 +610,7 @@ class TestRackDNS(MAASTestCase):
117 ip = factory.make_ip_address()
118 mock_conn = Mock()
119 mock_conn.address = (ip, random.randint(5240, 5250))
120- mock_rpc.connections[eventloop] = mock_conn
121+ mock_rpc.connections[eventloop] = {mock_conn}
122
123 dns = external.RackDNS()
124 region_ips = list(dns._genRegionIps(mock_rpc.connections))
125@@ -626,7 +627,7 @@ class TestRackDNS(MAASTestCase):
126 ip = factory.make_ip_address()
127 mock_conn = Mock()
128 mock_conn.address = (ip, random.randint(5240, 5250))
129- mock_rpc.connections[eventloop] = mock_conn
130+ mock_rpc.connections[eventloop] = {mock_conn}
131
132 dns = external.RackDNS()
133 region_ips = frozenset(dns._genRegionIps(mock_rpc.connections))
134@@ -659,7 +660,8 @@ class TestRackProxy(MAASTestCase):
135 return frozenset(
136 {
137 client.address[0]
138- for _, client in rpc_service.connections.items()
139+ for _, clients in rpc_service.connections.items()
140+ for client in clients
141 }
142 )
143
144@@ -824,7 +826,8 @@ class TestRackSyslog(MAASTestCase):
145 return frozenset(
146 {
147 (eventloop, client.address[0])
148- for eventloop, client in rpc_service.connections.items()
149+ for eventloop, clients in rpc_service.connections.items()
150+ for client in clients
151 }
152 )
153
154diff --git a/src/provisioningserver/rackdservices/tests/test_http.py b/src/provisioningserver/rackdservices/tests/test_http.py
155index bc43c66..43cb495 100644
156--- a/src/provisioningserver/rackdservices/tests/test_http.py
157+++ b/src/provisioningserver/rackdservices/tests/test_http.py
158@@ -92,7 +92,8 @@ class TestRackHTTPService(MAASTestCase):
159 return frozenset(
160 {
161 client.address[0]
162- for _, client in rpc_service.connections.items()
163+ for _, clients in rpc_service.connections.items()
164+ for client in clients
165 }
166 )
167
168@@ -208,7 +209,7 @@ class TestRackHTTPService(MAASTestCase):
169 ip = factory.make_ip_address()
170 mock_conn = Mock()
171 mock_conn.address = (ip, random.randint(5240, 5250))
172- mock_rpc.connections[eventloop] = mock_conn
173+ mock_rpc.connections[eventloop] = {mock_conn}
174
175 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)
176 region_ips = list(service._genRegionIps())
177@@ -225,7 +226,7 @@ class TestRackHTTPService(MAASTestCase):
178 ip = factory.make_ip_address()
179 mock_conn = Mock()
180 mock_conn.address = (ip, random.randint(5240, 5250))
181- mock_rpc.connections[eventloop] = mock_conn
182+ mock_rpc.connections[eventloop] = {mock_conn}
183
184 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)
185 region_ips = frozenset(service._genRegionIps())
186@@ -244,7 +245,7 @@ class TestRackHTTPService(MAASTestCase):
187 ip_addresses.add("[%s]" % ip)
188 mock_conn = Mock()
189 mock_conn.address = (ip, random.randint(5240, 5250))
190- mock_rpc.connections[eventloop] = mock_conn
191+ mock_rpc.connections[eventloop] = {mock_conn}
192
193 service = http.RackHTTPService(self.make_dir(), mock_rpc, reactor)
194 region_ips = set(service._genRegionIps())
195diff --git a/src/provisioningserver/rpc/clusterservice.py b/src/provisioningserver/rpc/clusterservice.py
196index c92d48a..a7205db 100644
197--- a/src/provisioningserver/rpc/clusterservice.py
198+++ b/src/provisioningserver/rpc/clusterservice.py
199@@ -9,7 +9,6 @@ import json
200 from operator import itemgetter
201 import os
202 from os import urandom
203-import random
204 from socket import AF_INET, AF_INET6, gethostname
205 import sys
206 from urllib.parse import urlparse
207@@ -24,7 +23,6 @@ from twisted.internet.defer import (
208 maybeDeferred,
209 returnValue,
210 )
211-from twisted.internet.endpoints import connectProtocol, TCP6ClientEndpoint
212 from twisted.internet.error import ConnectError, ConnectionClosed, ProcessDone
213 from twisted.internet.threads import deferToThread
214 from twisted.protocols import amp
215@@ -67,6 +65,7 @@ from provisioningserver.rpc.boot_images import (
216 list_boot_images,
217 )
218 from provisioningserver.rpc.common import Ping, RPCProtocol
219+from provisioningserver.rpc.connectionpool import ConnectionPool
220 from provisioningserver.rpc.exceptions import CannotConfigureDHCP
221 from provisioningserver.rpc.interfaces import IConnectionToRegion
222 from provisioningserver.rpc.osystems import (
223@@ -999,6 +998,7 @@ class ClusterClient(Cluster):
224 # Events for this protocol's life-cycle.
225 self.authenticated = DeferredValue()
226 self.ready = DeferredValue()
227+ self.in_use = False
228 self.localIdent = None
229
230 @property
231@@ -1201,13 +1201,15 @@ class ClusterClientService(TimerService):
232
233 time_started = None
234
235- def __init__(self, reactor):
236+ def __init__(self, reactor, max_idle_conns=1, max_conns=1, keepalive=1000):
237 super().__init__(self._calculate_interval(None, None), self._tryUpdate)
238- self.connections = {}
239- self.try_connections = {}
240 self._previous_work = (None, None)
241 self.clock = reactor
242
243+ self.connections = ConnectionPool(
244+ reactor, self, max_idle_conns, max_conns, keepalive
245+ )
246+
247 # Stored the URL used to connect to the region controller. This will be
248 # the URL that was used to get the eventloops.
249 self.maas_url = None
250@@ -1236,11 +1238,19 @@ class ClusterClientService(TimerService):
251 :raises: :py:class:`~.exceptions.NoConnectionsAvailable` when
252 there are no open connections to a region controller.
253 """
254- conns = list(self.connections.values())
255- if len(conns) == 0:
256+ if len(self.connections) == 0:
257 raise exceptions.NoConnectionsAvailable()
258 else:
259- return common.Client(random.choice(conns))
260+ try:
261+ return common.Client(
262+ self.connections.get_random_free_connection()
263+ )
264+ except exceptions.AllConnectionsBusy as e:
265+ for endpoint_conns in self.connections.values():
266+ if len(endpoint_conns) < self.connections._max_connections:
267+ raise e
268+ # return a busy connection, assume it will free up or timeout
269+ return common.Client(self.connections.get_random_connection())
270
271 @deferred
272 def getClientNow(self):
273@@ -1259,10 +1269,17 @@ class ClusterClientService(TimerService):
274 return self.getClient()
275 except exceptions.NoConnectionsAvailable:
276 return self._tryUpdate().addCallback(call, self.getClient)
277+ except exceptions.AllConnectionsBusy:
278+ return self.connections.scale_up_connections().addCallback(
279+ call, self.getClient
280+ )
281
282 def getAllClients(self):
283 """Return a list of all connected :class:`common.Client`s."""
284- return [common.Client(conn) for conn in self.connections.values()]
285+ return [
286+ common.Client(conn)
287+ for conn in self.connections.get_all_connections()
288+ ]
289
290 def _tryUpdate(self):
291 """Attempt to refresh outgoing connections.
292@@ -1391,7 +1408,9 @@ class ClusterClientService(TimerService):
293 """Update the saved RPC info state."""
294 # Build a list of addresses based on the current connections.
295 connected_addr = {
296- conn.address[0] for _, conn in self.connections.items()
297+ conn.address[0]
298+ for _, conns in self.connections.items()
299+ for conn in conns
300 }
301 if (
302 self._rpc_info_state is None
303@@ -1467,8 +1486,8 @@ class ClusterClientService(TimerService):
304 # Gather the list of successful responses.
305 successful = []
306 errors = []
307- for sucess, result in results:
308- if sucess:
309+ for success, result in results:
310+ if success:
311 body, orig_url = result
312 eventloops = body.get("eventloops")
313 if eventloops is not None:
314@@ -1656,12 +1675,15 @@ class ClusterClientService(TimerService):
315 "Dropping connections to event-loops: %s"
316 % (", ".join(drop.keys()))
317 )
318+ drop_defers = []
319+ for eventloop, connections in drop.items():
320+ for connection in connections:
321+ drop_defers.append(
322+ maybeDeferred(self.connections.disconnect, connection)
323+ )
324+ self.connections.remove_connection(eventloop, connection)
325 yield DeferredList(
326- [
327- maybeDeferred(self._drop_connection, connection)
328- for eventloop, connections in drop.items()
329- for connection in connections
330- ],
331+ drop_defers,
332 consumeErrors=True,
333 )
334
335@@ -1692,11 +1714,12 @@ class ClusterClientService(TimerService):
336 # between consenting adults.
337 for eventloop, addresses in eventloops.items():
338 if eventloop in self.connections:
339- connection = self.connections[eventloop]
340- if connection.address not in addresses:
341- drop[eventloop] = [connection]
342- if eventloop in self.try_connections:
343- connection = self.try_connections[eventloop]
344+ connection_list = self.connections[eventloop]
345+ for connection in connection_list:
346+ if connection.address not in addresses:
347+ drop[eventloop] = [connection]
348+ if self.connections.is_staged(eventloop):
349+ connection = self.connections.get_staged_connection(eventloop)
350 if connection.address not in addresses:
351 drop[eventloop] = [connection]
352
353@@ -1705,7 +1728,7 @@ class ClusterClientService(TimerService):
354 for eventloop, addresses in eventloops.items():
355 if (
356 eventloop not in self.connections
357- and eventloop not in self.try_connections
358+ and not self.connections.is_staged(eventloop)
359 ) or eventloop in drop:
360 connect[eventloop] = addresses
361
362@@ -1714,13 +1737,13 @@ class ClusterClientService(TimerService):
363 # the process in which the event-loop is no longer running, but
364 # it could be an indicator of a heavily loaded machine, or a
365 # fault. In any case, it seems to make sense to disconnect.
366- for eventloop in self.connections:
367+ for eventloop in self.connections.keys():
368 if eventloop not in eventloops:
369- connection = self.connections[eventloop]
370- drop[eventloop] = [connection]
371- for eventloop in self.try_connections:
372+ connection_list = self.connections[eventloop]
373+ drop[eventloop] = connection_list
374+ for eventloop in self.connections.get_staged_connections():
375 if eventloop not in eventloops:
376- connection = self.try_connections[eventloop]
377+ connection = self.connections.get_staged_connection(eventloop)
378 drop[eventloop] = [connection]
379
380 return drop, connect
381@@ -1730,7 +1753,7 @@ class ClusterClientService(TimerService):
382 """Connect to `eventloop` using all `addresses`."""
383 for address in addresses:
384 try:
385- connection = yield self._make_connection(eventloop, address)
386+ connection = yield self.connections.connect(eventloop, address)
387 except ConnectError as error:
388 host, port = address
389 log.msg(
390@@ -1747,29 +1770,17 @@ class ClusterClientService(TimerService):
391 ),
392 )
393 else:
394- self.try_connections[eventloop] = connection
395+ self.connections.stage_connection(eventloop, connection)
396 break
397
398- def _make_connection(self, eventloop, address):
399- """Connect to `eventloop` at `address`."""
400- # Force everything to use AF_INET6 sockets.
401- endpoint = TCP6ClientEndpoint(self.clock, *address)
402- protocol = ClusterClient(address, eventloop, self)
403- return connectProtocol(endpoint, protocol)
404-
405- def _drop_connection(self, connection):
406- """Drop the given `connection`."""
407- return connection.transport.loseConnection()
408-
409+ @inlineCallbacks
410 def add_connection(self, eventloop, connection):
411 """Add the connection to the tracked connections.
412
413 Update the saved RPC info state information based on the new
414 connection.
415 """
416- if eventloop in self.try_connections:
417- del self.try_connections[eventloop]
418- self.connections[eventloop] = connection
419+ yield self.connections.add_connection(eventloop, connection)
420 self._update_saved_rpc_info_state()
421
422 def remove_connection(self, eventloop, connection):
423@@ -1778,12 +1789,7 @@ class ClusterClientService(TimerService):
424 If this is the last connection that was keeping rackd connected to
425 a regiond then dhcpd and dhcpd6 services will be turned off.
426 """
427- if eventloop in self.try_connections:
428- if self.try_connections[eventloop] is connection:
429- del self.try_connections[eventloop]
430- if eventloop in self.connections:
431- if self.connections[eventloop] is connection:
432- del self.connections[eventloop]
433+ self.connections.remove_connection(eventloop, connection)
434 # Disable DHCP when no connections to a region controller.
435 if len(self.connections) == 0:
436 stopping_services = []
437diff --git a/src/provisioningserver/rpc/common.py b/src/provisioningserver/rpc/common.py
438index 5d67bba..40e091f 100644
439--- a/src/provisioningserver/rpc/common.py
440+++ b/src/provisioningserver/rpc/common.py
441@@ -14,7 +14,11 @@ from twisted.python.failure import Failure
442 from provisioningserver.logger import LegacyLogger
443 from provisioningserver.prometheus.metrics import PROMETHEUS_METRICS
444 from provisioningserver.rpc.interfaces import IConnection, IConnectionToRegion
445-from provisioningserver.utils.twisted import asynchronous, deferWithTimeout
446+from provisioningserver.utils.twisted import (
447+ asynchronous,
448+ callOut,
449+ deferWithTimeout,
450+)
451
452 log = LegacyLogger()
453
454@@ -156,6 +160,11 @@ class Client:
455 :return: A deferred result. Call its `wait` method (with a timeout
456 in seconds) to block on the call's completion.
457 """
458+ self._conn.in_use = True
459+
460+ def _free_conn():
461+ self._conn.in_use = False
462+
463 if len(args) != 0:
464 receiver_name = "{}.{}".format(
465 self.__module__,
466@@ -171,11 +180,19 @@ class Client:
467 if timeout is undefined:
468 timeout = 120 # 2 minutes
469 if timeout is None or timeout <= 0:
470- return self._conn.callRemote(cmd, **kwargs)
471+ d = self._conn.callRemote(cmd, **kwargs)
472+ if isinstance(d, Deferred):
473+ d.addBoth(lambda x: callOut(x, _free_conn))
474+ else:
475+ _free_conn()
476+ return d
477 else:
478- return deferWithTimeout(
479- timeout, self._conn.callRemote, cmd, **kwargs
480- )
481+ d = deferWithTimeout(timeout, self._conn.callRemote, cmd, **kwargs)
482+ if isinstance(d, Deferred):
483+ d.addBoth(lambda x: callOut(x, _free_conn))
484+ else:
485+ _free_conn()
486+ return d
487
488 @asynchronous
489 def getHostCertificate(self):
490diff --git a/src/provisioningserver/rpc/connectionpool.py b/src/provisioningserver/rpc/connectionpool.py
491new file mode 100644
492index 0000000..8023f80
493--- /dev/null
494+++ b/src/provisioningserver/rpc/connectionpool.py
495@@ -0,0 +1,163 @@
496+# Copyright 2022 Canonical Ltd. This software is licensed under the
497+# GNU Affero General Public License version 3 (see the file LICENSE).
498+
499+""" RPC Connection Pooling and Lifecycle """
500+
501+import random
502+
503+from twisted.internet.defer import inlineCallbacks
504+from twisted.internet.endpoints import connectProtocol, TCP6ClientEndpoint
505+
506+from provisioningserver.rpc import exceptions
507+
508+
509+class ConnectionPool:
510+ def __init__(
511+ self, reactor, service, max_idle_conns=1, max_conns=1, keepalive=1000
512+ ):
513+ # The maximum number of connections to allways allocate per eventloop
514+ self._max_idle_connections = max_idle_conns
515+ # The maximum number of connections to allocate while under load per eventloop
516+ self._max_connections = max_conns
517+ # The duration in milliseconds to keep scaled up connections alive
518+ self._keepalive = keepalive
519+
520+ self.connections = {}
521+ self.try_connections = {}
522+ self.clock = reactor
523+ self._service = service
524+
525+ def __setitem__(self, key, item):
526+ self.connections[key] = item
527+
528+ def __getitem__(self, key):
529+ return self.connections.get(key)
530+
531+ def __repr__(self):
532+ return repr(self.connections)
533+
534+ def __len__(self):
535+ return len(self.get_all_connections())
536+
537+ def __delitem__(self, key):
538+ del self.connections[key]
539+
540+ def __contains__(self, item):
541+ return item in self.connections
542+
543+ def __cmp__(self, value):
544+ return self.connections.__cmp__(value)
545+
546+ def __eq__(self, value):
547+ return self.connections.__eq__(value)
548+
549+ def keys(self):
550+ return self.connections.keys()
551+
552+ def values(self):
553+ return self.connections.values()
554+
555+ def items(self):
556+ return self.connections.items()
557+
558+ def _reap_extra_connection(self, eventloop, conn):
559+ if not conn.in_use:
560+ self.disconnect(conn)
561+ return self.remove_connection(eventloop, conn)
562+ return self.clock.callLater(
563+ self._keepalive, self._reap_extra_connection, eventloop, conn
564+ )
565+
566+ def is_staged(self, eventloop):
567+ return eventloop in self.try_connections
568+
569+ def get_staged_connection(self, eventloop):
570+ return self.try_connections.get(eventloop)
571+
572+ def get_staged_connections(self):
573+ return self.try_connections
574+
575+ def stage_connection(self, eventloop, connection):
576+ self.try_connections[eventloop] = connection
577+
578+ @inlineCallbacks
579+ def scale_up_connections(self):
580+ for ev, ev_conns in self.connections.items():
581+ # pick first group with room for additional conns
582+ if len(ev_conns) < self._max_connections:
583+ # spawn an extra connection
584+ conn_to_clone = random.choice(list(ev_conns))
585+ conn = yield self.connect(ev, conn_to_clone.address)
586+ self.connections[ev].append(conn)
587+ self.clock.callLater(
588+ self._keepalive, self._reap_extra_connection, ev, conn
589+ )
590+ return
591+ raise exceptions.MaxConnectionsOpen()
592+
593+ def get_connection(self, eventloop):
594+ return random.choice(self.connections[eventloop])
595+
596+ def get_random_connection(self):
597+ return random.choice(self.get_all_connections())
598+
599+ def get_random_free_connection(self):
600+ free_conns = self.get_all_free_connections()
601+ if len(free_conns) == 0:
602+ # caller should create a new connection
603+ raise exceptions.AllConnectionsBusy()
604+ return random.choice(free_conns)
605+
606+ def get_all_connections(self):
607+ return [
608+ conn
609+ for conn_list in self.connections.values()
610+ for conn in conn_list
611+ ]
612+
613+ def get_all_free_connections(self):
614+ return [
615+ conn
616+ for conn_list in self.connections.values()
617+ for conn in conn_list
618+ if not conn.in_use
619+ ]
620+
621+ @inlineCallbacks
622+ def connect(self, eventloop, address):
623+ from provisioningserver.rpc.clusterservice import ClusterClient
624+
625+ # Force everything to use AF_INET6 sockets.
626+ endpoint = TCP6ClientEndpoint(self.clock, *address)
627+ protocol = ClusterClient(address, eventloop, self._service)
628+ conn = yield connectProtocol(endpoint, protocol)
629+ return conn
630+
631+ def disconnect(self, connection):
632+ return connection.transport.loseConnection()
633+
634+ @inlineCallbacks
635+ def add_connection(self, eventloop, connection):
636+ if self.is_staged(eventloop):
637+ del self.try_connections[eventloop]
638+ if eventloop not in self.connections:
639+ self.connections[eventloop] = []
640+
641+ self.connections[eventloop].append(connection)
642+
643+ # clone connection to equal num idle connections
644+ if self._max_idle_connections - 1 > 0:
645+ for _ in range(self._max_idle_connections - 1):
646+ extra_conn = yield self.connect(
647+ connection.eventloop, connection.address
648+ )
649+ self.connections[eventloop].append(extra_conn)
650+
651+ def remove_connection(self, eventloop, connection):
652+ if self.is_staged(eventloop):
653+ if self.try_connections[eventloop] is connection:
654+ del self.try_connections[eventloop]
655+ if connection in self.connections.get(eventloop, []):
656+ self.connections[eventloop].remove(connection)
657+ if len(self.connections[eventloop]) == 0:
658+ del self.connections[eventloop]
659diff --git a/src/provisioningserver/rpc/exceptions.py b/src/provisioningserver/rpc/exceptions.py
660index 7ee4f3f..136e471 100644
661--- a/src/provisioningserver/rpc/exceptions.py
662+++ b/src/provisioningserver/rpc/exceptions.py
663@@ -12,6 +12,14 @@ class NoConnectionsAvailable(Exception):
664 self.uuid = uuid
665
666
667+class AllConnectionsBusy(Exception):
668+ """The current connection pool is busy"""
669+
670+
671+class MaxConnectionsOpen(Exception):
672+ """The maxmimum number of connections are currently open"""
673+
674+
675 class NoSuchEventType(Exception):
676 """The specified event type was not found."""
677
678diff --git a/src/provisioningserver/rpc/testing/__init__.py b/src/provisioningserver/rpc/testing/__init__.py
679index ee4a9e2..1b2f94f 100644
680--- a/src/provisioningserver/rpc/testing/__init__.py
681+++ b/src/provisioningserver/rpc/testing/__init__.py
682@@ -262,7 +262,8 @@ class MockClusterToRegionRPCFixtureBase(fixtures.Fixture, metaclass=ABCMeta):
683 {
684 "eventloops": {
685 eventloop: [client.address]
686- for eventloop, client in connections
687+ for eventloop, clients in connections
688+ for client in clients
689 }
690 },
691 orig_url,
692diff --git a/src/provisioningserver/rpc/testing/doubles.py b/src/provisioningserver/rpc/testing/doubles.py
693index cb9f27f..0785859 100644
694--- a/src/provisioningserver/rpc/testing/doubles.py
695+++ b/src/provisioningserver/rpc/testing/doubles.py
696@@ -30,6 +30,7 @@ class FakeConnection:
697 ident = attr.ib(default=sentinel.ident)
698 hostCertificate = attr.ib(default=sentinel.hostCertificate)
699 peerCertificate = attr.ib(default=sentinel.peerCertificate)
700+ in_use = attr.ib(default=False)
701
702 def callRemote(self, cmd, **arguments):
703 return succeed(sentinel.response)
704@@ -48,6 +49,7 @@ class FakeConnectionToRegion:
705 address = attr.ib(default=(sentinel.host, sentinel.port))
706 hostCertificate = attr.ib(default=sentinel.hostCertificate)
707 peerCertificate = attr.ib(default=sentinel.peerCertificate)
708+ in_use = attr.ib(default=False)
709
710 def callRemote(self, cmd, **arguments):
711 return succeed(sentinel.response)
712@@ -56,6 +58,22 @@ class FakeConnectionToRegion:
713 verifyObject(IConnectionToRegion, FakeConnectionToRegion())
714
715
716+@attr.s(eq=False, order=False)
717+@implementer(IConnectionToRegion)
718+class FakeBusyConnectionToRegion:
719+ "A fake `IConnectionToRegion` that appears busy." ""
720+
721+ ident = attr.ib(default=sentinel.ident)
722+ localIdent = attr.ib(default=sentinel.localIdent)
723+ address = attr.ib(default=(sentinel.host, sentinel.port))
724+ hostCertificate = attr.ib(default=sentinel.hostCertificate)
725+ peerCertificate = attr.ib(default=sentinel.peerCertificate)
726+ in_use = attr.ib(default=True)
727+
728+ def callRemote(self, cmd, **arguments):
729+ return succeed(sentinel.response)
730+
731+
732 class StubOS(OperatingSystem):
733 """An :py:class:`OperatingSystem` subclass that has canned answers.
734
735diff --git a/src/provisioningserver/rpc/tests/test_clusterservice.py b/src/provisioningserver/rpc/tests/test_clusterservice.py
736index b50311d..6f3e4f9 100644
737--- a/src/provisioningserver/rpc/tests/test_clusterservice.py
738+++ b/src/provisioningserver/rpc/tests/test_clusterservice.py
739@@ -23,7 +23,6 @@ from testtools.matchers import (
740 Is,
741 IsInstance,
742 KeysEqual,
743- MatchesAll,
744 MatchesDict,
745 MatchesListwise,
746 MatchesStructure,
747@@ -32,7 +31,6 @@ from twisted import web
748 from twisted.application.internet import TimerService
749 from twisted.internet import error, reactor
750 from twisted.internet.defer import Deferred, fail, inlineCallbacks, succeed
751-from twisted.internet.endpoints import TCP6ClientEndpoint
752 from twisted.internet.error import ConnectionClosed
753 from twisted.internet.task import Clock
754 from twisted.internet.testing import StringTransportWithDisconnection
755@@ -117,7 +115,11 @@ from provisioningserver.rpc.testing import (
756 call_responder,
757 MockLiveClusterToRegionRPCFixture,
758 )
759-from provisioningserver.rpc.testing.doubles import DummyConnection, StubOS
760+from provisioningserver.rpc.testing.doubles import (
761+ FakeBusyConnectionToRegion,
762+ FakeConnection,
763+ StubOS,
764+)
765 from provisioningserver.security import set_shared_secret_on_filesystem
766 from provisioningserver.service_monitor import service_monitor
767 from provisioningserver.testing.config import ClusterConfigurationFixture
768@@ -444,8 +446,10 @@ class TestClusterProtocol_DescribePowerTypes(MAASTestCase):
769 )
770
771
772-def make_inert_client_service():
773- service = ClusterClientService(Clock())
774+def make_inert_client_service(max_idle_conns=1, max_conns=1, keepalive=1):
775+ service = ClusterClientService(
776+ Clock(), max_idle_conns, max_conns, keepalive
777+ )
778 # ClusterClientService's superclass, TimerService, creates a
779 # LoopingCall with now=True. We neuter it here to allow
780 # observation of the behaviour of _update_interval() for
781@@ -498,11 +502,11 @@ class TestClusterClientService(MAASTestCase):
782 )
783
784 # Fake some connections.
785- service.connections = {
786- ipv4client.eventloop: ipv4client,
787- ipv6client.eventloop: ipv6client,
788- ipv6mapped.eventloop: ipv6mapped,
789- hostclient.eventloop: hostclient,
790+ service.connections.connections = {
791+ ipv4client.eventloop: [ipv4client],
792+ ipv6client.eventloop: [ipv6client],
793+ ipv6mapped.eventloop: [ipv6mapped],
794+ hostclient.eventloop: [hostclient],
795 }
796
797 # Update the RPC state to the filesystem and info cache.
798@@ -515,7 +519,8 @@ class TestClusterClientService(MAASTestCase):
799 Equals(
800 {
801 client.address[0]
802- for _, client in service.connections.items()
803+ for _, clients in service.connections.items()
804+ for client in clients
805 }
806 ),
807 )
808@@ -999,9 +1004,9 @@ class TestClusterClientService(MAASTestCase):
809 def test_update_connections_initially(self):
810 service = ClusterClientService(Clock())
811 mock_client = Mock()
812- _make_connection = self.patch(service, "_make_connection")
813+ _make_connection = self.patch(service.connections, "connect")
814 _make_connection.side_effect = lambda *args: succeed(mock_client)
815- _drop_connection = self.patch(service, "_drop_connection")
816+ _drop_connection = self.patch(service.connections, "disconnect")
817
818 info = json.loads(self.example_rpc_info_view_response.decode("ascii"))
819 yield service._update_connections(info["eventloops"])
820@@ -1020,7 +1025,7 @@ class TestClusterClientService(MAASTestCase):
821 "host1:pid=2002": mock_client,
822 "host2:pid=3003": mock_client,
823 },
824- service.try_connections,
825+ service.connections.try_connections,
826 )
827
828 self.assertEqual([], _drop_connection.mock_calls)
829@@ -1038,7 +1043,7 @@ class TestClusterClientService(MAASTestCase):
830 for address in addresses:
831 client = Mock()
832 client.address = address
833- service.connections[eventloop] = client
834+ service.connections.connections[eventloop] = [client]
835
836 logger = self.useFixture(TwistedLoggerFixture())
837
838@@ -1055,7 +1060,7 @@ class TestClusterClientService(MAASTestCase):
839 @inlineCallbacks
840 def test_update_connections_connect_error_is_logged_tersely(self):
841 service = ClusterClientService(Clock())
842- _make_connection = self.patch(service, "_make_connection")
843+ _make_connection = self.patch(service.connections, "connect")
844 _make_connection.side_effect = error.ConnectionRefusedError()
845
846 logger = self.useFixture(TwistedLoggerFixture())
847@@ -1079,7 +1084,7 @@ class TestClusterClientService(MAASTestCase):
848 @inlineCallbacks
849 def test_update_connections_unknown_error_is_logged_with_stack(self):
850 service = ClusterClientService(Clock())
851- _make_connection = self.patch(service, "_make_connection")
852+ _make_connection = self.patch(service.connections, "connect")
853 _make_connection.side_effect = RuntimeError("Something went wrong.")
854
855 logger = self.useFixture(TwistedLoggerFixture())
856@@ -1106,8 +1111,8 @@ class TestClusterClientService(MAASTestCase):
857
858 def test_update_connections_when_there_are_existing_connections(self):
859 service = ClusterClientService(Clock())
860- _make_connection = self.patch(service, "_make_connection")
861- _drop_connection = self.patch(service, "_drop_connection")
862+ _connect = self.patch(service.connections, "connect")
863+ _disconnect = self.patch(service.connections, "disconnect")
864
865 host1client = ClusterClient(
866 ("::ffff:1.1.1.1", 1111), "host1:pid=1", service
867@@ -1120,9 +1125,9 @@ class TestClusterClientService(MAASTestCase):
868 )
869
870 # Fake some connections.
871- service.connections = {
872- host1client.eventloop: host1client,
873- host2client.eventloop: host2client,
874+ service.connections.connections = {
875+ host1client.eventloop: [host1client],
876+ host2client.eventloop: [host2client],
877 }
878
879 # Request a new set of connections that overlaps with the
880@@ -1137,10 +1142,10 @@ class TestClusterClientService(MAASTestCase):
881 # A connection is made for host3's event-loop, and the
882 # connection to host2's event-loop is dropped.
883 self.assertThat(
884- _make_connection,
885+ _connect,
886 MockCalledOnceWith(host3client.eventloop, host3client.address),
887 )
888- self.assertThat(_drop_connection, MockCalledWith(host2client))
889+ self.assertThat(_disconnect, MockCalledWith(host2client))
890
891 @inlineCallbacks
892 def test_update_only_updates_interval_when_eventloops_are_unknown(self):
893@@ -1175,57 +1180,15 @@ class TestClusterClientService(MAASTestCase):
894 logger.dump(),
895 )
896
897- def test_make_connection(self):
898- service = ClusterClientService(Clock())
899- connectProtocol = self.patch(clusterservice, "connectProtocol")
900- service._make_connection("an-event-loop", ("a.example.com", 1111))
901- self.assertThat(connectProtocol.call_args_list, HasLength(1))
902- self.assertThat(
903- connectProtocol.call_args_list[0][0],
904- MatchesListwise(
905- (
906- # First argument is an IPv4 TCP client endpoint
907- # specification.
908- MatchesAll(
909- IsInstance(TCP6ClientEndpoint),
910- MatchesStructure.byEquality(
911- _reactor=service.clock,
912- _host="a.example.com",
913- _port=1111,
914- ),
915- ),
916- # Second argument is a ClusterClient instance, the
917- # protocol to use for the connection.
918- MatchesAll(
919- IsInstance(clusterservice.ClusterClient),
920- MatchesStructure.byEquality(
921- address=("a.example.com", 1111),
922- eventloop="an-event-loop",
923- service=service,
924- ),
925- ),
926- )
927- ),
928- )
929-
930- def test_drop_connection(self):
931- connection = Mock()
932- service = make_inert_client_service()
933- service.startService()
934- service._drop_connection(connection)
935- self.assertThat(
936- connection.transport.loseConnection, MockCalledOnceWith()
937- )
938-
939 def test_add_connection_removes_from_try_connections(self):
940 service = make_inert_client_service()
941 service.startService()
942 endpoint = Mock()
943 connection = Mock()
944 connection.address = (":::ffff", 2222)
945- service.try_connections[endpoint] = connection
946+ service.connections.try_connections[endpoint] = connection
947 service.add_connection(endpoint, connection)
948- self.assertThat(service.try_connections, Equals({}))
949+ self.assertThat(service.connections.try_connections, Equals({}))
950
951 def test_add_connection_adds_to_connections(self):
952 service = make_inert_client_service()
953@@ -1234,7 +1197,7 @@ class TestClusterClientService(MAASTestCase):
954 connection = Mock()
955 connection.address = (":::ffff", 2222)
956 service.add_connection(endpoint, connection)
957- self.assertThat(service.connections, Equals({endpoint: connection}))
958+ self.assertEqual(service.connections, {endpoint: [connection]})
959
960 def test_add_connection_calls__update_saved_rpc_info_state(self):
961 service = make_inert_client_service()
962@@ -1248,21 +1211,45 @@ class TestClusterClientService(MAASTestCase):
963 service._update_saved_rpc_info_state, MockCalledOnceWith()
964 )
965
966+ def test_add_connection_creates_max_idle_connections(self):
967+ service = make_inert_client_service(max_idle_conns=2)
968+ service.startService()
969+ endpoint = Mock()
970+ connection = Mock()
971+ connection.address = (":::ffff", 2222)
972+ connection2 = Mock()
973+ connection.address = (":::ffff", 2222)
974+ self.patch(service.connections, "connect").return_value = succeed(
975+ connection2
976+ )
977+ self.patch_autospec(service, "_update_saved_rpc_info_state")
978+ service.add_connection(endpoint, connection)
979+ self.assertEqual(
980+ len(
981+ [
982+ conn
983+ for conns in service.connections.values()
984+ for conn in conns
985+ ]
986+ ),
987+ service.connections._max_idle_connections,
988+ )
989+
990 def test_remove_connection_removes_from_try_connections(self):
991 service = make_inert_client_service()
992 service.startService()
993 endpoint = Mock()
994 connection = Mock()
995- service.try_connections[endpoint] = connection
996+ service.connections.try_connections[endpoint] = connection
997 service.remove_connection(endpoint, connection)
998- self.assertThat(service.try_connections, Equals({}))
999+ self.assertEqual(service.connections.try_connections, {})
1000
1001 def test_remove_connection_removes_from_connections(self):
1002 service = make_inert_client_service()
1003 service.startService()
1004 endpoint = Mock()
1005 connection = Mock()
1006- service.connections[endpoint] = connection
1007+ service.connections[endpoint] = {connection}
1008 service.remove_connection(endpoint, connection)
1009 self.assertThat(service.connections, Equals({}))
1010
1011@@ -1271,7 +1258,7 @@ class TestClusterClientService(MAASTestCase):
1012 service.startService()
1013 endpoint = Mock()
1014 connection = Mock()
1015- service.connections[endpoint] = connection
1016+ service.connections[endpoint] = {connection}
1017 service.remove_connection(endpoint, connection)
1018 self.assertEqual(service.step, service.INTERVAL_LOW)
1019
1020@@ -1280,7 +1267,7 @@ class TestClusterClientService(MAASTestCase):
1021 service.startService()
1022 endpoint = Mock()
1023 connection = Mock()
1024- service.connections[endpoint] = connection
1025+ service.connections[endpoint] = {connection}
1026
1027 # Enable both dhcpd and dhcpd6.
1028 service_monitor.getServiceByName("dhcpd").on()
1029@@ -1294,45 +1281,96 @@ class TestClusterClientService(MAASTestCase):
1030
1031 def test_getClient(self):
1032 service = ClusterClientService(Clock())
1033- service.connections = {
1034- sentinel.eventloop01: DummyConnection(),
1035- sentinel.eventloop02: DummyConnection(),
1036- sentinel.eventloop03: DummyConnection(),
1037+ service.connections.connections = {
1038+ sentinel.eventloop01: [FakeConnection()],
1039+ sentinel.eventloop02: [FakeConnection()],
1040+ sentinel.eventloop03: [FakeConnection()],
1041 }
1042 self.assertIn(
1043 service.getClient(),
1044- {common.Client(conn) for conn in service.connections.values()},
1045+ {
1046+ common.Client(conn)
1047+ for conns in service.connections.values()
1048+ for conn in conns
1049+ },
1050 )
1051
1052 def test_getClient_when_there_are_no_connections(self):
1053 service = ClusterClientService(Clock())
1054- service.connections = {}
1055+ service.connections.connections = {}
1056 self.assertRaises(exceptions.NoConnectionsAvailable, service.getClient)
1057
1058 @inlineCallbacks
1059+ def test_getClientNow_scales_connections_when_busy(self):
1060+ service = ClusterClientService(Clock(), max_conns=2)
1061+ service.connections.connections = {
1062+ sentinel.eventloop01: [FakeBusyConnectionToRegion()],
1063+ sentinel.eventloop02: [FakeBusyConnectionToRegion()],
1064+ sentinel.eventloop03: [FakeBusyConnectionToRegion()],
1065+ }
1066+ self.patch(service.connections, "connect").return_value = succeed(
1067+ FakeConnection()
1068+ )
1069+ original_conns = [
1070+ conn for conns in service.connections.values() for conn in conns
1071+ ]
1072+ new_client = yield service.getClientNow()
1073+ new_conn = new_client._conn
1074+ self.assertIsNotNone(new_conn)
1075+ self.assertNotIn(new_conn, original_conns)
1076+ self.assertIn(
1077+ new_conn,
1078+ [conn for conns in service.connections.values() for conn in conns],
1079+ )
1080+
1081+ @inlineCallbacks
1082+ def test_getClientNow_returns_an_existing_connection_when_max_are_open(
1083+ self,
1084+ ):
1085+ service = ClusterClientService(Clock(), max_conns=1)
1086+ service.connections.connections = {
1087+ sentinel.eventloop01: [FakeBusyConnectionToRegion()],
1088+ sentinel.eventloop02: [FakeBusyConnectionToRegion()],
1089+ sentinel.eventloop03: [FakeBusyConnectionToRegion()],
1090+ }
1091+ self.patch(service, "_make_connection").return_value = succeed(
1092+ FakeConnection()
1093+ )
1094+ original_conns = [
1095+ conn for conns in service.connections.values() for conn in conns
1096+ ]
1097+ new_client = yield service.getClientNow()
1098+ new_conn = new_client._conn
1099+ self.assertIsNotNone(new_conn)
1100+ self.assertIn(new_conn, original_conns)
1101+
1102+ @inlineCallbacks
1103 def test_getClientNow_returns_current_connection(self):
1104 service = ClusterClientService(Clock())
1105- service.connections = {
1106- sentinel.eventloop01: DummyConnection(),
1107- sentinel.eventloop02: DummyConnection(),
1108- sentinel.eventloop03: DummyConnection(),
1109+ service.connections.connections = {
1110+ sentinel.eventloop01: [FakeConnection()],
1111+ sentinel.eventloop02: [FakeConnection()],
1112+ sentinel.eventloop03: [FakeConnection()],
1113 }
1114 client = yield service.getClientNow()
1115 self.assertIn(
1116 client,
1117- {common.Client(conn) for conn in service.connections.values()},
1118+ [
1119+ common.Client(conn)
1120+ for conns in service.connections.values()
1121+ for conn in conns
1122+ ],
1123 )
1124
1125 @inlineCallbacks
1126 def test_getClientNow_calls__tryUpdate_when_there_are_no_connections(self):
1127 service = ClusterClientService(Clock())
1128- service.connections = {}
1129
1130 def addConnections():
1131- service.connections = {
1132- sentinel.eventloop01: DummyConnection(),
1133- sentinel.eventloop02: DummyConnection(),
1134- sentinel.eventloop03: DummyConnection(),
1135+ service.connections.connections = {
1136+ sentinel.eventloop01: [FakeConnection()],
1137+ sentinel.eventloop02: [FakeConnection()],
1138+ sentinel.eventloop03: [FakeConnection()],
1139 }
1140 return succeed(None)
1141
1142@@ -1340,12 +1378,15 @@ class TestClusterClientService(MAASTestCase):
1143 client = yield service.getClientNow()
1144 self.assertIn(
1145 client,
1146- {common.Client(conn) for conn in service.connections.values()},
1147+ {
1148+ common.Client(conn)
1149+ for conns in service.connections.values()
1150+ for conn in conns
1151+ },
1152 )
1153
1154 def test_getClientNow_raises_exception_when_no_clients(self):
1155 service = ClusterClientService(Clock())
1156- service.connections = {}
1157
1158 self.patch(service, "_tryUpdate").return_value = succeed(None)
1159 d = service.getClientNow()
1160@@ -1383,17 +1424,16 @@ class TestClusterClientService(MAASTestCase):
1161 def test_getAllClients(self):
1162 service = ClusterClientService(Clock())
1163 uuid1 = factory.make_UUID()
1164- c1 = DummyConnection()
1165- service.connections[uuid1] = c1
1166+ c1 = FakeConnection()
1167+ service.connections[uuid1] = {c1}
1168 uuid2 = factory.make_UUID()
1169- c2 = DummyConnection()
1170- service.connections[uuid2] = c2
1171+ c2 = FakeConnection()
1172+ service.connections[uuid2] = {c2}
1173 clients = service.getAllClients()
1174 self.assertEqual(clients, [common.Client(c1), common.Client(c2)])
1175
1176 def test_getAllClients_when_there_are_no_connections(self):
1177 service = ClusterClientService(Clock())
1178- service.connections = {}
1179 self.assertThat(service.getAllClients(), Equals([]))
1180
1181
1182@@ -1546,7 +1586,7 @@ class TestClusterClient(MAASTestCase):
1183
1184 def test_connecting(self):
1185 client = self.make_running_client()
1186- client.service.try_connections[client.eventloop] = client
1187+ client.service.connections.try_connections[client.eventloop] = client
1188 self.patch_authenticate_for_success(client)
1189 self.patch_register_for_success(client)
1190 self.assertEqual(client.service.connections, {})
1191@@ -1560,16 +1600,19 @@ class TestClusterClient(MAASTestCase):
1192 self.assertTrue(extract_result(wait_for_authenticated))
1193 # ready has been set with the name of the event-loop.
1194 self.assertEqual(client.eventloop, extract_result(wait_for_ready))
1195- self.assertEqual(client.service.try_connections, {})
1196+ self.assertEqual(len(client.service.connections.try_connections), 0)
1197 self.assertEqual(
1198- client.service.connections, {client.eventloop: client}
1199+ client.service.connections.connections,
1200+ {client.eventloop: [client]},
1201 )
1202
1203 def test_disconnects_when_there_is_an_existing_connection(self):
1204 client = self.make_running_client()
1205
1206 # Pretend that a connection already exists for this address.
1207- client.service.connections[client.eventloop] = sentinel.connection
1208+ client.service.connections.connections[client.eventloop] = [
1209+ sentinel.connection
1210+ ]
1211
1212 # Connect via an in-memory transport.
1213 transport = StringTransportWithDisconnection()
1214@@ -1586,7 +1629,8 @@ class TestClusterClient(MAASTestCase):
1215 # The connections list is unchanged because the new connection
1216 # immediately disconnects.
1217 self.assertEqual(
1218- client.service.connections, {client.eventloop: sentinel.connection}
1219+ client.service.connections,
1220+ {client.eventloop: [sentinel.connection]},
1221 )
1222 self.assertFalse(client.connected)
1223 self.assertIsNone(client.transport)
1224@@ -1631,7 +1675,7 @@ class TestClusterClient(MAASTestCase):
1225
1226 # The connections list is unchanged because the new connection
1227 # immediately disconnects.
1228- self.assertEqual(client.service.connections, {})
1229+ self.assertEqual(client.service.connections.connections, {})
1230 self.assertFalse(client.connected)
1231
1232 def test_disconnects_when_authentication_errors(self):
1233diff --git a/src/provisioningserver/rpc/tests/test_connectionpool.py b/src/provisioningserver/rpc/tests/test_connectionpool.py
1234new file mode 100644
1235index 0000000..692d5e6
1236--- /dev/null
1237+++ b/src/provisioningserver/rpc/tests/test_connectionpool.py
1238@@ -0,0 +1,280 @@
1239+# Copyright 2022 Canonical Ltd. This software is licensed under the
1240+# GNU Affero General Public License version 3 (see the file LICENSE).
1241+
1242+from unittest.mock import Mock
1243+
1244+from twisted.internet.defer import inlineCallbacks, succeed
1245+from twisted.internet.endpoints import TCP6ClientEndpoint
1246+from twisted.internet.task import Clock
1247+
1248+from maastesting import get_testing_timeout
1249+from maastesting.testcase import MAASTestCase, MAASTwistedRunTest
1250+from maastesting.twisted import extract_result
1251+from provisioningserver.rpc import connectionpool as connectionpoolModule
1252+from provisioningserver.rpc import exceptions
1253+from provisioningserver.rpc.clusterservice import ClusterClient
1254+from provisioningserver.rpc.connectionpool import ConnectionPool
1255+
1256+TIMEOUT = get_testing_timeout()
1257+
1258+
1259+class TestConnectionPool(MAASTestCase):
1260+
1261+ run_tests_with = MAASTwistedRunTest.make_factory(timeout=TIMEOUT)
1262+
1263+ def test_setitem_sets_item_in_connections(self):
1264+ cp = ConnectionPool(Clock(), Mock())
1265+ key = Mock()
1266+ val = Mock()
1267+ cp[key] = val
1268+ self.assertEqual(cp.connections, {key: val})
1269+
1270+ def test_getitem_gets_item_in_connections(self):
1271+ cp = ConnectionPool(Clock(), Mock())
1272+ key = Mock()
1273+ val = Mock()
1274+ cp[key] = val
1275+ self.assertEqual(cp.connections[key], cp[key])
1276+
1277+ def test_len_gets_length_of_connections(self):
1278+ cp = ConnectionPool(Clock(), Mock())
1279+ key = Mock()
1280+ val = Mock()
1281+ cp[key] = [val]
1282+ self.assertEqual(len(cp), len(cp.get_all_connections()))
1283+
1284+ def test_delitem_removes_item_from_connections(self):
1285+ cp = ConnectionPool(Clock(), Mock())
1286+ key = Mock()
1287+ val = Mock()
1288+ cp[key] = val
1289+ self.assertEqual(cp.connections[key], val)
1290+ del cp[key]
1291+ self.assertIsNone(cp.connections.get(key))
1292+
1293+ def test_contains_returns_if_key_in_connections(self):
1294+ cp = ConnectionPool(Clock(), Mock())
1295+ key = Mock()
1296+ val = Mock()
1297+ cp[key] = val
1298+ self.assertEqual(key in cp, key in cp.connections)
1299+
1300+ def test_compare_ConnectionPool_equal_to_compare_connections(self):
1301+ cp = ConnectionPool(Clock(), Mock())
1302+ self.assertEqual(cp, cp.connections)
1303+ self.assertEqual(cp, {})
1304+
1305+ def test__reap_extra_connection_reaps_a_non_busy_connection(self):
1306+ cp = ConnectionPool(Clock(), Mock())
1307+ eventloop = Mock()
1308+ connection = Mock()
1309+ connection.in_use = False
1310+ cp[eventloop] = [connection]
1311+ disconnect = self.patch(cp, "disconnect")
1312+ cp._reap_extra_connection(eventloop, connection)
1313+ self.assertEqual(len(cp), 0)
1314+ disconnect.assert_called_once_with(connection)
1315+
1316+ def test__reap_extra_connection_waits_for_a_busy_connection(self):
1317+ clock = Clock()
1318+ cp = ConnectionPool(clock, Mock())
1319+ eventloop = Mock()
1320+ connection = Mock()
1321+ connection.in_use = True
1322+ cp[eventloop] = [connection]
1323+ self.patch(cp, "disconnect")
1324+ cp._reap_extra_connection(eventloop, connection)
1325+ self.assertIn(eventloop, clock.calls[0].args)
1326+ self.assertIn(connection, clock.calls[0].args)
1327+ self.assertEqual(
1328+ "_reap_extra_connection", clock.calls[0].func.__name__
1329+ )
1330+ self.assertEqual(cp._keepalive, clock.calls[0].time)
1331+
1332+ def test_is_staged(self):
1333+ cp = ConnectionPool(Clock(), Mock())
1334+ eventloop1 = Mock()
1335+ eventloop2 = Mock()
1336+ cp.try_connections[eventloop1] = Mock()
1337+ self.assertTrue(cp.is_staged(eventloop1))
1338+ self.assertFalse(cp.is_staged(eventloop2))
1339+
1340+ def test_get_staged_connection(self):
1341+ cp = ConnectionPool(Clock(), Mock())
1342+ eventloop = Mock()
1343+ connection = Mock()
1344+ cp.try_connections[eventloop] = connection
1345+ self.assertEqual(cp.get_staged_connection(eventloop), connection)
1346+
1347+ def test_get_staged_connections(self):
1348+ cp = ConnectionPool(Clock(), Mock())
1349+ eventloop = Mock()
1350+ connection = Mock()
1351+ cp.try_connections[eventloop] = connection
1352+ self.assertEqual(cp.get_staged_connections(), {eventloop: connection})
1353+
1354+ def test_scale_up_connections_adds_a_connection(self):
1355+ cp = ConnectionPool(Clock(), Mock(), max_conns=2)
1356+ eventloop = Mock()
1357+ connection1 = Mock()
1358+ connection2 = Mock()
1359+ connect = self.patch(cp, "connect")
1360+ connect.return_value = succeed(connection2)
1361+ cp[eventloop] = [connection1]
1362+ cp.scale_up_connections()
1363+ self.assertCountEqual(cp[eventloop], [connection1, connection2])
1364+
1365+ def test_scale_up_connections_raises_MaxConnectionsOpen_when_cannot_create_another(
1366+ self,
1367+ ):
1368+ cp = ConnectionPool(Clock(), Mock())
1369+ eventloop = Mock()
1370+ connection1 = Mock()
1371+ connection2 = Mock()
1372+ connect = self.patch(cp, "connect")
1373+ connect.return_value = succeed(connection2)
1374+ cp[eventloop] = [connection1]
1375+ self.assertRaises(
1376+ exceptions.MaxConnectionsOpen,
1377+ extract_result,
1378+ cp.scale_up_connections(),
1379+ )
1380+
1381+ def test_get_connection(self):
1382+ cp = ConnectionPool(Clock(), Mock(), max_idle_conns=2, max_conns=2)
1383+ eventloops = [Mock() for _ in range(3)]
1384+ cp.connections = {
1385+ eventloop: [Mock() for _ in range(2)] for eventloop in eventloops
1386+ }
1387+ self.assertIn(cp.get_connection(eventloops[0]), cp[eventloops[0]])
1388+
1389+ def test_get_random_connection(self):
1390+ cp = ConnectionPool(Clock(), Mock(), max_idle_conns=2, max_conns=2)
1391+ eventloops = [Mock() for _ in range(3)]
1392+ cp.connections = {
1393+ eventloop: [Mock() for _ in range(2)] for eventloop in eventloops
1394+ }
1395+ self.assertIn(
1396+ cp.get_connection(eventloops[0]),
1397+ [conn for conn_list in cp.values() for conn in conn_list],
1398+ )
1399+
1400+ def test_get_random_free_connection_returns_a_free_connection(self):
1401+ cp = ConnectionPool(Clock(), Mock())
1402+ eventloops = [Mock() for _ in range(3)]
1403+
1404+ def _create_conn(in_use):
1405+ conn = Mock()
1406+ conn.in_use = in_use
1407+ return conn
1408+
1409+ cp.connections = {
1410+ eventloops[0]: [_create_conn(True)],
1411+ eventloops[1]: [_create_conn(False)],
1412+ eventloops[2]: [_create_conn(True)],
1413+ }
1414+ conn = cp.get_random_free_connection()
1415+ self.assertIn(conn, cp[eventloops[1]])
1416+
1417+ def test_get_random_free_connection_raises_AllConnectionsBusy_when_there_are_no_free_connections(
1418+ self,
1419+ ):
1420+ cp = ConnectionPool(Clock(), Mock())
1421+ eventloops = [Mock() for _ in range(3)]
1422+
1423+ def _create_conn(in_use):
1424+ conn = Mock()
1425+ conn.in_use = in_use
1426+ return conn
1427+
1428+ cp.connections = {
1429+ eventloops[0]: [_create_conn(True)],
1430+ eventloops[1]: [_create_conn(True)],
1431+ eventloops[2]: [_create_conn(True)],
1432+ }
1433+
1434+ self.assertRaises(
1435+ exceptions.AllConnectionsBusy, cp.get_random_free_connection
1436+ )
1437+
1438+ def test_get_all_connections(self):
1439+ cp = ConnectionPool(Clock(), Mock())
1440+ eventloops = [Mock() for _ in range(3)]
1441+ cp.connections = {
1442+ eventloops[0]: [Mock()],
1443+ eventloops[1]: [Mock()],
1444+ eventloops[2]: [Mock()],
1445+ }
1446+
1447+ self.assertCountEqual(
1448+ cp.get_all_connections(),
1449+ [conn for conn_list in cp.values() for conn in conn_list],
1450+ )
1451+
1452+ def test_get_all_free_connections(self):
1453+ cp = ConnectionPool(Clock(), Mock(), max_conns=2)
1454+ eventloops = [Mock() for _ in range(3)]
1455+
1456+ def _create_conn(in_use):
1457+ conn = Mock()
1458+ conn.in_use = in_use
1459+ return conn
1460+
1461+ cp.connections = {
1462+ eventloops[0]: [_create_conn(True), _create_conn(False)],
1463+ eventloops[1]: [_create_conn(True)],
1464+ eventloops[2]: [_create_conn(False)],
1465+ }
1466+
1467+ self.assertCountEqual(
1468+ cp.get_all_free_connections(),
1469+ [
1470+ conn
1471+ for conn_list in cp.values()
1472+ for conn in conn_list
1473+ if not conn.in_use
1474+ ],
1475+ )
1476+
1477+ @inlineCallbacks
1478+ def test_connect(self):
1479+ clock = Clock()
1480+ connection = Mock()
1481+ service = Mock()
1482+ cp = ConnectionPool(clock, service)
1483+ connectProtocol = self.patch(connectionpoolModule, "connectProtocol")
1484+ connectProtocol.return_value = connection
1485+ result = yield cp.connect("an-event-loop", ("a.example.com", 1111))
1486+ self.assertEqual(len(connectProtocol.call_args_list), 1)
1487+ connectProtocol.called_once_with(
1488+ TCP6ClientEndpoint(reactor=clock, host="a.example.com", port=1111),
1489+ ClusterClient(
1490+ address=("a.example.com", 1111),
1491+ eventloop="an-event-loop",
1492+ service=service,
1493+ ),
1494+ )
1495+ self.assertEqual(result, connection)
1496+
1497+ def test_drop_connection(self):
1498+ connection = Mock()
1499+ cp = ConnectionPool(Clock(), Mock())
1500+ cp.disconnect(connection)
1501+ connection.transport.loseConnection.assert_called_once_with()
1502+
1503+ @inlineCallbacks
1504+ def test_add_connection_adds_the_staged_connection(self):
1505+ eventloop = Mock()
1506+ connection = Mock()
1507+ cp = ConnectionPool(Clock(), Mock())
1508+ cp.try_connections = {eventloop: connection}
1509+ yield cp.add_connection(eventloop, connection)
1510+ self.assertIn(connection, cp.get_all_connections())
1511+
1512+ def test_remove_connection_removes_connection_from_pool(self):
1513+ eventloop = Mock()
1514+ connection = Mock()
1515+ cp = ConnectionPool(Clock(), Mock())
1516+ cp.connections[eventloop] = [connection]
1517+ cp.remove_connection(eventloop, connection)
1518+ self.assertEqual(cp.connections, {})

Subscribers

People subscribed via source and target branches