Skip to content

Simplify cli #269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
aa3bc0b
stuff
jlamypoirier May 14, 2025
28d321e
stuff
jlamypoirier May 14, 2025
1bbd7fb
stuff
jlamypoirier May 14, 2025
3595949
Minimalistic dynamic configs
jlamypoirier May 14, 2025
39b1a04
stuff
jlamypoirier May 15, 2025
743edaa
Simplify cli
jlamypoirier May 15, 2025
038106f
Simplify cli
jlamypoirier May 15, 2025
d0e86cb
Merge remote-tracking branch 'origin/main' into misc
jlamypoirier May 21, 2025
3fa314c
Merge branch 'misc' into minimalistic_dynamic_classes
jlamypoirier May 21, 2025
7bb6ee9
Merge remote-tracking branch 'origin/main' into minimalistic_dynamic_…
jlamypoirier May 21, 2025
a98a2ae
Merge branch 'minimalistic_dynamic_classes' into simplify_cli
jlamypoirier May 27, 2025
85e02b8
Merge branch 'main' into simplify_cli
jlamypoirier May 27, 2025
0fd49e4
Merge remote-tracking branch 'origin/main' into simplify_cli
jlamypoirier May 27, 2025
8ce8674
Dynamic transformer
jlamypoirier May 27, 2025
5513e48
fixes
jlamypoirier May 27, 2025
9a9578a
Merge remote-tracking branch 'origin/main' into simplify_cli
jlamypoirier Jun 4, 2025
6a3e48f
fixes
jlamypoirier Jun 4, 2025
3971464
Merge remote-tracking branch 'origin/main' into dynamic_transformer
jlamypoirier Jun 4, 2025
757f115
Merge branch 'dynamic_transformer' into simplify_cli
jlamypoirier Jun 4, 2025
81e53f8
fixes
jlamypoirier Jun 4, 2025
9eb745c
Merge remote-tracking branch 'origin/main' into dynamic_transformer
jlamypoirier Jun 12, 2025
52c109a
Merge branch 'dynamic_transformer' into simplify_cli
jlamypoirier Jun 12, 2025
62abf27
fixes
jlamypoirier Jun 12, 2025
619a0da
Update fast_llm/layers/transformer/rotary/rotary.py
jlamypoirier Jun 12, 2025
912431c
fix
jlamypoirier Jun 12, 2025
3c8cf0b
Merge remote-tracking branch 'origin/main' into dynamic_transformer
jlamypoirier Jun 12, 2025
e3a7c55
Merge branch 'dynamic_transformer' into simplify_cli
jlamypoirier Jun 12, 2025
de1dc4e
fix
jlamypoirier Jun 12, 2025
ec53911
Merge remote-tracking branch 'origin/main' into simplify_cli
jlamypoirier Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions fast_llm/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import logging
import sys
import traceback

from fast_llm.config import ValidationError
from fast_llm.engine.config_utils.logging import configure_logging
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.config_utils.runnable import RunnableConfig

# Import these submodules to ensure classes are added to the dynamic class registry.
import fast_llm.data.auto # isort: skip
import fast_llm.engine.checkpoint.convert # isort: skip
import fast_llm.models.auto # isort: skip

logger = logging.getLogger(__name__)


def fast_llm_main(args: list[str] | None = None):
# TODO: Add hook to register model classes? (environment variable?)
# (Pre-)configure logging
configure_logging()
try:
RunnableConfig.parse_and_run(args)
except Exception as e:
if sys.gettrace():
raise
if isinstance(e, ValidationError):
log_main_rank(traceback.format_exc(), log_fn=logger.error)
else:
logger.critical(traceback.format_exc())
sys.exit(1)


if __name__ == "__main__":
fast_llm_main()
28 changes: 24 additions & 4 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def config_class[
Fast-LLM replacement for the default dataclass wrapper. Performs additional verifications.
"""

def wrap(cls):
def wrap(cls: type[T]) -> type[T]:
Assert.custom(issubclass, cls, Config)
if hasattr(cls, "__post_init__"):
raise TypeError(f"`__post_init__` should not be implemented for `Config` classes")
Expand All @@ -274,6 +274,8 @@ def __init__(self, **kwargs):

if dynamic_type is not None:
for cls_, name in dynamic_type.items():
if cls.dynamic_type_name is None:
cls.dynamic_type_name = name
cls_.register_subclass(name, wrapped)

return wrapped
Expand Down Expand Up @@ -317,6 +319,13 @@ class Config(metaclass=ConfigMeta):
# A registry for all the config classes.
_registry: typing.ClassVar[Registry[str, type[typing.Self]] | None] = None

# The main dynamic type name used to refer to this class.
# If a class is known under multiple names,
# the `type` field will be standardized to this value to prevent ambiguities.
# TODO: Force a unique value across all classes to prevent ambiguities when loading saved configs?
# Set through the `@config_class` decorator, as the first entry in `dynamic_type`. DO NOT SET EXPLICITLY.
dynamic_type_name: typing.ClassVar[str | None] = None

def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Make the class read-only after validation.
Expand Down Expand Up @@ -370,16 +379,18 @@ def validate[T: Config](self: T, *, _is_validating: bool = False) -> T:
Validate a class and mark it as read-only
This should not be overridden in derived classes.
"""
# Should be handled in `from_dict`, but can fail if instantiating directly.
try:
expected_class = self.get_subclass(self.type)
except KeyError as e:
# Delayed instantiation error in `from_dict`.
raise ValidationError(*e.args)

if expected_class is not None:
# Should be handled in `from_dict`, but can fail if instantiating directly.
# Handled in `from_dict`, checking again for extra safety.
Assert.is_(self.__class__, expected_class)
if self.dynamic_type_name is not None:
# This also makes the type explicit.
# Done during validation so we don't accidentally use default subtypes as updates.
self.type = self.dynamic_type_name

if not self._validated:
try:
Expand All @@ -406,6 +417,7 @@ def _validate(self) -> None:
raise ValidationError(
f"{type(self).__name__} hasn't been validated. Make sure to use the @config_class decorator."
)

errors = []
with self._set_implicit_default(None):
for name, field in self.fields():
Expand Down Expand Up @@ -730,6 +742,10 @@ def from_dict(
default = copy.deepcopy(default)
for update in updates:
if isinstance(update, Config):
if update._validated:
# Try to prevent issues where fields are set and made explicit during validation, ex. `type`.
# If this is intentional (and safe), the serialized config can be used as an argument instead.
raise ValueError(f"Validated configs should not be used as update.")
update = update.to_dict(serialized=False)
else:
update = copy.deepcopy(update)
Expand Down Expand Up @@ -942,7 +958,11 @@ def __init_subclass__(cls):
f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated."
f" Make sure to use the @config_class decorator."
)

# Remove class vars set in parent class.
cls.__class_validated__ = False
cls.dynamic_type_name = None

for name in list(cls.__dict__):
value = getattr(cls, name)
if isinstance(value, FieldUpdate):
Expand Down
16 changes: 4 additions & 12 deletions fast_llm/data/auto.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
from fast_llm.data.preparator.config import DatasetPreparatorConfig
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig
from fast_llm.utils import Registry
"""
Import these submodules to ensure classes are added to the dynamic class registry.
"""

dataset_preparator_registry = Registry[str, DatasetPreparatorConfig](
"DatasetPreparator",
{
dataset_preparator.preparator_name: dataset_preparator
for dataset_preparator in [
GPTMemmapDatasetPreparatorConfig,
]
},
)
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class SamplingData:

def update_config(self, update: SamplingConfig):
return dataclasses.replace(
self, config=self.config.from_dict(self.config, update, update_type=UpdateType.update)
self, config=self.config.from_dict(self.config, update.to_dict(), update_type=UpdateType.update)
)

def get_next_rank(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/preparator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fast_llm.engine.config_utils.runnable import RunnableConfig


@config_class()
@config_class(registry=True, dynamic_type={RunnableConfig: "prepare"})
class DatasetPreparatorConfig(RunnableConfig):
preparator_name: typing.ClassVar[str]

Expand Down
3 changes: 2 additions & 1 deletion fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fast_llm.data.config import TokenizerConfig
from fast_llm.data.preparator.config import DatasetPreparatorConfig
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -116,7 +117,7 @@ def _validate(self) -> None:
Assert.in_range(self.rank, 0, self.world_size)


@config_class()
@config_class(dynamic_type={RunnableConfig: "prepare_gpt_memmap", DatasetPreparatorConfig: "gpt_memmap"})
class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
preparator_name: typing.ClassVar[str] = "gpt_memmap"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
from fast_llm.functional.config import TritonConfig
from fast_llm.models.auto import model_registry
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
Expand All @@ -18,7 +17,7 @@
logger = logging.getLogger(__name__)


@config_class()
@config_class(dynamic_type={RunnableConfig: "convert"})
class ConvertConfig(RunnableConfig):
input: CheckpointLoadConfig = Field()
output: CheckpointSaveConfig = Field()
Expand All @@ -29,22 +28,29 @@ class ConvertConfig(RunnableConfig):

@classmethod
def _get_parser(cls):
# TODO: Infer the model type from the loaded model instead?
parser = super()._get_parser()
parser.add_argument(
"model_type",
choices=model_registry.keys(),
help="The Fast-LLM model type to use. Must be defined in the model registry in `fast_llm.models.auto`.",
help="The Fast-LLM model type to use.",
)
return parser

@classmethod
def _from_parsed_args(cls, parsed: argparse.Namespace, unparsed: list[str]):
config = super()._from_parsed_args(parsed, unparsed)
config.model = model_registry[parsed.model_type]
config.model = parsed.model_type
return config

@classmethod
def _first_arg_is_dynamic_type(cls, args: list[str]) -> bool:
# The first arg defines the model type, not the converter class.
return False

def _validate(self):
assert self.model is not None
if isinstance(self.model, str):
self.model = FastLLMModelConfig.get_subclass(self.model)
self.input.setup(self.model)
self.output.setup(self.model)
super()._validate()
Expand Down Expand Up @@ -158,7 +164,3 @@ def run(self):
# All good!
(self.output.path / "ok").open("w")
logger.info(f">>> All done!")


if __name__ == "__main__":
ConvertConfig.parse_and_run()
13 changes: 11 additions & 2 deletions fast_llm/engine/config_utils/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
logger = logging.getLogger(__name__)


@config_class()
@config_class(registry=True)
class RunnableConfig(Config):
@classmethod
def parse_and_run(cls, args=None) -> None:
def parse_and_run(cls, args: list[str] | None = None) -> None:
if args is None:
args = sys.argv[1:]
if cls._first_arg_is_dynamic_type(args):
# Allow chained dynamic type selection without the `type=`, ex. `train gpt`.
return cls.get_subclass(args[0]).parse_and_run(args[1:])
parsed, unparsed = cls._get_parser().parse_known_args(args)
with NoAutoValidate():
config: "RunnableConfig" = cls._from_parsed_args(parsed, unparsed)
Expand All @@ -35,6 +40,10 @@ def parse_and_run(cls, args=None) -> None:
config._show(parsed.verbose)
runnable()

@classmethod
def _first_arg_is_dynamic_type(cls, args: list[str]) -> bool:
return len(args) >= 1 and "=" not in args[0] and not args[0].startswith("-")

@classmethod
def _get_parser(cls) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
Expand Down
8 changes: 2 additions & 6 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _validate(self) -> None:
SHARD_PAD_TO_MULTIPLE = 32


@config_class()
@config_class(registry=True)
class FastLLMModelConfig(Config):
_abstract = True
checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = (
Expand Down Expand Up @@ -377,13 +377,9 @@ def _from_dict(
if "fast_llm_version" not in default:
default["fast_llm_version"] = "0"

# Determine the model config class.
from fast_llm.models.auto import model_registry

model_config_class = default["model"]
if isinstance(model_config_class, str):
Assert.incl(model_config_class, model_registry)
model_config_class = model_registry[model_config_class]
model_config_class = FastLLMModelConfig.get_subclass(default["model"])
default["model"] = model_config_class

# TODO v0.3: Remove backward compatibility.
Expand Down
3 changes: 2 additions & 1 deletion fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DistributedCheckpointFormat,
)
from fast_llm.engine.config_utils.run import ExperimentConfig
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig
from fast_llm.engine.optimizer.config import OptimizerConfig
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
Expand Down Expand Up @@ -333,7 +334,7 @@ def _validate(self) -> None:
self.wandb.alert.assert_sub_interval(self.logs)


@config_class()
@config_class(registry=True, dynamic_type={RunnableConfig: "train"})
class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig):
_abstract = True
# TODO: Generalize data, schedule, logging, etc.
Expand Down
1 change: 0 additions & 1 deletion fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def _from_dict(
class NoNormalizationConfig(NormalizationConfig):
_abstract = False

@abc.abstractmethod
def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module":
return torch.nn.Identity()

Expand Down
35 changes: 6 additions & 29 deletions fast_llm/models/auto.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,7 @@
from fast_llm.engine.multi_stage.config import FastLLMModelConfig
from fast_llm.engine.training.config import TrainerConfig
from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig
from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig
from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridTrainerConfig
from fast_llm.utils import Registry
"""
Import these submodules to ensure classes are added to the dynamic class registry.
"""

model_registry = Registry[str, FastLLMModelConfig](
"Model",
{
model.model_name: model
for model in [
GPTModelConfig,
CustomModelConfig,
HybridSSMModelConfig,
]
},
)

trainer_registry = Registry[str, TrainerConfig](
"Model",
{
trainer.get_field("model").type.model_name: trainer
for trainer in [
GPTTrainerConfig,
CustomTrainerConfig,
HybridTrainerConfig,
]
},
)
from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig # isort: skip
from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip
from fast_llm.models.ssm.config import HybridSSMModelConfig, HybridSSMTrainerConfig # isort: skip
7 changes: 5 additions & 2 deletions fast_llm/models/custom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from fast_llm.config import FieldUpdate, config_class
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.engine.multi_stage.config import FastLLMModelConfig
from fast_llm.engine.training.config import TrainerConfig
from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig

if typing.TYPE_CHECKING:
Expand All @@ -22,7 +25,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig):
pass


@config_class()
@config_class(dynamic_type={FastLLMModelConfig: "gpt_custom"})
class CustomModelConfig(GPTModelConfig):
# TODO: Add custom model config parameters, if any (typically none).
model_name: typing.ClassVar[str] = "gpt_custom"
Expand All @@ -46,7 +49,7 @@ class PretrainedCustomModelConfig(PretrainedGPTModelConfig):
model: CustomModelConfig = FieldUpdate()


@config_class()
@config_class(dynamic_type={RunnableConfig: "train_gpt_custom", TrainerConfig: "gpt_custom"})
class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig):
# TODO: Add custom trainer config parameters, if any (typically none).
data: CustomDataConfig = FieldUpdate()
Expand Down
Loading
Loading