@@ -116,13 +116,14 @@ Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
116
116
}
117
117
118
118
// Return the modified layout for AlterOpLayout pass.
119
+ template <typename T>
119
120
InferCorrectLayoutOutput ReduceInferCorrectLayout (const Attrs& attrs,
120
121
const Array<Layout>& new_in_layouts,
121
122
const Array<Layout>& old_in_layouts,
122
123
const Array<tvm::relay::Type>& old_in_types) {
123
- const auto * attrs_ptr = attrs.as <ReduceAttrs >();
124
+ const auto * attrs_ptr = attrs.as <T >();
124
125
ICHECK (attrs_ptr);
125
- ObjectPtr<ReduceAttrs > params = make_object<ReduceAttrs >(*attrs_ptr);
126
+ ObjectPtr<T > params = make_object<T >(*attrs_ptr);
126
127
127
128
// Get the reduce axes.
128
129
Array<Array<IndexExpr>> old_in_shapes;
@@ -389,6 +390,7 @@ values over a given axis.
389
390
.set_support_level(4 )
390
391
.add_type_rel(" ArgReduce" , GenericReduceRel<ArgReduceAttrs>)
391
392
.set_attr<FTVMCompute>(" FTVMCompute" , ArgMaxCompute)
393
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ArgReduceAttrs>)
392
394
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
393
395
394
396
Array<te::Tensor> ArgMinCompute (const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -405,6 +407,7 @@ values over a given axis.
405
407
.set_support_level(4 )
406
408
.add_type_rel(" ArgReduce" , GenericReduceRel<ArgReduceAttrs>)
407
409
.set_attr<FTVMCompute>(" FTVMCompute" , ArgMinCompute)
410
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ArgReduceAttrs>)
408
411
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
409
412
410
413
Array<te::Tensor> SumCompute (const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -433,7 +436,7 @@ Example::
433
436
.set_attrs_type<ReduceAttrs>()
434
437
.set_support_level(4 )
435
438
.add_type_rel(" Reduce" , ReduceRel)
436
- .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout)
439
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs> )
437
440
.set_attr<FTVMCompute>(" FTVMCompute" , SumCompute)
438
441
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
439
442
@@ -468,6 +471,7 @@ Example::
468
471
.set_support_level(4 )
469
472
.add_type_rel(" Reduce" , ReduceRel)
470
473
.set_attr<FTVMCompute>(" FTVMCompute" , AllCompute)
474
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>)
471
475
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
472
476
473
477
Array<te::Tensor> AnyCompute (const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -516,6 +520,7 @@ RELAY_REGISTER_REDUCE_OP("max")
516
520
.set_support_level(4 )
517
521
.add_type_rel(" Reduce" , ReduceRel)
518
522
.set_attr<FTVMCompute>(" FTVMCompute" , MaxCompute)
523
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>)
519
524
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
520
525
521
526
Array<te::Tensor> MinCompute (const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -531,6 +536,7 @@ RELAY_REGISTER_REDUCE_OP("min")
531
536
.set_support_level(4 )
532
537
.add_type_rel(" Reduce" , ReduceRel)
533
538
.set_attr<FTVMCompute>(" FTVMCompute" , MinCompute)
539
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>)
534
540
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
535
541
536
542
Array<te::Tensor> ProdCompute (const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -551,17 +557,18 @@ Example::
551
557
[[1,4],[4,3],[5,2]],
552
558
[[7,1],[7,2],[7,3]]]
553
559
554
- mean (data, axis=1)
560
+ prod (data, axis=1)
555
561
[35562240]
556
562
557
- mean (data, axis=[1,2])
563
+ prod (data, axis=[1,2])
558
564
[ 36 480 2058]
559
565
560
566
)code" TVM_ADD_FILELINE)
561
567
.set_attrs_type<ReduceAttrs>()
562
568
.set_support_level(4 )
563
569
.add_type_rel(" Reduce" , ReduceRel)
564
570
.set_attr<FTVMCompute>(" FTVMCompute" , ProdCompute)
571
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>)
565
572
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
566
573
567
574
Array<te::Tensor> MeanCompute (const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -600,6 +607,7 @@ Example::
600
607
.set_support_level(4 )
601
608
.add_type_rel(" Reduce" , ReduceRel)
602
609
.set_attr<FTVMCompute>(" FTVMCompute" , MeanCompute)
610
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>)
603
611
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
604
612
605
613
bool VarianceRel (const Array<Type>& types, int num_inputs, const Attrs& attrs,
@@ -675,6 +683,7 @@ RELAY_REGISTER_OP("variance")
675
683
.add_argument(" mean" , " Tensor" , " The mean tensor." )
676
684
.add_type_rel(" Variance" , VarianceRel)
677
685
.set_attr<FTVMCompute>(" FTVMCompute" , VarianceCompute)
686
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , ReduceInferCorrectLayout<ReduceAttrs>)
678
687
.set_attr<TOpPattern>(" TOpPattern" , kCommReduce );
679
688
680
689
} // namespace relay
0 commit comments