@@ -2102,20 +2102,39 @@ class RewriteQuantizedAddOp : public OpRewritePattern<stablehlo::AddOp> {
2102
2102
}
2103
2103
};
2104
2104
2105
+ // Rewrites quantized `stablehlo.constant` to `tfl.pseudo_qconst`.
2106
+ class RewriteQuantizedConstantOp
2107
+ : public OpRewritePattern<stablehlo::ConstantOp> {
2108
+ public:
2109
+ using OpRewritePattern<stablehlo::ConstantOp>::OpRewritePattern;
2110
+
2111
+ LogicalResult match (stablehlo::ConstantOp op) const override {
2112
+ return success (IsQuantizedTensorType (op.getOutput ().getType ()));
2113
+ }
2114
+
2115
+ void rewrite (stablehlo::ConstantOp op,
2116
+ PatternRewriter& rewriter) const override {
2117
+ rewriter.replaceOpWithNewOp <TFL::QConstOp>(
2118
+ op, /* qtype=*/ TypeAttr::get (op.getOutput ().getType ()),
2119
+ /* value=*/ op.getValue ());
2120
+ }
2121
+ };
2122
+
2105
2123
void UniformQuantizedStableHloToTflPass::runOnOperation () {
2106
2124
func::FuncOp func_op = getOperation ();
2107
2125
MLIRContext& ctx = getContext ();
2108
2126
2109
2127
RewritePatternSet patterns (&ctx);
2110
2128
patterns.add <RewriteUniformDequantizeOp, RewriteUniformQuantizeOp,
2111
- RewriteQuantizedBroadcastInDimOp, RewriteQuantizedConcatenateOp,
2129
+ RewriteQuantizedAddOp, RewriteQuantizedBroadcastInDimOp,
2130
+ RewriteQuantizedConcatenateOp, RewriteQuantizedConstantOp,
2112
2131
RewriteQuantizedConvolutionOp,
2113
2132
RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp,
2114
2133
RewriteQuantizedDynamicReshapeOp, RewriteQuantizedDynamicSliceOp,
2115
2134
RewriteQuantizedGatherOp, RewriteQuantizedPadOp,
2116
2135
RewriteQuantizedReduceWindowOpWithMax, RewriteQuantizedReshapeOp,
2117
2136
RewriteQuantizedSelectOp, RewriteQuantizedSliceOp,
2118
- RewriteQuantizedTransposeOp, RewriteQuantizedAddOp >(&ctx);
2137
+ RewriteQuantizedTransposeOp>(&ctx);
2119
2138
2120
2139
if (failed (applyPatternsAndFoldGreedily (func_op, std::move (patterns)))) {
2121
2140
func_op.emitError () << " Failed to convert stablehlo ops with uniform "
0 commit comments