55
66package org .opensearch .ml .rest ;
77
8+ import static java .util .concurrent .TimeUnit .SECONDS ;
9+ import static org .opensearch .common .xcontent .json .JsonXContent .jsonXContent ;
810import static org .opensearch .core .xcontent .XContentParserUtils .ensureExpectedToken ;
11+ import static org .opensearch .ml .common .CommonValue .ML_AGENT_INDEX ;
912import static org .opensearch .ml .plugin .MachineLearningPlugin .ML_BASE_URI ;
1013import static org .opensearch .ml .plugin .MachineLearningPlugin .STREAM_EXECUTE_THREAD_POOL ;
1114import static org .opensearch .ml .utils .MLExceptionUtils .AGENT_FRAMEWORK_DISABLED_ERR_MSG ;
2326import java .util .Map ;
2427import java .util .concurrent .CompletableFuture ;
2528
29+ import org .opensearch .OpenSearchStatusException ;
2630import org .opensearch .action .ActionRequestValidationException ;
31+ import org .opensearch .action .get .GetRequest ;
2732import org .opensearch .cluster .service .ClusterService ;
2833import org .opensearch .common .lease .Releasable ;
34+ import org .opensearch .common .util .concurrent .ThreadContext ;
2935import org .opensearch .common .xcontent .LoggingDeprecationHandler ;
3036import org .opensearch .common .xcontent .XContentFactory ;
3137import org .opensearch .common .xcontent .support .XContentHttpChunk ;
38+ import org .opensearch .core .action .ActionListener ;
3239import org .opensearch .core .common .bytes .BytesReference ;
3340import org .opensearch .core .common .io .stream .StreamInput ;
3441import org .opensearch .core .rest .RestStatus ;
3845import org .opensearch .http .HttpChunk ;
3946import org .opensearch .ml .action .execute .TransportExecuteStreamTaskAction ;
4047import org .opensearch .ml .common .FunctionName ;
48+ import org .opensearch .ml .common .MLModel ;
49+ import org .opensearch .ml .common .agent .MLAgent ;
4150import org .opensearch .ml .common .dataset .remote .RemoteInferenceInputDataSet ;
4251import org .opensearch .ml .common .exception .MLException ;
4352import org .opensearch .ml .common .input .Input ;
5059import org .opensearch .ml .common .transport .MLTaskResponse ;
5160import org .opensearch .ml .common .transport .execute .MLExecuteStreamTaskAction ;
5261import org .opensearch .ml .common .transport .execute .MLExecuteTaskRequest ;
62+ import org .opensearch .ml .model .MLModelManager ;
5363import org .opensearch .ml .repackage .com .google .common .annotations .VisibleForTesting ;
5464import org .opensearch .ml .repackage .com .google .common .collect .ImmutableList ;
5565import org .opensearch .rest .BaseRestHandler ;
@@ -74,11 +84,17 @@ public class RestMLExecuteStreamAction extends BaseRestHandler {
7484 private static final String ML_EXECUTE_STREAM_ACTION = "ml_execute_stream_action" ;
7585 private final MLFeatureEnabledSetting mlFeatureEnabledSetting ;
7686 private ClusterService clusterService ;
87+ private MLModelManager mlModelManager ;
7788
7889 /**
7990 * Constructor
8091 */
81- public RestMLExecuteStreamAction (MLFeatureEnabledSetting mlFeatureEnabledSetting , ClusterService clusterService ) {
92+ public RestMLExecuteStreamAction (
93+ MLModelManager mlModelManager ,
94+ MLFeatureEnabledSetting mlFeatureEnabledSetting ,
95+ ClusterService clusterService
96+ ) {
97+ this .mlModelManager = mlModelManager ;
8298 this .mlFeatureEnabledSetting = mlFeatureEnabledSetting ;
8399 this .clusterService = clusterService ;
84100 }
@@ -122,6 +138,14 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
122138
123139 String agentId = request .param (PARAMETER_AGENT_ID );
124140
141+ // Validate agent and model synchronously before starting stream
142+ MLAgent agent = validateAndGetAgent (agentId , client );
143+ if (agent .getLlm () != null && agent .getLlm ().getModelId () != null ) {
144+ if (!isModelValid (agent .getLlm ().getModelId (), request , client )) {
145+ throw new OpenSearchStatusException ("Failed to find model" , RestStatus .NOT_FOUND );
146+ }
147+ }
148+
125149 final StreamingRestChannelConsumer consumer = (channel ) -> {
126150 Map <String , List <String >> headers = Map
127151 .of (
@@ -217,6 +241,59 @@ public MLTaskResponse read(StreamInput in) throws IOException {
217241 };
218242 }
219243
244+ @ VisibleForTesting
245+ MLAgent validateAndGetAgent (String agentId , NodeClient client ) {
246+ try {
247+ CompletableFuture <MLAgent > future = new CompletableFuture <>();
248+
249+ try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
250+ client .get (new GetRequest (ML_AGENT_INDEX , agentId ), ActionListener .runBefore (ActionListener .wrap (response -> {
251+ if (response .isExists ()) {
252+ try {
253+ XContentParser parser = jsonXContent
254+ .createParser (null , LoggingDeprecationHandler .INSTANCE , response .getSourceAsString ());
255+ ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .nextToken (), parser );
256+ future .complete (MLAgent .parse (parser ));
257+ } catch (Exception e ) {
258+ future .completeExceptionally (e );
259+ }
260+ } else {
261+ future .completeExceptionally (new OpenSearchStatusException ("Agent not found" , RestStatus .NOT_FOUND ));
262+ }
263+ }, future ::completeExceptionally ), context ::restore ));
264+ }
265+
266+ // TODO: Make validation async
267+ return future .get (5 , SECONDS );
268+ } catch (Exception e ) {
269+ log .error ("Failed to validate agent {}" , agentId , e );
270+ throw new OpenSearchStatusException ("Failed to find agent with the provided agent id: " + agentId , RestStatus .NOT_FOUND );
271+ }
272+ }
273+
274+ @ VisibleForTesting
275+ boolean isModelValid (String modelId , RestRequest request , NodeClient client ) throws IOException {
276+ try {
277+ CompletableFuture <MLModel > future = new CompletableFuture <>();
278+
279+ try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
280+ mlModelManager
281+ .getModel (
282+ modelId ,
283+ getTenantID (mlFeatureEnabledSetting .isMultiTenancyEnabled (), request ),
284+ ActionListener .runBefore (ActionListener .wrap (future ::complete , future ::completeExceptionally ), context ::restore )
285+ );
286+ }
287+
288+ // TODO: make model validation async
289+ future .get (5 , SECONDS );
290+ return true ;
291+ } catch (Exception e ) {
292+ log .error ("Failed to validate model {}" , e .getMessage ());
293+ return false ;
294+ }
295+ }
296+
220297 /**
221298 * Creates a MLExecuteTaskRequest from a RestRequest
222299 *
@@ -248,77 +325,81 @@ MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesRefere
248325 }
249326
250327 private HttpChunk convertToHttpChunk (MLTaskResponse response ) throws IOException {
251- String memoryId = "" ;
252- String parentInteractionId = "" ;
253- String content = "" ;
328+ String sseData ;
254329 boolean isLast = false ;
255330
256- // TODO: refactor to handle other types of agents
257- // Extract values from multiple tensors
258331 try {
259- ModelTensorOutput output = (ModelTensorOutput ) response .getOutput ();
260- if (output != null && !output .getMlModelOutputs ().isEmpty ()) {
261- ModelTensors modelTensors = output .getMlModelOutputs ().get (0 );
262- List <ModelTensor > tensors = modelTensors .getMlModelTensors ();
263-
264- for (ModelTensor tensor : tensors ) {
265- String name = tensor .getName ();
266- if ("memory_id" .equals (name ) && tensor .getResult () != null ) {
267- memoryId = tensor .getResult ();
268- } else if ("parent_interaction_id" .equals (name ) && tensor .getResult () != null ) {
269- parentInteractionId = tensor .getResult ();
270- } else if (("llm_response" .equals (name ) || "response" .equals (name )) && tensor .getDataAsMap () != null ) {
271- Map <String , ?> dataMap = tensor .getDataAsMap ();
272- if (dataMap .containsKey ("content" )) {
273- content = (String ) dataMap .get ("content" );
274- if (content == null )
275- content = "" ;
276- }
277- if (dataMap .containsKey ("is_last" )) {
278- isLast = Boolean .TRUE .equals (dataMap .get ("is_last" ));
279- }
280- }
281- }
332+ Map <String , ?> dataMap = extractDataMap (response );
333+
334+ if (dataMap .containsKey ("error" )) {
335+ // Error response
336+ String errorMessage = (String ) dataMap .get ("error" );
337+ sseData = String .format ("data: {\" error\" : \" %s\" }\n \n " , errorMessage .replace ("\" " , "\\ \" " ).replace ("\n " , "\\ n" ));
338+ isLast = true ;
339+ } else {
340+ // TODO: refactor to handle other types of agents
341+ // Regular response - extract values and build proper structure
342+ String memoryId = extractTensorResult (response , "memory_id" );
343+ String parentInteractionId = extractTensorResult (response , "parent_interaction_id" );
344+ String content = dataMap .containsKey ("content" ) ? (String ) dataMap .get ("content" ) : "" ;
345+ isLast = dataMap .containsKey ("is_last" ) ? Boolean .TRUE .equals (dataMap .get ("is_last" )) : false ;
346+ boolean finalIsLast = isLast ;
347+
348+ List <ModelTensor > orderedTensors = List
349+ .of (
350+ ModelTensor .builder ().name ("memory_id" ).result (memoryId ).build (),
351+ ModelTensor .builder ().name ("parent_interaction_id" ).result (parentInteractionId ).build (),
352+ ModelTensor .builder ().name ("response" ).dataAsMap (new LinkedHashMap <String , Object >() {
353+ {
354+ put ("content" , content );
355+ put ("is_last" , finalIsLast );
356+ }
357+ }).build ()
358+ );
359+
360+ ModelTensors tensors = ModelTensors .builder ().mlModelTensors (orderedTensors ).build ();
361+ ModelTensorOutput tensorOutput = ModelTensorOutput .builder ().mlModelOutputs (List .of (tensors )).build ();
362+
363+ XContentBuilder builder = XContentFactory .jsonBuilder ();
364+ tensorOutput .toXContent (builder , ToXContent .EMPTY_PARAMS );
365+ sseData = "data: " + builder .toString () + "\n \n " ;
282366 }
283367 } catch (Exception e ) {
284- log .error ("Failed to extract values from response" , e );
368+ log .error ("Failed to process response" , e );
369+ sseData = "data: {\" error\" : \" Processing failed\" }\n \n " ;
370+ isLast = true ;
285371 }
372+ return createHttpChunk (sseData , isLast );
373+ }
286374
287- String finalContent = content ;
288- boolean finalIsLast = isLast ;
289-
290- log
291- .info (
292- "Converting to HttpChunk - memoryId: '{}', parentId: '{}', content: '{}', isLast: {}" ,
293- memoryId ,
294- parentInteractionId ,
295- content ,
296- isLast
297- );
375+ private String extractTensorResult (MLTaskResponse response , String tensorName ) {
376+ ModelTensorOutput output = (ModelTensorOutput ) response .getOutput ();
377+ if (output != null && !output .getMlModelOutputs ().isEmpty ()) {
378+ ModelTensors tensors = output .getMlModelOutputs ().get (0 );
379+ for (ModelTensor tensor : tensors .getMlModelTensors ()) {
380+ if (tensorName .equals (tensor .getName ()) && tensor .getResult () != null ) {
381+ return tensor .getResult ();
382+ }
383+ }
384+ }
385+ return "" ;
386+ }
298387
299- // Create ordered tensors
300- List <ModelTensor > orderedTensors = List
301- .of (
302- ModelTensor .builder ().name ("memory_id" ).result (memoryId ).build (),
303- ModelTensor .builder ().name ("parent_interaction_id" ).result (parentInteractionId ).build (),
304- ModelTensor .builder ().name ("response" ).dataAsMap (new LinkedHashMap <String , Object >() {
305- {
306- put ("content" , finalContent );
307- put ("is_last" , finalIsLast );
388+ private Map <String , ?> extractDataMap (MLTaskResponse response ) {
389+ ModelTensorOutput output = (ModelTensorOutput ) response .getOutput ();
390+ if (output != null && !output .getMlModelOutputs ().isEmpty ()) {
391+ ModelTensors tensors = output .getMlModelOutputs ().get (0 );
392+ for (ModelTensor tensor : tensors .getMlModelTensors ()) {
393+ String name = tensor .getName ();
394+ if ("error" .equals (name ) || "llm_response" .equals (name ) || "response" .equals (name )) {
395+ Map <String , ?> dataMap = tensor .getDataAsMap ();
396+ if (dataMap != null ) {
397+ return dataMap ;
308398 }
309- }).build ()
310- );
311-
312- ModelTensors tensors = ModelTensors .builder ().mlModelTensors (orderedTensors ).build ();
313-
314- ModelTensorOutput tensorOutput = ModelTensorOutput .builder ().mlModelOutputs (List .of (tensors )).build ();
315-
316- XContentBuilder builder = XContentFactory .jsonBuilder ();
317- tensorOutput .toXContent (builder , ToXContent .EMPTY_PARAMS );
318- String jsonData = builder .toString ();
319-
320- String sseData = "data: " + jsonData + "\n \n " ;
321- return createHttpChunk (sseData , isLast );
399+ }
400+ }
401+ }
402+ return Map .of ();
322403 }
323404
324405 private HttpChunk createHttpChunk (String sseData , boolean isLast ) {
0 commit comments