Skip to content

Fix and improve loading of distributed checkpoints #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 88 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
cb86f45
Test all models
jlamypoirier Jun 5, 2025
f8850e4
Parametrized dependencies
jlamypoirier Jun 6, 2025
478ac05
fixes
jlamypoirier Jun 6, 2025
d3b18a1
stuff
jlamypoirier Jun 9, 2025
8c64f03
fix
jlamypoirier Jun 9, 2025
c0f648c
fixes
jlamypoirier Jun 10, 2025
e92c311
stuff
jlamypoirier Jun 11, 2025
b877fb2
stuff
jlamypoirier Jun 11, 2025
907aef0
attempt
jlamypoirier Jun 11, 2025
1340903
attempt
jlamypoirier Jun 11, 2025
8aed0a3
Cleanup tests
jlamypoirier Jun 11, 2025
830a380
fixes
jlamypoirier Jun 11, 2025
13e1da5
fix
jlamypoirier Jun 12, 2025
aa0e821
Merge remote-tracking branch 'origin/main' into update_base_image
jlamypoirier Jun 12, 2025
45bb0ff
Merge remote-tracking branch 'origin/main' into update_base_image
jlamypoirier Jun 12, 2025
c467b63
Merge branch 'update_base_image' into test_all_models
jlamypoirier Jun 12, 2025
0dffe5c
fixes
jlamypoirier Jun 12, 2025
dcc5064
fixes
jlamypoirier Jun 12, 2025
9d415bc
fixes
jlamypoirier Jun 12, 2025
a6cce17
Merge branch 'update_base_image' into test_all_models
jlamypoirier Jun 12, 2025
68251c2
fixes
jlamypoirier Jun 12, 2025
68333ef
Merge remote-tracking branch 'origin/main' into test_all_models
jlamypoirier Jun 12, 2025
639d6c2
doc
jlamypoirier Jun 12, 2025
7465428
stuff
jlamypoirier Jun 12, 2025
ced34e0
stuff
jlamypoirier Jun 12, 2025
b328f07
stuff
jlamypoirier Jun 12, 2025
7ed804b
stuff
jlamypoirier Jun 12, 2025
890ad75
Merge branch 'improve_testing' into test_all_models
jlamypoirier Jun 12, 2025
6f00035
stuff
jlamypoirier Jun 12, 2025
e45ff6a
stuff
jlamypoirier Jun 12, 2025
8b16be2
Merge branch 'improve_testing' into test_all_models
jlamypoirier Jun 12, 2025
68db703
Merge branch 'main' into improve_testing
jlamypoirier Jun 12, 2025
e8615c2
Merge branch 'improve_testing' into test_all_models
jlamypoirier Jun 12, 2025
67d3c92
fix
jlamypoirier Jun 12, 2025
c2ae03d
fix
jlamypoirier Jun 13, 2025
31da2a8
misc
jlamypoirier Jun 13, 2025
c2ee8fe
stuff
jlamypoirier Jun 13, 2025
6c775e4
stuff
jlamypoirier Jun 13, 2025
4ba584b
Merge branch 'model_testing_configs' into test_all_models
jlamypoirier Jun 13, 2025
d41e0d5
misc
jlamypoirier Jun 13, 2025
59582c3
misc
jlamypoirier Jun 13, 2025
c0ca0b9
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 13, 2025
8ecf81e
fix
jlamypoirier Jun 13, 2025
2c009a8
Merge branch 'model_testing_configs' into test_all_models
jlamypoirier Jun 13, 2025
c5b29e2
Revert "misc"
jlamypoirier Jun 13, 2025
4071b70
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 13, 2025
bfa8d00
Merge branch 'model_testing_configs' into test_all_models
jlamypoirier Jun 13, 2025
edced8c
Cleanup tests
jlamypoirier Jun 13, 2025
9b904ad
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 13, 2025
58677d2
fix
jlamypoirier Jun 13, 2025
8d48d1f
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 13, 2025
c0b5e8e
Merge branch 'model_testing_configs' into cleanup_tests
jlamypoirier Jun 13, 2025
4171f27
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 13, 2025
e125fa9
move to directory
jlamypoirier Jun 13, 2025
d61445a
Merge remote-tracking branch 'origin/main' into improve_testing
jlamypoirier Jun 16, 2025
9c5883e
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 16, 2025
5a928f0
Merge branch 'model_testing_configs' into cleanup_tests
jlamypoirier Jun 16, 2025
7dc7f53
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 16, 2025
d164f25
fixes
jlamypoirier Jun 16, 2025
0889d2f
Merge branch 'main' into improve_testing
jlamypoirier Jun 16, 2025
006e1ff
Merge branch 'improve_testing' into model_testing_configs
jlamypoirier Jun 16, 2025
8dc3abe
Merge remote-tracking branch 'origin/main' into model_testing_configs
jlamypoirier Jun 16, 2025
7a04c6a
Merge branch 'model_testing_configs' into cleanup_tests
jlamypoirier Jun 16, 2025
7eb4c5d
Merge remote-tracking branch 'origin/main' into cleanup_tests
jlamypoirier Jun 16, 2025
645eeb1
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 16, 2025
9179127
fix
jlamypoirier Jun 16, 2025
d97e4c1
fix
jlamypoirier Jun 16, 2025
c95e8eb
Fix dropless mlp
jlamypoirier Jun 17, 2025
c4a34f0
Merge branch 'main' into cleanup_tests
jlamypoirier Jun 17, 2025
bdf37ca
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 17, 2025
8667b9d
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 17, 2025
468ed7e
fix
jlamypoirier Jun 17, 2025
eb734bd
fix
jlamypoirier Jun 17, 2025
141ab00
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 17, 2025
1e16c15
stuff
jlamypoirier Jun 19, 2025
4f74237
Merge branch 'main' into test_all_models
jlamypoirier Jun 19, 2025
58d4275
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 19, 2025
c338d44
fixes
jlamypoirier Jun 19, 2025
cc806ef
Merge branch 'main' into cleanup_tests
jlamypoirier Jun 19, 2025
9cba39b
Merge branch 'cleanup_tests' into test_all_models
jlamypoirier Jun 19, 2025
33d5595
Merge branch 'test_all_models' into fix_dropless_mlp
jlamypoirier Jun 19, 2025
ef7ab29
Merge remote-tracking branch 'origin/main' into fix_dropless_mlp
jlamypoirier Jun 19, 2025
a26297d
Merge branch 'fix_dropless_mlp' into distributed_load
jlamypoirier Jun 19, 2025
f4a86bf
Merge remote-tracking branch 'origin/main' into distributed_load
jlamypoirier Jun 19, 2025
452397c
stuff
jlamypoirier Jun 20, 2025
0329424
fix
jlamypoirier Jun 20, 2025
ec33c6f
fixes
jlamypoirier Jun 20, 2025
2a08b14
Parallel tests
jlamypoirier Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ENV PIP_CONSTRAINT=""
# We need to compile from the repo because of https://github.com/state-spaces/mamba/issues/720 (same for causal-conv1d)
# We set the number of workers to avoid OOM when compiling on laptop. (TODO: Can we make it configurable?)
RUN MAX_JOBS=4 pip install --no-build-isolation "causal-conv1d@git+https://github.com/Dao-AILab/causal-conv1d@2a288a1"
RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@74729d0"
RUN MAX_JOBS=4 pip install --no-build-isolation "mamba_ssm[causal-conv1d]@git+https://github.com/state-spaces/mamba@4a8a2a2"
# Copy dependency files with universal write permissions for all users.
COPY --chmod=777 setup.py setup.cfg pyproject.toml ./
COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
Expand Down
58 changes: 39 additions & 19 deletions fast_llm/engine/checkpoint/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm}))
shard_names = self.get_shard_names(config)
# Make sure all shards to load are in the checkpoint.
Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards))
Assert.leq(set(shard_names), set(loaded_metadata.shards))
Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names))

# Using `log_fn=bool` sets the output to true if the error list is non-empty.
Expand Down Expand Up @@ -95,7 +95,13 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
)
path = config.path / f"rank_{rank}.safetensors"
log_main_rank(f"Loading from {path}", log_fn=logger.info)
# TODO: skip shards without overlap.

# First do a dry run to check if there is any overlap.
if not self._has_shard_overlaps(loaded_model):
# No overlap found, skip this file.
continue

# TODO: Lazy loading?
with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f:
# TODO: Use self_shard
if "state_shard" in f.keys():
Expand All @@ -111,22 +117,36 @@ def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
shard_name: f.get_tensor(f"{shard_name}_shard") for shard_name in shard_names
}

for shard_name, loaded_shard in loaded_shards.items():
loaded_model.get_shard_meta(shard_name).validate(loaded_shard)

self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in shard_names}

counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device)
for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards):
for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards):
self_fsdp.copy_shard_overlaps(
loaded_fsdp,
self_fsdp_shards,
loaded_fsdp_shards,
counter,
self._model.distributed.device,
)

context.mark_as_loaded(counter.item())
self._copy_shard_overlaps(loaded_model, loaded_shards, context)

return loaded_metadata.metadata

def _has_shard_overlaps(self, loaded_model) -> bool:
for _, loaded_fsdp, _ in loaded_model.split_shards_by_fsdp({}):
for _, self_fsdp, _ in self._model.split_shards_by_fsdp({}):
counter = self_fsdp.copy_shard_overlaps(
loaded_fsdp,
None,
None,
self._model.distributed.device,
)
if counter:
return True
return False

def _copy_shard_overlaps(self, loaded_model, loaded_shards, context):
for shard_name, loaded_shard in loaded_shards.items():
loaded_model.get_shard_meta(shard_name).validate(loaded_shard)

self_shards = {shard_name: self._model.get_shard(shard_name) for shard_name in loaded_shards}

for _, loaded_fsdp, loaded_fsdp_shards in loaded_model.split_shards_by_fsdp(loaded_shards):
for _, self_fsdp, self_fsdp_shards in self._model.split_shards_by_fsdp(self_shards):
counter = self_fsdp.copy_shard_overlaps(
loaded_fsdp,
self_fsdp_shards,
loaded_fsdp_shards,
self._model.distributed.device,
)
for parameter, count in counter.items():
context.mark_as_loaded(count, parameter, True)
104 changes: 54 additions & 50 deletions fast_llm/engine/checkpoint/safe_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from torch.distributed import all_reduce

from fast_llm.core.distributed import add_ephemeral_timeout
from fast_llm.engine.multi_stage.config import ShardName
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.functional.triton.pointwise import triton_fill
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,14 +48,17 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_type:
self._validate()

def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None) -> None:
def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None, partial: bool = False) -> None:
self._loaded += count
if parameter is not None:
parameter_name, shard_name = parameter
if shard_name not in self._loaded_parameters:
self._loaded_parameters[shard_name] = {}
Assert.not_incl(parameter_name, self._loaded_parameters[shard_name])
self._loaded_parameters[shard_name][parameter_name] = count
if not partial and parameter_name in self._loaded_parameters[shard_name]:
raise ValueError(f"Duplicate loaded parameter ({parameter_name}, {shard_name})")
self._loaded_parameters[shard_name][parameter_name] = (
self._loaded_parameters[shard_name].get(parameter_name, 0) + count
)

def _validate(self) -> None:
errors = []
Expand Down Expand Up @@ -105,7 +108,7 @@ def _check_missing(self, errors: list[str]) -> None:
f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {parameter_name} in stage {stage.index}, shard {shard_name}"
f" (locally {local_missing_for_param:,} out of {local_values.numel():,})"
)
missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item()
missing_for_pad = buffer[-fsdp._global_pad :].isnan().sum().item() if fsdp._global_pad > 0 else 0
if missing_for_pad > 0:
global_total += missing_for_pad
local_missing_for_pad = (
Expand All @@ -127,52 +130,53 @@ def _check_missing(self, errors: list[str]) -> None:
)

def _check_parameters(self, errors: list[str]) -> None:
loaded_shard_names = set(self._loaded_parameters)
shard_names = set(self._self_shards)
if loaded_shard_names != shard_names:
errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}")
for shard_name in shard_names & loaded_shard_names:
counter_per_parameter = {
parameter_name: self._loaded_parameters[shard_name].pop(parameter_name, None)
for parameter_name in self._model.parameter_names
}
for parameter_name, count in self._loaded_parameters[shard_name].items():
errors.append(f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})')
for parameter_name, counter in counter_per_parameter.items():
if self._model.is_parameter_on_device(parameter_name):
if counter is None:
errors.append(f'Missing parameter "{parameter_name}" for shard "{shard_name}"')
elif counter is not None and counter > 0:
errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"')
if self._distributed.world_group is not None:
counter_list = []
for parameter_name, counter in counter_per_parameter.items():
parameter_stage = self._model.get_parameter_stage(parameter_name)
parameter_meta = parameter_stage.get_parameter_meta(parameter_name)
if (
counter is None
or (not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0)
or parameter_stage.is_tied_weight_copy
):
# Ignore the counter from missing or duplicate tensors.
counter = 0
counter_list.append(counter)

counter_tensor = torch.tensor(counter_list, dtype=torch.int64).to(self._distributed.device)

add_ephemeral_timeout(self._distributed.world_group, self._timeout)
all_reduce(counter_tensor, group=self._distributed.world_group)
counter_per_parameter = {
parameter_name: counter
for parameter_name, counter in zip(counter_per_parameter, counter_tensor.tolist())
}
for parameter_name, counter in counter_per_parameter.items():
parameter_size = (
self._model.get_parameter_stage(parameter_name)
.get_parameter_meta(parameter_name)
.global_shape.numel()
if set(self._loaded_parameters) != set(self._self_shards):
errors.append(f"Incorrect loaded shards: {tuple(self._loaded_parameters)}!={tuple(self._self_shards)}")

counters = []
# Compare local counts against expected values.
for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters:
for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]:
counter = self._loaded_parameters[shard_name].pop(parameter_meta.tensor_name, 0)
local_size = (
fsdp.get_parameter_size_in_shard(parameter_name, shard_name)
if self._model.is_parameter_on_device(parameter_name)
else 0
)
if counter != local_size:
errors.append(
f'Local counter mismatch for parameter "{parameter_name}"'
f' and shard "{shard_name}": loaded {counter}, expected {local_size}'
)
# Accumulate in a list for global counter check.
if (
not parameter_meta.is_tensor_parallel and self._distributed.config.tensor_rank != 0
) or stage.is_tied_weight_copy:
# Ignore the counter from duplicate tensors.
counter = 0
counters.append(counter)

# Check for unexpected parameters.
for shard_name, loaded in self._loaded_parameters.items():
for parameter_name, count in loaded.items():
errors.append(f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})')

# All-reduce to get global counts.
if self._distributed.world_group is not None:
counter_tensor = torch.tensor(counters, dtype=torch.int64).to(self._distributed.device)
# This may be the first distributed barrier after loading, so we need to wait for everyone to finish.
add_ephemeral_timeout(self._distributed.world_group, self._timeout)
all_reduce(counter_tensor, group=self._distributed.world_group)
counters = counter_tensor.tolist()

# Compare global counts against expected values.
for stage, fsdp, parameter_name, parameter_meta in self._model.stages_fsdp_parameters:
for shard_name in self._self_shards if fsdp.requires_grad else [ShardName.weights]:
counter = counters.pop(0)
parameter_size = parameter_meta.global_shape.numel()
if counter != parameter_size:
errors.append(
f'Global counter mismatch for parameter "{parameter_name}" and shard "{shard_name}": {counter} != {parameter_size}'
f'Global counter mismatch for parameter "{parameter_name}"'
f' and shard "{shard_name}": loaded {counter}, expected {parameter_size}'
)
assert not counters
5 changes: 4 additions & 1 deletion fast_llm/engine/config_utils/tensor_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def add_tensor_dim(self, dim: TensorDim) -> None:
else:
if dim.parallel_dim is not None:
assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name
Assert.eq(dim.parallel_dim, self._distributed_config.distributed_dims[dim.parallel_dim.name])
Assert.eq(
dim.parallel_dim.__dict__,
self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__,
)
self._tensor_dims[dim.name] = dim

def get_tensor_dim(self, name: str) -> TensorDim:
Expand Down
Loading
Loading