Skip to content

Commit

Permalink
Checkpoint metadata (ServiceNow#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier authored Oct 31, 2024
1 parent 120c89c commit 519e9cb
Show file tree
Hide file tree
Showing 19 changed files with 287 additions and 199 deletions.
1 change: 1 addition & 0 deletions fast_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.1.0"
2 changes: 1 addition & 1 deletion fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(

@classmethod
def architecture_cls(cls):
return cls.config_cls.architecture_cls
return cls.config_cls.architecture_class

@abc.abstractmethod
def get_layers(self):
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BaseModelConfig(BaseModelArchitectureConfig):
# TODO: Find better name?
"""

architecture_cls: typing.ClassVar[type[BaseModelArchitectureConfig]]
architecture_class: typing.ClassVar[type[BaseModelArchitectureConfig]]

def get_architecture(self):
return self.architecture_cls.from_dict(self, strict=False)
return self.architecture_class.from_dict(self, strict=False)
77 changes: 51 additions & 26 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@

import yaml

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.multi_stage.config import FastLLMModelConfig
from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel

logger = logging.getLogger(__name__)

# TODO: Use packaging.version? (Safer but extra requirement)
CHECKPOINT_VERSION = "0.1"
KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1")


def export_safetensors_metadata(metadata):
"""
Expand Down Expand Up @@ -100,16 +96,6 @@ def load_fast_llm(self):
return self == ModelConfigType.fast_llm


@config_class()
class CheckpointPathConfigBase(Config):
_abstract = True
path: pathlib.Path | None = Field(
default=None,
desc="Location of the checkpoint.",
hint=FieldHint.core,
)


@config_class()
class CheckpointConfigBase(Config):
_abstract = True
Expand Down Expand Up @@ -148,10 +134,11 @@ def _from_dict(


@config_class()
class CheckpointStateConfigBase(Config):
class CheckpointStateConfigBase(CheckpointConfigBase):
_abstract = True
model_weights: bool = Field(default=True, desc="Save/load the model weights.", hint=FieldHint.feature)
optimizer_state: bool = Field(default=False, desc="Save/load the optimizer state.", hint=FieldHint.feature)
# Defaults and descriptions are set in derived classes.
model_weights: bool = Field(default=True, hint=FieldHint.feature)
optimizer_state: bool = Field(default=None, hint=FieldHint.feature)

@classmethod
def _from_dict(
Expand All @@ -166,7 +153,7 @@ def _from_dict(


@config_class()
class CheckpointSaveConfigBase(Config):
class CheckpointSaveConfigBase(CheckpointConfigBase):
_abstract = True
parameters_per_file: int = Field(
default=2**32,
Expand All @@ -182,17 +169,41 @@ class CheckpointSaveConfigBase(Config):


@config_class()
class CheckpointSaveMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase):
class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase):
model_weights: bool = FieldUpdate(desc="Save the model weights.")
optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.")

def _validate(self):
if self.optimizer_state is None:
# TODO: Make sure it's a type
self.optimizer_state = self.format.support_optimizer
super()._validate()
if self.optimizer_state:
assert self.format.support_optimizer


@config_class()
class CheckpointPathConfigBase(CheckpointConfigBase):
_abstract = True
path: pathlib.Path | None = Field(
default=None,
desc="Location of the checkpoint.",
hint=FieldHint.core,
)


@config_class()
class CheckpointSaveMetadataConfig(CheckpointPathConfigBase):
_abstract = False


@config_class()
class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateConfigBase, CheckpointSaveConfigBase):
class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConfigBase):
_abstract = False


@config_class()
class CheckpointLoadMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase):
class CheckpointLoadMetadataConfig(CheckpointPathConfigBase):
_abstract = False

load_config: ModelConfigType = Field(
Expand All @@ -215,6 +226,14 @@ def compare_log_fn(self):
class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase):
_abstract = False

model_weights: bool = FieldUpdate(desc="Load the model weights.")
optimizer_state: bool = FieldUpdate(default=False, desc="Load the optimizer state.")

def _validate(self):
super()._validate()
if self.optimizer_state:
assert self.format.support_optimizer


class CheckpointHandler(abc.ABC):
format: typing.ClassVar[type[CheckpointFormat]]
Expand All @@ -226,13 +245,19 @@ def __init__(self, model: "FastLLMModel"):

@classmethod
@abc.abstractmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
pass

@abc.abstractmethod
def save(self, config: CheckpointSaveConfig, metadata: dict):
def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"):
pass

@abc.abstractmethod
def load(self, config: CheckpointLoadConfig, metadata: dict):
def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"):
pass

def get_num_shards(self, config: CheckpointStateConfigBase):
return len(self._model.state_shard_names) if config.optimizer_state else 1

def get_shard_names(self, config: CheckpointStateConfigBase):
return self._model.state_shard_names if config.optimizer_state else self._model.state_shard_names[:1]
22 changes: 12 additions & 10 deletions fast_llm/engine/checkpoint/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from fast_llm.engine.checkpoint.safe_load import SafeLoad
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.multi_stage.config import CheckpointMetadata
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)
Expand All @@ -28,24 +29,25 @@ class DistributedCheckpointHandler(CheckpointHandler):

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
return yaml.safe_load((config.path / "metadata.yaml").open("r"))
return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r")))

def save(self, config: CheckpointSaveConfig, metadata: dict):
def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata):
serialized_metadata = metadata.to_serialized()
if self._model.distributed_config.rank == 0:
yaml.safe_dump(metadata, (config.path / "metadata.yaml").open("w"))
num_shards = len(self._model.state_shard_names) if config.optimizer_state else 1
yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w"))
safetensors.torch.save_file(
tensors={"state_shard": self._model.state_shard[:num_shards]},
tensors={"state_shard": self._model.state_shard[: self.get_num_shards(config)]},
filename=config.path / f"rank_{self._model.distributed_config.rank}.safetensors",
metadata=export_safetensors_metadata(metadata),
metadata=export_safetensors_metadata(serialized_metadata),
)

def load(self, config: CheckpointLoadConfig, metadata: dict):
def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata):
# TODO: More safety checks
loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm})
loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata)
num_shards = self._model.num_state_shards if config.optimizer_state else 1
Assert.eq(metadata["state_shard_names"][:num_shards], list(self._model.state_shard_names[:num_shards]))
num_shards = self.get_num_shards(config)
shard_names = self.get_shard_names(config)
Assert.eq(metadata.shards[:num_shards], list(shard_names))

same_format = (
loaded_config.to_serialized(verbose=None) == self._model.fast_llm_config.to_serialized(verbose=None)
Expand All @@ -72,7 +74,7 @@ def load(self, config: CheckpointLoadConfig, metadata: dict):
for rank in range(loaded_config.distributed.world_size):
loaded_model = self._model.__class__(
loaded_config.to_copy({("distributed", "rank"): rank}),
optimizer_state_names=self._model.state_shard_names[1:num_shards],
optimizer_state_names=shard_names[1:],
verbose=False,
)
path = config.path / f"rank_{rank}.safetensors"
Expand Down
Loading

0 comments on commit 519e9cb

Please sign in to comment.