Skip to content

Commit 44952da

Browse files
committed
feat(node): add round-robin host shuffling capability
- add `round_robin_hosts` config option to `ConfigDict` typed dict - implement random shuffling of healthy nodes in `NodeManager.get_node()` - add test to verify shuffling behavior with round-robin enabled
1 parent b0d477b commit 44952da

File tree

3 files changed

+94
-6
lines changed

3 files changed

+94
-6
lines changed

src/typesense/configuration.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class ConfigDict(typing.TypedDict):
8080
dictionaries or URLs that represent the read replica nodes.
8181
8282
connection_timeout_seconds (float): The connection timeout in seconds.
83+
84+
round_robin_hosts (bool): Whether or not to shuffle hosts between requests
8385
"""
8486

8587
nodes: typing.List[typing.Union[str, NodeConfigDict]]
@@ -96,6 +98,7 @@ class ConfigDict(typing.TypedDict):
9698
typing.List[typing.Union[str, NodeConfigDict]]
9799
] # deprecated
98100
connection_timeout_seconds: typing.NotRequired[float]
101+
round_robin_hosts: typing.NotRequired[bool]
99102

100103

101104
class Node:
@@ -184,6 +187,7 @@ class Configuration:
184187
retry_interval_seconds (float): The interval in seconds between retries.
185188
healthcheck_interval_seconds (int): The interval in seconds between health checks.
186189
verify (bool): Whether to verify the SSL certificate.
190+
round_robin_hosts (bool): Whether or not to shuffle hosts between requests
187191
"""
188192

189193
def __init__(
@@ -219,6 +223,7 @@ def __init__(
219223
60,
220224
)
221225
self.verify = config_dict.get("verify", True)
226+
self.round_robin_hosts = config_dict.get("round_robin_hosts", False)
222227
self.additional_headers = config_dict.get("additional_headers", {})
223228

224229
def _handle_nearest_node(

src/typesense/node_manager.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131

3232
import copy
3333
import time
34+
import random
35+
import sys
36+
if sys.version_info >= (3, 11):
37+
import typing
38+
else:
39+
import typing_extensions as typing
3440

3541
from typesense.configuration import Configuration, Node
3642
from typesense.logger import logger
@@ -71,12 +77,68 @@ def get_node(self) -> Node:
7177
Returns:
7278
Node: The selected node for the next operation.
7379
"""
74-
if self.config.nearest_node:
75-
if self.config.nearest_node.healthy or self._is_due_for_health_check(
76-
self.config.nearest_node,
77-
):
78-
return self.config.nearest_node
80+
if self._should_use_nearest_node():
81+
return self.config.nearest_node
82+
83+
healthy_nodes = self._get_healthy_nodes()
84+
85+
if not healthy_nodes:
86+
logger.debug("No healthy nodes were found. Returning the next node.")
87+
return self.nodes[self.node_index]
88+
89+
if self.config.round_robin_hosts:
90+
return self._get_shuffled_node(healthy_nodes)
91+
92+
return self._get_next_round_robin_node()
93+
94+
def _should_use_nearest_node(self) -> bool:
95+
"""
96+
Check if we should use the nearest node.
97+
98+
Returns:
99+
bool: True if nearest node should be used, False otherwise.
100+
"""
101+
return bool(
102+
self.config.nearest_node
103+
and (
104+
self.config.nearest_node.healthy
105+
or self._is_due_for_health_check(self.config.nearest_node)
106+
)
107+
)
108+
109+
def _get_healthy_nodes(self) -> typing.List[Node]:
110+
"""
111+
Get a list of all healthy nodes.
112+
113+
Returns:
114+
List[Node]: List of healthy nodes.
115+
"""
116+
return [
117+
node for node in self.nodes
118+
if node.healthy or self._is_due_for_health_check(node)
119+
]
120+
121+
def _get_shuffled_node(self, healthy_nodes: typing.List[Node]) -> Node:
122+
"""
123+
Get a randomly shuffled node from the list of healthy nodes.
79124
125+
Args:
126+
healthy_nodes (List[Node]): List of healthy nodes to choose from.
127+
128+
Returns:
129+
Node: A randomly selected healthy node.
130+
"""
131+
random.shuffle(healthy_nodes)
132+
self.node_index = (self.node_index + 1) % len(self.nodes)
133+
return healthy_nodes[0]
134+
135+
def _get_next_round_robin_node(self) -> Node:
136+
"""
137+
Get the next node using standard round-robin selection.
138+
139+
Returns:
140+
Node: The next node in the round-robin sequence.
141+
"""
80142
node_index = 0
81143
while node_index < len(self.nodes):
82144
node_index += 1
@@ -85,7 +147,6 @@ def get_node(self) -> Node:
85147
if node.healthy or self._is_due_for_health_check(node):
86148
return node
87149

88-
logger.debug("No healthy nodes were found. Returning the next node.")
89150
return self.nodes[self.node_index]
90151

91152
def set_node_health(self, node: Node, is_healthy: bool) -> None:

tests/api_call_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,28 @@ def test_get_node_round_robin_selection(
8080
assert_match_object(node3, fake_api_call.config.nodes[2])
8181

8282

83+
def test_get_node_round_robin_shuffle(
84+
fake_api_call: ApiCall,
85+
mocker: MockerFixture,
86+
) -> None:
87+
"""Test that it shuffles healthy nodes when round_robin_hosts is true."""
88+
fake_api_call.config.nearest_node = None
89+
fake_api_call.config.round_robin_hosts = True
90+
mocker.patch("time.time", return_value=100)
91+
92+
shuffle_mock = mocker.patch("random.shuffle")
93+
94+
for _ in range(3):
95+
fake_api_call.node_manager.get_node()
96+
97+
assert shuffle_mock.call_count == 3
98+
99+
for call in shuffle_mock.call_args_list:
100+
args = call[0][0]
101+
assert isinstance(args, list)
102+
assert all(node.healthy for node in args)
103+
104+
83105
def test_get_exception() -> None:
84106
"""Test that it correctly returns the exception class for a given status code."""
85107
assert RequestHandler._get_exception(0) == exceptions.HTTPStatus0Error

0 commit comments

Comments
 (0)