一个基于 Accelerate 的深度学习训练框架,专为图像分类任务设计,提供统一的 GPU 配置、实验管理和模块化设计。
- 统一的 GPU 配置: 支持单卡和多卡训练的统一参数设置
- 实验管理: 集成 SwanLab 进行实验跟踪和可视化
- 模块化设计: 支持自定义数据集、模型、损失函数和优化器
- 模型支持: 内置 ResNet、VGG、DenseNet 等经典模型
- 灵活配置: 基于 YAML 的配置文件系统
- 断点续训: 支持训练中断后的自动恢复
- 早停机制: 防止过拟合的智能训练停止
- 推理可视化: 提供模型推理和结果可视化工具
- 详细日志: 完整的训练过程记录和监控
本框架采用模块化设计,将深度学习训练的各个组件解耦,便于扩展和定制:
- 配置驱动: 所有训练参数通过 YAML 配置文件管理
- 工厂模式: 通过工厂函数动态创建数据集、模型、损失函数等组件
- 分布式友好: 基于 Accelerate 实现无缝的单卡/多卡切换
- 实验追踪: 集成 SwanLab 实现训练过程的可视化监控
SwanLab_Accelerator/
├── configs/ # 配置文件目录
│ └── image_classification.yaml # 图像分类配置文件
├── data/ # 数据集目录
├── datasets/ # 数据集模块
│ ├── __init__.py
│ ├── base_dataset.py # 数据集基类
│ ├── cifar.py # CIFAR数据集
│ └── imagenet.py # ImageNet数据集
├── models/ # 模型模块
│ ├── __init__.py
│ ├── resnet.py # ResNet模型
│ ├── vgg.py # VGG模型
│ └── densenet.py # DenseNet模型
├── trainer/ # 训练器模块
│ ├── __init__.py
│ ├── base_trainer.py # 基础训练器(支持断点恢复和早停)
│ └── train_image_classification.py # 图像分类训练脚本
├── utils/ # 工具模块
│ ├── __init__.py
│ ├── config.py # 配置加载工具
│ ├── factory.py # 组件工厂(支持自定义组件)
│ ├── logger.py # 日志工具
│ └── metrics.py # 评估指标
├── schedulers/ # 学习率调度器
│ ├── __init__.py
│ └── factory.py # 调度器工厂
├── losses/ # 损失函数模块
│ ├── __init__.py
│ └── custom_losses.py # 自定义损失函数
├── optimizers/ # 优化器模块
│ ├── __init__.py
│ └── custom_optimizers.py # 自定义优化器
├── result/ # 训练结果目录
├── swanlog/ # SwanLab日志目录
├── inference_visualization.py # 推理和可视化脚本
└── README.md # 项目说明文档
# 安装依赖
pip install torch torchvision accelerate swanlab pyyaml matplotlib seaborn scikit-learn
# 初始化 accelerate(首次使用)
accelerate config# 使用默认GPU
python trainer/train_image_classification.py --config configs/image_classification.yaml
# 指定GPU
python trainer/train_image_classification.py --config configs/image_classification.yaml --gpu 0# 使用accelerate启动多卡训练
accelerate launch trainer/train_image_classification.py --config configs/image_classification.yaml
# 指定使用的GPU
accelerate launch trainer/train_image_classification.py --config configs/image_classification.yaml --gpu 0,1,2,3# 自动查找最新检查点恢复
python trainer/train_image_classification.py --config configs/image_classification.yaml --resume auto
# 从指定检查点恢复
python trainer/train_image_classification.py --config configs/image_classification.yaml --resume ./result/checkpoints/epoch_10# 基本推理(使用实际的检查点路径)
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data
# 带可视化的推理
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data --visualize
# 完整分析(推理+可视化+错误分析)
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data --visualize --error_analysis
# 如果配置文件无法自动找到,可以手动指定
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data --config configs/image_classification.yaml --visualize配置文件 configs/image_classification.yaml 控制所有训练参数:
# 数据集配置
dataset:
name: "CIFAR10" # 数据集名称
root: "./data" # 数据根目录
download: true # 是否自动下载
num_workers: 4 # 数据加载进程数
pin_memory: true # 是否使用内存锁定
# 模型配置
model:
name: "resnet18" # 模型名称
num_classes: 10 # 分类数量
pretrained: false # 是否使用预训练权重
# 训练配置
training:
num_epochs: 100 # 训练轮数
batch_size: 128 # 批次大小
# 早停配置
early_stopping:
patience: 10 # 早停耐心值
metric: "val_accuracy" # 监控指标
mode: "max" # 优化模式(max/min)
# 损失函数配置
loss:
name: "CrossEntropyLoss" # 损失函数名称
# 优化器配置
optimizer:
name: "Adam" # 优化器名称
lr: 0.001 # 学习率
weight_decay: 1e-4 # 权重衰减在 datasets/ 目录下创建新的数据集文件:
# datasets/custom_dataset.py
from torch.utils.data import Dataset
from .base_dataset import BaseDataset
class CustomDataset(BaseDataset):
def __init__(self, root, train=True, transform=None, download=False):
super().__init__()
# 实现数据集逻辑
pass
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 返回 (image, label)
pass然后在 utils/factory.py 中注册:
def create_dataset(config):
dataset_name = config.get('name', 'CIFAR10')
if dataset_name == 'CustomDataset':
from datasets.custom_dataset import CustomDataset
# 创建数据集实例在 models/ 目录下创建新的模型文件:
# models/custom_model.py
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
# 定义模型结构
pass
def forward(self, x):
# 前向传播
pass在 utils/factory.py 中注册:
def create_model(config):
model_name = config.get('name', 'resnet18')
if model_name == 'custom_model':
from models.custom_model import CustomModel
return CustomModel(num_classes=config.get('num_classes', 10))在 losses/custom_losses.py 中定义:
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
# 实现Focal Loss逻辑
pass在 utils/factory.py 中注册:
def create_loss_function(config):
loss_name = config.get('name', 'CrossEntropyLoss')
if loss_name == 'FocalLoss':
from losses.custom_losses import FocalLoss
return FocalLoss(
alpha=config.get('alpha', 1),
gamma=config.get('gamma', 2)
)在 optimizers/custom_optimizers.py 中定义:
import torch.optim as optim
class CustomOptimizer(optim.Optimizer):
def __init__(self, params, lr=1e-3):
defaults = dict(lr=lr)
super().__init__(params, defaults)
def step(self, closure=None):
# 实现优化步骤
pass在 utils/factory.py 中注册:
def create_optimizer(model, config):
optimizer_name = config.get('name', 'Adam')
if optimizer_name == 'CustomOptimizer':
from optimizers.custom_optimizers import CustomOptimizer
return CustomOptimizer(
model.parameters(),
lr=config.get('lr', 1e-3)
)- 参数解析: 解析命令行参数(配置文件、GPU设置等)
- GPU配置: 根据
--gpu参数设置CUDA_VISIBLE_DEVICES - 配置加载: 从YAML文件加载训练配置
- 组件创建: 通过工厂函数创建数据集、模型、损失函数、优化器
- 训练执行: 调用
BaseTrainer.train()开始训练
- 单卡训练: 直接使用
python命令,通过--gpu参数指定GPU - 多卡训练: 使用
accelerate launch命令,Accelerate自动处理分布式设置
- 自动恢复:
--resume auto自动查找最新检查点 - 指定恢复:
--resume path/to/checkpoint从指定检查点恢复 - 状态恢复: 恢复模型权重、优化器状态、学习率调度器状态、随机数种子
- 监控指标: 可配置监控的验证指标(如
val_accuracy、val_loss) - 优化模式: 支持最大化(
max)和最小化(min)模式 - 耐心值: 配置连续多少个epoch无改善后停止训练
框架自动集成SwanLab进行实验跟踪:
- 自动记录: 训练损失、验证准确率、学习率等指标
- 实验命名: 基于数据集、模型和时间戳的自动命名
- 可视化: 实时查看训练曲线和指标变化
- 定期保存: 根据
--save_freq参数定期保存检查点 - 最佳模型: 自动保存验证集上表现最好的模型
- 完整状态: 保存模型、优化器、调度器的完整状态
A: 修改配置文件中的 dataset.name 字段,确保对应的数据集类已在工厂函数中注册。
A: 在配置文件中设置 model.pretrained: true,或在模型配置中指定预训练权重路径。
A: 使用 accelerate launch 时,通过 --gpu 参数指定,如 --gpu 0,1,2,3。
A: 修改配置文件中的 early_stopping 部分,调整 patience、metric 和 mode 参数。
A: 推理脚本会自动查找配置文件,如果找不到可以通过 --config 参数手动指定。
欢迎提交Issue和Pull Request来改进这个项目。在贡献代码时,请确保:
- 遵循现有的代码风格和结构
- 添加适当的文档和注释
- 确保新功能与现有模块兼容
- 提供相应的测试用例
本项目采用MIT许可证,详见LICENSE文件。