Skip to content

Commit d077f31

Browse files
Sq-engXia
andauthored
Parameter Passing for Predict via Remote Connector (#4121)
* poc Signed-off-by: Shiqi Xia <sqxia@amazon.com> * version that run correctly Signed-off-by: Shiqi Xia <sqxia@amazon.com> * passed all UT Signed-off-by: Shiqi Xia <sqxia@amazon.com> * passed all UT Signed-off-by: Shiqi Xia <sqxia@amazon.com> * Remove the removeMissParameterFields Signed-off-by: Shiqi Xia <sqxia@amazon.com> --------- Signed-off-by: Shiqi Xia <sqxia@amazon.com> Co-authored-by: Xia <sqxia@7cf34de0a95f.ant.amazon.com>
1 parent 8019998 commit d077f31

File tree

4 files changed

+238
-1
lines changed

4 files changed

+238
-1
lines changed

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public class StringUtils {
7878
}
7979
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
8080

81-
private static final ObjectMapper MAPPER = new ObjectMapper();
81+
public static final ObjectMapper MAPPER = new ObjectMapper();
8282

8383
public static boolean isValidJsonString(String json) {
8484
if (json == null || json.isBlank()) {

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,41 @@ public void createPayload() {
193193
Assert.assertEquals("{\"input\": \"test input value\"}", predictPayload);
194194
}
195195

196+
@Test
197+
public void createPayload_ExtraParams() {
198+
199+
String requestBody =
200+
"{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": \"${parameters.content_type}\" }}";
201+
String expected =
202+
"{\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": \"query\" }}";
203+
204+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
205+
Map<String, String> parameters = new HashMap<>();
206+
parameters.put("input", "test value");
207+
parameters.put("sparseEmbeddingFormat", "WORD");
208+
parameters.put("content_type", "query");
209+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
210+
connector.validatePayload(predictPayload);
211+
Assert.assertEquals(expected, predictPayload);
212+
}
213+
214+
@Test
215+
public void createPayload_MissingParamsInvalidJson() {
216+
exceptionRule.expect(IllegalArgumentException.class);
217+
exceptionRule
218+
.expectMessage(
219+
"Invalid payload: {\"input\": \"test value\", \"parameters\": {\"sparseEmbeddingFormat\": \"WORD\", \"content_type\": ${parameters.content_type} }}"
220+
);
221+
String requestBody =
222+
"{\"input\": \"${parameters.input}\", \"parameters\": {\"sparseEmbeddingFormat\": \"${parameters.sparseEmbeddingFormat}\", \"content_type\": ${parameters.content_type} }}";
223+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
224+
Map<String, String> parameters = new HashMap<>();
225+
parameters.put("input", "test value");
226+
parameters.put("sparseEmbeddingFormat", "WORD");
227+
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
228+
connector.validatePayload(predictPayload);
229+
}
230+
196231
@Test
197232
public void parseResponse_modelTensorJson() throws IOException {
198233
HttpConnector connector = createHttpConnector();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
package org.opensearch.ml.engine.algorithms.remote;
77

8+
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
89
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
910
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
1011
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;
1112

13+
import java.io.IOException;
1214
import java.util.Arrays;
1315
import java.util.Collection;
1416
import java.util.HashMap;
@@ -28,11 +30,14 @@
2830
import org.opensearch.common.collect.Tuple;
2931
import org.opensearch.common.unit.TimeValue;
3032
import org.opensearch.common.util.TokenBucket;
33+
import org.opensearch.common.xcontent.XContentFactory;
3134
import org.opensearch.commons.ConfigConstants;
3235
import org.opensearch.commons.authuser.User;
3336
import org.opensearch.core.action.ActionListener;
3437
import org.opensearch.core.rest.RestStatus;
3538
import org.opensearch.core.xcontent.NamedXContentRegistry;
39+
import org.opensearch.core.xcontent.ToXContent;
40+
import org.opensearch.core.xcontent.XContentBuilder;
3641
import org.opensearch.ml.common.FunctionName;
3742
import org.opensearch.ml.common.connector.Connector;
3843
import org.opensearch.ml.common.connector.ConnectorAction;
@@ -42,10 +47,12 @@
4247
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
4348
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4449
import org.opensearch.ml.common.input.MLInput;
50+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
4551
import org.opensearch.ml.common.model.MLGuard;
4652
import org.opensearch.ml.common.output.model.ModelTensorOutput;
4753
import org.opensearch.ml.common.output.model.ModelTensors;
4854
import org.opensearch.ml.common.transport.MLTaskResponse;
55+
import org.opensearch.ml.common.utils.StringUtils;
4956
import org.opensearch.script.ScriptService;
5057
import org.opensearch.threadpool.ThreadPool;
5158
import org.opensearch.transport.client.Client;
@@ -83,6 +90,7 @@ default void executeAction(String action, MLInput mlInput, ActionListener<MLTask
8390
MLInput
8491
.builder()
8592
.algorithm(FunctionName.TEXT_EMBEDDING)
93+
.parameters(mlInput.getParameters())
8694
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
8795
.build(),
8896
new ExecutionContext(sequence++),
@@ -187,6 +195,18 @@ default void preparePayloadAndInvoke(
187195
inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters());
188196
}
189197
parameters.putAll(inputParameters);
198+
199+
MLAlgoParams algoParams = mlInput.getParameters();
200+
if (algoParams != null) {
201+
try {
202+
Map<String, String> parametersMap = getParams(mlInput);
203+
parameters.putAll(parametersMap);
204+
} catch (IOException e) {
205+
actionListener.onFailure(e);
206+
return;
207+
}
208+
}
209+
190210
RemoteInferenceInputDataSet inputData = processInput(action, mlInput, connector, parameters, getScriptService());
191211
if (inputData.getParameters() != null) {
192212
parameters.putAll(inputData.getParameters());
@@ -227,6 +247,15 @@ && getUserRateLimiterMap().get(user.getName()) != null
227247
}
228248
}
229249

250+
static Map<String, String> getParams(MLInput mlInput) throws IOException {
251+
XContentBuilder builder = XContentFactory.jsonBuilder();
252+
mlInput.getParameters().toXContent(builder, ToXContent.EMPTY_PARAMS);
253+
builder.flush();
254+
String json = builder.toString();
255+
Map<String, Object> tempMap = StringUtils.MAPPER.readValue(json, Map.class);
256+
return getParameterMap(tempMap);
257+
}
258+
230259
default BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) {
231260
switch (connectorClientConfig.getRetryBackoffPolicy()) {
232261
case EXPONENTIAL_EQUAL_JITTER:

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
import static org.mockito.ArgumentMatchers.any;
99
import static org.mockito.Mockito.argThat;
10+
import static org.mockito.Mockito.doThrow;
1011
import static org.mockito.Mockito.spy;
1112
import static org.mockito.Mockito.times;
13+
import static org.mockito.Mockito.verify;
1214
import static org.mockito.Mockito.when;
1315
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
1416
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
@@ -17,7 +19,9 @@
1719
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
1820
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
1921

22+
import java.io.IOException;
2023
import java.util.Arrays;
24+
import java.util.HashMap;
2125
import java.util.Map;
2226

2327
import org.junit.Assert;
@@ -30,6 +34,7 @@
3034
import org.opensearch.common.settings.Settings;
3135
import org.opensearch.common.util.concurrent.ThreadContext;
3236
import org.opensearch.core.action.ActionListener;
37+
import org.opensearch.core.xcontent.XContentBuilder;
3338
import org.opensearch.ingest.TestTemplateService;
3439
import org.opensearch.ml.common.FunctionName;
3540
import org.opensearch.ml.common.connector.AwsConnector;
@@ -39,6 +44,10 @@
3944
import org.opensearch.ml.common.connector.RetryBackoffPolicy;
4045
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
4146
import org.opensearch.ml.common.input.MLInput;
47+
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
48+
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
49+
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
50+
import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
4251
import org.opensearch.ml.common.output.model.ModelTensors;
4352
import org.opensearch.ml.engine.encryptor.Encryptor;
4453
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
@@ -64,6 +73,9 @@ public class RemoteConnectorExecutorTest {
6473
@Mock
6574
ActionListener<Tuple<Integer, ModelTensors>> actionListener;
6675

76+
@Mock
77+
private MLAlgoParams mlInputParams;
78+
6779
@Before
6880
public void setUp() {
6981
MockitoAnnotations.openMocks(this);
@@ -169,4 +181,165 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault()
169181
);
170182
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
171183
}
184+
185+
@Test
186+
public void executePreparePayloadAndInvoke_PassingParameter() {
187+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
188+
Connector connector = getConnector(parameters);
189+
AwsConnectorExecutor executor = getExecutor(connector);
190+
191+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
192+
.builder()
193+
.parameters(Map.of("input", "You are a ${parameters.role}"))
194+
.actionType(PREDICT)
195+
.build();
196+
String actionType = inputDataSet.getActionType().toString();
197+
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
198+
.builder()
199+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
200+
.embeddingContentType(null)
201+
.build();
202+
MLInput mlInput = MLInput
203+
.builder()
204+
.algorithm(FunctionName.TEXT_EMBEDDING)
205+
.parameters(inputParams)
206+
.inputDataset(inputDataSet)
207+
.build();
208+
209+
Exception exception = Assert
210+
.assertThrows(
211+
IllegalArgumentException.class,
212+
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
213+
);
214+
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
215+
}
216+
217+
@Test
218+
public void executePreparePayloadAndInvoke_GetParamsIOException() throws Exception {
219+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
220+
Connector connector = getConnector(parameters);
221+
AwsConnectorExecutor executor = getExecutor(connector);
222+
223+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
224+
.builder()
225+
.parameters(Map.of("input", "test input"))
226+
.actionType(PREDICT)
227+
.build();
228+
String actionType = inputDataSet.getActionType().toString();
229+
doThrow(new IOException("UT test IOException")).when(mlInputParams).toXContent(any(XContentBuilder.class), any());
230+
MLInput mlInput = MLInput
231+
.builder()
232+
.algorithm(FunctionName.TEXT_EMBEDDING)
233+
.parameters(mlInputParams)
234+
.inputDataset(inputDataSet)
235+
.build();
236+
237+
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener);
238+
verify(actionListener).onFailure(argThat(e -> e instanceof IOException && e.getMessage().contains("UT test IOException")));
239+
}
240+
241+
@Test
242+
public void executeGetParams_MissingParameter() {
243+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
244+
Connector connector = getConnector(parameters);
245+
AwsConnectorExecutor executor = getExecutor(connector);
246+
247+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
248+
.builder()
249+
.parameters(Map.of("input", "${parameters.input}"))
250+
.actionType(PREDICT)
251+
.build();
252+
String actionType = inputDataSet.getActionType().toString();
253+
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
254+
.builder()
255+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
256+
.embeddingContentType(null)
257+
.build();
258+
MLInput mlInput = MLInput
259+
.builder()
260+
.algorithm(FunctionName.TEXT_EMBEDDING)
261+
.parameters(inputParams)
262+
.inputDataset(inputDataSet)
263+
.build();
264+
265+
try {
266+
Map<String, String> paramsMap = RemoteConnectorExecutor.getParams(mlInput);
267+
Map<String, String> expectedMap = new HashMap<>();
268+
expectedMap.put("sparse_embedding_format", "WORD");
269+
Assert.assertEquals(expectedMap, paramsMap);
270+
} catch (IOException e) {
271+
e.printStackTrace();
272+
}
273+
}
274+
275+
@Test
276+
public void executeGetParams_PassingParameter() {
277+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
278+
Connector connector = getConnector(parameters);
279+
AwsConnectorExecutor executor = getExecutor(connector);
280+
281+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
282+
.builder()
283+
.parameters(Map.of("input", "${parameters.input}"))
284+
.actionType(PREDICT)
285+
.build();
286+
String actionType = inputDataSet.getActionType().toString();
287+
AsymmetricTextEmbeddingParameters inputParams = AsymmetricTextEmbeddingParameters
288+
.builder()
289+
.sparseEmbeddingFormat(SparseEmbeddingFormat.WORD)
290+
.embeddingContentType(AsymmetricTextEmbeddingParameters.EmbeddingContentType.PASSAGE)
291+
.build();
292+
MLInput mlInput = MLInput
293+
.builder()
294+
.algorithm(FunctionName.TEXT_EMBEDDING)
295+
.parameters(inputParams)
296+
.inputDataset(inputDataSet)
297+
.build();
298+
299+
try {
300+
Map<String, String> paramsMap = RemoteConnectorExecutor.getParams(mlInput);
301+
Map<String, String> expectedMap = new HashMap<>();
302+
expectedMap.put("sparse_embedding_format", "WORD");
303+
expectedMap.put("content_type", "PASSAGE");
304+
Assert.assertEquals(expectedMap, paramsMap);
305+
} catch (IOException e) {
306+
e.printStackTrace();
307+
}
308+
}
309+
310+
@Test
311+
public void executeGetParams_ConvertToString() {
312+
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
313+
Connector connector = getConnector(parameters);
314+
AwsConnectorExecutor executor = getExecutor(connector);
315+
316+
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
317+
.builder()
318+
.parameters(Map.of("input", "${parameters.input}"))
319+
.actionType(PREDICT)
320+
.build();
321+
KMeansParams inputParams = KMeansParams
322+
.builder()
323+
.centroids(5)
324+
.iterations(100)
325+
.distanceType(KMeansParams.DistanceType.EUCLIDEAN)
326+
.build();
327+
MLInput mlInput = MLInput
328+
.builder()
329+
.algorithm(FunctionName.TEXT_EMBEDDING)
330+
.parameters(inputParams)
331+
.inputDataset(inputDataSet)
332+
.build();
333+
334+
try {
335+
Map<String, String> paramsMap = RemoteConnectorExecutor.getParams(mlInput);
336+
Map<String, String> expectedMap = new HashMap<>();
337+
expectedMap.put("centroids", "5");
338+
expectedMap.put("iterations", "100");
339+
expectedMap.put("distance_type", "EUCLIDEAN");
340+
Assert.assertEquals(expectedMap, paramsMap);
341+
} catch (IOException e) {
342+
e.printStackTrace();
343+
}
344+
}
172345
}

0 commit comments

Comments
 (0)