Skip to content

prims.where computes broadcast shape to avoid runtime-trace mismatch #2135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented May 25, 2025

In order to fix the shape mismatch between runtime and trace, this PR exclude pred.shape from the result shape candidates if pred is a CPU scalar tensor.

When the condition of torch.where is a CPU scalar tensor, the broadcast doesn't occur as we can see in the following trace:

def computation(cond, a, b):
  # cond: "cpu b8[]"
  # a: "cpu f32[4, 4]"
  # b: "cpu f32[]"

  # repro.py:11:             return torch.where(cond, a, b)
  t0 = ltorch.where(cond, a, b)  # t0: "cpu f32[]"
    # t0 = prims.where(cond, a, b)  # t0: "cpu f32[]"
  return (t0,)

This causes a mismatch between runtime results and trace outputs.

This trace is obtained by running the script below:

import torch
import thunder

def f(cond, a, b):
    return torch.where(cond, a, b)

if __name__ == "__main__":
    jitted = thunder.jit(f)
    with torch.device("cpu"):
        cond = torch.tensor(True, dtype=torch.bool)
        a = torch.tensor([
            [ 6.5077,  7.6115,  8.6770,  6.2419],
            [ 7.1474, -4.7458, -8.1032,  8.7696],
            [ 7.9486,  5.3679, -7.8606, -5.3753],
            [-5.8697,  7.7565, -1.9403,  4.1512],
        ])
        b = torch.tensor(-2.6034)
    out = jitted(cond, a, b)

    print(thunder.last_traces(jitted)[0])

If tensors are on CUDA, then the trace is

def computation(cond, a, b):
  # cond: "cuda:0 b8[]"
  # a: "cuda:0 f32[4, 4]"
  # b: "cuda:0 f32[]"

  # repro.py:11:             return torch.where(cond, a, b)
  t2 = ltorch.where(cond, a, b)  # t2: "cuda:0 f32[4, 4]"
    # t0 = prims.broadcast_in_dim(cond, (4, 4), ())  # t0: "cuda:0 b8[4, 4]"
    # t1 = prims.broadcast_in_dim(b, (4, 4), ())  # t1: "cuda:0 f32[4, 4]"
    # t2 = prims.where(t0, a, t1)  # t2: "cuda:0 f32[4, 4]"
  return (t2,)

The difference between traces of CPU and CUDA would be attributed to

def _maybe_broadcast(x, shape):
if treat_cpu_scalar_tensors_as_numbers and utils.is_cpu_scalar_tensor(x):
return x
if hasattr(x, "shape"):
if not utils.same_shape(x.shape, common_shape):
return expand(x, common_shape)
return x
.

Rel: #2069

next step:

  • Fix resultdevice
    resultdevice = devices.cpu
    devices_ = tuple(x.device for x in (pred, a, b) if isinstance(x, TensorProxy))
    if len(devices_) > 0:
    resultdevice = devices_[0]
    when pred is a CPU scalar tensor and either a or b is a CUDA tensor.

@crcrpar crcrpar requested review from mruberry, lantiga and t-vi as code owners May 25, 2025 12:09
@crcrpar crcrpar changed the title [where] Force broadcast of cpu scalar tensors Force broadcast cpu scalar condition in where to avoid runtime-trace mismatch May 25, 2025
@crcrpar crcrpar changed the title Force broadcast cpu scalar condition in where to avoid runtime-trace mismatch prims.where computes broadcast shape to avoid runtime-trace mismatch May 25, 2025
@crcrpar
Copy link
Collaborator Author

crcrpar commented May 26, 2025

diff --git a/thunder/core/prims.py b/thunder/core/prims.py
index dec77bf7..4fa12fcc 100644
--- a/thunder/core/prims.py
+++ b/thunder/core/prims.py
@@ -2846,7 +2846,10 @@ def _where_meta(pred: Number | TensorProxy, a: Number | TensorProxy, b: Number |
     resultdevice = devices.cpu
     devices_ = tuple(x.device for x in (pred, a, b) if isinstance(x, TensorProxy))
     if len(devices_) > 0:
-        resultdevice = devices_[0]
+        if len(devices_) == 1 or any(not isinstance(pred, TensorProxy) or pred.numel != 1 or pred.device != devices.cpu):
+            resultdevice = devices_[0]
+        else:
+            result_device = devices_[1]

     # Determines result dtype
     numbertype, tensordtype = utils.check_same_dtype(a, b)

^^^ one awkward implementation for result device

@t-vi
Copy link
Collaborator

t-vi commented May 26, 2025

In order to fix the shape mismatch between runtime and trace, this PR clones compute_broadcast_shape of thunder.clang in thunder.core.prims and call prims.compute_broadcast_shape in the existing compute_broadcast_shape.

I don't think broadcasting in prims is the right thing to do. Why would we not revise the clang op to insert the missing broadcast ops?

@crcrpar
Copy link
Collaborator Author

crcrpar commented May 26, 2025

This PR does not add broadcast inside prim. It uses a piece of clang's broadcast op to compute the output shape. Also broadcast in clang definition seems to puzzle a case of CPU pred and cuda a and b.

@@ -2832,7 +2858,7 @@ def _where_meta(pred: Number | TensorProxy, a: Number | TensorProxy, b: Number |
# Determines output shape
# NOTE Assumes at least one of pred, a, and b is a TensorProxy because of prior check for Number x Number x Number
shapes = tuple(x.shape for x in (pred, a, b) if isinstance(x, TensorProxy))
resultshape = shapes[0]
resultshape = compute_broadcast_shape(*shapes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't typically support broadcasting with the primitives. In fact, just above this there's a check that all the tensors have the same shape

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a check to ensure all tensors have the same shape, but it allows CPU scalar tensors (with True as the default value for treat_cpu_scalar_tensors_as_numbers).

Broadcasting behavior must happen here inside the primitive if we allow passing CPU scalar tensors. We allow this because executors like nvFuser can work with CPU scalar tensors as inputs to CUDA kernels.

@mruberry
Copy link
Collaborator

The broadcast helper should be returning the correct size here: (4, 4). If it isn't, we should fix the helper.

I think the issue is here:

resultshape = shapes[0]

If we want to support where with numbers or CPU tensors that are treated as numbers, then we need to select the shape of the tensor that is not treated as a number (if any such tensor exists).

crcrpar added 4 commits May 30, 2025 00:19
When the `condition` of `torch.where` is a CPU scalar tensor, the
broadcast doesn't occur as we can see in the following trace:
```python
def computation(cond, a, b):
  # cond: "cpu b8[]"
  # a: "cpu f32[4, 4]"
  # b: "cpu f32[]"

  # repro.py:11:             return torch.where(cond, a, b)
  t0 = ltorch.where(cond, a, b)  # t0: "cpu f32[]"
    # t0 = prims.where(cond, a, b)  # t0: "cpu f32[]"
  return (t0,)
```
This causes a mismatch between runtime results and trace outputs.

This trace is obtained by running the script below:
```python
import torch
import thunder

def f(cond, a, b):
    return torch.where(cond, a, b)

if __name__ == "__main__":
    jitted = thunder.jit(f)
    with torch.device("cpu"):
        cond = torch.tensor(True, dtype=torch.bool)
        a = torch.tensor([
            [ 6.5077,  7.6115,  8.6770,  6.2419],
            [ 7.1474, -4.7458, -8.1032,  8.7696],
            [ 7.9486,  5.3679, -7.8606, -5.3753],
            [-5.8697,  7.7565, -1.9403,  4.1512],
        ])
        b = torch.tensor(-2.6034)
    out = jitted(cond, a, b)

    print(thunder.last_traces(jitted)[0])
```

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the fix-where-with-cpu-scalar-tensor-cond branch from ffb925c to 5c5af10 Compare May 29, 2025 15:19
@crcrpar
Copy link
Collaborator Author

crcrpar commented May 29, 2025

I simplified this branch by just filtering out pred if it's a CPU bool scalar tensor.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
resultshape = shapes[0]
# It's possible that `pred` is a CPU bool scalar tensor and either or both of `a` and `b` are a CUDA tensor.
# In that case, `shapes[0]`, i.e., `pred.shape` should not be the result shape.
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still not correct. For example:

cond =  torch.tensor(True)
a = torch.tensor(1)
b = torch.randn((2, 2))
torch.where(cond, a, b)  # tensor([[1., 1.], [1., 1.]])

the shape would be the third shape, not the second.

The same thing happens on CUDA devices. The tensors broadcast. Although on CUDA devices the tensors will already have been broadcast previously before the primitive.

What you can do is take this line from the check_same_shape function:

non_scalar_shapes = tuple(x.shape for x in args if isinstance(x, TensorProxy) and not is_cpu_scalar_tensor(x))

Then:

if len(non_scalar_shapes) > 0:
  resultshape = non_scalar_shapes[0]
else:
  resultshape = shapes[0]

And I believe that will work. We should add a test for all the inputs being "CPU number tensors," and then each of cond, a, or b being tensors that aren't CPU numbers, and verifying the shape is the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my initial approach was nice.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have three shape tuples, each with any number of dimensions and any dimension as input. Then we have a progressive number of checks or transformations to restrict what shapes are allowed right before computing the result shape:

  1. thunder.torch.where calls thunder.clang.where
  2. Inside thunder.clang.where: broadcast input shapes if possible, except for CPU scalar tensors
  3. thunder.clang.where calls thunder.prims.where (where_meta)
  4. Now our three tuples should all have an equal number of dimensions and equal dimensions, except for CPU scalar tensors
  5. Inside where_meta: there's a check check_same_shape after which we certainly have shape tuples with an equal number of dimensions and equal dimensions, except for CPU scalar tensors
  6. Now here comes the bug: using resultshape = shapes[0] is incorrect because it may include a CPU scalar tensor shape.
    Masaki's initial approach was: we had all shapes checked at this point, so let's reuse broadcasting behavior to compute the result shape to ignore possible CPU scalar tensors. It's concise and correct.

Another possible approach is to filter out CPU scalar tensors in

shapes = tuple(x.shape for x in (pred, a, b) if isinstance(x, TensorProxy))

with shapes = tuple(x.shape for x in (pred, a, b) if isinstance(x, TensorProxy) and not is_cpu_scalar_tensor(x)) and account for the case when shapes is empty: resultshape = shapes[0] if shapes else shapes. It's also concise and correct.

We should add a test for all the inputs being "CPU number tensors," and then each of cond, a, or b being tensors that aren't CPU numbers, and verifying the shape is the same.

We don't have any tests for meta function consistency. The tests are being added in #2069.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants