Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,11 +1812,20 @@ def version_9(cls, ctx, node, **kwargs):
onehot_indice = ctx.make_node("Cast", [label_name], attr={"to": TensorProto.INT64}).output[0]
else:
onehot_indice = label_name
label_node = ctx.make_node(op_type="OneHot",
inputs=[onehot_indice, depth_node, values_node])
if ctx.opset < 11:
label_node = ctx.make_node(op_type="OneHot",
inputs=[onehot_indice, depth_node, values_node])
else:
# OneHot is very slow but this workaround requires opset 11
index_unsq = GraphBuilder(ctx).make_unsqueeze({'data': onehot_indice, 'axes': [-1]})
depth_sq = GraphBuilder(ctx).make_squeeze({'data': depth_node, 'axes': [0]})
zero_const = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
one_const = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
dp_range = ctx.make_node("Range", [zero_const, depth_sq, one_const]).output[0]
label_node = ctx.make_node("Equal", [index_unsq, dp_range])
# the above logic makes output dtype of label_node now always int64
# make sure label has same dtype as logit
if logit_dtype != TensorProto.INT64:
if logit_dtype != ctx.get_dtype(label_node.output[0]):
label_node = ctx.make_node("Cast", label_node.output, attr={"to": logit_dtype}, dtypes=[logit_dtype])

_make_sparse_softmax_cross_entropy_with_logits(ctx, label_node, logit_node, node)