Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 156 additions & 33 deletions diff_latent_attack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -86,35 +87,96 @@ def ddim_reverse_sample(image, prompt, model, num_inference_steps: int = 20, gui

def register_attention_control(model, controller):
def ca_forward(self, place_in_unet):

def forward(x, context=None):
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)

sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
):
if self.spatial_norm is not None:
hidden_states = self.spatial_norm(hidden_states, temb)

batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = self.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, self.heads, -1, attention_mask.shape[-1]
) # type: ignore

if self.group_norm is not None:
hidden_states = self.group_norm(
hidden_states.transpose(1, 2)
).transpose(1, 2)

query = self.to_q(hidden_states)
is_cross = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(
encoder_hidden_states
)

key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)

def reshape_heads_to_batch_dim(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(
batch_size, seq_len, head_size, dim // head_size
)
tensor = tensor.permute(0, 2, 1, 3).reshape(
batch_size * head_size, seq_len, dim // head_size
)
return tensor

query = reshape_heads_to_batch_dim(query)
key = reshape_heads_to_batch_dim(key)
value = reshape_heads_to_batch_dim(value)

sim = torch.einsum("b i d, b j d -> b i j", query, key) * scale

attn = sim.softmax(dim=-1)
attn = controller(attn, is_cross, place_in_unet)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)

out = torch.einsum("b i j, b j d -> b i d", attn, value)

def reshape_batch_dim_to_heads(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(
batch_size // head_size, head_size, seq_len, dim
)
tensor = tensor.permute(0, 2, 1, 3).reshape(
batch_size // head_size, seq_len, dim * head_size
)
return tensor

out = reshape_batch_dim_to_heads(out)
out = self.to_out[0](out)
out = self.to_out[1](out)

out = out / self.rescale_output_factor

return out

return forward

def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'CrossAttention':
if net_.__class__.__name__ == "Attention":
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
elif hasattr(net_, 'children'):
elif hasattr(net_, "children"):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
Expand All @@ -133,32 +195,93 @@ def register_recr(net_, count, place_in_unet):

def reset_attention_control(model):
def ca_forward(self):
def forward(x, context=None):
q = self.to_q(x)
is_cross = context is not None
context = context if is_cross else x
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)

sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
):
if self.spatial_norm is not None:
hidden_states = self.spatial_norm(hidden_states, temb)

batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = self.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, self.heads, -1, attention_mask.shape[-1]
) # type: ignore

if self.group_norm is not None:
hidden_states = self.group_norm(
hidden_states.transpose(1, 2)
).transpose(1, 2)

query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(
encoder_hidden_states
)

key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)

def reshape_heads_to_batch_dim(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(
batch_size, seq_len, head_size, dim // head_size
)
tensor = tensor.permute(0, 2, 1, 3).reshape(
batch_size * head_size, seq_len, dim // head_size
)
return tensor

query = reshape_heads_to_batch_dim(query)
key = reshape_heads_to_batch_dim(key)
value = reshape_heads_to_batch_dim(value)

sim = torch.einsum("b i d, b j d -> b i j", query, key) * scale

attn = sim.softmax(dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)

out = torch.einsum("b i j, b j d -> b i d", attn, value)

def reshape_batch_dim_to_heads(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(
batch_size // head_size, head_size, seq_len, dim
)
tensor = tensor.permute(0, 2, 1, 3).reshape(
batch_size // head_size, seq_len, dim * head_size
)
return tensor

out = reshape_batch_dim_to_heads(out)
out = self.to_out[0](out)
out = self.to_out[1](out)

out = out / self.rescale_output_factor

return out

return forward

def register_recr(net_):
if net_.__class__.__name__ == 'CrossAttention':
if net_.__class__.__name__ == "Attention":
net_.forward = ca_forward(net_)
elif hasattr(net_, 'children'):
elif hasattr(net_, "children"):
for net__ in net_.children():
register_recr(net__)

Expand Down