Skip to content

Commit 3d0af15

Browse files
MarisaKirisamekevinthesun
authored andcommitted
[Relay] crossentropy_with_logits and its gradient (apache#4075)
* save * lint
1 parent 19f105f commit 3d0af15

File tree

6 files changed

+73
-4
lines changed

6 files changed

+73
-4
lines changed

python/tvm/relay/op/_reduce.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ def _schedule_reduce(_, outs, target):
3737
_reg.register_schedule("mean", _schedule_reduce)
3838
_reg.register_schedule("variance", _schedule_reduce)
3939
_reg.register_schedule("nn.cross_entropy", _schedule_reduce)
40+
_reg.register_schedule("nn.cross_entropy_with_logits", _schedule_reduce)

python/tvm/relay/op/_tensor_grad.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,12 @@ def cross_entropy_grad(orig, grad):
449449
batch_size = take(shape, const(0, dtype='int32'), axis=0)
450450
grad = grad / batch_size.astype('float32')
451451
return [-grad * y / x, -grad * log(x)]
452+
453+
454+
@register_gradient("nn.cross_entropy_with_logits")
455+
def cross_entropy_with_logits_grad(orig, grad):
456+
x, y = orig.args
457+
shape = shape_of(x)
458+
batch_size = take(shape, const(0, dtype='int32'), axis=0)
459+
grad = grad / batch_size.astype('float32')
460+
return [-grad * y, -grad * x]

python/tvm/relay/op/nn/_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,12 @@ def schedule_bitserial_dense(attrs, outputs, target):
770770
def compute_cross_entropy(attrs, inputs, out_dtype, target):
771771
x, y = inputs
772772
return [-topi.sum(topi.log(x) * y) / x.shape[0]]
773+
774+
775+
reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
776+
777+
778+
@reg.register_compute("nn.cross_entropy_with_logits")
779+
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
780+
x, y = inputs
781+
return [-topi.sum(x * y) / x.shape[0]]

python/tvm/relay/op/nn/nn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,3 +1807,22 @@ def cross_entropy(predictions, targets):
18071807
The computed result.
18081808
"""
18091809
return _make.cross_entropy(predictions, targets)
1810+
1811+
1812+
def cross_entropy_with_logits(predictions, targets):
1813+
"""CrossEntropy with logits.
1814+
1815+
Parameters
1816+
----------
1817+
predictions : tvm.relay.Expr
1818+
The predictions.
1819+
1820+
targets : tvm.relay.Expr
1821+
The targets.
1822+
1823+
Returns
1824+
-------
1825+
result : tvm.relay.Expr
1826+
The computed result.
1827+
"""
1828+
return _make.cross_entropy_with_logits(predictions, targets)

src/relay/op/nn/nn.cc

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ bool CrossEntropyRel(const Array<Type>& types,
910910
return true;
911911
}
912912

913-
// Positional relay function to create batch_matmul operator used by frontend FFI.
913+
// Positional relay function to create cross_entropy operator used by frontend FFI.
914914
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
915915
static const Op& op = Op::Get("nn.cross_entropy");
916916
return CallNode::make(op, {predictions, targets}, Attrs(), {});
@@ -933,5 +933,28 @@ Do log on the data - do not accept logits.
933933
.add_type_rel("CrossEntropy", CrossEntropyRel);
934934

935935

936+
// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
937+
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
938+
static const Op& op = Op::Get("nn.cross_entropy_with_logits");
939+
return CallNode::make(op, {predictions, targets}, Attrs(), {});
940+
}
941+
942+
943+
TVM_REGISTER_API("relay.op.nn._make.cross_entropy_with_logits")
944+
.set_body_typed(MakeCrossEntropyWithLogits);
945+
946+
947+
RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
948+
.describe(R"code(
949+
Computes cross entropy given predictions and targets.
950+
Accept logits.
951+
)code" TVM_ADD_FILELINE)
952+
.set_num_inputs(2)
953+
.add_argument("x", "1D Tensor", "Predictions.")
954+
.add_argument("y", "1D Tensor", "Targets.")
955+
.set_support_level(10)
956+
.add_type_rel("CrossEntropy", CrossEntropyRel);
957+
958+
936959
} // namespace relay
937960
} // namespace tvm

tests/python/relay/test_op_grad_level10.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,23 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import pytest
18+
1719
from tvm import relay
1820
from tvm.relay.testing import check_grad
1921

2022

2123
def test_cross_entropy_grad():
22-
x = relay.var("x", shape=(1, 5))
23-
y = relay.var("y", shape=(1, 5))
24+
x = relay.var("x", shape=(2, 5))
25+
y = relay.var("y", shape=(2, 5))
2426
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), eps=0.01, scale=0.1, mean=1)
2527

2628

29+
def test_cross_entropy_with_logits_grad():
30+
x = relay.var("x", shape=(2, 5))
31+
y = relay.var("y", shape=(2, 5))
32+
check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
33+
34+
2735
if __name__ == "__main__":
28-
test_cross_entropy_grad()
36+
pytest.main([__file__])

0 commit comments

Comments
 (0)