Skip to content

Commit 878a7f6

Browse files
author
Martin Yuan
committed
Refactor LLMEdgeManager's to_dtype
1 parent 90f0843 commit 878a7f6

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -588,25 +588,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
588588
)
589589

590590
# At this point, the model is loaded in the default fp32.
591-
592-
# Checkpoint dtype should be lower or equal precision to the dtype override.
591+
# override dtype
593592
checkpoint_dtype = edge_manager.model.checkpoint_dtype
594-
if not (
595-
checkpoint_dtype == dtype_override.to_torch_dtype()
596-
or (
597-
checkpoint_dtype == torch.float16
598-
and dtype_override.to_torch_dtype() == torch.float32
599-
)
600-
or (
601-
checkpoint_dtype == torch.bfloat16
602-
and dtype_override.to_torch_dtype() == torch.float32
603-
)
604-
):
605-
logging.warning(
606-
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
607-
)
608-
609-
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
593+
edge_manager.to_dtype(dtype_override)
610594

611595
# We want to quantize (in the source transforms) the weights of the model
612596
# in the checkpoint dtype.

extension/llm/export/builder.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,25 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
147147
assert not dtype_override or isinstance(
148148
dtype_override, DType
149149
), "Override dtype needs to be of type <DType>"
150-
if dtype_override is not None and dtype_override != self.dtype:
150+
151+
# Checkpoint dtype should be lower or equal precision to the dtype override.
152+
checkpoint_dtype = self.model.checkpoint_dtype
153+
if not (
154+
checkpoint_dtype == dtype_override.to_torch_dtype()
155+
or (
156+
checkpoint_dtype == torch.float16
157+
and dtype_override.to_torch_dtype() == torch.float32
158+
)
159+
or (
160+
checkpoint_dtype == torch.bfloat16
161+
and dtype_override.to_torch_dtype() == torch.float32
162+
)
163+
):
164+
logging.warning(
165+
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
166+
)
167+
168+
if dtype_override != self.dtype:
151169
torch_dtype = dtype_override.to_torch_dtype()
152170
logging.info(f"model.to {torch_dtype}")
153171
self.model = self.model.to(dtype=torch_dtype)

0 commit comments

Comments
 (0)