Skip to content

Commit

Permalink
fix to jamba config, asserting attention and expert offset (#33316)
Browse files Browse the repository at this point in the history
* fix to jamba config, asserting attention and expert offset

* fix foramtting

* fix foramtting

* fix foramtting

* changed to error raise instead of assertion, added unittests

* fix

* changed t_ to property_

* changed t_ to property_

* quickfix

* ran code styler
  • Loading branch information
ErezSC42 authored Sep 17, 2024
1 parent 3476c19 commit 46c2757
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/transformers/models/jamba/configuration_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __init__(
self.attn_layer_period = attn_layer_period
self.attn_layer_offset = attn_layer_offset

self._check_supported_offset("attention", self.attn_layer_period, self.attn_layer_offset)
self._check_supported_offset("expert", self.expert_layer_period, self.expert_layer_offset)

self.use_mamba_kernels = use_mamba_kernels
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
Expand Down Expand Up @@ -222,3 +225,9 @@ def layers_num_experts(self):
self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
for i in range(self.num_hidden_layers)
]

def _check_supported_offset(self, property_: str, period: int, offset: int):
if offset >= period:
raise ValueError(
f"{property_} layer offset ({offset}) must be smaller than {property_} layer period ({period})"
)
44 changes: 43 additions & 1 deletion tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,48 @@
)


class JambaConfigTester(ConfigTester):
def _create_attn_config(self, attn_layer_offset: int, attn_layer_period: int):
_input_dict = self.inputs_dict.copy()
_input_dict["attn_layer_offset"] = attn_layer_offset
_input_dict["attn_layer_period"] = attn_layer_period
return self.config_class(**_input_dict)

def _create_expert_config(self, expert_layer_offset: int, expert_layer_period: int):
_input_dict = self.inputs_dict.copy()
_input_dict["expert_layer_offset"] = expert_layer_offset
_input_dict["expert_layer_period"] = expert_layer_period
return self.config_class(**_input_dict)

def test_attn_offsets(self):
self._create_attn_config(attn_layer_offset=0, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=1, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=2, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=3, attn_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_attn_config(attn_layer_offset=4, attn_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_attn_config(attn_layer_offset=5, attn_layer_period=4)

def test_expert_offsets(self):
self._create_expert_config(expert_layer_offset=0, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=1, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=2, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=3, expert_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_expert_config(expert_layer_offset=4, expert_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_expert_config(expert_layer_offset=5, expert_layer_period=4)

def test_jamba_offset_properties(self):
self.test_attn_offsets()
self.test_expert_offsets()

def run_common_tests(self):
self.test_jamba_offset_properties()
return super().run_common_tests()


class JambaModelTester:
def __init__(
self,
Expand Down Expand Up @@ -302,7 +344,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi

def setUp(self):
self.model_tester = JambaModelTester(self)
self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37)
self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37)

def test_config(self):
self.config_tester.run_common_tests()
Expand Down

0 comments on commit 46c2757

Please sign in to comment.