Skip to content

Commit

Permalink
Remove a redundant .softmax() in FlashDeformAttn.
Browse files Browse the repository at this point in the history
  • Loading branch information
HELLORPG committed May 17, 2024
1 parent b2b6028 commit 398607a
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions DCNv4_op/DCNv4/modules/flash_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def forward(
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
Expand All @@ -131,7 +130,6 @@ def forward(

# Cat sampling_offsets and attention_weights, generate sampling_loc_attn:
sampling_locations = sampling_locations.flatten(-3).half()
attention_weights = attention_weights.flatten(-2)
sampling_loc_attn = torch.cat([sampling_locations, attention_weights], dim=-1)

output = FlashDeformAttnFunction.apply(
Expand Down

0 comments on commit 398607a

Please sign in to comment.