From f639366408666e289305a7acf503383d9040e9f5 Mon Sep 17 00:00:00 2001 From: Vivicai1005 Date: Tue, 6 Feb 2024 17:47:35 +0800 Subject: [PATCH] edit version check --- infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/infer.py b/infer.py index f6112d8..877e277 100644 --- a/infer.py +++ b/infer.py @@ -4,6 +4,7 @@ import fire import torch import transformers +from packaging import version from utils.modeling_hack import get_model from utils.streaming import generate_stream @@ -24,7 +25,7 @@ def main( rope_factor: float = 8.0, streaming: bool = True # streaming is always enabled now ): - assert transformers.__version__.startswith('4.34') + assert version.parse(transformers.__version__) >= version.parse(4.34) assert model_type.lower() in ['chat', 'base'], f"model_type must be one of ['chat', 'base'], got {model_type}" assert rope_scaling in [None, 'yarn', 'dynamic'], f"rope_scaling must be one of [None, 'yarn', 'dynamic'], got {rope_scaling}"