Skip to content

Commit 05ce995

Browse files
committed
Add T5XXLPreprocessor and remove T5XXLTokenizer
1 parent dcf3ec6 commit 05ce995

File tree

4 files changed

+164
-104
lines changed

4 files changed

+164
-104
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import keras
15+
16+
from keras_nlp.src.layers.preprocessing.start_end_packer import StartEndPacker
17+
from keras_nlp.src.models.preprocessor import Preprocessor
18+
from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer
19+
from keras_nlp.src.utils.keras_utils import (
20+
convert_inputs_to_list_of_tensor_segments,
21+
)
22+
23+
24+
class T5XXLPreprocessor(Preprocessor):
25+
tokenizer_cls = T5Tokenizer
26+
27+
def __init__(
28+
self,
29+
tokenizer,
30+
sequence_length=256,
31+
add_start_token=False,
32+
add_end_token=True,
33+
**kwargs,
34+
):
35+
super().__init__(**kwargs)
36+
self.tokenizer = tokenizer
37+
self.sequence_length = sequence_length
38+
self.add_start_token = add_start_token
39+
self.add_end_token = add_end_token
40+
41+
def build(self, input_shape):
42+
# Defer packer creation to `build()` so that we can be sure tokenizer
43+
# assets have loaded when restoring a saved model.
44+
self.packer = StartEndPacker(
45+
start_value=self.tokenizer.start_token_id,
46+
end_value=self.tokenizer.end_token_id,
47+
pad_value=self.tokenizer.pad_token_id,
48+
sequence_length=self.sequence_length,
49+
return_padding_mask=True,
50+
)
51+
self.built = True
52+
53+
def call(
54+
self,
55+
x,
56+
y=None,
57+
sample_weight=None,
58+
sequence_length=None,
59+
):
60+
x = convert_inputs_to_list_of_tensor_segments(x)
61+
if len(x) != 1:
62+
raise ValueError(
63+
"T5XXL requires each input feature to contain only "
64+
f"one segment, but received {len(x)}. If you are using T5XXL"
65+
" for a multi-segment classification task, please refer to "
66+
"classification models like BERT or RoBERTa."
67+
)
68+
sequence_length = sequence_length or self.sequence_length
69+
token_ids, padding_mask = self.packer(
70+
self.tokenizer(x[0]),
71+
sequence_length=sequence_length,
72+
add_start_value=self.add_start_token,
73+
add_end_value=self.add_end_token,
74+
)
75+
x = {
76+
"token_ids": token_ids,
77+
"padding_mask": padding_mask,
78+
}
79+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
80+
81+
def get_config(self):
82+
config = super().get_config()
83+
config.update(
84+
{
85+
"sequence_length": self.sequence_length,
86+
"add_start_token": self.add_start_token,
87+
"add_end_token": self.add_end_token,
88+
}
89+
)
90+
return config
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import pytest
17+
18+
from keras_nlp.src.models.stable_diffusion_v3.t5_xxl_preprocessor import (
19+
T5XXLPreprocessor,
20+
)
21+
from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer
22+
from keras_nlp.src.tests.test_case import TestCase
23+
24+
25+
class GemmaPreprocessorTest(TestCase):
26+
def setUp(self):
27+
self.tokenizer = T5Tokenizer(
28+
proto=os.path.join(self.get_test_data_dir(), "t5_test_vocab.spm")
29+
)
30+
self.init_kwargs = {
31+
"tokenizer": self.tokenizer,
32+
"sequence_length": 10,
33+
}
34+
self.input_data = ["the quick brown fox"]
35+
36+
def test_preprocessor_basics(self):
37+
self.run_preprocessing_layer_test(
38+
cls=T5XXLPreprocessor,
39+
init_kwargs=self.init_kwargs,
40+
input_data=self.input_data,
41+
expected_output={
42+
"token_ids": [[4, 9, 5, 7, 1, 0, 0, 0, 0, 0]],
43+
"padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]],
44+
},
45+
)
46+
47+
def test_no_start_end_token(self):
48+
input_data = ["the quick brown fox"] * 4
49+
preprocessor = T5XXLPreprocessor(
50+
tokenizer=self.tokenizer,
51+
sequence_length=8,
52+
add_start_token=False,
53+
add_end_token=False,
54+
)
55+
x = preprocessor(input_data)
56+
self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4)
57+
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
58+
59+
def test_sequence_length_override(self):
60+
input_data = "the quick brown fox"
61+
preprocessor = T5XXLPreprocessor(**self.init_kwargs)
62+
x = preprocessor(input_data, sequence_length=4)
63+
self.assertAllEqual(x["token_ids"], [4, 9, 5, 1])
64+
65+
@pytest.mark.kaggle_key_required
66+
@pytest.mark.extra_large
67+
def test_all_presets(self):
68+
self.skipTest("TODO")
69+
for preset in T5XXLPreprocessor.presets:
70+
self.run_preset_test(
71+
cls=T5XXLPreprocessor,
72+
preset=preset,
73+
input_data=self.input_data,
74+
)

keras_nlp/src/models/stable_diffusion_v3/t5_xxl_tokenizer.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

keras_nlp/src/models/stable_diffusion_v3/t5_xxl_tokenizer_test.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)