Skip to content

Fintune part of a whole embeding parameters. #5231

Open
@CongHan0808

Description

I add 100 new tokens to the vocabulary and corresponding embedings. I try to only fintune these embeddings and fix raw tokens with pretrained weights. I follow #4192, then use safe_get_full_grad and safe_set_full_fp32_param to get and modify the grad of the parameter. But all weights of the parameters are updated
There are my code:

model_engine.backward(total_loss)
textembeds_masks = torch.zeros_like(model_engine.in_adaptor.text_embed.weight).to(device=model_engine.local_rank)
textembeds_masks[VOCAB_SIZE_SRC+1,:] = 1
with torch.no_grad():
    for p_name,param in model_engine.named_parameters():
        if "in_adaptor.text_embed.weight" in p_name:
            if param.grad is not None:
           
                hp_grad = safe_get_full_grad(param)
                exp_avg = safe_get_full_optimizer_state(param, "exp_avg")
                exp_avg_sq = safe_get_full_optimizer_state(param, "exp_avg_sq")
                # hp_grad.copy_(hp_grad.data*textembeds_masks)
                
                safe_set_full_fp32_param(param, hp_grad.data*textembeds_masks)
                safe_set_full_optimizer_state(param, exp_avg.data * textembeds_masks, "exp_avg")
                safe_set_full_optimizer_state(param, exp_avg_sq.data * textembeds_masks, "exp_avg_sq")
model_engine.step()
scheduler(step)

After some checkpoints, the raw tokens' weights of in_adaptor.text_embed.weight in different cks are different. How should I change my code to keep the raw tokens' weights the same and only fintune the new tokens' weights.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions