diff --git a/paddlenlp/transformers/utils.py b/paddlenlp/transformers/utils.py index f785a5358af4..d4c319575aca 100644 --- a/paddlenlp/transformers/utils.py +++ b/paddlenlp/transformers/utils.py @@ -167,6 +167,18 @@ def adapt_stale_fwd_patch(self, name, value): "StaticFunction" ): return value + + # NOTE(changwenbin & zhoukangkang): + # When use model = paddle.incubate.jit.inference(model), it reportes errors, we fix it here. + # is_inference_mode API is only avaliable in PaddlePaddle develop,so we add a try except. + try: + from paddle.incubate.jit import is_inference_mode + + if is_inference_mode(value): + return value + except: + pass + if hasattr(inspect, "getfullargspec"): ( patch_spec_args,