Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ plugin/build/
.DS_Store
*/bin/
**/*.factorypath
**/out
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ public class ConnectorClientConfig implements ToXContentObject, Writeable {
public static final String MAX_RETRY_TIMES_FIELD = "max_retry_times";
public static final String RETRY_BACKOFF_POLICY_FIELD = "retry_backoff_policy";

public static final Integer MAX_CONNECTION_DEFAULT_VALUE = Integer.valueOf(30);
public static final Integer CONNECTION_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);
public static final Integer READ_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);
public static final Integer RETRY_BACKOFF_MILLIS_DEFAULT_VALUE = 200;
public static final Integer RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE = 30;
public static final Integer MAX_RETRY_TIMES_DEFAULT_VALUE = 0;
public static final int MAX_CONNECTION_DEFAULT_VALUE = 30;
public static final int CONNECTION_TIMEOUT_DEFAULT_VALUE = 1000;
public static final int READ_TIMEOUT_DEFAULT_VALUE = 10;
public static final int RETRY_BACKOFF_MILLIS_DEFAULT_VALUE = 200;
public static final int RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE = 30;
public static final int MAX_RETRY_TIMES_DEFAULT_VALUE = 0;
public static final RetryBackoffPolicy RETRY_BACKOFF_POLICY_DEFAULT_VALUE = RetryBackoffPolicy.CONSTANT;
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_RETRY = Version.V_2_15_0;
private Integer maxConnections;
private Integer connectionTimeout;
private Integer readTimeout;
private Integer connectionTimeoutMillis;
private Integer readTimeoutSeconds;
private Integer retryBackoffMillis;
Comment on lines +47 to 49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use milliseconds for both?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, sounds very confusing to use different way to configure these timeouts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In documentation, we specified that connection timeout is in millisecond and read timeout in second: https://docs.opensearch.org/latest/ml-commons-plugin/remote-models/blueprints/#configuration-parameters, changing this could break existing customer configurations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opensearch-project/documentation-website#9460

based on this issue, the documentation was changed. Let's make the values consistent. I don't have preference about seconds vs miliseconds. But keep this consistent please. Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I can change this to milliseconds but this could be a breaking change for existing customer connector configurations, e.g. user configuring to 3s now become 3 milliseconds which basically means all remote model calls fails, @ylwu-amzn , please also confirm on this.

private Integer retryTimeoutSeconds;
private Integer maxRetryTimes;
Expand All @@ -54,16 +54,16 @@ public class ConnectorClientConfig implements ToXContentObject, Writeable {
@Builder(toBuilder = true)
public ConnectorClientConfig(
Integer maxConnections,
Integer connectionTimeout,
Integer readTimeout,
Integer connectionTimeoutMillis,
Integer readTimeoutSeconds,
Integer retryBackoffMillis,
Integer retryTimeoutSeconds,
Integer maxRetryTimes,
RetryBackoffPolicy retryBackoffPolicy
) {
this.maxConnections = maxConnections;
this.connectionTimeout = connectionTimeout;
this.readTimeout = readTimeout;
this.connectionTimeoutMillis = connectionTimeoutMillis;
this.readTimeoutSeconds = readTimeoutSeconds;
this.retryBackoffMillis = retryBackoffMillis;
this.retryTimeoutSeconds = retryTimeoutSeconds;
this.maxRetryTimes = maxRetryTimes;
Expand All @@ -73,8 +73,8 @@ public ConnectorClientConfig(
public ConnectorClientConfig(StreamInput input) throws IOException {
Version streamInputVersion = input.getVersion();
this.maxConnections = input.readOptionalInt();
this.connectionTimeout = input.readOptionalInt();
this.readTimeout = input.readOptionalInt();
this.connectionTimeoutMillis = input.readOptionalInt();
this.readTimeoutSeconds = input.readOptionalInt();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) {
this.retryBackoffMillis = input.readOptionalInt();
this.retryTimeoutSeconds = input.readOptionalInt();
Expand All @@ -87,8 +87,8 @@ public ConnectorClientConfig(StreamInput input) throws IOException {

public ConnectorClientConfig() {
this.maxConnections = MAX_CONNECTION_DEFAULT_VALUE;
this.connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE;
this.readTimeout = READ_TIMEOUT_DEFAULT_VALUE;
this.connectionTimeoutMillis = CONNECTION_TIMEOUT_DEFAULT_VALUE;
this.readTimeoutSeconds = READ_TIMEOUT_DEFAULT_VALUE;
this.retryBackoffMillis = RETRY_BACKOFF_MILLIS_DEFAULT_VALUE;
this.retryTimeoutSeconds = RETRY_TIMEOUT_SECONDS_DEFAULT_VALUE;
this.maxRetryTimes = MAX_RETRY_TIMES_DEFAULT_VALUE;
Expand All @@ -99,8 +99,8 @@ public ConnectorClientConfig() {
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeOptionalInt(maxConnections);
out.writeOptionalInt(connectionTimeout);
out.writeOptionalInt(readTimeout);
out.writeOptionalInt(connectionTimeoutMillis);
out.writeOptionalInt(readTimeoutSeconds);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_RETRY)) {
out.writeOptionalInt(retryBackoffMillis);
out.writeOptionalInt(retryTimeoutSeconds);
Expand All @@ -120,11 +120,11 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (maxConnections != null) {
builder.field(MAX_CONNECTION_FIELD, maxConnections);
}
if (connectionTimeout != null) {
builder.field(CONNECTION_TIMEOUT_FIELD, connectionTimeout);
if (connectionTimeoutMillis != null) {
builder.field(CONNECTION_TIMEOUT_FIELD, connectionTimeoutMillis);
}
if (readTimeout != null) {
builder.field(READ_TIMEOUT_FIELD, readTimeout);
if (readTimeoutSeconds != null) {
builder.field(READ_TIMEOUT_FIELD, readTimeoutSeconds);
}
if (retryBackoffMillis != null) {
builder.field(RETRY_BACKOFF_MILLIS_FIELD, retryBackoffMillis);
Expand Down Expand Up @@ -190,8 +190,8 @@ public static ConnectorClientConfig parse(XContentParser parser) throws IOExcept
return ConnectorClientConfig
.builder()
.maxConnections(maxConnections)
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.connectionTimeoutMillis(connectionTimeout)
.readTimeoutSeconds(readTimeout)
.retryBackoffMillis(retryBackoffMillis)
.retryTimeoutSeconds(retryTimeoutSeconds)
.maxRetryTimes(maxRetryTimes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ public void writeTo_ReadFromStream() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig
.builder()
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.connectionTimeoutMillis(1000)
.readTimeoutSeconds(3)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
Expand Down Expand Up @@ -55,8 +55,8 @@ public void writeTo_ReadFromStream_diffVersionThenNotProcessRetryOptions() throw
ConnectorClientConfig config = ConnectorClientConfig
.builder()
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.connectionTimeoutMillis(1000)
.readTimeoutSeconds(3)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
Expand All @@ -71,8 +71,8 @@ public void writeTo_ReadFromStream_diffVersionThenNotProcessRetryOptions() throw
ConnectorClientConfig readConfig = ConnectorClientConfig.fromStream(input);

Assert.assertEquals(Integer.valueOf(10), readConfig.getMaxConnections());
Assert.assertEquals(Integer.valueOf(5000), readConfig.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(3000), readConfig.getReadTimeout());
Assert.assertEquals(Integer.valueOf(1000), readConfig.getConnectionTimeoutMillis());
Assert.assertEquals(Integer.valueOf(3), readConfig.getReadTimeoutSeconds());
Assert.assertNull(readConfig.getRetryBackoffMillis());
Assert.assertNull(readConfig.getRetryTimeoutSeconds());
Assert.assertNull(readConfig.getMaxRetryTimes());
Expand All @@ -84,8 +84,8 @@ public void toXContent() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig
.builder()
.maxConnections(10)
.connectionTimeout(5000)
.readTimeout(3000)
.connectionTimeoutMillis(1000)
.readTimeoutSeconds(3)
.retryBackoffMillis(123)
.retryTimeoutSeconds(456)
.maxRetryTimes(789)
Expand All @@ -96,14 +96,14 @@ public void toXContent() throws IOException {
config.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

String expectedJson = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000,"
String expectedJson = "{\"max_connection\":10,\"connection_timeout\":1000,\"read_timeout\":3,"
+ "\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}";
Assert.assertEquals(expectedJson, content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000,"
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3,"
+ "\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"constant\"}";
XContentParser parser = XContentType.JSON
.xContent()
Expand All @@ -117,8 +117,8 @@ public void parse() throws IOException {
ConnectorClientConfig config = ConnectorClientConfig.parse(parser);

Assert.assertEquals(Integer.valueOf(10), config.getMaxConnections());
Assert.assertEquals(Integer.valueOf(5000), config.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(3000), config.getReadTimeout());
Assert.assertEquals(Integer.valueOf(5000), config.getConnectionTimeoutMillis());
Assert.assertEquals(Integer.valueOf(3), config.getReadTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(123), config.getRetryBackoffMillis());
Assert.assertEquals(Integer.valueOf(456), config.getRetryTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(789), config.getMaxRetryTimes());
Expand All @@ -127,7 +127,7 @@ public void parse() throws IOException {

@Test
public void parse_whenMalformedBackoffPolicy_thenFail() throws IOException {
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000,"
String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3,"
+ "\"retry_backoff_millis\":123,\"retry_timeout_seconds\":456,\"max_retry_times\":789,\"retry_backoff_policy\":\"test\"}";
XContentParser parser = XContentType.JSON
.xContent()
Expand All @@ -147,8 +147,8 @@ public void testDefaultValues() {
ConnectorClientConfig config = ConnectorClientConfig.builder().build();

Assert.assertNull(config.getMaxConnections());
Assert.assertNull(config.getConnectionTimeout());
Assert.assertNull(config.getReadTimeout());
Assert.assertNull(config.getConnectionTimeoutMillis());
Assert.assertNull(config.getReadTimeoutSeconds());
Assert.assertNull(config.getRetryBackoffMillis());
Assert.assertNull(config.getRetryTimeoutSeconds());
Assert.assertNull(config.getMaxRetryTimes());
Expand All @@ -160,8 +160,8 @@ public void testDefaultValuesInitByNewInstance() {
ConnectorClientConfig config = new ConnectorClientConfig();

Assert.assertEquals(Integer.valueOf(30), config.getMaxConnections());
Assert.assertEquals(Integer.valueOf(30000), config.getConnectionTimeout());
Assert.assertEquals(Integer.valueOf(30000), config.getReadTimeout());
Assert.assertEquals(Integer.valueOf(1000), config.getConnectionTimeoutMillis());
Assert.assertEquals(Integer.valueOf(10), config.getReadTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(200), config.getRetryBackoffMillis());
Assert.assertEquals(Integer.valueOf(30), config.getRetryTimeoutSeconds());
Assert.assertEquals(Integer.valueOf(0), config.getMaxRetryTimes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ public void testParse() throws Exception {
testParseFromJsonString(expectedInputStr, parsedInput -> {
assertEquals(TEST_CONNECTOR_NAME, parsedInput.getName());
assertEquals(20, parsedInput.getConnectorClientConfig().getMaxConnections().intValue());
assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeout().intValue());
assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeout().intValue());
assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeoutSeconds().intValue());
assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeoutMillis().intValue());
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,73 @@

package org.opensearch.ml.engine.algorithms.remote;

import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;

import lombok.Getter;
import lombok.Setter;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;

@Setter
@Getter
public abstract class AbstractConnectorExecutor implements RemoteConnectorExecutor {
private ConnectorClientConfig connectorClientConfig;

@Setter
private volatile boolean connectorPrivateIpEnabled;

private volatile AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>();

private ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig();

public void initialize(Connector connector) {
if (connector.getConnectorClientConfig() != null) {
connectorClientConfig = connector.getConnectorClientConfig();
} else {
connectorClientConfig = new ConnectorClientConfig();
}
}

protected SdkAsyncHttpClient getHttpClient() {
// This block for high performance retrieval after http client is created.
SdkAsyncHttpClient existingClient = httpClientRef.get();
if (existingClient != null) {
return existingClient;
}
// This block handles concurrent http client creation.
synchronized (this) {
existingClient = httpClientRef.get();
if (existingClient != null) {
return existingClient;
}
Duration connectionTimeout = Duration
.ofMillis(
Optional
.ofNullable(connectorClientConfig.getConnectionTimeoutMillis())
.orElse(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE)
);
Duration readTimeout = Duration
.ofSeconds(
Optional
.ofNullable(connectorClientConfig.getReadTimeoutSeconds())
.orElse(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE)
);
int maxConnection = Optional
.ofNullable(connectorClientConfig.getMaxConnections())
.orElse(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE);
SdkAsyncHttpClient newClient = MLHttpClientFactory
.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled);
httpClientRef.set(newClient);
return newClient;
}
}

public void close() {
SdkAsyncHttpClient httpClient = httpClientRef.getAndSet(null);
if (httpClient != null) {
httpClient.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;

import org.apache.commons.text.StringEscapeUtils;
import org.apache.logging.log4j.Logger;
Expand All @@ -28,7 +26,6 @@
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand All @@ -41,15 +38,12 @@
import org.opensearch.transport.StreamTransportService;
import org.opensearch.transport.client.Client;

import com.google.common.annotations.VisibleForTesting;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;

@Log4j2
@ConnectorExecutor(AWS_SIGV4)
Expand All @@ -73,15 +67,10 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
@Getter
private MLGuard mlGuard;

private final AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>();

@Setter
@Getter
private StreamTransportService streamTransportService;

@Setter
private boolean connectorPrivateIpEnabled;

public AwsConnectorExecutor(Connector connector) {
super.initialize(connector);
this.connector = (AwsConnector) connector;
Expand Down Expand Up @@ -183,19 +172,4 @@ private void validateLLMInterface(String llmInterface) {
throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface));
}
}

@VisibleForTesting
protected SdkAsyncHttpClient getHttpClient() {
if (httpClientRef.get() == null) {
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
this.httpClientRef
.compareAndSet(
null,
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
);
}
return httpClientRef.get();
}
}
Loading
Loading