Skip to content

Commit

Permalink
Merge pull request #52 from jzhang38/continue_pretrain
Browse files Browse the repository at this point in the history
Continue pretrain
  • Loading branch information
jzhang38 authored Oct 3, 2023
2 parents bd7a16c + 6e706ab commit 9627da5
Show file tree
Hide file tree
Showing 11 changed files with 802 additions and 21 deletions.
19 changes: 19 additions & 0 deletions lit_gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,25 @@ def norm_class(self) -> Type:
intermediate_size=2048,
n_query_groups=1,
),
dict(
org="StatNLP-research",
name="code_tiny_LLaMA_1b",
block_size=8192,
vocab_size=32000,
padding_multiple=64,
n_layer=22,
n_head=32,
n_embd=2048,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
_norm_class="FusedRMSNorm",
norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6
_mlp_class="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
condense_ratio= 4
),
]
configs.extend(tiny_LLaMA)

Expand Down
33 changes: 23 additions & 10 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def forward(

cos, sin = self.rope_cache
if use_kv_cache:

cos = cos.index_select(0, input_pos)
sin = sin.index_select(0, input_pos)
mask = self.mask_cache.index_select(2, input_pos)
Expand All @@ -100,12 +101,12 @@ def forward(

# forward the model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)

if not use_kv_cache:
for block in self.transformer.h:
x, *_ = block(x, (cos, sin), max_seq_length)
else:
self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1))
self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
for i, block in enumerate(self.transformer.h):
x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i])

Expand All @@ -132,14 +133,15 @@ def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:

def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:
B = idx.size(0)
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups

k_cache_shape = (
B,
heads,
max_seq_length,
heads,
rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size),
)
v_cache_shape = (B, heads, max_seq_length, self.config.head_size)
v_cache_shape = (B, max_seq_length, heads, self.config.head_size)
device = idx.device
return [
(torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
Expand All @@ -165,6 +167,7 @@ def forward(
input_pos: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, Optional[KVCache]]:

n_1 = self.norm_1(x)
h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache)
if self.config.parallel_residual:
Expand Down Expand Up @@ -248,10 +251,11 @@ def forward(
if input_pos[-1] >= max_seq_length:
input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
# shift 1 position to the left
cache_k = torch.roll(cache_k, -1, dims=2)
cache_v = torch.roll(cache_v, -1, dims=2)
k = cache_k.index_copy_(2, input_pos, k)
v = cache_v.index_copy_(2, input_pos, v)
cache_k = torch.roll(cache_k, -1, dims=1)
cache_v = torch.roll(cache_v, -1, dims=1)

k = cache_k.index_copy_(1, input_pos, k)
v = cache_v.index_copy_(1, input_pos, v)
kv_cache = k, v

y = self.scaled_dot_product_attention(q, k, v, mask=mask)
Expand All @@ -267,6 +271,7 @@ def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
):
scale = 1.0 / math.sqrt(self.config.head_size)

if (
FlashAttention2Available
and mask is None
Expand All @@ -276,7 +281,15 @@ def scaled_dot_product_attention(
from flash_attn import flash_attn_func

return flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=scale, causal=True)
assert False
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if q.size() != k.size():
k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
)
return y.transpose(1, 2)


Expand Down
Loading

0 comments on commit 9627da5

Please sign in to comment.