Skip to content

Commit 5cd4c96

Browse files
[FIX] allow partial updates to llm and memory fields in MLAgentUpdateInput (#4040)
* allow partial updates to llm and memory fields in MLAgentUpdateInput Signed-off-by: Jiaping Zeng <jpz@amazon.com> * Trigger build Signed-off-by: Jiaping Zeng <jpz@amazon.com> * add more tests to improve coverage Signed-off-by: Jiaping Zeng <jpz@amazon.com> --------- Signed-off-by: Jiaping Zeng <jpz@amazon.com> Co-authored-by: Dhrubo Saha <dhrubo@amazon.com>
1 parent fb71867 commit 5cd4c96

File tree

3 files changed

+748
-69
lines changed

3 files changed

+748
-69
lines changed

common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java

Lines changed: 144 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.io.IOException;
1313
import java.time.Instant;
1414
import java.util.ArrayList;
15+
import java.util.HashMap;
1516
import java.util.HashSet;
1617
import java.util.List;
1718
import java.util.Map;
@@ -41,20 +42,28 @@ public class MLAgentUpdateInput implements ToXContentObject, Writeable {
4142
public static final String AGENT_NAME_FIELD = "name";
4243
public static final String DESCRIPTION_FIELD = "description";
4344
public static final String LLM_FIELD = "llm";
45+
public static final String LLM_MODEL_ID_FIELD = "model_id";
46+
public static final String LLM_PARAMETERS_FIELD = "parameters";
4447
public static final String TOOLS_FIELD = "tools";
4548
public static final String PARAMETERS_FIELD = "parameters";
4649
public static final String MEMORY_FIELD = "memory";
50+
public static final String MEMORY_TYPE_FIELD = "type";
51+
public static final String MEMORY_SESSION_ID_FIELD = "session_id";
52+
public static final String MEMORY_WINDOW_SIZE_FIELD = "window_size";
4753
public static final String APP_TYPE_FIELD = "app_type";
4854
public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time";
4955

5056
@Getter
5157
private String agentId;
5258
private String name;
5359
private String description;
54-
private LLMSpec llm;
60+
private String llmModelId;
61+
private Map<String, String> llmParameters;
5562
private List<MLToolSpec> tools;
5663
private Map<String, String> parameters;
57-
private MLMemorySpec memory;
64+
private String memoryType;
65+
private String memorySessionId;
66+
private Integer memoryWindowSize;
5867
private String appType;
5968
private Instant lastUpdateTime;
6069
private String tenantId;
@@ -64,21 +73,27 @@ public MLAgentUpdateInput(
6473
String agentId,
6574
String name,
6675
String description,
67-
LLMSpec llm,
76+
String llmModelId,
77+
Map<String, String> llmParameters,
6878
List<MLToolSpec> tools,
6979
Map<String, String> parameters,
70-
MLMemorySpec memory,
80+
String memoryType,
81+
String memorySessionId,
82+
Integer memoryWindowSize,
7183
String appType,
7284
Instant lastUpdateTime,
7385
String tenantId
7486
) {
7587
this.agentId = agentId;
7688
this.name = name;
7789
this.description = description;
78-
this.llm = llm;
90+
this.llmModelId = llmModelId;
91+
this.llmParameters = llmParameters;
7992
this.tools = tools;
8093
this.parameters = parameters;
81-
this.memory = memory;
94+
this.memoryType = memoryType;
95+
this.memorySessionId = memorySessionId;
96+
this.memoryWindowSize = memoryWindowSize;
8297
this.appType = appType;
8398
this.lastUpdateTime = lastUpdateTime;
8499
this.tenantId = tenantId;
@@ -90,8 +105,9 @@ public MLAgentUpdateInput(StreamInput in) throws IOException {
90105
agentId = in.readString();
91106
name = in.readOptionalString();
92107
description = in.readOptionalString();
108+
llmModelId = in.readOptionalString();
93109
if (in.readBoolean()) {
94-
llm = new LLMSpec(in);
110+
llmParameters = in.readMap(StreamInput::readString, StreamInput::readOptionalString);
95111
}
96112
if (in.readBoolean()) {
97113
tools = new ArrayList<>();
@@ -103,9 +119,9 @@ public MLAgentUpdateInput(StreamInput in) throws IOException {
103119
if (in.readBoolean()) {
104120
parameters = in.readMap(StreamInput::readString, StreamInput::readOptionalString);
105121
}
106-
if (in.readBoolean()) {
107-
memory = new MLMemorySpec(in);
108-
}
122+
memoryType = in.readOptionalString();
123+
memorySessionId = in.readOptionalString();
124+
memoryWindowSize = in.readOptionalInt();
109125
lastUpdateTime = in.readOptionalInstant();
110126
appType = in.readOptionalString();
111127
tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? in.readOptionalString() : null;
@@ -121,17 +137,34 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
121137
if (description != null) {
122138
builder.field(DESCRIPTION_FIELD, description);
123139
}
124-
if (llm != null) {
125-
builder.field(LLM_FIELD, llm);
140+
if (llmModelId != null || (llmParameters != null && !llmParameters.isEmpty())) {
141+
builder.startObject(LLM_FIELD);
142+
if (llmModelId != null) {
143+
builder.field(LLM_MODEL_ID_FIELD, llmModelId);
144+
}
145+
if (llmParameters != null && !llmParameters.isEmpty()) {
146+
builder.field(LLM_PARAMETERS_FIELD, llmParameters);
147+
}
148+
builder.endObject();
126149
}
127150
if (tools != null && !tools.isEmpty()) {
128151
builder.field(TOOLS_FIELD, tools);
129152
}
130153
if (parameters != null && !parameters.isEmpty()) {
131154
builder.field(PARAMETERS_FIELD, parameters);
132155
}
133-
if (memory != null) {
134-
builder.field(MEMORY_FIELD, memory);
156+
if (memoryType != null || memorySessionId != null || memoryWindowSize != null) {
157+
builder.startObject(MEMORY_FIELD);
158+
if (memoryType != null) {
159+
builder.field(MEMORY_TYPE_FIELD, memoryType);
160+
}
161+
if (memorySessionId != null) {
162+
builder.field(MEMORY_SESSION_ID_FIELD, memorySessionId);
163+
}
164+
if (memoryWindowSize != null) {
165+
builder.field(MEMORY_WINDOW_SIZE_FIELD, memoryWindowSize);
166+
}
167+
builder.endObject();
135168
}
136169
if (appType != null) {
137170
builder.field(APP_TYPE_FIELD, appType);
@@ -152,9 +185,10 @@ public void writeTo(StreamOutput out) throws IOException {
152185
out.writeString(agentId);
153186
out.writeOptionalString(name);
154187
out.writeOptionalString(description);
155-
if (llm != null) {
188+
out.writeOptionalString(llmModelId);
189+
if (llmParameters != null && !llmParameters.isEmpty()) {
156190
out.writeBoolean(true);
157-
llm.writeTo(out);
191+
out.writeMap(llmParameters, StreamOutput::writeString, StreamOutput::writeOptionalString);
158192
} else {
159193
out.writeBoolean(false);
160194
}
@@ -173,12 +207,9 @@ public void writeTo(StreamOutput out) throws IOException {
173207
} else {
174208
out.writeBoolean(false);
175209
}
176-
if (memory != null) {
177-
out.writeBoolean(true);
178-
memory.writeTo(out);
179-
} else {
180-
out.writeBoolean(false);
181-
}
210+
out.writeOptionalString(memoryType);
211+
out.writeOptionalString(memorySessionId);
212+
out.writeOptionalInt(memoryWindowSize);
182213
out.writeOptionalInstant(lastUpdateTime);
183214
out.writeOptionalString(appType);
184215
if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) {
@@ -190,10 +221,13 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException
190221
String agentId = null;
191222
String name = null;
192223
String description = null;
193-
LLMSpec llm = null;
224+
String llmModelId = null;
225+
Map<String, String> llmParameters = null;
194226
List<MLToolSpec> tools = null;
195227
Map<String, String> parameters = null;
196-
MLMemorySpec memory = null;
228+
String memoryType = null;
229+
String memorySessionId = null;
230+
Integer memoryWindowSize = null;
197231
String appType = null;
198232
Instant lastUpdateTime = null;
199233
String tenantId = null;
@@ -213,7 +247,22 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException
213247
description = parser.text();
214248
break;
215249
case LLM_FIELD:
216-
llm = LLMSpec.parse(parser);
250+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
251+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
252+
String llmFieldName = parser.currentName();
253+
parser.nextToken();
254+
switch (llmFieldName) {
255+
case LLM_MODEL_ID_FIELD:
256+
llmModelId = parser.text();
257+
break;
258+
case LLM_PARAMETERS_FIELD:
259+
llmParameters = parser.mapStrings();
260+
break;
261+
default:
262+
parser.skipChildren();
263+
break;
264+
}
265+
}
217266
break;
218267
case TOOLS_FIELD:
219268
tools = new ArrayList<>();
@@ -226,7 +275,25 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException
226275
parameters = parser.mapStrings();
227276
break;
228277
case MEMORY_FIELD:
229-
memory = MLMemorySpec.parse(parser);
278+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
279+
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
280+
String memoryFieldName = parser.currentName();
281+
parser.nextToken();
282+
switch (memoryFieldName) {
283+
case MEMORY_TYPE_FIELD:
284+
memoryType = parser.text();
285+
break;
286+
case MEMORY_SESSION_ID_FIELD:
287+
memorySessionId = parser.text();
288+
break;
289+
case MEMORY_WINDOW_SIZE_FIELD:
290+
memoryWindowSize = parser.intValue();
291+
break;
292+
default:
293+
parser.skipChildren();
294+
break;
295+
}
296+
}
230297
break;
231298
case APP_TYPE_FIELD:
232299
appType = parser.text();
@@ -243,21 +310,67 @@ public static MLAgentUpdateInput parse(XContentParser parser) throws IOException
243310
}
244311
}
245312

246-
return new MLAgentUpdateInput(agentId, name, description, llm, tools, parameters, memory, appType, lastUpdateTime, tenantId);
313+
return new MLAgentUpdateInput(
314+
agentId,
315+
name,
316+
description,
317+
llmModelId,
318+
llmParameters,
319+
tools,
320+
parameters,
321+
memoryType,
322+
memorySessionId,
323+
memoryWindowSize,
324+
appType,
325+
lastUpdateTime,
326+
tenantId
327+
);
247328
}
248329

249330
public MLAgent toMLAgent(MLAgent originalAgent) {
331+
LLMSpec finalLlm;
332+
if (llmModelId == null && (llmParameters == null || llmParameters.isEmpty())) {
333+
finalLlm = originalAgent.getLlm();
334+
} else {
335+
LLMSpec originalLlm = originalAgent.getLlm();
336+
337+
String finalModelId = llmModelId != null ? llmModelId : originalLlm.getModelId();
338+
339+
Map<String, String> finalParameters = new HashMap<>();
340+
if (originalLlm != null && originalLlm.getParameters() != null) {
341+
finalParameters.putAll(originalLlm.getParameters());
342+
}
343+
if (llmParameters != null) {
344+
finalParameters.putAll(llmParameters);
345+
}
346+
347+
finalLlm = LLMSpec.builder().modelId(finalModelId).parameters(finalParameters).build();
348+
}
349+
350+
MLMemorySpec finalMemory;
351+
if (memoryType == null && memorySessionId == null && memoryWindowSize == null) {
352+
finalMemory = originalAgent.getMemory();
353+
} else {
354+
MLMemorySpec originalMemory = originalAgent.getMemory();
355+
356+
String finalMemoryType = memoryType != null ? memoryType : originalMemory.getType();
357+
String finalSessionId = memorySessionId != null ? memorySessionId : originalMemory.getSessionId();
358+
Integer finalWindowSize = memoryWindowSize != null ? memoryWindowSize : originalMemory.getWindowSize();
359+
360+
finalMemory = MLMemorySpec.builder().type(finalMemoryType).sessionId(finalSessionId).windowSize(finalWindowSize).build();
361+
}
362+
250363
return MLAgent
251364
.builder()
252365
.type(originalAgent.getType())
253366
.createdTime(originalAgent.getCreatedTime())
254367
.isHidden(originalAgent.getIsHidden())
255368
.name(name == null ? originalAgent.getName() : name)
256369
.description(description == null ? originalAgent.getDescription() : description)
257-
.llm(llm == null ? originalAgent.getLlm() : llm)
370+
.llm(finalLlm)
258371
.tools(tools == null ? originalAgent.getTools() : tools)
259372
.parameters(parameters == null ? originalAgent.getParameters() : parameters)
260-
.memory(memory == null ? originalAgent.getMemory() : memory)
373+
.memory(finalMemory)
261374
.lastUpdateTime(lastUpdateTime)
262375
.appType(appType)
263376
.tenantId(tenantId)
@@ -270,8 +383,8 @@ private void validate() {
270383
String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH)
271384
);
272385
}
273-
if (memory != null && !memory.getType().equals("conversation_index")) {
274-
throw new IllegalArgumentException(String.format("Invalid memory type: %s", memory.getType()));
386+
if (memoryType != null && !memoryType.equals("conversation_index")) {
387+
throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType));
275388
}
276389
if (tools != null) {
277390
Set<String> toolNames = new HashSet<>();

0 commit comments

Comments
 (0)