Skip to content

Typing improvements #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions docs/developer_guide/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define:

```python
def _create_weight_converters(self) -> list[WeightConverter]:
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.base_model_config.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i+1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
```

And that's it! We're ready to use the new checkpoint format in Fast-LLM.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ training:
iterations: null
test_iters: 0
batch:
sequence_length: 8192
micro_batch_size: 1
batch_size: 32
sequence_length: 4096
micro_batch_size: 2
batch_size: 64
data:
format: random
split: [1, 0, 0]
Expand All @@ -27,18 +27,18 @@ model:
normalization:
type: rms_norm
epsilon: 1.0e-05
rotary:
type: default
theta: 10000
num_layers: 32
hidden_size: 4096
ffn_hidden_size: 14336
num_attention_heads: 32
head_groups: 8
add_linear_biases: false
use_rotary_embeddings: true
gated: true
activation_type: silu
triton_rotary: true
kv_channels: 128
rotary_embedding_scale: -9.210340371976184
window_size: 4096
init_method_std: 0.009021
attention_dropout: 0.0
Expand All @@ -49,7 +49,6 @@ model:
zero_stage: 2
distributed:
training_dtype: bf16
distributed_timeout: 3600
seed: 984059
run:
experiment_dir: mistral_4_nodes_benchmark
experiment_dir: mistral_example
93 changes: 58 additions & 35 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,11 @@ def __post_init__(self):
In general this should not be overridden in derived classes,
and all post-processing should be done in `_validate`
"""
self._check_abstract()
self._validated = False
if _AUTO_VALIDATE:
self.validate()

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Make the class read-only after validation.
"""
Expand All @@ -308,7 +307,7 @@ def __setattr__(self, key, value):
)
super().__setattr__(key, value)

def __delattr__(self, key):
def __delattr__(self, key: str) -> None:
"""
Make the class read-only after validation.
"""
Expand All @@ -319,7 +318,7 @@ def __delattr__(self, key):
)
super().__delattr__(key)

def validate(self, *, _is_validating=False):
def validate[T](self: T, *, _is_validating: bool = False) -> T:
"""
Validate a class and mark it as read-only
This should not be overridden in derived classes.
Expand All @@ -335,14 +334,15 @@ def validate(self, *, _is_validating=False):
self._validated = True
return self

def _validate(self):
def _validate(self) -> None:
"""
Verify that the type hints are respected,
and fix some know entries compatible with the type hint (ex. `int -> float`, `str -> pathlib.Path`)

Can be extended to add custom post-processing (typically before the super() call)
and validation (typically after)
"""
self._check_abstract()
errors = []
for name, field in self.fields():
if not field.init or field._field_type == dataclasses._FIELD_CLASSVAR: # noqa
Expand Down Expand Up @@ -522,7 +522,7 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]:
return cls.__dataclass_fields__.items() # noqa

@classmethod
def get_field(cls, name) -> Field:
def get_field(cls, name: str) -> Field:
return cls.__dataclass_fields__[name] # noqa

def _to_dict(
Expand All @@ -531,7 +531,7 @@ def _to_dict(
all_fields: bool = False,
format_: _ConfigDictFormat = _ConfigDictFormat.nested,
serializable: bool = False,
):
) -> dict[str, typing.Any]:
"""
Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`.
When not flat, the dict includes a `__class__` entry which allows support for derived classes.
Expand Down Expand Up @@ -561,12 +561,12 @@ def _add_field_to_args(
args: dict | list,
name: str | None,
field: Field | None,
value,
value: typing.Any,
verbose: int | None = None,
all_fields: bool = False,
format_: _ConfigDictFormat = _ConfigDictFormat.nested,
serializable: bool = False,
):
) -> None:
if (
field is not None
and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR)
Expand Down Expand Up @@ -604,17 +604,12 @@ def _add_field_to_args(
else:
field_value = value
if serializable:
if hasattr(value, "__fast_llm_serialize__"):
field_value = field_value.__fast_llm_serialize__()
if isinstance(value, enum.Enum):
field_value = field_value.value
# Tag is not actually serializable, but needs to be kept as-is for config processing,
# and should be absent for valid configs.
elif not isinstance(value, int | float | bool | str | Tag | None):
field_value = str(field_value)
field_value = cls._serialize_value(value)
if format_ == _ConfigDictFormat.tuple:
field_value = {(): field_value}

if serializable:
name = cls._serialize_value(name)
if format_ == _ConfigDictFormat.tuple:
args.update({(name,) + name_: value_ for name_, value_ in field_value.items()})
elif format_ == _ConfigDictFormat.nested:
Expand All @@ -626,24 +621,37 @@ def _add_field_to_args(
else:
raise NotImplementedError(format_)

def to_copy(
self,
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
):
@classmethod
def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None:
value = value
if hasattr(value, "__fast_llm_serialize__"):
value = value.__fast_llm_serialize__()
if isinstance(value, enum.Enum):
value = value.value
# Tag is not actually serializable, but needs to be kept as-is for config processing,
# and should be absent for valid configs.
elif not isinstance(value, int | float | bool | str | Tag | None):
value = str(value)
return value

def to_copy[
T
](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T:
return self.from_dict(self, *updates, strict=strict)

def to_serialized(self, verbose: int | None = FieldVerboseLevel.core):
def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]:
return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True)

def to_logs(
def to_logs[
T
](
self,
verbose: int | None = FieldVerboseLevel.core,
log_fn=logger.info,
log_fn: typing.Callable[[str], T] = logger.info,
title: str | None = None,
width: int = 80,
fill_char: str = "-",
):
) -> T:
arg_dict = self.to_serialized(verbose=verbose)
if title is None:
title = self._get_class_name()
Expand All @@ -654,7 +662,7 @@ def to_logs(
)

@classmethod
def _get_class_name(cls):
def _get_class_name(cls) -> str:
return get_type_name(cls)

@classmethod
Expand All @@ -663,7 +671,7 @@ def from_dict(
default: typing.Union["Config", dict[str, typing.Any]],
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
):
) -> typing.Self:
if isinstance(default, Config):
default = default._to_dict()
for update in updates:
Expand All @@ -679,7 +687,7 @@ def from_flat_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
):
) -> typing.Self:
# TODO v0.3: Remove flat format
return cls._from_dict(default, strict, True)

Expand All @@ -689,8 +697,7 @@ def _from_dict(
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
):
cls._check_abstract()
) -> typing.Self:
# TODO v0.3: Remove flat format
out_arg_dict = {}

Expand Down Expand Up @@ -807,7 +814,7 @@ def _handle_renamed_field(
old_name: str | tuple[str, ...],
new_name: str | tuple[str, ...],
fn: typing.Callable | None = None,
):
) -> None:
if old_name in default:
warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.")
value = pop_nested_dict_value(default, old_name)
Expand Down Expand Up @@ -839,11 +846,13 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
)

@classmethod
def _check_abstract(cls):
def _check_abstract(cls) -> None:
if cls._abstract:
raise RuntimeError(f"{cls.__name__} is abstract")
raise ValidationError(f"{cls.__name__} is abstract")
if not cls.__class_validated__:
raise RuntimeError(f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator.")
raise ValidationError(
f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator."
)

def __init_subclass__(cls):
"""
Expand Down Expand Up @@ -893,3 +902,17 @@ def __init_subclass__(cls):
else:
# dataclasses expects an annotation, so we use the one from the base class.
cls.__annotations__[name] = base_class_field.type


class Configurable[ConfigType: Config]:
config_class: typing.ClassVar[type[Config]] = Config

def __init__(self, config: ConfigType, *args, **kwargs):
Assert.custom(isinstance, config, self.config_class)
self._config = config
# Handle multiple inheritance.
super().__init__(*args, **kwargs)

@property
def config(self) -> ConfigType:
return self._config
Loading
Loading