-
Notifications
You must be signed in to change notification settings - Fork 190
Description
The function torch_scatter.composite.scatter_softmax
is supposed to adpated to torch.float16. Unfortunately, in the torch.cuda.amp.autocast()
context, scatter_softmax
returns torch.float32 due to the source code below:
recentered_scores_exp = recentered_scores.exp()
This is because the torch.exp()
always return the torch.float32 in the torch.cuda.amp.autocast()
context (see Ops that can autocast to float32 in https://pytorch.org/docs/stable/amp.html).
What about change the code recentered_scores_exp = recentered_scores.exp()
-> recentered_scores_exp = recentered_scores.exp_()
? Because the torch.exp_()
returns the same type tensors of the input.
The reproduction is as belows:
In addition, I think there is no need of eps
in the scatter_softmax
. This is because the recentered_scores_exp
for any input indices should be already greater than or equal to 1. It is tricky compared to common softmax.