Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] add register action request/response #1780

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class LLMSpec implements ToXContentObject {
public static final String MODEL_ID_FIELD = "model_id";
Expand Down
46 changes: 27 additions & 19 deletions common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -26,7 +27,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLAgent implements ToXContentObject, Writeable {
public static final String AGENT_NAME_FIELD = "name";
Expand Down Expand Up @@ -64,9 +65,6 @@ public MLAgent(String name,
Instant createdTime,
Instant lastUpdateTime,
String appType) {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
}
this.name = name;
this.type = type;
this.description = description;
Expand All @@ -77,6 +75,24 @@ public MLAgent(String name,
this.createdTime = createdTime;
this.lastUpdateTime = lastUpdateTime;
this.appType = appType;
validate();
}

private void validate() {
if (name == null) {
throw new IllegalArgumentException("agent name is null");
}
Set<String> toolNames = new HashSet<>();
if (tools != null) {
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Duplicate tool defined: " + toolName);
} else {
toolNames.add(toolName);
}
}
}
}

public MLAgent(StreamInput input) throws IOException{
Expand All @@ -99,18 +115,10 @@ public MLAgent(StreamInput input) throws IOException{
if (input.readBoolean()) {
memory = new MLMemorySpec(input);
}
createdTime = input.readInstant();
lastUpdateTime = input.readInstant();
appType = input.readString();
if (!"flow".equals(type)) {
Set<String> toolNames = new HashSet<>();
for (MLToolSpec toolSpec : tools) {
String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType());
if (toolNames.contains(toolName)) {
throw new IllegalArgumentException("Tool has duplicate name or alias: " + toolName);
}
}
}
createdTime = input.readOptionalInstant();
lastUpdateTime = input.readOptionalInstant();
appType = input.readOptionalString();
validate();
}

public void writeTo(StreamOutput out) throws IOException {
Expand Down Expand Up @@ -144,9 +152,9 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeInstant(createdTime);
out.writeInstant(lastUpdateTime);
out.writeString(appType);
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
out.writeOptionalString(appType);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -18,7 +19,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


@EqualsAndHashCode
@Getter
public class MLMemorySpec implements ToXContentObject {
public static final String MEMORY_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.agent;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -19,7 +20,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;


@EqualsAndHashCode
@Getter
public class MLToolSpec implements ToXContentObject {
public static final String TOOL_TYPE_FIELD = "type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.transport.agent;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
Expand All @@ -20,6 +21,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLAgentGetResponse extends ActionResponse implements ToXContentObject {
MLAgent mlAgent;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import org.opensearch.action.ActionType;

public class MLRegisterAgentAction extends ActionType<MLRegisterAgentResponse> {
public static MLRegisterAgentAction INSTANCE = new MLRegisterAgentAction();
public static final String NAME = "cluster:admin/opensearch/ml/agents/register";

private MLRegisterAgentAction() {
super(NAME, MLRegisterAgentResponse::new);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.agent.MLAgent;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLRegisterAgentRequest extends ActionRequest {

MLAgent mlAgent;

@Builder
public MLRegisterAgentRequest(MLAgent mlAgent) {
this.mlAgent = mlAgent;
}

public MLRegisterAgentRequest(StreamInput in) throws IOException {
super(in);
this.mlAgent = new MLAgent(in);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (mlAgent == null) {
exception = addValidationError("ML agent can't be null", exception);
}

return exception;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.mlAgent.writeTo(out);
}

public static MLRegisterAgentRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLRegisterAgentRequest) {
return (MLRegisterAgentRequest) actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionRequest into MLRegisterAgentRequest", e);
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.agent;

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLRegisterAgentResponse extends ActionResponse implements ToXContentObject {
public static final String AGENT_ID_FIELD = "agent_id";

private String agentId;

public MLRegisterAgentResponse(StreamInput in) throws IOException {
super(in);
this.agentId = in.readString();
}

public MLRegisterAgentResponse(String agentId) {
this.agentId= agentId;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(agentId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(AGENT_ID_FIELD, agentId);
builder.endObject();
return builder;
}

public static MLRegisterAgentResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLRegisterAgentResponse) {
return (MLRegisterAgentResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLRegisterAgentResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionResponse into MLRegisterAgentResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import lombok.Getter;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject {
Expand Down Expand Up @@ -49,4 +54,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
return builder;
}

public static MLUndeployModelsResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLUndeployModelsResponse) {
return (MLUndeployModelsResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLUndeployModelsResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("Failed to parse ActionResponse into MLUndeployModelsResponse", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ public void constructor_NullName() {
MLAgent agent = new MLAgent(null, "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), null, null, Instant.EPOCH, Instant.EPOCH, "test");
}

@Test
public void constructor_DuplicateTool() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Duplicate tool defined: test_tool_name");
MLToolSpec mlToolSpec = new MLToolSpec("test_tool_type", "test_tool_name", "test", Collections.EMPTY_MAP, false);
MLAgent agent = new MLAgent("test_name", "test_type", "test_description", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(mlToolSpec, mlToolSpec), null, null, Instant.EPOCH, Instant.EPOCH, "test");
}

@Test
public void writeTo() throws IOException {
MLAgent agent = new MLAgent("test", "test", "test", new LLMSpec("test_model", Map.of("test_key", "test_value")), List.of(new MLToolSpec("test", "test", "test", Collections.EMPTY_MAP, false)), Map.of("test", "test"), new MLMemorySpec("test", "123", 0), Instant.EPOCH, Instant.EPOCH, "test");
Expand Down
Loading
Loading