Skip to content

Commit 935d477

Browse files
authored
Merge pull request #327 from jakaline-dev/main
P+: Extended Textual Conditioning in Text-to-Image Generation
2 parents b996f5a + 24e3d4b commit 935d477

File tree

4 files changed

+950
-22
lines changed

4 files changed

+950
-22
lines changed

XTI_hijack.py

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import torch
2+
from typing import Union, List, Optional, Dict, Any, Tuple
3+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
4+
5+
def unet_forward_XTI(self,
6+
sample: torch.FloatTensor,
7+
timestep: Union[torch.Tensor, float, int],
8+
encoder_hidden_states: torch.Tensor,
9+
class_labels: Optional[torch.Tensor] = None,
10+
return_dict: bool = True,
11+
) -> Union[UNet2DConditionOutput, Tuple]:
12+
r"""
13+
Args:
14+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
15+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
16+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
17+
return_dict (`bool`, *optional*, defaults to `True`):
18+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
19+
20+
Returns:
21+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
22+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
23+
returning a tuple, the first element is the sample tensor.
24+
"""
25+
# By default samples have to be AT least a multiple of the overall upsampling factor.
26+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
27+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
28+
# on the fly if necessary.
29+
default_overall_up_factor = 2**self.num_upsamplers
30+
31+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
32+
forward_upsample_size = False
33+
upsample_size = None
34+
35+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
36+
logger.info("Forward upsample size to force interpolation output size.")
37+
forward_upsample_size = True
38+
39+
# 0. center input if necessary
40+
if self.config.center_input_sample:
41+
sample = 2 * sample - 1.0
42+
43+
# 1. time
44+
timesteps = timestep
45+
if not torch.is_tensor(timesteps):
46+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
47+
# This would be a good case for the `match` statement (Python 3.10+)
48+
is_mps = sample.device.type == "mps"
49+
if isinstance(timestep, float):
50+
dtype = torch.float32 if is_mps else torch.float64
51+
else:
52+
dtype = torch.int32 if is_mps else torch.int64
53+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
54+
elif len(timesteps.shape) == 0:
55+
timesteps = timesteps[None].to(sample.device)
56+
57+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
58+
timesteps = timesteps.expand(sample.shape[0])
59+
60+
t_emb = self.time_proj(timesteps)
61+
62+
# timesteps does not contain any weights and will always return f32 tensors
63+
# but time_embedding might actually be running in fp16. so we need to cast here.
64+
# there might be better ways to encapsulate this.
65+
t_emb = t_emb.to(dtype=self.dtype)
66+
emb = self.time_embedding(t_emb)
67+
68+
if self.config.num_class_embeds is not None:
69+
if class_labels is None:
70+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
71+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
72+
emb = emb + class_emb
73+
74+
# 2. pre-process
75+
sample = self.conv_in(sample)
76+
77+
# 3. down
78+
down_block_res_samples = (sample,)
79+
down_i = 0
80+
for downsample_block in self.down_blocks:
81+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
82+
sample, res_samples = downsample_block(
83+
hidden_states=sample,
84+
temb=emb,
85+
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
86+
)
87+
down_i += 2
88+
else:
89+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
90+
91+
down_block_res_samples += res_samples
92+
93+
# 4. mid
94+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
95+
96+
# 5. up
97+
up_i = 7
98+
for i, upsample_block in enumerate(self.up_blocks):
99+
is_final_block = i == len(self.up_blocks) - 1
100+
101+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
102+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
103+
104+
# if we have not reached the final block and need to forward the
105+
# upsample size, we do it here
106+
if not is_final_block and forward_upsample_size:
107+
upsample_size = down_block_res_samples[-1].shape[2:]
108+
109+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
110+
sample = upsample_block(
111+
hidden_states=sample,
112+
temb=emb,
113+
res_hidden_states_tuple=res_samples,
114+
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
115+
upsample_size=upsample_size,
116+
)
117+
up_i += 3
118+
else:
119+
sample = upsample_block(
120+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
121+
)
122+
# 6. post-process
123+
sample = self.conv_norm_out(sample)
124+
sample = self.conv_act(sample)
125+
sample = self.conv_out(sample)
126+
127+
if not return_dict:
128+
return (sample,)
129+
130+
return UNet2DConditionOutput(sample=sample)
131+
132+
def downblock_forward_XTI(
133+
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
134+
):
135+
output_states = ()
136+
i = 0
137+
138+
for resnet, attn in zip(self.resnets, self.attentions):
139+
if self.training and self.gradient_checkpointing:
140+
141+
def create_custom_forward(module, return_dict=None):
142+
def custom_forward(*inputs):
143+
if return_dict is not None:
144+
return module(*inputs, return_dict=return_dict)
145+
else:
146+
return module(*inputs)
147+
148+
return custom_forward
149+
150+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
151+
hidden_states = torch.utils.checkpoint.checkpoint(
152+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
153+
)[0]
154+
else:
155+
hidden_states = resnet(hidden_states, temb)
156+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
157+
158+
output_states += (hidden_states,)
159+
i += 1
160+
161+
if self.downsamplers is not None:
162+
for downsampler in self.downsamplers:
163+
hidden_states = downsampler(hidden_states)
164+
165+
output_states += (hidden_states,)
166+
167+
return hidden_states, output_states
168+
169+
def upblock_forward_XTI(
170+
self,
171+
hidden_states,
172+
res_hidden_states_tuple,
173+
temb=None,
174+
encoder_hidden_states=None,
175+
upsample_size=None,
176+
):
177+
i = 0
178+
for resnet, attn in zip(self.resnets, self.attentions):
179+
# pop res hidden states
180+
res_hidden_states = res_hidden_states_tuple[-1]
181+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
182+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
183+
184+
if self.training and self.gradient_checkpointing:
185+
186+
def create_custom_forward(module, return_dict=None):
187+
def custom_forward(*inputs):
188+
if return_dict is not None:
189+
return module(*inputs, return_dict=return_dict)
190+
else:
191+
return module(*inputs)
192+
193+
return custom_forward
194+
195+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
196+
hidden_states = torch.utils.checkpoint.checkpoint(
197+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
198+
)[0]
199+
else:
200+
hidden_states = resnet(hidden_states, temb)
201+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
202+
203+
i += 1
204+
205+
if self.upsamplers is not None:
206+
for upsampler in self.upsamplers:
207+
hidden_states = upsampler(hidden_states, upsample_size)
208+
209+
return hidden_states

0 commit comments

Comments
 (0)