@@ -103,6 +103,50 @@ RELAY_REGISTER_OP("nn.bias_add")
103103// relay.nn.dense
104104TVM_REGISTER_NODE_TYPE (DenseAttrs);
105105
106+
107+ bool DenseRel (const Array<Type>& types,
108+ int num_inputs,
109+ const Attrs& attrs,
110+ const TypeReporter& reporter) {
111+ CHECK_EQ (types.size (), 3 );
112+ const auto * data = types[0 ].as <TensorTypeNode>();
113+ const auto * weight = types[1 ].as <TensorTypeNode>();
114+ if (data == nullptr ) return false ;
115+
116+ const DenseAttrs* param = attrs.as <DenseAttrs>();
117+ CHECK (param != nullptr );
118+
119+ CHECK (static_cast <int >(data->shape .size ()) != 0 );
120+
121+ Array<tvm::Expr> oshape = data->shape ;
122+ if (param->units .defined ()) {
123+ Array<tvm::Expr> dshape = data->shape ;
124+ // validate the weight shape is proper if defined
125+ // Assign weight type
126+ Array<IndexExpr> wshape ({param->units , dshape[dshape.size () - 1 ]});
127+ reporter->Assign (types[1 ], TensorTypeNode::make (wshape, data->dtype ));
128+ oshape.Set ((oshape.size () - 1 ), param->units );
129+ } else {
130+ if (weight == nullptr ) return false ;
131+ Array<tvm::Expr> wshape = weight->shape ;
132+ CHECK (static_cast <int >(weight->shape .size ()) == 2 );
133+ CHECK (reporter->AssertEQ (data->shape [data->shape .size () - 1 ], weight->shape [1 ]))
134+ << " DenseRel: input dimension doesn't match,"
135+ << " data shape=" << data->shape
136+ << " , weight shape=" << weight->shape ;
137+ oshape.Set ((oshape.size () - 1 ), wshape[0 ]);
138+ }
139+
140+ DataType out_dtype = param->out_dtype ;
141+ if (out_dtype.bits () == 0 ) {
142+ out_dtype = data->dtype ;
143+ }
144+ // assign output type
145+ reporter->Assign (types[2 ], TensorTypeNode::make (oshape, out_dtype));
146+ return true ;
147+ }
148+
149+
106150// Positional relay function to create dense operator used by frontend FFI.
107151Expr MakeDense (Expr data,
108152 Expr weight,
@@ -698,11 +742,11 @@ bool BatchMatmulRel(const Array<Type>& types,
698742 if (x == nullptr || y == nullptr ) return false ;
699743 CHECK (x->shape .size () == 3 && y->shape .size () == 3 );
700744 CHECK (reporter->AssertEQ (x->shape [0 ], y->shape [0 ]))
701- << " BatchDot: batch dimension doesn't match, "
702- << " x shape=" << x->shape
703- << " , y shape=" << y->shape ;
745+ << " BatchDot: batch dimension doesn't match,"
746+ << " x shape=" << x->shape
747+ << " , y shape=" << y->shape ;
704748 CHECK (reporter->AssertEQ (x->shape [2 ], y->shape [2 ]))
705- << " BatchDot: shapes of x and y is inconsistent, "
749+ << " BatchDot: shapes of x and y is inconsistent,"
706750 << " x shape=" << x->shape
707751 << " , y shape=" << y->shape ;
708752
@@ -746,6 +790,51 @@ are data in batch.
746790.set_support_level(10 )
747791.add_type_rel(" BatchMatmul" , BatchMatmulRel);
748792
793+ // relay.nn.cross_entropy
794+ bool CrossEntropyRel (const Array<Type>& types,
795+ int num_inputs,
796+ const Attrs& attrs,
797+ const TypeReporter& reporter) {
798+ CHECK_EQ (types.size (), 3 );
799+ const auto * x = types[0 ].as <TensorTypeNode>();
800+ const auto * y = types[1 ].as <TensorTypeNode>();
801+ if (x == nullptr || y == nullptr ) return false ;
802+ CHECK (x->shape .size () == 2 && y->shape .size () == 2 )
803+ << " CrossEntropy: shapes of x and y is inconsistent,"
804+ << " x shape=" << x->shape
805+ << " , y shape=" << y->shape ;
806+ CHECK (reporter->AssertEQ (x->shape [0 ], y->shape [0 ]))
807+ << " CrossEntropy: shapes of x and y is inconsistent,"
808+ << " x shape=" << x->shape
809+ << " , y shape=" << y->shape ;
810+ CHECK (reporter->AssertEQ (x->shape [1 ], y->shape [1 ]))
811+ << " CrossEntropy: shapes of x and y is inconsistent,"
812+ << " x shape=" << x->shape
813+ << " , y shape=" << y->shape ;
814+ // assign output type
815+ reporter->Assign (types[2 ], TensorTypeNode::make ({}, x->dtype ));
816+ return true ;
817+ }
818+
819+ // Positional relay function to create batch_matmul operator used by frontend FFI.
820+ Expr MakeCrossEntropy (Expr predictions, Expr targets) {
821+ static const Op& op = Op::Get (" nn.cross_entropy" );
822+ return CallNode::make (op, {predictions, targets}, Attrs (), {});
823+ }
824+
825+
826+ TVM_REGISTER_API (" relay.op.nn._make.cross_entropy" )
827+ .set_body_typed(MakeCrossEntropy);
828+
829+
830+ RELAY_REGISTER_OP (" nn.cross_entropy" )
831+ .describe(R"code( Computes cross entropy given preditions and targets.)code" TVM_ADD_FILELINE)
832+ .set_num_inputs(2 )
833+ .add_argument(" x" , " 1D Tensor" , " Predictions." )
834+ .add_argument(" y" , " 1D Tensor" , " Targets." )
835+ .set_support_level(10 )
836+ .add_type_rel(" CrossEntropy" , CrossEntropyRel);
837+
749838
750839} // namespace relay
751840} // namespace tvm
0 commit comments