Skip to content

一个完整的基于transformer的股票模型训练项目

License

Notifications You must be signed in to change notification settings

hujiyo/EquiNet-v2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

73 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EquiNet

项目简介

  • 基于A股2021年后股票历史数据进行建模
  • EquiNet基于历史数据进行统计建模,对未来3天是否上涨进行打分
  • 数据泄露的解决方案:吸取了之前数据泄露的教训,最新的训练方法会将数据集中最近的80天数据冻结作为评估集,并通过去交叠的机制避免了数据泄露,同时确保了评估的一致性和可重复性。
  • 预测性能良好:在大量测试训练过程中我发现AUC的高低对于阈值的选取过于敏感,大部分时刻并不能很稳定的反映模型的实际排序打分能力,因此最新的评估机制我最终还是返璞归真,决定完全由Top-k%的收益率来进行测试(实验默认取1%),同时新增几何复利收益率来评估模型的现实稳定性
  • 新增教师模型标签纠偏机制:能在收益率表现较优的模型的基础之上对现有只经过真实标签训练的模型进行上限的拔高,通常能将模型的收益率进一步提升0%-0.6%,同时大幅度提高训练稳定性

主要特性

  • Transformer 架构:强大的时序建模能力,适合金融数据。
  • 测试集损失监控:每轮训练后计算并显示测试集损失,帮助判断模型是否过拟合。
  • 灵活参数配置:支持自定义模型维度、层数、训练轮数等,适应不同硬件与需求。
  • 窗口归一化处理:每样本进行独立归一化,消除价格量级影响。
  • GPU加速:BF16精度在CUDA平台训练。
  • 训练过程评估:每轮训练后自动评估模型当前预测准确率。
  • 实用预测:专注于预测股票上涨概率,更符合实际交易需求。
  • 手动尝试调参:只提供相对基础的config参数,实际可个性化调整参数以追求更高预测效果。

目录结构

EquiNet/
├── src/
│   ├── data/                	# 存放.csv股票数据文件
│   ├── out/                 	# 训练输出的模型权重
│   ├── train.py         	    # 主训练脚本
│   ├── train_clone.py        # AB教师训练脚本
│   ├── train_evolve.py       # 进化训练脚本
│   ├── config.py        	    # 统一配置文件
│   └── ...              	    # 其他可能的源码
├── LICENSE
├── README.md
└── ...

数据格式说明

  • 数据目录:./data
  • 价格区间:1 ~ 6 元附近
  • 每个.csv文件对应一只股票
  • 字段说明(共8列):
    • time:日期(如2023/06/27
    • start:开盘价
    • max:最高价
    • min:最低价
    • end:收盘价
    • volume:股票成交量
    • exchange:换手率

v2数据找时间发布到huggingface

模型输出说明

最终文件名格式

modelB_top1_p1_11pct_thr0_485_auc0_6182_ep29_1214_1930.pth
  │      │    │          │        │       │      └── 时间戳
  │      │    │          │        │       └── 最佳轮次
  │      │    │          │        └── AUC
  │      │    │          └── 阈值(实盘用)
  │      │    └── 收益率
  │      └── Top-K
  └── 模型类型

实盘使用:看到 thr0_485 就知道当模型的预测值 ≥ 0.485 时就说明该样本达到Top1%的输出分下限了

解读: modelB - 模型的来源:clone脚本生成的B型模型 top1 - Top1%选股,k取1 p1_11pct - 收益率 +1.11% 在测试集上的收益率 thr0_485 - 阈值 0.485(预测值≥0.485即入选Top1%) auc0_6182 - AUC 0.6182 ep29 - 第29轮 1214_1930 - 12月14日19:30

安装与环境

  • 环境配置:environment.yaml

使用流程

快速开始

  1. 克隆项目

    git clone https://gitee.com/hujiyo/EquiNet.git

    将数据集放在./data目录下。每个.csv文件对应一只股票。

  2. 创建虚拟环境

    conda env create -f environment.yaml && conda activate equinet
  3. 运行训练脚本

    python src/train.py

训练结果示例(最新版本可能略微不同 &&参数需要自己去调)

Epoch 8/40, LR: 0.000992 (正常训练), 采样进度: 9677321/14210505 (68.1%)
  使用时间顺序采样器生成数据...
    批量生成样本索引: 8轮...
    生成了42697个样本索引,开始处理...
    样本池: 正样本=3084, 负样本=38253
    已生成 56/56 个batch
  训练进度: 100.0% (56/56), 平均损失: 3.0219

  样本预测示例 (Epoch 8):
    样本20716: 预测=0.6250, 真实=1.0
    样本27946: 预测=0.5977, 真实=0.0
    样本3405: 预测=0.3066, 真实=0.0
    样本2562: 预测=0.6016, 真实=0.0
    样本81207: 预测=0.3965, 真实=0.0
  不上涨: 60316/92623 = 0.651
  上涨: 2738/4001 = 0.684
  上涨准确率: 2738/35045 = 0.078 准确率: 14775/35045 = 0.422
  置信度区间精确度:
    0.50-0.55: 上涨准确=365/7182=0.051, 非负准确=3125/7182=0.435
    0.55-0.58: 上涨准确=375/6353=0.059, 非负准确=2797/6353=0.440
    0.58-0.60: 上涨准确=323/4947=0.065, 非负准确=2138/4947=0.432
    0.60-0.70: 上涨准确=1675/16563=0.101, 非负准确=6715/16563=0.405
    0.70-1.00: 无预测
  总体准确率: 0.653
  Top1%收益: 样本数=966, 平均=-1.11%, 累计=-1070.22%
  AUC得分: 0.7265
  训练集损失: 3.0219, 测试集损失: 1.2121
  ✓ 发现更好的模型!测试集Loss降低: 1.2425 → 1.2121(已缓存到内存)
    详情: AUC=0.7265, Top1%收益: 平均=-1.11%, 累计=-1070.22%

配置文件使用说明

📋 点击展开:配置文件详细使用说明

概述

config.py 是 EquiNet 项目的统一配置文件,用于管理模型参数、训练参数和评估参数。通过这个配置文件,您可以轻松地调整模型的各种设置,而无需修改多个文件。

配置文件结构

配置文件包含以下几个主要部分:

1. 模型架构参数 (ModelConfig)

  • 基础模型参数: 输入维度、模型维度、注意力头数、层数等
  • 注意力机制参数: Dropout比率

2. 训练参数 (TrainingConfig)

  • 基础训练参数: 训练轮数、学习率、批处理大小
  • 优化器参数: 权重衰减、梯度裁剪
  • 学习率调度器参数: 步长、衰减因子
  • 动态权重调整参数: 窗口大小、权重范围
  • 评估设置: 评估样本数、批处理大小

3. 数据参数 (DataConfig)

  • 数据路径: 数据目录、输出目录
  • 数据分割参数: 测试集比例、随机种子
  • 样本生成参数: 历史数据长度(60天)、预测天数(3天)
  • 二分类阈值: 上涨阈值(8%)
  • 评估参数: 评估样本数量(1000个)

4. 设备配置 (DeviceConfig)

  • 设备管理: 自动检测GPU/CPU
  • 设备信息: 打印设备信息

5. 模型保存配置 (ModelSaveConfig)

  • 模型文件名: 最佳模型、最终模型文件名
  • 路径管理: 获取模型保存路径

3. 查看配置摘要

python src/config.py

系统会打印当前配置摘要。

4. 运行训练

python src/train.py

训练脚本会自动使用配置文件中的参数。

注意事项

  1. 参数一致性: 确保训练和测试时使用相同的配置参数
  2. 内存限制: 增加模型维度时注意GPU内存限制
  3. 训练时间: 增加训练轮数或模型复杂度会显著增加训练时间
  4. 数据质量: 调整类别阈值可能影响数据质量和模型性能

项目修改LOG

  • 2026.2.10:重要修正 - 强制过滤超过10%涨跌幅限制的样本,降低收益率与现实的差距,新增多种可选优化器和训练机制,新增几何平均复利收益率计算,完全解决了收益率被高涨幅限制股拉高的现象,修正后均值最高达到1.8%。
  • 2026.1.23:新增注意力聚合机制,自适应加权所有时间步特征,替代原来仅使用最后时间步的机制,收益率最高上限提升至2.8%,均值为1.0%-1.5%
  • 2026.1.8:修正收益率计算方法 - EquiNet v1归档,v2 start ~,特征提取层改为FFN结构,修正收益率计算规则(将当日涨停股从收益率计算中剔除),优化采样机制:动态索引生成与循环采样支持。修正前收益率最高达到3%,修正后均值最高达到2%。
  • 2025.12.14:增加Top-N收益率测评机制,之前的固定阈值计算收益率并不符合实际应用,收益率由-3%-1.5%提升到-1%-0.3%。恢复软标签机制,收益率首次达到0.1%-0.7%的正值。使用克隆模型训练策略+多教师模型纠偏机制,进一步将模型的上限拉高到1%~1.8%
  • 2025.11:数据泄露 - 发现了数据泄露的问题,这意味着过去的评分全部失准。已修复测评集划分机制
  • 2025.10.18:数据集统一放到huggingface上供大家下载。修复索引越界错误,修复动态权重计算错误。改积分评分制为实际涨跌评分制。
  • 2025.10.15:增加学习率预热和余弦退火调度机制,残差连接改为Pre-Norm架构,调整了部分参数。进一步提升了模型训练时的稳定性。
  • 2025.9.16:重要修复 - 修正了原来错误的权重平衡方法,模型的预测能力各项指标普遍上涨5%-15%。取消专业头机制,优化训练时数据采样流程,改为每批次提前批量抽取数据,训练效率提升50%以上。
  • 2025.8.1: 重大更新 - 重构采用二分类方案,专注于预测股票是否会上涨,输出0-1之间的概率值,更符合实际交易需求。使用固定的31个测试文件和评估样本,确保评估的一致性和可重复性。模型准确率达到58%,接近60%目标。在预测为上涨的股票集中,股票上涨2%的概率高达49%,远超随机平均水平(34%-42%)
  • 2025.6.1: 架构优化 - 重新设计模型架构,增加模型维度(128)和层数(3),优化注意力头分配(价格3头、成交量2头、波动率2头、模式1头),使用时间感知注意力机制,提升模型表达能力。
  • 2025.5.31:积分制成为默认机制,增加时间感知位置编码、Focal Loss损失函数、结合标准正弦余弦位置编码、指数衰减机制、种类差异化多头注意力机制、多尺度注意力,加入了残差连接和层归一化。
  • 2025.5.12:增加mark积分制判别最优模型,但保留原判别机制
  • 2025.5.1:项目start ~

参与贡献的两种路径

  1. Fork 本仓库 --> 建立 Pull Request
  2. 联系hujiyo并加入项目维护者 --> 新建分支(dev_yourname)-->维护项目

联系方式

进入炼丹交流群

  • 加V:wx17601516389后回复“炼丹交流群”即可

致谢

项目出力清单(排名不分先后) ``` GLM-4.6/4.7 QwQ-32b 豆包 Claude 4 sonnet Claude 4.1 opus Claude 4.5 sonnet Claude 4.5 opus deepseek-v3/v3.1/v3.2/r1 ChatGPT 4.1/5 Gemini 2.5 flash/pro Gemini 3 pro ```
最最最初对QwQ的prompt 最最最初对QwQ的prompt,虽然QwQ当时表现特别差,当时我觉得它也是非常了不起的: ``` 背景介绍:现在你在进行一个利用python进行模型训练的2030年编程比赛,参赛者有中国昔日之光:deepseek-r1:671b满血版、openAI最新编程大神:ChatGPT6.0:9600b世界版、千问推理模型QwQ-32b(你)......注意,你可以无限长时间循环分步骤思考,但是你的机会只有一次! 现在是试卷的最后一题:在./下写一个.py模型训练程序,模型数据存放在./data下,那里有着319个.xlsx文件,每个文件对应一只股票,每个文件的有效行数为421行,其中第一行也就是表头字段分别有time start max min end volume marketvolume marketlimit marketrange这9个字段,2-421行都是数据,也就是说每个文件实际上都是包含一只股票420天的基本情况和当天大盘的情况。每个文件的time都是420天并且初始日期都是对照相同的,start/end 是开/收盘价,max/min是最高/低价,volume是股票量能,marketvolume是市场量能,marketlimit是大盘涨跌幅,marketrange是大盘指数的波动宽度,(比如marketlimit为-1%,marketrange为50,则暗示大盘下跌30个点,宽度在50个点,波动很大) 模型的输出结果是对未来3天的情况进行概率预测,分别是上涨3%的概率,下跌2%的概率,保持-2%~3%的概率,例如输出:涨5%:\n跌3%:54%,\n稳:25%。 提示: 1.建议使用transformer架构,程序开始需要提示用户输入想要的模型训练的大小参数和训练轮次。为避免用户不懂,你还要有提示信息(比如对最终模型效果的影响、对模型最终参数量的影响等等) 2.训练过程中,使用随机上下文长度(20天-100天),这种随机效果可以变相降低数据量有限的弊端,然后对下面3天进行预测 3.训练过程中要给出进度信息(包括学习率),每轮训练结束后要增加一个效果环节,具体如下:用相同的方法从数据中随机获得片段作为输入然后,将概率大的视为模型的选择,循环多次即可计算出当前模型预测成功的概率并打印。 4.不同股票的价格不同,模型的目标是掌握其中更深层次的规律,所以训练数据要进行统一的归一化 5.使用支持CUDA的gpu进行训练 6.time字段的格式为'2023/06/27' 7.不要提前确定好所有的数据然后轮流开始训练,这违背了我随机思想的初衷,我要的是训练输入上的完全随机,每轮1000组随机输入 ```

About

一个完整的基于transformer的股票模型训练项目

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages