Skip to content

Commit 2ae1e44

Browse files
Initial commit for agent streaming
Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>
1 parent 5a7f4cf commit 2ae1e44

27 files changed

+1208
-82
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.execute;
7+
8+
import org.opensearch.action.ActionType;
9+
10+
public class MLExecuteStreamTaskAction extends ActionType<MLExecuteTaskResponse> {
11+
public static final MLExecuteStreamTaskAction INSTANCE = new MLExecuteStreamTaskAction();
12+
public static final String NAME = "cluster:admin/opensearch/ml/execute/stream";
13+
14+
private MLExecuteStreamTaskAction() {
15+
super(NAME, MLExecuteTaskResponse::new);
16+
}
17+
}

common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,27 @@
2222
import org.opensearch.ml.common.MLCommonsClassLoader;
2323
import org.opensearch.ml.common.input.Input;
2424
import org.opensearch.ml.common.transport.MLTaskRequest;
25+
import org.opensearch.transport.TransportChannel;
2526

2627
import lombok.AccessLevel;
2728
import lombok.Builder;
2829
import lombok.Getter;
2930
import lombok.NonNull;
31+
import lombok.Setter;
3032
import lombok.ToString;
3133
import lombok.experimental.FieldDefaults;
34+
import lombok.experimental.NonFinal;
3235

3336
@Getter
3437
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
3538
@ToString
3639
public class MLExecuteTaskRequest extends MLTaskRequest {
3740

41+
@Getter
42+
@Setter
43+
@NonFinal
44+
private transient TransportChannel streamingChannel;
45+
3846
FunctionName functionName;
3947
Input input;
4048

ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
import org.opensearch.ml.common.exception.ExecuteException;
1010
import org.opensearch.ml.common.input.Input;
1111
import org.opensearch.ml.common.output.Output;
12+
import org.opensearch.transport.TransportChannel;
1213

1314
public interface Executable {
1415

1516
/**
1617
* Execute algorithm with given input data.
1718
* @param input input data
19+
* @param listener action listener
20+
* @param channel transport channel
1821
*/
19-
void execute(Input input, ActionListener<Output> listener) throws ExecuteException;
22+
void execute(Input input, ActionListener<Output> listener, TransportChannel channel) throws ExecuteException;
2023
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.opensearch.ml.common.output.MLOutput;
2727
import org.opensearch.ml.common.output.Output;
2828
import org.opensearch.ml.engine.encryptor.Encryptor;
29+
import org.opensearch.transport.TransportChannel;
2930

3031
import lombok.Getter;
3132
import lombok.extern.log4j.Log4j2;
@@ -186,20 +187,20 @@ public MLOutput trainAndPredict(Input input) {
186187
return trainAndPredictable.trainAndPredict(mlInput);
187188
}
188189

189-
public void execute(Input input, ActionListener<Output> listener) throws Exception {
190+
public void execute(Input input, ActionListener<Output> listener, TransportChannel channel) throws Exception {
190191
validateInput(input);
191192
if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
192193
MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
193194
if (executable == null) {
194195
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
195196
}
196-
executable.execute(input, listener);
197+
executable.execute(input, listener, channel);
197198
} else {
198199
Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
199200
if (executable == null) {
200201
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
201202
}
202-
executable.execute(input, listener);
203+
executable.execute(input, listener, channel);
203204
}
204205
}
205206

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.ml.engine.MLExecutable;
2929
import org.opensearch.ml.engine.ModelHelper;
3030
import org.opensearch.ml.engine.utils.ZipUtils;
31+
import org.opensearch.transport.TransportChannel;
3132

3233
import ai.djl.Application;
3334
import ai.djl.Device;
@@ -52,7 +53,7 @@ public abstract class DLModelExecute implements MLExecutable {
5253
protected Device[] devices;
5354
protected AtomicInteger nextDevice = new AtomicInteger(0);
5455

55-
public abstract void execute(Input input, ActionListener<Output> listener);
56+
public abstract void execute(Input input, ActionListener<Output> listener, TransportChannel channel);
5657

5758
protected Predictor<float[][], ai.djl.modality.Output> getPredictor() {
5859
int currentDevice = nextDevice.getAndIncrement();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
import org.opensearch.remote.metadata.client.SdkClient;
8181
import org.opensearch.remote.metadata.common.SdkClientUtils;
8282
import org.opensearch.search.fetch.subphase.FetchSourceContext;
83+
import org.opensearch.transport.TransportChannel;
8384
import org.opensearch.transport.client.Client;
8485

8586
import com.google.common.annotations.VisibleForTesting;
@@ -143,7 +144,7 @@ public void onMultiTenancyEnabledChanged(boolean isEnabled) {
143144
}
144145

145146
@Override
146-
public void execute(Input input, ActionListener<Output> listener) {
147+
public void execute(Input input, ActionListener<Output> listener, TransportChannel channel) {
147148
if (!(input instanceof AgentMLInput)) {
148149
throw new IllegalArgumentException("wrong input");
149150
}
@@ -271,7 +272,8 @@ public void execute(Input input, ActionListener<Output> listener) {
271272
isAsync,
272273
outputs,
273274
modelTensors,
274-
mlAgent
275+
mlAgent,
276+
channel
275277
);
276278
}, e -> {
277279
log.error("Failed to get existing interaction for regeneration", e);
@@ -287,7 +289,8 @@ public void execute(Input input, ActionListener<Output> listener) {
287289
isAsync,
288290
outputs,
289291
modelTensors,
290-
mlAgent
292+
mlAgent,
293+
channel
291294
);
292295
}
293296
}, ex -> {
@@ -318,7 +321,8 @@ public void execute(Input input, ActionListener<Output> listener) {
318321
outputs,
319322
modelTensors,
320323
listener,
321-
createdMemory
324+
createdMemory,
325+
channel
322326
),
323327
ex -> {
324328
log.error("Failed to find memory with memory_id: {}", memoryId, ex);
@@ -329,7 +333,6 @@ public void execute(Input input, ActionListener<Output> listener) {
329333
return;
330334
}
331335
}
332-
333336
executeAgent(
334337
inputDataSet,
335338
mlTask,
@@ -339,7 +342,8 @@ public void execute(Input input, ActionListener<Output> listener) {
339342
outputs,
340343
modelTensors,
341344
listener,
342-
null
345+
null,
346+
channel
343347
);
344348
}
345349
} catch (Exception e) {
@@ -382,7 +386,8 @@ private void saveRootInteractionAndExecute(
382386
boolean isAsync,
383387
List<ModelTensors> outputs,
384388
List<ModelTensor> modelTensors,
385-
MLAgent mlAgent
389+
MLAgent mlAgent,
390+
TransportChannel channel
386391
) {
387392
String appType = mlAgent.getAppType();
388393
String question = inputDataSet.getParameters().get(QUESTION);
@@ -416,7 +421,8 @@ private void saveRootInteractionAndExecute(
416421
outputs,
417422
modelTensors,
418423
listener,
419-
memory
424+
memory,
425+
channel
420426
),
421427
e -> {
422428
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
@@ -425,7 +431,18 @@ private void saveRootInteractionAndExecute(
425431
)
426432
);
427433
} else {
428-
executeAgent(inputDataSet, mlTask, isAsync, memory.getConversationId(), mlAgent, outputs, modelTensors, listener, memory);
434+
executeAgent(
435+
inputDataSet,
436+
mlTask,
437+
isAsync,
438+
memory.getConversationId(),
439+
mlAgent,
440+
outputs,
441+
modelTensors,
442+
listener,
443+
memory,
444+
channel
445+
);
429446
}
430447
}, ex -> {
431448
log.error("Failed to create parent interaction", ex);
@@ -442,7 +459,8 @@ private void executeAgent(
442459
List<ModelTensors> outputs,
443460
List<ModelTensor> modelTensors,
444461
ActionListener<Output> listener,
445-
ConversationIndexMemory memory
462+
ConversationIndexMemory memory,
463+
TransportChannel channel
446464
) {
447465
String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null;
448466
if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) {
@@ -494,7 +512,7 @@ private void executeAgent(
494512
memory
495513
);
496514
inputDataSet.getParameters().put(TASK_ID_FIELD, taskId);
497-
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
515+
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel);
498516
}, e -> {
499517
log.error("Failed to create task for agent async execution", e);
500518
listener.onFailure(e);
@@ -508,7 +526,7 @@ private void executeAgent(
508526
parentInteractionId,
509527
memory
510528
);
511-
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
529+
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel);
512530
}
513531
}
514532

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentRunner.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.opensearch.core.action.ActionListener;
1111
import org.opensearch.ml.common.agent.MLAgent;
12+
import org.opensearch.transport.TransportChannel;
1213

1314
/**
1415
* Agent executor interface definition. Agent executor will be used by {@link MLAgentExecutor} to invoke agents.
@@ -20,6 +21,7 @@ public interface MLAgentRunner {
2021
* @param mlAgent
2122
* @param params
2223
* @param listener
24+
* @param channel
2325
*/
24-
void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener);
26+
void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener, TransportChannel channel);
2527
}

0 commit comments

Comments
 (0)