Skip to content

Commit

Permalink
Update lm eval model
Browse files Browse the repository at this point in the history
  • Loading branch information
mzio committed Sep 20, 2024
1 parent 1a97061 commit e557509
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
30 changes: 25 additions & 5 deletions lm_eval_harness/eval_lm_harness_big.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,20 @@ def count_params(module) -> int:
return sum(p.numel() for p in module.parameters())


def check_state_dict_keys(_keys, layer_idx, rank=0):
try:
assert len(_keys.unexpected_keys) == 0
if rank == 0:
print_header(f'*** All expected keys matched successfully {layer_idx} ***')
except Exception as e:
if rank == 0:
print(e)
print_header('*** Error: unexpected keys in checkpoint ***')
print(f'Unexpected keys at {layer_idx}:')
for k in _keys.unexpected_keys:
print(k)


def main():
sys.path.append(LM_EVALUATION_HARNESS_PATH)
from lm_eval import evaluator
Expand Down Expand Up @@ -344,7 +358,8 @@ def main():
peft_gradient_checkpointing=not args.no_peft_grad_ckpt,
train_attention=False)
if True: # rank == 0:
if distill_config.trainer.name is not None or args.attn_mlp_checkpoint_path is not None:
# if distill_config.trainer.name is not None or args.attn_mlp_checkpoint_path is not None:
if distill_config.trainer.name is not None and args.attn_mlp_checkpoint_path is not None:
# if args.replicate == 64:
# distill_config.model_name = distill_config.model_name.replace(f'-se={args.seed}', '-se=0').replace(f'-s={args.seed}', '-s=0')
# else:
Expand All @@ -366,10 +381,15 @@ def main():
merge_loras=False,
peft_gradient_checkpointing=not args.no_peft_grad_ckpt)
if True: # rank == 0:
model = load_sharded_model_single_gpu(model, model_path=args.finetune_checkpoint_path, # None,
cfg=finetune_config, rank=rank)
if '.pt' in args.finetune_checkpoint_path:
with torch.no_grad():
_keys = model.load_state_dict(torch.load(args.finetune_checkpoint_path), strict=False)
check_state_dict_keys(_keys, 0)
else:
model = load_sharded_model_single_gpu(model, model_path=args.finetune_checkpoint_path, # None,
cfg=finetune_config, rank=rank)

if rank == 0:
if True: # if rank == 0:
print_header('** Sanity check model weights **')
for n, p in model.named_parameters():
# if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)):
Expand Down Expand Up @@ -421,4 +441,4 @@ def main():


if __name__ == '__main__':
main()
main()
19 changes: 19 additions & 0 deletions lm_eval_harness/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from src.model.modeling_mistral import LooooolcatsMistralForCausalLM as LOOOOOLCATS_MISTRAL_MODEL_CLASS

from src.model.modeling_llama_sharded import ShardedLolcatsLlamaForCausalLM as SHARDED_LOLCATS_LLAMA_MODEL_CLASS
from src.model.modeling_llama_sharded_roll import ShardedRollLolcatsLlamaForCausalLM as SHARDED_ROLL_LOLCATS_LLAMA_MODEL_CLASS


class LolcatsLlamaForCausalLM(AutoCausalLM):
Expand Down Expand Up @@ -63,6 +64,24 @@ def add_special_tokens(self) -> bool:
return self._add_special_tokens
else:
return False


class ShardedRollLolcatsLlamaForCausalLM(AutoCausalLM):
"""
Wrapper for Llama or Mistral-like autoregressive language model
"""
AUTO_MODEL_CLASS = SHARDED_ROLL_LOLCATS_LLAMA_MODEL_CLASS
@property
def add_special_tokens(self) -> bool:
"""Whether to include special tokens in encoded text. This should be
determined by whether or not the model was trained with special tokens.
TODO: Remove these conditionals once HuggingFace supports a way to
check whether or not an arbitrary model was trained with special tokens.
"""
if self._add_special_tokens is not None:
return self._add_special_tokens
else:
return False


class LooooolcatsLlamaForCausalLM(AutoCausalLM):
Expand Down

0 comments on commit e557509

Please sign in to comment.