Skip to content

Commit 4864f66

Browse files
[FEATURE] Predict Stream (opensearch-project#4187)
* Initial commit for predict stream Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Address comments, add some UTs Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Address comments, add more UTs Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Fix failing test Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Fix failing tests Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Fix tests Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> * Increase test coverage Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com> --------- Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>
1 parent f994fe8 commit 4864f66

File tree

45 files changed

+2753
-61
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2753
-61
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
import java.util.ArrayList;
2020
import java.util.HashMap;
2121
import java.util.List;
22+
import java.util.Locale;
2223
import java.util.Map;
2324
import java.util.Optional;
2425
import java.util.function.BiFunction;
2526
import java.util.regex.Matcher;
2627
import java.util.regex.Pattern;
2728

29+
import org.apache.commons.text.StringEscapeUtils;
2830
import org.apache.commons.text.StringSubstitutor;
2931
import org.opensearch.Version;
3032
import org.opensearch.common.io.stream.BytesStreamOutput;
@@ -36,6 +38,9 @@
3638
import org.opensearch.ml.common.AccessMode;
3739
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
3840

41+
import com.google.gson.JsonObject;
42+
import com.google.gson.JsonParser;
43+
3944
import lombok.Builder;
4045
import lombok.EqualsAndHashCode;
4146
import lombok.NoArgsConstructor;
@@ -351,12 +356,41 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
351356

352357
if (!isJson(payload)) {
353358
throw new IllegalArgumentException("Invalid payload: " + payload);
359+
} else if (neededStreamParameterInPayload(parameters)) {
360+
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
361+
jsonObject.addProperty("stream", true);
362+
payload = jsonObject.toString();
354363
}
355364
return (T) payload;
356365
}
357366
return (T) parameters.get("http_body");
358367
}
359368

369+
private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
370+
if (parameters == null) {
371+
return false;
372+
}
373+
374+
boolean isStream = parameters.containsKey("stream");
375+
if (!isStream) {
376+
return false;
377+
}
378+
379+
String llmInterface = parameters.get("_llm_interface");
380+
if (llmInterface.isBlank()) {
381+
return false;
382+
}
383+
384+
llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT);
385+
llmInterface = StringEscapeUtils.unescapeJava(llmInterface);
386+
switch (llmInterface) {
387+
case "openai/v1/chat/completions":
388+
return true;
389+
default:
390+
return false;
391+
}
392+
}
393+
360394
protected String fillNullParameters(Map<String, String> parameters, String payload) {
361395
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
362396
String newPayload = payload;

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,4 +485,8 @@ private MLCommonsSettings() {}
485485
Setting.Property.NodeScope,
486486
Setting.Property.Final
487487
);
488+
489+
// Feature flag for streaming feature
490+
public static final Setting<Boolean> ML_COMMONS_STREAM_ENABLED = Setting
491+
.boolSetting(ML_PLUGIN_SETTING_PREFIX + "stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
488492
}

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED;
2323
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
2424
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED;
25+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;
2526

2627
import java.util.ArrayList;
2728
import java.util.List;
@@ -64,6 +65,8 @@ public class MLFeatureEnabledSetting {
6465

6566
private volatile Boolean isIndexInsightEnabled;
6667

68+
private volatile Boolean isStreamEnabled;
69+
6770
private final List<SettingsChangeListener> listeners = new ArrayList<>();
6871

6972
public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
@@ -84,6 +87,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
8487
isMcpConnectorEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(settings);
8588
isAgenticMemoryEnabled = ML_COMMONS_AGENTIC_MEMORY_ENABLED.get(settings);
8689
isIndexInsightEnabled = ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED.get(settings);
90+
isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings);
8791

8892
clusterService
8993
.getClusterSettings()
@@ -110,6 +114,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
110114
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> isAgenticSearchEnabled = it);
111115
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> isMcpConnectorEnabled = it);
112116
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_MEMORY_ENABLED, it -> isAgenticMemoryEnabled = it);
117+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
113118
clusterService
114119
.getClusterSettings()
115120
.addSettingsUpdateConsumer(ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED, it -> isIndexInsightEnabled = it);
@@ -243,4 +248,11 @@ public boolean isMcpConnectorEnabled() {
243248
public boolean isIndexInsightEnabled() {
244249
return isIndexInsightEnabled;
245250
}
251+
252+
/** Whether the streaming feature is enabled. If disabled, APIs in ml-commons will block stream.
253+
* @return whether the streaming is enabled.
254+
*/
255+
public boolean isStreamEnabled() {
256+
return isStreamEnabled;
257+
}
246258
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.prediction;
7+
8+
import org.opensearch.action.ActionType;
9+
import org.opensearch.ml.common.transport.MLTaskResponse;
10+
11+
public class MLPredictionStreamTaskAction extends ActionType<MLTaskResponse> {
12+
public static final MLPredictionStreamTaskAction INSTANCE = new MLPredictionStreamTaskAction();
13+
public static final String NAME = "cluster:admin/opensearch/ml/predict/stream";
14+
15+
private MLPredictionStreamTaskAction() {
16+
super(NAME, MLTaskResponse::new);
17+
}
18+
}

common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.opensearch.core.common.io.stream.StreamOutput;
2424
import org.opensearch.ml.common.input.MLInput;
2525
import org.opensearch.ml.common.transport.MLTaskRequest;
26+
import org.opensearch.transport.TransportChannel;
2627

2728
import lombok.AccessLevel;
2829
import lombok.Builder;
@@ -36,6 +37,10 @@
3637
@ToString
3738
public class MLPredictionTaskRequest extends MLTaskRequest {
3839

40+
@Getter
41+
@Setter
42+
private transient TransportChannel streamingChannel;
43+
3944
String modelId;
4045
MLInput mlInput;
4146
String tenantId;

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,52 @@ public void createPayload_MissingParamsInvalidJson() {
228228
connector.validatePayload(predictPayload);
229229
}
230230

231+
@Test
232+
public void createPayload_WithStreamParameter_OpenAI() {
233+
String requestBody = "{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"${parameters.input}\"}]}";
234+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
235+
236+
Map<String, String> parameters = new HashMap<>();
237+
parameters.put("input", "Hello world");
238+
parameters.put("stream", "true");
239+
parameters.put("_llm_interface", "openai/v1/chat/completions");
240+
241+
String payload = connector.createPayload(PREDICT.name(), parameters);
242+
Assert
243+
.assertEquals(
244+
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello world\"}],\"stream\":true}",
245+
payload
246+
);
247+
}
248+
249+
@Test
250+
public void createPayload_WithoutStreamParameter() {
251+
String requestBody = "{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"${parameters.input}\"}]}";
252+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
253+
254+
Map<String, String> parameters = new HashMap<>();
255+
parameters.put("input", "Hello world");
256+
parameters.put("_llm_interface", "openai/v1/chat/completions");
257+
258+
String payload = connector.createPayload(PREDICT.name(), parameters);
259+
Assert.assertEquals("{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello world\"}]}", payload);
260+
}
261+
262+
@Test
263+
public void createPayload_WithStreamParameter_UnsupportedInterface() {
264+
String requestBody = "{\"input\": \"${parameters.input}\"}";
265+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
266+
267+
Map<String, String> parameters = new HashMap<>();
268+
parameters.put("input", "Hello world");
269+
parameters.put("stream", "true");
270+
parameters.put("_llm_interface", "invalid/interface");
271+
272+
String payload = connector.createPayload(PREDICT.name(), parameters);
273+
274+
Assert.assertEquals("{\"input\": \"Hello world\"}", payload);
275+
}
276+
231277
@Test
232278
public void parseResponse_modelTensorJson() throws IOException {
233279
HttpConnector connector = createHttpConnector();

common/src/test/java/org/opensearch/ml/common/settings/MLCommonsSettingsTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,9 @@ public void testAgenticMemoryDisabledMessage() {
102102
"The Agentic Memory APIs are not enabled. To enable, please update the setting plugins.ml_commons.agentic_memory_enabled";
103103
assertEquals(expectedMessage, MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE);
104104
}
105+
106+
@Test
107+
public void testStreamDisabledByDefault() {
108+
assertFalse(MLCommonsSettings.ML_COMMONS_STREAM_ENABLED.getDefault(null));
109+
}
105110
}

common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ public void setUp() {
4848
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED,
4949
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
5050
MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED,
51-
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED
51+
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
52+
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
5253
)
5354
);
5455
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
@@ -73,6 +74,7 @@ public void testDefaults_allFeaturesEnabled() {
7374
.put("plugins.ml_commons.mcp_connector_enabled", true)
7475
.put("plugins.ml_commons.agentic_search_enabled", true)
7576
.put("plugins.ml_commons.agentic_memory_enabled", true)
77+
.put("plugins.ml_commons.stream_enabled", true)
7678
.build();
7779

7880
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -92,6 +94,7 @@ public void testDefaults_allFeaturesEnabled() {
9294
assertTrue(setting.isMcpConnectorEnabled());
9395
assertTrue(setting.isAgenticSearchEnabled());
9496
assertTrue(setting.isAgenticMemoryEnabled());
97+
assertTrue(setting.isStreamEnabled());
9598
}
9699

97100
@Test
@@ -113,6 +116,7 @@ public void testDefaults_someFeaturesDisabled() {
113116
.put("plugins.ml_commons.mcp_connector_enabled", false)
114117
.put("plugins.ml_commons.agentic_search_enabled", false)
115118
.put("plugins.ml_commons.agentic_memory_enabled", false)
119+
.put("plugins.ml_commons.stream_enabled", false)
116120
.build();
117121

118122
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -132,6 +136,7 @@ public void testDefaults_someFeaturesDisabled() {
132136
assertFalse(setting.isMcpConnectorEnabled());
133137
assertFalse(setting.isAgenticSearchEnabled());
134138
assertFalse(setting.isAgenticMemoryEnabled());
139+
assertFalse(setting.isStreamEnabled());
135140
}
136141

137142
@Test

ml-algorithms/build.gradle

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ dependencies {
7171
implementation platform('software.amazon.awssdk:bom:2.30.18')
7272
api 'software.amazon.awssdk:auth:2.30.18'
7373
implementation 'software.amazon.awssdk:apache-client'
74+
implementation ('software.amazon.awssdk:bedrockruntime') {
75+
exclude group: 'io.netty'
76+
}
7477
implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') {
7578
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
7679
}
@@ -90,6 +93,8 @@ dependencies {
9093
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
9194
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
9295
testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
96+
api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.12.0'
97+
implementation group: 'com.squareup.okhttp3', name: 'okhttp-sse', version: '4.12.0'
9398
}
9499

95100
lombok {

ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.opensearch.ml.common.output.MLOutput;
1515
import org.opensearch.ml.common.transport.MLTaskResponse;
1616
import org.opensearch.ml.engine.encryptor.Encryptor;
17+
import org.opensearch.transport.TransportChannel;
1718

1819
/**
1920
* This is machine learning algorithms predict interface.
@@ -41,6 +42,10 @@ default MLOutput predict(MLInput mlInput) {
4142
}
4243

4344
default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
45+
asyncPredict(mlInput, actionListener, null);
46+
}
47+
48+
default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener, TransportChannel channel) {
4449
actionListener.onFailure(new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG));
4550
}
4651

0 commit comments

Comments
 (0)