From 8cf2091aee73f9a9d690a2b54428638be7ad5147 Mon Sep 17 00:00:00 2001 From: ranch Date: Fri, 28 Jul 2023 14:46:13 +0800 Subject: [PATCH] add test for test_get_keys_to_not_convert --- tests/bnb/test_mixed_int8.py | 44 ++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/bnb/test_mixed_int8.py b/tests/bnb/test_mixed_int8.py index f905b26e3f71c2..8b23e576e2e4d2 100644 --- a/tests/bnb/test_mixed_int8.py +++ b/tests/bnb/test_mixed_int8.py @@ -124,6 +124,50 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() + def test_get_keys_to_not_convert(self): + r""" + Test the `get_keys_to_not_convert` function. + """ + from transformers import MptForCausalLM, Blip2ForConditionalGeneration, OPTForCausalLM, AutoModelForMaskedLM + from transformers.utils.bitsandbytes import get_keys_to_not_convert + from accelerate import init_empty_weights + + config = AutoConfig.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True) + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"]) + + config = AutoConfig.from_pretrained("mosaicml/mpt-7b") + with init_empty_weights(): + model = MptForCausalLM(config) + # The order of the keys does not matter, so we sort them before comparing, same for the other tests. + self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "transformer.wte"].sort()) + + model_id = "Salesforce/blip2-opt-2.7b" + config = AutoConfig.from_pretrained(model_id) + + with init_empty_weights(): + model = Blip2ForConditionalGeneration(config) + self.assertEqual( + get_keys_to_not_convert(model).sort(), + ["language_model.lm_head", "language_model.model.decoder.embed_tokens"].sort(), + ) + + model_id = "facebook/opt-350m" + config = AutoConfig.from_pretrained(model_id) + with init_empty_weights(): + model = OPTForCausalLM(config) + self.assertEqual(get_keys_to_not_convert(model).sort(), ["lm_head", "model.decoder.embed_tokens"].sort()) + + model_id = "roberta-large" + config = AutoConfig.from_pretrained(model_id) + with init_empty_weights(): + model = AutoModelForMaskedLM.from_config(config) + self.assertEqual( + get_keys_to_not_convert(model).sort(), + ["'roberta.embeddings.word_embeddings', 'lm_head', 'lm_head.decoder"].sort(), + ) + def test_quantization_config_json_serialization(self): r""" A simple test to check if the quantization config is correctly serialized and deserialized