diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index 7ed2de21..64791eed 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -116,9 +116,11 @@ import io.qdrant.client.grpc.SnapshotsService.ListSnapshotsResponse; import io.qdrant.client.grpc.SnapshotsService.SnapshotDescription; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import javax.annotation.Nullable; import org.slf4j.Logger; @@ -127,7 +129,8 @@ /** Client for the Qdrant vector database. */ public class QdrantClient implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(QdrantClient.class); - private final QdrantGrpcClient grpcClient; + private final List grpcClients; + private final AtomicInteger nextClientIndex = new AtomicInteger(0); /** * Creates a new instance of {@link QdrantClient} @@ -135,7 +138,87 @@ public class QdrantClient implements AutoCloseable { * @param grpcClient The low-level gRPC client to use. */ public QdrantClient(QdrantGrpcClient grpcClient) { - this.grpcClient = grpcClient; + this.grpcClients = new ArrayList<>(1); + this.grpcClients.add(grpcClient); + } + + /** + * Creates a new instance of {@link QdrantClient} with connection pooling. Creates multiple + * independent gRPC connections with the same configuration. + * + * @param host The host to connect to. + * @param port The port to connect to. + * @param useTransportLayerSecurity Whether to use TLS. + * @param poolSize The number of gRPC clients to create in the pool. Must be at least 1. + * @param apiKey The API key for authentication. + * @param timeout The default timeout for requests. + */ + public QdrantClient( + String host, + int port, + boolean useTransportLayerSecurity, + int poolSize, + @Nullable String apiKey, + @Nullable Duration timeout) { + if (poolSize <= 0) { + throw new IllegalArgumentException("Pool size must be at least 1"); + } + + this.grpcClients = new ArrayList<>(poolSize); + + // Create clients for the pool - each with its own independent connection + for (int i = 0; i < poolSize; i++) { + // For the first client, check compatibility. For others, skip to avoid redundant checks + boolean checkCompatibility = (i == 0); + QdrantGrpcClient.Builder builder = + QdrantGrpcClient.newBuilder(host, port, useTransportLayerSecurity, checkCompatibility); + + if (apiKey != null) { + builder.withApiKey(apiKey); + } + if (timeout != null) { + builder.withTimeout(timeout); + } + + this.grpcClients.add(builder.build()); + } + } + + /** + * Creates a new instance of {@link QdrantClient} with connection pooling. Creates multiple + * independent gRPC connections with the same configuration. + * + * @param host The host to connect to. + * @param port The port to connect to. + * @param useTransportLayerSecurity Whether to use TLS. + * @param poolSize The number of gRPC clients to create in the pool. Must be at least 1. + */ + public QdrantClient(String host, int port, boolean useTransportLayerSecurity, int poolSize) { + this(host, port, useTransportLayerSecurity, poolSize, null, null); + } + + /** + * Creates a new instance of {@link QdrantClient} with default connection pooling (pool size = 3). + * + * @param host The host to connect to. + * @param port The port to connect to. + * @param useTransportLayerSecurity Whether to use TLS. + */ + public QdrantClient(String host, int port, boolean useTransportLayerSecurity) { + this(host, port, useTransportLayerSecurity, 3); + } + + /** + * Creates a new instance of {@link QdrantClient} with a custom list of gRPC clients for pooling. + * + * @param grpcClients The list of gRPC clients to use for pooling. Must not be null or empty. + */ + public QdrantClient(List grpcClients) { + if (grpcClients == null || grpcClients.isEmpty()) { + throw new IllegalArgumentException("gRPC clients list cannot be null or empty"); + } + + this.grpcClients = new ArrayList<>(grpcClients); } /** @@ -147,10 +230,16 @@ public QdrantClient(QdrantGrpcClient grpcClient) { * where functionality may not yet be exposed by the higher level client. * * - * @return The low-level gRPC client + * @return The low-level gRPC client in a round-robin fashion. */ public QdrantGrpcClient grpcClient() { - return grpcClient; + if (grpcClients.size() == 1) { + return grpcClients.get(0); + } + + // Atomically increment and wrap around the counter for round-robin selection + int index = nextClientIndex.getAndIncrement() % grpcClients.size(); + return grpcClients.get(index); } /** @@ -171,8 +260,10 @@ public ListenableFuture healthCheckAsync() { public ListenableFuture healthCheckAsync(@Nullable Duration timeout) { QdrantFutureStub qdrant = timeout != null - ? this.grpcClient.qdrant().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) - : this.grpcClient.qdrant(); + ? this.grpcClient() + .qdrant() + .withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) + : this.grpcClient().qdrant(); return qdrant.healthCheck(HealthCheckRequest.getDefaultInstance()); } @@ -3083,7 +3174,14 @@ public ListenableFuture deleteFullSnapshotAsync( @Override public void close() { - grpcClient.close(); + // Close all clients in the pool + for (QdrantGrpcClient client : grpcClients) { + try { + client.close(); + } catch (Exception e) { + logger.warn("Failed to close gRPC client in pool", e); + } + } } private void addLogFailureCallback(ListenableFuture future, String message) { @@ -3103,19 +3201,21 @@ public void onFailure(Throwable t) { private CollectionsGrpc.CollectionsFutureStub getCollections(@Nullable Duration timeout) { return timeout != null - ? this.grpcClient.collections().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) - : this.grpcClient.collections(); + ? this.grpcClient() + .collections() + .withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) + : this.grpcClient().collections(); } private PointsGrpc.PointsFutureStub getPoints(@Nullable Duration timeout) { return timeout != null - ? this.grpcClient.points().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) - : this.grpcClient.points(); + ? this.grpcClient().points().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) + : this.grpcClient().points(); } private SnapshotsGrpc.SnapshotsFutureStub getSnapshots(@Nullable Duration timeout) { return timeout != null - ? this.grpcClient.snapshots().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) - : this.grpcClient.snapshots(); + ? this.grpcClient().snapshots().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS) + : this.grpcClient().snapshots(); } } diff --git a/src/test/java/io/qdrant/client/QdrantClientTest.java b/src/test/java/io/qdrant/client/QdrantClientTest.java index 624768f8..3da3b663 100644 --- a/src/test/java/io/qdrant/client/QdrantClientTest.java +++ b/src/test/java/io/qdrant/client/QdrantClientTest.java @@ -39,4 +39,40 @@ public void teardown() { void canAccessChannelOnGrpcClient() { Assertions.assertTrue(client.grpcClient().channel().authority().startsWith("localhost")); } + + @Test + void connectionPoolingCreatesMultipleConnections() { + String host = QDRANT_CONTAINER.getHost(); + int port = QDRANT_CONTAINER.getGrpcPort(); + + QdrantClient pooledClient = new QdrantClient(host, port, false, 3); + + try { + QdrantGrpcClient client1 = pooledClient.grpcClient(); + QdrantGrpcClient client2 = pooledClient.grpcClient(); + QdrantGrpcClient client3 = pooledClient.grpcClient(); + QdrantGrpcClient client4 = pooledClient.grpcClient(); // Should wrap around to first + + Assertions.assertSame(client1, client4); // Should wrap around to first client + + // Verify that different clients have different channels (true connection pooling) + Assertions.assertNotSame(client1.channel(), client2.channel()); + Assertions.assertNotSame(client2.channel(), client3.channel()); + } finally { + pooledClient.close(); + } + } + + @Test + void defaultConnectionPoolingWorks() { + String host = QDRANT_CONTAINER.getHost(); + int port = QDRANT_CONTAINER.getGrpcPort(); + QdrantClient defaultClient = new QdrantClient(host, port, false); + + try { + Assertions.assertNotNull(defaultClient.grpcClient()); + } finally { + defaultClient.close(); + } + } }