Open
Description
Description of feature
Once we merged #228, #235, #239, #240, we can start with the torch backend implementation. Therefore
- we will keep one
CellFlow
class for both backends - have a
src/backends/jax
and asrc/backends/torch
directory CellFlow.prepare_model
will have an argumentbackend: Literal["jax", "torch"]
- for now, let's not implement GENOT for torch, let's just go with
OTFlowMatching
, to keep things simple- - the only problem which arises is that
prepare_model
takes backend-specific arguments, namelymatch_fn
,optimizer
, andvf_act_fn
.
For the last point, I see the following as the best solution
allow passing both jax and torch instances, setting per default None
, describe the default in the docs, and eventually instantiating them later in the solver classes