Skip to content

Commit 5173979

Browse files
Fix the instance norm test input size bug in layout infer test (#661)
Correct gamma and beta size in InstanceNorm.nhwc case Type: Bug Fix Issue: 37103 Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
1 parent 74e2740 commit 5173979

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

src/tim/transform/layout_inference_test.cc

+20-20
Original file line numberDiff line numberDiff line change
@@ -229,30 +229,30 @@ TEST(InstanceNorm, nhwc) {
229229
auto src_graph = ctx->CreateGraph();
230230

231231
tim::vx::ShapeType io_shape({2, 2, 2, 2}); //nhwc
232-
tim::vx::ShapeType param_shape({1});
233-
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
232+
tim::vx::ShapeType param_shape({2});
233+
tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32,
234234
io_shape, tim::vx::TensorAttribute::INPUT);
235-
tim::vx::TensorSpec param_spec(tim::vx::DataType::FLOAT32,
235+
tim::vx::TensorSpec param_spec(tim::vx::DataType::FLOAT32,
236236
param_shape, tim::vx::TensorAttribute::INPUT);
237-
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
237+
tim::vx::TensorSpec output_spec(tim::vx::DataType::FLOAT32,
238238
io_shape, tim::vx::TensorAttribute::OUTPUT);
239239

240-
auto input_tensor = src_graph->CreateTensor(input_spec);
241-
auto beta_tensor = src_graph->CreateTensor(param_spec);
242-
auto gamma_tensor = src_graph->CreateTensor(param_spec);
243-
auto output_tensor = src_graph->CreateTensor(output_spec);
244-
245-
std::vector<float> in_data = {
246-
0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 2.0f, 0.0f, 4.0f, 1.0f, -1.0f, -1.0f, 2.0f, -1.0f, -2.0f, 1.0f, 4.0f
247-
};
248-
std::vector<float> beta = {0};
249-
std::vector<float> gamma = {1.0f};
250-
std::vector<float> golden = {
251-
0.0f, -1.1470304f, 0.0f, -0.22940612f, 0.0f, -0.22940612f, 0.0f, 1.6058424f, 0.99995005f,
252-
-0.7337929f, -0.99995005f, 0.52413774f, -0.99995005f, -1.1531031f, 0.99995005f, 1.3627582f,
253-
};
254-
auto op = src_graph->CreateOperation<tim::vx::ops::InstanceNormalization>(1e-4f, tim::vx::DataLayout::CWHN);
255-
(*op).BindInputs({input_tensor, beta_tensor, gamma_tensor}).BindOutputs({output_tensor});
240+
auto input_tensor = src_graph->CreateTensor(input_spec);
241+
auto beta_tensor = src_graph->CreateTensor(param_spec);
242+
auto gamma_tensor = src_graph->CreateTensor(param_spec);
243+
auto output_tensor = src_graph->CreateTensor(output_spec);
244+
245+
std::vector<float> in_data = {
246+
0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 2.0f, 0.0f, 4.0f, 1.0f, -1.0f, -1.0f, 2.0f, -1.0f, -2.0f, 1.0f, 4.0f
247+
};
248+
std::vector<float> beta = {0,0};
249+
std::vector<float> gamma = {1.0f,1.0f};
250+
std::vector<float> golden = {
251+
0.0f, -1.1470304f, 0.0f, -0.22940612f, 0.0f, -0.22940612f, 0.0f, 1.6058424f, 0.99995005f,
252+
-0.7337929f, -0.99995005f, 0.52413774f, -0.99995005f, -1.1531031f, 0.99995005f, 1.3627582f,
253+
};
254+
auto op = src_graph->CreateOperation<tim::vx::ops::InstanceNormalization>(1e-4f, tim::vx::DataLayout::CWHN);
255+
(*op).BindInputs({input_tensor, beta_tensor, gamma_tensor}).BindOutputs({output_tensor});
256256
// Do layout inference
257257
auto transform = tim::transform::LayoutInference(src_graph, ctx);
258258
auto infer_graph = transform.first;

0 commit comments

Comments
 (0)