-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Introduce arith.scaling_extf
and arith.scaling_truncf
#141965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some notes
return rewriter.notifyMatchFailure( | ||
op, "scaling truncf is not using scale operand of type f8E8M0FNU"); | ||
} | ||
auto scaleTy = scaleOperand.getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type
} else if (inputETy.getIntOrFloatBitWidth() > 32) { | ||
inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand); | ||
} | ||
inputTy = inputOperand.getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could update these to f32Type in the if statements above, but it doesn't matter
Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter); | ||
Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter); | ||
Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand); | ||
Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be an extui. But also, there's no need to go i32 here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I first need to calculate unbiased scale value. I can do that while being in i8
.
But then i also need to subtract emax
(max exponent of largest normal number in resultant quantized dtype).
That subtraction could underflow or overflow and that needs to be checked and clamped later on. Therefore i require i32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be an extui.
Thanks. Good catch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so, my bigger complaint is that you can simplify the generated code substantially if you just switch on what kind of type you're extending to
That is, f32 requires nothing - that's already a +- 127 situation
Types shorter than f32 will need the subtraction.
... Also, I'm doing to re-read the code but I'm not convinced this should be subtracting the max normalized exponent. Are we sure it isn't "clamp to the exponent range of the type"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... Ah, we're subtracting the max exponent of the result type
Which can't lead to overflow
This could be substantially simplified if we just use usub_sat
(which we'd need a MLIR Arith op for but that's fairly trivial)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... But also, the code you linked is for quantization
I think it's reasonable to assume that someone implementing quantization will already have done the scale-biasing thing and so we don't need to do it here
Unless we have evidence that the hardware implementations perform the subtraction described here? (We'll probably want to go find the AMD behavior)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... and if you're doing usub_sat, you don't need to unbias the exponent.
But also, I'd make sure this is something that other implementors of scaling_truncf implement so we don't get conflicting lowerings
const llvm::fltSemantics &resultFltSemantics = | ||
llvm::cast<FloatType>(resultETy).getFloatSemantics(); | ||
int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics); | ||
Value cMaxNormalExponent = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip all this if we're in f32 or higher?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrote using f32.
Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero, | ||
inputExponentU8); | ||
Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter); | ||
Value flushedInput = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all seems overcomplicated?
This could just be extending the scale to f32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rewrote using f32. It does simplify things a bit. Thanks
Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>. | ||
Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>. | ||
Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op. | ||
In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand from the description, it doesn't need to be broadcasted, you could use a non-broadcasted tensor of shape <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>?
If that's the case, I don't think it's useful to explain all of these details, broadcasting is just a use-case. If I understood it correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the description needs to be updated - this arith op is set up to do things elementwise because arith ops in general are elementwise and the broadcast scale thing is a special case that gets pattern-matched in a future ArithToAMDGPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to rewrite documentation. Please check again and let me know if it is more clear now.
op, "scaling extf is not using scale operand of type f8E8M0FNU"); | ||
} | ||
Type resultTy = op.getType(); | ||
// extf on scale will essentially create f32 number that is 2^scale and will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why f32? can't resultTy be any float type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check if resultTy >= Float8E8M0FNU and >= inputType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle, scaled truncation from f32 to f32 is a really weird way to spell division,b ut we might want to verify it away
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why f32? can't resultTy be any float type
Changed comment to better reflect what it's doing.
should we check if resultTy >= Float8E8M0FNU and >= inputType
As part of verification, it checks that output dtype is of larger widhth compared to input.
https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1460
Value result = b.create<arith::DivFOp>(flushedInput, scaleF32); | ||
// propagate rounding mode and fast math attributes | ||
Value resultCast = b.create<arith::TruncFOp>( | ||
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check resultTy <= f32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check resultTy <= f32?
Verify() checks that output width is smaller compared to input.
there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter
No, other arith.truncf
are mainly for scales dtype conversion which just operates on exponent and not really affected by rounding mode or fast math.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
Then, Verify()
checks is a strict check. Therefore output_bit_width < input_bit_width.
So this would never really be truncating to f32 resultTy in practice.
But I understand the output of this function is always f32
No, why do you think so ? Output dtype will be whatever user has specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, why do you think so ? Output dtype will be whatever user has specified.
I mean result of the function before truncation. result.dtype = f32, right?
In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor nits! LGTM. I'll wait for @krzysz00.
PatternRewriter &rewriter) { | ||
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); | ||
if (auto shapedTy = dyn_cast<ShapedType>(type)) { | ||
return rewriter.create<arith::ConstantOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can update the attr
here: attr = DenseElementsAttr::get(shapedTy, attr)
. It will return the right thing. (Both are fine to me).
Co-authored-by: Prashant Kumar <pk5561@gmail.com>
// emax is calculated as exponent of the largest normal value in quantized type. | ||
scale.normalize = arith.divf(scale.extf, emax) | ||
scale.clamped = clamp(scale.normalize) // clamp underflows | ||
input.flused = flush_denorms(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are some type conversions for input and scale that are not explained here. Not sure if we want all those details here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, That would be more details than necessary.
Value result = b.create<arith::DivFOp>(flushedInput, scaleF32); | ||
// propagate rounding mode and fast math attributes | ||
Value resultCast = b.create<arith::TruncFOp>( | ||
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, why do you think so ? Output dtype will be whatever user has specified.
I mean result of the function before truncation. result.dtype = f32, right?
In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Note: I'd rather we not land this just yet because I'm still waiting to find out if potential hardware-specific lowerings of I have a suspicion that the answer is "no" - that that adjustment is part of the scale computation process, not the scale application process, and so the semantics of scaling_truncf shouldn't include it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hold for semantics questions, and @llvm/pr-subscribers-mlir-nvgpu for input on Nvidia semantics while I wait on AMD answers
This PR adds
arith.scaling_truncf
andarith.scaling_extf
operations which does the block quantization following OCP MXFP specs listed here https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdfOCP MXFP Spec comes with reference implementation here https://github.com/microsoft/microxcaling/tree/main
Interesting piece of reference code is this method
_quantize_mx
https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.Both
arith.scaling_truncf
andarith.scaling_extf
are designed to be an elementwise operation. Please see description about them inArithOps.td
file for more details.A few things to note about the
arith.scaling_truncf
CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar @tgymnich