Skip to content

Commit 3ef1fda

Browse files
committed
feat: Connection pooling
Signed-off-by: Anush008 <anushshetty90@gmail.com>
1 parent af30bab commit 3ef1fda

File tree

2 files changed

+137
-13
lines changed

2 files changed

+137
-13
lines changed

src/main/java/io/qdrant/client/QdrantClient.java

Lines changed: 101 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,11 @@
116116
import io.qdrant.client.grpc.SnapshotsService.ListSnapshotsResponse;
117117
import io.qdrant.client.grpc.SnapshotsService.SnapshotDescription;
118118
import java.time.Duration;
119+
import java.util.ArrayList;
119120
import java.util.List;
120121
import java.util.Map;
121122
import java.util.concurrent.TimeUnit;
123+
import java.util.concurrent.atomic.AtomicInteger;
122124
import java.util.stream.Collectors;
123125
import javax.annotation.Nullable;
124126
import org.slf4j.Logger;
@@ -127,15 +129,83 @@
127129
/** Client for the Qdrant vector database. */
128130
public class QdrantClient implements AutoCloseable {
129131
private static final Logger logger = LoggerFactory.getLogger(QdrantClient.class);
130-
private final QdrantGrpcClient grpcClient;
132+
private final List<QdrantGrpcClient> grpcClients;
133+
private final AtomicInteger nextClientIndex = new AtomicInteger(0);
131134

132135
/**
133136
* Creates a new instance of {@link QdrantClient}
134137
*
135138
* @param grpcClient The low-level gRPC client to use.
136139
*/
137140
public QdrantClient(QdrantGrpcClient grpcClient) {
138-
this.grpcClient = grpcClient;
141+
this.grpcClients = new ArrayList<>(1);
142+
this.grpcClients.add(grpcClient);
143+
}
144+
145+
/**
146+
* Creates a new instance of {@link QdrantClient} with connection pooling. Creates multiple
147+
* independent gRPC connections with the same configuration.
148+
*
149+
* @param host The host to connect to.
150+
* @param port The port to connect to.
151+
* @param useTransportLayerSecurity Whether to use TLS.
152+
* @param poolSize The number of gRPC clients to create in the pool. Must be at least 1.
153+
* @param apiKey The API key for authentication.
154+
* @param timeout The default timeout for requests.
155+
*/
156+
public QdrantClient(
157+
String host,
158+
int port,
159+
boolean useTransportLayerSecurity,
160+
int poolSize,
161+
@Nullable String apiKey,
162+
@Nullable Duration timeout) {
163+
if (poolSize <= 0) {
164+
throw new IllegalArgumentException("Pool size must be at least 1");
165+
}
166+
167+
this.grpcClients = new ArrayList<>(poolSize);
168+
169+
// Create clients for the pool - each with its own independent connection
170+
for (int i = 0; i < poolSize; i++) {
171+
// For the first client, check compatibility. For others, skip to avoid redundant checks
172+
boolean checkCompatibility = (i == 0);
173+
QdrantGrpcClient.Builder builder =
174+
QdrantGrpcClient.newBuilder(host, port, useTransportLayerSecurity, checkCompatibility);
175+
176+
if (apiKey != null) {
177+
builder.withApiKey(apiKey);
178+
}
179+
if (timeout != null) {
180+
builder.withTimeout(timeout);
181+
}
182+
183+
this.grpcClients.add(builder.build());
184+
}
185+
}
186+
187+
/**
188+
* Creates a new instance of {@link QdrantClient} with connection pooling. Creates multiple
189+
* independent gRPC connections with the same configuration.
190+
*
191+
* @param host The host to connect to.
192+
* @param port The port to connect to.
193+
* @param useTransportLayerSecurity Whether to use TLS.
194+
* @param poolSize The number of gRPC clients to create in the pool. Must be at least 1.
195+
*/
196+
public QdrantClient(String host, int port, boolean useTransportLayerSecurity, int poolSize) {
197+
this(host, port, useTransportLayerSecurity, poolSize, null, null);
198+
}
199+
200+
/**
201+
* Creates a new instance of {@link QdrantClient} with default connection pooling (pool size = 3).
202+
*
203+
* @param host The host to connect to.
204+
* @param port The port to connect to.
205+
* @param useTransportLayerSecurity Whether to use TLS.
206+
*/
207+
public QdrantClient(String host, int port, boolean useTransportLayerSecurity) {
208+
this(host, port, useTransportLayerSecurity, 3);
139209
}
140210

141211
/**
@@ -147,10 +217,17 @@ public QdrantClient(QdrantGrpcClient grpcClient) {
147217
* where functionality may not yet be exposed by the higher level client.
148218
* </ul>
149219
*
150-
* @return The low-level gRPC client
220+
* @return The low-level gRPC client. If connection pooling is enabled, returns the next client in
221+
* round-robin fashion.
151222
*/
152223
public QdrantGrpcClient grpcClient() {
153-
return grpcClient;
224+
if (grpcClients.size() == 1) {
225+
return grpcClients.get(0);
226+
}
227+
228+
// Atomically increment and wrap around the counter for round-robin selection
229+
int index = nextClientIndex.getAndIncrement() % grpcClients.size();
230+
return grpcClients.get(index);
154231
}
155232

156233
/**
@@ -171,8 +248,10 @@ public ListenableFuture<HealthCheckReply> healthCheckAsync() {
171248
public ListenableFuture<HealthCheckReply> healthCheckAsync(@Nullable Duration timeout) {
172249
QdrantFutureStub qdrant =
173250
timeout != null
174-
? this.grpcClient.qdrant().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
175-
: this.grpcClient.qdrant();
251+
? this.grpcClient()
252+
.qdrant()
253+
.withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
254+
: this.grpcClient().qdrant();
176255
return qdrant.healthCheck(HealthCheckRequest.getDefaultInstance());
177256
}
178257

@@ -3083,7 +3162,14 @@ public ListenableFuture<DeleteSnapshotResponse> deleteFullSnapshotAsync(
30833162

30843163
@Override
30853164
public void close() {
3086-
grpcClient.close();
3165+
// Close all clients in the pool
3166+
for (QdrantGrpcClient client : grpcClients) {
3167+
try {
3168+
client.close();
3169+
} catch (Exception e) {
3170+
logger.warn("Failed to close gRPC client in pool", e);
3171+
}
3172+
}
30873173
}
30883174

30893175
private <V> void addLogFailureCallback(ListenableFuture<V> future, String message) {
@@ -3103,19 +3189,21 @@ public void onFailure(Throwable t) {
31033189

31043190
private CollectionsGrpc.CollectionsFutureStub getCollections(@Nullable Duration timeout) {
31053191
return timeout != null
3106-
? this.grpcClient.collections().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
3107-
: this.grpcClient.collections();
3192+
? this.grpcClient()
3193+
.collections()
3194+
.withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
3195+
: this.grpcClient().collections();
31083196
}
31093197

31103198
private PointsGrpc.PointsFutureStub getPoints(@Nullable Duration timeout) {
31113199
return timeout != null
3112-
? this.grpcClient.points().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
3113-
: this.grpcClient.points();
3200+
? this.grpcClient().points().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
3201+
: this.grpcClient().points();
31143202
}
31153203

31163204
private SnapshotsGrpc.SnapshotsFutureStub getSnapshots(@Nullable Duration timeout) {
31173205
return timeout != null
3118-
? this.grpcClient.snapshots().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
3119-
: this.grpcClient.snapshots();
3206+
? this.grpcClient().snapshots().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
3207+
: this.grpcClient().snapshots();
31203208
}
31213209
}

src/test/java/io/qdrant/client/QdrantClientTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,40 @@ public void teardown() {
3939
void canAccessChannelOnGrpcClient() {
4040
Assertions.assertTrue(client.grpcClient().channel().authority().startsWith("localhost"));
4141
}
42+
43+
@Test
44+
void connectionPoolingCreatesMultipleConnections() {
45+
String host = QDRANT_CONTAINER.getHost();
46+
int port = QDRANT_CONTAINER.getGrpcPort();
47+
48+
QdrantClient pooledClient = new QdrantClient(host, port, false, 3);
49+
50+
try {
51+
QdrantGrpcClient client1 = pooledClient.grpcClient();
52+
QdrantGrpcClient client2 = pooledClient.grpcClient();
53+
QdrantGrpcClient client3 = pooledClient.grpcClient();
54+
QdrantGrpcClient client4 = pooledClient.grpcClient(); // Should wrap around to first
55+
56+
Assertions.assertSame(client1, client4); // Should wrap around to first client
57+
58+
// Verify that different clients have different channels (true connection pooling)
59+
Assertions.assertNotSame(client1.channel(), client2.channel());
60+
Assertions.assertNotSame(client2.channel(), client3.channel());
61+
} finally {
62+
pooledClient.close();
63+
}
64+
}
65+
66+
@Test
67+
void defaultConnectionPoolingWorks() {
68+
String host = QDRANT_CONTAINER.getHost();
69+
int port = QDRANT_CONTAINER.getGrpcPort();
70+
QdrantClient defaultClient = new QdrantClient(host, port, false);
71+
72+
try {
73+
Assertions.assertNotNull(defaultClient.grpcClient());
74+
} finally {
75+
defaultClient.close();
76+
}
77+
}
4278
}

0 commit comments

Comments
 (0)