-
Notifications
You must be signed in to change notification settings - Fork 19
make all 3 gemms in Float8Linear support configurability, not user facing #315
Conversation
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
inpt_tensor: torch.Tensor, | ||
linear_mm_config: LinearMMConfig, | ||
reduce_amax: bool = False, | ||
gemm_input_role: GemmInputRole = GemmInputRole.X, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe no default value for gemm role?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can clean up in a separate PR, there is extra complexity because we'd need to change the argument order
self.backward_config = ScaledMMConfig( | ||
emulate, False, False, config.pad_inner_dim | ||
# TODO(future): user level configuration of gemms | ||
self.linear_mm_config = LinearMMConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[I think this might be another stylistic thing so no need to change]:
I think I would actually make this a func and then super document it. Its not very clear reading this what everything does so I would clearly explain in that func the exact recipe that we choose by default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't user facing, so we can clean up at any time
@@ -76,6 +84,7 @@ def float8_cat(aten_op, args, kwargs=None): | |||
scale = chunked_tensors[0]._scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future PR:
We should code share between unflatten/flatten and here to just splat out the extra metadata that lives on the tensor
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: not ready for review yet Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
…not user facing" Summary: This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable: 1. add `LinearMMConfig` to `Float8Tensor` to tie together the three `ScaledMMConfig` objects, one per gemm 2. add `GemmInputRole` to `Float8Tensor` to specify how to pick the right config 3. plumb all of these throughout the codebase Note that none of this is user facing, and there is no logic change. Planned follow-ups: * a future PR will make the per-gemm behavior configurable in a user facing way, which will hook up to the objects introduced in this PR * a future PR will update the naming from x/w/dL_dY to input/weight/grad_output throughout the codebase Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This pull request has been merged in c58fb5d. |
Stack from ghstack (oldest at bottom):
Summary:
This PR adds some plumbing for how to eventually make all 3 gemms in a linear fwd/bwd configurable:
LinearMMConfig
toFloat8Tensor
to tie together the threeScaledMMConfig
objects, one per gemmGemmInputRole
toFloat8Tensor
to specify how to pick the right configNote that none of this is user facing, and there is no logic change. Planned follow-ups:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59973551