Skip to content

Add Jax Inference Methods #503

@BradyPlanden

Description

@BradyPlanden

Feature description

#481 adds the Jaxified IDAKLU solver as an experimental implementation with auto-differentiation applied to the cost/likelihood functions. This issue aims to expand this functionality with Jax inference methods such as:

  • Numpyro for MCMC sampling
  • Optax for frequentist/deterministic inference methods
  • GPJax for Gaussian Processes
  • BlackJax for sampling

Motivation

Jax offers a compiled interface for parameter optimisation with lowering to both GPU/TPU. This can enable both performance improvements for PyBOP's methods, as well as removing the need for manual definition of gradients from cost/likelihoods.

Possible implementation

Design outlines and discussion needs to occur to ensure an integrated development into PyBOP's predefined design.

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    Status

    Todo

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions