Skip to content

Commit

Permalink
Merge pull request #2 from botev/t1
Browse files Browse the repository at this point in the history
Adding tests that work.
  • Loading branch information
botev authored Apr 3, 2024
2 parents 6085aa5 + 3887c14 commit 66b3ca6
Show file tree
Hide file tree
Showing 3 changed files with 444 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class RecurrentGemmaConfig(PretrainedConfig):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the MLP representations.
num_heads (`int`, *optional*, defaults to 10):
num_attention_heads (`int`, *optional*, defaults to 10):
The number of heads for the attention block and the number of
heads/blocks for the block-diagonal layers used in the RG-LRU gates.
This number must divide `hidden_size` and `lru_width`.
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
vocab_size=256000,
hidden_size=2560,
intermediate_size=3 * 2560,
num_heads=10,
num_attention_heads=10,
lru_width=None,
embeddings_scale_by_sqrt_dim=True,
attention_window_size=2048,
Expand All @@ -124,8 +124,7 @@ def __init__(
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_attention_heads = num_attention_heads
self.lru_width = lru_width if lru_width is not None else hidden_size
self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim
self.attention_window_size = attention_window_size
Expand All @@ -134,7 +133,9 @@ def __init__(
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self._block_types = block_types
self._block_types = list(block_types)

self.head_dim = self.hidden_size // self.num_attention_heads

super().__init__(
pad_token_id=pad_token_id,
Expand Down
Loading

0 comments on commit 66b3ca6

Please sign in to comment.