Skip to content

Commit 9926a40

Browse files
jenrivertensorflower-gardener
authored andcommitted
Add quantization/legalization for stablehlo.add and respective pipeline changes.
* Added `enable_full_int_quantization` in `StaticRangePtqPreset` to determine full int quantization. This value will be `false` by default, meaning only compute-heavy ops will be quantized unless specified. * Added tests for the above config change. * Follow up tests will include e2e python tests. PiperOrigin-RevId: 620067140
1 parent 2d48a6d commit 9926a40

19 files changed

+363
-186
lines changed

tensorflow/compiler/mlir/lite/stablehlo/tests/uniform-quantized-stablehlo-to-tfl.mlir

Lines changed: 103 additions & 69 deletions
Large diffs are not rendered by default.

tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc

Lines changed: 79 additions & 85 deletions
Large diffs are not rendered by default.

tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@ void PopulateDefaultCalibrationOptions(QuantizationConfig& quant_config) {
9898
// {matcher {function_name {regex: ".*"}}
9999
// {method {static_range_ptq {}}}
100100
// }
101-
QuantizationSpec GetDefaultStaticRangePtqSpec() {
101+
QuantizationSpec GetDefaultStaticRangePtqSpec(StaticRangePtqPreset preset) {
102102
QuantizationSpec spec{};
103103
// Default for all ops.
104-
spec.mutable_matcher()->mutable_function_name()->set_regex(".*");
104+
spec.mutable_matcher()->mutable_function_name()->set_regex(
105+
preset.enable_full_int_quantization() ? ".*" : "^.*(conv|dot|gather).*");
105106
spec.mutable_method()->mutable_static_range_ptq();
106107

107108
return spec;
@@ -161,7 +162,8 @@ void ExpandStaticRangePtqPreset(const StaticRangePtqPreset& preset,
161162
// expansion from `StaticRangePtqPreset` gets populated first and then
162163
// user-provided explicit `QuantizationSpec`s will be appended.
163164
QuantizationSpecs new_specs{};
164-
*new_specs.add_specs() = GetDefaultStaticRangePtqSpec();
165+
*new_specs.add_specs() =
166+
GetDefaultStaticRangePtqSpec(/*preset=*/config.static_range_ptq_preset());
165167
*new_specs.add_specs() = GetStaticRangePtqSpecForConvolution();
166168

167169
// Append user-provided specs to override existing specs.

tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,12 @@ TEST(ExpandPresetsTest, ExpandUnspecifiedPreset) {
147147
EXPECT_FALSE(new_config.has_pipeline_config());
148148
}
149149

150-
TEST(ExpandPresetsTest, ExpandStaticRangePtqPreset) {
150+
TEST(ExpandPresetsTest, ExpandStaticRangePtqEnableFullIntquantization) {
151151
QuantizationConfig config{};
152152
RepresentativeDatasetConfig& preset_dataset_config =
153153
*config.mutable_static_range_ptq_preset()->add_representative_datasets();
154+
config.mutable_static_range_ptq_preset()->set_enable_full_int_quantization(
155+
true);
154156
preset_dataset_config.mutable_tf_record()->set_path("/test/path");
155157

156158
const QuantizationConfig new_config = ExpandPresets(config);
@@ -185,6 +187,21 @@ TEST(ExpandPresetsTest, ExpandStaticRangePtqPreset) {
185187
StrEq("/test/path"));
186188
}
187189

190+
TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetDefault) {
191+
QuantizationConfig config{};
192+
RepresentativeDatasetConfig& preset_dataset_config =
193+
*config.mutable_static_range_ptq_preset()->add_representative_datasets();
194+
preset_dataset_config.mutable_tf_record()->set_path("/test/path");
195+
196+
const QuantizationConfig new_config = ExpandPresets(config);
197+
ASSERT_THAT(new_config.specs().specs(), SizeIs(2));
198+
199+
const QuantizationSpec& spec = new_config.specs().specs(0);
200+
EXPECT_THAT(spec.matcher().function_name().regex(),
201+
StrEq("^.*(conv|dot|gather).*"));
202+
EXPECT_TRUE(spec.method().has_static_range_ptq());
203+
}
204+
188205
TEST(ExpandPresetsTest,
189206
ExpandStaticRangePtqPresetWithTopLevelRepresentativeDataset) {
190207
// Test the scenario where both
@@ -216,7 +233,8 @@ TEST(ExpandPresetsTest,
216233

217234
TEST(ExpandPresetsTest, ExpandStaticRangePtqPresetThenAppendExplicitSpecs) {
218235
QuantizationConfig config{};
219-
config.mutable_static_range_ptq_preset();
236+
config.mutable_static_range_ptq_preset()->set_enable_full_int_quantization(
237+
true);
220238

221239
QuantizationSpec& user_provided_spec = *config.mutable_specs()->add_specs();
222240
user_provided_spec.mutable_matcher()->mutable_function_name()->set_regex(

tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ void AddPostCalibrationPasses(
5858
OpPassManager& pm, const PipelineConfig& pipeline_config,
5959
const StaticRangePtqPreset& static_range_ptq_preset) {
6060
QuantizeCompositeFunctionsPassOptions options;
61+
// TODO: b/331120943 - Use QuantizationConfig instead of preset flags.
6162
options.enable_per_channel_quantized_weight_ =
6263
static_range_ptq_preset.enable_per_channel_quantized_weight();
64+
options.enable_full_int_quantization_ =
65+
static_range_ptq_preset.enable_full_int_quantization();
6366
// For debugging purposes.
6467
options.mlir_dump_file_name_ = "quantize_composite_functions";
6568
options.enable_weight_only_ = false;

tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,11 @@ def LiftGather : Pat<
6767
(NamedAttr<"slice_sizes"> $slice_sizes),
6868
(NamedAttr<"indices_are_sorted"> (DefaultOrNullAttr $indices_are_sorted)))),
6969
[(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $operand)], [], (addBenefit 1)>;
70+
71+
def LiftAdd : Pat<
72+
(StableHLO_AddOp:$res
73+
$lhs, $rhs),
74+
(LiftAsTFXlaCallModule<"composite_add_fn">
75+
(ArgumentList $lhs, $rhs),
76+
(ResultList $res)),
77+
[(IsNotInLiftedFunc $res), (IsNotInStableHloOpRegion $res)], [], (addBenefit 1)>;

tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function
6060
"enable-per-channel-quantized-weight",
6161
"bool", /*default=*/"true",
6262
"Whether to enable per-channel quantized weights.">,
63+
Option<"enable_full_int_quantization_",
64+
"enable-full-int-quantization",
65+
"bool", /*default=*/"false",
66+
"Whether to enable full int quantization, including non compute-heavy ops.">,
6367
Option<"mlir_dump_file_name_", "mlir-dump-file-name",
6468
"std::optional<std::string>", /*default=*/"std::nullopt",
6569
"MLIR dump file name.">,
@@ -102,6 +106,10 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> {
102106
"enable-per-channel-quantized-weight",
103107
"bool", /*default=*/"true",
104108
"Whether to enable per-channel quantized weights.">,
109+
Option<"enable_full_int_quantization_",
110+
"enable-full-int-quantization",
111+
"bool", /*default=*/"false",
112+
"Whether to apply full int quantization, including non compute-heavy ops.">,
105113
Option<"enable_weight_only_",
106114
"enable-weight-only",
107115
"bool", /*default=*/"false",

tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,12 @@ void PopulateComputeHeavyPatterns(
954954
patterns.add<QuantizeOpWithRegionPattern>(ctx);
955955
}
956956

957+
void PopulateAllQuantizablePatterns(MLIRContext& ctx,
958+
RewritePatternSet& patterns) {
959+
patterns.add<XlaCallModuleOpToCallOp<QuantizeSingularOpPattern<AddOp>>>(
960+
ctx, /*enable_per_channel_quantized_weight=*/false);
961+
}
962+
957963
void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx,
958964
RewritePatternSet& patterns) {
959965
patterns.add<

tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,11 @@ class StableHloQuantizationPattern : public OpRewritePattern<RootOpT> {
254254
void PopulateComputeHeavyPatterns(MLIRContext& ctx, RewritePatternSet& patterns,
255255
bool enable_per_channel_quantized_weight);
256256

257+
// Populates conversion patterns for all quantizable ops, including
258+
// ops that are not compute-heavy and data movement ops.
259+
void PopulateAllQuantizablePatterns(MLIRContext& ctx,
260+
RewritePatternSet& patterns);
261+
257262
// Populates pattern weight-only quantization.
258263
void PopulateQuantizeWeightOnlyPatterns(MLIRContext& ctx,
259264
RewritePatternSet& patterns);

tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,11 @@ class QuantizePass : public impl::QuantizePassBase<QuantizePass> {
9696
using impl::QuantizePassBase<QuantizePass>::QuantizePassBase;
9797

9898
explicit QuantizePass(const bool enable_per_channel_quantized_weight,
99+
const bool enable_full_int_quantization,
99100
const bool enable_weight_only,
100101
const QuantizationSpecs& quant_specs) {
101102
enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight;
103+
enable_full_int_quantization_ = enable_full_int_quantization;
102104
enable_weight_only_ = enable_weight_only;
103105
}
104106

@@ -120,6 +122,11 @@ void QuantizePass::runOnOperation() {
120122
PopulateComputeHeavyPatterns(ctx, patterns,
121123
enable_per_channel_quantized_weight_);
122124

125+
// Quantize all quantizable ops, including ops that are not compute-heavy.
126+
if (enable_full_int_quantization_) {
127+
PopulateAllQuantizablePatterns(ctx, patterns);
128+
}
129+
123130
if (failed(applyPatternsAndFoldGreedily(module_op, std::move(patterns)))) {
124131
// There are cases where no rewrites happen even if a pattern matches,
125132
// causing this to result in a convergence failure. Consider this as a

0 commit comments

Comments
 (0)