Skip to content

Expose all nutpie compile backends through pm.sample #7497

Open
@jessegrabowski

Description

@jessegrabowski

Description

Nutpie currently has two compile modes, numba and JAX, with a 3rd pytorch backend on the way. It would be nice if we could easily access these via pm.sample.

Proposal 1: Allow nutpie.compile_pymc kwargs in nuts_sampler_kwargs

  • Pros: It's easy, since there are only two such arguments: backend and gradient-backend. We just check for and pop them before forwarding all other arguments to nutpie.sample.
  • Cons: It might be see as "unexpected" behavior, since some keywords are going to one function, and some to another. Also, the nuts_sampler_kwargs argument isn't very beautiful in the first place

Proposal 2: pip-style optional arguments, like nuts_sampler="nutpie[jax]" and nuts_sampler="nutpie[numba]"

  • Pros: It's quite pretty!
  • Cons: technically you can pick both the forward and backward compile mode, so if a user wanted that, she'd still have to import nutpie and do it manually. Maybe that's enough of a corner case that it's ok? Also it's a different API to other samplers (although blackjax could benefit from something similar to ask for the many different options over there -- but that's beyond the scope here).

Proposal 3: Add a new compile_kwargs argument to pm.sample

  • Pros: It's very clear. It could be used to forward kwargs to pytensor as well, which is a nice side bonus.
  • Cons: It's another argument to an already bloated pm.sample function

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions