-
Notifications
You must be signed in to change notification settings - Fork 666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add fake_quantization op conversion in onnx #4512
Conversation
|
||
|
||
def set_moving_max_min_value(): | ||
max_key, min_key = "", "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里初始值设成 None(因为 None 一定是一个不合法的值),在 31 行检查两个 key 都不是 None
node.output_tensor_names[0], | ||
name=id_util.UniqueStr(node.name), | ||
) | ||
if opset == 10: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改成 opset < 13,173 行改成 else
|
||
def test_fake_quantization_affine_gpu_moving_average(test_case): | ||
generate_fake_quantization_test_moving_average( | ||
formula="google", scheme="affine", device_type="gpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formula 只能是 google,这个参数可以去掉。
然后可以参考一下其它单测里,用 GenArgDict 自动生成参数组合的方式
@@ -123,6 +124,7 @@ std::string OpTypeName4OpNode(const OpNode* node) { | |||
|
|||
using OpConfMap = HashMap<std::string, OperatorConf>; | |||
|
|||
template<DataType VALUE_TYPE = DataType::kFloat> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不需要大写,就叫 data_type 就可以
return GenLogicalBlobName(train_step_var.name(), | ||
train_step_var.variable_conf().out()); | ||
} | ||
}()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里写成普通的顺序代码吧,现在这样定义 lambda 并立即调用的易读性不高,如果把这个 lambda 函数取个名字再调用的话,一是这段代码只会执行一遍,二是这段代码有副作用(创建 variable op),本身就不能多次执行。所以直接改成顺序代码吧
@@ -255,22 +255,19 @@ user_op::UserOpConfWrapper MovingMinMaxObserver(const std::string& name, const s | |||
Get1DZeroVariableOpConf(moving_max_name, scope_symbol_id, 1, inserted_ops); | |||
const auto moving_min_var = | |||
Get1DZeroVariableOpConf(moving_min_name, scope_symbol_id, 1, inserted_ops); | |||
std::string train_step_value = train_step_lbn; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里换一个清晰一些的名字,例如 observer_current_train_step
* add min_max_observer and moving_average_min_max_observer conversion in onnx * fix op's name * fix moving_average_min_max_observer * update quantization ops conversion * update quantization ops conversion and tests * update ops version * delete auto-imported package * update min_max_observer op * format * fix test_quantization_aware_training * update test_quantize_op * add fake_quantization conversion in onnx * format quantize.py * fix fake_quantization * update fake_quantization conversion * update fake_quantization conversion and its test * update quantization_aware_training * update quantization_aware_training Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Former-commit-id: 1fef03a
No description provided.