Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OOM in slim pruning #1123

Merged
merged 1 commit into from
Jul 6, 2021
Merged

Conversation

wanghaoshuang
Copy link
Contributor

@wanghaoshuang wanghaoshuang commented Jul 5, 2021

问题

v100-16G环境下执行以下脚本,报OOM错误:

python slim/prune.py \
    --config configs/attention_unet/attention_unet_cityscapes_1024x512_80k.yml \
    --batch_size 1 \
    --retraining_iters 10 --pruning_ratio 0.5 --save_dir prune_model/attention_unet/ 

修复

修改PaddleSeg repo中的slim/prune.py文件,将

sample_shape = [1] + list(val_dataset[0][0].shape)

修改为:

sample_shape = [1] + list(train_dataset[0][0].shape)

https://github.com/PaddlePaddle/PaddleSeg/blob/develop/slim/prune.py#L152

已验证该修复之后,模型训练显存占用不会超过16G.

原因:

在构造L1NormFilterPruner实例时,会根据sample_shape执行一遍动态图的前向计算,以便将动态图转成静态图,该过程占用的显存与sample_shape正相关。动态图转成静态图是为了分析模型结构,所用的sample_shape不会影响分析的准确性,所以要尽量使用小的sample_shape.

以CityScape数据为例,train_dataset的shape为 (3, 512, 1024),val_dataset的shape为(3, 1024, 2048) ;

@wanghaoshuang wanghaoshuang reopened this Jul 6, 2021
@nepeplwu nepeplwu merged commit dcbc00c into PaddlePaddle:develop Jul 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants