Skip to content

Commit c5ad5c3

Browse files
committed
Merge agent-revamp into feature/memory as a single commit
Signed-off-by: Pavan Yekbote <pybot@amazon.com>
1 parent 14b2206 commit c5ad5c3

File tree

23 files changed

+2109
-4
lines changed

23 files changed

+2109
-4
lines changed

common/src/main/java/org/opensearch/ml/common/agent/AgentInput.java

Lines changed: 555 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.agent;
7+
8+
import java.util.List;
9+
10+
import lombok.extern.log4j.Log4j2;
11+
12+
/**
13+
* Utility class for validating standardized agent input formats.
14+
* The AgentInput itself is already standardized - this validator just validates it
15+
* and ensures it's ready to be passed to ModelProviders for conversion to their
16+
* specific request body formats.
17+
*/
18+
@Log4j2
19+
public class AgentInputProcessor {
20+
21+
// Private constructor to prevent instantiation
22+
private AgentInputProcessor() {
23+
throw new UnsupportedOperationException("Utility class cannot be instantiated");
24+
}
25+
26+
/**
27+
* Validates the standardized AgentInput.
28+
* The AgentInput is passed through after validation - ModelProviders will
29+
* handle the conversion to their specific request body parameters.
30+
*
31+
* @param agentInput the standardized agent input
32+
* @throws IllegalArgumentException if input is invalid
33+
*/
34+
public static void validateInput(AgentInput agentInput) {
35+
if (agentInput == null || agentInput.getInput() == null) {
36+
throw new IllegalArgumentException("AgentInput and its input field cannot be null");
37+
}
38+
39+
InputType type = agentInput.getInputType();
40+
41+
switch (type) {
42+
case TEXT:
43+
validateTextInput((String) agentInput.getInput());
44+
break;
45+
case CONTENT_BLOCKS:
46+
@SuppressWarnings("unchecked")
47+
List<ContentBlock> blocks = (List<ContentBlock>) agentInput.getInput();
48+
validateContentBlocks(blocks);
49+
break;
50+
case MESSAGES:
51+
@SuppressWarnings("unchecked")
52+
List<Message> messages = (List<Message>) agentInput.getInput();
53+
validateMessages(messages);
54+
break;
55+
default:
56+
throw new IllegalArgumentException("Unsupported input type: " + type);
57+
}
58+
}
59+
60+
/**
61+
* Validates simple text input.
62+
*
63+
* @param text the text input
64+
* @throws IllegalArgumentException if text is invalid
65+
*/
66+
private static void validateTextInput(String text) {
67+
if (text == null || text.trim().isEmpty()) {
68+
throw new IllegalArgumentException("Text input cannot be null or empty");
69+
}
70+
}
71+
72+
/**
73+
* Validates multi-modal content blocks.
74+
*
75+
* @param blocks the list of content blocks
76+
* @throws IllegalArgumentException if content blocks are invalid
77+
*/
78+
private static void validateContentBlocks(List<ContentBlock> blocks) {
79+
if (blocks == null || blocks.isEmpty()) {
80+
throw new IllegalArgumentException("Content blocks cannot be null or empty");
81+
}
82+
83+
for (ContentBlock block : blocks) {
84+
if (block.getType() == null) {
85+
throw new IllegalArgumentException("Content block type cannot be null");
86+
}
87+
88+
switch (block.getType()) {
89+
case TEXT:
90+
if (block.getText() == null || block.getText().trim().isEmpty()) {
91+
throw new IllegalArgumentException("Text content block cannot have null or empty text");
92+
}
93+
break;
94+
case IMAGE:
95+
if (block.getImage() == null) {
96+
throw new IllegalArgumentException("Image content block must have image data");
97+
}
98+
break;
99+
case DOCUMENT:
100+
if (block.getDocument() == null) {
101+
throw new IllegalArgumentException("Document content block must have document data");
102+
}
103+
break;
104+
case VIDEO:
105+
if (block.getVideo() == null) {
106+
throw new IllegalArgumentException("Video content block must have video data");
107+
}
108+
break;
109+
default:
110+
throw new IllegalArgumentException("Unsupported content block type: " + block.getType());
111+
}
112+
}
113+
}
114+
115+
/**
116+
* Validates message-based conversation input.
117+
*
118+
* @param messages the list of messages
119+
* @throws IllegalArgumentException if messages are invalid
120+
*/
121+
private static void validateMessages(List<Message> messages) {
122+
if (messages == null || messages.isEmpty()) {
123+
throw new IllegalArgumentException("Messages cannot be null or empty");
124+
}
125+
126+
for (Message message : messages) {
127+
if (message.getRole() == null || message.getRole().trim().isEmpty()) {
128+
throw new IllegalArgumentException("Message role cannot be null or empty");
129+
}
130+
131+
if (message.getContent() == null || message.getContent().isEmpty()) {
132+
throw new IllegalArgumentException("Message content cannot be null or empty");
133+
}
134+
135+
// Validate each content block in the message
136+
validateContentBlocks(message.getContent());
137+
}
138+
}
139+
140+
/**
141+
* Extracts question text from AgentInput for prompt template usage.
142+
* This provides the text that will be used in prompt templates that reference $parameters.question.
143+
*/
144+
public static String extractQuestionText(AgentInput agentInput) {
145+
validateInput(agentInput);
146+
return switch (agentInput.getInputType()) {
147+
case TEXT -> (String) agentInput.getInput();
148+
case CONTENT_BLOCKS -> {
149+
// For content blocks, extract and combine text content
150+
@SuppressWarnings("unchecked")
151+
List<ContentBlock> blocks = (List<ContentBlock>) agentInput.getInput();
152+
yield extractTextFromContentBlocks(blocks);
153+
}
154+
case MESSAGES -> {
155+
// For messages, extract the last user message text
156+
@SuppressWarnings("unchecked")
157+
List<Message> messages = (List<Message>) agentInput.getInput();
158+
yield extractTextFromMessages(messages);
159+
}
160+
default -> throw new IllegalArgumentException("Unsupported input type: " + agentInput.getInputType());
161+
};
162+
}
163+
164+
/**
165+
* Extracts text content from content blocks for human-readable display.
166+
* Ignores non text blocks
167+
* @throws IllegalArgumentException if content blocks are invalid[
168+
*/
169+
private static String extractTextFromContentBlocks(List<ContentBlock> blocks) {
170+
if (blocks == null || blocks.isEmpty()) {
171+
throw new IllegalArgumentException("Content blocks cannot be null or empty");
172+
}
173+
174+
StringBuilder textBuilder = new StringBuilder();
175+
for (ContentBlock block : blocks) {
176+
if (block.getType() == ContentType.TEXT) {
177+
String text = block.getText();
178+
if (text != null && !text.trim().isEmpty()) {
179+
textBuilder.append(text.trim());
180+
textBuilder.append("\n");
181+
}
182+
}
183+
}
184+
185+
return textBuilder.toString();
186+
}
187+
188+
/**
189+
* Extracts text content from last message.
190+
*/
191+
private static String extractTextFromMessages(List<Message> messages) {
192+
if (messages == null || messages.isEmpty()) {
193+
throw new IllegalArgumentException("Messages cannot be null or empty");
194+
}
195+
196+
Message message = messages.getLast();
197+
return extractTextFromContentBlocks(message.getContent());
198+
}
199+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.agent;
7+
8+
import org.opensearch.ml.common.connector.Connector;
9+
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
10+
11+
import lombok.extern.log4j.Log4j2;
12+
13+
/**
14+
* Service class for handling model creation during agent registration
15+
*/
16+
@Log4j2
17+
public class AgentModelService {
18+
19+
/**
20+
* Creates a model input from the agent model specification
21+
* @param modelSpec the model specification from agent registration
22+
* @return MLRegisterModelInput ready for model registration
23+
* @throws IllegalArgumentException if model provider is not supported
24+
*/
25+
public static MLRegisterModelInput createModelFromSpec(MLAgentModelSpec modelSpec) {
26+
validateModelSpec(modelSpec);
27+
ModelProvider provider = ModelProviderFactory.getProvider(modelSpec.getModelProvider());
28+
29+
Connector connector = provider.createConnector(modelSpec.getModelId(), modelSpec.getCredential(), modelSpec.getModelParameters());
30+
31+
return provider.createModelInput(modelSpec.getModelId(), connector, modelSpec.getModelParameters());
32+
}
33+
34+
/**
35+
* Infers the LLM interface from model provider for function calling
36+
* @param modelProvider the model provider string
37+
* @return the corresponding LLM interface string, or null if not supported
38+
*/
39+
public static String inferLLMInterface(String modelProvider) {
40+
if (modelProvider == null) {
41+
return null;
42+
}
43+
44+
try {
45+
ModelProvider provider = ModelProviderFactory.getProvider(modelProvider);
46+
return provider.getLLMInterface();
47+
} catch (Exception e) {
48+
log.error("Failed to infer LLM interface", e);
49+
return null;
50+
}
51+
}
52+
53+
/**
54+
* Validates that the model specification is complete and valid
55+
* @param modelSpec the model specification to validate
56+
* @throws IllegalArgumentException if validation fails
57+
*/
58+
private static void validateModelSpec(MLAgentModelSpec modelSpec) {
59+
if (modelSpec == null) {
60+
throw new IllegalArgumentException("Model specification not found");
61+
}
62+
63+
if (modelSpec.getModelId() == null || modelSpec.getModelId().trim().isEmpty()) {
64+
throw new IllegalArgumentException("model_id cannot be null or empty");
65+
}
66+
67+
if (modelSpec.getModelProvider() == null || modelSpec.getModelProvider().trim().isEmpty()) {
68+
throw new IllegalArgumentException("model_provider cannot be null or empty");
69+
}
70+
71+
// Validate that the provider type is supported
72+
try {
73+
ModelProviderType.from(modelSpec.getModelProvider());
74+
} catch (IllegalArgumentException e) {
75+
throw new IllegalArgumentException("Unsupported model provider: " + modelSpec.getModelProvider());
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)