Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/lib/validate-input.js
Original file line number Diff line number Diff line change
Expand Up @@ -831,3 +831,33 @@ export function validateTileParams(input, repetitions) {
`Invalid repetitions ${repetitions} - it should be an Array of positive integers.`);
}
}

export function validatePadParams(input, beginningPadding, endingPadding, mode) {
const inputRank = input.rank;
if (inputRank === 0) {
throw new Error(`The input's rank should be greater than 0.`);
}
Comment on lines +837 to +839
Copy link

@fdwr fdwr May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the spec, pad currently has this odd outlier in the algorithm steps "If input’s rank is 0, then throw a TypeError", but I don't see any other operator with that explicitly, and also, the table for pad says that any rank is allowed...

operand allowed data types allowed ranks
input any N
output same as input same as input

...which is inconsistent. Logically, pad should be able to accept a scalar, as the beginningPadding.length (==0) and endingPadding.length (==0) will still equal rank (==0), and it's simply a nop, but I don't know if the spec should support 0D, because I fear that many backends will not (ORT with the DML EP works fine with this case, but then ORT's CPU EP breaks, and CoreML/TF probably will too).

Anyway, resolve this, as the code is consistent with the current spec.

if (beginningPadding.length !== inputRank) {
throw new Error(`Invalid beginningPadding, beginningPadding's size ${beginningPadding.length}` +
` is not equal to input's rank ${inputRank}.`);
}
if (endingPadding.length !== inputRank) {
throw new Error(`Invalid endingPadding, endingPadding's size ${beginningPadding.length} is ` +
`not equal to input's rank ${inputRank}.`);
}
if (mode === 'reflection') {
const inputShape = input.shape;
for (let index = 0; index < inputRank; ++index) {
if (beginningPadding[index] >= inputShape[index]) {
throw new Error(`Invalid beginningPadding on reflection mode, beginningPadding[index] ` +
`${beginningPadding[index]} is greater than or equal to inputShape[index] ` +
`${inputShape[index]}.`);
}
if (endingPadding[index] >= inputShape[index]) {
throw new Error(`Invalid endingPadding on reflection mode, endingPadding[index] ` +
`${endingPadding[index]} is greater than or equal to inputShape[index] ` +
`${inputShape[index]}.`);
}
}
}
}
12 changes: 6 additions & 6 deletions src/pad.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'use strict';

import {Tensor} from './lib/tensor.js';
import {validatePadParams} from './lib/validate-input.js';

/**
* Get mapped location from source tensor.
Expand All @@ -24,16 +25,14 @@ function getMappedLocation(location, inputShape, beginningPadding, mode) {
}
}
} else {
// reflection mode or symmetric mode
const offset = mode === 'symmetric' ? 1 : 0;
// reflection mode
for (let i = 0; i < rank; i++) {
if (mappedLocation[i] < beginningPadding[i]) {
mappedLocation[i] = beginningPadding[i] + (beginningPadding[i] - mappedLocation[i]) -
beginningPadding[i] - offset;
beginningPadding[i];
} else if (mappedLocation[i] >= beginningPadding[i] + inputShape[i]) {
mappedLocation[i] = beginningPadding[i] + inputShape[i] - 1 -
(mappedLocation[i] - (beginningPadding[i] + inputShape[i] -1)) -
beginningPadding[i] + offset;
(mappedLocation[i] - (beginningPadding[i] + inputShape[i] -1)) - beginningPadding[i];
} else {
mappedLocation[i] -= beginningPadding[i];
}
Expand Down Expand Up @@ -66,7 +65,7 @@ function updateOutputElement(index, source, destination, beginningPadding, mode,
if (needPadding) {
if (mode === 'constant') {
result = value;
} else if (mode === 'edge' || mode === 'reflection' || mode === 'symmetric') {
} else if (mode === 'edge' || mode === 'reflection') {
const targetLocation = getMappedLocation(location, sourceShape, beginningPadding, mode);
result = source.getValueByLocation(targetLocation);
} else {
Expand Down Expand Up @@ -95,6 +94,7 @@ export function pad(
mode='constant',
value=0,
} = {}) {
validatePadParams(input, beginningPadding, endingPadding, mode);
const outputShape = input.shape.map((v, i) => v + beginningPadding[i] + endingPadding[i]);
const output = new Tensor(outputShape);
for (let i = 0; i < output.size; ++i) {
Expand Down
87 changes: 0 additions & 87 deletions test/pad_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -167,91 +167,4 @@ describe('test pad', function() {
],
});
});

it('pad symmetric mode 2D', function() {
testPad(
{
shape: [2, 3],
values: [1, 2, 3, 4, 5, 6],
},
[1, 2],
[1, 2],
{
mode: 'symmetric',
},
{
shape: [4, 7],
values: [
2., 1., 1., 2., 3., 3., 2.,
2., 1., 1., 2., 3., 3., 2.,
5., 4., 4., 5., 6., 6., 5.,
5., 4., 4., 5., 6., 6., 5.,
],
});
});

it('pad symmetric mode 4D', function() {
testPad(
{
shape: [2, 2, 3, 3],
values: [
0, 1, 2,
3, 4, 5,
6, 7, 8,

9, 10, 11,
12, 13, 14,
15, 16, 17,

18, 19, 20,
21, 22, 23,
24, 25, 26,

27, 28, 29,
30, 31, 32,
33, 34, 35,
],
},
[0, 0, 2, 2],
[0, 0, 2, 2],
{
mode: 'symmetric',
},
{
shape: [2, 2, 7, 7],
values: [
4, 3, 3, 4, 5, 5, 4,
1, 0, 0, 1, 2, 2, 1,
1, 0, 0, 1, 2, 2, 1,
4, 3, 3, 4, 5, 5, 4,
7, 6, 6, 7, 8, 8, 7,
7, 6, 6, 7, 8, 8, 7,
4, 3, 3, 4, 5, 5, 4,

13, 12, 12, 13, 14, 14, 13,
10, 9, 9, 10, 11, 11, 10,
10, 9, 9, 10, 11, 11, 10,
13, 12, 12, 13, 14, 14, 13,
16, 15, 15, 16, 17, 17, 16,
16, 15, 15, 16, 17, 17, 16,
13, 12, 12, 13, 14, 14, 13,

22, 21, 21, 22, 23, 23, 22,
19, 18, 18, 19, 20, 20, 19,
19, 18, 18, 19, 20, 20, 19,
22, 21, 21, 22, 23, 23, 22,
25, 24, 24, 25, 26, 26, 25,
25, 24, 24, 25, 26, 26, 25,
22, 21, 21, 22, 23, 23, 22,

31, 30, 30, 31, 32, 32, 31,
28, 27, 27, 28, 29, 29, 28,
28, 27, 27, 28, 29, 29, 28,
31, 30, 30, 31, 32, 32, 31,
34, 33, 33, 34, 35, 35, 34,
34, 33, 33, 34, 35, 35, 34,
31, 30, 30, 31, 32, 32, 31,
],
});
});
});