Skip to content

Saunak626/SwanLab_Accelerator

Repository files navigation

SwanLab Accelerator

一个基于 Accelerate 的深度学习训练框架,专为图像分类任务设计,提供统一的 GPU 配置、实验管理和模块化设计。

主要特性

  • 统一的 GPU 配置: 支持单卡和多卡训练的统一参数设置
  • 实验管理: 集成 SwanLab 进行实验跟踪和可视化
  • 模块化设计: 支持自定义数据集、模型、损失函数和优化器
  • 模型支持: 内置 ResNet、VGG、DenseNet 等经典模型
  • 灵活配置: 基于 YAML 的配置文件系统
  • 断点续训: 支持训练中断后的自动恢复
  • 早停机制: 防止过拟合的智能训练停止
  • 推理可视化: 提供模型推理和结果可视化工具
  • 详细日志: 完整的训练过程记录和监控

核心设计理念

本框架采用模块化设计,将深度学习训练的各个组件解耦,便于扩展和定制:

  • 配置驱动: 所有训练参数通过 YAML 配置文件管理
  • 工厂模式: 通过工厂函数动态创建数据集、模型、损失函数等组件
  • 分布式友好: 基于 Accelerate 实现无缝的单卡/多卡切换
  • 实验追踪: 集成 SwanLab 实现训练过程的可视化监控

项目结构

SwanLab_Accelerator/
├── configs/                    # 配置文件目录
│   └── image_classification.yaml  # 图像分类配置文件
├── data/                       # 数据集目录
├── datasets/                   # 数据集模块
│   ├── __init__.py
│   ├── base_dataset.py        # 数据集基类
│   ├── cifar.py               # CIFAR数据集
│   └── imagenet.py            # ImageNet数据集
├── models/                     # 模型模块
│   ├── __init__.py
│   ├── resnet.py              # ResNet模型
│   ├── vgg.py                 # VGG模型
│   └── densenet.py            # DenseNet模型
├── trainer/                    # 训练器模块
│   ├── __init__.py
│   ├── base_trainer.py        # 基础训练器(支持断点恢复和早停)
│   └── train_image_classification.py  # 图像分类训练脚本
├── utils/                      # 工具模块
│   ├── __init__.py
│   ├── config.py              # 配置加载工具
│   ├── factory.py             # 组件工厂(支持自定义组件)
│   ├── logger.py              # 日志工具
│   └── metrics.py             # 评估指标
├── schedulers/                 # 学习率调度器
│   ├── __init__.py
│   └── factory.py             # 调度器工厂
├── losses/                     # 损失函数模块
│   ├── __init__.py
│   └── custom_losses.py       # 自定义损失函数
├── optimizers/                 # 优化器模块
│   ├── __init__.py
│   └── custom_optimizers.py   # 自定义优化器
├── result/                     # 训练结果目录
├── swanlog/                    # SwanLab日志目录
├── inference_visualization.py  # 推理和可视化脚本
└── README.md                   # 项目说明文档

快速开始

环境准备

# 安装依赖
pip install torch torchvision accelerate swanlab pyyaml matplotlib seaborn scikit-learn

# 初始化 accelerate(首次使用)
accelerate config

训练模式

1. 单卡训练

# 使用默认GPU
python trainer/train_image_classification.py --config configs/image_classification.yaml

# 指定GPU
python trainer/train_image_classification.py --config configs/image_classification.yaml --gpu 0

2. 多卡训练

# 使用accelerate启动多卡训练
accelerate launch trainer/train_image_classification.py --config configs/image_classification.yaml

# 指定使用的GPU
accelerate launch trainer/train_image_classification.py --config configs/image_classification.yaml --gpu 0,1,2,3

3. 断点恢复训练

# 自动查找最新检查点恢复
python trainer/train_image_classification.py --config configs/image_classification.yaml --resume auto

# 从指定检查点恢复
python trainer/train_image_classification.py --config configs/image_classification.yaml --resume ./result/checkpoints/epoch_10

推理和可视化

# 基本推理(使用实际的检查点路径)
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data

# 带可视化的推理
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data --visualize

# 完整分析(推理+可视化+错误分析)
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data --visualize --error_analysis

# 如果配置文件无法自动找到,可以手动指定
python inference_visualization.py --model_path ./result/ImageClassification_CIFAR10_20250720_201039/epoch_10_checkpoint --data_path ./data --config configs/image_classification.yaml --visualize

配置文件详解

配置文件 configs/image_classification.yaml 控制所有训练参数:

# 数据集配置
dataset:
  name: "CIFAR10"              # 数据集名称
  root: "./data"               # 数据根目录
  download: true               # 是否自动下载
  num_workers: 4               # 数据加载进程数
  pin_memory: true             # 是否使用内存锁定

# 模型配置
model:
  name: "resnet18"             # 模型名称
  num_classes: 10              # 分类数量
  pretrained: false            # 是否使用预训练权重

# 训练配置
training:
  num_epochs: 100              # 训练轮数
  batch_size: 128              # 批次大小
  
# 早停配置
early_stopping:
  patience: 10                 # 早停耐心值
  metric: "val_accuracy"       # 监控指标
  mode: "max"                  # 优化模式(max/min)

# 损失函数配置
loss:
  name: "CrossEntropyLoss"     # 损失函数名称

# 优化器配置
optimizer:
  name: "Adam"                 # 优化器名称
  lr: 0.001                    # 学习率
  weight_decay: 1e-4           # 权重衰减

模块化扩展指南

1. 添加新数据集

datasets/ 目录下创建新的数据集文件:

# datasets/custom_dataset.py
from torch.utils.data import Dataset
from .base_dataset import BaseDataset

class CustomDataset(BaseDataset):
    def __init__(self, root, train=True, transform=None, download=False):
        super().__init__()
        # 实现数据集逻辑
        pass
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # 返回 (image, label)
        pass

然后在 utils/factory.py 中注册:

def create_dataset(config):
    dataset_name = config.get('name', 'CIFAR10')
    
    if dataset_name == 'CustomDataset':
        from datasets.custom_dataset import CustomDataset
        # 创建数据集实例

2. 添加新模型

models/ 目录下创建新的模型文件:

# models/custom_model.py
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # 定义模型结构
        pass
    
    def forward(self, x):
        # 前向传播
        pass

utils/factory.py 中注册:

def create_model(config):
    model_name = config.get('name', 'resnet18')
    
    if model_name == 'custom_model':
        from models.custom_model import CustomModel
        return CustomModel(num_classes=config.get('num_classes', 10))

3. 添加自定义损失函数

losses/custom_losses.py 中定义:

import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        # 实现Focal Loss逻辑
        pass

utils/factory.py 中注册:

def create_loss_function(config):
    loss_name = config.get('name', 'CrossEntropyLoss')
    
    if loss_name == 'FocalLoss':
        from losses.custom_losses import FocalLoss
        return FocalLoss(
            alpha=config.get('alpha', 1),
            gamma=config.get('gamma', 2)
        )

4. 添加自定义优化器

optimizers/custom_optimizers.py 中定义:

import torch.optim as optim

class CustomOptimizer(optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)
    
    def step(self, closure=None):
        # 实现优化步骤
        pass

utils/factory.py 中注册:

def create_optimizer(model, config):
    optimizer_name = config.get('name', 'Adam')
    
    if optimizer_name == 'CustomOptimizer':
        from optimizers.custom_optimizers import CustomOptimizer
        return CustomOptimizer(
            model.parameters(),
            lr=config.get('lr', 1e-3)
        )

训练流程说明

1. 训练启动流程

  1. 参数解析: 解析命令行参数(配置文件、GPU设置等)
  2. GPU配置: 根据 --gpu 参数设置 CUDA_VISIBLE_DEVICES
  3. 配置加载: 从YAML文件加载训练配置
  4. 组件创建: 通过工厂函数创建数据集、模型、损失函数、优化器
  5. 训练执行: 调用 BaseTrainer.train() 开始训练

2. 单卡 vs 多卡训练

  • 单卡训练: 直接使用 python 命令,通过 --gpu 参数指定GPU
  • 多卡训练: 使用 accelerate launch 命令,Accelerate自动处理分布式设置

3. 断点恢复机制

  • 自动恢复: --resume auto 自动查找最新检查点
  • 指定恢复: --resume path/to/checkpoint 从指定检查点恢复
  • 状态恢复: 恢复模型权重、优化器状态、学习率调度器状态、随机数种子

4. 早停机制

  • 监控指标: 可配置监控的验证指标(如 val_accuracyval_loss
  • 优化模式: 支持最大化(max)和最小化(min)模式
  • 耐心值: 配置连续多少个epoch无改善后停止训练

实验管理

SwanLab集成

框架自动集成SwanLab进行实验跟踪:

  • 自动记录: 训练损失、验证准确率、学习率等指标
  • 实验命名: 基于数据集、模型和时间戳的自动命名
  • 可视化: 实时查看训练曲线和指标变化

检查点管理

  • 定期保存: 根据 --save_freq 参数定期保存检查点
  • 最佳模型: 自动保存验证集上表现最好的模型
  • 完整状态: 保存模型、优化器、调度器的完整状态

常见问题

Q: 如何切换数据集?

A: 修改配置文件中的 dataset.name 字段,确保对应的数据集类已在工厂函数中注册。

Q: 如何使用预训练模型?

A: 在配置文件中设置 model.pretrained: true,或在模型配置中指定预训练权重路径。

Q: 多卡训练时如何指定GPU?

A: 使用 accelerate launch 时,通过 --gpu 参数指定,如 --gpu 0,1,2,3

Q: 如何调整早停策略?

A: 修改配置文件中的 early_stopping 部分,调整 patiencemetricmode 参数。

Q: 推理脚本找不到配置文件怎么办?

A: 推理脚本会自动查找配置文件,如果找不到可以通过 --config 参数手动指定。

贡献指南

欢迎提交Issue和Pull Request来改进这个项目。在贡献代码时,请确保:

  1. 遵循现有的代码风格和结构
  2. 添加适当的文档和注释
  3. 确保新功能与现有模块兼容
  4. 提供相应的测试用例

许可证

本项目采用MIT许可证,详见LICENSE文件。

About

Use accelerator to train and record the training demo in Swanlab

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages