Skip to content

Commit 05695b6

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add checks for compute_slice (#15647)
Summary: Add safety checks to compute_slice, to ensure that we: 1. Do not read outside of the src tensor bounds 2. Do not write outside of the output tensor bounds Also pass in KernelRuntimeContext to use ET_KERNEL_CHECK_MSG and make errors non-fatal. Differential Revision: D86433966
1 parent d361573 commit 05695b6

File tree

6 files changed

+31
-5
lines changed

6 files changed

+31
-5
lines changed

backends/cadence/fusion_g3/operators/op_slice_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ Tensor& slice_copy_Tensor_out(
123123
InvalidArgument,
124124
out);
125125

126-
torch::executor::compute_slice(in, dim, start, length, step, out);
126+
torch::executor::compute_slice(ctx, in, dim, start, length, step, out);
127127
}
128128

129129
return out;

backends/cadence/hifi/operators/op_slice_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Tensor& slice_copy_Tensor_out(
6464
InvalidArgument,
6565
out);
6666

67-
compute_slice(in, dim, start, length, step, out);
67+
compute_slice(ctx, in, dim, start, length, step, out);
6868

6969
return out;
7070
}

kernels/portable/cpu/op_narrow_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Tensor& narrow_copy_out(
4646
out);
4747

4848
if (length != 0) {
49-
compute_slice(in, dim, start, length, 1, out);
49+
compute_slice(ctx, in, dim, start, length, 1, out);
5050
}
5151

5252
return out;

kernels/portable/cpu/op_slice_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Tensor& slice_copy_Tensor_out(
5555
InvalidArgument,
5656
out);
5757

58-
compute_slice(in, dim, start, length, step, out);
58+
compute_slice(ctx, in, dim, start, length, step, out);
5959

6060
return out;
6161
}

kernels/portable/cpu/util/slice_util.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,33 @@ int64_t adjust_slice_indices(
150150
}
151151

152152
void compute_slice(
153+
KernelRuntimeContext& ctx,
153154
const Tensor& in,
154155
int64_t dim,
155156
int64_t start,
156157
int64_t length,
157158
int64_t step,
158159
Tensor& out) {
160+
ET_KERNEL_CHECK_MSG(
161+
ctx,
162+
dim < in.dim(),
163+
InvalidArgument,
164+
/* void */,
165+
"Requested dim is larger than input tensor dim; dim = %" PRId64,
166+
dim);
159167
size_t dim_length = in.size(dim);
160-
168+
ET_KERNEL_CHECK_MSG(
169+
ctx,
170+
start < dim_length,
171+
InvalidArgument,
172+
/* void */,
173+
"Requested start is larger than the dim length.");
174+
ET_KERNEL_CHECK_MSG(
175+
ctx,
176+
length * step < dim_length,
177+
InvalidArgument,
178+
/* void */,
179+
"Requested length * step is larger than the dim size.");
161180
size_t leading_dims = getLeadingDims(in, dim);
162181
size_t trailing_dims = getTrailingDims(in, dim);
163182

@@ -170,6 +189,12 @@ void compute_slice(
170189
const char* input_data = in.const_data_ptr<char>();
171190
char* dest = out.mutable_data_ptr<char>();
172191

192+
ET_KERNEL_CHECK_MSG(
193+
ctx,
194+
out.nbytes() >= (length * leading_dims * length_per_step),
195+
InvalidArgument,
196+
/* void */,
197+
"out.nbytes() is smaller than the expected slice size.");
173198
for (const auto i : c10::irange(leading_dims)) {
174199
const char* src = input_data + (i * dim_length + start) * length_per_step;
175200
for ([[maybe_unused]] const auto j : c10::irange(length)) {

kernels/portable/cpu/util/slice_util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ int64_t adjust_slice_indices(
5555
int64_t step);
5656

5757
void compute_slice(
58+
KernelRuntimeContext& ctx,
5859
const Tensor& in,
5960
int64_t dim,
6061
int64_t start,

0 commit comments

Comments
 (0)