Skip to content

Commit

Permalink
add test for test_get_keys_to_not_convert
Browse files Browse the repository at this point in the history
  • Loading branch information
ranchlai committed Jul 28, 2023
1 parent 2ef72c1 commit 8cf2091
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,50 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def test_get_keys_to_not_convert(self):
r"""
Test the `get_keys_to_not_convert` function.
"""
from transformers import MptForCausalLM, Blip2ForConditionalGeneration, OPTForCausalLM, AutoModelForMaskedLM
from transformers.utils.bitsandbytes import get_keys_to_not_convert
from accelerate import init_empty_weights

config = AutoConfig.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])

config = AutoConfig.from_pretrained("mosaicml/mpt-7b")
with init_empty_weights():
model = MptForCausalLM(config)
# The order of the keys does not matter, so we sort them before comparing, same for the other tests.
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "transformer.wte"].sort())

model_id = "Salesforce/blip2-opt-2.7b"
config = AutoConfig.from_pretrained(model_id)

with init_empty_weights():
model = Blip2ForConditionalGeneration(config)
self.assertEqual(
get_keys_to_not_convert(model).sort(),
["language_model.lm_head", "language_model.model.decoder.embed_tokens"].sort(),
)

model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id)
with init_empty_weights():
model = OPTForCausalLM(config)
self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "model.decoder.embed_tokens"].sort())

model_id = "roberta-large"
config = AutoConfig.from_pretrained(model_id)
with init_empty_weights():
model = AutoModelForMaskedLM.from_config(config)
self.assertEqual(
get_keys_to_not_convert(model).sort(),
["'roberta.embeddings.word_embeddings', 'lm_head', 'lm_head.decoder"].sort(),
)

def test_quantization_config_json_serialization(self):
r"""
A simple test to check if the quantization config is correctly serialized and deserialized
Expand Down

0 comments on commit 8cf2091

Please sign in to comment.