@@ -128,37 +128,65 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift")
128
128
PrimExpr q = call->args [2 ];
129
129
PrimExpr s = call->args [3 ];
130
130
131
- // Only int32 types are supported (any number of lanes is allowed)
132
- ICHECK (y.dtype ().code () == DLDataTypeCode::kDLInt && y.dtype ().bits () == 32 );
133
- ICHECK (s.dtype ().code () == DLDataTypeCode::kDLInt && s.dtype ().bits () == 32 );
134
-
135
- DataType hp_dtype = DataType::Int (64 , x.dtype ().lanes ());
136
- DataType lp_dtype = DataType::Int (32 , x.dtype ().lanes ());
137
-
138
- // 1) Calculating the integer multiplier and integer shift
139
- PrimExpr zero = make_const (s.dtype (), 0 );
140
- PrimExpr left_shift = tir::Select (s > zero, s, zero);
141
- PrimExpr right_shift = tir::Select (s > zero, zero, -s);
142
-
143
- // 2) Cast and Multiply the integer multiplier
144
- PrimExpr one = make_const (hp_dtype, 1 );
145
- x = cast (hp_dtype, x);
146
- y = cast (hp_dtype, y);
147
- x = tir::Select (left_shift != zero, x << left_shift, x);
148
-
149
- // 3) Perform the multiplication in higher precision.
150
- x = x * y;
151
-
152
- // 4) Find the rounding scalar
153
- PrimExpr total_right_shift = right_shift + q;
154
- PrimExpr pos_rounding_value = (one << (total_right_shift - 1 ));
155
- x = x + pos_rounding_value;
156
-
157
- // 5) Simply right shift the result to get the final output.
158
- x = x >> total_right_shift;
159
-
160
- // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
161
- *rv = cast (lp_dtype, x);
131
+ // Lambda function to extract the int value from PrimExpr
132
+ auto get_int_value = [](const PrimExpr node) {
133
+ auto broadcast_node = node.as <BroadcastNode>();
134
+ CHECK (broadcast_node != nullptr );
135
+ auto int_node = broadcast_node->value .as <IntImmNode>();
136
+ CHECK (int_node != nullptr );
137
+ return int_node->value ;
138
+ };
139
+ // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
140
+ // fixed point multiplier will represent a float value of 0.5. In fixed point, this is
141
+ // represented by 1 << 30.
142
+ if (get_int_value (y) == (1 << 30 )) {
143
+ PrimExpr exp = s - 1 ;
144
+ int exp_val = get_int_value (s) - 1 ;
145
+ if (exp_val > 0 ) {
146
+ // power of 2 is greater than 0, apply left shift.
147
+ *rv = x << exp;
148
+ } else {
149
+ // power of 2 is less than 0, round and then apply right shift.
150
+ DataType lp_dtype = DataType::Int (32 , x.dtype ().lanes ());
151
+ PrimExpr one = make_const (lp_dtype, 1 );
152
+ exp = -exp;
153
+ PrimExpr rounding_factor = one << (exp - 1 );
154
+ PrimExpr rounded_t = x + rounding_factor;
155
+ *rv = rounded_t >> exp;
156
+ }
157
+ } else {
158
+ // Only int32 types are supported (any number of lanes is allowed)
159
+ ICHECK (y.dtype ().code () == DLDataTypeCode::kDLInt && y.dtype ().bits () == 32 );
160
+ ICHECK (s.dtype ().code () == DLDataTypeCode::kDLInt && s.dtype ().bits () == 32 );
161
+
162
+ DataType hp_dtype = DataType::Int (64 , x.dtype ().lanes ());
163
+ DataType lp_dtype = DataType::Int (32 , x.dtype ().lanes ());
164
+
165
+ // 1) Calculating the integer multiplier and integer shift
166
+ PrimExpr zero = make_const (s.dtype (), 0 );
167
+ PrimExpr left_shift = tir::Select (s > zero, s, zero);
168
+ PrimExpr right_shift = tir::Select (s > zero, zero, -s);
169
+
170
+ // 2) Cast and Multiply the integer multiplier
171
+ PrimExpr one = make_const (hp_dtype, 1 );
172
+ x = cast (hp_dtype, x);
173
+ y = cast (hp_dtype, y);
174
+ x = tir::Select (left_shift != zero, x << left_shift, x);
175
+
176
+ // 3) Perform the multiplication in higher precision.
177
+ x = x * y;
178
+
179
+ // 4) Find the rounding scalar
180
+ PrimExpr total_right_shift = right_shift + q;
181
+ PrimExpr pos_rounding_value = (one << (total_right_shift - 1 ));
182
+ x = x + pos_rounding_value;
183
+
184
+ // 5) Simply right shift the result to get the final output.
185
+ x = x >> total_right_shift;
186
+
187
+ // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
188
+ *rv = cast (lp_dtype, x);
189
+ }
162
190
});
163
191
164
192
} // namespace intrin
0 commit comments