diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index a84e3c945c..4aa65190b9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -75,6 +75,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; @@ -82,9 +84,15 @@ import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.connector.McpConnector; import org.opensearch.ml.common.connector.McpStreamableHttpConnector; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.algorithms.remote.McpConnectorExecutor; @@ -1178,4 +1186,96 @@ private static Map parseStringMapParameter(String rawValue, Stri return null; } } + + public static String extractSummaryFromResponse(MLTaskResponse response, Map parameters) { + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { + throw new IllegalStateException("No model output available in response"); + } + + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { + throw new IllegalStateException("No model tensors available in response"); + } + + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + + if (tensor.getDataAsMap() == null) { + throw new IllegalStateException("No data map available in tensor"); + } + + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey(RESPONSE_FIELD)) { + return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); + } + + if (dataMap.containsKey("output")) { + try { + Object outputObj = JsonPath.read(dataMap, parameters.get(LLM_RESPONSE_FILTER)); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); + } + } catch (PathNotFoundException e) { + throw new IllegalStateException("Failed to extract output using filter path", e); + } + } + + throw new IllegalStateException("No result/response field found. Available fields: " + dataMap.keySet()); + } + + public static void generateMaxStepSummary( + Client client, + LLMSpec llmSpec, + String promptContent, + String systemPrompt, + Map allParams, + String tenantId, + ActionListener listener + ) { + if (promptContent == null || promptContent.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Prompt content cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + summaryParams.putAll(allParams); + summaryParams.put(MLPlanExecuteAndReflectAgentRunner.PROMPT_FIELD, promptContent); + summaryParams.put(MLPlanExecuteAndReflectAgentRunner.SYSTEM_PROMPT_FIELD, systemPrompt); + + MLPredictionTaskRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + try { + String summary = extractSummaryFromResponse(response, summaryParams); + if (summary == null || summary.isEmpty()) { + listener.onFailure(new RuntimeException("Empty LLM summary response")); + return; + } + listener.onResponse(summary); + } catch (IllegalStateException e) { + log.error("Failed to extract summary, using fallback", e); + listener.onFailure(e); + } + }, listener::onFailure)); + } catch (IllegalStateException e) { + log.error("Failed to generate summary", e); + listener.onFailure(e); + } + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 38be5cf5f0..8dfa36e469 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -19,7 +19,6 @@ import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.INTERACTIONS_PREFIX; -import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; @@ -67,16 +66,13 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.hooks.HookRegistry; -import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.common.output.model.ModelTensor; @@ -84,8 +80,6 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; @@ -100,7 +94,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.gson.reflect.TypeToken; -import com.jayway.jsonpath.JsonPath; import lombok.Data; import lombok.NoArgsConstructor; @@ -502,7 +495,11 @@ private void runReAct( ); toolParams.put(TENANT_ID_FIELD, tenantId); lastToolParams.clear(); - lastToolParams.putAll(toolParams); + toolParams + .entrySet() + .stream() + .filter(e -> e.getValue() != null) + .forEach(e -> lastToolParams.put(e.getKey(), e.getValue())); runTool( tools, toolSpecMap, @@ -1128,12 +1125,6 @@ void generateLLMSummary( } try { - Map summaryParams = new HashMap<>(); - if (llmSpec.getParameters() != null) { - summaryParams.putAll(llmSpec.getParameters()); - } - summaryParams.putAll(parameter); - // Convert ModelTensors to strings before joining, skip session/interaction IDs List stepStrings = new ArrayList<>(); for (ModelTensors tensor : stepsSummary) { @@ -1152,74 +1143,24 @@ void generateLLMSummary( } } } - String steps = String.format(Locale.ROOT, "Question: %s\n\nCompleted Steps:\n%s", question, String.join("\n", stepStrings)); - summaryParams.put(PROMPT, steps); - summaryParams.put(SYSTEM_PROMPT_FIELD, MAX_STEP_SUMMARY_CHAT_AGENT_SYSTEM_PROMPT); - - ActionRequest request = new MLPredictionTaskRequest( - llmSpec.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) - .build(), - null, - tenantId - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { - String summary = extractSummaryFromResponse(response, summaryParams); - if (summary == null) { - listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); - return; - } - listener.onResponse(summary); - }, listener::onFailure)); + String promptContent = String + .format(Locale.ROOT, "Question: %s\n\nCompleted Steps:\n%s", question, String.join("\n", stepStrings)); + + AgentUtils + .generateMaxStepSummary( + client, + llmSpec, + promptContent, + MAX_STEP_SUMMARY_CHAT_AGENT_SYSTEM_PROMPT, + parameter, + tenantId, + listener + ); } catch (Exception e) { listener.onFailure(e); } } - public String extractSummaryFromResponse(MLTaskResponse response, Map params) { - try { - ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); - if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { - return null; - } - - ModelTensors tensors = output.getMlModelOutputs().getFirst(); - if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { - return null; - } - - ModelTensor tensor = tensors.getMlModelTensors().getFirst(); - if (tensor.getResult() != null) { - return tensor.getResult().trim(); - } - - if (tensor.getDataAsMap() == null) { - return null; - } - - Map dataMap = tensor.getDataAsMap(); - if (dataMap.containsKey("response")) { - return String.valueOf(dataMap.get("response")).trim(); - } - - if (dataMap.containsKey("output")) { - Object outputObj = JsonPath.read(dataMap, params.get(LLM_RESPONSE_FILTER)); - if (outputObj != null) { - return String.valueOf(outputObj).trim(); - } - } - - log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); - return null; - } catch (Exception e) { - log.error("Failed to extract summary from response", e); - throw new RuntimeException("Failed to extract summary from response", e); - } - } - private void saveMessage( Memory memory, String question, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 8a1f7a6e7f..a29127e504 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -853,6 +853,7 @@ private void handleMaxStepsReached( ); }, finalListener::onFailure); + String fallbackResult = generateFallbackResult(maxSteps, completedSteps); generateSummary(llm, completedSteps, allParams, ActionListener.wrap(summary -> { log.info("Summary generated successfully"); responseListener @@ -861,17 +862,6 @@ private void handleMaxStepsReached( ); }, e -> { log.error("Summary generation failed, using fallback", e); - String fallbackResult = completedSteps.isEmpty() || completedSteps.size() < 2 - ? String.format("Max Steps Limit (%d) Reached. Use memory_id with same task to restart.", maxSteps) - : String - .format( - "Max Steps Limit (%d) Reached. Use memory_id with same task to restart. \n " - + "Last executed step: %s, \n " - + "Last executed step result: %s", - maxSteps, - completedSteps.get(completedSteps.size() - 2), - completedSteps.getLast() - ); responseListener.onResponse(fallbackResult); })); } @@ -888,83 +878,39 @@ private void generateSummary( } try { - Map summaryParams = new HashMap<>(); - if (llmSpec.getParameters() != null) { - summaryParams.putAll(llmSpec.getParameters()); - } - // Add allParams to ensure LLM_RESPONSE_FILTER is available - summaryParams.putAll(allParams); - String userObjective = allParams.get(USER_PROMPT_FIELD); String steps = String.format(Locale.ROOT, String.join("\n", completedSteps)); String promptWithObjective = String .format("Objective: %s\n\nCompleted Steps:\n%s", userObjective != null ? userObjective : "", steps); - summaryParams.put(PROMPT_FIELD, promptWithObjective); - summaryParams.put(SYSTEM_PROMPT_FIELD, MAX_STEP_SUMMARY_PER_SYSTEM_PROMPT); - - MLPredictionTaskRequest request = new MLPredictionTaskRequest( - llmSpec.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) - .build(), - null, - allParams.get(TENANT_ID_FIELD) - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { - String summary = extractSummaryFromResponse(response, summaryParams); - if (summary == null || summary.trim().isEmpty()) { - log.error("Extracted summary is empty"); - listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); - return; - } - listener.onResponse(summary); - }, listener::onFailure)); + AgentUtils + .generateMaxStepSummary( + client, + llmSpec, + promptWithObjective, + MAX_STEP_SUMMARY_PER_SYSTEM_PROMPT, + allParams, + allParams.get(TENANT_ID_FIELD), + listener + ); } catch (Exception e) { listener.onFailure(e); } } - private String extractSummaryFromResponse(MLTaskResponse response, Map parameters) { - try { - ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); - if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { - return null; - } - - ModelTensors tensors = output.getMlModelOutputs().getFirst(); - if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { - return null; - } - - ModelTensor tensor = tensors.getMlModelTensors().getFirst(); - if (tensor.getResult() != null) { - return tensor.getResult().trim(); - } - - if (tensor.getDataAsMap() == null) { - return null; - } - - Map dataMap = tensor.getDataAsMap(); - if (dataMap.containsKey(RESPONSE_FIELD)) { - return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); - } - - if (dataMap.containsKey("output")) { - Object outputObj = JsonPath.read(dataMap, parameters.get(LLM_RESPONSE_FILTER)); - if (outputObj != null) { - return String.valueOf(outputObj).trim(); - } - } - - log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); - return null; - } catch (Exception e) { - log.error("Summary extraction failed", e); - throw new RuntimeException("Failed to extract summary from response", e); - } + @VisibleForTesting + String generateFallbackResult(int maxSteps, List completedSteps) { + return completedSteps.isEmpty() || completedSteps.size() < 2 + ? String.format("Max Steps Limit (%d) Reached. Use memory_id with same task to restart.", maxSteps) + : String + .format( + "Max Steps Limit (%d) Reached. Use memory_id with same task to restart. \n " + + "Last executed step: %s, \n " + + "Last executed step result: %s", + maxSteps, + completedSteps.get(completedSteps.size() - 2), + completedSteps.getLast() + ); } + } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f479772c1c..178c15274f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1305,7 +1305,7 @@ public void testExtractSummaryFromResponse() { Map params = new HashMap<>(); params.put("llm_response_filter", "$.output.message.content[0].text"); - String result = mlChatAgentRunner.extractSummaryFromResponse(response, params); + String result = AgentUtils.extractSummaryFromResponse(response, params); assertEquals("Valid summary text", result); } @@ -1330,7 +1330,7 @@ public void testExtractSummaryFromResponse_WithResponseField() { Map params = new HashMap<>(); params.put("llm_response_filter", "$.output.message.content[0].text"); - String result = mlChatAgentRunner.extractSummaryFromResponse(response, params); + String result = AgentUtils.extractSummaryFromResponse(response, params); assertEquals("Summary from response field", result); } @@ -1343,8 +1343,12 @@ public void testExtractSummaryFromResponse_WithNullDataMap() { Map params = new HashMap<>(); params.put("llm_response_filter", "$.output.message.content[0].text"); - String result = mlChatAgentRunner.extractSummaryFromResponse(response, params); - assertEquals(null, result); + try { + AgentUtils.extractSummaryFromResponse(response, params); + Assert.fail("Expected IllegalStateException"); + } catch (IllegalStateException e) { + Assert.assertTrue(e.getMessage().contains("No data map available in tensor")); + } } @Test @@ -1358,8 +1362,12 @@ public void testExtractSummaryFromResponse_WithEmptyDataMap() { Map params = new HashMap<>(); params.put("llm_response_filter", "$.output.message.content[0].text"); - String result = mlChatAgentRunner.extractSummaryFromResponse(response, params); - assertEquals(null, result); + try { + AgentUtils.extractSummaryFromResponse(response, params); + Assert.fail("Expected IllegalStateException"); + } catch (IllegalStateException e) { + Assert.assertTrue(e.getMessage().contains("No result/response field found")); + } } @Test