Skip to content

Commit

Permalink
[JS/WebGPU] Fixed bugs in inputs validation of Resize (#21955)
Browse files Browse the repository at this point in the history
- 'scales' and 'sizes' may be empty tensor, make sure it's 1D tensor and
non-empty
- Make sure 'scales' and 'sizes' if present its length is non-zero
  • Loading branch information
Honry authored Oct 5, 2024
1 parent b5ef855 commit 39c8b37
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ const validateInputs = (
throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize');
}

if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) {
if (
scalesInputIndex > 0 &&
inputs.length > scalesInputIndex &&
inputs[scalesInputIndex].dims.length === 1 &&
inputs[scalesInputIndex].dims[0] > 0
) {
inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value));
if (
scales.length !== 0 &&
Expand All @@ -127,18 +132,23 @@ const validateInputs = (
updateScales(scales, attributes.axes, rank).forEach((value, index) => (scales[index] = value));
}
}
if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) {
if (
sizesInputIndex > 0 &&
inputs.length > sizesInputIndex &&
inputs[sizesInputIndex].dims.length === 1 &&
inputs[sizesInputIndex].dims[0] > 0
) {
inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value)));
if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) {
if (sizes.length !== 0 && sizes.length !== rank && opsetVersion >= 18 && sizes.length !== attributes.axes.length) {
throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up');
}
}

if (attributes.axes.length > 0) {
if (scales.length !== attributes.axes.length) {
if (scales.length !== 0 && scales.length !== attributes.axes.length) {
throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified');
}
if (sizes.length !== attributes.axes.length) {
if (sizes.length !== 0 && sizes.length !== attributes.axes.length) {
throw new Error('Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified');
}
}
Expand Down

0 comments on commit 39c8b37

Please sign in to comment.