Skip to content

Commit b743c13

Browse files
committed
Add input size check and unit test
1 parent 578358b commit b743c13

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

extension/module/module.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,12 @@ runtime::Result<std::vector<runtime::EValue>> Module::execute(
240240
auto& method = methods_.at(method_name).method;
241241
auto& inputs = methods_.at(method_name).inputs;
242242

243+
ET_CHECK_OR_RETURN_ERROR(
244+
input_values.size() <= inputs.size(),
245+
InvalidArgument,
246+
"input size: %zu does not match method input size: %zu",
247+
input_values.size(),
248+
inputs.size());
243249
for (size_t i = 0; i < input_values.size(); ++i) {
244250
if (!input_values[i].isNone()) {
245251
inputs[i] = input_values[i];

extension/module/test/module_test.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,16 @@ TEST_F(ModuleTest, TestExecuteOnCurrupted) {
216216
EXPECT_NE(result.error(), Error::Ok);
217217
}
218218

219+
TEST_F(ModuleTest, TestExecuteWithTooManyInputs) {
220+
Module module(model_path_);
221+
222+
auto tensor = make_tensor_ptr({2, 2}, {1.f, 2.f, 3.f, 4.f});
223+
224+
const auto result = module.execute("forward", {tensor, tensor, 1.0, 1.0});
225+
226+
EXPECT_NE(result.error(), Error::Ok);
227+
}
228+
219229
TEST_F(ModuleTest, TestGet) {
220230
Module module(model_path_);
221231

0 commit comments

Comments
 (0)