|
10 | 10 | import org.elasticsearch.test.ESTestCase;
|
11 | 11 | import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults;
|
12 | 12 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
|
| 13 | +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization; |
13 | 14 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
|
14 | 15 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
|
15 | 16 | import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
|
16 | 17 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
|
17 | 18 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
|
| 19 | +import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2Tokenizer; |
18 | 20 | import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
|
19 | 21 | import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
|
20 | 22 |
|
21 | 23 | import java.io.IOException;
|
22 | 24 | import java.util.List;
|
23 | 25 |
|
24 | 26 | import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
|
| 27 | +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_SCORES; |
| 28 | +import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_VOCAB; |
25 | 29 | import static org.hamcrest.Matchers.closeTo;
|
26 | 30 | import static org.hamcrest.Matchers.equalTo;
|
27 | 31 | import static org.hamcrest.Matchers.is;
|
@@ -62,6 +66,33 @@ public void testProcessor() throws IOException {
|
62 | 66 | assertThat(result.predictedValue(), closeTo(42, 1e-6));
|
63 | 67 | }
|
64 | 68 |
|
| 69 | + public void testBalancedTruncationWithLongInput() throws IOException { |
| 70 | + String question = "Is Elasticsearch scalable?"; |
| 71 | + StringBuilder longInputBuilder = new StringBuilder(); |
| 72 | + for (int i = 0; i < 1000; i++) { |
| 73 | + longInputBuilder.append(TEST_CASE_VOCAB.get(randomIntBetween(0, TEST_CASE_VOCAB.size() - 1))).append(i).append(" "); |
| 74 | + } |
| 75 | + String longInput = longInputBuilder.toString().trim(); |
| 76 | + |
| 77 | + DebertaV2Tokenization tokenization = new DebertaV2Tokenization(false, true, null, Tokenization.Truncate.BALANCED, -1); |
| 78 | + DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder(TEST_CASE_VOCAB, TEST_CASE_SCORES, tokenization).build(); |
| 79 | + TextSimilarityConfig textSimilarityConfig = new TextSimilarityConfig( |
| 80 | + question, |
| 81 | + new VocabularyConfig(""), |
| 82 | + tokenization, |
| 83 | + "result", |
| 84 | + TextSimilarityConfig.SpanScoreFunction.MAX |
| 85 | + ); |
| 86 | + TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer); |
| 87 | + TokenizationResult tokenizationResult = processor.getRequestBuilder(textSimilarityConfig) |
| 88 | + .buildRequest(List.of(longInput), "1", Tokenization.Truncate.BALANCED, -1, null) |
| 89 | + .tokenization(); |
| 90 | + |
| 91 | + // Assert that the tokenization result is as expected |
| 92 | + assertThat(tokenizationResult.anyTruncated(), is(true)); |
| 93 | + assertThat(tokenizationResult.getTokenization(0).tokenIds().length, equalTo(512)); |
| 94 | + } |
| 95 | + |
65 | 96 | public void testResultFunctions() {
|
66 | 97 | BertTokenization tokenization = new BertTokenization(false, true, 384, Tokenization.Truncate.NONE, 128);
|
67 | 98 | BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build();
|
|
0 commit comments