Open
Description
@DanFu09 Thanks for open-sourcing the code!
I see that in your previous fly repo(https://github.com/HazyResearch/fly), you used cast_inputs=torch.float16 for BlockdiagButterflyMultiply, but changed it to bf16 here. I wonder if there's a specific reason (e.g. fp16 training not converging due to range issues)?
Also, I wonder if there are opportunities for fusing the two bmm operations into one kernel? It seems hard to find the exact kernel torch is calling though.
Metadata
Metadata
Assignees
Labels
No labels