Skip to content

Commit

Permalink
Switch D to BF16, also need to gate for macOS 14.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 31, 2024
1 parent fe99c48 commit de59d3e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
6 changes: 3 additions & 3 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisi
// unrolled (head dimension vastly exceeds head block dimension).
if (lowPrecisionIntermediates) {
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16;
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::BF16;
} else {
memoryPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
memoryPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
Expand Down Expand Up @@ -356,7 +356,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
// The register precision of L/D only counts for backward key-value.
if (lowPrecisionIntermediates) {
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::D] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
} else {
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
Expand All @@ -383,7 +383,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
registerPrecisions[AttentionOperand::S] = lowPrecisionInputs ? GEMMOperandPrecision::FP16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::dP] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dS] = hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
} else {
registerPrecisions[AttentionOperand::S] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP32;
Expand Down
19 changes: 18 additions & 1 deletion lib/nnc/mfa/v2/AttentionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,24 @@ unsigned short AttentionKernel::createThreadgroupMemoryAllocation() const noexce
std::string AttentionKernel::createSource() const noexcept {
CodeWriter source;

bool injectBF16Methods = (memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::D] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dO] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dV] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dP] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dS] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dK] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dQ] == GEMMOperandPrecision::BF16);
bool injectBF16Methods = false;
switch (type.value) {
case AttentionKernelType::forward:
if ((memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16)) {
injectBF16Methods = true;
}
break;
case AttentionKernelType::backwardQuery:
if ((memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::D] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dO] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dP] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dS] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dQ] == GEMMOperandPrecision::BF16)) {
injectBF16Methods = true;
}
break;
case AttentionKernelType::backwardKeyValue:
if ((memoryPrecisions[AttentionOperand::Q] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::K] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::S] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::P] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::V] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::O] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::L] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::D] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dO] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dV] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dP] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dS] == GEMMOperandPrecision::BF16) || (memoryPrecisions[AttentionOperand::dK] == GEMMOperandPrecision::BF16)) {
injectBF16Methods = true;
}
break;
}

// Inject the contents of the headers.
source += createMetalSimdgroupEvent() + "\n";
Expand Down

0 comments on commit de59d3e

Please sign in to comment.