-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flatting Packing / maybe fix #5443 and #5426 #5458
Conversation
if total_length >= cutoff_len: | ||
break | ||
|
||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里导致 Inst 数据被异常截断 #5426, 也许考虑引入一个新的参数来保证是否可以被截断?我的样本是2轮次的 tool 调用,但是如果截断就只会学习到输出 tool_calls 没有最后的答案了。 而且这里现在截断的实现方式将会导致 user 和 assistant 的内容被截断。如在 mistral 模板中, 会产生 [INST] xxxxxxx
的结果,而xxxxx[/INST]
就不见了,这显然是不正确的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得不是这里的问题?non-packing 也会有同样的行为
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不过我确实觉得需要加一个参数控制一下,因为有些情况下不允许一个样本被中间截断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不截 prompt 的话 assistant 放在哪里呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接跳过,drop掉这个样本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加了参数控制是否可以截断,默认不能截断
packed_input_ids.append(batch_input_ids[index]) | ||
packed_labels.append(batch_labels[index]) | ||
packed_images.append(batch_images[index]) | ||
packed_videos.append(batch_videos[index]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
延迟处理,此时先不返回 position ids,在 collator 中整合并返回 position ids
data_args.flatting_packing and | ||
(getattr(model.config, "_attn_implementation", None) != "flash_attention_2") | ||
): | ||
logger.warning("The `flatting_packing` only support `flash_attention_2`! Maybe cause Out of memory!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也许应该强制开启 fa2,但是这个时候已经晚了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flat packing 应该不是和 fa2 强制绑定的,本质上就是 4d attention mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该是绑定的,packing-with-FA2,他是通过 flash-attention 直接计算的,不需要 4d attention mask 了,虽然本质上是这样的,但是 fa2 不能输入 4d attention mask,细节可以看这个 transformers pull request
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我知道,他的实现是绑定的,原理上 sdpa 和 eager 照样能用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那可能也行
想问下这个flatting_packing和neat_packing的区别是什么呢,单看选项说明(Enable sequence packing with flattening)仍然不太理解 |
实现了这个 packing-with-FA2,经测试,该方案练吞吐量比 neat_packing 更高 |
mistral 的 function call 我还在修改,晚会提交 |
could you open another pr for function call updates? |
好的,那我重新整理一下代码? |
2. fix knapsack, may cause hiyouga#5443 3. avoid supervised examples wrongly truncation
现在应该是一个干净的提交,工具调用的 PR 在 #5473 |
我在相同数据集上相同训练配置尝试了一下neat_packing 和 flatting_packing 发现flatting_packing 初始loss显著高于neat_packing(2.1 vs 0.9) 模型参数YI-9B lr=1e-5 |
找到flatten_packing初始loss高的原因了,transformers版本需要升级到最新4.45.0,accelerate==0.34.2 |
Any updates for this PR? |
好心人做完实验了吗,效果对比怎么样哇 |
@hiyouga 目前的实现有什么问题吗? |
目前是一个什么状态了,neat_packing + fa2 是否达到了同等的训练loss,目前测试下来实际效果挺差的,要么就是无限循环输出,要么就是输出一些怪怪的其他文字,如韩文,法文等。明显是数据concat时候带进去的 |
neat_packing 的 concat 似乎是有问题的,我这里处理了一下,但是不知道什么原因一直没合并 |
What does this PR do?
Before submitting