Open
Description
Currently, the default configs are set like this
INTEGRATE_DEFAULT_CONFIG = {"method": "rk45", ..., tolerance=1e-3}
...
def __init__(self, ..., integrate_kwargs=None):
self.integrate_kwargs = integrate_kwargs or INTEGRATE_DEFAULT_CONFIG
This presents several issues:
- There is no way to overwrite only part of the default config. Passing a dictionary for
integrate_kwargs
immediately overrides all other defaults. - No validation is performed on the passed values, even if their values could only be relevant much later during training or at inference time, with no convenient way to fix them
- Some values, such as "method" are opaque with respect to what options are valid, and users need to rummage around in other parts of the docs
To fix, we would need to:
- Partially update the default configs with values the user passes
- Perform validation on the type and value of each parameter passed
- Improve our error messages, or provide a lot of redundant documentation
We could fix 1. and 2. with manual dictionary updates and parameter validation. However, using Pydantic's BaseModel class already allows both out of the box, while also fixing 3 with good error messages and highlighting from static type checkers.
This would change the above code to this:
class IntegrateConfig(BaseModel):
method: Literal["rk45", "euler"] = "rk45"
...
tolerance: PositiveFloat = 1e-3
...
def __init__(self, ..., integrate_kwargs=None):
self.integrate_kwargs = FlowMatching.IntegrateConfig(**(integrate_kwargs or {}))
We could also change the name to (self.)integrate_config
for consistency.
Disadvantages:
- Increased maintenance for the configs
- Slightly elevated entry barrier for contributors, since knowledge about Pydantic is required for correct typing
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Future