Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
Expand All @@ -26,6 +23,7 @@
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLOutputType;
import org.opensearch.secure_sm.AccessController;
import org.reflections.Reflections;

import com.fasterxml.jackson.core.JsonParseException;
Expand All @@ -43,14 +41,7 @@ public class MLCommonsClassLoader {
private static Map<String, Class<?>> connectorClassMap = new HashMap<>();

static {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
loadClassMapping();
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException("Can't load class mapping in ML commons", e);
}
AccessController.doPrivileged(MLCommonsClassLoader::loadClassMapping);
}

public static void loadClassMapping() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.List;
import java.util.Map;
Expand All @@ -35,6 +32,7 @@
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.secure_sm.AccessController;

/**
* Connector defines how to connect to a remote service.
Expand Down Expand Up @@ -129,15 +127,9 @@ static Connector createConnector(XContentBuilder builder, String connectorProtoc
}
}

@SuppressWarnings("removal")
static Connector createConnector(XContentParser parser) throws IOException {
Map<String, Object> connectorMap = parser.map();
String jsonStr;
try {
jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(connectorMap));
} catch (PrivilegedActionException e) {
throw new IllegalArgumentException("wrong connector");
}
String jsonStr = AccessController.doPrivileged(() -> gson.toJson(connectorMap));
String connectorProtocol = (String) connectorMap.get("protocol");

return createConnector(jsonStr, connectorProtocol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
Expand All @@ -38,6 +36,7 @@
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.secure_sm.AccessController;
import org.opensearch.transport.client.Client;

import lombok.Builder;
Expand Down Expand Up @@ -225,7 +224,7 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
Map<String, Object> queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
CountDownLatch latch = new CountDownLatch(1);
try {
queryBody = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBodyMap));
queryBody = AccessController.doPrivileged(() -> gson.toJson(queryBodyMap));
SearchDataObjectRequest searchDataObjectRequest = buildSearchDataObjectRequest(indexName, queryBody);
var responseListener = new LatchedActionListener<>(ActionListener.<SearchResponse>wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import static org.opensearch.ml.common.utils.StringUtils.gson;

import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
Expand All @@ -38,6 +36,7 @@
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.secure_sm.AccessController;
import org.opensearch.transport.client.Client;

import lombok.Builder;
Expand Down Expand Up @@ -107,8 +106,7 @@ public Boolean validate(String in, Map<String, String> parameters) {
ActionListener<MLTaskResponse> internalListener = ActionListener.wrap(predictionResponse -> {
ModelTensorOutput output = (ModelTensorOutput) predictionResponse.getOutput();
ModelTensor tensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0);
String guardrailResponse = AccessController
.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(tensor.getDataAsMap().get("response")));
String guardrailResponse = AccessController.doPrivileged(() -> gson.toJson(tensor.getDataAsMap().get("response")));
log.info("Guardrail response: {}", guardrailResponse);
if (!validateAcceptRegex(guardrailResponse)) {
isAccepted.set(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -26,6 +23,7 @@
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.secure_sm.AccessController;

import lombok.Builder;
import lombok.Data;
Expand Down Expand Up @@ -278,14 +276,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(result);
if (dataAsMap != null) {
out.writeBoolean(true);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
out.writeString(gson.toJson(dataAsMap));
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
AccessController.doPrivilegedChecked(() -> { out.writeString(gson.toJson(dataAsMap)); });
} else {
out.writeBoolean(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
Expand All @@ -38,6 +35,7 @@
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.secure_sm.AccessController;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -248,79 +246,56 @@ public static Map<String, String> filteredParameterMap(Map<String, ?> parameterO
filteredKeys.retainAll(allowedList);
for (String key : filteredKeys) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
AccessController.doPrivileged(() -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
}
});
}
return parameters;
}

@SuppressWarnings("removal")
public static Map<String, String> getParameterMap(Map<String, ?> parameterObjs) {
Map<String, String> parameters = new HashMap<>();
if (parameterObjs == null)
return parameters;
for (String key : parameterObjs.keySet()) {
Object value = parameterObjs.get(key);
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
if (value instanceof String) {
parameters.put(key, (String) value);
} else {
parameters.put(key, gson.toJson(value));
}
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
}
return parameters;
}

@SuppressWarnings("removal")
public static String toJson(Object value) {
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> {
AccessController.doPrivileged(() -> {
if (value instanceof String) {
return (String) value;
parameters.put(key, (String) value);
} else {
return gson.toJson(value);
parameters.put(key, gson.toJson(value));
}
});
} catch (PrivilegedActionException e) {
throw new RuntimeException(e);
}
return parameters;
}

public static String toJson(Object value) {
return AccessController.doPrivileged(() -> {
if (value instanceof String) {
return (String) value;
} else {
return gson.toJson(value);
}
});
}

@SuppressWarnings("removal")
public static Map<String, String> convertScriptStringToJsonString(Map<String, Object> processedInput) {
Map<String, String> parameterStringMap = new HashMap<>();
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Map<String, Object> parametersMap = (Map<String, Object>) processedInput.getOrDefault("parameters", Map.of());
for (String key : parametersMap.keySet()) {
if (parametersMap.get(key) instanceof String) {
parameterStringMap.put(key, (String) parametersMap.get(key));
} else {
parameterStringMap.put(key, gson.toJson(parametersMap.get(key)));
}
AccessController.doPrivileged(() -> {
Map<String, Object> parametersMap = (Map<String, Object>) processedInput.getOrDefault("parameters", Map.of());
for (String key : parametersMap.keySet()) {
if (parametersMap.get(key) instanceof String) {
parameterStringMap.put(key, (String) parametersMap.get(key));
} else {
parameterStringMap.put(key, gson.toJson(parametersMap.get(key)));
}
return null;
});
} catch (PrivilegedActionException e) {
log.error("Error processing parameters", e);
throw new RuntimeException(e);
}
}
});
return parameterStringMap;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
package org.opensearch.ml.engine;

import java.lang.reflect.Constructor;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
Expand All @@ -23,6 +20,7 @@
import org.opensearch.ml.engine.annotation.Ingester;
import org.opensearch.ml.engine.annotation.Processor;
import org.opensearch.ml.engine.processor.MLProcessorType;
import org.opensearch.secure_sm.AccessController;
import org.reflections.Reflections;

@SuppressWarnings("removal")
Expand All @@ -43,16 +41,11 @@ public class MLEngineClassLoader {
private static Map<Enum<?>, Object> mlObjects = new HashMap<>();

static {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
loadClassMapping();
loadIngestClassMapping();
loadMLProcessorClassMapping();
return null;
});
} catch (PrivilegedActionException e) {
throw new RuntimeException("Can't load class mapping in ML engine", e);
}
AccessController.doPrivileged(() -> {
loadClassMapping();
loadIngestClassMapping();
loadMLProcessorClassMapping();
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
Expand All @@ -34,6 +31,7 @@
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.secure_sm.AccessController;

import com.google.gson.stream.JsonReader;

Expand All @@ -58,7 +56,6 @@ public ModelHelper(MLEngine mlEngine) {
this.mlEngine = mlEngine;
}

@SuppressWarnings("removal")
public void downloadPrebuiltModelConfig(
String taskId,
MLRegisterModelInput registerModelInput,
Expand All @@ -74,7 +71,7 @@ public void downloadPrebuiltModelConfig(
String modelGroupId = registerModelInput.getModelGroupId();
MLDeploySetting mlDeploySetting = registerModelInput.getDeploySetting();
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
AccessController.doPrivilegedChecked(() -> {

Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
String configCacheFilePath = registerModelPath.resolve("config.json").toString();
Expand Down Expand Up @@ -223,12 +220,11 @@ public boolean isModelAllowed(MLRegisterModelInput registerModelInput, List mode
return false;
}

@SuppressWarnings("removal")
public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput registerModelInput) throws PrivilegedActionException {
public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput registerModelInput) throws IOException {
String modelName = registerModelInput.getModelName();
String version = registerModelInput.getVersion();
try {
return AccessController.doPrivileged((PrivilegedExceptionAction<List>) () -> {
return AccessController.doPrivilegedChecked(() -> {

Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
String cacheFilePath = registerModelPath.resolve("model_meta_list.json").toString();
Expand Down Expand Up @@ -257,7 +253,6 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
@SuppressWarnings("removal")
public void downloadAndSplit(
MLModelFormat modelFormat,
String taskId,
Expand All @@ -269,7 +264,7 @@ public void downloadAndSplit(
ActionListener<Map<String, Object>> listener
) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
AccessController.doPrivilegedChecked(() -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
String modelPath = registerModelPath + ".zip";
Path modelPartsPath = registerModelPath.resolve("chunks");
Expand Down
Loading
Loading