Skip to content

Conversation

@shangzhizhou
Copy link
Member

PR types

Others

PR changes

Others

Describe

Add trt layer norm dynamic. cherry-pick #33293

* add dynamic layer_norm plugin

* fix bug

* fix numpy.allclose

* fix format

* fix code style

* remove shepe in dynamic shape

* code format

* remove layer norm fp16

* fix format
@paddle-bot-old
Copy link

paddle-bot-old bot commented Jun 9, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

cudaMalloc(&variance_gpu_half_d_, variance_shape_product * sizeof(half));
}

half *scale_cpu_half =
Copy link
Contributor

@Superjomn Superjomn Jun 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用 unique_ptr 吧,避免裸用 malloc, free,防止内存泄露

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,删除了这部分代码。layer_norm的fp16实现精度不满足用户需求,暂时只提供fp32的支持。

…PaddlePaddle#33535)

* 1, remove layernorm dynamic fp16; 2, let reshape out in dynamic shape

* remove useless code
Copy link
Contributor

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Superjomn Superjomn merged commit e5bd7eb into PaddlePaddle:release/2.1 Jun 16, 2021
@shangzhizhou shangzhizhou deleted the add_trt_layer_norm_dynamic branch June 18, 2021 08:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants