Skip to content

LeeCASC/flow_matching

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flow Matching 无条件生成模型

这是一个使用 Flow Matching 方法实现的无条件生成模型,专门用于生成形状为 (batch_size, 256, 16, 48) 的特征张量。

项目概述

Flow Matching 是一种基于连续归一化流(Continuous Normalizing Flows)的生成模型方法,它通过学习从噪声分布到目标数据分布的连续变换来生成样本。与传统的扩散模型相比,Flow Matching 具有更快的采样速度和更好的理论性质。

主要特性

  • 无条件生成: 从随机噪声生成目标形状的特征张量
  • Flow Matching 算法: 使用最新的 Flow Matching 方法进行训练
  • 高效采样: 相比扩散模型,采样速度更快
  • 可视化工具: 包含完整的可视化和分析工具
  • 模块化设计: 易于扩展和修改

安装依赖

pip install -r requirements.txt

快速开始

1. 训练模型

python train.py

这将:

  • 创建一个合成数据集用于演示
  • 训练 Flow Matching 模型
  • 保存模型检查点到 checkpoints/ 目录
  • 生成训练曲线和样本可视化

2. 生成样本

python demo.py

这将:

  • 加载训练好的模型
  • 生成新的样本
  • 提供详细的统计分析和可视化

3. 直接使用模型

from flow_matching_model import create_model, FlowMatchingTrainer

# 创建模型
model = create_model()
trainer = FlowMatchingTrainer(model, device='cuda')

# 生成样本
samples = trainer.sample(batch_size=4, num_steps=100)
print(f"生成样本形状: {samples.shape}")

模型架构

FlowMatchingModel

主要的生成模型,包含:

  • 时间嵌入层: 将时间步信息嵌入到网络中
  • 输入投影层: 将输入特征投影到隐藏维度
  • ResNet 主干网络: 使用残差块处理特征
  • 输出投影层: 将隐藏特征投影回目标维度

网络参数

  • 输入/输出通道数: 256
  • 空间尺寸: 16 × 48
  • 隐藏维度: 512
  • 残差块数量: 4

训练过程

Flow Matching 算法

  1. 路径插值: 在噪声分布和数据分布之间创建线性插值路径
  2. 速度场预测: 网络学习预测从噪声到数据的速度场
  3. 损失计算: 使用 MSE 损失比较预测和真实速度场
  4. 采样: 通过积分速度场从噪声生成样本

训练参数

  • 学习率: 1e-4
  • 优化器: AdamW
  • 批量大小: 16
  • 训练轮次: 50

文件结构

flow_matching/
├── flow_matching_model.py  # 核心模型实现
├── train.py               # 训练脚本
├── demo.py                # 演示和分析脚本
├── requirements.txt       # 依赖包列表
├── README.md             # 项目说明
└── checkpoints/          # 模型检查点目录

理论背景

Flow Matching vs 扩散模型

Flow Matching 相比传统扩散模型的主要优势:

  1. 更快的采样: 不需要多步去噪过程
  2. 更好的理论性质: 基于连续归一化流的理论
  3. 更简单的训练: 直接学习速度场,不需要噪声调度

数学原理

Flow Matching 学习一个速度场 v(x, t),使得:

dx/dt = v(x, t)

其中 x(t) 是从噪声分布到数据分布的连续路径。训练目标是最小化:

L = E[||v_θ(x(t), t) - v_true(x(t), t)||²]

实验结果

训练完成后,模型能够:

  • 生成符合目标分布的特征张量
  • 保持空间和通道间的相关性
  • 产生多样化的样本

扩展和修改

修改目标形状

要生成不同形状的特征,修改 create_model() 函数中的参数:

model = FlowMatchingModel(
    in_channels=your_channels,
    height=your_height,
    width=your_width,
    hidden_dim=your_hidden_dim,
    num_blocks=your_blocks
)

使用真实数据

替换 SyntheticDataset 类,加载您的真实数据:

class YourDataset:
    def __init__(self, data_path):
        # 加载您的数据
        self.data = load_your_data(data_path)
    
    def __getitem__(self, idx):
        return self.data[idx]

故障排除

内存不足

如果遇到内存问题,可以:

  • 减小批量大小
  • 减少隐藏维度
  • 使用梯度累积

训练不收敛

如果训练不收敛,可以:

  • 调整学习率
  • 增加训练轮次
  • 检查数据分布

引用

如果您在研究中使用了这个实现,请引用相关的 Flow Matching 论文:

@article{flow_matching_2023,
  title={Flow Matching for Generative Modeling},
  author={...},
  journal={...},
  year={2023}
}

许可证

本项目采用 MIT 许可证。详见 LICENSE 文件。

贡献

欢迎提交 Issue 和 Pull Request 来改进这个项目!

About

easy generation code

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages