Skip to content

Commit ae43501

Browse files
authored
Merge pull request #2 from xinyual/addTokenizerAndSparseEncoding
Add tokenizer and sparse encoding
2 parents ab398ee + 880a4a0 commit ae43501

File tree

33 files changed

+1109
-192
lines changed

33 files changed

+1109
-192
lines changed

common/build.gradle

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ dependencies {
1919
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
2020

2121
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
22-
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
23-
compileOnly group: 'org.json', name: 'json', version: '20230227'
22+
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
23+
implementation group: 'org.json', name: 'json', version: '20230227'
2424
}
2525

2626
lombok {

common/src/main/java/org/opensearch/ml/common/FunctionName.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ public enum FunctionName {
1717
RCF_SUMMARIZE,
1818
LOGISTIC_REGRESSION,
1919
TEXT_EMBEDDING,
20+
SPARSE_ENCODING,
21+
TOKENIZE,
2022
METRICS_CORRELATION,
2123
REMOTE;
2224

@@ -33,7 +35,7 @@ public static FunctionName from(String value) {
3335
* @return true for deep learning model.
3436
*/
3537
public static boolean isDLModel(FunctionName functionName) {
36-
if (functionName == TEXT_EMBEDDING) {
38+
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == TOKENIZE) {
3739
return true;
3840
}
3941
return false;

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
104104
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
105105
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
106106
} else {
107-
Map<String, Object> map = new HashMap<>();
108-
map.put("response", response);
109-
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(map).build());
107+
Map<String, Object> map = new HashMap<>();
108+
map.put("response", response);
109+
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(map).build());
110110
}
111111
}
112112

common/src/main/java/org/opensearch/ml/common/input/MLInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
239239
}
240240
}
241241
MLInputDataset inputDataSet = null;
242-
if (algorithm == FunctionName.TEXT_EMBEDDING) {
242+
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.TOKENIZE) {
243243
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
244244
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
245245
}

common/src/main/java/org/opensearch/ml/common/input/nlp/TextDocsMLInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
* ML input class which supports a list fo text docs.
2626
* This class can be used for TEXT_EMBEDDING model.
2727
*/
28-
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
28+
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.TOKENIZE})
2929
public class TextDocsMLInput extends MLInput {
3030
public static final String TEXT_DOCS_FIELD = "text_docs";
3131
public static final String RESULT_FILTER_FIELD = "result_filter";

common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public MLRegisterModelInput(FunctionName functionName,
104104
if (modelFormat == null) {
105105
throw new IllegalArgumentException("model format is null");
106106
}
107-
if (url != null && modelConfig == null) {
107+
if (url != null && modelConfig == null && functionName != FunctionName.TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) {
108108
throw new IllegalArgumentException("model config is null");
109109
}
110110
}

common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
8484
if (modelContentHashValue == null) {
8585
throw new IllegalArgumentException("model content hash value is null");
8686
}
87-
if (modelConfig == null) {
87+
if (modelConfig == null && functionName != FunctionName.TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) {
8888
throw new IllegalArgumentException("model config is null");
8989
}
9090
if (totalChunks == null) {

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import java.security.AccessController;
1818
import java.security.PrivilegedActionException;
1919
import java.security.PrivilegedExceptionAction;
20+
import java.util.ArrayList;
2021
import java.util.HashMap;
2122
import java.util.List;
2223
import java.util.Map;

common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,21 +149,27 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException {
149149
assertArrayEquals(new long[]{1, 2}, metrics);
150150
}
151151

152-
@Test
153-
public void testClassLoader_MLInput() throws IOException {
154-
assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING));
152+
private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException {
153+
assertTrue(MLCommonsClassLoader.canInitMLInput(functionName));
155154

156155
String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
157156
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
158157
Collections.emptyList()).getNamedXContents()), null, jsonStr);
159158
parser.nextToken();
160159

161-
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class);
160+
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class);
162161
assertNotNull(mlInput);
163-
assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName());
162+
assertEquals(functionName, mlInput.getFunctionName());
164163
assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size());
165164
}
166165

166+
@Test
167+
public void testClassLoader_MLInput() throws IOException {
168+
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
169+
testClassLoader_MLInput_DlModel(FunctionName.TOKENIZE);
170+
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
171+
}
172+
167173
public enum TestEnum {
168174
TEST
169175
}

common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,19 @@ public void parse_LinearRegression() throws IOException {
110110
});
111111
}
112112

113-
@Test
114-
public void parse_TextEmbedding() throws IOException {
113+
private void parse_NLPModel(FunctionName functionName) throws IOException {
115114
String sentence = "test sentence";
116115
String column = "column1";
117116
Integer position = 1;
118117
ModelResultFilter resultFilter = ModelResultFilter.builder()
119118
.targetResponse(Arrays.asList(column))
120119
.targetResponsePositions(Arrays.asList(position))
121120
.build();
122-
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence))
123-
.resultFilter(resultFilter).build();
124-
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
125-
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
121+
122+
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build();
123+
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
124+
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
125+
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
126126
assertNotNull(parsedInput.getInputDataset());
127127
TextDocsInputDataSet parsedInputDataSet = (TextDocsInputDataSet) parsedInput.getInputDataset();
128128
assertEquals(1, parsedInputDataSet.getDocs().size());
@@ -134,19 +134,33 @@ public void parse_TextEmbedding() throws IOException {
134134
}
135135

136136
@Test
137-
public void parse_TextEmbedding_NullResultFilter() throws IOException {
137+
public void parse_NLP_Related() throws IOException {
138+
parse_NLPModel(FunctionName.TEXT_EMBEDDING);
139+
parse_NLPModel(FunctionName.TOKENIZE);
140+
parse_NLPModel(FunctionName.SPARSE_ENCODING);
141+
}
142+
143+
private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws IOException {
138144
String sentence = "test sentence";
139145
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).build();
140-
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"]}";
141-
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
146+
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"]}";
147+
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
148+
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
142149
assertNotNull(parsedInput.getInputDataset());
143150
assertEquals(1, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().size());
144151
assertEquals(sentence, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().get(0));
145152
});
146153
}
147154

148-
private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr,
149-
Consumer<MLInput> verify) throws IOException {
155+
156+
@Test
157+
public void parse_NLPRelated_NullResultFilter() throws IOException {
158+
parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING);
159+
parse_NLPModel_NullResultFilter(FunctionName.TOKENIZE);
160+
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
161+
}
162+
163+
private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
150164
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
151165
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
152166
input.toXContent(builder, ToXContent.EMPTY_PARAMS);

0 commit comments

Comments
 (0)