From 45468158833dc688681f59e2a3dbf6569b9fe1a0 Mon Sep 17 00:00:00 2001 From: floveqq <108666823+floveqq@users.noreply.github.com> Date: Sat, 6 May 2023 13:46:44 +0800 Subject: [PATCH 1/7] Update README_cn.md --- README_cn.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README_cn.md b/README_cn.md index 20544565..bb5619ac 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,3 +1,11 @@ + +## 飞桨黑客松第四期 +赛题155. YOLOv8复现 + + +###################################################################### + + 简体中文 | [English](README_en.md) ## 简介 From 0002ef5b9a8c2891ef5be4c198e78e91c0074702 Mon Sep 17 00:00:00 2001 From: floveqq <940925733@qq.com> Date: Sun, 14 May 2023 09:22:02 +0000 Subject: [PATCH 2/7] yolov8 modify --- README.md | 449 +++++++++++++++++- configs/yolov8/_base_/optimizer_500e.yml | 1 + configs/yolov8/_base_/optimizer_500e_high.yml | 1 + configs/yolov8/_base_/yolov8_cspdarknet.yml | 2 +- configs/yolov8/_base_/yolov8_reader.yml | 2 +- .../yolov8/_base_/yolov8_reader_high_aug.yml | 4 +- ppdet/engine/trainer.py | 3 +- ppdet/modeling/heads/yolov8_head.py | 266 ++++++----- 8 files changed, 602 insertions(+), 126 deletions(-) mode change 120000 => 100644 README.md diff --git a/README.md b/README.md deleted file mode 120000 index 4015683c..00000000 --- a/README.md +++ /dev/null @@ -1 +0,0 @@ -README_cn.md \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000..bb5619ac --- /dev/null +++ b/README.md @@ -0,0 +1,448 @@ + +## 飞桨黑客松第四期 +赛题155. YOLOv8复现 + + +###################################################################### + + +简体中文 | [English](README_en.md) + +## 简介 + +**PaddleYOLO**是基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的YOLO系列模型库,**只包含YOLO系列模型的相关代码**,支持`YOLOv3`、`PP-YOLO`、`PP-YOLOv2`、`PP-YOLOE`、`PP-YOLOE+`、`YOLOX`、`YOLOv5`、`YOLOv6`、`YOLOv7`、`YOLOv8`、`YOLOv5u`、`YOLOv7u`、`RTMDet`等模型,COCO数据集模型库请参照 [ModelZoo](docs/MODEL_ZOO_cn.md) 和 [configs](configs/)。 + +
+ + +
+ +**注意:** + + - **PaddleYOLO** 代码库协议为 **[GPL 3.0](LICENSE)**,[YOLOv5](configs/yolov5)、[YOLOv6](configs/yolov6)、[YOLOv7](configs/yolov7)和[YOLOv8](configs/yolov8)这几类模型代码不合入[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection),其余YOLO模型推荐在[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)中使用,**会最先发布PP-YOLO系列特色检测模型的最新进展**; + - **PaddleYOLO**代码库**推荐使用paddlepaddle-2.3.2以上的版本**,请参考[官网](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html)下载对应适合版本,**Windows平台请安装paddle develop版本**; + - **PaddleYOLO 的[Roadmap](https://github.com/PaddlePaddle/PaddleYOLO/issues/44)** issue用于收集用户的需求,欢迎提出您的建议和需求; + +## 教程 + +
+安装 + +Clone 代码库和安装 [requirements.txt](./requirements.txt),环境需要在一个 +[**Python>=3.7.0**](https://www.python.org/) 下的环境,且需要安装 +[**PaddlePaddle>=2.3.2**](https://www.paddlepaddle.org.cn/install/)。 + +```bash +git clone https://github.com/PaddlePaddle/PaddleYOLO # clone +cd PaddleYOLO +pip install -r requirements.txt # install +``` + +
+ +
+训练/验证/预测/ +将以下命令写在一个脚本文件里如```run.sh```,一键运行命令为:```sh run.sh```,也可命令行一句句去运行。 + +```bash +model_name=ppyoloe # 可修改,如 yolov7 +job_name=ppyoloe_plus_crn_s_80e_coco # 可修改,如 yolov7_tiny_300e_coco + +config=configs/${model_name}/${job_name}.yml +log_dir=log_dir/${job_name} +# weights=https://bj.bcebos.com/v1/paddledet/models/${job_name}.pdparams +weights=output/${job_name}/model_final.pdparams + +# 1.训练(单卡/多卡),加 --eval 表示边训边评估,加 --amp 表示混合精度训练 +# CUDA_VISIBLE_DEVICES=0 python tools/train.py -c ${config} --eval --amp +python -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --eval --amp + +# 2.评估,加 --classwise 表示输出每一类mAP +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c ${config} -o weights=${weights} --classwise + +# 3.预测 (单张图/图片文件夹) +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c ${config} -o weights=${weights} --infer_img=demo/000000014439_640x640.jpg --draw_threshold=0.5 +# CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c ${config} -o weights=${weights} --infer_dir=demo/ --draw_threshold=0.5 +``` + +
+ +
+部署/测速 + +将以下命令写在一个脚本文件里如```run.sh```,一键运行命令为:```sh run.sh```,也可命令行一句句去运行。 + +```bash +model_name=ppyoloe # 可修改,如 yolov7 +job_name=ppyoloe_plus_crn_s_80e_coco # 可修改,如 yolov7_tiny_300e_coco + +config=configs/${model_name}/${job_name}.yml +log_dir=log_dir/${job_name} +# weights=https://bj.bcebos.com/v1/paddledet/models/${job_name}.pdparams +weights=output/${job_name}/model_final.pdparams + +# 4.导出模型,以下3种模式选一种 +## 普通导出,加trt表示用于trt加速,对NMS和silu激活函数提速明显 +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c ${config} -o weights=${weights} # trt=True + +## exclude_post_process去除后处理导出,返回和YOLOv5导出ONNX时相同格式的concat后的1个Tensor,是未缩放回原图的坐标+分类置信度 +# CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c ${config} -o weights=${weights} exclude_post_process=True # trt=True + +## exclude_nms去除NMS导出,返回2个Tensor,是缩放回原图后的坐标和分类置信度 +# CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c ${config} -o weights=${weights} exclude_nms=True # trt=True + +# 5.部署预测,注意不能使用 去除后处理 或 去除NMS 导出后的模型去预测 +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/${job_name} --image_file=demo/000000014439_640x640.jpg --device=GPU + +# 6.部署测速,加 “--run_mode=trt_fp16” 表示在TensorRT FP16模式下测速,注意如需用到 trt_fp16 则必须为加 trt=True 导出的模型 +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/${job_name} --image_file=demo/000000014439_640x640.jpg --device=GPU --run_benchmark=True # --run_mode=trt_fp16 + +# 7.onnx导出,一般结合 exclude_post_process去除后处理导出的模型 +paddle2onnx --model_dir output_inference/${job_name} --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 12 --save_file ${job_name}.onnx + +# 8.onnx trt测速 +/usr/local/TensorRT-8.0.3.4/bin/trtexec --onnx=${job_name}.onnx --workspace=4096 --avgRuns=10 --shapes=input:1x3x640x640 --fp16 +/usr/local/TensorRT-8.0.3.4/bin/trtexec --onnx=${job_name}.onnx --workspace=4096 --avgRuns=10 --shapes=input:1x3x640x640 --fp32 +``` + +- 如果想切换模型,只要修改开头两行即可,如: + ``` + model_name=yolov7 + job_name=yolov7_tiny_300e_coco + ``` +- 导出**onnx**,首先安装[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX),`pip install paddle2onnx`; +- **统计FLOPs(G)和Params(M)**,首先安装[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim),`pip install paddleslim`,然后设置[runtime.yml](configs/runtime.yml)里`print_flops: True`和`print_params: True`,并且注意确保是**单尺度**下如640x640,**打印的是MACs,FLOPs=2*MACs**。 + +
+ + +
+ [训练自定义数据集](https://github.com/PaddlePaddle/PaddleYOLO/issues/43) + +- 请参照[文档](docs/MODEL_ZOO_cn.md#自定义数据集)和[issue](https://github.com/PaddlePaddle/PaddleYOLO/issues/43); +- PaddleDetection团队提供了**基于PP-YOLOE的各种垂类检测模型**的配置文件和权重,用户也可以作为参考去使用自定义数据集。请参考 [PP-YOLOE application](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/ppyoloe/application)、[pphuman](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/pphuman)、[ppvehicle](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/ppvehicle)、[visdrone](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/visdrone) 和 [smalldet](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.6/configs/smalldet)。 +- PaddleDetection团队也提供了**VOC数据集的各种YOLO模型**的配置文件和权重,用户也可以作为参考去使用自定义数据集。请参考 [voc](configs/voc)。 +- 训练自定义数据集之前请先**确保加载了对应COCO权重作为预训练**,将配置文件中的`pretrain_weights: `设置为对应COCO模型训好的权重,一般会提示head分类层卷积的通道数没对应上,属于正常现象,是由于自定义数据集一般和COCO数据集种类数不一致; +- YOLO检测模型建议**总`batch_size`至少大于`64`**去训练,如果资源不够请**换小模型**或**减小模型的输入尺度**,为了保障较高检测精度,**尽量不要尝试单卡训和总`batch_size`小于`64`训**; + +
+ + +## 更新日志 + +* 【2023/03/13】支持[YOLOv5u](configs/yolov5/yolov5u)和[YOLOv7u](configs/yolov7/yolov7u)预测和部署; +* 【2023/01/10】支持[YOLOv8](configs/yolov8)预测和部署; +* 【2022/09/29】支持[RTMDet](configs/rtmdet)预测和部署; +* 【2022/09/26】发布[PaddleYOLO](https://github.com/PaddlePaddle/PaddleYOLO)模型套件,请参照[ModelZoo](docs/MODEL_ZOO_cn.md); +* 【2022/09/19】支持[YOLOv6](configs/yolov6)新版,包括n/t/s/m/l模型; +* 【2022/08/23】发布`YOLOSeries`代码库: 支持`YOLOv3`,`PP-YOLOE`,`PP-YOLOE+`,`YOLOX`,`YOLOv5`,`YOLOv6`,`YOLOv7`等YOLO模型,支持`ConvNeXt`骨干网络高精度版`PP-YOLOE`,`YOLOX`和`YOLOv5`等模型,支持PaddleSlim无损加速量化训练`PP-YOLOE`,`YOLOv5`,`YOLOv6`和`YOLOv7`等模型,详情可阅读[此文章](https://mp.weixin.qq.com/s/Hki01Zs2lQgvLSLWS0btrA); + + +## 产品动态 + +- 🔥 **2023.3.14:PaddleYOLO发布[release/2.6版本](https://github.com/PaddlePaddle/PaddleYOLO/tree/release/2.6)** + - 💡 模型套件: + - 支持`YOLOv8`,`YOLOv5u`,`YOLOv7u`等YOLO模型预测和部署; + - 支持`Swin-Transformer`、`ViT`、`FocalNet`骨干网络高精度版`PP-YOLOE+`等模型; + - 支持`YOLOv8`在[FastDeploy](https://github.com/PaddlePaddle/FastDeploy/tree/develop/examples/vision/detection/paddledetection)中多硬件快速部署; + +- 🔥 **2022.9.26:PaddleYOLO发布[release/2.5版本](https://github.com/PaddlePaddle/PaddleYOLO/tree/release/2.5)** + - 💡 模型套件: + - 发布[PaddleYOLO](https://github.com/PaddlePaddle/PaddleYOLO)模型套件: 支持`YOLOv3`,`PP-YOLOE`,`PP-YOLOE+`,`YOLOX`,`YOLOv5`,`YOLOv6`,`YOLOv7`等YOLO模型,支持`ConvNeXt`骨干网络高精度版`PP-YOLOE`,`YOLOX`和`YOLOv5`等模型,支持PaddleSlim无损加速量化训练`PP-YOLOE`,`YOLOv5`,`YOLOv6`和`YOLOv7`等模型; + +- 🔥 **2022.8.26:PaddleDetection发布[release/2.5版本](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5)** + - 🗳 特色模型: + - 发布[PP-YOLOE+](configs/ppyoloe),最高精度提升2.4% mAP,达到54.9% mAP,模型训练收敛速度提升3.75倍,端到端预测速度最高提升2.3倍;多个下游任务泛化性提升 + - 发布[PicoDet-NPU](configs/picodet)模型,支持模型全量化部署;新增[PicoDet](configs/picodet)版面分析模型 + - 发布[PP-TinyPose升级版](./configs/keypoint/tiny_pose/)增强版,在健身、舞蹈等场景精度提升9.1% AP,支持侧身、卧躺、跳跃、高抬腿等非常规动作 + - 🔮 场景能力: + - 发布行人分析工具[PP-Human v2](./deploy/pipeline),新增打架、打电话、抽烟、闯入四大行为识别,底层算法性能升级,覆盖行人检测、跟踪、属性三类核心算法能力,提供保姆级全流程开发及模型优化策略,支持在线视频流输入 + - 首次发布[PP-Vehicle](./deploy/pipeline),提供车牌识别、车辆属性分析(颜色、车型)、车流量统计以及违章检测四大功能,兼容图片、在线视频流、视频输入,提供完善的二次开发文档教程 + - 💡 前沿算法: + - 全面覆盖的[YOLO家族](https://github.com/PaddlePaddle/PaddleYOLO)经典与最新模型: 包括YOLOv3,百度飞桨自研的实时高精度目标检测检测模型PP-YOLOE,以及前沿检测算法YOLOv4、YOLOv5、YOLOX,YOLOv6及YOLOv7 + - 新增基于[ViT](configs/vitdet)骨干网络高精度检测模型,COCO数据集精度达到55.7% mAP;新增[OC-SORT](configs/mot/ocsort)多目标跟踪模型;新增[ConvNeXt](configs/convnext)骨干网络 + - 📋 产业范例:新增[智能健身](https://aistudio.baidu.com/aistudio/projectdetail/4385813)、[打架识别](https://aistudio.baidu.com/aistudio/projectdetail/4086987?channelType=0&channel=0)、[来客分析](https://aistudio.baidu.com/aistudio/projectdetail/4230123?channelType=0&channel=0)、车辆结构化范例 + +- 2022.3.24:PaddleDetection发布[release/2.4版本](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4) + - 发布高精度云边一体SOTA目标检测模型[PP-YOLOE](configs/ppyoloe),提供s/m/l/x版本,l版本COCO test2017数据集精度51.6%,V100预测速度78.1 FPS,支持混合精度训练,训练较PP-YOLOv2加速33%,全系列多尺度模型,满足不同硬件算力需求,可适配服务器、边缘端GPU及其他服务器端AI加速卡。 + - 发布边缘端和CPU端超轻量SOTA目标检测模型[PP-PicoDet增强版](configs/picodet),精度提升2%左右,CPU预测速度提升63%,新增参数量0.7M的PicoDet-XS模型,提供模型稀疏化和量化功能,便于模型加速,各类硬件无需单独开发后处理模块,降低部署门槛。 + - 发布实时行人分析工具[PP-Human](deploy/pipeline),支持行人跟踪、人流量统计、人体属性识别与摔倒检测四大能力,基于真实场景数据特殊优化,精准识别各类摔倒姿势,适应不同环境背景、光线及摄像角度。 + - 新增[YOLOX](configs/yolox)目标检测模型,支持nano/tiny/s/m/l/x版本,x版本COCO val2017数据集精度51.8%。 + +- [更多版本发布](https://github.com/PaddlePaddle/PaddleDetection/releases) + +## 简介 + +**PaddleDetection**为基于飞桨PaddlePaddle的端到端目标检测套件,内置**30+模型算法**及**250+预训练模型**,覆盖**目标检测、实例分割、跟踪、关键点检测**等方向,其中包括**服务器端和移动端高精度、轻量级**产业级SOTA模型、冠军方案和学术前沿算法,并提供配置化的网络模块组件、十余种数据增强策略和损失函数等高阶优化支持和多种部署方案,在打通数据处理、模型开发、训练、压缩、部署全流程的基础上,提供丰富的案例及教程,加速算法产业落地应用。 + +
+ +
+ +## 特性 + +- **模型丰富**: 包含**目标检测**、**实例分割**、**人脸检测**、****关键点检测****、**多目标跟踪**等**250+个预训练模型**,涵盖多种**全球竞赛冠军**方案。 +- **使用简洁**:模块化设计,解耦各个网络组件,开发者轻松搭建、试用各种检测模型及优化策略,快速得到高性能、定制化的算法。 +- **端到端打通**: 从数据增强、组网、训练、压缩、部署端到端打通,并完备支持**云端**/**边缘端**多架构、多设备部署。 +- **高性能**: 基于飞桨的高性能内核,模型训练速度及显存占用优势明显。支持FP16训练, 支持多机训练。 + +
+ +
+ +## 技术交流 + +- 如果你发现任何PaddleDetection存在的问题或者是建议, 欢迎通过[GitHub Issues](https://github.com/PaddlePaddle/PaddleDetection/issues)给我们提issues。 + +- **欢迎加入PaddleDetection 微信用户群(扫码填写问卷即可入群)** + - **入群福利 💎:获取PaddleDetection团队整理的重磅学习大礼包🎁** + - 📊 福利一:获取飞桨联合业界企业整理的开源数据集 + - 👨‍🏫 福利二:获取PaddleDetection历次发版直播视频与最新直播咨询 + - 🗳 福利三:获取垂类场景预训练模型集合,包括工业、安防、交通等5+行业场景 + - 🗂 福利四:获取10+全流程产业实操范例,覆盖火灾烟雾检测、人流量计数等产业高频场景 +
+ +
+ +## 套件结构概览 + + + + + + + + + + + + + + + + + + + +
+ Architectures + + Backbones + + Components + + Data Augmentation +
+
    +
    Object Detection +
      +
    • YOLOv3
    • +
    • YOLOv5
    • +
    • YOLOv6
    • +
    • YOLOv7
    • +
    • YOLOv8
    • +
    • PP-YOLOv1/v2
    • +
    • PP-YOLO-Tiny
    • +
    • PP-YOLOE
    • +
    • PP-YOLOE+
    • +
    • YOLOX
    • +
    • RTMDet
    • +
    +
+
+
Details +
    +
  • ResNet(&vd)
  • +
  • CSPResNet
  • +
  • DarkNet
  • +
  • CSPDarkNet
  • +
  • ConvNeXt
  • +
  • EfficientRep
  • +
  • CSPBepBackbone
  • +
  • ELANNet
  • +
  • CSPNeXt
  • +
+
+
Common +
    +
  • Sync-BN
  • +
  • Group Norm
  • +
  • DCNv2
  • +
  • EMA
  • +
+ +
FPN +
    +
  • YOLOv3FPN
  • +
  • PPYOLOFPN
  • +
  • PPYOLOTinyFPN
  • +
  • PPYOLOPAN
  • +
  • YOLOCSPPAN
  • +
  • Custom-PAN
  • +
  • RepPAN
  • +
  • CSPRepPAN
  • +
  • ELANFPN
  • +
  • ELANFPNP6
  • +
  • CSPNeXtPAFPN
  • +
+ +
Loss +
    +
  • Smooth-L1
  • +
  • GIoU/DIoU/CIoU
  • +
  • IoUAware
  • +
  • Focal Loss
  • +
  • VariFocal Loss
  • +
+ +
Post-processing +
    +
  • SoftNMS
  • +
  • MatrixNMS
  • +
+ +
Speed +
    +
  • FP16 training
  • +
  • Multi-machine training
  • +
+ +
+
Details +
    +
  • Resize
  • +
  • Lighting
  • +
  • Flipping
  • +
  • Expand
  • +
  • Crop
  • +
  • Color Distort
  • +
  • Random Erasing
  • +
  • Mixup
  • +
  • AugmentHSV
  • +
  • Mosaic
  • +
  • Cutmix
  • +
  • Grid Mask
  • +
  • Auto Augment
  • +
  • Random Perspective
  • +
+
+ +## 模型性能概览 + +
+ 云端模型性能对比 + +各模型结构和骨干网络的代表模型在COCO数据集上精度mAP和单卡Tesla V100上预测速度(FPS)对比图。 + +
+ +
+ +**说明:** + +- `PP-YOLOE`是对`PP-YOLO v2`模型的进一步优化,在COCO数据集精度51.6%,Tesla V100预测速度78.1FPS +- `PP-YOLOE+`是对`PPOLOE`模型的进一步优化,在COCO数据集精度53.3%,Tesla V100预测速度78.1FPS +- 图中模型均可在[模型库](#模型库)中获取 + +
+ +
+ 移动端模型性能对比 + +各移动端模型在COCO数据集上精度mAP和高通骁龙865处理器上预测速度(FPS)对比图。 + +
+ +
+ +**说明:** + +- 测试数据均使用高通骁龙865(4\*A77 + 4\*A55)处理器batch size为1, 开启4线程测试,测试使用NCNN预测库,测试脚本见[MobileDetBenchmark](https://github.com/JiweiMaster/MobileDetBenchmark) +- [PP-PicoDet](configs/picodet)及[PP-YOLO-Tiny](configs/ppyolo)为PaddleDetection自研模型,其余模型PaddleDetection暂未提供 + +
+ +## 模型库 + +
+ 1. 通用检测 + +#### [PP-YOLOE+](./configs/ppyoloe)系列 推荐场景:Nvidia V100, T4等云端GPU和Jetson系列等边缘端设备 + +| 模型名称 | COCO精度(mAP) | V100 TensorRT FP16速度(FPS) | 配置文件 | 模型下载 | +|:---------- |:-----------:|:-------------------------:|:-----------------------------------------------------:|:------------------------------------------------------------------------------------:| +| PP-YOLOE+_s | 43.9 | 333.3 | [链接](configs/ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml) | [下载地址](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco.pdparams) | +| PP-YOLOE+_m | 50.0 | 208.3 | [链接](configs/ppyoloe/ppyoloe_plus_crn_m_80e_coco.yml) | [下载地址](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_m_80e_coco.pdparams) | +| PP-YOLOE+_l | 53.3 | 149.2 | [链接](configs/ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml) | [下载地址](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_m_80e_coco.pdparams) | +| PP-YOLOE+_x | 54.9 | 95.2 | [链接](configs/ppyoloe/ppyoloe_plus_crn_x_80e_coco.yml) | [下载地址](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_x_80e_coco.pdparams) | + +#### 前沿检测算法 + +| 模型名称 | COCO精度(mAP) | V100 TensorRT FP16速度(FPS) | 配置文件 | 模型下载 | +|:------------------------------------------------------------------ |:-----------:|:-------------------------:|:------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------:| +| [YOLOX-l](configs/yolox) | 50.1 | 107.5 | [链接](configs/yolox/yolox_l_300e_coco.yml) | [下载地址](https://paddledet.bj.bcebos.com/models/yolox_l_300e_coco.pdparams) | +| [YOLOv5-l](configs/yolov5) | 48.6 | 136.0 | [链接](configs/yolov5/yolov5_l_300e_coco.yml) | [下载地址](https://paddledet.bj.bcebos.com/models/yolov5_l_300e_coco.pdparams) | +| [YOLOv7-l](configs/yolov7) | 51.0 | 135.0 | [链接](configs/yolov7/yolov7_l_300e_coco.yml) | [下载地址](https://paddledet.bj.bcebos.com/models/yolov7_l_300e_coco.pdparams) | + +
+ + +## 文档教程 + +### 入门教程 + +- [安装说明](docs/tutorials/INSTALL_cn.md) +- [快速体验](docs/tutorials/QUICK_STARTED_cn.md) +- [数据准备](docs/tutorials/data/README.md) +- [PaddleDetection全流程使用](docs/tutorials/GETTING_STARTED_cn.md) +- [FAQ/常见问题汇总](docs/tutorials/FAQ) + +### 进阶教程 + +- 参数配置 + + - [PP-YOLO参数说明](docs/tutorials/config_annotation/ppyolo_r50vd_dcn_1x_coco_annotation.md) + +- 模型压缩(基于[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)) + + - [剪裁/量化/蒸馏教程](configs/slim) + +- [推理部署](deploy/README.md) + + - [模型导出教程](deploy/EXPORT_MODEL.md) + - [Paddle Inference部署](deploy/README.md) + - [Python端推理部署](deploy/python) + - [C++端推理部署](deploy/cpp) + - [Paddle-Lite部署](deploy/lite) + - [Paddle Serving部署](deploy/serving) + - [ONNX模型导出](deploy/EXPORT_ONNX_MODEL.md) + - [推理benchmark](deploy/BENCHMARK_INFER.md) + +- 进阶开发 + + - [数据处理模块](docs/advanced_tutorials/READER.md) + - [新增检测模型](docs/advanced_tutorials/MODEL_TECHNICAL.md) + - 二次开发教程 + - [目标检测](docs/advanced_tutorials/customization/detection.md) + + +## 版本更新 + +版本更新内容请参考[版本更新文档](docs/CHANGELOG.md) + + +## 许可证书 + +本项目的发布受[GPL-3.0 license](LICENSE)许可认证。 + + +## 引用 + +``` +@misc{ppdet2019, +title={PaddleDetection, Object detection and instance segmentation toolkit based on PaddlePaddle.}, +author={PaddlePaddle Authors}, +howpublished = {\url{https://github.com/PaddlePaddle/PaddleDetection}}, +year={2019} +} +``` diff --git a/configs/yolov8/_base_/optimizer_500e.yml b/configs/yolov8/_base_/optimizer_500e.yml index f57dd8b6..3251d70f 100644 --- a/configs/yolov8/_base_/optimizer_500e.yml +++ b/configs/yolov8/_base_/optimizer_500e.yml @@ -17,3 +17,4 @@ OptimizerBuilder: regularizer: factor: 0.0005 type: L2 + clip_grad_by_value: 10. diff --git a/configs/yolov8/_base_/optimizer_500e_high.yml b/configs/yolov8/_base_/optimizer_500e_high.yml index 7379e751..9dd72709 100644 --- a/configs/yolov8/_base_/optimizer_500e_high.yml +++ b/configs/yolov8/_base_/optimizer_500e_high.yml @@ -17,3 +17,4 @@ OptimizerBuilder: regularizer: factor: 0.0005 type: L2 + clip_grad_by_value: 10. diff --git a/configs/yolov8/_base_/yolov8_cspdarknet.yml b/configs/yolov8/_base_/yolov8_cspdarknet.yml index dee6ec59..fff154f7 100644 --- a/configs/yolov8/_base_/yolov8_cspdarknet.yml +++ b/configs/yolov8/_base_/yolov8_cspdarknet.yml @@ -1,5 +1,5 @@ architecture: YOLOv8 -norm_type: sync_bn +#norm_type: sync_bn use_ema: True ema_decay: 0.9999 ema_decay_type: "exponential" diff --git a/configs/yolov8/_base_/yolov8_reader.yml b/configs/yolov8/_base_/yolov8_reader.yml index 1e2f1da8..748cccee 100644 --- a/configs/yolov8/_base_/yolov8_reader.yml +++ b/configs/yolov8/_base_/yolov8_reader.yml @@ -29,7 +29,7 @@ EvalReader: - Pad: {size: *input_size, fill_value: [114., 114., 114.]} - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} - Permute: {} - batch_size: 1 + batch_size: 8 TestReader: diff --git a/configs/yolov8/_base_/yolov8_reader_high_aug.yml b/configs/yolov8/_base_/yolov8_reader_high_aug.yml index 1ca74206..a7e802e7 100644 --- a/configs/yolov8/_base_/yolov8_reader_high_aug.yml +++ b/configs/yolov8/_base_/yolov8_reader_high_aug.yml @@ -7,7 +7,7 @@ worker_num: 4 TrainReader: sample_transforms: - Decode: {} - - MosaicPerspective: {mosaic_prob: 1.0, target_size: *input_size, scale: 0.9, mixup_prob: 0.15} + - MosaicPerspective: {mosaic_prob: 1.0, target_size: *input_size, scale: 0.9, mixup_prob: 0.1, copy_paste_prob: 0.1} - RandomHSV: {hgain: 0.015, sgain: 0.7, vgain: 0.4} - RandomFlip: {} batch_transforms: @@ -29,7 +29,7 @@ EvalReader: - Pad: {size: *input_size, fill_value: [114., 114., 114.]} - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} - Permute: {} - batch_size: 1 + batch_size: 8 TestReader: diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index bf440f12..de5bb9ee 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -69,8 +69,7 @@ def __init__(self, cfg, mode='train'): self.custom_white_list = self.cfg.get('custom_white_list', None) self.custom_black_list = self.cfg.get('custom_black_list', None) - if self.cfg.architecture in ['RTMDet', 'YOLOv6', 'YOLOv8' - ] and self.mode == 'train': + if self.cfg.architecture in ['RTMDet', 'YOLOv6'] and self.mode == 'train': raise NotImplementedError('{} training not supported yet.'.format( self.cfg.architecture)) diff --git a/ppdet/modeling/heads/yolov8_head.py b/ppdet/modeling/heads/yolov8_head.py index 1fcc98b9..da432218 100644 --- a/ppdet/modeling/heads/yolov8_head.py +++ b/ppdet/modeling/heads/yolov8_head.py @@ -52,9 +52,9 @@ def __init__(self, nms='MultiClassNMS', eval_size=None, loss_weight={ - 'class': 1.0, - 'iou': 2.5, - 'dfl': 0.5, + 'class': 0.5, + 'iou': 7.5, + 'dfl': 1.5, }, trt=False, exclude_nms=False, @@ -68,6 +68,7 @@ def __init__(self, self.fpn_strides = fpn_strides self.grid_cell_scale = grid_cell_scale self.grid_cell_offset = grid_cell_offset + self.reg_max = reg_max if reg_range: self.reg_range = reg_range else: @@ -86,11 +87,10 @@ def __init__(self, self.use_shared_conv = use_shared_conv # cls loss - self.bce = nn.BCEWithLogitsLoss( - pos_weight=paddle.to_tensor([1.0]), reduction="mean") + self.bce = nn.BCEWithLogitsLoss(reduction='none') # pred head - c2 = max((16, in_channels[0] // 4, self.reg_channels * 4)) + c2 = max((16, in_channels[0] // 4, self.reg_max * 4)) c3 = max(in_channels[0], self.num_classes) self.conv_reg = nn.LayerList() self.conv_cls = nn.LayerList() @@ -103,7 +103,7 @@ def __init__(self, c2, c2, 3, 1, act=act), nn.Conv2D( c2, - self.reg_channels * 4, + self.reg_max * 4, 1, bias_attr=ParamAttr(regularizer=L2Decay(0.0))), ])) @@ -119,14 +119,7 @@ def __init__(self, 1, bias_attr=ParamAttr(regularizer=L2Decay(0.0))), ])) - # projection conv - self.dfl_conv = nn.Conv2D(self.reg_channels, 1, 1, bias_attr=False) - self.dfl_conv.skip_quant = True - self.proj = paddle.linspace(0, self.reg_channels - 1, self.reg_channels) - self.dfl_conv.weight.set_value( - self.proj.reshape([1, self.reg_channels, 1, 1])) - self.dfl_conv.weight.stop_gradient = True - + self.proj = paddle.arange(self.reg_max).astype('float32') # self._init_bias() @classmethod @@ -151,53 +144,43 @@ def forward_train(self, feats, targets): feats, self.fpn_strides, self.grid_cell_scale, self.grid_cell_offset) - cls_score_list, reg_distri_list = [], [] + cls_logits_list, bbox_preds_list, bbox_dist_preds_list = [], [], [] for i, feat in enumerate(feats): - reg_distri = self.conv_reg[i](feat) + _, _, h, w = feat.shape + l = h * w + bbox_dist_preds = self.conv_reg[i](feat) cls_logit = self.conv_cls[i](feat) - # cls and reg - cls_score = F.sigmoid(cls_logit) - cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1])) - reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1])) - cls_score_list = paddle.concat(cls_score_list, axis=1) - reg_distri_list = paddle.concat(reg_distri_list, axis=1) + bbox_dist_preds = bbox_dist_preds.reshape([-1, 4, self.reg_max, l]).transpose([0, 3, 1, 2]) + # [8, 6400, 4, 16] + bbox_preds = F.softmax(bbox_dist_preds, axis=3).matmul(self.proj.reshape([-1, 1])).squeeze(-1) # [8, 6400, 4] + + cls_logits_list.append(cls_logit) + bbox_preds_list.append(bbox_preds.transpose([0, 2, 1]).reshape([-1, 4, h, w])) + bbox_dist_preds_list.append(bbox_dist_preds) return self.get_loss([ - cls_score_list, reg_distri_list, anchors, anchor_points, + cls_logits_list, bbox_preds_list, bbox_dist_preds_list, anchors, anchor_points, num_anchors_list, stride_tensor ], targets) def forward_eval(self, feats): anchor_points, stride_tensor = self._generate_anchors(feats) - cls_score_list, reg_dist_list = [], [] + cls_logits_list, bbox_preds_list = [], [] for i, feat in enumerate(feats): _, _, h, w = feat.shape l = h * w - reg_dist = self.conv_reg[i](feat) + bbox_dist_preds = self.conv_reg[i](feat) cls_logit = self.conv_cls[i](feat) - reg_dist = reg_dist.reshape( - [-1, 4, self.reg_channels, l]).transpose( - [0, 2, 3, 1]) # Note diff - if self.use_shared_conv: - reg_dist = self.dfl_conv(F.softmax(reg_dist, axis=1)).squeeze(1) - # [bs, l, 4] - else: - reg_dist = F.softmax(reg_dist, axis=1) - # cls and reg - cls_score = F.sigmoid(cls_logit) - cls_score_list.append(cls_score.reshape([-1, self.num_classes, l])) - reg_dist_list.append(reg_dist) - - cls_score_list = paddle.concat(cls_score_list, axis=-1) - if self.use_shared_conv: - reg_dist_list = paddle.concat(reg_dist_list, axis=1) - else: - reg_dist_list = paddle.concat(reg_dist_list, axis=2) - reg_dist_list = self.dfl_conv(reg_dist_list).squeeze(1) + bbox_dist_preds = bbox_dist_preds.reshape( + [-1, 4, self.reg_max, l]).transpose([0, 3, 1, 2]) + # [8, 6400, 4, 16] + bbox_preds = F.softmax(bbox_dist_preds, axis=3).matmul(self.proj.reshape([-1, 1])).squeeze(-1) # [8, 6400, 4] + cls_logits_list.append(cls_logit) + bbox_preds_list.append(bbox_preds.transpose([0, 2, 1]).reshape([-1, 4, h, w])) - return cls_score_list, reg_dist_list, anchor_points, stride_tensor + return cls_logits_list, bbox_preds_list, anchor_points, stride_tensor def _generate_anchors(self, feats=None, dtype='float32'): # just use in eval time @@ -228,18 +211,17 @@ def _varifocal_loss(pred_score, gt_score, label, alpha=0.75, gamma=2.0): pred_score, gt_score, weight=weight, reduction='sum') return loss - def _bbox_decode(self, anchor_points, pred_dist): - _, l, _ = get_static_shape(pred_dist) - pred_dist = F.softmax(pred_dist.reshape([-1, l, 4, self.reg_channels])) - pred_dist = self.dfl_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1) - return batch_distance2bbox(anchor_points, pred_dist) + # def _bbox_decode(self, anchor_points, pred_dist): + # _, l, _ = get_static_shape(pred_dist) + # pred_dist = F.softmax(pred_dist.reshape([-1, l, 4, self.reg_channels])) + # pred_dist = self.dfl_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1) + # return batch_distance2bbox(anchor_points, pred_dist) - def _bbox2distance(self, points, bbox): + def _bbox2distance(self, points, bbox, reg_max=15, eps=0.01): x1y1, x2y2 = paddle.split(bbox, 2, -1) lt = points - x1y1 rb = x2y2 - points - return paddle.concat([lt, rb], -1).clip(self.reg_range[0], - self.reg_range[1] - 1 - 0.01) + return paddle.concat([lt, rb], -1).clip(0, reg_max - eps) def _df_loss(self, pred_dist, target, lower_bound=0): target_left = paddle.cast(target.floor(), 'int64') @@ -254,100 +236,133 @@ def _df_loss(self, pred_dist, target, lower_bound=0): reduction='none') * weight_right return (loss_left + loss_right).mean(-1, keepdim=True) - def _bbox_loss(self, pred_dist, pred_bboxes, anchor_points, assigned_labels, - assigned_bboxes, assigned_scores, assigned_scores_sum): - # select positive samples mask - mask_positive = (assigned_labels != self.num_classes) - num_pos = mask_positive.sum() - # pos/neg loss - if num_pos > 0: - # l1 + iou - bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4]) - pred_bboxes_pos = paddle.masked_select(pred_bboxes, - bbox_mask).reshape([-1, 4]) - assigned_bboxes_pos = paddle.masked_select( - assigned_bboxes, bbox_mask).reshape([-1, 4]) - bbox_weight = paddle.masked_select( - assigned_scores.sum(-1), mask_positive).unsqueeze(-1) + def get_loss(self, head_outs, gt_meta): + cls_scores, bbox_preds, bbox_dist_preds, anchors,\ + anchor_points, num_anchors_list, stride_tensor = head_outs - # loss_l1 just see if train well - if self.print_l1_loss: - loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos) - else: - loss_l1 = paddle.zeros([1]) - # ciou loss - iou = bbox_iou( - pred_bboxes_pos, assigned_bboxes_pos, x1y1x2y2=False, ciou=True) - loss_iou = ((1.0 - iou) * bbox_weight).sum() / assigned_scores_sum + bs = cls_scores[0].shape[0] + flatten_cls_preds = [ + cls_pred.transpose([0, 2, 3, 1]).reshape([bs, -1, self.num_classes]) + for cls_pred in cls_scores + ] + flatten_pred_bboxes = [ + bbox_pred.transpose([0, 2, 3, 1]).reshape([bs, -1, 4]) + for bbox_pred in bbox_preds + ] + flatten_pred_dists = [ + bbox_pred_org.reshape([bs, -1, self.reg_max * 4]) + for bbox_pred_org in bbox_dist_preds + ] - dist_mask = mask_positive.unsqueeze(-1).tile( - [1, 1, self.reg_channels * 4]) - pred_dist_pos = paddle.masked_select( - pred_dist, dist_mask).reshape([-1, 4, self.reg_channels]) - assigned_ltrb = self._bbox2distance(anchor_points, assigned_bboxes) - assigned_ltrb_pos = paddle.masked_select( - assigned_ltrb, bbox_mask).reshape([-1, 4]) - loss_dfl = self._df_loss(pred_dist_pos, assigned_ltrb_pos, - self.reg_range[0]) * bbox_weight - loss_dfl = loss_dfl.sum() / assigned_scores_sum - else: - loss_l1 = paddle.zeros([1]) - loss_iou = paddle.zeros([1]) - loss_dfl = pred_dist.sum() * 0. - return loss_l1, loss_iou, loss_dfl + flatten_dist_preds = paddle.concat(flatten_pred_dists, 1) # [8, 8400, 64] + pred_scores = paddle.concat(flatten_cls_preds, 1) # [8, 8400, 80] + pred_distri = paddle.concat(flatten_pred_bboxes, 1) # [8, 8400, 4] - def get_loss(self, head_outs, gt_meta): - pred_scores, pred_distri, anchors,\ - anchor_points, num_anchors_list, stride_tensor = head_outs anchor_points_s = anchor_points / stride_tensor - pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri) + # pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri) + pred_bboxes = batch_distance2bbox(anchor_points_s, pred_distri) # xyxy # 9345410. [8, 8400, 4] + pred_bboxes = pred_bboxes * stride_tensor # must *stride - gt_labels = gt_meta['gt_class'] - gt_bboxes = gt_meta['gt_bbox'] + gt_labels = gt_meta['gt_class'] # [16, 51, 1] + gt_bboxes = gt_meta['gt_bbox'] # xyxy max=640 # [16, 51, 4] pad_gt_mask = gt_meta['pad_gt_mask'] + # pad_gt_mask = paddle.cast((gt_bboxes.sum(-1, keepdim=True) > 0), 'float32') + # label assignment + # [8, 8400, 4] 80842408. 8781976. assigned_scores 131.79727173 assigned_labels, assigned_bboxes, assigned_scores = \ self.assigner( - pred_scores.detach(), - pred_bboxes.detach() * stride_tensor, + F.sigmoid(pred_scores.detach()), + pred_bboxes.detach(), # * stride_tensor, anchor_points, num_anchors_list, gt_labels, - gt_bboxes, + gt_bboxes, # xyxy pad_gt_mask, bg_index=self.num_classes) # rescale bbox + # print('assigned_scores ', assigned_scores.max(), assigned_scores.sum()) assigned_bboxes /= stride_tensor + pred_bboxes /= stride_tensor ### + # cls loss - if self.use_varifocal_loss: + if 0: #self.use_varifocal_loss: one_hot_label = F.one_hot(assigned_labels, self.num_classes + 1)[..., :-1] loss_cls = self._varifocal_loss(pred_scores, assigned_scores, one_hot_label) else: - loss_cls = self.bce(pred_scores, assigned_scores) + loss_cls = self.bce(pred_scores, assigned_scores).sum() # [16, 8400, 80] assigned_scores_sum = assigned_scores.sum() if paddle.distributed.get_world_size() > 1: paddle.distributed.all_reduce(assigned_scores_sum) assigned_scores_sum /= paddle.distributed.get_world_size() assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.) - # loss_cls /= assigned_scores_sum - - loss_l1, loss_iou, loss_dfl = \ - self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s, - assigned_labels, assigned_bboxes, assigned_scores, - assigned_scores_sum) - loss = self.loss_weight['class'] * loss_cls + \ - self.loss_weight['iou'] * loss_iou + \ - self.loss_weight['dfl'] * loss_dfl + loss_cls /= assigned_scores_sum + + # select positive samples mask + mask_positive = (assigned_labels != self.num_classes) + num_pos = mask_positive.sum() + # pos/neg loss + if num_pos > 0: + # ciou loss + bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4]) + pred_bboxes_pos = paddle.masked_select( + pred_bboxes, bbox_mask).reshape([-1, 4]) + assigned_bboxes_pos = paddle.masked_select( + assigned_bboxes, bbox_mask).reshape([-1, 4]) + bbox_weight = paddle.masked_select( + assigned_scores.sum(-1), mask_positive).unsqueeze(-1) + iou = bbox_iou( + pred_bboxes_pos.split(4, axis=-1), + assigned_bboxes_pos.split(4, axis=-1), + x1y1x2y2=True, # xyxy + ciou=True, + eps=1e-7) + loss_iou = ((1.0 - iou) * bbox_weight).sum() / assigned_scores_sum + + if 1: #self.print_l1_loss: + loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos) + else: + loss_l1 = paddle.zeros([1]) + + # dfl loss + dist_mask = mask_positive.unsqueeze(-1).tile( + [1, 1, self.reg_max * 4]) + pred_dist_pos = paddle.masked_select( + flatten_dist_preds, dist_mask).reshape([-1, 4, self.reg_max]) + assigned_ltrb = self._bbox2distance( + anchor_points_s, + assigned_bboxes, + reg_max=self.reg_max - 1, + eps=0.01) + assigned_ltrb_pos = paddle.masked_select( + assigned_ltrb, bbox_mask).reshape([-1, 4]) + + loss_dfl = self._df_loss(pred_dist_pos, + assigned_ltrb_pos) * bbox_weight + loss_dfl = loss_dfl.sum() / assigned_scores_sum + else: + loss_iou = flatten_dist_preds.sum() * 0. + loss_dfl = flatten_dist_preds.sum() * 0. + loss_l1 = flatten_dist_preds.sum() * 0. + + loss_cls *= self.loss_weight['class'] + loss_iou *= self.loss_weight['iou'] + loss_dfl *= self.loss_weight['dfl'] + loss_total = loss_cls + loss_iou + loss_dfl + # bs = head_outs[0].shape[0] + num_gpus = gt_meta.get('num_gpus', 8) + total_bs = bs * num_gpus + out_dict = { - 'loss': loss, - 'loss_cls': loss_cls, - 'loss_iou': loss_iou, - 'loss_dfl': loss_dfl, + 'loss': loss_total * total_bs, + 'loss_cls': loss_cls * total_bs, + 'loss_iou': loss_iou * total_bs, + 'loss_dfl': loss_dfl * total_bs, } if self.print_l1_loss: # just see convergence @@ -355,14 +370,27 @@ def get_loss(self, head_outs, gt_meta): return out_dict def post_process(self, head_outs, im_shape, scale_factor): - pred_scores, pred_dist, anchor_points, stride_tensor = head_outs - pred_bboxes = batch_distance2bbox(anchor_points, pred_dist) + pred_scores_list, pred_dist_list, anchor_points, stride_tensor = head_outs + bs = pred_scores_list[0].shape[0] + pred_scores = [ + cls_score.transpose([0, 2, 3, 1]).reshape([bs, -1, self.num_classes]) + for cls_score in pred_scores_list + ] + pred_dists = [ + bbox_pred.transpose([0, 2, 3, 1]).reshape([bs, -1, 4]) + for bbox_pred in pred_dist_list + ] + pred_scores = F.sigmoid(paddle.concat(pred_scores, 1)) # [8, 8400, 80] + pred_bboxes = paddle.concat(pred_dists, 1) # [8, 8400, 4] + + pred_bboxes = batch_distance2bbox(anchor_points, pred_bboxes) pred_bboxes *= stride_tensor if self.exclude_post_process: return paddle.concat( [pred_bboxes, pred_scores.transpose([0, 2, 1])], axis=-1) else: + pred_scores = pred_scores.transpose([0, 2, 1]) # scale bbox to origin scale_factor = scale_factor.flip(-1).tile([1, 2]).unsqueeze(1) pred_bboxes /= scale_factor From 53a91dcd0c5240ac0ccf14e821ccf7c985ed34c1 Mon Sep 17 00:00:00 2001 From: floveqq <940925733@qq.com> Date: Tue, 16 May 2023 04:11:02 +0000 Subject: [PATCH 3/7] Update config --- configs/yolov8/_base_/yolov8_cspdarknet.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/yolov8/_base_/yolov8_cspdarknet.yml b/configs/yolov8/_base_/yolov8_cspdarknet.yml index fff154f7..ebd28800 100644 --- a/configs/yolov8/_base_/yolov8_cspdarknet.yml +++ b/configs/yolov8/_base_/yolov8_cspdarknet.yml @@ -1,5 +1,5 @@ architecture: YOLOv8 -#norm_type: sync_bn +norm_type: sync_bn use_ema: True ema_decay: 0.9999 ema_decay_type: "exponential" @@ -34,7 +34,7 @@ YOLOv8Head: beta: 6.0 nms: name: MultiClassNMS - nms_top_k: 1000 + nms_top_k: 3000 keep_top_k: 300 score_threshold: 0.001 nms_threshold: 0.7 From 93f874b2a5aba8cebff1d573c9b9c190b78d831f Mon Sep 17 00:00:00 2001 From: floveqq <940925733@qq.com> Date: Wed, 17 May 2023 08:29:22 +0000 Subject: [PATCH 4/7] clean --- ppdet/modeling/heads/yolov8_head.py | 66 +++++++++-------------------- 1 file changed, 20 insertions(+), 46 deletions(-) diff --git a/ppdet/modeling/heads/yolov8_head.py b/ppdet/modeling/heads/yolov8_head.py index da432218..75486494 100644 --- a/ppdet/modeling/heads/yolov8_head.py +++ b/ppdet/modeling/heads/yolov8_head.py @@ -24,7 +24,6 @@ from ..bbox_utils import bbox_iou from ..assigners.utils import generate_anchors_for_grid_cell from ppdet.modeling.backbones.csp_darknet import BaseConv -from ppdet.modeling.ops import get_static_shape from ppdet.modeling.layers import MultiClassNMS __all__ = ['YOLOv8Head'] @@ -120,7 +119,7 @@ def __init__(self, bias_attr=ParamAttr(regularizer=L2Decay(0.0))), ])) self.proj = paddle.arange(self.reg_max).astype('float32') - # self._init_bias() + #self._init_bias() ? @classmethod def from_config(cls, cfg, input_shape): @@ -151,8 +150,7 @@ def forward_train(self, feats, targets): bbox_dist_preds = self.conv_reg[i](feat) cls_logit = self.conv_cls[i](feat) bbox_dist_preds = bbox_dist_preds.reshape([-1, 4, self.reg_max, l]).transpose([0, 3, 1, 2]) - # [8, 6400, 4, 16] - bbox_preds = F.softmax(bbox_dist_preds, axis=3).matmul(self.proj.reshape([-1, 1])).squeeze(-1) # [8, 6400, 4] + bbox_preds = F.softmax(bbox_dist_preds, axis=3).matmul(self.proj.reshape([-1, 1])).squeeze(-1) cls_logits_list.append(cls_logit) bbox_preds_list.append(bbox_preds.transpose([0, 2, 1]).reshape([-1, 4, h, w])) @@ -175,8 +173,7 @@ def forward_eval(self, feats): bbox_dist_preds = bbox_dist_preds.reshape( [-1, 4, self.reg_max, l]).transpose([0, 3, 1, 2]) - # [8, 6400, 4, 16] - bbox_preds = F.softmax(bbox_dist_preds, axis=3).matmul(self.proj.reshape([-1, 1])).squeeze(-1) # [8, 6400, 4] + bbox_preds = F.softmax(bbox_dist_preds, axis=3).matmul(self.proj.reshape([-1, 1])).squeeze(-1) cls_logits_list.append(cls_logit) bbox_preds_list.append(bbox_preds.transpose([0, 2, 1]).reshape([-1, 4, h, w])) @@ -204,19 +201,6 @@ def _generate_anchors(self, feats=None, dtype='float32'): stride_tensor = paddle.concat(stride_tensor) return anchor_points, stride_tensor - @staticmethod - def _varifocal_loss(pred_score, gt_score, label, alpha=0.75, gamma=2.0): - weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label - loss = F.binary_cross_entropy( - pred_score, gt_score, weight=weight, reduction='sum') - return loss - - # def _bbox_decode(self, anchor_points, pred_dist): - # _, l, _ = get_static_shape(pred_dist) - # pred_dist = F.softmax(pred_dist.reshape([-1, l, 4, self.reg_channels])) - # pred_dist = self.dfl_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1) - # return batch_distance2bbox(anchor_points, pred_dist) - def _bbox2distance(self, points, bbox, reg_max=15, eps=0.01): x1y1, x2y2 = paddle.split(bbox, 2, -1) lt = points - x1y1 @@ -255,27 +239,23 @@ def get_loss(self, head_outs, gt_meta): for bbox_pred_org in bbox_dist_preds ] - flatten_dist_preds = paddle.concat(flatten_pred_dists, 1) # [8, 8400, 64] - pred_scores = paddle.concat(flatten_cls_preds, 1) # [8, 8400, 80] - pred_distri = paddle.concat(flatten_pred_bboxes, 1) # [8, 8400, 4] + flatten_dist_preds = paddle.concat(flatten_pred_dists, 1) + pred_scores = paddle.concat(flatten_cls_preds, 1) + pred_distri = paddle.concat(flatten_pred_bboxes, 1) anchor_points_s = anchor_points / stride_tensor - # pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri) - pred_bboxes = batch_distance2bbox(anchor_points_s, pred_distri) # xyxy # 9345410. [8, 8400, 4] - pred_bboxes = pred_bboxes * stride_tensor # must *stride + pred_bboxes = batch_distance2bbox(anchor_points_s, pred_distri) # xyxy + pred_bboxes = pred_bboxes * stride_tensor - gt_labels = gt_meta['gt_class'] # [16, 51, 1] - gt_bboxes = gt_meta['gt_bbox'] # xyxy max=640 # [16, 51, 4] + gt_labels = gt_meta['gt_class'] + gt_bboxes = gt_meta['gt_bbox'] # xyxy pad_gt_mask = gt_meta['pad_gt_mask'] - # pad_gt_mask = paddle.cast((gt_bboxes.sum(-1, keepdim=True) > 0), 'float32') - # label assignment - # [8, 8400, 4] 80842408. 8781976. assigned_scores 131.79727173 assigned_labels, assigned_bboxes, assigned_scores = \ self.assigner( F.sigmoid(pred_scores.detach()), - pred_bboxes.detach(), # * stride_tensor, + pred_bboxes.detach(), anchor_points, num_anchors_list, gt_labels, @@ -283,18 +263,11 @@ def get_loss(self, head_outs, gt_meta): pad_gt_mask, bg_index=self.num_classes) # rescale bbox - # print('assigned_scores ', assigned_scores.max(), assigned_scores.sum()) assigned_bboxes /= stride_tensor - pred_bboxes /= stride_tensor ### + pred_bboxes /= stride_tensor # cls loss - if 0: #self.use_varifocal_loss: - one_hot_label = F.one_hot(assigned_labels, - self.num_classes + 1)[..., :-1] - loss_cls = self._varifocal_loss(pred_scores, assigned_scores, - one_hot_label) - else: - loss_cls = self.bce(pred_scores, assigned_scores).sum() # [16, 8400, 80] + loss_cls = self.bce(pred_scores, assigned_scores).sum() assigned_scores_sum = assigned_scores.sum() if paddle.distributed.get_world_size() > 1: @@ -324,7 +297,7 @@ def get_loss(self, head_outs, gt_meta): eps=1e-7) loss_iou = ((1.0 - iou) * bbox_weight).sum() / assigned_scores_sum - if 1: #self.print_l1_loss: + if self.print_l1_loss: loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos) else: loss_l1 = paddle.zeros([1]) @@ -354,7 +327,7 @@ def get_loss(self, head_outs, gt_meta): loss_iou *= self.loss_weight['iou'] loss_dfl *= self.loss_weight['dfl'] loss_total = loss_cls + loss_iou + loss_dfl - # bs = head_outs[0].shape[0] + num_gpus = gt_meta.get('num_gpus', 8) total_bs = bs * num_gpus @@ -373,15 +346,15 @@ def post_process(self, head_outs, im_shape, scale_factor): pred_scores_list, pred_dist_list, anchor_points, stride_tensor = head_outs bs = pred_scores_list[0].shape[0] pred_scores = [ - cls_score.transpose([0, 2, 3, 1]).reshape([bs, -1, self.num_classes]) + cls_score.transpose([0, 2, 3, 1]).reshape([-1, int(cls_score.shape[2] * cls_score.shape[3]), self.num_classes]) for cls_score in pred_scores_list ] pred_dists = [ - bbox_pred.transpose([0, 2, 3, 1]).reshape([bs, -1, 4]) + bbox_pred.transpose([0, 2, 3, 1]).reshape([-1, int(bbox_pred.shape[2] * bbox_pred.shape[3]), 4]) for bbox_pred in pred_dist_list ] - pred_scores = F.sigmoid(paddle.concat(pred_scores, 1)) # [8, 8400, 80] - pred_bboxes = paddle.concat(pred_dists, 1) # [8, 8400, 4] + pred_scores = F.sigmoid(paddle.concat(pred_scores, 1)) + pred_bboxes = paddle.concat(pred_dists, 1) pred_bboxes = batch_distance2bbox(anchor_points, pred_bboxes) pred_bboxes *= stride_tensor @@ -400,3 +373,4 @@ def post_process(self, head_outs, im_shape, scale_factor): else: bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) return bbox_pred, bbox_num + From c50d12748a3d7bc1cd49c7480916e748b00b6c4b Mon Sep 17 00:00:00 2001 From: floveqq <940925733@qq.com> Date: Wed, 17 May 2023 13:36:46 +0000 Subject: [PATCH 5/7] export model ok --- ppdet/modeling/heads/yolov8_head.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/ppdet/modeling/heads/yolov8_head.py b/ppdet/modeling/heads/yolov8_head.py index 75486494..0fc9974e 100644 --- a/ppdet/modeling/heads/yolov8_head.py +++ b/ppdet/modeling/heads/yolov8_head.py @@ -165,6 +165,7 @@ def forward_eval(self, feats): anchor_points, stride_tensor = self._generate_anchors(feats) cls_logits_list, bbox_preds_list = [], [] + feats_shapes = [] for i, feat in enumerate(feats): _, _, h, w = feat.shape l = h * w @@ -176,8 +177,20 @@ def forward_eval(self, feats): bbox_preds = F.softmax(bbox_dist_preds, axis=3).matmul(self.proj.reshape([-1, 1])).squeeze(-1) cls_logits_list.append(cls_logit) bbox_preds_list.append(bbox_preds.transpose([0, 2, 1]).reshape([-1, 4, h, w])) + feats_shapes.append(l) - return cls_logits_list, bbox_preds_list, anchor_points, stride_tensor + pred_scores = [ + cls_score.transpose([0, 2, 3, 1]).reshape([-1, size, self.num_classes]) + for size, cls_score in zip(feats_shapes, cls_logits_list) + ] + pred_dists = [ + bbox_pred.transpose([0, 2, 3, 1]).reshape([-1, size, 4]) + for size, bbox_pred in zip(feats_shapes, bbox_preds_list) + ] + pred_scores = F.sigmoid(paddle.concat(pred_scores, 1)) + pred_bboxes = paddle.concat(pred_dists, 1) + + return pred_scores, pred_bboxes, anchor_points, stride_tensor def _generate_anchors(self, feats=None, dtype='float32'): # just use in eval time @@ -343,18 +356,7 @@ def get_loss(self, head_outs, gt_meta): return out_dict def post_process(self, head_outs, im_shape, scale_factor): - pred_scores_list, pred_dist_list, anchor_points, stride_tensor = head_outs - bs = pred_scores_list[0].shape[0] - pred_scores = [ - cls_score.transpose([0, 2, 3, 1]).reshape([-1, int(cls_score.shape[2] * cls_score.shape[3]), self.num_classes]) - for cls_score in pred_scores_list - ] - pred_dists = [ - bbox_pred.transpose([0, 2, 3, 1]).reshape([-1, int(bbox_pred.shape[2] * bbox_pred.shape[3]), 4]) - for bbox_pred in pred_dist_list - ] - pred_scores = F.sigmoid(paddle.concat(pred_scores, 1)) - pred_bboxes = paddle.concat(pred_dists, 1) + pred_scores, pred_bboxes, anchor_points, stride_tensor = head_outs pred_bboxes = batch_distance2bbox(anchor_points, pred_bboxes) pred_bboxes *= stride_tensor From d510f5535faaced8b5d5d283dd3a6b7c79d8101e Mon Sep 17 00:00:00 2001 From: QJL <108666823+floveqq@users.noreply.github.com> Date: Wed, 17 May 2023 21:39:33 +0800 Subject: [PATCH 6/7] Update README_cn.md --- README_cn.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/README_cn.md b/README_cn.md index bb5619ac..20544565 100644 --- a/README_cn.md +++ b/README_cn.md @@ -1,11 +1,3 @@ - -## 飞桨黑客松第四期 -赛题155. YOLOv8复现 - - -###################################################################### - - 简体中文 | [English](README_en.md) ## 简介 From 048bc3f23a10f1169ff98ef7b3cc69c31a91e1c8 Mon Sep 17 00:00:00 2001 From: QJL <108666823+floveqq@users.noreply.github.com> Date: Wed, 17 May 2023 21:40:07 +0800 Subject: [PATCH 7/7] Update README.md --- README.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/README.md b/README.md index bb5619ac..20544565 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,3 @@ - -## 飞桨黑客松第四期 -赛题155. YOLOv8复现 - - -###################################################################### - - 简体中文 | [English](README_en.md) ## 简介