1
+ import torch
2
+ from typing import Union , List , Optional , Dict , Any , Tuple
3
+ from diffusers .models .unet_2d_condition import UNet2DConditionOutput
4
+
5
+ def unet_forward_XTI (self ,
6
+ sample : torch .FloatTensor ,
7
+ timestep : Union [torch .Tensor , float , int ],
8
+ encoder_hidden_states : torch .Tensor ,
9
+ class_labels : Optional [torch .Tensor ] = None ,
10
+ return_dict : bool = True ,
11
+ ) -> Union [UNet2DConditionOutput , Tuple ]:
12
+ r"""
13
+ Args:
14
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
15
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
16
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
17
+ return_dict (`bool`, *optional*, defaults to `True`):
18
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
19
+
20
+ Returns:
21
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
22
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
23
+ returning a tuple, the first element is the sample tensor.
24
+ """
25
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
26
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
27
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
28
+ # on the fly if necessary.
29
+ default_overall_up_factor = 2 ** self .num_upsamplers
30
+
31
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
32
+ forward_upsample_size = False
33
+ upsample_size = None
34
+
35
+ if any (s % default_overall_up_factor != 0 for s in sample .shape [- 2 :]):
36
+ logger .info ("Forward upsample size to force interpolation output size." )
37
+ forward_upsample_size = True
38
+
39
+ # 0. center input if necessary
40
+ if self .config .center_input_sample :
41
+ sample = 2 * sample - 1.0
42
+
43
+ # 1. time
44
+ timesteps = timestep
45
+ if not torch .is_tensor (timesteps ):
46
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
47
+ # This would be a good case for the `match` statement (Python 3.10+)
48
+ is_mps = sample .device .type == "mps"
49
+ if isinstance (timestep , float ):
50
+ dtype = torch .float32 if is_mps else torch .float64
51
+ else :
52
+ dtype = torch .int32 if is_mps else torch .int64
53
+ timesteps = torch .tensor ([timesteps ], dtype = dtype , device = sample .device )
54
+ elif len (timesteps .shape ) == 0 :
55
+ timesteps = timesteps [None ].to (sample .device )
56
+
57
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
58
+ timesteps = timesteps .expand (sample .shape [0 ])
59
+
60
+ t_emb = self .time_proj (timesteps )
61
+
62
+ # timesteps does not contain any weights and will always return f32 tensors
63
+ # but time_embedding might actually be running in fp16. so we need to cast here.
64
+ # there might be better ways to encapsulate this.
65
+ t_emb = t_emb .to (dtype = self .dtype )
66
+ emb = self .time_embedding (t_emb )
67
+
68
+ if self .config .num_class_embeds is not None :
69
+ if class_labels is None :
70
+ raise ValueError ("class_labels should be provided when num_class_embeds > 0" )
71
+ class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
72
+ emb = emb + class_emb
73
+
74
+ # 2. pre-process
75
+ sample = self .conv_in (sample )
76
+
77
+ # 3. down
78
+ down_block_res_samples = (sample ,)
79
+ down_i = 0
80
+ for downsample_block in self .down_blocks :
81
+ if hasattr (downsample_block , "has_cross_attention" ) and downsample_block .has_cross_attention :
82
+ sample , res_samples = downsample_block (
83
+ hidden_states = sample ,
84
+ temb = emb ,
85
+ encoder_hidden_states = encoder_hidden_states [down_i :down_i + 2 ],
86
+ )
87
+ down_i += 2
88
+ else :
89
+ sample , res_samples = downsample_block (hidden_states = sample , temb = emb )
90
+
91
+ down_block_res_samples += res_samples
92
+
93
+ # 4. mid
94
+ sample = self .mid_block (sample , emb , encoder_hidden_states = encoder_hidden_states [6 ])
95
+
96
+ # 5. up
97
+ up_i = 7
98
+ for i , upsample_block in enumerate (self .up_blocks ):
99
+ is_final_block = i == len (self .up_blocks ) - 1
100
+
101
+ res_samples = down_block_res_samples [- len (upsample_block .resnets ) :]
102
+ down_block_res_samples = down_block_res_samples [: - len (upsample_block .resnets )]
103
+
104
+ # if we have not reached the final block and need to forward the
105
+ # upsample size, we do it here
106
+ if not is_final_block and forward_upsample_size :
107
+ upsample_size = down_block_res_samples [- 1 ].shape [2 :]
108
+
109
+ if hasattr (upsample_block , "has_cross_attention" ) and upsample_block .has_cross_attention :
110
+ sample = upsample_block (
111
+ hidden_states = sample ,
112
+ temb = emb ,
113
+ res_hidden_states_tuple = res_samples ,
114
+ encoder_hidden_states = encoder_hidden_states [up_i :up_i + 3 ],
115
+ upsample_size = upsample_size ,
116
+ )
117
+ up_i += 3
118
+ else :
119
+ sample = upsample_block (
120
+ hidden_states = sample , temb = emb , res_hidden_states_tuple = res_samples , upsample_size = upsample_size
121
+ )
122
+ # 6. post-process
123
+ sample = self .conv_norm_out (sample )
124
+ sample = self .conv_act (sample )
125
+ sample = self .conv_out (sample )
126
+
127
+ if not return_dict :
128
+ return (sample ,)
129
+
130
+ return UNet2DConditionOutput (sample = sample )
131
+
132
+ def downblock_forward_XTI (
133
+ self , hidden_states , temb = None , encoder_hidden_states = None , attention_mask = None , cross_attention_kwargs = None
134
+ ):
135
+ output_states = ()
136
+ i = 0
137
+
138
+ for resnet , attn in zip (self .resnets , self .attentions ):
139
+ if self .training and self .gradient_checkpointing :
140
+
141
+ def create_custom_forward (module , return_dict = None ):
142
+ def custom_forward (* inputs ):
143
+ if return_dict is not None :
144
+ return module (* inputs , return_dict = return_dict )
145
+ else :
146
+ return module (* inputs )
147
+
148
+ return custom_forward
149
+
150
+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states , temb )
151
+ hidden_states = torch .utils .checkpoint .checkpoint (
152
+ create_custom_forward (attn , return_dict = False ), hidden_states , encoder_hidden_states [i ]
153
+ )[0 ]
154
+ else :
155
+ hidden_states = resnet (hidden_states , temb )
156
+ hidden_states = attn (hidden_states , encoder_hidden_states = encoder_hidden_states [i ]).sample
157
+
158
+ output_states += (hidden_states ,)
159
+ i += 1
160
+
161
+ if self .downsamplers is not None :
162
+ for downsampler in self .downsamplers :
163
+ hidden_states = downsampler (hidden_states )
164
+
165
+ output_states += (hidden_states ,)
166
+
167
+ return hidden_states , output_states
168
+
169
+ def upblock_forward_XTI (
170
+ self ,
171
+ hidden_states ,
172
+ res_hidden_states_tuple ,
173
+ temb = None ,
174
+ encoder_hidden_states = None ,
175
+ upsample_size = None ,
176
+ ):
177
+ i = 0
178
+ for resnet , attn in zip (self .resnets , self .attentions ):
179
+ # pop res hidden states
180
+ res_hidden_states = res_hidden_states_tuple [- 1 ]
181
+ res_hidden_states_tuple = res_hidden_states_tuple [:- 1 ]
182
+ hidden_states = torch .cat ([hidden_states , res_hidden_states ], dim = 1 )
183
+
184
+ if self .training and self .gradient_checkpointing :
185
+
186
+ def create_custom_forward (module , return_dict = None ):
187
+ def custom_forward (* inputs ):
188
+ if return_dict is not None :
189
+ return module (* inputs , return_dict = return_dict )
190
+ else :
191
+ return module (* inputs )
192
+
193
+ return custom_forward
194
+
195
+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states , temb )
196
+ hidden_states = torch .utils .checkpoint .checkpoint (
197
+ create_custom_forward (attn , return_dict = False ), hidden_states , encoder_hidden_states [i ]
198
+ )[0 ]
199
+ else :
200
+ hidden_states = resnet (hidden_states , temb )
201
+ hidden_states = attn (hidden_states , encoder_hidden_states = encoder_hidden_states [i ]).sample
202
+
203
+ i += 1
204
+
205
+ if self .upsamplers is not None :
206
+ for upsampler in self .upsamplers :
207
+ hidden_states = upsampler (hidden_states , upsample_size )
208
+
209
+ return hidden_states
0 commit comments