Skip to content

torch.bmm kernel fusion #28

Open
Open
@Edenzzzz

Description

@Edenzzzz

@DanFu09 Thanks for open-sourcing the code!
I see that in your previous fly repo(https://github.com/HazyResearch/fly), you used cast_inputs=torch.float16 for BlockdiagButterflyMultiply, but changed it to bf16 here. I wonder if there's a specific reason (e.g. fp16 training not converging due to range issues)?
Also, I wonder if there are opportunities for fusing the two bmm operations into one kernel? It seems hard to find the exact kernel torch is calling though.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions