Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tokenizer and sparse encoding #1301

Merged
Merged
Changes from 1 commit
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
e37956e
add tokenizer and sparse encoding
xinyual Aug 25, 2023
27d8fe0
add tokenizer and sparse encoding
xinyual Aug 25, 2023
0f8cd9c
add tokenizer and sparse encoding
xinyual Aug 25, 2023
af9ba42
add tokenizer and sparse encoding
xinyual Aug 25, 2023
6f49c22
add tokenizer and sparse encoding
xinyual Aug 25, 2023
2b5ea1b
remove special token
xinyual Aug 28, 2023
f832494
add filter
xinyual Aug 28, 2023
f49c471
try empty model
xinyual Aug 28, 2023
2be43a7
remove warm up
xinyual Aug 28, 2023
4b96521
try empty model
xinyual Aug 28, 2023
7b9f97e
add block
xinyual Aug 28, 2023
084c56f
add log
xinyual Aug 28, 2023
c2951cf
add log
xinyual Aug 28, 2023
61720c3
add log
xinyual Aug 28, 2023
99eda1a
remove log
xinyual Aug 28, 2023
5aa5698
remove pt file detect
xinyual Aug 28, 2023
1d4ddba
add log
xinyual Aug 28, 2023
b6614ca
add functionName pipeline
xinyual Aug 28, 2023
e3ca040
remove verify log
xinyual Aug 28, 2023
5eaf588
skip special token in sparse encoding
xinyual Aug 28, 2023
37a878f
skip omit tokenize config
xinyual Aug 29, 2023
4c4ac28
skip omit tokenize config-change warm up logic
xinyual Aug 29, 2023
03ce4af
reArch
xinyual Aug 29, 2023
eac759d
deduplicate
xinyual Aug 29, 2023
c0513b4
omit ml config in sparse encoding
xinyual Aug 29, 2023
0428cf4
add null config in warm up
xinyual Aug 29, 2023
51eef93
fix original test
xinyual Aug 29, 2023
17b755f
add tokenize ut half
xinyual Aug 29, 2023
418aa30
fix sparse encoding bug
xinyual Aug 30, 2023
1fe20e2
add UT for sparse encoding and tokenize
xinyual Aug 30, 2023
df1ca31
remove useless framwork type
xinyual Aug 30, 2023
16b5c89
common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java
xinyual Aug 31, 2023
2a15287
change key for tokenize
xinyual Aug 31, 2023
5ddbb24
reArch DLModel
xinyual Aug 31, 2023
401a7a6
reArch DLModel again
xinyual Aug 31, 2023
2c85cea
response format
xinyual Aug 31, 2023
64a6e6b
tokenize only one output
xinyual Aug 31, 2023
eb7ea9c
clean sparse output
xinyual Aug 31, 2023
65a72ca
clean sparse output
xinyual Aug 31, 2023
f898d6d
change UT number
xinyual Aug 31, 2023
d50d64a
remove useless predict code
xinyual Sep 4, 2023
f6428ee
remove useless part
xinyual Sep 4, 2023
46b0f00
change tokenize way
xinyual Sep 5, 2023
1b57385
reArch add textEmbedding model
xinyual Sep 5, 2023
21bff22
add tokenize logic
xinyual Sep 5, 2023
d5581b6
add abstract
xinyual Sep 5, 2023
ae39b41
clear code
xinyual Sep 5, 2023
2b8529a
fix it class
xinyual Sep 6, 2023
d32e273
fix it class
xinyual Sep 6, 2023
61ac300
add IT file
xinyual Sep 6, 2023
c957d12
reformulate
xinyual Sep 7, 2023
7924cc5
reformulate remote inference
xinyual Sep 7, 2023
71338fa
reformulate remote inference
xinyual Sep 7, 2023
09a6acd
reformulate remote inference json and array
xinyual Sep 7, 2023
84b7006
verify
xinyual Sep 8, 2023
dc30251
undo string utils
xinyual Sep 8, 2023
2ab086e
skip dummy model
xinyual Sep 11, 2023
1e3a2f3
skip dummy model
xinyual Sep 11, 2023
67db622
skip dummy model
xinyual Sep 11, 2023
2748b97
skip dummy model
xinyual Sep 11, 2023
44a5bbd
skip dummy model
xinyual Sep 11, 2023
1e3def0
skip dummy model
xinyual Sep 11, 2023
086f6b0
add inner load Model
xinyual Sep 11, 2023
abc4064
rename variable
xinyual Sep 11, 2023
976d04f
add default for idf
xinyual Sep 12, 2023
89fe98f
add ut for sparse encoding and tokenizer
xinyual Sep 12, 2023
7173590
add close model
xinyual Sep 12, 2023
98a69b2
change mock class
xinyual Sep 12, 2023
2ab6ccf
remove buffer for sparse encoding output
xinyual Sep 12, 2023
0bea7f1
change tokenize model ready logic
xinyual Sep 13, 2023
65a5ed0
rewrite input functionName
xinyual Sep 14, 2023
55e60d9
deduplicate
xinyual Sep 14, 2023
7e9f015
change UT usage
xinyual Sep 14, 2023
86cc578
fix downloadAndSplit test
xinyual Sep 14, 2023
a9cb526
fix Helper test
xinyual Sep 14, 2023
e1c9359
remove meaningless change
xinyual Sep 18, 2023
31222d0
remove complie change
xinyual Sep 19, 2023
1214437
rename
xinyual Sep 21, 2023
fdaba84
fix typo error and simplify wrap code
xinyual Sep 21, 2023
4b96a09
add comment
xinyual Sep 25, 2023
185e95b
using gson and remove useless close logic
xinyual Sep 25, 2023
4f8847c
update comment and import problem
xinyual Sep 26, 2023
7f04f4c
add static idf name
xinyual Sep 26, 2023
4e3dc78
fix format problem
xinyual Sep 26, 2023
a837080
extract an abstract model for sparse and dense sentence transformer t…
xinyual Sep 26, 2023
dca433a
fix typo error
xinyual Sep 26, 2023
73b3f02
remove duplicate tokenizer file, fix import problem and add comment f…
xinyual Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tokenize ut half
Signed-off-by: xinyual <xinyual@amazon.com>
  • Loading branch information
xinyual committed Sep 26, 2023
commit 17b755fb0bf0999e181f73398689716da01810a3
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package org.opensearch.ml.engine.algorithms.tokenize;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.algorithms.tokenize.TokenizerModel;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;

import java.io.File;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.*;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.HUGGINGFACE_TRANSFORMERS;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS;
import static org.opensearch.ml.engine.algorithms.DLModel.*;
import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE;
import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingModel.SENTENCE_EMBEDDING;

public class TokenizeModelTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

private File modelZipFile;
private String modelId;
private String modelName;
private FunctionName functionName;
private String version;
private MLModel model;
private ModelHelper modelHelper;
private Map<String, Object> params;
private TokenizerModel tokenizerModel;
private Path mlCachePath;
private Path mlConfigPath;
private TextDocsInputDataSet inputDataSet;
private int dimension = 384;
private MLEngine mlEngine;
private Encryptor encryptor;

@Before
public void setUp() throws URISyntaxException {
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelId = "test_model_id";
modelName = "test_model_name";
functionName = FunctionName.TEXT_EMBEDDING;
version = "1";
model = MLModel.builder()
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.name("test_model_name")
.modelId("test_model_id")
.algorithm(FunctionName.TOKENIZE)
.version("1.0.0")
.modelState(MLModelState.TRAINED)
.build();
modelHelper = new ModelHelper(mlEngine);
params = new HashMap<>();
modelZipFile = new File(getClass().getResource("all-MiniLM-L6-v2_torchscript_sentence-transformer.zip").toURI());
params.put(MODEL_ZIP_FILE, modelZipFile);
params.put(MODEL_HELPER, modelHelper);
params.put(ML_ENGINE, mlEngine);
tokenizerModel = new TokenizerModel();

inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("today is sunny", "That is a happy dog")).build();
}

@Test
public void initModel_predict_Tokenize_SmallModel() throws URISyntaxException {
Map<String, Object> params = new HashMap<>();
params.put(MODEL_HELPER, modelHelper);
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("demo_tokenize.zip").toURI()));
params.put(ML_ENGINE, mlEngine);
MLModel smallModel = model.toBuilder().build();
tokenizerModel.initModel(smallModel, params, encryptor);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TOKENIZE).inputDataset(inputDataSet).build();
ModelTensorOutput output = (ModelTensorOutput)tokenizerModel.predict(mlInput);
List<ModelTensors> mlModelOutputs = output.getMlModelOutputs();
assertEquals(2, mlModelOutputs.size());
for (int i=0;i<mlModelOutputs.size();i++) {
ModelTensors tensors = mlModelOutputs.get(i);
List<ModelTensor> mlModelTensors = tensors.getMlModelTensors();
assertEquals(2, mlModelTensors.size());
ModelTensor tensor = mlModelTensors.get(0);
Map<String, ?> resultMap = tensor.getDataAsMap();
assertEquals(resultMap.size(), 4);
}
tokenizerModel.close();
}


@Test
public void initModel_predict_Tokenize_SmallModel_ResultFilter() {
tokenizerModel.initModel(model, params, encryptor);
ModelResultFilter resultFilter = ModelResultFilter.builder().returnNumber(true).targetResponse(Arrays.asList(SENTENCE_EMBEDDING)).build();
TextDocsInputDataSet textDocsInputDataSet = inputDataSet.toBuilder().resultFilter(resultFilter).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
ModelTensorOutput output = (ModelTensorOutput)tokenizerModel.predict(mlInput);
List<ModelTensors> mlModelOutputs = output.getMlModelOutputs();
assertEquals(2, mlModelOutputs.size());
for (int i=0;i<mlModelOutputs.size();i++) {
ModelTensors tensors = mlModelOutputs.get(i);
List<ModelTensor> mlModelTensors = tensors.getMlModelTensors();
assertEquals(1, mlModelTensors.size());
}
tokenizerModel.close();
}

@Test
public void initModel_NullModelHelper() throws URISyntaxException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model helper is null");
Map<String, Object> params = new HashMap<>();
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("demo_tokenize.zip").toURI()));
tokenizerModel.initModel(model, params, encryptor);
}

@Test
public void initModel_NullMLEngine() throws URISyntaxException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("ML engine is null");
Map<String, Object> params = new HashMap<>();
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("demo_tokenize.zip").toURI()));
params.put(MODEL_HELPER, modelHelper);
tokenizerModel.initModel(model, params, encryptor);
}

@Test
public void initModel_NullModelId() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model id is null");
model.setModelId(null);
tokenizerModel.initModel(model, params, encryptor);
}

@Test
public void initModel_WrongModelFile() throws URISyntaxException {
try {
Map<String, Object> params = new HashMap<>();
params.put(MODEL_HELPER, modelHelper);
params.put(MODEL_ZIP_FILE, new File(getClass().getResource("wrong_zip_with_2_pt_file.zip").toURI()));
params.put(ML_ENGINE, mlEngine);
tokenizerModel.initModel(model, params, encryptor);
} catch (Exception e) {
assertEquals(MLException.class, e.getClass());
Throwable rootCause = ExceptionUtils.getRootCause(e);
assertEquals(IllegalArgumentException.class, rootCause.getClass());
assertEquals("found multiple models", rootCause.getMessage());
}
}

@Test
public void initModel_WrongFunctionName() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("wrong function name");
MLModel mlModel = model.toBuilder().algorithm(FunctionName.KMEANS).build();
tokenizerModel.initModel(mlModel, params, encryptor);
}

@Test
public void predict_NullModelHelper() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model not deployed");
tokenizerModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
}

@Test
public void predict_NullModelId() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model not deployed");
model.setModelId(null);
try {
tokenizerModel.initModel(model, params, encryptor);
} catch (Exception e) {
assertEquals("model id is null", e.getMessage());
}
tokenizerModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
}

@Test
public void predict_AfterModelClosed() {
exceptionRule.expect(MLException.class);
exceptionRule.expectMessage("Failed to inference TOKENIZE");
tokenizerModel.initModel(model, params, encryptor);
tokenizerModel.close();
tokenizerModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
}

@Test
public void parseModelTensorOutput_NullOutput() {
exceptionRule.expect(MLException.class);
exceptionRule.expectMessage("No output generated");
tokenizerModel.parseModelTensorOutput(null, null);
}

@Test
public void predict_BeforeInitingModel() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("model not deployed");
tokenizerModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(), model);
}

@After
public void tearDown() {
FileUtils.deleteFileQuietly(mlCachePath);
}

private int findSentenceEmbeddingPosition(ModelTensors modelTensors) {
List<ModelTensor> mlModelTensors = modelTensors.getMlModelTensors();
for (int i=0; i<mlModelTensors.size(); i++) {
ModelTensor mlModelTensor = mlModelTensors.get(i);
if (SENTENCE_EMBEDDING.equals(mlModelTensor.getName())) {
return i;
}
}
throw new ResourceNotFoundException("no sentence embedding found");
}
}