【Hackathon 10th Spring No.6】CrystalLLM Model Reproduction#266
Open
r-cloudforge wants to merge 1 commit into
Open
【Hackathon 10th Spring No.6】CrystalLLM Model Reproduction#266r-cloudforge wants to merge 1 commit into
r-cloudforge wants to merge 1 commit into
Conversation
|
Thanks for your contribution! |
1bf29fc to
d89c271
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
概述
实现 CrystalLLM 模型复现,基于论文 Crystal Structure Generation with Autoregressive Large Language Modeling (Antunes et al., Nature Communications, 2024),参考实现 lantunes/CrystaLLM(MIT License)。
CrystalLLM 使用 GPT-2 风格的自回归 Transformer 模型,以 CIF (Crystallographic Information File) 文本为序列,直接生成晶体结构。模型在 230 万结构(MP + OQMD + NOMAD)上训练,支持 Perov-5、MP-20、Carbon-24、MPTS-52 四个标准基准数据集。
新增内容
模型 (
ppmat/models/crystalllm/)CrystalLLM: GPT-2 风格 Transformer,支持 Small (8层/8头/512维) 和 Large (16层/16头/1024维) 配置CIFTokenizer: 自定义 CIF 分词器,vocab_size=371(89 原子 + 10 数字 + 31 关键词 + 13 符号 + 227 空间群 + 1 UNK)paddle.matmul(x, wte.weight, transpose_y=True)实现 LM head 与 token embedding 共享MCTS 采样器 (
ppmat/sampler/crystalllm_sampler.py)CrystalLLMSampler: 统一采样接口,支持标准自回归采样(sample())和 MCTS 引导采样(sample_mcts())MCTSSampler: 基于 PUCT/UCT 的蒙特卡洛树搜索,使用晶体结构有效性作为 rewardMCTSEvaluator: 基于 pymatgen 的晶体结构评估器(键长合理性 + 空间群一致性 + 化学式一致性)PUCTSelector、UCTSelector、GreedySelectorContextSensitiveTreeBuilder数据集 (
ppmat/datasets/cif_token_dataset.py)CIFTokenDataset: 加载预分词的 CIF 二进制数据(uint16 memmap)starts.pkl索引每个 CIF 的起始位置)评估指标 (
ppmat/metrics/crystal_metrics.py)is_valid(): 综合检查(CIF 可解析 + 键长合理 + 空间群一致 + 化学式一致 + 原子占位一致)bond_length_reasonableness_score(): 基于 pymatgen CrystalNN 的键长合理性评分is_space_group_consistent(),is_formula_consistent(),is_atom_site_multiplicity_consistent()get_unit_cell_volume(),remove_atom_props_block(): MCTS 采样辅助函数权重转换 (
structure_generation/convert_weights.py)torch.compile训练产生的_orig_mod.transformer.前缀配置文件 (
structure_generation/configs/crystalllm/)测试与评估
test/test_unit.py: 39 个单元测试(分词器、数据集、指标、模型前向)test/test_pipeline.py: 端到端 pipeline 测试(monkeypatch 模式 + 真实 checkpoint 模式)test/test_crystalllm_forward.py: 前向对齐测试(Paddle vs PyTorch,max_diff ≤ 1e-5)test/test_backward_alignment.py: 反向对齐测试(Paddle vs PyTorch,5 步训练 loss 对比)eval_v1_small.py: v1_small 提示式采样评估脚本eval_multi_dataset.py: 全部 4 个数据集的统一评估脚本(perov-5, carbon-24, mp-20, mpts-52)工具
tools/prepare_netdisk.sh: 一键下载 Zenodo 权重、转换 Paddle 格式、整理上传目录验收结果
1. 前向对齐(验收标准:diff ≤ 1e-6)
2. 反向对齐(验收标准:训练2轮以上,loss一致)
3. 采样指标(验收标准:误差 ≤ 5%)
v1_small 500 样本(论文 prompt 协议):
4. 数据集覆盖(验收标准:原论文所有数据集)
eval_multi_dataset.py可一键评估全部 4 个数据集5. 新任务类型文档
structure_generation/目录:基于文本的晶体结构生成任务6. 预训练模型 / 数据集
tools/prepare_netdisk.sh)使用方式
相关 issue
Closes part of #194 (CrystalLLM — 任务 #1)
RFC: PaddlePaddle/community#1256