Skip to content

Commit

Permalink
Fix the issue of mismatch of params for FlashDeformAttnFunction.
Browse files Browse the repository at this point in the history
  • Loading branch information
HELLORPG committed May 17, 2024
1 parent 97dcd91 commit b2b6028
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions DCNv4_op/DCNv4/modules/flash_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,17 @@ def forward(
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1])
)

# Cat sampling_offsets and attention_weights, generate sampling_loc_attn:
sampling_locations = sampling_locations.flatten(-3).half()
attention_weights = attention_weights.flatten(-2)
sampling_loc_attn = torch.cat([sampling_locations, attention_weights], dim=-1)

output = FlashDeformAttnFunction.apply(
value,
input_spatial_shapes,
input_level_start_index,
sampling_locations,
attention_weights,
sampling_loc_attn,
self.im2col_step,
self.n_points
)
Expand Down

0 comments on commit b2b6028

Please sign in to comment.