TwoTST (Dual-Stream Self-Supervised Pretraining Framework) 是一个用于fMRI数据分析的双流Transformer框架,融合了时序Transformer和连接Transformer的优点,通过自监督预训练和对比学习提升ASD分类性能。
TwoTST框架包含两个独立的Transformer分支:
- TST1 (Transformer-TS): 处理原始fMRI时间序列,使用ROI-level掩码策略进行预训练
- TST2 (Transformer-FC): 处理PCC(Pearson相关系数)上三角向量,使用元素级掩码策略进行预训练
- ✅ 双流Transformer架构:分别处理时序和连接特征
- ✅ 自监督预训练:两种不同的掩码策略适配不同数据类型
- ✅ 顺序预训练:先TST1,后TST2,逐步学习表征
- ✅ 可选对比学习:对齐两个分支的特征空间
- ✅ 多种融合策略:支持5种特征融合方法
- ✅ 5折交叉验证:稳健的模型评估
┌─────────────────────────────────────────────────────────────┐
│ 数据预处理 │
│ fmri.npy (N, T, R) → 清洗 → (N', T, R) │
│ ↓ │
│ 时间序列 (N', R, T) + PCC向量 (N', R*(R-1)/2) │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Phase 1: TST1 预训练 │
│ 时间序列 → ROI-level掩码 → Transformer-TS → 重建时序 │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Phase 2: TST2 预训练 │
│ PCC向量 → 元素级掩码 → Transformer-FC → 重建PCC │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Phase 3: 对比学习(可选) │
│ TST1特征 + TST2特征 → InfoNCE损失 → 特征对齐 │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Phase 4: 微调分类 │
│ 融合特征 → MLP分类器 → ASD/TC预测 │
└─────────────────────────────────────────────────────────────┘
TwoTST/
├── models/ # 模型定义
│ ├── transformer_ts.py # TST1: 时序Transformer
│ ├── transformer_fc.py # TST2: 连接Transformer
│ ├── fusion.py # 融合模块(5种策略)
│ ├── dual_stream.py # 双流模型
│ └── __init__.py # 模型导出
│
├── pretrain/ # 预训练模块
│ ├── mask_utils.py # 掩码策略工具
│ ├── pretrain_ts.py # TST1预训练脚本
│ ├── pretrain_fc.py # TST2预训练脚本
│ ├── contrastive.py # 对比学习模块
│ └── __init__.py
│
├── scripts/ # 训练脚本
│ ├── prepare_data.py # 数据预处理
│ ├── train_pretrain.py # 预训练入口
│ ├── train_finetune.py # 微调入口
│ └── start_tensorboard.sh # TensorBoard启动脚本
│
├── utils/ # 工具函数
│ ├── data_loader.py # 数据加载器
│ ├── metrics.py # 评估指标
│ └── __init__.py
│
├── configs/ # 配置文件
│ ├── default.yaml # 默认配置
│ └── base_template.yaml # 配置模板
│
├── requirements.txt # 依赖包
└── README.md # 本文档
注意: 训练过程中会生成以下目录(建议添加到 .gitignore):
checkpoints/- 模型检查点logs/- TensorBoard日志data/- 数据文件results/- 实验结果
系统要求:
- Python >= 3.7
- CUDA >= 10.2 (GPU推荐)
安装依赖:
# 克隆仓库
git clone https://github.com/Leezy-Ray/twoTST.git
cd twoTST
# 安装依赖
pip install -r requirements.txt主要依赖包:
- PyTorch >= 1.10.0
- NumPy >= 1.20.0
- scikit-learn >= 1.0.0
- TensorBoard >= 2.8.0
- tqdm, PyYAML
准备您的fMRI数据,格式为 (n_samples, time_points, n_rois) 的numpy数组。
# 运行数据预处理脚本
python scripts/prepare_data.py \
--data_path /path/to/your/fmri.npy \
--output_dir data/processed \
--n_rois 200 \
--time_points 100数据预处理功能:
- ✅ 自动清洗ROI全零样本
- ✅ 计算PCC(Pearson相关系数)上三角向量
- ✅ 可选滑动窗口数据增强
- ✅ 数据标准化和划分(训练/验证/测试)
输出数据格式:
timeseries:(n_samples, n_rois, time_points)- 时间序列pcc_vectors:(n_samples, n_rois*(n_rois-1)/2)- PCC上三角向量labels:(n_samples,)- 标签 (0=ASD, 1=TC)
使用统一入口脚本进行顺序预训练:
python scripts/train_pretrain.py \
--data_path data/processed/processed_data.pkl \
--pretrain_tst1 \
--pretrain_tst2 \
--tst1_epochs 100 \
--tst2_epochs 100 \
--batch_size 32 \
--lr 1e-4 \
--save_dir checkpoints \
--log_dir logsTST1预训练(时序Transformer):
python pretrain/pretrain_ts.py \
--data_path data/processed/processed_data.pkl \
--epochs 100 \
--batch_size 32 \
--lr 1e-4 \
--save_dir checkpoints/tst1 \
--log_dir logs/tst1TST1特点:
- 输入:
(batch, n_rois, time_points)- 时间序列 - 掩码策略: ROI-level掩码(随机掩码25%或50%的ROI整列)
- 预训练任务: 重建被掩码ROI的完整时间序列
- 模型参数: ~19M
TST2预训练(连接Transformer):
python pretrain/pretrain_fc.py \
--data_path data/processed/processed_data.pkl \
--epochs 100 \
--batch_size 32 \
--lr 1e-4 \
--mask_ratio 0.15 \
--save_dir checkpoints/tst2 \
--log_dir logs/tst2TST2特点:
- 输入:
(batch, pcc_dim)- PCC上三角向量 - 掩码策略: 元素级掩码(随机掩码15%的PCC值)
- 预训练任务: 重建被掩码的PCC值
- 模型参数: ~16M
加载预训练权重进行下游分类任务:
python scripts/train_finetune.py \
--data_path data/processed/processed_data.pkl \
--tst1_checkpoint checkpoints/tst1/tst1_best.pt \
--tst2_checkpoint checkpoints/tst2/tst2_best.pt \
--fusion_type cross_attention \
--epochs 100 \
--batch_size 32 \
--lr 5e-5 \
--n_folds 5 \
--use_contrastive # 可选:启用对比学习微调参数说明:
--fusion_type: 融合策略 (concat,gated,cross_attention,bilinear,attention_pooling)--n_folds: 交叉验证折数(默认5折)--use_contrastive: 是否在微调前进行对比学习对齐
训练过程中可以使用TensorBoard实时查看训练曲线:
# 启动TensorBoard服务
bash scripts/start_tensorboard.sh
# 或手动启动
tensorboard --logdir=logs --port=6006 --host=0.0.0.0然后在浏览器中访问 http://localhost:6006 查看训练曲线。
配置文件位于 configs/default.yaml,主要配置项:
# 数据配置
data:
n_rois: 200
time_points: 100
pcc_dim: 19900
# TST1配置
tst1:
emb_dim: 512
n_heads: 8
n_layers: 6
dim_feedforward: 2048
# TST2配置
tst2:
d_model: 256
n_heads: 8
n_layers: 2
dim_feedforward: 512
# 融合配置
fusion:
type: cross_attention # concat/gated/cross_attention/bilinear/attention_pooling
# 对比学习配置
contrastive:
enabled: false
temperature: 0.07
epochs: 50框架支持5种融合策略:
- ConcatFusion: 简单拼接
[h_ts; h_fc] - GatedFusion: 门控融合
gate * h_ts + (1-gate) * h_fc - CrossAttentionFusion: 交叉注意力融合(推荐)
- BilinearFusion: 双线性融合
- AttentionPoolingFusion: 注意力池化融合
微调脚本会自动计算以下指标:
- Accuracy: 准确率
- Precision: 精确率
- Recall: 召回率
- F1 Score: F1分数
- AUC: ROC曲线下面积
- Sensitivity/Specificity: 敏感度/特异度
输出示例:
Cross-Validation Results:
----------------------------------------
Accuracy : 0.7234 ± 0.0234
Precision : 0.7123 ± 0.0198
Recall : 0.7345 ± 0.0212
F1 : 0.7231 ± 0.0201
AUC : 0.7891 ± 0.0156
import torch
from models import create_dual_stream_model
from utils.data_loader import load_processed_data, TwoTSTDataset
# 加载数据
data = load_processed_data('data/processed/processed_data.pkl')
dataset = TwoTSTDataset(
data['timeseries'],
data['pcc_vectors'],
data['labels']
)
# 创建模型
model = create_dual_stream_model(
n_rois=200,
time_points=100,
pcc_dim=19900,
fusion_type='cross_attention'
)
# 加载预训练权重
model.load_pretrained_tst1('checkpoints/tst1/tst1_best.pt')
model.load_pretrained_tst2('checkpoints/tst2/tst2_best.pt')
# 前向传播
timeseries = torch.randn(8, 200, 100)
pcc_vector = torch.randn(8, 19900)
logits = model(timeseries, pcc_vector)A: 可以减小batch_size或使用梯度累积:
--batch_size 16 # 减小batch sizeA: 可以修改脚本,只加载并使用单个分支的预训练权重。
A: 在微调脚本中指定checkpoint路径:
--tst1_checkpoint checkpoints/tst1/tst1_best.pt
--tst2_checkpoint checkpoints/tst2/tst2_best.ptA: 在数据预处理时启用:
python scripts/prepare_data.py \
--use_sliding_window \
--window_size 50 \
--stride 25本项目参考了以下工作:
- ROI-level掩码预训练时序Transformer
- PCC上三角向量掩码预训练连接Transformer
本项目仅供研究使用。
欢迎提交Issue和Pull Request!
如果您有任何问题或建议,请:
- 提交 Issue
- 发起 Pull Request
如有问题,请提交Issue或联系项目维护者。
TwoTST - Dual-Stream Self-Supervised Pretraining Framework for fMRI Analysis