Skip to content

Conversation

@Aznix07
Copy link
Contributor

@Aznix07 Aznix07 commented Dec 17, 2025

What does this PR do?

This PR fixes a regression in v5.0.0rc1 where the _decode method in TokenizerBackend was not respecting the clean_up_tokenization_spaces paramter, causing unwanted spaces to appear before punctuation in decoded output.

Reproduction:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('mlx-community/Llama-3.2-1B-Instruct-4bit')
text = tokenizer.decode([128000, 64, 1174, 65])
print(text)

Behavior:

  • v4.57.3: <|begin_of_text|>a,b ✅
  • v5.0.0rc1 (before fix): <|begin_of_text|>a ,b ❌ (extra space before comma)
  • v5.0.0rc1 (after fix): <|begin_of_text|>a,b ✅

Solution

Added the missing clean_up_tokenization_spaces logic to TokenizersBackend._decode() method. When enabled (default behavior), it removes extra spaces before punctuation using regex pattern matching.

Fixes #42913

Who can review?

@ArthurZucker @itazap

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42916&sha=99c932

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! is there a motivation to have that? We removed it because its unintuitive, and can be done by the user itself outside. do you have a specific usecase in mind?

@Aznix07
Copy link
Contributor Author

Aznix07 commented Dec 17, 2025

Hi @ArthurZucker!

Thanks for the response! I realize I should have clarified this before submitting the PR - apologize for that.

What I observed:

  • The clean_up_tokenization_spaces param exists in decode() and defaults to True
  • But in TokenizersBackend._decode(), it's not being applied
  • This causes: decode([128000, 64, 1174, 65])'a ,b' instead of 'a,b'

My question:
Was removing this cleanup intentional for v5.0.0?

If yes, I can close this PR. If it was unintended, Im happy to adjust the implementation based on your guidance.

Thanks!

@apaniukov
Copy link

Hey! is there a motivation to have that? We removed it because its unintuitive, and can be done by the user itself outside. do you have a specific usecase in mind?

@ArthurZucker the issue is that in v4 clean_up_tokenization_spaces flag worked for all tokenizers, and in v5 it works only when the particular tokenizer implementation defined clean_up_tokenization method.

  1. It is inconsistent with PythonBackend tokenizer, where there is default implementation in _decode method.
  2. clean_up_tokenization_spaces flag in decode method now works for some tokenizers, and stop working for other.
  3. (Some) Tokenizers that has self.clean_up_tokenization_spaces=True change behavior in v5, like Llama form the issue.

Removing the behavior and keeping the clean_up_tokenization_spaces flag in the decode is unintuitive, because we need to check does it really work for each tokenizer we use.

I created the issue about it earlier: #42898

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unexpected tokenizer behavior difference from v4 to v5

3 participants