Skip to content

Commit

Permalink
remote inference: add unit test for create connector request/response (
Browse files Browse the repository at this point in the history
…opensearch-project#1067)

* remote inference: add unit test for create connector request/response

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix failed UT

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix failed UT

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent 33a095d commit ab7ad70
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
public static final String OWNER_FIELD = "owner";
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String DRY_RUN_FIELD = "dry_run";

public static final String DRY_RUN_CONNECTOR_NAME = "dryRunConnector";

Expand All @@ -52,6 +53,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode access;
private boolean dryRun = false;

@Builder(toBuilder = true)
public MLCreateConnectorInput(String name,
Expand All @@ -63,8 +65,20 @@ public MLCreateConnectorInput(String name,
List<ConnectorAction> actions,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode access
AccessMode access,
boolean dryRun
) {
if (!dryRun) {
if (name == null) {
throw new IllegalArgumentException("Connector name is null");
}
if (version == null) {
throw new IllegalArgumentException("Connector version is null");
}
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
}
this.name = name;
this.description = description;
this.version = version;
Expand All @@ -75,6 +89,7 @@ public MLCreateConnectorInput(String name,
this.backendRoles = backendRoles;
this.addAllBackendRoles = addAllBackendRoles;
this.access = access;
this.dryRun = dryRun;
}

public static MLCreateConnectorInput parse(XContentParser parser) throws IOException {
Expand All @@ -88,6 +103,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
List<String> backendRoles = null;
Boolean addAllBackendRoles = null;
AccessMode access = null;
boolean dryRun = false;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -133,12 +149,15 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
case ACCESS_MODE_FIELD:
access = AccessMode.from(parser.text());
break;
case DRY_RUN_FIELD:
dryRun = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access);
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun);
}

@Override
Expand Down Expand Up @@ -181,7 +200,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@Override
public void writeTo(StreamOutput output) throws IOException {
output.writeString(name);
output.writeString(description);
output.writeOptionalString(description);
output.writeString(version);
output.writeString(protocol);
if (parameters != null) {
Expand Down Expand Up @@ -211,20 +230,19 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
if (addAllBackendRoles != null) {
output.writeBoolean(addAllBackendRoles);
}
output.writeOptionalBoolean(addAllBackendRoles);
if (access != null) {
output.writeBoolean(true);
output.writeEnum(access);
} else {
output.writeBoolean(false);
}
output.writeBoolean(dryRun);
}

public MLCreateConnectorInput(StreamInput input) throws IOException {
name = input.readString();
description = input.readString();
description = input.readOptionalString();
version = input.readString();
protocol = input.readString();
if (input.readBoolean()) {
Expand All @@ -247,5 +265,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.access = input.readEnum(AccessMode.class);
}
dryRun = input.readBoolean();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;

import java.io.IOException;
import java.io.UncheckedIOException;

public class MLCreateConnectorRequestTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Test
public void validate_nullInput() {
MLCreateConnectorRequest request = new MLCreateConnectorRequest((MLCreateConnectorInput)null);
ActionRequestValidationException exception = request.validate();
Assert.assertTrue(exception.getMessage().contains("ML Connector input can't be null"));
}

@Test
public void readFromStream() throws IOException {
MLCreateConnectorInput input = MLCreateConnectorInput.builder()
.name("test_connector")
.protocol("http")
.version("1")
.description("test")
.build();
MLCreateConnectorRequest request = new MLCreateConnectorRequest(input);
BytesStreamOutput output = new BytesStreamOutput();
request.writeTo(output);
MLCreateConnectorRequest request2 = new MLCreateConnectorRequest(output.bytes().streamInput());
Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName());
Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol());
Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion());
Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription());
}

@Test
public void fromActionRequest() {
MLCreateConnectorInput input = MLCreateConnectorInput.builder()
.name("test_connector")
.protocol("http")
.version("1")
.description("test")
.build();
ActionRequest request = new MLCreateConnectorRequest(input);
MLCreateConnectorRequest request2 = MLCreateConnectorRequest.fromActionRequest(request);
Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName());
Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol());
Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion());
Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription());
}

@Test
public void fromActionRequest_Exception() {
exceptionRule.expect(UncheckedIOException.class);
exceptionRule.expectMessage("Failed to parse ActionRequest into MLCreateConnectorRequest");
ActionRequest request = new MLConnectorGetRequest("test_id", true);
MLCreateConnectorRequest.fromActionRequest(request);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.connector;

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;

public class MLCreateConnectorResponseTest {

@Test
public void toXContent() throws IOException {
MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id");
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"connector_id\":\"test_id\"}", content);
}

@Test
public void readFromStream() throws IOException {
MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id");
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);

MLCreateConnectorResponse response2 = new MLCreateConnectorResponse(output.bytes().streamInput());
Assert.assertEquals("test_id", response2.getConnectorId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public TransportCreateConnectorAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.fromActionRequest(request);
MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput();
if (MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME.equals(mlCreateConnectorInput.getName())) {
if (mlCreateConnectorInput.isDryRun()) {
MLCreateConnectorResponse response = new MLCreateConnectorResponse(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME);
listener.onResponse(response);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
log.error(e.getMessage(), e);
listener.onFailure(e);
});
MLCreateConnectorRequest mlCreateConnectorRequest = createConnectorRequest();
MLCreateConnectorRequest mlCreateConnectorRequest = createDryRunConnectorRequest();
client.execute(MLCreateConnectorAction.INSTANCE, mlCreateConnectorRequest, dryRunResultListener);
}
} else {
Expand All @@ -207,8 +207,8 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
}
}

private MLCreateConnectorRequest createConnectorRequest() {
MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().name("dryRunConnector").build();
private MLCreateConnectorRequest createDryRunConnectorRequest() {
MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().dryRun(true).build();
return new MLCreateConnectorRequest(createConnectorInput);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ public void setup() {
Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret");
input = MLCreateConnectorInput
.builder()
.name("test_name")
.version("1")
.actions(actions)
.parameters(parameters)
.protocol(ConnectorProtocols.HTTP)
Expand Down Expand Up @@ -430,6 +432,7 @@ public void test_execute_dryRun_connector_creation() {

MLCreateConnectorInput mlCreateConnectorInput = mock(MLCreateConnectorInput.class);
when(mlCreateConnectorInput.getName()).thenReturn(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME);
when(mlCreateConnectorInput.isDryRun()).thenReturn(true);
MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput);
action.doExecute(task, request, actionListener);
verify(actionListener).onResponse(any(MLCreateConnectorResponse.class));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() {
MLRegisterModelInput input = mock(MLRegisterModelInput.class);
when(request.getRegisterModelInput()).thenReturn(input);
when(input.getModelName()).thenReturn("Test Model");
when(input.getVersion()).thenReturn("1");
when(input.getModelGroupId()).thenReturn("modelGroupID");
when(input.getFunctionName()).thenReturn(FunctionName.REMOTE);
Connector connector = mock(Connector.class);
Expand Down

0 comments on commit ab7ad70

Please sign in to comment.