Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential masking issue in get_masked_and_orig_text function #184

Open
aldopareja opened this issue Aug 27, 2024 · 0 comments
Open

Potential masking issue in get_masked_and_orig_text function #184

aldopareja opened this issue Aug 27, 2024 · 0 comments
Assignees

Comments

@aldopareja
Copy link
Member

File: src/instructlab/training/data_process.py

In the get_masked_and_orig_text function, there's a potential issue with the masking mechanism:

def get_masked_and_orig_text(sample):
    labels = sample["labels"]
    input_ids = sample["input_ids"]
    label = [pad_tk[0] if tk == -100 else tk for tk in labels]
    text = tokenizer.decode(label).replace(pad_str, "<mask>")
    orig_text = tokenizer.decode(input_ids)
    return text, orig_text

Issue Description

The function uses pad_str to identify and replace masked tokens with "". However, if pad_str happens to be the same as another special token in the vocabulary, it could lead to unintended replacements, making it appear that some tokens are masked when they actually aren't.

Recommendation

To avoid this potential confusion, consider one of the following approaches:

  1. Use a unique string for masking that's guaranteed not to appear in the tokenizer's vocabulary.
  2. Utilize a dedicated special token for masking (e.g., "") and add it to the tokenizer's special tokens.
  3. Implement the masking logic directly on the token IDs before decoding, ensuring only the intended tokens are masked.

this could fix it, but we should instead add it when adding <|pretrain|> and the other one.

def get_masked_and_orig_text(sample):
    labels = sample["labels"]
    input_ids = sample["input_ids"]
    tokenizer.add_special_tokens({"additional_special_tokens": ["<MASK>"]})
    mask_token_id = tokenizer.convert_tokens_to_ids("<MASK>")
    label = [mask_token_id if tk == -100 else tk for tk in labels]
    text = tokenizer.decode(label)
    orig_text = tokenizer.decode(input_ids)
    return text, orig_text
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants