Skip to content

[RFC] AlterOpLayout Pass Refactoring #3670

@yzhliu

Description

@yzhliu

This is the follow-up issue for https://discuss.tvm.ai/t/rfc-functionality-of-alteroplayout-and-possible-refactoring/ To enhance the AlterOpLayout pass,

The reason to do the refactoring is to address the limitation in current AlterOpLayout pass:

  • the altered op should have the same number of arguments as the previous one
  • do not support nested tuple arguments
  • It cannot properly deal with the scenario where altered operator has a different type, e.g., conv2d replaced by conv2d+add, which @anijain2305 recently required for quantization

It is extremely difficult (or even impossible) to address the above problems in current design, the detailed reason can be found in "Expand me to read the functionality of AlterOpLayout, as well as the motivation of doing refactoring." part of the post on discuss

I would like to propose 4 more passes to replace current AlterOpLayout pass,

  • Layout inference pass
    To infer the layout of each layer.
    example:
    https://github.com/yzhliu/tvm-1/blob/refactor_alter_layout/src/relay/op/nn/convolution.cc#L144
    https://github.com/yzhliu/tvm-1/blob/refactor_alter_layout/tests/python/relay/test_pass_infer_layout.py

  • Rewrite operator pass
    This pass is to rewrite the operator to another (set of) operator(s), while the shape/dtype/layout need to remain the same for input and output. API,

    @conv2d_rewrite_op.register("cpu")
    def _rewrite_conv2d(attrs, inputs, tinfo):

    This can be used to convert
    (NCHW) -> conv2d -> (NCHW)
    to
    (NCHW) -> LT(NCHW->NCHW16c) -> conv2d_NCHW16c -> LT(NCHW16c -> NCHW) -> (NCHW)

  • Propagate layout pass
    This pass is to convert other operators to use the layout of its previous operator, it can be used to convert
    conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> add
    to
    conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> LT(NCHW->NCHW16c) -> add -> LT(NCHW16c->NCHW)
    The API looks like, (can be pre-defined rules)

    @add_propagate_layout()
    propagate_layout_add(origin=["NCHW", "CHW"], preferred=["NCHW16c", "CHW16c"])
  • Peephole pass
    Remove unnecessary layout transform operators, it can be used to convert
    conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> LT(NCHW->NCHW16c) -> add -> LT(NCHW16c->NCHW)
    to
    conv2d_NCHW16c -> add -> LT(NCHW16c->NCHW)

@tqchen @merrymercy @anijain2305

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions