Skip to content

Commit

Permalink
create model group automatically with first model version (opensearch…
Browse files Browse the repository at this point in the history
…-project#1063)

* create model group automatically with first model version

Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna authored and zane-neo committed Sep 1, 2023
1 parent 0de3fba commit 38226aa
Show file tree
Hide file tree
Showing 31 changed files with 1,289 additions and 521 deletions.
2 changes: 1 addition & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jacocoTestCoverageVerification {
}
limit {
counter = 'BRANCH'
minimum = 0.6 //TODO: add more test to meet the coverage bar 0.9
minimum = 0.5 //TODO: add more test to meet the coverage bar 0.9
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable {
public static final String NAME_FIELD = "name"; //mandatory
public static final String DESCRIPTION_FIELD = "description"; //optional
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String MODEL_ACCESS_MODE = "model_access_mode"; //optional
public static final String MODEL_ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional

private String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class MLUpdateModelGroupInput implements ToXContentObject, Writeable {
public static final String NAME_FIELD = "name"; //optional
public static final String DESCRIPTION_FIELD = "description"; //optional
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String MODEL_ACCESS_MODE = "model_access_mode"; //optional
public static final String MODEL_ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; //optional


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelName == null) {
throw new IllegalArgumentException("model name is null");
}
if (modelGroupId == null) {
throw new IllegalArgumentException("model group id is null");
}
if (functionName != FunctionName.REMOTE) {
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
Expand Down Expand Up @@ -132,7 +129,7 @@ public MLRegisterModelInput(FunctionName functionName,
public MLRegisterModelInput(StreamInput in) throws IOException {
this.functionName = in.readEnum(FunctionName.class);
this.modelName = in.readString();
this.modelGroupId = in.readString();
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
this.url = in.readOptionalString();
Expand Down Expand Up @@ -162,7 +159,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(functionName);
out.writeString(modelName);
out.writeString(modelGroupId);
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
out.writeOptionalString(url);
Expand Down Expand Up @@ -208,8 +205,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
builder.field(FUNCTION_NAME_FIELD, functionName);
builder.field(NAME_FIELD, modelName);
builder.field(VERSION_FIELD, version);
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
if (version != null) {
builder.field(VERSION_FIELD, version);
}
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -29,20 +33,26 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{

public static final String FUNCTION_NAME_FIELD = "function_name";
public static final String MODEL_NAME_FIELD = "name"; //mandatory
public static final String DESCRIPTION_FIELD = "description";
public static final String DESCRIPTION_FIELD = "description"; //optional

public static final String VERSION_FIELD = "version";
public static final String MODEL_FORMAT_FIELD = "model_format"; //mandatory
public static final String MODEL_STATE_FIELD = "model_state";
public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes";
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; //mandatory
public static final String MODEL_CONFIG_FIELD = "model_config"; //mandatory
public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; //mandatory
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //mandatory
public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //optional
public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional
public static final String ACCESS_MODE = "access_mode"; //optional
public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional

private FunctionName functionName;
private String name;

private String modelGroupId;
private String description;
private String version;

private MLModelFormat modelFormat;

Expand All @@ -52,9 +62,14 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{
private String modelContentHashValue;
private MLModelConfig modelConfig;
private Integer totalChunks;
private List<String> backendRoles;
private AccessMode accessMode;
private Boolean isAddAllBackendRoles;

@Builder(toBuilder = true)
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks) {
public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List<String> backendRoles,
AccessMode accessMode,
Boolean isAddAllBackendRoles) {
if (name == null) {
throw new IllegalArgumentException("model name is null");
}
Expand All @@ -63,9 +78,6 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
} else {
this.functionName = functionName;
}
if (modelGroupId == null) {
throw new IllegalArgumentException("model group id is null");
}
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
Expand All @@ -80,19 +92,24 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
}
this.name = name;
this.modelGroupId = modelGroupId;
this.version = version;
this.description = description;
this.modelFormat = modelFormat;
this.modelState = modelState;
this.modelContentSizeInBytes = modelContentSizeInBytes;
this.modelContentHashValue = modelContentHashValue;
this.modelConfig = modelConfig;
this.totalChunks = totalChunks;
this.backendRoles = backendRoles;
this.accessMode = accessMode;
this.isAddAllBackendRoles = isAddAllBackendRoles;
}

public MLRegisterModelMetaInput(StreamInput in) throws IOException{
this.name = in.readString();
this.functionName = in.readEnum(FunctionName.class);
this.modelGroupId = in.readString();
this.modelGroupId = in.readOptionalString();
this.version = in.readOptionalString();
this.description = in.readOptionalString();
if (in.readBoolean()) {
modelFormat = in.readEnum(MLModelFormat.class);
Expand All @@ -106,13 +123,19 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{
modelConfig = new TextEmbeddingModelConfig(in);
}
this.totalChunks = in.readInt();
this.backendRoles = in.readOptionalStringList();
if (in.readBoolean()) {
accessMode = in.readEnum(AccessMode.class);
}
this.isAddAllBackendRoles = in.readOptionalBoolean();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeEnum(functionName);
out.writeString(modelGroupId);
out.writeOptionalString(modelGroupId);
out.writeOptionalString(version);
out.writeOptionalString(description);
if (modelFormat != null) {
out.writeBoolean(true);
Expand All @@ -135,14 +158,32 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeInt(totalChunks);
if (backendRoles != null) {
out.writeBoolean(true);
out.writeStringCollection(backendRoles);
} else {
out.writeBoolean(false);
}
if (accessMode != null) {
out.writeBoolean(true);
out.writeEnum(accessMode);
} else {
out.writeBoolean(false);
}
out.writeOptionalBoolean(isAddAllBackendRoles);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(MODEL_NAME_FIELD, name);
builder.field(FUNCTION_NAME_FIELD, functionName);
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
if (modelGroupId != null) {
builder.field(MODEL_GROUP_ID_FIELD, modelGroupId);
}
if (version != null) {
builder.field(VERSION_FIELD, version);
}
if (description != null) {
builder.field(DESCRIPTION_FIELD, description);
}
Expand All @@ -156,13 +197,23 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
builder.field(MODEL_CONTENT_HASH_VALUE_FIELD, modelContentHashValue);
builder.field(MODEL_CONFIG_FIELD, modelConfig);
builder.field(TOTAL_CHUNKS_FIELD, totalChunks);
if (backendRoles != null && backendRoles.size() > 0) {
builder.field(BACKEND_ROLES_FIELD, backendRoles);
}
if (accessMode != null) {
builder.field(ACCESS_MODE, accessMode);
}
if (isAddAllBackendRoles != null) {
builder.field(ADD_ALL_BACKEND_ROLES, isAddAllBackendRoles);
}
builder.endObject();
return builder;
}

public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOException {
String name = null;
FunctionName functionName = null;
String modelGroupId = null;
String version = null;
String description = null;
MLModelFormat modelFormat = null;
Expand All @@ -171,6 +222,9 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
String modelContentHashValue = null;
MLModelConfig modelConfig = null;
Integer totalChunks = null;
List<String> backendRoles = null;
AccessMode accessMode = null;
Boolean isAddAllBackendRoles = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -184,6 +238,9 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
functionName = FunctionName.from(parser.text());
break;
case MODEL_GROUP_ID_FIELD:
modelGroupId = parser.text();
break;
case VERSION_FIELD:
version = parser.text();
break;
case DESCRIPTION_FIELD:
Expand All @@ -207,12 +264,25 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc
case TOTAL_CHUNKS_FIELD:
totalChunks = parser.intValue(false);
break;
case BACKEND_ROLES_FIELD:
backendRoles = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
backendRoles.add(parser.text());
}
break;
case ACCESS_MODE:
accessMode = AccessMode.from(parser.text().toLowerCase(Locale.ROOT));
break;
case ADD_ALL_BACKEND_ROLES:
isAddAllBackendRoles = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLRegisterModelMetaInput(name, functionName, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks);
return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,6 @@ public void constructor_NullModelName() {
.build();
}

@Test
public void constructor_NullModelGroupId() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model group id is null");
MLRegisterModelInput.builder()
.functionName(functionName)
.modelName(modelName)
.modelGroupId(null)
.build();
}

@Test
public void constructor_NullModelFormat() {
exceptionRule.expect(IllegalArgumentException.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ public class MLRegisterModelMetaInputTest {
public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2);
mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
}

@Test
Expand Down Expand Up @@ -75,14 +75,14 @@ public void testToXContent() throws IOException {{
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ public class MLRegisterModelMetaRequestTest {
public void setUp() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512);
mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2);
mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null);
}

@Test
Expand Down
Loading

0 comments on commit 38226aa

Please sign in to comment.