@@ -26,23 +26,44 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
2626 }
2727 const auto & input_shape = ctx.getInputType (0 )->tensor_type ().shape ();
2828 const auto input_ndim = input_shape.dim_size ();
29-
29+ if (input_ndim == 1 ) {
30+ return ;
31+ }
3032 auto output_shape = ctx.getOutputType (0 )->mutable_tensor_type ()->mutable_shape ();
3133 // This operator only applies to the last dimension; thus -1
3234 for (int i = 0 ; i < input_ndim - 1 ; ++i) {
3335 *output_shape->add_dim () = input_shape.dim (i);
3436 }
35- // The length of second input is the length of the last dimension of the output
37+
38+ // value of the output's last dimension is the total amount of indices
39+ // set Unknown length for the last dimension if it cannot be calculated
40+ auto last_dim = output_shape->add_dim ();
3641 if (hasInputShape (ctx, 1 )) {
3742 const auto & indices_shape = getInputShape (ctx, 1 );
3843 if (indices_shape.dim_size () > 0 ) {
39- auto dim = indices_shape.dim (0 );
40- *output_shape->add_dim () = dim;
41- return ;
44+ int64_t num_indices = 1 ;
45+ std::string single_symbolic_dim;
46+ for (int i = 0 ; i < indices_shape.dim_size (); i++) {
47+ if (indices_shape.dim (i).has_dim_value ()) {
48+ num_indices *= indices_shape.dim (i).dim_value ();
49+ } else if (indices_shape.dim (i).has_dim_param ()) {
50+ if (single_symbolic_dim.empty ()) {
51+ // it is possible to set symbolic dimension param if the rest dim values are all value 1
52+ single_symbolic_dim = indices_shape.dim (i).dim_param ();
53+ } else {
54+ return ;
55+ }
56+ } else {
57+ return ;
58+ }
59+ }
60+ if (single_symbolic_dim.empty ()) {
61+ last_dim->set_dim_value (num_indices);
62+ } else if (num_indices == 1 ) {
63+ last_dim->set_dim_param (single_symbolic_dim);
64+ }
4265 }
4366 }
44- // Unknown length of the last dimension
45- output_shape->add_dim ();
4667 })
4768 .TypeConstraint(
4869 " T" ,
@@ -851,9 +872,9 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
851872 " Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified." );
852873 }
853874
854- std::vector<std::string> label_strs ;
855- auto result = getRepeatedAttribute (ctx, " classlabels_strings" , label_strs );
856- bool using_strings = (result && !label_strs .empty ());
875+ std::vector<std::string> classlabels_strings ;
876+ auto result = getRepeatedAttribute (ctx, " classlabels_strings" , classlabels_strings );
877+ bool using_strings = (result && !classlabels_strings .empty ());
857878 if (using_strings) {
858879 updateOutputElemType (ctx, 0 , TensorProto::STRING);
859880 } else {
@@ -864,10 +885,16 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
864885 checkInputRank (ctx, 0 , 2 );
865886 Dim N, E;
866887 unifyInputDim (ctx, 0 , 0 , N);
867- std::vector<int64_t > class_ids;
868- auto has_ids = getRepeatedAttribute (ctx, " class_ids" , class_ids);
869- if (has_ids) {
870- unifyDim (E, class_ids.size ());
888+
889+ if (using_strings) {
890+ unifyDim (E, classlabels_strings.size ());
891+ } else {
892+ std::vector<int64_t > classlabels_int64s;
893+ result = getRepeatedAttribute (ctx, " classlabels_int64s" , classlabels_int64s);
894+ if (!result || classlabels_int64s.empty ()) {
895+ fail_shape_inference (" Non of classlabels_int64s or classlabels_strings is set." );
896+ }
897+ unifyDim (E, classlabels_int64s.size ());
871898 }
872899 updateOutputShape (ctx, 0 , {N});
873900 updateOutputShape (ctx, 1 , {N, E});
0 commit comments