-
Notifications
You must be signed in to change notification settings - Fork 30.5k
Description
System Info
transformers==4.43
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- Take a working fine-tuning pipeline that uses custom 4D attention masks in transformers 4.40 and fine-tunes a Llama 3 model
- Run that same pipeline with transformers 4.41 (or the most recent version, 4.43)
Expected behavior
I expect that behavior with 4D attention masking will stay consistent from 4.40 to 4.43. However, I understand that the 4D masking was a new feature, and perhaps some changes were necessary in order to make it work with the rest of the framework.
First, thanks again for the implementation of 4D masking. This is really useful to my work and was critical for us in developing+releasing our recent work on TabuLa.
It seems that perhaps a breaking change was introduced to masking, specifically in this PR, where masks were no longer "negated" for the user. After this change, it appears that masks that previously worked (before the PR) now need to be "negated" in order to work; otherwise ValueError is raised here when fine-tuning Llama model.
However, to me it's not clear what "negation" actually means. Negation doesn't appear to be documented anywhere. Furthermore, it seems easy to make an attention mask that would pass this block (i.e., having a max value of zero) but that might be incorrect in other ways. It seems like there is some negation logic here, but this won't work for a typical binary attention mask: doing 1.0 - attention_mask
simply flips the mask, so if there were any zero entries before they will now be 1, triggering the same ValueError as above.
So, in this issue I have the following question:
- What is a "negated" mask, and how can I get from a "standard" binary attention mask used elsewhere in the transformers library to a "negated" attention mask that works with the new 4D attention masking scheme?
And in this issue I also suggest the following changes:
- Document what "negated" attention mask is (ideally in the docstring of
_update_causal_mask()
but that also could live anywhere the maintainers decide is appropriate) - Ideally, provide a function that negates a binary attention mask (as I mention above, this code might be a starting point, but it doesn't seem to work on standard binary attention masks so it is likely more is needed, I don't know).
- Improve the message raised by the ValueError triggered here to explicitly describe how a negated mask should be formed and what the values represent (if they are not 1/0, I expect that this will not be obvious to users of
transformers
accustomed to such masks) - Improve the check that raises that value to do more than simply check the max of the attention mask, and truly check that the mask is properly negated (whatever that means)
Happy to contribute to this if someone can provide answers to these questions -- again, this is a terrific capability to have in the library and I am super grateful to the team for the work on it!