Skip to content

Commit

Permalink
Checkpoint submodule (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier authored Oct 23, 2024
1 parent a528154 commit 8f53e4d
Show file tree
Hide file tree
Showing 15 changed files with 98 additions and 95 deletions.
4 changes: 2 additions & 2 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from fast_llm.config import Config, config_class

if typing.TYPE_CHECKING:
from fast_llm.engine.checkpoint.external import ExternalStateDictConverter
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.engine.multi_stage.conversion import ExternalModelConverter


@config_class()
Expand All @@ -30,7 +30,7 @@ def compare_architecture(
return self.get_architecture().compare(model_config.get_architecture(), log_fn)

@classmethod
def get_converter_class(cls, model_type: str | None = None) -> type["ExternalModelConverter"]:
def get_converter_class(cls, model_type: str | None = None) -> type["ExternalStateDictConverter"]:
raise NotImplementedError()


Expand Down
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import safetensors
import torch
import yaml

from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
from fast_llm.engine.checkpoint.state_dict import StateDictConverter
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert

Expand Down Expand Up @@ -138,75 +138,7 @@ def import_weight(
return (torch.cat([weight_[:] for weight_ in weight]),)


class ModelConverter(abc.ABC):
base_file_name: typing.ClassVar[str]

@classmethod
@abc.abstractmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
pass

@abc.abstractmethod
def convert_state_dict(
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
pass

@abc.abstractmethod
def load_weights(
self,
directory: pathlib.Path | str,
device,
shard_names: list[str],
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
pass


def _import_safetensors_metadata(metadata):
return {key: yaml.safe_load(value) for key, value in metadata.items()}


class TrivialConverter(ModelConverter):
base_file_name = "state_dict"

@classmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
return f"{parameter_name}/{shard_name}"

def convert_state_dict(
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
out_state_dict = state_dict.copy()
state_dict.clear()
return out_state_dict

def load_weights(
self,
directory: pathlib.Path | str,
device,
shard_names: list[str],
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
index_path = directory / f"state_dict.safetensors.index.json"
logger.info(f"Loading index from {index_path}")
file_names = set(json.load(index_path.open("r"))["weight_map"].values())
for file_name in file_names:
logger.info(f"Loading from {directory / file_name}")
with safetensors.safe_open(
directory / file_name,
framework="pt",
device=str(device),
) as f:
metadata = _import_safetensors_metadata(f.metadata())
Assert.eq(metadata["state_shard_names"][: len(shard_names)], list(shard_names))
for key in f.keys():
parameter_name, shard_name = key.split("/", 1)
if shard_name in shard_names:
yield parameter_name, shard_name, f.get_slice(key)

# return metadata["metadata"]


class ExternalModelConverter(ModelConverter):
class ExternalStateDictConverter(StateDictConverter):
base_file_name = "model"
_base_model_cls: type[BaseModelConfig]
_config_converters: list[ParamConverter]
Expand Down Expand Up @@ -326,8 +258,8 @@ def _get_fast_llm_attribute(config: BaseModelArchitectureConfig, name: str | tup
return val


class AutoModelConverter(ExternalModelConverter, abc.ABC):
converter_map: dict[str, type[ExternalModelConverter]]
class AutoStateDictConverter(ExternalStateDictConverter, abc.ABC):
converter_map: dict[str, type[ExternalStateDictConverter]]

@classmethod
def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
Expand All @@ -338,7 +270,7 @@ def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = Fa
return cls.converter_map[config["model_type"]].from_config(config, architecture_only)


class HuggingfaceModelConverter(ExternalModelConverter, abc.ABC):
class HuggingfaceStateDictConverter(ExternalStateDictConverter, abc.ABC):
model_type: str | None = None

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import abc
import json
import logging
import pathlib
import typing

import safetensors
import safetensors.torch
import torch
import yaml

from fast_llm.core.distributed import safe_barrier
from fast_llm.engine.config_utils.checkpoint import CheckpointSaveConfig
from fast_llm.engine.checkpoint.config import CheckpointSaveConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


def _import_safetensors_metadata(metadata):
return {key: yaml.safe_load(value) for key, value in metadata.items()}


def _export_safetensors_metadata(metadata):
"""
Safetensor only accepts string entries, so we convert to string explicitly.
Expand All @@ -26,6 +36,68 @@ def _export_safetensors_metadata(metadata):
}


class StateDictConverter(abc.ABC):
base_file_name: typing.ClassVar[str]

@classmethod
@abc.abstractmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
pass

@abc.abstractmethod
def convert_state_dict(
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
pass

@abc.abstractmethod
def load_weights(
self,
directory: pathlib.Path | str,
device,
shard_names: list[str],
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
pass


class TrivialConverter(StateDictConverter):
base_file_name = "state_dict"

@classmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
return f"{parameter_name}/{shard_name}"

def convert_state_dict(
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
out_state_dict = state_dict.copy()
state_dict.clear()
return out_state_dict

def load_weights(
self,
directory: pathlib.Path | str,
device,
shard_names: list[str],
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
index_path = directory / f"state_dict.safetensors.index.json"
logger.info(f"Loading index from {index_path}")
file_names = set(json.load(index_path.open("r"))["weight_map"].values())
for file_name in file_names:
logger.info(f"Loading from {directory / file_name}")
with safetensors.safe_open(
directory / file_name,
framework="pt",
device=str(device),
) as f:
metadata = _import_safetensors_metadata(f.metadata())
Assert.eq(metadata["state_shard_names"][: len(shard_names)], list(shard_names))
for key in f.keys():
parameter_name, shard_name = key.split("/", 1)
if shard_name in shard_names:
yield parameter_name, shard_name, f.get_slice(key)


class StateDictSaver:
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/huggingface/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import transformers

from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadMetadataConfig
from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadMetadataConfig
from fast_llm.engine.multi_stage.config import FastLLMModelConfig

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import transformers.modeling_outputs

from fast_llm.config import NoAutoValidate
from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig
from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadConfig
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.huggingface.config import HuggingfaceModelConfig
from fast_llm.engine.multi_stage.config import StageMode
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from fast_llm.config import Config, Field, FieldHint, NoAutoValidate, check_field, config_class, skip_valid_if_none
from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.engine.config_utils.checkpoint import (
from fast_llm.engine.checkpoint.config import (
CHECKPOINT_VERSION,
KNOWN_CHECKPOINT_VERSIONS,
CheckpointFormat,
Expand Down
9 changes: 4 additions & 5 deletions fast_llm/engine/multi_stage/fast_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@

from fast_llm.core.distributed import all_reduce, broadcast
from fast_llm.engine.base_model.base_model import BaseModel
from fast_llm.engine.config_utils.checkpoint import (
from fast_llm.engine.checkpoint.config import (
CHECKPOINT_VERSION,
CheckpointFormat,
CheckpointLoadConfig,
CheckpointSaveConfig,
ModelConfigType,
)
from fast_llm.engine.checkpoint.state_dict import StateDictConverter, StateDictSaver, TrivialConverter
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.checkpoint import StateDictSaver
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
from fast_llm.engine.multi_stage.conversion import ModelConverter, TrivialConverter
from fast_llm.engine.multi_stage.multi_stage import MultiStageModel
from fast_llm.functional.triton.pointwise import triton_fill
from fast_llm.utils import Assert
Expand Down Expand Up @@ -120,7 +119,7 @@ def save_checkpoint(
else:
raise NotImplementedError(config.format)

def _save_state_dict(self, config: CheckpointSaveConfig, converter: ModelConverter, metadata: dict):
def _save_state_dict(self, config: CheckpointSaveConfig, converter: StateDictConverter, metadata: dict):
with StateDictSaver(
config,
distributed=self._distributed,
Expand Down Expand Up @@ -168,7 +167,7 @@ def load_checkpoint(self, config: CheckpointLoadConfig):
raise NotImplementedError(config.format)
return metadata.get("metadata")

def _load_state_dict(self, config: CheckpointLoadConfig, converter: ModelConverter):
def _load_state_dict(self, config: CheckpointLoadConfig, converter: StateDictConverter):
num_shards = len(self._state_shard_names) if config.optimizer_state else 1
with self._SafeLoadContext(self, num_shards=num_shards) as context:
state_dict = {}
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.data.config import AbstractDataConfig
from fast_llm.engine.config_utils.checkpoint import (
from fast_llm.engine.checkpoint.config import (
CheckpointConfigBase,
CheckpointFormat,
CheckpointLoadConfig,
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds

if typing.TYPE_CHECKING:
from fast_llm.engine.multi_stage.conversion import ExternalModelConverter
from fast_llm.engine.checkpoint.external import ExternalStateDictConverter


@config_class()
Expand All @@ -28,7 +28,7 @@ def _from_dict(
return super()._from_dict(default, strict, flat)

@classmethod
def get_converter_class(cls, model_type: str | None = None) -> type["ExternalModelConverter"]:
def get_converter_class(cls, model_type: str | None = None) -> type["ExternalStateDictConverter"]:
from fast_llm.models.gpt.conversion import AutoGPTConverter

return AutoGPTConverter if model_type is None else AutoGPTConverter.converter_map[model_type]
Expand Down
10 changes: 5 additions & 5 deletions fast_llm/models/gpt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import torch

from fast_llm.engine.multi_stage.conversion import (
AutoModelConverter,
from fast_llm.engine.checkpoint.external import (
AutoStateDictConverter,
ConstantExportParamConverter,
ConstantImportParamConverter,
HuggingfaceModelConverter,
HuggingfaceStateDictConverter,
IgnoreImportParamConverter,
IgnoreWeightConverter,
MappedConfigParamConverter,
Expand Down Expand Up @@ -89,7 +89,7 @@ def import_weight(
return (merged_weight.t().contiguous(),)


class CommonHuggingfaceConverter(HuggingfaceModelConverter):
class CommonHuggingfaceConverter(HuggingfaceStateDictConverter):
config: GPTArchitectureConfig
_base_model_cls = GPTBaseModelConfig
"""
Expand Down Expand Up @@ -324,7 +324,7 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str):
]


class AutoGPTConverter(AutoModelConverter, HuggingfaceModelConverter, abc.ABC):
class AutoGPTConverter(AutoStateDictConverter, HuggingfaceStateDictConverter, abc.ABC):
converter_map = {
HuggingfaceModelType.starcoder2: Starcoder2HuggingfaceConverter,
HuggingfaceModelType.llama: LlamaHuggingfaceConverter,
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/tools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typing

from fast_llm.config import Field, config_class
from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig
from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
Expand Down
2 changes: 1 addition & 1 deletion tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import transformers
import yaml

from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig, ModelConfigType
from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadConfig, ModelConfigType
from fast_llm.engine.multi_stage.config import StageMode
from fast_llm.models.auto import model_registry
from fast_llm.tools.convert import ConversionConfig
Expand Down
2 changes: 1 addition & 1 deletion tools/push_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import subprocess

from fast_llm.config import Field, config_class
from fast_llm.engine.config_utils.checkpoint import CheckpointFormat
from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.config_utils.runnable import RunnableConfig

try:
Expand Down

0 comments on commit 8f53e4d

Please sign in to comment.