Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support tensorrt in qat #5287

Merged
merged 10 commits into from
Jun 24, 2021
10 changes: 4 additions & 6 deletions oneflow/core/job_rewriter/quantization_aware_training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ std::string QuantizationSchemeAttr4QatConfig(const QatConfig& qat_config) {
// TODO: refactor the following 4 methods by registration
std::string QuantizationFormulaAttr4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "" || target_backend == "tensorrt7") {
if (target_backend == "" || target_backend == "tensorrt") {
return "google";
} else if (target_backend == "cambricon") {
return "cambricon";
Expand All @@ -183,18 +183,16 @@ OpTypeSet Int8List4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "") {
return {"add_n", "matmul", "batch_matmul", "conv2d", "avg_pool_2d", "max_pool_2d"};
} else if (target_backend == "cambricon") {
} else if (target_backend == "cambricon" || target_backend == "tensorrt") {
return {"conv2d", "matmul"};
} else if (target_backend == "tensorrt7") {
return {"conv2d"};
} else {
UNIMPLEMENTED();
}
}

OpTypeSet TransparentList4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "" || target_backend == "tensorrt7") {
if (target_backend == "" || target_backend == "tensorrt") {
return {"reshape"};
} else if (target_backend == "cambricon") {
return {};
Expand All @@ -205,7 +203,7 @@ OpTypeSet TransparentList4QatConfig(const QatConfig& qat_config) {

bool InsertQuantOpAfterInt8Ops4QatConfig(const QatConfig& qat_config) {
const auto target_backend = qat_config.target_backend();
if (target_backend == "" || target_backend == "tensorrt7") {
if (target_backend == "" || target_backend == "tensorrt") {
return true;
} else if (target_backend == "cambricon") {
return false;
Expand Down