@@ -91,7 +91,9 @@ void CheckOutput(const std::vector<OperationExpression>& expressions,
9191                 const  std::vector<int > output_ids_of_subgraph, int  i) {
9292  std::vector<float > var (cpu_tensors.size ());
9393  for  (auto  id : input_ids_of_subgraph) {
94-     var[id] = cpu_tensors[id].data <float >()[i];
94+     if  (id >= 0 ) {
95+       var[id] = cpu_tensors[id].data <float >()[i];
96+     }
9597  }
9698
9799  for  (auto  expression : expressions) {
@@ -182,10 +184,8 @@ void TestMainImpl(std::string func_name, std::string code_str,
182184          gpu_tensors[id].mutable_data <float >(cpu_tensors[id].dims (), place);
183185      fusion_group::SetupRandomCPUTensor<float >(&cpu_tensors[id]);
184186      TensorCopySync (cpu_tensors[id], place, &gpu_tensors[id]);
185-     } else  {
186-       gpu_ptrs[id] = nullptr ;
187+       args.push_back (&gpu_ptrs[id]);
187188    }
188-     args.push_back (&gpu_ptrs[id]);
189189  }
190190
191191  for  (auto  id : output_ids) {
@@ -283,7 +283,7 @@ TEST(code_generator, elementwise_grad) {
283283  //  t3 = relu(t2)
284284  //  t2' = relu_grad(t2, t3, t3')
285285  //  t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
286-   fusion_group::OperationExpression exp1 (" relu_grad" 2 , 3 , 7 }, {6 });
286+   fusion_group::OperationExpression exp1 (" relu_grad" 2 , - 1 , 7 }, {6 });
287287  fusion_group::OperationExpression exp2 (" elementwise_mul_grad" 0 , 1 , 2 , 6 },
288288                                         {4 , 5 });
289289  std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
@@ -300,7 +300,7 @@ TEST(code_generator, elementwise_grad) {
300300  //   Op(relu_grad), inputs:{2,3,7}, outputs:{6}
301301  //   Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5}
302302  int  n = cpu_tensors[0 ].numel ();
303-   std::vector<int > input_ids = {0 , 1 , 2 , 3 , 7 };
303+   std::vector<int > input_ids = {0 , 1 , 2 , - 1 , 7 };
304304  std::vector<int > output_ids = {4 , 5 , 6 };
305305  TestMain (" elementwise_grad_kernel_0" 
306306           output_ids);
0 commit comments