Closed
Description
Describe the bug
I discovered this when testing #659. It is currently possible to import diffusers
without any of PyTorch and Flax; the dummy classes are loaded. But when Flax is installed and PyTorch isn't, then importing fails.
This is because of two reasons:
modeling_flax_utils.py
importsload_state_dict
. This is so we can perform the weight conversion from a PyTorch checkout.- Flax scheduler outputs are subclasses of
SchedulerOutput
, which declares the sample as a PyTorch tensor. I think we should create aFlaxSchedulerOutput
instead.
Reproduction
Doesn't work
pip uninstall torch
pip install flax
>>> import diffusers
Works
pip uninstall torch
pip uninstall flax
>>> import diffusers
Logs
No response
System Info
Diffusers @ 877bec8