Skip to content

Can't import diffusers with Flax but without PyTorch #685

Closed
@pcuenca

Description

@pcuenca

Describe the bug

I discovered this when testing #659. It is currently possible to import diffusers without any of PyTorch and Flax; the dummy classes are loaded. But when Flax is installed and PyTorch isn't, then importing fails.

This is because of two reasons:

  • modeling_flax_utils.py imports load_state_dict. This is so we can perform the weight conversion from a PyTorch checkout.
  • Flax scheduler outputs are subclasses of SchedulerOutput, which declares the sample as a PyTorch tensor. I think we should create a FlaxSchedulerOutput instead.

Reproduction

Doesn't work

pip uninstall torch
pip install flax

>>> import diffusers

Works

pip uninstall torch
pip uninstall flax

>>> import diffusers

Logs

No response

System Info

Diffusers @ 877bec8

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions