Skip to content

[CINN] TileBroadcastTactic NHWC layout broadcast support #71434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 188 commits into from

Conversation

Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented Mar 5, 2025

NOTE: the same PR is posted again in #71464, since the current PR has some problems with git history. Therefore the current PR is discarded.

PR Category

CINN

PR Types

Improvements

Description

Implement the NHWC layout support for TileBroadcastTactic introduced in #70092.

Introduction

The original TileBroadcastTactic only handles the broadcast of the following form:

   [1, C, 1, 1] => [N, C, H, W].

For NHWC layout broadcast, CINN will fall back to TileFirstGeneralTactic since the last axis is not a broadcast axis. It will have some performance issues, such as using a rather static block size for tensor with different channels, which might produces excessive GMEM load.

The extended TileBroadcastTactic also covers the following case:

[1, 1, 1, C] => [N, H, W, C]

as long as the last axis is a preserved axis (see #70092 for the term definition).

Performance Impact

This tactic extension avoids excessive load in the a simple way: find an appropriate block size K, which satifies:

  • C % K is 0
  • K is a multiple of 32
  • K should be as close to C as possible.
  • K is within a certain range, for example: [128, 1024].

The first requirement is the most important one: it will eliminate excessive loads by making the load index invariant to the thread coarsening loop index. Here is an example:

Without this extension, for the input tensor with shape (64, 56, 56, 192), the block size is 256, therefore, with thread coarsening, the thread will load one value from GMEM each loop iteration, totaling 4 loads per tensor (that needs to broadcasting).

// example 192-> (64, 56, 56, 192) tensor broadcast
  for (int32_t thread_loop_i = 0; thread_loop_i < 4; thread_loop_i += 1) {
    float var_0_local = var_0[((((thread_loop_i * 256) + (int)threadIdx.x) + ((int)blockIdx.x * 1024)) % 192)];
    /* ... */
    to_broad_cast[/* ... */] = some_func(var_0_local, /* ... */);
}

With this extension, we only need to load once, since the load index is loop-index invariant. Effectively reducing the number of loads required:

// example 192-> (64, 56, 56, 192) tensor broadcast
  for (int32_t thread_loop_i = 0; thread_loop_i < 4; thread_loop_i += 1) {
    float var_0_local = var_0[threadIdx.x];
    /* ... */
    to_broad_cast[/* ... */] = some_func(var_0_local, /* ... */);
}

Limitation

  • For tensors with large C, since the block size can not be too high (otherwise the occupancy will be bad), the load index won't be invariant and the loaded data won't be reused. In this tactic extension, we offered a simple (but not perfect) solution: disabling thread coarsening to reduce register requirement, which will in turn increase occupancy. To solve the problem: register requirement must be dealt with.
  • This extension supports broadcast with any shape, as long as the broadcast form is [B, P]. For example: (1, H, 1, C) -> (N, H, W, C). But this is not locally tested and can be potentially erroneous.

Experiment Results

Tested on BatchNorm op (data_layout="NHWC"), the forward broadcast kernel (V100):

shape max bandwidth % runtime (us) bandwidth improvement (%) runtime decrease (%)
256,64,64,192 92.26 1940 35.56 -26.24
128,72,72,224 92.33 1430 64.43 -39.14
256,48,48,224 92.31 1270 67.29 -40.37
400,72,72,384 92.30 7680 62.10 -38.13

Pcard-89620

Enigmatisms and others added 30 commits March 5, 2025 09:31
* PP shared layer with multi attrs

* update to all equal
* split device file

* fix path

* add setup.py

* split device_event_def.h
… ir-trt into pir-trt. (PaddlePaddle#70961)

* fix

* add pd_op.atan,tan,asin,acos

* fix pd_op.full_like and add pd_op.atan,pd_op.tan,pd_op.acos,pd_op.asin

* fix

* fix

* fix

* fix

* fix

* fix

* fix pd_op.pool2d

* fix

* fix pd_op.pool2d

* fix

* fix

* 增加trt_config.allow_only_specify_trt_ops

* fix
* pd_op.linear_interp

* fix

* fix

* fix

* 增加单测覆盖率

* fix

* fix

* fix
* support int8 quant in trt

* support int8 quant in trt

* fix coverage

* perfect code
* add send recv

* fix

* remove assert in reshard func

* update p_recv_kernel.cu

* recover copyright

* add send_recv functor

* fix include error

* fix build error

* update THROW message
* Support XPU for static auto-parallel

* Remove logs
* split device file

* fix path

* add setup.py

* split device_event_def.h

* modify path
* update pop/push instruction

* delete std::cout
* longlong2int for dynamic shape

* change cuda func args type

* add args for grid reduce

* fix bug for ci

* remove ! in func name

* refine code

* ir copy on host module args

* update dynamic cast

* fix comment

* polish code

* refine code
…e#71142)

* add gard_api inplace version

* split api.h and backward_api.h

* modify build wrong

* modify build bug

* change backward_api_yaml args

* change backward_api_yaml args

* modify impl to base
…le#71157)

* add pow and index_put

* fix codestyle

* Update test_converter_math.py

* update

* fix codestyle
…addlePaddle#70883)

* Update transform_gpu_forloop pass

* Update op_lowering_impl.cc

* Update CudaSyncThreadsDropIfThenElse pass

* Disable EliminateCommonGlobalMemoryRead pass
zhangbo9674 and others added 25 commits March 5, 2025 09:34
…ddlePaddle#71212)

* support: each parameter has different lr in merged_momentum

* fix code style

* support only one lr in merged_momentum

* support only one lr in merged_momentum

* limit: lr should be either 1 or len of param

* fix code style

* add some test cases for merged_moementum
---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…dictor will use it in customplace (PaddlePaddle#71362)

* add customdevice default_pass

* add declare

* rm stdmove

* add custom_load pass
* add UT: test_selected_high_order_derivative

* remove default axis = 0

* update error msg

* update op_compat.yaml

* update UT

* update code

* only run UT in gpu
…lShape (PaddlePaddle#71320)

* fix

* fix

* handle dynamic shape in PIR infer_local_shape and infer_global_shape and add PIR nd_mesh_alltoall reshard function

* fix bugs in static auto parallel

---------

Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
* [XPU] add isfinite/isinf support

* fix test

* fix
Copy link

paddle-bot bot commented Mar 5, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@CLAassistant
Copy link

CLAassistant commented Mar 6, 2025

CLA assistant check
All committers have signed the CLA.

@Enigmatisms Enigmatisms closed this Mar 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.