@@ -910,7 +910,7 @@ bool CrossEntropyRel(const Array<Type>& types,
910910 return true ;
911911}
912912
913- // Positional relay function to create batch_matmul operator used by frontend FFI.
913+ // Positional relay function to create cross_entropy operator used by frontend FFI.
914914Expr MakeCrossEntropy (Expr predictions, Expr targets) {
915915 static const Op& op = Op::Get (" nn.cross_entropy" );
916916 return CallNode::make (op, {predictions, targets}, Attrs (), {});
@@ -933,5 +933,28 @@ Do log on the data - do not accept logits.
933933.add_type_rel(" CrossEntropy" , CrossEntropyRel);
934934
935935
936+ // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
937+ Expr MakeCrossEntropyWithLogits (Expr predictions, Expr targets) {
938+ static const Op& op = Op::Get (" nn.cross_entropy_with_logits" );
939+ return CallNode::make (op, {predictions, targets}, Attrs (), {});
940+ }
941+
942+
943+ TVM_REGISTER_API (" relay.op.nn._make.cross_entropy_with_logits" )
944+ .set_body_typed(MakeCrossEntropyWithLogits);
945+
946+
947+ RELAY_REGISTER_OP (" nn.cross_entropy_with_logits" )
948+ .describe(R"code(
949+ Computes cross entropy given predictions and targets.
950+ Accept logits.
951+ )code" TVM_ADD_FILELINE)
952+ .set_num_inputs(2 )
953+ .add_argument(" x" , " 1D Tensor" , " Predictions." )
954+ .add_argument(" y" , " 1D Tensor" , " Targets." )
955+ .set_support_level(10 )
956+ .add_type_rel(" CrossEntropy" , CrossEntropyRel);
957+
958+
936959} // namespace relay
937960} // namespace tvm
0 commit comments