Skip to content

Ignore keys on validate_rope #33753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,23 @@ def _compute_llama3_parameters(
}


def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
def _check_received_keys(
rope_type: str,
received_keys: set,
required_keys: set,
optional_keys: Optional[set] = None,
ignore_keys: Optional[set] = None,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if "type" in received_keys:
received_keys -= {"type"}
required_keys.add("rope_type")

# Some models need to store model-specific keys, and we don't want to throw warning at them
if ignore_keys is not None:
received_keys -= ignore_keys

missing_keys = required_keys - received_keys
if missing_keys:
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
Expand All @@ -379,47 +389,47 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")


def _validate_default_rope_parameters(config: PretrainedConfig):
def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)


def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")


def _validate_yarn_parameters(config: PretrainedConfig):
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
Expand All @@ -444,14 +454,14 @@ def _validate_yarn_parameters(config: PretrainedConfig):
)


def _validate_longrope_parameters(config: PretrainedConfig):
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)

partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
Expand Down Expand Up @@ -494,12 +504,12 @@ def _validate_longrope_parameters(config: PretrainedConfig):
)


def _validate_llama3_parameters(config: PretrainedConfig):
def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)

factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
Expand Down Expand Up @@ -541,7 +551,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
}


def rope_config_validation(config: PretrainedConfig):
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
Expand All @@ -553,7 +563,7 @@ def rope_config_validation(config: PretrainedConfig):
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config)
validation_fn(config, ignore_keys=ignore_keys)
else:
logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/qwen2_vl/configuration_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,13 @@ def __init__(

# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default'
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
rope_config_validation(self, ignore_keys={"mrope_section"})

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
13 changes: 13 additions & 0 deletions tests/utils/test_modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def test_rope_validation(self):
with self.assertRaises(KeyError):
rope_config_validation(config)

# Any other parameters passed to RoPE will raise a warning that a particular key is not used
# But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys`
model_specific_kwarg = "mrope_sections" # e,g in Qwen2-VL

for rope_type in all_rope_types:
if rope_type == "default":
config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True}
rope_config_validation(config, ignore_keys={model_specific_kwarg})
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
rope_config_validation(config)
self.assertEqual(len(logs.output), 1)
self.assertIn(model_specific_kwarg, logs.output[0])

def test_default_rope_function_bc(self):
config = LlamaConfig()
device = torch_device
Expand Down
Loading