Description
Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow.js): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Windows
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
- TensorFlow.js installed from (npm or script link): npm
- TensorFlow.js version (use command below): @tensorflow/tfjs-backend-webgpu: ^4.21.0
- Browser version: Google Chrome 128.0.6613.114 (Official Build) (64-bit)
- Tensorflow.js Converter Version: Not specified
Describe the current behavior
When attempting to perform matrix multiplication using tf.matMul(r, s)
where tensor r
has shape [1493284, 3, 3] and tensor s
has shape [1493284, 3, 3], an error is thrown: "Dispatch size exceeds WebGPU limits in Y or Z dimension."
The error occurs in the reshapeDispatch
function of the WebGPU backend when it tries to handle a dispatch shape of [1, 1, 1493284].
Describe the expected behavior
The matrix multiplication should be performed successfully without throwing an error related to dispatch size limits.
Standalone code to reproduce the issue
const r = tf.zeros([1493284, 3, 3]);
const s = tf.zeros([1493284, 3, 3]);
const l = tf.matMul(r, s);
Other info / logs
The error is triggered by this assertion in the WebGPU backend code:
tf.util.assert(dispatch[0] > MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE &&
layout.y === undefined && layout.z === undefined, function () { return 'Dispatch size exceeds WebGPU limits in Y or Z dimension.'; });
This occurs in the reshapeDispatch
function, which is called to handle the dispatch shape [1, 1, 1493284] generated by the matrix multiplication operation.
The full implementation of the reshapeDispatch
function can be found at:
https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-webgpu/dist/tf-backend-webgpu.js (around line 1252)