diff --git a/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java index 59831f4ae5..4a6fe53520 100644 --- a/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java +++ b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java @@ -10,8 +10,10 @@ import java.time.Duration; import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.http.SdkHttpConfigurationOption; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.utils.AttributeMap; @Log4j2 public class MLHttpClientFactory { @@ -20,22 +22,33 @@ public static SdkAsyncHttpClient getAsyncHttpClient( Duration connectionTimeout, Duration readTimeout, int maxConnections, - boolean connectorPrivateIpEnabled + boolean connectorPrivateIpEnabled, + boolean skipSslVerification ) { return doPrivileged(() -> { + if (skipSslVerification) { + log + .warn( + "SSL certificate verification is DISABLED. This connection is vulnerable to man-in-the-middle" + + " attacks. Only use this setting in trusted environments." + ); + } log .debug( - "Creating MLHttpClient with connectionTimeout: {}, readTimeout: {}, maxConnections: {}", + "Creating MLHttpClient with connectionTimeout: {}, readTimeout: {}, maxConnections: {}," + " skipSslVerification: {}", connectionTimeout, readTimeout, - maxConnections + maxConnections, + skipSslVerification ); SdkAsyncHttpClient delegate = NettyNioAsyncHttpClient .builder() .connectionTimeout(connectionTimeout) .readTimeout(readTimeout) .maxConcurrency(maxConnections) - .build(); + .buildWithDefaults( + AttributeMap.builder().put(SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES, skipSslVerification).build() + ); return new MLValidatableAsyncHttpClient(delegate, connectorPrivateIpEnabled); }); } diff --git a/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java index d0664cca3a..75c3e9846d 100644 --- a/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java +++ b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java @@ -17,7 +17,10 @@ public class MLHttpClientFactoryTests { @Test public void test_getSdkAsyncHttpClient_success() { - SdkAsyncHttpClient client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false); + SdkAsyncHttpClient client = MLHttpClientFactory + .getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false, false); + assertNotNull(client); + client = MLHttpClientFactory.getAsyncHttpClient(Duration.ofSeconds(100), Duration.ofSeconds(100), 100, false, true); assertNotNull(client); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 3b53935aaf..11911f18a2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -190,10 +190,15 @@ protected SdkAsyncHttpClient getHttpClient() { Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + boolean skipSslVerification = false; + if (connector.getParameters() != null && connector.getParameters().containsKey(SKIP_SSL_VERIFICATION)) { + skipSslVerification = Boolean.parseBoolean(connector.getParameters().get(SKIP_SSL_VERIFICATION)); + } this.httpClientRef .compareAndSet( null, - MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled) + MLHttpClientFactory + .getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled, skipSslVerification) ); } return httpClientRef.get(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 7804770258..7e796afc70 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -180,10 +180,15 @@ protected SdkAsyncHttpClient getHttpClient() { Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + boolean skipSslVerification = false; + if (connector.getParameters() != null && connector.getParameters().containsKey(SKIP_SSL_VERIFICATION)) { + skipSslVerification = Boolean.parseBoolean(connector.getParameters().get(SKIP_SSL_VERIFICATION)); + } this.httpClientRef .compareAndSet( null, - MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled) + MLHttpClientFactory + .getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled, skipSslVerification) ); } return httpClientRef.get(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 5181e8d087..6fb1fb62ad 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -65,6 +65,7 @@ public interface RemoteConnectorExecutor { public String RETRY_EXECUTOR = "opensearch_ml_predict_remote"; + String SKIP_SSL_VERIFICATION = "skip_ssl_verification"; default void executeAction(String action, MLInput mlInput, ActionListener actionListener) { executeAction(action, mlInput, actionListener, null);