-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
This is the follow-up issue for https://discuss.tvm.ai/t/rfc-functionality-of-alteroplayout-and-possible-refactoring/ To enhance the AlterOpLayout pass,
The reason to do the refactoring is to address the limitation in current AlterOpLayout pass:
- the altered op should have the same number of arguments as the previous one
- do not support nested tuple arguments
- It cannot properly deal with the scenario where altered operator has a different type, e.g., conv2d replaced by conv2d+add, which @anijain2305 recently required for quantization
It is extremely difficult (or even impossible) to address the above problems in current design, the detailed reason can be found in "Expand me to read the functionality of AlterOpLayout, as well as the motivation of doing refactoring." part of the post on discuss
I would like to propose 4 more passes to replace current AlterOpLayout pass,
-
Layout inference pass
To infer the layout of each layer.
example:
https://github.com/yzhliu/tvm-1/blob/refactor_alter_layout/src/relay/op/nn/convolution.cc#L144
https://github.com/yzhliu/tvm-1/blob/refactor_alter_layout/tests/python/relay/test_pass_infer_layout.py -
Rewrite operator pass
This pass is to rewrite the operator to another (set of) operator(s), while the shape/dtype/layout need to remain the same for input and output. API,@conv2d_rewrite_op.register("cpu") def _rewrite_conv2d(attrs, inputs, tinfo):
This can be used to convert
(NCHW) -> conv2d -> (NCHW)
to
(NCHW) -> LT(NCHW->NCHW16c) -> conv2d_NCHW16c -> LT(NCHW16c -> NCHW) -> (NCHW) -
Propagate layout pass
This pass is to convert other operators to use the layout of its previous operator, it can be used to convert
conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> add
to
conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> LT(NCHW->NCHW16c) -> add -> LT(NCHW16c->NCHW)
The API looks like, (can be pre-defined rules)@add_propagate_layout() propagate_layout_add(origin=["NCHW", "CHW"], preferred=["NCHW16c", "CHW16c"])
-
Peephole pass
Remove unnecessary layout transform operators, it can be used to convert
conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> LT(NCHW->NCHW16c) -> add -> LT(NCHW16c->NCHW)
to
conv2d_NCHW16c -> add -> LT(NCHW16c->NCHW)