diff --git a/src/transformers/models/jamba/configuration_jamba.py b/src/transformers/models/jamba/configuration_jamba.py index 58c8a685feab9b..b493db7ed456b3 100644 --- a/src/transformers/models/jamba/configuration_jamba.py +++ b/src/transformers/models/jamba/configuration_jamba.py @@ -193,6 +193,9 @@ def __init__( self.attn_layer_period = attn_layer_period self.attn_layer_offset = attn_layer_offset + self._check_supported_offset("attention", self.attn_layer_period, self.attn_layer_offset) + self._check_supported_offset("expert", self.expert_layer_period, self.expert_layer_offset) + self.use_mamba_kernels = use_mamba_kernels self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv @@ -222,3 +225,9 @@ def layers_num_experts(self): self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1 for i in range(self.num_hidden_layers) ] + + def _check_supported_offset(self, property_: str, period: int, offset: int): + if offset >= period: + raise ValueError( + f"{property_} layer offset ({offset}) must be smaller than {property_} layer period ({period})" + ) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 6cbfe62cfe172b..6e1a2cf2cf9c44 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -50,6 +50,48 @@ ) +class JambaConfigTester(ConfigTester): + def _create_attn_config(self, attn_layer_offset: int, attn_layer_period: int): + _input_dict = self.inputs_dict.copy() + _input_dict["attn_layer_offset"] = attn_layer_offset + _input_dict["attn_layer_period"] = attn_layer_period + return self.config_class(**_input_dict) + + def _create_expert_config(self, expert_layer_offset: int, expert_layer_period: int): + _input_dict = self.inputs_dict.copy() + _input_dict["expert_layer_offset"] = expert_layer_offset + _input_dict["expert_layer_period"] = expert_layer_period + return self.config_class(**_input_dict) + + def test_attn_offsets(self): + self._create_attn_config(attn_layer_offset=0, attn_layer_period=4) + self._create_attn_config(attn_layer_offset=1, attn_layer_period=4) + self._create_attn_config(attn_layer_offset=2, attn_layer_period=4) + self._create_attn_config(attn_layer_offset=3, attn_layer_period=4) + with self.parent.assertRaises(ValueError): + self._create_attn_config(attn_layer_offset=4, attn_layer_period=4) + with self.parent.assertRaises(ValueError): + self._create_attn_config(attn_layer_offset=5, attn_layer_period=4) + + def test_expert_offsets(self): + self._create_expert_config(expert_layer_offset=0, expert_layer_period=4) + self._create_expert_config(expert_layer_offset=1, expert_layer_period=4) + self._create_expert_config(expert_layer_offset=2, expert_layer_period=4) + self._create_expert_config(expert_layer_offset=3, expert_layer_period=4) + with self.parent.assertRaises(ValueError): + self._create_expert_config(expert_layer_offset=4, expert_layer_period=4) + with self.parent.assertRaises(ValueError): + self._create_expert_config(expert_layer_offset=5, expert_layer_period=4) + + def test_jamba_offset_properties(self): + self.test_attn_offsets() + self.test_expert_offsets() + + def run_common_tests(self): + self.test_jamba_offset_properties() + return super().run_common_tests() + + class JambaModelTester: def __init__( self, @@ -302,7 +344,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def setUp(self): self.model_tester = JambaModelTester(self) - self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37) + self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests()