Skip to content

Introduce arith.scaling_extf and arith.scaling_truncf #141965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: main
Choose a base branch
from

Conversation

umangyadav
Copy link
Contributor

@umangyadav umangyadav commented May 29, 2025

This PR adds arith.scaling_truncf and arith.scaling_extf operations which does the block quantization following OCP MXFP specs listed here https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

OCP MXFP Spec comes with reference implementation here https://github.com/microsoft/microxcaling/tree/main

Interesting piece of reference code is this method _quantize_mx https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.

Both arith.scaling_truncf and arith.scaling_extf are designed to be an elementwise operation. Please see description about them in ArithOps.td file for more details.

A few things to note about the arith.scaling_truncf

  1. OCP Spec flushes denorms to zero.
  2. It normalizes the shared scale exponent by emax (exponent of largest normal number in resulting quantized type).
  3. Clamps normalized shared exponent.
  4. NaNs are propagated

CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar @tgymnich

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some notes

return rewriter.notifyMatchFailure(
op, "scaling truncf is not using scale operand of type f8E8M0FNU");
}
auto scaleTy = scaleOperand.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type

} else if (inputETy.getIntOrFloatBitWidth() > 32) {
inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
}
inputTy = inputOperand.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could update these to f32Type in the if statements above, but it doesn't matter

Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an extui. But also, there's no need to go i32 here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first need to calculate unbiased scale value. I can do that while being in i8.

But then i also need to subtract emax (max exponent of largest normal number in resultant quantized dtype).
That subtraction could underflow or overflow and that needs to be checked and clamped later on. Therefore i require i32

Copy link
Contributor Author

@umangyadav umangyadav May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an extui.

Thanks. Good catch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so, my bigger complaint is that you can simplify the generated code substantially if you just switch on what kind of type you're extending to

That is, f32 requires nothing - that's already a +- 127 situation

Types shorter than f32 will need the subtraction.

... Also, I'm doing to re-read the code but I'm not convinced this should be subtracting the max normalized exponent. Are we sure it isn't "clamp to the exponent range of the type"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Ah, we're subtracting the max exponent of the result type

Which can't lead to overflow

This could be substantially simplified if we just use usub_sat (which we'd need a MLIR Arith op for but that's fairly trivial)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... But also, the code you linked is for quantization

I think it's reasonable to assume that someone implementing quantization will already have done the scale-biasing thing and so we don't need to do it here

Unless we have evidence that the hardware implementations perform the subtraction described here? (We'll probably want to go find the AMD behavior)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and if you're doing usub_sat, you don't need to unbias the exponent.

But also, I'd make sure this is something that other implementors of scaling_truncf implement so we don't get conflicting lowerings

const llvm::fltSemantics &resultFltSemantics =
llvm::cast<FloatType>(resultETy).getFloatSemantics();
int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
Value cMaxNormalExponent =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip all this if we're in f32 or higher?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote using f32.

Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
inputExponentU8);
Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
Value flushedInput =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all seems overcomplicated?

This could just be extending the scale to f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote using f32. It does simplify things a bit. Thanks

Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>.
Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>.
Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.
In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand from the description, it doesn't need to be broadcasted, you could use a non-broadcasted tensor of shape <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>?

If that's the case, I don't think it's useful to explain all of these details, broadcasting is just a use-case. If I understood it correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the description needs to be updated - this arith op is set up to do things elementwise because arith ops in general are elementwise and the broadcast scale thing is a special case that gets pattern-matched in a future ArithToAMDGPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to rewrite documentation. Please check again and let me know if it is more clear now.

op, "scaling extf is not using scale operand of type f8E8M0FNU");
}
Type resultTy = op.getType();
// extf on scale will essentially create f32 number that is 2^scale and will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why f32? can't resultTy be any float type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check if resultTy >= Float8E8M0FNU and >= inputType

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, scaled truncation from f32 to f32 is a really weird way to spell division,b ut we might want to verify it away

Copy link
Contributor Author

@umangyadav umangyadav May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why f32? can't resultTy be any float type

Changed comment to better reflect what it's doing.

should we check if resultTy >= Float8E8M0FNU and >= inputType

As part of verification, it checks that output dtype is of larger widhth compared to input.
https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1460

Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
// propagate rounding mode and fast math attributes
Value resultCast = b.create<arith::TruncFOp>(
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Verify() checks that output width is smaller compared to input.

https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1587

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

No, other arith.truncf are mainly for scales dtype conversion which just operates on exponent and not really affected by rounding mode or fast math.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
Then, Verify() checks is a strict check. Therefore output_bit_width < input_bit_width.
So this would never really be truncating to f32 resultTy in practice.

But I understand the output of this function is always f32

No, why do you think so ? Output dtype will be whatever user has specified.

Copy link
Contributor

@dhernandez0 dhernandez0 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, why do you think so ? Output dtype will be whatever user has specified.

I mean result of the function before truncation. result.dtype = f32, right?

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.

I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor nits! LGTM. I'll wait for @krzysz00.

PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return rewriter.create<arith::ConstantOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update the attr here: attr = DenseElementsAttr::get(shapedTy, attr). It will return the right thing. (Both are fine to me).

// emax is calculated as exponent of the largest normal value in quantized type.
scale.normalize = arith.divf(scale.extf, emax)
scale.clamped = clamp(scale.normalize) // clamp underflows
input.flused = flush_denorms(input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are some type conversions for input and scale that are not explained here. Not sure if we want all those details here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, That would be more details than necessary.

Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
// propagate rounding mode and fast math attributes
Value resultCast = b.create<arith::TruncFOp>(
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
Copy link
Contributor

@dhernandez0 dhernandez0 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, why do you think so ? Output dtype will be whatever user has specified.

I mean result of the function before truncation. result.dtype = f32, right?

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.

I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.

Copy link
Contributor

@dhernandez0 dhernandez0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@krzysz00
Copy link
Contributor

krzysz00 commented Jun 3, 2025

Note: I'd rather we not land this just yet because I'm still waiting to find out if potential hardware-specific lowerings of arith.scaling_truncf will perform the exponent subtraction that this code does.

I have a suspicion that the answer is "no" - that that adjustment is part of the scale computation process, not the scale application process, and so the semantics of scaling_truncf shouldn't include it.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold for semantics questions, and @llvm/pr-subscribers-mlir-nvgpu for input on Nvidia semantics while I wait on AMD answers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants