Skip to content

Commit 0374222

Browse files
akolarkunnudhrubo-osmingshl
authored
[Enhancement] Enhance validation for create connector API (#3579)
* [Enhancement] Enhance validation for create connector API This change will address the second part of validation "pre and post processing function validation". Partially resolves #2993 Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com> * [Enhancement] Enhance validation for create connector API This change will address the second part of validation "pre and post processing function validation". Partially resolves #2993 Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com> * [Enhancement] Enhance validation for create connector API This change will address the second part of validation "pre and post processing function validation". Partially resolves #2993 Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com> * [Enhancement] Enhance validation for create connector API This change will address the second part of validation "pre and post processing function validation". Partially resolves #2993 Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com> --------- Signed-off-by: Abdul Muneer Kolarkunnu <muneer.kolarkunnu@netapp.com> Co-authored-by: Dhrubo Saha <dhrubo@amazon.com> Co-authored-by: Mingshi Liu <mingshl@amazon.com>
1 parent a894ff1 commit 0374222

File tree

6 files changed

+560
-13
lines changed

6 files changed

+560
-13
lines changed

common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@
99

1010
import java.io.IOException;
1111
import java.util.HashSet;
12+
import java.util.List;
1213
import java.util.Locale;
1314
import java.util.Map;
1415
import java.util.Set;
1516

17+
import org.apache.commons.text.StringSubstitutor;
18+
import org.apache.logging.log4j.LogManager;
19+
import org.apache.logging.log4j.Logger;
1620
import org.opensearch.core.common.io.stream.StreamInput;
1721
import org.opensearch.core.common.io.stream.StreamOutput;
1822
import org.opensearch.core.common.io.stream.Writeable;
@@ -35,6 +39,17 @@ public class ConnectorAction implements ToXContentObject, Writeable {
3539
public static final String REQUEST_BODY_FIELD = "request_body";
3640
public static final String ACTION_PRE_PROCESS_FUNCTION = "pre_process_function";
3741
public static final String ACTION_POST_PROCESS_FUNCTION = "post_process_function";
42+
public static final String OPENAI = "openai";
43+
public static final String COHERE = "cohere";
44+
public static final String BEDROCK = "bedrock";
45+
public static final String SAGEMAKER = "sagemaker";
46+
public static final String SAGEMAKER_PRE_POST_FUNC_TEXT = "default";
47+
public static final List<String> SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES = List.of(SAGEMAKER, OPENAI, BEDROCK, COHERE);
48+
49+
private static final String INBUILT_FUNC_PREFIX = "connector.";
50+
private static final String PRE_PROCESS_FUNC = "PreProcessFunction";
51+
private static final String POST_PROCESS_FUNC = "PostProcessFunction";
52+
private static final Logger logger = LogManager.getLogger(ConnectorAction.class);
3853

3954
private ActionType actionType;
4055
private String method;
@@ -185,6 +200,81 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
185200
.build();
186201
}
187202

203+
/**
204+
* Checks the compatibility of pre and post-process functions with the selected LLM service.
205+
* Each LLM service (eg: Bedrock, OpenAI, SageMaker) has recommended pre and post-process functions
206+
* designed for optimal performance. While it's possible to use functions from other services,
207+
* it's strongly advised to use the corresponding functions for the best results.
208+
* This method logs a warning if non-corresponding functions are detected, but allows the
209+
* configuration to proceed. Users should be aware that using mismatched functions may lead
210+
* to unexpected behavior or reduced performance, though it won't necessarily cause failures.
211+
*
212+
* @param parameters - connector parameters
213+
*/
214+
public void validatePrePostProcessFunctions(Map<String, String> parameters) {
215+
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
216+
String endPoint = substitutor.replace(url);
217+
String remoteServer = getRemoteServerFromURL(endPoint);
218+
if (!remoteServer.isEmpty()) {
219+
validateProcessFunctions(remoteServer, preProcessFunction, PRE_PROCESS_FUNC);
220+
validateProcessFunctions(remoteServer, postProcessFunction, POST_PROCESS_FUNC);
221+
}
222+
}
223+
224+
/**
225+
* To get the remote server name from url
226+
*
227+
* @param url - remote server url
228+
* @return - returns the corresponding remote server name for url, if server is not in the pre-defined list,
229+
* it returns null
230+
*/
231+
public static String getRemoteServerFromURL(String url) {
232+
return SUPPORTED_REMOTE_SERVERS_FOR_DEFAULT_ACTION_TYPES.stream().filter(url::contains).findFirst().orElse("");
233+
}
234+
235+
private void validateProcessFunctions(String remoteServer, String processFunction, String funcNameForWarnText) {
236+
if (isInBuiltProcessFunction(processFunction)) {
237+
switch (remoteServer) {
238+
case OPENAI:
239+
if (!processFunction.contains(OPENAI)) {
240+
logWarningForInvalidProcessFunc(OPENAI, funcNameForWarnText);
241+
}
242+
break;
243+
case COHERE:
244+
if (!processFunction.contains(COHERE)) {
245+
logWarningForInvalidProcessFunc(COHERE, funcNameForWarnText);
246+
}
247+
break;
248+
case BEDROCK:
249+
if (!processFunction.contains(BEDROCK)) {
250+
logWarningForInvalidProcessFunc(BEDROCK, funcNameForWarnText);
251+
}
252+
break;
253+
case SAGEMAKER:
254+
if (!processFunction.contains(SAGEMAKER_PRE_POST_FUNC_TEXT)) {
255+
logWarningForInvalidProcessFunc(SAGEMAKER, funcNameForWarnText);
256+
}
257+
}
258+
}
259+
}
260+
261+
private boolean isInBuiltProcessFunction(String processFunction) {
262+
return (processFunction != null && processFunction.startsWith(INBUILT_FUNC_PREFIX));
263+
}
264+
265+
private void logWarningForInvalidProcessFunc(String remoteServer, String funcNameForWarnText) {
266+
logger
267+
.warn(
268+
"LLM service is "
269+
+ remoteServer
270+
+ ", so "
271+
+ funcNameForWarnText
272+
+ " should be corresponding to "
273+
+ remoteServer
274+
+ " for better results."
275+
);
276+
}
277+
188278
public enum ActionType {
189279
PREDICT,
190280
EXECUTE,

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ public HttpConnector(
7070
String tenantId
7171
) {
7272
validateProtocol(protocol);
73+
if (actions != null) {
74+
for (ConnectorAction action : actions) {
75+
action.validatePrePostProcessFunctions(parameters);
76+
}
77+
}
7378
this.name = name;
7479
this.description = description;
7580
this.version = version;

common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ public MLCreateConnectorInput(
109109
if (credential == null || credential.isEmpty()) {
110110
throw new IllegalArgumentException("Connector credential is null or empty list");
111111
}
112+
if (actions != null) {
113+
for (ConnectorAction action : actions) {
114+
action.validatePrePostProcessFunctions(parameters);
115+
}
116+
}
112117
}
113118
this.name = name;
114119
this.description = description;

0 commit comments

Comments
 (0)