Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 132 additions & 19 deletions src/diffusers/schedulers/scheduling_amused.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import torch

Expand All @@ -9,13 +9,48 @@
from .scheduling_utils import SchedulerMixin


def gumbel_noise(t, generator=None):
def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
"""
Generate Gumbel noise for sampling.

Args:
t (`torch.Tensor`):
Input tensor to match the shape and dtype of the output noise.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible sampling.

Returns:
`torch.Tensor`:
Gumbel-distributed noise with the same shape, dtype, and device as the input tensor.
"""
device = generator.device if generator is not None else t.device
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))


def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
def mask_by_random_topk(
mask_len: torch.Tensor,
probs: torch.Tensor,
temperature: float = 1.0,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
"""
Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.

Args:
mask_len (`torch.Tensor`):
Number of tokens to mask per sample in the batch.
probs (`torch.Tensor`):
Probability scores for each token.
temperature (`float`, *optional*, defaults to 1.0):
Temperature parameter for controlling randomness in the masking process.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible sampling.

Returns:
`torch.Tensor`:
Boolean mask indicating which tokens should be masked.
"""
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
sorted_confidence = torch.sort(confidence, dim=-1).values
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
Expand All @@ -29,38 +64,70 @@ class AmusedSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.

Args:
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`):
Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model
input in the denoising loop.
pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*):
The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current
timestep. `pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.Tensor
pred_original_sample: torch.Tensor = None
pred_original_sample: Optional[torch.Tensor] = None


class AmusedScheduler(SchedulerMixin, ConfigMixin):
"""
A scheduler for masked token generation as used in [`AmusedPipeline`].

This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear
schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates
on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models.

This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the
generic methods the library implements for all schedulers such as loading and saving.

Args:
mask_token_id (`int`):
The token ID used to represent masked tokens in the sequence.
masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`):
The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`.
"""

order = 1

temperatures: torch.Tensor
temperatures: Optional[torch.Tensor]
timesteps: Optional[torch.Tensor]

@register_to_config
def __init__(
self,
mask_token_id: int,
masking_schedule: str = "cosine",
masking_schedule: Literal["cosine", "linear"] = "cosine",
):
self.temperatures = None
self.timesteps = None

def set_timesteps(
self,
num_inference_steps: int,
temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
device: Union[str, torch.device] = None,
):
temperature: Union[float, Tuple[float, float], List[float]] = (2, 0),
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Set the discrete timesteps used for the diffusion chain (to be run before inference).

Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`):
Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
temperatures will be linearly interpolated between the first and second values across all timesteps. If
a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not
moved.
"""
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)

if isinstance(temperature, (tuple, list)):
Expand All @@ -71,12 +138,38 @@ def set_timesteps(
def step(
self,
model_output: torch.Tensor,
timestep: torch.long,
timestep: int,
sample: torch.LongTensor,
starting_mask_ratio: int = 1,
starting_mask_ratio: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[AmusedSchedulerOutput, Tuple]:
) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
"""
Predict the sample at the previous timestep by masking tokens based on confidence scores.

Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.LongTensor`):
A current instance of a sample created by the diffusion process. Contains token IDs, with masked
positions indicated by `mask_token_id`.
starting_mask_ratio (`float`, *optional*, defaults to 1.0):
A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
masked at each step.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible sampling.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.

Returns:
[`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the
second element is the predicted original sample tensor (`pred_original_sample`).
"""
two_dim_input = sample.ndim == 3 and model_output.ndim == 4

if two_dim_input:
Expand Down Expand Up @@ -137,7 +230,27 @@ def step(

return AmusedSchedulerOutput(prev_sample, pred_original_sample)

def add_noise(self, sample, timesteps, generator=None):
def add_noise(
self,
sample: torch.LongTensor,
timesteps: int,
generator: Optional[torch.Generator] = None,
) -> torch.LongTensor:
"""
Add noise to a sample by randomly masking tokens according to the masking schedule.

Args:
sample (`torch.LongTensor`):
The input sample containing token IDs to be partially masked.
timesteps (`int`):
The timestep that determines how much masking to apply. Higher timesteps result in more masking.
generator (`torch.Generator`, *optional*):
A random number generator for reproducible masking.

Returns:
`torch.LongTensor`:
The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
"""
step_idx = (self.timesteps == timesteps).nonzero()
ratio = (step_idx + 1) / len(self.timesteps)

Expand Down