Skip to content

Fix import with Flax but without PyTorch #688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxSchedulerMixin,
FlaxScoreSdeVeScheduler,
)
else:
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError

from . import is_torch_available
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import load_state_dict
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
Expand Down Expand Up @@ -391,6 +391,14 @@ def from_pretrained(
)

if from_pt:
if is_torch_available():
from .modeling_utils import load_state_dict
else:
raise EnvironmentError(
"Can't load the model in PyTorch format because PyTorch is not installed. "
"Please, install PyTorch or use native Flax weights."
)

# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)

Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from .configuration_utils import ConfigMixin
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging


Expand All @@ -46,7 +46,7 @@
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"FlaxSchedulerMixin": ["save_config", "from_config"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
Expand Down Expand Up @@ -436,7 +436,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin):
elif issubclass(class_obj, FlaxSchedulerMixin):
loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = scheduler_state
else:
Expand Down
16 changes: 10 additions & 6 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline


if is_torch_available():
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
else:
from ..utils.dummy_pt_objects import * # noqa F403

if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import PIL
from PIL import Image

from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available


@dataclass
Expand All @@ -27,7 +27,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected: List[bool]


if is_transformers_available():
if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin
else:
from ..utils.dummy_flax_objects import * # noqa F403

if is_scipy_available():

if is_scipy_available() and is_torch_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
16 changes: 8 additions & 8 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -68,11 +68,11 @@ def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):


@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput):
state: DDIMSchedulerState


class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
Expand Down Expand Up @@ -183,7 +183,7 @@ def step(
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -197,11 +197,11 @@ def step(
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class

Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.

"""
if state.num_inference_steps is None:
Expand Down Expand Up @@ -252,7 +252,7 @@ def step(
if not return_dict:
return (prev_sample, state)

return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state)

def add_noise(
self,
Expand Down
16 changes: 8 additions & 8 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import random

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
Expand Down Expand Up @@ -67,11 +67,11 @@ def create(cls, num_train_timesteps: int):


@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
class FlaxDDPMSchedulerOutput(FlaxSchedulerOutput):
state: DDPMSchedulerState


class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
Expand Down Expand Up @@ -191,7 +191,7 @@ def step(
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -205,11 +205,11 @@ def step(
key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class

Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxDDPMSchedulerOutput`] or `tuple`: [`FlaxDDPMSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.

"""
t = timestep
Expand Down Expand Up @@ -257,7 +257,7 @@ def step(
if not return_dict:
return (pred_prev_sample, state)

return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
return FlaxDDPMSchedulerOutput(prev_sample=pred_prev_sample, state=state)

def add_noise(
self,
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/schedulers/scheduling_karras_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils_flax import FlaxSchedulerMixin


@flax.struct.dataclass
Expand Down Expand Up @@ -56,7 +56,7 @@ class FlaxKarrasVeOutput(BaseOutput):
state: KarrasVeSchedulerState


class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
Expand Down Expand Up @@ -170,7 +170,7 @@ def step(
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class

Returns:
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
Expand Down Expand Up @@ -209,7 +209,7 @@ def step_correct(
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class

Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
Expand Down
16 changes: 8 additions & 8 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput


@flax.struct.dataclass
Expand All @@ -37,11 +37,11 @@ def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):


@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
class FlaxLMSSchedulerOutput(FlaxSchedulerOutput):
state: LMSDiscreteSchedulerState


class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson:
Expand Down Expand Up @@ -145,7 +145,7 @@ def step(
sample: jnp.ndarray,
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[FlaxLMSSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Expand All @@ -157,11 +157,11 @@ def step(
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than FlaxLMSSchedulerOutput class

Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
[`FlaxLMSSchedulerOutput`] or `tuple`: [`FlaxLMSSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.

"""
sigma = state.sigmas[timestep]
Expand All @@ -187,7 +187,7 @@ def step(
if not return_dict:
return (prev_sample, state)

return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return FlaxLMSSchedulerOutput(prev_sample=prev_sample, state=state)

def add_noise(
self,
Expand Down
Loading