Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 113 additions & 13 deletions src/main/java/io/qdrant/client/QdrantClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@
import io.qdrant.client.grpc.SnapshotsService.ListSnapshotsResponse;
import io.qdrant.client.grpc.SnapshotsService.SnapshotDescription;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
Expand All @@ -127,15 +129,96 @@
/** Client for the Qdrant vector database. */
public class QdrantClient implements AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(QdrantClient.class);
private final QdrantGrpcClient grpcClient;
private final List<QdrantGrpcClient> grpcClients;
private final AtomicInteger nextClientIndex = new AtomicInteger(0);

/**
* Creates a new instance of {@link QdrantClient}
*
* @param grpcClient The low-level gRPC client to use.
*/
public QdrantClient(QdrantGrpcClient grpcClient) {
this.grpcClient = grpcClient;
this.grpcClients = new ArrayList<>(1);
this.grpcClients.add(grpcClient);
}

/**
* Creates a new instance of {@link QdrantClient} with connection pooling. Creates multiple
* independent gRPC connections with the same configuration.
*
* @param host The host to connect to.
* @param port The port to connect to.
* @param useTransportLayerSecurity Whether to use TLS.
* @param poolSize The number of gRPC clients to create in the pool. Must be at least 1.
* @param apiKey The API key for authentication.
* @param timeout The default timeout for requests.
*/
public QdrantClient(
String host,
int port,
boolean useTransportLayerSecurity,
int poolSize,
@Nullable String apiKey,
@Nullable Duration timeout) {
if (poolSize <= 0) {
throw new IllegalArgumentException("Pool size must be at least 1");
}

this.grpcClients = new ArrayList<>(poolSize);

// Create clients for the pool - each with its own independent connection
for (int i = 0; i < poolSize; i++) {
// For the first client, check compatibility. For others, skip to avoid redundant checks
boolean checkCompatibility = (i == 0);
QdrantGrpcClient.Builder builder =
QdrantGrpcClient.newBuilder(host, port, useTransportLayerSecurity, checkCompatibility);

if (apiKey != null) {
builder.withApiKey(apiKey);
}
if (timeout != null) {
builder.withTimeout(timeout);
}

this.grpcClients.add(builder.build());
}
}

/**
* Creates a new instance of {@link QdrantClient} with connection pooling. Creates multiple
* independent gRPC connections with the same configuration.
*
* @param host The host to connect to.
* @param port The port to connect to.
* @param useTransportLayerSecurity Whether to use TLS.
* @param poolSize The number of gRPC clients to create in the pool. Must be at least 1.
*/
public QdrantClient(String host, int port, boolean useTransportLayerSecurity, int poolSize) {
this(host, port, useTransportLayerSecurity, poolSize, null, null);
}

/**
* Creates a new instance of {@link QdrantClient} with default connection pooling (pool size = 3).
*
* @param host The host to connect to.
* @param port The port to connect to.
* @param useTransportLayerSecurity Whether to use TLS.
*/
public QdrantClient(String host, int port, boolean useTransportLayerSecurity) {
this(host, port, useTransportLayerSecurity, 3);
}

/**
* Creates a new instance of {@link QdrantClient} with a custom list of gRPC clients for pooling.
*
* @param grpcClients The list of gRPC clients to use for pooling. Must not be null or empty.
*/
public QdrantClient(List<QdrantGrpcClient> grpcClients) {
if (grpcClients == null || grpcClients.isEmpty()) {
throw new IllegalArgumentException("gRPC clients list cannot be null or empty");
}

this.grpcClients = new ArrayList<>(grpcClients);
}

/**
Expand All @@ -147,10 +230,16 @@ public QdrantClient(QdrantGrpcClient grpcClient) {
* where functionality may not yet be exposed by the higher level client.
* </ul>
*
* @return The low-level gRPC client
* @return The low-level gRPC client in a round-robin fashion.
*/
public QdrantGrpcClient grpcClient() {
return grpcClient;
if (grpcClients.size() == 1) {
return grpcClients.get(0);
}

// Atomically increment and wrap around the counter for round-robin selection
int index = nextClientIndex.getAndIncrement() % grpcClients.size();
return grpcClients.get(index);
}

/**
Expand All @@ -171,8 +260,10 @@ public ListenableFuture<HealthCheckReply> healthCheckAsync() {
public ListenableFuture<HealthCheckReply> healthCheckAsync(@Nullable Duration timeout) {
QdrantFutureStub qdrant =
timeout != null
? this.grpcClient.qdrant().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient.qdrant();
? this.grpcClient()
.qdrant()
.withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient().qdrant();
return qdrant.healthCheck(HealthCheckRequest.getDefaultInstance());
}

Expand Down Expand Up @@ -3083,7 +3174,14 @@ public ListenableFuture<DeleteSnapshotResponse> deleteFullSnapshotAsync(

@Override
public void close() {
grpcClient.close();
// Close all clients in the pool
for (QdrantGrpcClient client : grpcClients) {
try {
client.close();
} catch (Exception e) {
logger.warn("Failed to close gRPC client in pool", e);
}
}
}

private <V> void addLogFailureCallback(ListenableFuture<V> future, String message) {
Expand All @@ -3103,19 +3201,21 @@ public void onFailure(Throwable t) {

private CollectionsGrpc.CollectionsFutureStub getCollections(@Nullable Duration timeout) {
return timeout != null
? this.grpcClient.collections().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient.collections();
? this.grpcClient()
.collections()
.withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient().collections();
}

private PointsGrpc.PointsFutureStub getPoints(@Nullable Duration timeout) {
return timeout != null
? this.grpcClient.points().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient.points();
? this.grpcClient().points().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient().points();
}

private SnapshotsGrpc.SnapshotsFutureStub getSnapshots(@Nullable Duration timeout) {
return timeout != null
? this.grpcClient.snapshots().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient.snapshots();
? this.grpcClient().snapshots().withDeadlineAfter(timeout.toMillis(), TimeUnit.MILLISECONDS)
: this.grpcClient().snapshots();
}
}
36 changes: 36 additions & 0 deletions src/test/java/io/qdrant/client/QdrantClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,40 @@ public void teardown() {
void canAccessChannelOnGrpcClient() {
Assertions.assertTrue(client.grpcClient().channel().authority().startsWith("localhost"));
}

@Test
void connectionPoolingCreatesMultipleConnections() {
String host = QDRANT_CONTAINER.getHost();
int port = QDRANT_CONTAINER.getGrpcPort();

QdrantClient pooledClient = new QdrantClient(host, port, false, 3);

try {
QdrantGrpcClient client1 = pooledClient.grpcClient();
QdrantGrpcClient client2 = pooledClient.grpcClient();
QdrantGrpcClient client3 = pooledClient.grpcClient();
QdrantGrpcClient client4 = pooledClient.grpcClient(); // Should wrap around to first

Assertions.assertSame(client1, client4); // Should wrap around to first client

// Verify that different clients have different channels (true connection pooling)
Assertions.assertNotSame(client1.channel(), client2.channel());
Assertions.assertNotSame(client2.channel(), client3.channel());
} finally {
pooledClient.close();
}
}

@Test
void defaultConnectionPoolingWorks() {
String host = QDRANT_CONTAINER.getHost();
int port = QDRANT_CONTAINER.getGrpcPort();
QdrantClient defaultClient = new QdrantClient(host, port, false);

try {
Assertions.assertNotNull(defaultClient.grpcClient());
} finally {
defaultClient.close();
}
}
}
Loading