@@ -287,12 +287,11 @@ def __post_init__(self):
287
287
In general this should not be overridden in derived classes,
288
288
and all post-processing should be done in `_validate`
289
289
"""
290
- self ._check_abstract ()
291
290
self ._validated = False
292
291
if _AUTO_VALIDATE :
293
292
self .validate ()
294
293
295
- def __setattr__ (self , key , value ) :
294
+ def __setattr__ (self , key : str , value : typing . Any ) -> None :
296
295
"""
297
296
Make the class read-only after validation.
298
297
"""
@@ -308,7 +307,7 @@ def __setattr__(self, key, value):
308
307
)
309
308
super ().__setattr__ (key , value )
310
309
311
- def __delattr__ (self , key ) :
310
+ def __delattr__ (self , key : str ) -> None :
312
311
"""
313
312
Make the class read-only after validation.
314
313
"""
@@ -319,7 +318,7 @@ def __delattr__(self, key):
319
318
)
320
319
super ().__delattr__ (key )
321
320
322
- def validate (self , * , _is_validating = False ):
321
+ def validate [ T ] (self : T , * , _is_validating : bool = False ) -> T :
323
322
"""
324
323
Validate a class and mark it as read-only
325
324
This should not be overridden in derived classes.
@@ -335,14 +334,15 @@ def validate(self, *, _is_validating=False):
335
334
self ._validated = True
336
335
return self
337
336
338
- def _validate (self ):
337
+ def _validate (self ) -> None :
339
338
"""
340
339
Verify that the type hints are respected,
341
340
and fix some know entries compatible with the type hint (ex. `int -> float`, `str -> pathlib.Path`)
342
341
343
342
Can be extended to add custom post-processing (typically before the super() call)
344
343
and validation (typically after)
345
344
"""
345
+ self ._check_abstract ()
346
346
errors = []
347
347
for name , field in self .fields ():
348
348
if not field .init or field ._field_type == dataclasses ._FIELD_CLASSVAR : # noqa
@@ -522,7 +522,7 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]:
522
522
return cls .__dataclass_fields__ .items () # noqa
523
523
524
524
@classmethod
525
- def get_field (cls , name ) -> Field :
525
+ def get_field (cls , name : str ) -> Field :
526
526
return cls .__dataclass_fields__ [name ] # noqa
527
527
528
528
def _to_dict (
@@ -531,7 +531,7 @@ def _to_dict(
531
531
all_fields : bool = False ,
532
532
format_ : _ConfigDictFormat = _ConfigDictFormat .nested ,
533
533
serializable : bool = False ,
534
- ):
534
+ ) -> dict [ str , typing . Any ] :
535
535
"""
536
536
Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`.
537
537
When not flat, the dict includes a `__class__` entry which allows support for derived classes.
@@ -561,12 +561,12 @@ def _add_field_to_args(
561
561
args : dict | list ,
562
562
name : str | None ,
563
563
field : Field | None ,
564
- value ,
564
+ value : typing . Any ,
565
565
verbose : int | None = None ,
566
566
all_fields : bool = False ,
567
567
format_ : _ConfigDictFormat = _ConfigDictFormat .nested ,
568
568
serializable : bool = False ,
569
- ):
569
+ ) -> None :
570
570
if (
571
571
field is not None
572
572
and (not field .init or field ._field_type == dataclasses ._FIELD_CLASSVAR )
@@ -604,17 +604,12 @@ def _add_field_to_args(
604
604
else :
605
605
field_value = value
606
606
if serializable :
607
- if hasattr (value , "__fast_llm_serialize__" ):
608
- field_value = field_value .__fast_llm_serialize__ ()
609
- if isinstance (value , enum .Enum ):
610
- field_value = field_value .value
611
- # Tag is not actually serializable, but needs to be kept as-is for config processing,
612
- # and should be absent for valid configs.
613
- elif not isinstance (value , int | float | bool | str | Tag | None ):
614
- field_value = str (field_value )
607
+ field_value = cls ._serialize_value (value )
615
608
if format_ == _ConfigDictFormat .tuple :
616
609
field_value = {(): field_value }
617
610
611
+ if serializable :
612
+ name = cls ._serialize_value (name )
618
613
if format_ == _ConfigDictFormat .tuple :
619
614
args .update ({(name ,) + name_ : value_ for name_ , value_ in field_value .items ()})
620
615
elif format_ == _ConfigDictFormat .nested :
@@ -626,24 +621,37 @@ def _add_field_to_args(
626
621
else :
627
622
raise NotImplementedError (format_ )
628
623
629
- def to_copy (
630
- self ,
631
- * updates : typing .Union ["Config" , dict [str | tuple [str , ...], typing .Any ]],
632
- strict : bool = True ,
633
- ):
624
+ @classmethod
625
+ def _serialize_value (cls , value : typing .Any ) -> int | float | bool | str | None :
626
+ value = value
627
+ if hasattr (value , "__fast_llm_serialize__" ):
628
+ value = value .__fast_llm_serialize__ ()
629
+ if isinstance (value , enum .Enum ):
630
+ value = value .value
631
+ # Tag is not actually serializable, but needs to be kept as-is for config processing,
632
+ # and should be absent for valid configs.
633
+ elif not isinstance (value , int | float | bool | str | Tag | None ):
634
+ value = str (value )
635
+ return value
636
+
637
+ def to_copy [
638
+ T
639
+ ](self : T , * updates : typing .Union ["Config" , dict [str | tuple [str , ...], typing .Any ]], strict : bool = True ,) -> T :
634
640
return self .from_dict (self , * updates , strict = strict )
635
641
636
- def to_serialized (self , verbose : int | None = FieldVerboseLevel .core ):
642
+ def to_serialized (self , verbose : int | None = FieldVerboseLevel .core ) -> dict [ str , typing . Any ] :
637
643
return self ._to_dict (verbose = verbose , format_ = _ConfigDictFormat .nested , serializable = True )
638
644
639
- def to_logs (
645
+ def to_logs [
646
+ T
647
+ ](
640
648
self ,
641
649
verbose : int | None = FieldVerboseLevel .core ,
642
- log_fn = logger .info ,
650
+ log_fn : typing . Callable [[ str ], T ] = logger .info ,
643
651
title : str | None = None ,
644
652
width : int = 80 ,
645
653
fill_char : str = "-" ,
646
- ):
654
+ ) -> T :
647
655
arg_dict = self .to_serialized (verbose = verbose )
648
656
if title is None :
649
657
title = self ._get_class_name ()
@@ -654,7 +662,7 @@ def to_logs(
654
662
)
655
663
656
664
@classmethod
657
- def _get_class_name (cls ):
665
+ def _get_class_name (cls ) -> str :
658
666
return get_type_name (cls )
659
667
660
668
@classmethod
@@ -663,7 +671,7 @@ def from_dict(
663
671
default : typing .Union ["Config" , dict [str , typing .Any ]],
664
672
* updates : typing .Union ["Config" , dict [str | tuple [str , ...], typing .Any ]],
665
673
strict : bool = True ,
666
- ):
674
+ ) -> typing . Self :
667
675
if isinstance (default , Config ):
668
676
default = default ._to_dict ()
669
677
for update in updates :
@@ -679,7 +687,7 @@ def from_flat_dict(
679
687
cls ,
680
688
default : dict [str , typing .Any ],
681
689
strict : bool = True ,
682
- ):
690
+ ) -> typing . Self :
683
691
# TODO v0.3: Remove flat format
684
692
return cls ._from_dict (default , strict , True )
685
693
@@ -689,8 +697,7 @@ def _from_dict(
689
697
default : dict [str , typing .Any ],
690
698
strict : bool = True ,
691
699
flat : bool = False ,
692
- ):
693
- cls ._check_abstract ()
700
+ ) -> typing .Self :
694
701
# TODO v0.3: Remove flat format
695
702
out_arg_dict = {}
696
703
@@ -807,7 +814,7 @@ def _handle_renamed_field(
807
814
old_name : str | tuple [str , ...],
808
815
new_name : str | tuple [str , ...],
809
816
fn : typing .Callable | None = None ,
810
- ):
817
+ ) -> None :
811
818
if old_name in default :
812
819
warnings .warn (f"Field `{ old_name } ` is deprecated in class { get_type_name (cls )} , use `{ new_name } ` instead." )
813
820
value = pop_nested_dict_value (default , old_name )
@@ -839,11 +846,13 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
839
846
)
840
847
841
848
@classmethod
842
- def _check_abstract (cls ):
849
+ def _check_abstract (cls ) -> None :
843
850
if cls ._abstract :
844
- raise RuntimeError (f"{ cls .__name__ } is abstract" )
851
+ raise ValidationError (f"{ cls .__name__ } is abstract" )
845
852
if not cls .__class_validated__ :
846
- raise RuntimeError (f"{ cls .__name__ } hasn't been validated. Make sure to use the @config_class decorator." )
853
+ raise ValidationError (
854
+ f"{ cls .__name__ } hasn't been validated. Make sure to use the @config_class decorator."
855
+ )
847
856
848
857
def __init_subclass__ (cls ):
849
858
"""
@@ -893,3 +902,17 @@ def __init_subclass__(cls):
893
902
else :
894
903
# dataclasses expects an annotation, so we use the one from the base class.
895
904
cls .__annotations__ [name ] = base_class_field .type
905
+
906
+
907
+ class Configurable [ConfigType : Config ]:
908
+ config_class : typing .ClassVar [type [Config ]] = Config
909
+
910
+ def __init__ (self , config : ConfigType , * args , ** kwargs ):
911
+ Assert .custom (isinstance , config , self .config_class )
912
+ self ._config = config
913
+ # Handle multiple inheritance.
914
+ super ().__init__ (* args , ** kwargs )
915
+
916
+ @property
917
+ def config (self ) -> ConfigType :
918
+ return self ._config
0 commit comments