Skip to content

Commit d8ae4a9

Browse files
committed
Add tests
1 parent 9b9d7e5 commit d8ae4a9

File tree

9 files changed

+255
-53
lines changed

9 files changed

+255
-53
lines changed

keras_hub/src/models/backbone.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,20 +278,16 @@ def load_lora_weights(self, filepath):
278278
layer.lora_kernel_b.assign(lora_kernel_b)
279279
store.close()
280280

281-
def export_to_transformers(self, path, verbose=True):
281+
def export_to_transformers(self, path):
282282
"""Export the backbone model to HuggingFace Transformers format.
283-
284283
This saves the backbone's configuration and weights in a format
285284
compatible with HuggingFace Transformers. For unsupported model
286285
architectures, a ValueError is raised.
287-
288286
Args:
289287
path: str. Path to save the exported model.
290-
verbose: bool. If True, print success messages (default: True).
291-
292288
"""
293289
from keras_hub.src.utils.transformers.export.hf_exporter import (
294290
export_backbone,
295291
)
296292

297-
export_backbone(self, path, verbose=verbose)
293+
export_backbone(self, path)

keras_hub/src/models/backbone_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from keras_hub.src.models.backbone import Backbone
77
from keras_hub.src.models.bert.bert_backbone import BertBackbone
8+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
89
from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
910
from keras_hub.src.tests.test_case import TestCase
1011
from keras_hub.src.utils.preset_utils import CONFIG_FILE
@@ -15,6 +16,18 @@
1516

1617

1718
class TestBackbone(TestCase):
19+
def setUp(self):
20+
# Common config for backbone instantiation in export tests
21+
self.backbone_config = {
22+
"vocabulary_size": 1000,
23+
"num_layers": 2,
24+
"num_query_heads": 4,
25+
"num_key_value_heads": 1,
26+
"hidden_dim": 512,
27+
"intermediate_dim": 1024,
28+
"head_dim": 128,
29+
}
30+
1831
def test_preset_accessors(self):
1932
bert_presets = set(BertBackbone.presets.keys())
2033
gpt2_presets = set(GPT2Backbone.presets.keys())
@@ -105,3 +118,21 @@ def test_save_to_preset(self):
105118
ref_out = backbone(data)
106119
new_out = restored_backbone(data)
107120
self.assertAllClose(ref_out, new_out)
121+
122+
def test_export_supported_model(self):
123+
backbone = GemmaBackbone(**self.backbone_config)
124+
export_path = os.path.join(self.get_temp_dir(), "export_backbone")
125+
backbone.export_to_transformers(export_path)
126+
# Basic check: config file exists
127+
self.assertTrue(
128+
os.path.exists(os.path.join(export_path, "config.json"))
129+
)
130+
131+
def test_export_unsupported_model(self):
132+
class UnsupportedBackbone(GemmaBackbone):
133+
pass
134+
135+
backbone = UnsupportedBackbone(**self.backbone_config)
136+
export_path = os.path.join(self.get_temp_dir(), "unsupported")
137+
with self.assertRaises(ValueError):
138+
backbone.export_to_transformers(export_path)

keras_hub/src/models/causal_lm.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -393,30 +393,27 @@ def postprocess(x):
393393

394394
return self._normalize_generate_outputs(outputs, input_is_scalar)
395395

396-
def export_to_transformers(self, path, verbose=True):
396+
def export_to_transformers(self, path):
397397
"""Export the full CausalLM model to HuggingFace Transformers format.
398-
399398
This exports the backbone, tokenizer, and configurations in a format
400399
compatible with HuggingFace Transformers. For unsupported
401400
model architectures, a ValueError is raised.
402-
401+
If the preprocessor is None, only the backbone is exported.
403402
Args:
404403
path: str. Path to save the exported model.
405-
verbose: bool. If True, print success messages (default: True).
406-
407404
"""
408-
missing = []
409-
if self.preprocessor is None:
410-
missing.append("preprocessor")
411-
elif self.preprocessor.tokenizer is None:
412-
missing.append("tokenizer")
413-
if missing:
414-
raise ValueError(
415-
"CausalLM must have a preprocessor and a tokenizer for export. "
416-
"Missing: " + " ".join(missing)
417-
)
418405
from keras_hub.src.utils.transformers.export.hf_exporter import (
419-
export_to_safetensors,
406+
export_backbone,
407+
)
408+
from keras_hub.src.utils.transformers.export.hf_exporter import (
409+
export_tokenizer,
420410
)
421411

422-
export_to_safetensors(self, path, verbose=verbose)
412+
export_backbone(self.backbone, path)
413+
if self.preprocessor is not None:
414+
if self.preprocessor.tokenizer is None:
415+
raise ValueError(
416+
"CausalLM preprocessor must have a tokenizer for"
417+
"export if attached."
418+
)
419+
export_tokenizer(self.preprocessor.tokenizer, path)

keras_hub/src/models/causal_lm_preprocessor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,16 @@ def sequence_length(self, value):
180180
self._sequence_length = value
181181
if self.packer is not None:
182182
self.packer.sequence_length = value
183+
184+
def export_to_transformers(self, path):
185+
"""Export the preprocessor(tokenizer) to HuggingFace format.
186+
Args:
187+
path: str. Path to save the exported preprocessor/tokenizer.
188+
"""
189+
if self.tokenizer is None:
190+
raise ValueError("Preprocessor must have a tokenizer for export.")
191+
from keras_hub.src.utils.transformers.export.hf_exporter import (
192+
export_tokenizer,
193+
)
194+
195+
export_tokenizer(self.tokenizer, path)

keras_hub/src/models/causal_lm_preprocessor_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import os
2+
13
import pytest
4+
from sentencepiece import SentencePieceTrainer
25

36
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
47
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
8+
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
9+
GemmaCausalLMPreprocessor,
10+
)
11+
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
512
from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (
613
GPT2CausalLMPreprocessor,
714
)
@@ -10,6 +17,32 @@
1017

1118

1219
class TestCausalLMPreprocessor(TestCase):
20+
def setUp(self):
21+
# Common setup for export tests
22+
train_sentences = [
23+
"The quick brown fox jumped.",
24+
"I like pizza.",
25+
"This is a test.",
26+
]
27+
self.proto_prefix = os.path.join(self.get_temp_dir(), "dummy_vocab")
28+
SentencePieceTrainer.train(
29+
sentence_iterator=iter(train_sentences),
30+
model_prefix=self.proto_prefix,
31+
vocab_size=290,
32+
model_type="unigram",
33+
pad_id=0,
34+
bos_id=2,
35+
eos_id=1,
36+
unk_id=3,
37+
byte_fallback=True,
38+
pad_piece="<pad>",
39+
bos_piece="<bos>",
40+
eos_piece="<eos>",
41+
unk_piece="<unk>",
42+
user_defined_symbols=["<start_of_turn>", "<end_of_turn>"],
43+
add_dummy_prefix=False,
44+
)
45+
1346
def test_preset_accessors(self):
1447
bert_presets = set(BertTokenizer.presets.keys())
1548
gpt2_presets = set(GPT2Preprocessor.presets.keys())
@@ -43,3 +76,21 @@ def test_from_preset_errors(self):
4376
with self.assertRaises(ValueError):
4477
# No loading on a non-keras model.
4578
GPT2CausalLMPreprocessor.from_preset("hf://spacy/en_core_web_sm")
79+
80+
def test_export_supported_preprocessor(self):
81+
tokenizer = GemmaTokenizer(proto=f"{self.proto_prefix}.model")
82+
preprocessor = GemmaCausalLMPreprocessor(tokenizer=tokenizer)
83+
export_path = os.path.join(self.get_temp_dir(), "export_preprocessor")
84+
preprocessor.export_to_transformers(export_path)
85+
# Basic check: tokenizer config exists
86+
self.assertTrue(
87+
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
88+
)
89+
90+
def test_export_missing_tokenizer(self):
91+
preprocessor = GemmaCausalLMPreprocessor(tokenizer=None)
92+
export_path = os.path.join(
93+
self.get_temp_dir(), "export_missing_tokenizer"
94+
)
95+
with self.assertRaises(ValueError):
96+
preprocessor.export_to_transformers(export_path)

keras_hub/src/models/task_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,16 @@
44
import keras
55
import numpy as np
66
import pytest
7+
from sentencepiece import SentencePieceTrainer
78

89
from keras_hub.src.models.bert.bert_text_classifier import BertTextClassifier
910
from keras_hub.src.models.causal_lm import CausalLM
11+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
12+
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
13+
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
14+
GemmaCausalLMPreprocessor,
15+
)
16+
from keras_hub.src.models.gemma.gemma_tokenizer import GemmaTokenizer
1017
from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
1118
from keras_hub.src.models.image_classifier import ImageClassifier
1219
from keras_hub.src.models.preprocessor import Preprocessor
@@ -44,6 +51,46 @@ def __init__(self, preprocessor=None, activation=None, **kwargs):
4451

4552

4653
class TestTask(TestCase):
54+
def setUp(self):
55+
# Common setup for export tests
56+
train_sentences = [
57+
"The quick brown fox jumped.",
58+
"I like pizza.",
59+
"This is a test.",
60+
]
61+
self.proto_prefix = os.path.join(self.get_temp_dir(), "dummy_vocab")
62+
SentencePieceTrainer.train(
63+
sentence_iterator=iter(train_sentences),
64+
model_prefix=self.proto_prefix,
65+
vocab_size=290,
66+
model_type="unigram",
67+
pad_id=0,
68+
bos_id=2,
69+
eos_id=1,
70+
unk_id=3,
71+
byte_fallback=True,
72+
pad_piece="<pad>",
73+
bos_piece="<bos>",
74+
eos_piece="<eos>",
75+
unk_piece="<unk>",
76+
user_defined_symbols=["<start_of_turn>", "<end_of_turn>"],
77+
add_dummy_prefix=False,
78+
)
79+
self.tokenizer = GemmaTokenizer(proto=f"{self.proto_prefix}.model")
80+
self.backbone = GemmaBackbone(
81+
vocabulary_size=self.tokenizer.vocabulary_size(),
82+
num_layers=2,
83+
num_query_heads=4,
84+
num_key_value_heads=1,
85+
hidden_dim=512,
86+
intermediate_dim=1024,
87+
head_dim=128,
88+
)
89+
self.preprocessor = GemmaCausalLMPreprocessor(tokenizer=self.tokenizer)
90+
self.causal_lm = GemmaCausalLM(
91+
backbone=self.backbone, preprocessor=self.preprocessor
92+
)
93+
4794
def test_preset_accessors(self):
4895
bert_presets = set(BertTextClassifier.presets.keys())
4996
gpt2_presets = set(GPT2CausalLM.presets.keys())
@@ -171,3 +218,50 @@ def test_save_to_preset_custom_backbone_and_preprocessor(self):
171218
restored_task = ImageClassifier.from_preset(save_dir)
172219
actual = restored_task.predict(batch)
173220
self.assertAllClose(expected, actual)
221+
222+
def test_export_attached(self):
223+
export_path = os.path.join(self.get_temp_dir(), "export_attached")
224+
self.causal_lm.export_to_transformers(export_path)
225+
# Basic check: config and tokenizer files exist
226+
self.assertTrue(
227+
os.path.exists(os.path.join(export_path, "config.json"))
228+
)
229+
self.assertTrue(
230+
os.path.exists(os.path.join(export_path, "tokenizer_config.json"))
231+
)
232+
233+
def test_export_detached(self):
234+
export_path_backbone = os.path.join(
235+
self.get_temp_dir(), "export_detached_backbone"
236+
)
237+
export_path_preprocessor = os.path.join(
238+
self.get_temp_dir(), "export_detached_preprocessor"
239+
)
240+
original_preprocessor = self.causal_lm.preprocessor
241+
self.causal_lm.preprocessor = None
242+
self.causal_lm.export_to_transformers(export_path_backbone)
243+
self.causal_lm.preprocessor = original_preprocessor
244+
self.preprocessor.export_to_transformers(export_path_preprocessor)
245+
# Basic check: backbone has config, no tokenizer; preprocessor has
246+
# tokenizer config
247+
self.assertTrue(
248+
os.path.exists(os.path.join(export_path_backbone, "config.json"))
249+
)
250+
self.assertFalse(
251+
os.path.exists(
252+
os.path.join(export_path_backbone, "tokenizer_config.json")
253+
)
254+
)
255+
self.assertTrue(
256+
os.path.exists(
257+
os.path.join(export_path_preprocessor, "tokenizer_config.json")
258+
)
259+
)
260+
261+
def test_export_missing_tokenizer(self):
262+
self.preprocessor.tokenizer = None
263+
export_path = os.path.join(
264+
self.get_temp_dir(), "export_missing_tokenizer"
265+
)
266+
with self.assertRaises(ValueError):
267+
self.causal_lm.export_to_transformers(export_path)

keras_hub/src/tokenizers/tokenizer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,18 +262,15 @@ class like `keras_hub.models.Tokenizer.from_preset()`, or from
262262
cls = find_subclass(preset, cls, backbone_cls)
263263
return loader.load_tokenizer(cls, config_file, **kwargs)
264264

265-
def export_to_transformers(self, path, verbose=True):
265+
def export_to_transformers(self, path):
266266
"""Export the tokenizer to HuggingFace Transformers format.
267-
268267
This saves tokenizer assets in a format compatible with HuggingFace
269268
Transformers.
270-
271269
Args:
272270
path: str. Path to save the exported tokenizer.
273-
verbose: bool. If True, print success messages (default: True).
274271
"""
275272
from keras_hub.src.utils.transformers.export.hf_exporter import (
276273
export_tokenizer,
277274
)
278275

279-
export_tokenizer(self, path, verbose=verbose)
276+
export_tokenizer(self, path)

0 commit comments

Comments
 (0)