diff --git a/llm/predictor.py b/llm/predictor.py index bbbaf1ceace4..5353d7f627ec 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -871,7 +871,7 @@ def create_predictor( ) model.eval() else: - raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]") + raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]") predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) elif predictor_args.mode == "static": config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) @@ -915,8 +915,16 @@ def create_predictor( cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape( config, predictor_args.batch_size, predictor_args.total_max_length ) + elif "qwen" in config.architectures[0].lower(): + from paddlenlp.experimental.transformers import ( + QWenForCausalLMInferenceModel, + ) + + cache_kvs_shape = QWenForCausalLMInferenceModel.get_cache_kvs_shape( + config, predictor_args.batch_size, predictor_args.total_max_length + ) else: - raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]") + raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]") predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) else: raise ValueError("the `mode` should be one of [dynamic, static]")