Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ private MLCommonsSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final Setting<Integer> ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE = Setting
.intSetting(
ML_PLUGIN_SETTING_PREFIX + "text_similarity_batch_size",
500,
1,
1000,
Setting.Property.NodeScope,
Setting.Property.Dynamic
);
public static final Setting<Integer> ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE = Setting
.intSetting(
ML_PLUGIN_SETTING_PREFIX + "max_deploy_model_tasks_per_node",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.Locale;
import java.util.Map;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
Expand Down Expand Up @@ -51,11 +52,24 @@ public class MLEngine {

private Encryptor encryptor;

public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
@Getter
private ClusterService clusterService;

public MLEngine(Path opensearchDataFolder, Encryptor encryptor, ClusterService clusterService) {
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlConfigPath = mlCachePath.resolve("config");
this.encryptor = encryptor;
this.clusterService = clusterService;
}

/**
* @deprecated Retained for backward compatibility. Scheduled for removal. <br/>
* Use {@link #MLEngine(Path, Encryptor, ClusterService)} instead.
*/
@Deprecated(forRemoval = true)
public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
this(opensearchDataFolder, encryptor, null);
}

public String getPrebuiltModelMetaListPath() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.commons.io.FileUtils;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLException;
Expand Down Expand Up @@ -342,4 +343,10 @@ public ModelTensors parseModelTensorOutput(Output output, ModelResultFilter resu
return tensorOutput;
}

public Settings getClusterSettings() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually add all our settings here. Can you please check how other settings are propagated and follow accordingly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will check it. Thank you

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question. Would you like the Batch Processing to be enabled or disabled with a boolean feature from MLFeatureEnabledSetting?
The Batch Size, however, is passed from the MLCommonsSettings.java

if (mlEngine.getClusterService() != null) {
return mlEngine.getClusterService().getSettings();
}
return Settings.EMPTY;
}
Comment on lines +346 to +351
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard against potential null mlEngine in getClusterSettings()

As written, calling getClusterSettings() before initModel assigns mlEngine will throw an NPE on mlEngine.getClusterService(). Suggest a small defensive check:

-    public Settings getClusterSettings() {
-        if (mlEngine.getClusterService() != null) {
-            return mlEngine.getClusterService().getSettings();
-        }
-        return Settings.EMPTY;
-    }
+    public Settings getClusterSettings() {
+        if (mlEngine != null && mlEngine.getClusterService() != null) {
+            return mlEngine.getClusterService().getSettings();
+        }
+        return Settings.EMPTY;
+    }

[scratchpad_start]
[done]
DLModel.java reviewed.
[scratchpad_end] -->

🤖 Prompt for AI Agents
In ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java
around lines 346 to 351, the method getClusterSettings() can NPE if mlEngine is
null; add a defensive null check for mlEngine and for
mlEngine.getClusterService() before calling getSettings(), and return
Settings.EMPTY when either is null so the method never dereferences a null
pointer.

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import java.util.ArrayList;
import java.util.List;

import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.annotation.Function;

Expand All @@ -43,16 +45,31 @@ public class TextSimilarityCrossEncoderModel extends DLModel {
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
MLInputDataset inputDataSet = mlInput.getInputDataset();
List<ModelTensors> tensorOutputs = new ArrayList<>();
Output output;
TextSimilarityInputDataSet textSimInput = (TextSimilarityInputDataSet) inputDataSet;
String queryText = textSimInput.getQueryText();
for (String doc : textSimInput.getTextDocs()) {
Input input = new Input();
input.add(queryText);
input.add(doc);
output = getPredictor().predict(input);
ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes());
tensorOutputs.add(outputTensors);
List<String> textDocs = textSimInput.getTextDocs();

Settings clusterSettings = getClusterSettings();
final int batchSize = MLCommonsSettings.ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE.get(clusterSettings);

for (int i = 0; i < textDocs.size(); i += batchSize) {
int endIndex = Math.min(i + batchSize, textDocs.size());
List<String> batchDocs = textDocs.subList(i, endIndex);
List<Input> batchInputs = new ArrayList<>(batchDocs.size());

for (String doc : batchDocs) {
Input input = new Input();
input.add(queryText);
input.add(doc);
batchInputs.add(input);
}

List<Output> batchOutputs = getPredictor().batchPredict(batchInputs);

for (Output output : batchOutputs) {
ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes());
tensorOutputs.add(outputTensors);
}
}
return new ModelTensorOutput(tensorOutputs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.PairList;

public class TextSimilarityTranslator extends SentenceTransformerTranslator {
public final String SIMILARITY_NAME = "similarity";
Expand Down Expand Up @@ -95,4 +96,73 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
return output;
}

@Override
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) {
NDManager manager = ctx.getNDManager();
int batchSize = inputs.size();
List<String> sentences = new ArrayList<>(batchSize);
List<String> contexts = new ArrayList<>(batchSize);
for (Input input : inputs) {
String sentence = input.getAsString(0);
String context = input.getAsString(1);
sentences.add(sentence);
contexts.add(context);
}
// Tokenize in batches
Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts));
int seqLen = encodings[0].getIds().length;
for (Encoding enc : encodings) {
seqLen = Math.max(seqLen, enc.getIds().length);
}
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMasks = new long[batchSize][seqLen];
long[][] tokenTypeIds = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
inputIds[i] = encodings[i].getIds();
attentionMasks[i] = encodings[i].getAttentionMask();
tokenTypeIds[i] = encodings[i].getTypeIds();
}
NDArray inputIdsArray = manager.create(inputIds);
inputIdsArray.setName("input_ids");
NDArray attentionMaskArray = manager.create(attentionMasks);
attentionMaskArray.setName("attention_mask");
NDArray tokenTypeArray = manager.create(tokenTypeIds);
tokenTypeArray.setName("token_type_ids");
NDList ndList = new NDList();
ndList.add(inputIdsArray);
ndList.add(attentionMaskArray);
ndList.add(tokenTypeArray);
return ndList;
}
Comment on lines +99 to +136
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix 2D array construction in batchProcessInput

You compute seqLen and allocate [batchSize][seqLen], but then overwrite each row with encodings[i].getIds() / masks / types. That makes the initial allocation and seqLen effectively unused and can produce ragged arrays.

Recommend keeping the fixed row length and copying values instead:

-        int seqLen = encodings[0].getIds().length;
-        for (Encoding enc : encodings) {
-            seqLen = Math.max(seqLen, enc.getIds().length);
-        }
-        long[][] inputIds = new long[batchSize][seqLen];
-        long[][] attentionMasks = new long[batchSize][seqLen];
-        long[][] tokenTypeIds = new long[batchSize][seqLen];
-        for (int i = 0; i < batchSize; i++) {
-            inputIds[i] = encodings[i].getIds();
-            attentionMasks[i] = encodings[i].getAttentionMask();
-            tokenTypeIds[i] = encodings[i].getTypeIds();
-        }
+        int seqLen = encodings[0].getIds().length;
+        for (Encoding enc : encodings) {
+            seqLen = Math.max(seqLen, enc.getIds().length);
+        }
+
+        long[][] inputIds = new long[batchSize][seqLen];
+        long[][] attentionMasks = new long[batchSize][seqLen];
+        long[][] tokenTypeIds = new long[batchSize][seqLen];
+        for (int i = 0; i < batchSize; i++) {
+            long[] ids = encodings[i].getIds();
+            long[] masks = encodings[i].getAttentionMask();
+            long[] types = encodings[i].getTypeIds();
+            System.arraycopy(ids, 0, inputIds[i], 0, ids.length);
+            System.arraycopy(masks, 0, attentionMasks[i], 0, masks.length);
+            System.arraycopy(types, 0, tokenTypeIds[i], 0, types.length);
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@Override
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) {
NDManager manager = ctx.getNDManager();
int batchSize = inputs.size();
List<String> sentences = new ArrayList<>(batchSize);
List<String> contexts = new ArrayList<>(batchSize);
for (Input input : inputs) {
String sentence = input.getAsString(0);
String context = input.getAsString(1);
sentences.add(sentence);
contexts.add(context);
}
// Tokenize in batches
Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts));
int seqLen = encodings[0].getIds().length;
for (Encoding enc : encodings) {
seqLen = Math.max(seqLen, enc.getIds().length);
}
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMasks = new long[batchSize][seqLen];
long[][] tokenTypeIds = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
inputIds[i] = encodings[i].getIds();
attentionMasks[i] = encodings[i].getAttentionMask();
tokenTypeIds[i] = encodings[i].getTypeIds();
}
NDArray inputIdsArray = manager.create(inputIds);
inputIdsArray.setName("input_ids");
NDArray attentionMaskArray = manager.create(attentionMasks);
attentionMaskArray.setName("attention_mask");
NDArray tokenTypeArray = manager.create(tokenTypeIds);
tokenTypeArray.setName("token_type_ids");
NDList ndList = new NDList();
ndList.add(inputIdsArray);
ndList.add(attentionMaskArray);
ndList.add(tokenTypeArray);
return ndList;
}
@Override
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) {
NDManager manager = ctx.getNDManager();
int batchSize = inputs.size();
List<String> sentences = new ArrayList<>(batchSize);
List<String> contexts = new ArrayList<>(batchSize);
for (Input input : inputs) {
String sentence = input.getAsString(0);
String context = input.getAsString(1);
sentences.add(sentence);
contexts.add(context);
}
// Tokenize in batches
Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts));
int seqLen = encodings[0].getIds().length;
for (Encoding enc : encodings) {
seqLen = Math.max(seqLen, enc.getIds().length);
}
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMasks = new long[batchSize][seqLen];
long[][] tokenTypeIds = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
long[] ids = encodings[i].getIds();
long[] masks = encodings[i].getAttentionMask();
long[] types = encodings[i].getTypeIds();
System.arraycopy(ids, 0, inputIds[i], 0, ids.length);
System.arraycopy(masks, 0, attentionMasks[i], 0, masks.length);
System.arraycopy(types, 0, tokenTypeIds[i], 0, types.length);
}
NDArray inputIdsArray = manager.create(inputIds);
inputIdsArray.setName("input_ids");
NDArray attentionMaskArray = manager.create(attentionMasks);
attentionMaskArray.setName("attention_mask");
NDArray tokenTypeArray = manager.create(tokenTypeIds);
tokenTypeArray.setName("token_type_ids");
NDList ndList = new NDList();
ndList.add(inputIdsArray);
ndList.add(attentionMaskArray);
ndList.add(tokenTypeArray);
return ndList;
}
🤖 Prompt for AI Agents
In
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java
around lines 99 to 136, the code allocates fixed-size 2D arrays
[batchSize][seqLen] then replaces entire rows with encodings[i].getIds() etc.,
which discards the fixed length and creates ragged arrays; instead, copy each
encoding into the preallocated row and pad remaining elements to the fixed
seqLen (e.g., with 0 for input_ids and token_type_ids, and 0 for attention_mask)
so every row has length seqLen, or throw if an encoding exceeds seqLen;
implement a loop that for each i copies enc.getIds(), enc.getAttentionMask(),
enc.getTypeIds() into inputIds[i][0..len-1] etc., leaving the rest as zeros
before creating NDArrays.


@Override
public List<Output> batchProcessOutput(TranslatorContext ctx, NDList list) {
NDArray batchArray = list.getFirst();
int batchSize = (int) batchArray.getShape().get(0);
List<Output> outputs = new ArrayList<>(batchSize);
for (int i = 0; i < batchSize; i++) {
NDArray itemArray = batchArray.get(i);

Number[] itemData = itemArray.toArray();
long[] itemShape = itemArray.getShape().getShape();
DataType dataType = itemArray.getDataType();
MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name());
ByteBuffer itemBuffer = itemArray.toByteBuffer();

ModelTensor tensor = ModelTensor
.builder()
.name(SIMILARITY_NAME)
.data(itemData)
.shape(itemShape)
.dataType(mlResultDataType)
.byteBuffer(itemBuffer)
.build();

ModelTensors modelTensorOutput = new ModelTensors(List.of(tensor));
Output output = new Output(200, "OK");
output.add(modelTensorOutput.toBytes());
outputs.add(output);
}
return outputs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -173,6 +175,80 @@ public void test_TextSimilarity_Translator_ProcessOutput() throws URISyntaxExcep
assertEquals(1, data.length);
}

@Test
public void test_TextSimilarity_Translator_BatchProcessInput() throws URISyntaxException, IOException {
TextSimilarityTranslator textSimilarityTranslator = new TextSimilarityTranslator();
TranslatorContext translatorContext = mock(TranslatorContext.class);
Model mlModel = mock(Model.class);
when(translatorContext.getModel()).thenReturn(mlModel);
when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent());
textSimilarityTranslator.prepare(translatorContext);

NDManager manager = mock(NDManager.class);
when(translatorContext.getNDManager()).thenReturn(manager);
Input input = mock(Input.class);
String testSentence = "hello world";
when(input.getAsString(0)).thenReturn(testSentence);
when(input.getAsString(1)).thenReturn(testSentence);
NDArray indiceNdArray = mock(NDArray.class);
when(indiceNdArray.toLongArray()).thenReturn(new long[] { 102l, 101l });
when(manager.create((long[][]) any())).thenReturn(indiceNdArray);
doNothing().when(indiceNdArray).setName(any());
List<Input> inputList = new ArrayList<>(1);
inputList.add(input);
NDList outputList = textSimilarityTranslator.batchProcessInput(translatorContext, inputList);
assertEquals(3, outputList.size());
Iterator<NDArray> iterator = outputList.iterator();
while (iterator.hasNext()) {
NDArray ndArray = iterator.next();
long[] output = ndArray.toLongArray();
assertEquals(2, output.length);
}
}

@Test
public void test_TextSimilarity_Translator_BatchProcessOutput() throws URISyntaxException, IOException {
TextSimilarityTranslator textSimilarityTranslator = new TextSimilarityTranslator();
TranslatorContext translatorContext = mock(TranslatorContext.class);
Model mlModel = mock(Model.class);
when(translatorContext.getModel()).thenReturn(mlModel);
when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent());
textSimilarityTranslator.prepare(translatorContext);

NDArray batchArray = mock(NDArray.class);
Shape batchShape = mock(Shape.class);
when(batchArray.getShape()).thenReturn(batchShape);
when(batchShape.get(0)).thenReturn(2L);

NDArray itemArray1 = mock(NDArray.class);
NDArray itemArray2 = mock(NDArray.class);
Shape itemShape = mock(Shape.class);
when(itemShape.getShape()).thenReturn(new long[] { 1 });
when(itemArray1.getShape()).thenReturn(itemShape);
when(itemArray2.getShape()).thenReturn(itemShape);
when(itemArray1.toArray()).thenReturn(new Number[] { 1.0f });
when(itemArray2.toArray()).thenReturn(new Number[] { 2.0f });
when(itemArray1.getDataType()).thenReturn(DataType.FLOAT32);
when(itemArray2.getDataType()).thenReturn(DataType.FLOAT32);
when(itemArray1.toByteBuffer()).thenReturn(ByteBuffer.allocate(4));
when(itemArray2.toByteBuffer()).thenReturn(ByteBuffer.allocate(4));
when(batchArray.get(0)).thenReturn(itemArray1);
when(batchArray.get(1)).thenReturn(itemArray2);

NDList ndList = new NDList(batchArray);
List<Output> outputs = textSimilarityTranslator.batchProcessOutput(translatorContext, ndList);
assertEquals(2, outputs.size());
for (Output output : outputs) {
byte[] bytes = output.getData().getAsBytes();
ModelTensors tensorOutput = ModelTensors.fromBytes(bytes);
List<ModelTensor> modelTensorsList = tensorOutput.getMlModelTensors();
assertEquals(1, modelTensorsList.size());
ModelTensor modelTensor = modelTensorsList.get(0);
assertEquals("similarity", modelTensor.getName());
assertEquals(1, modelTensor.getData().length);
}
}

@Test
public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxException {
textSimilarityCrossEncoderModel.initModel(model, params, encryptor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ public Collection<Object> createComponents(

encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler);

mlEngine = new MLEngine(dataPath, encryptor);
mlEngine = new MLEngine(dataPath, encryptor, clusterService);
nodeHelper = new DiscoveryNodeHelper(clusterService, settings);
modelCacheHelper = new MLModelCacheHelper(clusterService, settings);
cmHandler = new OpenSearchConversationalMemoryHandler(client, clusterService);
Expand Down Expand Up @@ -1364,7 +1364,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
MLCommonsSettings.REMOTE_METADATA_GLOBAL_TENANT_ID,
MLCommonsSettings.REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL,
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED,
MLCommonsSettings.ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE
);
return settings;
}
Expand Down
Loading