Skip to content

Commit 73f91c5

Browse files
committed
multiple tokenizers with different filenames can save now
1 parent 307c523 commit 73f91c5

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/transformers/processing_utils.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -794,10 +794,12 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
794794
if hasattr(attribute, "_set_processor_class"):
795795
attribute._set_processor_class(self.__class__.__name__)
796796

797-
# Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
798-
if attribute_name == "tokenizer":
799-
# Propagate save_jinja_files to tokenizer to ensure we don't get conflicts
800-
attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files)
797+
# if attribute is tokenizer, then save it in its own file for avoid overwriting
798+
if hasattr(attribute, "save_pretrained"):
799+
# use the attribute_name as prefix to create a unique file
800+
attribute_save_dir = os.path.join(save_directory, attribute_name)
801+
os.makedirs(attribute_save_dir, exist_ok=True)
802+
attribute.save_pretrained(attribute_save_dir, save_jinja_files=save_jinja_files)
801803
elif attribute._auto_class is not None:
802804
custom_object_save(attribute, save_directory, config=attribute)
803805

@@ -1450,7 +1452,14 @@ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
14501452
else:
14511453
attribute_class = cls.get_possibly_dynamic_module(class_name)
14521454

1453-
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
1455+
# updated loading path for handling multiple tokenizers
1456+
attribute_path = os.path.join(pretrained_model_name_or_path, attribute_name)
1457+
if os.path.isdir(attribute_path):
1458+
# load from its attribute's-specific folder
1459+
args.append(attribute_class.from_pretrained(attribute_path, **kwargs))
1460+
else:
1461+
# now fallback to original path
1462+
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
14541463

14551464
return args
14561465

0 commit comments

Comments
 (0)