Skip to content

Fix import with Flax but without PyTorch #688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 3, 2022
Merged

Fix import with Flax but without PyTorch #688

merged 13 commits into from
Oct 3, 2022

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Sep 30, 2022

Fixes #685.

Copy link
Member Author

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here SchedulerOutput is conditionally defined to use either PyTorch tensors or jax arrays.

This is a little different to what we usually do: create an alternative class FlaxSchedulerOutput, load one or the other, and the corresponding dummy versions. But this approach felt easier and I think it would touch less files. Happy to change it to the other method if this is not a good idea.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 30, 2022

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 394 to 400
if is_torch_available():
from .modeling_utils import load_state_dict
else:
raise EnvironmentError(
f"Can't load the model in PyTorch format because PyTorch is not installed. "
f"Please, install PyTorch or use native Flax weights."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Looks good to me, my only comment is that we should Flax prefix to all those output classes for consistency and also to make it different from PT.

@pcuenca
Copy link
Member Author

pcuenca commented Sep 30, 2022

@patil-suraj I applied your suggestions but then found additional issues after installing on a fresh environment.

  • If transformers is not installed import failed.
  • The torch pipelines were imported when attempting to use Flax Stable Diffusion, so it didn't work.

@pcuenca pcuenca requested a review from patil-suraj September 30, 2022 16:24
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

@pcuenca pcuenca merged commit 688031c into main Oct 3, 2022
@pcuenca pcuenca deleted the fix-flax-no-torch branch October 3, 2022 14:23
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* Don't use `load_state_dict` if torch is not installed.

* Define `SchedulerOutput` to use torch or flax arrays.

* Don't import LMSDiscreteScheduler without torch.

* Create distinct FlaxSchedulerOutput.

* Additional changes required for FlaxSchedulerMixin

* Do not import torch pipelines in Flax.

* Revert "Define `SchedulerOutput` to use torch or flax arrays."

This reverts commit f653140.

* Prefix Flax scheduler outputs for consistency.

* make style

* FlaxSchedulerOutput is now a dataclass.

* Don't use f-string without placeholders.

* Add blank line.

* Style (docstrings)
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Don't use `load_state_dict` if torch is not installed.

* Define `SchedulerOutput` to use torch or flax arrays.

* Don't import LMSDiscreteScheduler without torch.

* Create distinct FlaxSchedulerOutput.

* Additional changes required for FlaxSchedulerMixin

* Do not import torch pipelines in Flax.

* Revert "Define `SchedulerOutput` to use torch or flax arrays."

This reverts commit f653140.

* Prefix Flax scheduler outputs for consistency.

* make style

* FlaxSchedulerOutput is now a dataclass.

* Don't use f-string without placeholders.

* Add blank line.

* Style (docstrings)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can't import diffusers with Flax but without PyTorch
4 participants