Skip to content

Torch backend implementation #248

Open
@MUCDK

Description

@MUCDK

Description of feature

Once we merged #228, #235, #239, #240, we can start with the torch backend implementation. Therefore

  • we will keep one CellFlow class for both backends
  • have a src/backends/jax and a src/backends/torch directory
  • CellFlow.prepare_model will have an argument backend: Literal["jax", "torch"]
  • for now, let's not implement GENOT for torch, let's just go with OTFlowMatching, to keep things simple-
  • the only problem which arises is that prepare_model takes backend-specific arguments, namely match_fn, optimizer, and vf_act_fn.

For the last point, I see the following as the best solution
allow passing both jax and torch instances, setting per default None , describe the default in the docs, and eventually instantiating them later in the solver classes

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions