Skip to content

scatter_softmax for torch.float16 #238

@Veason-silverbullet

Description

@Veason-silverbullet

The function torch_scatter.composite.scatter_softmax is supposed to adpated to torch.float16. Unfortunately, in the torch.cuda.amp.autocast() context, scatter_softmax returns torch.float32 due to the source code below:

recentered_scores_exp = recentered_scores.exp()

This is because the torch.exp() always return the torch.float32 in the torch.cuda.amp.autocast() context (see Ops that can autocast to float32 in https://pytorch.org/docs/stable/amp.html).

What about change the code recentered_scores_exp = recentered_scores.exp() -> recentered_scores_exp = recentered_scores.exp_()? Because the torch.exp_() returns the same type tensors of the input.

The reproduction is as belows:
scatter_softmax


In addition, I think there is no need of eps in the scatter_softmax. This is because the recentered_scores_exp for any input indices should be already greater than or equal to 1. It is tricky compared to common softmax.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions