Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas Kernel Expected Output Shape Error Using Grids On TPU #25379

Open
dshalem opened this issue Dec 10, 2024 · 3 comments
Open

Pallas Kernel Expected Output Shape Error Using Grids On TPU #25379

dshalem opened this issue Dec 10, 2024 · 3 comments
Assignees
Labels
bug Something isn't working type:Bug

Comments

@dshalem
Copy link

dshalem commented Dec 10, 2024

Description

🐛 Bug

I am trying to write a custom Pallas kernel to use it in TPU. I am using blocking method to keep my kernel from going OOM. However, when I am using grids, it seems that I get kernel problems with the expected output and input shapes. It seems that the chunking / splitting of the input does not perform as expected. I checked that my code indeed has the right shapes, grid and indexing method. However, the kernel itself is getting wrong input.

I think it may be bug in how the TPU is handling the chunking in pallas kernels, but I am not sure. Any help would be appreciated!

To Reproduce

I am attaching here tests for replication. You can see that only the tests with original input tensors larger than block size fails.

My Kernel Code

  @jax.jit
  def round_down_and_up(x: jax.Array) -> (jax.Array, jax.Array):
      """
      Simplified wrapper for kernel execution, treating x as a vector.
      Handles explicit padding and alignment with TPU tiling constraints.
      """
      block_size = 128
      # padded_length = (original_length + block_size - 1) // block_size * block_size
      #
      # # Explicitly pad the input tensor
      # if original_length != padded_length:
      #     x = jnp.pad(x, (0, padded_length - original_length), mode="constant", constant_values=0)
  
      # Define block shape and grid
      block_shape = (128,)  # TPU requires blocks divisible by 128 for f32
      grid = (len(x) + block_size - 1) // block_shape[0]
  
      # Define BlockSpec
      block_spec = pl.BlockSpec(block_shape=block_shape, index_map=lambda i: (i,))
  
      # Define output shape
      out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
  
      # # Debugging: Verify padded shape, block shape, and grid
      # jax.debug.print("Input Length: {input_length}, Padded Length: {padded_length}, "
      #                 "BlockShape: {block_shape}, Grid: {grid}",
      #                 input_length=original_length, padded_length=padded_length,
      #                 block_shape=block_shape, grid=grid)
  
      print(f"Out shape: {out_shape}, grid: {grid}, block_shape: {block_shape}")
      # Call the kernel
      x_low, x_high = pl.pallas_call(
          round_down_and_up_bfloat16_kernel,
          out_shape=(out_shape, out_shape),
          grid=(grid,),
          in_specs=(block_spec,),  # Input tiling specification
          out_specs=(block_spec, block_spec),  # Output tiling specification
      )(x)
  
      return x_low, x_high

Debug Output And Stack Trace

RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 1: XLA layout does not match MLIR layout for an operand of shape f32[192]: expected {0:T(128)}, got {0:T(256)}

But printing the values I provide to the pallas_call
Out shape: ShapeDtypeStruct(shape=(192,), dtype=float32), grid: 2, block_shape: (128,)

Expected behavior

The tests should not fail. When run on GPU they all pass.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU
  • torch_xla version: 2.4

Additional context

The tests for easy replication:
test_pallas_tiling.py.zip

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0]
jax.devices (4 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0) TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-aacdf60c-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')

@dshalem dshalem added the bug Something isn't working label Dec 10, 2024
@justinjfu
Copy link
Collaborator

I think this is actually related to handling of 1D arrays in Pallas. Usually, arrays in our kernel are 2D.

I'm going to need some time to track down what exactly the problem is, but your code should run if you add a leading (1,) dimension to the block shape and input array shape.

@dshalem
Copy link
Author

dshalem commented Dec 11, 2024

I think this is actually related to handling of 1D arrays in Pallas. Usually, arrays in our kernel are 2D.

I'm going to need some time to track down what exactly the problem is, but your code should run if you add a leading (1,) dimension to the block shape and input array shape.

Ok it resolved the problem!

However, I notice that training is much slower when using the kernels. The kernels perform bitwise operations. Could it be that these bitwise operations are being executed on the CPU, even though I wrapped them in a kernel and further wrapped it in torch_xla? Or do I need to change something in my function for it to be more optimized for tpus?

@jax.jit
def round_down_and_up(x: jax.Array) -> (jax.Array, jax.Array):
    """
    Simplified wrapper for kernel execution, treating x as a vector.
    Handles explicit padding and alignment with TPU tiling constraints.
    """
    block_size = 512
    # Define block shape and grid
    # adding leading 1, to block_shape to make it 2D. TPU requires blocks divisible by 128 for f32
    block_shape = (1, block_size)

    grid = (x.shape[1] + block_size - 1) // block_shape[1]
    grid = (1, grid)

    # Define BlockSpec
    index_map_for_2d = lambda i, j: (i, j)
    block_spec = pl.BlockSpec(block_shape=block_shape, index_map=index_map_for_2d)

    # Define output shape
    out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)

    # print(f"Input shape: {x.shape}")
    # print(f"Out shape: {out_shape}, grid: {grid}, block_shape: {block_shape}")
    #
    # print(f"Out shape: {out_shape}, grid: {grid}, block_shape: {block_shape}")
    # Call the kernel
    x_low, x_high = pl.pallas_call(
        round_down_and_up_bfloat16_kernel,
        out_shape=(out_shape, out_shape),
        # out_shape = (jax.ShapeDtypeStruct(x.shape, x.dtype), jax.ShapeDtypeStruct(x.shape, x.dtype)),
        # grid=(grid,),
        grid=grid,
        in_specs=(block_spec,),  # Input tiling specification
        out_specs=(block_spec, block_spec),  # Output tiling specification
    )(x)

    return x_low, x_high

@justinjfu
Copy link
Collaborator

Pallas lowers to Mosaic which doesn't have a CPU backend, so it's impossible that it runs on CPU unintentionally. The only way to run on CPU is to pass in interpret=True, but in that case you wouldn't see Mosaic error messages.

Without seeing the rest of your code, my guess is that it's related to fusion. When you run your code without custom kernels, XLA will automatically try to fuse neighboring ops together into it's own generated kernels so it can avoid memory copies between HBM <-> VMEM. However, when you use a pallas_call, XLA doesn't know how to fuse into a custom kernel, so you get stuck with redundant copies. Since your kernel is element-wise, it's probably memory-bound, meaning that most of the time spent in the kernel is waiting for memory copies to finish and not actually doing computation. So fusion would actually help a lot here if this is the case.

The solution to this problem is to fold in the neighboring operations into the kernel. For example, if you're rounding followed by a bf16 matmul, then you should do the rounding inside the matmul kernel and not as a standalone kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working type:Bug
Projects
None yet
Development

No branches or pull requests

2 participants