Skip to content

Commit cf8a538

Browse files
nathaliellenaasonianuj287
authored andcommitted
Fix agent streaming with security enabled + error handling (opensearch-project#4256)
* Fix agent streaming with security enabled Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Address comment Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Apply spotless Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Fix agent streaming Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Clean up Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Add more tests Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> --------- Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> Signed-off-by: Anuj Soni <sonianuj287@gmail.com>
1 parent 4e3f0eb commit cf8a538

File tree

5 files changed

+408
-75
lines changed

5 files changed

+408
-75
lines changed

plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteStreamTaskAction.java

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,27 @@
77

88
import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL;
99

10+
import java.io.IOException;
11+
1012
import org.opensearch.action.ActionRequest;
1113
import org.opensearch.action.support.ActionFilters;
1214
import org.opensearch.action.support.HandledTransportAction;
1315
import org.opensearch.common.Nullable;
1416
import org.opensearch.common.inject.Inject;
1517
import org.opensearch.core.action.ActionListener;
18+
import org.opensearch.core.common.io.stream.StreamInput;
1619
import org.opensearch.ml.common.FunctionName;
1720
import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction;
1821
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
1922
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
20-
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;
2123
import org.opensearch.ml.task.MLExecuteTaskRunner;
2224
import org.opensearch.ml.task.MLTaskRunner;
2325
import org.opensearch.tasks.Task;
26+
import org.opensearch.threadpool.ThreadPool;
2427
import org.opensearch.transport.StreamTransportService;
2528
import org.opensearch.transport.TransportChannel;
29+
import org.opensearch.transport.TransportException;
30+
import org.opensearch.transport.TransportResponseHandler;
2631
import org.opensearch.transport.TransportService;
2732

2833
import lombok.AccessLevel;
@@ -71,16 +76,42 @@ public static StreamTransportService getStreamTransportService() {
7176
}
7277

7378
public void messageReceived(MLExecuteTaskRequest request, TransportChannel channel, Task task) {
74-
StreamPredictActionListener<MLExecuteTaskResponse, MLExecuteTaskRequest> streamListener = new StreamPredictActionListener<>(
75-
channel
76-
);
77-
doExecute(task, request, streamListener, channel);
79+
request.setStreamingChannel(channel);
80+
transportService
81+
.sendRequest(
82+
transportService.getLocalNode(),
83+
MLExecuteStreamTaskAction.NAME,
84+
request,
85+
new TransportResponseHandler<MLExecuteTaskResponse>() {
86+
public MLExecuteTaskResponse read(StreamInput in) throws IOException {
87+
return new MLExecuteTaskResponse(in);
88+
}
89+
90+
public void handleResponse(MLExecuteTaskResponse response) {}
91+
92+
public void handleException(TransportException exp) {
93+
try {
94+
channel.sendResponse(exp);
95+
} catch (Exception e) {
96+
log.error("Failed to send error response", e);
97+
}
98+
}
99+
100+
public String executor() {
101+
return ThreadPool.Names.SAME;
102+
}
103+
}
104+
);
78105
}
79106

80107
@Override
81108
protected void doExecute(Task task, ActionRequest request, ActionListener<MLExecuteTaskResponse> listener) {
82-
// This should never be called for streaming action
83-
listener.onFailure(new UnsupportedOperationException("Use doExecute with TransportChannel for streaming requests"));
109+
TransportChannel channel = ((MLExecuteTaskRequest) request).getStreamingChannel();
110+
if (channel != null) {
111+
doExecute(task, request, listener, channel);
112+
} else {
113+
listener.onFailure(new UnsupportedOperationException("Use doExecute with TransportChannel for streaming requests"));
114+
}
84115
}
85116

86117
protected void doExecute(Task task, ActionRequest request, ActionListener<MLExecuteTaskResponse> listener, TransportChannel channel) {

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,11 @@ public List<RestHandler> getRestHandlers(
969969
clusterService
970970
);
971971
RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting);
972-
RestMLExecuteStreamAction restMlExecuteStreamAction = new RestMLExecuteStreamAction(mlFeatureEnabledSetting, clusterService);
972+
RestMLExecuteStreamAction restMlExecuteStreamAction = new RestMLExecuteStreamAction(
973+
mlModelManager,
974+
mlFeatureEnabledSetting,
975+
clusterService
976+
);
973977
RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting);
974978
RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting);
975979
RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting);

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java

Lines changed: 144 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
package org.opensearch.ml.rest;
77

8+
import static java.util.concurrent.TimeUnit.SECONDS;
9+
import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent;
810
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
11+
import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
912
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
1013
import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL;
1114
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
@@ -23,12 +26,16 @@
2326
import java.util.Map;
2427
import java.util.concurrent.CompletableFuture;
2528

29+
import org.opensearch.OpenSearchStatusException;
2630
import org.opensearch.action.ActionRequestValidationException;
31+
import org.opensearch.action.get.GetRequest;
2732
import org.opensearch.cluster.service.ClusterService;
2833
import org.opensearch.common.lease.Releasable;
34+
import org.opensearch.common.util.concurrent.ThreadContext;
2935
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
3036
import org.opensearch.common.xcontent.XContentFactory;
3137
import org.opensearch.common.xcontent.support.XContentHttpChunk;
38+
import org.opensearch.core.action.ActionListener;
3239
import org.opensearch.core.common.bytes.BytesReference;
3340
import org.opensearch.core.common.io.stream.StreamInput;
3441
import org.opensearch.core.rest.RestStatus;
@@ -38,6 +45,8 @@
3845
import org.opensearch.http.HttpChunk;
3946
import org.opensearch.ml.action.execute.TransportExecuteStreamTaskAction;
4047
import org.opensearch.ml.common.FunctionName;
48+
import org.opensearch.ml.common.MLModel;
49+
import org.opensearch.ml.common.agent.MLAgent;
4150
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4251
import org.opensearch.ml.common.exception.MLException;
4352
import org.opensearch.ml.common.input.Input;
@@ -50,6 +59,7 @@
5059
import org.opensearch.ml.common.transport.MLTaskResponse;
5160
import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction;
5261
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
62+
import org.opensearch.ml.model.MLModelManager;
5363
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
5464
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
5565
import org.opensearch.rest.BaseRestHandler;
@@ -74,11 +84,17 @@ public class RestMLExecuteStreamAction extends BaseRestHandler {
7484
private static final String ML_EXECUTE_STREAM_ACTION = "ml_execute_stream_action";
7585
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
7686
private ClusterService clusterService;
87+
private MLModelManager mlModelManager;
7788

7889
/**
7990
* Constructor
8091
*/
81-
public RestMLExecuteStreamAction(MLFeatureEnabledSetting mlFeatureEnabledSetting, ClusterService clusterService) {
92+
public RestMLExecuteStreamAction(
93+
MLModelManager mlModelManager,
94+
MLFeatureEnabledSetting mlFeatureEnabledSetting,
95+
ClusterService clusterService
96+
) {
97+
this.mlModelManager = mlModelManager;
8298
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
8399
this.clusterService = clusterService;
84100
}
@@ -122,6 +138,14 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
122138

123139
String agentId = request.param(PARAMETER_AGENT_ID);
124140

141+
// Validate agent and model synchronously before starting stream
142+
MLAgent agent = validateAndGetAgent(agentId, client);
143+
if (agent.getLlm() != null && agent.getLlm().getModelId() != null) {
144+
if (!isModelValid(agent.getLlm().getModelId(), request, client)) {
145+
throw new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND);
146+
}
147+
}
148+
125149
final StreamingRestChannelConsumer consumer = (channel) -> {
126150
Map<String, List<String>> headers = Map
127151
.of(
@@ -217,6 +241,59 @@ public MLTaskResponse read(StreamInput in) throws IOException {
217241
};
218242
}
219243

244+
@VisibleForTesting
245+
MLAgent validateAndGetAgent(String agentId, NodeClient client) {
246+
try {
247+
CompletableFuture<MLAgent> future = new CompletableFuture<>();
248+
249+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
250+
client.get(new GetRequest(ML_AGENT_INDEX, agentId), ActionListener.runBefore(ActionListener.wrap(response -> {
251+
if (response.isExists()) {
252+
try {
253+
XContentParser parser = jsonXContent
254+
.createParser(null, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString());
255+
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
256+
future.complete(MLAgent.parse(parser));
257+
} catch (Exception e) {
258+
future.completeExceptionally(e);
259+
}
260+
} else {
261+
future.completeExceptionally(new OpenSearchStatusException("Agent not found", RestStatus.NOT_FOUND));
262+
}
263+
}, future::completeExceptionally), context::restore));
264+
}
265+
266+
// TODO: Make validation async
267+
return future.get(5, SECONDS);
268+
} catch (Exception e) {
269+
log.error("Failed to validate agent {}", agentId, e);
270+
throw new OpenSearchStatusException("Failed to find agent with the provided agent id: " + agentId, RestStatus.NOT_FOUND);
271+
}
272+
}
273+
274+
@VisibleForTesting
275+
boolean isModelValid(String modelId, RestRequest request, NodeClient client) throws IOException {
276+
try {
277+
CompletableFuture<MLModel> future = new CompletableFuture<>();
278+
279+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
280+
mlModelManager
281+
.getModel(
282+
modelId,
283+
getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request),
284+
ActionListener.runBefore(ActionListener.wrap(future::complete, future::completeExceptionally), context::restore)
285+
);
286+
}
287+
288+
// TODO: make model validation async
289+
future.get(5, SECONDS);
290+
return true;
291+
} catch (Exception e) {
292+
log.error("Failed to validate model {}", e.getMessage());
293+
return false;
294+
}
295+
}
296+
220297
/**
221298
* Creates a MLExecuteTaskRequest from a RestRequest
222299
*
@@ -248,77 +325,81 @@ MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesRefere
248325
}
249326

250327
private HttpChunk convertToHttpChunk(MLTaskResponse response) throws IOException {
251-
String memoryId = "";
252-
String parentInteractionId = "";
253-
String content = "";
328+
String sseData;
254329
boolean isLast = false;
255330

256-
// TODO: refactor to handle other types of agents
257-
// Extract values from multiple tensors
258331
try {
259-
ModelTensorOutput output = (ModelTensorOutput) response.getOutput();
260-
if (output != null && !output.getMlModelOutputs().isEmpty()) {
261-
ModelTensors modelTensors = output.getMlModelOutputs().get(0);
262-
List<ModelTensor> tensors = modelTensors.getMlModelTensors();
263-
264-
for (ModelTensor tensor : tensors) {
265-
String name = tensor.getName();
266-
if ("memory_id".equals(name) && tensor.getResult() != null) {
267-
memoryId = tensor.getResult();
268-
} else if ("parent_interaction_id".equals(name) && tensor.getResult() != null) {
269-
parentInteractionId = tensor.getResult();
270-
} else if (("llm_response".equals(name) || "response".equals(name)) && tensor.getDataAsMap() != null) {
271-
Map<String, ?> dataMap = tensor.getDataAsMap();
272-
if (dataMap.containsKey("content")) {
273-
content = (String) dataMap.get("content");
274-
if (content == null)
275-
content = "";
276-
}
277-
if (dataMap.containsKey("is_last")) {
278-
isLast = Boolean.TRUE.equals(dataMap.get("is_last"));
279-
}
280-
}
281-
}
332+
Map<String, ?> dataMap = extractDataMap(response);
333+
334+
if (dataMap.containsKey("error")) {
335+
// Error response
336+
String errorMessage = (String) dataMap.get("error");
337+
sseData = String.format("data: {\"error\": \"%s\"}\n\n", errorMessage.replace("\"", "\\\"").replace("\n", "\\n"));
338+
isLast = true;
339+
} else {
340+
// TODO: refactor to handle other types of agents
341+
// Regular response - extract values and build proper structure
342+
String memoryId = extractTensorResult(response, "memory_id");
343+
String parentInteractionId = extractTensorResult(response, "parent_interaction_id");
344+
String content = dataMap.containsKey("content") ? (String) dataMap.get("content") : "";
345+
isLast = dataMap.containsKey("is_last") ? Boolean.TRUE.equals(dataMap.get("is_last")) : false;
346+
boolean finalIsLast = isLast;
347+
348+
List<ModelTensor> orderedTensors = List
349+
.of(
350+
ModelTensor.builder().name("memory_id").result(memoryId).build(),
351+
ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(),
352+
ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap<String, Object>() {
353+
{
354+
put("content", content);
355+
put("is_last", finalIsLast);
356+
}
357+
}).build()
358+
);
359+
360+
ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build();
361+
ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
362+
363+
XContentBuilder builder = XContentFactory.jsonBuilder();
364+
tensorOutput.toXContent(builder, ToXContent.EMPTY_PARAMS);
365+
sseData = "data: " + builder.toString() + "\n\n";
282366
}
283367
} catch (Exception e) {
284-
log.error("Failed to extract values from response", e);
368+
log.error("Failed to process response", e);
369+
sseData = "data: {\"error\": \"Processing failed\"}\n\n";
370+
isLast = true;
285371
}
372+
return createHttpChunk(sseData, isLast);
373+
}
286374

287-
String finalContent = content;
288-
boolean finalIsLast = isLast;
289-
290-
log
291-
.info(
292-
"Converting to HttpChunk - memoryId: '{}', parentId: '{}', content: '{}', isLast: {}",
293-
memoryId,
294-
parentInteractionId,
295-
content,
296-
isLast
297-
);
375+
private String extractTensorResult(MLTaskResponse response, String tensorName) {
376+
ModelTensorOutput output = (ModelTensorOutput) response.getOutput();
377+
if (output != null && !output.getMlModelOutputs().isEmpty()) {
378+
ModelTensors tensors = output.getMlModelOutputs().get(0);
379+
for (ModelTensor tensor : tensors.getMlModelTensors()) {
380+
if (tensorName.equals(tensor.getName()) && tensor.getResult() != null) {
381+
return tensor.getResult();
382+
}
383+
}
384+
}
385+
return "";
386+
}
298387

299-
// Create ordered tensors
300-
List<ModelTensor> orderedTensors = List
301-
.of(
302-
ModelTensor.builder().name("memory_id").result(memoryId).build(),
303-
ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(),
304-
ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap<String, Object>() {
305-
{
306-
put("content", finalContent);
307-
put("is_last", finalIsLast);
388+
private Map<String, ?> extractDataMap(MLTaskResponse response) {
389+
ModelTensorOutput output = (ModelTensorOutput) response.getOutput();
390+
if (output != null && !output.getMlModelOutputs().isEmpty()) {
391+
ModelTensors tensors = output.getMlModelOutputs().get(0);
392+
for (ModelTensor tensor : tensors.getMlModelTensors()) {
393+
String name = tensor.getName();
394+
if ("error".equals(name) || "llm_response".equals(name) || "response".equals(name)) {
395+
Map<String, ?> dataMap = tensor.getDataAsMap();
396+
if (dataMap != null) {
397+
return dataMap;
308398
}
309-
}).build()
310-
);
311-
312-
ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build();
313-
314-
ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
315-
316-
XContentBuilder builder = XContentFactory.jsonBuilder();
317-
tensorOutput.toXContent(builder, ToXContent.EMPTY_PARAMS);
318-
String jsonData = builder.toString();
319-
320-
String sseData = "data: " + jsonData + "\n\n";
321-
return createHttpChunk(sseData, isLast);
399+
}
400+
}
401+
}
402+
return Map.of();
322403
}
323404

324405
private HttpChunk createHttpChunk(String sseData, boolean isLast) {

0 commit comments

Comments
 (0)