Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
89b1c4a
add tokenizer and sparse encoding
xinyual Aug 25, 2023
f1268cd
add tokenizer and sparse encoding
xinyual Aug 25, 2023
3d489e3
add tokenizer and sparse encoding
xinyual Aug 25, 2023
599b43a
add tokenizer and sparse encoding
xinyual Aug 25, 2023
4e7e540
add tokenizer and sparse encoding
xinyual Aug 25, 2023
6c5ec1d
remove special token
xinyual Aug 28, 2023
9a42042
add filter
xinyual Aug 28, 2023
8db87c5
try empty model
xinyual Aug 28, 2023
8c50a4c
remove warm up
xinyual Aug 28, 2023
63f6e48
try empty model
xinyual Aug 28, 2023
fff587d
add block
xinyual Aug 28, 2023
ed64c1a
add log
xinyual Aug 28, 2023
2384bd6
add log
xinyual Aug 28, 2023
8379a89
add log
xinyual Aug 28, 2023
d7c306d
remove log
xinyual Aug 28, 2023
91e5e00
remove pt file detect
xinyual Aug 28, 2023
47aa4f7
add log
xinyual Aug 28, 2023
acaccbf
add functionName pipeline
xinyual Aug 28, 2023
98cba58
remove verify log
xinyual Aug 28, 2023
6e79550
skip special token in sparse encoding
xinyual Aug 28, 2023
c00eb57
skip omit tokenize config
xinyual Aug 29, 2023
73de04c
skip omit tokenize config-change warm up logic
xinyual Aug 29, 2023
95a41b5
reArch
xinyual Aug 29, 2023
8ffd217
deduplicate
xinyual Aug 29, 2023
83205e6
omit ml config in sparse encoding
xinyual Aug 29, 2023
4ef9c66
add null config in warm up
xinyual Aug 29, 2023
a1493cd
fix original test
xinyual Aug 29, 2023
b126bb0
add tokenize ut half
xinyual Aug 29, 2023
e038ed6
fix sparse encoding bug
xinyual Aug 30, 2023
a03ce36
add UT for sparse encoding and tokenize
xinyual Aug 30, 2023
828bd04
remove useless framwork type
xinyual Aug 30, 2023
88a4a2d
common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java
xinyual Aug 31, 2023
4571e4a
change key for tokenize
xinyual Aug 31, 2023
969a407
reArch DLModel
xinyual Aug 31, 2023
8397f29
reArch DLModel again
xinyual Aug 31, 2023
b5a2471
response format
xinyual Aug 31, 2023
d14038d
tokenize only one output
xinyual Aug 31, 2023
0f0bb5e
clean sparse output
xinyual Aug 31, 2023
92063c6
clean sparse output
xinyual Aug 31, 2023
a07e4be
change UT number
xinyual Aug 31, 2023
c23ee2f
remove useless predict code
xinyual Sep 4, 2023
18c9124
remove useless part
xinyual Sep 4, 2023
ce420db
change tokenize way
xinyual Sep 5, 2023
90b1fa2
reArch add textEmbedding model
xinyual Sep 5, 2023
3890e23
add tokenize logic
xinyual Sep 5, 2023
458994f
add abstract
xinyual Sep 5, 2023
eb82069
clear code
xinyual Sep 5, 2023
376853f
fix it class
xinyual Sep 6, 2023
83a1f34
fix it class
xinyual Sep 6, 2023
445e731
add IT file
xinyual Sep 6, 2023
1e280e8
reformulate
xinyual Sep 7, 2023
7367081
reformulate remote inference
xinyual Sep 7, 2023
0ab9a6c
reformulate remote inference
xinyual Sep 7, 2023
419e98a
reformulate remote inference json and array
xinyual Sep 7, 2023
713ed5b
verify
xinyual Sep 8, 2023
f7deeeb
undo string utils
xinyual Sep 8, 2023
3d6f3f2
skip dummy model
xinyual Sep 11, 2023
18c3821
skip dummy model
xinyual Sep 11, 2023
b5b414f
skip dummy model
xinyual Sep 11, 2023
606313d
skip dummy model
xinyual Sep 11, 2023
17b426d
skip dummy model
xinyual Sep 11, 2023
c8bbd9f
skip dummy model
xinyual Sep 11, 2023
3a06dd5
add inner load Model
xinyual Sep 11, 2023
a247844
rename variable
xinyual Sep 11, 2023
635ef61
add default for idf
xinyual Sep 12, 2023
91d7209
add ut for sparse encoding and tokenizer
xinyual Sep 12, 2023
d8c8d40
add close model
xinyual Sep 12, 2023
9f99e07
change mock class
xinyual Sep 12, 2023
880a4a0
remove buffer for sparse encoding output
xinyual Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dependencies {
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'

compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly group: 'org.json', name: 'json', version: '20230227'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation group: 'org.json', name: 'json', version: '20230227'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
SPARSE_ENCODING,
TOKENIZE,
METRICS_CORRELATION,
REMOTE;

Expand All @@ -33,7 +35,7 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == TOKENIZE) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else {
Map<String, Object> map = new HashMap<>();
map.put("response", response);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(map).build());
Map<String, Object> map = new HashMap<>();
map.put("response", response);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(map).build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
}
}
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.TOKENIZE) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.TOKENIZE})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (url != null && modelConfig == null) {
if (url != null && modelConfig == null && functionName != FunctionName.TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) {
throw new IllegalArgumentException("model config is null");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
if (modelContentHashValue == null) {
throw new IllegalArgumentException("model content hash value is null");
}
if (modelConfig == null) {
if (modelConfig == null && functionName != FunctionName.TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) {
throw new IllegalArgumentException("model config is null");
}
if (totalChunks == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,27 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException {
assertArrayEquals(new long[]{1, 2}, metrics);
}

@Test
public void testClassLoader_MLInput() throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING));
private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(functionName));

String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class);
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class);
assertNotNull(mlInput);
assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName());
assertEquals(functionName, mlInput.getFunctionName());
assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size());
}

@Test
public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
testClassLoader_MLInput_DlModel(FunctionName.TOKENIZE);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
}

public enum TestEnum {
TEST
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,19 @@ public void parse_LinearRegression() throws IOException {
});
}

@Test
public void parse_TextEmbedding() throws IOException {
private void parse_NLPModel(FunctionName functionName) throws IOException {
String sentence = "test sentence";
String column = "column1";
Integer position = 1;
ModelResultFilter resultFilter = ModelResultFilter.builder()
.targetResponse(Arrays.asList(column))
.targetResponsePositions(Arrays.asList(position))
.build();
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence))
.resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {

TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
TextDocsInputDataSet parsedInputDataSet = (TextDocsInputDataSet) parsedInput.getInputDataset();
assertEquals(1, parsedInputDataSet.getDocs().size());
Expand All @@ -134,19 +134,33 @@ public void parse_TextEmbedding() throws IOException {
}

@Test
public void parse_TextEmbedding_NullResultFilter() throws IOException {
public void parse_NLP_Related() throws IOException {
parse_NLPModel(FunctionName.TEXT_EMBEDDING);
parse_NLPModel(FunctionName.TOKENIZE);
parse_NLPModel(FunctionName.SPARSE_ENCODING);
}

private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws IOException {
String sentence = "test sentence";
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().size());
assertEquals(sentence, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().get(0));
});
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr,
Consumer<MLInput> verify) throws IOException {

@Test
public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING);
parse_NLPModel_NullResultFilter(FunctionName.TOKENIZE);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ GET /_plugins/_ml/profile/models/zwla5YUB1qmVrJFlwzXJ
"models": {
"zwla5YUB1qmVrJFlwzXJ": { # model id
"model_state": "LOADED",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel@1a0b0793",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel@1a0b0793",
"target_worker_nodes": [ # plan to deploy model to these nodes
"0TLL4hHxRv6_G3n6y1l0BQ"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
Expand All @@ -202,7 +202,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
verifyModelZipFile(modelFormat, modelPath, modelName, functionName);
String hash = calculateFileHash(modelZipFile);
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Expand All @@ -224,7 +224,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
}
}

public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName) throws IOException {
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
Expand All @@ -239,7 +239,7 @@ public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePat
}
}
}
if (!hasPtFile && !hasOnnxFile) {
if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.TOKENIZE) {
throw new IllegalArgumentException("Can't find model file");
}
if (!hasTokenizerFile) {
Expand Down
Loading