-
Notifications
You must be signed in to change notification settings - Fork 96
prims.where
computes broadcast shape to avoid runtime-trace mismatch
#2135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
where
to avoid runtime-trace mismatch
where
to avoid runtime-trace mismatchprims.where
computes broadcast shape to avoid runtime-trace mismatch
diff --git a/thunder/core/prims.py b/thunder/core/prims.py
index dec77bf7..4fa12fcc 100644
--- a/thunder/core/prims.py
+++ b/thunder/core/prims.py
@@ -2846,7 +2846,10 @@ def _where_meta(pred: Number | TensorProxy, a: Number | TensorProxy, b: Number |
resultdevice = devices.cpu
devices_ = tuple(x.device for x in (pred, a, b) if isinstance(x, TensorProxy))
if len(devices_) > 0:
- resultdevice = devices_[0]
+ if len(devices_) == 1 or any(not isinstance(pred, TensorProxy) or pred.numel != 1 or pred.device != devices.cpu):
+ resultdevice = devices_[0]
+ else:
+ result_device = devices_[1]
# Determines result dtype
numbertype, tensordtype = utils.check_same_dtype(a, b) ^^^ one awkward implementation for result device |
I don't think broadcasting in prims is the right thing to do. Why would we not revise the clang op to insert the missing broadcast ops? |
This PR does not add broadcast inside prim. It uses a piece of clang's broadcast op to compute the output shape. Also broadcast in clang definition seems to puzzle a case of CPU pred and cuda a and b. |
thunder/core/prims.py
Outdated
@@ -2832,7 +2858,7 @@ def _where_meta(pred: Number | TensorProxy, a: Number | TensorProxy, b: Number | | |||
# Determines output shape | |||
# NOTE Assumes at least one of pred, a, and b is a TensorProxy because of prior check for Number x Number x Number | |||
shapes = tuple(x.shape for x in (pred, a, b) if isinstance(x, TensorProxy)) | |||
resultshape = shapes[0] | |||
resultshape = compute_broadcast_shape(*shapes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't typically support broadcasting with the primitives. In fact, just above this there's a check that all the tensors have the same shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a check to ensure all tensors have the same shape, but it allows CPU scalar tensors (with True as the default value for treat_cpu_scalar_tensors_as_numbers
).
Broadcasting behavior must happen here inside the primitive if we allow passing CPU scalar tensors. We allow this because executors like nvFuser can work with CPU scalar tensors as inputs to CUDA kernels.
The broadcast helper should be returning the correct size here: (4, 4). If it isn't, we should fix the helper. I think the issue is here: lightning-thunder/thunder/core/prims.py Line 2835 in d0f6474
If we want to support where with numbers or CPU tensors that are treated as numbers, then we need to select the shape of the tensor that is not treated as a number (if any such tensor exists). |
When the `condition` of `torch.where` is a CPU scalar tensor, the broadcast doesn't occur as we can see in the following trace: ```python def computation(cond, a, b): # cond: "cpu b8[]" # a: "cpu f32[4, 4]" # b: "cpu f32[]" # repro.py:11: return torch.where(cond, a, b) t0 = ltorch.where(cond, a, b) # t0: "cpu f32[]" # t0 = prims.where(cond, a, b) # t0: "cpu f32[]" return (t0,) ``` This causes a mismatch between runtime results and trace outputs. This trace is obtained by running the script below: ```python import torch import thunder def f(cond, a, b): return torch.where(cond, a, b) if __name__ == "__main__": jitted = thunder.jit(f) with torch.device("cpu"): cond = torch.tensor(True, dtype=torch.bool) a = torch.tensor([ [ 6.5077, 7.6115, 8.6770, 6.2419], [ 7.1474, -4.7458, -8.1032, 8.7696], [ 7.9486, 5.3679, -7.8606, -5.3753], [-5.8697, 7.7565, -1.9403, 4.1512], ]) b = torch.tensor(-2.6034) out = jitted(cond, a, b) print(thunder.last_traces(jitted)[0]) ``` Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
ffb925c
to
5c5af10
Compare
I simplified this branch by just filtering out |
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
resultshape = shapes[0] | ||
# It's possible that `pred` is a CPU bool scalar tensor and either or both of `a` and `b` are a CUDA tensor. | ||
# In that case, `shapes[0]`, i.e., `pred.shape` should not be the result shape. | ||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still not correct. For example:
cond = torch.tensor(True)
a = torch.tensor(1)
b = torch.randn((2, 2))
torch.where(cond, a, b) # tensor([[1., 1.], [1., 1.]])
the shape would be the third shape, not the second.
The same thing happens on CUDA devices. The tensors broadcast. Although on CUDA devices the tensors will already have been broadcast previously before the primitive.
What you can do is take this line from the check_same_shape
function:
non_scalar_shapes = tuple(x.shape for x in args if isinstance(x, TensorProxy) and not is_cpu_scalar_tensor(x))
Then:
if len(non_scalar_shapes) > 0:
resultshape = non_scalar_shapes[0]
else:
resultshape = shapes[0]
And I believe that will work. We should add a test for all the inputs being "CPU number tensors," and then each of cond, a, or b being tensors that aren't CPU numbers, and verifying the shape is the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think my initial approach was nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have three shape tuples, each with any number of dimensions and any dimension as input. Then we have a progressive number of checks or transformations to restrict what shapes are allowed right before computing the result shape:
thunder.torch.where
callsthunder.clang.where
- Inside
thunder.clang.where
: broadcast input shapes if possible, except for CPU scalar tensors thunder.clang.where
callsthunder.prims.where
(where_meta
)- Now our three tuples should all have an equal number of dimensions and equal dimensions, except for CPU scalar tensors
- Inside
where_meta
: there's a checkcheck_same_shape
after which we certainly have shape tuples with an equal number of dimensions and equal dimensions, except for CPU scalar tensors - Now here comes the bug: using
resultshape = shapes[0]
is incorrect because it may include a CPU scalar tensor shape.
Masaki's initial approach was: we had all shapes checked at this point, so let's reuse broadcasting behavior to compute the result shape to ignore possible CPU scalar tensors. It's concise and correct.
Another possible approach is to filter out CPU scalar tensors in
lightning-thunder/thunder/core/prims.py
Line 2834 in bc3a5d0
shapes = tuple(x.shape for x in (pred, a, b) if isinstance(x, TensorProxy)) |
with
shapes = tuple(x.shape for x in (pred, a, b) if isinstance(x, TensorProxy) and not is_cpu_scalar_tensor(x))
and account for the case when shapes
is empty: resultshape = shapes[0] if shapes else shapes
. It's also concise and correct.
We should add a test for all the inputs being "CPU number tensors," and then each of cond, a, or b being tensors that aren't CPU numbers, and verifying the shape is the same.
We don't have any tests for meta function consistency. The tests are being added in #2069.
In order to fix the shape mismatch between runtime and trace, this PR exclude
pred.shape
from the result shape candidates ifpred
is a CPU scalar tensor.When the
condition
oftorch.where
is a CPU scalar tensor, the broadcast doesn't occur as we can see in the following trace:This causes a mismatch between runtime results and trace outputs.
This trace is obtained by running the script below:
If tensors are on CUDA, then the trace is
The difference between traces of CPU and CUDA would be attributed to
lightning-thunder/thunder/clang/__init__.py
Lines 1410 to 1417 in d0f6474
Rel: #2069
next step:
resultdevice
lightning-thunder/thunder/core/prims.py
Lines 2820 to 2823 in d0f6474
pred
is a CPU scalar tensor and eithera
orb
is a CUDA tensor.