Skip to content

Conversation

@urmichm
Copy link

@urmichm urmichm commented Dec 6, 2025

Description

Text Similarity processes in batches

Related Issues

Resolves #4276

Check List

  • New functionality includes testing.
  • New functionality has been documented.
  • API changes companion pull request created.
  • Commits are signed per the DCO using --signoff.
  • Public documentation issue/PR created.

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

Summary by CodeRabbit

  • New Features

    • Implemented batch processing for text similarity predictions, improving performance for large-scale similarity computations.
    • Added configurable ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE setting to control batch sizes (default: 500, range: 1–1000).
  • Tests

    • Added batch processing validation tests for text similarity operations.

✏️ Tip: You can customize this high-level summary in your review settings.

Text similarity is done in batches. Translator is updated to handle the batch process input/output

Signed-off-by: Michael <urmich.m@gmail.com>
- TextSimilarityTranslator unit tests for batch process input and output.
- `spotlessApply` applied

Signed-off-by: Michael <urmich.m@gmail.com>
- old MLEngine constructor marked as deprecated
- useless comment on DLModel removed

Signed-off-by: Michael <urmich.m@gmail.com>
@coderabbitai
Copy link

coderabbitai bot commented Dec 6, 2025

Walkthrough

This PR implements batch processing for text similarity predictions to improve efficiency. It adds a configurable batch size setting, integrates ClusterService into MLEngine for cluster setting access, and refactors TextSimilarityCrossEncoderModel to process predictions in batches instead of one-by-one.

Changes

Cohort / File(s) Change Summary
Configuration
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java
Added ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE setting (default 500, range 1–1000, NodeScope/Dynamic)
Infrastructure & Dependency Injection
ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java, plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java
Integrated ClusterService into MLEngine; added new constructor MLEngine(Path, Encryptor, ClusterService) with backward-compatible deprecated constructor; updated plugin to instantiate MLEngine with ClusterService; added ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE to settings registry
Model Layer & Batch Processing
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java, ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java, ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java
Added getClusterSettings() accessor in DLModel; refactored TextSimilarityCrossEncoderModel to use batchPredict() instead of per-document prediction loops; introduced batchProcessInput() and batchProcessOutput() in TextSimilarityTranslator for multi-document batch processing
Tests
ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java
Added test_TextSimilarity_Translator_BatchProcessInput() and test_TextSimilarity_Translator_BatchProcessOutput() to verify batch tokenization and output conversion

Sequence Diagram

sequenceDiagram
    participant Client
    participant TextSimilarityCrossEncoderModel
    participant DLModel
    participant MLEngine
    participant ClusterService
    participant TextSimilarityTranslator
    participant Predictor

    Client->>TextSimilarityCrossEncoderModel: predict(modelId, mlInput)
    TextSimilarityCrossEncoderModel->>DLModel: getClusterSettings()
    DLModel->>MLEngine: getClusterService()
    MLEngine-->>DLModel: clusterService
    DLModel->>ClusterService: getSettings()
    ClusterService-->>DLModel: settings
    DLModel-->>TextSimilarityCrossEncoderModel: clusterSettings

    Note over TextSimilarityCrossEncoderModel: Extract batchSize<br/>from settings

    loop For each batch of textDocs
        TextSimilarityCrossEncoderModel->>TextSimilarityCrossEncoderModel: Build batch inputs
        TextSimilarityCrossEncoderModel->>TextSimilarityTranslator: batchProcessInput(batchInputs)
        TextSimilarityTranslator-->>TextSimilarityTranslator: Tokenize & encode batch
        TextSimilarityTranslator-->>TextSimilarityCrossEncoderModel: NDList (batched tokens)
        
        TextSimilarityCrossEncoderModel->>Predictor: batchPredict(ndList)
        Predictor-->>TextSimilarityCrossEncoderModel: NDArray (batch output)
        
        TextSimilarityCrossEncoderModel->>TextSimilarityTranslator: batchProcessOutput(ndArray)
        TextSimilarityTranslator-->>TextSimilarityTranslator: Convert to ModelTensors
        TextSimilarityTranslator-->>TextSimilarityCrossEncoderModel: ModelTensors
        
        TextSimilarityCrossEncoderModel->>TextSimilarityCrossEncoderModel: Accumulate outputs
    end

    TextSimilarityCrossEncoderModel-->>Client: ModelTensorOutput
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20–30 minutes

  • Areas requiring extra attention:
    • Batch processing logic in TextSimilarityCrossEncoderModel to ensure correct batch assembly, slicing, and output aggregation
    • ClusterService null-safety in DLModel.getClusterSettings() and batch size fallback behavior
    • Backward compatibility and proper delegation in deprecated MLEngine(Path, Encryptor) constructor
    • Test coverage for edge cases: empty batches, batch size boundary conditions, null settings scenarios

Poem

🐰 One by one we'd hop and bound,
But batches make us spin around!
Grouped together, fast and fleet,
Similarity predictions: sweet! 💨✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title 'Text Similarity' is vague and generic, lacking specific details about the actual change (batch processing implementation). Use a more descriptive title such as 'Implement batch processing for Text Similarity predictions' to clearly convey the main change.
✅ Passed checks (3 passed)
Check name Status Explanation
Description check ✅ Passed The PR description covers the main purpose and includes a reference to the related issue, but lacks documentation and API specification updates in the checklist.
Linked Issues check ✅ Passed The PR successfully implements batch processing for TextSimilarityCrossEncoderModel by adding configurable batch size settings, introducing batchProcessInput/batchProcessOutput methods, and refactoring predictions to use batchPredict instead of per-item predict loops.
Out of Scope Changes check ✅ Passed All changes are directly related to implementing batch processing for text similarity: new batch size setting, ClusterService integration for settings access, batch methods in translator, and updated tests for batch functionality.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@urmichm
Copy link
Author

urmichm commented Dec 6, 2025

@coderabbitai review

@coderabbitai
Copy link

coderabbitai bot commented Dec 6, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@urmichm urmichm requested a deployment to ml-commons-cicd-env-require-approval December 6, 2025 19:27 — with GitHub Actions Waiting
@urmichm urmichm requested a deployment to ml-commons-cicd-env-require-approval December 6, 2025 19:27 — with GitHub Actions Waiting
@urmichm urmichm requested a deployment to ml-commons-cicd-env-require-approval December 6, 2025 19:27 — with GitHub Actions Waiting
@urmichm urmichm requested a deployment to ml-commons-cicd-env-require-approval December 6, 2025 19:27 — with GitHub Actions Waiting
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java (1)

79-87: Text similarity batch size setting looks consistent and well-scoped

ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE (default 500, range 1–1000, NodeScope/Dynamic) fits the intended use for controlling batch size and matches patterns of nearby ML Commons settings. Just ensure it’s documented in the public settings docs.

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java (1)

32-36: Batch translator tests cover the new input/output paths

The new test_TextSimilarity_Translator_BatchProcessInput and _BatchProcessOutput mirror the existing single-item tests and validate the NDList shape and serialized ModelTensor outputs for batched cases. This gives good coverage of the new translator methods.

Minor follow-up (non-blocking): setUp() still uses the deprecated MLEngine(Path, Encryptor) ctor; consider switching tests to the 3-arg ctor with a dummy/mocked ClusterService when convenient to keep tests aligned with the preferred API.

Also applies to: 178-250

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6d33e91 and 216990b.

📒 Files selected for processing (7)
  • common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java (1 hunks)
  • ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java (2 hunks)
  • ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java (2 hunks)
  • ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java (2 hunks)
  • ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java (2 hunks)
  • ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java (2 hunks)
  • plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java (1)
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java (1)
  • MLCommonsSettings (24-495)
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java (1)
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java (1)
  • MLCommonsSettings (24-495)
🔇 Additional comments (4)
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java (1)

138-167: batchProcessOutput implementation aligns with per-item output handling

The batched output handling correctly slices the leading batch dimension, wraps each item as a single ModelTensor named "similarity", and serializes via ModelTensors. This matches the single-input processOutput contract and looks good.

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java (1)

685-685: ClusterService wiring into MLEngine and settings exposure look correct

Using new MLEngine(dataPath, encryptor, clusterService) and adding ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE to getSettings() cleanly hook the new batch-size configuration into the plugin lifecycle; this should make DLModel.getClusterSettings() usable from algorithms without extra plumbing.

Also applies to: 1367-1368

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java (1)

55-73: MLEngine’s new ClusterService-aware constructor maintains BWC

Storing ClusterService via the new (Path, Encryptor, ClusterService) ctor and having the deprecated 2-arg ctor delegate with null cleanly supports both existing call sites and new cluster-aware consumers (e.g., DLModel.getClusterSettings()), without changing existing behavior.

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java (1)

23-24: Batched predict implementation correctly chunks inputs by configurable batch size

The new predict implementation cleanly:

  • Reads ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE from cluster settings (with a sane default via Settings.EMPTY),
  • Chunks textDocs by that batch size,
  • Builds paired Input objects per (queryText, doc) and calls batchPredict,
  • Accumulates ModelTensors into the final ModelTensorOutput.

The batching logic and loop boundaries look correct and should preserve previous semantics while improving efficiency.

Also applies to: 31-32, 50-75

Comment on lines +346 to +351
public Settings getClusterSettings() {
if (mlEngine.getClusterService() != null) {
return mlEngine.getClusterService().getSettings();
}
return Settings.EMPTY;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard against potential null mlEngine in getClusterSettings()

As written, calling getClusterSettings() before initModel assigns mlEngine will throw an NPE on mlEngine.getClusterService(). Suggest a small defensive check:

-    public Settings getClusterSettings() {
-        if (mlEngine.getClusterService() != null) {
-            return mlEngine.getClusterService().getSettings();
-        }
-        return Settings.EMPTY;
-    }
+    public Settings getClusterSettings() {
+        if (mlEngine != null && mlEngine.getClusterService() != null) {
+            return mlEngine.getClusterService().getSettings();
+        }
+        return Settings.EMPTY;
+    }

[scratchpad_start]
[done]
DLModel.java reviewed.
[scratchpad_end] -->

🤖 Prompt for AI Agents
In ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java
around lines 346 to 351, the method getClusterSettings() can NPE if mlEngine is
null; add a defensive null check for mlEngine and for
mlEngine.getClusterService() before calling getSettings(), and return
Settings.EMPTY when either is null so the method never dereferences a null
pointer.

Comment on lines +99 to +136
@Override
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) {
NDManager manager = ctx.getNDManager();
int batchSize = inputs.size();
List<String> sentences = new ArrayList<>(batchSize);
List<String> contexts = new ArrayList<>(batchSize);
for (Input input : inputs) {
String sentence = input.getAsString(0);
String context = input.getAsString(1);
sentences.add(sentence);
contexts.add(context);
}
// Tokenize in batches
Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts));
int seqLen = encodings[0].getIds().length;
for (Encoding enc : encodings) {
seqLen = Math.max(seqLen, enc.getIds().length);
}
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMasks = new long[batchSize][seqLen];
long[][] tokenTypeIds = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
inputIds[i] = encodings[i].getIds();
attentionMasks[i] = encodings[i].getAttentionMask();
tokenTypeIds[i] = encodings[i].getTypeIds();
}
NDArray inputIdsArray = manager.create(inputIds);
inputIdsArray.setName("input_ids");
NDArray attentionMaskArray = manager.create(attentionMasks);
attentionMaskArray.setName("attention_mask");
NDArray tokenTypeArray = manager.create(tokenTypeIds);
tokenTypeArray.setName("token_type_ids");
NDList ndList = new NDList();
ndList.add(inputIdsArray);
ndList.add(attentionMaskArray);
ndList.add(tokenTypeArray);
return ndList;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix 2D array construction in batchProcessInput

You compute seqLen and allocate [batchSize][seqLen], but then overwrite each row with encodings[i].getIds() / masks / types. That makes the initial allocation and seqLen effectively unused and can produce ragged arrays.

Recommend keeping the fixed row length and copying values instead:

-        int seqLen = encodings[0].getIds().length;
-        for (Encoding enc : encodings) {
-            seqLen = Math.max(seqLen, enc.getIds().length);
-        }
-        long[][] inputIds = new long[batchSize][seqLen];
-        long[][] attentionMasks = new long[batchSize][seqLen];
-        long[][] tokenTypeIds = new long[batchSize][seqLen];
-        for (int i = 0; i < batchSize; i++) {
-            inputIds[i] = encodings[i].getIds();
-            attentionMasks[i] = encodings[i].getAttentionMask();
-            tokenTypeIds[i] = encodings[i].getTypeIds();
-        }
+        int seqLen = encodings[0].getIds().length;
+        for (Encoding enc : encodings) {
+            seqLen = Math.max(seqLen, enc.getIds().length);
+        }
+
+        long[][] inputIds = new long[batchSize][seqLen];
+        long[][] attentionMasks = new long[batchSize][seqLen];
+        long[][] tokenTypeIds = new long[batchSize][seqLen];
+        for (int i = 0; i < batchSize; i++) {
+            long[] ids = encodings[i].getIds();
+            long[] masks = encodings[i].getAttentionMask();
+            long[] types = encodings[i].getTypeIds();
+            System.arraycopy(ids, 0, inputIds[i], 0, ids.length);
+            System.arraycopy(masks, 0, attentionMasks[i], 0, masks.length);
+            System.arraycopy(types, 0, tokenTypeIds[i], 0, types.length);
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@Override
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) {
NDManager manager = ctx.getNDManager();
int batchSize = inputs.size();
List<String> sentences = new ArrayList<>(batchSize);
List<String> contexts = new ArrayList<>(batchSize);
for (Input input : inputs) {
String sentence = input.getAsString(0);
String context = input.getAsString(1);
sentences.add(sentence);
contexts.add(context);
}
// Tokenize in batches
Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts));
int seqLen = encodings[0].getIds().length;
for (Encoding enc : encodings) {
seqLen = Math.max(seqLen, enc.getIds().length);
}
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMasks = new long[batchSize][seqLen];
long[][] tokenTypeIds = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
inputIds[i] = encodings[i].getIds();
attentionMasks[i] = encodings[i].getAttentionMask();
tokenTypeIds[i] = encodings[i].getTypeIds();
}
NDArray inputIdsArray = manager.create(inputIds);
inputIdsArray.setName("input_ids");
NDArray attentionMaskArray = manager.create(attentionMasks);
attentionMaskArray.setName("attention_mask");
NDArray tokenTypeArray = manager.create(tokenTypeIds);
tokenTypeArray.setName("token_type_ids");
NDList ndList = new NDList();
ndList.add(inputIdsArray);
ndList.add(attentionMaskArray);
ndList.add(tokenTypeArray);
return ndList;
}
@Override
public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) {
NDManager manager = ctx.getNDManager();
int batchSize = inputs.size();
List<String> sentences = new ArrayList<>(batchSize);
List<String> contexts = new ArrayList<>(batchSize);
for (Input input : inputs) {
String sentence = input.getAsString(0);
String context = input.getAsString(1);
sentences.add(sentence);
contexts.add(context);
}
// Tokenize in batches
Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts));
int seqLen = encodings[0].getIds().length;
for (Encoding enc : encodings) {
seqLen = Math.max(seqLen, enc.getIds().length);
}
long[][] inputIds = new long[batchSize][seqLen];
long[][] attentionMasks = new long[batchSize][seqLen];
long[][] tokenTypeIds = new long[batchSize][seqLen];
for (int i = 0; i < batchSize; i++) {
long[] ids = encodings[i].getIds();
long[] masks = encodings[i].getAttentionMask();
long[] types = encodings[i].getTypeIds();
System.arraycopy(ids, 0, inputIds[i], 0, ids.length);
System.arraycopy(masks, 0, attentionMasks[i], 0, masks.length);
System.arraycopy(types, 0, tokenTypeIds[i], 0, types.length);
}
NDArray inputIdsArray = manager.create(inputIds);
inputIdsArray.setName("input_ids");
NDArray attentionMaskArray = manager.create(attentionMasks);
attentionMaskArray.setName("attention_mask");
NDArray tokenTypeArray = manager.create(tokenTypeIds);
tokenTypeArray.setName("token_type_ids");
NDList ndList = new NDList();
ndList.add(inputIdsArray);
ndList.add(attentionMaskArray);
ndList.add(tokenTypeArray);
return ndList;
}
🤖 Prompt for AI Agents
In
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java
around lines 99 to 136, the code allocates fixed-size 2D arrays
[batchSize][seqLen] then replaces entire rows with encodings[i].getIds() etc.,
which discards the fixed length and creates ragged arrays; instead, copy each
encoding into the preallocated row and pad remaining elements to the fixed
seqLen (e.g., with 0 for input_ids and token_type_ids, and 0 for attention_mask)
so every row has length seqLen, or throw if an encoding exceeds seqLen;
implement a loop that for each i copies enc.getIds(), enc.getAttentionMask(),
enc.getTypeIds() into inputIds[i][0..len-1] etc., leaving the rest as zeros
before creating NDArrays.

@dhrubo-os
Copy link
Collaborator

This PR is in Draft. Is this PR ready for review?

return tensorOutput;
}

public Settings getClusterSettings() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually add all our settings here. Can you please check how other settings are propagated and follow accordingly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will check it. Thank you

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question. Would you like the Batch Processing to be enabled or disabled with a boolean feature from MLFeatureEnabledSetting?
The Batch Size, however, is passed from the MLCommonsSettings.java

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Why not Batch Predict?

2 participants