Skip to content

Commit e0b305a

Browse files
gs-olivebowang007
authored andcommitted
fix: Out-Of-Bounds bug in Unsqueeze (#1820)
1 parent a8e693f commit e0b305a

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

core/util/trt_util.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,14 @@ nvinfer1::Dims unpadDims(const nvinfer1::Dims& d) {
161161
}
162162

163163
nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val, bool use_zeros) {
164-
// acceptable range for pos is [0, d.nbDims]
165-
TORCHTRT_ASSERT(pos >= 0 && pos <= d.nbDims, "ERROR: Index to unsqueeze is out of bounds.");
164+
// Acceptable range for pos is [-d.nbDims - 1, d.nbDims]
165+
TORCHTRT_ASSERT(
166+
pos >= (-d.nbDims - 1) && pos <= d.nbDims,
167+
"ERROR: Index to unsqueeze is out of bounds. "
168+
<< "Expected value in range [" << (-d.nbDims - 1) << ", " << d.nbDims << "], but got " << pos);
169+
170+
// Unsqueeze with negative dimensions creates a new dimension at that index
171+
pos = (pos < 0) ? (pos + d.nbDims + 1) : pos;
166172

167173
nvinfer1::Dims dims;
168174
for (int i = 0, j = 0; j <= d.nbDims; j++) {

0 commit comments

Comments
 (0)