Skip to content

[Bug] model_type argument as str for checkpoints classes #1946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/deep_dives/checkpointer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ For this section we'll use the Llama2 13B model in HF format.
checkpoint_dir=checkpoint_dir,
checkpoint_files=pytorch_files,
output_dir=checkpoint_dir,
model_type=ModelType.LLAMA2
model_type="LLAMA2"
)
torchtune_sd = checkpointer.load_checkpoint()

Expand Down
14 changes: 7 additions & 7 deletions torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class FullModelTorchTuneCheckpointer(_CheckpointerInterface):
checkpoint_dir (str): Directory containing the checkpoint files
checkpoint_files (List[str]): List of checkpoint files to load. Since the checkpointer takes care
of sorting by file ID, the order in this list does not matter
model_type (ModelType): Model type of the model for which the checkpointer is being loaded
model_type (str): Model type of the model for which the checkpointer is being loaded
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
Expand All @@ -130,7 +130,7 @@ def __init__(
self,
checkpoint_dir: str,
checkpoint_files: List[str],
model_type: ModelType,
model_type: str,
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(
)

self._resume_from_checkpoint = resume_from_checkpoint
self._model_type = model_type
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)

# recipe_checkpoint contains the recipe state. This should be available if
Expand Down Expand Up @@ -322,7 +322,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
checkpoint_dir (str): Directory containing the checkpoint files
checkpoint_files (Union[List[str], Dict[str, str]]): List of checkpoint files to load. Since the checkpointer takes care
of sorting by file ID, the order in this list does not matter. TODO: update this
model_type (ModelType): Model type of the model for which the checkpointer is being loaded
model_type (str): Model type of the model for which the checkpointer is being loaded
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
Expand All @@ -338,7 +338,7 @@ def __init__(
self,
checkpoint_dir: str,
checkpoint_files: Union[List[str], Dict[str, str]],
model_type: ModelType,
model_type: str,
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
Expand Down Expand Up @@ -723,7 +723,7 @@ class FullModelMetaCheckpointer(_CheckpointerInterface):
checkpoint_dir (str): Directory containing the checkpoint files
checkpoint_files (List[str]): List of checkpoint files to load. Currently this checkpointer only
supports loading a single checkpoint file.
model_type (ModelType): Model type of the model for which the checkpointer is being loaded
model_type (str): Model type of the model for which the checkpointer is being loaded
output_dir (str): Directory to save the checkpoint files
adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
Expand All @@ -739,7 +739,7 @@ def __init__(
self,
checkpoint_dir: str,
checkpoint_files: List[str],
model_type: ModelType,
model_type: str,
output_dir: str,
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
Expand Down
Loading