Skip to content

[Prototype] Generalize dynamic config classes #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4b606b0
Generalize config classes
jlamypoirier Apr 30, 2025
4a67660
cli
jlamypoirier Apr 30, 2025
531f67d
Merge branch 'main' into generalize_dynamic_classes
jlamypoirier May 2, 2025
1823407
misc
jlamypoirier May 5, 2025
fe7acd9
stuff
jlamypoirier May 7, 2025
bee7a4b
Merge remote-tracking branch 'origin/main' into generalize_dynamic_cl…
jlamypoirier May 7, 2025
d41be60
stuff
jlamypoirier May 7, 2025
ec35a50
fixes
jlamypoirier May 8, 2025
3005c8c
stuff
jlamypoirier May 9, 2025
5735d21
fix
jlamypoirier May 9, 2025
a7e7362
Merge remote-tracking branch 'origin/main' into generalize_dynamic_cl…
jlamypoirier May 9, 2025
6357365
stuff
jlamypoirier May 9, 2025
207aef0
stuff
jlamypoirier May 12, 2025
31579bd
Bring back default_factory
jlamypoirier May 13, 2025
f79ed27
fix
jlamypoirier May 13, 2025
0a37209
Revert "fix"
jlamypoirier May 13, 2025
897cc0f
Revert "Bring back default_factory"
jlamypoirier May 13, 2025
8a49e0f
stuff
jlamypoirier May 14, 2025
843a621
fix
jlamypoirier May 14, 2025
aa3bc0b
stuff
jlamypoirier May 14, 2025
28d321e
stuff
jlamypoirier May 14, 2025
1bbd7fb
stuff
jlamypoirier May 14, 2025
a426450
Merge branch 'misc' into generalize_dynamic_classes
jlamypoirier May 14, 2025
87e11f0
stuff
jlamypoirier May 14, 2025
60a656e
stuff
jlamypoirier May 14, 2025
3595949
Minimalistic dynamic configs
jlamypoirier May 14, 2025
39b1a04
stuff
jlamypoirier May 15, 2025
f29b3fc
Merge branch 'minimalistic_dynamic_classes' into generalize_dynamic_c…
jlamypoirier May 15, 2025
743edaa
Simplify cli
jlamypoirier May 15, 2025
038106f
Simplify cli
jlamypoirier May 15, 2025
d0e86cb
Merge remote-tracking branch 'origin/main' into misc
jlamypoirier May 21, 2025
3fa314c
Merge branch 'misc' into minimalistic_dynamic_classes
jlamypoirier May 21, 2025
7bb6ee9
Merge remote-tracking branch 'origin/main' into minimalistic_dynamic_…
jlamypoirier May 21, 2025
a98a2ae
Merge branch 'minimalistic_dynamic_classes' into simplify_cli
jlamypoirier May 27, 2025
85e02b8
Merge branch 'main' into simplify_cli
jlamypoirier May 27, 2025
0fd49e4
Merge remote-tracking branch 'origin/main' into simplify_cli
jlamypoirier May 27, 2025
0a010f3
Merge branch 'main' into generalize_dynamic_classes
jlamypoirier May 27, 2025
c415e32
Merge branch 'simplify_cli' into generalize_dynamic_classes
jlamypoirier May 27, 2025
e199d0a
fix
jlamypoirier May 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions fast_llm/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import sys
import traceback

from fast_llm.config import ValidationError
from fast_llm.engine.config_utils.logging import configure_logging
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.config_utils.runnable import RunnableConfig

# Import these submodules to ensure classes are added to the dynamic class registry.
import fast_llm.data.auto # isort: skip
import fast_llm.engine.checkpoint.convert # isort: skip
import fast_llm.models.auto # isort: skip

logger = logging.getLogger(__name__)


def fast_llm_main(args: list[str] | None = None):
# TODO: Add hook to register model classes? (environment variable?)
# (Pre-)configure logging
configure_logging()
try:
RunnableConfig.parse_and_run(args)
except Exception as e:
if sys.gettrace():
raise
if isinstance(e, ValidationError):
log_main_rank(traceback.format_exc(), log_fn=logger.error)
else:
logger.critical(traceback.format_exc())
sys.exit(1)


if __name__ == "__main__":
fast_llm_main()
47 changes: 25 additions & 22 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

logger = logging.getLogger(__name__)


_AUTO_VALIDATE = True

MISSING = Tag("<MISSING>")
Expand Down Expand Up @@ -245,7 +244,7 @@ def _process_config_class(cls: type["Config"]):

def config_class[
T: Config
](registry: bool = False, dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
](dynamic_type: "dict[type[Config], str]|None" = None) -> typing.Callable[[type[T]], type[T]]:
"""
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""
Expand All @@ -270,11 +269,8 @@ def __init__(self, **kwargs):

wrapped.__init__ = __init__

wrapped._registry = Registry[str, type[wrapped]](wrapped.__name__, {}) if registry else None

if dynamic_type is not None:
for cls_, name in dynamic_type.items():
print(cls_, name, wrapped)
cls_.register_subclass(name, wrapped)

return wrapped
Expand Down Expand Up @@ -316,7 +312,7 @@ class Config(metaclass=ConfigMeta):
_setting_implicit_default: bool | None = Field(init=False)

# A registry for all the config classes.
_registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None
_registry: typing.ClassVar[Registry[str, type[typing.Self]]] = Registry[str, "type[Config]"]("Config", {})

def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Expand Down Expand Up @@ -371,17 +367,6 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
Validate a class and mark it as read-only
This should not be overridden in derived classes.
"""
# Should be handled in `from_dict`, but can fail if instantiating directly.
try:
expected_class = self.get_subclass(self.type)
except KeyError as e:
# Delayed instantiation error in `from_dict`.
raise ValidationError(*e.args)

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self.__class__, expected_class)

if not self._validated:
try:
self._validate()
Expand All @@ -401,6 +386,17 @@ def _validate(self) -> None:
Can be extended to add custom post-processing (typically before the super() call)
and validation (typically after)
"""
# Should be handled in `from_dict`, but can fail if instantiating directly.
try:
expected_class = self.get_subclass(self.type)
except KeyError as e:
# Delayed instantiation error in `from_dict`.
raise ValidationError(*e.args)

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
Assert.is_(self.__class__, expected_class)
Copy link
Contributor

@oleksost oleksost May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlamypoirier shouldn't this be something like Assert.custom(issubclass, expected_class, self.__class__) instead?

I.e. the dynamically inferred class (expected_class) should be a subclass of self.class, it should not necessarily match the class exactly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value is overridden in from_dict. Anyway this PR isn't too relevant anymore, I broke it into pieces


if self._abstract:
raise ValidationError(f"{type(self).__name__} is abstract")
if not self.__class_validated__:
Expand All @@ -409,6 +405,8 @@ def _validate(self) -> None:
)
errors = []
with self._set_implicit_default(None):
# Set the type field, or override it to the provided type with the actual class for clarity and safety.
self.type = self.__class__.__name__
for name, field in self.fields():
if not field.init or field._field_type != dataclasses._FIELD: # noqa
continue
Expand Down Expand Up @@ -486,6 +484,7 @@ def _validate_element(cls, value, type_, name: str):
raise FieldTypeError(f"Not a type.")
elif issubclass(type_, Config):
cls._validate_element_type(value, type_, strict=False)

value.validate(_is_validating=True)
else:
value = cls._validate_simple(value, type_)
Expand Down Expand Up @@ -737,7 +736,7 @@ def from_dict(
for keys, value in update.items():
set_nested_dict_value(default, keys, value, update_type)

return cls._from_dict(default, strict)
return cls._from_dict(default, strict=strict)

@classmethod
def from_flat_dict(
Expand Down Expand Up @@ -899,8 +898,6 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
@classmethod
def register_subclass(cls, name: str, cls_: type[typing.Self]) -> None:
Assert.custom(issubclass, cls_, cls)
if cls._registry is None:
raise NotImplementedError(f"Subclass `{cls.__name__}` doesn't have a registry..")
if name in cls._registry:
old_cls = cls._registry[name]
if old_cls.__name__ == cls_.__name__ and cls._registry[name].__module__ == cls_.__module__:
Expand All @@ -916,7 +913,7 @@ def get_subclass(cls, name: str | None):
return None
cls_ = None
for base_class in cls.__mro__:
if issubclass(base_class, Config) and base_class._registry is not None and name in base_class._registry:
if issubclass(base_class, Config) and name in base_class._registry:
if cls_ is None:
cls_ = base_class._registry[name]
if not issubclass(cls_, cls):
Expand All @@ -937,6 +934,12 @@ def __init_subclass__(cls):
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
Assert.eq(cls.__name__, cls.__qualname__)
cls._registry = Registry[str, type[cls]](cls.__name__, {})
if not cls._abstract:
Config.register_subclass(cls.__name__, cls)
short_name = cls.__name__.strip("Config")
if short_name != cls.__name__:
Config.register_subclass(short_name, cls)
for base_class in cls.__mro__:
if issubclass(base_class, Config) and base_class is not cls:
assert cls.__class_validated__, (
Expand Down Expand Up @@ -982,7 +985,7 @@ def __init_subclass__(cls):
cls.__annotations__[name] = base_class_field.type

# Type for the field. At the end of class definition to avoid shadowing builtin.
type: str | None = Field(
type: str = Field(
default=None,
desc="The config class name.",
hint=FieldHint.feature,
Expand Down
16 changes: 4 additions & 12 deletions fast_llm/data/auto.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
from fast_llm.data.preparator.config import DatasetPreparatorConfig
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig
from fast_llm.utils import Registry
"""
Import these submodules to ensure classes are added to the dynamic class registry.
"""

dataset_preparator_registry = Registry[str, DatasetPreparatorConfig](
"DatasetPreparator",
{
dataset_preparator.preparator_name: dataset_preparator
for dataset_preparator in [
GPTMemmapDatasetPreparatorConfig,
]
},
)
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
14 changes: 1 addition & 13 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,18 @@ class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
if sampling_parameters.use_preference_loss_spans:
stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch]
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
)


Expand Down Expand Up @@ -160,7 +149,6 @@ def get_iterator(
sampling_parameters = self._sampling_parameters[dataset_name]
Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length)
log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...")

return iter(
torch.utils.data.DataLoader(
self._datasets[dataset_name], # noqa
Expand Down
3 changes: 1 addition & 2 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ class GPTSamplingParameters(SamplingParameters):
sequence_length: int
vocab_size: int
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
cross_document_attention: bool = True
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
Expand All @@ -93,7 +92,7 @@ class GPTSamplingData(SamplingData):
truncate_documents: bool = True


@config_class(registry=True)
@config_class()
class GPTSampledDatasetConfig(SampledDatasetConfig):
pass

Expand Down
Loading