Skip to content

Commit a536d2f

Browse files
[ML] Fix for Deberta tokenizer when input sequence exceeds 512 tokens (elastic#117595)
* Add test and fix * Update docs/changelog/117595.yaml * Remove test which wasn't working (cherry picked from commit 433a00c)
1 parent 4975f1d commit a536d2f

File tree

4 files changed

+61
-2
lines changed

4 files changed

+61
-2
lines changed

docs/changelog/117595.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117595
2+
summary: Fix for Deberta tokenizer when input sequence exceeds 512 tokens
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/NlpTokenizer.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,29 @@ public List<TokenizationResult.Tokens> tokenize(String seq1, String seq2, Tokeni
331331
tokenIdsSeq2 = tokenIdsSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
332332
tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, maxSequenceLength() - extraTokens - tokenIdsSeq1.size());
333333
}
334+
case BALANCED -> {
335+
isTruncated = true;
336+
int firstSequenceLength = 0;
337+
338+
if (tokenIdsSeq2.size() > (maxSequenceLength() - getNumExtraTokensForSeqPair()) / 2) {
339+
firstSequenceLength = min(tokenIdsSeq1.size(), (maxSequenceLength() - getNumExtraTokensForSeqPair()) / 2);
340+
} else {
341+
firstSequenceLength = min(
342+
tokenIdsSeq1.size(),
343+
maxSequenceLength() - tokenIdsSeq2.size() - getNumExtraTokensForSeqPair()
344+
);
345+
}
346+
int secondSequenceLength = min(
347+
tokenIdsSeq2.size(),
348+
maxSequenceLength() - firstSequenceLength - getNumExtraTokensForSeqPair()
349+
);
350+
351+
tokenIdsSeq1 = tokenIdsSeq1.subList(0, firstSequenceLength);
352+
tokenPositionMapSeq1 = tokenPositionMapSeq1.subList(0, firstSequenceLength);
353+
354+
tokenIdsSeq2 = tokenIdsSeq2.subList(0, secondSequenceLength);
355+
tokenPositionMapSeq2 = tokenPositionMapSeq2.subList(0, secondSequenceLength);
356+
}
334357
case NONE -> throw ExceptionsHelper.badRequestException(
335358
"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]",
336359
numTokens,

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,22 @@
1010
import org.elasticsearch.test.ESTestCase;
1111
import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults;
1212
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
13+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization;
1314
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig;
1415
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
1617
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizationResult;
1718
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
19+
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2Tokenizer;
1820
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
1921
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
2022

2123
import java.io.IOException;
2224
import java.util.List;
2325

2426
import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
27+
import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_SCORES;
28+
import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DebertaV2TokenizerTests.TEST_CASE_VOCAB;
2529
import static org.hamcrest.Matchers.closeTo;
2630
import static org.hamcrest.Matchers.equalTo;
2731
import static org.hamcrest.Matchers.is;
@@ -62,6 +66,33 @@ public void testProcessor() throws IOException {
6266
assertThat(result.predictedValue(), closeTo(42, 1e-6));
6367
}
6468

69+
public void testBalancedTruncationWithLongInput() throws IOException {
70+
String question = "Is Elasticsearch scalable?";
71+
StringBuilder longInputBuilder = new StringBuilder();
72+
for (int i = 0; i < 1000; i++) {
73+
longInputBuilder.append(TEST_CASE_VOCAB.get(randomIntBetween(0, TEST_CASE_VOCAB.size() - 1))).append(i).append(" ");
74+
}
75+
String longInput = longInputBuilder.toString().trim();
76+
77+
DebertaV2Tokenization tokenization = new DebertaV2Tokenization(false, true, null, Tokenization.Truncate.BALANCED, -1);
78+
DebertaV2Tokenizer tokenizer = DebertaV2Tokenizer.builder(TEST_CASE_VOCAB, TEST_CASE_SCORES, tokenization).build();
79+
TextSimilarityConfig textSimilarityConfig = new TextSimilarityConfig(
80+
question,
81+
new VocabularyConfig(""),
82+
tokenization,
83+
"result",
84+
TextSimilarityConfig.SpanScoreFunction.MAX
85+
);
86+
TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer);
87+
TokenizationResult tokenizationResult = processor.getRequestBuilder(textSimilarityConfig)
88+
.buildRequest(List.of(longInput), "1", Tokenization.Truncate.BALANCED, -1, null)
89+
.tokenization();
90+
91+
// Assert that the tokenization result is as expected
92+
assertThat(tokenizationResult.anyTruncated(), is(true));
93+
assertThat(tokenizationResult.getTokenization(0).tokenIds().length, equalTo(512));
94+
}
95+
6596
public void testResultFunctions() {
6697
BertTokenization tokenization = new BertTokenization(false, true, 384, Tokenization.Truncate.NONE, 128);
6798
BertTokenizer tokenizer = BertTokenizer.builder(TEST_CASED_VOCAB, tokenization).build();

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2TokenizerTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
public class DebertaV2TokenizerTests extends ESTestCase {
2525

26-
private static final List<String> TEST_CASE_VOCAB = List.of(
26+
public static final List<String> TEST_CASE_VOCAB = List.of(
2727
DebertaV2Tokenizer.CLASS_TOKEN,
2828
DebertaV2Tokenizer.PAD_TOKEN,
2929
DebertaV2Tokenizer.SEPARATOR_TOKEN,
@@ -48,7 +48,7 @@ public class DebertaV2TokenizerTests extends ESTestCase {
4848
"<0xAD>",
4949
"▁"
5050
);
51-
private static final List<Double> TEST_CASE_SCORES = List.of(
51+
public static final List<Double> TEST_CASE_SCORES = List.of(
5252
0.0,
5353
0.0,
5454
0.0,

0 commit comments

Comments
 (0)