Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GPTNeoX models #32

Merged
merged 17 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[fix] some of the bugs preventing fine-tune run
+ There's still bugs in the attention dimensions mismatch
  • Loading branch information
naubull2 committed Sep 28, 2023
commit 9c9d0a2bdbb3c9430f0675d530f8ceb8f4049805
5 changes: 3 additions & 2 deletions fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="EleutherAI/gpt-neox-20b")
model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
model_type: Optional[str] = field(default="gpt-neox")

@dataclass
Expand Down Expand Up @@ -123,14 +123,15 @@ def train():
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
use_fast=True,
)

special_tokens_dict = dict()
Expand Down
10 changes: 5 additions & 5 deletions gptneox_attn_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def forward_attention(
present = (key, value) if use_cache else None

# NOTE: apply shift
group_size = int(q_len * group_size_ratio)
if q_len % group_size > 0:
raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
num_group = q_len // group_size
if self.training and not use_full:
def shift(qkv, num_heads, head_dim):
# qkv = [bsz, nh, q_len, d]
group_size = int(q_len * group_size_ratio)
if q_len % group_size > 0:
raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
num_group = q_len // group_size
qkv = qkv.transpose(1, 2)
# qkv = [bsz, q_len, nh, d]
qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1)
Expand All @@ -139,7 +139,7 @@ def shift(qkv, num_heads, head_dim):
if self.training and not use_full:
attn_output = attn_output.transpose(1, 2)
# [bsz, q_len, nh, hd]
attn_output[:, :, num_heads//2:] = attn_output[:, :, num_heads//2:].roll(group_size//2, dims=1)
attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1)
attn_output = attn_output.transpose(1, 2)

# Reshape outputs
Expand Down