-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config #19660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config #19660
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1429,25 +1429,19 @@ def matryoshka_dimensions(self): | |
return getattr(self.hf_config, "matryoshka_dimensions", None) | ||
|
||
def get_and_verify_max_len(self, max_model_len: int): | ||
tokenizer_config = try_get_tokenizer_config( | ||
self.tokenizer, | ||
trust_remote_code=self.trust_remote_code, | ||
revision=self.tokenizer_revision) | ||
Comment on lines
+1433
to
+1435
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a comment explaining why # Fetch tokenizer config to determine model max length.
tokenizer_config = try_get_tokenizer_config(
self.tokenizer,
trust_remote_code=self.trust_remote_code,
revision=self.tokenizer_revision) |
||
max_model_len = _get_and_verify_max_len( | ||
hf_config=self.hf_text_config, | ||
tokenizer_config=tokenizer_config, | ||
max_model_len=max_model_len, | ||
disable_sliding_window=self.disable_sliding_window, | ||
sliding_window_len=self.get_hf_config_sliding_window(), | ||
spec_target_max_model_len=self.spec_target_max_model_len, | ||
encoder_config=self.encoder_config) | ||
|
||
tokenizer_config = try_get_tokenizer_config( | ||
self.tokenizer, | ||
trust_remote_code=self.trust_remote_code, | ||
revision=self.tokenizer_revision) | ||
|
||
if tokenizer_config is None: | ||
return max_model_len | ||
|
||
model_max_length = tokenizer_config.get("model_max_length", | ||
max_model_len) | ||
max_model_len = min(max_model_len, model_max_length) | ||
logger.info("Using max model len %s", max_model_len) | ||
return max_model_len | ||
|
||
|
||
|
@@ -3283,6 +3277,7 @@ def _get_and_verify_dtype( | |
|
||
def _get_and_verify_max_len( | ||
hf_config: PretrainedConfig, | ||
tokenizer_config: Optional[dict], | ||
max_model_len: Optional[int], | ||
disable_sliding_window: bool, | ||
sliding_window_len: Optional[Union[int, list[Optional[int]]]], | ||
|
@@ -3309,7 +3304,7 @@ def _get_and_verify_max_len( | |
"max_seq_length", | ||
"seq_len", | ||
] | ||
# Choose the smallest "max_length" from the possible keys. | ||
# Choose the smallest "max_length" from the possible keys | ||
max_len_key = None | ||
for key in possible_keys: | ||
max_len = getattr(hf_config, key, None) | ||
|
@@ -3332,6 +3327,13 @@ def _get_and_verify_max_len( | |
derived_max_model_len = min(derived_max_model_len, | ||
sliding_window_len_min) | ||
|
||
# Consider model_max_length in tokenizer_config | ||
if tokenizer_config: | ||
tokenizer_model_max_length = tokenizer_config.get( | ||
"model_max_length", derived_max_model_len) | ||
derived_max_model_len = min(derived_max_model_len, | ||
tokenizer_model_max_length) | ||
Comment on lines
+3333
to
+3335
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a log message here to indicate when tokenizer_model_max_length = tokenizer_config.get(
"model_max_length", derived_max_model_len)
if tokenizer_model_max_length < derived_max_model_len:
logger.info(f"Limiting max model length to {tokenizer_model_max_length} based on tokenizer config.")
derived_max_model_len = min(derived_max_model_len,
tokenizer_model_max_length) |
||
|
||
# If none of the keys were found in the config, use a default and | ||
# log a warning. | ||
if derived_max_model_len == float("inf"): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a test case where
max_model_len
is a string (e.g.,'1k'
) to ensure the parsing logic handles human-readable formats correctly.