From 7d8a2072d26e8f78454df0229e948f147bd9635e Mon Sep 17 00:00:00 2001 From: Saeid Ghafouri Date: Tue, 4 Apr 2023 04:16:30 -0400 Subject: [PATCH] =?UTF-8?q?expose=20the=20cabability=20of=20choosing=20DL?= =?UTF-8?q?=20framework=20of=20the=20HF=20pipelines=20mo=E2=80=A6=20(#1066?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- runtimes/huggingface/mlserver_huggingface/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/runtimes/huggingface/mlserver_huggingface/common.py b/runtimes/huggingface/mlserver_huggingface/common.py index 932b9dca4..a3d0eb5b7 100644 --- a/runtimes/huggingface/mlserver_huggingface/common.py +++ b/runtimes/huggingface/mlserver_huggingface/common.py @@ -53,6 +53,7 @@ class Config: task_suffix: str = "" pretrained_model: Optional[str] = None pretrained_tokenizer: Optional[str] = None + framework: Optional[str] = None optimum_model: bool = False device: int = -1 @@ -140,6 +141,7 @@ def load_pipeline_from_settings( tokenizer=tokenizer, device=device, batch_size=batch_size, + framework=hf_settings.framework, ) # If max_batch_size > 0 we need to ensure tokens are padded