Skip to content

Commit

Permalink
Adding connector http timeout in the connector level (#1835)
Browse files Browse the repository at this point in the history
* working on connector http timeout

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* adding more test + fixing integration test issue

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* updating default values

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* fixing unit tests

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* input format changed

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* removed unused code

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* fixing test

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* refactored more and add more tests

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* applying spotlessApply

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* addressing comments

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* removing spaces

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* addressed comments + updated open api model name and endpoints (deprecated)

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* updating test

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* adding fields

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

* addresseing comments to rename client config

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>

---------

Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
  • Loading branch information
dhrubo-os authored Mar 6, 2024
1 parent a159498 commit 25f2122
Show file tree
Hide file tree
Showing 23 changed files with 544 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ public class CommonValue {
+ AbstractConnector.CREDENTIAL_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ AbstractConnector.CLIENT_CONFIG_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ AbstractConnector.ACTIONS_FIELD
+ "\" : {\"type\": \"flat_object\"}\n";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public abstract class AbstractConnector implements Connector {
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String OWNER_FIELD = "owner";
public static final String ACCESS_FIELD = "access";
public static final String CLIENT_CONFIG_FIELD = "client_config";


protected String name;
protected String description;
Expand All @@ -65,6 +67,8 @@ public abstract class AbstractConnector implements Connector {
protected AccessMode access;
protected Instant createdTime;
protected Instant lastUpdateTime;
@Setter
protected ConnectorClientConfig connectorClientConfig;

protected Map<String, String> createPredictDecryptedHeaders(Map<String, String> headers) {
if (headers == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ public class AwsConnector extends HttpConnector {
@Builder(builderMethodName = "awsConnectorBuilder")
public AwsConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode, User owner) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner);
List<String> backendRoles, AccessMode accessMode, User owner,
ConnectorClientConfig connectorClientConfig) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode,
owner, connectorClientConfig);
validate();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ public interface Connector extends ToXContentObject, Writeable {
Map<String, String> getParameters();

List<ConnectorAction> getActions();

ConnectorClientConfig getConnectorClientConfig();

String getPredictEndpoint(Map<String, String> parameters);

String getPredictHttpMethod();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@Getter
@EqualsAndHashCode
public class ConnectorClientConfig implements ToXContentObject, Writeable {

public static final String MAX_CONNECTION_FIELD = "max_connection";
public static final String CONNECTION_TIMEOUT_FIELD = "connection_timeout";
public static final String READ_TIMEOUT_FIELD = "read_timeout";

public static final Integer MAX_CONNECTION_DEFAULT_VALUE = Integer.valueOf(30);
public static final Integer CONNECTION_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);
public static final Integer READ_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);

private Integer maxConnections;
private Integer connectionTimeout;
private Integer readTimeout;

@Builder(toBuilder = true)
public ConnectorClientConfig(
Integer maxConnections,
Integer connectionTimeout,
Integer readTimeout
) {
this.maxConnections = maxConnections;
this.connectionTimeout = connectionTimeout;
this.readTimeout = readTimeout;

}

public ConnectorClientConfig(StreamInput input) throws IOException {
this.maxConnections = input.readOptionalInt();
this.connectionTimeout = input.readOptionalInt();
this.readTimeout = input.readOptionalInt();
}

public ConnectorClientConfig() {
this.maxConnections = MAX_CONNECTION_DEFAULT_VALUE;
this.connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE;
this.readTimeout = READ_TIMEOUT_DEFAULT_VALUE;
}

@Override
public void writeTo(StreamOutput out) throws IOException {

out.writeOptionalInt(maxConnections);
out.writeOptionalInt(connectionTimeout);
out.writeOptionalInt(readTimeout);
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
XContentBuilder builder = xContentBuilder.startObject();
if (maxConnections != null) {
builder.field(MAX_CONNECTION_FIELD, maxConnections);
}
if (connectionTimeout != null) {
builder.field(CONNECTION_TIMEOUT_FIELD, connectionTimeout);
}
if (readTimeout != null) {
builder.field(READ_TIMEOUT_FIELD, readTimeout);
}
return builder.endObject();
}

public static ConnectorClientConfig fromStream(StreamInput in) throws IOException {
ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig(in);
return connectorClientConfig;
}

public static ConnectorClientConfig parse(XContentParser parser) throws IOException {
Integer maxConnections = null;
Integer connectionTimeout = null;
Integer readTimeout = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MAX_CONNECTION_FIELD:
maxConnections = parser.intValue();
break;
case CONNECTION_TIMEOUT_FIELD:
connectionTimeout = parser.intValue();
break;
case READ_TIMEOUT_FIELD:
readTimeout = parser.intValue();
break;
default:
parser.skipChildren();
break;
}
}
return ConnectorClientConfig.builder()
.maxConnections(maxConnections)
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public class HttpConnector extends AbstractConnector {
@Builder
public HttpConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode, User owner) {
List<String> backendRoles, AccessMode accessMode, User owner,
ConnectorClientConfig connectorClientConfig) {
validateProtocol(protocol);
this.name = name;
this.description = description;
Expand All @@ -64,6 +65,8 @@ public HttpConnector(String name, String description, String version, String pro
this.backendRoles = backendRoles;
this.access = accessMode;
this.owner = owner;
this.connectorClientConfig = connectorClientConfig;

}

public HttpConnector(String protocol, XContentParser parser) throws IOException {
Expand Down Expand Up @@ -121,6 +124,9 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
case LAST_UPDATED_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
break;
case CLIENT_CONFIG_FIELD:
connectorClientConfig = ConnectorClientConfig.parse(parser);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -167,6 +173,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (lastUpdateTime != null) {
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
}
if (connectorClientConfig != null) {
builder.field(CLIENT_CONFIG_FIELD, connectorClientConfig);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -205,6 +214,11 @@ private void parseFromStream(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.owner = new User(input);
}
this.createdTime = input.readOptionalInstant();
this.lastUpdateTime = input.readOptionalInstant();
if (input.readBoolean()) {
this.connectorClientConfig = new ConnectorClientConfig(input);
}
}

@Override
Expand Down Expand Up @@ -247,6 +261,14 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
if (connectorClientConfig != null) {
out.writeBoolean(true);
connectorClientConfig.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -279,6 +301,9 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
if (updateContent.getAccess() != null) {
this.access = updateContent.getAccess();
}
if (updateContent.getConnectorClientConfig() != null) {
this.connectorClientConfig = updateContent.getConnectorClientConfig();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.connector.AbstractConnector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorClientConfig;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -43,6 +45,8 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String DRY_RUN_FIELD = "dry_run";



public static final String DRY_RUN_CONNECTOR_NAME = "dryRunConnector";

private String name;
Expand All @@ -55,8 +59,10 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode access;
private boolean dryRun = false;
private boolean updateConnector = false;
private boolean dryRun;
private boolean updateConnector;
private ConnectorClientConfig connectorClientConfig;


@Builder(toBuilder = true)
public MLCreateConnectorInput(String name,
Expand All @@ -70,7 +76,9 @@ public MLCreateConnectorInput(String name,
Boolean addAllBackendRoles,
AccessMode access,
boolean dryRun,
boolean updateConnector
boolean updateConnector,
ConnectorClientConfig connectorClientConfig

) {
if (!dryRun && !updateConnector) {
if (name == null) {
Expand All @@ -95,6 +103,8 @@ public MLCreateConnectorInput(String name,
this.access = access;
this.dryRun = dryRun;
this.updateConnector = updateConnector;
this.connectorClientConfig = connectorClientConfig;

}

public static MLCreateConnectorInput parse(XContentParser parser) throws IOException {
Expand All @@ -113,6 +123,7 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update
Boolean addAllBackendRoles = null;
AccessMode access = null;
boolean dryRun = false;
ConnectorClientConfig connectorClientConfig = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -161,12 +172,16 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update
case DRY_RUN_FIELD:
dryRun = parser.booleanValue();
break;
case AbstractConnector.CLIENT_CONFIG_FIELD:
connectorClientConfig = ConnectorClientConfig.parse(parser);
break;
default:
parser.skipChildren();
break;
}
}
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector);
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions,
backendRoles, addAllBackendRoles, access, dryRun, updateConnector, connectorClientConfig);
}

@Override
Expand Down Expand Up @@ -202,6 +217,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (access != null) {
builder.field(ACCESS_MODE_FIELD, access);
}
if (connectorClientConfig != null) {
builder.field(AbstractConnector.CLIENT_CONFIG_FIELD, connectorClientConfig);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -248,6 +266,13 @@ public void writeTo(StreamOutput output) throws IOException {
}
output.writeBoolean(dryRun);
output.writeBoolean(updateConnector);
if (connectorClientConfig != null) {
output.writeBoolean(true);
connectorClientConfig.writeTo(output);
} else {
output.writeBoolean(false);
}

}

public MLCreateConnectorInput(StreamInput input) throws IOException {
Expand Down Expand Up @@ -277,5 +302,8 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
}
dryRun = input.readBoolean();
updateConnector = input.readBoolean();
if (input.readBoolean()) {
this.connectorClientConfig = new ConnectorClientConfig(input);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,19 @@ public void toXContent_InternalConnector() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"," +
"\"model_version\":\"1.0.0\",\"description\":\"test model\",\"connector\":{\"name\":\"test_connector_name\"," +
"\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," +
assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," +
"\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," +
"\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," +
"\"description\":\"this is a test connector\",\"protocol\":\"http\"," +
"\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," +
"\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," +
"\"headers\":{\"api_key\":\"${credential.key}\"}," +
"\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," +
"\"pre_process_function\":\"connector.pre_process.openai.embedding\"," +
"\"post_process_function\":\"connector.post_process.openai.embedding\"}]," +
"\"backend_roles\":[\"role1\",\"role2\"]," +
"\"access\":\"public\"}}", mlModelContent);
"\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," +
"\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000}}}",
mlModelContent);
}

@Test
Expand Down
Loading

0 comments on commit 25f2122

Please sign in to comment.