-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pallas Kernel Expected Output Shape Error Using Grids On TPU #25379
Comments
I think this is actually related to handling of 1D arrays in Pallas. Usually, arrays in our kernel are 2D. I'm going to need some time to track down what exactly the problem is, but your code should run if you add a leading (1,) dimension to the block shape and input array shape. |
Ok it resolved the problem! However, I notice that training is much slower when using the kernels. The kernels perform bitwise operations. Could it be that these bitwise operations are being executed on the CPU, even though I wrapped them in a kernel and further wrapped it in torch_xla? Or do I need to change something in my function for it to be more optimized for tpus?
|
Pallas lowers to Mosaic which doesn't have a CPU backend, so it's impossible that it runs on CPU unintentionally. The only way to run on CPU is to pass in interpret=True, but in that case you wouldn't see Mosaic error messages. Without seeing the rest of your code, my guess is that it's related to fusion. When you run your code without custom kernels, XLA will automatically try to fuse neighboring ops together into it's own generated kernels so it can avoid memory copies between HBM <-> VMEM. However, when you use a pallas_call, XLA doesn't know how to fuse into a custom kernel, so you get stuck with redundant copies. Since your kernel is element-wise, it's probably memory-bound, meaning that most of the time spent in the kernel is waiting for memory copies to finish and not actually doing computation. So fusion would actually help a lot here if this is the case. The solution to this problem is to fold in the neighboring operations into the kernel. For example, if you're rounding followed by a bf16 matmul, then you should do the rounding inside the matmul kernel and not as a standalone kernel. |
Description
🐛 Bug
I am trying to write a custom Pallas kernel to use it in TPU. I am using blocking method to keep my kernel from going OOM. However, when I am using grids, it seems that I get kernel problems with the expected output and input shapes. It seems that the chunking / splitting of the input does not perform as expected. I checked that my code indeed has the right shapes, grid and indexing method. However, the kernel itself is getting wrong input.
I think it may be bug in how the TPU is handling the chunking in pallas kernels, but I am not sure. Any help would be appreciated!
To Reproduce
I am attaching here tests for replication. You can see that only the tests with original input tensors larger than block size fails.
My Kernel Code
Debug Output And Stack Trace
RuntimeError: Bad StatusOr access: INVALID_ARGUMENT: Mosaic failed to compile TPU kernel: Failed to verify layout for Mosaic kernel operand 1: XLA layout does not match MLIR layout for an operand of shape f32[192]: expected {0:T(128)}, got {0:T(256)}
But printing the values I provide to the pallas_call
Out shape: ShapeDtypeStruct(shape=(192,), dtype=float32), grid: 2, block_shape: (128,)
Expected behavior
The tests should not fail. When run on GPU they all pass.
Environment
Additional context
The tests for easy replication:
test_pallas_tiling.py.zip
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.33
jaxlib: 0.4.33
numpy: 1.26.4
python: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0]
jax.devices (4 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0) TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-aacdf60c-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')
The text was updated successfully, but these errors were encountered: