@@ -41,16 +41,15 @@ void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
4141
4242  auto  output = op_desc.Output (" Y" front ();
4343  auto  op_name = op_desc.Type () + " :" Output (" Y" front ();
44-   engine_->AddOp (op_name, " Scale" " X" 
45-   engine_->AddOpAttr (op_name, " bias_term" true );
46-   engine_->AddOpAttr (op_name, " axis" 1 );
47-   engine_->AddOpAttr (op_name, " num_axes" 1 );
48- 
4944  bool  is_test = boost::get<bool >(op_desc.GetAttr (" is_test" 
50-   PADDLE_ENFORCE (is_test);
51-   float  epsilon = boost::get<float >(op_desc.GetAttr (" epsilon" 
52-   engine_->AddOpAttr (op_name, " epsilon" 
45+   auto  epsilon = boost::get<float >(op_desc.GetAttr (" epsilon" 
46+ 
47+   auto  bn_op_name = op_name + " :bn" 
48+   auto  bn_output = bn_op_name + " _output" 
49+   engine_->AddOp (bn_op_name, " BatchNorm" " X" 
50+   engine_->AddOpAttr (bn_op_name, " epsilon" 
5351
52+   auto  scale_op_name = op_name + " :scale" 
5453  auto  get_lod_tensor = [this , &scope, &op_name](const  std::string &var_name,
5554                                                 framework::LoDTensor *tensor) {
5655    auto  *v = scope.FindVar (var_name);
@@ -69,50 +68,54 @@ void BatchNormOpConverter::operator()(const framework::proto::OpDesc &op,
6968  get_lod_tensor (inputs[" Scale" scale_t );
7069  get_lod_tensor (inputs[" Variance" variance_t );
7170
72-   auto  *bias = bias_t .mutable_data <float >(platform::CPUPlace ());
73-   auto  *mean = mean_t .mutable_data <float >(platform::CPUPlace ());
74-   auto  *scale = scale_t .mutable_data <float >(platform::CPUPlace ());
75-   auto  *variance = variance_t .mutable_data <float >(platform::CPUPlace ());
76- 
77-   framework::LoDTensor combile_scale_t ;
78-   framework::LoDTensor combile_bias_t ;
79-   combile_scale_t .Resize (scale_t .dims ());
80-   combile_bias_t .Resize (bias_t .dims ());
81- 
82-   auto  *combile_scale =
83-       combile_scale_t .mutable_data <float >(platform::CPUPlace ());
84-   auto  *combile_bias = combile_bias_t .mutable_data <float >(platform::CPUPlace ());
85- 
86-   size_t  elem_num = combile_scale_t .memory_size () / sizeof (float );
87-   for  (size_t  i = 0 ; i < elem_num; i++) {
88-     combile_scale[i] = scale[i] / sqrtf (variance[i] + epsilon);
89-     combile_bias[i] = bias[i] - mean[i] * combile_scale[i];
90-   }
91- 
92-   auto  fill_shape = [](size_t  n, std::vector<int > *shape) {
93-     shape->insert (shape->begin (), 1 );
94-     if  (shape->size () < n) {
95-       shape->insert (shape->end (), n - shape->size (), 1 );
71+   auto  fill_shape = [](size_t  n, std::vector<int > shape) {
72+     shape.insert (shape.begin (), 1 );
73+     if  (shape.size () < n) {
74+       shape.insert (shape.end (), n - shape.size (), 1 );
9675    }
76+     return  shape;
9777  };
98-   auto  scale_shape = framework::vectorize2int (combile_scale_t .dims ());
99-   auto  bias_shape = framework::vectorize2int (combile_bias_t .dims ());
100-   fill_shape (4 , &scale_shape);
101-   fill_shape (4 , &bias_shape);
102-   Shape weight1_shape (scale_shape);
103-   Shape weight2_shape (bias_shape);
78+   Shape shape1 (fill_shape (4 , framework::vectorize2int (mean_t .dims ())));
79+   Shape shape2 (fill_shape (4 , framework::vectorize2int (variance_t .dims ())));
10480  auto  *weight1 =
105-       GraphGlobalMem<NV>::Global ().template  new_block <AK_FLOAT>(weight1_shape);
106-   auto  *scale_data = static_cast <float  *>(weight1->h_tensor ().mutable_data ());
107-   std::copy_n (combile_scale_t .data <float >(), combile_scale_t .numel (),
108-               scale_data);
109-   engine_->AddOpAttr (op_name, " weight_1" 
81+       GraphGlobalMem<NV>::Global ().template  new_block <AK_FLOAT>(shape1);
82+   auto  *mean_data = static_cast <float  *>(weight1->h_tensor ().mutable_data ());
83+   std::copy_n (mean_t .data <float >(), mean_t .numel (), mean_data);
84+   engine_->AddOpAttr (bn_op_name, " weight_1" 
11085
11186  auto  *weight2 =
112-       GraphGlobalMem<NV>::Global ().template  new_block <AK_FLOAT>(weight2_shape);
113-   auto  *bias_data = static_cast <float  *>(weight2->h_tensor ().mutable_data ());
114-   std::copy_n (combile_bias_t .data <float >(), combile_bias_t .numel (), bias_data);
115-   engine_->AddOpAttr (op_name, " weight_2" 
87+       GraphGlobalMem<NV>::Global ().template  new_block <AK_FLOAT>(shape2);
88+   auto  *variance_data =
89+       static_cast <float  *>(weight2->h_tensor ().mutable_data ());
90+   std::copy_n (variance_t .data <float >(), variance_t .numel (), variance_data);
91+   engine_->AddOpAttr (bn_op_name, " weight_2" 
92+ 
93+   Shape shape3 (std::vector<int >({1 , 1 , 1 , 1 }));
94+   auto  *weight3 =
95+       GraphGlobalMem<NV>::Global ().template  new_block <AK_FLOAT>(shape3);
96+   auto  *alpha_data = static_cast <float  *>(weight3->h_tensor ().mutable_data ());
97+   float  weight3_data[] = {1 };
98+   std::copy (std::begin (weight3_data), std::end (weight3_data), alpha_data);
99+   engine_->AddOpAttr (bn_op_name, " weight_3" 
100+ 
101+   Shape scale_shape (fill_shape (4 , framework::vectorize2int (scale_t .dims ())));
102+   auto  *scale =
103+       GraphGlobalMem<NV>::Global ().template  new_block <AK_FLOAT>(scale_shape);
104+   auto  *scale_data = static_cast <float  *>(scale->h_tensor ().mutable_data ());
105+   std::copy_n (scale_t .data <float >(), scale_t .numel (), scale_data);
106+ 
107+   Shape bias_shape (fill_shape (4 , framework::vectorize2int (bias_t .dims ())));
108+   auto  *bias =
109+       GraphGlobalMem<NV>::Global ().template  new_block <AK_FLOAT>(bias_shape);
110+   auto  *bias_data = static_cast <float  *>(bias->h_tensor ().mutable_data ());
111+   std::copy_n (bias_t .data <float >(), bias_t .numel (), bias_data);
112+ 
113+   engine_->AddOp (scale_op_name, " Scale" 
114+   engine_->AddOpAttr (scale_op_name, " axis" 1 );
115+   engine_->AddOpAttr (scale_op_name, " num_axes" 1 );
116+   engine_->AddOpAttr (scale_op_name, " bias_term" true );
117+   engine_->AddOpAttr (scale_op_name, " weight_1" 
118+   engine_->AddOpAttr (scale_op_name, " weight_2" 
116119}
117120
118121}  //  namespace anakin
0 commit comments