Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.Version;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand All @@ -36,6 +38,9 @@
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;

import com.google.gson.JsonObject;
import com.google.gson.JsonParser;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -351,12 +356,41 @@ public <T> T createPayload(String action, Map<String, String> parameters) {

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else if (neededStreamParameterInPayload(parameters)) {
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
jsonObject.addProperty("stream", true);
payload = jsonObject.toString();
}
return (T) payload;
}
return (T) parameters.get("http_body");
}

private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
if (parameters == null) {
return false;
}

boolean isStream = parameters.containsKey("stream");
if (!isStream) {
return false;
}

String llmInterface = parameters.get("_llm_interface");
if (llmInterface.isBlank()) {
return false;
}

llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT);
llmInterface = StringEscapeUtils.unescapeJava(llmInterface);
switch (llmInterface) {
case "openai/v1/chat/completions":
return true;
default:
return false;
}
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
String newPayload = payload;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,4 +485,8 @@ private MLCommonsSettings() {}
Setting.Property.NodeScope,
Setting.Property.Final
);

// Feature flag for streaming feature
public static final Setting<Boolean> ML_COMMONS_STREAM_ENABLED = Setting
.boolSetting(ML_PLUGIN_SETTING_PREFIX + "stream_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED;
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_STREAM_ENABLED;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -64,6 +65,8 @@ public class MLFeatureEnabledSetting {

private volatile Boolean isIndexInsightEnabled;

private volatile Boolean isStreamEnabled;

private final List<SettingsChangeListener> listeners = new ArrayList<>();

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
Expand All @@ -84,6 +87,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
isMcpConnectorEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(settings);
isAgenticMemoryEnabled = ML_COMMONS_AGENTIC_MEMORY_ENABLED.get(settings);
isIndexInsightEnabled = ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED.get(settings);
isStreamEnabled = ML_COMMONS_STREAM_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand All @@ -110,6 +114,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> isAgenticSearchEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> isMcpConnectorEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_MEMORY_ENABLED, it -> isAgenticMemoryEnabled = it);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STREAM_ENABLED, it -> isStreamEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED, it -> isIndexInsightEnabled = it);
Expand Down Expand Up @@ -243,4 +248,11 @@ public boolean isMcpConnectorEnabled() {
public boolean isIndexInsightEnabled() {
return isIndexInsightEnabled;
}

/** Whether the streaming feature is enabled. If disabled, APIs in ml-commons will block stream.
* @return whether the streaming is enabled.
*/
public boolean isStreamEnabled() {
return isStreamEnabled;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.prediction;

import org.opensearch.action.ActionType;
import org.opensearch.ml.common.transport.MLTaskResponse;

public class MLPredictionStreamTaskAction extends ActionType<MLTaskResponse> {
public static final MLPredictionStreamTaskAction INSTANCE = new MLPredictionStreamTaskAction();
public static final String NAME = "cluster:admin/opensearch/ml/predict/stream";

private MLPredictionStreamTaskAction() {
super(NAME, MLTaskResponse::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskRequest;
import org.opensearch.transport.TransportChannel;

import lombok.AccessLevel;
import lombok.Builder;
Expand All @@ -36,6 +37,10 @@
@ToString
public class MLPredictionTaskRequest extends MLTaskRequest {

@Getter
@Setter
private transient TransportChannel streamingChannel;

String modelId;
MLInput mlInput;
String tenantId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,52 @@ public void createPayload_MissingParamsInvalidJson() {
connector.validatePayload(predictPayload);
}

@Test
public void createPayload_WithStreamParameter_OpenAI() {
String requestBody = "{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"${parameters.input}\"}]}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);

Map<String, String> parameters = new HashMap<>();
parameters.put("input", "Hello world");
parameters.put("stream", "true");
parameters.put("_llm_interface", "openai/v1/chat/completions");

String payload = connector.createPayload(PREDICT.name(), parameters);
Assert
.assertEquals(
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello world\"}],\"stream\":true}",
payload
);
}

@Test
public void createPayload_WithoutStreamParameter() {
String requestBody = "{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"${parameters.input}\"}]}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);

Map<String, String> parameters = new HashMap<>();
parameters.put("input", "Hello world");
parameters.put("_llm_interface", "openai/v1/chat/completions");

String payload = connector.createPayload(PREDICT.name(), parameters);
Assert.assertEquals("{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello world\"}]}", payload);
}

@Test
public void createPayload_WithStreamParameter_UnsupportedInterface() {
String requestBody = "{\"input\": \"${parameters.input}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);

Map<String, String> parameters = new HashMap<>();
parameters.put("input", "Hello world");
parameters.put("stream", "true");
parameters.put("_llm_interface", "invalid/interface");

String payload = connector.createPayload(PREDICT.name(), parameters);

Assert.assertEquals("{\"input\": \"Hello world\"}", payload);
}

@Test
public void parseResponse_modelTensorJson() throws IOException {
HttpConnector connector = createHttpConnector();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,9 @@ public void testAgenticMemoryDisabledMessage() {
"The Agentic Memory APIs are not enabled. To enable, please update the setting plugins.ml_commons.agentic_memory_enabled";
assertEquals(expectedMessage, MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE);
}

@Test
public void testStreamDisabledByDefault() {
assertFalse(MLCommonsSettings.ML_COMMONS_STREAM_ENABLED.getDefault(null));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public void setUp() {
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED,
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED,
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
)
);
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
Expand All @@ -73,6 +74,7 @@ public void testDefaults_allFeaturesEnabled() {
.put("plugins.ml_commons.mcp_connector_enabled", true)
.put("plugins.ml_commons.agentic_search_enabled", true)
.put("plugins.ml_commons.agentic_memory_enabled", true)
.put("plugins.ml_commons.stream_enabled", true)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -92,6 +94,7 @@ public void testDefaults_allFeaturesEnabled() {
assertTrue(setting.isMcpConnectorEnabled());
assertTrue(setting.isAgenticSearchEnabled());
assertTrue(setting.isAgenticMemoryEnabled());
assertTrue(setting.isStreamEnabled());
}

@Test
Expand All @@ -113,6 +116,7 @@ public void testDefaults_someFeaturesDisabled() {
.put("plugins.ml_commons.mcp_connector_enabled", false)
.put("plugins.ml_commons.agentic_search_enabled", false)
.put("plugins.ml_commons.agentic_memory_enabled", false)
.put("plugins.ml_commons.stream_enabled", false)
.build();

MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
Expand All @@ -132,6 +136,7 @@ public void testDefaults_someFeaturesDisabled() {
assertFalse(setting.isMcpConnectorEnabled());
assertFalse(setting.isAgenticSearchEnabled());
assertFalse(setting.isAgenticMemoryEnabled());
assertFalse(setting.isStreamEnabled());
}

@Test
Expand Down
5 changes: 5 additions & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ dependencies {
implementation platform('software.amazon.awssdk:bom:2.30.18')
api 'software.amazon.awssdk:auth:2.30.18'
implementation 'software.amazon.awssdk:apache-client'
implementation ('software.amazon.awssdk:bedrockruntime') {
exclude group: 'io.netty'
}
implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') {
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
}
Expand All @@ -90,6 +93,8 @@ dependencies {
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
testImplementation group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
api group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.12.0'
implementation group: 'com.squareup.okhttp3', name: 'okhttp-sse', version: '4.12.0'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.transport.TransportChannel;

/**
* This is machine learning algorithms predict interface.
Expand Down Expand Up @@ -41,6 +42,10 @@ default MLOutput predict(MLInput mlInput) {
}

default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
asyncPredict(mlInput, actionListener, null);
}

default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener, TransportChannel channel) {
actionListener.onFailure(new IllegalStateException(METHOD_NOT_IMPLEMENTED_ERROR_MSG));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@

package org.opensearch.ml.engine.algorithms.remote;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;

import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;

import lombok.Getter;
import lombok.Setter;
Expand All @@ -23,4 +32,23 @@ public void initialize(Connector connector) {
connectorClientConfig = new ConnectorClientConfig();
}
}

public void sendContentResponse(String content, boolean isLast, StreamPredictActionListener<MLTaskResponse, ?> actionListener) {
List<ModelTensor> modelTensors = new ArrayList<>();
Map<String, Object> dataMap = Map.of("content", content, "is_last", isLast);

modelTensors.add(ModelTensor.builder().name("response").dataAsMap(dataMap).build());
ModelTensorOutput output = ModelTensorOutput
.builder()
.mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(modelTensors).build()))
.build();
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
actionListener.onStreamResponse(response, isLast);
}

public void sendCompletionResponse(AtomicBoolean isStreamClosed, StreamPredictActionListener<MLTaskResponse, ?> actionListener) {
if (isStreamClosed.compareAndSet(false, true)) {
sendContentResponse("", true, actionListener);
}
}
}
Loading
Loading