Skip to content

Bias tensors #1259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,17 @@ def convert_hf_checkpoint(
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
Expand All @@ -93,11 +100,10 @@ def convert_hf_checkpoint(
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_heads):
dim = config.dim
return (
w.view(n_heads, 2, config.head_dim // 2, dim)
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
.transpose(1, 2)
.reshape(config.head_dim * n_heads, dim)
.reshape(w.shape)
)

merged_result = {}
Expand Down Expand Up @@ -130,6 +136,7 @@ def load_safetensors():
continue
assert state_dict is not None, f"Unable to load tensors from {file}"
merged_result.update(state_dict)

final_result = {}
for key, value in merged_result.items():
if "layers" in key:
Expand All @@ -145,16 +152,18 @@ def load_safetensors():
final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
if "wq.weight" in key or "wq.bias" in key:
wk_key = key.replace("wq", "wk")
wv_key = key.replace("wq", "wv")
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
k = final_result[wk_key]
v = final_result[wv_key]
q = permute(q, config.n_heads)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
del final_result[wk_key]
del final_result[wv_key]
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
torch.save(final_result, model_dir / "model.pth")
print("Done.")
Expand Down
43 changes: 24 additions & 19 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
try:
# TODO: remove this after we figure out where in torchtune an `evaluate` module
# is being imported, which is being confused with huggingface's `evaluate``.
import lm_eval # noqa
import lm_eval # noqa
except Exception:
pass

Expand Down Expand Up @@ -278,6 +278,9 @@ class TransformerArgs:
# For pipeline parallel
n_stages: int = 1
stage_idx: int = 0
# Optional biases
attention_bias: bool = False
feed_forward_bias: bool = False

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down Expand Up @@ -394,7 +397,7 @@ def from_name(cls, name: str):
config = [
config
for config in known_model_params
if config in str(name).upper() or config in str(name)
if config.upper() in str(name).upper() or config in str(name)
]

# We may have two or more configs matched (e.g., "7B" and
Expand Down Expand Up @@ -471,7 +474,7 @@ def build_model(self) -> nn.Module:
modules[name] = module_class(TransformerArgs.from_params(config_args))
else:
modules[name] = module_class(**config_args)

# Temporary add extra params to the DeepFusionModel.
# TODO: Remove it once we can make fusion model configurable in model_param.
if recipe.fusion_class == DeepFusionModel:
Expand Down Expand Up @@ -730,16 +733,16 @@ def __init__(self, config: TransformerArgs):

# key, query, value projections for all heads, but in a batch
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias)
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=config.attention_bias)
self.wk = nn.Linear(
config.dim, config.n_local_heads * config.head_dim, bias=False
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
)
self.wv = nn.Linear(
config.dim, config.n_local_heads * config.head_dim, bias=False
config.dim, config.n_local_heads * config.head_dim, bias=config.attention_bias
)

self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=config.attention_bias)
self.kv_cache = None

self.n_heads = config.n_heads
Expand All @@ -766,14 +769,16 @@ def load_hook(self, state_dict, prefix, *args):
# wv = state_dict.pop(prefix + "wv.weight")
# state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

if prefix + "wqkv.weight" in state_dict:
wqkv = state_dict.pop(prefix + "wqkv.weight")
q_size = self.n_heads * self.head_dim
kv_size = self.n_local_heads * self.head_dim
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
state_dict[prefix + "wq.weight"] = wq
state_dict[prefix + "wk.weight"] = wk
state_dict[prefix + "wv.weight"] = wv
for tensor_suffix in ["weight", "bias"]:
wqkv_key = f"{prefix}wqkv.{tensor_suffix}"
if wqkv_key in state_dict:
wqkv = state_dict.pop(wqkv_key)
q_size = self.n_heads * self.head_dim
kv_size = self.n_local_heads * self.head_dim
wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0)
state_dict[f"{prefix}wq.{tensor_suffix}"] = wq
state_dict[f"{prefix}wk.{tensor_suffix}"] = wk
state_dict[f"{prefix}wv.{tensor_suffix}"] = wv

return

Expand Down Expand Up @@ -852,9 +857,9 @@ def forward(
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=config.feed_forward_bias)
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=config.feed_forward_bias)

def distribute(self, device_mesh: DeviceMesh):
parallelize_module(self.w1, device_mesh, ColwiseParallel())
Expand Down
Loading