From a3ef894c4ff132e85d18ffbd46a327f34518d35d Mon Sep 17 00:00:00 2001 From: Mikhail Urmich <32458509+urmichm@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:42:55 +0100 Subject: [PATCH 1/3] Text similarity, batchPredict Text similarity is done in batches. Translator is updated to handle the batch process input/output Signed-off-by: Michael --- .../ml/common/settings/MLCommonsSettings.java | 9 +++ .../org/opensearch/ml/engine/MLEngine.java | 12 +++- .../ml/engine/algorithms/DLModel.java | 8 +++ .../TextSimilarityCrossEncoderModel.java | 33 ++++++--- .../TextSimilarityTranslator.java | 67 +++++++++++++++++++ .../ml/plugin/MachineLearningPlugin.java | 5 +- 6 files changed, 123 insertions(+), 11 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java index c139ea4b68..ee8844e5d5 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java @@ -76,6 +76,15 @@ private MLCommonsSettings() {} Setting.Property.NodeScope, Setting.Property.Dynamic ); + public static final Setting ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE = Setting + .intSetting( + ML_PLUGIN_SETTING_PREFIX + "text_similarity_batch_size", + 500, + 1, + 1000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); public static final Setting ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE = Setting .intSetting( ML_PLUGIN_SETTING_PREFIX + "max_deploy_model_tasks_per_node", diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index 05f97475de..a0ebec2456 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -12,6 +12,7 @@ import java.util.Locale; import java.util.Map; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; @@ -51,11 +52,20 @@ public class MLEngine { private Encryptor encryptor; - public MLEngine(Path opensearchDataFolder, Encryptor encryptor) { + @Getter + private ClusterService clusterService; + + public MLEngine(Path opensearchDataFolder, Encryptor encryptor, ClusterService clusterService) { this.mlCachePath = opensearchDataFolder.resolve("ml_cache"); this.mlModelsCachePath = mlCachePath.resolve("models_cache"); this.mlConfigPath = mlCachePath.resolve("config"); this.encryptor = encryptor; + this.clusterService = clusterService; + } + + // QUESTION: May we remove this constructor? + public MLEngine(Path opensearchDataFolder, Encryptor encryptor) { + this(opensearchDataFolder, encryptor, null); } public String getPrebuiltModelMetaListPath() { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java index 0aa755bd6b..ff8c0d9b4d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.io.FileUtils; +import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLException; @@ -342,4 +343,11 @@ public ModelTensors parseModelTensorOutput(Output output, ModelResultFilter resu return tensorOutput; } + public Settings getClusterSettings() { + if (mlEngine.getClusterService() != null) { + // QUESTION: removing the constructor will make cluster settings non-null + return mlEngine.getClusterService().getSettings(); + } + return Settings.EMPTY; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java index d3049c851a..7335109f08 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.List; +import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; @@ -27,6 +28,7 @@ import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.settings.MLCommonsSettings; import org.opensearch.ml.engine.algorithms.DLModel; import org.opensearch.ml.engine.annotation.Function; @@ -43,16 +45,31 @@ public class TextSimilarityCrossEncoderModel extends DLModel { public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { MLInputDataset inputDataSet = mlInput.getInputDataset(); List tensorOutputs = new ArrayList<>(); - Output output; TextSimilarityInputDataSet textSimInput = (TextSimilarityInputDataSet) inputDataSet; String queryText = textSimInput.getQueryText(); - for (String doc : textSimInput.getTextDocs()) { - Input input = new Input(); - input.add(queryText); - input.add(doc); - output = getPredictor().predict(input); - ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes()); - tensorOutputs.add(outputTensors); + List textDocs = textSimInput.getTextDocs(); + + Settings clusterSettings = getClusterSettings(); + final int batchSize = MLCommonsSettings.ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE.get(clusterSettings); + + for (int i = 0; i < textDocs.size(); i += batchSize) { + int endIndex = Math.min(i + batchSize, textDocs.size()); + List batchDocs = textDocs.subList(i, endIndex); + List batchInputs = new ArrayList<>(batchDocs.size()); + + for (String doc : batchDocs) { + Input input = new Input(); + input.add(queryText); + input.add(doc); + batchInputs.add(input); + } + + List batchOutputs = getPredictor().batchPredict(batchInputs); + + for (Output output: batchOutputs) { + ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes()); + tensorOutputs.add(outputTensors); + } } return new ModelTensorOutput(tensorOutputs); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java index 4967d2035b..3b71a9e763 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java @@ -35,6 +35,7 @@ import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.DataType; import ai.djl.translate.TranslatorContext; +import ai.djl.util.PairList; public class TextSimilarityTranslator extends SentenceTransformerTranslator { public final String SIMILARITY_NAME = "similarity"; @@ -95,4 +96,70 @@ public Output processOutput(TranslatorContext ctx, NDList list) { return output; } + @Override + public NDList batchProcessInput(TranslatorContext ctx, List inputs) { + NDManager manager = ctx.getNDManager(); + int batchSize = inputs.size(); + List sentences = new ArrayList<>(batchSize); + List contexts = new ArrayList<>(batchSize); + for (Input input : inputs) { + String sentence = input.getAsString(0); + String context = input.getAsString(1); + sentences.add(sentence); + contexts.add(context); + } + // Tokenize in batches + Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts)); + int seqLen = encodings[0].getIds().length; + long[][] inputIds = new long[batchSize][seqLen]; + long[][] attentionMasks = new long[batchSize][seqLen]; + long[][] tokenTypeIds = new long[batchSize][seqLen]; + for (int i = 0; i < batchSize; i++) { + inputIds[i] = encodings[i].getIds(); + attentionMasks[i] = encodings[i].getAttentionMask(); + tokenTypeIds[i] = encodings[i].getTypeIds(); + } + NDArray inputIdsArray = manager.create(inputIds); + inputIdsArray.setName("input_ids"); + NDArray attentionMaskArray = manager.create(attentionMasks); + attentionMaskArray.setName("attention_mask"); + NDArray tokenTypeArray = manager.create(tokenTypeIds); + tokenTypeArray.setName("token_type_ids"); + NDList ndList = new NDList(); + ndList.add(inputIdsArray); + ndList.add(attentionMaskArray); + ndList.add(tokenTypeArray); + return ndList; + } + + @Override + public List batchProcessOutput(TranslatorContext ctx, NDList list) { + NDArray batchArray = list.get(0); + int batchSize = (int) batchArray.getShape().get(0); + List outputs = new ArrayList<>(batchSize); + for (int i = 0; i < batchSize; i++) { + NDArray itemArray = batchArray.get(i); + + Number[] itemData = itemArray.toArray(); + long[] itemShape = itemArray.getShape().getShape(); + DataType dataType = itemArray.getDataType(); + MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name()); + ByteBuffer itemBuffer = itemArray.toByteBuffer(); + + ModelTensor tensor = ModelTensor + .builder() + .name(SIMILARITY_NAME) + .data(itemData) + .shape(itemShape) + .dataType(mlResultDataType) + .byteBuffer(itemBuffer) + .build(); + + ModelTensors modelTensorOutput = new ModelTensors(List.of(tensor)); + Output output = new Output(200, "OK"); + output.add(modelTensorOutput.toBytes()); + outputs.add(output); + } + return outputs; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 62de34961e..b84f71e3fa 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -682,7 +682,7 @@ public Collection createComponents( encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); - mlEngine = new MLEngine(dataPath, encryptor); + mlEngine = new MLEngine(dataPath, encryptor, clusterService); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); modelCacheHelper = new MLModelCacheHelper(clusterService, settings); cmHandler = new OpenSearchConversationalMemoryHandler(client, clusterService); @@ -1364,7 +1364,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED, MLCommonsSettings.REMOTE_METADATA_GLOBAL_TENANT_ID, MLCommonsSettings.REMOTE_METADATA_GLOBAL_RESOURCE_CACHE_TTL, - MLCommonsSettings.ML_COMMONS_STREAM_ENABLED + MLCommonsSettings.ML_COMMONS_STREAM_ENABLED, + MLCommonsSettings.ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE ); return settings; } From 6f080a75763d1c5ed06f51039e4b8fb8c25fbfdd Mon Sep 17 00:00:00 2001 From: Mikhail Urmich <32458509+urmichm@users.noreply.github.com> Date: Sat, 6 Dec 2025 15:25:26 +0100 Subject: [PATCH 2/3] Unit Tests and spotless apply - TextSimilarityTranslator unit tests for batch process input and output. - `spotlessApply` applied Signed-off-by: Michael --- .../TextSimilarityCrossEncoderModel.java | 2 +- .../TextSimilarityTranslator.java | 5 +- .../TextSimilarityCrossEncoderModelTest.java | 76 +++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java index 7335109f08..26d02f8844 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java @@ -66,7 +66,7 @@ public ModelTensorOutput predict(String modelId, MLInput mlInput) throws Transla List batchOutputs = getPredictor().batchPredict(batchInputs); - for (Output output: batchOutputs) { + for (Output output : batchOutputs) { ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes()); tensorOutputs.add(outputTensors); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java index 3b71a9e763..1ebabf8324 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java @@ -111,6 +111,9 @@ public NDList batchProcessInput(TranslatorContext ctx, List inputs) { // Tokenize in batches Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts)); int seqLen = encodings[0].getIds().length; + for (Encoding enc : encodings) { + seqLen = Math.max(seqLen, enc.getIds().length); + } long[][] inputIds = new long[batchSize][seqLen]; long[][] attentionMasks = new long[batchSize][seqLen]; long[][] tokenTypeIds = new long[batchSize][seqLen]; @@ -134,7 +137,7 @@ public NDList batchProcessInput(TranslatorContext ctx, List inputs) { @Override public List batchProcessOutput(TranslatorContext ctx, NDList list) { - NDArray batchArray = list.get(0); + NDArray batchArray = list.getFirst(); int batchSize = (int) batchArray.getShape().get(0); List outputs = new ArrayList<>(batchSize); for (int i = 0; i < batchSize; i++) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java index c09f2a42c0..ddd733ecaf 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java @@ -29,8 +29,10 @@ import java.io.File; import java.io.IOException; import java.net.URISyntaxException; +import java.nio.ByteBuffer; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -173,6 +175,80 @@ public void test_TextSimilarity_Translator_ProcessOutput() throws URISyntaxExcep assertEquals(1, data.length); } + @Test + public void test_TextSimilarity_Translator_BatchProcessInput() throws URISyntaxException, IOException { + TextSimilarityTranslator textSimilarityTranslator = new TextSimilarityTranslator(); + TranslatorContext translatorContext = mock(TranslatorContext.class); + Model mlModel = mock(Model.class); + when(translatorContext.getModel()).thenReturn(mlModel); + when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent()); + textSimilarityTranslator.prepare(translatorContext); + + NDManager manager = mock(NDManager.class); + when(translatorContext.getNDManager()).thenReturn(manager); + Input input = mock(Input.class); + String testSentence = "hello world"; + when(input.getAsString(0)).thenReturn(testSentence); + when(input.getAsString(1)).thenReturn(testSentence); + NDArray indiceNdArray = mock(NDArray.class); + when(indiceNdArray.toLongArray()).thenReturn(new long[] { 102l, 101l }); + when(manager.create((long[][]) any())).thenReturn(indiceNdArray); + doNothing().when(indiceNdArray).setName(any()); + List inputList = new ArrayList<>(1); + inputList.add(input); + NDList outputList = textSimilarityTranslator.batchProcessInput(translatorContext, inputList); + assertEquals(3, outputList.size()); + Iterator iterator = outputList.iterator(); + while (iterator.hasNext()) { + NDArray ndArray = iterator.next(); + long[] output = ndArray.toLongArray(); + assertEquals(2, output.length); + } + } + + @Test + public void test_TextSimilarity_Translator_BatchProcessOutput() throws URISyntaxException, IOException { + TextSimilarityTranslator textSimilarityTranslator = new TextSimilarityTranslator(); + TranslatorContext translatorContext = mock(TranslatorContext.class); + Model mlModel = mock(Model.class); + when(translatorContext.getModel()).thenReturn(mlModel); + when(mlModel.getModelPath()).thenReturn(Paths.get(getClass().getResource("../tokenize/tokenizer.json").toURI()).getParent()); + textSimilarityTranslator.prepare(translatorContext); + + NDArray batchArray = mock(NDArray.class); + Shape batchShape = mock(Shape.class); + when(batchArray.getShape()).thenReturn(batchShape); + when(batchShape.get(0)).thenReturn(2L); + + NDArray itemArray1 = mock(NDArray.class); + NDArray itemArray2 = mock(NDArray.class); + Shape itemShape = mock(Shape.class); + when(itemShape.getShape()).thenReturn(new long[] { 1 }); + when(itemArray1.getShape()).thenReturn(itemShape); + when(itemArray2.getShape()).thenReturn(itemShape); + when(itemArray1.toArray()).thenReturn(new Number[] { 1.0f }); + when(itemArray2.toArray()).thenReturn(new Number[] { 2.0f }); + when(itemArray1.getDataType()).thenReturn(DataType.FLOAT32); + when(itemArray2.getDataType()).thenReturn(DataType.FLOAT32); + when(itemArray1.toByteBuffer()).thenReturn(ByteBuffer.allocate(4)); + when(itemArray2.toByteBuffer()).thenReturn(ByteBuffer.allocate(4)); + when(batchArray.get(0)).thenReturn(itemArray1); + when(batchArray.get(1)).thenReturn(itemArray2); + + NDList ndList = new NDList(batchArray); + List outputs = textSimilarityTranslator.batchProcessOutput(translatorContext, ndList); + assertEquals(2, outputs.size()); + for (Output output : outputs) { + byte[] bytes = output.getData().getAsBytes(); + ModelTensors tensorOutput = ModelTensors.fromBytes(bytes); + List modelTensorsList = tensorOutput.getMlModelTensors(); + assertEquals(1, modelTensorsList.size()); + ModelTensor modelTensor = modelTensorsList.get(0); + assertEquals("similarity", modelTensor.getName()); + assertEquals(1, modelTensor.getData().length); + } + } + @Test public void initModel_predict_TorchScript_CrossEncoder() throws URISyntaxException { textSimilarityCrossEncoderModel.initModel(model, params, encryptor); From 216990b69f94ea254f3e663364cf90e806787fff Mon Sep 17 00:00:00 2001 From: Mikhail Urmich <32458509+urmichm@users.noreply.github.com> Date: Sat, 6 Dec 2025 19:52:09 +0100 Subject: [PATCH 3/3] Cosmetic Updates - old MLEngine constructor marked as deprecated - useless comment on DLModel removed Signed-off-by: Michael --- .../src/main/java/org/opensearch/ml/engine/MLEngine.java | 6 +++++- .../java/org/opensearch/ml/engine/algorithms/DLModel.java | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index a0ebec2456..682c3329ed 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -63,7 +63,11 @@ public MLEngine(Path opensearchDataFolder, Encryptor encryptor, ClusterService c this.clusterService = clusterService; } - // QUESTION: May we remove this constructor? + /** + * @deprecated Retained for backward compatibility. Scheduled for removal.
+ * Use {@link #MLEngine(Path, Encryptor, ClusterService)} instead. + */ + @Deprecated(forRemoval = true) public MLEngine(Path opensearchDataFolder, Encryptor encryptor) { this(opensearchDataFolder, encryptor, null); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java index ff8c0d9b4d..d2e489d03e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java @@ -345,7 +345,6 @@ public ModelTensors parseModelTensorOutput(Output output, ModelResultFilter resu public Settings getClusterSettings() { if (mlEngine.getClusterService() != null) { - // QUESTION: removing the constructor will make cluster settings non-null return mlEngine.getClusterService().getSettings(); } return Settings.EMPTY;