@@ -62,6 +62,13 @@ enum TypeFlag {
6262 kUint64 = 10 ,
6363};
6464
65+ enum IndicatorRuleFlag {
66+ kGT0 = 0 ,
67+ kLT0 = 1 ,
68+ kMax = 2 ,
69+ kMin = 3 ,
70+ };
71+
6572#define DMLC_DECLARE_DTYPE_FIELD (name ) \
6673 DMLC_DECLARE_FIELD (name) \
6774 .add_enum(" float16" , kFloat16 ) \
@@ -84,6 +91,28 @@ struct CastParam : public dmlc::Parameter<CastParam> {
8491 }
8592};
8693
94+ struct IndicatorParam : public dmlc ::Parameter<IndicatorParam> {
95+ TShape axis;
96+ bool exclude;
97+ DMLC_DECLARE_PARAMETER (IndicatorParam) {
98+ DMLC_DECLARE_FIELD (axis).set_default (TShape ())
99+ .describe (R"code( The axis or axes along which to perform the indicator rule.
100+
101+ The default, `axis=()`, will compute over all elements into a
102+ scalar array with shape `(1,)`.
103+
104+ If `axis` is int, rule is applied on a particular axis.
105+
106+ If `axis` is a tuple of ints, rule is applied on all the axes
107+ specified in the tuple.
108+
109+ If `exclude` is true, rule will be applied on the axes that are
110+ NOT in axis instead.)code" );
111+ DMLC_DECLARE_FIELD (exclude).set_default (false )
112+ .describe (" Whether to apply rule on axis that are NOT in axis instead." );
113+ }
114+ };
115+
87116struct ReshapeParam : public dmlc ::Parameter<ReshapeParam> {
88117 Tuple<int64_t > shape;
89118
@@ -97,8 +126,7 @@ struct SqueezeParam : public dmlc::Parameter<SqueezeParam> {
97126
98127 DMLC_DECLARE_PARAMETER (SqueezeParam) {
99128 DMLC_DECLARE_FIELD (axis).set_default (TShape ())
100- .describe (" The axis to squeeze in the input tensor."
101- " If set to None, all size=1 axes will be squeezed" );
129+ .describe (" The axis to squeeze in the input tensor." );
102130 }
103131};
104132
@@ -110,6 +138,15 @@ struct ScalarParam : public dmlc::Parameter<ScalarParam> {
110138 }
111139};
112140
141+ struct FillValueParam : public dmlc ::Parameter<FillValueParam> {
142+ double fill_value;
143+
144+ DMLC_DECLARE_PARAMETER (FillValueParam) {
145+ DMLC_DECLARE_FIELD (fill_value)
146+ .describe (" Scalar value to be filled" );
147+ }
148+ };
149+
113150struct TransposeParam : public dmlc ::Parameter<TransposeParam> {
114151 TShape axes;
115152
@@ -158,16 +195,49 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
158195 }
159196};
160197
198+ struct InitOpWithScalarParam : public dmlc ::Parameter<InitOpWithScalarParam> {
199+ TShape shape;
200+ int dtype;
201+ double fill_value;
202+
203+ DMLC_DECLARE_PARAMETER (InitOpWithScalarParam) {
204+ DMLC_DECLARE_FIELD (shape).set_default (TShape ());
205+ DMLC_DECLARE_DTYPE_FIELD (dtype).set_default (kFloat32 )
206+ .describe (" Target data type." );
207+ DMLC_DECLARE_FIELD (fill_value).describe (" Scalar value to fill" );
208+ }
209+ };
210+
161211struct InitOpParam : public dmlc ::Parameter<InitOpParam> {
162212 TShape shape;
163213 int dtype;
164- double value;
165214
166215 DMLC_DECLARE_PARAMETER (InitOpParam) {
167216 DMLC_DECLARE_FIELD (shape).set_default (TShape ());
168217 DMLC_DECLARE_DTYPE_FIELD (dtype).set_default (kFloat32 )
169218 .describe (" Target data type." );
170- DMLC_DECLARE_FIELD (value).describe (" Value to fill" );
219+ }
220+ };
221+
222+ struct ElementWiseReduceParam : public dmlc ::Parameter<ElementWiseReduceParam> {
223+ int num_args;
224+ DMLC_DECLARE_PARAMETER (ElementWiseReduceParam) {
225+ DMLC_DECLARE_FIELD (num_args).set_lower_bound (1 )
226+ .describe (" Number of inputs to be reduced." );
227+ }
228+ };
229+
230+ struct MatMulParam : public dmlc ::Parameter<MatMulParam> {
231+ bool transpose_a;
232+ bool transpose_b;
233+
234+ DMLC_DECLARE_PARAMETER (MatMulParam) {
235+ DMLC_DECLARE_FIELD (transpose_a)
236+ .describe (" If true then transpose the first input before dot." )
237+ .set_default (false );
238+ DMLC_DECLARE_FIELD (transpose_b)
239+ .describe (" If true then transpose the second input before dot." )
240+ .set_default (false );
171241 }
172242};
173243
0 commit comments