Skip to content

Commit

Permalink
Add tokenizer and sparse encoding (#1301)
Browse files Browse the repository at this point in the history
* add tokenizer and sparse encoding

Signed-off-by: xinyual <xinyual@amazon.com>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <xinyual@amazon.com>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <xinyual@amazon.com>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <xinyual@amazon.com>

* add tokenizer and sparse encoding

Signed-off-by: xinyual <xinyual@amazon.com>

* remove special token

Signed-off-by: xinyual <xinyual@amazon.com>

* add filter

Signed-off-by: xinyual <xinyual@amazon.com>

* try empty model

Signed-off-by: xinyual <xinyual@amazon.com>

* remove warm up

Signed-off-by: xinyual <xinyual@amazon.com>

* try empty model

Signed-off-by: xinyual <xinyual@amazon.com>

* add block

Signed-off-by: xinyual <xinyual@amazon.com>

* add log

Signed-off-by: xinyual <xinyual@amazon.com>

* add log

Signed-off-by: xinyual <xinyual@amazon.com>

* add log

Signed-off-by: xinyual <xinyual@amazon.com>

* remove log

Signed-off-by: xinyual <xinyual@amazon.com>

* remove pt file detect

Signed-off-by: xinyual <xinyual@amazon.com>

* add log

Signed-off-by: xinyual <xinyual@amazon.com>

* add functionName pipeline

Signed-off-by: xinyual <xinyual@amazon.com>

* remove verify log

Signed-off-by: xinyual <xinyual@amazon.com>

* skip special token in sparse encoding

Signed-off-by: xinyual <xinyual@amazon.com>

* skip omit tokenize config

Signed-off-by: xinyual <xinyual@amazon.com>

* skip omit tokenize config-change warm up logic

Signed-off-by: xinyual <xinyual@amazon.com>

* reArch

Signed-off-by: xinyual <xinyual@amazon.com>

* deduplicate

Signed-off-by: xinyual <xinyual@amazon.com>

* omit ml config in sparse encoding

Signed-off-by: xinyual <xinyual@amazon.com>

* add null config in warm up

Signed-off-by: xinyual <xinyual@amazon.com>

* fix original test

Signed-off-by: xinyual <xinyual@amazon.com>

* add tokenize ut half

Signed-off-by: xinyual <xinyual@amazon.com>

* fix sparse encoding bug

Signed-off-by: xinyual <xinyual@amazon.com>

* add UT for sparse encoding and tokenize

Signed-off-by: xinyual <xinyual@amazon.com>

* remove useless framwork type

Signed-off-by: xinyual <xinyual@amazon.com>

* common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java

Signed-off-by: xinyual <xinyual@amazon.com>

* change key for tokenize

Signed-off-by: xinyual <xinyual@amazon.com>

* reArch DLModel

Signed-off-by: xinyual <xinyual@amazon.com>

* reArch DLModel again

Signed-off-by: xinyual <xinyual@amazon.com>

* response format

Signed-off-by: xinyual <xinyual@amazon.com>

* tokenize only one output

Signed-off-by: xinyual <xinyual@amazon.com>

* clean sparse output

Signed-off-by: xinyual <xinyual@amazon.com>

* clean sparse output

Signed-off-by: xinyual <xinyual@amazon.com>

* change UT number

Signed-off-by: xinyual <xinyual@amazon.com>

* remove useless predict code

Signed-off-by: xinyual <xinyual@amazon.com>

* remove useless part

Signed-off-by: xinyual <xinyual@amazon.com>

* change tokenize way

Signed-off-by: xinyual <xinyual@amazon.com>

* reArch add textEmbedding model

Signed-off-by: xinyual <xinyual@amazon.com>

* add tokenize logic

Signed-off-by: xinyual <xinyual@amazon.com>

* add abstract

Signed-off-by: xinyual <xinyual@amazon.com>

* clear code

Signed-off-by: xinyual <xinyual@amazon.com>

* fix it class

Signed-off-by: xinyual <xinyual@amazon.com>

* fix it class

Signed-off-by: xinyual <xinyual@amazon.com>

* add IT file

Signed-off-by: xinyual <xinyual@amazon.com>

* reformulate

Signed-off-by: xinyual <xinyual@amazon.com>

* reformulate remote inference

Signed-off-by: xinyual <xinyual@amazon.com>

* reformulate remote inference

Signed-off-by: xinyual <xinyual@amazon.com>

* reformulate remote inference json and array

Signed-off-by: xinyual <xinyual@amazon.com>

* verify

Signed-off-by: xinyual <xinyual@amazon.com>

* undo string utils

Signed-off-by: xinyual <xinyual@amazon.com>

* skip dummy model

Signed-off-by: xinyual <xinyual@amazon.com>

* skip dummy model

Signed-off-by: xinyual <xinyual@amazon.com>

* skip dummy model

Signed-off-by: xinyual <xinyual@amazon.com>

* skip dummy model

Signed-off-by: xinyual <xinyual@amazon.com>

* skip dummy model

Signed-off-by: xinyual <xinyual@amazon.com>

* skip dummy model

Signed-off-by: xinyual <xinyual@amazon.com>

* add inner load Model

Signed-off-by: xinyual <xinyual@amazon.com>

* rename variable

Signed-off-by: xinyual <xinyual@amazon.com>

* add default for idf

Signed-off-by: xinyual <xinyual@amazon.com>

* add ut for sparse encoding and tokenizer

Signed-off-by: xinyual <xinyual@amazon.com>

* add close model

Signed-off-by: xinyual <xinyual@amazon.com>

* change mock class

Signed-off-by: xinyual <xinyual@amazon.com>

* remove buffer for sparse encoding output

Signed-off-by: xinyual <xinyual@amazon.com>

* change tokenize model ready logic

Signed-off-by: xinyual <xinyual@amazon.com>

* rewrite input functionName

Signed-off-by: xinyual <xinyual@amazon.com>

* deduplicate

Signed-off-by: xinyual <xinyual@amazon.com>

* change UT usage

Signed-off-by: xinyual <xinyual@amazon.com>

* fix downloadAndSplit test

Signed-off-by: xinyual <xinyual@amazon.com>

* fix Helper  test

Signed-off-by: xinyual <xinyual@amazon.com>

* remove meaningless change

Signed-off-by: xinyual <xinyual@amazon.com>

* remove complie change

Signed-off-by: xinyual <xinyual@amazon.com>

* rename

Signed-off-by: xinyual <xinyual@amazon.com>

* fix typo error and simplify wrap code

Signed-off-by: xinyual <xinyual@amazon.com>

* add comment

Signed-off-by: xinyual <xinyual@amazon.com>

* using gson and remove useless close logic

Signed-off-by: xinyual <xinyual@amazon.com>

* update comment and import problem

Signed-off-by: xinyual <xinyual@amazon.com>

* add static idf name

Signed-off-by: xinyual <xinyual@amazon.com>

* fix format problem

Signed-off-by: xinyual <xinyual@amazon.com>

* extract an abstract model for sparse and dense sentence transformer translator

Signed-off-by: xinyual <xinyual@amazon.com>

* fix typo error

Signed-off-by: xinyual <xinyual@amazon.com>

* remove duplicate tokenizer file, fix import problem and add comment for tokenizer model

Signed-off-by: xinyual <xinyual@amazon.com>

---------

Signed-off-by: xinyual <xinyual@amazon.com>
(cherry picked from commit 31a4e25)
  • Loading branch information
xinyual authored and github-actions[bot] committed Sep 27, 2023
1 parent 9b6af2e commit 9ddd3e0
Show file tree
Hide file tree
Showing 34 changed files with 1,101 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class CommonValue {
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
REMOTE;

Expand All @@ -33,7 +35,7 @@ public static FunctionName from(String value) {
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Map;
import java.util.Optional;

import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;
import static org.opensearch.ml.common.utils.StringUtils.isJson;

@Getter
Expand Down Expand Up @@ -101,7 +102,7 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
return;
}
if (response instanceof String && isJson((String)response)) {
Map<String, Object> data = StringUtils.fromJson((String) response, "response");
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else {
Map<String, Object> map = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
}
}
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING})
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -104,7 +103,7 @@ public MLRegisterModelInput(FunctionName functionName,
if (modelFormat == null) {
throw new IllegalArgumentException("model format is null");
}
if (url != null && modelConfig == null) {
if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m
if (modelContentHashValue == null) {
throw new IllegalArgumentException("model content hash value is null");
}
if (modelConfig == null) {
if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration.
throw new IllegalArgumentException("model config is null");
}
if (totalChunks == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,27 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException {
assertArrayEquals(new long[]{1, 2}, metrics);
}

@Test
public void testClassLoader_MLInput() throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(FunctionName.TEXT_EMBEDDING));
private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException {
assertTrue(MLCommonsClassLoader.canInitMLInput(functionName));

String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}";
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(FunctionName.TEXT_EMBEDDING, new Object[]{parser, FunctionName.TEXT_EMBEDDING}, XContentParser.class, FunctionName.class);
TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class);
assertNotNull(mlInput);
assertEquals(FunctionName.TEXT_EMBEDDING, mlInput.getFunctionName());
assertEquals(functionName, mlInput.getFunctionName());
assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size());
}

@Test
public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
}

public enum TestEnum {
TEST
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.core.common.Strings;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -110,19 +108,19 @@ public void parse_LinearRegression() throws IOException {
});
}

@Test
public void parse_TextEmbedding() throws IOException {
private void parse_NLPModel(FunctionName functionName) throws IOException {
String sentence = "test sentence";
String column = "column1";
Integer position = 1;
ModelResultFilter resultFilter = ModelResultFilter.builder()
.targetResponse(Arrays.asList(column))
.targetResponsePositions(Arrays.asList(position))
.build();
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence))
.resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {

TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build();
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
TextDocsInputDataSet parsedInputDataSet = (TextDocsInputDataSet) parsedInput.getInputDataset();
assertEquals(1, parsedInputDataSet.getDocs().size());
Expand All @@ -134,19 +132,33 @@ public void parse_TextEmbedding() throws IOException {
}

@Test
public void parse_TextEmbedding_NullResultFilter() throws IOException {
public void parse_NLP_Related() throws IOException {
parse_NLPModel(FunctionName.TEXT_EMBEDDING);
parse_NLPModel(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel(FunctionName.SPARSE_ENCODING);
}

private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws IOException {
String sentence = "test sentence";
TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).build();
String expectedInputStr = "{\"algorithm\":\"TEXT_EMBEDDING\",\"text_docs\":[\"test sentence\"]}";
testParse(FunctionName.TEXT_EMBEDDING, inputDataset, expectedInputStr, parsedInput -> {
String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"]}";
expectedInputStr = expectedInputStr.replace("functionName", functionName.toString());
testParse(functionName, inputDataset, expectedInputStr, parsedInput -> {
assertNotNull(parsedInput.getInputDataset());
assertEquals(1, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().size());
assertEquals(sentence, ((TextDocsInputDataSet) parsedInput.getInputDataset()).getDocs().get(0));
});
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr,
Consumer<MLInput> verify) throws IOException {

@Test
public void parse_NLPRelated_NullResultFilter() throws IOException {
parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_TOKENIZE);
parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING);
}

private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer<MLInput> verify) throws IOException {
MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
input.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ GET /_plugins/_ml/profile/models/zwla5YUB1qmVrJFlwzXJ
"models": {
"zwla5YUB1qmVrJFlwzXJ": { # model id
"model_state": "LOADED",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel@1a0b0793",
"predictor": "org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel@1a0b0793",
"target_worker_nodes": [ # plan to deploy model to these nodes
"0TLL4hHxRv6_G3n6y1l0BQ"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public List downloadPrebuiltModelMetaList(String taskId, MLRegisterModelInput re
* @param modelContentHash model content hash value
* @param listener action listener
*/
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, ActionListener<Map<String, Object>> listener) {
public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String modelName, String version, String url, String modelContentHash, FunctionName functionName, ActionListener<Map<String, Object>> listener) {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path registerModelPath = mlEngine.getRegisterModelPath(taskId, modelName, version);
Expand All @@ -200,7 +200,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
File modelZipFile = new File(modelPath);
log.debug("download model to file {}", modelZipFile.getAbsolutePath());
DownloadUtils.download(url, modelPath, new ProgressBar());
verifyModelZipFile(modelFormat, modelPath, modelName);
verifyModelZipFile(modelFormat, modelPath, modelName, functionName);
String hash = calculateFileHash(modelZipFile);
if (hash.equals(modelContentHash)) {
List<String> chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE);
Expand All @@ -222,7 +222,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo
}
}

public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName) throws IOException {
public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePath, String modelName, FunctionName functionName) throws IOException {
boolean hasPtFile = false;
boolean hasOnnxFile = false;
boolean hasTokenizerFile = false;
Expand All @@ -237,7 +237,7 @@ public void verifyModelZipFile(MLModelFormat modelFormat, String modelZipFilePat
}
}
}
if (!hasPtFile && !hasOnnxFile) {
if (!hasPtFile && !hasOnnxFile && functionName != FunctionName.SPARSE_TOKENIZE) { // sparse tokenizer model doesn't need model file.
throw new IllegalArgumentException("Can't find model file");
}
if (!hasTokenizerFile) {
Expand Down
Loading

0 comments on commit 9ddd3e0

Please sign in to comment.