Skip to content

Commit 05935e0

Browse files
committed
+ _get_head_dtype
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent f864138 commit 05935e0

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

vllm/config/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,13 +746,14 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
746746

747747
self.pooler_config = self._init_pooler_config()
748748

749-
self.dtype = _get_and_verify_dtype(
749+
self.dtype: torch.dtype = _get_and_verify_dtype(
750750
self.model,
751751
self.hf_config,
752752
self.dtype,
753753
is_pooling_model=self.runner_type == "pooling",
754754
revision=self.revision,
755755
)
756+
self.head_dtype: torch.dtype = self._get_head_dtype()
756757

757758
# Interleaved attention is not supported by some backends in V0
758759
if (not self.disable_sliding_window
@@ -1778,8 +1779,10 @@ def get_and_verify_max_len(self, max_model_len: int):
17781779
logger.info("Using max model len %s", max_model_len)
17791780
return max_model_len
17801781

1781-
@property
1782-
def head_dtype(self) -> torch.dtype:
1782+
def _get_head_dtype(self) -> torch.dtype:
1783+
if torch.float32 not in current_platform.supported_dtypes:
1784+
return self.dtype
1785+
17831786
if envs.VLLM_USING_FP32_HEAD:
17841787
return torch.float32
17851788

0 commit comments

Comments
 (0)