-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Fix import with Flax but without PyTorch #688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here SchedulerOutput
is conditionally defined to use either PyTorch tensors or jax arrays.
This is a little different to what we usually do: create an alternative class FlaxSchedulerOutput
, load one or the other, and the corresponding dummy versions. But this approach felt easier and I think it would touch less files. Happy to change it to the other method if this is not a good idea.
The documentation is not available anymore as the PR was closed or merged. |
src/diffusers/modeling_flax_utils.py
Outdated
if is_torch_available(): | ||
from .modeling_utils import load_state_dict | ||
else: | ||
raise EnvironmentError( | ||
f"Can't load the model in PyTorch format because PyTorch is not installed. " | ||
f"Please, install PyTorch or use native Flax weights." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot! Looks good to me, my only comment is that we should Flax
prefix to all those output classes for consistency and also to make it different from PT.
This reverts commit f653140.
@patil-suraj I applied your suggestions but then found additional issues after installing on a fresh environment.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks a lot!
* Don't use `load_state_dict` if torch is not installed. * Define `SchedulerOutput` to use torch or flax arrays. * Don't import LMSDiscreteScheduler without torch. * Create distinct FlaxSchedulerOutput. * Additional changes required for FlaxSchedulerMixin * Do not import torch pipelines in Flax. * Revert "Define `SchedulerOutput` to use torch or flax arrays." This reverts commit f653140. * Prefix Flax scheduler outputs for consistency. * make style * FlaxSchedulerOutput is now a dataclass. * Don't use f-string without placeholders. * Add blank line. * Style (docstrings)
* Don't use `load_state_dict` if torch is not installed. * Define `SchedulerOutput` to use torch or flax arrays. * Don't import LMSDiscreteScheduler without torch. * Create distinct FlaxSchedulerOutput. * Additional changes required for FlaxSchedulerMixin * Do not import torch pipelines in Flax. * Revert "Define `SchedulerOutput` to use torch or flax arrays." This reverts commit f653140. * Prefix Flax scheduler outputs for consistency. * make style * FlaxSchedulerOutput is now a dataclass. * Don't use f-string without placeholders. * Add blank line. * Style (docstrings)
Fixes #685.