@@ -229,30 +229,30 @@ TEST(InstanceNorm, nhwc) {
229
229
auto src_graph = ctx->CreateGraph ();
230
230
231
231
tim::vx::ShapeType io_shape ({2 , 2 , 2 , 2 }); // nhwc
232
- tim::vx::ShapeType param_shape ({1 });
233
- tim::vx::TensorSpec input_spec (tim::vx::DataType::FLOAT32,
232
+ tim::vx::ShapeType param_shape ({2 });
233
+ tim::vx::TensorSpec input_spec (tim::vx::DataType::FLOAT32,
234
234
io_shape, tim::vx::TensorAttribute::INPUT);
235
- tim::vx::TensorSpec param_spec (tim::vx::DataType::FLOAT32,
235
+ tim::vx::TensorSpec param_spec (tim::vx::DataType::FLOAT32,
236
236
param_shape, tim::vx::TensorAttribute::INPUT);
237
- tim::vx::TensorSpec output_spec (tim::vx::DataType::FLOAT32,
237
+ tim::vx::TensorSpec output_spec (tim::vx::DataType::FLOAT32,
238
238
io_shape, tim::vx::TensorAttribute::OUTPUT);
239
239
240
- auto input_tensor = src_graph->CreateTensor (input_spec);
241
- auto beta_tensor = src_graph->CreateTensor (param_spec);
242
- auto gamma_tensor = src_graph->CreateTensor (param_spec);
243
- auto output_tensor = src_graph->CreateTensor (output_spec);
244
-
245
- std::vector<float > in_data = {
246
- 0 .0f , 1 .0f , 0 .0f , 2 .0f , 0 .0f , 2 .0f , 0 .0f , 4 .0f , 1 .0f , -1 .0f , -1 .0f , 2 .0f , -1 .0f , -2 .0f , 1 .0f , 4 .0f
247
- };
248
- std::vector<float > beta = {0 };
249
- std::vector<float > gamma = {1 .0f };
250
- std::vector<float > golden = {
251
- 0 .0f , -1 .1470304f , 0 .0f , -0 .22940612f , 0 .0f , -0 .22940612f , 0 .0f , 1 .6058424f , 0 .99995005f ,
252
- -0 .7337929f , -0 .99995005f , 0 .52413774f , -0 .99995005f , -1 .1531031f , 0 .99995005f , 1 .3627582f ,
253
- };
254
- auto op = src_graph->CreateOperation <tim::vx::ops::InstanceNormalization>(1e-4f , tim::vx::DataLayout::CWHN);
255
- (*op).BindInputs ({input_tensor, beta_tensor, gamma_tensor}).BindOutputs ({output_tensor});
240
+ auto input_tensor = src_graph->CreateTensor (input_spec);
241
+ auto beta_tensor = src_graph->CreateTensor (param_spec);
242
+ auto gamma_tensor = src_graph->CreateTensor (param_spec);
243
+ auto output_tensor = src_graph->CreateTensor (output_spec);
244
+
245
+ std::vector<float > in_data = {
246
+ 0 .0f , 1 .0f , 0 .0f , 2 .0f , 0 .0f , 2 .0f , 0 .0f , 4 .0f , 1 .0f , -1 .0f , -1 .0f , 2 .0f , -1 .0f , -2 .0f , 1 .0f , 4 .0f
247
+ };
248
+ std::vector<float > beta = {0 , 0 };
249
+ std::vector<float > gamma = {1 . 0f , 1 .0f };
250
+ std::vector<float > golden = {
251
+ 0 .0f , -1 .1470304f , 0 .0f , -0 .22940612f , 0 .0f , -0 .22940612f , 0 .0f , 1 .6058424f , 0 .99995005f ,
252
+ -0 .7337929f , -0 .99995005f , 0 .52413774f , -0 .99995005f , -1 .1531031f , 0 .99995005f , 1 .3627582f ,
253
+ };
254
+ auto op = src_graph->CreateOperation <tim::vx::ops::InstanceNormalization>(1e-4f , tim::vx::DataLayout::CWHN);
255
+ (*op).BindInputs ({input_tensor, beta_tensor, gamma_tensor}).BindOutputs ({output_tensor});
256
256
// Do layout inference
257
257
auto transform = tim::transform::LayoutInference (src_graph, ctx);
258
258
auto infer_graph = transform.first ;
0 commit comments