diff --git a/.gitignore b/.gitignore index 1c58ed192d..687822e71b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ plugin/build/ .DS_Store */bin/ **/*.factorypath +**/out diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java index 4d617ce896..44d7456eb3 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java @@ -35,17 +35,17 @@ public class ConnectorClientConfig implements ToXContentObject, Writeable { public static final String MAX_RETRY_TIMES_FIELD = "max_retry_times"; public static final String RETRY_BACKOFF_POLICY_FIELD = "retry_backoff_policy"; - public static final Integer MAX_CONNECTION_DEFAULT_VALUE = Integer.valueOf(30); - public static final Integer CONNECTION_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000); - public static final Integer READ_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000); - public static final Integer RETRY_BACKOFF_MILLIS_DEFAULT_VALUE = 200; - public static final Integer RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE = 30; - public static final Integer MAX_RETRY_TIMES_DEFAULT_VALUE = 0; + public static final int MAX_CONNECTION_DEFAULT_VALUE = 30; + public static final int CONNECTION_TIMEOUT_DEFAULT_VALUE = 1000; + public static final int READ_TIMEOUT_DEFAULT_VALUE = 10; + public static final int RETRY_BACKOFF_MILLIS_DEFAULT_VALUE = 200; + public static final int RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE = 30; + public static final int MAX_RETRY_TIMES_DEFAULT_VALUE = 0; public static final RetryBackoffPolicy RETRY_BACKOFF_POLICY_DEFAULT_VALUE = RetryBackoffPolicy.CONSTANT; public static final Version MINIMAL_SUPPORTED_VERSION_FOR_RETRY = Version.V_2_15_0; private Integer maxConnections; - private Integer connectionTimeout; - private Integer readTimeout; + private Integer connectionTimeoutMillis; + private Integer readTimeoutSeconds; private Integer retryBackoffMillis; private Integer retryTimeoutSeconds; private Integer maxRetryTimes; @@ -54,16 +54,16 @@ public class ConnectorClientConfig implements ToXContentObject, Writeable { @Builder(toBuilder = true) public ConnectorClientConfig( Integer maxConnections, - Integer connectionTimeout, - Integer readTimeout, + Integer connectionTimeoutMillis, + Integer readTimeoutSeconds, Integer retryBackoffMillis, Integer retryTimeoutSeconds, Integer maxRetryTimes, RetryBackoffPolicy retryBackoffPolicy ) { this.maxConnections = maxConnections; - this.connectionTimeout = connectionTimeout; - this.readTimeout = readTimeout; + this.connectionTimeoutMillis = connectionTimeoutMillis; + this.readTimeoutSeconds = readTimeoutSeconds; this.retryBackoffMillis = retryBackoffMillis; this.retryTimeoutSeconds = retryTimeoutSeconds; this.maxRetryTimes = maxRetryTimes; @@ -73,8 +73,8 @@ public ConnectorClientConfig( public ConnectorClientConfig(StreamInput input) throws IOException { Version streamInputVersion = input.getVersion(); this.maxConnections = input.readOptionalInt(); - this.connectionTimeout = input.readOptionalInt(); - this.readTimeout = input.readOptionalInt(); + this.connectionTimeoutMillis = input.readOptionalInt(); + this.readTimeoutSeconds = input.readOptionalInt(); if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) { this.retryBackoffMillis = input.readOptionalInt(); this.retryTimeoutSeconds = input.readOptionalInt(); @@ -87,8 +87,8 @@ public ConnectorClientConfig(StreamInput input) throws IOException { public ConnectorClientConfig() { this.maxConnections = MAX_CONNECTION_DEFAULT_VALUE; - this.connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE; - this.readTimeout = READ_TIMEOUT_DEFAULT_VALUE; + this.connectionTimeoutMillis = CONNECTION_TIMEOUT_DEFAULT_VALUE; + this.readTimeoutSeconds = READ_TIMEOUT_DEFAULT_VALUE; this.retryBackoffMillis = RETRY_BACKOFF_MILLIS_DEFAULT_VALUE; this.retryTimeoutSeconds = RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE; this.maxRetryTimes = MAX_RETRY_TIMES_DEFAULT_VALUE; @@ -99,8 +99,8 @@ public ConnectorClientConfig() { public void writeTo(StreamOutput out) throws IOException { Version streamOutputVersion = out.getVersion(); out.writeOptionalInt(maxConnections); - out.writeOptionalInt(connectionTimeout); - out.writeOptionalInt(readTimeout); + out.writeOptionalInt(connectionTimeoutMillis); + out.writeOptionalInt(readTimeoutSeconds); if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) { out.writeOptionalInt(retryBackoffMillis); out.writeOptionalInt(retryTimeoutSeconds); @@ -120,11 +120,11 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (maxConnections != null) { builder.field(MAX_CONNECTION_FIELD, maxConnections); } - if (connectionTimeout != null) { - builder.field(CONNECTION_TIMEOUT_FIELD, connectionTimeout); + if (connectionTimeoutMillis != null) { + builder.field(CONNECTION_TIMEOUT_FIELD, connectionTimeoutMillis); } - if (readTimeout != null) { - builder.field(READ_TIMEOUT_FIELD, readTimeout); + if (readTimeoutSeconds != null) { + builder.field(READ_TIMEOUT_FIELD, readTimeoutSeconds); } if (retryBackoffMillis != null) { builder.field(RETRY_BACKOFF_MILLIS_FIELD, retryBackoffMillis); @@ -190,8 +190,8 @@ public static ConnectorClientConfig parse(XContentParser parser) throws IOExcept return ConnectorClientConfig .builder() .maxConnections(maxConnections) - .connectionTimeout(connectionTimeout) - .readTimeout(readTimeout) + .connectionTimeoutMillis(connectionTimeout) + .readTimeoutSeconds(readTimeout) .retryBackoffMillis(retryBackoffMillis) .retryTimeoutSeconds(retryTimeoutSeconds) .maxRetryTimes(maxRetryTimes) diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorClientConfigTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorClientConfigTest.java index 4cf2a08f60..cea1d30bab 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorClientConfigTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorClientConfigTest.java @@ -24,8 +24,8 @@ public void writeTo_ReadFromStream() throws IOException { ConnectorClientConfig config = ConnectorClientConfig .builder() .maxConnections(10) - .connectionTimeout(5000) - .readTimeout(3000) + .connectionTimeoutMillis(1000) + .readTimeoutSeconds(3) .retryBackoffMillis(123) .retryTimeoutSeconds(456) .maxRetryTimes(789) @@ -55,8 +55,8 @@ public void writeTo_ReadFromStream_diffVersionThenNotProcessRetryOptions() throw ConnectorClientConfig config = ConnectorClientConfig .builder() .maxConnections(10) - .connectionTimeout(5000) - .readTimeout(3000) + .connectionTimeoutMillis(1000) + .readTimeoutSeconds(3) .retryBackoffMillis(123) .retryTimeoutSeconds(456) .maxRetryTimes(789) @@ -71,8 +71,8 @@ public void writeTo_ReadFromStream_diffVersionThenNotProcessRetryOptions() throw ConnectorClientConfig readConfig = ConnectorClientConfig.fromStream(input); Assert.assertEquals(Integer.valueOf(10), readConfig.getMaxConnections()); - Assert.assertEquals(Integer.valueOf(5000), readConfig.getConnectionTimeout()); - Assert.assertEquals(Integer.valueOf(3000), readConfig.getReadTimeout()); + Assert.assertEquals(Integer.valueOf(1000), readConfig.getConnectionTimeoutMillis()); + Assert.assertEquals(Integer.valueOf(3), readConfig.getReadTimeoutSeconds()); Assert.assertNull(readConfig.getRetryBackoffMillis()); Assert.assertNull(readConfig.getRetryTimeoutSeconds()); Assert.assertNull(readConfig.getMaxRetryTimes()); @@ -84,8 +84,8 @@ public void toXContent() throws IOException { ConnectorClientConfig config = ConnectorClientConfig .builder() .maxConnections(10) - .connectionTimeout(5000) - .readTimeout(3000) + .connectionTimeoutMillis(1000) + .readTimeoutSeconds(3) .retryBackoffMillis(123) .retryTimeoutSeconds(456) .maxRetryTimes(789) @@ -96,14 +96,14 @@ public void toXContent() throws IOException { config.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - String expectedJson = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," + String expectedJson = "{\"max_connection\":10,\"connection_timeout\":1000,\"read_timeout\":3," + "\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}"; Assert.assertEquals(expectedJson, content); } @Test public void parse() throws IOException { - String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," + String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3," + "\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}"; XContentParser parser = XContentType.JSON .xContent() @@ -117,8 +117,8 @@ public void parse() throws IOException { ConnectorClientConfig config = ConnectorClientConfig.parse(parser); Assert.assertEquals(Integer.valueOf(10), config.getMaxConnections()); - Assert.assertEquals(Integer.valueOf(5000), config.getConnectionTimeout()); - Assert.assertEquals(Integer.valueOf(3000), config.getReadTimeout()); + Assert.assertEquals(Integer.valueOf(5000), config.getConnectionTimeoutMillis()); + Assert.assertEquals(Integer.valueOf(3), config.getReadTimeoutSeconds()); Assert.assertEquals(Integer.valueOf(123), config.getRetryBackoffMillis()); Assert.assertEquals(Integer.valueOf(456), config.getRetryTimeoutSeconds()); Assert.assertEquals(Integer.valueOf(789), config.getMaxRetryTimes()); @@ -127,7 +127,7 @@ public void parse() throws IOException { @Test public void parse_whenMalformedBackoffPolicy_thenFail() throws IOException { - String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000," + String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3," + "\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"test\"}"; XContentParser parser = XContentType.JSON .xContent() @@ -147,8 +147,8 @@ public void testDefaultValues() { ConnectorClientConfig config = ConnectorClientConfig.builder().build(); Assert.assertNull(config.getMaxConnections()); - Assert.assertNull(config.getConnectionTimeout()); - Assert.assertNull(config.getReadTimeout()); + Assert.assertNull(config.getConnectionTimeoutMillis()); + Assert.assertNull(config.getReadTimeoutSeconds()); Assert.assertNull(config.getRetryBackoffMillis()); Assert.assertNull(config.getRetryTimeoutSeconds()); Assert.assertNull(config.getMaxRetryTimes()); @@ -160,8 +160,8 @@ public void testDefaultValuesInitByNewInstance() { ConnectorClientConfig config = new ConnectorClientConfig(); Assert.assertEquals(Integer.valueOf(30), config.getMaxConnections()); - Assert.assertEquals(Integer.valueOf(30000), config.getConnectionTimeout()); - Assert.assertEquals(Integer.valueOf(30000), config.getReadTimeout()); + Assert.assertEquals(Integer.valueOf(1000), config.getConnectionTimeoutMillis()); + Assert.assertEquals(Integer.valueOf(10), config.getReadTimeoutSeconds()); Assert.assertEquals(Integer.valueOf(200), config.getRetryBackoffMillis()); Assert.assertEquals(Integer.valueOf(30), config.getRetryTimeoutSeconds()); Assert.assertEquals(Integer.valueOf(0), config.getMaxRetryTimes()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index a7df00618a..d1ca734302 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -324,8 +324,8 @@ public void testParse() throws Exception { testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName()); assertEquals(20, parsedInput.getConnectorClientConfig().getMaxConnections().intValue()); - assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeout().intValue()); - assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeout().intValue()); + assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeoutSeconds().intValue()); + assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeoutMillis().intValue()); }); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java index 46c653776d..cb6510b751 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java @@ -5,22 +5,73 @@ package org.opensearch.ml.engine.algorithms.remote; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import lombok.Getter; import lombok.Setter; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @Setter @Getter public abstract class AbstractConnectorExecutor implements RemoteConnectorExecutor { - private ConnectorClientConfig connectorClientConfig; + + @Setter + private volatile boolean connectorPrivateIpEnabled; + + private volatile AtomicReference httpClientRef = new AtomicReference<>(); + + private ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig(); public void initialize(Connector connector) { if (connector.getConnectorClientConfig() != null) { connectorClientConfig = connector.getConnectorClientConfig(); - } else { - connectorClientConfig = new ConnectorClientConfig(); + } + } + + protected SdkAsyncHttpClient getHttpClient() { + // This block for high performance retrieval after http client is created. + SdkAsyncHttpClient existingClient = httpClientRef.get(); + if (existingClient != null) { + return existingClient; + } + // This block handles concurrent http client creation. + synchronized (this) { + existingClient = httpClientRef.get(); + if (existingClient != null) { + return existingClient; + } + Duration connectionTimeout = Duration + .ofMillis( + Optional + .ofNullable(connectorClientConfig.getConnectionTimeoutMillis()) + .orElse(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE) + ); + Duration readTimeout = Duration + .ofSeconds( + Optional + .ofNullable(connectorClientConfig.getReadTimeoutSeconds()) + .orElse(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE) + ); + int maxConnection = Optional + .ofNullable(connectorClientConfig.getMaxConnections()) + .orElse(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE); + SdkAsyncHttpClient newClient = MLHttpClientFactory + .getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled); + httpClientRef.set(newClient); + return newClient; + } + } + + public void close() { + SdkAsyncHttpClient httpClient = httpClientRef.getAndSet(null); + if (httpClient != null) { + httpClient.close(); } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 3b53935aaf..32c8657cfe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -14,11 +14,9 @@ import java.security.AccessController; import java.security.PrivilegedExceptionAction; -import java.time.Duration; import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.text.StringEscapeUtils; import org.apache.logging.log4j.Logger; @@ -28,7 +26,6 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; -import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; @@ -41,15 +38,12 @@ import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.client.Client; -import com.google.common.annotations.VisibleForTesting; - import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.async.AsyncExecuteRequest; -import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @Log4j2 @ConnectorExecutor(AWS_SIGV4) @@ -73,15 +67,10 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { @Getter private MLGuard mlGuard; - private final AtomicReference httpClientRef = new AtomicReference<>(); - @Setter @Getter private StreamTransportService streamTransportService; - @Setter - private boolean connectorPrivateIpEnabled; - public AwsConnectorExecutor(Connector connector) { super.initialize(connector); this.connector = (AwsConnector) connector; @@ -183,19 +172,4 @@ private void validateLLMInterface(String llmInterface) { throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); } } - - @VisibleForTesting - protected SdkAsyncHttpClient getHttpClient() { - if (httpClientRef.get() == null) { - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); - Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); - this.httpClientRef - .compareAndSet( - null, - MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled) - ); - } - return httpClientRef.get(); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 7804770258..5c77cabddd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -13,11 +13,9 @@ import java.security.AccessController; import java.security.PrivilegedExceptionAction; -import java.time.Duration; import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.text.StringEscapeUtils; import org.apache.logging.log4j.Logger; @@ -27,7 +25,6 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; -import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; @@ -40,15 +37,12 @@ import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.client.Client; -import com.google.common.annotations.VisibleForTesting; - import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.async.AsyncExecuteRequest; -import software.amazon.awssdk.http.async.SdkAsyncHttpClient; @Log4j2 @ConnectorExecutor(HTTP) @@ -72,10 +66,6 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { @Setter @Getter private MLGuard mlGuard; - @Setter - private volatile boolean connectorPrivateIpEnabled; - - private final AtomicReference httpClientRef = new AtomicReference<>(); @Setter @Getter @@ -173,19 +163,4 @@ private void validateLLMInterface(String llmInterface) { throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); } } - - @VisibleForTesting - protected SdkAsyncHttpClient getHttpClient() { - if (httpClientRef.get() == null) { - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); - Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); - this.httpClientRef - .compareAndSet( - null, - MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled) - ); - } - return httpClientRef.get(); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java index 181296e335..954dd473be 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java @@ -71,8 +71,8 @@ public List getMcpToolSpecs() { : MCP_DEFAULT_SSE_ENDPOINT; List mcpToolSpecs = new ArrayList<>(); try { - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeoutMillis()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeoutSeconds()); Consumer headerConfig = builder -> { if (connector.getDecryptedHeaders() != null) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java index d73e294785..cd89884aa7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java @@ -74,8 +74,8 @@ public List getMcpToolSpecs() { .orElse(MCP_DEFAULT_STREAMABLE_HTTP_ENDPOINT); List mcpToolSpecs = new ArrayList<>(); try { - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeoutMillis()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeoutSeconds()); Consumer headerConfig = builder -> { if (connector.getDecryptedHeaders() != null) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 5181e8d087..4bb0c98c21 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -61,8 +61,14 @@ import org.opensearch.transport.client.Client; import lombok.Builder; +import software.amazon.awssdk.utils.SdkAutoCloseable; -public interface RemoteConnectorExecutor { +/** + * This class is responsible for executing the remote connector actions like prediction, it internally encapsulates a SdkAsyncHttpclient + * which is used to send the request to the remote model service, when the executor is being closed, we need to close the internal + * HttpClient as well. + */ +public interface RemoteConnectorExecutor extends SdkAutoCloseable { public String RETRY_EXECUTOR = "opensearch_ml_predict_remote"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index a000989ee6..a29271ba8a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -97,7 +97,10 @@ public void asyncPredict(MLInput mlInput, ActionListener actionL @Override public void close() { - this.connectorExecutor = null; + if (this.connectorExecutor != null) { + this.connectorExecutor.close(); + this.connectorExecutor = null; + } } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java index 15078dfccb..5f45873161 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java @@ -49,8 +49,8 @@ public HttpStreamingHandler(String llmInterface, Connector connector, ConnectorC this.llmInterface = llmInterface; // Get connector client configuration - Duration connectionTimeout = Duration.ofSeconds(connectorClientConfig.getConnectionTimeout()); - Duration readTimeout = Duration.ofSeconds(connectorClientConfig.getReadTimeout()); + Duration connectionTimeout = Duration.ofMillis(connectorClientConfig.getConnectionTimeoutMillis()); + Duration readTimeout = Duration.ofSeconds(connectorClientConfig.getReadTimeoutSeconds()); // Initialize OkHttp client for SSE try { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java index 56ad15cbd4..9ebe9ac751 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java @@ -29,17 +29,17 @@ public void setUp() { public void testValidateWithNullConfig() { when(mockConnector.getConnectorClientConfig()).thenReturn(null); executor.initialize(mockConnector); - assertEquals(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE, executor.getConnectorClientConfig().getMaxConnections()); - assertEquals(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getConnectionTimeout()); - assertEquals(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getReadTimeout()); + assertEquals(Integer.valueOf(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE), executor.getConnectorClientConfig().getMaxConnections()); + assertEquals(Integer.valueOf(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE), executor.getConnectorClientConfig().getConnectionTimeoutMillis()); + assertEquals(Integer.valueOf(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE), executor.getConnectorClientConfig().getReadTimeoutSeconds()); } @Test public void testValidateWithNonNullConfigButNullValues() { when(mockConnector.getConnectorClientConfig()).thenReturn(connectorClientConfig); executor.initialize(mockConnector); - assertEquals(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE, executor.getConnectorClientConfig().getMaxConnections()); - assertEquals(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getConnectionTimeout()); - assertEquals(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getReadTimeout()); + assertEquals(Integer.valueOf(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE), executor.getConnectorClientConfig().getMaxConnections()); + assertEquals(Integer.valueOf(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE), executor.getConnectorClientConfig().getConnectionTimeoutMillis()); + assertEquals(Integer.valueOf(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE), executor.getConnectorClientConfig().getReadTimeoutSeconds()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java index 98fe2b0bf7..40c15530c0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchRequestProcessorIT.java @@ -51,6 +51,9 @@ public class RestMLInferenceSearchRequestProcessorIT extends MLCommonsRestTestCa + OPENAI_KEY + "\"\n" + " },\n" + + "\"client_config\": {\n" + + " \"read_timeout\": 60\n" + + " },\n" + " \"actions\": [\n" + " {\n" + " \"action_type\": \"predict\",\n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index 2be0cd495f..4e0b27cc50 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -334,6 +334,9 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " \"parameters\": {\n" + " \"model\": \"command-a-03-2025\"\n" + " },\n" + + " \"client_config\": {\n" + + " \"read_timeout\": 60\n" + + " },\n" + " \"actions\": [\n" + " {\n" + " \"action_type\": \"predict\",\n"