-
Notifications
You must be signed in to change notification settings - Fork 186
Text Similarity #4500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Text Similarity #4500
Conversation
Text similarity is done in batches. Translator is updated to handle the batch process input/output Signed-off-by: Michael <urmich.m@gmail.com>
- TextSimilarityTranslator unit tests for batch process input and output. - `spotlessApply` applied Signed-off-by: Michael <urmich.m@gmail.com>
- old MLEngine constructor marked as deprecated - useless comment on DLModel removed Signed-off-by: Michael <urmich.m@gmail.com>
WalkthroughThis PR implements batch processing for text similarity predictions to improve efficiency. It adds a configurable batch size setting, integrates ClusterService into MLEngine for cluster setting access, and refactors TextSimilarityCrossEncoderModel to process predictions in batches instead of one-by-one. Changes
Sequence DiagramsequenceDiagram
participant Client
participant TextSimilarityCrossEncoderModel
participant DLModel
participant MLEngine
participant ClusterService
participant TextSimilarityTranslator
participant Predictor
Client->>TextSimilarityCrossEncoderModel: predict(modelId, mlInput)
TextSimilarityCrossEncoderModel->>DLModel: getClusterSettings()
DLModel->>MLEngine: getClusterService()
MLEngine-->>DLModel: clusterService
DLModel->>ClusterService: getSettings()
ClusterService-->>DLModel: settings
DLModel-->>TextSimilarityCrossEncoderModel: clusterSettings
Note over TextSimilarityCrossEncoderModel: Extract batchSize<br/>from settings
loop For each batch of textDocs
TextSimilarityCrossEncoderModel->>TextSimilarityCrossEncoderModel: Build batch inputs
TextSimilarityCrossEncoderModel->>TextSimilarityTranslator: batchProcessInput(batchInputs)
TextSimilarityTranslator-->>TextSimilarityTranslator: Tokenize & encode batch
TextSimilarityTranslator-->>TextSimilarityCrossEncoderModel: NDList (batched tokens)
TextSimilarityCrossEncoderModel->>Predictor: batchPredict(ndList)
Predictor-->>TextSimilarityCrossEncoderModel: NDArray (batch output)
TextSimilarityCrossEncoderModel->>TextSimilarityTranslator: batchProcessOutput(ndArray)
TextSimilarityTranslator-->>TextSimilarityTranslator: Convert to ModelTensors
TextSimilarityTranslator-->>TextSimilarityCrossEncoderModel: ModelTensors
TextSimilarityCrossEncoderModel->>TextSimilarityCrossEncoderModel: Accumulate outputs
end
TextSimilarityCrossEncoderModel-->>Client: ModelTensorOutput
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20–30 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java (1)
79-87: Text similarity batch size setting looks consistent and well-scoped
ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZE(default 500, range 1–1000, NodeScope/Dynamic) fits the intended use for controlling batch size and matches patterns of nearby ML Commons settings. Just ensure it’s documented in the public settings docs.ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java (1)
32-36: Batch translator tests cover the new input/output pathsThe new
test_TextSimilarity_Translator_BatchProcessInputand_BatchProcessOutputmirror the existing single-item tests and validate the NDList shape and serializedModelTensoroutputs for batched cases. This gives good coverage of the new translator methods.Minor follow-up (non-blocking):
setUp()still uses the deprecatedMLEngine(Path, Encryptor)ctor; consider switching tests to the 3-arg ctor with a dummy/mockedClusterServicewhen convenient to keep tests aligned with the preferred API.Also applies to: 178-250
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java(1 hunks)ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java(2 hunks)ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java(2 hunks)ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java(2 hunks)ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java(2 hunks)ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModelTest.java(2 hunks)plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java (1)
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java (1)
MLCommonsSettings(24-495)
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java (1)
common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java (1)
MLCommonsSettings(24-495)
🔇 Additional comments (4)
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java (1)
138-167:batchProcessOutputimplementation aligns with per-item output handlingThe batched output handling correctly slices the leading batch dimension, wraps each item as a single
ModelTensornamed"similarity", and serializes viaModelTensors. This matches the single-inputprocessOutputcontract and looks good.plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java (1)
685-685: ClusterService wiring intoMLEngineand settings exposure look correctUsing
new MLEngine(dataPath, encryptor, clusterService)and addingML_COMMONS_TEXT_SIMILARITY_BATCH_SIZEtogetSettings()cleanly hook the new batch-size configuration into the plugin lifecycle; this should makeDLModel.getClusterSettings()usable from algorithms without extra plumbing.Also applies to: 1367-1368
ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java (1)
55-73: MLEngine’s new ClusterService-aware constructor maintains BWCStoring
ClusterServicevia the new(Path, Encryptor, ClusterService)ctor and having the deprecated 2-arg ctor delegate withnullcleanly supports both existing call sites and new cluster-aware consumers (e.g.,DLModel.getClusterSettings()), without changing existing behavior.ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityCrossEncoderModel.java (1)
23-24: Batched predict implementation correctly chunks inputs by configurable batch sizeThe new
predictimplementation cleanly:
- Reads
ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZEfrom cluster settings (with a sane default viaSettings.EMPTY),- Chunks
textDocsby that batch size,- Builds paired
Inputobjects per(queryText, doc)and callsbatchPredict,- Accumulates
ModelTensorsinto the finalModelTensorOutput.The batching logic and loop boundaries look correct and should preserve previous semantics while improving efficiency.
Also applies to: 31-32, 50-75
| public Settings getClusterSettings() { | ||
| if (mlEngine.getClusterService() != null) { | ||
| return mlEngine.getClusterService().getSettings(); | ||
| } | ||
| return Settings.EMPTY; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against potential null mlEngine in getClusterSettings()
As written, calling getClusterSettings() before initModel assigns mlEngine will throw an NPE on mlEngine.getClusterService(). Suggest a small defensive check:
- public Settings getClusterSettings() {
- if (mlEngine.getClusterService() != null) {
- return mlEngine.getClusterService().getSettings();
- }
- return Settings.EMPTY;
- }
+ public Settings getClusterSettings() {
+ if (mlEngine != null && mlEngine.getClusterService() != null) {
+ return mlEngine.getClusterService().getSettings();
+ }
+ return Settings.EMPTY;
+ }[scratchpad_start]
[done]
DLModel.java reviewed.
[scratchpad_end] -->
🤖 Prompt for AI Agents
In ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java
around lines 346 to 351, the method getClusterSettings() can NPE if mlEngine is
null; add a defensive null check for mlEngine and for
mlEngine.getClusterService() before calling getSettings(), and return
Settings.EMPTY when either is null so the method never dereferences a null
pointer.
| @Override | ||
| public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) { | ||
| NDManager manager = ctx.getNDManager(); | ||
| int batchSize = inputs.size(); | ||
| List<String> sentences = new ArrayList<>(batchSize); | ||
| List<String> contexts = new ArrayList<>(batchSize); | ||
| for (Input input : inputs) { | ||
| String sentence = input.getAsString(0); | ||
| String context = input.getAsString(1); | ||
| sentences.add(sentence); | ||
| contexts.add(context); | ||
| } | ||
| // Tokenize in batches | ||
| Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts)); | ||
| int seqLen = encodings[0].getIds().length; | ||
| for (Encoding enc : encodings) { | ||
| seqLen = Math.max(seqLen, enc.getIds().length); | ||
| } | ||
| long[][] inputIds = new long[batchSize][seqLen]; | ||
| long[][] attentionMasks = new long[batchSize][seqLen]; | ||
| long[][] tokenTypeIds = new long[batchSize][seqLen]; | ||
| for (int i = 0; i < batchSize; i++) { | ||
| inputIds[i] = encodings[i].getIds(); | ||
| attentionMasks[i] = encodings[i].getAttentionMask(); | ||
| tokenTypeIds[i] = encodings[i].getTypeIds(); | ||
| } | ||
| NDArray inputIdsArray = manager.create(inputIds); | ||
| inputIdsArray.setName("input_ids"); | ||
| NDArray attentionMaskArray = manager.create(attentionMasks); | ||
| attentionMaskArray.setName("attention_mask"); | ||
| NDArray tokenTypeArray = manager.create(tokenTypeIds); | ||
| tokenTypeArray.setName("token_type_ids"); | ||
| NDList ndList = new NDList(); | ||
| ndList.add(inputIdsArray); | ||
| ndList.add(attentionMaskArray); | ||
| ndList.add(tokenTypeArray); | ||
| return ndList; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix 2D array construction in batchProcessInput
You compute seqLen and allocate [batchSize][seqLen], but then overwrite each row with encodings[i].getIds() / masks / types. That makes the initial allocation and seqLen effectively unused and can produce ragged arrays.
Recommend keeping the fixed row length and copying values instead:
- int seqLen = encodings[0].getIds().length;
- for (Encoding enc : encodings) {
- seqLen = Math.max(seqLen, enc.getIds().length);
- }
- long[][] inputIds = new long[batchSize][seqLen];
- long[][] attentionMasks = new long[batchSize][seqLen];
- long[][] tokenTypeIds = new long[batchSize][seqLen];
- for (int i = 0; i < batchSize; i++) {
- inputIds[i] = encodings[i].getIds();
- attentionMasks[i] = encodings[i].getAttentionMask();
- tokenTypeIds[i] = encodings[i].getTypeIds();
- }
+ int seqLen = encodings[0].getIds().length;
+ for (Encoding enc : encodings) {
+ seqLen = Math.max(seqLen, enc.getIds().length);
+ }
+
+ long[][] inputIds = new long[batchSize][seqLen];
+ long[][] attentionMasks = new long[batchSize][seqLen];
+ long[][] tokenTypeIds = new long[batchSize][seqLen];
+ for (int i = 0; i < batchSize; i++) {
+ long[] ids = encodings[i].getIds();
+ long[] masks = encodings[i].getAttentionMask();
+ long[] types = encodings[i].getTypeIds();
+ System.arraycopy(ids, 0, inputIds[i], 0, ids.length);
+ System.arraycopy(masks, 0, attentionMasks[i], 0, masks.length);
+ System.arraycopy(types, 0, tokenTypeIds[i], 0, types.length);
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @Override | |
| public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) { | |
| NDManager manager = ctx.getNDManager(); | |
| int batchSize = inputs.size(); | |
| List<String> sentences = new ArrayList<>(batchSize); | |
| List<String> contexts = new ArrayList<>(batchSize); | |
| for (Input input : inputs) { | |
| String sentence = input.getAsString(0); | |
| String context = input.getAsString(1); | |
| sentences.add(sentence); | |
| contexts.add(context); | |
| } | |
| // Tokenize in batches | |
| Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts)); | |
| int seqLen = encodings[0].getIds().length; | |
| for (Encoding enc : encodings) { | |
| seqLen = Math.max(seqLen, enc.getIds().length); | |
| } | |
| long[][] inputIds = new long[batchSize][seqLen]; | |
| long[][] attentionMasks = new long[batchSize][seqLen]; | |
| long[][] tokenTypeIds = new long[batchSize][seqLen]; | |
| for (int i = 0; i < batchSize; i++) { | |
| inputIds[i] = encodings[i].getIds(); | |
| attentionMasks[i] = encodings[i].getAttentionMask(); | |
| tokenTypeIds[i] = encodings[i].getTypeIds(); | |
| } | |
| NDArray inputIdsArray = manager.create(inputIds); | |
| inputIdsArray.setName("input_ids"); | |
| NDArray attentionMaskArray = manager.create(attentionMasks); | |
| attentionMaskArray.setName("attention_mask"); | |
| NDArray tokenTypeArray = manager.create(tokenTypeIds); | |
| tokenTypeArray.setName("token_type_ids"); | |
| NDList ndList = new NDList(); | |
| ndList.add(inputIdsArray); | |
| ndList.add(attentionMaskArray); | |
| ndList.add(tokenTypeArray); | |
| return ndList; | |
| } | |
| @Override | |
| public NDList batchProcessInput(TranslatorContext ctx, List<Input> inputs) { | |
| NDManager manager = ctx.getNDManager(); | |
| int batchSize = inputs.size(); | |
| List<String> sentences = new ArrayList<>(batchSize); | |
| List<String> contexts = new ArrayList<>(batchSize); | |
| for (Input input : inputs) { | |
| String sentence = input.getAsString(0); | |
| String context = input.getAsString(1); | |
| sentences.add(sentence); | |
| contexts.add(context); | |
| } | |
| // Tokenize in batches | |
| Encoding[] encodings = tokenizer.batchEncode(new PairList<>(sentences, contexts)); | |
| int seqLen = encodings[0].getIds().length; | |
| for (Encoding enc : encodings) { | |
| seqLen = Math.max(seqLen, enc.getIds().length); | |
| } | |
| long[][] inputIds = new long[batchSize][seqLen]; | |
| long[][] attentionMasks = new long[batchSize][seqLen]; | |
| long[][] tokenTypeIds = new long[batchSize][seqLen]; | |
| for (int i = 0; i < batchSize; i++) { | |
| long[] ids = encodings[i].getIds(); | |
| long[] masks = encodings[i].getAttentionMask(); | |
| long[] types = encodings[i].getTypeIds(); | |
| System.arraycopy(ids, 0, inputIds[i], 0, ids.length); | |
| System.arraycopy(masks, 0, attentionMasks[i], 0, masks.length); | |
| System.arraycopy(types, 0, tokenTypeIds[i], 0, types.length); | |
| } | |
| NDArray inputIdsArray = manager.create(inputIds); | |
| inputIdsArray.setName("input_ids"); | |
| NDArray attentionMaskArray = manager.create(attentionMasks); | |
| attentionMaskArray.setName("attention_mask"); | |
| NDArray tokenTypeArray = manager.create(tokenTypeIds); | |
| tokenTypeArray.setName("token_type_ids"); | |
| NDList ndList = new NDList(); | |
| ndList.add(inputIdsArray); | |
| ndList.add(attentionMaskArray); | |
| ndList.add(tokenTypeArray); | |
| return ndList; | |
| } |
🤖 Prompt for AI Agents
In
ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_similarity/TextSimilarityTranslator.java
around lines 99 to 136, the code allocates fixed-size 2D arrays
[batchSize][seqLen] then replaces entire rows with encodings[i].getIds() etc.,
which discards the fixed length and creates ragged arrays; instead, copy each
encoding into the preallocated row and pad remaining elements to the fixed
seqLen (e.g., with 0 for input_ids and token_type_ids, and 0 for attention_mask)
so every row has length seqLen, or throw if an encoding exceeds seqLen;
implement a loop that for each i copies enc.getIds(), enc.getAttentionMask(),
enc.getTypeIds() into inputIds[i][0..len-1] etc., leaving the rest as zeros
before creating NDArrays.
|
This PR is in Draft. Is this PR ready for review? |
| return tensorOutput; | ||
| } | ||
|
|
||
| public Settings getClusterSettings() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ml-commons/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java
Line 34 in 6d33e91
| public class MLFeatureEnabledSetting { |
We usually add all our settings here. Can you please check how other settings are propagated and follow accordingly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will check it. Thank you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a question. Would you like the Batch Processing to be enabled or disabled with a boolean feature from MLFeatureEnabledSetting?
The Batch Size, however, is passed from the MLCommonsSettings.java
Description
Text Similarity processes in batches
Related Issues
Resolves #4276
Check List
--signoff.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.
Summary by CodeRabbit
New Features
ML_COMMONS_TEXT_SIMILARITY_BATCH_SIZEsetting to control batch sizes (default: 500, range: 1–1000).Tests
✏️ Tip: You can customize this high-level summary in your review settings.