Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix things, prepare for v0.2 #85

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix things, prepare for v0.2
  • Loading branch information
jlamypoirier committed Dec 4, 2024
commit 1726b8278aa86a4571d6a97b7f806fc0367b0c93
8 changes: 4 additions & 4 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class _ConfigDictFormat(str, enum.Enum):
# TODO v0.2: delete class
# TODO v0.3: delete class
flat = "flat"
nested = "nested"
tuple = "tuple"
Expand Down Expand Up @@ -681,7 +681,7 @@ def from_flat_dict(
default: dict[str, typing.Any],
strict: bool = True,
):
# TODO v0.2: Remove flat format
# TODO v0.3: Remove flat format
return cls._from_dict(default, strict, True)

@classmethod
Expand All @@ -692,10 +692,10 @@ def _from_dict(
flat: bool = False,
):
cls._check_abstract()
# TODO v0.2: Remove flat format
# TODO v0.3: Remove flat format
out_arg_dict = {}

# TODO v0.2: Remove backward compatibility fix
# TODO v0.3: Remove backward compatibility fix
if "__class__" in default:
del default["__class__"]

Expand Down
3 changes: 2 additions & 1 deletion fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
# A utility function to check for tensor-parallel (or other) mismatches.
all_tensors = tensor.new_empty((group.size(),) + tensor.shape)
all_gather_into_tensor(all_tensors, tensor, group)

mismatches = (all_tensors != tensor).any(dim=0)
num_mismatches = mismatches.sum().item()
if num_mismatches > 0:
num_nans = tensor.isnan().sum().item()
logger.error(
f"MISMATCH {name} {num_mismatches:,} / {tensor.numel():,}"
+ ("" if num_nans > 0 else f" [{num_mismatches:,} nans detected locally]")
+ ("" if num_nans == 0 else f" [{num_nans:,} nans detected locally]")
)


Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int
seed=self._distributed_config.seed,
cache_directory=(
self._dataset_prefixes[name].parent
if self._cache_directory is None
if self._cache_directory is None and isinstance(self._dataset_prefixes[name], pathlib.Path)
else self._cache_directory
),
verbose=self._num_datasets <= 5,
Expand Down
24 changes: 1 addition & 23 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import pathlib
import typing
import warnings

import yaml

Expand Down Expand Up @@ -56,7 +55,7 @@ def __fast_llm_serialize__(cls):


class DistributedCheckpointFormat(CheckpointFormat):
# TODO v0.2: Add `enforce_version_match`
# TODO v0.3: Add `enforce_version_match`
name: typing.ClassVar[str] = "distributed"
enforce_architecture_match: typing.ClassVar[bool] = True

Expand Down Expand Up @@ -120,27 +119,6 @@ def setup(self, model_config: typing.Union["FastLLMModelConfig", type["FastLLMMo
else:
self.format = format

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
):
# TODO v0.2: Remove.
cls._handle_renamed_field(default, "imported_type", "model_type")
if "model_type" in default:
warnings.warn(
"`CheckpointConfigBase.model_type` is deprecated."
" Instead, use the model name directly as the checkpoint format."
)
if default.get("format", None) in ("huggingface", "external"):
default["format"] = default.get("model_type")
if default["format"] is None:
default["format"] = "auto"
del default["model_type"]
return super()._from_dict(default, strict, flat)


@config_class()
class CheckpointStateConfigBase(CheckpointConfigBase):
Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _load_config(cls, directory: pathlib.Path | str) -> dict:

@classmethod
def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]:
# TODO v0.2: not used in this class
# TODO v0.3: not used in this class
exported_config = {}
for converter in cls._get_config_converters():
value = converter.export_param(
Expand All @@ -211,7 +211,6 @@ def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing
def _import_config(
cls, config: dict[str, typing.Any], architecture_only: bool = False
) -> BaseModelArchitectureConfig: # noqa
# TODO v0.2: not used in this class
kwargs = {}
for converter in cls._get_config_converters():
value = converter.import_param(
Expand Down
16 changes: 3 additions & 13 deletions fast_llm/engine/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import json
import logging
import typing

Expand Down Expand Up @@ -114,18 +113,9 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler):

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
# TODO v0.2: Remove backward compatibility.
old_path = config.path / "state_dict.safetensors.index.json"
if old_path.is_file():
logger.warning(f"Loading metadata from {old_path} (deprecated format)")
serialized_metadata = json.load((config.path / "state_dict.safetensors.index.json").open("r"))
metadata = CheckpointMetadata.from_dict(serialized_metadata)["metadata"]
metadata.metadata["index"] = serialized_metadata["weight_map"]
else:
path = config.path / f"metadata.yaml"
logger.warning(f"Loading metadata from {path}")
metadata = CheckpointMetadata.from_dict(yaml.safe_load(path.open("r")))
return metadata
path = config.path / f"metadata.yaml"
logger.warning(f"Loading metadata from {path}")
return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@charlesGE can you confirm that we don't need this anymore please?


def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict):
path = config.path / f"metadata.yaml"
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RunConfig(Config):
tensor_logs: TensorLogsConfig = Field(
default_factory=TensorLogsConfig, desc="Configuration for debug tensor logs.", hint=FieldHint.logging
)
# TODO v0.2: Adjust (now only affects logging to file).
# TODO v0.3: Adjust (now only affects logging to file).
structured_logs: bool = Field(
default=True, desc="Configure logging to the Fast-LLM format.", hint=FieldHint.logging
)
Expand Down
24 changes: 19 additions & 5 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ class DistributedDimNames:
pipeline = "pipeline"
sequence_data = "sequence_data"
batch_data = "batch_data"
tensor_and_sequence_data = "tensor_and_sequence_data"


@config_class()
class DistributedConfig(Config):
"""
Configuration for the distributed setup.
Also include variables for global settings such as data types, random seeds, initialization parameters.
TODO v0.2: Move these unrelated variables elsewhere.
TODO v0.3: Move these unrelated variables elsewhere.
TODO: Avoid hard-coding distributed dims (use derived class?)
TODO: Separate distributed space from config?
"""
Expand Down Expand Up @@ -197,19 +198,19 @@ class DistributedConfig(Config):
valid=check_field(Assert.gt, 0),
)
seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional)
# TODO v0.2: Rename to compute_dtype (not just for training), move elsewhere
# TODO v0.3: Rename to compute_dtype (not just for training), move elsewhere
training_dtype: DataType = Field(
default=DataType.float32,
desc="The data type used for the forward and backward passes.",
hint=FieldHint.core,
)
# TODO v0.2: move elsewhere
# TODO v0.3: move elsewhere
optimization_dtype: DataType = Field(
default=DataType.float32,
desc="The data type used for the optimizer.",
hint=FieldHint.expert,
)
# TODO v0.2: move random state elsewhere
# TODO v0.3: move random state elsewhere
# Extra seed parameters (can usually be left alone)
dp_seed_shift: int = Field(
default=_BIG_PRIMES[0], desc="Seed shift for extra randomness.", hint=FieldHint.optional
Expand Down Expand Up @@ -330,6 +331,19 @@ def _validate(self):
parent=DistributedDimNames.data,
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}",
parent=(
DistributedDimNames.tensor
if self.sequence_data_parallel == 1
else DistributedDimNames.sequence_data if self.tensor_parallel == 1 else DistributedDimNames.world
),
)
)

super()._validate()

Expand Down Expand Up @@ -361,7 +375,7 @@ def _from_dict(
strict: bool = True,
flat: bool = False,
):
# TODO v0.2: Remove backward compatibility fix
# TODO v0.3: Remove backward compatibility fix
if "sequence_first" in default and strict:
del default["sequence_first"]
if "separate_init_generators" in default and strict:
Expand Down
31 changes: 23 additions & 8 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import torch
import torch.distributed

from fast_llm.engine.distributed.config import MAX_SEED, DistributedConfig, DistributedDim, PhaseType
from fast_llm.engine.distributed.config import (
MAX_SEED,
DistributedConfig,
DistributedDim,
DistributedDimNames,
PhaseType,
)
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -49,12 +55,14 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
Assert.eq(distributed_dim.name, name)
self.add_group(distributed_dim)

self.world_group = self._process_groups["world"]
self.data_group = self._process_groups["data"]
self.pipeline_group = self._process_groups["pipeline"]
self.tensor_group = self._process_groups["tensor"]
self.sequence_data_group = self._process_groups["sequence_data"]
self.batch_data_group = self._process_groups["batch_data"]
self.world_group = self._process_groups[DistributedDimNames.world]
self.data_group = self._process_groups[DistributedDimNames.data]
self.pipeline_group = self._process_groups[DistributedDimNames.pipeline]
self.tensor_group = self._process_groups[DistributedDimNames.tensor]
self.sequence_data_group = self._process_groups[DistributedDimNames.sequence_data]
self.batch_data_group = self._process_groups[DistributedDimNames.batch_data]
self.tensor_and_sequence_data_group = self._process_groups[DistributedDimNames.tensor_and_sequence_data]

self._config.log_first_rank(f"Setting random seeds...")

dp_shift = self._config.dp_seed_shift * self._config.data_rank
Expand Down Expand Up @@ -139,8 +147,15 @@ def add_group(self, distributed_dim: DistributedDim):
def set_step(self, step: int, phase: PhaseType):
"""
Reseed pytorch for a given training step.
TODO v0.2: Move unrelated content elsewhere.
TODO v0.3: Move unrelated content elsewhere.
"""
seed_shift = step * self._config.sample_seed_shift + self._phase_seeds_shifts[phase]
self.pp_generator.manual_seed((self._pp_seed + seed_shift) % MAX_SEED)
self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED)

def __del__(self):
# Shutdown the process group backend explicitly to prevent a nccl warning.
# We can't call `destroy_process_group` directly because pytorch doesn't know about it.
for group in self._process_groups.values():
if group is not None and hasattr(group, "_shutdown"):
group._shutdown() # noqa
2 changes: 1 addition & 1 deletion fast_llm/engine/huggingface/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _get_config_dict(
)

config_dict = {"fast_llm_config": fast_llm_config}
# TODO v0.2: ???
# TODO v0.3: ???
if "huggingface_config" in metadata:
assert "fast_llm_config" not in metadata["huggingface_config"]
config_dict.update(metadata.pop("huggingface_config"))
Expand Down
12 changes: 4 additions & 8 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,6 @@ def get_checkpoint_format(cls, format: typing.Union[type[CheckpointFormat], str]
format_ = cls.get_checkpoint_format(format.name)
Assert.is_(format, format_)
return format_
# TODO v0.2: Remove backward compatibility.
if format == "state_dict":
format = "fast_llm"
for format_ in cls.checkpoint_formats:
if format_.name == format:
return format_
Expand All @@ -247,7 +244,6 @@ def get_huggingface_model_class(cls) -> type["HuggingfacePreTrainedModel"]:

@classmethod
def get_base_model_config_class(cls) -> type[BaseModelConfig]:
# TODO v0.2: Still needed?
return cls.get_field("base_model").type

@classmethod
Expand All @@ -270,8 +266,8 @@ def from_metadata(
updates: dict[str | tuple[str, ...], typing.Any] | None = None,
):
# TODO: Standardize to *updates?
# TODO v0.2: Update, remove support for older checkpoints.
if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1):
# TODO v0.3: Update, remove support for older checkpoints.
if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2):
raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}")
pretrained_config = cls.from_dict(metadata.config)
if not pretrained.load_config.load_architecture:
Expand Down Expand Up @@ -424,7 +420,7 @@ def _from_dict(
strict: bool = True,
flat: bool = False,
):
# TODO v0.2: Remove backward compatibility.
# TODO v0.3: Remove backward compatibility.
cls._handle_renamed_field(default, "checkpoint_type", "format")
cls._handle_renamed_field(default, "checkpoint_version", "fast_llm_version")
cls._handle_renamed_field(default, "fast_llm_config", "config")
Expand All @@ -445,7 +441,7 @@ def _from_dict(
model_config_class = model_registry[model_config_class]
default["model"] = model_config_class

# TODO v0.2: Remove backward compatibility.
# TODO v0.3: Remove backward compatibility.
if "config" not in default:
default["config"] = {
"base_model": model_config_class.get_base_model_config_class().from_flat_dict(
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/fast_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_pretrained(
metadata = cls.config_class.load_metadata(pretrained_config)
config = cls.config_class.from_metadata(pretrained_config, metadata, default_config, config_updates)
if mode.support_training:
# TODO v0.2: Make metadata.shards mandatory?
# TODO v0.3: Make metadata.shards mandatory?
if metadata.shards:
if optimizer_state_names is None:
optimizer_state_names = metadata.shards[1:]
Expand Down
12 changes: 10 additions & 2 deletions fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ def invalidate_buffer(self):
self._is_restored = False

def _log_layer_forward(self, output, kwargs, i):
if self._config.debug_tensor_parallel and self._distributed.tensor_group is not None:
if (
self._config.debug_tensor_parallel
and self._distributed.tensor_group is not None
and not self._meta_outputs[i].is_tensor_parallel
):
check_parallel_match(output, self._distributed.tensor_group, f"layer {self._layer_range[i]} fw")
if self._config.debug_layer_outputs:
name = f"layer {self._layer_range[i]} fw"
Expand All @@ -224,7 +228,11 @@ def _log_layer_forward(self, output, kwargs, i):
def _log_layer_backward(self, input_, kwargs, i):
if not input_.requires_grad:
return
if self._config.debug_tensor_parallel and self._distributed.tensor_group is not None:
if (
self._config.debug_tensor_parallel
and self._distributed.tensor_group is not None
and not self._meta_inputs[i].is_tensor_parallel
):
input_.register_hook(
lambda grad: check_parallel_match(
grad, self._distributed.tensor_group, f"layer {self._layer_range[i]} bw"
Expand Down
5 changes: 2 additions & 3 deletions fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,15 @@ def to_delete(self, iterations: list[int]):
class TrainingCheckpointConfig(TrainingCheckpointBaseConfig):
_abstract = False
save_name: typing.ClassVar[str] = "checkpoint"
# TODO v0.2: Rename to `checkpoint` so we don't need this extra variable?
interval = FieldUpdate(
desc="The number of training iterations between each checkpoint." " Setting to None will disable checkpoints."
desc="The number of training iterations between each checkpoint. Setting to None will disable checkpoints."
)
offset = FieldUpdate(desc="Offset for the first checkpoint.")
callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after checkpoint.")
keep: int | None = FieldUpdate(default=5)

def get_save_directory(self, experiment_directory: pathlib.Path) -> pathlib.Path:
# TODO v0.2: Remove backward compatibility.
# TODO v0.3: Remove backward compatibility.
old_path = experiment_directory / "checkpoints"
new_path = experiment_directory / "checkpoint"
return old_path if old_path.is_dir() and not new_path.is_dir() else new_path
Expand Down
Loading
Loading