Skip to content

Commit

Permalink
Modular checkpointing (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier authored Oct 25, 2024
1 parent f9880e2 commit 8fee762
Show file tree
Hide file tree
Showing 17 changed files with 750 additions and 608 deletions.
5 changes: 0 additions & 5 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from fast_llm.config import Config, config_class

if typing.TYPE_CHECKING:
from fast_llm.engine.checkpoint.external import ExternalStateDictConverter
from fast_llm.engine.config_utils.tensor_space import TensorSpace


Expand All @@ -29,10 +28,6 @@ def compare_architecture(
):
return self.get_architecture().compare(model_config.get_architecture(), log_fn)

@classmethod
def get_converter_class(cls, model_type: str | None = None) -> type["ExternalStateDictConverter"]:
raise NotImplementedError()


@config_class()
class BaseModelConfig(BaseModelArchitectureConfig):
Expand Down
74 changes: 57 additions & 17 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,50 @@
# TODO: Use packaging.version? (Safer but extra requirement)
import abc
import enum
import logging
import pathlib
import typing
import warnings

import yaml

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel

logger = logging.getLogger(__name__)

# TODO: Use packaging.version? (Safer but extra requirement)
CHECKPOINT_VERSION = "0.1"
KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1")


class CheckpointFormat(str, enum.Enum):
def export_safetensors_metadata(metadata):
"""
Safetensor only accepts string entries, so we convert to string explicitly.
We use yaml rather than json because json requires explicit quotation marks on strings, which breaks things.
(ex. "format": "pt" becomes '"pt"' which breaks huggingface models.)
We avoid using safe_dump for scalars because it adds junk ("\n...\n") at the end of the string
(decoding is unaffected.)
"""
return {
key: str(value) if isinstance(value, (str, int, float, bool)) else yaml.safe_dump(value)
for key, value in metadata.items()
}


def import_safetensors_metadata(metadata):
return {key: yaml.safe_load(value) for key, value in metadata.items()}


class CheckpointFormat(str):
# Distributed checkpoint for fast checkpointing and resuming.
distributed = "distributed"
# Model state dict, for safe long-term storage in Fast-LLM format.
state_dict = "state_dict"
# A checkpoint format external to Fast-LLM.
external = "external"


class ModelConfigType(str, enum.Enum):
Expand Down Expand Up @@ -57,16 +79,11 @@ class CheckpointPathConfigBase(Config):
@config_class()
class CheckpointConfigBase(Config):
_abstract = True
format: CheckpointFormat = Field(
format: str = Field(
default=CheckpointFormat.distributed,
desc="Format of the checkpoint.",
hint=FieldHint.core,
)
model_type: str | None = Field(
default=None,
desc="Model type for external models (ex. Huggingace model name).",
hint=FieldHint.feature,
)

@classmethod
def _from_dict(
Expand All @@ -76,10 +93,17 @@ def _from_dict(
flat: bool = False,
):
# TODO v0.2: Remove.
if default.get("format", None) == "huggingface":
warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.")
default["format"] = CheckpointFormat.external.value
cls._handle_renamed_field(default, "imported_type", "model_type")
if "model_type" in default:
warnings.warn(
"`CheckpointConfigBase.model_type` is deprecated."
" Instead, use the model name directly as the checkpoint format."
)
if default.get("format", None) in ("huggingface", "external"):
default["format"] = default.get("model_type")
if default["format"] is None:
default["format"] = "auto"
del default["model_type"]
return super()._from_dict(default, strict, flat)


Expand Down Expand Up @@ -151,8 +175,24 @@ def compare_log_fn(self):
class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase):
_abstract = False

def _validate(self):
super()._validate()
if self.format == CheckpointFormat.external:
# TODO: Support optimizer?
assert not self.optimizer_state

class Converter(abc.ABC):
# TODO: Rename? (Checkpointer? Saver?)

def __init__(self, model: "FastLLMModel"):
self._model = model

# TODO: save_metadata?

@classmethod
@abc.abstractmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
pass

@abc.abstractmethod
def save(self, config: CheckpointSaveConfig, metadata: dict):
pass

@abc.abstractmethod
def load(self, config: CheckpointLoadConfig, metadata: dict):
pass
93 changes: 93 additions & 0 deletions fast_llm/engine/checkpoint/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging

import safetensors.torch
import torch
import yaml

from fast_llm.engine.checkpoint.config import (
CheckpointLoadConfig,
CheckpointLoadMetadataConfig,
CheckpointSaveConfig,
Converter,
ModelConfigType,
export_safetensors_metadata,
)
from fast_llm.engine.checkpoint.safe_load import SafeLoad
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


class DistributedConverter(Converter):

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
return yaml.safe_load((config.path / "metadata.yaml").open("r"))

def save(self, config: CheckpointSaveConfig, metadata: dict):
if self._model.distributed_config.rank == 0:
yaml.safe_dump(metadata, (config.path / "metadata.yaml").open("w"))
num_shards = len(self._model.state_shard_names) if config.optimizer_state else 1
safetensors.torch.save_file(
tensors={"state_shard": self._model.state_shard[:num_shards]},
filename=config.path / f"rank_{self._model.distributed_config.rank}.safetensors",
metadata=export_safetensors_metadata(metadata),
)

def load(self, config: CheckpointLoadConfig, metadata: dict):
# TODO: More safety checks
loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm})
loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata)
num_shards = self._model.num_state_shards if config.optimizer_state else 1
Assert.eq(metadata["state_shard_names"][:num_shards], list(self._model.state_shard_names[:num_shards]))

if (
loaded_config.to_serialized(verbose=None) == self._model.fast_llm_config.to_serialized(verbose=None)
and config.optimizer_state
):
logger.info("Checkpoint format matches, using fast load")
# TODO: Add version without optimizer state?
with safetensors.safe_open(
config.path / f"rank_{self._model.distributed_config.rank}.safetensors",
framework="pt",
device=str(self._model.distributed.device),
) as f:
# TODO: Does this copy twice?
self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards])
else:
logger.info("Checkpoint format doesn't match, using safe load")
self._model.base_model_config.compare_architecture(loaded_config.base_model, config.compare_log_fn)
with SafeLoad(self._model, num_shards=num_shards) as context:
for rank in range(loaded_config.distributed.world_size):
loaded_model = self._model.__class__(
loaded_config.to_copy({("distributed", "rank"): rank}),
optimizer_state_names=self._model.state_shard_names[1:num_shards],
verbose=False,
)
path = config.path / f"rank_{rank}.safetensors"
logger.info(f"Loading from {path}")
# TODO: skip shards without overlap.
with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f:
# TODO: Use self_shard
loaded_shard = f.get_slice("state_shard")[:num_shards]
loaded_model.state_shard_meta.validate(loaded_shard)

# TODO: Improve num shard selection.
self_shard_split = self._model.state_shard[: loaded_shard.size(0)].split(
self._model.stage_shard_sizes, 1
)
loaded_shard_split = loaded_shard.split(loaded_model.stage_shard_sizes, 1)

counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device)
for loaded_shard_index, loaded_stage in enumerate(loaded_model.stages_on_device.values()):
loaded_shards = (
loaded_shard_split[loaded_shard_index].to(self._model.distributed.device).unbind(0)
)
for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()):
self_stage._copy_shard_overlaps( # noqa
loaded_stage,
self_shard_split[self_shard_index].unbind(0),
loaded_shards,
counter,
)
context.mark_as_loaded(counter.item())
87 changes: 58 additions & 29 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
import torch

from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
from fast_llm.engine.checkpoint.config import (
CHECKPOINT_VERSION,
CheckpointLoadConfig,
CheckpointLoadMetadataConfig,
CheckpointSaveConfig,
)
from fast_llm.engine.checkpoint.state_dict import StateDictConverter
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert

Expand Down Expand Up @@ -139,13 +146,12 @@ def import_weight(


class ExternalStateDictConverter(StateDictConverter):
base_file_name = "model"
_base_model_cls: type[BaseModelConfig]
_config_converters: list[ParamConverter]

def __init__(self, config: BaseModelArchitectureConfig):
self.config = config
Assert.custom(isinstance, config, self._base_model_cls.architecture_cls)
def __init__(self, model: "FastLLMModel"):
super().__init__(model)
Assert.custom(isinstance, self._model.base_model_config, self._base_model_cls.architecture_cls)
weight_converters = self._create_weight_converters()
self._export_converters = {
weight_converter.fast_llm_name[0]: weight_converter
Expand All @@ -166,17 +172,7 @@ def _create_weight_converters(self) -> list[WeightConverter]:
pass

@classmethod
@abc.abstractmethod
def load_config(cls, directory: pathlib.Path | str) -> dict[str, typing.Any]:
pass

@classmethod
@abc.abstractmethod
def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]):
pass

@classmethod
def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]:
def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]:
exported_config = {}
for converter in cls._get_config_converters():
value = converter.export_param(
Expand All @@ -190,7 +186,7 @@ def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.
return exported_config # Noqa

@classmethod
def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa
def _import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa
kwargs = {}
for converter in cls._get_config_converters():
value = converter.import_param(
Expand All @@ -204,11 +200,7 @@ def import_config(cls, config: dict[str, typing.Any], architecture_only: bool =
config_class = cls._base_model_cls.architecture_cls if architecture_only else cls._base_model_cls
return config_class.from_dict({}, kwargs)

@classmethod
def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
return cls(cls.import_config(config, architecture_only=architecture_only))

def convert_state_dict(
def _convert_state_dict(
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
out_state_dict = {}
Expand Down Expand Up @@ -262,19 +254,56 @@ class AutoStateDictConverter(ExternalStateDictConverter, abc.ABC):
converter_map: dict[str, type[ExternalStateDictConverter]]

@classmethod
def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
return cls.converter_map[config["model_type"]].import_config(config, architecture_only)
def get_converter_class(cls, format: str):
if format in cls.converter_map:
return cls.converter_map[format]
elif format == "auto":
return cls
else:
raise NotImplementedError(format)

# TODO: load_metadata???

@classmethod
def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
return cls.converter_map[config["model_type"]].from_config(config, architecture_only)
def _import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
# TODO: ???
return cls.converter_map[config["model_type"]]._import_config(config, architecture_only)


class HuggingfaceStateDictConverter(ExternalStateDictConverter, abc.ABC):
model_type: str | None = None
base_file_name = "model"

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
imported_model_config = cls._import_config(cls._load_config(config.path), True)
return {
# TODO: Avoid `to_serialized`?
"fast_llm_config": {"base_model": imported_model_config.to_serialized()},
# TODO: Handle "auto"?
"checkpoint_type": config.format,
"checkpoint_version": CHECKPOINT_VERSION,
}

def save(self, config: CheckpointSaveConfig, metadata: dict):
huggingface_config = self._export_config(self._model.base_model_config)
self._save_config(config.path, huggingface_config)
metadata = {
"fast_llm_metadata": metadata,
"model_config": huggingface_config,
"format": "pt",
}
super().save(config, metadata)

def load(self, config: CheckpointLoadConfig, metadata: dict):
assert not config.optimizer_state
self._model.base_model_config.compare_architecture(
self._base_model_cls.from_dict(metadata["fast_llm_config"]["base_model"]), config.compare_log_fn
)
super().load(config, metadata)

@classmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
def _get_key(cls, parameter_name: str, shard_name: str) -> str:
Assert.eq(shard_name, "weights")
return parameter_name

Expand All @@ -284,7 +313,7 @@ def _create_config_converters(cls) -> list[ParamConverter]:
return [ConstantExportParamConverter(None, "model_type", cls.model_type)]

@classmethod
def load_config(cls, directory: pathlib.Path | str):
def _load_config(cls, directory: pathlib.Path | str):
import transformers

config = transformers.AutoConfig.from_pretrained(directory).to_dict()
Expand All @@ -293,12 +322,12 @@ def load_config(cls, directory: pathlib.Path | str):
return config

@classmethod
def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]):
def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]):
import transformers

transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory)

def load_weights(
def _load_weights(
self,
directory: pathlib.Path | str,
device,
Expand Down
Loading

0 comments on commit 8fee762

Please sign in to comment.