-
Notifications
You must be signed in to change notification settings - Fork 19
Thread the scaling type argument throughout fp8 #301
base: gh/drisspg/1/base
Are you sure you want to change the base?
Conversation
amax_buffer: Optional[torch.Tensor] = None, | ||
mm_config: Optional[ScaledMMConfig] = None, | ||
float8_dtype: torch.dtype, | ||
amax_buffer: Optional[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the defualt args since this is always called from inner func with defualt args
float8_experimental/float8_tensor.py
Outdated
@@ -31,6 +28,20 @@ | |||
) | |||
|
|||
|
|||
class ScalingStrategy(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thoughts about using Granularity
, which is more specific than Strategy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah thats a better word, this needed some bikeshedding
return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) | ||
return Float8Tensor( | ||
bits_fp8, | ||
x_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, since we decided to not add scaling_strategy
to torch._scaled_mm
, why do we need it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make this a property of Float8Tensor, e.g. infer from the existing scales... hmm
Actually I might like this more..
We need the enum still since we want modules to specify their granularity
# Summary This PR adds a ScalingGranularity Enum, and threads it though the stack to all the places we call 'tensor_to_amax" and tensor_to_scale. - Currently hardcodes TensroWise.Scaling in Float8Linear, Float8DynamicLinear, Float8InferenceLinear. Asserts that granularity is TensorWise for now. - Added this as a property of WeightWithDynamicFloat8CastTensor, since we need to know a prior how do do the scaling for fp8 comms. ### Testing ``` Shell ============================================================================= test session starts ============================================================================= platform linux -- Python 3.12.4, pytest-7.4.0, pluggy-1.5.0 rootdir: /home/drisspg/meta/float8_experimental plugins: hypothesis-6.104.1 collected 9 items test/test_fsdp2/test_fsdp2_eager.py ......... [100%] ============================================================================= 9 passed in 30.77s ============================================================================== all tests successful ``` [ghstack-poisoned]
Summary
This PR adds a ScalingGranularity Enum, and threads it though the stack to all the places we call 'tensor_to_amax" and tensor_to_scale.
Testing
Stack from ghstack (oldest at bottom):