fused_mha default to false for nets with relu^2 activation#2378
fused_mha default to false for nets with relu^2 activation#2378borg323 wants to merge 1 commit intoLeelaChessZero:masterfrom
Conversation
|
I suspect that other networks will have NaN problems too. I suspect that all networks are fp16 unsafe because they were trained to use fp32. |
|
However the cutlass code was robust so far. It may be that updating the code to use the latest upstream version is to blame. Even so, ReLU^2 is a likely source of overflows in fp16, so avoiding unsafe code still makes sense to me. |
|
It makes sense to disable it for a network which have clear problems. I'm worried that BT4 might have a rare NaN problem. I was also thinking how to make a generic fix for overflows. I raised it here in case you might come up with a simple generic solution. I think this should be merged, if none comes up with a simple solution soon. |
No description provided.