Skip to content

Commit 4496a40

Browse files
authored
Typing improvements (#114)
1 parent 07b1622 commit 4496a40

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+1136
-953
lines changed

docs/developer_guide/conversion.md

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define:
230230

231231
```python
232232
def _create_weight_converters(self) -> list[WeightConverter]:
233-
converters = []
234-
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
235-
num_layers = self._model.base_model_config.transformer.num_layers
236-
237-
# A simple renaming example, for the word embeddings.
238-
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))
239-
240-
# We usually want to loop dynamically over layers
241-
for i in range(num_layers):
242-
# A `SplitWeightConverter` example, splitting a weight in two.
243-
converters.append(SplitWeightConverter(
244-
f"layers.{i+1}.weight",
245-
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
246-
))
247-
return converters
233+
converters = []
234+
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
235+
num_layers = self._model.config.base_model.transformer.num_layers
236+
237+
# A simple renaming example, for the word embeddings.
238+
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))
239+
240+
# We usually want to loop dynamically over layers
241+
for i in range(num_layers):
242+
# A `SplitWeightConverter` example, splitting a weight in two.
243+
converters.append(SplitWeightConverter(
244+
f"layers.{i + 1}.weight",
245+
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
246+
))
247+
return converters
248248
```
249249

250250
And that's it! We're ready to use the new checkpoint format in Fast-LLM.

examples/mistral-4-node-benchmark.yaml renamed to examples/mistral.yaml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ training:
77
iterations: null
88
test_iters: 0
99
batch:
10-
sequence_length: 8192
11-
micro_batch_size: 1
12-
batch_size: 32
10+
sequence_length: 4096
11+
micro_batch_size: 2
12+
batch_size: 64
1313
data:
1414
format: random
1515
split: [1, 0, 0]
@@ -27,18 +27,18 @@ model:
2727
normalization:
2828
type: rms_norm
2929
epsilon: 1.0e-05
30+
rotary:
31+
type: default
32+
theta: 10000
3033
num_layers: 32
3134
hidden_size: 4096
3235
ffn_hidden_size: 14336
3336
num_attention_heads: 32
3437
head_groups: 8
3538
add_linear_biases: false
36-
use_rotary_embeddings: true
3739
gated: true
3840
activation_type: silu
39-
triton_rotary: true
4041
kv_channels: 128
41-
rotary_embedding_scale: -9.210340371976184
4242
window_size: 4096
4343
init_method_std: 0.009021
4444
attention_dropout: 0.0
@@ -49,7 +49,6 @@ model:
4949
zero_stage: 2
5050
distributed:
5151
training_dtype: bf16
52-
distributed_timeout: 3600
5352
seed: 984059
5453
run:
55-
experiment_dir: mistral_4_nodes_benchmark
54+
experiment_dir: mistral_example

fast_llm/config.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,11 @@ def __post_init__(self):
287287
In general this should not be overridden in derived classes,
288288
and all post-processing should be done in `_validate`
289289
"""
290-
self._check_abstract()
291290
self._validated = False
292291
if _AUTO_VALIDATE:
293292
self.validate()
294293

295-
def __setattr__(self, key, value):
294+
def __setattr__(self, key: str, value: typing.Any) -> None:
296295
"""
297296
Make the class read-only after validation.
298297
"""
@@ -308,7 +307,7 @@ def __setattr__(self, key, value):
308307
)
309308
super().__setattr__(key, value)
310309

311-
def __delattr__(self, key):
310+
def __delattr__(self, key: str) -> None:
312311
"""
313312
Make the class read-only after validation.
314313
"""
@@ -319,7 +318,7 @@ def __delattr__(self, key):
319318
)
320319
super().__delattr__(key)
321320

322-
def validate(self, *, _is_validating=False):
321+
def validate[T](self: T, *, _is_validating: bool = False) -> T:
323322
"""
324323
Validate a class and mark it as read-only
325324
This should not be overridden in derived classes.
@@ -335,14 +334,15 @@ def validate(self, *, _is_validating=False):
335334
self._validated = True
336335
return self
337336

338-
def _validate(self):
337+
def _validate(self) -> None:
339338
"""
340339
Verify that the type hints are respected,
341340
and fix some know entries compatible with the type hint (ex. `int -> float`, `str -> pathlib.Path`)
342341
343342
Can be extended to add custom post-processing (typically before the super() call)
344343
and validation (typically after)
345344
"""
345+
self._check_abstract()
346346
errors = []
347347
for name, field in self.fields():
348348
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
@@ -522,7 +522,7 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]:
522522
return cls.__dataclass_fields__.items() # noqa
523523

524524
@classmethod
525-
def get_field(cls, name) -> Field:
525+
def get_field(cls, name: str) -> Field:
526526
return cls.__dataclass_fields__[name] # noqa
527527

528528
def _to_dict(
@@ -531,7 +531,7 @@ def _to_dict(
531531
all_fields: bool = False,
532532
format_: _ConfigDictFormat = _ConfigDictFormat.nested,
533533
serializable: bool = False,
534-
):
534+
) -> dict[str, typing.Any]:
535535
"""
536536
Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`.
537537
When not flat, the dict includes a `__class__` entry which allows support for derived classes.
@@ -561,12 +561,12 @@ def _add_field_to_args(
561561
args: dict | list,
562562
name: str | None,
563563
field: Field | None,
564-
value,
564+
value: typing.Any,
565565
verbose: int | None = None,
566566
all_fields: bool = False,
567567
format_: _ConfigDictFormat = _ConfigDictFormat.nested,
568568
serializable: bool = False,
569-
):
569+
) -> None:
570570
if (
571571
field is not None
572572
and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR)
@@ -604,17 +604,12 @@ def _add_field_to_args(
604604
else:
605605
field_value = value
606606
if serializable:
607-
if hasattr(value, "__fast_llm_serialize__"):
608-
field_value = field_value.__fast_llm_serialize__()
609-
if isinstance(value, enum.Enum):
610-
field_value = field_value.value
611-
# Tag is not actually serializable, but needs to be kept as-is for config processing,
612-
# and should be absent for valid configs.
613-
elif not isinstance(value, int | float | bool | str | Tag | None):
614-
field_value = str(field_value)
607+
field_value = cls._serialize_value(value)
615608
if format_ == _ConfigDictFormat.tuple:
616609
field_value = {(): field_value}
617610

611+
if serializable:
612+
name = cls._serialize_value(name)
618613
if format_ == _ConfigDictFormat.tuple:
619614
args.update({(name,) + name_: value_ for name_, value_ in field_value.items()})
620615
elif format_ == _ConfigDictFormat.nested:
@@ -626,24 +621,37 @@ def _add_field_to_args(
626621
else:
627622
raise NotImplementedError(format_)
628623

629-
def to_copy(
630-
self,
631-
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
632-
strict: bool = True,
633-
):
624+
@classmethod
625+
def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None:
626+
value = value
627+
if hasattr(value, "__fast_llm_serialize__"):
628+
value = value.__fast_llm_serialize__()
629+
if isinstance(value, enum.Enum):
630+
value = value.value
631+
# Tag is not actually serializable, but needs to be kept as-is for config processing,
632+
# and should be absent for valid configs.
633+
elif not isinstance(value, int | float | bool | str | Tag | None):
634+
value = str(value)
635+
return value
636+
637+
def to_copy[
638+
T
639+
](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T:
634640
return self.from_dict(self, *updates, strict=strict)
635641

636-
def to_serialized(self, verbose: int | None = FieldVerboseLevel.core):
642+
def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]:
637643
return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True)
638644

639-
def to_logs(
645+
def to_logs[
646+
T
647+
](
640648
self,
641649
verbose: int | None = FieldVerboseLevel.core,
642-
log_fn=logger.info,
650+
log_fn: typing.Callable[[str], T] = logger.info,
643651
title: str | None = None,
644652
width: int = 80,
645653
fill_char: str = "-",
646-
):
654+
) -> T:
647655
arg_dict = self.to_serialized(verbose=verbose)
648656
if title is None:
649657
title = self._get_class_name()
@@ -654,7 +662,7 @@ def to_logs(
654662
)
655663

656664
@classmethod
657-
def _get_class_name(cls):
665+
def _get_class_name(cls) -> str:
658666
return get_type_name(cls)
659667

660668
@classmethod
@@ -663,7 +671,7 @@ def from_dict(
663671
default: typing.Union["Config", dict[str, typing.Any]],
664672
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
665673
strict: bool = True,
666-
):
674+
) -> typing.Self:
667675
if isinstance(default, Config):
668676
default = default._to_dict()
669677
for update in updates:
@@ -679,7 +687,7 @@ def from_flat_dict(
679687
cls,
680688
default: dict[str, typing.Any],
681689
strict: bool = True,
682-
):
690+
) -> typing.Self:
683691
# TODO v0.3: Remove flat format
684692
return cls._from_dict(default, strict, True)
685693

@@ -689,8 +697,7 @@ def _from_dict(
689697
default: dict[str, typing.Any],
690698
strict: bool = True,
691699
flat: bool = False,
692-
):
693-
cls._check_abstract()
700+
) -> typing.Self:
694701
# TODO v0.3: Remove flat format
695702
out_arg_dict = {}
696703

@@ -807,7 +814,7 @@ def _handle_renamed_field(
807814
old_name: str | tuple[str, ...],
808815
new_name: str | tuple[str, ...],
809816
fn: typing.Callable | None = None,
810-
):
817+
) -> None:
811818
if old_name in default:
812819
warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.")
813820
value = pop_nested_dict_value(default, old_name)
@@ -839,11 +846,13 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
839846
)
840847

841848
@classmethod
842-
def _check_abstract(cls):
849+
def _check_abstract(cls) -> None:
843850
if cls._abstract:
844-
raise RuntimeError(f"{cls.__name__} is abstract")
851+
raise ValidationError(f"{cls.__name__} is abstract")
845852
if not cls.__class_validated__:
846-
raise RuntimeError(f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator.")
853+
raise ValidationError(
854+
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
855+
)
847856

848857
def __init_subclass__(cls):
849858
"""
@@ -893,3 +902,17 @@ def __init_subclass__(cls):
893902
else:
894903
# dataclasses expects an annotation, so we use the one from the base class.
895904
cls.__annotations__[name] = base_class_field.type
905+
906+
907+
class Configurable[ConfigType: Config]:
908+
config_class: typing.ClassVar[type[Config]] = Config
909+
910+
def __init__(self, config: ConfigType, *args, **kwargs):
911+
Assert.custom(isinstance, config, self.config_class)
912+
self._config = config
913+
# Handle multiple inheritance.
914+
super().__init__(*args, **kwargs)
915+
916+
@property
917+
def config(self) -> ConfigType:
918+
return self._config

0 commit comments

Comments
 (0)